tilelang.language.builtin

Builtin operations exposed on the TileLang language surface.

Functions

access_ptr(base[, access_type, offset, extent, ...])

Create a TileLang tl.access_ptr from a buffer-like base location.

deallocate_tmem(tmem)

Explicitly deallocate a TMEM buffer allocated by T.alloc_tmem.

create_tma_descriptor(*args)

Create a Tensor Memory Access (TMA) descriptor.

tma_load(*args)

Perform a Tensor Memory Access (TMA) load operation.

tma_load_2sm(*args)

Perform a TMA load with 2SM (two Streaming Multiprocessors) on Blackwell.

fence_proxy_async()

Issue a shared memory fence for asynchronous proxy operations.

tma_store_arrive()

Signal the arrival of a TMA store operation.

tma_store_wait([count])

Wait for completion of TMA store operations.

set_max_nreg(reg_count, is_inc)

Set the maximum number of registers to use.

inc_max_nreg(reg_count)

Increment the maximum number of registers to use.

dec_max_nreg(reg_count)

Decrement the maximum number of registers to use.

annotate_producer_reg_dealloc([reg_count])

Annotate the producer reg dealloc.

annotate_consumer_reg_alloc([reg_count])

Annotate the consumer reg alloc.

no_set_max_nreg()

Disable the maximum register limit setting.

disable_warp_group_reg_alloc()

Disable the warp group reg alloc.

ptx_arrive_cluster_barrier(mbarrier, cta_id)

Arrive at a shared barrier in cluster.

mbarrier_wait_parity(mbarrier, parity)

Wait for memory barrier parity condition.

mbarrier_arrive(mbarrier[, cta_id])

Arrive at memory barrier.

mbarrier_expect_tx(mbarrier, tx)

Set expected transaction count for memory barrier.

mbarrier_arrive_expect_tx(mbarrier, tx)

Arrive at a memory barrier and expect completion of async transactions.

warpgroup_arrive()

Signal warpgroup readiness for subsequent WGMMA operations.

warpgroup_commit_batch()

Commit the current warpgroup batch for WGMMA operations.

warpgroup_wait(num_mma)

Wait for completion of the specified warpgroup batch.

get_lane_idx([warp_size])

Return the logical lane index of the calling thread within a warp.

get_warp_idx_sync([warp_size])

Return the canonical warp index, assuming the warp's threads are converged.

get_warp_idx([warp_size])

Return the canonical warp index without synchronizing the warp.

get_warp_group_idx([warp_size, warps_per_group])

Return the canonical warp group index for the calling thread.

shuffle_elect(thread_extent)

Elect exactly one lane within a logical thread group.

warpgroup_fence_operand(buffer_or_ptr[, offset, ...])

Insert a warpgroup fence for the destination accumulator registers.

wait_wgmma(id)

Wait for WGMMA (Warp Group Matrix Multiply-Accumulate) operations to complete.

barrier_wait(mbarrier, parity)

Wait for a memory barrier to complete.

barrier_arrive(mbarrier)

Arrive at a memory barrier.

shfl_xor(value, delta[, width, mask])

