tilelang.intrinsics.mma_macro_generator

Attributes

Classes

TensorCoreIntrinEmitter

To eliminate Python syntax within TIR Macro.

TensorCoreIntrinEmitterWithLadderTransform

To eliminate Python syntax within TIR Macro.

INT4TensorCoreIntrinEmitter

To eliminate Python syntax within TIR Macro.

INT4TensorCoreIntrinEmitterWithLadderTransform

To eliminate Python syntax within TIR Macro.

Module Contents

tilelang.intrinsics.mma_macro_generator.lift
class tilelang.intrinsics.mma_macro_generator.TensorCoreIntrinEmitter(a_dtype=T.float16, b_dtype=T.float16, accum_dtype=T.float16, a_transposed=False, b_transposed=False, block_row_warps=2, block_col_warps=2, warp_row_tiles=8, warp_col_tiles=8, chunk=16, reduce_k=1, num_elems_per_byte=1, is_m_first=False, thread_var=None)

To eliminate Python syntax within TIR Macro.

参数:
  • a_dtype (str)

  • b_dtype (str)

  • accum_dtype (str)

  • a_transposed (bool)

  • b_transposed (bool)

  • block_row_warps (int)

  • block_col_warps (int)

  • warp_row_tiles (int)

  • warp_col_tiles (int)

  • chunk (int)

  • reduce_k (int)

  • num_elems_per_byte (int)

  • is_m_first (bool | None)

  • thread_var (tvm.tir.Var | None)

M_DIM = 16
n_dim = 16
WARP_SIZE = 32
dtype_abbrv
is_m_first: bool = False
warp_rows: int = 1
warp_cols: int = 1
a_dtype
b_dtype
accum_dtype
a_transposed = False
b_transposed = False
block_row_warps = 2
block_col_warps = 2
warp_row_tiles = 8
warp_col_tiles = 8
chunk = 16
reduce_k = 1
threads = 128
num_elems_per_byte = 1
thread_var = None
get_thread_binding()
get_store_index_map(inverse=False)
参数:

inverse (bool)

返回类型:

tvm.tir.IndexMap

extract_thread_binding(thread_id, is_m_first=None)

is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)]

参数:
  • thread_id (tvm.tir.PrimExpr)

  • is_m_first (bool | None)

返回类型:

tuple[tvm.tir.PrimExpr, tvm.tir.PrimExpr, tvm.tir.PrimExpr]

ldmatrix_a(A_local_buf, A_shared_buf, ki, rk=0)
参数:
  • A_local_buf (tvm.tir.Buffer)

  • A_shared_buf (tvm.tir.Buffer | tvm.tir.BufferRegion)

  • ki (tvm.tir.PrimExpr)

  • rk (tvm.tir.PrimExpr | None)

ldmatrix_b(B_local_buf, B_shared_buf, ki, rk=0)
参数:
  • B_local_buf (tvm.tir.Buffer)

  • B_shared_buf (tvm.tir.Buffer | tvm.tir.BufferRegion)

  • ki (tvm.tir.PrimExpr)

  • rk (tvm.tir.PrimExpr | None)

mma(A_local_buf, B_local_buf, C_local_buf, k_inner=0)
参数:
  • A_local_buf (tvm.tir.Buffer)

  • B_local_buf (tvm.tir.Buffer)

  • C_local_buf (tvm.tir.Buffer)

  • k_inner (tvm.tir.PrimExpr | None)

stmatrix(C_local_buf, C_buf, pid_m=None, pid_n=None)
make_mma_load_layout(local_buf, matrix='A')

Create a layout function for storing MMA results into a fragment buffer. This layout is used in conjunction with inverse_mma_store_layout to map fragment indices to threads and local indices.

参数:
  • local_buf (tir.Buffer) -- The local buffer representing a fragment of a matrix.

  • matrix (Literal['A', 'B'])

返回:

A fragment object that describes how threads and indices in local_buf are laid out.

返回类型:

T.Fragment

抛出:

AssertionError -- If local_buf is not detected to be a fragment buffer.

make_mma_store_layout(local_buf)

