mlx.core.grad

目錄

mlx.core.grad#

grad(fun: Callable[P, R], argnums: int | Sequence[int] | None = None, argnames: str | Sequence[str] = []) Callable[P, Any]#

回傳一個計算 fun 梯度的函式。

參數:
  • fun (Callable) -- A function which takes a variable number of array or trees of array and returns a scalar output array.

  • argnums (int or list(int), optional) -- Specify the index (or indices) of the positional arguments of fun to compute the gradient with respect to. If neither argnums nor argnames are provided argnums defaults to 0 indicating fun's first argument.

  • argnames (str or list(str), optional) -- Specify keyword arguments of fun to compute gradients with respect to. It defaults to [] so no gradients for keyword arguments by default.

回傳:

A function which has the same input arguments as fun and returns the gradient(s).

回傳型別:

Callable