tilelang.jit.adapter.wrapper

Attributes

類別

BaseWrapper

Helper class that provides a standard way to create an ABC using

TLCUDASourceWrapper

TLHIPSourceWrapper

A wrapper class for the TileLang HIP backend.

TLCPUSourceWrapper

TLMetalSourceWrapper

TLWrapper

A wrapper class for the TileLang backend.

TLPyWrapper

A wrapper class for the TileLang backend.

Module Contents

tilelang.jit.adapter.wrapper.PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY = Multiline-String
Show Value
"""
    cudaError_t result_{0} = cudaFuncSetAttribute({0}, cudaFuncAttributeMaxDynamicSharedMemorySize, {1});
    if (result_{0} != cudaSuccess) {{
        snprintf(error_buf, ERROR_BUF_SIZE, "Failed to set the allowed dynamic shared memory size to %d with error: %s", {1}, cudaGetErrorString(result_{0}));
        return -1;
    }}
"""
tilelang.jit.adapter.wrapper.PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY_HIP = Multiline-String
Show Value
"""
    int device_{0} = 0;
    hipError_t dev_res_{0} = hipGetDevice(&device_{0});
    if (dev_res_{0} != hipSuccess) {{
        snprintf(error_buf, ERROR_BUF_SIZE, "Failed to get HIP device for {0}: %s", hipGetErrorString(dev_res_{0}));
        return -1;
    }}
    int max_smem_{0} = 0;
    hipError_t attr_res_{0} = hipDeviceGetAttribute(&max_smem_{0}, hipDeviceAttributeMaxSharedMemoryPerBlock, device_{0});
    if (attr_res_{0} != hipSuccess || max_smem_{0} <= 0) {{
        snprintf(error_buf, ERROR_BUF_SIZE, "Failed to query HIP max shared memory for {0}: %s", hipGetErrorString(attr_res_{0}));
        return -1;
    }}
    if ({1} > max_smem_{0}) {{
        snprintf(
            error_buf,
            ERROR_BUF_SIZE,
            "Requested dynamic shared memory %d exceeds device limit %d for {0}",
            {1},
            max_smem_{0}
        );
        return -1;
    }}
    return 0;
"""
tilelang.jit.adapter.wrapper.PREDEF_INIT_FUNC = Multiline-String
Show Value
"""
#define ERROR_BUF_SIZE 1024
static char error_buf[ERROR_BUF_SIZE];

extern "C" const char* get_last_error() {{
    return error_buf;
}}

extern "C" int init() {{
    error_buf[0] = '\0';
    {0}
    return 0;
}}
"""
tilelang.jit.adapter.wrapper.PREDEF_HOST_FUNC = Multiline-String
Show Value
"""
extern "C" int call({}) {{
{}
  return 0;
}}
"""
tilelang.jit.adapter.wrapper.L2_PERSISTENT_MAP_CREATE_HANDLE = Multiline-String
Show Value
"""
  cudaStreamAttrValue stream_attribute;
  size_t init_persisting_l2_cache_size;
  cudaDeviceGetLimit(&init_persisting_l2_cache_size, cudaLimitPersistingL2CacheSize);
"""
tilelang.jit.adapter.wrapper.L2_PERSISTENT_MAP_INIT_FUNC = Multiline-String
Show Value
"""
  stream_attribute.accessPolicyWindow.hitRatio = {1};
  stream_attribute.accessPolicyWindow.hitProp = cudaAccessPropertyPersisting;
  stream_attribute.accessPolicyWindow.missProp = cudaAccessPropertyStreaming;
  cudaDeviceSetLimit(cudaLimitPersistingL2CacheSize, {2});
  stream_attribute.accessPolicyWindow.base_ptr = (void*)({0});
  stream_attribute.accessPolicyWindow.num_bytes = {2};
  cudaStreamSetAttribute(stream, cudaStreamAttributeAccessPolicyWindow, &stream_attribute);
"""
tilelang.jit.adapter.wrapper.L2_PERSISTENT_MAP_RESET_HANDLE = Multiline-String
Show Value
"""
  stream_attribute.accessPolicyWindow.num_bytes = 0;
  cudaStreamSetAttribute(stream, cudaStreamAttributeAccessPolicyWindow, &stream_attribute);
  cudaCtxResetPersistingL2Cache();
  cudaDeviceSetLimit(cudaLimitPersistingL2CacheSize, init_persisting_l2_cache_size);
"""
tilelang.jit.adapter.wrapper.TMA_DESC_INIT_FUNC = Multiline-String
Show Value
"""
  CUtensorMap {0};
  CUtensorMapDataType {0}_type= (CUtensorMapDataType){1};
  cuuint32_t {0}_tensorRank= {2};
  void *{0}_globalAddress= {3};
  cuuint64_t {0}_globalDim[{2}]= {{{4}}};
  cuuint64_t {0}_globalStride[{2}]= {{{5}}};
  cuuint32_t {0}_boxDim[{2}]= {{{6}}};
  cuuint32_t {0}_elementStrides[{2}]= {{{7}}};
  CUtensorMapInterleave {0}_interleave= (CUtensorMapInterleave){8};
  CUtensorMapSwizzle {0}_swizzle= (CUtensorMapSwizzle){9};
  CUtensorMapL2promotion {0}_l2Promotion= (CUtensorMapL2promotion){10};
  CUtensorMapFloatOOBfill {0}_oobFill= (CUtensorMapFloatOOBfill){11};

  CUresult {0}_result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)(
    &{0}, {0}_type, {0}_tensorRank, {0}_globalAddress, {0}_globalDim, {0}_globalStride + 1, {0}_boxDim, {0}_elementStrides, {0}_interleave, {0}_swizzle, {0}_l2Promotion, {0}_oobFill);

  if ({0}_result != CUDA_SUCCESS) {{
          std::stringstream ss;
          ss << "Error: Failed to initialize the TMA descriptor {0}";
          snprintf(error_buf, ERROR_BUF_SIZE, "%s", ss.str().c_str());
          return -1;
  }}
"""
tilelang.jit.adapter.wrapper.TMA_IM2COL_DESC_INIT_FUNC = Multiline-String
Show Value
"""
  CUtensorMap {0};
  CUtensorMapDataType {0}_type= (CUtensorMapDataType){1};
  cuuint32_t {0}_tensorRank= {2};
  void *{0}_globalAddress= {3};
  cuuint64_t {0}_globalDim[{2}]= {{{4}}};
  cuuint64_t {0}_globalStride[{2}]= {{{5}}};
  cuuint32_t {0}_elementStrides[{2}]= {{{6}}};
  int {0}_lowerCorner[{2} - 2]= {{{7}}};
  int {0}_upperCorner[{2} - 2]= {{{8}}};
  cuuint32_t {0}_channelsPerPixel= {9};
  cuuint32_t {0}_pixelsPerColumn= {10};
  CUtensorMapInterleave {0}_interleave= (CUtensorMapInterleave){11};
  CUtensorMapSwizzle {0}_swizzle= (CUtensorMapSwizzle){12};
  CUtensorMapL2promotion {0}_l2Promotion= (CUtensorMapL2promotion){13};
  CUtensorMapFloatOOBfill {0}_oobFill= (CUtensorMapFloatOOBfill){14};

  CUresult {0}_result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeIm2col)(
    &{0}, {0}_type, {0}_tensorRank, {0}_globalAddress, {0}_globalDim, {0}_globalStride + 1,
    {0}_lowerCorner, {0}_upperCorner, {0}_channelsPerPixel, {0}_pixelsPerColumn, {0}_elementStrides, {0}_interleave, {0}_swizzle, {0}_l2Promotion, {0}_oobFill);

  if ({0}_result != CUDA_SUCCESS) {{
          std::stringstream ss;
          ss << "Error: Failed to initialize the TMA descriptor {0}";
          snprintf(error_buf, ERROR_BUF_SIZE, "%s", ss.str().c_str());
          return -1;
  }}
"""
tilelang.jit.adapter.wrapper.KERNEL_LAUNCH_FUNC_CODE = Multiline-String
Show Value
"""
  {{
          cudaLaunchConfig_t config;
          cudaLaunchAttribute attribute[1];
          attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
          attribute[0].val.programmaticStreamSerializationAllowed = 1;
          config.attrs = attribute;
          config.numAttrs = 1;
          config.stream = stream;
          config.gridDim = {0};
          config.blockDim = {1};
          config.dynamicSmemBytes = {2};
          cudaLaunchKernelEx(&config, {4}, {3});
  }}
"""
tilelang.jit.adapter.wrapper.KERNEL_CLUSTER_LAUNCH_FUNC_CODE = Multiline-String
Show Value
"""
  {{
          cudaLaunchConfig_t config;
          cudaLaunchAttribute attribute[2];
          attribute[0].id = cudaLaunchAttributeClusterDimension;
          attribute[0].val.clusterDim = {{{5}, {6}, {7}}};
          attribute[1].id = cudaLaunchAttributeProgrammaticStreamSerialization;
          attribute[1].val.programmaticStreamSerializationAllowed = 1;
          config.attrs = attribute;
          config.numAttrs = 2;
          config.stream = stream;
          config.gridDim = {0};
          config.blockDim = {1};
          config.dynamicSmemBytes = {2};
          cudaError_t cluster_attr_result = cudaFuncSetAttribute({4}, cudaFuncAttributeNonPortableClusterSizeAllowed, 1);
          if (cluster_attr_result != cudaSuccess) {{
                  snprintf(error_buf, ERROR_BUF_SIZE, "Failed to set cluster attribute for {4}: %s", cudaGetErrorString(cluster_attr_result));
                  return -1;
          }}
          cudaLaunchKernelEx(&config, {4}, {3});
  }}
"""
class tilelang.jit.adapter.wrapper.BaseWrapper

