隨機

隨機#

MLX 的隨機取樣函式預設使用隱式的全域 PRNG 狀態。不過,所有函式都提供可選的 key 關鍵字引數,以便在需要更細緻的控制或明確的狀態管理時使用。

例如,你可以用以下方式產生隨機數:

for _ in range(3):
  print(mx.random.uniform())

這會列印一串不重複的偽隨機數。你也可以明確設定 key:

key = mx.random.key(0)
for _ in range(3):
  print(mx.random.uniform(key=key))

這樣每次迭代都會產生相同的偽隨機數。

我們遵循 JAX 的 PRNG 設計,採用可分割的 Threefry(以計數器為基礎的 PRNG)。

bernoulli([p, shape, key, stream])

產生伯努利隨機值。

categorical(logits[, axis, shape, ...])

從類別分佈取樣。

gumbel([shape, dtype, key, stream])

從標準 Gumbel 分佈取樣。

key(seed)

由種子取得 PRNG key。

normal([shape, dtype, loc, scale, key, stream])

產生常態分佈的隨機數。

multivariate_normal(mean, cov[, shape, ...])

在給定平均與共變異數下產生多變量常態隨機樣本。

randint(low, high[, shape, dtype, key, stream])

從指定區間產生隨機整數。

seed(seed)

為全域 PRNG 設定種子。

split(key[, num, stream])

將 PRNG key 拆分為子 key。

truncated_normal(lower, upper[, shape, ...])

從截斷常態分佈產生數值。

uniform([low, high, shape, dtype, key, stream])

產生均勻分佈的隨機數。

laplace([shape, dtype, loc, scale, key, stream])

從拉普拉斯分佈取樣。

permutation(x[, axis, key, stream])

產生隨機排列,或打亂陣列的元素順序。