自訂 Metal 內核#

MLX 支援透過 Python 與 C++ API 撰寫自訂 Metal 內核。

簡單範例#

讓我們撰寫一個自訂內核,逐元素計算 exp

source = """
    uint elem = thread_position_in_grid.x;
    T tmp = inp[elem];
    out[elem] = metal::exp(tmp);
"""

kernel = mx.fast.metal_kernel(
    name="myexp",
    input_names=["inp"],
    output_names=["out"],
    source=source,
)

def exp_elementwise(a: mx.array):
    outputs = kernel(
        inputs=[a],
        template=[("T", mx.float32)],
        grid=(a.size, 1, 1),
        threadgroup=(256, 1, 1),
        output_shapes=[a.shape],
        output_dtypes=[a.dtype],
    )
    return outputs[0]

a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))

每次建立內核時,都會建立新的 Metal 程式庫並可能進行 JIT 編譯。為了降低這樣的額外負擔,請用 fast.metal_kernel() 建立一次內核後重複使用。

備註

source 中只需提供 Metal 內核的主體即可。函式簽名會自動產生。

完整的函式簽名會根據以下內容自動產生:

  • inputs 的形狀/資料型別

    在上述例子中,a 是型別為 mx.float16mx.array,並以鍵 inp 傳入,因此我們會在簽名中加入 const device float16_t* inp。若 source 中出現 inp_shapeinp_stridesinp_ndim,也會為了方便而一併加入。

  • output_dtypes 的列表

    在上述例子中,out 是型別為 mx.float16mx.array,因此會加入 device float16_t* out

  • 透過 template 傳入的範本參數

    在上述例子中,template=[("T", mx.float32)] 會為函式加入 template <typename T>,並以 custom_kernel_myexp_float<float> 實例化此範本。範本參數可以是 mx.core.Dtypeintbool

  • source 中使用的 Metal 屬性,例如 [[thread_position_in_grid]]

    這些會被加入為函式參數。支援 Metal Shading Language Specification 表 5.8 中定義的所有屬性。

綜合以上內容,myexp 產生的函式簽名如下:

template <typename T>
[[kernel]] void custom_kernel_myexp_float(
  const device float16_t* inp [[buffer(0)]],
  device float16_t* out [[buffer(1)]],
  uint3 thread_position_in_grid [[thread_position_in_grid]]) {

        uint elem = thread_position_in_grid.x;
        T tmp = inp[elem];
        out[elem] = metal::exp(tmp);

}

template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;

注意:gridthreadgroup 是 Metal dispatchThreads 函式的參數。這代表我們會啟動 mx.prod(grid) 個執行緒,並依 threadgroup 大小分成多個執行緒群組。為了最佳效能,每個執行緒群組的各維度應小於或等於對應的 grid 維度。

ast.metal_kernel.__call__() 中傳入 verbose=True 會列印產生的程式碼以便除錯。

使用形狀/步幅#

fast.metal_kernel() 支援 ensure_row_contiguous 參數,預設為 True。這會在啟動內核前視需要複製陣列輸入,以確保記憶體配置為列連續。一般而言,這會讓撰寫內核更容易,因為索引時不用擔心間隔或維度順序。

若想避免這個複製,當 source 中出現這些名稱時,fast.metal_kernel() 會自動為每個輸入陣列 a 傳入 a_shapea_stridesa_ndim。接著我們可以使用 MLX 內建的索引工具為每個執行緒取出正確的元素。

讓我們把上面的 myexp 改成支援任意步幅的陣列,而不依賴 ensure_row_contiguous 的複製:

source = """
    uint elem = thread_position_in_grid.x;
    // Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
    uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
    T tmp = inp[loc];
    // Output arrays are always row contiguous
    out[elem] = metal::exp(tmp);
"""

kernel = mx.fast.metal_kernel(
    name="myexp_strided",
    input_names=["inp"],
    output_names=["out"],
    source=source,
    ensure_row_contiguous=False,
)

def exp_elementwise(a: mx.array):
    outputs = kernel(
        inputs=[a],
        template=[("T", mx.float32)],
        grid=(a.size, 1, 1),
        threadgroup=(256, 1, 1),
        output_shapes=[a.shape],
        output_dtypes=[a.dtype],
    )
    return outputs[0]

