mlx.nn.losses.gaussian_nll_loss

mlx.nn.losses.gaussian_nll_loss#

class gaussian_nll_loss(inputs: array, targets: array, vars: array, full: bool = False, eps: float = 1e-06, reduction: Literal['none', 'mean', 'sum'] = 'mean')#

Computes the negative log likelihood loss for a Gaussian distribution.

The loss is given by:

\[\frac{1}{2}\left(\log\left(\max\left(\text{vars}, \ \epsilon\right)\right) + \frac{\left(\text{inputs} - \text{targets} \right)^2} {\max\left(\text{vars}, \ \epsilon \right)}\right) + \text{const.}\]

where inputs are the predicted means and vars are the the predicted variances.

參數:
  • inputs (array) -- The predicted expectation of the Gaussian distribution.

  • targets (array) -- The target values (samples from the Gaussian distribution).

  • vars (array) -- The predicted variance of the Gaussian distribution.

  • full (bool, optional) -- Whether to include the constant term in the loss calculation. Default: False.

  • eps (float, optional) -- Small positive constant for numerical stability. Default: 1e-6.

  • reduction (str, optional) -- Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'none'.

回傳:

The Gaussian NLL loss.

回傳型別:

array