匯出函式#
MLX 提供 API 可將函式匯出到檔案並從檔案匯入。這讓你可以在一個 MLX 前端(例如 Python)編寫計算,並在另一個 MLX 前端(例如 C++)執行。
本指南透過一些範例說明 MLX 匯出 API 的基本用法。完整函式清單請參考 API 文件。
匯出基礎#
先從一個簡單範例開始:
def fun(x, y):
return x + y
x = mx.array(1.0)
y = mx.array(1.0)
mx.export_function("add.mlxfn", fun, x, y)
要匯出函式,請提供可用來呼叫該函式的範例輸入陣列。資料內容不重要,但陣列的形狀與型別很重要。在上述範例中,我們使用兩個 float32 標量陣列匯出 fun。接著可以匯入函式並執行:
add_fun = mx.import_function("add.mlxfn")
out, = add_fun(mx.array(1.0), mx.array(2.0))
# Prints: array(3, dtype=float32)
print(out)
out, = add_fun(mx.array(1.0), mx.array(3.0))
# Prints: array(4, dtype=float32)
print(out)
# Raises an exception
add_fun(mx.array(1), mx.array(3.0))
# Raises an exception
add_fun(mx.array([1.0, 2.0]), mx.array(3.0))
請注意對 add_fun 的第三與第四次呼叫會拋出例外,因為輸入的形狀與型別不同於匯出函式時使用的範例輸入。
也請注意,即使原始 fun 只回傳單一輸出陣列,匯入後的函式一律回傳包含一個或多個陣列的元組。
傳給 export_function() 與匯入函式的輸入,可以用可變的位置參數或陣列元組指定:
def fun(x, y):
return x + y
x = mx.array(1.0)
y = mx.array(1.0)
# Both arguments to fun are positional
mx.export_function("add.mlxfn", fun, x, y)
# Same as above
mx.export_function("add.mlxfn", fun, (x, y))
imported_fun = mx.import_function("add.mlxfn")
# Ok
out, = imported_fun(x, y)
# Also ok
out, = imported_fun((x, y))
你可以用位置參數或關鍵字參數傳入範例輸入。如果你以關鍵字參數匯出函式,呼叫匯入函式時也必須使用相同的關鍵字參數。
def fun(x, y):
return x + y
# One argument to fun is positional, the other is a kwarg
mx.export_function("add.mlxfn", fun, x, y=y)
imported_fun = mx.import_function("add.mlxfn")
# Ok
out, = imported_fun(x, y=y)
# Also ok
out, = imported_fun((x,), {"y": y})
# Raises since the keyword argument is missing
out, = imported_fun(x, y)
# Raises since the keyword argument has the wrong key
out, = imported_fun(x, z=y)
匯出模組#
An mlx.nn.Module can be exported with or without the parameters included
in the exported function. Here's an example:
model = nn.Linear(4, 4)
mx.eval(model.parameters())
def call(x):
return model(x)
mx.export_function("model.mlxfn", call, mx.zeros(4))
在上述範例中,匯出了 mlx.nn.Linear 模組。其參數也會儲存在 model.mlxfn 檔案中。
備註
對於匯出函式中封閉的陣列,請特別注意要先將其評估。匯出的計算圖會包含產生這些封閉輸入的計算。
如果上述範例缺少 mx.eval(model.parameters(),匯出的函式就會包含 mlx.nn.Module 參數的隨機初始化。
如果你只想匯出 Module.__call__ 函式而不包含參數,請將參數作為輸入傳給 call 包裝函式:
model = nn.Linear(4, 4)
mx.eval(model.parameters())
def call(x, **params):
# Set the model's parameters to the input parameters
model.update(tree_unflatten(list(params.items())))
return model(x)
params = tree_flatten(model.parameters(), destination={})
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
Exporting with a Callback#
To inspect the exported graph, you can pass a callback instead of a file path
to export_function().
def fun(x):
return x.astype(mx.int32)
def callback(args):
print(args)
mx.export_function(callback, fun, mx.array([1.0, 2.0]))
The argument to the callback (args) is a dictionary which includes a
type field. The possible types are:
"inputs": The ordered positional inputs to the exported function"keyword_inputs": The keyword specified inputs to the exported function"outputs": The ordered outputs of the exported function"constants": Any graph constants"primitives": Inner graph nodes representating the operations
Each type has additional fields in the args dictionary.
無形狀匯出#
與 compile() 相同,函式也能針對動態形狀輸入進行匯出。對 export_function() 或 exporter() 傳入 shapeless=True,即可匯出可用於可變形狀輸入的函式:
mx.export_function("fun.mlxfn", mx.abs, mx.array([0.0]), shapeless=True)
imported_abs = mx.import_function("fun.mlxfn")
# Ok
out, = imported_abs(mx.array([-1.0]))
# Also ok
out, = imported_abs(mx.array([-1.0, -2.0]))
若 shapeless=False``(預設值),第二次呼叫 ``imported_abs 會因形狀不符而拋出例外。
無形狀匯出的運作方式與無形狀編譯相同,應謹慎使用。詳情請參考 無形狀編譯文件。
匯出多個追蹤#
在某些情況下,函式會因輸入參數不同而建立不同的計算圖。一個簡單的管理方式是針對每組輸入匯出到新檔案。這在許多情況下可行,但若匯出函式包含大量重複的常數資料(例如 mlx.nn.Module 的參數)則可能不夠理想。
MLX 的匯出 API 可透過 exporter() 建立匯出內容管理器,將同一函式的多個追蹤匯出到單一檔案:
def fun(x, y=None):
constant = mx.array(3.0)
if y is not None:
x += y
return x + constant
with mx.exporter("fun.mlxfn", fun) as exporter:
exporter(mx.array(1.0))
exporter(mx.array(1.0), y=mx.array(0.0))
imported_function = mx.import_function("fun.mlxfn")
# Call the function with y=None
out, = imported_function(mx.array(1.0))
print(out)
# Call the function with y specified
out, = imported_function(mx.array(1.0), y=mx.array(1.0))
print(out)
在上述範例中,函式的常數資料(即 constant)只會儲存一次。
匯入函式的轉換#
像 grad()、vmap() 與 compile() 等函式轉換,在匯入函式上也能像一般 Python 函式一樣運作:
def fun(x):
return mx.sin(x)
x = mx.array(0.0)
mx.export_function("sine.mlxfn", fun, x)
imported_fun = mx.import_function("sine.mlxfn")
# Take the derivative of the imported function
dfdx = mx.grad(lambda x: imported_fun(x)[0])
# Prints: array(1, dtype=float32)
print(dfdx(x))
# Compile the imported function
mx.compile(imported_fun)
# Prints: array(0, dtype=float32)
print(compiled_fun(x)[0])
在 C++ 中匯入函式#
在 C++ 中匯入與執行函式,基本上與在 Python 中匯入與執行相同。首先依照 指示 建立一個使用 MLX 作為程式庫的簡單 C++ 專案。
接著從 Python 匯出一個簡單函式:
def fun(x, y):
return mx.exp(x + y)
x = mx.array(1.0)
y = mx.array(1.0)
mx.export_function("fun.mlxfn", fun, x, y)
只需幾行程式碼即可在 C++ 中匯入並執行該函式:
auto fun = mx::import_function("fun.mlxfn");
auto inputs = {mx::array(1.0), mx::array(1.0)};
auto outputs = fun(inputs);
// Prints: array(2, dtype=float32)
std::cout << outputs[0] << std::endl;
匯入的函式在 C++ 中也可像 Python 一樣進行轉換。在 C++ 中呼叫匯入函式時,位置參數使用 std::vector<mx::array>,關鍵字參數使用 std::map<std::string, mx::array>。
更多範例#
以下是更多完整範例,示範從 Python 匯出較複雜的函式,並在 C++ 中匯入與執行: