mlx.core.segmented_mm

Contents

mlx.core.segmented_mm#

segmented_mm(a: array, b: array, /, segments: array, *, stream: None | Stream | Device = None) array#

Perform a matrix multiplication but segment the inner dimension and save the result for each segment separately.

Parameters:
  • a (array) – Input array of shape MxK.

  • b (array) – Input array of shape KxN.

  • segments (array) – The offsets into the inner dimension for each segment.

Returns:

The result per segment of shape MxN.

Return type:

array