tilelang.tileop.gemm.gemm_base

Classes

GemmBase

Base class for GEMM tile operators.

Module Contents

class tilelang.tileop.gemm.gemm_base.GemmBase

Base class for GEMM tile operators.

Classifies the GEMM variant by the memory scopes of operands A and B (SS, SR, RS, TS, RR) and provides common property accessors for the underlying gemm_node IR node.

gemm_node: tvm.ir.base.Node
abstractmethod infer_layout(target, thread_nums)
参数:
  • target (tvm.target.Target)

  • thread_nums (int)

abstractmethod lower(layout_map, target, thread_bounds, thread_var)
参数:
  • layout_map (dict)

  • target (tvm.target.Target)

  • thread_bounds (tvm.ir.Range)

  • thread_var (tvm.tir.Var)

is_gemm_ss()

Return True if both A and B are in shared memory (SS variant).

返回类型:

bool

is_gemm_sr()

Return True if A is in shared memory and B is in registers (SR variant).

返回类型:

bool

is_gemm_rs()

Return True if A is in registers and B is in shared memory (RS variant).

返回类型:

bool

is_gemm_ts()

Return True if A is in tensor memory and B is in shared memory (TS variant).

返回类型:

bool

is_gemm_rr()

Return True if both A and B are in registers (RR variant).

返回类型:

bool

property M: int
返回类型:

int

property N: int
返回类型:

int

property K: int
返回类型:

int

property trans_A: bool
返回类型:

bool

property trans_B: bool
返回类型:

bool

property in_dtype: str

Input data type for the multiplication.

For the TS variant, A resides in TMEM with the accumulator dtype, so the actual input dtype is derived from B.

返回类型:

str

property accum_dtype: str
返回类型:

str

property chunk: int
返回类型:

int

property A: tvm.tir.Buffer
返回类型:

tvm.tir.Buffer

property B: tvm.tir.Buffer
返回类型:

tvm.tir.Buffer

property C: tvm.tir.Buffer
返回类型:

tvm.tir.Buffer

property ARegion
property BRegion
property CRegion
property stride_A: int
返回类型:

int

property stride_B: int
返回类型:

int

property offset_A: int
返回类型:

int

property offset_B: int
返回类型:

int

property clear_accum: tvm.ir.PrimExpr
返回类型:

tvm.ir.PrimExpr

property k_pack: int
返回类型:

int

property wg_wait: int
返回类型:

int

property policy: tilelang.tileop.base.GemmWarpPolicy
返回类型:

tilelang.tileop.base.GemmWarpPolicy

property mbarptr: tvm.ir.PrimExpr
返回类型:

tvm.ir.PrimExpr

property mbar: tvm.tir.BufferLoad | None
返回类型:

tvm.tir.BufferLoad | None

property C_coords
get_region_base_offsets(region)

Get the base offset (start index) for each dimension from a BufferRegion.

For example, if region is A_shared[ko % 2, 0:128, 0:64], this returns [ko % 2, 0, 0]

参数:

region -- BufferRegion object

返回:

List of PrimExpr representing the base offset for each dimension

property A_base_offsets

Get base offsets for each dimension of A region

property B_base_offsets

Get base offsets for each dimension of B region

property C_base_offsets

Get base offsets for each dimension of C region