tilelang.contrib.cutedsl.gemm_v1

類別

函式

make_aligned_tensor(ptr, layout, align_bytes[, swizzle])

gemm_ss(M, N, K, warp_m, warp_n, trans_A, trans_B, ...)

GEMM with both A and B from shared memory

gemm_rs(M, N, K, warp_m, warp_n, trans_A, trans_B, ...)

GEMM with A from register/fragment and B from shared memory

gemm_sr(M, N, K, warp_m, warp_n, trans_A, trans_B, ...)

GEMM with A from shared memory and B from register/fragment

gemm_rr(M, N, K, warp_m, warp_n, trans_A, trans_B, ...)

GEMM with both A and B from register/fragment

Module Contents

tilelang.contrib.cutedsl.gemm_v1.make_aligned_tensor(ptr, layout, align_bytes, swizzle=False)
參數:
  • ptr (cutlass.cute.Pointer)

  • layout (cutlass.cute.Layout)

  • align_bytes (int)

tilelang.contrib.cutedsl.gemm_v1.gemm_ss(M, N, K, warp_m, warp_n, trans_A, trans_B, clear_accum, stride_A, stride_B, offset_A, offset_B, use_wgmma=None, wg_wait=0, A_ptr=None, B_ptr=None, C_ptr=None)

GEMM with both A and B from shared memory

參數:
  • A_ptr (cutlass.cute.Pointer)

  • B_ptr (cutlass.cute.Pointer)

  • C_ptr (cutlass.cute.Pointer)

tilelang.contrib.cutedsl.gemm_v1.gemm_rs(M, N, K, warp_m, warp_n, trans_A, trans_B, clear_accum, stride_A, stride_B, offset_A, offset_B, use_wgmma=None, wg_wait=0, A_ptr=None, B_ptr=None, C_ptr=None)

GEMM with A from register/fragment and B from shared memory

參數:
  • A_ptr (cutlass.cute.Pointer)

  • B_ptr (cutlass.cute.Pointer)

  • C_ptr (cutlass.cute.Pointer)

tilelang.contrib.cutedsl.gemm_v1.gemm_sr(M, N, K, warp_m, warp_n, trans_A, trans_B, clear_accum, stride_A, stride_B, offset_A, offset_B, use_wgmma=None, wg_wait=0, A_ptr=None, B_ptr=None, C_ptr=None)

GEMM with A from shared memory and B from register/fragment

參數:
  • A_ptr (cutlass.cute.Pointer)

  • B_ptr (cutlass.cute.Pointer)

  • C_ptr (cutlass.cute.Pointer)

tilelang.contrib.cutedsl.gemm_v1.gemm_rr(M, N, K, warp_m, warp_n, trans_A, trans_B, clear_accum, stride_A, stride_B, offset_A, offset_B, use_wgmma=None, wg_wait=0, A_ptr=None, B_ptr=None, C_ptr=None)

GEMM with both A and B from register/fragment

參數:
  • A_ptr (cutlass.cute.Pointer)

  • B_ptr (cutlass.cute.Pointer)

  • C_ptr (cutlass.cute.Pointer)

class tilelang.contrib.cutedsl.gemm_v1.Gemm_SM80(M, N, K, warp_m, warp_n, trans_A, trans_B, clear_accum, stride_A, stride_B, offset_A, offset_B, A_type, B_type, C_type)
__call__(sA_ptr, sB_ptr, rC_ptr)

GEMM body: both A and B from shared memory

參數:
  • sA_ptr (cutlass.cute.Pointer)

  • sB_ptr (cutlass.cute.Pointer)

  • rC_ptr (cutlass.cute.Pointer)

body_rs(rA_ptr, sB_ptr, rC_ptr)

GEMM body_rs: A from register, B from shared memory

參數:
  • rA_ptr (cutlass.cute.Pointer)

  • sB_ptr (cutlass.cute.Pointer)

  • rC_ptr (cutlass.cute.Pointer)

body_sr(sA_ptr, rB_ptr, rC_ptr)

GEMM body_sr: A from shared memory, B from register

參數:
  • sA_ptr (cutlass.cute.Pointer)

  • rB_ptr (cutlass.cute.Pointer)

  • rC_ptr (cutlass.cute.Pointer)

class tilelang.contrib.cutedsl.gemm_v1.Gemm_SM90(M, N, K, warp_m, warp_n, trans_A, trans_B, clear_accum, stride_A, stride_B, offset_A, offset_B, A_type, B_type, C_type)
static make_tma_atom(tensor, smem_layout_staged, smem_tile, mcast_dim)
static get_tma_atom(tensor, tiler_mk, stages=1)
static make_smem_layout_AB(dtype, major_mode, tiler_mk, stages=1)
參數:

major_mode (cutlass.utils.LayoutEnum)

__call__(sA_ptr, sB_ptr, rC_ptr, wg_wait=0, clear_accum=False)
參數:
  • sA_ptr (cutlass.cute.Pointer)

  • sB_ptr (cutlass.cute.Pointer)

  • rC_ptr (cutlass.cute.Pointer)

  • wg_wait (cutlass.Constexpr)

  • clear_accum (cutlass.Constexpr)

body_rs(rA_ptr, sB_ptr, rC_ptr, wg_wait=0, clear_accum=False)

GEMM body_rs for SM90/Hopper: A from register, B from shared memory. Based on cute::tl_wgmma::GemmTensorOp::body_rs from gemm_sm90.h

參數:
  • rA_ptr (cutlass.cute.Pointer)

  • sB_ptr (cutlass.cute.Pointer)

  • rC_ptr (cutlass.cute.Pointer)

  • wg_wait (cutlass.Constexpr)

  • clear_accum (cutlass.Constexpr)