mlx.nn.init.identity

目錄

mlx.nn.init.identity#

identity(dtype: Dtype = mlx.core.float32) Callable[[array], array]#

An initializer that returns an identity matrix.

參數:

dtype (Dtype, optional) -- The data type of the array. Default: float32.

回傳:

An initializer that returns an identity matrix with the same shape as the input.

回傳型別:

Callable[[array], array]

範例

>>> init_fn = nn.init.identity()
>>> init_fn(mx.zeros((2, 2)))
array([[1, 0],
       [0, 1]], dtype=float32)