tilelang.jit.kernel =================== .. py:module:: tilelang.jit.kernel Attributes ---------- .. autoapisummary:: tilelang.jit.kernel.logger Classes ------- .. autoapisummary:: tilelang.jit.kernel.JITKernel Module Contents --------------- .. py:data:: logger .. py:class:: JITKernel(func = None, out_idx = None, execution_backend = 'tvm_ffi', target = 'auto', target_host = None, verbose = False, pass_configs = None, from_database = False, compile_flags = None) Bases: :py:obj:`Generic`\ [\ :py:obj:`_P`\ , :py:obj:`_T`\ ] A wrapper class for compiling and invoking TileLang (TVM TIR) functions as PyTorch-compatible functions. .. attribute:: artifact The compiled artifact containing the runtime module and parameters. :type: CompiledArtifact .. attribute:: adapter The adapter for the compiled function. :type: BaseKernelAdapter .. attribute:: torch_function The compiled function that can be invoked as a PyTorch-compatible function. :type: Callable .. py:attribute:: prim_func :type: tvm.tir.PrimFunc :value: None .. py:attribute:: artifact :type: tilelang.engine.param.CompiledArtifact :value: None .. py:attribute:: adapter :type: tilelang.jit.adapter.BaseKernelAdapter :value: None .. py:attribute:: torch_function :type: Callable :value: None .. py:attribute:: latency :type: float :value: None .. py:attribute:: config :type: dict[str, Any] :value: None .. py:attribute:: ref_latency :type: float :value: None .. py:attribute:: execution_backend :value: 'tvm_ffi' .. py:attribute:: target_host :value: None .. py:attribute:: verbose :value: False .. py:attribute:: pass_configs :value: None .. py:attribute:: compile_flags :value: None .. py:attribute:: target .. py:method:: from_database(func, host_kernel_source, device_kernel_source, kernel_lib_path, params, target, target_host, out_idx, execution_backend, pass_configs = None, compile_flags = None) :classmethod: Alternative constructor to create a TorchFunction directly from a database. .. py:method:: __call__(*args, **kwds) Invokes the compiled function with the given arguments. :param \*args: Positional arguments for the function. :type \*args: Any :param \*\*kwds: Keyword arguments for the function. :type \*\*kwds: Any :returns: The result of the function execution. :rtype: Any .. py:method:: from_tilelang_function(tilelang_func, **kwargs) :classmethod: Alternative constructor to create a TorchFunction directly from a TileLang PrimFunc. :param tilelang_func: The TileLang (TVM TIR) function to compile. :type tilelang_func: tvm.tir.PrimFunc :param \*\*kwargs: Additional keyword arguments to pass to the constructor. :type \*\*kwargs: dict :returns: An instance of TorchFunction wrapping the compiled function. :rtype: TorchFunction .. py:method:: get_profiler(tensor_supply_type = TensorSupplyType.Auto) Creates a profiler to benchmark the compiled runtime module. :param tensor_supply_type: The type of input tensors to supply for profiling (default: TensorSupplyType.Auto). :type tensor_supply_type: TensorSupplyType, optional :returns: A Profiler instance for benchmarking the runtime module. :rtype: Profiler .. py:method:: get_kernel_source(kernel_only = True) Returns the source code of the compiled kernel function. :returns: The source code of the compiled kernel function. :rtype: str .. py:method:: get_host_source() Returns the source code of the host function. .. py:method:: run_once(func = None) .. py:method:: show_source(which = 'kernel') Print generated source code to stdout. :param which: Select which source to print. Defaults to "kernel". :type which: Literal["kernel", "host", "both"], optional .. rubric:: Examples >>> jit_kernel.show_source() # print kernel source >>> jit_kernel.show_source("host") # print host source >>> jit_kernel.show_source("both") # print both sources .. py:method:: export_sources(kernel_path = None, host_path = None) Export generated source code to files. :param kernel_path: Destination file path to write the kernel source. If None, skips writing kernel code. :type kernel_path: Optional[str] :param host_path: Destination file path to write the host source. If None, skips writing host code. :type host_path: Optional[str] .. rubric:: Examples >>> jit_kernel.export_sources(kernel_path="/tmp/kernel.cu") >>> jit_kernel.export_sources(host_path="/tmp/host.cc") >>> jit_kernel.export_sources( ... kernel_path="/tmp/kernel.cu", ... host_path="/tmp/host.cc", ... ) .. py:method:: print_source_code(which = 'kernel', file = None) Deprecated: use show_source() or export_sources() instead. :param which: Kept for backward compatibility with printing behavior. :type which: Literal["kernel", "host", "both"], optional :param file: If provided, behaves like export_sources(kernel_path=file). :type file: Optional[str] .. rubric:: Examples >>> # New API (preferred) >>> jit_kernel.show_source("both") >>> jit_kernel.export_sources(kernel_path="/tmp/kernel.cu") >>> # Old API (still works but deprecated) >>> jit_kernel.print_source_code(file="/tmp/kernel.cu") .. py:method:: update_tuner_result(latency, config, ref_latency) Updates the tuning results for this kernel. :param latency: The measured latency of this kernel configuration. :type latency: float :param config: The configuration parameters used for this kernel. :type config: Dict[str, Any] :param ref_latency: The reference latency to compare against. :type ref_latency: float :rtype: None .. py:method:: get_tuner_result() Gets the tuning results for this kernel. :returns: A dictionary containing: - latency: The measured latency of this kernel - config: The configuration parameters used - ref_latency: The reference latency for comparison :rtype: Dict[str, Any] .. py:property:: out_idx :type: list[int] .. py:property:: params :type: list[tilelang.engine.param.KernelParam] .. py:property:: kernel_source :type: str .. py:property:: host_source :type: str .. py:method:: export_library(kernel_file) Exports the compiled kernel function to a shared library file. :param kernel_file: The path to the shared library file to create. :type kernel_file: str .. py:method:: show_ptx() Print compiled PTX for the kernel (CUDA only). .. rubric:: Examples >>> jit_kernel.show_ptx() .. py:method:: export_ptx(path) Export compiled PTX to a file (CUDA only). :param path: Destination file path to write PTX. :type path: str .. rubric:: Examples >>> jit_kernel.export_ptx("/tmp/kernel.ptx") .. py:method:: show_sass() Print disassembled SASS for the kernel (CUDA only). .. rubric:: Examples >>> jit_kernel.show_sass() .. py:method:: export_sass(path) Export disassembled SASS to a file (CUDA only). :param path: Destination file path to write SASS. :type path: str .. rubric:: Examples >>> jit_kernel.export_sass("/tmp/kernel.sass")