Create a layout function for storing MMA results into a fragment buffer. This layout is used in conjunction with inverse_mma_store_layout to map fragment indices to threads and local indices.

参数:

local_buf (tir.Buffer) -- The local buffer representing a fragment of a matrix.

返回:

A fragment object that describes how threads and indices in local_buf are laid out.

返回类型:

T.Fragment

抛出:

AssertionError -- If local_buf is not detected to be a fragment buffer.

class tilelang.intrinsics.mma_macro_generator.TensorCoreIntrinEmitterWithLadderTransform(a_dtype=T.float16, b_dtype=T.float16, accum_dtype=T.float16, a_transposed=False, b_transposed=False, block_row_warps=2, block_col_warps=2, warp_row_tiles=8, warp_col_tiles=8, chunk=16, reduce_k=1, num_elems_per_byte=1, is_m_first=False, transform_kind_a=0, transform_kind_b=0)

Bases: TensorCoreIntrinEmitter

To eliminate Python syntax within TIR Macro. With Ladder Transform Plugin.

参数:
  • a_dtype (str)

  • b_dtype (str)

  • accum_dtype (str)

  • a_transposed (bool)

  • b_transposed (bool)

  • block_row_warps (int)

  • block_col_warps (int)

  • warp_row_tiles (int)

  • warp_col_tiles (int)

  • chunk (int)

  • reduce_k (int)

  • num_elems_per_byte (int)

  • is_m_first (bool | None)

  • transform_kind_a (int | tilelang.common.TransformKind)

  • transform_kind_b (int | tilelang.common.TransformKind)

ldmatrix_a(A_local_buf, A_shared_buf, ki, rk=0)
ldmatrix_b(B_local_buf, B_shared_buf, ki, rk=0)
mma(A_local_buf, B_local_buf, C_local_buf)
class tilelang.intrinsics.mma_macro_generator.INT4TensorCoreIntrinEmitter(a_dtype=T.float16, b_dtype=T.float16, accum_dtype=T.float16, a_transposed=False, b_transposed=False, block_row_warps=2, block_col_warps=2, warp_row_tiles=8, warp_col_tiles=8, chunk=16, reduce_k=1, num_elems_per_byte=1, is_m_first=False, thread_var=None)

Bases: TensorCoreIntrinEmitter

To eliminate Python syntax within TIR Macro.

参数:
  • a_dtype (str)

  • b_dtype (str)

  • accum_dtype (str)

  • a_transposed (bool)

  • b_transposed (bool)

  • block_row_warps (int)

  • block_col_warps (int)

  • warp_row_tiles (int)

  • warp_col_tiles (int)

  • chunk (int)

  • reduce_k (int)

  • num_elems_per_byte (int)

  • is_m_first (bool | None)

  • thread_var (tvm.tir.Var | None)

mma(A_local_buf, B_local_buf, C_local_buf)
class tilelang.intrinsics.mma_macro_generator.INT4TensorCoreIntrinEmitterWithLadderTransform(a_dtype=T.float16, b_dtype=T.float16, accum_dtype=T.float16, a_transposed=False, b_transposed=False, block_row_warps=2, block_col_warps=2, warp_row_tiles=8, warp_col_tiles=8, chunk=16, reduce_k=1, num_elems_per_byte=1, is_m_first=False, transform_kind_a=0, transform_kind_b=0)

Bases: TensorCoreIntrinEmitterWithLadderTransform

To eliminate Python syntax within TIR Macro. With Ladder Transform Plugin.

参数:
  • a_dtype (str)

  • b_dtype (str)

  • accum_dtype (str)

  • a_transposed (bool)

  • b_transposed (bool)

  • block_row_warps (int)

  • block_col_warps (int)

  • warp_row_tiles (int)

  • warp_col_tiles (int)

  • chunk (int)

  • reduce_k (int)

  • num_elems_per_byte (int)

  • is_m_first (bool | None)

  • transform_kind_a (int | tilelang.common.TransformKind)

  • transform_kind_b (int | tilelang.common.TransformKind)

mma(A_local_buf, B_local_buf, C_local_buf)