tilelang.jit ============ .. py:module:: tilelang.jit .. autoapi-nested-parse:: This module provides an auto-tuning infrastructure for TileLang (tl) programs. It includes functionality to JIT-compile TileLang programs into a runnable kernel adapter using TVM. Submodules ---------- .. toctree:: :maxdepth: 1 /autoapi/tilelang/jit/adapter/index /autoapi/tilelang/jit/env/index /autoapi/tilelang/jit/execution_backend/index /autoapi/tilelang/jit/kernel/index /autoapi/tilelang/jit/param/index Attributes ---------- .. autoapisummary:: tilelang.jit.logger tilelang.jit.ExecutionBackend Classes ------- .. autoapisummary:: tilelang.jit.JITImpl Functions --------- .. autoapisummary:: tilelang.jit.compile tilelang.jit.par_compile tilelang.jit.jit tilelang.jit.lazy_jit Package Contents ---------------- .. py:data:: logger .. py:function:: compile(func = None, out_idx = None, execution_backend = 'auto', target = 'auto', target_host = None, verbose = False, pass_configs = None, compile_flags = None) Compile the given TileLang PrimFunc with TVM and build a JITKernel. :param func: The TileLang TIR function to compile and wrap. :type func: tvm.tir.PrimFunc, optional :param out_idx: Index(es) of the output tensors to return (default: None). :type out_idx: Union[List[int], int], optional :param execution_backend: Execution backend to use for kernel execution. Use "auto" to pick a sensible default per target (cuda->tvm_ffi, metal->torch, others->cython). :type execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], optional :param target: Compilation target, either as a string or a TVM Target object (default: "auto"). :type target: Union[str, Target], optional :param target_host: Target host for cross-compilation (default: None). :type target_host: Union[str, Target], optional :param verbose: Whether to enable verbose output (default: False). :type verbose: bool, optional :param pass_configs: Additional keyword arguments to pass to the Compiler PassContext. Refer to `tilelang.transform.PassConfigKey` for supported options. :type pass_configs: dict, optional .. py:function:: par_compile(funcs, out_idx = None, execution_backend = 'auto', target = 'auto', target_host = None, verbose = False, pass_configs = None, compile_flags = None, num_workers = None, ignore_error = False) Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels. :param funcs: The TileLang TIR functions to compile and wrap. :type funcs: Iterable[tvm.tir.PrimFunc] :param out_idx: Index(es) of the output tensors to return (default: None). :type out_idx: Union[List[int], int], optional :param execution_backend: Execution backend to use for kernel execution. Use "auto" to pick a sensible default per target (cuda->tvm_ffi, metal->torch, others->cython). :type execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], optional :param target: Compilation target, either as a string or a TVM Target object (default: "auto"). :type target: Union[str, Target], optional :param target_host: Target host for cross-compilation (default: None). :type target_host: Union[str, Target], optional :param verbose: Whether to enable verbose output (default: False). :type verbose: bool, optional :param pass_configs: Additional keyword arguments to pass to the Compiler PassContext. Refer to `tilelang.transform.PassConfigKey` for supported options. :type pass_configs: dict, optional .. py:class:: JITImpl Bases: :py:obj:`Generic`\ [\ :py:obj:`_P`\ , :py:obj:`_KP`\ , :py:obj:`_T`\ , :py:obj:`_Ret`\ ] Detailed Just-In-Time wrapper for TileLang programs. This dataclass encapsulates the configuration and runtime helpers used by the top-level `jit` and `jit2` decorators. It represents a configured JIT "factory" that can (a) elaborate TileLang/PrimFunc creators into concrete TIR (PrimFunc), (b) compile those TIR functions into runnable kernels via the TVM bridge, (c) cache compiled kernels keyed by call-site arguments (and optional tuning parameters), and (d) provide parallel compilation helpers for batch autotuning workflows. .. attribute:: out_idx Which output tensor(s) of the compiled kernel should be returned to the caller. Accepts a single index, a list of indices, or None to return all. :type: list[int] | int | None .. attribute:: execution_backend Backend used for exchanging arguments and executing the generated kernel. :type: Literal["dlpack", "ctypes", "cython"] .. attribute:: target TVM compilation target (e.g. "cuda", "llvm", or "auto"). :type: str | tvm.target.Target .. attribute:: target_host Host target used for cross-compilation, or None to infer/default. :type: str | tvm.target.Target | None .. attribute:: verbose Enable verbose messages during compilation/build. :type: bool .. attribute:: pass_configs Extra TVM pass configuration options forwarded to the compiler's PassContext. :type: dict[str, Any] | None .. attribute:: debug_root_path If provided, compiled kernel source and the elaborated Python program are written to this directory to ease debugging and inspection. :type: str | None .. attribute:: compile_flags Additional flags passed to the compiler. A single string will be converted to a single-element list. :type: list[str] | str | None .. attribute:: func_source Original Python source string from which the PrimFunc or creator was derived. Used for diagnostics and debug dumps. :type: str .. attribute:: signature Function signature of the original Python function (useful for tooling). :type: inspect.Signature .. attribute:: v2 Indicates whether the object wraps a "v2" PrimFunc creator (True) or a plain callable / PrimFunc (False). v2-mode enables argument conversion hooks and a distinct cache keying strategy. :type: bool .. attribute:: func The underlying object: either a user function that returns a PrimFunc (creator), a PrimFuncCreater, or an already-constructed PrimFunc. For presentation/readability the function is stored last in the dataclass. :type: Callable | PrimFunc | PrimFuncCreater .. attribute:: Behavioral summary .. attribute:: ------------------ .. attribute:: - get_tir(*args, \*\*kwargs) Converts provided call-site arguments into a concrete PrimFunc. If the wrapped object is a PrimFuncCreater or a user callable, it is invoked with the given arguments. If the wrapped object is already a PrimFunc, it is returned as-is. .. attribute:: - compile(...) A convenience wrapper that elaborates and immediately compiles a single PrimFunc into a JITKernel using the module-level `compile` function. When `debug_root_path` is set, the compiled C kernel and the source Python program are saved for inspection. .. attribute:: - par_compile(configs, ...) Accepts an iterable of configs (either dicts mapping keyword args or tuples mapping to positional args). Each config is elaborated to a PrimFunc and the resulting set is compiled in parallel via the module-level `par_compile` helper. Returns a list of JITKernel objects in the same order as the provided configs. .. py:attribute:: out_idx :type: list[int] | int | None .. py:attribute:: execution_backend :type: Literal['auto', 'dlpack', 'tvm_ffi', 'ctypes', 'cython', 'nvrtc', 'torch'] .. py:attribute:: target :type: str | tvm.target.Target .. py:attribute:: target_host :type: str | tvm.target.Target .. py:attribute:: verbose :type: bool .. py:attribute:: pass_configs :type: dict[str, Any] | None .. py:attribute:: debug_root_path :type: str | None .. py:attribute:: compile_flags :type: list[str] | str | None .. py:attribute:: func_source :type: str .. py:attribute:: signature :type: inspect.Signature .. py:attribute:: lazy_jit :type: bool .. py:attribute:: func :type: Callable[_P, _T] | tilelang.language.v2.PrimFunc[_KP, _T] .. py:property:: annot :type: dict[str, tilelang.language.v2.annot.Annot] .. py:method:: __post_init__() .. py:method:: get_tir(*args, **kwargs) Retrieve a TIR (Tensor Intermediate Representation) PrimFunc from the stored callable or object. .. py:method:: par_compile(configs, num_workers = None, ignore_error = False) Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels. :param configs: The configurations to elaborate and compile. Each config can be either a dictionary mapping keyword arguments to values, or a tuple of positional arguments. :type configs: Iterable[Union[dict[str, Any], tuple[Any, ...]]] :param num_workers: Number of parallel workers to use for compilation. Defaults to None, which lets the system decide. :type num_workers: int, optional :param ignore_error: If True, compilation errors for individual configs will be logged as warnings and the corresponding result will be None. If False, any compilation error will raise an exception. Defaults to False. :type ignore_error: bool, optional :returns: A list of compiled JITKernel objects corresponding to the provided configs. :rtype: List[JITKernel] .. py:method:: compile(*args, **kwargs) .. py:method:: parse_cache_key(*args, **kwargs) .. py:method:: convert_kernel_args(*args, **kwargs) .. py:method:: __call__(*args, **kwargs) .. py:data:: ExecutionBackend .. py:function:: jit(func: Callable[_P, tilelang.language.v2.PrimFunc[_KP, _T]]) -> JITImpl[_P, _KP, _T, kernel.JITKernel[_KP, _T]] jit(*, out_idx: Any = None, target: str | tvm.target.Target = 'auto', target_host: str | tvm.target.Target = None, execution_backend: ExecutionBackend = 'auto', verbose: bool = False, pass_configs: dict[str, Any] | None = None, debug_root_path: str | None = None, compile_flags: list[str] | str | None = None) -> Callable[[Callable[_P, tilelang.language.v2.PrimFunc[_KP, _T]]], JITImpl[_P, _KP, _T, kernel.JITKernel[_KP, _T]]] Just-In-Time (JIT) compiler decorator for TileLang functions. This decorator can be used without arguments (e.g., `@tilelang.jit`): Applies JIT compilation with default settings. :param func_or_out_idx: If using `@tilelang.jit(...)` to configure, this is the `out_idx` parameter. If using `@tilelang.jit` directly on a function, this argument is implicitly the function to be decorated (and `out_idx` will be `None`). :type func_or_out_idx: Any, optional :param target: Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto". :type target: Union[str, Target], optional :param target_host: Target host for cross-compilation. Defaults to None. :type target_host: Union[str, Target], optional :param execution_backend: Backend for kernel execution and argument passing. Use "auto" to pick a sensible default per target (cuda->tvm_ffi, metal->torch, others->cython). :type execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], optional :param verbose: Enables verbose logging during compilation. Defaults to False. :type verbose: bool, optional :param pass_configs: Configurations for TVM's pass context. Defaults to None. :type pass_configs: Optional[Dict[str, Any]], optional :param debug_root_path: Directory to save compiled kernel source for debugging. Defaults to None. :type debug_root_path: Optional[str], optional :returns: Either a JIT-compiled wrapper around the input function, or a configured decorator instance that can then be applied to a function. :rtype: Callable .. py:function:: lazy_jit(func: Callable[_KP, _T]) -> JITImpl[_KP, _KP, _T, _T] lazy_jit(*, out_idx: Any = None, target: str | tvm.target.Target = 'auto', target_host: str | tvm.target.Target = None, execution_backend: ExecutionBackend = 'auto', verbose: bool = False, pass_configs: dict[str, Any] | None = None, debug_root_path: str | None = None, compile_flags: list[str] | str | None = None) -> Callable[[Callable[_KP, _T]], JITImpl[_KP, _KP, _T, _T]]