tilelang.tileop.gemm.gemm_wgmma

類別

GemmWGMMA

Base class for GEMM tile operators.

Module Contents

class tilelang.tileop.gemm.gemm_wgmma.GemmWGMMA

基底類別:tilelang.tileop.gemm.gemm_base.GemmBase

Base class for GEMM tile operators.

Classifies the GEMM variant by the memory scopes of operands A and B (SS, SR, RS, TS, RR) and provides common property accessors for the underlying gemm_node IR node.

infer_shared_layout(continuity)

Infer the swizzle layout for shared memory based on continuity.

WGMMA can directly use shared memory as input, so the swizzle layout must match the tensor core's access pattern. The swizzle granularity is determined by the continuous dimension size:

  • 128B swizzle (Full): continuity % (vectorized_size * 8) == 0

  • 64B swizzle (Half): continuity % (vectorized_size * 4) == 0

  • 32B swizzle (Quarter): continuity % (vectorized_size * 2) == 0

  • Linear (no swizzle): otherwise

See: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html

參數:

continuity (int)

回傳值型別:

Callable[[tvm.tir.Buffer], tilelang.layout.Layout]

infer_layout(target, thread_nums)
參數:
  • target (tvm.target.Target)

  • thread_nums (int)

lower(layout_map, target, thread_bounds, thread_var, mbar_phase_expr=None)
參數:
  • layout_map (dict)

  • target (tvm.target.Target)

  • thread_bounds (tvm.ir.Range)

  • thread_var (tvm.tir.Var)

  • mbar_phase_expr (tvm.tir.PrimExpr | None)

is_gemm_ss()

Return True if both A and B are in shared memory (SS variant).

回傳值型別:

bool

is_gemm_sr()

Return True if A is in shared memory and B is in registers (SR variant).

回傳值型別:

bool

is_gemm_rs()

Return True if A is in registers and B is in shared memory (RS variant).

回傳值型別:

bool

is_gemm_rr()

Return True if both A and B are in registers (RR variant).

回傳值型別:

bool