匯出函式#

MLX 提供 API 可將函式匯出到檔案並從檔案匯入。這讓你可以在一個 MLX 前端(例如 Python)編寫計算,並在另一個 MLX 前端(例如 C++)執行。

本指南透過一些範例說明 MLX 匯出 API 的基本用法。完整函式清單請參考 API 文件

匯出基礎#

先從一個簡單範例開始:

def fun(x, y):
  return x + y

x = mx.array(1.0)
y = mx.array(1.0)
mx.export_function("add.mlxfn", fun, x, y)

要匯出函式,請提供可用來呼叫該函式的範例輸入陣列。資料內容不重要,但陣列的形狀與型別很重要。在上述範例中,我們使用兩個 float32 標量陣列匯出 fun。接著可以匯入函式並執行:

add_fun = mx.import_function("add.mlxfn")

out, = add_fun(mx.array(1.0), mx.array(2.0))
# Prints: array(3, dtype=float32)
print(out)

out, = add_fun(mx.array(1.0), mx.array(3.0))
# Prints: array(4, dtype=float32)
print(out)

# Raises an exception
add_fun(mx.array(1), mx.array(3.0))

# Raises an exception
add_fun(mx.array([1.0, 2.0]), mx.array(3.0))

請注意對 add_fun 的第三與第四次呼叫會拋出例外,因為輸入的形狀與型別不同於匯出函式時使用的範例輸入。

也請注意,即使原始 fun 只回傳單一輸出陣列,匯入後的函式一律回傳包含一個或多個陣列的元組。

傳給 export_function() 與匯入函式的輸入,可以用可變的位置參數或陣列元組指定:

def fun(x, y):
  return x + y

x = mx.array(1.0)
y = mx.array(1.0)

# Both arguments to fun are positional
mx.export_function("add.mlxfn", fun, x, y)

# Same as above
mx.export_function("add.mlxfn", fun, (x, y))

imported_fun = mx.import_function("add.mlxfn")

# Ok
out, = imported_fun(x, y)

# Also ok
out, = imported_fun((x, y))

你可以用位置參數或關鍵字參數傳入範例輸入。如果你以關鍵字參數匯出函式,呼叫匯入函式時也必須使用相同的關鍵字參數。

def fun(x, y):
  return x + y

# One argument to fun is positional, the other is a kwarg
mx.export_function("add.mlxfn", fun, x, y=y)

imported_fun = mx.import_function("add.mlxfn")

# Ok
out, = imported_fun(x, y=y)

# Also ok
out, = imported_fun((x,), {"y": y})

# Raises since the keyword argument is missing
out, = imported_fun(x, y)

# Raises since the keyword argument has the wrong key
out, = imported_fun(x, z=y)

匯出模組#

An mlx.nn.Module can be exported with or without the parameters included in the exported function. Here's an example:

model = nn.Linear(4, 4)
mx.eval(model.parameters())

def call(x):
   return model(x)

mx.export_function("model.mlxfn", call, mx.zeros(4))

在上述範例中,匯出了 mlx.nn.Linear 模組。其參數也會儲存在 model.mlxfn 檔案中。

備註

對於匯出函式中封閉的陣列,請特別注意要先將其評估。匯出的計算圖會包含產生這些封閉輸入的計算。

如果上述範例缺少 mx.eval(model.parameters(),匯出的函式就會包含 mlx.nn.Module 參數的隨機初始化。

如果你只想匯出 Module.__call__ 函式而不包含參數,請將參數作為輸入傳給 call 包裝函式:

model = nn.Linear(4, 4)
mx.eval(model.parameters())

def call(x, **params):
  # Set the model's parameters to the input parameters
  model.update(tree_unflatten(list(params.items())))
  return model(x)

params = tree_flatten(model.parameters(), destination={})
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)

Exporting with a Callback#

To inspect the exported graph, you can pass a callback instead of a file path to export_function().

def fun(x):
  return x.astype(mx.int32)

def callback(args):
  print(args)

