mlx.nn.losses.cosine_similarity_loss

mlx.nn.losses.cosine_similarity_loss#

class cosine_similarity_loss(x1: array, x2: array, axis: int = 1, eps: float = 1e-08, reduction: Literal['none', 'mean', 'sum'] = 'none')#

計算兩個輸入的餘弦相似度。

The cosine similarity loss is given by

\[\frac{x_1 \cdot x_2}{\max(\|x_1\| \cdot \|x_2\|, \epsilon)}\]
參數:
  • x1 (mx.array) -- The first set of inputs.

  • x2 (mx.array) -- The second set of inputs.

  • axis (int, optional) -- The embedding axis. Default: 1.

  • eps (float, optional) -- The minimum value of the denominator used for numerical stability. Default: 1e-8.

  • reduction (str, optional) -- Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'none'.

回傳:

The computed cosine similarity loss.

回傳型別:

mx.array