mlx.nn.layers.distributed.shard_linear

mlx.nn.layers.distributed.shard_linear#

shard_linear(module: Module, sharding: str, *, segments: int | list = 1, group: Group | None = None)#

Create a new linear layer that has its parameters sharded and also performs distributed communication either in the forward or backward pass.

备注

Contrary to shard_inplace, the original layer is not changed but a new layer is returned.

参数:
  • module (Module) -- The linear layer to be sharded.

  • sharding (str) -- One of "all-to-sharded" and "sharded-to-all" that defines the type of sharding to perform.

  • segments (int or list) -- The segments to use. Default: 1.

  • group (Group) -- The distributed group to shard across. If not set, the global group will be used. Default: None.