Bases: abc.ABC

Helper class that provides a standard way to create an ABC using inheritance.

abstractmethod wrap(*args, **kwargs)
tilelang.jit.adapter.wrapper.logger
class tilelang.jit.adapter.wrapper.TLCUDASourceWrapper(scheduled_ir_module, source, target, device_mod=None, host_mod=None, pass_configs=None)
參數:
  • scheduled_ir_module (tvm.IRModule)

  • source (str)

  • target (tvm.target.Target)

  • device_mod (tvm.IRModule | None)

  • host_mod (tvm.IRModule | None)

  • pass_configs (dict[str, Any] | None)

backend = 'tl'
device_mod: tvm.IRModule | None = None
host_mod: tvm.IRModule | None = None
pass_configs: dict[str, Any] | None = None
mod
target
source
function_names: str | None = None
dynamic_smem_buf: int | None = None
block_info: list[int] | dict = [1, 1, 1]
grid_info: list[int] | dict = [1, 1, 1]
tma_descriptor_args: dict | None = None
l2_persistent_map: dict[str, dict] | None
pdl_sync_map: dict[str, int] | None
srcpath: str | None = None
libpath: str | None = None
lib_code: str | None
is_tma_descriptor_arg(arg_name)
參數:

arg_name (str)

回傳型別:

