tilelang.tileop.base¶
Classes¶
Enumeration for GEMM Warp Partitioning Policies. |
Module Contents¶
- class tilelang.tileop.base.GemmWarpPolicy¶
Bases:
enum.IntEnumEnumeration for GEMM Warp Partitioning Policies.
- Square = 0¶
- FullRow = 1¶
- FullCol = 2¶
- is_square()¶
Check if the policy is a square partitioning.
- 返回:
True if the policy is square, False otherwise.
- 返回类型:
- is_full_row()¶
Check if the policy is a full row partitioning.
- 返回:
True if the policy is full row, False otherwise.
- 返回类型:
- is_full_col()¶
Check if the policy is a full column partitioning.
- 返回:
True if the policy is full column, False otherwise.
- 返回类型:
- static to_prime_factors(num)¶
Compute the prime factorization of a given number.
- 参数:
num (int) -- The number to factorize.
- 返回:
A list of prime factors of the number.
- 返回类型:
list
- compute_warp_partition(M, N, num_warps)¶
Compute the warp partition (m_warp, n_warp) based on the given policy.
- 参数:
M (int) -- The number of rows in the GEMM workload.
N (int) -- The number of columns in the GEMM workload.
num_warps (int) -- The total number of warps available.
- 返回:
A tuple (m_warp, n_warp) representing the partitioning of warps.
- 返回类型:
tuple
- 抛出:
ValueError -- If the policy is invalid or the partitioning fails.
AssertionError -- If M or N is not divisible by the required factor for FullRow or FullCol policies.
- classmethod from_warp_partition(m_warp, n_warp)¶
Determine the warp policy based on the given warp partitioning.
- 参数:
m_warp (int) -- Number of warps in the row dimension
n_warp (int) -- Number of warps in the column dimension
- 返回:
The corresponding warp policy
- 返回类型:
示例
>>> GemmWarpPolicy.from_block_row_cols(4, 1) # All warps in rows GemmWarpPolicy.FullRow >>> GemmWarpPolicy.from_block_row_cols(1, 4) # All warps in columns GemmWarpPolicy.FullCol >>> GemmWarpPolicy.from_block_row_cols(2, 2) # Balanced distribution GemmWarpPolicy.Square