Control Flow¶
This guide covers the control‑flow primitives in TileLang and how they lower to efficient GPU code. You will use these to structure loops, handle boundaries, and express pipelined compute.
Overview¶
Conditionals:
if/elif/else, ternary (x if c else y)Loops:
T.serial,T.unroll,T.Parallel,T.PipelinedWhile loops:
whilewith a TIR conditionFlow control: Python
break/continueSafety: automatic OOB guards via the LegalizeSafeMemoryAccess pass
The examples assume import tilelang.language as T.
Conditionals¶
Standard Python if/elif/else is supported inside @T.prim_func kernels.
Conditions should be TIR expressions (e.g., i < N). Python plain booleans are
treated as compile‑time constants and will be folded.
for i in T.serial(N):
if i < N: # TIR condition
C[i] = A[i] + B[i]
else:
pass
# Ternary
x = (A[i] if i < N else 0)
Short‑circuit boolean ops are supported. For multi‑dimensional bounds, use
T.any_of / T.all_of for clarity:
if T.all_of(i < M, j < N):
C[i, j] = A[i, j] + B[i, j]
Boundary handling note
The LegalizeSafeMemoryAccess pass automatically inserts guards when an access may be out‑of‑bounds, and elides them when proven safe. You can often omit explicit
ifchecks for simple edge handling, but keep them when you need custom logic or clarity.
Loops¶
Serial¶
T.serial creates a plain for‑loop. Common forms:
for i in T.serial(N):
... # 0..N-1
for i in T.serial(0, N, 2):
... # 0, 2, 4, ...
Unroll¶
T.unroll requests loop unrolling for small trip counts.
for k in T.unroll(K_TILE):
acc += a[k] * b[k]
Advanced: TileLang forwards unroll hints to TIR; factor/explicit knobs are available for expert tuning.
Parallel (elementwise)¶
T.Parallel(ext0, ext1, ...) builds nested loops that map well to elementwise
operations. The body receives all indices in one for header:
for i, j in T.Parallel(M, N):
C[i, j] = A[i, j] + B[i, j]
Optional: coalesced_width= can hint memory coalescing for the innermost loop.
Pipelined (software pipelining)¶
T.Pipelined(iters, num_stages=...) overlaps producer/consumer stages (e.g.,
Global→Shared copies with compute). This is the backbone of GEMM/attention
pipelines.
for ko in T.Pipelined(T.ceildiv(K, BK), num_stages=3):
T.copy(A[by * BM, ko * BK], A_s) # stage: copy A tile
T.copy(B[ko * BK, bx * BN], B_s) # stage: copy B tile
T.gemm(A_s, B_s, C_f) # stage: compute
Persistent (advanced)¶
T.Persistent(domain, wave_size, index, group_size=...) exposes persistent
thread‑block style looping. It is an advanced construct that TileLang lowers in
later passes and is typically used by specialized templates.
While Loops¶
while is supported when the condition is a TIR expression. Avoid infinite
loops; TileLang will error if it detects a constant‑true condition.
i = 0
while i < N:
...
if done:
break
i += 1
Break and Continue¶
Use Python break/continue to exit or skip within T.serial/T.unroll/
T.Parallel/while loops. Keep the body clean after a break/continue for
readability; the compiler will ignore the dead path.
Putting It Together: Residual Tile Handling¶
Below is a typical edge pattern for a 2D kernel. With LegalizeSafeMemoryAccess, the explicit guard can be omitted when you don’t need a custom edge path.
for i, j in T.Parallel(M, N):
gi = by * BM + i
gj = bx * BN + j
if T.all_of(gi < M, gj < N): # optional in many cases
C[gi, gj] = A[gi, gj] + B[gi, gj]
Debugging Conditions¶
Use T.print to inspect values under predicates. For buffers, TileLang prints
from a single thread to avoid duplicate outputs.
if i == 0:
T.print(C, msg='C tile:')