mlx.core.vjp

目錄

mlx.core.vjp#

vjp(fun: Callable, primals: list[array], cotangents: list[array]) tuple[list[array], list[array]]#

計算向量-雅可比乘積。

Computes the product of the cotangents with the Jacobian of a function fun evaluated at primals.

參數:
  • fun (Callable) -- A function which takes a variable number of array and returns a single array or list of array.

  • primals (list(array)) -- A list of array at which to evaluate the Jacobian.

  • cotangents (list(array)) -- A list of array which are the "vector" in the vector-Jacobian product. The cotangents should be the same in number, shape, and type as the outputs of fun.

回傳:

A tuple with the outputs of fun in the first position and the vector-Jacobian products in the second position.

回傳型別:

tuple(list(array), list(array))

範例

import mlx.core as mx

outs, vjps = mx.vjp(mx.sin, (mx.array(1.0),), (mx.array(1.0),))