bool

create_dispatch_func(code, function_informations)
get_declaration(declare_kernel_code)
參數:

declare_kernel_code (str)

回傳型別:

str

generate_l2_persistent_map(function_name)
參數:

function_name (str)

回傳型別:

str

generate_tma_descriptor_args(desc_name_map, desc_name_var_map)
參數:
  • desc_name_map (dict[str, str])

  • desc_name_var_map (dict[str, tilelang.tvm.tir.Var])

回傳型別:

str

parse_source_information()
get_dynamic_symbolic_set(prim_func)
get_init_func()
update_lib_code(code)
參數:

code (str)

get_stream_type()
回傳型別:

dict[str, str]

property prim_func
property device_func
property host_func
class tilelang.jit.adapter.wrapper.TLHIPSourceWrapper(scheduled_ir_module, source, target, device_mod=None, host_mod=None, pass_configs=None)

Bases: TLCUDASourceWrapper

A wrapper class for the TileLang HIP backend.

參數:
  • scheduled_ir_module (tvm.IRModule)

  • source (str)

  • target (tvm.target.Target)

  • device_mod (tvm.IRModule | None)

  • host_mod (tvm.IRModule | None)

  • pass_configs (dict[str, Any] | None)

get_declaration(declare_kernel_code)
參數:

