mlx.core.random.categorical#
- categorical(logits: array, axis: int = -1, shape: Sequence[int] | None = None, num_samples: int | None = None, key: array | None = None, stream: None | Stream | Device = None) array#
從類別分佈取樣。
The values are sampled from the categorical distribution specified by the unnormalized values in
logits. Note, at most one ofshapeornum_samplescan be specified. If both areNone, the output has the same shape aslogitswith theaxisdimension removed.- 參數:
logits (array) -- The unnormalized categorical distribution(s).
axis (int, optional) -- The axis which specifies the distribution. Default:
-1.shape (list(int), optional) -- The shape of the output. This must be broadcast compatible with
logits.shapewith theaxisdimension removed. Default:Nonenum_samples (int, optional) -- The number of samples to draw from each of the categorical distributions in
logits. The output will havenum_samplesin the last dimension. Default:None.key (array, optional) -- A PRNG key. Default:
None.
- 回傳:
The
shape-sized output array with typeuint32.- 回傳型別: