tilelang.contrib.cutedsl.reduce =============================== .. py:module:: tilelang.contrib.cutedsl.reduce .. autoapi-nested-parse:: Reduce operations for CuTeDSL backend. Based on tl_templates/cuda/reduce.h Classes ------- .. autoapisummary:: tilelang.contrib.cutedsl.reduce.SumOp tilelang.contrib.cutedsl.reduce.MaxOp tilelang.contrib.cutedsl.reduce.MinOp tilelang.contrib.cutedsl.reduce.BitAndOp tilelang.contrib.cutedsl.reduce.BitOrOp tilelang.contrib.cutedsl.reduce.BitXorOp tilelang.contrib.cutedsl.reduce.CumSum1D tilelang.contrib.cutedsl.reduce.CumSum2D tilelang.contrib.cutedsl.reduce.NamedBarrier Functions --------- .. autoapisummary:: tilelang.contrib.cutedsl.reduce.min tilelang.contrib.cutedsl.reduce.max tilelang.contrib.cutedsl.reduce.bar_sync tilelang.contrib.cutedsl.reduce.bar_sync_ptx tilelang.contrib.cutedsl.reduce.AllReduce Module Contents --------------- .. py:function:: min(a, b, c=None) Type-aware min: uses arith.minsi for integers, nvvm.fmin for floats. Falls back to integer path if float conversion fails (signless int types). .. py:function:: max(a, b, c=None) Type-aware max: uses arith.maxsi for integers, nvvm.fmax for floats. Falls back to integer path if float conversion fails (signless int types). .. py:class:: SumOp Sum reduction operator .. py:method:: __call__(x, y) :staticmethod: .. py:class:: MaxOp Max reduction operator .. py:method:: __call__(x, y) :staticmethod: .. py:class:: MinOp Min reduction operator .. py:method:: __call__(x, y) :staticmethod: .. py:class:: BitAndOp Bitwise AND reduction operator .. py:method:: __call__(x, y) :staticmethod: .. py:class:: BitOrOp Bitwise OR reduction operator .. py:method:: __call__(x, y) :staticmethod: .. py:class:: BitXorOp Bitwise XOR reduction operator .. py:method:: __call__(x, y) :staticmethod: .. py:function:: bar_sync(barrier_id, number_of_threads) .. py:function:: bar_sync_ptx(barrier_id, number_of_threads) .. py:class:: CumSum1D(threads, reverse) 1D cumulative sum operation. Based on tl::CumSum1D from reduce.h Template params: threads: Number of threads reverse: Whether to cumsum in reverse order .. py:attribute:: threads .. py:attribute:: reverse .. py:attribute:: SEG :value: 32 .. py:method:: run(src, dst, N) Perform 1D cumulative sum. :param src: Source pointer :param dst: Destination pointer :param N: Number of elements (must be compile-time constant or small) .. py:class:: CumSum2D(threads, dim, reverse) 2D cumulative sum operation. Based on tl::CumSum2D from reduce.h Template params: threads: Number of threads (must be power of 2, 32-1024) dim: Axis along which to cumsum (0 or 1) reverse: Whether to cumsum in reverse order .. py:attribute:: threads .. py:attribute:: dim .. py:attribute:: reverse .. py:attribute:: SEG :value: 32 .. py:attribute:: TILE_H .. py:method:: run(src, dst, H, W) Perform 2D cumulative sum. :param src: Source pointer :param dst: Destination pointer :param H: Number of rows :param W: Number of columns (should be <= 32 for single-segment case) .. py:class:: NamedBarrier(all_threads) Named barrier policy for AllReduce, uses bar.sync instead of __syncthreads. Based on tl::NamedBarrier from reduce.h .. py:attribute:: all_threads .. py:function:: AllReduce(reducer, threads, scale, thread_offset, all_threads=None) AllReduce operation implementing warp/block-level reduction. Based on tl::AllReduce from reduce.h :param reducer: Reducer operator class (SumOp, MaxOp, etc.) :param threads: Number of threads participating in reduction :param scale: Reduction scale factor :param thread_offset: Thread ID offset :param all_threads: Total number of threads in block :returns: A callable object with run() and run_hopper() methods