declare_kernel_code (str)

回傳型別:

str

get_init_func()
get_stream_type()
回傳型別:

dict[str, str]

class tilelang.jit.adapter.wrapper.TLCPUSourceWrapper(scheduled_ir_module, source, target, device_mod=None, host_mod=None, pass_configs=None)
參數:
  • scheduled_ir_module (tvm.IRModule)

  • source (str)

  • target (tvm.target.Target)

  • device_mod (tvm.IRModule | None)

  • host_mod (tvm.IRModule | None)

  • pass_configs (dict[str, Any] | None)

INIT_FUNC = Multiline-String
Show Value
"""
#define ERROR_BUF_SIZE 1024
static char error_buf[ERROR_BUF_SIZE];

extern "C" const char* get_last_error() {
    return error_buf;
}

extern "C" int init() {
    error_buf[0] = '\0';

    return 0;
}
"""
CALL_PREFIX
backend = 'tl'
device_mod: tvm.IRModule | None = None
host_mod: tvm.IRModule | None = None
pass_configs: dict[str, Any] | None = None
mod
target
source
function_names: str | None = None
dynamic_smem_buf: int | None = None
srcpath: str | None = None
libpath: str | None = None
lib_code: str | None
create_call_func(code, function_informations)
parse_source_information()
get_dynamic_symbolic_set(prim_func)
get_cpu_init_func()
update_lib_code(code)
參數:

code (str)

property prim_func
class tilelang.jit.adapter.wrapper.TLMetalSourceWrapper(scheduled_ir_module, source, target, device_mod=None, host_mod=None, pass_configs=None)
參數:
  • scheduled_ir_module (tvm.IRModule)

  • source (str)

  • target (tvm.target.Target)

  • device_mod (tvm.IRModule | None)

  • host_mod (tvm.IRModule | None)

  • pass_configs (dict[str, Any] | None)

mod
target
source
pass_configs = None
device_mod = None
host_mod = None
lib_code
update_lib_code(code)
參數:

code (str)

class tilelang.jit.adapter.wrapper.TLWrapper(target)

Bases: BaseWrapper

A wrapper class for the TileLang backend.

參數:

target (tvm.target.Target)

device_mod: tvm.IRModule | None = None
host_mod: tvm.IRModule | None = None
pass_configs: dict[str, Any] | None = None
target: tvm.target.Target | None = None
lib: object | None = None
scheduled_ir_module = None
assign_optimized_module(scheduled_ir_module)
參數:

scheduled_ir_module (tvm.IRModule)

assign_pass_configs(pass_configs)
參數:

pass_configs (dict[str, Any])

assign_host_module(host_mod)
參數:

host_mod (tvm.IRModule)

assign_device_module(device_mod)
參數:

device_mod (tvm.IRModule)

wrap(c_source)
參數:

c_source (str)

class tilelang.jit.adapter.wrapper.TLPyWrapper(target)

Bases: TLWrapper

A wrapper class for the TileLang backend.

參數:

target (tvm.target.Target)

wrap(py_source)
參數:

py_source (str)