a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
# make non-contiguous
a = a[::2]
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))

複雜範例#

讓我們實作更複雜的例子:"bilinear" 模式的 grid_sample

我們從使用標準運算的 MLX 實作開始:

def grid_sample_ref(x, grid):
    N, H_in, W_in, _ = x.shape
    ix = ((grid[..., 0] + 1) * W_in - 1) / 2
    iy = ((grid[..., 1] + 1) * H_in - 1) / 2

    ix_nw = mx.floor(ix).astype(mx.int32)
    iy_nw = mx.floor(iy).astype(mx.int32)

    ix_ne = ix_nw + 1
    iy_ne = iy_nw

    ix_sw = ix_nw
    iy_sw = iy_nw + 1

    ix_se = ix_nw + 1
    iy_se = iy_nw + 1

    nw = (ix_se - ix)    * (iy_se - iy)
    ne = (ix    - ix_sw) * (iy_sw - iy)
    sw = (ix_ne - ix)    * (iy    - iy_ne)
    se = (ix    - ix_nw) * (iy    - iy_nw)

    I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
    I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
    I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
    I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]

    mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
    mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
    mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
    mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)

    I_nw *= mask_nw[..., None]
    I_ne *= mask_ne[..., None]
    I_sw *= mask_sw[..., None]
    I_se *= mask_se[..., None]

    output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se

    return output

現在讓我們將 custom_function()fast.metal_kernel() 搭配使用,為正向與反向傳播寫出快速的 GPU 內核。

首先我們將正向傳播實作為融合內核:

source = """
    uint elem = thread_position_in_grid.x;
    int H = x_shape[1];
    int W = x_shape[2];
    int C = x_shape[3];
    int gH = grid_shape[1];
    int gW = grid_shape[2];

    int w_stride = C;
    int h_stride = W * w_stride;
    int b_stride = H * h_stride;

    uint grid_idx = elem / C * 2;
    float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
    float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;

    int ix_nw = floor(ix);
    int iy_nw = floor(iy);

    int ix_ne = ix_nw + 1;
    int iy_ne = iy_nw;

    int ix_sw = ix_nw;
    int iy_sw = iy_nw + 1;

    int ix_se = ix_nw + 1;
    int iy_se = iy_nw + 1;

    T nw = (ix_se - ix)    * (iy_se - iy);
    T ne = (ix    - ix_sw) * (iy_sw - iy);
    T sw = (ix_ne - ix)    * (iy    - iy_ne);
    T se = (ix    - ix_nw) * (iy    - iy_nw);

    int batch_idx = elem / C / gH / gW * b_stride;
    int channel_idx = elem % C;
    int base_idx = batch_idx + channel_idx;

    T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
    T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
    T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
    T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];

    I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
    I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
    I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
    I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;

    out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
"""

kernel = mx.fast.metal_kernel(
    name="grid_sample",
    input_names=["x", "grid"],
    output_names=["out"],
    source=source,
)

@mx.custom_function
def grid_sample(x, grid):

    assert x.ndim == 4, "`x` must be 4D."
    assert grid.ndim == 4, "`grid` must be 4D."

    B, _, _, C = x.shape
    _, gN, gM, D = grid.shape
    out_shape = (B, gN, gM, C)

    assert D == 2, "Last dim of `grid` must be size 2."

    outputs = kernel(
        inputs=[x, grid],
        template=[("T", x.dtype)],
        output_shapes=[out_shape],
        output_dtypes=[x.dtype],
        grid=(np.prod(out_shape), 1, 1),
        threadgroup=(256, 1, 1),
    )
    return outputs[0]

對於像是以下合理大小的輸入:

x.shape = (8, 1024, 1024, 64)
grid.shape = (8, 256, 256, 2)

在 M1 Max 上,我們看到顯著的效能提升:

55.7ms -> 6.7ms => 8x speed up

Grid Sample VJP#

由於我們用 custom_function() 裝飾 grid_sample,現在可以定義自訂的 vjp 轉換,讓 MLX 能夠對它做微分。

