tilelang.tileop.gemm.gemm_wgmma

類別

Module Contents

class tilelang.tileop.gemm.gemm_wgmma.GemmWGMMA

Bases: tilelang.tileop.gemm.gemm_base.GemmBase

infer_shared_layout(continuity)

Infer the swizzle layout for shared memory based on continuity.

WGMMA can directly use shared memory as input, so the swizzle layout must match the tensor core's access pattern. The swizzle granularity is determined by the continuous dimension size:

  • 128B swizzle (Full): continuity % (vectorized_size * 8) == 0

  • 64B swizzle (Half): continuity % (vectorized_size * 4) == 0

  • 32B swizzle (Quarter): continuity % (vectorized_size * 2) == 0

  • Linear (no swizzle): otherwise

See: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html

參數:

continuity (int)

回傳型別:

Callable[[tvm.tir.Buffer], tilelang.layout.Layout]

infer_layout(target, thread_nums)
參數:
  • target (tvm.target.Target)

  • thread_nums (int)

lower(layout_map, target, thread_bounds, thread_var)
參數:
  • layout_map (dict)

  • target (tvm.target.Target)

  • thread_bounds (tvm.ir.Range)

  • thread_var (tvm.tir.Var)

is_gemm_ss()
回傳型別:

bool

is_gemm_sr()
回傳型別:

bool

is_gemm_rs()
回傳型別:

bool

is_gemm_rr()
回傳型別:

bool