Module

目錄

Module#

class Module#

使用 MLX 建構神經網路的基底類別。

mlx.nn.layers 中提供的所有層都繼承此類別,你的模型也應如此。

Module 可以在任意巢狀的 Python list 或 dict 中包含其他 Module 實例或 mlx.core.array 實例。之後可使用 mlx.nn.Module.parameters() 遞迴擷取所有 mlx.core.array 實例。

此外,Module 有可訓練與不可訓練(稱為「凍結」)參數的概念。使用 mlx.nn.value_and_grad() 時,回傳的梯度僅針對可訓練參數。模組中的所有陣列預設為可訓練,除非透過 freeze() 加入「凍結」集合。

import mlx.core as mx
import mlx.nn as nn

class MyMLP(nn.Module):
    def __init__(self, in_dims: int, out_dims: int, hidden_dims: int = 16):
        super().__init__()

        self.in_proj = nn.Linear(in_dims, hidden_dims)
        self.out_proj = nn.Linear(hidden_dims, out_dims)

    def __call__(self, x):
        x = self.in_proj(x)
        x = mx.maximum(x, 0)
        return self.out_proj(x)

model = MyMLP(2, 1)

# All the model parameters are created but since MLX is lazy by
# default, they are not evaluated yet. Calling `mx.eval` actually
# allocates memory and initializes the parameters.
mx.eval(model.parameters())

# Setting a parameter to a new value is as simply as accessing that
# parameter and assigning a new array to it.
model.in_proj.weight = model.in_proj.weight * 2
mx.eval(model.parameters())

屬性

Module.training

布林值,表示模型是否處於訓練模式。

Module.state

模組的狀態字典

方法

Module.apply(map_fn[, filter_fn])

使用提供的 map_fn 映射所有參數,並立即用映射後的參數更新模組。

Module.apply_to_modules(apply_fn)

對此實例中的所有模組(包含此實例)套用函式。

Module.children()

回傳此 Module 實例的直接子孫模組。

Module.eval()

將模型設為評估模式。

Module.filter_and_map(filter_fn[, map_fn, ...])

使用 filter_fn 遞迴篩選模組內容,只保留 filter_fn 回傳 true 的鍵和值。

Module.freeze(*[, recurse, keys, strict])

凍結 Module 的參數或其中一部分。

Module.leaf_modules()

回傳不包含其他模組的子模組。

Module.load_weights(file_or_weights[, strict])

.npz.safetensors 檔案,或清單更新模型權重。

Module.modules()

回傳此實例中所有模組的清單。

Module.named_modules()

回傳此實例中所有模組及其點記法名稱的清單。

Module.parameters()

以字典與清單的巢狀結構遞迴回傳此 Module 的所有 mlx.core.array 成員。

Module.save_weights(file)

將模型權重儲存到檔案。

Module.set_dtype(dtype[, predicate])

設定模組參數的 dtype。

Module.train([mode])

設定模型為訓練或非訓練模式。

Module.trainable_parameters()

以字典與清單的巢狀結構遞迴回傳此 Module 所有未凍結的 mlx.core.array 成員。

Module.unfreeze(*[, recurse, keys, strict])

解除凍結 Module 的參數或其中一部分。

Module.update(parameters[, strict])

將此 Module 的參數替換為提供的巢狀字典與清單中的參數。

Module.update_modules(modules[, strict])

以巢狀字典與清單中的提供值替換此 Module 實例的子模組。