tilelang.engine.param

The profiler and convert to torch utils

類別

KernelParam

Represents parameters for a kernel operation, storing dtype and shape information.

CompiledArtifact

Represents a compiled kernel artifact containing both host and device code.

Module Contents

class tilelang.engine.param.KernelParam

Represents parameters for a kernel operation, storing dtype and shape information. Used to describe tensor or scalar parameters in TVM/PyTorch interop.

dtype: tilelang.tvm.DataType
shape: list[int | tvm.tir.Var]
classmethod from_buffer(buffer)

Creates a KernelParam instance from a TVM Buffer object.

參數:

buffer (tvm.tir.Buffer) -- TVM Buffer object containing dtype and shape information

回傳:

KernelParam instance with dtype directly from buffer and shape

引發:

ValueError -- If dimension type is not supported (not IntImm or Var)

classmethod from_var(var)

Creates a KernelParam instance from a TVM Variable object. Used for scalar parameters.

參數:

var (tvm.tir.Var) -- TVM Variable object containing dtype information

回傳:

KernelParam instance representing a scalar (empty shape)

is_scalar()

Checks if the parameter represents a scalar value.

回傳:

True if parameter has no dimensions (empty shape), False otherwise

回傳型別:

bool

is_unsigned()

Checks if the parameter represents an unsigned integer type.

回傳:

True if parameter is an unsigned integer type, False otherwise

回傳型別:

bool

is_float8()

Checks if the parameter represents a float8 type.

回傳:

True if parameter is a float8 type, False otherwise

回傳型別:

bool

is_float4()

Checks if the parameter represents a float4 type.

回傳:

True if parameter is a float4 type, False otherwise

回傳型別:

bool

is_boolean()

Checks if the parameter represents a boolean type.

回傳:

True if parameter is a boolean type, False otherwise

回傳型別:

bool

torch_dtype()

Converts the TVM DataType to PyTorch dtype.

This method is used when creating PyTorch tensors from KernelParam, as PyTorch's tensor creation functions require torch.dtype.

回傳:

Corresponding PyTorch dtype

回傳型別:

torch.dtype

範例

>>> param = KernelParam.from_buffer(buffer)
>>> tensor = torch.empty(shape, dtype=param.torch_dtype())
tilelang_dtype()

Converts the TVM DataType to TileLang dtype.

回傳:

Corresponding TileLang dtype

回傳型別:

T.dtype

class tilelang.engine.param.CompiledArtifact

Represents a compiled kernel artifact containing both host and device code. Stores all necessary components for kernel execution in the TVM runtime.

host_mod: tilelang.tvm.IRModule
device_mod: tilelang.tvm.IRModule
params: list[KernelParam]
kernel_source: str
rt_mod: tilelang.tvm.runtime.Module | None = None