MLX 的自訂擴充#

你可以在 CPU 或 GPU 上用自訂運算擴充 MLX。本指南透過一個簡單範例說明如何做到這點。

範例介紹#

假設你想要一個運算,輸入兩個陣列 xy,分別乘上係數 alphabeta,再相加得到結果 z = alpha * x + beta * y。你可以直接在 MLX 中這麼做:

import mlx.core as mx

def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
    return alpha * x + beta * y

此函式執行該運算,同時將實作與函式轉換交給 MLX。

不過你可能想自訂底層實作,例如讓它更快。在本教學中,我們會逐步加入自訂擴充。內容涵蓋:

  • MLX 程式庫的結構。

  • 實作 CPU 運算。

  • 使用 Metal 實作 GPU 運算。

  • 加入 vjpjvp 函式轉換。

  • 建置自訂擴充並綁定到 Python。

運算與原語#

MLX 中的運算會建立計算圖。原語提供評估與轉換計算圖的規則。讓我們先更深入介紹運算。

運算#

運算是對陣列操作的前端函式。它們定義於 C++ API(運算),並由 Python API(運算)加以綁定。

我們想要一個名為 axpby() 的運算,接收兩個陣列 xy,以及兩個純量 alphabeta。以下是在 C++ 中的定義方式:

/**
*  Scale and sum two vectors element-wise
*  z = alpha * x + beta * y
*
*  Use NumPy-style broadcasting between x and y
*  Inputs are upcasted to floats if needed
**/
array axpby(
    const array& x, // Input array x
    const array& y, // Input array y
    const float alpha, // Scaling factor for x
    const float beta, // Scaling factor for y
    StreamOrDevice s = {} // Stream on which to schedule the operation
);

最簡單的實作方式是使用現有運算:

array axpby(
    const array& x, // Input array x
    const array& y, // Input array y
    const float alpha, // Scaling factor for x
    const float beta, // Scaling factor for y
    StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
) {
    // Scale x and y on the provided stream
    auto ax = multiply(array(alpha), x, s);
    auto by = multiply(array(beta), y, s);

    // Add and return
    return add(ax, by, s);
}

運算本身不包含作用於資料的實作,也不包含轉換規則。相反地,它們是易用的介面,以 Primitive 為建構基礎。

原語#

Primitivearray 計算圖的一部分。它定義了如何在給定輸入陣列時產生輸出陣列。此外,Primitive 也有在 CPU 或 GPU 上執行的方法,以及 vjpjvp 等函式轉換。讓我們回到範例來更具體說明:

class Axpby : public Primitive {
  public:
    explicit Axpby(Stream stream, float alpha, float beta)
        : Primitive(stream), alpha_(alpha), beta_(beta){};

    /**
    * A primitive must know how to evaluate itself on the CPU/GPU
    * for the given inputs and populate the output array.
    *
    * To avoid unnecessary allocations, the evaluation function
    * is responsible for allocating space for the array.
    */
    void eval_cpu(
        const std::vector<array>& inputs,
        std::vector<array>& outputs) override;
    void eval_gpu(
        const std::vector<array>& inputs,
        std::vector<array>& outputs) override;

    /** The Jacobian-vector product. */
    std::vector<array> jvp(
        const std::vector<array>& primals,
        const std::vector<array>& tangents,
        const std::vector<int>& argnums) override;

    /** The vector-Jacobian product. */
    std::vector<array> vjp(
        const std::vector<array>& primals,
        const std::vector<array>& cotangents,
        const std::vector<int>& argnums,
        const std::vector<array>& outputs) override;

    /**
    * The primitive must know how to vectorize itself across
    * the given axes. The output is a pair containing the array
    * representing the vectorized computation and the axis which
    * corresponds to the output vectorized dimension.
    */
    std::pair<std::vector<array>, std::vector<int>> vmap(
        const std::vector<array>& inputs,
        const std::vector<int>& axes) override;

    /** The name of primitive. */
    const char* name() const override {
      return "Axpby";
    }

    /** Equivalence check **/
    bool is_equivalent(const Primitive& other) const override;

  private:
    float alpha_;
    float beta_;
};

Axpby 類別衍生自基礎 Primitive 類別。Axpbyalphabeta 視為參數,並透過 Axpby::eval_cpu()Axpby::eval_gpu() 提供在給定輸入時產生輸出陣列的實作。同時也在 Axpby::jvp()Axpby::vjp()Axpby::vmap() 中提供轉換規則。

使用原語#

