Data Parallelism#
MLX enables efficient data parallel distributed training through its distributed communication primitives.
訓練範例#
本節將把 MLX 的訓練迴圈改為支援資料平行的分散式訓練。也就是在將梯度套用到模型前,先在多台主機間對梯度取平均。
Our training loop looks like the following code snippet if we omit the model, dataset, and optimizer initialization.
model = ...
optimizer = ...
dataset = ...
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
optimizer.update(model, grads)
return loss
for x, y in dataset:
loss = step(model, x, y)
mx.eval(loss, model.parameters())
要在多台機器間對梯度取平均,只需執行 all_sum() 並除以 Group 的大小。也就是用下列函式對梯度做 mlx.utils.tree_map():
def all_avg(x):
return mx.distributed.all_sum(x) / mx.distributed.init().size()
把所有內容整合後,在其他部分不變的情況下,訓練步驟如下:
from mlx.utils import tree_map
def all_reduce_grads(grads):
N = mx.distributed.init().size()
if N == 1:
return grads
return tree_map(
lambda x: mx.distributed.all_sum(x) / N,
grads
)
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
grads = all_reduce_grads(grads) # <--- This line was added
optimizer.update(model, grads)
return loss
Using nn.average_gradients#
雖然上面的程式碼可以正確運作,但每個梯度都會進行一次通訊。將多個梯度聚合後再通訊,會更有效率。
這正是 mlx.nn.average_gradients() 的目的。最終程式碼與上例幾乎相同:
model = ...
optimizer = ...
dataset = ...
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
grads = mx.nn.average_gradients(grads) # <---- This line was added
optimizer.update(model, grads)
return loss
for x, y in dataset:
loss = step(model, x, y)
mx.eval(loss, model.parameters())