You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

759 lines
24 KiB

"""
Python polyfills for torch.utils.pytree
"""
from __future__ import annotations
from collections import deque
from dataclasses import dataclass, field
from typing import Any, TYPE_CHECKING, TypeVar
import optree
import optree._C
import optree.utils
from optree import (
is_namedtuple,
is_namedtuple_class,
is_namedtuple_instance,
is_structseq,
is_structseq_class,
is_structseq_instance,
namedtuple_fields,
structseq_fields,
)
import torch.utils._cxx_pytree as cxx_pytree # noqa: F401
import torch.utils._pytree as python_pytree
from torch.utils._pytree import BUILTIN_TYPES, STANDARD_DICT_TYPES
from ..decorators import substitute_in_graph
if TYPE_CHECKING:
import builtins
from collections.abc import Callable, Iterable, Mapping
from typing_extensions import Self, TypeIs
from torch.utils._cxx_pytree import PyTree
__all__ = [
"is_namedtuple",
"is_namedtuple_class",
"is_namedtuple_instance",
"is_structseq",
"is_structseq_class",
"is_structseq_instance",
"namedtuple_fields",
"structseq_fields",
"treespec_leaf",
"treespec_tuple",
"treespec_dict",
"tree_is_leaf",
"tree_iter",
"tree_leaves",
"tree_flatten",
"tree_flatten_with_path",
"tree_structure",
"tree_unflatten",
]
_T = TypeVar("_T")
_KT = TypeVar("_KT")
_VT = TypeVar("_VT")
@substitute_in_graph(
optree._C.is_dict_insertion_ordered,
can_constant_fold_through=True,
)
def _(*args: Any, **kwargs: Any) -> bool:
# In namespace 'torch', the dictionary is always traversed in insertion order.
# This function returns True.
raise ValueError(
"Should not be called directly "
"because the original function will be called in the constant fold path."
)
__name = ""
for __name, __func in (
("is_namedtuple", is_namedtuple),
("is_namedtuple_class", is_namedtuple_class),
("is_namedtuple_instance", is_namedtuple_instance),
("is_structseq", is_structseq),
("is_structseq_class", is_structseq_class),
("is_structseq_instance", is_structseq_instance),
("namedtuple_fields", namedtuple_fields),
("structseq_fields", structseq_fields),
):
globals()[__name] = substitute_in_graph(
__func, # type: ignore[arg-type]
can_constant_fold_through=True,
)(__func.__python_implementation__) # type: ignore[attr-defined]
del __func
del __name
@substitute_in_graph(optree.tree_is_leaf, can_constant_fold_through=True) # type: ignore[arg-type]
def tree_is_leaf(
tree: PyTree,
/,
is_leaf: Callable[[PyTree], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = "",
) -> bool:
if (tree is None and none_is_leaf) or (is_leaf is not None and is_leaf(tree)):
return True
if optree.register_pytree_node.get(type(tree), namespace=namespace) is None:
return True
return False
@substitute_in_graph(optree.tree_iter, can_constant_fold_through=False) # type: ignore[arg-type]
def tree_iter(
tree: PyTree,
/,
is_leaf: Callable[[PyTree], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = "",
) -> Iterable[Any]:
stack = [tree]
while stack:
node = stack.pop()
if tree_is_leaf(
node,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
):
yield node
continue
children, *_ = optree.tree_flatten_one_level(
node,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
stack.extend(reversed(children))
@substitute_in_graph(optree.tree_leaves, can_constant_fold_through=True) # type: ignore[arg-type]
def tree_leaves(
tree: PyTree,
/,
is_leaf: Callable[[PyTree], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = "",
) -> list[Any]:
return list(
tree_iter(
tree,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
)
class _Asterisk(str):
__slots__ = ()
def __new__(cls) -> Self:
return super().__new__(cls, "*")
def __repr__(self) -> str:
return "*" # no quotes
_asterisk = _Asterisk()
del _Asterisk
@dataclass(frozen=True)
class PyTreeSpec:
"""Analog for :class:`optree.PyTreeSpec` in Python."""
_children: tuple[PyTreeSpec, ...]
_type: builtins.type | None
_metadata: Any
_entries: tuple[Any, ...]
_unflatten_func: Callable[[Any | None, Iterable[PyTree]], PyTree] | None
none_is_leaf: bool
namespace: str
num_nodes: int = field(init=False)
num_leaves: int = field(init=False)
num_children: int = field(init=False)
def __post_init__(self, /) -> None:
if self._type is None:
assert len(self._children) == 0
assert self._metadata is None
assert self._entries == ()
assert self._unflatten_func is None
num_nodes = 1
num_leaves = 1
num_children = 0
else:
assert callable(self._unflatten_func)
num_nodes = 1
num_leaves = 0
for child in self._children:
num_nodes += child.num_nodes
num_leaves += child.num_leaves
num_children = len(self._children)
object.__setattr__(self, "num_nodes", num_nodes)
object.__setattr__(self, "num_leaves", num_leaves)
object.__setattr__(self, "num_children", num_children)
def __repr__(self, /) -> str:
def helper(treespec: PyTreeSpec) -> str:
if treespec.is_leaf():
assert treespec.type is None
return _asterisk
assert treespec.type is not None
assert callable(treespec._unflatten_func)
children_representations = [
helper(subspec) for subspec in treespec._children
]
if (
treespec.type in BUILTIN_TYPES
or (treespec.type is type(None) and not self.none_is_leaf)
or optree.is_namedtuple_class(treespec.type)
or optree.is_structseq_class(treespec.type)
):
# pyrefly: ignore [bad-return]
return treespec._unflatten_func(
treespec._metadata,
children_representations,
)
return (
f"CustomTreeNode({treespec.type.__name__}[{treespec._metadata!r}], "
f"[{', '.join(children_representations)}])"
)
inner = [
str(helper(self)),
*(["NoneIsLeaf"] if self.none_is_leaf else []),
f"namespace={self.namespace!r}",
]
return f"PyTreeSpec({', '.join(inner)})"
def __len__(self, /) -> int:
return self.num_leaves
@property
def type(self, /) -> builtins.type | None:
return self._type
def is_leaf(self, /) -> bool:
return self.num_nodes == 1 and self.num_leaves == 1
def paths(self, /) -> list[tuple[Any, ...]]:
def helper(treespec: PyTreeSpec, path_prefix: list[Any]) -> None:
if treespec.is_leaf():
paths.append(path_prefix)
return
for entry, subspec in zip(
treespec._entries,
treespec._children,
strict=True,
):
helper(subspec, path_prefix + [entry])
paths: list[list[Any]] = []
helper(self, [])
return [tuple(path) for path in paths]
def accessors(self, /) -> list[optree.PyTreeAccessor]:
def helper(
treespec: PyTreeSpec,
entry_path_prefix: list[optree.PyTreeEntry],
) -> None:
if treespec.is_leaf():
entry_paths.append(entry_path_prefix)
return
node_type = treespec.type
assert node_type is not None
handler = optree.register_pytree_node.get(
node_type, namespace=treespec.namespace
)
assert handler is not None
kind: optree.PyTreeKind = handler.kind
path_entry_type: type[optree.PyTreeEntry] = handler.path_entry_type
for entry, subspec in zip(
treespec._entries,
treespec._children,
strict=True,
):
helper(
subspec,
entry_path_prefix + [path_entry_type(entry, node_type, kind)],
)
entry_paths: list[list[optree.PyTreeEntry]] = []
helper(self, [])
return [optree.PyTreeAccessor(path) for path in entry_paths]
def children(self, /) -> list[PyTreeSpec]:
return list(self._children)
def child(self, index: int, /) -> PyTreeSpec:
return self._children[index]
def entries(self, /) -> list[Any]:
return list(self._entries)
def entry(self, index: int, /) -> Any:
return self._entries[index]
def flatten_up_to(self, tree: PyTree, /) -> list[PyTree]:
def helper(
treespec: PyTreeSpec,
node: PyTree,
subtrees: list[PyTree],
) -> None:
if treespec.is_leaf():
subtrees.append(node)
return
node_type = type(node)
if treespec.type not in BUILTIN_TYPES:
# Always require custom node types to match exactly
if node_type != treespec.type:
raise ValueError(
f"Type mismatch; "
f"expected {treespec.type!r}, but got {node_type!r}.",
)
children, metadata, *_ = optree.tree_flatten_one_level(
node,
none_is_leaf=self.none_is_leaf,
namespace=self.namespace,
)
if len(children) != treespec.num_children:
raise ValueError(
f"Node arity mismatch; "
f"expected {treespec.num_children}, but got {len(children)}.",
)
if metadata != treespec._metadata:
raise ValueError(
f"Node context mismatch for custom node type {treespec.type!r}.",
)
else:
# For builtin dictionary types, we allow some flexibility
# Otherwise, we require exact matches
both_standard_dict = (
treespec.type in STANDARD_DICT_TYPES
and node_type in STANDARD_DICT_TYPES
)
if not both_standard_dict and node_type != treespec.type:
raise ValueError(
f"Node type mismatch; "
f"expected {treespec.type!r}, but got {node_type!r}.",
)
if len(node) != treespec.num_children:
raise ValueError(
f"Node arity mismatch; "
f"expected {treespec.num_children}, but got {len(node)}.",
)
if both_standard_dict:
# dictionary types are compatible with each other
expected_keys = treespec.entries()
got_key_set = set(node)
expected_key_set = set(expected_keys)
if got_key_set != expected_key_set:
missing_keys = expected_key_set.difference(got_key_set)
extra_keys = got_key_set.difference(expected_key_set)
message = ""
if missing_keys:
message += f"; missing key(s): {missing_keys}"
if extra_keys:
message += f"; extra key(s): {extra_keys}"
raise ValueError(f"Node keys mismatch{message}.")
children = [node[key] for key in expected_keys]
else:
# node_type is treespec.type
children, metadata, *_ = optree.tree_flatten_one_level(
node,
none_is_leaf=self.none_is_leaf,
namespace=self.namespace,
)
if (
node_type is not deque # ignore mismatch of `maxlen` for deque
) and metadata != treespec._metadata:
raise ValueError(
f"Node metadata mismatch for node type {treespec.type!r}; "
f"expected {treespec._metadata!r}, but got {metadata!r}.", # namedtuple type mismatch
)
for subtree, subspec in zip(children, treespec._children, strict=True):
helper(subspec, subtree, subtrees)
subtrees: list[PyTree] = []
helper(self, tree, subtrees)
return subtrees
def unflatten(self, leaves: Iterable[Any], /) -> PyTree:
if not isinstance(leaves, (list, tuple)):
leaves = list(leaves)
if len(leaves) != self.num_leaves:
raise ValueError(
f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} "
f"but the spec refers to a pytree that holds {self.num_leaves} "
f"items ({self}).",
)
if self.is_leaf():
return leaves[0]
# Recursively unflatten the children
start = 0
end = 0
subtrees = []
for subspec in self._children:
end += subspec.num_leaves
subtrees.append(subspec.unflatten(leaves[start:end]))
start = end
assert callable(self._unflatten_func)
return self._unflatten_func(self._metadata, subtrees)
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec | python_pytree.TreeSpec]:
return isinstance(obj, (PyTreeSpec, python_pytree.TreeSpec))
@substitute_in_graph( # type: ignore[arg-type]
optree.treespec_leaf,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def treespec_leaf(
*,
none_is_leaf: bool = False,
namespace: str = "", # unused
) -> PyTreeSpec:
return PyTreeSpec(
(),
None,
None,
(),
None,
none_is_leaf=none_is_leaf,
namespace="",
)
@substitute_in_graph( # type: ignore[arg-type]
optree.treespec_tuple,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def treespec_tuple(
iterable: Iterable[PyTreeSpec] = (),
/,
*,
none_is_leaf: bool = False,
namespace: str = "",
) -> PyTreeSpec:
children = tuple(iterable)
if any(not _is_pytreespec_instance(child) for child in children):
raise ValueError(f"Expected a tuple of PyTreeSpecs, got: {children!r}.")
if any(child.none_is_leaf != none_is_leaf for child in children):
raise ValueError(
"All children PyTreeSpecs must have the same `none_is_leaf` value "
f"as the parent; expected {none_is_leaf}, got: {children!r}.",
)
if any(child.namespace not in (namespace, "") for child in children):
raise ValueError(
"All children PyTreeSpecs must have the same `namespace` value "
f"as the parent; expected {namespace!r}, got: {children!r}.",
)
handler = optree.register_pytree_node.get(tuple, namespace=namespace)
assert handler is not None
return PyTreeSpec(
tuple(children),
tuple,
None,
tuple(range(len(children))),
handler.unflatten_func,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
@substitute_in_graph( # type: ignore[arg-type]
optree.treespec_dict,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def treespec_dict(
mapping: Mapping[Any, PyTreeSpec] | Iterable[tuple[Any, PyTreeSpec]] = (),
/,
*,
none_is_leaf: bool = False,
namespace: str = "",
**kwargs: PyTreeSpec,
) -> PyTreeSpec:
dct = dict(mapping, **kwargs)
if any(not _is_pytreespec_instance(child) for child in dct.values()):
raise ValueError(f"Expected a dictionary of TreeSpecs, got: {dct!r}.")
if any(child.none_is_leaf != none_is_leaf for child in dct.values()):
raise ValueError(
"All children PyTreeSpecs must have the same `none_is_leaf` value "
f"as the parent; expected {none_is_leaf}, got: {dct!r}.",
)
if any(child.namespace not in (namespace, "") for child in dct.values()):
raise ValueError(
"All children PyTreeSpecs must have the same `namespace` value "
f"as the parent; expected {namespace!r}, got: {dct!r}.",
)
(
children,
metadata,
entries,
unflatten_func,
) = optree.tree_flatten_one_level( # type: ignore[assignment,var-annotated]
dct, # type: ignore[arg-type]
none_is_leaf=none_is_leaf,
namespace=namespace,
)
return PyTreeSpec(
tuple(children), # type: ignore[arg-type]
dict,
metadata,
entries,
unflatten_func, # type: ignore[arg-type]
none_is_leaf=none_is_leaf,
namespace=namespace,
)
@substitute_in_graph( # type: ignore[arg-type]
optree.tree_flatten,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def tree_flatten(
tree: PyTree,
/,
is_leaf: Callable[[PyTree], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = "",
) -> tuple[list[Any], PyTreeSpec]:
def helper(node: PyTree, leaves: list[Any]) -> PyTreeSpec:
if tree_is_leaf(
node,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
):
leaves.append(node)
return PyTreeSpec(
(),
None,
None,
(),
None,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
(
children,
metadata,
entries,
unflatten_func,
) = optree.tree_flatten_one_level(
node,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
# Recursively flatten the children
subspecs = tuple(helper(child, leaves) for child in children)
return PyTreeSpec(
subspecs,
type(node),
metadata,
entries,
unflatten_func, # type: ignore[arg-type]
none_is_leaf=none_is_leaf,
namespace=namespace,
) # type: ignore[arg-type]
leaves: list[Any] = []
treespec = helper(tree, leaves)
return leaves, treespec
@substitute_in_graph( # type: ignore[arg-type]
optree._C.flatten,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def _C_flatten(
tree: PyTree,
/,
leaf_predicate: Callable[[PyTree], bool] | None = None,
none_is_leaf: bool = False,
namespace: str = "",
) -> tuple[list[Any], PyTreeSpec]:
return tree_flatten( # type: ignore[return-value]
tree,
is_leaf=leaf_predicate,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
@substitute_in_graph( # type: ignore[arg-type]
optree.tree_flatten_with_path,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def tree_flatten_with_path(
tree: PyTree,
/,
is_leaf: Callable[[PyTree], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = "",
) -> tuple[list[tuple[Any, ...]], list[Any], PyTreeSpec]:
leaves, treespec = tree_flatten(
tree,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
return treespec.paths(), leaves, treespec # type: ignore[return-value]
@substitute_in_graph( # type: ignore[arg-type]
optree._C.flatten_with_path,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def _C_flatten_with_path(
tree: PyTree,
/,
leaf_predicate: Callable[[PyTree], bool] | None = None,
none_is_leaf: bool = False,
namespace: str = "",
) -> tuple[list[tuple[Any, ...]], list[Any], PyTreeSpec]:
return tree_flatten_with_path( # type: ignore[return-value]
tree,
is_leaf=leaf_predicate,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
@substitute_in_graph( # type: ignore[arg-type]
optree.tree_structure,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def tree_structure(
tree: PyTree,
/,
is_leaf: Callable[[PyTree], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = "",
) -> PyTreeSpec:
return tree_flatten( # type: ignore[return-value]
tree,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)[1]
@substitute_in_graph( # type: ignore[arg-type]
optree.tree_unflatten,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[Any]) -> PyTree:
if not _is_pytreespec_instance(treespec):
raise TypeError(
f"Expected `treespec` to be an instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
)
return treespec.unflatten(leaves)
_none_registration = optree.register_pytree_node.get(type(None))
assert _none_registration is not None
@substitute_in_graph( # type: ignore[arg-type]
_none_registration.unflatten_func,
can_constant_fold_through=True,
skip_signature_check=True,
)
def none_unflatten(_: None, children: Iterable[_T], /) -> None:
if len(list(children)) != 0:
raise ValueError("Expected no children.")
return None
with optree.dict_insertion_ordered(False, namespace="torch"):
_dict_registration = optree.register_pytree_node.get(dict)
assert _dict_registration is not None
@substitute_in_graph( # type: ignore[arg-type]
_dict_registration.flatten_func,
can_constant_fold_through=True,
skip_signature_check=True,
)
def dict_flatten(
dct: dict[_KT, _VT], /
) -> tuple[list[_VT], tuple[list[_KT], list[_KT]], tuple[_KT, ...]]:
sorted_keys = optree.utils.total_order_sorted(dct)
values = [dct[key] for key in sorted_keys]
original_keys = list(dct)
return values, (original_keys, sorted_keys), tuple(sorted_keys)
@substitute_in_graph( # type: ignore[arg-type]
_dict_registration.unflatten_func,
can_constant_fold_through=True,
skip_signature_check=True,
)
def dict_unflatten(
metadata: tuple[list[_KT], list[_KT]],
values: Iterable[_VT],
/,
) -> dict[_KT, _VT]:
original_keys, sorted_keys = metadata
d = dict.fromkeys(original_keys)
d.update(zip(sorted_keys, values, strict=True))
return d # type: ignore[return-value]