轉換成 NumPy 和其他框架#
MLX 陣列支援透過以下方式在其他框架之間進行轉換:
讓我們將一個陣列轉換為 NumPy,然後再轉換回來。
import mlx.core as mx
import numpy as np
a = mx.arange(3)
b = np.array(a) # copy of a
c = mx.array(b) # copy of b
備註
由於 NumPy 不支援 bfloat16 陣列,你需要先轉成 float16 或 float32:np.array(a.astype(mx.float32))。否則會出現例如 Item size 2 for PEP 3118 buffer format string does not match the dtype V item size 0. 的錯誤。
預設情況下,NumPy 會把資料複製到新陣列。可以透過建立陣列檢視來避免複製:
a = mx.arange(3)
a_view = np.array(a, copy=False)
print(a_view.flags.owndata) # False
a_view[0] = 1
print(a[0].item()) # 1
備註
型別為 float64 的 NumPy 陣列會預設轉換為 float32 的 MLX 陣列。
NumPy 陣列檢視本質上仍是一般的 NumPy 陣列,只是它不擁有記憶體。因此對檢視的寫入會反映到原始陣列。
雖然這能有效避免陣列複製,但要注意,對陣列記憶體的外部修改不會反映在梯度中。
讓我們用範例示範:
def f(x):
x_view = np.array(x, copy=False)
x_view[:] *= x_view # modify memory without telling mx
return x.sum()
x = mx.array([3.0])
y, df = mx.value_and_grad(f)(x)
print("f(x) = x² =", y.item()) # 9.0
print("f'(x) = 2x !=", df.item()) # 1.0
函式 f 透過記憶體檢視間接修改了陣列 x。然而這個修改不會反映在梯度中,最後一行輸出 1.0 即可看出,它只代表求和運算的梯度。x 的平方是在 MLX 外部進行的,因此不會納入梯度。也請注意,類似問題也會在陣列轉換與複製時出現。例如定義 mx.array(np.array(x)**2).sum() 的函式也會產生不正確的梯度,即使沒有對 MLX 記憶體做就地操作。
PyTorch#
警告
PyTorch 對 memoryview 的支援屬於實驗性,對多維陣列可能失效。建議目前先轉為 NumPy。
PyTorch 支援緩衝區協定,但需要明確的 memoryview。
import mlx.core as mx
import torch
a = mx.arange(3)
b = torch.tensor(memoryview(a))
c = mx.array(b)
JAX#
JAX 完整支援緩衝區協定。
import mlx.core as mx
import jax.numpy as jnp
a = mx.arange(3)
b = jnp.array(a)
c = mx.array(b)
TensorFlow#
TensorFlow 支援緩衝區協定,但需要明確的 memoryview。
import mlx.core as mx
import tensorflow as tf
a = mx.arange(3)
b = tf.constant(memoryview(a))
c = mx.array(b)