Skip to content

metaopt/optree

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

437 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

OpTree

Python 3.9+ PyPI GitHub Workflow Status GitHub Workflow Status Codecov Documentation Status Downloads GitHub Repo Stars

Optimized PyTree Utilities.


Table of Contents


Installation

Install from PyPI (PyPI / Status):

pip3 install --upgrade optree

Install from conda-forge (conda-forge):

conda install conda-forge::optree

Install the latest version from GitHub:

pip3 install git+https://github.com/metaopt/optree.git

Or, clone this repo and install manually:

git clone --depth=1 https://github.com/metaopt/optree.git && cd optree

pip3 install .

The following options are available while building the Python C extension from source:

export CMAKE_COMMAND="/path/to/custom/cmake"
export CMAKE_BUILD_TYPE="Debug"
export CMAKE_CXX_STANDARD="20"  # C++17 is tested on Linux/macOS (C++20 is required on Windows)
export OPTREE_CXX_WERROR="OFF"
export _GLIBCXX_USE_CXX11_ABI="1"
export pybind11_DIR="/path/to/custom/pybind11"

pip3 install .

Compiling from source requires Python 3.9+, a C++ compiler (g++ / clang++ / icpx / cl.exe) that supports C++20, and a cmake installation.


PyTrees

A PyTree is a recursive structure that can be an arbitrarily nested Python container (e.g., tuple, list, dict, OrderedDict, namedtuple, etc.) and/or an opaque Python object. The key concepts of tree operations are tree flattening and its inverse (unflattening). Additional tree operations can be built from these two primitives (e.g., tree_map = tree_unflatten ∘ map ∘ tree_flatten).

Tree flattening traverses the entire tree in a left-to-right depth-first manner and returns the leaves in a deterministic order.

>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': 5, 'd': 6}
>>> optree.tree_flatten(tree)
([1, 2, 3, 4, 5, 6], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}))
>>> optree.tree_flatten(1)
([1], PyTreeSpec(*))
>>> optree.tree_flatten(None)
([], PyTreeSpec(None))
>>> optree.tree_map(lambda x: x**2, tree)
{'b': (4, [9, 16]), 'a': 1, 'c': 25, 'd': 36}

This usually implies that equal pytrees produce equal lists of leaves and the same tree structure. See also section Key Ordering for Dictionaries.

>>> {'a': [1, 2], 'b': [3]} == {'b': [3], 'a': [1, 2]}
True
>>> optree.tree_leaves({'a': [1, 2], 'b': [3]}) == optree.tree_leaves({'b': [3], 'a': [1, 2]})
True
>>> optree.tree_structure({'a': [1, 2], 'b': [3]}) == optree.tree_structure({'b': [3], 'a': [1, 2]})
True
>>> optree.tree_map(lambda x: x**2, {'a': [1, 2], 'b': [3]})
{'a': [1, 4], 'b': [9]}
>>> optree.tree_map(lambda x: x**2, {'b': [3], 'a': [1, 2]})
{'b': [9], 'a': [1, 4]}

Tip

Since OpTree v0.14.1, a new namespace optree.pytree is introduced as aliases for optree.tree_* functions. The following examples are equivalent to the above:

>>> import optree.pytree as pt
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': 5, 'd': 6}
>>> pt.flatten(tree)
([1, 2, 3, 4, 5, 6], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}))
>>> pt.flatten(1)
([1], PyTreeSpec(*))
>>> pt.flatten(None)
([], PyTreeSpec(None))
>>> pt.map(lambda x: x**2, tree)
{'b': (4, [9, 16]), 'a': 1, 'c': 25, 'd': 36}
>>> pt.map(lambda x: x**2, {'a': [1, 2], 'b': [3]})
{'a': [1, 4], 'b': [9]}
>>> pt.map(lambda x: x**2, {'b': [3], 'a': [1, 2]})
{'b': [9], 'a': [1, 4]}

