tilelang.intrinsics.mma_macro_generator

Attributes

類別

TensorCoreIntrinEmitter

To eliminate Python syntax within TIR Macro.

TensorCoreIntrinEmitterWithLadderTransform

To eliminate Python syntax within TIR Macro.

INT4TensorCoreIntrinEmitter

To eliminate Python syntax within TIR Macro.

INT4TensorCoreIntrinEmitterWithLadderTransform

To eliminate Python syntax within TIR Macro.

Module Contents

tilelang.intrinsics.mma_macro_generator.lift
class tilelang.intrinsics.mma_macro_generator.TensorCoreIntrinEmitter(a_dtype=T.float16, b_dtype=T.float16, accum_dtype=T.float16, a_transposed=False, b_transposed=False, block_row_warps=2, block_col_warps=2, warp_row_tiles=8, warp_col_tiles=8, chunk=16, reduce_k=1, num_elems_per_byte=1, is_m_first=False, thread_var=None)

To eliminate Python syntax within TIR Macro.

參數:
  • a_dtype (str)

  • b_dtype (str)

  • accum_dtype (str)

  • a_transposed (bool)

  • b_transposed (bool)

  • block_row_warps (int)

  • block_col_warps (int)

  • warp_row_tiles (int)

  • warp_col_tiles (int)

  • chunk (int)

  • reduce_k (int)

  • num_elems_per_byte (int)

  • is_m_first (bool | None)

  • thread_var (tvm.tir.Var | None)

M_DIM = 16
n_dim = 16
WARP_SIZE = 32
dtype_abbrv
is_m_first: bool = False
warp_rows: int = 1
warp_cols: int = 1
a_dtype
b_dtype
accum_dtype
a_transposed = False
b_transposed = False
block_row_warps = 2
block_col_warps = 2
warp_row_tiles = 8
warp_col_tiles = 8
chunk = 16
reduce_k = 1
threads = 128
num_elems_per_byte = 1
thread_var = None
get_thread_binding()
get_store_index_map(inverse=False)
參數:

inverse (bool)

回傳型別:

tvm.tir.IndexMap

extract_thread_binding(thread_id, is_m_first=None)

is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)]

參數:
  • thread_id (tvm.tir.PrimExpr)

  • is_m_first (bool | None)

回傳型別:

tuple[tvm.tir.PrimExpr, tvm.tir.PrimExpr, tvm.tir.PrimExpr]

ldmatrix_a(A_local_buf, A_shared_buf, ki, rk=0)
參數:
  • A_local_buf (tvm.tir.Buffer)

  • A_shared_buf (tvm.tir.Buffer | tvm.tir.BufferRegion)

  • ki (tvm.tir.PrimExpr)

  • rk (tvm.tir.PrimExpr | None)

ldmatrix_b(B_local_buf, B_shared_buf, ki, rk=0)
參數:
  • B_local_buf (tvm.tir.Buffer)

  • B_shared_buf (tvm.tir.Buffer | tvm.tir.BufferRegion)

  • ki (tvm.tir.PrimExpr)

  • rk (tvm.tir.PrimExpr | None)

mma(A_local_buf, B_local_buf, C_local_buf, k_inner=0)
參數:
  • A_local_buf (tvm.tir.Buffer)

  • B_local_buf (tvm.tir.Buffer)

  • C_local_buf (tvm.tir.Buffer)

  • k_inner (tvm.tir.PrimExpr | None)

stmatrix(C_local_buf, C_buf, pid_m=None, pid_n=None)
make_mma_load_layout(local_buf, matrix='A')

Create a layout function for storing MMA results into a fragment buffer. This layout is used in conjunction with inverse_mma_store_layout to map fragment indices to threads and local indices.

參數:
  • local_buf (tir.Buffer) -- The local buffer representing a fragment of a matrix.

  • matrix (Literal['A', 'B'])

回傳:

A fragment object that describes how threads and indices in local_buf are laid out.

回傳型別:

T.Fragment

引發:

AssertionError -- If local_buf is not detected to be a fragment buffer.

make_mma_store_layout(local_buf)