XOR-swap value across lanes (__shfl_xor_sync on CUDA,

shfl_down(value, delta[, width, mask])

Shift value down by delta lanes (__shfl_down_sync on CUDA,

shfl_up(value, delta[, width, mask])

Shift value up by delta lanes (__shfl_up_sync on CUDA,

sync_threads([barrier_id, arrive_count])

Synchronize all threads in a block.

sync_warp([mask])

Synchronize all threads in a warp.

shfl_sync(value, srcLane[, width, mask])

Broadcast value from srcLane to all lanes in the subgroup of

any_sync(predicate[, mask])

Non-zero if ANY active lane in mask has a non-zero predicate.

all_sync(predicate[, mask])

Non-zero only if ALL active lanes in mask have a non-zero predicate.

ballot_sync(predicate[, mask])

Return a uint64 bitmask of lanes in mask whose predicate is set.

ballot(predicate)

Full-warp / full-wavefront ballot. Equivalent to

activemask()

Return a uint64 bitmask of currently active (non-exited) lanes.

syncthreads_count(predicate)

Block barrier that returns the number of threads whose predicate

syncthreads_and(predicate)

Block barrier that returns non-zero only if ALL threads have a non-zero

syncthreads_or(predicate)

Block barrier that returns non-zero if ANY thread has a non-zero

match_any_sync(value[, mask])

Return a uint32 bitmask of lanes in mask whose value equals

match_all_sync(value[, mask])

Return mask if all lanes in mask agree on value, else 0.

sync_global()

Synchronize all threads in the entire grid.

sync_grid()

Synchronize all threads in a grid.

initialize_wgmma_descriptor(descriptor, start_address)

Initialize a WGMMA/UTCMMA shared-memory descriptor.

initialize_tcgen05_descriptor(descriptor, ...[, ...])

Initialize a TCGEN05 shared-memory descriptor.

increase_descriptor_offset(descriptor, offset)

Increase the offset of a memory descriptor.

loop_break()

Break out of the innermost loop.

cp_async_barrier_noinc(barrier)

Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.

tcgen05_mma_arrive(mbar[, arrive_2cta])

Signal UMMA (TCGEN05) barrier arrival for a shared-memory mbarrier pointer.

tcgen05_before_thread_sync()

tcgen05_after_thread_sync()

tcgen05_cp_warpx4(smem_src, tmem_dst[, ...])

Copy one or more packed scale-factor chunks from shared memory to tensor memory.

tcgen05_sf_warp_transpose(smem_src)

Warp-level transpose for one or more packed scale-factor chunks in shared memory.

ptx_mma_sm70(shape, A_layout, B_layout, A_dtype, ...)

TVM intrinsic for ptx tensor core mma instructions on SM70 (Volta).

ds_read_tr16_b64(src)

LDS transpose read, 64-bit, 16-element transpose (gfx950 only).

ds_read_tr8_b64(src)

LDS transpose read, 64-bit, 8-element transpose (gfx950 only).

ldg32(src[, pred])

Load 32 bits (4 bytes) from global memory using explicit PTX instructions.

ldg64(src[, pred])

Load 64 bits (8 bytes) from global memory using explicit PTX instructions.

ldg128(src[, pred])

Load 128 bits (16 bytes) from global memory using explicit PTX instructions.

ldg256(src[, pred])

Load 256 bits (32 bytes) from global memory using explicit PTX instructions.

stg32(dst, value[, pred])

Store 32 bits (4 bytes) to global memory using explicit PTX instructions.

stg64(dst, value[, pred])

Store 64 bits (8 bytes) to global memory using explicit PTX instructions.

stg128(dst, value[, pred])

Store 128 bits (16 bytes) to global memory using explicit PTX instructions.

stg256(dst, value[, pred])

Store 256 bits (32 bytes) to global memory using explicit PTX instructions.

Module Contents

tilelang.language.builtin.access_ptr(base, access_type='r', *extents, offset=0, extent=None, ignore_last_ndim=0)

Create a TileLang tl.access_ptr from a buffer-like base location.

This is a frontend convenience wrapper that keeps a BufferLoad argument in the resulting call so downstream passes can recover the referenced tir.Buffer (including strides/storage scope) and the rw_mask (read/write intent) required by synchronization and safety checks.

The returned tl.access_ptr is expected to be lowered to tir.builtin.tvm_access_ptr later in the TileLang compilation pipeline.

Parameters:
  • base (BufferLikeType) – The base location to take the address of. Supported: - tir.BufferLoad (e.g. A[i, j]): pointer to that element - tir.BufferRegion: pointer to the region minima - tir.Buffer: pointer to the beginning of the buffer - tir.Var with let-binding to one of the above (inside TileLang frame)

  • access_type (str | int) – Access mask for the pointer. Common string forms: “r”, “w”, “rw”. Integer bitmask is also accepted (1=read, 2=write, 3=read-write).

  • *extents (PrimExpr | int) –

    Optional per-axis extents. When provided and extent is not specified, the 1D extent passed to tvm_access_ptr is computed as the product of the provided extents (padding leading dimensions with 1 if needed).

    For example: - T.access_ptr(A[i], “r”) -> extent defaults to 1 (element pointer) - T.access_ptr(A[i], “r”, 16) -> extent=16 - T.access_ptr(A[i, j], “r”, m, n) -> extent=m*n

  • offset (PrimExpr | int) – Additional element offset from the base location.

  • extent (PrimExpr | int | None) – Optional explicit 1D extent override (in elements). If provided, it takes precedence over *extents.

  • ignore_last_ndim (int) – If non-zero, the base linear offset is computed only over the leading dimensions, ignoring the last ignore_last_ndim axes. This is useful when treating an N-D buffer as a view of its trailing sub-tensor.

Returns:

ptr – A handle-typed tir.Call to tl.access_ptr.

Return type:

PrimExpr

tilelang.language.builtin.deallocate_tmem(tmem)

Explicitly deallocate a TMEM buffer allocated by T.alloc_tmem.

By default, TileLang inserts a TMEM deallocation automatically at the end of the allocation block. Calling T.deallocate_tmem(buf) suppresses that automatic tail deallocation for buf and lowers an explicit deallocation at the call site instead.

Notes: - The deallocation must obey the hardware TMEM rules: it should be issued by

the same warp that performed the allocation.

  • Once this API is used, the buffer lifetime is user-managed for the current block; deallocating too early or conditionally is the user’s responsibility.

Parameters:

tmem (tvm.tir.Buffer) – A TMEM buffer previously returned by T.alloc_tmem.

Return type:

None

tilelang.language.builtin.create_tma_descriptor(*args)

Create a Tensor Memory Access (TMA) descriptor.

This is an internal API used by copy lowering. The argument list depends on the tensor rank and encodes the full TMA descriptor configuration:

create_tma_descriptor(data_type, rank, global_addr,

global_shape…, global_stride…, smem_box…, smem_stride…, interleave, swizzle, l2_promotion, oob_fill)

Total arguments: 7 + 4 * rank.

Returns:

A handle to the created TMA descriptor

Return type:

tir.Call

tilelang.language.builtin.tma_load(*args)

Perform a Tensor Memory Access (TMA) load operation.

This is an internal API used by copy lowering. Arguments:

tma_load(descriptor, mbarrier, smem_addr, coord_0, …, coord_n, eviction_policy)

Returns:

A handle to the TMA load operation

Return type:

tir.Call

tilelang.language.builtin.tma_load_2sm(*args)

Perform a TMA load with 2SM (two Streaming Multiprocessors) on Blackwell.

This is an internal API. Same arguments as tma_load(), but with the use_2cta annotation enabled for 2-CTA cooperative loading.

Returns:

A handle to the TMA load operation

Return type:

tir.Call

tilelang.language.builtin.fence_proxy_async()

Issue a shared memory fence for asynchronous proxy operations.

Ensures that prior asynchronous operations (e.g. TMA stores) are visible to subsequent memory accesses. Maps to fence.proxy.async.shared::cta.

Returns:

A handle to the fence operation

Return type:

tir.Call

tilelang.language.builtin.tma_store_arrive()

Signal the arrival of a TMA store operation.

Commits the current group of outstanding TMA store operations. Maps to cp.async.bulk.commit_group.

Returns:

A handle to the store arrive operation

Return type:

tir.Call

tilelang.language.builtin.tma_store_wait(count=0)

Wait for completion of TMA store operations.

Waits until the number of outstanding TMA store groups is at most count. Maps to the PTX instruction cp.async.bulk.wait_group.read <count>.

Parameters:

count (int) – The maximum number of outstanding store groups allowed to remain in flight. Defaults to 0 (wait for all stores to complete).

Returns:

A handle to the store wait operation

Return type:

tir.Call

tilelang.language.builtin.set_max_nreg(reg_count, is_inc)

Set the maximum number of registers to use. Detailed Documentation: https://docs.nvidia.com/cuda/parallel-thread-execution/#miscellaneous-instructions-setmaxnreg

Parameters:
  • reg_count (int) – int The number of registers to allocate

  • is_inc (int) – int Whether to increment or decrement the register count 0 if decrement, 1 if increment

Returns:

A handle to the register setting operation

Return type:

tir.Call

tilelang.language.builtin.inc_max_nreg(reg_count)

Increment the maximum number of registers to use.

Parameters:

reg_count (int)

tilelang.language.builtin.dec_max_nreg(reg_count)

Decrement the maximum number of registers to use.

Parameters:

reg_count (int)

tilelang.language.builtin.annotate_producer_reg_dealloc(reg_count=24)

Annotate the producer reg dealloc.

Parameters:

reg_count (int)

tilelang.language.builtin.annotate_consumer_reg_alloc(reg_count=240)

Annotate the consumer reg alloc.

Parameters:

reg_count (int)

tilelang.language.builtin.no_set_max_nreg()

Disable the maximum register limit setting.

tilelang.language.builtin.disable_warp_group_reg_alloc()

Disable the warp group reg alloc.

tilelang.language.builtin.ptx_arrive_cluster_barrier(mbarrier, cta_id)

Arrive at a shared barrier in cluster.

Parameters:
  • mbarrier (tilelang._typing.BarrierType) – BarrierType The memory barrier to arrive at

  • cta_id (int | tvm.tir.Var) – int | Var The peer CTA rank in cluster to arrive at.

tilelang.language.builtin.mbarrier_wait_parity(mbarrier, parity)

Wait for memory barrier parity condition.

Parameters:
  • mbarrier (tilelang._typing.BarrierType) – BarrierType

  • on (The memory barrier to wait) –

    parity: int | Var

    The parity value to wait for

  • parity (int | tvm.tir.Var)

Examples

mbar = T.alloc_barrier(1)
# Wait for parity 0 on a single mbarrier
T.mbarrier_wait_parity(mbar, 0)

mbars = T.alloc_barrier([128] * n)
# Wait for parity value on one of the mbarriers
T.mbarrier_wait_parity(mbars[ko], ko)

# Common usage in pipelined kernels:
for ko in range(num_stages):
    # Producer waits for consumer to finish previous iteration
    T.mbarrier_wait_parity(mbars[1], ko ^ 1)
    # Producer copies data
    T.copy(A_global, A_shared)
    # Producer signals data ready
    T.mbarrier_arrive(mbars[0])

    # Consumer waits for producer data
    T.mbarrier_wait_parity(mbars[0], ko)
    # Consumer computes
    T.gemm(A_shared, B_shared, C_local)
    # Consumer signals completion
    T.mbarrier_arrive(mbars[1])
Returns:

A handle to the barrier wait operation

Return type:

tir.Call

Parameters:
  • mbarrier (tilelang._typing.BarrierType)

  • parity (int | tvm.tir.Var)

tilelang.language.builtin.mbarrier_arrive(mbarrier, cta_id=None)

Arrive at memory barrier.

Parameters:
  • mbarrier (tilelang._typing.BarrierType) – BarrierType The memory barrier to arrive at

  • cta_id (int | tvm.tir.Var | None) – int | Var | None The peer CTA rank in cluster to arrive at. (Only valid for cluster barriers) If not provided, will arrive on current CTA’s barrier.

tilelang.language.builtin.mbarrier_expect_tx(mbarrier, tx)

Set expected transaction count for memory barrier.

Parameters:
  • mbarrier (tilelang._typing.BarrierType) – BarrierType The memory barrier to expect transaction count for

  • tx (int) – int The expected transaction count

Returns:

A handle to the barrier expectation operation

Return type:

tir.Call

tilelang.language.builtin.mbarrier_arrive_expect_tx(mbarrier, tx)

Arrive at a memory barrier and expect completion of async transactions.

Parameters:
  • mbarrier (tilelang._typing.BarrierType)

  • tx (int)

tilelang.language.builtin.warpgroup_arrive()

Signal warpgroup readiness for subsequent WGMMA operations.

Returns:

A handle to the warpgroup arrive operation.

Return type:

tir.Call

tilelang.language.builtin.warpgroup_commit_batch()

Commit the current warpgroup batch for WGMMA operations.

Returns:

A handle to the warpgroup commit batch operation.

Return type:

tir.Call

tilelang.language.builtin.warpgroup_wait(num_mma)

Wait for completion of the specified warpgroup batch.

Parameters:

num_mma (int) – int Identifier of the warpgroup MMA batch to wait on.

Returns:

A handle to the warpgroup wait operation.

Return type:

tir.Call

tilelang.language.builtin.get_lane_idx(warp_size=None)

Return the logical lane index of the calling thread within a warp.

Parameters:

warp_size (Optional[int, PrimExpr]) – Logical warp (or wavefront) size. Defaults to 32 on NVIDIA and 64 on AMD.

Return type:

tvm.tir.PrimExpr

Example

>>> lane = T.get_lane_idx()
>>> custom_lane = T.get_lane_idx(64)  # override warp size explicitly

Implementation Notes

Lowers to the CUDA helper tl::get_lane_idx(warp_size) defined in src/tl_templates/cuda/intrin.h, which computes the lane index from the linear thread id using the provided warp_size.

tilelang.language.builtin.get_warp_idx_sync(warp_size=None)

Return the canonical warp index, assuming the warp’s threads are converged.

Parameters:

warp_size (Optional[int, PrimExpr]) – Logical warp size used for the index calculation.

Return type:

tvm.tir.PrimExpr

Example

>>> warp = T.get_warp_idx_sync()
>>> custom_warp = T.get_warp_idx_sync(64)

Implementation Notes

Emits tl::get_warp_idx_sync(warp_size) which divides the block-linear thread id by warp_size, matching the semantics of CUTLASS’ canonical helpers.

tilelang.language.builtin.get_warp_idx(warp_size=None)

Return the canonical warp index without synchronizing the warp.

Parameters:

warp_size (Optional[int, PrimExpr]) – Logical warp size used for the index calculation.

Return type:

tvm.tir.PrimExpr

Example

>>> warp = T.get_warp_idx()
>>> custom_warp = T.get_warp_idx(64)

Implementation Notes

Lowers to tl::get_warp_idx(warp_size) which divides the block-linear thread id by the provided warp_size without requiring warp convergence.

tilelang.language.builtin.get_warp_group_idx(warp_size=None, warps_per_group=None)

Return the canonical warp group index for the calling thread.

Parameters:
  • warp_size (Optional[int, PrimExpr]) – Logical warp size to use (defaults to 32 on NVIDIA / 64 on AMD).

  • warps_per_group (Optional[int, PrimExpr]) – Number of warps per warp-group. Defaults to 4 on NVIDIA architectures.

Return type:

tvm.tir.PrimExpr

Example

>>> group = T.get_warp_group_idx()
>>> custom_group = T.get_warp_group_idx(32, 6)  # treat 6 warps as a group

Implementation Notes

Generates tl::get_warp_group_idx(warp_size, warps_per_group) which divides the block-linear thread id by warp_size * warps_per_group, matching the canonical ordering while allowing architecture-specific overrides.

tilelang.language.builtin.shuffle_elect(thread_extent)

Elect exactly one lane within a logical thread group.

Parameters:

thread_extent (int) – Size (in threads) of the group in which a single lane should be elected. Passing 0 elects a single lane in the entire thread block.

Return type:

tvm.tir.PrimExpr

Example

>>> is_leader = T.shuffle_elect(64)
>>> T.if_then_else(is_leader, do_leader_work(), T.evaluate(0))

Implementation Notes

Lowered to the CUDA helper tl::tl_shuffle_elect<thread_extent>() defined in src/tl_templates/cuda/intrin.h, which relies on cutlass::canonical_warp_idx_sync() and cute::elect_one_sync() (or __shfl_sync) to pick one lane per group.

tilelang.language.builtin.warpgroup_fence_operand(buffer_or_ptr, offset=0, num_regs=None, dtype=None)

Insert a warpgroup fence for the destination accumulator registers.

This prevents NVCC from sinking uses of accumulator fragments past the corresponding WGMMA operations by issuing an empty inline assembly barrier on every register.

Parameters:
  • buffer_or_ptr (tilelang._typing.BufferLikeType | tvm.tir.PrimExpr) – BufferLikeType | PrimExpr A buffer representing the accumulator fragment, a buffer load/region that identifies a starting element within the fragment, or a pointer expression (e.g., tvm_access_ptr/address_of/typed Var).

  • offset (int | tvm.tir.PrimExpr) – int | PrimExpr Element offset from the start of the accumulator fragment.

  • num_regs (int | tvm.tir.PrimExpr | None) – int | PrimExpr | None Number of 32-bit registers to fence. If None and a Buffer is provided, it will be derived from the buffer shape and dtype.

  • dtype (tilelang._typing.DType | None) – DType | None Data type string of the accumulator elements. When passing a buffer or buffer-derived expression, dtype is inferred. It is required only when passing a raw pointer expression that cannot be inferred.

Returns:

A handle to the warpgroup fence operation.

Return type:

tir.Call

tilelang.language.builtin.wait_wgmma(id)

Wait for WGMMA (Warp Group Matrix Multiply-Accumulate) operations to complete.

Parameters:

id (int) – int The id of the WGMMA operation to wait for

Returns:

A handle to the WGMMA wait operation

Return type:

tir.Call

tilelang.language.builtin.barrier_wait(mbarrier, parity)

Wait for a memory barrier to complete.

Parameters:
  • mbarrier (tilelang._typing.BarrierType) – BarrierType The memory barrier to wait on

  • parity (int | tvm.tir.Var) – int | Var The parity value to wait for

Returns:

A handle to the barrier wait operation

Return type:

tir.Call

Current implementation is a sugar syntax for mbarrier_wait_parity, as we only support parity 0 and 1.

tilelang.language.builtin.barrier_arrive(mbarrier)

Arrive at a memory barrier.

Parameters:

mbarrier (tilelang._typing.BarrierType) – BarrierType The memory barrier to arrive at

tilelang.language.builtin.shfl_xor(value, delta, width=_DEFAULT_SHFL_WIDTH, mask=_FULL_WARP_MASK)

XOR-swap value across lanes (__shfl_xor_sync on CUDA, __shfl_xor on HIP — mask ignored on HIP).

Parameters:
  • value (int | tvm.tir.PrimExpr | tvm.tir.Call)

  • delta (int | tvm.tir.PrimExpr | tvm.tir.Call)

  • width (int | tvm.tir.PrimExpr)

  • mask (int | tvm.tir.PrimExpr)

tilelang.language.builtin.shfl_down(value, delta, width=_DEFAULT_SHFL_WIDTH, mask=_FULL_WARP_MASK)

Shift value down by delta lanes (__shfl_down_sync on CUDA, __shfl_down on HIP).

Parameters:
  • value (int | tvm.tir.PrimExpr | tvm.tir.Call)

  • delta (int | tvm.tir.PrimExpr | tvm.tir.Call)

  • width (int | tvm.tir.PrimExpr)

  • mask (int | tvm.tir.PrimExpr)

tilelang.language.builtin.shfl_up(value, delta, width=_DEFAULT_SHFL_WIDTH, mask=_FULL_WARP_MASK)

Shift value up by delta lanes (__shfl_up_sync on CUDA, __shfl_up on HIP).

Parameters:
  • value (int | tvm.tir.PrimExpr | tvm.tir.Call)

  • delta (int | tvm.tir.PrimExpr | tvm.tir.Call)

  • width (int | tvm.tir.PrimExpr)

  • mask (int | tvm.tir.PrimExpr)

tilelang.language.builtin.sync_threads(barrier_id=None, arrive_count=None)

Synchronize all threads in a block.

Parameters:
  • barrier_id (int)

  • arrive_count (int)

tilelang.language.builtin.sync_warp(mask=None)

Synchronize all threads in a warp.

Parameters:

mask (int)

tilelang.language.builtin.shfl_sync(value, srcLane, width=_DEFAULT_SHFL_WIDTH, mask=_FULL_WARP_MASK)

Broadcast value from srcLane to all lanes in the subgroup of width lanes (__shfl_sync on CUDA, __shfl on HIP — mask ignored on HIP).

Parameters:
  • value (int | tvm.tir.PrimExpr)

  • srcLane (int | tvm.tir.PrimExpr)

  • width (int | tvm.tir.PrimExpr)

  • mask (int | tvm.tir.PrimExpr)

tilelang.language.builtin.any_sync(predicate, mask=_FULL_WARP_MASK)

Non-zero if ANY active lane in mask has a non-zero predicate.

Lowers to __any_sync(mask, predicate) on CUDA and __any(predicate) on HIP (the mask is ignored on HIP because the full wavefront is always convergent).

Parameters:
  • predicate (int | tvm.tir.PrimExpr) – Integer condition to test.

  • mask (int | tvm.tir.PrimExpr) – Warp lane mask (defaults to 0xFFFFFFFF, i.e. all 32 lanes).

Returns:

Non-zero if any thread in the mask has a non-zero predicate.

Return type:

int32

tilelang.language.builtin.all_sync(predicate, mask=_FULL_WARP_MASK)

Non-zero only if ALL active lanes in mask have a non-zero predicate.

Lowers to __all_sync(mask, predicate) on CUDA and __all(predicate) on HIP.

Parameters:
  • predicate (int | tvm.tir.PrimExpr) – Integer condition to test.

  • mask (int | tvm.tir.PrimExpr) – Warp lane mask (defaults to 0xFFFFFFFF, i.e. all 32 lanes).

Returns:

Non-zero if all threads in the mask have a non-zero predicate.

Return type:

int32

tilelang.language.builtin.ballot_sync(predicate, mask=_FULL_WARP_MASK)

Return a uint64 bitmask of lanes in mask whose predicate is set.

CUDA: __ballot_sync(mask, predicate) returns unsigned int; codegen zero-extends it to uint64 (upper 32 bits always zero for 32-wide warps). HIP: __ballot(predicate) returns uint64 natively, covering all 64 wavefront lanes. The mask argument is ignored on HIP.

Returns:

Bitmask with bit N set if lane N’s predicate is non-zero.

Return type:

uint64

Parameters:
  • predicate (int | tvm.tir.PrimExpr)

  • mask (int | tvm.tir.PrimExpr)

tilelang.language.builtin.ballot(predicate)

Full-warp / full-wavefront ballot. Equivalent to ballot_sync(predicate) (i.e. with the default full warp mask).

Returns:

Bitmask with bit N set if lane N’s predicate is non-zero.

Return type:

uint64

Parameters:

predicate (int | tvm.tir.PrimExpr)

tilelang.language.builtin.activemask()

Return a uint64 bitmask of currently active (non-exited) lanes.

Lowers to __activemask() (zero-extended to uint64) on CUDA and __ballot(1) on HIP.

Return type:

tvm.tir.PrimExpr

tilelang.language.builtin.syncthreads_count(predicate)

Block barrier that returns the number of threads whose predicate evaluates to non-zero (__syncthreads_count on CUDA and HIP).

Parameters:

predicate (int | tvm.tir.PrimExpr)

Return type:

tvm.tir.PrimExpr

tilelang.language.builtin.syncthreads_and(predicate)

Block barrier that returns non-zero only if ALL threads have a non-zero predicate (__syncthreads_and on CUDA and HIP).

Parameters:

predicate (int | tvm.tir.PrimExpr)

Return type:

tvm.tir.PrimExpr

tilelang.language.builtin.syncthreads_or(predicate)

Block barrier that returns non-zero if ANY thread has a non-zero predicate (__syncthreads_or on CUDA and HIP).

Parameters:

predicate (int | tvm.tir.PrimExpr)

Return type:

tvm.tir.PrimExpr

tilelang.language.builtin.match_any_sync(value, mask=_FULL_WARP_MASK)

Return a uint32 bitmask of lanes in mask whose value equals the calling lane’s value. Lowers to __match_any_sync on CUDA (compute capability >= 7.0). Not supported on HIP.

Parameters:
  • value (int | tvm.tir.PrimExpr)

  • mask (int | tvm.tir.PrimExpr)

Return type:

tvm.tir.PrimExpr

tilelang.language.builtin.match_all_sync(value, mask=_FULL_WARP_MASK)

Return mask if all lanes in mask agree on value, else 0.

Lowers to __match_all_sync on CUDA (compute capability >= 7.0); the trailing int* predicate output is hidden in codegen and discarded. Callers can reconstruct the predicate as result != 0. Not supported on HIP.

Parameters:
  • value (int | tvm.tir.PrimExpr)

  • mask (int | tvm.tir.PrimExpr)

Return type:

tvm.tir.PrimExpr

tilelang.language.builtin.sync_global()

Synchronize all threads in the entire grid.

tilelang.language.builtin.sync_grid()

Synchronize all threads in a grid.

tilelang.language.builtin.initialize_wgmma_descriptor(descriptor, start_address, layout_type_=0, leading_byte_offset=0, stride_byte_offset=0)

Initialize a WGMMA/UTCMMA shared-memory descriptor.

Parameters:
  • descriptor (tvm.tir.Buffer)

  • start_address (tvm.tir.PrimExpr)

  • layout_type_ (int)

  • leading_byte_offset (int)

  • stride_byte_offset (int)

Return type:

tvm.tir.PrimExpr

tilelang.language.builtin.initialize_tcgen05_descriptor(descriptor, start_address, leading_byte_offset, stride_byte_offset, base_offset=0, leading_is_absolute=False, swizzle_mode=0)

Initialize a TCGEN05 shared-memory descriptor.

Parameters:
  • descriptor (tvm.tir.Buffer)

  • start_address (tvm.tir.PrimExpr)

  • leading_byte_offset (int)

  • stride_byte_offset (int)

  • base_offset (int)

  • leading_is_absolute (bool)

  • swizzle_mode (int)

Return type:

tvm.tir.PrimExpr

tilelang.language.builtin.increase_descriptor_offset(descriptor, offset)

Increase the offset of a memory descriptor.

Parameters:
  • descriptor (PrimExpr) – The memory descriptor to modify.

  • offset (PrimExpr) – The offset value to increase.

Returns:

A handle representing the modified descriptor.

Return type:

PrimExpr

tilelang.language.builtin.loop_break()

Break out of the innermost loop.

tilelang.language.builtin.cp_async_barrier_noinc(barrier)

Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.

Parameters:

barrier (tilelang._typing.BarrierType)

tilelang.language.builtin.tcgen05_mma_arrive(mbar, arrive_2cta=False)

Signal UMMA (TCGEN05) barrier arrival for a shared-memory mbarrier pointer.

Parameters:
  • mbar (tir.Buffer | BufferLoad | PrimExpr) – The mbarrier object in shared memory (e.g., Barrier*) or its address.

  • arrive_2cta (bool) – Whether to also arrive at the peer CTA’s barrier. If set, will be lowered to umma_arrive_multicast_2x1SM.

tilelang.language.builtin.tcgen05_before_thread_sync()
tilelang.language.builtin.tcgen05_after_thread_sync()
tilelang.language.builtin.tcgen05_cp_warpx4(smem_src, tmem_dst, tmem_col_offset=0, *, use_2cta=False)

Copy one or more packed scale-factor chunks from shared memory to tensor memory.

The helper lowers to one or more tcgen05.cp.cta_group::{1,2}.32x128b.warpx4 instructions. For 1D packed uint32 scale buffers, each 128-word chunk maps to 4 TMEM columns and the column offset is advanced automatically.

Parameters:

use_2cta (bool)

tilelang.language.builtin.tcgen05_sf_warp_transpose(smem_src)

Warp-level transpose for one or more packed scale-factor chunks in shared memory.

For 1D packed uint32 scale buffers, the helper automatically applies the transpose to each 128-word chunk in order.

tilelang.language.builtin.ptx_mma_sm70(shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, multiplicand_a, a_index, multiplicand_b, b_index, accumulator, c_index)

TVM intrinsic for ptx tensor core mma instructions on SM70 (Volta).

This intrinsic provides SM70-specific MMA operations that support m16n16k4 shape with FP16 inputs and FP16/FP32 accumulation.

Parameters:
  • shape (str) – The shape of mma fragment (e.g., “m16n16k4”).

  • A_layout (str) – The layout of multiplicand fragment A (“row” or “col”).

  • B_layout (str) – The layout of multiplicand fragment B (“row” or “col”).

  • A_dtype (str) – The data type of multiplicand fragment A (typically “fp16”).

  • B_dtype (str) – The data type of multiplicand fragment B (typically “fp16”).

  • C_dtype (str) – The data type of accumulator fragment C (“fp16” or “fp32”).

  • multiplicand_a (Var) – The multiplicand fragment A variable.

  • a_index (Expr) – The index of multiplicand fragment A.

  • multiplicand_b (Var) – The multiplicand fragment B variable.

  • b_index (Expr) – The index of multiplicand fragment B.

  • accumulator (Var) – The accumulator fragment C variable.

  • c_index (Expr) – The index of accumulator fragment C.

Returns:

call – The call expression.

Return type:

PrimExpr

Examples

>>> T.ptx_mma_sm70(
...     "float16",
...     "m16n16k4",
...     "row",
...     "col",
...     "fp16",
...     "fp16",
...     "fp16",
...     A_local.data,
...     0,
...     B_local.data,
...     0,
...     C_local.data,
...     0,
... )
tilelang.language.builtin.ds_read_tr16_b64(src)

LDS transpose read, 64-bit, 16-element transpose (gfx950 only).

Reads 8 bytes from LDS (__shared__ memory) with a 16-element transpose. Used for FP16/BF16 MFMA matrix B-loads on MI350/MI355X (gfx950).

Parameters:

src (tilelang._typing.BufferLikeType) – A Buffer, BufferRegion, or BufferLoad in shared memory.

Returns:

The loaded 64-bit value as uint32x2.

Return type:

PrimExpr

Example

>>> val = T.ds_read_tr16_b64(smem[i])
tilelang.language.builtin.ds_read_tr8_b64(src)

LDS transpose read, 64-bit, 8-element transpose (gfx950 only).

Reads 8 bytes from LDS (__shared__ memory) with an 8-element transpose. Used for FP32 MFMA matrix B-loads on MI350/MI355X (gfx950).

Parameters:

src (tilelang._typing.BufferLikeType) – A Buffer, BufferRegion, or BufferLoad in shared memory.

Returns:

The loaded 64-bit value as uint32x2.

Return type:

PrimExpr

Example

>>> val = T.ds_read_tr8_b64(smem[i])
tilelang.language.builtin.ldg32(src, pred=None)

Load 32 bits (4 bytes) from global memory using explicit PTX instructions.

Usage: T.ldg32(x[i]) or T.ldg32(x[i:i+2]) emits tl::ldg32(ptr).

Parameters:
  • src (tilelang._typing.BufferLikeType) – A Buffer, BufferRegion, or BufferLoad.

  • pred (tvm.tir.PrimExpr) – Optional predicate condition. If False, the load is skipped.

Returns:

The loaded 32-bit value.

Return type:

PrimExpr

Example

>>> val = T.ldg32(x[i])
>>> val = T.ldg32(x[i:i+2])  # load 2 x fp16
>>> val = T.ldg32(x[i], pred=i < N)  # predicated load
tilelang.language.builtin.ldg64(src, pred=None)

Load 64 bits (8 bytes) from global memory using explicit PTX instructions.

Usage: T.ldg64(x[i]) or T.ldg64(x[i:i+4]) emits tl::ldg64(ptr).

Parameters:
  • src (tilelang._typing.BufferLikeType) – A Buffer, BufferRegion, or BufferLoad.

  • pred (tvm.tir.PrimExpr) – Optional predicate condition. If False, the load is skipped.

Returns:

The loaded 64-bit value.

Return type:

PrimExpr

Example

>>> val = T.ldg64(x[i])
>>> val = T.ldg64(x[i:i+4])  # load 4 x fp16
>>> val = T.ldg64(x[i], pred=i < N)  # predicated load
tilelang.language.builtin.ldg128(src, pred=None)

Load 128 bits (16 bytes) from global memory using explicit PTX instructions.

Usage: T.ldg128(x[i]) or T.ldg128(x[i:i+8]) emits tl::ldg128(ptr).

Parameters:
  • src (tilelang._typing.BufferLikeType) – A Buffer, BufferRegion, or BufferLoad.

  • pred (tvm.tir.PrimExpr) – Optional predicate condition. If False, the load is skipped.

Returns:

The loaded 128-bit value.

Return type:

PrimExpr

Example

>>> val = T.ldg128(x[i])
>>> val = T.ldg128(x[i:i+8])  # load 8 x fp16
>>> val = T.ldg128(x[i], pred=i < N)  # predicated load
tilelang.language.builtin.ldg256(src, pred=None)

Load 256 bits (32 bytes) from global memory using explicit PTX instructions.

Usage: T.ldg256(x[i]) or T.ldg256(x[i:i+16]) emits tl::ldg256(ptr).

Parameters:
  • src (tilelang._typing.BufferLikeType) – A Buffer, BufferRegion, or BufferLoad.

  • pred (tvm.tir.PrimExpr) – Optional predicate condition. If False, the load is skipped.

Returns:

The loaded 256-bit value.

Return type:

PrimExpr

Example

>>> val = T.ldg256(x[i])
>>> val = T.ldg256(x[i:i+16])  # load 16 x fp16
>>> val = T.ldg256(x[i], pred=i < N)  # predicated load
tilelang.language.builtin.stg32(dst, value, pred=None)

Store 32 bits (4 bytes) to global memory using explicit PTX instructions.

Usage: T.stg32(y[i], value) emits tl::stg32(ptr, value).

Parameters:
  • dst (tilelang._typing.BufferLikeType) – A Buffer, BufferRegion, or BufferLoad indicating the destination.

  • value (tvm.tir.PrimExpr) – The 32-bit value to store.

  • pred (tvm.tir.PrimExpr) – Optional predicate condition. If False, the store is skipped.

Return type:

None

Example

>>> T.stg32(y[i], val)
>>> T.stg32(y[i], val, pred=i < N)  # predicated store
tilelang.language.builtin.stg64(dst, value, pred=None)

Store 64 bits (8 bytes) to global memory using explicit PTX instructions.

Usage: T.stg64(y[i:i+2], value) emits tl::stg64(ptr, value).

Parameters:
  • dst (tilelang._typing.BufferLikeType) – A Buffer, BufferRegion, or BufferLoad indicating the destination.

  • value (tvm.tir.PrimExpr) – The 64-bit value to store (e.g., uint2).

  • pred (tvm.tir.PrimExpr) – Optional predicate condition. If False, the store is skipped.

Return type:

None

Example

>>> T.stg64(y[i:i+2], val)
>>> T.stg64(y[i:i+2], val, pred=i < N)  # predicated store
tilelang.language.builtin.stg128(dst, value, pred=None)

Store 128 bits (16 bytes) to global memory using explicit PTX instructions.

Usage: T.stg128(y[i:i+4], value) emits tl::stg128(ptr, value).

Parameters:
  • dst (tilelang._typing.BufferLikeType) – A Buffer, BufferRegion, or BufferLoad indicating the destination.

  • value (tvm.tir.PrimExpr) – The 128-bit value to store (e.g., uint4).

  • pred (tvm.tir.PrimExpr) – Optional predicate condition. If False, the store is skipped.

Return type:

None

Example

>>> T.stg128(y[i:i+4], val)
>>> T.stg128(y[i:i+4], val, pred=i < N)  # predicated store
tilelang.language.builtin.stg256(dst, value, pred=None)

Store 256 bits (32 bytes) to global memory using explicit PTX instructions.

Usage: T.stg256(y[i:i+8], value) emits tl::stg256(ptr, value).

Parameters:
  • dst (tilelang._typing.BufferLikeType) – A Buffer, BufferRegion, or BufferLoad indicating the destination.

  • value (tvm.tir.PrimExpr) – The 256-bit value to store (e.g., ulonglong4).

  • pred (tvm.tir.PrimExpr) – Optional predicate condition. If False, the store is skipped.

Return type:

None

Example

>>> T.stg256(y[i:i+8], val)
>>> T.stg256(y[i:i+8], val, pred=i < N)  # predicated store