Module#
- class Module#
使用 MLX 建構神經網路的基底類別。
mlx.nn.layers中提供的所有層都繼承此類別,你的模型也應如此。Module可以在任意巢狀的 Python list 或 dict 中包含其他Module實例或mlx.core.array實例。之後可使用mlx.nn.Module.parameters()遞迴擷取所有mlx.core.array實例。此外,
Module有可訓練與不可訓練(稱為「凍結」)參數的概念。使用mlx.nn.value_and_grad()時,回傳的梯度僅針對可訓練參數。模組中的所有陣列預設為可訓練,除非透過freeze()加入「凍結」集合。import mlx.core as mx import mlx.nn as nn class MyMLP(nn.Module): def __init__(self, in_dims: int, out_dims: int, hidden_dims: int = 16): super().__init__() self.in_proj = nn.Linear(in_dims, hidden_dims) self.out_proj = nn.Linear(hidden_dims, out_dims) def __call__(self, x): x = self.in_proj(x) x = mx.maximum(x, 0) return self.out_proj(x) model = MyMLP(2, 1) # All the model parameters are created but since MLX is lazy by # default, they are not evaluated yet. Calling `mx.eval` actually # allocates memory and initializes the parameters. mx.eval(model.parameters()) # Setting a parameter to a new value is as simply as accessing that # parameter and assigning a new array to it. model.in_proj.weight = model.in_proj.weight * 2 mx.eval(model.parameters())
屬性
布林值,表示模型是否處於訓練模式。
模組的狀態字典
方法
Module.apply(map_fn[, filter_fn])使用提供的
map_fn映射所有參數,並立即用映射後的參數更新模組。Module.apply_to_modules(apply_fn)對此實例中的所有模組(包含此實例)套用函式。
回傳此 Module 實例的直接子孫模組。
將模型設為評估模式。
Module.filter_and_map(filter_fn[, map_fn, ...])使用
filter_fn遞迴篩選模組內容,只保留filter_fn回傳 true 的鍵和值。Module.freeze(*[, recurse, keys, strict])凍結 Module 的參數或其中一部分。
回傳不包含其他模組的子模組。
Module.load_weights(file_or_weights[, strict])從
.npz或.safetensors檔案,或清單更新模型權重。回傳此實例中所有模組的清單。
回傳此實例中所有模組及其點記法名稱的清單。
以字典與清單的巢狀結構遞迴回傳此 Module 的所有
mlx.core.array成員。Module.save_weights(file)將模型權重儲存到檔案。
Module.set_dtype(dtype[, predicate])設定模組參數的 dtype。
Module.train([mode])設定模型為訓練或非訓練模式。
以字典與清單的巢狀結構遞迴回傳此 Module 所有未凍結的
mlx.core.array成員。Module.unfreeze(*[, recurse, keys, strict])解除凍結 Module 的參數或其中一部分。
Module.update(parameters[, strict])將此 Module 的參數替換為提供的巢狀字典與清單中的參數。
Module.update_modules(modules[, strict])以巢狀字典與清單中的提供值替換此
Module實例的子模組。