tilelang.carver.roller.hint

Hint definition for schedule

類別

TensorCoreExtraConfig

This class is used to store extra information for tensorcore

Stride

Manages stride information for a given axis of a tensor.

TileDict

Manages tiling information and configurations for computational tasks.

IntrinInfo

The information of tensorcore intrinsic related information

Hint

Central configuration class for managing various parameters of computational tasks.

Module Contents

class tilelang.carver.roller.hint.TensorCoreExtraConfig(AS_shape, BS_shape, AF_shape, BF_shape, tc_axis)

This class is used to store extra information for tensorcore

參數:
  • AS_shape (tuple[int])

  • BS_shape (tuple[int])

  • AF_shape (tuple[int])

  • BF_shape (tuple[int])

  • tc_axis (tuple[int])

AS_shape: tuple[int]
BS_shape: tuple[int]
AF_shape: tuple[int]
BF_shape: tuple[int]
tc_axis: tuple[int]
class tilelang.carver.roller.hint.Stride(stride=1, ax=-1)

Manages stride information for a given axis of a tensor.

參數:
  • stride (int)

  • ax (int)

property ax: int
回傳型別:

int

property stride: int
回傳型別:

int

compute_strides_from_shape(shape)
參數:

shape (list[int])

回傳型別:

list[int]

compute_elements_from_shape(shape)
參數:

shape (list[int])

回傳型別:

int

is_valid()
回傳型別:

bool

__repr__()
回傳型別:

str

class tilelang.carver.roller.hint.TileDict(output_tile)

Manages tiling information and configurations for computational tasks.

output_tile
tile_map
rstep_map
cached_tensors_map
output_strides_map
tensor_strides_map
traffic = -1
smem_cost = -1
block_per_SM = -1
num_wave = -1
grid_size = -1
valid = True
get_tile(func)
回傳型別:

list[int]

get_rstep(node)
回傳型別:

dict[str, int]

__hash__()
回傳型別:

int

class tilelang.carver.roller.hint.IntrinInfo(in_dtype, out_dtype, trans_b, input_transform_kind=0, weight_transform_kind=0)

The information of tensorcore intrinsic related information

參數:
  • in_dtype (str)

  • out_dtype (str)

  • trans_b (bool)

  • input_transform_kind (int)

  • weight_transform_kind (int)

in_dtype
out_dtype
trans_a = False
trans_b
input_transform_kind = 0
weight_transform_kind = 0
__repr__()
回傳型別:

str

is_input_8bit()
回傳型別:

bool

property smooth_a: bool
回傳型別:

bool

property smooth_b: bool
回傳型別:

bool

property inter_transform_a: bool
回傳型別:

bool

property inter_transform_b: bool
回傳型別:

bool

class tilelang.carver.roller.hint.Hint

Central configuration class for managing various parameters of computational tasks.

arch = None
use_tc = None
block = []
thread = []
warp = []
rstep = []
reduce_thread = []
rasterization_plan
cached_tensors = []
output_strides
schedule_stages = None
block_reduction_depth: int = None
split_k_factor: int = 1
vectorize: dict[str, int]
pipeline_stage = 1
use_async = False
opt_shapes: dict[str, int]
intrin_info
shared_scope: str = 'shared'
pass_context: dict
to_dict()
回傳型別:

dict

classmethod from_dict(dic)
參數:

dic (dict)

回傳型別:

Hint

tensorcore_legalization()
property raxis_order: list[int]
回傳型別:

list[int]

property step: list[int]
回傳型別:

list[int]

__repr__()
回傳型別:

str

complete_config(node)
參數:

node (tilelang.carver.roller.PrimFuncNode)