編譯#
MLX 提供 compile() 函式變換,用於編譯計算圖。函式編譯會透過合併共用工作與融合某些運算來得到更小的圖。在許多情況下,這能大幅改善執行時間與記憶體使用量。
開始使用 compile() 很簡單,但對於更複雜的圖與進階用法,有一些邊界情況值得注意。
Compile 的基礎#
先從一個簡單範例開始:
def fun(x, y):
return mx.exp(-x) + y
x = mx.array(1.0)
y = mx.array(2.0)
# Regular call, no compilation
# Prints: array(2.36788, dtype=float32)
print(fun(x, y))
# Compile the function
compiled_fun = mx.compile(fun)
# Prints: array(2.36788, dtype=float32)
print(compiled_fun(x, y))
一般函式與編譯後函式的輸出在數值精度上是一致的。
第一次呼叫已編譯的函式時,MLX 會建構計算圖、最佳化,並產生與編譯程式碼。這可能相對較慢。不過 MLX 會快取已編譯的函式,因此多次呼叫不會再次觸發編譯。這表示你通常應該只編譯會重複使用的函式。
def fun(x, y):
return mx.exp(-x) + y
x = mx.array(1.0)
y = mx.array(2.0)
compiled_fun = mx.compile(fun)
# Compiled here
compiled_fun(x, y)
# Not compiled again
compiled_fun(x, y)
# Not compiled again
mx.compile(fun)(x, y)
有幾個重要情況會導致函式被重新編譯:
變更形狀或維度數量
變更任何輸入的型別
變更函式輸入的數量
在某些情況下只會重新執行部分編譯流程(例如變更形狀),而在其他情況下會重新執行完整編譯流程(例如變更型別)。一般而言,應避免太頻繁地編譯函式。
另一種需要留意的寫法是編譯會頻繁建立與銷毀的函式。例如在迴圈中編譯匿名函式:
a = mx.array(1.0)
# Don't do this, compiles lambda at each iteration
for _ in range(5):
mx.compile(lambda x: mx.exp(mx.abs(x)))(a)
加速示例#
The mlx.nn.gelu() is a nonlinear activation function commonly used with
Transformer-based models. The implementation involves several unary and binary
element-wise operations:
def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
若使用小陣列,這個函式會受限於開銷;若使用大陣列,則會受限於記憶體頻寬。不過 gelu 中所有運算都可透過 compile() 融合為單一核心,能在兩種情況下都大幅加速。
讓我們比較一般函式與編譯後函式的執行時間。我們會使用以下計時輔助函式,包含暖身並處理同步:
import time
def timeit(fun, x):
# warm up
for _ in range(10):
mx.eval(fun(x))
tic = time.perf_counter()
for _ in range(100):
mx.eval(fun(x))
toc = time.perf_counter()
tpi = 1e3 * (toc - tic) / 100
print(f"Time per iteration {tpi:.3f} (ms)")
現在建立一個陣列並對兩個函式進行基準測試:
x = mx.random.uniform(shape=(32, 1000, 4096))
timeit(gelu, x)
timeit(mx.compile(gelu), x)
在 M1 Max 上,時間分別為 15.5 與 3.1 毫秒,編譯後的 gelu 快了五倍。
除錯#
編譯函式第一次被呼叫時會以占位輸入進行追蹤。這表示你無法在編譯函式內評估陣列(例如列印其內容)。
@mx.compile
def fun(x):
z = -x
print(z) # Crash
return mx.exp(z)
fun(mx.array(5.0))
為了除錯,檢視陣列很有幫助。一種作法是用 disable_compile() 或 MLX_DISABLE_COMPILE 旗標全域停用編譯。例如以下即使 fun 已編譯也沒問題:
@mx.compile
def fun(x):
z = -x
print(z) # Okay
return mx.exp(z)
mx.disable_compile()
fun(mx.array(5.0))
純函式#
編譯函式預期是*純函式*,也就是不應有副作用。例如:
state = []
@mx.compile
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z)
fun(mx.array(1.0), mx.array(2.0))
# Crash!
print(state)
在第一次呼叫 fun 之後,state 清單會包含一個占位陣列。這個占位陣列沒有任何資料,只用來建構計算圖。列印這種陣列會造成崩潰。
你有兩個做法可處理此問題。第一個是直接把 state 當作輸出回傳:
state = []
@mx.compile
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z), state
_, state = fun(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)]
print(state)
有時回傳更新後的狀態很不方便。因此 compile() 提供參數來捕捉隱式輸出:
from functools import partial
state = []
# Tell compile to capture state as an output
@partial(mx.compile, outputs=state)
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z)
fun(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)]
print(state)
當你編譯的函式會更新一個陣列容器時(例如訓練 mlx.nn.Module 的參數時常見),這尤其有用。
編譯函式也會把不在參數列表中的任何輸入視為常數。例如:
state = [mx.array(1.0)]
@mx.compile
def fun(x):
return x + state[0]
# Prints array(2, dtype=float32)
print(fun(mx.array(1.0)))
# Update state
state[0] = mx.array(5.0)
# Still prints array(2, dtype=float32)
print(fun(mx.array(1.0)))
若要讓狀態變更反映在 fun 的輸出上,你同樣有兩個選項。第一個是直接把 state 作為函式輸入。
state = [mx.array(1.0)]
@mx.compile
def fun(x, state):
return x + state[0]
# Prints array(2, dtype=float32)
print(fun(mx.array(1.0), state))
# Update state
state[0] = mx.array(5.0)
# Prints array(6, dtype=float32)
print(fun(mx.array(1.0), state))
在某些情況下這會很不方便。因此 compile() 也提供參數來捕捉隱式輸入:
from functools import partial
state = [mx.array(1.0)]
# Tell compile to capture state as an input
@partial(mx.compile, inputs=state)
def fun(x):
return x + state[0]
# Prints array(2, dtype=float32)
print(fun(mx.array(1.0)))
# Update state
state[0] = mx.array(5.0)
# Prints array(6, dtype=float32)
print(fun(mx.array(1.0)))
編譯訓練圖#
本節將以一個常見設定的簡單範例說明如何使用 compile():使用 mlx.nn.Module 訓練模型,並搭配具有狀態的 mlx.optimizers.Optimizer。我們會示範如何用 compile() 編譯完整的前向、反向與更新。
首先,以下是不使用任何編譯的簡單範例:
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
# 4 examples with 10 features each
x = mx.random.uniform(shape=(4, 10))
# 0, 1 targets
y = mx.array([0, 1, 0, 1])
# Simple linear model
model = nn.Linear(10, 1)
# SGD with momentum
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
def loss_fn(model, x, y):
logits = model(x).squeeze()
return nn.losses.binary_cross_entropy(logits, y)
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
# Perform 10 steps of gradient descent
for it in range(10):
loss, grads = loss_and_grad_fn(model, x, y)
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state)
要編譯更新,我們可以把所有步驟包成一個函式,並用適當的輸入與輸出捕捉進行編譯。以下是相同範例的編譯版本:
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from functools import partial
# 4 examples with 10 features each
x = mx.random.uniform(shape=(4, 10))
# 0, 1 targets
y = mx.array([0, 1, 0, 1])
# Simple linear model
model = nn.Linear(10, 1)
# SGD with momentum
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
def loss_fn(model, x, y):
logits = model(x).squeeze()
return nn.losses.binary_cross_entropy(logits, y)
# The state that will be captured as input and output
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(x, y):
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn(model, x, y)
optimizer.update(model, grads)
return loss
# Perform 10 steps of gradient descent
for it in range(10):
loss = step(x, y)
# Evaluate the model and optimizer state
mx.eval(state)
print(loss)
備註
如果你使用會進行隨機取樣的模組(例如 mlx.nn.Dropout()),請確保也把 mx.random.state 納入 compile() 捕捉的 state,即 state = [model.state, optimizer.state, mx.random.state]。
備註
更多完整訓練圖編譯範例請參考 MLX Examples GitHub 專案。
與 Compile 搭配的變換#
在 MLX 中,函式變換可以組合使用。你可以把任何函式變換套用到其他函式變換的輸出。詳情請參考 function transforms 的文件。
編譯變換後的函式會如預期般運作:
grad_fn = mx.grad(mx.exp)
compiled_grad_fn = mx.compile(grad_fn)
# Prints: array(2.71828, dtype=float32)
print(grad_fn(mx.array(1.0)))
# Also prints: array(2.71828, dtype=float32)
print(compiled_grad_fn(mx.array(1.0)))
備註
為了盡可能多地編譯,對已編譯函式的變換,預設不會再被編譯。若要編譯變換後的函式,只需再通過 compile()。
你也可以編譯會呼叫已編譯函式的函式。最佳做法是編譯最外層函式,讓 compile() 有最多機會最佳化計算圖:
@mx.compile
def inner(x):
return mx.exp(-mx.abs(x))
def outer(x):
inner(inner(x))
# Compiling the outer function is good to do as it will likely
# be faster even though the inner functions are compiled
fun = mx.compile(outer)
無形狀編譯#
當已編譯函式的輸入形狀改變時,函式會重新編譯。你可以在 compile() 指定 shapeless=True,只編譯一次並可用於可變形狀的輸入。在此情況下,輸入形狀變更不會觸發重新編譯。
def fun(x, y):
return mx.abs(x + y)
compiled_fun = mx.compile(fun, shapeless=True)
x = mx.array(1.0)
y = mx.array(-2.0)
# Firt call compiles the function
print(compiled_fun(x, y))
# Second call with different shapes
# does not recompile the function
x = mx.array([1.0, -6.0])
y = mx.array([-2.0, 3.0])
print(compiled_fun(x, y))
請謹慎使用無形狀編譯。因為形狀變更不會觸發編譯,任何依賴輸入形狀的圖都可能無法如預期運作。形狀相依的計算很常見,有時也不易察覺。例如:
def fun(x):
return x.reshape(x.shape[0] * x.shape[1], -1)
compiled_fun = mx.compile(fun, shapeless=True)
x = mx.random.uniform(shape=(2, 3, 4))
out = compiled_fun(x)
x = mx.random.uniform(shape=(5, 5, 3))
# Error, can't reshape (5, 5, 3) to (6, -1)
out = compiled_fun(x)
第二次呼叫 compiled_fun 會失敗,因為 reshape() 使用了第一次呼叫時 x 的靜態形狀。我們可以改用 flatten() 來避免硬編碼 x 的形狀:
def fun(x):
return x.flatten(0, 1)
compiled_fun = mx.compile(fun, shapeless=True)
x = mx.random.uniform(shape=(2, 3, 4))
out = compiled_fun(x)
x = mx.random.uniform(shape=(5, 5, 3))
# Ok
out = compiled_fun(x)