mlx.core.jvp

Contents

mlx.core.jvp#

jvp(fun: Callable, primals: list[array], tangents: list[array]) tuple[list[array], list[array]]#

Compute the Jacobian-vector product.

This computes the product of the Jacobian of a function fun evaluated at primals with the tangents.

Parameters:
  • fun (Callable) – A function which takes a variable number of array and returns a single array or list of array.

  • primals (list(array)) – A list of array at which to evaluate the Jacobian.

  • tangents (list(array)) – A list of array which are the “vector” in the Jacobian-vector product. The tangents should be the same in number, shape, and type as the inputs of fun (i.e. the primals).

Returns:

A tuple with the outputs of fun in the first position and the Jacobian-vector products in the second position.

Return type:

tuple(list(array), list(array))

Example

import mlx.core as mx

outs, jvps = mx.jvp(mx.sin, (mx.array(1.0),), (mx.array(1.0),))