Since OpTree v0.16.0, a re-export API optree.pytree.reexport(...) is available to create a new module that exports all the optree.pytree APIs with a given namespace. This is useful for downstream libraries to create their own pytree utilities without passing the namespace argument explicitly.

# foo/__init__.py
import optree
pytree = optree.pytree.reexport(namespace='foo')
del optree

# foo/bar.py
from foo import pytree

@pytree.dataclasses.dataclass
class Bar:
    a: int
    b: float

# User code
>>> import foo

>>> foo.pytree.flatten({'a': 1, 'b': 2, 'c': foo.bar.Bar(3, 4.0)})
(
    [1, 2, 3, 4.0],
    PyTreeSpec({'a': *, 'b': *, 'c': CustomTreeNode(Bar[()], [*, *])}, namespace='foo')
)

>>> foo.pytree.functools.reduce(lambda x, y: x * y, {'a': 1, 'b': 2, 'c': foo.bar.Bar(3, 4.0)})
24.0

Tree Nodes and Leaves

A tree is a collection of non-leaf nodes and leaf nodes, where the leaf nodes are opaque objects having no children to flatten. optree.tree_flatten(...) will flatten the tree and return a list of leaf nodes while the non-leaf nodes will be stored in the tree structure specification.

Built-in PyTree Node Types

OpTree out-of-the-box supports the following Python container types in the global registry:

These types are considered non-leaf nodes in the tree. Python objects whose type is not registered are treated as leaf nodes. The registry lookup uses the is operator to determine whether the type matches, so subclasses need to be registered explicitly; otherwise, their instances will be treated as leaves. The NoneType is a special case discussed in section None is Non-leaf Node vs. None is Leaf.

Registering a Container-like Custom Type as Non-leaf Nodes

A container-like Python type can be registered in the type registry with a pair of functions that specify:

  • flatten_func(container) -> (children, metadata, entries): convert an instance of the container type to a (children, metadata, entries) triple, where children is an iterable of subtrees and entries is an iterable of path entries of the container (e.g., indices or keys).
  • unflatten_func(metadata, children) -> container: convert such a pair back to an instance of the container type.

The metadata is some necessary data apart from the children to reconstruct the container, e.g., the keys of the dictionary (the children are values).

The entries can be omitted (only return a pair) or are optional to implement (return None). If so, use range(len(children)) (i.e., flat indices) as path entries of the current node. The signature for the flatten function can be one of the following:

  • flatten_func(container) -> (children, metadata, entries)
  • flatten_func(container) -> (children, metadata, None)
  • flatten_func(container) -> (children, metadata)

The following examples show how to register custom types and use them with tree_flatten and tree_map. Please refer to section Notes about the PyTree Type Registry for more information.

# Register a custom type with lambda functions
optree.register_pytree_node(
    set,
    lambda s: (sorted(s), None),        # flatten: (set) -> (children, metadata)
    lambda _, children: set(children),  # unflatten: (metadata, children) -> set
    namespace='set',
)

# Register a custom type into a namespace with accessor support
import types

# This can be whatever your container type is.
class MyContainer(types.SimpleNamespace):
    """A simple container type based on SimpleNamespace."""

# (Optional) Define a custom path entry type for your container for accessor support.
# Here we showcase how to define one. In practice, you can use the built-in `optree.GetAttrEntry`.
class MyContainerEntry(optree.PyTreeEntry):
    def __call__(self, obj):
        return getattr(obj, self.entry)

    def codify(self, node=''):
        return f'{node}.{self.entry}'

optree.register_pytree_node(
    MyContainer,
    flatten_func=lambda ct: (                 # flatten: (MyContainer) -> (children, metadata, entries)
        list(vars(ct).values()),
        list(vars(ct).keys()),
        list(vars(ct).keys()),
    ),
    unflatten_func=lambda keys, values: (     # unflatten: (metadata, children) -> MyContainer
        MyContainer(**dict(zip(keys, values)))
    ),
    path_entry_type=MyContainerEntry,
    namespace='mycontainer',
)
>>> tree = {'config': MyContainer(lr=0.01, momentum=0.9), 'steps': 1000}