mx.export_function(callback, fun, mx.array([1.0, 2.0]))

The argument to the callback (args) is a dictionary which includes a type field. The possible types are:

  • "inputs": The ordered positional inputs to the exported function

  • "keyword_inputs": The keyword specified inputs to the exported function

  • "outputs": The ordered outputs of the exported function

  • "constants": Any graph constants

  • "primitives": Inner graph nodes representating the operations

Each type has additional fields in the args dictionary.

無形狀匯出#

compile() 相同,函式也能針對動態形狀輸入進行匯出。對 export_function()exporter() 傳入 shapeless=True,即可匯出可用於可變形狀輸入的函式:

mx.export_function("fun.mlxfn", mx.abs, mx.array([0.0]), shapeless=True)
imported_abs = mx.import_function("fun.mlxfn")

# Ok
out, = imported_abs(mx.array([-1.0]))

# Also ok
out, = imported_abs(mx.array([-1.0, -2.0]))

shapeless=False``(預設值),第二次呼叫 ``imported_abs 會因形狀不符而拋出例外。

無形狀匯出的運作方式與無形狀編譯相同,應謹慎使用。詳情請參考 無形狀編譯文件

匯出多個追蹤#

在某些情況下,函式會因輸入參數不同而建立不同的計算圖。一個簡單的管理方式是針對每組輸入匯出到新檔案。這在許多情況下可行,但若匯出函式包含大量重複的常數資料(例如 mlx.nn.Module 的參數)則可能不夠理想。

MLX 的匯出 API 可透過 exporter() 建立匯出內容管理器,將同一函式的多個追蹤匯出到單一檔案:

def fun(x, y=None):
    constant = mx.array(3.0)
    if y is not None:
      x += y
    return x + constant

with mx.exporter("fun.mlxfn", fun) as exporter:
    exporter(mx.array(1.0))
    exporter(mx.array(1.0), y=mx.array(0.0))

imported_function = mx.import_function("fun.mlxfn")

# Call the function with y=None
out, = imported_function(mx.array(1.0))
print(out)

# Call the function with y specified
out, = imported_function(mx.array(1.0), y=mx.array(1.0))
print(out)

在上述範例中,函式的常數資料(即 constant)只會儲存一次。

匯入函式的轉換#

grad()vmap()compile() 等函式轉換,在匯入函式上也能像一般 Python 函式一樣運作:

def fun(x):
    return mx.sin(x)

x = mx.array(0.0)
mx.export_function("sine.mlxfn", fun, x)

imported_fun = mx.import_function("sine.mlxfn")

# Take the derivative of the imported function
dfdx = mx.grad(lambda x: imported_fun(x)[0])
# Prints: array(1, dtype=float32)
print(dfdx(x))

# Compile the imported function
mx.compile(imported_fun)
# Prints: array(0, dtype=float32)
print(compiled_fun(x)[0])

在 C++ 中匯入函式#

在 C++ 中匯入與執行函式,基本上與在 Python 中匯入與執行相同。首先依照 指示 建立一個使用 MLX 作為程式庫的簡單 C++ 專案。

接著從 Python 匯出一個簡單函式:

def fun(x, y):
    return mx.exp(x + y)

x = mx.array(1.0)
y = mx.array(1.0)
mx.export_function("fun.mlxfn", fun, x, y)

只需幾行程式碼即可在 C++ 中匯入並執行該函式:

auto fun = mx::import_function("fun.mlxfn");

auto inputs = {mx::array(1.0), mx::array(1.0)};
auto outputs = fun(inputs);

// Prints: array(2, dtype=float32)
std::cout << outputs[0] << std::endl;

匯入的函式在 C++ 中也可像 Python 一樣進行轉換。在 C++ 中呼叫匯入函式時,位置參數使用 std::vector<mx::array>,關鍵字參數使用 std::map<std::string, mx::array>

更多範例#

以下是更多完整範例,示範從 Python 匯出較複雜的函式,並在 C++ 中匯入與執行: