tilelang.language.reduce_op

Reduce operations exposed on the TileLang language surface.

Attributes

函式

reduce(buffer, out, reduce_type, dim, clear)

Perform a reduction operation on a buffer along a specified dimension.

reduce_max(buffer, out[, dim, clear])

Perform reduce max on input buffer, store the result to output buffer

reduce_min(buffer, out[, dim, clear])

Perform reduce min on input buffer, store the result to output buffer.

reduce_sum(buffer, out[, dim, clear])

Perform reduce sum on input buffer, store the result to output buffer.

reduce_abssum(buffer, out[, dim])

Perform reduce absolute sum on input buffer, store the result to output buffer.

reduce_absmax(buffer, out[, dim, clear])

Perform reduce absolute max on input buffer, store the result to output buffer.

reduce_bitand(buffer, out[, dim, clear])

Perform reduce bitwise-and on input buffer, store the result to output buffer.

reduce_bitor(buffer, out[, dim, clear])

Perform reduce bitwise-or on input buffer, store the result to output buffer.

reduce_bitxor(buffer, out[, dim, clear])

Perform reduce bitwise-xor on input buffer, store the result to output buffer.

cumsum_fragment(src, dst, dim, reverse)

Compute cumulative sum for fragment buffers by copying to shared memory first.

cumsum(src[, dst, dim, reverse])

Compute the cumulative sum of src along dim, writing results to dst.

finalize_reducer(reducer)

Finalize a reducer buffer by emitting the tl.tileop.finalize_reducer intrinsic.

warp_reduce_sum(value)

Perform warp reduction sum on a register value.

warp_reduce_max(value)

Perform warp reduction max on a register value.

warp_reduce_min(value)

Perform warp reduction min on a register value.

warp_reduce_bitand(value)

Perform warp reduction bitwise-and on a register value.

warp_reduce_bitor(value)

Perform warp reduction bitwise-or on a register value.

Module Contents

tilelang.language.reduce_op.ReduceKind
tilelang.language.reduce_op.reduce(buffer, out, reduce_type, dim, clear)

Perform a reduction operation on a buffer along a specified dimension.

參數:
  • buffer (tir.Buffer) -- Input buffer to reduce

  • out (tir.Buffer) -- Output buffer to store results

  • reduce_type (str) -- Type of reduction ('max', 'min', 'sum', 'abssum')

  • dim (int) -- Dimension along which to perform reduction

  • clear (bool) -- Whether to initialize the output buffer before reduction

回傳型別:

None

tilelang.language.reduce_op.reduce_max(buffer, out, dim=-1, clear=True)

Perform reduce max on input buffer, store the result to output buffer

參數:
  • buffer (Buffer) -- The input buffer.

  • out (Buffer) -- The output buffer.

  • dim (int) -- The dimension to perform reduce on

  • clear (bool) -- If set to True, the output buffer will first be initialized to -inf.

回傳:

handle

回傳型別:

PrimExpr

tilelang.language.reduce_op.reduce_min(buffer, out, dim=-1, clear=True)

Perform reduce min on input buffer, store the result to output buffer.

參數:
  • buffer (tir.Buffer) -- The input buffer

  • out (tir.Buffer) -- The output buffer

  • dim (int) -- The dimension to perform reduce on

  • clear (bool, optional) -- If True, output buffer will be initialized to inf. Defaults to True.

回傳:

Handle to the reduction operation

回傳型別:

tir.Call

tilelang.language.reduce_op.reduce_sum(buffer, out, dim=-1, clear=True)

Perform reduce sum on input buffer, store the result to output buffer.

參數:
  • buffer (tir.Buffer) -- The input buffer

  • out (tir.Buffer) -- The output buffer

  • dim (int) -- The dimension to perform reduce on

  • clear (bool, optional) -- If True, output buffer will be cleared before reduction. If False, results will be accumulated on existing values. Defaults to True.

回傳型別:

None

Note: When clear=True, reduce_sum will not compute directly on the output buffer. This is because

during warp reduction, the same value would be accumulated multiple times (number of threads in the warp). Therefore, the implementation with clear=True follows these steps:

  1. create a temp buffer with same shape and dtype as out

  2. copy out to temp buffer

  3. call reduce_sum with temp buffer and out

  4. Add temp buffer to out

回傳:

Handle to the reduction operation

回傳型別:

tir.Call

參數:
  • buffer (tvm.tir.Buffer)

  • out (tvm.tir.Buffer)

  • dim (int)

  • clear (bool)

tilelang.language.reduce_op.reduce_abssum(buffer, out, dim=-1)

Perform reduce absolute sum on input buffer, store the result to output buffer.

參數:
  • buffer (tir.Buffer) -- The input buffer

  • out (tir.Buffer) -- The output buffer

  • dim (int) -- The dimension to perform reduce on

回傳:

Handle to the reduction operation

回傳型別:

tir.Call

tilelang.language.reduce_op.reduce_absmax(buffer, out, dim=-1, clear=True)

Perform reduce absolute max on input buffer, store the result to output buffer.

參數:
  • buffer (tir.Buffer) -- The input buffer

  • out (tir.Buffer) -- The output buffer

  • dim (int) -- The dimension to perform reduce on

  • clear (bool)

回傳:

Handle to the reduction operation

回傳型別:

tir.Call

tilelang.language.reduce_op.reduce_bitand(buffer, out, dim=-1, clear=True)

Perform reduce bitwise-and on input buffer, store the result to output buffer.

參數:
  • buffer (tir.Buffer) -- The input buffer

  • out (tir.Buffer) -- The output buffer

  • dim (int) -- The dimension to perform reduce on

  • clear (bool)

回傳:

Handle to the reduction operation

回傳型別:

tir.Call

tilelang.language.reduce_op.reduce_bitor(buffer, out, dim=-1, clear=True)

Perform reduce bitwise-or on input buffer, store the result to output buffer.

參數:
  • buffer (tir.Buffer) -- The input buffer

  • out (tir.Buffer) -- The output buffer

  • dim (int) -- The dimension to perform reduce on

  • clear (bool)

回傳:

Handle to the reduction operation

回傳型別:

tir.Call

tilelang.language.reduce_op.reduce_bitxor(buffer, out, dim=-1, clear=True)

Perform reduce bitwise-xor on input buffer, store the result to output buffer.

參數:
  • buffer (tir.Buffer) -- The input buffer

  • out (tir.Buffer) -- The output buffer

  • dim (int) -- The dimension to perform reduce on

  • clear (bool)

回傳:

Handle to the reduction operation

回傳型別:

tir.Call

tilelang.language.reduce_op.cumsum_fragment(src, dst, dim, reverse)

Compute cumulative sum for fragment buffers by copying to shared memory first.

This macro handles cumulative sum operations on fragment buffers by first copying the data to shared memory, performing the cumsum operation, and then copying back.

參數:
  • src (tilelang._typing.BufferLikeType) -- Source buffer (Buffer, BufferRegion, or BufferLoad) containing input data.

  • dst (tilelang._typing.BufferLikeType) -- Destination buffer (Buffer, BufferRegion, or BufferLoad) for output data.

  • dim (int) -- Dimension along which to compute cumulative sum.

  • reverse (bool) -- If True, compute cumulative sum in reverse order.

回傳型別:

None

tilelang.language.reduce_op.cumsum(src, dst=None, dim=0, reverse=False)

Compute the cumulative sum of src along dim, writing results to dst.

Negative dim indices are normalized (Python-style). If dst is None, the operation is performed in-place into src. Raises ValueError when dim is out of bounds for src.shape. When src.scope() == "local.fragment", this delegates to cumsum_fragment; otherwise it emits the tl.cumsum intrinsic.

Supports Buffer, BufferRegion, and BufferLoad inputs, allowing operations on buffer slices/regions.

範例

A 1D inclusive scan that writes the result into a separate shared-memory buffer:

