mlx.core.random.truncated_normal#
- truncated_normal(lower: scalar | array, upper: scalar | array, shape: Sequence[int] | None = None, dtype: Dtype | None = float32, key: array | None = None, stream: None | Stream | Device = None) array#
Generate values from a truncated normal distribution.
The values are sampled from the truncated normal distribution on the domain
(lower, upper). The boundsloweranduppercan be scalars or arrays and must be broadcastable toshape.- 参数:
lower (scalar or array) -- Lower bound of the domain.
upper (scalar or array) -- Upper bound of the domain.
shape (list(int), optional) -- The shape of the output. Default:
().dtype (Dtype, optional) -- The data type of the output. Default:
float32.key (array, optional) -- A PRNG key. Default:
None.
- 返回:
The output array of random values.
- 返回类型: