mlx.nn.average_gradients

mlx.nn.average_gradients#

average_gradients(gradients: Any, group: Group | None = None, all_reduce_size: int = 33554432, communication_stream: Stream | None = None)#

對傳入群組中的分散式程序進行梯度平均。

This helper enables concatenating several gradients of small arrays to one big all reduce call for better networking performance.

參數:
  • gradients (Any) -- The Python tree containing the gradients (it should have the same structure across processes)

  • group (Optional[Group]) -- The group of processes to average the gradients. If set to None the global group is used. Default: None.

  • all_reduce_size (int) -- Group arrays until their size in bytes exceeds this number. Perform one communication step per group of arrays. If less or equal to 0 array grouping is disabled. Default: 32MiB.

  • communication_stream (Optional[Stream]) -- The stream to use for the communication. If unspecified the default communication stream is used which can vary by back-end. Default: None.