tilelang.tileop.base

類別

GemmWarpPolicy

Enumeration for GEMM Warp Partitioning Policies.

Module Contents

class tilelang.tileop.base.GemmWarpPolicy

Bases: enum.IntEnum

Enumeration for GEMM Warp Partitioning Policies.

Square = 0
FullRow = 1
FullCol = 2
is_square()

Check if the policy is a square partitioning.

回傳:

True if the policy is square, False otherwise.

回傳型別:

bool

is_full_row()

Check if the policy is a full row partitioning.

回傳:

True if the policy is full row, False otherwise.

回傳型別:

bool

is_full_col()

Check if the policy is a full column partitioning.

回傳:

True if the policy is full column, False otherwise.

回傳型別:

bool

static to_prime_factors(num)

Compute the prime factorization of a given number.

參數:

num (int) -- The number to factorize.

回傳:

A list of prime factors of the number.

回傳型別:

list

compute_warp_partition(M, N, num_warps)

Compute the warp partition (m_warp, n_warp) based on the given policy.

參數:
  • M (int) -- The number of rows in the GEMM workload.

  • N (int) -- The number of columns in the GEMM workload.

  • num_warps (int) -- The total number of warps available.

回傳:

A tuple (m_warp, n_warp) representing the partitioning of warps.

回傳型別:

tuple

引發:
  • ValueError -- If the policy is invalid or the partitioning fails.

  • AssertionError -- If M or N is not divisible by the required factor for FullRow or FullCol policies.

classmethod from_warp_partition(m_warp, n_warp)

Determine the warp policy based on the given warp partitioning.

參數:
  • m_warp (int) -- Number of warps in the row dimension

  • n_warp (int) -- Number of warps in the column dimension

回傳:

The corresponding warp policy

回傳型別:

GemmWarpPolicy

範例

>>> GemmWarpPolicy.from_block_row_cols(4, 1)  # All warps in rows
GemmWarpPolicy.FullRow
>>> GemmWarpPolicy.from_block_row_cols(1, 4)  # All warps in columns
GemmWarpPolicy.FullCol
>>> GemmWarpPolicy.from_block_row_cols(2, 2)  # Balanced distribution
GemmWarpPolicy.Square