tilelang.jit.kernel

Attributes

Classes

JITKernel

A wrapper class for compiling and invoking TileLang (TVM TIR) functions as PyTorch-compatible functions.

Module Contents

tilelang.jit.kernel.logger
class tilelang.jit.kernel.JITKernel(func=None, out_idx=None, execution_backend='tvm_ffi', target='auto', target_host=None, verbose=False, pass_configs=None, from_database=False, compile_flags=None)

Bases: Generic[_P, _T]

A wrapper class for compiling and invoking TileLang (TVM TIR) functions as PyTorch-compatible functions.

参数:
  • func (tvm.tir.PrimFunc)

  • out_idx (list[int] | int)

  • execution_backend (Literal['tvm_ffi', 'cython', 'nvrtc', 'torch', 'cutedsl'])

  • target (str | tvm.target.Target)

  • target_host (str | tvm.target.Target)

  • verbose (bool)

  • pass_configs (dict[str, Any] | None)

  • from_database (bool)

  • compile_flags (list[str] | None)

artifact

The compiled artifact containing the runtime module and parameters.

Type:

CompiledArtifact

adapter

The adapter for the compiled function.

Type:

BaseKernelAdapter

torch_function

The compiled function that can be invoked as a PyTorch-compatible function.

Type:

Callable

prim_func: tvm.tir.PrimFunc = None
artifact: tilelang.engine.param.CompiledArtifact = None
adapter: tilelang.jit.adapter.BaseKernelAdapter = None
torch_function: Callable = None
latency: float = None
config: dict[str, Any] = None
ref_latency: float = None
execution_backend = 'tvm_ffi'
target_host = None
verbose = False
pass_configs = None
compile_flags
target
classmethod from_database(func, host_kernel_source, device_kernel_source, kernel_lib_path, params, target, target_host, out_idx, execution_backend, pass_configs=None, compile_flags=None)

Alternative constructor to create a TorchFunction directly from a database.

参数:
  • func (tvm.tir.PrimFunc)

  • host_kernel_source (str)

  • device_kernel_source (str)

  • kernel_lib_path (str)

  • params (list[tilelang.engine.param.KernelParam])

  • target (str | tvm.target.Target)

  • target_host (str | tvm.target.Target)

  • out_idx (list[int] | int)

  • execution_backend (Literal['tvm_ffi', 'cython', 'nvrtc', 'torch'])

  • pass_configs (dict[str, Any] | None)

  • compile_flags (list[str] | None)

__call__(*args, **kwds)

Invokes the compiled function with the given arguments.

参数:
  • *args (Any) -- Positional arguments for the function.

  • **kwds (Any) -- Keyword arguments for the function.

返回:

The result of the function execution.

返回类型:

Any

classmethod from_tilelang_function(tilelang_func, **kwargs)

Alternative constructor to create a TorchFunction directly from a TileLang PrimFunc.

参数:
  • tilelang_func (tvm.tir.PrimFunc) -- The TileLang (TVM TIR) function to compile.

  • **kwargs (dict) -- Additional keyword arguments to pass to the constructor.

返回:

An instance of TorchFunction wrapping the compiled function.

返回类型:

TorchFunction

get_profiler(tensor_supply_type=TensorSupplyType.Auto)

Creates a profiler to benchmark the compiled runtime module.

参数:

tensor_supply_type (TensorSupplyType, optional) -- The type of input tensors to supply for profiling (default: TensorSupplyType.Auto).

返回:

A Profiler instance for benchmarking the runtime module.

返回类型:

Profiler

get_kernel_source(kernel_only=True)

Returns the source code of the compiled kernel function.

返回:

The source code of the compiled kernel function.

返回类型:

str

参数:

kernel_only (bool)

get_host_source()

Returns the source code of the host function.

返回类型:

str

run_once(func=None)
参数:

func (Callable | None)

返回类型:

None

show_source(which='kernel')

Print generated source code to stdout.

参数:

which (Literal["kernel", "host", "both"], optional) -- Select which source to print. Defaults to "kernel".

返回类型:

None

示例

>>> jit_kernel.show_source()            # print kernel source
>>> jit_kernel.show_source("host")      # print host source
>>> jit_kernel.show_source("both")      # print both sources
export_sources(kernel_path=None, host_path=None)

Export generated source code to files.

参数:
  • kernel_path (Optional[str]) -- Destination file path to write the kernel source. If None, skips writing kernel code.

  • host_path (Optional[str]) -- Destination file path to write the host source. If None, skips writing host code.

返回类型:

None

示例

>>> jit_kernel.export_sources(kernel_path="/tmp/kernel.cu")
>>> jit_kernel.export_sources(host_path="/tmp/host.cc")
>>> jit_kernel.export_sources(
...     kernel_path="/tmp/kernel.cu",
...     host_path="/tmp/host.cc",
... )
print_source_code(which='kernel', file=None)

Deprecated: use show_source() or export_sources() instead.

参数:
  • which (Literal["kernel", "host", "both"], optional) -- Kept for backward compatibility with printing behavior.

  • file (Optional[str]) -- If provided, behaves like export_sources(kernel_path=file).

返回类型:

None

示例

>>> # New API (preferred)
>>> jit_kernel.show_source("both")
>>> jit_kernel.export_sources(kernel_path="/tmp/kernel.cu")
>>> # Old API (still works but deprecated)
>>> jit_kernel.print_source_code(file="/tmp/kernel.cu")
update_tuner_result(latency, config, ref_latency)

Updates the tuning results for this kernel.

参数:
  • latency (float) -- The measured latency of this kernel configuration.

  • config (Dict[str, Any]) -- The configuration parameters used for this kernel.

  • ref_latency (float) -- The reference latency to compare against.

返回类型:

None

get_tuner_result()

Gets the tuning results for this kernel.

返回:

A dictionary containing: - latency: The measured latency of this kernel - config: The configuration parameters used - ref_latency: The reference latency for comparison

返回类型:

Dict[str, Any]

property out_idx: list[int]
返回类型:

list[int]

property params: list[tilelang.engine.param.KernelParam]
返回类型:

list[tilelang.engine.param.KernelParam]

property kernel_source: str
返回类型:

str

property host_source: str
返回类型:

str

export_library(kernel_file)

Exports the compiled kernel function to a shared library file.

参数:

kernel_file (str) -- The path to the shared library file to create.

返回类型:

None

show_ptx()

Print compiled PTX for the kernel (CUDA only).

示例

>>> jit_kernel.show_ptx()
返回类型:

None

export_ptx(path)

Export compiled PTX to a file (CUDA only).

参数:

path (str) -- Destination file path to write PTX.

返回类型:

None

示例

>>> jit_kernel.export_ptx("/tmp/kernel.ptx")
show_sass()

Print disassembled SASS for the kernel (CUDA only).

示例

>>> jit_kernel.show_sass()
返回类型:

None

export_sass(path)

Export disassembled SASS to a file (CUDA only).

参数:

path (str) -- Destination file path to write SASS.

返回类型:

None

示例

>>> jit_kernel.export_sass("/tmp/kernel.sass")