tilelang.tileop.gemm_sp

Submodules

Classes

Functions

gemm_sp_py_infer_layout(gemm_sp_py, target, thread_bounds)

gemm_sp_py_lower(gemm_sp_py, target, thread_bounds, ...)

Package Contents

tilelang.tileop.gemm_sp.gemm_sp_py_infer_layout(gemm_sp_py, target, thread_bounds)
参数:
tilelang.tileop.gemm_sp.gemm_sp_py_lower(gemm_sp_py, target, thread_bounds, thread_var)
参数:
  • gemm_sp_py (gemm_sp_mma.GemmSPMMA)

  • target (tvm.target.Target)

  • thread_bounds (tvm.ir.Range)

  • thread_var (tvm.tir.Var)

class tilelang.tileop.gemm_sp.GemmSPPy

Bases: tvm.ir.base.Node, tvm.runtime.Scriptable

A: tvm.tir.Buffer
E: tvm.tir.Buffer
B: tvm.tir.Buffer
C: tvm.tir.Buffer
APtr: tvm.tir.PrimExpr
EPtr: tvm.tir.PrimExpr
BPtr: tvm.tir.PrimExpr
CPtr: tvm.tir.PrimExpr
M: int
N: int
K: int
trans_A: bool
trans_B: bool
stride_A: int
stride_B: int
offset_A: int
offset_B: int
clear_accum: bool
k_pack: int
wg_wait: int
policy: tilelang.tileop.base.GemmWarpPolicy
infer_layout(target, thread_nums)
参数:
  • target (tvm.target.Target)

  • thread_nums (int)

lower(target, thread_nums, thread_var)
参数:
  • target (tvm.target.Target)

  • thread_nums (int)

  • thread_var (tvm.tir.Var)