mlx.utils.tree_map#
- tree_map(fn: Callable, tree: Any, *rest: Any, is_leaf: Callable | None = None) Any#
Applies
fnto the leaves of the Python treetreeand returns a new collection with the results.If
restis provided, every item is assumed to be a superset oftreeand the corresponding leaves are provided as extra positional arguments tofn. In that respect,tree_map()is closer toitertools.starmap()than tomap().The keyword argument
is_leafdecides what constitutes a leaf fromtreesimilar totree_flatten().import mlx.nn as nn from mlx.utils import tree_map model = nn.Linear(10, 10) print(model.parameters().keys()) # dict_keys(['weight', 'bias']) # square the parameters model.update(tree_map(lambda x: x*x, model.parameters()))
- 参数:
fn (callable) -- The function that processes the leaves of the tree.
tree (Any) -- The main Python tree that will be iterated upon.
rest (tuple[Any]) -- Extra trees to be iterated together with
tree.is_leaf (callable, optional) -- An optional callable that returns
Trueif the passed object is considered a leaf orFalseotherwise.
- 返回:
A Python tree with the new values returned by
fn.