tilelang.carver.template.flashattention

類別

FlashAttentionTemplate

Base class template for hardware-aware configurations.

Module Contents

class tilelang.carver.template.flashattention.FlashAttentionTemplate

Bases: tilelang.carver.template.base.BaseTemplate

Base class template for hardware-aware configurations. This serves as an abstract base class (ABC) that defines the structure for subclasses implementing hardware-specific optimizations.

batch_size: int = 1
num_heads: int = 1
head_dim: int = 1
seq_length: int = 1
seq_kv_length: int = 1
is_causal: bool = False
in_dtype: str = 'float16'
out_dtype: str = 'float16'
accum_dtype: str = 'float16'
get_hardware_aware_configs(arch=None, topk=10)

Retrieves optimized hardware-aware configurations.

參數:
  • arch (TileDevice, optional) -- The target hardware architecture.

  • topk (int, optional) -- Number of top configurations to consider.

回傳:

A list of optimization hints for hardware acceleration.

回傳型別:

List[Hint]

initialize_function()

Defines and initializes the matrix multiplication computation.

This method sets up placeholders for input matrices, computes the matrix multiplication using TVM's compute API, and optionally applies bias and type casting.

引發:

AssertionError -- If M, N, or K are not positive integers.

回傳型別:

None

params_as_dict()

Returns the template parameters as a dictionary.

回傳:

Dictionary containing template parameter values.

回傳型別:

dict

property class_attributes

Returns the class attributes in dictionary form.

回傳:

Dictionary of class attributes.

回傳型別:

dict

__repr__()

Returns a string representation of the class instance.

回傳:

A formatted string representation of the class.

回傳型別:

str