神經網路#
在 MLX 中撰寫任意複雜的神經網路,只需使用 mlx.core.array 與 mlx.core.value_and_grad() 即可。然而,這需要使用者一再撰寫相同的簡單神經網路操作,並且手動且明確地處理所有參數狀態與初始化。
mlx.nn 模組提供直覺的方式來組合神經網路層、初始化其參數、在微調時凍結參數等,從而解決此問題。
神經網路快速開始#
import mlx.core as mx
import mlx.nn as nn
class MLP(nn.Module):
def __init__(self, in_dims: int, out_dims: int):
super().__init__()
self.layers = [
nn.Linear(in_dims, 128),
nn.Linear(128, 128),
nn.Linear(128, out_dims),
]
def __call__(self, x):
for i, l in enumerate(self.layers):
x = mx.maximum(x, 0) if i > 0 else x
x = l(x)
return x
# The model is created with all its parameters but nothing is initialized
# yet because MLX is lazily evaluated
mlp = MLP(2, 10)
# We can access its parameters by calling mlp.parameters()
params = mlp.parameters()
print(params["layers"][0]["weight"].shape)
# Printing a parameter will cause it to be evaluated and thus initialized
print(params["layers"][0])
# We can also force evaluate all parameters to initialize the model
mx.eval(mlp.parameters())
# A simple loss function.
# NOTE: It doesn't matter how it uses the mlp model. It currently captures
# it from the local scope. It could be a positional argument or a
# keyword argument.
def l2_loss(x, y):
y_hat = mlp(x)
return (y_hat - y).square().mean()
# Calling `nn.value_and_grad` instead of `mx.value_and_grad` returns the
# gradient with respect to `mlp.trainable_parameters()`
loss_and_grad = nn.value_and_grad(mlp, l2_loss)
Module 類別#
任何神經網路程式庫的主力都是 Module 類別。在 MLX 中,Module 類別是 mlx.core.array 或 Module 實例的容器。它的主要功能是提供遞迴 存取 與 更新 其參數及其子模組參數的方法。
參數#
模組的參數是任何型別為 mlx.core.array 的公開成員(名稱不應以 _ 開頭)。它可以任意地巢狀在其他 Module 實例或清單與字典中。
可使用 Module.parameters() 提取包含模組及其子模組所有參數的巢狀字典。
Module 也可以追蹤「凍結」的參數。詳情請見 Module.freeze() 方法。mlx.nn.value_and_grad() 回傳的梯度會以這些可訓練參數為對象。
更新參數#
MLX 模組允許存取與更新個別參數。然而,多數時候需要更新模組參數的大型子集合。此動作由 Module.update() 進行。
檢視模組#
最簡單的檢視模型架構方式是直接列印。延續上方範例,你可以用以下方式列印 MLP:
print(mlp)
輸出會是:
MLP(
(layers.0): Linear(input_dims=2, output_dims=128, bias=True)
(layers.1): Linear(input_dims=128, output_dims=128, bias=True)
(layers.2): Linear(input_dims=128, output_dims=10, bias=True)
)
若要取得 Module 中陣列的更詳細資訊,可以在參數上使用 mlx.utils.tree_map()。例如,查看 Module 中所有參數的形狀可使用:
from mlx.utils import tree_map
shapes = tree_map(lambda p: p.shape, mlp.parameters())
另一個例子,你可以用以下方式計算 Module 的參數數量:
from mlx.utils import tree_flatten
num_params = sum(v.size for _, v in tree_flatten(mlp.parameters()))
Value 與 Grad#
使用 Module 並不妨礙使用 MLX 的高階函式變換(mlx.core.value_and_grad()、mlx.core.grad() 等)。然而,這些函式變換假設函式是純函式,也就是應將參數作為被變換函式的引數傳入。
使用 MLX 模組有個簡單的模式可以做到這點
model = ...
def f(params, other_inputs):
model.update(params) # <---- Necessary to make the model use the passed parameters
return model(other_inputs)
f(model.trainable_parameters(), mx.zeros((10,)))
然而,mlx.nn.value_and_grad() 正好提供了這個模式,且只會計算模型可訓練參數的梯度。
詳細來說:
它會用一個函式包裹傳入的函式,該函式會呼叫
Module.update(),確保模型使用所提供的參數。它會呼叫
mlx.core.value_and_grad(),將函式轉換為同時計算對傳入參數的梯度的函式。它會再以一個函式包裹回傳的函式,將可訓練參數作為第一個引數傳給
mlx.core.value_and_grad()回傳的函式。
|
將傳入的函式 |
|
依據判斷式對模組的子模組進行量化。 |
|
對傳入群組中的分散式程序進行梯度平均。 |
|
Perform a distributed optimizer step by sharding gradients and optimizer states across ranks. |
- Module
Module- mlx.nn.Module.training
- mlx.nn.Module.state
- mlx.nn.Module.apply
- mlx.nn.Module.apply_to_modules
- mlx.nn.Module.children
- mlx.nn.Module.eval
- mlx.nn.Module.filter_and_map
- mlx.nn.Module.freeze
- mlx.nn.Module.leaf_modules
- mlx.nn.Module.load_weights
- mlx.nn.Module.modules
- mlx.nn.Module.named_modules
- mlx.nn.Module.parameters
- mlx.nn.Module.save_weights
- mlx.nn.Module.set_dtype
- mlx.nn.Module.train
- mlx.nn.Module.trainable_parameters
- mlx.nn.Module.unfreeze
- mlx.nn.Module.update
- mlx.nn.Module.update_modules
- Layers
- mlx.nn.ALiBi
- mlx.nn.AllToShardedLinear
- mlx.nn.AvgPool1d
- mlx.nn.AvgPool2d
- mlx.nn.AvgPool3d
- mlx.nn.BatchNorm
- mlx.nn.CELU
- mlx.nn.Conv1d
- mlx.nn.Conv2d
- mlx.nn.Conv3d
- mlx.nn.ConvTranspose1d
- mlx.nn.ConvTranspose2d
- mlx.nn.ConvTranspose3d
- mlx.nn.Dropout
- mlx.nn.Dropout2d
- mlx.nn.Dropout3d
- mlx.nn.Embedding
- mlx.nn.ELU
- mlx.nn.GELU
- mlx.nn.GLU
- mlx.nn.GroupNorm
- mlx.nn.GRU
- mlx.nn.HardShrink
- mlx.nn.HardTanh
- mlx.nn.Hardswish
- mlx.nn.InstanceNorm
- mlx.nn.LayerNorm
- mlx.nn.LeakyReLU
- mlx.nn.Linear
- mlx.nn.LogSigmoid
- mlx.nn.LogSoftmax
- mlx.nn.LSTM
- mlx.nn.MaxPool1d
- mlx.nn.MaxPool2d
- mlx.nn.MaxPool3d
- mlx.nn.Mish
- mlx.nn.MultiHeadAttention
- mlx.nn.PReLU
- mlx.nn.QuantizedAllToShardedLinear
- mlx.nn.QuantizedEmbedding
- mlx.nn.QuantizedLinear
- mlx.nn.QuantizedShardedToAllLinear
- mlx.nn.RMSNorm
- mlx.nn.ReLU
- mlx.nn.ReLU2
- mlx.nn.ReLU6
- mlx.nn.RNN
- mlx.nn.RoPE
- mlx.nn.SELU
- mlx.nn.Sequential
- mlx.nn.ShardedToAllLinear
- mlx.nn.Sigmoid
- mlx.nn.SiLU
- mlx.nn.SinusoidalPositionalEncoding
- mlx.nn.Softmin
- mlx.nn.Softshrink
- mlx.nn.Softsign
- mlx.nn.Softmax
- mlx.nn.Softplus
- mlx.nn.Step
- mlx.nn.Tanh
- mlx.nn.Transformer
- mlx.nn.Upsample
- Functions
- mlx.nn.elu
- mlx.nn.celu
- mlx.nn.gelu
- mlx.nn.gelu_approx
- mlx.nn.gelu_fast_approx
- mlx.nn.glu
- mlx.nn.hard_shrink
- mlx.nn.hard_tanh
- mlx.nn.hardswish
- mlx.nn.leaky_relu
- mlx.nn.log_sigmoid
- mlx.nn.log_softmax
- mlx.nn.mish
- mlx.nn.prelu
- mlx.nn.relu
- mlx.nn.relu2
- mlx.nn.relu6
- mlx.nn.selu
- mlx.nn.sigmoid
- mlx.nn.silu
- mlx.nn.softmax
- mlx.nn.softmin
- mlx.nn.softplus
- mlx.nn.softshrink
- mlx.nn.step
- mlx.nn.tanh
- 損失函數
- mlx.nn.losses.binary_cross_entropy
- mlx.nn.losses.cosine_similarity_loss
- mlx.nn.losses.cross_entropy
- mlx.nn.losses.gaussian_nll_loss
- mlx.nn.losses.hinge_loss
- mlx.nn.losses.huber_loss
- mlx.nn.losses.kl_div_loss
- mlx.nn.losses.l1_loss
- mlx.nn.losses.log_cosh_loss
- mlx.nn.losses.margin_ranking_loss
- mlx.nn.losses.mse_loss
- mlx.nn.losses.nll_loss
- mlx.nn.losses.smooth_l1_loss
- mlx.nn.losses.triplet_loss
- Initializers
- Distributed