# Flatten without specifying the namespace
>>> optree.tree_flatten(tree)  # `MyContainer`s are leaf nodes
([MyContainer(lr=0.01, momentum=0.9), 1000], PyTreeSpec({'config': *, 'steps': *}))

# Flatten with the namespace
>>> leaves, treespec = optree.tree_flatten(tree, namespace='mycontainer')
>>> leaves, treespec
(
    [0.01, 0.9, 1000],
    PyTreeSpec(
        {
            'config': CustomTreeNode(MyContainer[['lr', 'momentum']], [*, *]),
            'steps': *
        },
        namespace='mycontainer'
    )
)

# Custom `entries` are defined as attribute names
>>> optree.tree_paths(tree, namespace='mycontainer')
[('config', 'lr'), ('config', 'momentum'), ('steps',)]

# Custom path entry type defines the pytree access behavior
>>> optree.tree_accessors(tree, namespace='mycontainer')
[
    PyTreeAccessor(*['config'].lr, (MappingEntry(key='config', type=<class 'dict'>), MyContainerEntry(entry='lr', type=<class 'MyContainer'>))),
    PyTreeAccessor(*['config'].momentum, (MappingEntry(key='config', type=<class 'dict'>), MyContainerEntry(entry='momentum', type=<class 'MyContainer'>))),
    PyTreeAccessor(*['steps'], (MappingEntry(key='steps', type=<class 'dict'>),))
]

# Unflatten back to a copy of the original object
>>> optree.tree_unflatten(treespec, leaves)
{'config': MyContainer(lr=0.01, momentum=0.9), 'steps': 1000}

Users can also extend the pytree registry by decorating the custom class and defining an instance method __tree_flatten__ and a class method __tree_unflatten__.

from collections import UserDict

@optree.register_pytree_node_class(namespace='mydict')
class MyDict(UserDict):
    TREE_PATH_ENTRY_TYPE = optree.MappingEntry  # used by accessor APIs

    def __tree_flatten__(self):  # -> (children, metadata, entries)
        reversed_keys = sorted(self.keys(), reverse=True)
        return (
            [self[key] for key in reversed_keys],  # children
            reversed_keys,  # metadata
            reversed_keys,  # entries
        )

    @classmethod
    def __tree_unflatten__(cls, metadata, children):
        return cls(zip(metadata, children))
>>> tree = MyDict(b=4, a=(2, 3), c=MyDict({'d': 5, 'f': 6}))

# Flatten without specifying the namespace
>>> optree.tree_flatten_with_path(tree)  # `MyDict`s are leaf nodes
(
    [()],
    [MyDict(b=4, a=(2, 3), c=MyDict({'d': 5, 'f': 6}))],
    PyTreeSpec(*)
)

# Flatten with the namespace
>>> optree.tree_flatten_with_path(tree, namespace='mydict')
(
    [('c', 'f'), ('c', 'd'), ('b',), ('a', 0), ('a', 1)],
    [6, 5, 4, 2, 3],
    PyTreeSpec(
        CustomTreeNode(MyDict[['c', 'b', 'a']], [CustomTreeNode(MyDict[['f', 'd']], [*, *]), *, (*, *)]),
        namespace='mydict'
    )
)
>>> optree.tree_flatten_with_accessor(tree, namespace='mydict')
(
    [
        PyTreeAccessor(*['c']['f'], (MappingEntry(key='c', type=<class 'MyDict'>), MappingEntry(key='f', type=<class 'MyDict'>))),
        PyTreeAccessor(*['c']['d'], (MappingEntry(key='c', type=<class 'MyDict'>), MappingEntry(key='d', type=<class 'MyDict'>))),
        PyTreeAccessor(*['b'], (MappingEntry(key='b', type=<class 'MyDict'>),)),
        PyTreeAccessor(*['a'][0], (MappingEntry(key='a', type=<class 'MyDict'>), SequenceEntry(index=0, type=<class 'tuple'>))),
        PyTreeAccessor(*['a'][1], (MappingEntry(key='a', type=<class 'MyDict'>), SequenceEntry(index=1, type=<class 'tuple'>)))
    ],
    [6, 5, 4, 2, 3],
    PyTreeSpec(
        CustomTreeNode(MyDict[['c', 'b', 'a']], [CustomTreeNode(MyDict[['f', 'd']], [*, *]), *, (*, *)]),
        namespace='mydict'
    )
)

