神經網路#

在 MLX 中撰寫任意複雜的神經網路,只需使用 mlx.core.arraymlx.core.value_and_grad() 即可。然而,這需要使用者一再撰寫相同的簡單神經網路操作,並且手動且明確地處理所有參數狀態與初始化。

mlx.nn 模組提供直覺的方式來組合神經網路層、初始化其參數、在微調時凍結參數等,從而解決此問題。

神經網路快速開始#

import mlx.core as mx
import mlx.nn as nn

class MLP(nn.Module):
    def __init__(self, in_dims: int, out_dims: int):
        super().__init__()

        self.layers = [
            nn.Linear(in_dims, 128),
            nn.Linear(128, 128),
            nn.Linear(128, out_dims),
        ]

    def __call__(self, x):
        for i, l in enumerate(self.layers):
            x = mx.maximum(x, 0) if i > 0 else x
            x = l(x)
        return x

# The model is created with all its parameters but nothing is initialized
# yet because MLX is lazily evaluated
mlp = MLP(2, 10)

# We can access its parameters by calling mlp.parameters()
params = mlp.parameters()
print(params["layers"][0]["weight"].shape)

# Printing a parameter will cause it to be evaluated and thus initialized
print(params["layers"][0])

# We can also force evaluate all parameters to initialize the model
mx.eval(mlp.parameters())

# A simple loss function.
# NOTE: It doesn't matter how it uses the mlp model. It currently captures
#       it from the local scope. It could be a positional argument or a
#       keyword argument.
def l2_loss(x, y):
    y_hat = mlp(x)
    return (y_hat - y).square().mean()

# Calling `nn.value_and_grad` instead of `mx.value_and_grad` returns the
# gradient with respect to `mlp.trainable_parameters()`
loss_and_grad = nn.value_and_grad(mlp, l2_loss)

Module 類別#

任何神經網路程式庫的主力都是 Module 類別。在 MLX 中,Module 類別是 mlx.core.arrayModule 實例的容器。它的主要功能是提供遞迴 存取更新 其參數及其子模組參數的方法。

參數#

模組的參數是任何型別為 mlx.core.array 的公開成員(名稱不應以 _ 開頭)。它可以任意地巢狀在其他 Module 實例或清單與字典中。

可使用 Module.parameters() 提取包含模組及其子模組所有參數的巢狀字典。

Module 也可以追蹤「凍結」的參數。詳情請見 Module.freeze() 方法。mlx.nn.value_and_grad() 回傳的梯度會以這些可訓練參數為對象。

更新參數#

MLX 模組允許存取與更新個別參數。然而,多數時候需要更新模組參數的大型子集合。此動作由 Module.update() 進行。

檢視模組#

最簡單的檢視模型架構方式是直接列印。延續上方範例,你可以用以下方式列印 MLP

print(mlp)

輸出會是:

MLP(
  (layers.0): Linear(input_dims=2, output_dims=128, bias=True)
  (layers.1): Linear(input_dims=128, output_dims=128, bias=True)
  (layers.2): Linear(input_dims=128, output_dims=10, bias=True)
)

若要取得 Module 中陣列的更詳細資訊,可以在參數上使用 mlx.utils.tree_map()。例如,查看 Module 中所有參數的形狀可使用:

from mlx.utils import tree_map
shapes = tree_map(lambda p: p.shape, mlp.parameters())

另一個例子,你可以用以下方式計算 Module 的參數數量:

from mlx.utils import tree_flatten
num_params = sum(v.size for _, v in tree_flatten(mlp.parameters()))

Value 與 Grad#

使用 Module 並不妨礙使用 MLX 的高階函式變換(mlx.core.value_and_grad()mlx.core.grad() 等)。然而,這些函式變換假設函式是純函式,也就是應將參數作為被變換函式的引數傳入。

使用 MLX 模組有個簡單的模式可以做到這點

model = ...

def f(params, other_inputs):
    model.update(params)  # <---- Necessary to make the model use the passed parameters
    return model(other_inputs)

f(model.trainable_parameters(), mx.zeros((10,)))

然而,mlx.nn.value_and_grad() 正好提供了這個模式,且只會計算模型可訓練參數的梯度。

詳細來說:

  • 它會用一個函式包裹傳入的函式,該函式會呼叫 Module.update(),確保模型使用所提供的參數。

  • 它會呼叫 mlx.core.value_and_grad(),將函式轉換為同時計算對傳入參數的梯度的函式。

  • 它會再以一個函式包裹回傳的函式,將可訓練參數作為第一個引數傳給 mlx.core.value_and_grad() 回傳的函式。

value_and_grad(model, fn)

將傳入的函式 fn 轉換為同時計算 fn 對模型可訓練參數的梯度與其值的函式。

quantize(model[, group_size, bits, mode, ...])

依據判斷式對模組的子模組進行量化。

average_gradients(gradients[, group, ...])

對傳入群組中的分散式程序進行梯度平均。

fsdp_apply_gradients(gradients, parameters, ...)

Perform a distributed optimizer step by sharding gradients and optimizer states across ranks.