Data Parallelism

Data Parallelism#

MLX enables efficient data parallel distributed training through its distributed communication primitives.

訓練範例#

本節將把 MLX 的訓練迴圈改為支援資料平行的分散式訓練。也就是在將梯度套用到模型前,先在多台主機間對梯度取平均。

Our training loop looks like the following code snippet if we omit the model, dataset, and optimizer initialization.

model = ...
optimizer = ...
dataset = ...

def step(model, x, y):
    loss, grads = loss_grad_fn(model, x, y)
    optimizer.update(model, grads)
    return loss

for x, y in dataset:
    loss = step(model, x, y)
    mx.eval(loss, model.parameters())

要在多台機器間對梯度取平均,只需執行 all_sum() 並除以 Group 的大小。也就是用下列函式對梯度做 mlx.utils.tree_map()

def all_avg(x):
    return mx.distributed.all_sum(x) / mx.distributed.init().size()

把所有內容整合後,在其他部分不變的情況下,訓練步驟如下:

from mlx.utils import tree_map

def all_reduce_grads(grads):
    N = mx.distributed.init().size()
    if N == 1:
        return grads
    return tree_map(
        lambda x: mx.distributed.all_sum(x) / N,
        grads
    )

def step(model, x, y):
    loss, grads = loss_grad_fn(model, x, y)
    grads = all_reduce_grads(grads)  # <--- This line was added
    optimizer.update(model, grads)
    return loss

Using nn.average_gradients#

雖然上面的程式碼可以正確運作,但每個梯度都會進行一次通訊。將多個梯度聚合後再通訊,會更有效率。

這正是 mlx.nn.average_gradients() 的目的。最終程式碼與上例幾乎相同:

model = ...
optimizer = ...
dataset = ...

def step(model, x, y):
    loss, grads = loss_grad_fn(model, x, y)
    grads = mx.nn.average_gradients(grads)  # <---- This line was added
    optimizer.update(model, grads)
    return loss

for x, y in dataset:
    loss = step(model, x, y)
    mx.eval(loss, model.parameters())