From 7fd557107cc75bb6fc781e322fc3342b9e805d9a Mon Sep 17 00:00:00 2001 From: Alphaharrius Date: Tue, 26 May 2026 02:31:40 +0800 Subject: [PATCH] Establish foundational RG tree API and visualization support Introduce qrg.tree.Node as a first-class tree abstraction for renormalization group workflows, with a coherent set of APIs for growth, navigation, addressing, structural inspection, subtree manipulation, transform access, and visualization. This commit defines the core tree model around parent/child structure and leaf-key addressing, and adds the main operational interfaces needed to work with RG trees in practice: branches(), leaf(), parent(), root(), is_root(), is_leaf(), path(), trace(), cut(), find(), and get_transform(). Together these make the tree usable as a navigable and queryable computational object rather than only a storage container for tensors. The plotting layer is also established here, with tree rendering backends for networkx, matplotlib, and plotly. The implementation supports both structural graph extraction and direct visual inspection of RG trees, including default tree rendering behavior and backend-specific output. Top-level package exports and dependency metadata are updated so the tree API is part of the public package surface and its plotting backends are declared explicitly. Tests, typing, lint, and formatting are brought into alignment with the expanded API. Add a real Plotly tree backend while preserving matplotlib as the default for plot("tree"), and update tests to cover tree traversal, search, branch detachment, path/trace round-tripping, Plotly rendering, and transform composition behavior. Document get_transform() with a warning that inverse/backward use can create very large intermediates, especially for momentum-block tensors, and recommend staged application via composed=False for actual transport workflows. --- .gitignore | 3 + pyproject.toml | 3 + src/qrg/__init__.py | 4 +- src/qrg/py.typed | 0 src/qrg/tree.py | 1335 +++++++++++++++++++++++++++++++++++++++++++ tests/test_tree.py | 712 +++++++++++++++++++++++ 6 files changed, 2056 insertions(+), 1 deletion(-) create mode 100644 src/qrg/py.typed create mode 100644 src/qrg/tree.py create mode 100644 tests/test_tree.py diff --git a/.gitignore b/.gitignore index 4d3f601..77aa175 100644 --- a/.gitignore +++ b/.gitignore @@ -213,3 +213,6 @@ cython_debug/ marimo/_static/ marimo/_lsp/ __marimo__/ + +# ipynb +*.ipynb diff --git a/pyproject.toml b/pyproject.toml index 2785a60..f8b563b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,9 @@ authors = [ { name = "eyjafjallac" }, ] dependencies = [ + "matplotlib>=3.10.7", + "networkx>=3.0", + "plotly>=6.5.0", "qten>=0.4.2", ] [project.optional-dependencies] diff --git a/src/qrg/__init__.py b/src/qrg/__init__.py index 00695fa..c258ffe 100644 --- a/src/qrg/__init__.py +++ b/src/qrg/__init__.py @@ -1,5 +1,7 @@ """qRG package.""" -__all__ = ["__version__"] +from .tree import Node + +__all__ = ["Node", "__version__"] __version__ = "0.1.0" diff --git a/src/qrg/py.typed b/src/qrg/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/qrg/tree.py b/src/qrg/tree.py new file mode 100644 index 0000000..14b1b1c --- /dev/null +++ b/src/qrg/tree.py @@ -0,0 +1,1335 @@ +""" +Tree-building helpers for renormalization workflows. + +This module provides a lightweight `Node` object that stores named +`qten.Tensor` instances together with expansion methods for growing a tree of +derived nodes. Each node behaves like a mapping from string keys to tensors, +and each growth step records a `_LeavePath` describing both the child node and +the transform tensor that connects the current node's target tensor to the +child target tensor. + +The intended workflow is: + +1. Create a root node with [`Node.new`][qrg.tree.Node.new]. +2. Register one or more growth methods that expand a node into child nodes. +3. Call [`Node.grow`][qrg.tree.Node.grow] to materialize leaf paths. +4. Traverse children through [`Node.leaf`][qrg.tree.Node.leaf] and + [`Node.parent`][qrg.tree.Node.parent]. + +Notes +----- +The public growth-method contract is intentionally simple: a growth method +returns `dict[str, Node]`. The returned dictionary key is used as the lookup +key for the new leaf under the current node, while the child node's private +`_target` field determines which tensor inside the child becomes the leaf-path +transform. + +Examples +-------- +Create and grow a minimal tree: + +>>> import qten +>>> from qrg.tree import Node +>>> root_tensor = qten.eye(2) +>>> child_tensor = qten.eye(2) +>>> def grow_once(node: Node) -> dict[str, Node]: +... return { +... "child": Node( +... name="child", +... _target="iso", +... _data={"iso": child_tensor}, +... _parent=None, +... _leaves={}, +... _methods={}, +... ) +... } +>>> root = Node.new(name="root", target=root_tensor, methods={"step": grow_once}) +>>> root.grow("step") +Node(...) +>>> path = root.leaf("child") +>>> path.node.parent() is root +True +""" + +import importlib.util +import re +import textwrap +from collections.abc import Callable, Iterator, Mapping +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Self, cast + +import matplotlib.patheffects as pe +import matplotlib.pyplot as plt +import networkx as nx # type: ignore[import-untyped] +import plotly.graph_objects as go # type: ignore[import-untyped] +import qten +from matplotlib.patches import PathPatch +from matplotlib.path import Path as MplPath + +TensorLike = qten.Tensor[Any] + + +@dataclass +class _LeavePath: + """ + Leaf metadata stored on a parent node after growth. + + Parameters + ---------- + node : Node + Child node reached from the parent node. + transform : qten.Tensor + Tensor that maps the parent node's target tensor into the child node's + target tensor basis. + + Notes + ----- + The expected interpretation is + `transform.h(-2, -1) @ target @ transform`, where `target` is the parent + node's target tensor. This class does not perform that contraction itself; + it only stores the bookkeeping needed to do so elsewhere. + + Examples + -------- + A `_LeavePath` is usually created by [`Node.grow`][qrg.tree.Node.grow] + rather than instantiated directly. + """ + + node: "Node" + transform: TensorLike + + +_GrowMethod = Callable[["Node"], dict[str, "Node"]] +_GROWTH_MARKER = "_qrg_growth_name" + + +@dataclass +class Node(Mapping[str, TensorLike], qten.plottings.Plottable): + """ + Tree node containing tensors, growth methods, and derived leaf paths. + + `Node` serves two roles: + + - A mapping from string names to `qten.Tensor` objects. + - A tree vertex that can spawn children through registered growth methods. + + Parameters + ---------- + name : str + Human-readable identifier for the node. + _target : str + Key in `_data` identifying the node's target tensor. This key is used + by the parent node when extracting the transform tensor for a leaf path. + _data : dict[str, qten.Tensor] + Tensor payload owned by the node. + _parent : Node or None + Parent node in the tree. The root node stores `None`. + _leaves : dict[str, _LeavePath] + Cached leaf paths produced by growth methods. + _methods : dict[str, callable] + Registered growth methods. Each method receives the current node and + returns a mapping from leaf lookup names to child nodes. + + Notes + ----- + `Node` inherits from `Mapping`, so `node[key]`, `key in node`, iteration, + and `len(node)` are delegated to `_data`. + + Growth methods are owned by the root node. Registering a method from any + node forwards the registration to the root, and growth-method lookup also + resolves through the root only. + + The tree API distinguishes between: + + - The leaf lookup key in `_leaves`, chosen by the growth method's returned + dictionary key. + - The child node's `_target`, which identifies the tensor inside the child + that becomes the stored `_LeavePath.transform`. + + When a node is pickled, `_methods` is excluded from the serialized state. + After unpickling, the method registry is restored as an empty dictionary + and must be repopulated explicitly by the caller. + + Through [`plot()`][qten.plottings.Plottable.plot], nodes also support a + `tree` plot method with backends `networkx`, `matplotlib`, and `plotly`. + The `networkx` backend returns a directed graph for the whole tree rooted + at `node.root()`, with the current node highlighted in node attributes and + layout positions stored in graph-level metadata. The `matplotlib` and + `plotly` backends render that tree with boxed, wrapped labels and + highlight the current node directly in the figure. Calling + `node.plot("tree")` without an explicit backend defaults to the + `matplotlib` tree renderer. + + Examples + -------- + Build a root node and access its target tensor by mapping key: + + >>> import qten + >>> root = Node.new(name="root", target=qten.eye(2), methods={}) + >>> "root" in root + True + >>> root["root"] + Tensor(...) + """ + + name: str + _target: str + _data: dict[str, TensorLike] + _parent: "Node | None" + _leaves: dict[str, _LeavePath] + _methods: dict[str, _GrowMethod] + + def __getitem__(self, key: str) -> TensorLike: + """ + Return a tensor stored on this node by name. + + Parameters + ---------- + key : str + Tensor name to retrieve from the node payload. + + Returns + ------- + qten.Tensor + Stored tensor associated with `key`. + + Raises + ------ + KeyError + If `key` is not present in `_data`. + """ + return self._data[key] + + def __iter__(self) -> Iterator[str]: + """ + Iterate over stored tensor names. + + Returns + ------- + iterator of str + Iterator over `_data` keys in insertion order. + """ + return iter(self._data) + + def __len__(self) -> int: + """ + Return the number of tensors stored on this node. + + Returns + ------- + int + Number of entries in `_data`. + """ + return len(self._data) + + def plot( + self, + method: str, + backend: str | None = None, + *args: object, + **kwargs: object, + ) -> object: + """ + Dispatch plotting calls, defaulting tree plots to matplotlib output. + + Parameters + ---------- + method : str + Plot method name. + backend : str or None, optional + Backend name requested by the caller. When omitted, tree plots + default to `matplotlib` while other plot methods keep the usual + `plotly` default. + args + Positional arguments forwarded to the selected backend. + kwargs + Keyword arguments forwarded to the selected backend. + + Returns + ------- + object + Backend-specific plot object returned by the selected renderer. + """ + if backend is None or backend == "": + backend = "matplotlib" if method == "tree" else "plotly" + return super().plot(method, backend, *args, **kwargs) + + def __getstate__(self) -> dict[str, object]: + """ + Return pickle state for the node without the method registry. + + Returns + ------- + dict + Instance state with `_methods` replaced by an empty dictionary. + + Notes + ----- + Excluding `_methods` keeps serialization independent of whether the + registered callables are picklable. Callers are expected to restore the + root method registry manually after loading. + """ + state = cast(dict[str, object], self.__dict__.copy()) + state["_methods"] = {} + return state + + def __setstate__(self, state: dict[str, object]) -> None: + """ + Restore node state from pickle data. + + Parameters + ---------- + state : dict + Previously pickled instance state. + + Notes + ----- + `_methods` is always restored to an empty dictionary if it is missing + or was intentionally stripped during pickling. + """ + self.__dict__.update(state) + self._methods = {} + + def register_method(self, method_name: str, method: _GrowMethod) -> Self: + """ + Register or replace a growth method on the root registry. + + Parameters + ---------- + method_name : str + Name used later with [`grow()`][qrg.tree.Node.grow]. + method : callable + Callable receiving this node and returning + `dict[str, Node]`. The returned dictionary keys become leaf lookup + names under this node. + + Returns + ------- + Node + The current node, allowing fluent chaining even when registration + is initiated from a descendant. + + Notes + ----- + Registration is always applied to the root node's `_methods` + dictionary. Registering a method with an existing name replaces the + previous method without warning. + + Examples + -------- + >>> root = Node.new(name="root", target=qten.eye(2), methods={}) + >>> root.register_method("step", lambda node: {}) + Node(...) + """ + self.root()._methods[method_name] = method + return self + + @staticmethod + def growth(name: str) -> Callable[[_GrowMethod], _GrowMethod]: + """ + Mark a function in an external script as a loadable growth method. + + Parameters + ---------- + name : str + Registry name to use when the function is discovered by + [`register_methods()`][qrg.tree.Node.register_methods]. + + Returns + ------- + callable + Decorator that attaches growth-registration metadata to the + function and returns it unchanged. + + Notes + ----- + This decorator does not register the function immediately. It only + marks the function so that [`register_methods()`][qrg.tree.Node.register_methods] + can discover it when loading a script file. + + Examples + -------- + >>> @Node.growth("step") + ... def grow_step(node: Node) -> dict[str, Node]: + ... return {} + """ + + def decorator(method: _GrowMethod) -> _GrowMethod: + setattr(method, _GROWTH_MARKER, name) + return method + + return decorator + + def register_methods(self, path: str) -> Self: + """ + Load and register annotated growth methods from a Python script. + + Parameters + ---------- + path : str + Path to a Python source file. Relative paths are resolved against + the current working directory. + + Returns + ------- + Node + The current node, allowing fluent chaining even when registration + is initiated from a descendant. + + Raises + ------ + FileNotFoundError + If `path` does not point to an existing file. + ImportError + If the script cannot be loaded as a Python module. + Exception + Propagates any exception raised while executing the script. + + Notes + ----- + Only callables marked with [`@Node.growth(...)`][qrg.tree.Node.growth] + are registered. Unannotated functions in the script are ignored. + + Registration is forwarded to the root node's registry in the same way + as [`register_method()`][qrg.tree.Node.register_method]. + + Examples + -------- + >>> root = Node.new(name="root", target=qten.eye(2), methods={}) + >>> root.register_methods("growths.py") + Node(...) + """ + script_path = Path(path) + if not script_path.is_absolute(): + script_path = Path.cwd() / script_path + script_path = script_path.resolve() + + if not script_path.is_file(): + raise FileNotFoundError(f"Growth-method script not found: {script_path}") + + module_name = f"_qrg_growths_{script_path.stem}_{abs(hash(script_path))}" + spec = importlib.util.spec_from_file_location(module_name, script_path) + if spec is None or spec.loader is None: + raise ImportError(f"Could not load growth-method script: {script_path}") + + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + for value in vars(module).values(): + growth_name = getattr(value, _GROWTH_MARKER, None) + if growth_name is None: + continue + self.register_method(growth_name, value) + return self + + def target(self) -> TensorLike: + """ + Return the current target tensor for this node. + + Returns + ------- + qten.Tensor + Tensor stored at the node's `_target` key. + + Raises + ------ + KeyError + If `_target` is not present in `_data`. + + Notes + ----- + This is a convenience accessor for `self[self._target]`. + + Examples + -------- + >>> root = Node.new(name="root", target=qten.eye(2), methods={}) + >>> root.target() + Tensor(...) + """ + return self[self._target] + + def compute(self, name: str, f: Callable[["Node"], TensorLike]) -> Self: + """ + Compute and store a derived tensor on this node. + + Parameters + ---------- + name : str + Key under which to store the computed tensor in `_data`. + f : callable + Callable receiving the current node and returning the tensor to + store. + + Returns + ------- + Node + The current node, allowing fluent chaining. + + Raises + ------ + Exception + Propagates any exception raised by `f(node)`. + + Notes + ----- + If `name` already exists in `_data`, its previous value is replaced. + + Examples + -------- + >>> root = Node.new(name="root", target=qten.eye(2), methods={}) + >>> root.compute("copy", lambda node: node.target()) + Node(...) + >>> root["copy"] + Tensor(...) + """ + self._data[name] = f(self) + return self + + def leaf(self, target: str) -> _LeavePath: + """ + Return a previously materialized leaf path by lookup key. + + Parameters + ---------- + target : str + Leaf lookup key produced by a prior call to + [`grow()`][qrg.tree.Node.grow]. + + Returns + ------- + _LeavePath + Stored path containing the child node and its transform tensor. + + Raises + ------ + ValueError + If `target` is not present in `_leaves`. This usually means the + node has not been grown with the relevant method yet, or the + requested leaf name does not exist. + + Examples + -------- + After `grow()` completes, retrieve one child path: + + >>> path = root.leaf("child") + >>> path.node + Node(...) + """ + if target not in self._leaves: + raise ValueError(f"Target {target} not found in leaves of node {self.name}") + return self._leaves[target] + + def branches(self) -> dict[str, _LeavePath]: + """ + Return a shallow copy of the current node's direct branches. + + Returns + ------- + dict[str, _LeavePath] + Mapping from direct leaf lookup keys to stored leaf-path metadata. + + Notes + ----- + The returned dictionary is a shallow copy, so mutating it does not + modify the node's internal `_leaves` mapping. + """ + return dict(self._leaves) + + def cut(self, name: str) -> "Node": + """ + Detach a direct child branch from the current node. + + Parameters + ---------- + name : str + Leaf lookup key under the current node identifying the branch to + detach. + + Returns + ------- + Node + Root node of the detached subtree. + + Raises + ------ + ValueError + If `name` is not present in the current node's leaves. + + Notes + ----- + The returned node becomes the root of its own tree by clearing its + parent link. Only the direct branch selected by `name` is removed from + the current node. + """ + if name not in self._leaves: + raise ValueError(f"Target {name} not found in leaves of node {self.name}") + leaf = self._leaves.pop(name) + leaf.node._parent = None + return leaf.node + + def trace(self, expr: str) -> "Node": + """ + Resolve a node by dot-delimited leaf path. + + Parameters + ---------- + expr : str + Trace expression to resolve. Expressions of the form + `"a.b.c"` are resolved from the root node. Expressions beginning + with `"."`, such as `".a.b"`, are resolved relative to the + current node. The empty absolute path `""` resolves to the root, + and `"."` resolves to the current node. + + Returns + ------- + Node + Node reached by following the requested leaf path. + + Raises + ------ + ValueError + If the trace expression is malformed or cannot be fully resolved. + Errors identify the segment where resolution failed as + `trace cannot be resolve at "a.[b].c"`. + """ + is_relative = expr.startswith(".") + raw_path = expr[1:] if is_relative else expr + current = self if is_relative else self.root() + + if raw_path == "": + return current + + parts = raw_path.split(".") + if not parts or any(part == "" for part in parts): + raise ValueError(f'Invalid trace expression "{expr}"') + for index, part in enumerate(parts): + if part not in current._leaves: + resolved = parts[:index] + unresolved = parts[index + 1 :] + highlighted = ".".join(resolved + [f"[{part}]"] + unresolved) + raise ValueError(f'trace cannot be resolve at "{highlighted}"') + current = current._leaves[part].node + return current + + def path(self) -> str: + """ + Return the dot-delimited absolute leaf-key path from the root. + + Returns + ------- + str + Canonical absolute path from the root to this node. The root + node's path is the empty string `""`. + + Raises + ------ + ValueError + If the node is not reachable from its recorded parent chain using + the parents' leaf mappings. + """ + if self._parent is None: + return "" + + parts: list[str] = [] + current = self + while current._parent is not None: + parent = current._parent + for leaf_name, leaf in parent._leaves.items(): + if leaf.node is current: + parts.append(leaf_name) + current = parent + break + else: + raise RuntimeError( + f"Node {current.name} is not reachable from parent {parent.name}" + ) + parts.reverse() + return ".".join(parts) + + def get_transform(self, expr: str, *, composed: bool = True) -> TensorLike | list[TensorLike]: + """ + Return the transform(s) from the current node to a traced descendant. + + Parameters + ---------- + expr : str + Trace expression interpreted with the same rules as + [`trace()`][qrg.tree.Node.trace]. + composed : bool, default=True + If `True`, return a single composed transform tensor `T`. If + `False`, return the ordered list of direct branch transforms whose + product would form `T`. + + Returns + ------- + qten.Tensor or list[qten.Tensor] + Either the composed transform or the ordered list of transforms + from the current node to the traced node. + + Raises + ------ + ValueError + If `expr` resolves to a node outside the current node's subtree. + + Notes + ----- + For a descendant reached by transforms `[T1, T2, ..., Tn]`, the + composed result is `T1 @ T2 @ ... @ Tn`, so the descendant target is + obtained from the current target as + `T.h(-2, -1) @ target @ T`. + + Using this transform chain in the inverse/backward direction may + create very large intermediate tensors. In particular, for structured + tensor types such as momentum-block tensors, lifting a descendant + operator back toward an ancestor can expand into a large pair-resolved + space. For actual transport computations, staged application with + `composed=False` is usually preferable to materializing one large + composed transform. + """ + destination = self.trace(expr) + source_path = self.path() + destination_path = destination.path() + + if source_path == destination_path: + relative_parts: list[str] = [] + elif source_path == "": + relative_parts = destination_path.split(".") + elif destination_path.startswith(source_path + "."): + relative_parts = destination_path[len(source_path) + 1 :].split(".") + else: + raise ValueError( + f'Trace "{expr}" does not resolve to the current node or its descendants' + ) + + transforms: list[TensorLike] = [] + current = self + for part in relative_parts: + leaf = current.leaf(part) + transforms.append(leaf.transform) + current = leaf.node + + if not composed: + return transforms + + if not transforms: + target = self.target() + return qten.eye(target.dims, device=target.device) + + transform = transforms[0] + for step in transforms[1:]: + transform = transform @ step + return transform + + def find( + self, *, regex: str | None = None, predicate: Callable[["Node"], bool] | None = None + ) -> list["Node"]: + """ + Find nodes in the current subtree by name regex and/or predicate. + + Parameters + ---------- + regex : str, optional + Regular expression matched against each node's `name`. + predicate : callable, optional + Additional boolean filter applied to each visited node. + + Returns + ------- + list[Node] + All matching nodes in depth-first traversal order, starting from + the current node. + + Raises + ------ + ValueError + If neither `regex` nor `predicate` is provided. + + Notes + ----- + When both filters are supplied, a node must satisfy both. + """ + if regex is None and predicate is None: + raise ValueError("find() requires regex and/or predicate") + + pattern = re.compile(regex) if regex is not None else None + matches: list[Node] = [] + stack = [self] + + while stack: + current = stack.pop() + regex_match = pattern.search(current.name) is not None if pattern else True + predicate_match = predicate(current) if predicate else True + if regex_match and predicate_match: + matches.append(current) + + children = [path.node for path in current._leaves.values()] + stack.extend(reversed(children)) + + return matches + + def parent(self) -> "Node": + """ + Return the parent node. + + Returns + ------- + Node + Parent of the current node. + + Raises + ------ + ValueError + If this node is the root and therefore has no parent. + + Notes + ----- + Child parent links are assigned automatically by + [`grow()`][qrg.tree.Node.grow]. + """ + if self._parent is None: + raise ValueError(f"Node {self.name} has no parent") + return self._parent + + def root(self) -> "Node": + """ + Return the root node of the current tree. + + Returns + ------- + Node + Top-most ancestor reachable by following parent links. For the root + node itself, this method returns `self`. + + Notes + ----- + The root owns the tree-wide growth-method registry used by + [`register_method()`][qrg.tree.Node.register_method] and + [`grow()`][qrg.tree.Node.grow]. + + Examples + -------- + >>> root = Node.new(name="root", target=qten.eye(2), methods={}) + >>> root.root() is root + True + """ + root = self + while root._parent is not None: + root = root._parent + return root + + def is_root(self) -> bool: + """ + Return whether this node is the root of its tree. + + Returns + ------- + bool + `True` when the node has no parent, otherwise `False`. + """ + return self._parent is None + + def is_leaf(self) -> bool: + """ + Return whether this node currently has no child branches. + + Returns + ------- + bool + `True` when the node has no direct leaves, otherwise `False`. + """ + return len(self._leaves) == 0 + + def grow(self, method: str) -> Self: + """ + Expand this node using a registered growth method. + + Parameters + ---------- + method : str + Name of a previously registered growth method. + + Returns + ------- + Node + The current node after updating `_leaves`. + + Raises + ------ + ValueError + If `method` has not been registered on this node. + KeyError + If a returned child node does not contain its `_target` key in its + `_data` mapping. In that case the leaf transform cannot be derived. + + Notes + ----- + The growth method returns `dict[str, Node]`. For each returned child: + + - The dictionary key becomes the lookup key in `self._leaves`. + - The transform stored in `_LeavePath` is `child[child._target]`. + - The child node's `_parent` is set to `self`. + - The growth method itself is always resolved from the root node's + `_methods` registry, regardless of which node calls `grow()`. + + If a leaf key already exists, it is overwritten by the latest growth + result with the same name. + + Examples + -------- + >>> root.grow("step") + Node(...) + >>> child_path = root.leaf("child") + >>> child_path.node.parent() is root + True + """ + root = self.root() + if method not in root._methods: + raise ValueError(f"Method {method} not found in node {self.name}") + new_nodes = root._methods[method](self) + new_leaves = { + target: _LeavePath(node=node, transform=node[node._target]) + for target, node in new_nodes.items() + } + for leaf in new_leaves.values(): + leaf.node._parent = self + self._leaves.update(new_leaves) + return self + + @staticmethod + def new(*, name: str, target: TensorLike, methods: dict[str, _GrowMethod]) -> "Node": + """ + Construct a root node from a single target tensor. + + Parameters + ---------- + name : str + Root node name. This also becomes the initial `_target` key. + target : qten.Tensor + Target tensor stored on the root node under `name`. + methods : dict[str, callable] + Initial growth-method registry for the root node. + + Returns + ------- + Node + Newly created root node with no parent and no leaves. + + Notes + ----- + The created node stores: + + - `_target == name` + - `_data == {name: target}` + - `_parent is None` + + Examples + -------- + >>> import qten + >>> root = Node.new(name="hamiltonian", target=qten.eye(2), methods={}) + >>> root.name + 'hamiltonian' + >>> list(root) + ['hamiltonian'] + """ + return Node( + name=name, + _target=name, + _data={name: target}, + _parent=None, + _leaves={}, + _methods=methods, + ) + + +def _tree_layout(root: Node) -> dict[int, tuple[float, float]]: + positions: dict[int, tuple[float, float]] = {} + next_x = 0.0 + + def assign(node: Node, depth: int) -> float: + nonlocal next_x + children = [path.node for path in node._leaves.values()] + if not children: + x = next_x + next_x += 1.0 + else: + child_xs = [assign(child, depth + 1) for child in children] + x = sum(child_xs) / len(child_xs) + positions[id(node)] = (x, -float(depth)) + return x + + assign(root, 0) + return positions + + +def _wrap_node_label(label: str, target_key: str, *, width: int) -> str: + text = f"{label}\n[{target_key}]" + return "\n".join(textwrap.fill(line, width=width) for line in text.splitlines()) + + +def _tree_node_style(*, is_current: bool, is_root: bool) -> dict[str, object]: + if is_current: + return { + "facecolor": "#d1495b", + "edgecolor": "#7f1d1d", + "textcolor": "white", + "linewidth": 2.2, + } + if is_root: + return { + "facecolor": "#edc948", + "edgecolor": "#8c6d1f", + "textcolor": "black", + "linewidth": 2.0, + } + return { + "facecolor": "#d6ebf5", + "edgecolor": "#355070", + "textcolor": "#1f2933", + "linewidth": 1.6, + } + + +@Node.register_plot_method("tree", backend="networkx") # type: ignore[untyped-decorator] +def _plot_node_tree_networkx(node: Node) -> nx.DiGraph: + """ + Build a `networkx` graph representation of the full node tree. + + Parameters + ---------- + node : Node + Current node from which plotting was requested. + + Returns + ------- + networkx.DiGraph + Directed graph containing every node reachable from `node.root()`. + Graph metadata includes: + + - `graph["root"]`: root node of the tree + - `graph["current"]`: node on which `plot()` was invoked + - `graph["positions"]`: hierarchical layout positions keyed by graph + node id + + Graph node ids are `id(node)` for the corresponding `Node` object. + Each graph node stores: + + - `node`: original `Node` instance + - `label`: node name + - `target_key`: node `_target` + - `is_current`: whether this is the requested node + - `is_root`: whether this is the root node + + Each directed edge stores: + + - `target`: leaf lookup key on the parent + - `transform`: corresponding `_LeavePath.transform` + + """ + root = node.root() + positions = _tree_layout(root) + graph = nx.DiGraph() + graph.graph["root"] = root + graph.graph["current"] = node + graph.graph["positions"] = positions + + stack = [root] + while stack: + current = stack.pop() + current_id = id(current) + graph.add_node( + current_id, + node=current, + label=current.name, + target_key=current._target, + is_current=current is node, + is_root=current is root, + ) + for target, path in current._leaves.items(): + child = path.node + child_id = id(child) + graph.add_node( + child_id, + node=child, + label=child.name, + target_key=child._target, + is_current=child is node, + is_root=False, + ) + graph.add_edge(current_id, child_id, target=target, transform=path.transform) + stack.append(child) + + return graph + + +@Node.register_plot_method("tree", backend="matplotlib") # type: ignore[untyped-decorator] +def _( + node: Node, + *, + ax: Any = None, + figsize: tuple[float, float] = (10.0, 6.0), + label_width: int = 18, +) -> Any: + """ + Render the full node tree as a matplotlib figure with wrapped node labels. + + Parameters + ---------- + node : Node + Current node from which plotting was requested. + ax : matplotlib.axes.Axes, optional + Existing axes to draw on. If omitted, a new figure and axes are + created. + figsize : tuple[float, float], default=(10.0, 6.0) + Figure size used when `ax` is not provided. + label_width : int, default=18 + Maximum character width before node labels are wrapped onto new lines. + + Returns + ------- + matplotlib.figure.Figure + Figure containing the rendered tree. + + Notes + ----- + The current node is highlighted in crimson, the root in gold, and all + other nodes in light blue. Labels are wrapped before being placed inside + rounded bounding boxes so long names fit within node shapes more cleanly. + Connectors are drawn as orthogonal rounded-corner arrows so branching + structure reads more like a diagram than a raw graph. + """ + graph = _plot_node_tree_networkx(node) + positions = graph.graph["positions"] + + created_fig = False + if ax is None: + fig, ax = plt.subplots(figsize=figsize) + created_fig = True + else: + fig = ax.figure + + fig.patch.set_facecolor("#f7f3ea") + ax.set_facecolor("#f7f3ea") + + for src, dst in graph.edges: + x0, y0 = positions[src] + x1, y1 = positions[dst] + y_start = y0 - 0.16 + junction_y = y0 - 0.42 + y_end = y1 + 0.04 + + connector_path = MplPath( + [ + (x0, y_start), + (x0, junction_y), + (x1, junction_y), + (x1, y_end), + ], + [ + MplPath.MOVETO, + MplPath.LINETO, + MplPath.LINETO, + MplPath.LINETO, + ], + ) + connector = PathPatch( + connector_path, + facecolor="none", + edgecolor="#6b7a8f", + linewidth=1.8, + capstyle="round", + joinstyle="round", + zorder=2, + alpha=0.9, + ) + connector.set_path_effects([pe.Stroke(linewidth=3.2, foreground="#ebe4d8"), pe.Normal()]) + ax.add_patch(connector) + + for node_id, attrs in graph.nodes(data=True): + x, y = positions[node_id] + style = _tree_node_style( + is_current=attrs["is_current"], + is_root=attrs["is_root"], + ) + + label = _wrap_node_label(attrs["label"], attrs["target_key"], width=label_width) + text = ax.text( + x, + y, + label, + ha="center", + va="center", + color=style["textcolor"], + fontsize=10, + fontweight="bold", + fontfamily="monospace", + bbox={ + "boxstyle": "round,pad=0.55,rounding_size=0.25", + "facecolor": style["facecolor"], + "edgecolor": style["edgecolor"], + "linewidth": style["linewidth"], + }, + zorder=3, + ) + text.set_path_effects([pe.withStroke(linewidth=0.5, foreground=style["facecolor"])]) + + xs = [x for x, _ in positions.values()] + ys = [y for _, y in positions.values()] + ax.set_xlim(min(xs) - 0.95, max(xs) + 0.95) + ax.set_ylim(min(ys) - 0.9, max(ys) + 0.9) + ax.set_axis_off() + ax.set_title( + f"Tree View: {graph.graph['current'].name}", + fontsize=13, + fontweight="bold", + color="#243447", + pad=18, + ) + + if created_fig: + fig.tight_layout() + return fig + + +@Node.register_plot_method("tree", backend="plotly") # type: ignore[untyped-decorator] +def _( + node: Node, + *, + figure: Any = None, + label_width: int = 18, +) -> Any: + """ + Render the full node tree as an interactive Plotly figure. + + Parameters + ---------- + node : Node + Current node from which plotting was requested. + figure : plotly.graph_objects.Figure, optional + Existing figure to populate. If omitted, a new figure is created. + label_width : int, default=18 + Maximum character width before node labels are wrapped onto new lines. + + Returns + ------- + plotly.graph_objects.Figure + Figure containing the rendered tree. + """ + graph = _plot_node_tree_networkx(node) + positions = graph.graph["positions"] + + if figure is None: + fig = go.Figure() + else: + fig = figure + + edge_x: list[float | None] = [] + edge_y: list[float | None] = [] + for src, dst in graph.edges: + x0, y0 = positions[src] + x1, y1 = positions[dst] + y_start = y0 - 0.16 + junction_y = y0 - 0.42 + y_end = y1 + 0.04 + edge_x.extend([x0, x0, x1, x1, None]) + edge_y.extend([y_start, junction_y, junction_y, y_end, None]) + + fig.add_trace( + go.Scatter( + x=edge_x, + y=edge_y, + mode="lines", + line={"color": "#6b7a8f", "width": 2}, + hoverinfo="skip", + showlegend=False, + ) + ) + + node_x: list[float] = [] + node_y: list[float] = [] + hover_text: list[str] = [] + marker_color: list[str] = [] + marker_line_color: list[str] = [] + marker_line_width: list[float] = [] + + for node_id, attrs in graph.nodes(data=True): + x, y = positions[node_id] + style = _tree_node_style( + is_current=attrs["is_current"], + is_root=attrs["is_root"], + ) + label = _wrap_node_label( + cast(str, attrs["label"]), + cast(str, attrs["target_key"]), + width=label_width, + ) + fig.add_annotation( + x=x, + y=y, + text=label.replace("\n", "
"), + showarrow=False, + xanchor="center", + yanchor="middle", + align="center", + font={ + "color": style["textcolor"], + "size": 11, + "family": "Courier New, monospace", + }, + bgcolor=style["facecolor"], + bordercolor=style["edgecolor"], + borderwidth=style["linewidth"], + borderpad=10, + ) + node_x.append(x) + node_y.append(y) + hover_text.append( + f"name: {cast(str, attrs['label'])}
target: {cast(str, attrs['target_key'])}" + ) + marker_color.append(cast(str, style["facecolor"])) + marker_line_color.append(cast(str, style["edgecolor"])) + marker_line_width.append(cast(float, style["linewidth"])) + + fig.add_trace( + go.Scatter( + x=node_x, + y=node_y, + mode="markers", + marker={ + "size": 18, + "color": marker_color, + "line": { + "color": marker_line_color, + "width": marker_line_width, + }, + "opacity": 0.0, + }, + text=hover_text, + hovertemplate="%{text}", + showlegend=False, + ) + ) + + xs = [x for x, _ in positions.values()] + ys = [y for _, y in positions.values()] + fig.update_layout( + title={ + "text": f"Tree View: {graph.graph['current'].name}", + "font": {"size": 18, "color": "#243447"}, + "x": 0.5, + }, + paper_bgcolor="#f7f3ea", + plot_bgcolor="#f7f3ea", + xaxis={ + "range": [min(xs) - 0.95, max(xs) + 0.95], + "visible": False, + }, + yaxis={ + "range": [min(ys) - 0.9, max(ys) + 0.9], + "visible": False, + "scaleanchor": "x", + "scaleratio": 1, + }, + margin={"l": 20, "r": 20, "t": 60, "b": 20}, + ) + return fig diff --git a/tests/test_tree.py b/tests/test_tree.py new file mode 100644 index 0000000..ce09d08 --- /dev/null +++ b/tests/test_tree.py @@ -0,0 +1,712 @@ +import pickle +from pathlib import Path +from typing import Any, cast + +import pytest + +import qrg.tree as tree_module +from qrg.tree import Node + + +class FakeTensor: + def __init__( + self, + name: str, + dims: tuple[str, str] = ("d0", "d1"), + device: str = "cpu", + ) -> None: + self.name = name + self.dims = dims + self.device = device + + def __matmul__(self, other: "FakeTensor") -> "FakeTensor": + return FakeTensor(f"({self.name}@{other.name})", dims=self.dims, device=self.device) + + def h(self, *_axes: int) -> "FakeTensor": + return FakeTensor(f"{self.name}.h", dims=self.dims, device=self.device) + + +def _tensor_stub() -> Any: + return cast(Any, object()) + + +def test_grow_sets_parent_and_uses_child_target_for_transform() -> None: + root_tensor = _tensor_stub() + child_tensor = _tensor_stub() + + def grow_method(_: Node) -> dict[str, Node]: + child = Node( + name="child", + _target="child_transform", + _data={"child_transform": child_tensor}, + _parent=None, + _leaves={}, + _methods={}, + ) + return {"external_target_name": child} + + root = Node.new(name="root", target=root_tensor, methods={"grow": grow_method}) + root.grow("grow") + + leaf = root.leaf("external_target_name") + assert leaf.node.parent() is root + assert leaf.transform is child_tensor + + +def test_child_grow_looks_up_methods_from_root() -> None: + root_tensor = _tensor_stub() + child_tensor = _tensor_stub() + grandchild_tensor = _tensor_stub() + + def root_grow(_: Node) -> dict[str, Node]: + child = Node( + name="child", + _target="child_transform", + _data={"child_transform": child_tensor}, + _parent=None, + _leaves={}, + _methods={}, + ) + return {"child": child} + + def child_grow(_: Node) -> dict[str, Node]: + grandchild = Node( + name="grandchild", + _target="grandchild_transform", + _data={"grandchild_transform": grandchild_tensor}, + _parent=None, + _leaves={}, + _methods={}, + ) + return {"grandchild": grandchild} + + root = Node.new( + name="root", + target=root_tensor, + methods={"root_grow": root_grow, "child_grow": child_grow}, + ) + root.grow("root_grow") + + child = root.leaf("child").node + child.grow("child_grow") + grandchild = child.leaf("grandchild") + assert grandchild.node.parent() is child + assert grandchild.transform is grandchild_tensor + + +def test_register_method_from_child_updates_root_registry() -> None: + root_tensor = _tensor_stub() + child_tensor = _tensor_stub() + + def grow_method(_: Node) -> dict[str, Node]: + child = Node( + name="child", + _target="child_transform", + _data={"child_transform": child_tensor}, + _parent=None, + _leaves={}, + _methods={}, + ) + return {"child": child} + + def new_method(_: Node) -> dict[str, Node]: + return {} + + root = Node.new(name="root", target=root_tensor, methods={"grow": grow_method}) + root.grow("grow") + + child = root.leaf("child").node + child.register_method("new", new_method) + assert root._methods["new"] is new_method + assert "new" not in child._methods + + +def test_target_returns_tensor_at_current_target_key() -> None: + root_tensor = _tensor_stub() + child_tensor = _tensor_stub() + + def grow_method(_: Node) -> dict[str, Node]: + child = Node( + name="child", + _target="child_transform", + _data={"child_transform": child_tensor}, + _parent=None, + _leaves={}, + _methods={}, + ) + return {"child": child} + + root = Node.new(name="root", target=root_tensor, methods={"grow": grow_method}) + root.grow("grow") + + child = root.leaf("child").node + assert root.target() is root_tensor + assert child.target() is child_tensor + + +def test_compute_stores_derived_tensor_in_data() -> None: + root_tensor = _tensor_stub() + derived_tensor = _tensor_stub() + + root = Node.new(name="root", target=root_tensor, methods={}) + returned = root.compute("derived", lambda node: derived_tensor) + + assert returned is root + assert root["derived"] is derived_tensor + + +def test_cut_detaches_direct_branch_into_new_root() -> None: + root_tensor = _tensor_stub() + child_tensor = _tensor_stub() + + def grow_method(_: Node) -> dict[str, Node]: + child = Node( + name="child", + _target="child_transform", + _data={"child_transform": child_tensor}, + _parent=None, + _leaves={}, + _methods={}, + ) + return {"child": child} + + root = Node.new(name="root", target=root_tensor, methods={"grow": grow_method}) + root.grow("grow") + + child = root.cut("child") + + assert child.name == "child" + assert child.root() is child + with pytest.raises(ValueError, match="Target child not found in leaves of node root"): + root.leaf("child") + with pytest.raises(ValueError, match="Node child has no parent"): + child.parent() + + +def test_cut_raises_for_missing_direct_branch() -> None: + root = Node.new(name="root", target=_tensor_stub(), methods={}) + + with pytest.raises(ValueError, match="Target missing not found in leaves of node root"): + root.cut("missing") + + +def test_branches_returns_shallow_copy_of_direct_leaves() -> None: + root_tensor = _tensor_stub() + child_tensor = _tensor_stub() + + def grow_method(_: Node) -> dict[str, Node]: + child = Node( + name="child", + _target="child_transform", + _data={"child_transform": child_tensor}, + _parent=None, + _leaves={}, + _methods={}, + ) + return {"child": child} + + root = Node.new(name="root", target=root_tensor, methods={"grow": grow_method}) + root.grow("grow") + + branches = root.branches() + + assert set(branches) == {"child"} + assert branches["child"].node is root.leaf("child").node + + branches.pop("child") + assert set(root.branches()) == {"child"} + + +def test_root_returns_topmost_ancestor() -> None: + root_tensor = _tensor_stub() + child_tensor = _tensor_stub() + + def grow_method(_: Node) -> dict[str, Node]: + child = Node( + name="child", + _target="child_transform", + _data={"child_transform": child_tensor}, + _parent=None, + _leaves={}, + _methods={}, + ) + return {"child": child} + + root = Node.new(name="root", target=root_tensor, methods={"grow": grow_method}) + root.grow("grow") + + child = root.leaf("child").node + assert root.root() is root + assert child.root() is root + + +def test_is_root_reflects_parent_link() -> None: + root_tensor = _tensor_stub() + child_tensor = _tensor_stub() + + def grow_method(_: Node) -> dict[str, Node]: + child = Node( + name="child", + _target="child_transform", + _data={"child_transform": child_tensor}, + _parent=None, + _leaves={}, + _methods={}, + ) + return {"child": child} + + root = Node.new(name="root", target=root_tensor, methods={"grow": grow_method}) + root.grow("grow") + child = root.leaf("child").node + + assert root.is_root() is True + assert child.is_root() is False + + +def test_is_leaf_reflects_whether_node_has_direct_branches() -> None: + root_tensor = _tensor_stub() + child_tensor = _tensor_stub() + + def grow_method(_: Node) -> dict[str, Node]: + child = Node( + name="child", + _target="child_transform", + _data={"child_transform": child_tensor}, + _parent=None, + _leaves={}, + _methods={}, + ) + return {"child": child} + + root = Node.new(name="root", target=root_tensor, methods={"grow": grow_method}) + + assert root.is_leaf() is True + + root.grow("grow") + child = root.leaf("child").node + + assert root.is_leaf() is False + assert child.is_leaf() is True + + +def test_trace_resolves_absolute_and_relative_paths() -> None: + root_tensor = _tensor_stub() + child_tensor = _tensor_stub() + grandchild_tensor = _tensor_stub() + + def child_grow(_: Node) -> dict[str, Node]: + grandchild = Node( + name="grandchild", + _target="grandchild_transform", + _data={"grandchild_transform": grandchild_tensor}, + _parent=None, + _leaves={}, + _methods={}, + ) + return {"grandchild": grandchild} + + def root_grow(_: Node) -> dict[str, Node]: + child = Node( + name="child", + _target="child_transform", + _data={"child_transform": child_tensor}, + _parent=None, + _leaves={}, + _methods={}, + ) + return {"child": child} + + root = Node.new( + name="root", + target=root_tensor, + methods={"root_grow": root_grow, "child_grow": child_grow}, + ) + root.grow("root_grow") + child = root.leaf("child").node + child.grow("child_grow") + grandchild = child.leaf("grandchild").node + + assert root.trace("") is root + assert child.trace("") is root + assert child.trace(".") is child + assert root.trace("child.grandchild") is grandchild + assert child.trace(".grandchild") is grandchild + + +def test_path_returns_absolute_leaf_key_trace() -> None: + root_tensor = _tensor_stub() + child_tensor = _tensor_stub() + grandchild_tensor = _tensor_stub() + + def child_grow(_: Node) -> dict[str, Node]: + grandchild = Node( + name="grandchild", + _target="grandchild_transform", + _data={"grandchild_transform": grandchild_tensor}, + _parent=None, + _leaves={}, + _methods={}, + ) + return {"grandchild": grandchild} + + def root_grow(_: Node) -> dict[str, Node]: + child = Node( + name="child", + _target="child_transform", + _data={"child_transform": child_tensor}, + _parent=None, + _leaves={}, + _methods={}, + ) + return {"child": child} + + root = Node.new( + name="root", + target=root_tensor, + methods={"root_grow": root_grow, "child_grow": child_grow}, + ) + root.grow("root_grow") + child = root.leaf("child").node + child.grow("child_grow") + grandchild = child.leaf("grandchild").node + + assert root.path() == "" + assert child.path() == "child" + assert grandchild.path() == "child.grandchild" + assert root.trace(grandchild.path()) is grandchild + + +def test_trace_raises_with_explicit_break_position() -> None: + root_tensor = _tensor_stub() + child_tensor = _tensor_stub() + + def root_grow(_: Node) -> dict[str, Node]: + child = Node( + name="child", + _target="child_transform", + _data={"child_transform": child_tensor}, + _parent=None, + _leaves={}, + _methods={}, + ) + return {"child": child} + + root = Node.new(name="root", target=root_tensor, methods={"grow": root_grow}) + root.grow("grow") + child = root.leaf("child").node + + with pytest.raises(ValueError, match=r'trace cannot be resolve at "child\.\[missing\]\.leaf"'): + root.trace("child.missing.leaf") + with pytest.raises(ValueError, match=r'trace cannot be resolve at "\[missing\]\.leaf"'): + child.trace(".missing.leaf") + with pytest.raises(ValueError, match='Invalid trace expression "\\.child\\."'): + root.trace(".child.") + + +def test_find_filters_current_subtree_by_regex_and_predicate() -> None: + root_tensor = _tensor_stub() + alpha_tensor = _tensor_stub() + beta_tensor = _tensor_stub() + gamma_tensor = _tensor_stub() + + def gamma_grow(_: Node) -> dict[str, Node]: + gamma = Node( + name="gamma-child", + _target="gamma_transform", + _data={"gamma_transform": gamma_tensor}, + _parent=None, + _leaves={}, + _methods={}, + ) + return {"gamma": gamma} + + def root_grow(_: Node) -> dict[str, Node]: + alpha = Node( + name="alpha-child", + _target="alpha_transform", + _data={"alpha_transform": alpha_tensor}, + _parent=None, + _leaves={}, + _methods={}, + ) + beta = Node( + name="beta-child", + _target="beta_transform", + _data={"beta_transform": beta_tensor}, + _parent=None, + _leaves={}, + _methods={}, + ) + return {"alpha": alpha, "beta": beta} + + root = Node.new( + name="root", + target=root_tensor, + methods={"root_grow": root_grow, "gamma_grow": gamma_grow}, + ) + root.grow("root_grow") + alpha = root.leaf("alpha").node + beta = root.leaf("beta").node + alpha.grow("gamma_grow") + gamma = alpha.leaf("gamma").node + + assert root.find(regex="child$") == [alpha, gamma, beta] + assert root.find(predicate=lambda node: node is not root) == [alpha, gamma, beta] + assert root.find(regex="^a", predicate=lambda node: "alpha" in node.name) == [alpha] + assert alpha.find(regex="child$") == [alpha, gamma] + + +def test_find_requires_at_least_one_filter() -> None: + root = Node.new(name="root", target=_tensor_stub(), methods={}) + + with pytest.raises(ValueError, match="find\\(\\) requires regex and/or predicate"): + root.find() + + +def test_get_transform_returns_ordered_list_and_composed_tensor( + monkeypatch: pytest.MonkeyPatch, +) -> None: + root_tensor: Any = FakeTensor("root_target") + child_transform: Any = FakeTensor("child_transform") + grandchild_transform: Any = FakeTensor("grandchild_transform") + + def child_grow(_: Node) -> dict[str, Node]: + grandchild = Node( + name="grandchild", + _target="grandchild_transform", + _data={"grandchild_transform": grandchild_transform}, + _parent=None, + _leaves={}, + _methods={}, + ) + return {"grandchild": grandchild} + + def root_grow(_: Node) -> dict[str, Node]: + child = Node( + name="child", + _target="child_transform", + _data={"child_transform": child_transform}, + _parent=None, + _leaves={}, + _methods={}, + ) + return {"child": child} + + root = Node.new( + name="root", + target=root_tensor, + methods={"root_grow": root_grow, "child_grow": child_grow}, + ) + root.grow("root_grow") + child = root.leaf("child").node + child.grow("child_grow") + + monkeypatch.setattr( + cast(Any, tree_module.qten), # type: ignore[attr-defined] + "eye", + lambda dims, *, device=None: FakeTensor("identity", dims=dims, device=device), + ) + + assert root.get_transform("child.grandchild", composed=False) == [ + child_transform, + grandchild_transform, + ] + assert cast(Any, root.get_transform("child.grandchild")).name == ( + "(child_transform@grandchild_transform)" + ) + assert child.get_transform(".grandchild", composed=False) == [grandchild_transform] + assert cast(Any, child.get_transform(".")).name == "identity" + + +def test_get_transform_rejects_non_descendant_trace() -> None: + root_tensor: Any = FakeTensor("root_target") + left_transform: Any = FakeTensor("left_transform") + right_transform: Any = FakeTensor("right_transform") + + def root_grow(_: Node) -> dict[str, Node]: + left = Node( + name="left", + _target="left_transform", + _data={"left_transform": left_transform}, + _parent=None, + _leaves={}, + _methods={}, + ) + right = Node( + name="right", + _target="right_transform", + _data={"right_transform": right_transform}, + _parent=None, + _leaves={}, + _methods={}, + ) + return {"left": left, "right": right} + + root = Node.new(name="root", target=root_tensor, methods={"grow": root_grow}) + root.grow("grow") + left = root.leaf("left").node + + with pytest.raises( + ValueError, + match='Trace "right" does not resolve to the current node or its descendants', + ): + left.get_transform("right") + + +def test_pickle_roundtrip_excludes_methods_registry() -> None: + root_tensor = _tensor_stub() + child_tensor = _tensor_stub() + + def grow_method(_: Node) -> dict[str, Node]: + child = Node( + name="child", + _target="child_transform", + _data={"child_transform": child_tensor}, + _parent=None, + _leaves={}, + _methods={}, + ) + return {"child": child} + + root = Node.new(name="root", target=root_tensor, methods={"grow": grow_method}) + root.grow("grow") + + restored = pickle.loads(pickle.dumps(root)) + restored_child = restored.leaf("child").node + + assert restored._methods == {} + assert restored_child._methods == {} + assert restored_child.parent() is restored + + +def test_register_methods_loads_only_annotated_growths( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + script_path = tmp_path / "growths.py" + script_path.write_text( + "\n".join( + [ + "from qrg.tree import Node", + "", + '@Node.growth("loaded")', + "def loaded_method(node: Node) -> dict[str, Node]:", + " return {}", + "", + "def ignored_method(node: Node) -> dict[str, Node]:", + " return {}", + ] + ) + ) + + monkeypatch.chdir(tmp_path) + root = Node.new(name="root", target=_tensor_stub(), methods={}) + root.register_methods("growths.py") + + assert "loaded" in root._methods + assert root._methods["loaded"].__name__ == "loaded_method" + assert "ignored_method" not in root._methods + + +def test_plot_tree_networkx_returns_whole_tree_with_current_highlight() -> None: + root_tensor = _tensor_stub() + child_tensor = _tensor_stub() + grandchild_tensor = _tensor_stub() + + def child_grow(_: Node) -> dict[str, Node]: + grandchild = Node( + name="grandchild", + _target="grandchild_transform", + _data={"grandchild_transform": grandchild_tensor}, + _parent=None, + _leaves={}, + _methods={}, + ) + return {"grandchild_leaf": grandchild} + + def root_grow(_: Node) -> dict[str, Node]: + child = Node( + name="child", + _target="child_transform", + _data={"child_transform": child_tensor}, + _parent=None, + _leaves={}, + _methods={}, + ) + return {"child_leaf": child} + + root = Node.new( + name="root", + target=root_tensor, + methods={"root_grow": root_grow, "child_grow": child_grow}, + ) + root.grow("root_grow") + child = root.leaf("child_leaf").node + child.grow("child_grow") + grandchild = child.leaf("grandchild_leaf").node + + graph = cast(Any, child.plot("tree", backend="networkx")) + root_id = id(root) + child_id = id(child) + grandchild_id = id(grandchild) + + assert set(graph.nodes) == {root_id, child_id, grandchild_id} + assert set(graph.edges) == {(root_id, child_id), (child_id, grandchild_id)} + assert graph.graph["root"] is root + assert graph.graph["current"] is child + assert graph.nodes[root_id]["node"] is root + assert graph.nodes[root_id]["is_root"] is True + assert graph.nodes[root_id]["is_current"] is False + assert graph.nodes[child_id]["is_current"] is True + assert graph.edges[root_id, child_id]["target"] == "child_leaf" + assert graph.edges[child_id, grandchild_id]["target"] == "grandchild_leaf" + assert graph.graph["positions"][root_id][1] == 0.0 + assert graph.graph["positions"][child_id][1] == -1.0 + + +def test_plot_tree_matplotlib_wraps_labels() -> None: + root_tensor = _tensor_stub() + child_tensor = _tensor_stub() + + def grow_method(_: Node) -> dict[str, Node]: + child = Node( + name="child node with a very long label", + _target="a_very_long_target_key_name", + _data={"a_very_long_target_key_name": child_tensor}, + _parent=None, + _leaves={}, + _methods={}, + ) + return {"child_leaf": child} + + root = Node.new( + name="root node with a very long label", + target=root_tensor, + methods={"grow": grow_method}, + ) + root.grow("grow") + + fig = cast(Any, root.plot("tree", backend="matplotlib", label_width=12)) + labels = [text.get_text() for text in fig.axes[0].texts] + + assert any("\n" in label for label in labels) + + +def test_plot_tree_defaults_to_matplotlib_figure() -> None: + root = Node.new(name="root", target=_tensor_stub(), methods={}) + + fig = cast(Any, root.plot("tree")) + + assert hasattr(fig, "axes") + + +def test_plot_tree_plotly_returns_plotly_figure() -> None: + pytest.importorskip("plotly") + + root = Node.new(name="root", target=_tensor_stub(), methods={}) + + fig = cast(Any, root.plot("tree", backend="plotly")) + + assert hasattr(fig, "to_plotly_json") + assert fig.layout.title.text == "Tree View: root" + assert len(fig.layout.annotations) == 1