mlx.core.fast.rope

目录

mlx.core.fast.rope#

rope(a: array, dims: int, *, traditional: bool, base: float | None, scale: float, offset: int | array, freqs: array | None = None, stream: None | Stream | Device = None) array#

Apply rotary positional encoding to the input.

The input is expected to be at least 3D with shape (B, *, T, D) where:
  • B is the batch size.

  • T is the sequence length.

  • D is the feature dimension.

参数:
  • a (array) -- The input array.

  • dims (int) -- The feature dimensions to be rotated. If the input feature is larger than dims then the rest is left unchanged.

  • traditional (bool) -- If set to True choose the traditional implementation which rotates consecutive dimensions.

  • base (float, optional) -- The base used to compute angular frequency for each dimension in the positional encodings. Exactly one of base and freqs must be None.

  • scale (float) -- The scale used to scale the positions.

  • offset (int or array) -- The position offset to start at. If an array is given it can be a scalar or vector of B offsets for each example in the batch.

  • freqs (array, optional) -- Optional frequencies to use with RoPE. If set, the base parameter must be None. Default: None.

返回:

The output array.

返回类型:

array