tilelang.language.builtin ========================= .. py:module:: tilelang.language.builtin .. autoapi-nested-parse:: Builtin operations exposed on the TileLang language surface. Functions --------- .. autoapisummary:: tilelang.language.builtin.access_ptr tilelang.language.builtin.create_tma_descriptor tilelang.language.builtin.tma_load tilelang.language.builtin.fence_proxy_async tilelang.language.builtin.tma_store_arrive tilelang.language.builtin.tma_store_wait tilelang.language.builtin.set_max_nreg tilelang.language.builtin.inc_max_nreg tilelang.language.builtin.dec_max_nreg tilelang.language.builtin.annotate_producer_reg_dealloc tilelang.language.builtin.annotate_consumer_reg_alloc tilelang.language.builtin.no_set_max_nreg tilelang.language.builtin.disable_warp_group_reg_alloc tilelang.language.builtin.ptx_arrive_cluster_barrier tilelang.language.builtin.mbarrier_wait_parity tilelang.language.builtin.mbarrier_arrive tilelang.language.builtin.mbarrier_expect_tx tilelang.language.builtin.warpgroup_arrive tilelang.language.builtin.warpgroup_commit_batch tilelang.language.builtin.warpgroup_wait tilelang.language.builtin.get_lane_idx tilelang.language.builtin.get_warp_idx_sync tilelang.language.builtin.get_warp_idx tilelang.language.builtin.get_warp_group_idx tilelang.language.builtin.shuffle_elect tilelang.language.builtin.warpgroup_fence_operand tilelang.language.builtin.wait_wgmma tilelang.language.builtin.barrier_wait tilelang.language.builtin.barrier_arrive tilelang.language.builtin.shfl_xor tilelang.language.builtin.shfl_down tilelang.language.builtin.shfl_up tilelang.language.builtin.sync_threads tilelang.language.builtin.sync_warp tilelang.language.builtin.shfl_sync tilelang.language.builtin.sync_global tilelang.language.builtin.sync_grid tilelang.language.builtin.initialize_wgmma_descriptor tilelang.language.builtin.initialize_tcgen05_descriptor tilelang.language.builtin.increase_descriptor_offset tilelang.language.builtin.loop_break tilelang.language.builtin.cp_async_barrier_noinc tilelang.language.builtin.tcgen05_mma_arrive tilelang.language.builtin.ptx_mma_sm70 tilelang.language.builtin.ldg32 tilelang.language.builtin.ldg64 tilelang.language.builtin.ldg128 tilelang.language.builtin.ldg256 tilelang.language.builtin.stg32 tilelang.language.builtin.stg64 tilelang.language.builtin.stg128 tilelang.language.builtin.stg256 Module Contents --------------- .. py:function:: access_ptr(base, access_type = 'r', *extents, offset = 0, extent = None, ignore_last_ndim = 0) Create a TileLang `tl.access_ptr` from a buffer-like base location. This is a frontend convenience wrapper that keeps a `BufferLoad` argument in the resulting call so downstream passes can recover the referenced `tir.Buffer` (including strides/storage scope) *and* the `rw_mask` (read/write intent) required by synchronization and safety checks. The returned `tl.access_ptr` is expected to be lowered to `tir.builtin.tvm_access_ptr` later in the TileLang compilation pipeline. :param base: The base location to take the address of. Supported: - `tir.BufferLoad` (e.g. `A[i, j]`): pointer to that element - `tir.BufferRegion`: pointer to the region minima - `tir.Buffer`: pointer to the beginning of the buffer - `tir.Var` with let-binding to one of the above (inside TileLang frame) :type base: BufferLikeType :param access_type: Access mask for the pointer. Common string forms: `"r"`, `"w"`, `"rw"`. Integer bitmask is also accepted (1=read, 2=write, 3=read-write). :type access_type: str | int :param \*extents: Optional per-axis extents. When provided and `extent` is not specified, the 1D `extent` passed to `tvm_access_ptr` is computed as the product of the provided extents (padding leading dimensions with 1 if needed). For example: - `T.access_ptr(A[i], "r")` -> extent defaults to 1 (element pointer) - `T.access_ptr(A[i], "r", 16)` -> extent=16 - `T.access_ptr(A[i, j], "r", m, n)` -> extent=m*n :type \*extents: PrimExpr | int :param offset: Additional element offset from the base location. :type offset: PrimExpr | int :param extent: Optional explicit 1D extent override (in elements). If provided, it takes precedence over `*extents`. :type extent: PrimExpr | int | None :param ignore_last_ndim: If non-zero, the base linear offset is computed only over the leading dimensions, ignoring the last `ignore_last_ndim` axes. This is useful when treating an N-D buffer as a view of its trailing sub-tensor. :type ignore_last_ndim: int :returns: **ptr** -- A handle-typed `tir.Call` to `tl.access_ptr`. :rtype: PrimExpr .. py:function:: create_tma_descriptor(*args) Create a Tensor Memory Access (TMA) descriptor. :param \*args: Variable arguments defining the TMA descriptor configuration :returns: A handle to the created TMA descriptor :rtype: tir.Call .. py:function:: tma_load(*args) Perform a Tensor Memory Access (TMA) load operation. :param \*args: Variable arguments specifying the TMA load parameters :returns: A handle to the TMA load operation :rtype: tir.Call .. py:function:: fence_proxy_async(*args) Create a fence for asynchronous proxy operations. :param \*args: Variable arguments for fence configuration :returns: A handle to the fence operation :rtype: tir.Call .. py:function:: tma_store_arrive(*args) Signal the arrival of a TMA store operation. :param \*args: Variable arguments for the store arrival operation :returns: A handle to the store arrive operation :rtype: tir.Call .. py:function:: tma_store_wait(*args) Wait for completion of TMA store operations. :param \*args: Variable arguments specifying which store operations to wait for :returns: A handle to the store wait operation :rtype: tir.Call .. py:function:: set_max_nreg(reg_count, is_inc) Set the maximum number of registers to use. Detailed Documentation: https://docs.nvidia.com/cuda/parallel-thread-execution/#miscellaneous-instructions-setmaxnreg :param reg_count: int The number of registers to allocate :param is_inc: int Whether to increment or decrement the register count 0 if decrement, 1 if increment :returns: A handle to the register setting operation :rtype: tir.Call .. py:function:: inc_max_nreg(reg_count) Increment the maximum number of registers to use. .. py:function:: dec_max_nreg(reg_count) Decrement the maximum number of registers to use. .. py:function:: annotate_producer_reg_dealloc(reg_count = 24) Annotate the producer reg dealloc. .. py:function:: annotate_consumer_reg_alloc(reg_count = 240) Annotate the consumer reg alloc. .. py:function:: no_set_max_nreg() Disable the maximum register limit setting. .. py:function:: disable_warp_group_reg_alloc() Disable the warp group reg alloc. .. py:function:: ptx_arrive_cluster_barrier(mbarrier, cta_id) Arrive at a shared barrier in cluster. :param mbarrier: BarrierType The memory barrier to arrive at :param cta_id: int | Var The peer CTA rank in cluster to arrive at. .. py:function:: mbarrier_wait_parity(mbarrier, parity) Wait for memory barrier parity condition. :param mbarrier: BarrierType :param The memory barrier to wait on: parity: int | Var The parity value to wait for .. rubric:: 範例 .. code-block:: python mbar = T.alloc_barrier(1) # Wait for parity 0 on a single mbarrier T.mbarrier_wait_parity(mbar, 0) mbars = T.alloc_barrier([128] * n) # Wait for parity value on one of the mbarriers T.mbarrier_wait_parity(mbars[ko], ko) # Common usage in pipelined kernels: for ko in range(num_stages): # Producer waits for consumer to finish previous iteration T.mbarrier_wait_parity(mbars[1], ko ^ 1) # Producer copies data T.copy(A_global, A_shared) # Producer signals data ready T.mbarrier_arrive(mbars[0]) # Consumer waits for producer data T.mbarrier_wait_parity(mbars[0], ko) # Consumer computes T.gemm(A_shared, B_shared, C_local) # Consumer signals completion T.mbarrier_arrive(mbars[1]) :returns: A handle to the barrier wait operation :rtype: tir.Call .. py:function:: mbarrier_arrive(mbarrier, cta_id = None) Arrive at memory barrier. :param mbarrier: BarrierType The memory barrier to arrive at :param cta_id: int | Var | None The peer CTA rank in cluster to arrive at. (Only valid for cluster barriers) If not provided, will arrive on current CTA's barrier. .. py:function:: mbarrier_expect_tx(mbarrier, tx) Set expected transaction count for memory barrier. :param mbarrier: BarrierType The memory barrier to expect transaction count for :param tx: int The expected transaction count :returns: A handle to the barrier expectation operation :rtype: tir.Call .. py:function:: warpgroup_arrive() Signal warpgroup readiness for subsequent WGMMA operations. :returns: A handle to the warpgroup arrive operation. :rtype: tir.Call .. py:function:: warpgroup_commit_batch() Commit the current warpgroup batch for WGMMA operations. :returns: A handle to the warpgroup commit batch operation. :rtype: tir.Call .. py:function:: warpgroup_wait(num_mma) Wait for completion of the specified warpgroup batch. :param num_mma: int Identifier of the warpgroup MMA batch to wait on. :returns: A handle to the warpgroup wait operation. :rtype: tir.Call .. py:function:: get_lane_idx(warp_size = None) Return the logical lane index of the calling thread within a warp. :param warp_size: Logical warp (or wavefront) size. Defaults to 32 on NVIDIA and 64 on AMD. :type warp_size: Optional[int, PrimExpr] .. rubric:: 範例 >>> lane = T.get_lane_idx() >>> custom_lane = T.get_lane_idx(64) # override warp size explicitly Implementation Notes -------------------- Lowers to the CUDA helper `tl::get_lane_idx(warp_size)` defined in `src/tl_templates/cuda/intrin.h`, which computes the lane index from the linear thread id using the provided `warp_size`. .. py:function:: get_warp_idx_sync(warp_size = None) Return the canonical warp index, assuming the warp's threads are converged. :param warp_size: Logical warp size used for the index calculation. :type warp_size: Optional[int, PrimExpr] .. rubric:: 範例 >>> warp = T.get_warp_idx_sync() >>> custom_warp = T.get_warp_idx_sync(64) Implementation Notes -------------------- Emits `tl::get_warp_idx_sync(warp_size)` which divides the block-linear thread id by `warp_size`, matching the semantics of CUTLASS' canonical helpers. .. py:function:: get_warp_idx(warp_size = None) Return the canonical warp index without synchronizing the warp. :param warp_size: Logical warp size used for the index calculation. :type warp_size: Optional[int, PrimExpr] .. rubric:: 範例 >>> warp = T.get_warp_idx() >>> custom_warp = T.get_warp_idx(64) Implementation Notes -------------------- Lowers to `tl::get_warp_idx(warp_size)` which divides the block-linear thread id by the provided `warp_size` without requiring warp convergence. .. py:function:: get_warp_group_idx(warp_size = None, warps_per_group = None) Return the canonical warp group index for the calling thread. :param warp_size: Logical warp size to use (defaults to 32 on NVIDIA / 64 on AMD). :type warp_size: Optional[int, PrimExpr] :param warps_per_group: Number of warps per warp-group. Defaults to 4 on NVIDIA architectures. :type warps_per_group: Optional[int, PrimExpr] .. rubric:: 範例 >>> group = T.get_warp_group_idx() >>> custom_group = T.get_warp_group_idx(32, 6) # treat 6 warps as a group Implementation Notes -------------------- Generates `tl::get_warp_group_idx(warp_size, warps_per_group)` which divides the block-linear thread id by `warp_size * warps_per_group`, matching the canonical ordering while allowing architecture-specific overrides. .. py:function:: shuffle_elect(thread_extent) Elect exactly one lane within a logical thread group. :param thread_extent: Size (in threads) of the group in which a single lane should be elected. Passing 0 elects a single lane in the entire thread block. :type thread_extent: int .. rubric:: 範例 >>> is_leader = T.shuffle_elect(64) >>> T.if_then_else(is_leader, do_leader_work(), T.evaluate(0)) Implementation Notes -------------------- Lowered to the CUDA helper `tl::tl_shuffle_elect()` defined in `src/tl_templates/cuda/intrin.h`, which relies on `cutlass::canonical_warp_idx_sync()` and `cute::elect_one_sync()` (or `__shfl_sync`) to pick one lane per group. .. py:function:: warpgroup_fence_operand(buffer_or_ptr, offset = 0, num_regs = None, dtype = None) Insert a warpgroup fence for the destination accumulator registers. This prevents NVCC from sinking uses of accumulator fragments past the corresponding WGMMA operations by issuing an empty inline assembly barrier on every register. :param buffer_or_ptr: BufferLikeType | PrimExpr A buffer representing the accumulator fragment, a buffer load/region that identifies a starting element within the fragment, or a pointer expression (e.g., tvm_access_ptr/address_of/typed Var). :param offset: int | PrimExpr Element offset from the start of the accumulator fragment. :param num_regs: int | PrimExpr | None Number of 32-bit registers to fence. If None and a Buffer is provided, it will be derived from the buffer shape and dtype. :param dtype: DType | None Data type string of the accumulator elements. When passing a buffer or buffer-derived expression, dtype is inferred. It is required only when passing a raw pointer expression that cannot be inferred. :returns: A handle to the warpgroup fence operation. :rtype: tir.Call .. py:function:: wait_wgmma(id) Wait for WGMMA (Warp Group Matrix Multiply-Accumulate) operations to complete. :param id: int The id of the WGMMA operation to wait for :returns: A handle to the WGMMA wait operation :rtype: tir.Call .. py:function:: barrier_wait(mbarrier, parity) Wait for a memory barrier to complete. :param mbarrier: BarrierType The memory barrier to wait on :param parity: int | Var The parity value to wait for :returns: A handle to the barrier wait operation :rtype: tir.Call Current implementation is a sugar syntax for mbarrier_wait_parity, as we only support parity 0 and 1. .. py:function:: barrier_arrive(mbarrier) Arrive at a memory barrier. :param mbarrier: BarrierType The memory barrier to arrive at .. py:function:: shfl_xor(value, offset) Perform a shuffle operation with XOR offset. :param value: Optional[int, PrimExpr] The value to shuffle :param offset: Optional[int, PrimExpr] The offset for the shuffle operation :returns: A handle to the shuffle operation :rtype: tir.Call .. py:function:: shfl_down(value, offset) Perform a shuffle operation with down offset. :param value: Optional[int, PrimExpr] The value to shuffle .. py:function:: shfl_up(value, offset) Perform a shuffle operation with up offset. :param value: Optional[int, PrimExpr] The value to shuffle .. py:function:: sync_threads(barrier_id = None, arrive_count = None) Synchronize all threads in a block. .. py:function:: sync_warp(mask = None) Synchronize all threads in a warp. .. py:function:: shfl_sync(mask, value, srcLane, width = None) Receives data from a thread in the same warp. .. py:function:: sync_global() Synchronize all threads in the entire grid. .. py:function:: sync_grid() Synchronize all threads in a grid. .. py:function:: initialize_wgmma_descriptor(descriptor, start_address, layout_type_ = 0, leading_byte_offset = 0, stride_byte_offset = 0) Initialize a WGMMA/UTCMMA shared-memory descriptor. .. py:function:: initialize_tcgen05_descriptor(descriptor, start_address, leading_byte_offset, stride_byte_offset, base_offset = 0, leading_is_absolute = False, swizzle_mode = 0) Initialize a TCGEN05 shared-memory descriptor. .. py:function:: increase_descriptor_offset(descriptor, offset) Increase the offset of a memory descriptor. :param descriptor: The memory descriptor to modify. :type descriptor: PrimExpr :param offset: The offset value to increase. :type offset: PrimExpr :returns: A handle representing the modified descriptor. :rtype: PrimExpr .. py:function:: loop_break() Break out of the innermost loop. .. py:function:: cp_async_barrier_noinc(barrier) Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc. .. py:function:: tcgen05_mma_arrive(mbar) Signal UMMA (TCGEN05) barrier arrival for a shared-memory mbarrier pointer. :param mbar: The mbarrier object in shared memory (e.g., Barrier*) or its address. :type mbar: tir.Buffer | BufferLoad | PrimExpr .. py:function:: ptx_mma_sm70(shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, multiplicand_a, a_index, multiplicand_b, b_index, accumulator, c_index) TVM intrinsic for ptx tensor core mma instructions on SM70 (Volta). This intrinsic provides SM70-specific MMA operations that support m16n16k4 shape with FP16 inputs and FP16/FP32 accumulation. :param shape: The shape of mma fragment (e.g., "m16n16k4"). :type shape: str :param A_layout: The layout of multiplicand fragment A ("row" or "col"). :type A_layout: str :param B_layout: The layout of multiplicand fragment B ("row" or "col"). :type B_layout: str :param A_dtype: The data type of multiplicand fragment A (typically "fp16"). :type A_dtype: str :param B_dtype: The data type of multiplicand fragment B (typically "fp16"). :type B_dtype: str :param C_dtype: The data type of accumulator fragment C ("fp16" or "fp32"). :type C_dtype: str :param multiplicand_a: The multiplicand fragment A variable. :type multiplicand_a: Var :param a_index: The index of multiplicand fragment A. :type a_index: Expr :param multiplicand_b: The multiplicand fragment B variable. :type multiplicand_b: Var :param b_index: The index of multiplicand fragment B. :type b_index: Expr :param accumulator: The accumulator fragment C variable. :type accumulator: Var :param c_index: The index of accumulator fragment C. :type c_index: Expr :returns: **call** -- The call expression. :rtype: PrimExpr .. rubric:: 範例 >>> T.ptx_mma_sm70( ... "float16", ... "m16n16k4", ... "row", ... "col", ... "fp16", ... "fp16", ... "fp16", ... A_local.data, ... 0, ... B_local.data, ... 0, ... C_local.data, ... 0, ... ) .. py:function:: ldg32(src, pred = None) Load 32 bits (4 bytes) from global memory using explicit PTX instructions. Usage: `T.ldg32(x[i])` or `T.ldg32(x[i:i+2])` emits `tl::ldg32(ptr)`. :param src: A `Buffer`, `BufferRegion`, or `BufferLoad`. :param pred: Optional predicate condition. If False, the load is skipped. :returns: The loaded 32-bit value. :rtype: PrimExpr .. rubric:: 範例 >>> val = T.ldg32(x[i]) >>> val = T.ldg32(x[i:i+2]) # load 2 x fp16 >>> val = T.ldg32(x[i], pred=i < N) # predicated load .. py:function:: ldg64(src, pred = None) Load 64 bits (8 bytes) from global memory using explicit PTX instructions. Usage: `T.ldg64(x[i])` or `T.ldg64(x[i:i+4])` emits `tl::ldg64(ptr)`. :param src: A `Buffer`, `BufferRegion`, or `BufferLoad`. :param pred: Optional predicate condition. If False, the load is skipped. :returns: The loaded 64-bit value. :rtype: PrimExpr .. rubric:: 範例 >>> val = T.ldg64(x[i]) >>> val = T.ldg64(x[i:i+4]) # load 4 x fp16 >>> val = T.ldg64(x[i], pred=i < N) # predicated load .. py:function:: ldg128(src, pred = None) Load 128 bits (16 bytes) from global memory using explicit PTX instructions. Usage: `T.ldg128(x[i])` or `T.ldg128(x[i:i+8])` emits `tl::ldg128(ptr)`. :param src: A `Buffer`, `BufferRegion`, or `BufferLoad`. :param pred: Optional predicate condition. If False, the load is skipped. :returns: The loaded 128-bit value. :rtype: PrimExpr .. rubric:: 範例 >>> val = T.ldg128(x[i]) >>> val = T.ldg128(x[i:i+8]) # load 8 x fp16 >>> val = T.ldg128(x[i], pred=i < N) # predicated load .. py:function:: ldg256(src, pred = None) Load 256 bits (32 bytes) from global memory using explicit PTX instructions. Usage: `T.ldg256(x[i])` or `T.ldg256(x[i:i+16])` emits `tl::ldg256(ptr)`. :param src: A `Buffer`, `BufferRegion`, or `BufferLoad`. :param pred: Optional predicate condition. If False, the load is skipped. :returns: The loaded 256-bit value. :rtype: PrimExpr .. rubric:: 範例 >>> val = T.ldg256(x[i]) >>> val = T.ldg256(x[i:i+16]) # load 16 x fp16 >>> val = T.ldg256(x[i], pred=i < N) # predicated load .. py:function:: stg32(dst, value, pred = None) Store 32 bits (4 bytes) to global memory using explicit PTX instructions. Usage: `T.stg32(y[i], value)` emits `tl::stg32(ptr, value)`. :param dst: A `Buffer`, `BufferRegion`, or `BufferLoad` indicating the destination. :param value: The 32-bit value to store. :param pred: Optional predicate condition. If False, the store is skipped. .. rubric:: 範例 >>> T.stg32(y[i], val) >>> T.stg32(y[i], val, pred=i < N) # predicated store .. py:function:: stg64(dst, value, pred = None) Store 64 bits (8 bytes) to global memory using explicit PTX instructions. Usage: `T.stg64(y[i:i+2], value)` emits `tl::stg64(ptr, value)`. :param dst: A `Buffer`, `BufferRegion`, or `BufferLoad` indicating the destination. :param value: The 64-bit value to store (e.g., uint2). :param pred: Optional predicate condition. If False, the store is skipped. .. rubric:: 範例 >>> T.stg64(y[i:i+2], val) >>> T.stg64(y[i:i+2], val, pred=i < N) # predicated store .. py:function:: stg128(dst, value, pred = None) Store 128 bits (16 bytes) to global memory using explicit PTX instructions. Usage: `T.stg128(y[i:i+4], value)` emits `tl::stg128(ptr, value)`. :param dst: A `Buffer`, `BufferRegion`, or `BufferLoad` indicating the destination. :param value: The 128-bit value to store (e.g., uint4). :param pred: Optional predicate condition. If False, the store is skipped. .. rubric:: 範例 >>> T.stg128(y[i:i+4], val) >>> T.stg128(y[i:i+4], val, pred=i < N) # predicated store .. py:function:: stg256(dst, value, pred = None) Store 256 bits (32 bytes) to global memory using explicit PTX instructions. Usage: `T.stg256(y[i:i+8], value)` emits `tl::stg256(ptr, value)`. :param dst: A `Buffer`, `BufferRegion`, or `BufferLoad` indicating the destination. :param value: The 256-bit value to store (e.g., ulonglong4). :param pred: Optional predicate condition. If False, the store is skipped. .. rubric:: 範例 >>> T.stg256(y[i:i+8], val) >>> T.stg256(y[i:i+8], val, pred=i < N) # predicated store