運算可以使用此 Primitive 在計算圖中新增一個 array。建立:class:array 時可提供其資料型別、形狀、負責計算它的 Primitive,以及傳給該原語的 array 輸入。

讓我們以 Axpby 原語重新實作該運算。

array axpby(
    const array& x, // Input array x
    const array& y, // Input array y
    const float alpha, // Scaling factor for x
    const float beta, // Scaling factor for y
    StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
) {
    // Promote dtypes between x and y as needed
    auto promoted_dtype = promote_types(x.dtype(), y.dtype());

    // Upcast to float32 for non-floating point inputs x and y
    auto out_dtype = issubdtype(promoted_dtype, float32)
        ? promoted_dtype
        : promote_types(promoted_dtype, float32);

    // Cast x and y up to the determined dtype (on the same stream s)
    auto x_casted = astype(x, out_dtype, s);
    auto y_casted = astype(y, out_dtype, s);

    // Broadcast the shapes of x and y (on the same stream s)
    auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);
    auto out_shape = broadcasted_inputs[0].shape();

    // Construct the array as the output of the Axpby primitive
    // with the broadcasted and upcasted arrays as inputs
    return array(
        /* const std::vector<int>& shape = */ out_shape,
        /* Dtype dtype = */ out_dtype,
        /* std::unique_ptr<Primitive> primitive = */
        std::make_shared<Axpby>(to_stream(s), alpha, beta),
        /* const std::vector<array>& inputs = */ broadcasted_inputs);
}

此運算現在會處理以下事項:

  1. 將輸入升型並決定輸出資料型別。

  2. 廣播輸入並決定輸出形狀。

  3. 使用給定的 stream、alphabeta 建立 Axpby 原語。

  4. 使用原語與輸入建立輸出 array

實作原語#

單獨呼叫該運算時不會進行計算。運算只會建立計算圖。當我們對輸出陣列求值時,MLX 會排程計算圖的執行,並根據使用者指定的 stream/裝置呼叫 Axpby::eval_cpu()Axpby::eval_gpu()

警告

當呼叫 Primitive::eval_cpu()Primitive::eval_gpu() 時,輸出陣列尚未分配記憶體。因此需要由這些函式的實作來視需求配置記憶體。

實作 CPU 後端#

讓我們先實作 Axpby::eval_cpu()

這個方法會遍歷輸出陣列的每個元素,找出 xy 對應的輸入元素,並逐元素執行運算。這個邏輯在範本函式 axpby_impl() 中實作。

template <typename T>
void axpby_impl(
    const mx::array& x,
    const mx::array& y,
    mx::array& out,
    float alpha_,
    float beta_,
    mx::Stream stream) {
  out.set_data(mx::allocator::malloc(out.nbytes()));

  // Get the CPU command encoder and register input and output arrays
  auto& encoder = mx::cpu::get_command_encoder(stream);
  encoder.set_input_array(x);
  encoder.set_input_array(y);
  encoder.set_output_array(out);

  // Launch the CPU kernel
  encoder.dispatch([x_ptr = x.data<T>(),
                    y_ptr = y.data<T>(),
                    out_ptr = out.data<T>(),
                    size = out.size(),
                    shape = out.shape(),
                    x_strides = x.strides(),
                    y_strides = y.strides(),
                    alpha_,
                    beta_]() {

    // Cast alpha and beta to the relevant types
    T alpha = static_cast<T>(alpha_);
    T beta = static_cast<T>(beta_);

    // Do the element-wise operation for each output
    for (size_t out_idx = 0; out_idx < size; out_idx++) {
      // Map linear indices to offsets in x and y
      auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides);
      auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides);

      // We allocate the output to be contiguous and regularly strided
      // (defaults to row major) and hence it doesn't need additional mapping
      out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
    }
  });
}

我們的實作應適用於所有輸入的浮點陣列。因此我們為 float32float16bfloat16complex64 加入分派。若遇到非預期型別就會拋出錯誤。

void Axpby::eval_cpu(
    const std::vector<mx::array>& inputs,
    std::vector<mx::array>& outputs) {
  auto& x = inputs[0];
  auto& y = inputs[1];
  auto& out = outputs[0];

  // Dispatch to the correct dtype
  if (out.dtype() == mx::float32) {
    return axpby_impl<float>(x, y, out, alpha_, beta_, stream());
  } else if (out.dtype() == mx::float16) {
    return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream());
  } else if (out.dtype() == mx::bfloat16) {
    return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream());
  } else if (out.dtype() == mx::complex64) {
    return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream());
  } else {
    throw std::runtime_error(
        "Axpby is only supported for floating point types.");
  }
}