反向傳播需要以原子方式更新 x_grad/grid_grad,因此需要一些額外的 fast.metal_kernel() 功能:

  • init_value=0

    在內核執行前,將所有輸出初始化為此值。這讓我們可以只更新輸出陣列的一部分。

  • atomic_outputs=True

    在函式簽名中將所有內核輸出指定為 atomic。這表示我們可以使用 Metal 的 atomic 功能,從多個執行緒群組同時更新 x_gradgrid_grad 陣列。詳情請參見 Metal Shading Language Specification 第 6.15 節。

接著我們可以如下實作反向傳播:

source = """
    uint elem = thread_position_in_grid.x;
    int H = x_shape[1];
    int W = x_shape[2];
    int C = x_shape[3];
    // Pad C to the nearest larger simdgroup size multiple
    int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;

    int gH = grid_shape[1];
    int gW = grid_shape[2];

    int w_stride = C;
    int h_stride = W * w_stride;
    int b_stride = H * h_stride;

    uint grid_idx = elem / C_padded * 2;
    float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
    float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;

    int ix_nw = floor(ix);
    int iy_nw = floor(iy);

    int ix_ne = ix_nw + 1;
    int iy_ne = iy_nw;

    int ix_sw = ix_nw;
    int iy_sw = iy_nw + 1;

    int ix_se = ix_nw + 1;
    int iy_se = iy_nw + 1;

    T nw = (ix_se - ix)    * (iy_se - iy);
    T ne = (ix    - ix_sw) * (iy_sw - iy);
    T sw = (ix_ne - ix)    * (iy    - iy_ne);
    T se = (ix    - ix_nw) * (iy    - iy_nw);

    int batch_idx = elem / C_padded / gH / gW * b_stride;
    int channel_idx = elem % C_padded;
    int base_idx = batch_idx + channel_idx;

    T gix = T(0);
    T giy = T(0);
    if (channel_idx < C) {
        int cot_index = elem / C_padded * C + channel_idx;
        T cot = cotangent[cot_index];
        if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
            int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
            atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);

            T I_nw = x[offset];
            gix -= I_nw * (iy_se - iy) * cot;
            giy -= I_nw * (ix_se - ix) * cot;
        }
        if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
            int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
            atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);

            T I_ne = x[offset];
            gix += I_ne * (iy_sw - iy) * cot;
            giy -= I_ne * (ix - ix_sw) * cot;
        }
        if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
            int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
            atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);

            T I_sw = x[offset];
            gix -= I_sw * (iy - iy_ne) * cot;
            giy += I_sw * (ix_ne - ix) * cot;
        }
        if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
            int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
            atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);

            T I_se = x[offset];
            gix += I_se * (iy - iy_nw) * cot;
            giy += I_se * (ix - ix_nw) * cot;
        }
    }

    T gix_mult = W / 2;
    T giy_mult = H / 2;

    // Reduce across each simdgroup first.
    // This is much faster than relying purely on atomics.
    gix = simd_sum(gix);
    giy = simd_sum(giy);

    if (thread_index_in_simdgroup == 0) {
        atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
        atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
    }
"""
kernel = mx.fast.metal_kernel(
    name="grid_sample_grad",
    input_names=["x", "grid", "cotangent"],
    output_names=["x_grad", "grid_grad"],
    source=source,
    atomic_outputs=True,
)

@grid_sample.vjp
def grid_sample_vjp(primals, cotangent, _):
    x, grid = primals
    B, _, _, C = x.shape
    _, gN, gM, D = grid.shape

    assert D == 2, "Last dim of `grid` must be size 2."

    # pad the output channels to simd group size
    # so that our `simd_sum`s don't overlap.
    simdgroup_size = 32
    C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
    grid_size = B * gN * gM * C_padded
    outputs = kernel(
        inputs=[x, grid, cotangent],
        template=[("T", x.dtype)],
        output_shapes=[x.shape, grid.shape],
        output_dtypes=[x.dtype, x.dtype],
        grid=(grid_size, 1, 1),
        threadgroup=(256, 1, 1),
        init_value=0,
    )
    return outputs[0], outputs[1]

vjp 的加速幅度更大:

676.4ms -> 16.7ms => 40x speed up