Create a layout function for storing MMA results into a fragment buffer. This layout is used in conjunction with inverse_mma_store_layout to map fragment indices to threads and local indices.

參數:

local_buf (tir.Buffer) -- The local buffer representing a fragment of a matrix.

回傳:

A fragment object that describes how threads and indices in local_buf are laid out.

回傳型別:

T.Fragment

引發:

AssertionError -- If local_buf is not detected to be a fragment buffer.

class tilelang.intrinsics.mma_macro_generator.TensorCoreIntrinEmitterWithLadderTransform(a_dtype=T.float16, b_dtype=T.float16, accum_dtype=T.float16, a_transposed=False, b_transposed=False, block_row_warps=2, block_col_warps=2, warp_row_tiles=8, warp_col_tiles=8, chunk=16, reduce_k=1, num_elems_per_byte=1, is_m_first=False, transform_kind_a=0, transform_kind_b=0)

Bases: TensorCoreIntrinEmitter

To eliminate Python syntax within TIR Macro. With Ladder Transform Plugin.

參數:
  • a_dtype (str)

  • b_dtype (str)

  • accum_dtype (str)

  • a_transposed (bool)

  • b_transposed (bool)

  • block_row_warps (int)

  • block_col_warps (int)

  • warp_row_tiles (int)

  • warp_col_tiles (int)

  • chunk (int)

  • reduce_k (int)

  • num_elems_per_byte (int)

  • is_m_first (bool | None)

  • transform_kind_a (int | tilelang.common.TransformKind)

  • transform_kind_b (int | tilelang.common.TransformKind)

ldmatrix_a(A_local_buf, A_shared_buf, ki, rk=0)
ldmatrix_b(B_local_buf, B_shared_buf, ki, rk=0)
mma(A_local_buf, B_local_buf, C_local_buf)
class tilelang.intrinsics.mma_macro_generator.INT4TensorCoreIntrinEmitter(a_dtype=T.float16, b_dtype=T.float16, accum_dtype=T.float16, a_transposed=False, b_transposed=False, block_row_warps=2, block_col_warps=2, warp_row_tiles=8, warp_col_tiles=8, chunk=16, reduce_k=1, num_elems_per_byte=1, is_m_first=False, thread_var=None)

Bases: TensorCoreIntrinEmitter

To eliminate Python syntax within TIR Macro.

參數:
  • a_dtype (str)

  • b_dtype (str)

  • accum_dtype (str)

  • a_transposed (bool)

  • b_transposed (bool)

  • block_row_warps (int)

  • block_col_warps (int)

  • warp_row_tiles (int)

  • warp_col_tiles (int)

  • chunk (int)

  • reduce_k (int)

  • num_elems_per_byte (int)

  • is_m_first (bool | None)

  • thread_var (tvm.tir.Var | None)

mma(A_local_buf, B_local_buf, C_local_buf)
class tilelang.intrinsics.mma_macro_generator.INT4TensorCoreIntrinEmitterWithLadderTransform(a_dtype=T.float16, b_dtype=T.float16, accum_dtype=T.float16, a_transposed=False, b_transposed=False, block_row_warps=2, block_col_warps=2, warp_row_tiles=8, warp_col_tiles=8, chunk=16, reduce_k=1, num_elems_per_byte=1, is_m_first=False, transform_kind_a=0, transform_kind_b=0)

Bases: TensorCoreIntrinEmitterWithLadderTransform

To eliminate Python syntax within TIR Macro. With Ladder Transform Plugin.

參數:
  • a_dtype (str)

  • b_dtype (str)

  • accum_dtype (str)

  • a_transposed (bool)

  • b_transposed (bool)

  • block_row_warps (int)

  • block_col_warps (int)

  • warp_row_tiles (int)

  • warp_col_tiles (int)

  • chunk (int)

  • reduce_k (int)

  • num_elems_per_byte (int)

  • is_m_first (bool | None)

  • transform_kind_a (int | tilelang.common.TransformKind)

  • transform_kind_b (int | tilelang.common.TransformKind)

mma(A_local_buf, B_local_buf, C_local_buf)