mlx.nn.value_and_grad#
- value_and_grad(model: Module, fn: Callable)#
Transform the passed function
fnto a function that computes the gradients offnwrt the model's trainable parameters and also its value.- 參數:
model (Module) -- The model whose trainable parameters to compute gradients for
fn (Callable) -- The scalar function to compute gradients for
- 回傳:
A callable that returns the value of
fnand the gradients wrt the trainable parameters ofmodel