>>> import tilelang.language as T
>>> @T.prim_func
... def kernel(A: T.Tensor((128,), "float32"), B: T.Tensor((128,), "float32")):
...     with T.Kernel(1, threads=128):
...         A_shared = T.alloc_shared((128,), "float32")
...         T.copy(A, A_shared)
...         T.cumsum(src=A_shared, dst=A_shared, dim=0)
...         T.copy(A_shared, B)

A 2D prefix sum along the last dimension with reverse accumulation:

>>> import tilelang.language as T
>>> @T.prim_func
... def kernel2d(A: T.Tensor((64, 64), "float16"), B: T.Tensor((64, 64), "float16")):
...     with T.Kernel(1, 1, threads=256):
...         tile = T.alloc_shared((64, 64), "float16")
...         T.copy(A, tile)
...         T.cumsum(src=tile, dim=1, reverse=True)
...         T.copy(tile, B)

Operating on a buffer region (slice):

>>> import tilelang.language as T
>>> @T.prim_func
... def kernel_region(InputG_fragment: T.Tensor((128,), "float32"), chunk_size: T.int32):
...     with T.Kernel(1, threads=128):
...         i = T.int32(0)
...         T.cumsum(InputG_fragment[i * chunk_size:(i + 1) * chunk_size], dim=0)
回傳:

A handle to the emitted cumulative-sum operation.

回傳型別:

tir.Call

參數:
  • src (tilelang._typing.BufferLikeType)

  • dst (tilelang._typing.BufferLikeType | None)

  • dim (int)

  • reverse (bool)

tilelang.language.reduce_op.finalize_reducer(reducer)

Finalize a reducer buffer by emitting the tl.tileop.finalize_reducer intrinsic.

This returns a TVM tir.Call handle that finalizes the given reducer using its writable pointer. The call does not modify Python objects directly; it produces the low-level intrinsic call used by the IR.

參數:

reducer (tir.Buffer) -- Reducer buffer whose writable pointer will be finalized.

回傳:

Handle to the finalize reducer intrinsic call.

回傳型別:

tir.Call

tilelang.language.reduce_op.warp_reduce_sum(value)

Perform warp reduction sum on a register value.

This function reduces a value across all threads in a warp using shuffle operations. Each thread provides a register value, and after the reduction, all threads will have the sum of all values across the warp.

參數:

value (tir.PrimExpr) -- The input register value to reduce

回傳:

The reduced sum value (same on all threads in the warp)

回傳型別:

tir.PrimExpr

tilelang.language.reduce_op.warp_reduce_max(value)

Perform warp reduction max on a register value.

This function reduces a value across all threads in a warp using shuffle operations. Each thread provides a register value, and after the reduction, all threads will have the max of all values across the warp.

參數:

value (tir.PrimExpr) -- The input register value to reduce

回傳:

The reduced max value (same on all threads in the warp)

回傳型別:

tir.PrimExpr

tilelang.language.reduce_op.warp_reduce_min(value)

Perform warp reduction min on a register value.

This function reduces a value across all threads in a warp using shuffle operations. Each thread provides a register value, and after the reduction, all threads will have the min of all values across the warp.

參數:

value (tir.PrimExpr) -- The input register value to reduce

回傳:

The reduced min value (same on all threads in the warp)

回傳型別:

tir.PrimExpr

tilelang.language.reduce_op.warp_reduce_bitand(value)

Perform warp reduction bitwise-and on a register value.

This function reduces a value across all threads in a warp using shuffle operations. Each thread provides a register value, and after the reduction, all threads will have the bitwise-and of all values across the warp.

參數:

value (tir.PrimExpr) -- The input register value to reduce

回傳:

The reduced bitwise-and value (same on all threads in the warp)

回傳型別:

tir.PrimExpr

tilelang.language.reduce_op.warp_reduce_bitor(value)

Perform warp reduction bitwise-or on a register value.

This function reduces a value across all threads in a warp using shuffle operations. Each thread provides a register value, and after the reduction, all threads will have the bitwise-or of all values across the warp.

參數:

value (tir.PrimExpr) -- The input register value to reduce

回傳:

The reduced bitwise-or value (same on all threads in the warp)

回傳型別:

tir.PrimExpr