到這裡就足夠在 CPU stream 上執行 axpby() 運算!若你不打算在 GPU 上執行此運算,或不使用包含 Axpby 的計算圖轉換,可以在此停止實作原語。

實作 GPU 後端#

Apple silicon 裝置透過 Metal 著色語言存取 GPU,而 MLX 的 GPU 內核則使用 Metal 撰寫。

備註

如果你剛接觸 Metal,以下資源會很有幫助:

讓 GPU 內核保持簡單。我們會啟動與輸出元素數量相同的執行緒。每個執行緒會從 xy 取出需要的元素,進行逐點運算,並更新其負責的輸出元素。

template <typename T>
[[kernel]] void axpby_general(
        device const T* x [[buffer(0)]],
        device const T* y [[buffer(1)]],
        device T* out [[buffer(2)]],
        constant const float& alpha [[buffer(3)]],
        constant const float& beta [[buffer(4)]],
        constant const int* shape [[buffer(5)]],
        constant const int64_t* x_strides [[buffer(6)]],
        constant const int64_t* y_strides [[buffer(7)]],
        constant const int& ndim [[buffer(8)]],
        uint index [[thread_position_in_grid]]) {
    // Convert linear indices to offsets in array
    auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
    auto y_offset = elem_to_loc(index, shape, y_strides, ndim);

    // Do the operation and update the output
    out[index] =
        static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
}

接著需要為所有浮點型別實例化此範本,並為每個實例指定唯一的 host name 以便識別。

instantiate_kernel("axpby_general_float32", axpby_general, float)
instantiate_kernel("axpby_general_float16", axpby_general, float16_t)
instantiate_kernel("axpby_general_bfloat16", axpby_general, bfloat16_t)
instantiate_kernel("axpby_general_complex64", axpby_general, complex64_t)

用於決定內核、設定輸入、解析網格維度與分派到 GPU 的邏輯,都在 Axpby::eval_gpu() 中,如下所示。

/** Evaluate primitive on GPU */
void Axpby::eval_gpu(
  const std::vector<array>& inputs,
  std::vector<array>& outputs) {
    // Prepare inputs
    assert(inputs.size() == 2);
    auto& x = inputs[0];
    auto& y = inputs[1];
    auto& out = outputs[0];

    // Each primitive carries the stream it should execute on
    // and each stream carries its device identifiers
    auto& s = stream();
    // We get the needed metal device using the stream
    auto& d = metal::device(s.device);

    // Allocate output memory
    out.set_data(allocator::malloc(out.nbytes()));

    // Resolve name of kernel
    std::stream kname;
    kname = "axpby_general_" + type_to_name(out);

    // Load the metal library
    auto lib = d.get_library("mlx_ext", current_binary_dir());

    // Make a kernel from this metal library
    auto kernel = d.get_kernel(kname, lib);

    // Prepare to encode kernel
    auto& compute_encoder = d.get_command_encoder(s.index);
    compute_encoder.set_compute_pipeline_state(kernel);

    // Kernel parameters are registered with buffer indices corresponding to
    // those in the kernel declaration at axpby.metal
    int ndim = out.ndim();
    size_t nelem = out.size();

    // Encode input arrays to kernel
    compute_encoder.set_input_array(x, 0);
    compute_encoder.set_input_array(y, 1);

    // Encode output arrays to kernel
    compute_encoder.set_output_array(out, 2);

    // Encode alpha and beta
    compute_encoder.set_bytes(alpha_, 3);
    compute_encoder.set_bytes(beta_, 4);

    // Encode shape, strides and ndim
    compute_encoder.set_vector_bytes(x.shape(), 5);
    compute_encoder.set_vector_bytes(x.strides(), 6);
    compute_encoder.set_bytes(y.strides(), 7);
    compute_encoder.set_bytes(ndim, 8);

    // We launch 1 thread for each input and make sure that the number of
    // threads in any given threadgroup is not higher than the max allowed
    size_t tgp_size = std::min(nelem, kernel->maxTotalThreadsPerThreadgroup());

    // Fix the 3D size of each threadgroup (in terms of threads)
    MTL::Size group_dims = MTL::Size(tgp_size, 1, 1);

    // Fix the 3D size of the launch grid (in terms of threads)
    MTL::Size grid_dims = MTL::Size(nelem, 1, 1);

    // Launch the grid with the given number of threads divided among
    // the given threadgroups
    compute_encoder.dispatch_threads(grid_dims, group_dims);
}

現在我們可以在 CPU 和 GPU 上呼叫 axpby() 運算了!

