tilelang.carver.template.flashattention

Classes

FlashAttentionTemplate

Base class template for hardware-aware configurations.

Module Contents

class tilelang.carver.template.flashattention.FlashAttentionTemplate

Bases: tilelang.carver.template.base.BaseTemplate

Base class template for hardware-aware configurations. This serves as an abstract base class (ABC) that defines the structure for subclasses implementing hardware-specific optimizations.

batch_size: int = 1
num_heads: int = 1
head_dim: int = 1
seq_length: int = 1
seq_kv_length: int = 1
is_causal: bool = False
in_dtype: str = 'float16'
out_dtype: str = 'float16'
accum_dtype: str = 'float16'
get_hardware_aware_configs(arch=None, topk=10)

Retrieves optimized hardware-aware configurations.

参数:
  • arch (TileDevice, optional) -- The target hardware architecture.

  • topk (int, optional) -- Number of top configurations to consider.

返回:

A list of optimization hints for hardware acceleration.

返回类型:

List[Hint]

initialize_function()

Defines and initializes the matrix multiplication computation.

This method sets up placeholders for input matrices, computes the matrix multiplication using TVM's compute API, and optionally applies bias and type casting.

抛出:

AssertionError -- If M, N, or K are not positive integers.

返回类型:

None

params_as_dict()

Returns the template parameters as a dictionary.

返回:

Dictionary containing template parameter values.

返回类型:

dict

property class_attributes

Returns the class attributes in dictionary form.

返回:

Dictionary of class attributes.

返回类型:

dict

__repr__()

Returns a string representation of the class instance.

返回:

A formatted string representation of the class.

返回类型:

str