tilelang.language.gemm_op

GEMM (General Matrix Multiplication) operators exposed on the TileLang language surface.

Functions

gemm_v1(A, B, C[, transpose_A, transpose_B, policy, ...])

GEMM v1: use op tl.gemm.

gemm_v2(A, B, C[, transpose_A, transpose_B, policy, ...])

GEMM v2: use op tl.gemm_py.

gemm(A, B, C[, transpose_A, transpose_B, policy, ...])

TileLang GEMM operator.

Module Contents

tilelang.language.gemm_op.gemm_v1(A, B, C, transpose_A=False, transpose_B=False, policy=GemmWarpPolicy.Square, clear_accum=False, k_pack=1, wg_wait=0, mbar=None)

GEMM v1: use op tl.gemm.

参数:
  • A (tilelang._typing.BufferLikeType)

  • B (tilelang._typing.BufferLikeType)

  • C (tilelang._typing.BufferLikeType)

  • transpose_A (bool)

  • transpose_B (bool)

  • policy (tilelang.tileop.base.GemmWarpPolicy)

  • clear_accum (bool)

  • k_pack (int)

  • wg_wait (int)

  • mbar (tilelang._typing.BarrierType | None)

返回类型:

tvm.tir.PrimExpr

tilelang.language.gemm_op.gemm_v2(A, B, C, transpose_A=False, transpose_B=False, policy=GemmWarpPolicy.Square, clear_accum=False, k_pack=1, wg_wait=0, mbar=None)

GEMM v2: use op tl.gemm_py.

参数:
  • A (tilelang._typing.BufferLikeType)

  • B (tilelang._typing.BufferLikeType)

  • C (tilelang._typing.BufferLikeType)

  • transpose_A (bool)

  • transpose_B (bool)

  • policy (tilelang.tileop.base.GemmWarpPolicy)

  • clear_accum (bool)

  • k_pack (int)

  • wg_wait (int)

  • mbar (tilelang._typing.BarrierType | None)

返回类型:

tvm.tir.PrimExpr

tilelang.language.gemm_op.gemm(A, B, C, transpose_A=False, transpose_B=False, policy=GemmWarpPolicy.Square, clear_accum=False, k_pack=1, wg_wait=0, mbar=None)

TileLang GEMM operator.

参数:
  • A (BufferLikeType, i.e. Buffer | BufferLoad | BufferRegion, or Var) -- Input buffer A.

  • B (BufferLikeType) -- Input buffer B.

  • C (BufferLikeType) -- Output buffer C.

  • transpose_A (bool) -- Whether to transpose A. Defaults to False.

  • transpose_B (bool) -- Whether to transpose B. Defaults to False.

  • policy (GemmWarpPolicy) -- GEMM warp partition policy.

  • clear_accum (bool) -- Whether to clear the accumulator.

  • k_pack (int) -- Numbers of packed matrix cores, for ROCm only. Defaults to 1.

  • wg_wait (int) -- Int identifier of the warpgroup MMA batch to wait on.. Defaults to 0.

  • mbar (BarrierType, i.e. Buffer | BufferLoad, or Var, optional) -- Mbarrier in Blackwell. Defaults to None.

返回:

A handle to the GEMM operation.

返回类型:

tir.Call