mlx.nn.value_and_grad

mlx.nn.value_and_grad#

value_and_grad(model: Module, fn: Callable)#

Transform the passed function fn to a function that computes the gradients of fn wrt the model's trainable parameters and also its value.

参数:
  • model (Module) -- The model whose trainable parameters to compute gradients for

  • fn (Callable) -- The scalar function to compute gradients for

返回:

A callable that returns the value of fn and the gradients wrt the trainable parameters of model