惰性求值#

為什麼採用惰性求值#

在 MLX 中執行運算時,實際上不會進行計算,而是記錄計算圖。只有在執行 eval() 時才會進行真正的計算。

MLX 採用惰性求值,因為它有一些不錯的特性,下面會說明其中幾項。

轉換計算圖#

惰性求值讓我們可以在不實際計算的情況下記錄計算圖。這對 grad()vmap() 等函式轉換以及圖最佳化很有幫助。

目前 MLX 不會編譯並重新執行計算圖,計算圖都是動態產生的。不過,惰性求值讓未來要整合編譯以提升效能變得容易許多。

只計算你會使用的部分#

在 MLX 中,你不必太擔心計算從未被使用的輸出。例如:

def fun(x):
    a = fun1(x)
    b = expensive_fun(a)
    return a, b

y, _ = fun(x)

在這裡,我們其實不會計算 expensive_fun 的輸出。不過這種用法仍需小心,因為 expensive_fun 的圖仍會被建立,而這會有一定的成本。

同樣地,惰性求值也能在保持程式簡潔的同時節省記憶體。假設你有一個非常大的模型 Model,它繼承自 mlx.nn.Module。你可以用 model = Model() 來建立模型。通常這會以 float32 初始化所有權重,但在執行 eval() 前,初始化實際上不會做任何計算。若你改以 float16 權重更新模型,最大記憶體用量會是使用即時運算時的一半。

由於惰性計算,這種模式在 MLX 中很容易做到:

model = Model() # no memory used yet
model.load_weights("weights_fp16.safetensors")

何時該求值#

常見的問題是何時該使用 eval()。這在「讓圖變得過大」與「無法批次化足夠的有用工作」之間需要取捨。

例如:

for _ in range(100):
     a = a + b
     mx.eval(a)
     b = b * 2
     mx.eval(b)

這不是好主意,因為每次評估計算圖都有固定的開銷。另一方面,也有些開銷會隨著計算圖大小而成長,因此過大的計算圖(雖然計算上正確)可能會很昂貴。

幸運的是,MLX 對計算圖大小的容忍範圍很廣:每次評估從幾十個運算到數千個運算通常都沒問題。

多數數值計算都有一個反覆的外層迴圈(例如隨機梯度下降的迭代)。通常在這個外層迴圈的每次迭代使用 eval() 是自然且有效率的作法。

以下是具體範例:

for batch in dataset:

    # Nothing has been evaluated yet
    loss, grad = value_and_grad_fn(model, batch)

    # Still nothing has been evaluated
    optimizer.update(model, grad)

    # Evaluate the loss and the new parameters which will
    # run the full gradient computation and optimizer update
    mx.eval(loss, model.parameters())

需要注意的一個重要行為是圖何時會被隱式求值。只要你 print 陣列、將其轉成 numpy.ndarray,或透過 memoryview 存取其記憶體,圖就會被求值。透過 :func:`save`(或其他 MLX 儲存函式)儲存陣列也會觸發求值。

對純量陣列呼叫 array.item() 也會觸發求值。在上面的例子中,列印 loss(print(loss))或將 loss 純量加入清單(losses.append(loss.item()))都會造成計算圖被求值。如果這些語句在 mx.eval(loss, model.parameters()) 之前,則只會進行部分求值,僅計算前向傳播。

另外,多次對同一個陣列或一組陣列呼叫 eval() 完全沒有問題,實際上等同於不做任何事。

警告

以純量陣列作為流程控制條件會觸發求值。

以下為範例:

def fun(x):
    h, y = first_layer(x)
    if y > 0:  # An evaluation is done here!
        z  = second_layer_a(h)
    else:
        z  = second_layer_b(h)
    return z

使用陣列作為流程控制需要謹慎。上面的例子可以運作,甚至可搭配梯度轉換使用。然而,若求值過於頻繁,效率會非常差。