Notes about the PyTree Type Registry

There are several key attributes of the pytree type registry:

  1. The type registry is per-interpreter. Registering a custom type affects all modules that use OpTree in the same interpreter. Each interpreter (including subinterpreters) maintains its own registry, while child processes forked via multiprocessing inherit a copy.

Warning

For safety reasons, a namespace must be specified while registering a custom type. It is used to isolate the behavior of flattening and unflattening a pytree node type. This is to prevent accidental collisions between different libraries that may register the same type.

  1. Duplicate registration is not allowed. Registering the same type in the same namespace a second time raises an error. To update the behavior, first call unregister_pytree_node(cls, namespace=...) and then re-register. Alternatively, register the type under a different namespace.

    [!WARNING] Any PyTreeSpec objects created before the unregistration still hold a reference to the old registration. Unflattening such a PyTreeSpec will use the old unflatten_func, not the newly registered one.

  2. Built-in types cannot be re-registered. The behavior of the types listed in Built-in PyTree Node Types (e.g., key-sorted traversal for dict and collections.defaultdict) is fixed.

  3. Inherited subclasses are not implicitly registered. The registry lookup uses type(obj) is registered_type rather than isinstance(obj, registered_type). Users need to register the subclasses explicitly. To register all subclasses, it is easy to implement with metaclass or __init_subclass__, for example:

    from collections import UserDict
    
    @optree.register_pytree_node_class(namespace='mydict')
    class MyDict(UserDict):
        TREE_PATH_ENTRY_TYPE = optree.MappingEntry  # used by accessor APIs
    
        def __init_subclass__(cls):  # define this in the base class
            super().__init_subclass__()
            # Register a subclass to namespace 'mydict'
            optree.register_pytree_node_class(cls, namespace='mydict')
    
        def __tree_flatten__(self):  # -> (children, metadata, entries)
            reversed_keys = sorted(self.keys(), reverse=True)
            return (
                [self[key] for key in reversed_keys],  # children
                reversed_keys,  # metadata
                reversed_keys,  # entries
            )
    
        @classmethod
        def __tree_unflatten__(cls, metadata, children):
            return cls(zip(metadata, children))
    
    # Subclasses will be automatically registered in namespace 'mydict'
    class MyAnotherDict(MyDict):
        pass
    >>> tree = MyDict(b=4, a=(2, 3), c=MyAnotherDict({'d': 5, 'f': 6}))
    >>> optree.tree_flatten_with_path(tree, namespace='mydict')
    (
        [('c', 'f'), ('c', 'd'), ('b',), ('a', 0), ('a', 1)],
        [6, 5, 4, 2, 3],
        PyTreeSpec(
            CustomTreeNode(MyDict[['c', 'b', 'a']], [CustomTreeNode(MyAnotherDict[['f', 'd']], [*, *]), *, (*, *)]),
            namespace='mydict'
        )
    )
    >>> optree.tree_accessors(tree, namespace='mydict')
    [
        PyTreeAccessor(*['c']['f'], (MappingEntry(key='c', type=<class 'MyDict'>), MappingEntry(key='f', type=<class 'MyAnotherDict'>))),
        PyTreeAccessor(*['c']['d'], (MappingEntry(key='c', type=<class 'MyDict'>), MappingEntry(key='d', type=<class 'MyAnotherDict'>))),
        PyTreeAccessor(*['b'], (MappingEntry(key='b', type=<class 'MyDict'>),)),
        PyTreeAccessor(*['a'][0], (MappingEntry(key='a', type=<class 'MyDict'>), SequenceEntry(index=0, type=<class 'tuple'>))),
        PyTreeAccessor(*['a'][1], (MappingEntry(key='a', type=<class 'MyDict'>), SequenceEntry(index=1, type=<class 'tuple'>)))
    ]
  4. Beware of infinite recursion in custom flatten functions. The returned children are recursively flattened and may have the same type as the current node. Ensure your flatten function has a proper termination condition.

    import numpy as np
    import torch
    
    optree.register_pytree_node(
        np.ndarray,
        # Children are nested lists of Python objects
        lambda array: (np.atleast_1d(array).tolist(), array.ndim == 0),
        lambda scalar, rows: np.asarray(rows) if not scalar else np.asarray(rows[0]),
        namespace='numpy1',
    )
    
    optree.register_pytree_node(
        np.ndarray,
        # Returns a list of `np.ndarray`s without termination condition -> RecursionError!
        lambda array: ([array.ravel()], array.shape),
        lambda shape, children: children[0].reshape(shape),
        namespace='numpy2',
    )

