MLX 的自訂擴充#
你可以在 CPU 或 GPU 上用自訂運算擴充 MLX。本指南透過一個簡單範例說明如何做到這點。
範例介紹#
假設你想要一個運算,輸入兩個陣列 x 和 y,分別乘上係數 alpha 與 beta,再相加得到結果 z = alpha * x + beta * y。你可以直接在 MLX 中這麼做:
import mlx.core as mx
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
return alpha * x + beta * y
此函式執行該運算,同時將實作與函式轉換交給 MLX。
不過你可能想自訂底層實作,例如讓它更快。在本教學中,我們會逐步加入自訂擴充。內容涵蓋:
MLX 程式庫的結構。
實作 CPU 運算。
使用 Metal 實作 GPU 運算。
加入
vjp與jvp函式轉換。建置自訂擴充並綁定到 Python。
運算與原語#
MLX 中的運算會建立計算圖。原語提供評估與轉換計算圖的規則。讓我們先更深入介紹運算。
運算#
運算是對陣列操作的前端函式。它們定義於 C++ API(運算),並由 Python API(運算)加以綁定。
我們想要一個名為 axpby() 的運算,接收兩個陣列 x 與 y,以及兩個純量 alpha 與 beta。以下是在 C++ 中的定義方式:
/**
* Scale and sum two vectors element-wise
* z = alpha * x + beta * y
*
* Use NumPy-style broadcasting between x and y
* Inputs are upcasted to floats if needed
**/
array axpby(
const array& x, // Input array x
const array& y, // Input array y
const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y
StreamOrDevice s = {} // Stream on which to schedule the operation
);
最簡單的實作方式是使用現有運算:
array axpby(
const array& x, // Input array x
const array& y, // Input array y
const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
) {
// Scale x and y on the provided stream
auto ax = multiply(array(alpha), x, s);
auto by = multiply(array(beta), y, s);
// Add and return
return add(ax, by, s);
}
運算本身不包含作用於資料的實作,也不包含轉換規則。相反地,它們是易用的介面,以 Primitive 為建構基礎。
原語#
Primitive 是 array 計算圖的一部分。它定義了如何在給定輸入陣列時產生輸出陣列。此外,Primitive 也有在 CPU 或 GPU 上執行的方法,以及 vjp、jvp 等函式轉換。讓我們回到範例來更具體說明:
class Axpby : public Primitive {
public:
explicit Axpby(Stream stream, float alpha, float beta)
: Primitive(stream), alpha_(alpha), beta_(beta){};
/**
* A primitive must know how to evaluate itself on the CPU/GPU
* for the given inputs and populate the output array.
*
* To avoid unnecessary allocations, the evaluation function
* is responsible for allocating space for the array.
*/
void eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) override;
void eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) override;
/** The Jacobian-vector product. */
std::vector<array> jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) override;
/** The vector-Jacobian product. */
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
/**
* The primitive must know how to vectorize itself across
* the given axes. The output is a pair containing the array
* representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension.
*/
std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
/** The name of primitive. */
const char* name() const override {
return "Axpby";
}
/** Equivalence check **/
bool is_equivalent(const Primitive& other) const override;
private:
float alpha_;
float beta_;
};
Axpby 類別衍生自基礎 Primitive 類別。Axpby 將 alpha 與 beta 視為參數,並透過 Axpby::eval_cpu() 與 Axpby::eval_gpu() 提供在給定輸入時產生輸出陣列的實作。同時也在 Axpby::jvp()、Axpby::vjp() 與 Axpby::vmap() 中提供轉換規則。
使用原語#
運算可以使用此 Primitive 在計算圖中新增一個 array。建立:class:array 時可提供其資料型別、形狀、負責計算它的 Primitive,以及傳給該原語的 array 輸入。
讓我們以 Axpby 原語重新實作該運算。
array axpby(
const array& x, // Input array x
const array& y, // Input array y
const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
) {
// Promote dtypes between x and y as needed
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
// Upcast to float32 for non-floating point inputs x and y
auto out_dtype = issubdtype(promoted_dtype, float32)
? promoted_dtype
: promote_types(promoted_dtype, float32);
// Cast x and y up to the determined dtype (on the same stream s)
auto x_casted = astype(x, out_dtype, s);
auto y_casted = astype(y, out_dtype, s);
// Broadcast the shapes of x and y (on the same stream s)
auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);
auto out_shape = broadcasted_inputs[0].shape();
// Construct the array as the output of the Axpby primitive
// with the broadcasted and upcasted arrays as inputs
return array(
/* const std::vector<int>& shape = */ out_shape,
/* Dtype dtype = */ out_dtype,
/* std::unique_ptr<Primitive> primitive = */
std::make_shared<Axpby>(to_stream(s), alpha, beta),
/* const std::vector<array>& inputs = */ broadcasted_inputs);
}
此運算現在會處理以下事項:
將輸入升型並決定輸出資料型別。
廣播輸入並決定輸出形狀。
使用給定的 stream、
alpha與beta建立Axpby原語。使用原語與輸入建立輸出
array。
實作原語#
單獨呼叫該運算時不會進行計算。運算只會建立計算圖。當我們對輸出陣列求值時,MLX 會排程計算圖的執行,並根據使用者指定的 stream/裝置呼叫 Axpby::eval_cpu() 或 Axpby::eval_gpu()。
警告
當呼叫 Primitive::eval_cpu() 或 Primitive::eval_gpu() 時,輸出陣列尚未分配記憶體。因此需要由這些函式的實作來視需求配置記憶體。
實作 CPU 後端#
讓我們先實作 Axpby::eval_cpu()。
這個方法會遍歷輸出陣列的每個元素,找出 x 與 y 對應的輸入元素,並逐元素執行運算。這個邏輯在範本函式 axpby_impl() 中實作。
template <typename T>
void axpby_impl(
const mx::array& x,
const mx::array& y,
mx::array& out,
float alpha_,
float beta_,
mx::Stream stream) {
out.set_data(mx::allocator::malloc(out.nbytes()));
// Get the CPU command encoder and register input and output arrays
auto& encoder = mx::cpu::get_command_encoder(stream);
encoder.set_input_array(x);
encoder.set_input_array(y);
encoder.set_output_array(out);
// Launch the CPU kernel
encoder.dispatch([x_ptr = x.data<T>(),
y_ptr = y.data<T>(),
out_ptr = out.data<T>(),
size = out.size(),
shape = out.shape(),
x_strides = x.strides(),
y_strides = y.strides(),
alpha_,
beta_]() {
// Cast alpha and beta to the relevant types
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// Do the element-wise operation for each output
for (size_t out_idx = 0; out_idx < size; out_idx++) {
// Map linear indices to offsets in x and y
auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides);
auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides);
// We allocate the output to be contiguous and regularly strided
// (defaults to row major) and hence it doesn't need additional mapping
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
}
});
}
我們的實作應適用於所有輸入的浮點陣列。因此我們為 float32、float16、bfloat16 與 complex64 加入分派。若遇到非預期型別就會拋出錯誤。
void Axpby::eval_cpu(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) {
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Dispatch to the correct dtype
if (out.dtype() == mx::float32) {
return axpby_impl<float>(x, y, out, alpha_, beta_, stream());
} else if (out.dtype() == mx::float16) {
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream());
} else if (out.dtype() == mx::bfloat16) {
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream());
} else if (out.dtype() == mx::complex64) {
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream());
} else {
throw std::runtime_error(
"Axpby is only supported for floating point types.");
}
}
到這裡就足夠在 CPU stream 上執行 axpby() 運算!若你不打算在 GPU 上執行此運算,或不使用包含 Axpby 的計算圖轉換,可以在此停止實作原語。
實作 GPU 後端#
Apple silicon 裝置透過 Metal 著色語言存取 GPU,而 MLX 的 GPU 內核則使用 Metal 撰寫。
備註
如果你剛接觸 Metal,以下資源會很有幫助:
Metal 計算管線導覽:Metal Example
Metal 著色語言文件:Metal Specification
從 C++ 使用 Metal:Metal-cpp
讓 GPU 內核保持簡單。我們會啟動與輸出元素數量相同的執行緒。每個執行緒會從 x 與 y 取出需要的元素,進行逐點運算,並更新其負責的輸出元素。
template <typename T>
[[kernel]] void axpby_general(
device const T* x [[buffer(0)]],
device const T* y [[buffer(1)]],
device T* out [[buffer(2)]],
constant const float& alpha [[buffer(3)]],
constant const float& beta [[buffer(4)]],
constant const int* shape [[buffer(5)]],
constant const int64_t* x_strides [[buffer(6)]],
constant const int64_t* y_strides [[buffer(7)]],
constant const int& ndim [[buffer(8)]],
uint index [[thread_position_in_grid]]) {
// Convert linear indices to offsets in array
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
// Do the operation and update the output
out[index] =
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
}
接著需要為所有浮點型別實例化此範本,並為每個實例指定唯一的 host name 以便識別。
instantiate_kernel("axpby_general_float32", axpby_general, float)
instantiate_kernel("axpby_general_float16", axpby_general, float16_t)
instantiate_kernel("axpby_general_bfloat16", axpby_general, bfloat16_t)
instantiate_kernel("axpby_general_complex64", axpby_general, complex64_t)
用於決定內核、設定輸入、解析網格維度與分派到 GPU 的邏輯,都在 Axpby::eval_gpu() 中,如下所示。
/** Evaluate primitive on GPU */
void Axpby::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
// Prepare inputs
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Each primitive carries the stream it should execute on
// and each stream carries its device identifiers
auto& s = stream();
// We get the needed metal device using the stream
auto& d = metal::device(s.device);
// Allocate output memory
out.set_data(allocator::malloc(out.nbytes()));
// Resolve name of kernel
std::stream kname;
kname = "axpby_general_" + type_to_name(out);
// Load the metal library
auto lib = d.get_library("mlx_ext", current_binary_dir());
// Make a kernel from this metal library
auto kernel = d.get_kernel(kname, lib);
// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
// Kernel parameters are registered with buffer indices corresponding to
// those in the kernel declaration at axpby.metal
int ndim = out.ndim();
size_t nelem = out.size();
// Encode input arrays to kernel
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(y, 1);
// Encode output arrays to kernel
compute_encoder.set_output_array(out, 2);
// Encode alpha and beta
compute_encoder.set_bytes(alpha_, 3);
compute_encoder.set_bytes(beta_, 4);
// Encode shape, strides and ndim
compute_encoder.set_vector_bytes(x.shape(), 5);
compute_encoder.set_vector_bytes(x.strides(), 6);
compute_encoder.set_bytes(y.strides(), 7);
compute_encoder.set_bytes(ndim, 8);
// We launch 1 thread for each input and make sure that the number of
// threads in any given threadgroup is not higher than the max allowed
size_t tgp_size = std::min(nelem, kernel->maxTotalThreadsPerThreadgroup());
// Fix the 3D size of each threadgroup (in terms of threads)
MTL::Size group_dims = MTL::Size(tgp_size, 1, 1);
// Fix the 3D size of the launch grid (in terms of threads)
MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
// Launch the grid with the given number of threads divided among
// the given threadgroups
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
現在我們可以在 CPU 和 GPU 上呼叫 axpby() 運算了!
在繼續之前,先留意關於 MLX 與 Metal 的幾點:MLX 會追蹤目前的 command_buffer 與其對應的 MTLCommandBuffer。我們會依賴 d.get_command_encoder() 取得目前使用中的 Metal 計算命令編碼器,而不是自行建立新的編碼器並在最後呼叫 compute_encoder->end_encoding()。MLX 會持續將內核(計算管線)加入目前的命令緩衝區,直到達到指定上限或命令緩衝區需要為同步而清空。
原語轉換#
接下來,我們將為 Primitive 的轉換加入實作。這些轉換可以建構在其他運算之上,包括我們剛剛定義的運算:
/** The Jacobian-vector product. */
std::vector<array> Axpby::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
// Forward mode diff that pushes along the tangents
// The jvp transform on the primitive can be built with ops
// that are scheduled on the same stream as the primitive
// If argnums = {0}, we only push along x in which case the
// jvp is just the tangent scaled by alpha
// Similarly, if argnums = {1}, the jvp is just the tangent
// scaled by beta
if (argnums.size() > 1) {
auto scale = argnums[0] == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, tangents[0].dtype());
return {multiply(scale_arr, tangents[0], stream())};
}
// If argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta
else {
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
}
}
/** The vector-Jacobian product. */
std::vector<array> Axpby::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<int>& /* unused */) {
// Reverse mode diff
std::vector<array> vjps;
for (auto arg : argnums) {
auto scale = arg == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, cotangents[0].dtype());
vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
}
return vjps;
}
注意,轉換不需要完整定義即可開始使用 Primitive。
/** Vectorize primitive along given axis */
std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
throw std::runtime_error("[Axpby] vmap not implemented.");
}
建置與綁定#
先看看整體目錄結構。
extensions/axpby/定義 C++ 擴充程式庫extensions/mlx_sample_extensions設定對應的 Python 軟體包結構extensions/bindings.cpp提供我們運算的 Python 綁定extensions/CMakeLists.txt包含建置程式庫與 Python 綁定的 CMake 規則extensions/setup.py包含使用setuptools建置與安裝 Python 軟體包的規則
綁定到 Python#
我們使用 nanobind 為 C++ 程式庫建立 Python API。由於 mlx.core.array、mlx.core.stream 等元件已提供綁定,因此加入 axpby() 很簡單。
NB_MODULE(_ext, m) {
m.doc() = "Sample extension for MLX";
m.def(
"axpby",
&axpby,
"x"_a,
"y"_a,
"alpha"_a,
"beta"_a,
nb::kw_only(),
"stream"_a = nb::none(),
R"(
Scale and sum two vectors element-wise
``z = alpha * x + beta * y``
Follows numpy style broadcasting between ``x`` and ``y``
Inputs are upcasted to floats if needed
Args:
x (array): Input array.
y (array): Input array.
alpha (float): Scaling factor for ``x``.
beta (float): Scaling factor for ``y``.
Returns:
array: ``alpha * x + beta * y``
)");
}
上述範例的大部分複雜度來自額外的裝飾細節,例如字面名稱與文件字串。
警告
必須先匯入 mlx.core,再匯入以上 nanobind 模組定義的 mlx_sample_extensions,以確保 mlx.core 元件(如 mlx.core.array)的型別轉換器可用。
使用 CMake 建置#
建置 C++ 擴充程式庫只需要 find_package(MLX CONFIG),接著連結到你的程式庫即可。
# Add library
add_library(mlx_ext)
# Add sources
target_sources(
mlx_ext
PUBLIC
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp
)
# Add include headers
target_include_directories(
mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}
)
# Link to mlx
target_link_libraries(mlx_ext PUBLIC mlx)
我們也需要建置附帶的 Metal 程式庫。為了方便,我們提供 mlx_build_metallib() 函式,可根據來源碼、標頭、目的地等資訊建置 .metallib 目標(定義於 cmake/extension.cmake,並在 MLX 軟體包中自動匯入)。
實際用法如下:
# Build metallib
if(MLX_BUILD_METAL)
mlx_build_metallib(
TARGET mlx_ext_metallib
TITLE mlx_ext
SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS}
OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}
)
add_dependencies(
mlx_ext
mlx_ext_metallib
)
endif()
最後,我們建置 nanobind 綁定
nanobind_add_module(
_ext
NB_STATIC STABLE_ABI LTO NOMINSIZE
NB_DOMAIN mlx
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
)
target_link_libraries(_ext PRIVATE mlx_ext)
if(BUILD_SHARED_LIBS)
target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
endif()
使用 setuptools 建置#
在依照上述說明設定 CMake 建置規則之後,我們可以使用 mlx.extension 中定義的建置工具:
from mlx import extension
from setuptools import setup
if __name__ == "__main__":
setup(
name="mlx_sample_extensions",
version="0.0.0",
description="Sample C++ and Metal extensions for MLX primitives.",
ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
cmdclass={"build_ext": extension.CMakeBuild},
packages=["mlx_sample_extensions"],
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
extras_require={"dev":[]},
zip_safe=False,
python_requires=">=3.8",
)
備註
我們將 extensions/mlx_sample_extensions 視為軟體包目錄,即使它只包含 __init__.py,也能確保以下事項:
必須先匯入
mlx.core,再匯入_extC++ 擴充程式庫與 Metal 程式庫會與 Python 綁定放在同一位置,並在安裝軟體包時一併複製
要建置此軟體包,請先使用 pip install -r requirements.txt 安裝建置相依。接著可在 extensions/ 內使用 python setup.py build_ext -j8 --inplace 進行就地建置以供開發。
結果會產生以下目錄結構:
當你在 extensions/ 中使用 python -m pip install . 安裝時,軟體包會以 extensions/mlx_sample_extensions 相同的結構安裝,且因為它們被指定為 package_data,所以 C++ 與 Metal 程式庫會與 Python 綁定一併複製。
使用方式#
依照上述方式安裝擴充後,你應該可以直接匯入 Python 軟體包,像使用其他 MLX 運算一樣進行測試。
讓我們看看簡單腳本及其結果:
import mlx.core as mx
from mlx_sample_extensions import axpby
a = mx.ones((3, 4))
b = mx.ones((3, 4))
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}")
print(f"c is correct: {mx.all(c == 6.0).item()}")
輸出:
c shape: [3, 4]
c dtype: float32
c is correct: True
結果#
讓我們跑一個簡單的效能測試,看看新的 axpby 運算與一開始定義的簡單 simple_axpby() 相比如何。
import mlx.core as mx
from mlx_sample_extensions import axpby
import time
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
return alpha * x + beta * y
M = 4096
N = 4096
x = mx.random.normal((M, N))
y = mx.random.normal((M, N))
alpha = 4.0
beta = 2.0
mx.eval(x, y)
def bench(f):
# Warm up
for i in range(5):
z = f(x, y, alpha, beta)
mx.eval(z)
# Timed run
s = time.perf_counter()
for i in range(100):
z = f(x, y, alpha, beta)
mx.eval(z)
e = time.perf_counter()
return 1000 * (e - s) / 100
simple_time = bench(simple_axpby)
custom_time = bench(axpby)
print(f"Simple axpby: {simple_time:.3f} ms | Custom axpby: {custom_time:.3f} ms")
結果為 Simple axpby: 1.559 ms | Custom axpby: 0.774 ms。我們立刻看到適度的提升!
這個運算現在可以用來建構其他運算、用於 mlx.nn.Module 呼叫,也能作為 grad() 等圖轉換的一部分。
腳本#
下載程式碼
完整範例程式碼可在 mlx 取得。