編譯#

MLX 提供 compile() 函式變換,用於編譯計算圖。函式編譯會透過合併共用工作與融合某些運算來得到更小的圖。在許多情況下,這能大幅改善執行時間與記憶體使用量。

開始使用 compile() 很簡單,但對於更複雜的圖與進階用法,有一些邊界情況值得注意。

Compile 的基礎#

先從一個簡單範例開始:

def fun(x, y):
    return mx.exp(-x) + y

x = mx.array(1.0)
y = mx.array(2.0)

# Regular call, no compilation
# Prints: array(2.36788, dtype=float32)
print(fun(x, y))

# Compile the function
compiled_fun = mx.compile(fun)

# Prints: array(2.36788, dtype=float32)
print(compiled_fun(x, y))

一般函式與編譯後函式的輸出在數值精度上是一致的。

第一次呼叫已編譯的函式時,MLX 會建構計算圖、最佳化,並產生與編譯程式碼。這可能相對較慢。不過 MLX 會快取已編譯的函式,因此多次呼叫不會再次觸發編譯。這表示你通常應該只編譯會重複使用的函式。

def fun(x, y):
    return mx.exp(-x) + y

x = mx.array(1.0)
y = mx.array(2.0)

compiled_fun = mx.compile(fun)

# Compiled here
compiled_fun(x, y)

# Not compiled again
compiled_fun(x, y)

# Not compiled again
mx.compile(fun)(x, y)

有幾個重要情況會導致函式被重新編譯:

  • 變更形狀或維度數量

  • 變更任何輸入的型別

  • 變更函式輸入的數量

在某些情況下只會重新執行部分編譯流程(例如變更形狀),而在其他情況下會重新執行完整編譯流程(例如變更型別)。一般而言,應避免太頻繁地編譯函式。

另一種需要留意的寫法是編譯會頻繁建立與銷毀的函式。例如在迴圈中編譯匿名函式:

a = mx.array(1.0)
# Don't do this, compiles lambda at each iteration
for _ in range(5):
    mx.compile(lambda x: mx.exp(mx.abs(x)))(a)

加速示例#

The mlx.nn.gelu() is a nonlinear activation function commonly used with Transformer-based models. The implementation involves several unary and binary element-wise operations:

def gelu(x):
    return x * (1 + mx.erf(x / math.sqrt(2))) / 2

若使用小陣列,這個函式會受限於開銷;若使用大陣列,則會受限於記憶體頻寬。不過 gelu 中所有運算都可透過 compile() 融合為單一核心,能在兩種情況下都大幅加速。

讓我們比較一般函式與編譯後函式的執行時間。我們會使用以下計時輔助函式,包含暖身並處理同步:

import time

def timeit(fun, x):
    # warm up
    for _ in range(10):
        mx.eval(fun(x))

    tic = time.perf_counter()
    for _ in range(100):
        mx.eval(fun(x))
    toc = time.perf_counter()
    tpi = 1e3 * (toc - tic) / 100
    print(f"Time per iteration {tpi:.3f} (ms)")

現在建立一個陣列並對兩個函式進行基準測試:

x = mx.random.uniform(shape=(32, 1000, 4096))
timeit(gelu, x)
timeit(mx.compile(gelu), x)

在 M1 Max 上,時間分別為 15.5 與 3.1 毫秒,編譯後的 gelu 快了五倍。

除錯#

編譯函式第一次被呼叫時會以占位輸入進行追蹤。這表示你無法在編譯函式內評估陣列(例如列印其內容)。

@mx.compile
def fun(x):
    z = -x
    print(z)  # Crash
    return mx.exp(z)

fun(mx.array(5.0))

為了除錯,檢視陣列很有幫助。一種作法是用 disable_compile()MLX_DISABLE_COMPILE 旗標全域停用編譯。例如以下即使 fun 已編譯也沒問題:

@mx.compile
def fun(x):
    z = -x
    print(z) # Okay
    return mx.exp(z)