在繼續之前,先留意關於 MLX 與 Metal 的幾點:MLX 會追蹤目前的 command_buffer 與其對應的 MTLCommandBuffer。我們會依賴 d.get_command_encoder() 取得目前使用中的 Metal 計算命令編碼器,而不是自行建立新的編碼器並在最後呼叫 compute_encoder->end_encoding()。MLX 會持續將內核(計算管線)加入目前的命令緩衝區,直到達到指定上限或命令緩衝區需要為同步而清空。

原語轉換#

接下來,我們將為 Primitive 的轉換加入實作。這些轉換可以建構在其他運算之上,包括我們剛剛定義的運算:

/** The Jacobian-vector product. */
std::vector<array> Axpby::jvp(
        const std::vector<array>& primals,
        const std::vector<array>& tangents,
        const std::vector<int>& argnums) {
    // Forward mode diff that pushes along the tangents
    // The jvp transform on the primitive can be built with ops
    // that are scheduled on the same stream as the primitive

    // If argnums = {0}, we only push along x in which case the
    // jvp is just the tangent scaled by alpha
    // Similarly, if argnums = {1}, the jvp is just the tangent
    // scaled by beta
    if (argnums.size() > 1) {
        auto scale = argnums[0] == 0 ? alpha_ : beta_;
        auto scale_arr = array(scale, tangents[0].dtype());
        return {multiply(scale_arr, tangents[0], stream())};
    }
    // If argnums = {0, 1}, we take contributions from both
    // which gives us jvp = tangent_x * alpha + tangent_y * beta
    else {
        return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
    }
}
/** The vector-Jacobian product. */
std::vector<array> Axpby::vjp(
        const std::vector<array>& primals,
        const std::vector<array>& cotangents,
        const std::vector<int>& argnums,
        const std::vector<int>& /* unused */) {
    // Reverse mode diff
    std::vector<array> vjps;
    for (auto arg : argnums) {
        auto scale = arg == 0 ? alpha_ : beta_;
        auto scale_arr = array(scale, cotangents[0].dtype());
        vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
    }
    return vjps;
}

注意,轉換不需要完整定義即可開始使用 Primitive

/** Vectorize primitive along given axis */
std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
        const std::vector<array>& inputs,
        const std::vector<int>& axes) {
    throw std::runtime_error("[Axpby] vmap not implemented.");
}

建置與綁定#

先看看整體目錄結構。

extensions
├── axpby
│ ├── axpby.cpp
│ ├── axpby.h
│ └── axpby.metal
├── mlx_sample_extensions
│ └── __init__.py
├── bindings.cpp
├── CMakeLists.txt
└── setup.py
  • extensions/axpby/ 定義 C++ 擴充程式庫

  • extensions/mlx_sample_extensions 設定對應的 Python 軟體包結構

  • extensions/bindings.cpp 提供我們運算的 Python 綁定

  • extensions/CMakeLists.txt 包含建置程式庫與 Python 綁定的 CMake 規則

  • extensions/setup.py 包含使用 setuptools 建置與安裝 Python 軟體包的規則

綁定到 Python#

我們使用 nanobind 為 C++ 程式庫建立 Python API。由於 mlx.core.arraymlx.core.stream 等元件已提供綁定,因此加入 axpby() 很簡單。

NB_MODULE(_ext, m) {
     m.doc() = "Sample extension for MLX";

     m.def(
         "axpby",
         &axpby,
         "x"_a,
         "y"_a,
         "alpha"_a,
         "beta"_a,
         nb::kw_only(),
         "stream"_a = nb::none(),
         R"(
             Scale and sum two vectors element-wise
             ``z = alpha * x + beta * y``

             Follows numpy style broadcasting between ``x`` and ``y``
             Inputs are upcasted to floats if needed

             Args:
                 x (array): Input array.
                 y (array): Input array.
                 alpha (float): Scaling factor for ``x``.
                 beta (float): Scaling factor for ``y``.

             Returns:
                 array: ``alpha * x + beta * y``
         )");
 }

上述範例的大部分複雜度來自額外的裝飾細節,例如字面名稱與文件字串。

警告

必須先匯入 mlx.core,再匯入以上 nanobind 模組定義的 mlx_sample_extensions,以確保 mlx.core 元件(如 mlx.core.array)的型別轉換器可用。

使用 CMake 建置#

建置 C++ 擴充程式庫只需要 find_package(MLX CONFIG),接著連結到你的程式庫即可。

# Add library
add_library(mlx_ext)

# Add sources
target_sources(
    mlx_ext
    PUBLIC
    ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp
)

