tilelang.jit.adapter.cutedsl.wrapper
====================================
.. py:module:: tilelang.jit.adapter.cutedsl.wrapper
.. autoapi-nested-parse::
CuTeDSL Source Wrapper for TileLang.
This module provides C++ kernel launcher generation for the CuTeDSL backend.
Key features:
- Automatic C++ launcher generation with CUDA Driver API
- TMA descriptors on HOST memory, passed via __grid_constant__ (no device copy needed)
- cuLaunchKernel automatically copies 128-byte CUtensorMap to kernel param space
- Support for single and multiple kernel launches
- Complete cache system integration
Attributes
----------
.. autoapisummary::
tilelang.jit.adapter.cutedsl.wrapper.CPP_TMA_DESC_INIT_TEMPLATE
tilelang.jit.adapter.cutedsl.wrapper.CPP_TMA_IM2COL_DESC_INIT_TEMPLATE
tilelang.jit.adapter.cutedsl.wrapper.CPP_TMA_INIT_FUNC_TEMPLATE
tilelang.jit.adapter.cutedsl.wrapper.CPP_KERNEL_INIT_TEMPLATE
tilelang.jit.adapter.cutedsl.wrapper.CPP_TMA_LAUNCH_INIT_TEMPLATE
tilelang.jit.adapter.cutedsl.wrapper.CPP_KERNEL_LAUNCH_TEMPLATE
tilelang.jit.adapter.cutedsl.wrapper.CPP_LAUNCHER_TEMPLATE
tilelang.jit.adapter.cutedsl.wrapper.CUBIN_TMA_ATOM_INIT_TEMPLATE
tilelang.jit.adapter.cutedsl.wrapper.CUBIN_KERNEL_LAUNCH_TEMPLATE
tilelang.jit.adapter.cutedsl.wrapper.CUBIN_FAKE_TENSOR_TEMPLATE
tilelang.jit.adapter.cutedsl.wrapper.CUBIN_GEN_CODE_TEMPLATE
tilelang.jit.adapter.cutedsl.wrapper.PYTHON_HOST_FUNC_TEMPLATE
Classes
-------
.. autoapisummary::
tilelang.jit.adapter.cutedsl.wrapper.TLCuTeDSLSourceWrapper
Module Contents
---------------
.. py:data:: CPP_TMA_DESC_INIT_TEMPLATE
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
""" // Descriptor {desc_idx}: {desc_name} (tensor: {tensor_name})
{{
uint64_t globalDim[{rank}] = {{{global_dim_values}}};
uint64_t globalStrides[{stride_rank}] = {{{global_stride_values}}};
uint32_t boxDim[{rank}] = {{{box_dim_values}}};
uint32_t elemStrides[{rank}] = {{{elem_stride_values}}};
result = cuTensorMapEncodeTiled(
&tma_descs[{desc_idx}],
static_cast({dtype}),
{rank},
reinterpret_cast({tensor_name}_ptr),
globalDim,
globalStrides,
boxDim,
elemStrides,
static_cast({interleave}),
static_cast({swizzle}),
static_cast({l2_promotion}),
static_cast({oob_fill})
);
if (result != CUDA_SUCCESS) {{
std::cerr << "Failed to encode TMA descriptor {desc_idx}: " << result << "\n";
return result;
}}
}}
"""
.. raw:: html
.. py:data:: CPP_TMA_IM2COL_DESC_INIT_TEMPLATE
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
""" // Descriptor {desc_idx}: {desc_name} (tensor: {tensor_name}) [im2col]
{{
uint64_t globalDim[{rank}] = {{{global_dim_values}}};
uint64_t globalStrides[{stride_rank}] = {{{global_stride_values}}};
uint32_t elemStrides[{rank}] = {{{elem_stride_values}}};
int32_t lowerCorner[{rank_minus_two}] = {{{lower_corner_values}}};
int32_t upperCorner[{rank_minus_two}] = {{{upper_corner_values}}};
result = cuTensorMapEncodeIm2col(
&tma_descs[{desc_idx}],
static_cast({dtype}),
{rank},
reinterpret_cast({tensor_name}_ptr),
globalDim,
globalStrides,
lowerCorner,
upperCorner,
static_cast({channels_per_pixel}),
static_cast({pixels_per_column}),
elemStrides,
static_cast({interleave}),
static_cast({swizzle}),
static_cast({l2_promotion}),
static_cast({oob_fill})
);
if (result != CUDA_SUCCESS) {{
std::cerr << "Failed to encode TMA im2col descriptor {desc_idx}: " << result << "\n";
return result;
}}
}}
"""
.. raw:: html
.. py:data:: CPP_TMA_INIT_FUNC_TEMPLATE
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""CUresult tma_init(CUtensorMap* tma_descs, {func_args}) {{
// Initialize {num_descs} TMA descriptor(s) in caller-provided host array
// cuLaunchKernel will copy 128-byte CUtensorMap to kernel param space automatically
CUresult result;
{desc_init_code}
return CUDA_SUCCESS;
}}
"""
.. raw:: html
.. py:data:: CPP_KERNEL_INIT_TEMPLATE
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
""" // Find and configure kernel {kernel_idx}: {kernel_name}
result = find_kernel_by_pattern(module, "{kernel_name}", &kernels[{kernel_idx}]);
if (result != CUDA_SUCCESS) {{
std::cerr << "Failed to find kernel {kernel_name} on device " << device_id << ": " << result << "\n";
return result;
}}
if ({smem_size} > 0) {{
result = cuFuncSetAttribute(kernels[{kernel_idx}],
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
{smem_size});
if (result != CUDA_SUCCESS) {{
std::cerr << "Failed to set smem for {kernel_name} on device " << device_id << ": " << result << "\n";
return result;
}}
}}
"""
.. raw:: html
.. py:data:: CPP_TMA_LAUNCH_INIT_TEMPLATE
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
""" // Declare stack-local TMA descriptor array (eliminates concurrency race)
CUtensorMap tma_descs[{num_tma_descs}];
// Initialize TMA descriptors (HOST memory - passed via __grid_constant__)
// NOTE: We intentionally do NOT reuse/cached descriptors across launches.
// Pointer-only reuse is a correctness trap (shape/stride may change with same ptr),
// and correctness beats micro-optimizations.
result = tma_init(tma_descs, {tma_tensor_args});
if (result != CUDA_SUCCESS) {{
std::cerr << "Failed to initialize TMA descriptors: " << result << "\n";
return result;
}}
"""
.. raw:: html
.. py:data:: CPP_KERNEL_LAUNCH_TEMPLATE
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
""" // Launch kernel {kernel_idx}: {kernel_name}
{{
// Get the kernel for current device
auto kernels_it = g_device_kernels.find(device_id);
if (kernels_it == g_device_kernels.end()) {{
std::cerr << "Kernels not initialized for device " << device_id << "\n";
return CUDA_ERROR_NOT_INITIALIZED;
}}
const std::vector& kernels = kernels_it->second;
void* args[] = {{{kernel_args}}};
result = cuLaunchKernel(
kernels[{kernel_idx}],
{grid_x}, {grid_y}, {grid_z},
{block_x}, {block_y}, {block_z},
{smem_size},
stream,
args,
nullptr
);
if (result != CUDA_SUCCESS) {{
std::cerr << "Failed to launch kernel {kernel_name} on device " << device_id << ": " << result << "\n";
return result;
}}
}}
"""
.. raw:: html
.. py:data:: CPP_LAUNCHER_TEMPLATE
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""#include
#include
#include
#include
#include
#include
#include
#include
#include
// TVM Headers
#include
#include
#include
// Per-device module and kernel storage for multi-GPU support
// Each device needs its own CUmodule because modules are tied to CUDA contexts
static std::unordered_map g_device_modules;
static std::unordered_map> g_device_kernels;
static std::unordered_map g_device_contexts; // Track retained contexts for cleanup
static std::mutex g_devices_mutex;
// Find kernel by pattern (substring match, prefer base name over _N variants)
CUresult find_kernel_by_pattern(CUmodule module, const char* pattern, CUfunction* out_func) {{
CUresult result;
unsigned int num_funcs = 0;
result = cuModuleGetFunctionCount(&num_funcs, module);
if (result != CUDA_SUCCESS) {{
std::cerr << "Failed to get function count: " << result << "\n";
return result;
}}
std::vector func_list(num_funcs);
result = cuModuleEnumerateFunctions(func_list.data(), num_funcs, module);
if (result != CUDA_SUCCESS) {{
std::cerr << "Failed to enumerate functions: " << result << "\n";
return result;
}}
// Collect substring matches, separating base name from _N variants
std::vector> base_matches; // pattern not followed by _digit
std::vector> variant_matches; // pattern followed by _digit
size_t pattern_len = std::strlen(pattern);
for (unsigned int i = 0; i < num_funcs; i++) {{
const char* func_name = nullptr;
result = cuFuncGetName(&func_name, func_list[i]);
if (result != CUDA_SUCCESS || func_name == nullptr) {{
std::cerr << "Failed to get function name: " << result << "\n";
return result;
}}
std::string name_str(func_name);
size_t pos = name_str.find(pattern);
if (pos != std::string::npos) {{
// Found substring match
size_t after_pattern = pos + pattern_len;
// Check what follows the pattern
if (after_pattern < name_str.length() &&
name_str[after_pattern] == '_' &&
after_pattern + 1 < name_str.length() &&
std::isdigit(name_str[after_pattern + 1])) {{
// Pattern followed by _digit (e.g., "main_kernel_1")
variant_matches.push_back({{name_str, func_list[i]}});
}} else {{
// Pattern not followed by _digit (e.g., "main_kernel" itself)
base_matches.push_back({{name_str, func_list[i]}});
}}
}}
}}
// Decision logic: prefer base matches over variant matches
if (!base_matches.empty()) {{
if (base_matches.size() == 1) {{
*out_func = base_matches[0].second;
return CUDA_SUCCESS;
}}
// Multiple base matches - ambiguous
std::cerr << "Error: Pattern '" << pattern << "' matched " << base_matches.size()
<< " base kernels (ambiguous). Matches found:\n";
for (const auto& match : base_matches) {{
std::cerr << " - " << match.first << "\n";
}}
std::cerr << "Please use a more specific pattern.\n";
return CUDA_ERROR_NOT_FOUND;
}}
// No base matches, try variant matches
if (!variant_matches.empty()) {{
if (variant_matches.size() == 1) {{
*out_func = variant_matches[0].second;
return CUDA_SUCCESS;
}}
// Multiple variant matches - ambiguous
std::cerr << "Error: Pattern '" << pattern << "' matched " << variant_matches.size()
<< " variant kernels (ambiguous). Matches found:\n";
for (const auto& match : variant_matches) {{
std::cerr << " - " << match.first << "\n";
}}
std::cerr << "Please use a more specific pattern (e.g., '" << pattern << "_1').\n";
return CUDA_ERROR_NOT_FOUND;
}}
// No matches at all
std::cerr << "Failed to find kernel matching pattern '" << pattern << "'\n";
return CUDA_ERROR_NOT_FOUND;
}}
// Initialize CUDA module for a specific device (called once per device)
// Thread-safe and supports multi-GPU by tracking modules per device
// device_id: PyTorch CUDA device ID (e.g., 0, 1, 2...)
static CUresult tilelang_init_cuda_module(const std::string& cubin_path, int device_id) {{
std::lock_guard lock(g_devices_mutex);
// Fast path: module already initialized for this device
if (g_device_modules.find(device_id) != g_device_modules.end()) {{
return CUDA_SUCCESS;
}}
CUresult result;
result = cuInit(0);
if (result != CUDA_SUCCESS) {{
std::cerr << "Failed to initialize CUDA: " << result << "\n";
return result;
}}
// Get device handle for the requested device_id
CUdevice device;
result = cuDeviceGet(&device, device_id);
if (result != CUDA_SUCCESS) {{
std::cerr << "Failed to get CUDA device " << device_id << ": " << result << "\n";
return result;
}}
// Retain and set the primary context for this device
// PyTorch (Runtime API) creates and activates the primary context
// We need to explicitly acquire it via Driver API and set it as current
CUcontext ctx;
result = cuDevicePrimaryCtxRetain(&ctx, device);
if (result != CUDA_SUCCESS) {{
std::cerr << "Failed to retain primary context for device " << device_id << ": " << result << "\n";
return result;
}}
result = cuCtxSetCurrent(ctx);
if (result != CUDA_SUCCESS) {{
std::cerr << "Failed to set current context for device " << device_id << ": " << result << "\n";
return result;
}}
// Store the retained context for cleanup
g_device_contexts[device_id] = ctx;
// Read cubin file
std::ifstream cubin_file(cubin_path.c_str(), std::ios::binary);
if (!cubin_file) {{
std::cerr << "Failed to open cubin file: " << cubin_path << "\n";
return CUDA_ERROR_FILE_NOT_FOUND;
}}
std::vector cubin_data((std::istreambuf_iterator(cubin_file)),
std::istreambuf_iterator());
cubin_file.close();
if (cubin_data.empty()) {{
std::cerr << "Empty cubin file: " << cubin_path << "\n";
return CUDA_ERROR_INVALID_IMAGE;
}}
// Load module for this specific device
CUmodule module;
result = cuModuleLoadData(&module, cubin_data.data());
if (result != CUDA_SUCCESS) {{
std::cerr << "Failed to load CUDA module on device " << device_id << ": " << result << "\n";
return result;
}}
// Store module for this device
g_device_modules[device_id] = module;
return CUDA_SUCCESS;
}}
// Initialize kernel functions for a specific device (called once per device)
// Thread-safe and supports multi-GPU by tracking kernels per device
static CUresult tilelang_init_kernels(int device_id) {{
std::lock_guard lock(g_devices_mutex);
// Fast path: kernels already initialized for this device
if (g_device_kernels.find(device_id) != g_device_kernels.end()) {{
return CUDA_SUCCESS;
}}
// Get the module for this device
auto module_it = g_device_modules.find(device_id);
if (module_it == g_device_modules.end()) {{
std::cerr << "Module not initialized for device " << device_id << "\n";
return CUDA_ERROR_NOT_INITIALIZED;
}}
CUmodule module = module_it->second;
// Initialize kernel storage for this device
std::vector kernels({num_kernels});
CUresult result;
{kernel_inits}
// Store kernels for this device
g_device_kernels[device_id] = kernels;
return CUDA_SUCCESS;
}}
// TMA descriptor initialization (host-side)
{tma_init_func}
// Main kernel launcher
extern "C" CUresult launch_kernel({launch_func_sig}, uint64_t _stream, int device_id, tvm::ffi::Bytes cubin_path) {{
CUresult result;
std::string cubin_path_str(reinterpret_cast(cubin_path.data()), cubin_path.size());
result = tilelang_init_cuda_module(cubin_path_str, device_id);
if (result != CUDA_SUCCESS) return result;
result = tilelang_init_kernels(device_id);
if (result != CUDA_SUCCESS) return result;
{get_ptr_code}
CUstream stream = (CUstream)_stream;
{tma_init_in_launch}
{kernel_launches}
return CUDA_SUCCESS;
}}
// Cleanup function
extern "C" CUresult cleanup_module() {{
std::lock_guard lock(g_devices_mutex);
CUresult last_error = CUDA_SUCCESS;
// Step 1: Unload modules for all devices
for (auto& pair : g_device_modules) {{
if (pair.second != nullptr) {{
CUresult result = cuModuleUnload(pair.second);
if (result != CUDA_SUCCESS) {{
std::cerr << "Failed to unload module for device " << pair.first
<< ": " << result << "\n";
last_error = result;
// Continue cleanup even if unload fails
}}
}}
}}
// Step 2: Release primary contexts (must execute even if module unload failed)
// This ensures the reference count is decremented for every cuDevicePrimaryCtxRetain
for (auto& pair : g_device_contexts) {{
int device_id = pair.first;
CUcontext ctx = pair.second;
if (ctx != nullptr) {{
CUdevice device;
CUresult result = cuDeviceGet(&device, device_id);
if (result == CUDA_SUCCESS) {{
result = cuDevicePrimaryCtxRelease(device);
if (result != CUDA_SUCCESS) {{
std::cerr << "Failed to release primary context for device "
<< device_id << ": " << result << "\n";
last_error = result;
}}
}} else {{
std::cerr << "Failed to get device " << device_id
<< " for context release: " << result << "\n";
last_error = result;
}}
}}
}}
// Step 3: Clear all maps
g_device_modules.clear();
g_device_kernels.clear();
g_device_contexts.clear();
return last_error;
}}
TVM_FFI_DLL_EXPORT_TYPED_FUNC(launch_kernel, launch_kernel);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(cleanup_module, cleanup_module);
"""
.. raw:: html
.. py:data:: CUBIN_TMA_ATOM_INIT_TEMPLATE
:value: ' {desc_name} = tl.Gemm_SM90.get_tma_atom(__fake_tensor__, (32, 32))'
.. py:data:: CUBIN_KERNEL_LAUNCH_TEMPLATE
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
""" {function_name}({call_args}).launch(
grid=[{grid_x}, {grid_y}, {grid_z}],
block=[{block_x}, {block_y}, {block_z}],
smem={smem_size},
stream=stream,
)"""
.. raw:: html
.. py:data:: CUBIN_FAKE_TENSOR_TEMPLATE
:value: ' __fake_{arg_name}__ = make_fake_compact_tensor(_DTYPE_MAP[str({arg_name}.dtype)],...
.. py:data:: CUBIN_GEN_CODE_TEMPLATE
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""{lib_code}
@cute.jit
def kernel_wrapper({wrapper_args}):
{tma_init_code}{kernel_launches}
# Compile kernels to generate cubin
{fake_tensor_code}{fake_tma_tensor_code} __fake_stream__ = make_fake_stream()
# Always generate cubin under a unique staging directory to avoid concurrent
# processes clobbering each other's intermediate artifacts.
_staging_dir = Path(tempfile.mkdtemp(
prefix=Path(__file__).stem + ".cubin.staging.",
dir=_module_dir,
))
try:
_kernel_wrapper = cute.compile(
kernel_wrapper,
{compile_args},
options=f"--enable-tvm-ffi --keep-cubin --dump-dir={{_staging_dir.as_posix()}}",
)
# CuTeDSL generates a long, mangled cubin filename that includes argument/type info,
# e.g. "cutlass_kernel_wrapper_FakeTensor...sm_90a.cubin". We expect exactly one cubin.
_cubin_files = sorted(_staging_dir.glob("*.cubin"), key=lambda p: p.stat().st_mtime)
if len(_cubin_files) != 1:
raise RuntimeError(
f"Expected exactly one .cubin under {{_staging_dir}}, got {{len(_cubin_files)}}: {{_cubin_files}}"
)
os.replace(_cubin_files[0], _cubin_path)
finally:
shutil.rmtree(_staging_dir, ignore_errors=True)"""
.. raw:: html
.. py:data:: PYTHON_HOST_FUNC_TEMPLATE
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""import os
from pathlib import Path
# Minimal imports for runtime (no cutlass/cute - only needed for cubin generation)
import tvm.runtime as runtime
_cpp_launcher = None
_cpp_launcher_lib = None
_cubin_generated = False
# Pre-compute paths - cubin is stored alongside the launcher .so
# Use module basename to avoid conflicts when multiple kernels run concurrently
# e.g., "/tmp/tmp8liu__ho.py" -> "/tmp/tmp8liu__ho.cubin"
# "kernel.py" (in cache) -> "kernel.cubin"
_module_dir = Path(os.path.dirname(__file__))
_cubin_path = _module_dir / (Path(__file__).stem + ".cubin")
_cubin_path_bytes = _cubin_path.as_posix().encode('utf-8')
_cubin_needs_generation = not _cubin_path.exists()
def _generate_cubin_if_needed({cubin_gen_params}):
"""Generate cubin file on first call.
All CuTeDSL imports are inside this function to avoid slow
module-level initialization when loading from cache.
"""
global _cubin_generated, _cubin_path
# Lazy import CuTeDSL only when cubin generation is needed
from cuda.bindings.driver import CUstream
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import make_fake_stream, make_fake_compact_tensor
import tilelang.contrib.cutedsl as tl
# We rely on CuTeDSL's keep-cubin artifact rather than custom extraction.
import tempfile
import shutil
_DTYPE_MAP = {{
"torch.float32": cutlass.Float32,
"torch.float16": cutlass.Float16,
"torch.bfloat16": cutlass.BFloat16,
"torch.float8_e4m3fnuz": cutlass.Float8E4M3FN,
"torch.float8_e4m3fn": cutlass.Float8E4M3FN,
"torch.float8_e5m2": cutlass.Float8E5M2,
"torch.float64": cutlass.Float64,
"torch.int64": cutlass.Int64,
"torch.int32": cutlass.Int32,
"torch.uint32": cutlass.Uint32,
"torch.bool": cutlass.Boolean,
"torch.int8": cutlass.Int8,
"torch.uint8": cutlass.Uint8,
"torch.int16": cutlass.Int16,
"torch.uint16": cutlass.Uint16,
"torch.uchar": cutlass.Uint8,
}}
{cubin_gen_code}
_cubin_generated = True
def _load_cpp_launcher():
"""Load C++ kernel launcher."""
global _cpp_launcher, _cpp_launcher_lib
if _cpp_launcher is not None:
return _cpp_launcher
lib_path = os.path.join(os.path.dirname(__file__), "{launcher_lib_name}")
if not os.path.exists(lib_path):
raise FileNotFoundError(f"Launcher not found: {{lib_path}}")
_cpp_launcher_lib = runtime.load_module(lib_path)
_cpp_launcher = _cpp_launcher_lib["launch_kernel"]
return _cpp_launcher
def call({call_func_params}, stream, device_id=0):
"""Kernel dispatch function.
Args:
stream: CUDA stream handle
device_id: CUDA device ID (should be passed from caller, defaults to 0 for backward compatibility)
"""
global _cubin_path_bytes, _cubin_needs_generation
if _cubin_needs_generation:
_generate_cubin_if_needed({cubin_gen_call_args})
_cubin_needs_generation = False
{arg_prep_code}
launcher = _load_cpp_launcher()
result = launcher({launcher_call_args}, stream, device_id, _cubin_path_bytes)
if result != 0:
raise RuntimeError(f"Kernel launch failed with CUDA error {{result}}")
"""
.. raw:: html
.. py:class:: TLCuTeDSLSourceWrapper(scheduled_ir_module, source, target, device_mod = None, host_mod = None, pass_configs = None)
Bases: :py:obj:`tilelang.jit.adapter.wrapper.TLCUDASourceWrapper`
Wrapper class for TileLang CuTe DSL backend with C++ launcher.
Generates optimized C++ launcher code that:
- Loads cubin via CUDA Driver API
- Passes TMA descriptors by value (host-side, no device copy)
- Launches kernels with minimal Python overhead
- Supports both single and multiple kernel scenarios
.. py:property:: host_func
Override parent's host_func to return generated Python code.
.. py:method:: generate_tma_descriptor_args(desc_name_map, desc_name_var_map, tma_desc_code_map)
Generate TMA descriptor information for C++ code generation.
:returns: List of descriptor variable names in the order they were processed.
.. py:method:: create_dispatch_func(code, function_informations)
Create dispatch function - always use C++ launcher.
.. py:method:: create_dispatch_func_cpp_launcher(code, function_informations)
Create dispatch function using C++ launcher.
.. py:method:: get_launcher_cpp_code()
Get the generated C++ launcher code.
.. py:method:: update_lib_code(code)
Update the library code with the given code string.