tilelang.jit¶
This module provides an auto-tuning infrastructure for TileLang (tl) programs. It includes functionality to JIT-compile TileLang programs into a runnable kernel adapter using TVM.
Submodules¶
Attributes¶
Classes¶
Detailed Just-In-Time wrapper for TileLang programs. |
Functions¶
|
Compile the given TileLang PrimFunc with TVM and build a JITKernel. |
|
Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels. |
|
Just-In-Time (JIT) compiler decorator for TileLang functions. |
|
Package Contents¶
- tilelang.jit.logger¶
- tilelang.jit.compile(func=None, out_idx=None, execution_backend='auto', target='auto', target_host=None, verbose=False, pass_configs=None, compile_flags=None)¶
Compile the given TileLang PrimFunc with TVM and build a JITKernel. :param func: The TileLang TIR function to compile and wrap. :type func: tvm.tir.PrimFunc, optional :param out_idx: Index(es) of the output tensors to return (default: None). :type out_idx: Union[List[int], int], optional :param execution_backend: Execution backend to use for kernel execution. Use “auto” to pick a sensible
default per target (cuda->tvm_ffi, metal->torch, others->cython).
- Parameters:
target (Union[str, Target], optional) – Compilation target, either as a string or a TVM Target object (default: “auto”).
target_host (Union[str, Target], optional) – Target host for cross-compilation (default: None).
verbose (bool, optional) – Whether to enable verbose output (default: False).
pass_configs (dict, optional) – Additional keyword arguments to pass to the Compiler PassContext. Refer to tilelang.transform.PassConfigKey for supported options.
func (tilelang.language.v2.PrimFunc[_KP, _T])
out_idx (list[int] | int | None)
execution_backend (Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], optional)
compile_flags (list[str] | str | None)
- Return type:
kernel.JITKernel[_KP, _T]
- tilelang.jit.par_compile(funcs, out_idx=None, execution_backend='auto', target='auto', target_host=None, verbose=False, pass_configs=None, compile_flags=None, num_workers=None, ignore_error=False)¶
Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels. :param funcs: The TileLang TIR functions to compile and wrap. :type funcs: Iterable[tvm.tir.PrimFunc] :param out_idx: Index(es) of the output tensors to return (default: None). :type out_idx: Union[List[int], int], optional :param execution_backend: Execution backend to use for kernel execution. Use “auto” to pick a sensible
default per target (cuda->tvm_ffi, metal->torch, others->cython).
- Parameters:
target (Union[str, Target], optional) – Compilation target, either as a string or a TVM Target object (default: “auto”).
target_host (Union[str, Target], optional) – Target host for cross-compilation (default: None).
verbose (bool, optional) – Whether to enable verbose output (default: False).
pass_configs (dict, optional) – Additional keyword arguments to pass to the Compiler PassContext. Refer to tilelang.transform.PassConfigKey for supported options.
funcs (collections.abc.Iterable[tilelang.language.v2.PrimFunc[_KP, _T]])
out_idx (list[int] | int | None)
execution_backend (Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], optional)
compile_flags (list[str] | str | None)
num_workers (int)
ignore_error (bool)
- Return type:
list[kernel.JITKernel[_KP, _T]]
- class tilelang.jit.JITImpl¶
Bases:
Generic[_P,_KP,_T,_Ret]Detailed Just-In-Time wrapper for TileLang programs.
This dataclass encapsulates the configuration and runtime helpers used by the top-level jit and jit2 decorators. It represents a configured JIT “factory” that can (a) elaborate TileLang/PrimFunc creators into concrete TIR (PrimFunc), (b) compile those TIR functions into runnable kernels via the TVM bridge, (c) cache compiled kernels keyed by call-site arguments (and optional tuning parameters), and (d) provide parallel compilation helpers for batch autotuning workflows.
- out_idx¶
Which output tensor(s) of the compiled kernel should be returned to the caller. Accepts a single index, a list of indices, or None to return all.
- Type:
list[int] | int | None
- execution_backend¶
Backend used for exchanging arguments and executing the generated kernel.
- Type:
Literal[“dlpack”, “ctypes”, “cython”]
- target¶
TVM compilation target (e.g. “cuda”, “llvm”, or “auto”).
- Type:
str | tvm.target.Target
- target_host¶
Host target used for cross-compilation, or None to infer/default.
- Type:
str | tvm.target.Target | None
- pass_configs¶
Extra TVM pass configuration options forwarded to the compiler’s PassContext.
- Type:
dict[str, Any] | None
- debug_root_path¶
If provided, compiled kernel source and the elaborated Python program are written to this directory to ease debugging and inspection.
- Type:
str | None
- compile_flags¶
Additional flags passed to the compiler. A single string will be converted to a single-element list.
- Type:
list[str] | str | None
- func_source¶
Original Python source string from which the PrimFunc or creator was derived. Used for diagnostics and debug dumps.
- Type:
str
- signature¶
Function signature of the original Python function (useful for tooling).
- Type:
inspect.Signature
- v2¶
Indicates whether the object wraps a “v2” PrimFunc creator (True) or a plain callable / PrimFunc (False). v2-mode enables argument conversion hooks and a distinct cache keying strategy.
- Type:
- func¶
The underlying object: either a user function that returns a PrimFunc (creator), a PrimFuncCreater, or an already-constructed PrimFunc. For presentation/readability the function is stored last in the dataclass.
- Type:
Callable | PrimFunc | PrimFuncCreater
- Behavioral summary
- ------------------
- - get_tir(*args, \*\*kwargs)
Converts provided call-site arguments into a concrete PrimFunc. If the wrapped object is a PrimFuncCreater or a user callable, it is invoked with the given arguments. If the wrapped object is already a PrimFunc, it is returned as-is.
- - compile(...)
A convenience wrapper that elaborates and immediately compiles a single PrimFunc into a JITKernel using the module-level compile function. When debug_root_path is set, the compiled C kernel and the source Python program are saved for inspection.
- - par_compile(configs, ...)
Accepts an iterable of configs (either dicts mapping keyword args or tuples mapping to positional args). Each config is elaborated to a PrimFunc and the resulting set is compiled in parallel via the module-level par_compile helper. Returns a list of JITKernel objects in the same order as the provided configs.
- out_idx: list[int] | int | None¶
- execution_backend: Literal['auto', 'dlpack', 'tvm_ffi', 'ctypes', 'cython', 'nvrtc', 'torch']¶
- target: str | tvm.target.Target¶
- target_host: str | tvm.target.Target¶
- pass_configs: dict[str, Any] | None¶
- debug_root_path: str | None¶
- compile_flags: list[str] | str | None¶
- func_source: str¶
- signature: inspect.Signature¶
- func: Callable[_P, _T] | tilelang.language.v2.PrimFunc[_KP, _T]¶
- property annot: dict[str, tilelang.language.v2.annot.Annot]¶
- Return type:
dict[str, tilelang.language.v2.annot.Annot]
- __post_init__()¶
- get_tir(*args, **kwargs)¶
Retrieve a TIR (Tensor Intermediate Representation) PrimFunc from the stored callable or object.
- Parameters:
args (_P)
kwargs (_P)
- Return type:
tilelang.language.v2.PrimFunc[_KP, _T]
- par_compile(configs, num_workers=None, ignore_error=False)¶
Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels. :param configs: The configurations to elaborate and compile. Each config can be either
a dictionary mapping keyword arguments to values, or a tuple of positional arguments.
- Parameters:
num_workers (int, optional) – Number of parallel workers to use for compilation. Defaults to None, which lets the system decide.
ignore_error (bool, optional) – If True, compilation errors for individual configs will be logged as warnings and the corresponding result will be None. If False, any compilation error will raise an exception. Defaults to False.
configs (Iterable[Union[dict[str, Any], tuple[Any, ...]]])
- Returns:
A list of compiled JITKernel objects corresponding to the provided configs.
- Return type:
List[JITKernel]
- compile(*args, **kwargs)¶
- Parameters:
args (_P)
kwargs (_P)
- Return type:
_Ret
- parse_cache_key(*args, **kwargs)¶
- Parameters:
args (_P)
kwargs (_P)
- convert_kernel_args(*args, **kwargs)¶
- Parameters:
args (_P)
kwargs (_P)
- __call__(*args, **kwargs)¶
- Parameters:
args (_P)
kwargs (_P)
- Return type:
_Ret
- tilelang.jit.ExecutionBackend¶
- tilelang.jit.jit(func: Callable[_P, tilelang.language.v2.PrimFunc[_KP, _T]]) JITImpl[_P, _KP, _T, kernel.JITKernel[_KP, _T]]¶
- tilelang.jit.jit(*, out_idx: Any = None, target: str | tvm.target.Target = 'auto', target_host: str | tvm.target.Target = None, execution_backend: ExecutionBackend = 'auto', verbose: bool = False, pass_configs: dict[str, Any] | None = None, debug_root_path: str | None = None, compile_flags: list[str] | str | None = None) Callable[[Callable[_P, tilelang.language.v2.PrimFunc[_KP, _T]]], JITImpl[_P, _KP, _T, kernel.JITKernel[_KP, _T]]]
Just-In-Time (JIT) compiler decorator for TileLang functions.
- This decorator can be used without arguments (e.g., @tilelang.jit):
Applies JIT compilation with default settings.
- Parameters:
func_or_out_idx (Any, optional) – If using @tilelang.jit(…) to configure, this is the out_idx parameter. If using @tilelang.jit directly on a function, this argument is implicitly the function to be decorated (and out_idx will be None).
target (Union[str, Target], optional) – Compilation target for TVM (e.g., “cuda”, “llvm”). Defaults to “auto”.
target_host (Union[str, Target], optional) – Target host for cross-compilation. Defaults to None.
execution_backend (Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], optional) – Backend for kernel execution and argument passing. Use “auto” to pick a sensible default per target (cuda->tvm_ffi, metal->torch, others->cython).
verbose (bool, optional) – Enables verbose logging during compilation. Defaults to False.
pass_configs (Optional[Dict[str, Any]], optional) – Configurations for TVM’s pass context. Defaults to None.
debug_root_path (Optional[str], optional) – Directory to save compiled kernel source for debugging. Defaults to None.
- Returns:
Either a JIT-compiled wrapper around the input function, or a configured decorator instance that can then be applied to a function.
- Return type:
Callable
- tilelang.jit.lazy_jit(func: Callable[_KP, _T]) JITImpl[_KP, _KP, _T, _T]¶
- tilelang.jit.lazy_jit(*, out_idx: Any = None, target: str | tvm.target.Target = 'auto', target_host: str | tvm.target.Target = None, execution_backend: ExecutionBackend = 'auto', verbose: bool = False, pass_configs: dict[str, Any] | None = None, debug_root_path: str | None = None, compile_flags: list[str] | str | None = None) Callable[[Callable[_KP, _T]], JITImpl[_KP, _KP, _T, _T]]