tilelang.language.reduce ======================== .. py:module:: tilelang.language.reduce .. autoapi-nested-parse:: The language interface for tl programs. Functions --------- .. autoapisummary:: tilelang.language.reduce.reduce tilelang.language.reduce.reduce_max tilelang.language.reduce.reduce_min tilelang.language.reduce.reduce_sum tilelang.language.reduce.reduce_abssum tilelang.language.reduce.reduce_absmax tilelang.language.reduce.reduce_bitand tilelang.language.reduce.reduce_bitor tilelang.language.reduce.reduce_bitxor tilelang.language.reduce.cumsum_fragment tilelang.language.reduce.cumsum tilelang.language.reduce.finalize_reducer tilelang.language.reduce.warp_reduce_sum tilelang.language.reduce.warp_reduce_max tilelang.language.reduce.warp_reduce_min tilelang.language.reduce.warp_reduce_bitand tilelang.language.reduce.warp_reduce_bitor Module Contents --------------- .. py:function:: reduce(buffer, out, reduce_type, dim, clear) Perform a reduction operation on a buffer along a specified dimension. :param buffer: Input buffer to reduce :type buffer: tir.Buffer :param out: Output buffer to store results :type out: tir.Buffer :param reduce_type: Type of reduction ('max', 'min', 'sum', 'abssum') :type reduce_type: str :param dim: Dimension along which to perform reduction :type dim: int :param clear: Whether to initialize the output buffer before reduction :type clear: bool :returns: Handle to the reduction operation :rtype: tir.Call .. py:function:: reduce_max(buffer, out, dim = -1, clear = True) Perform reduce max on input buffer, store the result to output buffer :param buffer: The input buffer. :type buffer: Buffer :param out: The output buffer. :type out: Buffer :param dim: The dimension to perform reduce on :type dim: int :param clear: If set to True, the output buffer will first be initialized to -inf. :type clear: bool :returns: **handle** :rtype: PrimExpr .. py:function:: reduce_min(buffer, out, dim = -1, clear = True) Perform reduce min on input buffer, store the result to output buffer. :param buffer: The input buffer :type buffer: tir.Buffer :param out: The output buffer :type out: tir.Buffer :param dim: The dimension to perform reduce on :type dim: int :param clear: If True, output buffer will be initialized to inf. Defaults to True. :type clear: bool, optional :returns: Handle to the reduction operation :rtype: tir.Call .. py:function:: reduce_sum(buffer, out, dim = -1, clear = True) Perform reduce sum on input buffer, store the result to output buffer. :param buffer: The input buffer :type buffer: tir.Buffer :param out: The output buffer :type out: tir.Buffer :param dim: The dimension to perform reduce on :type dim: int :param clear: If True, output buffer will be cleared before reduction. If False, results will be accumulated on existing values. Defaults to True. :type clear: bool, optional Note: When clear=True, reduce_sum will not compute directly on the output buffer. This is because during warp reduction, the same value would be accumulated multiple times (number of threads in the warp). Therefore, the implementation with clear=True follows these steps: 1. create a temp buffer with same shape and dtype as out 2. copy out to temp buffer 3. call reduce_sum with temp buffer and out 4. Add temp buffer to out :returns: Handle to the reduction operation :rtype: tir.Call .. py:function:: reduce_abssum(buffer, out, dim = -1) Perform reduce absolute sum on input buffer, store the result to output buffer. :param buffer: The input buffer :type buffer: tir.Buffer :param out: The output buffer :type out: tir.Buffer :param dim: The dimension to perform reduce on :type dim: int :returns: Handle to the reduction operation :rtype: tir.Call .. py:function:: reduce_absmax(buffer, out, dim = -1, clear = True) Perform reduce absolute max on input buffer, store the result to output buffer. :param buffer: The input buffer :type buffer: tir.Buffer :param out: The output buffer :type out: tir.Buffer :param dim: The dimension to perform reduce on :type dim: int :returns: Handle to the reduction operation :rtype: tir.Call .. py:function:: reduce_bitand(buffer, out, dim = -1, clear = True) Perform reduce bitwise-and on input buffer, store the result to output buffer. :param buffer: The input buffer :type buffer: tir.Buffer :param out: The output buffer :type out: tir.Buffer :param dim: The dimension to perform reduce on :type dim: int :returns: Handle to the reduction operation :rtype: tir.Call .. py:function:: reduce_bitor(buffer, out, dim = -1, clear = True) Perform reduce bitwise-or on input buffer, store the result to output buffer. :param buffer: The input buffer :type buffer: tir.Buffer :param out: The output buffer :type out: tir.Buffer :param dim: The dimension to perform reduce on :type dim: int :returns: Handle to the reduction operation :rtype: tir.Call .. py:function:: reduce_bitxor(buffer, out, dim = -1, clear = True) Perform reduce bitwise-xor on input buffer, store the result to output buffer. :param buffer: The input buffer :type buffer: tir.Buffer :param out: The output buffer :type out: tir.Buffer :param dim: The dimension to perform reduce on :type dim: int :returns: Handle to the reduction operation :rtype: tir.Call .. py:function:: cumsum_fragment(src, dst, dim, reverse) .. py:function:: cumsum(src, dst = None, dim = 0, reverse = False) Compute the cumulative sum of `src` along `dim`, writing results to `dst`. Negative `dim` indices are normalized (Python-style). If `dst` is None, the operation is performed in-place into `src`. Raises ValueError when `dim` is out of bounds for `src.shape`. When `src.scope() == "local.fragment"`, this delegates to `cumsum_fragment`; otherwise it emits the `tl.cumsum` intrinsic. .. rubric:: Examples A 1D inclusive scan that writes the result into a separate shared-memory buffer: >>> import tilelang.language as T >>> @T.prim_func ... def kernel(A: T.Tensor((128,), "float32"), B: T.Tensor((128,), "float32")): ... with T.Kernel(1, threads=128): ... A_shared = T.alloc_shared((128,), "float32") ... T.copy(A, A_shared) ... T.cumsum(src=A_shared, dst=A_shared, dim=0) ... T.copy(A_shared, B) A 2D prefix sum along the last dimension with reverse accumulation: >>> import tilelang.language as T >>> @T.prim_func ... def kernel2d(A: T.Tensor((64, 64), "float16"), B: T.Tensor((64, 64), "float16")): ... with T.Kernel(1, 1, threads=256): ... tile = T.alloc_shared((64, 64), "float16") ... T.copy(A, tile) ... T.cumsum(src=tile, dim=1, reverse=True) ... T.copy(tile, B) :returns: A handle to the emitted cumulative-sum operation. :rtype: tir.Call .. py:function:: finalize_reducer(reducer) Finalize a reducer buffer by emitting the `tl.tileop.finalize_reducer` intrinsic. This returns a TVM `tir.Call` handle that finalizes the given reducer using its writable pointer. The call does not modify Python objects directly; it produces the low-level intrinsic call used by the IR. :param reducer: Reducer buffer whose writable pointer will be finalized. :type reducer: tir.Buffer :returns: Handle to the finalize reducer intrinsic call. :rtype: tir.Call .. py:function:: warp_reduce_sum(value) Perform warp reduction sum on a register value. This function reduces a value across all threads in a warp using shuffle operations. Each thread provides a register `value`, and after the reduction, all threads will have the sum of all values across the warp. :param value: The input register value to reduce :type value: tir.PrimExpr :returns: The reduced sum value (same on all threads in the warp) :rtype: tir.PrimExpr .. py:function:: warp_reduce_max(value) Perform warp reduction max on a register value. This function reduces a value across all threads in a warp using shuffle operations. Each thread provides a register `value`, and after the reduction, all threads will have the max of all values across the warp. :param value: The input register value to reduce :type value: tir.PrimExpr :returns: The reduced max value (same on all threads in the warp) :rtype: tir.PrimExpr .. py:function:: warp_reduce_min(value) Perform warp reduction min on a register value. This function reduces a value across all threads in a warp using shuffle operations. Each thread provides a register `value`, and after the reduction, all threads will have the min of all values across the warp. :param value: The input register value to reduce :type value: tir.PrimExpr :returns: The reduced min value (same on all threads in the warp) :rtype: tir.PrimExpr .. py:function:: warp_reduce_bitand(value) Perform warp reduction bitwise-and on a register value. This function reduces a value across all threads in a warp using shuffle operations. Each thread provides a register `value`, and after the reduction, all threads will have the bitwise-and of all values across the warp. :param value: The input register value to reduce :type value: tir.PrimExpr :returns: The reduced bitwise-and value (same on all threads in the warp) :rtype: tir.PrimExpr .. py:function:: warp_reduce_bitor(value) Perform warp reduction bitwise-or on a register value. This function reduces a value across all threads in a warp using shuffle operations. Each thread provides a register `value`, and after the reduction, all threads will have the bitwise-or of all values across the warp. :param value: The input register value to reduce :type value: tir.PrimExpr :returns: The reduced bitwise-or value (same on all threads in the warp) :rtype: tir.PrimExpr