mlx.core.grad#
- grad(fun: Callable[P, R], argnums: int | Sequence[int] | None = None, argnames: str | Sequence[str] = []) Callable[P, Any]#
Returns a function which computes the gradient of
fun.- 参数:
fun (Callable) -- A function which takes a variable number of
arrayor trees ofarrayand returns a scalar outputarray.argnums (int or list(int), optional) -- Specify the index (or indices) of the positional arguments of
funto compute the gradient with respect to. If neitherargnumsnorargnamesare providedargnumsdefaults to0indicatingfun's first argument.argnames (str or list(str), optional) -- Specify keyword arguments of
funto compute gradients with respect to. It defaults to [] so no gradients for keyword arguments by default.
- 返回:
A function which has the same input arguments as
funand returns the gradient(s).- 返回类型:
Callable