tilelang.transform¶
Wrapping transformations.
Submodules¶
Functions¶
Get the current pass context |
|
ClusterPlanning |
|
infer the fragment/shared memory layout |
|
Annotate tile operations with coarse-grained instruction kind. |
|
LayoutInference |
|
LowerTileOp |
|
InjectSoftwarePipeline |
|
Legalize negative indices in buffer loads. |
|
Inject Assumes for natural shape boundary conditions. And convert Assumes in Evaluate(Call(...)) form |
|
VerifyParallelLoop |
|
LowerHopperIntrin |
|
|
Insert sync between parallel read/write of shared buffers. |
IfStmtBinding |
|
MergeIfStmt |
|
LoopUnswitching: Hoist loop-invariant if statements out of loops. |
|
Producer-consumer warp specialization at the tile-op level. |
|
Compatibility alias for |
|
Inject set_max_nreg calls into warp-specialized functions. |
|
Fuse simple expect_tx -> TMA issue -> arrive back into arrive_and_expect_tx. |
|
InjectFenceProxy |
|
Inject tcgen05.fence::before_thread_sync / after_thread_sync at |
|
LegalizeLoopVectorize |
|
LegalizeLoopVectorize |
|
Lower TileLang frontend tl.access_ptr to tir.builtin.tvm_access_ptr. |
|
MakePackedAPI |
|
AnnotateDeviceRegions |
|
Split host/device functions even for empty kernels. |
|
Annotate read-only handle parameters for PrimFuncs. |
|
|
VectorizeLoop |
Lower eligible global->shared copies into PTX cp.async on CUDA. |
|
Deprecated alias of LowerPTXAsyncCopy. |
|
Lower attached storage access information on device. |
|
Config index bitwidth. |
|
FlattenBuffer |
|
|
MergeSharedMemoryAllocations |
LowerL2Persistent |
|
|
MarkCudaSyncCalls |
PersistThreadblock |
|
LowerSharedBarrier |
|
Plan and update buffer allocation locations within PrimFuncs. |
|
Hoist global buffer allocations to the top of the block (host side). |
|
StorageRewrite |
|
LowerOpaqueBlock |
|
LowerThreadAllreduce |
|
LowerIntrin |
|
Create and return a transform pass that lowers device kernel launch constructs to target-specific IR. |
|
LowerSharedTmem |
|
Return a TVM transform pass that performs layout reduction/normalization. |
|
Unroll loops as in Halide pipeline. |
|
Lower Ramp-based global memory load/store to ldg/stg intrinsics. |
|
Lower 2SM TCGEN5MMA and related on Blackwell target |
Package Contents¶
- tilelang.transform.get_pass_context()¶
Get the current pass context
- tilelang.transform.ClusterPlanning()¶
ClusterPlanning
- Returns:
fpass – The result pass
- Return type:
tvm.transform.Pass
- tilelang.transform.PipelinePlanning()¶
infer the fragment/shared memory layout
- Returns:
fpass – The result pass
- Return type:
tvm.transform.Pass
- tilelang.transform.InstructionAnnotation()¶
Annotate tile operations with coarse-grained instruction kind.
This pass runs before LayoutInference and LowerTileOp. It adds a
tl_instruction_kindannotation to each tile-op Call node indicating the instruction category (“tma”, “cp_async”, “sync”, “wgmma”, etc.) that will be selected during lowering.- Returns:
fpass – The result pass
- Return type:
tvm.transform.Pass
- tilelang.transform.LayoutInference()¶
LayoutInference
- Returns:
fpass – The result pass
- Return type:
tvm.transform.Pass
- tilelang.transform.LowerTileOp()¶
LowerTileOp
- Returns:
fpass – The result pass
- Return type:
tvm.transform.Pass
- tilelang.transform.InjectSoftwarePipeline()¶
InjectSoftwarePipeline
- Returns:
fpass – The result pass
- Return type:
tvm.transform.Pass
- tilelang.transform.LegalizeNegativeIndex()¶
Legalize negative indices in buffer loads.
- Returns:
fpass – The result pass
- Return type:
tvm.transform.Pass
- tilelang.transform.InjectAssumes()¶
Inject Assumes for natural shape boundary conditions. And convert Assumes in Evaluate(Call(…)) form (tvm builtin assume call) to AttrNode form.
Returns:¶
- fpasstvm.transform.Pass
The result pass
- tilelang.transform.VerifyParallelLoop()¶
VerifyParallelLoop
- Returns:
fpass – The result pass
- Return type:
tvm.transform.Pass
- tilelang.transform.LowerHopperIntrin()¶
LowerHopperIntrin
- Returns:
fpass – The result pass
- Return type:
tvm.transform.Pass
- tilelang.transform.ThreadSync(storage_scope)¶
Insert sync between parallel read/write of shared buffers.
- Parameters:
storage_scope (str) – The target storage scope.
- Returns:
fpass – The result pass
- Return type:
tvm.transform.Pass
- tilelang.transform.IfStmtBinding()¶
IfStmtBinding
- Returns:
fpass – The result pass
- Return type:
tvm.transform.Pass
- tilelang.transform.MergeIfStmt()¶
MergeIfStmt
- Returns:
fpass – The result pass
- Return type:
tvm.transform.Pass
- tilelang.transform.LoopUnswitching()¶
LoopUnswitching: Hoist loop-invariant if statements out of loops.
- Returns:
fpass – The result pass
- Return type:
tvm.transform.Pass
- tilelang.transform.ProducerConsumerWarpSpecialized()¶
Producer-consumer warp specialization at the tile-op level.
This pass runs before LayoutInference and LowerTileOp. It rewrites eligible pipelined tile-op loops into warp-specialized producer and consumer branches with explicit barrier synchronization.
- Returns:
fpass – The result pass
- Return type:
tvm.transform.Pass
- tilelang.transform.ProducerConsumerWarpSpecializedTiled()¶
Compatibility alias for
ProducerConsumerWarpSpecialized.The tiled tile-op implementation is now the canonical
ProducerConsumerWarpSpecializedpass.- Returns:
fpass – The result pass
- Return type:
tvm.transform.Pass
- tilelang.transform.AnnotateWarpGroupRegAlloc()¶
Inject set_max_nreg calls into warp-specialized functions.
This pass analyzes the function to collect register hints from set_max_nreg and no_set_max_nreg calls, then injects appropriate set_max_nreg calls into producer and consumer branches of warp-specialized code.
- Returns:
fpass – The result pass
- Return type:
tvm.transform.Pass
- tilelang.transform.FuseMBarrierArriveExpectTx()¶
Fuse simple expect_tx -> TMA issue -> arrive back into arrive_and_expect_tx.
- tilelang.transform.InjectFenceProxy()¶
InjectFenceProxy
- Returns:
fpass – The result pass
- Return type:
tvm.transform.Pass
- tilelang.transform.InjectTcgen05Fence()¶
Inject tcgen05.fence::before_thread_sync / after_thread_sync at conservative TCGEN05/TMEM synchronization boundaries on Blackwell (SM100+) targets.
The current pass wraps CTA-wide shared-memory syncs and also inserts fences around linear mbarrier wait/use and use/arrive handoff patterns. It is intentionally conservative and does not try to infer arbitrary barrier protocols.
- Returns:
fpass – The result pass
- Return type:
tvm.transform.Pass
- tilelang.transform.LegalizeVectorizedLoop()¶
LegalizeLoopVectorize
- Returns:
fpass – The result pass
- Return type:
tvm.transform.Pass
- tilelang.transform.LegalizeSafeMemoryAccess()¶
LegalizeLoopVectorize
- Returns:
fpass – The result pass
- Return type:
tvm.transform.Pass
- tilelang.transform.LowerAccessPtr()¶
Lower TileLang frontend tl.access_ptr to tir.builtin.tvm_access_ptr.
- Returns:
fpass – The result pass
- Return type:
tvm.transform.Pass
- tilelang.transform.MakePackedAPI()¶
MakePackedAPI
- Returns:
fpass – The result pass
- Return type:
tvm.transform.Pass
- tilelang.transform.AnnotateDeviceRegions()¶
AnnotateDeviceRegions
- Returns:
fpass – The result pass
- Return type:
tvm.transform.Pass
- tilelang.transform.SplitHostDevice()¶
Split host/device functions even for empty kernels.
- Returns:
fpass – The result pass
- Return type:
tvm.transform.Pass
- tilelang.transform.AnnotateReadOnlyParams()¶
Annotate read-only handle parameters for PrimFuncs.
Adds attribute tl.readonly_param_indices listing param indices that are never written, enabling CUDA codegen to emit const qualifiers to unlock read-only cache loads.
- Returns:
fpass – The result pass
- Return type:
tvm.transform.Pass
- tilelang.transform.VectorizeLoop(enable_vectorize=True)¶
VectorizeLoop
- Returns:
fpass – The result pass
- Return type:
tvm.transform.Pass
- Parameters:
enable_vectorize (bool)
- tilelang.transform.LowerPTXAsyncCopy()¶
Lower eligible global->shared copies into PTX cp.async on CUDA.
When enabled (pass config tl.enable_async_copy, default True), this pass may rewrite plain user-written global->shared BufferStore patterns (e.g. SIMT copies in T.Parallel) into tir.ptx_cp_async, and insert tir.ptx_commit_group + tir.ptx_wait_group(0) to preserve synchronous semantics for normal stores. If explicit commit/wait intrinsics already exist, the pass avoids duplicating them (and may insert a missing commit immediately before an existing wait to cover injected cp.async).
- Returns:
fpass – The result pass
- Return type:
tvm.transform.Pass
- tilelang.transform.InjectPTXAsyncCopy()¶
Deprecated alias of LowerPTXAsyncCopy.
- tilelang.transform.LowerDeviceStorageAccessInfo()¶
Lower attached storage access information on device.
- Returns:
fpass – The result pass
- Return type:
tvm.transform.Pass
Note
Run this pass after all storage access analysis finish.
- tilelang.transform.ConfigIndexBitwidth()¶
Config index bitwidth.
- Returns:
fpass (tvm.transform.Pass) – The result pass
—-
- tilelang.transform.FlattenBuffer()¶
FlattenBuffer
- Returns:
fpass – The result pass
- Return type:
tvm.transform.Pass
MergeSharedMemoryAllocations
- Returns:
fpass – The result pass
- Return type:
tvm.transform.Pass
- Parameters:
enable_aggressive_merge (bool)
align_bytes (int)
- tilelang.transform.LowerL2Persistent()¶
LowerL2Persistent
- tilelang.transform.PersistThreadblock()¶
PersistThreadblock
LowerSharedBarrier
- tilelang.transform.PlanAndUpdateBufferAllocationLocation()¶
Plan and update buffer allocation locations within PrimFuncs.
- Returns:
fpass – The result pass
- Return type:
tvm.transform.Pass
- tilelang.transform.HoistGlobalBufferAllocations()¶
Hoist global buffer allocations to the top of the block (host side).
- Returns:
fpass – The result pass
- Return type:
tvm.transform.Pass
- tilelang.transform.HoistNonRestrictParams()¶
- tilelang.transform.StorageRewrite()¶
StorageRewrite
- Returns:
fpass – The result pass
- Return type:
tvm.transform.Pass
- tilelang.transform.LowerOpaqueBlock()¶
LowerOpaqueBlock
- tilelang.transform.LowerThreadAllreduce()¶
LowerThreadAllreduce
- tilelang.transform.LowerIntrin()¶
LowerIntrin
- tilelang.transform.LowerDeviceKernelLaunch()¶
Create and return a transform pass that lowers device kernel launch constructs to target-specific IR.
This pass transforms high-level device kernel launch and related intrinsics into lower-level IR suitable for backend code generation and device-side lowering.
- Returns:
The transform pass that performs device kernel launch lowering.
- Return type:
tvm.transform.Pass
LowerSharedTmem
- tilelang.transform.LayoutReducer()¶
Return a TVM transform pass that performs layout reduction/normalization.
This wrapper delegates to the underlying FFI implementation and returns a pass object suitable for use in a PassContext or pass pipeline. The pass is intended to simplify or reduce tensor/layout-related representations during relay/tile transformations.
- Returns:
The transform pass object produced by the FFI backend.
- tilelang.transform.UnrollLoop()¶
Unroll loops as in Halide pipeline.
This pass unrolls loops based on configuration options including: - auto_max_step: Threshold of number of steps to be automatically unrolled - auto_max_depth: Maximum nested level of loops that can be automatically unrolled - auto_max_extent: Maximum extent of loop that will be unrolled - explicit_unroll: Whether to explicitly unroll instead of setting a pragma - unroll_local_access: Whether to always unroll local access
- Returns:
fpass – The result pass
- Return type:
tvm.transform.Pass
- tilelang.transform.LowerLDGSTG()¶
Lower Ramp-based global memory load/store to ldg/stg intrinsics.
This pass transforms vectorized global memory loads and stores (using Ramp indices) into explicit ldg32/64/128/256 and stg32/64/128/256 intrinsics for better codegen.
Key behaviors: - Converts Ramp-based global BufferLoad to ldg intrinsics - Converts Ramp-based global BufferStore to stg intrinsics - Supports predicated loads (if_then_else with else=0) - Supports predicated stores (if in then case) - Skips loads in async scope (will be lowered to cp.async) - Only enabled for CUDA targets
- Returns:
fpass – The result pass
- Return type:
tvm.transform.Pass
- tilelang.transform.LowerBlackwell2SM()¶
Lower 2SM TCGEN5MMA and related on Blackwell target
- Returns:
- tvm.transform.Pass
The result pass
- Return type:
fpass