None is Non-leaf Node vs. None is Leaf

The None object is Python's singleton for "no value" — analogous to null in other languages but also commonly used as a sentinel or implicit return value.

By default, the None object is considered a non-leaf node in the tree with arity 0, i.e., a non-leaf node that has no children. This is like the behavior of an empty tuple. While flattening a tree, it will remain in the tree structure definitions rather than in the leaves list.

>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> optree.tree_flatten(tree)
([1, 2, 3, 4, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *}))
>>> optree.tree_flatten(tree, none_is_leaf=True)
([1, 2, 3, 4, None, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf))
>>> optree.tree_flatten(1)
([1], PyTreeSpec(*))
>>> optree.tree_flatten(None)
([], PyTreeSpec(None))
>>> optree.tree_flatten(None, none_is_leaf=True)
([None], PyTreeSpec(*, NoneIsLeaf))

OpTree provides a keyword argument none_is_leaf to determine whether to consider the None object as a leaf, like other opaque objects. If none_is_leaf=True, the None object will be placed in the leaves list. Otherwise, the None object will remain in the tree structure specification.

>>> import torch

>>> linear = torch.nn.Linear(in_features=3, out_features=2, bias=False)
>>> linear._parameters  # a container has None
OrderedDict({
    'weight': Parameter containing:
              tensor([[-0.6677,  0.5209,  0.3295],
                      [-0.4876, -0.3142,  0.1785]], requires_grad=True),
    'bias': None
})

>>> optree.tree_map(torch.zeros_like, linear._parameters)
OrderedDict({
    'weight': tensor([[0., 0., 0.],
                      [0., 0., 0.]]),
    'bias': None
})

>>> optree.tree_map(torch.zeros_like, linear._parameters, none_is_leaf=True)
Traceback (most recent call last):
    ...
TypeError: zeros_like(): argument 'input' (position 1) must be Tensor, not NoneType

>>> optree.tree_map(lambda t: torch.zeros_like(t) if t is not None else 0, linear._parameters, none_is_leaf=True)
OrderedDict({
    'weight': tensor([[0., 0., 0.],
                      [0., 0., 0.]]),
    'bias': 0
})

Key Ordering for Dictionaries

The built-in Python dictionary (builtins.dict) is a mapping whose leaves are its values. Since Python 3.7, dict is guaranteed to be insertion ordered, but the equality operator (==) ignores key order. To ensure referential transparency — "equal dict" implies "equal ordering of leaves" — the leaves (values) are returned in key-sorted order. The same applies to collections.defaultdict.

