mlx.optimizers.clip_grad_norm

mlx.optimizers.clip_grad_norm#

clip_grad_norm(grads, max_norm)#

裁剪梯度的全域範數。

This function ensures that the global norm of the gradients does not exceed max_norm. It scales down the gradients proportionally if their norm is greater than max_norm.

範例

>>> grads = {"w1": mx.array([2, 3]), "w2": mx.array([1])}
>>> clipped_grads, total_norm = clip_grad_norm(grads, max_norm=2.0)
>>> print(clipped_grads)
{"w1": mx.array([...]), "w2": mx.array([...])}
參數:
  • grads (dict) -- A dictionary containing the gradient arrays.

  • max_norm (float) -- The maximum allowed global norm of the gradients.

回傳:

The possibly rescaled gradients and the original gradient norm.

回傳型別:

(dict, float)