tilelang.carver.template.matmul

類別

MatmulTemplate

A template for matrix multiplication (MatMul).

Module Contents

class tilelang.carver.template.matmul.MatmulTemplate

Bases: tilelang.carver.template.base.BaseTemplate

A template for matrix multiplication (MatMul).

This class defines the computation for a matrix-matrix multiplication with configurable parameters such as transposition, data types, and bias addition.

M

Number of rows in matrix A and matrix C.

Type:

int

N

Number of columns in matrix B and matrix C.

Type:

int

K

Number of columns in matrix A and rows in matrix B.

Type:

int

trans_A

Whether to transpose matrix A before multiplication.

Type:

bool

trans_B

Whether to transpose matrix B before multiplication.

Type:

bool

in_dtype

Data type of input matrices.

Type:

str

out_dtype

Data type of output matrix.

Type:

str

accum_dtype

Data type used for accumulation.

Type:

str

with_bias

Whether to add a bias term.

Type:

bool

M: int = None
N: int = None
K: int = None
trans_A: bool = False
trans_B: bool = True
in_dtype: str = 'float16'
out_dtype: str = 'float16'
accum_dtype: str = 'float16'
with_bias: bool = False
get_hardware_aware_configs(arch=None, topk=10)

Retrieves optimized hardware-aware configurations.

參數:
  • arch (TileDevice, optional) -- The target hardware architecture.

  • topk (int, optional) -- Number of top configurations to consider.

回傳:

A list of optimization hints for hardware acceleration.

回傳型別:

List[Hint]

initialize_function()

Defines and initializes the matrix multiplication computation.

This method sets up placeholders for input matrices, computes the matrix multiplication using TVM's compute API, and optionally applies bias and type casting.

引發:

AssertionError -- If M, N, or K are not positive integers.

回傳型別:

None

params_as_dict()

Returns the template parameters as a dictionary.

回傳:

Dictionary containing template parameter values.

回傳型別:

dict

property class_attributes

Returns the class attributes in dictionary form.

回傳:

Dictionary of class attributes.

回傳型別:

dict

__repr__()

Returns a string representation of the class instance.

回傳:

A formatted string representation of the class.

回傳型別:

str