>>> optree.tree_flatten({'a': [1, 2], 'b': [3]})
([1, 2, 3], PyTreeSpec({'a': [*, *], 'b': [*]}))
>>> optree.tree_flatten({'b': [3], 'a': [1, 2]})
([1, 2, 3], PyTreeSpec({'a': [*, *], 'b': [*]}))

To preserve insertion order during pytree traversal, use collections.OrderedDict, which considers key order in equality checks:

>>> OrderedDict([('a', [1, 2]), ('b', [3])]) == OrderedDict([('b', [3]), ('a', [1, 2])])
False
>>> optree.tree_flatten(OrderedDict([('a', [1, 2]), ('b', [3])]))
([1, 2, 3], PyTreeSpec(OrderedDict({'a': [*, *], 'b': [*]})))
>>> optree.tree_flatten(OrderedDict([('b', [3]), ('a', [1, 2])]))
([3, 1, 2], PyTreeSpec(OrderedDict({'b': [*], 'a': [*, *]})))

To flatten builtins.dict and collections.defaultdict objects with the insertion order preserved, use the dict_insertion_ordered context manager:

>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> optree.tree_flatten(tree)
(
    [1, 2, 3, 4, 5],
    PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *})
)
>>> with optree.dict_insertion_ordered(True, namespace='some-namespace'):
...     optree.tree_flatten(tree, namespace='some-namespace')
(
    [2, 3, 4, 1, 5],
    PyTreeSpec({'b': (*, [*, *]), 'a': *, 'c': None, 'd': *}, namespace='some-namespace')
)

Since OpTree v0.9.0, the key order of the reconstructed output dictionaries from tree_unflatten is guaranteed to be consistent with the key order of the input dictionaries in tree_flatten.

>>> leaves, treespec = optree.tree_flatten({'b': [3], 'a': [1, 2]})
>>> leaves, treespec
([1, 2, 3], PyTreeSpec({'a': [*, *], 'b': [*]}))
>>> optree.tree_unflatten(treespec, leaves)
{'b': [3], 'a': [1, 2]}
>>> optree.tree_map(lambda x: x, {'b': [3], 'a': [1, 2]})
{'b': [3], 'a': [1, 2]}
>>> optree.tree_map(lambda x: x + 1, {'b': [3], 'a': [1, 2]})
{'b': [4], 'a': [2, 3]}

This property is also preserved during serialization/deserialization.

>>> leaves, treespec = optree.tree_flatten({'b': [3], 'a': [1, 2]})
>>> leaves, treespec
([1, 2, 3], PyTreeSpec({'a': [*, *], 'b': [*]}))
>>> restored_treespec = pickle.loads(pickle.dumps(treespec))
>>> optree.tree_unflatten(treespec, leaves)
{'b': [3], 'a': [1, 2]}
>>> optree.tree_unflatten(restored_treespec, leaves)
{'b': [3], 'a': [1, 2]}

Note

The dict keys are not required to be comparable (sortable) or of a single type. Keys are sorted by key=lambda k: k first if possible, otherwise falling back to key=lambda k: (f'{k.__class__.__module__}.{k.__class__.__qualname__}', k). This handles most cases.

>>> sorted({1: 2, 1.5: 1}.keys())
[1, 1.5]
>>> sorted({'a': 3, 1: 2, 1.5: 1}.keys())
Traceback (most recent call last):
    ...
TypeError: '<' not supported between instances of 'int' and 'str'
>>> sorted({'a': 3, 1: 2, 1.5: 1}.keys(), key=lambda k: (f'{k.__class__.__module__}.{k.__class__.__qualname__}', k))
[1.5, 1, 'a']

Changelog

See CHANGELOG.md.


License

OpTree is released under the Apache License 2.0.

OpTree is based on JAX's implementation of the PyTree utility, with significant refactoring and several improvements. The original licenses can be found at JAX's Apache License 2.0 and Tensorflow's Apache License 2.0.

Packages

 
 
 

Contributors