陣列索引#

大多數情況下,MLX array 的索引方式與 NumPy numpy.ndarray 相同。更多細節請參考 NumPy 文件

例如,你可以使用一般整數與切片(slice)來索引陣列:

>>> arr = mx.arange(10)
>>> arr[3]
array(3, dtype=int32)
>>> arr[-2]  # negative indexing works
array(8, dtype=int32)
>>> arr[2:8:2] # start, stop, stride
array([2, 4, 6], dtype=int32)

對多維陣列而言,...Ellipsis 的語法與 NumPy 相同:

>>> arr = mx.arange(8).reshape(2, 2, 2)
>>> arr[:, :, 0]
array(3, dtype=int32)
array([[0, 2],
       [4, 6]], dtype=int32
>>> arr[..., 0]
array([[0, 2],
       [4, 6]], dtype=int32

你可以用 None 來建立新軸:

>>> arr = mx.arange(8)
>>> arr.shape
[8]
>>> arr[None].shape
[1, 8]

你也可以用一個 array 來索引另一個 array

>>> arr = mx.arange(10)
>>> idx = mx.array([5, 7])
>>> arr[idx]
array([5, 7], dtype=int32)

整數、slice...array 的索引可混合使用,與 NumPy 相同。

其他可能有助於索引陣列的函式還有 take()take_along_axis()

與 NumPy 的差異#

備註

MLX 的索引與 NumPy 有兩個重要差異:

  • 索引不會進行邊界檢查。超出範圍的索引屬於未定義行為。

  • 布林遮罩索引僅支援賦值(參見 布林遮罩賦值)。

之所以不做邊界檢查,是因為例外無法從 GPU 傳回。在啟動核心前先對陣列索引進行邊界檢查會非常低效。

以布林遮罩做索引可能會在未來被 MLX 支援。一般而言,MLX 對輸出*形狀*依賴輸入*資料*的運算支援有限。MLX 尚未支援的其他例子包含 numpy.nonzero()numpy.where() 的單輸入版本。

就地更新#

MLX 支援對索引陣列進行就地更新。例如:

>>> a = mx.array([1, 2, 3])
>>> a[2] = 0
>>> a
array([1, 2, 0], dtype=int32)

與 NumPy 相同,就地更新會反映在同一陣列的所有參照中:

>>> a = mx.array([1, 2, 3])
>>> b = a
>>> b[2] = 0
>>> b
array([1, 2, 0], dtype=int32)
>>> a
array([1, 2, 0], dtype=int32)

注意,與 NumPy 不同,切片會建立副本而非檢視,因此修改它不會改變原陣列:

>>> a = mx.array([1, 2, 3])
>>> b = a[:]
>>> b[2] = 0
>>> b
array([1, 2, 0], dtype=int32)
>>> a
array([1, 2, 3], dtype=int32)

與 NumPy 不同,對同一位置的更新是非決定性的:

>>> a = mx.array([1, 2, 3])
>>> a[[0, 0]] = mx.array([4, 5])

a 的第一個元素可能是 45

允許對使用就地更新的函式進行轉換,且可如預期運作。例如:

def fun(x, idx):
    x[idx] = 2.0
    return x.sum()

dfdx = mx.grad(fun)(mx.array([1.0, 2.0, 3.0]), mx.array([1]))
print(dfdx)  # Prints: array([1, 0, 1], dtype=float32)

上述 dfdx 會得到正確的梯度,也就是在 idx 為 0,其餘位置為 1。

布林遮罩賦值#

MLX 支援使用 NumPy 語法的布林索引。遮罩必須是 bool_ 的 MLX arraydtype=bool 的 NumPy ndarray。其他索引類型會走標準的 scatter 程式碼路徑。

>>> a = mx.array([1.0, 2.0, 3.0])
>>> mask = mx.array([True, False, True])
>>> updates = mx.array([5.0, 6.0])
>>> a[mask] = updates
>>> a
array([5.0, 2.0, 6.0], dtype=float32)

標量賦值會廣播到 mask 中所有 True 的位置。對於非標量賦值,updates 必須至少提供與 maskTrue 數量相同的元素。

>>> a = mx.zeros((2, 3))
>>> mask = mx.array([[True, False, True],
                     [False, False, True]])
>>> a[mask] = 1.0
>>> a
array([[1.0, 0.0, 1.0],
       [0.0, 0.0, 1.0]], dtype=float32)

布林遮罩遵循 NumPy 語意:

  • 遮罩的形狀必須與其索引的軸形狀完全一致。唯一的例外是標量布林遮罩,可廣播到整個陣列。

  • 未被遮罩覆蓋的軸會完整保留。

>>> a = mx.arange(1000).reshape(10, 10, 10)
>>> a[mx.random.normal((10, 10)) > 0.0] = 0  # valid: mask covers axes 0 and 1

形狀為 (10, 10) 的遮罩套用到前兩個軸,因此 a[mask] 會選出 mask[i, j]True 的一維切片 a[i, j, :]。像 (1, 10, 10)(10, 10, 1) 的形狀與被索引的軸不匹配,因此會拋出錯誤。