mlx.core.tensordot

目录

mlx.core.tensordot#

tensordot(a: array, b: array, /, axes: int | list[Sequence[int]] = 2, *, stream: None | Stream | Device = None) array#

Compute the tensor dot product along the specified axes.

参数:
  • a (array) -- Input array

  • b (array) -- Input array

  • axes (int or list(list(int)), optional) -- The number of dimensions to sum over. If an integer is provided, then sum over the last axes dimensions of a and the first axes dimensions of b. If a list of lists is provided, then sum over the corresponding dimensions of a and b. Default: 2.

返回:

The tensor dot product.

返回类型:

array