tilelang.language.builtin ========================= .. py:module:: tilelang.language.builtin .. autoapi-nested-parse:: The language interface for tl programs. Functions --------- .. autoapisummary:: tilelang.language.builtin.create_list_of_mbarrier tilelang.language.builtin.get_mbarrier tilelang.language.builtin.create_tma_descriptor tilelang.language.builtin.tma_load tilelang.language.builtin.fence_proxy_async tilelang.language.builtin.tma_store_arrive tilelang.language.builtin.tma_store_wait tilelang.language.builtin.set_max_nreg tilelang.language.builtin.inc_max_nreg tilelang.language.builtin.dec_max_nreg tilelang.language.builtin.annotate_producer_reg_dealloc tilelang.language.builtin.annotate_consumer_reg_alloc tilelang.language.builtin.no_set_max_nreg tilelang.language.builtin.disable_warp_group_reg_alloc tilelang.language.builtin.mbarrier_wait_parity tilelang.language.builtin.mbarrier_arrive tilelang.language.builtin.mbarrier_expect_tx tilelang.language.builtin.warpgroup_arrive tilelang.language.builtin.warpgroup_commit_batch tilelang.language.builtin.warpgroup_wait tilelang.language.builtin.get_lane_idx tilelang.language.builtin.get_warp_idx_sync tilelang.language.builtin.get_warp_idx tilelang.language.builtin.get_warp_group_idx tilelang.language.builtin.shuffle_elect tilelang.language.builtin.warpgroup_fence_operand tilelang.language.builtin.wait_wgmma tilelang.language.builtin.barrier_wait tilelang.language.builtin.barrier_arrive tilelang.language.builtin.shfl_xor tilelang.language.builtin.shfl_down tilelang.language.builtin.shfl_up tilelang.language.builtin.sync_threads tilelang.language.builtin.sync_global tilelang.language.builtin.sync_grid tilelang.language.builtin.initialize_wgmma_descriptor tilelang.language.builtin.initialize_tcgen05_descriptor tilelang.language.builtin.increase_descriptor_offset tilelang.language.builtin.loop_break tilelang.language.builtin.cp_async_barrier_noinc tilelang.language.builtin.tcgen05_mma_arrive tilelang.language.builtin.ptx_mma_sm70 Module Contents --------------- .. py:function:: create_list_of_mbarrier(*args) Create a list of memory barrier handles. :param \*args: Either a single list of arguments, or multiple arguments directly. :type \*args: list or Any :returns: Handle to the created list of memory barriers. :rtype: tvm.tir.Call :raises TypeError: If the input is not a list or variadic arguments. .. rubric:: Examples >>> create_list_of_mbarrier([128, 128]) >>> create_list_of_mbarrier(128, 128) .. py:function:: get_mbarrier(*args) Retrieve a memory barrier operation. :param \*args: Variable arguments to specify which memory barrier to retrieve :returns: A handle to the requested memory barrier :rtype: tir.Call .. py:function:: create_tma_descriptor(*args) Create a Tensor Memory Access (TMA) descriptor. :param \*args: Variable arguments defining the TMA descriptor configuration :returns: A handle to the created TMA descriptor :rtype: tir.Call .. py:function:: tma_load(*args) Perform a Tensor Memory Access (TMA) load operation. :param \*args: Variable arguments specifying the TMA load parameters :returns: A handle to the TMA load operation :rtype: tir.Call .. py:function:: fence_proxy_async(*args) Create a fence for asynchronous proxy operations. :param \*args: Variable arguments for fence configuration :returns: A handle to the fence operation :rtype: tir.Call .. py:function:: tma_store_arrive(*args) Signal the arrival of a TMA store operation. :param \*args: Variable arguments for the store arrival operation :returns: A handle to the store arrive operation :rtype: tir.Call .. py:function:: tma_store_wait(*args) Wait for completion of TMA store operations. :param \*args: Variable arguments specifying which store operations to wait for :returns: A handle to the store wait operation :rtype: tir.Call .. py:function:: set_max_nreg(reg_count, is_inc) Set the maximum number of registers to use. Detailed Documentation: https://docs.nvidia.com/cuda/parallel-thread-execution/#miscellaneous-instructions-setmaxnreg :param reg_count: int The number of registers to allocate :param is_inc: int Whether to increment or decrement the register count 0 if decrement, 1 if increment :returns: A handle to the register setting operation :rtype: tir.Call .. py:function:: inc_max_nreg(reg_count) Increment the maximum number of registers to use. .. py:function:: dec_max_nreg(reg_count) Decrement the maximum number of registers to use. .. py:function:: annotate_producer_reg_dealloc(reg_count = 24) Annotate the producer reg dealloc. .. py:function:: annotate_consumer_reg_alloc(reg_count = 240) Annotate the consumer reg alloc. .. py:function:: no_set_max_nreg() Disable the maximum register limit setting. .. py:function:: disable_warp_group_reg_alloc() Disable the warp group reg alloc. .. py:function:: mbarrier_wait_parity(mbarrier, parity) Wait for memory barrier parity condition. :param mbarrier: Optional[int, PrimExpr] The memory barrier to wait on :param parity: Optional[int, Var] The parity value to wait for .. rubric:: Examples .. code-block:: python # Wait for parity 0 on barrier 0 T.mbarrier_wait_parity(0, 0) # Wait for parity value in variable ko on barrier 1 T.mbarrier_wait_parity(1, ko) # Wait using barrier handle barrier = T.get_mbarrier(0) T.mbarrier_wait_parity(barrier, 1) # Common usage in pipelined kernels: for ko in range(num_stages): # Producer waits for consumer to finish previous iteration T.mbarrier_wait_parity(1, ko ^ 1) # Producer copies data T.copy(A_global, A_shared) # Producer signals data ready T.mbarrier_arrive(0) # Consumer waits for producer data T.mbarrier_wait_parity(0, ko) # Consumer computes T.gemm(A_shared, B_shared, C_local) # Consumer signals completion T.mbarrier_arrive(1) :returns: A handle to the barrier wait operation :rtype: tir.Call .. py:function:: mbarrier_arrive(mbarrier) Arrive at memory barrier. :param mbarrier: Optional[int, PrimExpr] The memory barrier to arrive at .. py:function:: mbarrier_expect_tx(*args) Set expected transaction count for memory barrier. :param \*args: Variable arguments specifying the expected transaction count :returns: A handle to the barrier expectation operation :rtype: tir.Call .. py:function:: warpgroup_arrive() Signal warpgroup readiness for subsequent WGMMA operations. :returns: A handle to the warpgroup arrive operation. :rtype: tir.Call .. py:function:: warpgroup_commit_batch() Commit the current warpgroup batch for WGMMA operations. :returns: A handle to the warpgroup commit batch operation. :rtype: tir.Call .. py:function:: warpgroup_wait(num_mma) Wait for completion of the specified warpgroup batch. :param num_mma: int Identifier of the warpgroup MMA batch to wait on. :returns: A handle to the warpgroup wait operation. :rtype: tir.Call .. py:function:: get_lane_idx(warp_size = None) Return the logical lane index of the calling thread within a warp. :param warp_size: Logical warp (or wavefront) size. Defaults to 32 on NVIDIA and 64 on AMD. :type warp_size: Optional[int, PrimExpr] .. rubric:: Example >>> lane = T.get_lane_idx() >>> custom_lane = T.get_lane_idx(64) # override warp size explicitly Implementation Notes -------------------- Lowers to the CUDA helper `tl::get_lane_idx(warp_size)` defined in `src/tl_templates/cuda/intrin.h`, which computes the lane index from the linear thread id using the provided `warp_size`. .. py:function:: get_warp_idx_sync(warp_size = None) Return the canonical warp index, assuming the warp's threads are converged. :param warp_size: Logical warp size used for the index calculation. :type warp_size: Optional[int, PrimExpr] .. rubric:: Example >>> warp = T.get_warp_idx_sync() >>> custom_warp = T.get_warp_idx_sync(64) Implementation Notes -------------------- Emits `tl::get_warp_idx_sync(warp_size)` which divides the block-linear thread id by `warp_size`, matching the semantics of CUTLASS' canonical helpers. .. py:function:: get_warp_idx(warp_size = None) Return the canonical warp index without synchronizing the warp. :param warp_size: Logical warp size used for the index calculation. :type warp_size: Optional[int, PrimExpr] .. rubric:: Example >>> warp = T.get_warp_idx() >>> custom_warp = T.get_warp_idx(64) Implementation Notes -------------------- Lowers to `tl::get_warp_idx(warp_size)` which divides the block-linear thread id by the provided `warp_size` without requiring warp convergence. .. py:function:: get_warp_group_idx(warp_size = None, warps_per_group = None) Return the canonical warp group index for the calling thread. :param warp_size: Logical warp size to use (defaults to 32 on NVIDIA / 64 on AMD). :type warp_size: Optional[int, PrimExpr] :param warps_per_group: Number of warps per warp-group. Defaults to 4 on NVIDIA architectures. :type warps_per_group: Optional[int, PrimExpr] .. rubric:: Example >>> group = T.get_warp_group_idx() >>> custom_group = T.get_warp_group_idx(32, 6) # treat 6 warps as a group Implementation Notes -------------------- Generates `tl::get_warp_group_idx(warp_size, warps_per_group)` which divides the block-linear thread id by `warp_size * warps_per_group`, matching the canonical ordering while allowing architecture-specific overrides. .. py:function:: shuffle_elect(thread_extent) Elect exactly one lane within a logical thread group. :param thread_extent: Size (in threads) of the group in which a single lane should be elected. Passing 0 elects a single lane in the entire thread block. :type thread_extent: int .. rubric:: Example >>> is_leader = T.shuffle_elect(64) >>> T.if_then_else(is_leader, do_leader_work(), T.evaluate(0)) Implementation Notes -------------------- Lowered to the CUDA helper `tl::tl_shuffle_elect()` defined in `src/tl_templates/cuda/intrin.h`, which relies on `cutlass::canonical_warp_idx_sync()` and `cute::elect_one_sync()` (or `__shfl_sync`) to pick one lane per group. .. py:function:: warpgroup_fence_operand(buffer_or_ptr, offset = 0, num_regs = None, dtype = None) Insert a warpgroup fence for the destination accumulator registers. This prevents NVCC from sinking uses of accumulator fragments past the corresponding WGMMA operations by issuing an empty inline assembly barrier on every register. :param buffer_or_ptr: Buffer | BufferLoad | BufferRegion | PrimExpr A buffer representing the accumulator fragment, a buffer load/region that identifies a starting element within the fragment, or a pointer expression (e.g., tvm_access_ptr/address_of/typed Var). :param offset: int | PrimExpr Element offset from the start of the accumulator fragment. :param num_regs: int | PrimExpr | None Number of 32-bit registers to fence. If None and a Buffer is provided, it will be derived from the buffer shape and dtype. :param dtype: str | None Data type string of the accumulator elements. When passing a buffer or buffer-derived expression, dtype is inferred. It is required only when passing a raw pointer expression that cannot be inferred. :returns: A handle to the warpgroup fence operation. :rtype: tir.Call .. py:function:: wait_wgmma(id) Wait for WGMMA (Warp Group Matrix Multiply-Accumulate) operations to complete. :param id: int The id of the WGMMA operation to wait for :returns: A handle to the WGMMA wait operation :rtype: tir.Call .. py:function:: barrier_wait(barrier_id, parity = None) Wait for a memory barrier to complete. :param barrier_id: Optional[int, PrimExpr] The memory barrier to wait on :param parity: Optional[int, Var] The parity value to wait for :returns: A handle to the barrier wait operation :rtype: tir.Call Current implementation is a sugar syntax for mbarrier_wait_parity, as we only support parity 0 and 1. .. py:function:: barrier_arrive(barrier_id) Arrive at a memory barrier. :param barrier_id: Optional[int, PrimExpr] The memory barrier to arrive at .. py:function:: shfl_xor(value, offset) Perform a shuffle operation with XOR offset. :param value: Optional[int, PrimExpr] The value to shuffle :param offset: Optional[int, PrimExpr] The offset for the shuffle operation :returns: A handle to the shuffle operation :rtype: tir.Call .. py:function:: shfl_down(value, offset) Perform a shuffle operation with down offset. :param value: Optional[int, PrimExpr] The value to shuffle .. py:function:: shfl_up(value, offset) Perform a shuffle operation with up offset. :param value: Optional[int, PrimExpr] The value to shuffle .. py:function:: sync_threads(barrier_id = None, arrive_count = None) Synchronize all threads in a block. .. py:function:: sync_global() Synchronize all threads in the entire grid. .. py:function:: sync_grid() Synchronize all threads in a grid. .. py:function:: initialize_wgmma_descriptor(descriptor, start_address, layout_type_ = 0, leading_byte_offset = 0, stride_byte_offset = 0) Initialize a WGMMA/UTCMMA shared-memory descriptor. .. py:function:: initialize_tcgen05_descriptor(descriptor, start_address, leading_byte_offset, stride_byte_offset, base_offset = 0, leading_is_absolute = False, swizzle_mode = 0) Initialize a TCGEN05 shared-memory descriptor. .. py:function:: increase_descriptor_offset(descriptor, offset) Increase the offset of a memory descriptor. :param descriptor: The memory descriptor to modify. :type descriptor: PrimExpr :param offset: The offset value to increase. :type offset: PrimExpr :returns: A handle representing the modified descriptor. :rtype: PrimExpr .. py:function:: loop_break() Break out of the innermost loop. .. py:function:: cp_async_barrier_noinc(barrier_id) Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc. .. py:function:: tcgen05_mma_arrive(mbar_ptr) Signal UMMA (TCGEN05) barrier arrival for a shared-memory mbarrier pointer. :param mbar_ptr: Pointer to the mbarrier object in shared memory (e.g., Barrier*). :type mbar_ptr: PrimExpr .. py:function:: ptx_mma_sm70(shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, multiplicand_a, a_index, multiplicand_b, b_index, accumulator, c_index) TVM intrinsic for ptx tensor core mma instructions on SM70 (Volta). This intrinsic provides SM70-specific MMA operations that support m16n16k4 shape with FP16 inputs and FP16/FP32 accumulation. :param shape: The shape of mma fragment (e.g., "m16n16k4"). :type shape: str :param A_layout: The layout of multiplicand fragment A ("row" or "col"). :type A_layout: str :param B_layout: The layout of multiplicand fragment B ("row" or "col"). :type B_layout: str :param A_dtype: The data type of multiplicand fragment A (typically "fp16"). :type A_dtype: str :param B_dtype: The data type of multiplicand fragment B (typically "fp16"). :type B_dtype: str :param C_dtype: The data type of accumulator fragment C ("fp16" or "fp32"). :type C_dtype: str :param multiplicand_a: The multiplicand fragment A variable. :type multiplicand_a: Var :param a_index: The index of multiplicand fragment A. :type a_index: Expr :param multiplicand_b: The multiplicand fragment B variable. :type multiplicand_b: Var :param b_index: The index of multiplicand fragment B. :type b_index: Expr :param accumulator: The accumulator fragment C variable. :type accumulator: Var :param c_index: The index of accumulator fragment C. :type c_index: Expr :returns: **call** -- The call expression. :rtype: PrimExpr .. rubric:: Examples >>> T.ptx_mma_sm70( ... "float16", ... "m16n16k4", ... "row", ... "col", ... "fp16", ... "fp16", ... "fp16", ... A_local.data, ... 0, ... B_local.data, ... 0, ... C_local.data, ... 0, ... )