mx.disable_compile()
fun(mx.array(5.0))

純函式#

編譯函式預期是*純函式*,也就是不應有副作用。例如:

state = []

@mx.compile
def fun(x, y):
    z = x + y
    state.append(z)
    return mx.exp(z)

fun(mx.array(1.0), mx.array(2.0))
# Crash!
print(state)

在第一次呼叫 fun 之後,state 清單會包含一個占位陣列。這個占位陣列沒有任何資料,只用來建構計算圖。列印這種陣列會造成崩潰。

你有兩個做法可處理此問題。第一個是直接把 state 當作輸出回傳:

state = []

@mx.compile
def fun(x, y):
   z = x + y
   state.append(z)
   return mx.exp(z), state

 _, state = fun(mx.array(1.0), mx.array(2.0))
 # Prints [array(3, dtype=float32)]
 print(state)

有時回傳更新後的狀態很不方便。因此 compile() 提供參數來捕捉隱式輸出:

from functools import partial

state = []

# Tell compile to capture state as an output
@partial(mx.compile, outputs=state)
def fun(x, y):
    z = x + y
    state.append(z)
    return mx.exp(z)

fun(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)]
print(state)

當你編譯的函式會更新一個陣列容器時(例如訓練 mlx.nn.Module 的參數時常見),這尤其有用。

編譯函式也會把不在參數列表中的任何輸入視為常數。例如:

state = [mx.array(1.0)]

@mx.compile
def fun(x):
    return x + state[0]

# Prints array(2, dtype=float32)
print(fun(mx.array(1.0)))

# Update state
state[0] = mx.array(5.0)

# Still prints array(2, dtype=float32)
print(fun(mx.array(1.0)))

若要讓狀態變更反映在 fun 的輸出上,你同樣有兩個選項。第一個是直接把 state 作為函式輸入。

state = [mx.array(1.0)]

@mx.compile
def fun(x, state):
    return x + state[0]

# Prints array(2, dtype=float32)
print(fun(mx.array(1.0), state))

# Update state
state[0] = mx.array(5.0)

# Prints array(6, dtype=float32)
print(fun(mx.array(1.0), state))

在某些情況下這會很不方便。因此 compile() 也提供參數來捕捉隱式輸入:

from functools import partial
state = [mx.array(1.0)]

# Tell compile to capture state as an input
@partial(mx.compile, inputs=state)
def fun(x):
    return x + state[0]

# Prints array(2, dtype=float32)
print(fun(mx.array(1.0)))

# Update state
state[0] = mx.array(5.0)

# Prints array(6, dtype=float32)
print(fun(mx.array(1.0)))

編譯訓練圖#

本節將以一個常見設定的簡單範例說明如何使用 compile():使用 mlx.nn.Module 訓練模型,並搭配具有狀態的 mlx.optimizers.Optimizer。我們會示範如何用 compile() 編譯完整的前向、反向與更新。

首先,以下是不使用任何編譯的簡單範例:

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim

# 4 examples with 10 features each
x = mx.random.uniform(shape=(4, 10))

# 0, 1 targets
y = mx.array([0, 1, 0, 1])

# Simple linear model
model = nn.Linear(10, 1)

# SGD with momentum
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)

def loss_fn(model, x, y):
    logits = model(x).squeeze()
    return nn.losses.binary_cross_entropy(logits, y)

loss_and_grad_fn = nn.value_and_grad(model, loss_fn)

# Perform 10 steps of gradient descent
for it in range(10):
    loss, grads = loss_and_grad_fn(model, x, y)
    optimizer.update(model, grads)
    mx.eval(model.parameters(), optimizer.state)

要編譯更新,我們可以把所有步驟包成一個函式,並用適當的輸入與輸出捕捉進行編譯。以下是相同範例的編譯版本:

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from functools import partial

# 4 examples with 10 features each
x = mx.random.uniform(shape=(4, 10))