# Add include headers
target_include_directories(
    mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}
)

# Link to mlx
target_link_libraries(mlx_ext PUBLIC mlx)

我們也需要建置附帶的 Metal 程式庫。為了方便,我們提供 mlx_build_metallib() 函式,可根據來源碼、標頭、目的地等資訊建置 .metallib 目標(定義於 cmake/extension.cmake,並在 MLX 軟體包中自動匯入)。

實際用法如下:

# Build metallib
if(MLX_BUILD_METAL)

mlx_build_metallib(
    TARGET mlx_ext_metallib
    TITLE mlx_ext
    SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
    INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS}
    OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}
)

add_dependencies(
    mlx_ext
    mlx_ext_metallib
)

endif()

最後,我們建置 nanobind 綁定

nanobind_add_module(
  _ext
  NB_STATIC STABLE_ABI LTO NOMINSIZE
  NB_DOMAIN mlx
  ${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
)
target_link_libraries(_ext PRIVATE mlx_ext)

if(BUILD_SHARED_LIBS)
  target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
endif()

使用 setuptools 建置#

在依照上述說明設定 CMake 建置規則之後,我們可以使用 mlx.extension 中定義的建置工具:

from mlx import extension
from setuptools import setup

if __name__ == "__main__":
    setup(
        name="mlx_sample_extensions",
        version="0.0.0",
        description="Sample C++ and Metal extensions for MLX primitives.",
        ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
        cmdclass={"build_ext": extension.CMakeBuild},
        packages=["mlx_sample_extensions"],
        package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
        extras_require={"dev":[]},
        zip_safe=False,
        python_requires=">=3.8",
    )

備註

我們將 extensions/mlx_sample_extensions 視為軟體包目錄,即使它只包含 __init__.py,也能確保以下事項:

  • 必須先匯入 mlx.core,再匯入 _ext

  • C++ 擴充程式庫與 Metal 程式庫會與 Python 綁定放在同一位置,並在安裝軟體包時一併複製

要建置此軟體包,請先使用 pip install -r requirements.txt 安裝建置相依。接著可在 extensions/ 內使用 python setup.py build_ext -j8 --inplace 進行就地建置以供開發。

結果會產生以下目錄結構:

extensions
├── mlx_sample_extensions
│ ├── __init__.py
│ ├── libmlx_ext.dylib # C++ 擴充程式庫
│ ├── mlx_ext.metallib # Metal 程式庫
│ └── _ext.cpython-3x-darwin.so # Python 綁定
...

當你在 extensions/ 中使用 python -m pip install . 安裝時,軟體包會以 extensions/mlx_sample_extensions 相同的結構安裝,且因為它們被指定為 package_data,所以 C++ 與 Metal 程式庫會與 Python 綁定一併複製。

使用方式#

依照上述方式安裝擴充後,你應該可以直接匯入 Python 軟體包,像使用其他 MLX 運算一樣進行測試。

讓我們看看簡單腳本及其結果:

import mlx.core as mx
from mlx_sample_extensions import axpby

a = mx.ones((3, 4))
b = mx.ones((3, 4))
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)

print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}")
print(f"c is correct: {mx.all(c == 6.0).item()}")

輸出:

c shape: [3, 4]
c dtype: float32
c is correct: True

結果#

讓我們跑一個簡單的效能測試,看看新的 axpby 運算與一開始定義的簡單 simple_axpby() 相比如何。

import mlx.core as mx
from mlx_sample_extensions import axpby
import time

def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
    return alpha * x + beta * y

M = 4096
N = 4096

x = mx.random.normal((M, N))
y = mx.random.normal((M, N))
alpha = 4.0
beta = 2.0

mx.eval(x, y)

def bench(f):
    # Warm up
    for i in range(5):
        z = f(x, y, alpha, beta)
        mx.eval(z)

    # Timed run
    s = time.perf_counter()
    for i in range(100):
        z = f(x, y, alpha, beta)
        mx.eval(z)
    e = time.perf_counter()
    return 1000 * (e - s) / 100

simple_time = bench(simple_axpby)
custom_time = bench(axpby)

print(f"Simple axpby: {simple_time:.3f} ms | Custom axpby: {custom_time:.3f} ms")

結果為 Simple axpby: 1.559 ms | Custom axpby: 0.774 ms。我們立刻看到適度的提升!

這個運算現在可以用來建構其他運算、用於 mlx.nn.Module 呼叫,也能作為 grad() 等圖轉換的一部分。

腳本#

下載程式碼

完整範例程式碼可在 mlx 取得。