tilelang.transform.hoist_broadcast_values

類別

函式

HoistBroadcastValues()

TVM Pass: HoistBroadcastValues.

Module Contents

class tilelang.transform.hoist_broadcast_values.HoistBroadcastValuesMutator

Bases: tvm.tir.PyStmtExprMutator

pending_defs = []
hoist_enabled = False
visit_broadcast_(op)
visit_buffer_store_(op)
參數:

op (tvm.tir.BufferStore)

visit_let_stmt_(op)
參數:

op (tvm.tir.LetStmt)

tilelang.transform.hoist_broadcast_values.HoistBroadcastValues()

TVM Pass: HoistBroadcastValues.

This pass scans the TIR for Broadcast operations involving immediate constants (IntImm, FloatImm). It extracts these constants into variables defined via LetStmt immediately surrounding the statement where the broadcast occurs.

Example Transformation:

Before:

A[i] = B[i] + T.Broadcast(3.14, 4) + T.Broadcast(3.14, 4)

After:

bv_3_14 = 3.14 bv_3_14_1 = 3.14 A[i] = B[i] + T.Broadcast(bv_3_14, 4) + T.Broadcast(bv_3_14_1, 4)