轉換成 NumPy 和其他框架

轉換成 NumPy 和其他框架#

MLX 陣列支援透過以下方式在其他框架之間進行轉換:

讓我們將一個陣列轉換為 NumPy,然後再轉換回來。

import mlx.core as mx
import numpy as np

a = mx.arange(3)
b = np.array(a) # copy of a
c = mx.array(b) # copy of b

備註

由於 NumPy 不支援 bfloat16 陣列,你需要先轉成 float16float32np.array(a.astype(mx.float32))。否則會出現例如 Item size 2 for PEP 3118 buffer format string does not match the dtype V item size 0. 的錯誤。

預設情況下,NumPy 會把資料複製到新陣列。可以透過建立陣列檢視來避免複製:

a = mx.arange(3)
a_view = np.array(a, copy=False)
print(a_view.flags.owndata) # False
a_view[0] = 1
print(a[0].item()) # 1

備註

型別為 float64 的 NumPy 陣列會預設轉換為 float32 的 MLX 陣列。

NumPy 陣列檢視本質上仍是一般的 NumPy 陣列,只是它不擁有記憶體。因此對檢視的寫入會反映到原始陣列。

雖然這能有效避免陣列複製,但要注意,對陣列記憶體的外部修改不會反映在梯度中。

讓我們用範例示範:

def f(x):
    x_view = np.array(x, copy=False)
    x_view[:] *= x_view # modify memory without telling mx
    return x.sum()

x = mx.array([3.0])
y, df = mx.value_and_grad(f)(x)
print("f(x) = x² =", y.item()) # 9.0
print("f'(x) = 2x !=", df.item()) # 1.0

函式 f 透過記憶體檢視間接修改了陣列 x。然而這個修改不會反映在梯度中,最後一行輸出 1.0 即可看出,它只代表求和運算的梯度。x 的平方是在 MLX 外部進行的,因此不會納入梯度。也請注意,類似問題也會在陣列轉換與複製時出現。例如定義 mx.array(np.array(x)**2).sum() 的函式也會產生不正確的梯度,即使沒有對 MLX 記憶體做就地操作。

PyTorch#

警告

PyTorch 對 memoryview 的支援屬於實驗性,對多維陣列可能失效。建議目前先轉為 NumPy。

PyTorch 支援緩衝區協定,但需要明確的 memoryview

import mlx.core as mx
import torch

a = mx.arange(3)
b = torch.tensor(memoryview(a))
c = mx.array(b)

JAX#

JAX 完整支援緩衝區協定。

import mlx.core as mx
import jax.numpy as jnp

a = mx.arange(3)
b = jnp.array(a)
c = mx.array(b)

TensorFlow#

TensorFlow 支援緩衝區協定,但需要明確的 memoryview

import mlx.core as mx
import tensorflow as tf

a = mx.arange(3)
b = tf.constant(memoryview(a))
c = mx.array(b)