mlx.core.vjp#
- vjp(fun: Callable, primals: list[array], cotangents: list[array]) tuple[list[array], list[array]]#
Compute the vector-Jacobian product.
Computes the product of the
cotangentswith the Jacobian of a functionfunevaluated atprimals.- 参数:
fun (Callable) -- A function which takes a variable number of
arrayand returns a singlearrayor list ofarray.primals (list(array)) -- A list of
arrayat which to evaluate the Jacobian.cotangents (list(array)) -- A list of
arraywhich are the "vector" in the vector-Jacobian product. Thecotangentsshould be the same in number, shape, and type as the outputs offun.
- 返回:
A tuple with the outputs of
funin the first position and the vector-Jacobian products in the second position.- 返回类型:
示例
import mlx.core as mx outs, vjps = mx.vjp(mx.sin, (mx.array(1.0),), (mx.array(1.0),))