tilelang.intrinsics.mfma_macro_generator

Attributes

類別

MatrixCoreIntrinEmitter

To eliminate Python syntax within TIR Macro.

MatrixCorePreshuffleIntrinEmitter

To eliminate Python syntax within TIR Macro.

Module Contents

tilelang.intrinsics.mfma_macro_generator.lift
class tilelang.intrinsics.mfma_macro_generator.MatrixCoreIntrinEmitter(a_dtype=T.float16, b_dtype=T.float16, accum_dtype=T.float16, a_transposed=False, b_transposed=False, block_row_warps=2, block_col_warps=2, warp_row_tiles=8, warp_col_tiles=8, chunk=16, reduce_k=1, num_elems_per_byte=1, k_pack=None, is_m_first=False, b_preshuffle=False, thread_var=None, target=None)

To eliminate Python syntax within TIR Macro.

參數:
  • a_dtype (str)

  • b_dtype (str)

  • accum_dtype (str)

  • a_transposed (bool)

  • b_transposed (bool)

  • block_row_warps (int)

  • block_col_warps (int)

  • warp_row_tiles (int)

  • warp_col_tiles (int)

  • chunk (int)

  • reduce_k (int)

  • num_elems_per_byte (int)

  • k_pack (int | None)

  • is_m_first (bool | None)

  • b_preshuffle (bool | None)

  • thread_var (tvm.tir.Var | None)

  • target (tvm.target.Target | None)

M_DIM = 16
N_DIM = 16
WARP_SIZE = 64
dtype_abbrv
k_pack = 1
is_m_first = False
a_dtype
b_dtype
accum_dtype
a_transposed = False
b_transposed = False
target = None
block_row_warps = 2
block_col_warps = 2
warp_row_tiles = 8
warp_col_tiles = 8
chunk = 16
warp_rows = 0
warp_cols = 0
reduce_k = 1
threads = 256
num_elems_per_byte = 1
thread_var = None
get_ldmatrix_index_map(is_b=False)
get_store_index_map(inverse=False)
參數:

inverse (bool)

回傳型別:

tvm.tir.IndexMap

get_thread_binding()
extract_thread_binding(thread_id, is_m_first=None)

is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)]

回傳型別:

tuple[tvm.tir.PrimExpr, tvm.tir.PrimExpr, tvm.tir.PrimExpr]

ldmatrix_a(A_local_buf, A_shared_buf, ki, rk=0)
參數:

A_shared_buf (tvm.tir.Buffer | tvm.tir.BufferRegion)

ldmatrix_b(B_local_buf, B_shared_buf, ki, rk=0)
參數:

B_shared_buf (tvm.tir.Buffer | tvm.tir.BufferRegion)

mfma(A_local_buf, B_local_buf, C_local_buf, k_inner=0)
參數:
  • A_local_buf (tvm.tir.Buffer)

  • B_local_buf (tvm.tir.Buffer)

  • C_local_buf (tvm.tir.Buffer)

  • k_inner (tvm.tir.PrimExpr | None)

stmatrix(C_local_buf, C_buf, pid_m=None, pid_n=None)
make_mfma_load_layout(local_buf, matrix='A')

Create a layout function for storing MFMA results into a fragment buffer.

參數:
  • local_buf (tir.Buffer) -- The local buffer representing a fragment of a matrix.

  • matrix (Literal['A', 'B'])

回傳:

A fragment object that describes how threads and indices in local_buf are laid out.

回傳型別:

T.Fragment

引發:

AssertionError -- If local_buf is not detected to be a fragment buffer.

make_mfma_store_layout(local_buf)

Create a layout function for storing MFMA results into a fragment buffer.

參數:

local_buf (tir.Buffer) -- The local buffer representing a fragment of a matrix.

回傳:

A fragment object that describes how threads and indices in local_buf are laid out.

回傳型別:

T.Fragment

引發:

AssertionError -- If local_buf is not detected to be a fragment buffer.

class tilelang.intrinsics.mfma_macro_generator.MatrixCorePreshuffleIntrinEmitter(a_dtype=T.float16, b_dtype=T.float16, accum_dtype=T.float16, a_transposed=False, b_transposed=False, block_row_warps=2, block_col_warps=2, warp_row_tiles=8, warp_col_tiles=8, chunk=16, reduce_k=1, num_elems_per_byte=1, k_pack=None, is_m_first=False, a_preshuffle=False, b_preshuffle=False, thread_var=None, target=None)

Bases: MatrixCoreIntrinEmitter

To eliminate Python syntax within TIR Macro.

參數:
  • a_dtype (str)

  • b_dtype (str)

  • accum_dtype (str)

  • a_transposed (bool)

  • b_transposed (bool)

  • block_row_warps (int)

  • block_col_warps (int)

  • warp_row_tiles (int)

  • warp_col_tiles (int)

  • chunk (int)

  • reduce_k (int)

  • num_elems_per_byte (int)

  • k_pack (int | None)

  • is_m_first (bool | None)

  • a_preshuffle (bool | None)

  • b_preshuffle (bool | None)

  • thread_var (tvm.tir.Var | None)

  • target (tvm.target.Target | None)

ldmatrix_a(A_local_buf, A_buf, ki, rk=0, pid_m=None, pid_n=None)
ldmatrix_b(B_local_buf, B_buf, ki, rk=0, pid_m=None, pid_n=None)