tilelang.transform ================== .. py:module:: tilelang.transform .. autoapi-nested-parse:: Wrapping transformations. Submodules ---------- .. toctree:: :maxdepth: 1 /autoapi/tilelang/transform/add_bufstore_wrapper/index /autoapi/tilelang/transform/decouple_type_cast/index /autoapi/tilelang/transform/hoist_broadcast_values/index /autoapi/tilelang/transform/metal/index /autoapi/tilelang/transform/pass_config/index /autoapi/tilelang/transform/simplify/index Functions --------- .. autoapisummary:: tilelang.transform.get_pass_context tilelang.transform.ClusterPlanning tilelang.transform.PipelinePlanning tilelang.transform.LayoutInference tilelang.transform.LowerTileOp tilelang.transform.InjectSoftwarePipeline tilelang.transform.FrontendLegalize tilelang.transform.LegalizeNegativeIndex tilelang.transform.InjectAssumes tilelang.transform.VerifyParallelLoop tilelang.transform.LowerHopperIntrin tilelang.transform.WarpSpecializedPipeline tilelang.transform.RewriteWgmmaSync tilelang.transform.ThreadSync tilelang.transform.ThreadPartialSync tilelang.transform.IfStmtBinding tilelang.transform.MergeIfStmt tilelang.transform.LoopUnswitching tilelang.transform.MultiVersionBuffer tilelang.transform.WarpSpecialized tilelang.transform.AnnotateWarpGroupRegAlloc tilelang.transform.InjectTmaBarrier tilelang.transform.InjectFenceProxy tilelang.transform.LegalizeVectorizedLoop tilelang.transform.LegalizeSafeMemoryAccess tilelang.transform.LowerAccessPtr tilelang.transform.MakePackedAPI tilelang.transform.AnnotateDeviceRegions tilelang.transform.SplitHostDevice tilelang.transform.AnnotateReadOnlyParams tilelang.transform.VectorizeLoop tilelang.transform.InjectPTXAsyncCopy tilelang.transform.LowerDeviceStorageAccessInfo tilelang.transform.ConfigIndexBitwidth tilelang.transform.FlattenBuffer tilelang.transform.EliminateStorageSyncForMBarrier tilelang.transform.MergeSharedMemoryAllocations tilelang.transform.LowerL2Persistent tilelang.transform.MarkCudaSyncCalls tilelang.transform.PersistThreadblock tilelang.transform.AlignDynamicSharedMemoryAllocations tilelang.transform.LowerSharedBarrier tilelang.transform.PlanAndUpdateBufferAllocationLocation tilelang.transform.HoistNonRestrictParams tilelang.transform.StorageRewrite tilelang.transform.LowerOpaqueBlock tilelang.transform.LowerThreadAllreduce tilelang.transform.LowerIntrin tilelang.transform.LowerDeviceKernelLaunch tilelang.transform.LowerSharedTmem tilelang.transform.LayoutReducer tilelang.transform.UnrollLoop tilelang.transform.LowerLDGSTG Package Contents ---------------- .. py:function:: get_pass_context() Get the current pass context .. py:function:: ClusterPlanning() ClusterPlanning :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. py:function:: PipelinePlanning() infer the fragment/shared memory layout :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. py:function:: LayoutInference() LayoutInference :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. py:function:: LowerTileOp() LowerTileOp :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. py:function:: InjectSoftwarePipeline() InjectSoftwarePipeline :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. py:function:: FrontendLegalize() FrontendLegalize :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. py:function:: LegalizeNegativeIndex() Legalize negative indices in buffer loads. :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. py:function:: InjectAssumes() Inject Assumes for natural shape boundary conditions. And convert Assumes in Evaluate(Call(...)) form (tvm builtin assume call) to AttrNode form. Returns: ------- fpass : tvm.transform.Pass The result pass .. py:function:: VerifyParallelLoop() VerifyParallelLoop :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. py:function:: LowerHopperIntrin() LowerHopperIntrin :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. py:function:: WarpSpecializedPipeline() WarpSpecializedPipeline :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. py:function:: RewriteWgmmaSync() RewriteWgmmaSync :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. py:function:: ThreadSync(storage_scope) Insert sync between parallel read/write of shared buffers. :param storage_scope: The target storage scope. :type storage_scope: str :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. py:function:: ThreadPartialSync(storage_scope) Insert partial sync. :param storage_scope: The target storage scope. :type storage_scope: str :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. py:function:: IfStmtBinding() IfStmtBinding :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. py:function:: MergeIfStmt() MergeIfStmt :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. py:function:: LoopUnswitching() LoopUnswitching: Hoist loop-invariant if statements out of loops. :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. py:function:: MultiVersionBuffer() WarpSpecializedPipeline :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. py:function:: WarpSpecialized() WarpSpecializedPipeline :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. py:function:: AnnotateWarpGroupRegAlloc() Inject set_max_nreg calls into warp-specialized functions. This pass analyzes the function to collect register hints from set_max_nreg and no_set_max_nreg calls, then injects appropriate set_max_nreg calls into producer and consumer branches of warp-specialized code. :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. py:function:: InjectTmaBarrier() InjectTmaBarrier :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. py:function:: InjectFenceProxy() InjectFenceProxy :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. py:function:: LegalizeVectorizedLoop() LegalizeLoopVectorize :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. py:function:: LegalizeSafeMemoryAccess() LegalizeLoopVectorize :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. py:function:: LowerAccessPtr() Lower TileLang frontend `tl.access_ptr` to `tir.builtin.tvm_access_ptr`. :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. py:function:: MakePackedAPI() MakePackedAPI :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. py:function:: AnnotateDeviceRegions() AnnotateDeviceRegions :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. py:function:: SplitHostDevice() Split host/device functions even for empty kernels. :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. py:function:: AnnotateReadOnlyParams() Annotate read-only handle parameters for PrimFuncs. Adds attribute `tl.readonly_param_indices` listing param indices that are never written, enabling CUDA codegen to emit `const` qualifiers to unlock read-only cache loads. :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. py:function:: VectorizeLoop(enable_vectorize = True) VectorizeLoop :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. py:function:: InjectPTXAsyncCopy() Rewrite global to shared memory copy on CUDA with asynchronous copy. :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. py:function:: LowerDeviceStorageAccessInfo() Lower attached storage access information on device. :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. note:: Run this pass after all storage access analysis finish. .. py:function:: ConfigIndexBitwidth() Config index bitwidth. :returns: * **fpass** (*tvm.transform.Pass*) -- The result pass * *----* .. py:function:: FlattenBuffer() FlattenBuffer :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. py:function:: EliminateStorageSyncForMBarrier() EliminateStorageSyncForMBarrier .. py:function:: MergeSharedMemoryAllocations(enable_aggressive_merge = False, align_bytes = 16) MergeSharedMemoryAllocations :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. py:function:: LowerL2Persistent() LowerL2Persistent .. py:function:: MarkCudaSyncCalls(have_pdl = False) MarkCudaSyncCalls .. py:function:: PersistThreadblock() PersistThreadblock .. py:function:: AlignDynamicSharedMemoryAllocations(align_bytes = 16) AlignDynamicSharedMemoryAllocations :param align_bytes: The alignment bytes. :type align_bytes: int .. py:function:: LowerSharedBarrier() LowerSharedBarrier .. py:function:: PlanAndUpdateBufferAllocationLocation() Plan and update buffer allocation locations within PrimFuncs. :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. py:function:: HoistNonRestrictParams() .. py:function:: StorageRewrite() StorageRewrite :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. py:function:: LowerOpaqueBlock() LowerOpaqueBlock .. py:function:: LowerThreadAllreduce() LowerThreadAllreduce .. py:function:: LowerIntrin() LowerIntrin .. py:function:: LowerDeviceKernelLaunch() Create and return a transform pass that lowers device kernel launch constructs to target-specific IR. This pass transforms high-level device kernel launch and related intrinsics into lower-level IR suitable for backend code generation and device-side lowering. :returns: The transform pass that performs device kernel launch lowering. :rtype: tvm.transform.Pass .. py:function:: LowerSharedTmem() LowerSharedTmem .. py:function:: LayoutReducer() Return a TVM transform pass that performs layout reduction/normalization. This wrapper delegates to the underlying FFI implementation and returns a pass object suitable for use in a PassContext or pass pipeline. The pass is intended to simplify or reduce tensor/layout-related representations during relay/tile transformations. :returns: The transform pass object produced by the FFI backend. .. py:function:: UnrollLoop() Unroll loops as in Halide pipeline. This pass unrolls loops based on configuration options including: - auto_max_step: Threshold of number of steps to be automatically unrolled - auto_max_depth: Maximum nested level of loops that can be automatically unrolled - auto_max_extent: Maximum extent of loop that will be unrolled - explicit_unroll: Whether to explicitly unroll instead of setting a pragma - unroll_local_access: Whether to always unroll local access :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass .. py:function:: LowerLDGSTG() Lower Ramp-based global memory load/store to ldg/stg intrinsics. This pass transforms vectorized global memory loads and stores (using Ramp indices) into explicit ldg32/64/128/256 and stg32/64/128/256 intrinsics for better codegen. Key behaviors: - Converts Ramp-based global BufferLoad to ldg intrinsics - Converts Ramp-based global BufferStore to stg intrinsics - Supports predicated loads (if_then_else with else=0) - Supports predicated stores (if in then case) - Skips loads in async scope (will be lowered to cp.async) - Only enabled for CUDA targets :returns: **fpass** -- The result pass :rtype: tvm.transform.Pass