# 0, 1 targets
y = mx.array([0, 1, 0, 1])

# Simple linear model
model = nn.Linear(10, 1)

# SGD with momentum
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)

def loss_fn(model, x, y):
    logits = model(x).squeeze()
    return nn.losses.binary_cross_entropy(logits, y)

# The state that will be captured as input and output
state = [model.state, optimizer.state]

@partial(mx.compile, inputs=state, outputs=state)
def step(x, y):
    loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
    loss, grads = loss_and_grad_fn(model, x, y)
    optimizer.update(model, grads)
    return loss

# Perform 10 steps of gradient descent
for it in range(10):
    loss = step(x, y)
    # Evaluate the model and optimizer state
    mx.eval(state)
    print(loss)

備註

如果你使用會進行隨機取樣的模組(例如 mlx.nn.Dropout()),請確保也把 mx.random.state 納入 compile() 捕捉的 state,即 state = [model.state, optimizer.state, mx.random.state]

備註

更多完整訓練圖編譯範例請參考 MLX Examples GitHub 專案。

與 Compile 搭配的變換#

在 MLX 中,函式變換可以組合使用。你可以把任何函式變換套用到其他函式變換的輸出。詳情請參考 function transforms 的文件。

編譯變換後的函式會如預期般運作:

grad_fn = mx.grad(mx.exp)

compiled_grad_fn = mx.compile(grad_fn)

# Prints: array(2.71828, dtype=float32)
print(grad_fn(mx.array(1.0)))

# Also prints: array(2.71828, dtype=float32)
print(compiled_grad_fn(mx.array(1.0)))

備註

為了盡可能多地編譯,對已編譯函式的變換,預設不會再被編譯。若要編譯變換後的函式,只需再通過 compile()

你也可以編譯會呼叫已編譯函式的函式。最佳做法是編譯最外層函式,讓 compile() 有最多機會最佳化計算圖:

@mx.compile
def inner(x):
    return mx.exp(-mx.abs(x))

def outer(x):
    inner(inner(x))

# Compiling the outer function is good to do as it will likely
# be faster even though the inner functions are compiled
fun = mx.compile(outer)

無形狀編譯#

當已編譯函式的輸入形狀改變時,函式會重新編譯。你可以在 compile() 指定 shapeless=True,只編譯一次並可用於可變形狀的輸入。在此情況下,輸入形狀變更不會觸發重新編譯。

def fun(x, y):
    return mx.abs(x + y)

compiled_fun = mx.compile(fun, shapeless=True)

x = mx.array(1.0)
y = mx.array(-2.0)

# Firt call compiles the function
print(compiled_fun(x, y))

# Second call with different shapes
# does not recompile the function
x = mx.array([1.0, -6.0])
y = mx.array([-2.0, 3.0])
print(compiled_fun(x, y))

請謹慎使用無形狀編譯。因為形狀變更不會觸發編譯,任何依賴輸入形狀的圖都可能無法如預期運作。形狀相依的計算很常見,有時也不易察覺。例如:

def fun(x):
    return x.reshape(x.shape[0] * x.shape[1], -1)

compiled_fun = mx.compile(fun, shapeless=True)

x = mx.random.uniform(shape=(2, 3, 4))

out = compiled_fun(x)

x = mx.random.uniform(shape=(5, 5, 3))

# Error, can't reshape (5, 5, 3) to (6, -1)
out = compiled_fun(x)

第二次呼叫 compiled_fun 會失敗,因為 reshape() 使用了第一次呼叫時 x 的靜態形狀。我們可以改用 flatten() 來避免硬編碼 x 的形狀:

def fun(x):
    return x.flatten(0, 1)

compiled_fun = mx.compile(fun, shapeless=True)

x = mx.random.uniform(shape=(2, 3, 4))

out = compiled_fun(x)

x = mx.random.uniform(shape=(5, 5, 3))

# Ok
out = compiled_fun(x)