mlx.core.distributed.sum_scatter#
- sum_scatter(x: array, *, group: Group | None = None, stream: None | Stream | Device = None) array#
Sum
xacross all processes in the group and shard the result along the first axis across ranks.x.shape[0]must be divisible by the group size.The result is equivalent to
all_sum(x)[rank*chunk_size:(rank+1)*chunk_size], wherechunk_size = x.shape[0] // group.size()andrankis the rank of this process in the group. Note:all_sumis mentioned only for illustration; the actual implementation does not performall_sumand uses a single reduce-scatter collective instead. Currently supported only for the NCCL backend.- Parameters:
- Returns:
The output array with shape
[x.shape[0] // group.size(), *x.shape[1:]].- Return type: