函式轉換#
MLX 使用可組合的函式轉換來進行自動微分、向量化與計算圖最佳化。要查看完整的函式轉換清單,請參考 API 文件。
可組合函式轉換的關鍵概念是:每個轉換都會回傳一個可再被進一步轉換的函式。
以下是一個簡單範例:
>>> dfdx = mx.grad(mx.sin)
>>> dfdx(mx.array(mx.pi))
array(-1, dtype=float32)
>>> mx.cos(mx.array(mx.pi))
array(-1, dtype=float32)
grad() 作用於 sin() 的輸出本身就是另一個函式。在此情況下,它是正弦函式的梯度,也就是餘弦函式。若要取得二階導數,可以這樣做:
>>> d2fdx2 = mx.grad(mx.grad(mx.sin))
>>> d2fdx2(mx.array(mx.pi / 2))
array(-1, dtype=float32)
>>> mx.sin(mx.array(mx.pi / 2))
array(1, dtype=float32)
對 grad() 的輸出再使用 grad() 一直都是可行的,會持續得到更高階導數。
MLX 的任何函式轉換都可以以任意順序、任意深度組合。關於 自動微分 與 自動向量化 的更多資訊,請參閱以下章節。至於 compile() 的更多內容,請見 compile 文件。
自動微分#
MLX 的自動微分是針對函式運作,而非隱式計算圖。
備註
如果你是從 PyTorch 轉到 MLX,就不再需要像 backward、zero_grad、detach 這類函式,或是 requires_grad 這類屬性。
最基本的例子就是像上面那樣對標量值函式取梯度。你可以使用 grad() 與 value_and_grad() 來計算更複雜函式的梯度。預設情況下,這些函式會對第一個引數計算梯度:
def loss_fn(w, x, y):
return mx.mean(mx.square(w * x - y))
w = mx.array(1.0)
x = mx.array([0.5, -0.5])
y = mx.array([1.5, -1.5])
# Computes the gradient of loss_fn with respect to w:
grad_fn = mx.grad(loss_fn)
dloss_dw = grad_fn(w, x, y)
# Prints array(-1, dtype=float32)
print(dloss_dw)
# To get the gradient with respect to x we can do:
grad_fn = mx.grad(loss_fn, argnums=1)
dloss_dx = grad_fn(w, x, y)
# Prints array([-1, 1], dtype=float32)
print(dloss_dx)
取得損失與梯度的一種方式是先呼叫 loss_fn 再呼叫 grad_fn,但這可能造成大量重複計算。相反地,你應該使用 value_and_grad()。沿用上例:
# Computes the gradient of loss_fn with respect to w:
loss_and_grad_fn = mx.value_and_grad(loss_fn)
loss, dloss_dw = loss_and_grad_fn(w, x, y)
# Prints array(1, dtype=float32)
print(loss)
# Prints array(-1, dtype=float32)
print(dloss_dw)
你也可以對任意巢狀的 Python 容器中的陣列取梯度(具體來說是 list、tuple 或 dict)。
假設在上例中我們想加入權重與偏置參數,一種不錯的寫法如下:
def loss_fn(params, x, y):
w, b = params["weight"], params["bias"]
h = w * x + b
return mx.mean(mx.square(h - y))
params = {"weight": mx.array(1.0), "bias": mx.array(0.0)}
x = mx.array([0.5, -0.5])
y = mx.array([1.5, -1.5])
# Computes the gradient of loss_fn with respect to both the
# weight and bias:
grad_fn = mx.grad(loss_fn)
grads = grad_fn(params, x, y)
# Prints
# {'weight': array(-1, dtype=float32), 'bias': array(0, dtype=float32)}
print(grads)
請注意參數的樹狀結構會在梯度中被保留。
在某些情況下,你可能希望梯度不要穿過函式的某一部分,此時可以使用 stop_gradient()。
自動向量化#
使用 vmap() 來自動向量化複雜函式。為了清楚起見,這裡會用一個基本且刻意簡化的例子,但對於難以手動最佳化的複雜函式,vmap() 其實非常強大。
警告
有些運算尚未支援 vmap()。若你遇到像 ValueError: Primitive's vmap not implemented. 這樣的錯誤,請在 issue 中回報並附上你的函式,我們會優先納入支援。
將兩組向量逐一相加的一種樸素作法是用迴圈:
xs = mx.random.uniform(shape=(4096, 100))
ys = mx.random.uniform(shape=(100, 4096))
def naive_add(xs, ys):
return [xs[i] + ys[:, i] for i in range(xs.shape[0])]
相對地,你可以使用 vmap() 自動向量化相加:
# Vectorize over the second dimension of x and the
# first dimension of y
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))
in_axes 參數可用來指定要對應輸入的哪些維度進行向量化。同樣地,使用 out_axes 來指定輸出中向量化軸的位置。
讓我們測量這兩種版本的耗時:
import timeit
print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100))
print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100))
在 M1 Max 上,樸素版本總共需要 5.639 秒,而向量化版本只需 0.024 秒,速度快了 200 倍以上。
當然,這個運算相當刻意。更好的方式是直接做 xs + ys.T,但對更複雜的函式來說,vmap() 仍然非常實用。