mlx.nn.losses.kl_div_loss#
- class kl_div_loss(inputs: array, targets: array, axis: int = -1, reduction: Literal['none', 'mean', 'sum'] = 'none')#
計算 Kullback-Leibler 散度損失。
Computes the following when
reduction == 'none':mx.exp(targets) * (targets - inputs).sum(axis)
- 參數:
inputs (array) -- Log probabilities for the predicted distribution.
targets (array) -- Log probabilities for the target distribution.
axis (int, optional) -- The distribution axis. Default:
-1.reduction (str, optional) -- Specifies the reduction to apply to the output:
'none'|'mean'|'sum'. Default:'none'.
- 回傳:
The computed Kullback-Leibler divergence loss.
- 回傳型別: