快速開始指南

快速開始指南#

基礎#

匯入 mlx.core 並建立 array

>> import mlx.core as mx
>> a = mx.array([1, 2, 3, 4])
>> a.shape
[4]
>> a.dtype
int32
>> b = mx.array([1.0, 2.0, 3.0, 4.0])
>> b.dtype
float32

MLX 的運算是惰性的。MLX 運算的輸出只有在需要時才會被計算。若要強制計算陣列,請使用 eval()。在某些情況下陣列會自動被計算,例如使用 array.item() 檢視純量、列印陣列,或將 array 轉換為 numpy.ndarray,都會自動計算該陣列。

>> c = a + b    # c not yet evaluated
>> mx.eval(c)  # evaluates c
>> c = a + b
>> print(c)     # Also evaluates c
array([2, 4, 6, 8], dtype=float32)
>> c = a + b
>> import numpy as np
>> np.array(c)   # Also evaluates c
array([2., 4., 6., 8.], dtype=float32)

更多細節請參閱 Lazy Evaluation 頁面。

函式與圖形轉換#

MLX 具備標準的函式轉換,例如 grad()vmap()。轉換可以任意組合,例如允許 ``grad(vmap(grad(fn)))``(或其他任意組合)。

>> x = mx.array(0.0)
>> mx.sin(x)
array(0, dtype=float32)
>> mx.grad(mx.sin)(x)
array(1, dtype=float32)
>> mx.grad(mx.grad(mx.sin))(x)
array(-0, dtype=float32)

其他梯度轉換包含用於向量-雅可比乘積的 vjp(),以及用於雅可比-向量乘積的 jvp()

使用 value_and_grad() 可有效率地同時計算函式的輸出,以及相對於輸入的梯度。