tilelang.jit.kernel¶
Attributes¶
Classes¶
A wrapper class for compiling and invoking TileLang (TVM TIR) functions as PyTorch-compatible functions. |
Module Contents¶
- tilelang.jit.kernel.logger¶
- class tilelang.jit.kernel.JITKernel(func=None, out_idx=None, execution_backend='tvm_ffi', target='auto', target_host=None, verbose=False, pass_configs=None, from_database=False, compile_flags=None)¶
Bases:
Generic[_P,_T]A wrapper class for compiling and invoking TileLang (TVM TIR) functions as PyTorch-compatible functions.
- 参数:
func (tvm.tir.PrimFunc)
out_idx (list[int] | int)
execution_backend (Literal['tvm_ffi', 'cython', 'nvrtc', 'torch', 'cutedsl'])
target (str | tvm.target.Target)
target_host (str | tvm.target.Target)
verbose (bool)
pass_configs (dict[str, Any] | None)
from_database (bool)
compile_flags (list[str] | None)
- artifact¶
The compiled artifact containing the runtime module and parameters.
- Type:
- adapter¶
The adapter for the compiled function.
- Type:
- torch_function¶
The compiled function that can be invoked as a PyTorch-compatible function.
- Type:
Callable
- prim_func: tvm.tir.PrimFunc = None¶
- artifact: tilelang.engine.param.CompiledArtifact = None¶
- adapter: tilelang.jit.adapter.BaseKernelAdapter = None¶
- torch_function: Callable = None¶
- latency: float = None¶
- config: dict[str, Any] = None¶
- ref_latency: float = None¶
- execution_backend = 'tvm_ffi'¶
- target_host = None¶
- verbose = False¶
- pass_configs = None¶
- compile_flags¶
- target¶
- classmethod from_database(func, host_kernel_source, device_kernel_source, kernel_lib_path, params, target, target_host, out_idx, execution_backend, pass_configs=None, compile_flags=None)¶
Alternative constructor to create a TorchFunction directly from a database.
- 参数:
func (tvm.tir.PrimFunc)
host_kernel_source (str)
device_kernel_source (str)
kernel_lib_path (str)
params (list[tilelang.engine.param.KernelParam])
target (str | tvm.target.Target)
target_host (str | tvm.target.Target)
out_idx (list[int] | int)
execution_backend (Literal['tvm_ffi', 'cython', 'nvrtc', 'torch'])
pass_configs (dict[str, Any] | None)
compile_flags (list[str] | None)
- __call__(*args, **kwds)¶
Invokes the compiled function with the given arguments.
- 参数:
*args (Any) -- Positional arguments for the function.
**kwds (Any) -- Keyword arguments for the function.
- 返回:
The result of the function execution.
- 返回类型:
Any
- classmethod from_tilelang_function(tilelang_func, **kwargs)¶
Alternative constructor to create a TorchFunction directly from a TileLang PrimFunc.
- 参数:
tilelang_func (tvm.tir.PrimFunc) -- The TileLang (TVM TIR) function to compile.
**kwargs (dict) -- Additional keyword arguments to pass to the constructor.
- 返回:
An instance of TorchFunction wrapping the compiled function.
- 返回类型:
TorchFunction
- get_profiler(tensor_supply_type=TensorSupplyType.Auto)¶
Creates a profiler to benchmark the compiled runtime module.
- 参数:
tensor_supply_type (TensorSupplyType, optional) -- The type of input tensors to supply for profiling (default: TensorSupplyType.Auto).
- 返回:
A Profiler instance for benchmarking the runtime module.
- 返回类型:
- get_kernel_source(kernel_only=True)¶
Returns the source code of the compiled kernel function.
- 返回:
The source code of the compiled kernel function.
- 返回类型:
str
- 参数:
kernel_only (bool)
- get_host_source()¶
Returns the source code of the host function.
- 返回类型:
str
- run_once(func=None)¶
- 参数:
func (Callable | None)
- 返回类型:
None
- show_source(which='kernel')¶
Print generated source code to stdout.
- 参数:
which (Literal["kernel", "host", "both"], optional) -- Select which source to print. Defaults to "kernel".
- 返回类型:
None
示例
>>> jit_kernel.show_source() # print kernel source >>> jit_kernel.show_source("host") # print host source >>> jit_kernel.show_source("both") # print both sources
- export_sources(kernel_path=None, host_path=None)¶
Export generated source code to files.
- 参数:
kernel_path (Optional[str]) -- Destination file path to write the kernel source. If None, skips writing kernel code.
host_path (Optional[str]) -- Destination file path to write the host source. If None, skips writing host code.
- 返回类型:
None
示例
>>> jit_kernel.export_sources(kernel_path="/tmp/kernel.cu") >>> jit_kernel.export_sources(host_path="/tmp/host.cc") >>> jit_kernel.export_sources( ... kernel_path="/tmp/kernel.cu", ... host_path="/tmp/host.cc", ... )
- print_source_code(which='kernel', file=None)¶
Deprecated: use show_source() or export_sources() instead.
- 参数:
which (Literal["kernel", "host", "both"], optional) -- Kept for backward compatibility with printing behavior.
file (Optional[str]) -- If provided, behaves like export_sources(kernel_path=file).
- 返回类型:
None
示例
>>> # New API (preferred) >>> jit_kernel.show_source("both") >>> jit_kernel.export_sources(kernel_path="/tmp/kernel.cu")
>>> # Old API (still works but deprecated) >>> jit_kernel.print_source_code(file="/tmp/kernel.cu")
- update_tuner_result(latency, config, ref_latency)¶
Updates the tuning results for this kernel.
- 参数:
latency (float) -- The measured latency of this kernel configuration.
config (Dict[str, Any]) -- The configuration parameters used for this kernel.
ref_latency (float) -- The reference latency to compare against.
- 返回类型:
None
- get_tuner_result()¶
Gets the tuning results for this kernel.
- 返回:
A dictionary containing: - latency: The measured latency of this kernel - config: The configuration parameters used - ref_latency: The reference latency for comparison
- 返回类型:
Dict[str, Any]
- property out_idx: list[int]¶
- 返回类型:
list[int]
- property params: list[tilelang.engine.param.KernelParam]¶
- 返回类型:
- property kernel_source: str¶
- 返回类型:
str
- property host_source: str¶
- 返回类型:
str
- export_library(kernel_file)¶
Exports the compiled kernel function to a shared library file.
- 参数:
kernel_file (str) -- The path to the shared library file to create.
- 返回类型:
None
- show_ptx()¶
Print compiled PTX for the kernel (CUDA only).
示例
>>> jit_kernel.show_ptx()
- 返回类型:
None
- export_ptx(path)¶
Export compiled PTX to a file (CUDA only).
- 参数:
path (str) -- Destination file path to write PTX.
- 返回类型:
None
示例
>>> jit_kernel.export_ptx("/tmp/kernel.ptx")
- show_sass()¶
Print disassembled SASS for the kernel (CUDA only).
示例
>>> jit_kernel.show_sass()
- 返回类型:
None
- export_sass(path)¶
Export disassembled SASS to a file (CUDA only).
- 参数:
path (str) -- Destination file path to write SASS.
- 返回类型:
None
示例
>>> jit_kernel.export_sass("/tmp/kernel.sass")