tilelang.intrinsics.wmma_macro_generator

WMMA intrinsic emitter for AMD RDNA architectures (gfx11 / gfx12).

Only supports the f16->f32, 16x16x16 variant with warp-size=32.

Thread-data mapping (per AMDGPU ISA):
gfx11:
  • A/B: duplicated across the two half-waves, so each logical input fragment is distributed over an effective wave size of 16 lanes.

  • C/D: distributed over the full wave32 output layout.

gfx12:
  • A/B: distributed over the full wave32 input layout.

  • C/D: distributed over the full wave32 output layout.

Attributes

Classes

WMMAIntrinEmitter

Intrinsic emitter for AMD RDNA WMMA (16x16x16, warp-size=32).

Module Contents

tilelang.intrinsics.wmma_macro_generator.lift
class tilelang.intrinsics.wmma_macro_generator.WMMAIntrinEmitter(a_dtype='float16', b_dtype='float16', accum_dtype='float32', a_transposed=False, b_transposed=False, block_row_warps=2, block_col_warps=2, warp_row_tiles=16, warp_col_tiles=16, chunk=16, k_pack=1, thread_var=None, target=None)

Intrinsic emitter for AMD RDNA WMMA (16x16x16, warp-size=32).

Supports:
  • fp16 -> fp32 (f32_16x16x16_f16_w32, with _gfx12 codegen suffix on gfx12)

参数:
  • a_dtype (str)

  • b_dtype (str)

  • accum_dtype (str)

  • a_transposed (bool)

  • b_transposed (bool)

  • block_row_warps (int)

  • block_col_warps (int)

  • warp_row_tiles (int)

  • warp_col_tiles (int)

  • chunk (int)

  • k_pack (int)

  • thread_var (tvm.tir.Var | None)

  • target (tvm.target.Target | None)

M_DIM = 16
N_DIM = 16
K_DIM = 16
WARP_SIZE = 32
a_dtype = 'float16'
b_dtype = 'float16'
accum_dtype = 'float32'
a_transposed = False
b_transposed = False
block_row_warps = 2
block_col_warps = 2
warp_row_tiles = 16
warp_col_tiles = 16
chunk = 16
k_pack = 1
thread_var = None
target = None
rdna_gen
micro_size_x = 16
micro_size_y = 16
micro_size_k = 16
local_size_a = 16
local_size_b = 16
local_size_out = 8
warp_rows = 1
warp_cols = 1
threads = 128
a_fragment_forward_fn
b_fragment_forward_fn
fragment_replicate = 2
store_index_map_fn
wmma_shape = 'f32_16x16x16_f16_w32'
get_thread_binding()
返回类型:

tvm.tir.PrimExpr

extract_thread_binding(thread_id)

Return (lane_id, warp_n, warp_m).

get_ldmatrix_index_map(is_b=False)

Return (forward, reverse) index maps for shared→local loading.

The actual layout functions are chosen during __init__ based on rdna_gen:
  • gfx11 uses half-wave duplicated A/B input layouts (32x16 naming).

  • gfx12 uses full wave32 A/B input layouts (32x8 naming).

参数:

is_b (bool)

get_store_index_map(inverse=False)

Return the store index map.

The forward map is (thread_id, local_id) -> (i, j), which is affine. The inverse map is (i, j) -> (thread_id, local_id).

参数:

inverse (bool)

返回类型:

tvm.tir.IndexMap

ldmatrix_a(A_local_buf, A_shared_buf, ki, rk=0)
ldmatrix_b(B_local_buf, B_shared_buf, ki, rk=0)
wmma(A_local_buf, B_local_buf, C_local_buf, k_inner=0)
参数:
  • A_local_buf (tvm.tir.Buffer)

  • B_local_buf (tvm.tir.Buffer)

  • C_local_buf (tvm.tir.Buffer)

  • k_inner (tvm.tir.PrimExpr | None)

stmatrix(C_local_buf, C_buf, pid_m=None, pid_n=None)
make_wmma_load_layout(local_buf, matrix='A')
参数:
  • local_buf (tvm.tir.Buffer)

  • matrix (Literal['A', 'B'])

返回类型:

tilelang.language.Fragment

make_wmma_store_layout(local_buf)
参数:

local_buf (tvm.tir.Buffer)

返回类型:

tilelang.language.Fragment