惰性求值#
為什麼採用惰性求值#
在 MLX 中執行運算時,實際上不會進行計算,而是記錄計算圖。只有在執行 eval() 時才會進行真正的計算。
MLX 採用惰性求值,因為它有一些不錯的特性,下面會說明其中幾項。
轉換計算圖#
惰性求值讓我們可以在不實際計算的情況下記錄計算圖。這對 grad()、vmap() 等函式轉換以及圖最佳化很有幫助。
目前 MLX 不會編譯並重新執行計算圖,計算圖都是動態產生的。不過,惰性求值讓未來要整合編譯以提升效能變得容易許多。
只計算你會使用的部分#
在 MLX 中,你不必太擔心計算從未被使用的輸出。例如:
def fun(x):
a = fun1(x)
b = expensive_fun(a)
return a, b
y, _ = fun(x)
在這裡,我們其實不會計算 expensive_fun 的輸出。不過這種用法仍需小心,因為 expensive_fun 的圖仍會被建立,而這會有一定的成本。
同樣地,惰性求值也能在保持程式簡潔的同時節省記憶體。假設你有一個非常大的模型 Model,它繼承自 mlx.nn.Module。你可以用 model = Model() 來建立模型。通常這會以 float32 初始化所有權重,但在執行 eval() 前,初始化實際上不會做任何計算。若你改以 float16 權重更新模型,最大記憶體用量會是使用即時運算時的一半。
由於惰性計算,這種模式在 MLX 中很容易做到:
model = Model() # no memory used yet
model.load_weights("weights_fp16.safetensors")
何時該求值#
常見的問題是何時該使用 eval()。這在「讓圖變得過大」與「無法批次化足夠的有用工作」之間需要取捨。
例如:
for _ in range(100):
a = a + b
mx.eval(a)
b = b * 2
mx.eval(b)
這不是好主意,因為每次評估計算圖都有固定的開銷。另一方面,也有些開銷會隨著計算圖大小而成長,因此過大的計算圖(雖然計算上正確)可能會很昂貴。
幸運的是,MLX 對計算圖大小的容忍範圍很廣:每次評估從幾十個運算到數千個運算通常都沒問題。
多數數值計算都有一個反覆的外層迴圈(例如隨機梯度下降的迭代)。通常在這個外層迴圈的每次迭代使用 eval() 是自然且有效率的作法。
以下是具體範例:
for batch in dataset:
# Nothing has been evaluated yet
loss, grad = value_and_grad_fn(model, batch)
# Still nothing has been evaluated
optimizer.update(model, grad)
# Evaluate the loss and the new parameters which will
# run the full gradient computation and optimizer update
mx.eval(loss, model.parameters())
需要注意的一個重要行為是圖何時會被隱式求值。只要你 print 陣列、將其轉成 numpy.ndarray,或透過 memoryview 存取其記憶體,圖就會被求值。透過 :func:`save`(或其他 MLX 儲存函式)儲存陣列也會觸發求值。
對純量陣列呼叫 array.item() 也會觸發求值。在上面的例子中,列印 loss(print(loss))或將 loss 純量加入清單(losses.append(loss.item()))都會造成計算圖被求值。如果這些語句在 mx.eval(loss, model.parameters()) 之前,則只會進行部分求值,僅計算前向傳播。
另外,多次對同一個陣列或一組陣列呼叫 eval() 完全沒有問題,實際上等同於不做任何事。
警告
以純量陣列作為流程控制條件會觸發求值。
以下為範例:
def fun(x):
h, y = first_layer(x)
if y > 0: # An evaluation is done here!
z = second_layer_a(h)
else:
z = second_layer_b(h)
return z
使用陣列作為流程控制需要謹慎。上面的例子可以運作,甚至可搭配梯度轉換使用。然而,若求值過於頻繁,效率會非常差。