diff --git a/.gitignore b/.gitignore
index 4d3f601..77aa175 100644
--- a/.gitignore
+++ b/.gitignore
@@ -213,3 +213,6 @@ cython_debug/
marimo/_static/
marimo/_lsp/
__marimo__/
+
+# ipynb
+*.ipynb
diff --git a/pyproject.toml b/pyproject.toml
index 2785a60..f8b563b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -14,6 +14,9 @@ authors = [
{ name = "eyjafjallac" },
]
dependencies = [
+ "matplotlib>=3.10.7",
+ "networkx>=3.0",
+ "plotly>=6.5.0",
"qten>=0.4.2",
]
[project.optional-dependencies]
diff --git a/src/qrg/__init__.py b/src/qrg/__init__.py
index 00695fa..c258ffe 100644
--- a/src/qrg/__init__.py
+++ b/src/qrg/__init__.py
@@ -1,5 +1,7 @@
"""qRG package."""
-__all__ = ["__version__"]
+from .tree import Node
+
+__all__ = ["Node", "__version__"]
__version__ = "0.1.0"
diff --git a/src/qrg/py.typed b/src/qrg/py.typed
new file mode 100644
index 0000000..e69de29
diff --git a/src/qrg/tree.py b/src/qrg/tree.py
new file mode 100644
index 0000000..14b1b1c
--- /dev/null
+++ b/src/qrg/tree.py
@@ -0,0 +1,1335 @@
+"""
+Tree-building helpers for renormalization workflows.
+
+This module provides a lightweight `Node` object that stores named
+`qten.Tensor` instances together with expansion methods for growing a tree of
+derived nodes. Each node behaves like a mapping from string keys to tensors,
+and each growth step records a `_LeavePath` describing both the child node and
+the transform tensor that connects the current node's target tensor to the
+child target tensor.
+
+The intended workflow is:
+
+1. Create a root node with [`Node.new`][qrg.tree.Node.new].
+2. Register one or more growth methods that expand a node into child nodes.
+3. Call [`Node.grow`][qrg.tree.Node.grow] to materialize leaf paths.
+4. Traverse children through [`Node.leaf`][qrg.tree.Node.leaf] and
+ [`Node.parent`][qrg.tree.Node.parent].
+
+Notes
+-----
+The public growth-method contract is intentionally simple: a growth method
+returns `dict[str, Node]`. The returned dictionary key is used as the lookup
+key for the new leaf under the current node, while the child node's private
+`_target` field determines which tensor inside the child becomes the leaf-path
+transform.
+
+Examples
+--------
+Create and grow a minimal tree:
+
+>>> import qten
+>>> from qrg.tree import Node
+>>> root_tensor = qten.eye(2)
+>>> child_tensor = qten.eye(2)
+>>> def grow_once(node: Node) -> dict[str, Node]:
+... return {
+... "child": Node(
+... name="child",
+... _target="iso",
+... _data={"iso": child_tensor},
+... _parent=None,
+... _leaves={},
+... _methods={},
+... )
+... }
+>>> root = Node.new(name="root", target=root_tensor, methods={"step": grow_once})
+>>> root.grow("step")
+Node(...)
+>>> path = root.leaf("child")
+>>> path.node.parent() is root
+True
+"""
+
+import importlib.util
+import re
+import textwrap
+from collections.abc import Callable, Iterator, Mapping
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any, Self, cast
+
+import matplotlib.patheffects as pe
+import matplotlib.pyplot as plt
+import networkx as nx # type: ignore[import-untyped]
+import plotly.graph_objects as go # type: ignore[import-untyped]
+import qten
+from matplotlib.patches import PathPatch
+from matplotlib.path import Path as MplPath
+
+TensorLike = qten.Tensor[Any]
+
+
+@dataclass
+class _LeavePath:
+ """
+ Leaf metadata stored on a parent node after growth.
+
+ Parameters
+ ----------
+ node : Node
+ Child node reached from the parent node.
+ transform : qten.Tensor
+ Tensor that maps the parent node's target tensor into the child node's
+ target tensor basis.
+
+ Notes
+ -----
+ The expected interpretation is
+ `transform.h(-2, -1) @ target @ transform`, where `target` is the parent
+ node's target tensor. This class does not perform that contraction itself;
+ it only stores the bookkeeping needed to do so elsewhere.
+
+ Examples
+ --------
+ A `_LeavePath` is usually created by [`Node.grow`][qrg.tree.Node.grow]
+ rather than instantiated directly.
+ """
+
+ node: "Node"
+ transform: TensorLike
+
+
+_GrowMethod = Callable[["Node"], dict[str, "Node"]]
+_GROWTH_MARKER = "_qrg_growth_name"
+
+
+@dataclass
+class Node(Mapping[str, TensorLike], qten.plottings.Plottable):
+ """
+ Tree node containing tensors, growth methods, and derived leaf paths.
+
+ `Node` serves two roles:
+
+ - A mapping from string names to `qten.Tensor` objects.
+ - A tree vertex that can spawn children through registered growth methods.
+
+ Parameters
+ ----------
+ name : str
+ Human-readable identifier for the node.
+ _target : str
+ Key in `_data` identifying the node's target tensor. This key is used
+ by the parent node when extracting the transform tensor for a leaf path.
+ _data : dict[str, qten.Tensor]
+ Tensor payload owned by the node.
+ _parent : Node or None
+ Parent node in the tree. The root node stores `None`.
+ _leaves : dict[str, _LeavePath]
+ Cached leaf paths produced by growth methods.
+ _methods : dict[str, callable]
+ Registered growth methods. Each method receives the current node and
+ returns a mapping from leaf lookup names to child nodes.
+
+ Notes
+ -----
+ `Node` inherits from `Mapping`, so `node[key]`, `key in node`, iteration,
+ and `len(node)` are delegated to `_data`.
+
+ Growth methods are owned by the root node. Registering a method from any
+ node forwards the registration to the root, and growth-method lookup also
+ resolves through the root only.
+
+ The tree API distinguishes between:
+
+ - The leaf lookup key in `_leaves`, chosen by the growth method's returned
+ dictionary key.
+ - The child node's `_target`, which identifies the tensor inside the child
+ that becomes the stored `_LeavePath.transform`.
+
+ When a node is pickled, `_methods` is excluded from the serialized state.
+ After unpickling, the method registry is restored as an empty dictionary
+ and must be repopulated explicitly by the caller.
+
+ Through [`plot()`][qten.plottings.Plottable.plot], nodes also support a
+ `tree` plot method with backends `networkx`, `matplotlib`, and `plotly`.
+ The `networkx` backend returns a directed graph for the whole tree rooted
+ at `node.root()`, with the current node highlighted in node attributes and
+ layout positions stored in graph-level metadata. The `matplotlib` and
+ `plotly` backends render that tree with boxed, wrapped labels and
+ highlight the current node directly in the figure. Calling
+ `node.plot("tree")` without an explicit backend defaults to the
+ `matplotlib` tree renderer.
+
+ Examples
+ --------
+ Build a root node and access its target tensor by mapping key:
+
+ >>> import qten
+ >>> root = Node.new(name="root", target=qten.eye(2), methods={})
+ >>> "root" in root
+ True
+ >>> root["root"]
+ Tensor(...)
+ """
+
+ name: str
+ _target: str
+ _data: dict[str, TensorLike]
+ _parent: "Node | None"
+ _leaves: dict[str, _LeavePath]
+ _methods: dict[str, _GrowMethod]
+
+ def __getitem__(self, key: str) -> TensorLike:
+ """
+ Return a tensor stored on this node by name.
+
+ Parameters
+ ----------
+ key : str
+ Tensor name to retrieve from the node payload.
+
+ Returns
+ -------
+ qten.Tensor
+ Stored tensor associated with `key`.
+
+ Raises
+ ------
+ KeyError
+ If `key` is not present in `_data`.
+ """
+ return self._data[key]
+
+ def __iter__(self) -> Iterator[str]:
+ """
+ Iterate over stored tensor names.
+
+ Returns
+ -------
+ iterator of str
+ Iterator over `_data` keys in insertion order.
+ """
+ return iter(self._data)
+
+ def __len__(self) -> int:
+ """
+ Return the number of tensors stored on this node.
+
+ Returns
+ -------
+ int
+ Number of entries in `_data`.
+ """
+ return len(self._data)
+
+ def plot(
+ self,
+ method: str,
+ backend: str | None = None,
+ *args: object,
+ **kwargs: object,
+ ) -> object:
+ """
+ Dispatch plotting calls, defaulting tree plots to matplotlib output.
+
+ Parameters
+ ----------
+ method : str
+ Plot method name.
+ backend : str or None, optional
+ Backend name requested by the caller. When omitted, tree plots
+ default to `matplotlib` while other plot methods keep the usual
+ `plotly` default.
+ args
+ Positional arguments forwarded to the selected backend.
+ kwargs
+ Keyword arguments forwarded to the selected backend.
+
+ Returns
+ -------
+ object
+ Backend-specific plot object returned by the selected renderer.
+ """
+ if backend is None or backend == "":
+ backend = "matplotlib" if method == "tree" else "plotly"
+ return super().plot(method, backend, *args, **kwargs)
+
+ def __getstate__(self) -> dict[str, object]:
+ """
+ Return pickle state for the node without the method registry.
+
+ Returns
+ -------
+ dict
+ Instance state with `_methods` replaced by an empty dictionary.
+
+ Notes
+ -----
+ Excluding `_methods` keeps serialization independent of whether the
+ registered callables are picklable. Callers are expected to restore the
+ root method registry manually after loading.
+ """
+ state = cast(dict[str, object], self.__dict__.copy())
+ state["_methods"] = {}
+ return state
+
+ def __setstate__(self, state: dict[str, object]) -> None:
+ """
+ Restore node state from pickle data.
+
+ Parameters
+ ----------
+ state : dict
+ Previously pickled instance state.
+
+ Notes
+ -----
+ `_methods` is always restored to an empty dictionary if it is missing
+ or was intentionally stripped during pickling.
+ """
+ self.__dict__.update(state)
+ self._methods = {}
+
+ def register_method(self, method_name: str, method: _GrowMethod) -> Self:
+ """
+ Register or replace a growth method on the root registry.
+
+ Parameters
+ ----------
+ method_name : str
+ Name used later with [`grow()`][qrg.tree.Node.grow].
+ method : callable
+ Callable receiving this node and returning
+ `dict[str, Node]`. The returned dictionary keys become leaf lookup
+ names under this node.
+
+ Returns
+ -------
+ Node
+ The current node, allowing fluent chaining even when registration
+ is initiated from a descendant.
+
+ Notes
+ -----
+ Registration is always applied to the root node's `_methods`
+ dictionary. Registering a method with an existing name replaces the
+ previous method without warning.
+
+ Examples
+ --------
+ >>> root = Node.new(name="root", target=qten.eye(2), methods={})
+ >>> root.register_method("step", lambda node: {})
+ Node(...)
+ """
+ self.root()._methods[method_name] = method
+ return self
+
+ @staticmethod
+ def growth(name: str) -> Callable[[_GrowMethod], _GrowMethod]:
+ """
+ Mark a function in an external script as a loadable growth method.
+
+ Parameters
+ ----------
+ name : str
+ Registry name to use when the function is discovered by
+ [`register_methods()`][qrg.tree.Node.register_methods].
+
+ Returns
+ -------
+ callable
+ Decorator that attaches growth-registration metadata to the
+ function and returns it unchanged.
+
+ Notes
+ -----
+ This decorator does not register the function immediately. It only
+ marks the function so that [`register_methods()`][qrg.tree.Node.register_methods]
+ can discover it when loading a script file.
+
+ Examples
+ --------
+ >>> @Node.growth("step")
+ ... def grow_step(node: Node) -> dict[str, Node]:
+ ... return {}
+ """
+
+ def decorator(method: _GrowMethod) -> _GrowMethod:
+ setattr(method, _GROWTH_MARKER, name)
+ return method
+
+ return decorator
+
+ def register_methods(self, path: str) -> Self:
+ """
+ Load and register annotated growth methods from a Python script.
+
+ Parameters
+ ----------
+ path : str
+ Path to a Python source file. Relative paths are resolved against
+ the current working directory.
+
+ Returns
+ -------
+ Node
+ The current node, allowing fluent chaining even when registration
+ is initiated from a descendant.
+
+ Raises
+ ------
+ FileNotFoundError
+ If `path` does not point to an existing file.
+ ImportError
+ If the script cannot be loaded as a Python module.
+ Exception
+ Propagates any exception raised while executing the script.
+
+ Notes
+ -----
+ Only callables marked with [`@Node.growth(...)`][qrg.tree.Node.growth]
+ are registered. Unannotated functions in the script are ignored.
+
+ Registration is forwarded to the root node's registry in the same way
+ as [`register_method()`][qrg.tree.Node.register_method].
+
+ Examples
+ --------
+ >>> root = Node.new(name="root", target=qten.eye(2), methods={})
+ >>> root.register_methods("growths.py")
+ Node(...)
+ """
+ script_path = Path(path)
+ if not script_path.is_absolute():
+ script_path = Path.cwd() / script_path
+ script_path = script_path.resolve()
+
+ if not script_path.is_file():
+ raise FileNotFoundError(f"Growth-method script not found: {script_path}")
+
+ module_name = f"_qrg_growths_{script_path.stem}_{abs(hash(script_path))}"
+ spec = importlib.util.spec_from_file_location(module_name, script_path)
+ if spec is None or spec.loader is None:
+ raise ImportError(f"Could not load growth-method script: {script_path}")
+
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module)
+
+ for value in vars(module).values():
+ growth_name = getattr(value, _GROWTH_MARKER, None)
+ if growth_name is None:
+ continue
+ self.register_method(growth_name, value)
+ return self
+
+ def target(self) -> TensorLike:
+ """
+ Return the current target tensor for this node.
+
+ Returns
+ -------
+ qten.Tensor
+ Tensor stored at the node's `_target` key.
+
+ Raises
+ ------
+ KeyError
+ If `_target` is not present in `_data`.
+
+ Notes
+ -----
+ This is a convenience accessor for `self[self._target]`.
+
+ Examples
+ --------
+ >>> root = Node.new(name="root", target=qten.eye(2), methods={})
+ >>> root.target()
+ Tensor(...)
+ """
+ return self[self._target]
+
+ def compute(self, name: str, f: Callable[["Node"], TensorLike]) -> Self:
+ """
+ Compute and store a derived tensor on this node.
+
+ Parameters
+ ----------
+ name : str
+ Key under which to store the computed tensor in `_data`.
+ f : callable
+ Callable receiving the current node and returning the tensor to
+ store.
+
+ Returns
+ -------
+ Node
+ The current node, allowing fluent chaining.
+
+ Raises
+ ------
+ Exception
+ Propagates any exception raised by `f(node)`.
+
+ Notes
+ -----
+ If `name` already exists in `_data`, its previous value is replaced.
+
+ Examples
+ --------
+ >>> root = Node.new(name="root", target=qten.eye(2), methods={})
+ >>> root.compute("copy", lambda node: node.target())
+ Node(...)
+ >>> root["copy"]
+ Tensor(...)
+ """
+ self._data[name] = f(self)
+ return self
+
+ def leaf(self, target: str) -> _LeavePath:
+ """
+ Return a previously materialized leaf path by lookup key.
+
+ Parameters
+ ----------
+ target : str
+ Leaf lookup key produced by a prior call to
+ [`grow()`][qrg.tree.Node.grow].
+
+ Returns
+ -------
+ _LeavePath
+ Stored path containing the child node and its transform tensor.
+
+ Raises
+ ------
+ ValueError
+ If `target` is not present in `_leaves`. This usually means the
+ node has not been grown with the relevant method yet, or the
+ requested leaf name does not exist.
+
+ Examples
+ --------
+ After `grow()` completes, retrieve one child path:
+
+ >>> path = root.leaf("child")
+ >>> path.node
+ Node(...)
+ """
+ if target not in self._leaves:
+ raise ValueError(f"Target {target} not found in leaves of node {self.name}")
+ return self._leaves[target]
+
+ def branches(self) -> dict[str, _LeavePath]:
+ """
+ Return a shallow copy of the current node's direct branches.
+
+ Returns
+ -------
+ dict[str, _LeavePath]
+ Mapping from direct leaf lookup keys to stored leaf-path metadata.
+
+ Notes
+ -----
+ The returned dictionary is a shallow copy, so mutating it does not
+ modify the node's internal `_leaves` mapping.
+ """
+ return dict(self._leaves)
+
+ def cut(self, name: str) -> "Node":
+ """
+ Detach a direct child branch from the current node.
+
+ Parameters
+ ----------
+ name : str
+ Leaf lookup key under the current node identifying the branch to
+ detach.
+
+ Returns
+ -------
+ Node
+ Root node of the detached subtree.
+
+ Raises
+ ------
+ ValueError
+ If `name` is not present in the current node's leaves.
+
+ Notes
+ -----
+ The returned node becomes the root of its own tree by clearing its
+ parent link. Only the direct branch selected by `name` is removed from
+ the current node.
+ """
+ if name not in self._leaves:
+ raise ValueError(f"Target {name} not found in leaves of node {self.name}")
+ leaf = self._leaves.pop(name)
+ leaf.node._parent = None
+ return leaf.node
+
+ def trace(self, expr: str) -> "Node":
+ """
+ Resolve a node by dot-delimited leaf path.
+
+ Parameters
+ ----------
+ expr : str
+ Trace expression to resolve. Expressions of the form
+ `"a.b.c"` are resolved from the root node. Expressions beginning
+ with `"."`, such as `".a.b"`, are resolved relative to the
+ current node. The empty absolute path `""` resolves to the root,
+ and `"."` resolves to the current node.
+
+ Returns
+ -------
+ Node
+ Node reached by following the requested leaf path.
+
+ Raises
+ ------
+ ValueError
+ If the trace expression is malformed or cannot be fully resolved.
+ Errors identify the segment where resolution failed as
+ `trace cannot be resolve at "a.[b].c"`.
+ """
+ is_relative = expr.startswith(".")
+ raw_path = expr[1:] if is_relative else expr
+ current = self if is_relative else self.root()
+
+ if raw_path == "":
+ return current
+
+ parts = raw_path.split(".")
+ if not parts or any(part == "" for part in parts):
+ raise ValueError(f'Invalid trace expression "{expr}"')
+ for index, part in enumerate(parts):
+ if part not in current._leaves:
+ resolved = parts[:index]
+ unresolved = parts[index + 1 :]
+ highlighted = ".".join(resolved + [f"[{part}]"] + unresolved)
+ raise ValueError(f'trace cannot be resolve at "{highlighted}"')
+ current = current._leaves[part].node
+ return current
+
+ def path(self) -> str:
+ """
+ Return the dot-delimited absolute leaf-key path from the root.
+
+ Returns
+ -------
+ str
+ Canonical absolute path from the root to this node. The root
+ node's path is the empty string `""`.
+
+ Raises
+ ------
+ ValueError
+ If the node is not reachable from its recorded parent chain using
+ the parents' leaf mappings.
+ """
+ if self._parent is None:
+ return ""
+
+ parts: list[str] = []
+ current = self
+ while current._parent is not None:
+ parent = current._parent
+ for leaf_name, leaf in parent._leaves.items():
+ if leaf.node is current:
+ parts.append(leaf_name)
+ current = parent
+ break
+ else:
+ raise RuntimeError(
+ f"Node {current.name} is not reachable from parent {parent.name}"
+ )
+ parts.reverse()
+ return ".".join(parts)
+
+ def get_transform(self, expr: str, *, composed: bool = True) -> TensorLike | list[TensorLike]:
+ """
+ Return the transform(s) from the current node to a traced descendant.
+
+ Parameters
+ ----------
+ expr : str
+ Trace expression interpreted with the same rules as
+ [`trace()`][qrg.tree.Node.trace].
+ composed : bool, default=True
+ If `True`, return a single composed transform tensor `T`. If
+ `False`, return the ordered list of direct branch transforms whose
+ product would form `T`.
+
+ Returns
+ -------
+ qten.Tensor or list[qten.Tensor]
+ Either the composed transform or the ordered list of transforms
+ from the current node to the traced node.
+
+ Raises
+ ------
+ ValueError
+ If `expr` resolves to a node outside the current node's subtree.
+
+ Notes
+ -----
+ For a descendant reached by transforms `[T1, T2, ..., Tn]`, the
+ composed result is `T1 @ T2 @ ... @ Tn`, so the descendant target is
+ obtained from the current target as
+ `T.h(-2, -1) @ target @ T`.
+
+ Using this transform chain in the inverse/backward direction may
+ create very large intermediate tensors. In particular, for structured
+ tensor types such as momentum-block tensors, lifting a descendant
+ operator back toward an ancestor can expand into a large pair-resolved
+ space. For actual transport computations, staged application with
+ `composed=False` is usually preferable to materializing one large
+ composed transform.
+ """
+ destination = self.trace(expr)
+ source_path = self.path()
+ destination_path = destination.path()
+
+ if source_path == destination_path:
+ relative_parts: list[str] = []
+ elif source_path == "":
+ relative_parts = destination_path.split(".")
+ elif destination_path.startswith(source_path + "."):
+ relative_parts = destination_path[len(source_path) + 1 :].split(".")
+ else:
+ raise ValueError(
+ f'Trace "{expr}" does not resolve to the current node or its descendants'
+ )
+
+ transforms: list[TensorLike] = []
+ current = self
+ for part in relative_parts:
+ leaf = current.leaf(part)
+ transforms.append(leaf.transform)
+ current = leaf.node
+
+ if not composed:
+ return transforms
+
+ if not transforms:
+ target = self.target()
+ return qten.eye(target.dims, device=target.device)
+
+ transform = transforms[0]
+ for step in transforms[1:]:
+ transform = transform @ step
+ return transform
+
+ def find(
+ self, *, regex: str | None = None, predicate: Callable[["Node"], bool] | None = None
+ ) -> list["Node"]:
+ """
+ Find nodes in the current subtree by name regex and/or predicate.
+
+ Parameters
+ ----------
+ regex : str, optional
+ Regular expression matched against each node's `name`.
+ predicate : callable, optional
+ Additional boolean filter applied to each visited node.
+
+ Returns
+ -------
+ list[Node]
+ All matching nodes in depth-first traversal order, starting from
+ the current node.
+
+ Raises
+ ------
+ ValueError
+ If neither `regex` nor `predicate` is provided.
+
+ Notes
+ -----
+ When both filters are supplied, a node must satisfy both.
+ """
+ if regex is None and predicate is None:
+ raise ValueError("find() requires regex and/or predicate")
+
+ pattern = re.compile(regex) if regex is not None else None
+ matches: list[Node] = []
+ stack = [self]
+
+ while stack:
+ current = stack.pop()
+ regex_match = pattern.search(current.name) is not None if pattern else True
+ predicate_match = predicate(current) if predicate else True
+ if regex_match and predicate_match:
+ matches.append(current)
+
+ children = [path.node for path in current._leaves.values()]
+ stack.extend(reversed(children))
+
+ return matches
+
+ def parent(self) -> "Node":
+ """
+ Return the parent node.
+
+ Returns
+ -------
+ Node
+ Parent of the current node.
+
+ Raises
+ ------
+ ValueError
+ If this node is the root and therefore has no parent.
+
+ Notes
+ -----
+ Child parent links are assigned automatically by
+ [`grow()`][qrg.tree.Node.grow].
+ """
+ if self._parent is None:
+ raise ValueError(f"Node {self.name} has no parent")
+ return self._parent
+
+ def root(self) -> "Node":
+ """
+ Return the root node of the current tree.
+
+ Returns
+ -------
+ Node
+ Top-most ancestor reachable by following parent links. For the root
+ node itself, this method returns `self`.
+
+ Notes
+ -----
+ The root owns the tree-wide growth-method registry used by
+ [`register_method()`][qrg.tree.Node.register_method] and
+ [`grow()`][qrg.tree.Node.grow].
+
+ Examples
+ --------
+ >>> root = Node.new(name="root", target=qten.eye(2), methods={})
+ >>> root.root() is root
+ True
+ """
+ root = self
+ while root._parent is not None:
+ root = root._parent
+ return root
+
+ def is_root(self) -> bool:
+ """
+ Return whether this node is the root of its tree.
+
+ Returns
+ -------
+ bool
+ `True` when the node has no parent, otherwise `False`.
+ """
+ return self._parent is None
+
+ def is_leaf(self) -> bool:
+ """
+ Return whether this node currently has no child branches.
+
+ Returns
+ -------
+ bool
+ `True` when the node has no direct leaves, otherwise `False`.
+ """
+ return len(self._leaves) == 0
+
+ def grow(self, method: str) -> Self:
+ """
+ Expand this node using a registered growth method.
+
+ Parameters
+ ----------
+ method : str
+ Name of a previously registered growth method.
+
+ Returns
+ -------
+ Node
+ The current node after updating `_leaves`.
+
+ Raises
+ ------
+ ValueError
+ If `method` has not been registered on this node.
+ KeyError
+ If a returned child node does not contain its `_target` key in its
+ `_data` mapping. In that case the leaf transform cannot be derived.
+
+ Notes
+ -----
+ The growth method returns `dict[str, Node]`. For each returned child:
+
+ - The dictionary key becomes the lookup key in `self._leaves`.
+ - The transform stored in `_LeavePath` is `child[child._target]`.
+ - The child node's `_parent` is set to `self`.
+ - The growth method itself is always resolved from the root node's
+ `_methods` registry, regardless of which node calls `grow()`.
+
+ If a leaf key already exists, it is overwritten by the latest growth
+ result with the same name.
+
+ Examples
+ --------
+ >>> root.grow("step")
+ Node(...)
+ >>> child_path = root.leaf("child")
+ >>> child_path.node.parent() is root
+ True
+ """
+ root = self.root()
+ if method not in root._methods:
+ raise ValueError(f"Method {method} not found in node {self.name}")
+ new_nodes = root._methods[method](self)
+ new_leaves = {
+ target: _LeavePath(node=node, transform=node[node._target])
+ for target, node in new_nodes.items()
+ }
+ for leaf in new_leaves.values():
+ leaf.node._parent = self
+ self._leaves.update(new_leaves)
+ return self
+
+ @staticmethod
+ def new(*, name: str, target: TensorLike, methods: dict[str, _GrowMethod]) -> "Node":
+ """
+ Construct a root node from a single target tensor.
+
+ Parameters
+ ----------
+ name : str
+ Root node name. This also becomes the initial `_target` key.
+ target : qten.Tensor
+ Target tensor stored on the root node under `name`.
+ methods : dict[str, callable]
+ Initial growth-method registry for the root node.
+
+ Returns
+ -------
+ Node
+ Newly created root node with no parent and no leaves.
+
+ Notes
+ -----
+ The created node stores:
+
+ - `_target == name`
+ - `_data == {name: target}`
+ - `_parent is None`
+
+ Examples
+ --------
+ >>> import qten
+ >>> root = Node.new(name="hamiltonian", target=qten.eye(2), methods={})
+ >>> root.name
+ 'hamiltonian'
+ >>> list(root)
+ ['hamiltonian']
+ """
+ return Node(
+ name=name,
+ _target=name,
+ _data={name: target},
+ _parent=None,
+ _leaves={},
+ _methods=methods,
+ )
+
+
+def _tree_layout(root: Node) -> dict[int, tuple[float, float]]:
+ positions: dict[int, tuple[float, float]] = {}
+ next_x = 0.0
+
+ def assign(node: Node, depth: int) -> float:
+ nonlocal next_x
+ children = [path.node for path in node._leaves.values()]
+ if not children:
+ x = next_x
+ next_x += 1.0
+ else:
+ child_xs = [assign(child, depth + 1) for child in children]
+ x = sum(child_xs) / len(child_xs)
+ positions[id(node)] = (x, -float(depth))
+ return x
+
+ assign(root, 0)
+ return positions
+
+
+def _wrap_node_label(label: str, target_key: str, *, width: int) -> str:
+ text = f"{label}\n[{target_key}]"
+ return "\n".join(textwrap.fill(line, width=width) for line in text.splitlines())
+
+
+def _tree_node_style(*, is_current: bool, is_root: bool) -> dict[str, object]:
+ if is_current:
+ return {
+ "facecolor": "#d1495b",
+ "edgecolor": "#7f1d1d",
+ "textcolor": "white",
+ "linewidth": 2.2,
+ }
+ if is_root:
+ return {
+ "facecolor": "#edc948",
+ "edgecolor": "#8c6d1f",
+ "textcolor": "black",
+ "linewidth": 2.0,
+ }
+ return {
+ "facecolor": "#d6ebf5",
+ "edgecolor": "#355070",
+ "textcolor": "#1f2933",
+ "linewidth": 1.6,
+ }
+
+
+@Node.register_plot_method("tree", backend="networkx") # type: ignore[untyped-decorator]
+def _plot_node_tree_networkx(node: Node) -> nx.DiGraph:
+ """
+ Build a `networkx` graph representation of the full node tree.
+
+ Parameters
+ ----------
+ node : Node
+ Current node from which plotting was requested.
+
+ Returns
+ -------
+ networkx.DiGraph
+ Directed graph containing every node reachable from `node.root()`.
+ Graph metadata includes:
+
+ - `graph["root"]`: root node of the tree
+ - `graph["current"]`: node on which `plot()` was invoked
+ - `graph["positions"]`: hierarchical layout positions keyed by graph
+ node id
+
+ Graph node ids are `id(node)` for the corresponding `Node` object.
+ Each graph node stores:
+
+ - `node`: original `Node` instance
+ - `label`: node name
+ - `target_key`: node `_target`
+ - `is_current`: whether this is the requested node
+ - `is_root`: whether this is the root node
+
+ Each directed edge stores:
+
+ - `target`: leaf lookup key on the parent
+ - `transform`: corresponding `_LeavePath.transform`
+
+ """
+ root = node.root()
+ positions = _tree_layout(root)
+ graph = nx.DiGraph()
+ graph.graph["root"] = root
+ graph.graph["current"] = node
+ graph.graph["positions"] = positions
+
+ stack = [root]
+ while stack:
+ current = stack.pop()
+ current_id = id(current)
+ graph.add_node(
+ current_id,
+ node=current,
+ label=current.name,
+ target_key=current._target,
+ is_current=current is node,
+ is_root=current is root,
+ )
+ for target, path in current._leaves.items():
+ child = path.node
+ child_id = id(child)
+ graph.add_node(
+ child_id,
+ node=child,
+ label=child.name,
+ target_key=child._target,
+ is_current=child is node,
+ is_root=False,
+ )
+ graph.add_edge(current_id, child_id, target=target, transform=path.transform)
+ stack.append(child)
+
+ return graph
+
+
+@Node.register_plot_method("tree", backend="matplotlib") # type: ignore[untyped-decorator]
+def _(
+ node: Node,
+ *,
+ ax: Any = None,
+ figsize: tuple[float, float] = (10.0, 6.0),
+ label_width: int = 18,
+) -> Any:
+ """
+ Render the full node tree as a matplotlib figure with wrapped node labels.
+
+ Parameters
+ ----------
+ node : Node
+ Current node from which plotting was requested.
+ ax : matplotlib.axes.Axes, optional
+ Existing axes to draw on. If omitted, a new figure and axes are
+ created.
+ figsize : tuple[float, float], default=(10.0, 6.0)
+ Figure size used when `ax` is not provided.
+ label_width : int, default=18
+ Maximum character width before node labels are wrapped onto new lines.
+
+ Returns
+ -------
+ matplotlib.figure.Figure
+ Figure containing the rendered tree.
+
+ Notes
+ -----
+ The current node is highlighted in crimson, the root in gold, and all
+ other nodes in light blue. Labels are wrapped before being placed inside
+ rounded bounding boxes so long names fit within node shapes more cleanly.
+ Connectors are drawn as orthogonal rounded-corner arrows so branching
+ structure reads more like a diagram than a raw graph.
+ """
+ graph = _plot_node_tree_networkx(node)
+ positions = graph.graph["positions"]
+
+ created_fig = False
+ if ax is None:
+ fig, ax = plt.subplots(figsize=figsize)
+ created_fig = True
+ else:
+ fig = ax.figure
+
+ fig.patch.set_facecolor("#f7f3ea")
+ ax.set_facecolor("#f7f3ea")
+
+ for src, dst in graph.edges:
+ x0, y0 = positions[src]
+ x1, y1 = positions[dst]
+ y_start = y0 - 0.16
+ junction_y = y0 - 0.42
+ y_end = y1 + 0.04
+
+ connector_path = MplPath(
+ [
+ (x0, y_start),
+ (x0, junction_y),
+ (x1, junction_y),
+ (x1, y_end),
+ ],
+ [
+ MplPath.MOVETO,
+ MplPath.LINETO,
+ MplPath.LINETO,
+ MplPath.LINETO,
+ ],
+ )
+ connector = PathPatch(
+ connector_path,
+ facecolor="none",
+ edgecolor="#6b7a8f",
+ linewidth=1.8,
+ capstyle="round",
+ joinstyle="round",
+ zorder=2,
+ alpha=0.9,
+ )
+ connector.set_path_effects([pe.Stroke(linewidth=3.2, foreground="#ebe4d8"), pe.Normal()])
+ ax.add_patch(connector)
+
+ for node_id, attrs in graph.nodes(data=True):
+ x, y = positions[node_id]
+ style = _tree_node_style(
+ is_current=attrs["is_current"],
+ is_root=attrs["is_root"],
+ )
+
+ label = _wrap_node_label(attrs["label"], attrs["target_key"], width=label_width)
+ text = ax.text(
+ x,
+ y,
+ label,
+ ha="center",
+ va="center",
+ color=style["textcolor"],
+ fontsize=10,
+ fontweight="bold",
+ fontfamily="monospace",
+ bbox={
+ "boxstyle": "round,pad=0.55,rounding_size=0.25",
+ "facecolor": style["facecolor"],
+ "edgecolor": style["edgecolor"],
+ "linewidth": style["linewidth"],
+ },
+ zorder=3,
+ )
+ text.set_path_effects([pe.withStroke(linewidth=0.5, foreground=style["facecolor"])])
+
+ xs = [x for x, _ in positions.values()]
+ ys = [y for _, y in positions.values()]
+ ax.set_xlim(min(xs) - 0.95, max(xs) + 0.95)
+ ax.set_ylim(min(ys) - 0.9, max(ys) + 0.9)
+ ax.set_axis_off()
+ ax.set_title(
+ f"Tree View: {graph.graph['current'].name}",
+ fontsize=13,
+ fontweight="bold",
+ color="#243447",
+ pad=18,
+ )
+
+ if created_fig:
+ fig.tight_layout()
+ return fig
+
+
+@Node.register_plot_method("tree", backend="plotly") # type: ignore[untyped-decorator]
+def _(
+ node: Node,
+ *,
+ figure: Any = None,
+ label_width: int = 18,
+) -> Any:
+ """
+ Render the full node tree as an interactive Plotly figure.
+
+ Parameters
+ ----------
+ node : Node
+ Current node from which plotting was requested.
+ figure : plotly.graph_objects.Figure, optional
+ Existing figure to populate. If omitted, a new figure is created.
+ label_width : int, default=18
+ Maximum character width before node labels are wrapped onto new lines.
+
+ Returns
+ -------
+ plotly.graph_objects.Figure
+ Figure containing the rendered tree.
+ """
+ graph = _plot_node_tree_networkx(node)
+ positions = graph.graph["positions"]
+
+ if figure is None:
+ fig = go.Figure()
+ else:
+ fig = figure
+
+ edge_x: list[float | None] = []
+ edge_y: list[float | None] = []
+ for src, dst in graph.edges:
+ x0, y0 = positions[src]
+ x1, y1 = positions[dst]
+ y_start = y0 - 0.16
+ junction_y = y0 - 0.42
+ y_end = y1 + 0.04
+ edge_x.extend([x0, x0, x1, x1, None])
+ edge_y.extend([y_start, junction_y, junction_y, y_end, None])
+
+ fig.add_trace(
+ go.Scatter(
+ x=edge_x,
+ y=edge_y,
+ mode="lines",
+ line={"color": "#6b7a8f", "width": 2},
+ hoverinfo="skip",
+ showlegend=False,
+ )
+ )
+
+ node_x: list[float] = []
+ node_y: list[float] = []
+ hover_text: list[str] = []
+ marker_color: list[str] = []
+ marker_line_color: list[str] = []
+ marker_line_width: list[float] = []
+
+ for node_id, attrs in graph.nodes(data=True):
+ x, y = positions[node_id]
+ style = _tree_node_style(
+ is_current=attrs["is_current"],
+ is_root=attrs["is_root"],
+ )
+ label = _wrap_node_label(
+ cast(str, attrs["label"]),
+ cast(str, attrs["target_key"]),
+ width=label_width,
+ )
+ fig.add_annotation(
+ x=x,
+ y=y,
+ text=label.replace("\n", "
"),
+ showarrow=False,
+ xanchor="center",
+ yanchor="middle",
+ align="center",
+ font={
+ "color": style["textcolor"],
+ "size": 11,
+ "family": "Courier New, monospace",
+ },
+ bgcolor=style["facecolor"],
+ bordercolor=style["edgecolor"],
+ borderwidth=style["linewidth"],
+ borderpad=10,
+ )
+ node_x.append(x)
+ node_y.append(y)
+ hover_text.append(
+ f"name: {cast(str, attrs['label'])}
target: {cast(str, attrs['target_key'])}"
+ )
+ marker_color.append(cast(str, style["facecolor"]))
+ marker_line_color.append(cast(str, style["edgecolor"]))
+ marker_line_width.append(cast(float, style["linewidth"]))
+
+ fig.add_trace(
+ go.Scatter(
+ x=node_x,
+ y=node_y,
+ mode="markers",
+ marker={
+ "size": 18,
+ "color": marker_color,
+ "line": {
+ "color": marker_line_color,
+ "width": marker_line_width,
+ },
+ "opacity": 0.0,
+ },
+ text=hover_text,
+ hovertemplate="%{text}",
+ showlegend=False,
+ )
+ )
+
+ xs = [x for x, _ in positions.values()]
+ ys = [y for _, y in positions.values()]
+ fig.update_layout(
+ title={
+ "text": f"Tree View: {graph.graph['current'].name}",
+ "font": {"size": 18, "color": "#243447"},
+ "x": 0.5,
+ },
+ paper_bgcolor="#f7f3ea",
+ plot_bgcolor="#f7f3ea",
+ xaxis={
+ "range": [min(xs) - 0.95, max(xs) + 0.95],
+ "visible": False,
+ },
+ yaxis={
+ "range": [min(ys) - 0.9, max(ys) + 0.9],
+ "visible": False,
+ "scaleanchor": "x",
+ "scaleratio": 1,
+ },
+ margin={"l": 20, "r": 20, "t": 60, "b": 20},
+ )
+ return fig
diff --git a/tests/test_tree.py b/tests/test_tree.py
new file mode 100644
index 0000000..ce09d08
--- /dev/null
+++ b/tests/test_tree.py
@@ -0,0 +1,712 @@
+import pickle
+from pathlib import Path
+from typing import Any, cast
+
+import pytest
+
+import qrg.tree as tree_module
+from qrg.tree import Node
+
+
+class FakeTensor:
+ def __init__(
+ self,
+ name: str,
+ dims: tuple[str, str] = ("d0", "d1"),
+ device: str = "cpu",
+ ) -> None:
+ self.name = name
+ self.dims = dims
+ self.device = device
+
+ def __matmul__(self, other: "FakeTensor") -> "FakeTensor":
+ return FakeTensor(f"({self.name}@{other.name})", dims=self.dims, device=self.device)
+
+ def h(self, *_axes: int) -> "FakeTensor":
+ return FakeTensor(f"{self.name}.h", dims=self.dims, device=self.device)
+
+
+def _tensor_stub() -> Any:
+ return cast(Any, object())
+
+
+def test_grow_sets_parent_and_uses_child_target_for_transform() -> None:
+ root_tensor = _tensor_stub()
+ child_tensor = _tensor_stub()
+
+ def grow_method(_: Node) -> dict[str, Node]:
+ child = Node(
+ name="child",
+ _target="child_transform",
+ _data={"child_transform": child_tensor},
+ _parent=None,
+ _leaves={},
+ _methods={},
+ )
+ return {"external_target_name": child}
+
+ root = Node.new(name="root", target=root_tensor, methods={"grow": grow_method})
+ root.grow("grow")
+
+ leaf = root.leaf("external_target_name")
+ assert leaf.node.parent() is root
+ assert leaf.transform is child_tensor
+
+
+def test_child_grow_looks_up_methods_from_root() -> None:
+ root_tensor = _tensor_stub()
+ child_tensor = _tensor_stub()
+ grandchild_tensor = _tensor_stub()
+
+ def root_grow(_: Node) -> dict[str, Node]:
+ child = Node(
+ name="child",
+ _target="child_transform",
+ _data={"child_transform": child_tensor},
+ _parent=None,
+ _leaves={},
+ _methods={},
+ )
+ return {"child": child}
+
+ def child_grow(_: Node) -> dict[str, Node]:
+ grandchild = Node(
+ name="grandchild",
+ _target="grandchild_transform",
+ _data={"grandchild_transform": grandchild_tensor},
+ _parent=None,
+ _leaves={},
+ _methods={},
+ )
+ return {"grandchild": grandchild}
+
+ root = Node.new(
+ name="root",
+ target=root_tensor,
+ methods={"root_grow": root_grow, "child_grow": child_grow},
+ )
+ root.grow("root_grow")
+
+ child = root.leaf("child").node
+ child.grow("child_grow")
+ grandchild = child.leaf("grandchild")
+ assert grandchild.node.parent() is child
+ assert grandchild.transform is grandchild_tensor
+
+
+def test_register_method_from_child_updates_root_registry() -> None:
+ root_tensor = _tensor_stub()
+ child_tensor = _tensor_stub()
+
+ def grow_method(_: Node) -> dict[str, Node]:
+ child = Node(
+ name="child",
+ _target="child_transform",
+ _data={"child_transform": child_tensor},
+ _parent=None,
+ _leaves={},
+ _methods={},
+ )
+ return {"child": child}
+
+ def new_method(_: Node) -> dict[str, Node]:
+ return {}
+
+ root = Node.new(name="root", target=root_tensor, methods={"grow": grow_method})
+ root.grow("grow")
+
+ child = root.leaf("child").node
+ child.register_method("new", new_method)
+ assert root._methods["new"] is new_method
+ assert "new" not in child._methods
+
+
+def test_target_returns_tensor_at_current_target_key() -> None:
+ root_tensor = _tensor_stub()
+ child_tensor = _tensor_stub()
+
+ def grow_method(_: Node) -> dict[str, Node]:
+ child = Node(
+ name="child",
+ _target="child_transform",
+ _data={"child_transform": child_tensor},
+ _parent=None,
+ _leaves={},
+ _methods={},
+ )
+ return {"child": child}
+
+ root = Node.new(name="root", target=root_tensor, methods={"grow": grow_method})
+ root.grow("grow")
+
+ child = root.leaf("child").node
+ assert root.target() is root_tensor
+ assert child.target() is child_tensor
+
+
+def test_compute_stores_derived_tensor_in_data() -> None:
+ root_tensor = _tensor_stub()
+ derived_tensor = _tensor_stub()
+
+ root = Node.new(name="root", target=root_tensor, methods={})
+ returned = root.compute("derived", lambda node: derived_tensor)
+
+ assert returned is root
+ assert root["derived"] is derived_tensor
+
+
+def test_cut_detaches_direct_branch_into_new_root() -> None:
+ root_tensor = _tensor_stub()
+ child_tensor = _tensor_stub()
+
+ def grow_method(_: Node) -> dict[str, Node]:
+ child = Node(
+ name="child",
+ _target="child_transform",
+ _data={"child_transform": child_tensor},
+ _parent=None,
+ _leaves={},
+ _methods={},
+ )
+ return {"child": child}
+
+ root = Node.new(name="root", target=root_tensor, methods={"grow": grow_method})
+ root.grow("grow")
+
+ child = root.cut("child")
+
+ assert child.name == "child"
+ assert child.root() is child
+ with pytest.raises(ValueError, match="Target child not found in leaves of node root"):
+ root.leaf("child")
+ with pytest.raises(ValueError, match="Node child has no parent"):
+ child.parent()
+
+
+def test_cut_raises_for_missing_direct_branch() -> None:
+ root = Node.new(name="root", target=_tensor_stub(), methods={})
+
+ with pytest.raises(ValueError, match="Target missing not found in leaves of node root"):
+ root.cut("missing")
+
+
+def test_branches_returns_shallow_copy_of_direct_leaves() -> None:
+ root_tensor = _tensor_stub()
+ child_tensor = _tensor_stub()
+
+ def grow_method(_: Node) -> dict[str, Node]:
+ child = Node(
+ name="child",
+ _target="child_transform",
+ _data={"child_transform": child_tensor},
+ _parent=None,
+ _leaves={},
+ _methods={},
+ )
+ return {"child": child}
+
+ root = Node.new(name="root", target=root_tensor, methods={"grow": grow_method})
+ root.grow("grow")
+
+ branches = root.branches()
+
+ assert set(branches) == {"child"}
+ assert branches["child"].node is root.leaf("child").node
+
+ branches.pop("child")
+ assert set(root.branches()) == {"child"}
+
+
+def test_root_returns_topmost_ancestor() -> None:
+ root_tensor = _tensor_stub()
+ child_tensor = _tensor_stub()
+
+ def grow_method(_: Node) -> dict[str, Node]:
+ child = Node(
+ name="child",
+ _target="child_transform",
+ _data={"child_transform": child_tensor},
+ _parent=None,
+ _leaves={},
+ _methods={},
+ )
+ return {"child": child}
+
+ root = Node.new(name="root", target=root_tensor, methods={"grow": grow_method})
+ root.grow("grow")
+
+ child = root.leaf("child").node
+ assert root.root() is root
+ assert child.root() is root
+
+
+def test_is_root_reflects_parent_link() -> None:
+ root_tensor = _tensor_stub()
+ child_tensor = _tensor_stub()
+
+ def grow_method(_: Node) -> dict[str, Node]:
+ child = Node(
+ name="child",
+ _target="child_transform",
+ _data={"child_transform": child_tensor},
+ _parent=None,
+ _leaves={},
+ _methods={},
+ )
+ return {"child": child}
+
+ root = Node.new(name="root", target=root_tensor, methods={"grow": grow_method})
+ root.grow("grow")
+ child = root.leaf("child").node
+
+ assert root.is_root() is True
+ assert child.is_root() is False
+
+
+def test_is_leaf_reflects_whether_node_has_direct_branches() -> None:
+ root_tensor = _tensor_stub()
+ child_tensor = _tensor_stub()
+
+ def grow_method(_: Node) -> dict[str, Node]:
+ child = Node(
+ name="child",
+ _target="child_transform",
+ _data={"child_transform": child_tensor},
+ _parent=None,
+ _leaves={},
+ _methods={},
+ )
+ return {"child": child}
+
+ root = Node.new(name="root", target=root_tensor, methods={"grow": grow_method})
+
+ assert root.is_leaf() is True
+
+ root.grow("grow")
+ child = root.leaf("child").node
+
+ assert root.is_leaf() is False
+ assert child.is_leaf() is True
+
+
+def test_trace_resolves_absolute_and_relative_paths() -> None:
+ root_tensor = _tensor_stub()
+ child_tensor = _tensor_stub()
+ grandchild_tensor = _tensor_stub()
+
+ def child_grow(_: Node) -> dict[str, Node]:
+ grandchild = Node(
+ name="grandchild",
+ _target="grandchild_transform",
+ _data={"grandchild_transform": grandchild_tensor},
+ _parent=None,
+ _leaves={},
+ _methods={},
+ )
+ return {"grandchild": grandchild}
+
+ def root_grow(_: Node) -> dict[str, Node]:
+ child = Node(
+ name="child",
+ _target="child_transform",
+ _data={"child_transform": child_tensor},
+ _parent=None,
+ _leaves={},
+ _methods={},
+ )
+ return {"child": child}
+
+ root = Node.new(
+ name="root",
+ target=root_tensor,
+ methods={"root_grow": root_grow, "child_grow": child_grow},
+ )
+ root.grow("root_grow")
+ child = root.leaf("child").node
+ child.grow("child_grow")
+ grandchild = child.leaf("grandchild").node
+
+ assert root.trace("") is root
+ assert child.trace("") is root
+ assert child.trace(".") is child
+ assert root.trace("child.grandchild") is grandchild
+ assert child.trace(".grandchild") is grandchild
+
+
+def test_path_returns_absolute_leaf_key_trace() -> None:
+ root_tensor = _tensor_stub()
+ child_tensor = _tensor_stub()
+ grandchild_tensor = _tensor_stub()
+
+ def child_grow(_: Node) -> dict[str, Node]:
+ grandchild = Node(
+ name="grandchild",
+ _target="grandchild_transform",
+ _data={"grandchild_transform": grandchild_tensor},
+ _parent=None,
+ _leaves={},
+ _methods={},
+ )
+ return {"grandchild": grandchild}
+
+ def root_grow(_: Node) -> dict[str, Node]:
+ child = Node(
+ name="child",
+ _target="child_transform",
+ _data={"child_transform": child_tensor},
+ _parent=None,
+ _leaves={},
+ _methods={},
+ )
+ return {"child": child}
+
+ root = Node.new(
+ name="root",
+ target=root_tensor,
+ methods={"root_grow": root_grow, "child_grow": child_grow},
+ )
+ root.grow("root_grow")
+ child = root.leaf("child").node
+ child.grow("child_grow")
+ grandchild = child.leaf("grandchild").node
+
+ assert root.path() == ""
+ assert child.path() == "child"
+ assert grandchild.path() == "child.grandchild"
+ assert root.trace(grandchild.path()) is grandchild
+
+
+def test_trace_raises_with_explicit_break_position() -> None:
+ root_tensor = _tensor_stub()
+ child_tensor = _tensor_stub()
+
+ def root_grow(_: Node) -> dict[str, Node]:
+ child = Node(
+ name="child",
+ _target="child_transform",
+ _data={"child_transform": child_tensor},
+ _parent=None,
+ _leaves={},
+ _methods={},
+ )
+ return {"child": child}
+
+ root = Node.new(name="root", target=root_tensor, methods={"grow": root_grow})
+ root.grow("grow")
+ child = root.leaf("child").node
+
+ with pytest.raises(ValueError, match=r'trace cannot be resolve at "child\.\[missing\]\.leaf"'):
+ root.trace("child.missing.leaf")
+ with pytest.raises(ValueError, match=r'trace cannot be resolve at "\[missing\]\.leaf"'):
+ child.trace(".missing.leaf")
+ with pytest.raises(ValueError, match='Invalid trace expression "\\.child\\."'):
+ root.trace(".child.")
+
+
+def test_find_filters_current_subtree_by_regex_and_predicate() -> None:
+ root_tensor = _tensor_stub()
+ alpha_tensor = _tensor_stub()
+ beta_tensor = _tensor_stub()
+ gamma_tensor = _tensor_stub()
+
+ def gamma_grow(_: Node) -> dict[str, Node]:
+ gamma = Node(
+ name="gamma-child",
+ _target="gamma_transform",
+ _data={"gamma_transform": gamma_tensor},
+ _parent=None,
+ _leaves={},
+ _methods={},
+ )
+ return {"gamma": gamma}
+
+ def root_grow(_: Node) -> dict[str, Node]:
+ alpha = Node(
+ name="alpha-child",
+ _target="alpha_transform",
+ _data={"alpha_transform": alpha_tensor},
+ _parent=None,
+ _leaves={},
+ _methods={},
+ )
+ beta = Node(
+ name="beta-child",
+ _target="beta_transform",
+ _data={"beta_transform": beta_tensor},
+ _parent=None,
+ _leaves={},
+ _methods={},
+ )
+ return {"alpha": alpha, "beta": beta}
+
+ root = Node.new(
+ name="root",
+ target=root_tensor,
+ methods={"root_grow": root_grow, "gamma_grow": gamma_grow},
+ )
+ root.grow("root_grow")
+ alpha = root.leaf("alpha").node
+ beta = root.leaf("beta").node
+ alpha.grow("gamma_grow")
+ gamma = alpha.leaf("gamma").node
+
+ assert root.find(regex="child$") == [alpha, gamma, beta]
+ assert root.find(predicate=lambda node: node is not root) == [alpha, gamma, beta]
+ assert root.find(regex="^a", predicate=lambda node: "alpha" in node.name) == [alpha]
+ assert alpha.find(regex="child$") == [alpha, gamma]
+
+
+def test_find_requires_at_least_one_filter() -> None:
+ root = Node.new(name="root", target=_tensor_stub(), methods={})
+
+ with pytest.raises(ValueError, match="find\\(\\) requires regex and/or predicate"):
+ root.find()
+
+
+def test_get_transform_returns_ordered_list_and_composed_tensor(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ root_tensor: Any = FakeTensor("root_target")
+ child_transform: Any = FakeTensor("child_transform")
+ grandchild_transform: Any = FakeTensor("grandchild_transform")
+
+ def child_grow(_: Node) -> dict[str, Node]:
+ grandchild = Node(
+ name="grandchild",
+ _target="grandchild_transform",
+ _data={"grandchild_transform": grandchild_transform},
+ _parent=None,
+ _leaves={},
+ _methods={},
+ )
+ return {"grandchild": grandchild}
+
+ def root_grow(_: Node) -> dict[str, Node]:
+ child = Node(
+ name="child",
+ _target="child_transform",
+ _data={"child_transform": child_transform},
+ _parent=None,
+ _leaves={},
+ _methods={},
+ )
+ return {"child": child}
+
+ root = Node.new(
+ name="root",
+ target=root_tensor,
+ methods={"root_grow": root_grow, "child_grow": child_grow},
+ )
+ root.grow("root_grow")
+ child = root.leaf("child").node
+ child.grow("child_grow")
+
+ monkeypatch.setattr(
+ cast(Any, tree_module.qten), # type: ignore[attr-defined]
+ "eye",
+ lambda dims, *, device=None: FakeTensor("identity", dims=dims, device=device),
+ )
+
+ assert root.get_transform("child.grandchild", composed=False) == [
+ child_transform,
+ grandchild_transform,
+ ]
+ assert cast(Any, root.get_transform("child.grandchild")).name == (
+ "(child_transform@grandchild_transform)"
+ )
+ assert child.get_transform(".grandchild", composed=False) == [grandchild_transform]
+ assert cast(Any, child.get_transform(".")).name == "identity"
+
+
+def test_get_transform_rejects_non_descendant_trace() -> None:
+ root_tensor: Any = FakeTensor("root_target")
+ left_transform: Any = FakeTensor("left_transform")
+ right_transform: Any = FakeTensor("right_transform")
+
+ def root_grow(_: Node) -> dict[str, Node]:
+ left = Node(
+ name="left",
+ _target="left_transform",
+ _data={"left_transform": left_transform},
+ _parent=None,
+ _leaves={},
+ _methods={},
+ )
+ right = Node(
+ name="right",
+ _target="right_transform",
+ _data={"right_transform": right_transform},
+ _parent=None,
+ _leaves={},
+ _methods={},
+ )
+ return {"left": left, "right": right}
+
+ root = Node.new(name="root", target=root_tensor, methods={"grow": root_grow})
+ root.grow("grow")
+ left = root.leaf("left").node
+
+ with pytest.raises(
+ ValueError,
+ match='Trace "right" does not resolve to the current node or its descendants',
+ ):
+ left.get_transform("right")
+
+
+def test_pickle_roundtrip_excludes_methods_registry() -> None:
+ root_tensor = _tensor_stub()
+ child_tensor = _tensor_stub()
+
+ def grow_method(_: Node) -> dict[str, Node]:
+ child = Node(
+ name="child",
+ _target="child_transform",
+ _data={"child_transform": child_tensor},
+ _parent=None,
+ _leaves={},
+ _methods={},
+ )
+ return {"child": child}
+
+ root = Node.new(name="root", target=root_tensor, methods={"grow": grow_method})
+ root.grow("grow")
+
+ restored = pickle.loads(pickle.dumps(root))
+ restored_child = restored.leaf("child").node
+
+ assert restored._methods == {}
+ assert restored_child._methods == {}
+ assert restored_child.parent() is restored
+
+
+def test_register_methods_loads_only_annotated_growths(
+ tmp_path: Path,
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ script_path = tmp_path / "growths.py"
+ script_path.write_text(
+ "\n".join(
+ [
+ "from qrg.tree import Node",
+ "",
+ '@Node.growth("loaded")',
+ "def loaded_method(node: Node) -> dict[str, Node]:",
+ " return {}",
+ "",
+ "def ignored_method(node: Node) -> dict[str, Node]:",
+ " return {}",
+ ]
+ )
+ )
+
+ monkeypatch.chdir(tmp_path)
+ root = Node.new(name="root", target=_tensor_stub(), methods={})
+ root.register_methods("growths.py")
+
+ assert "loaded" in root._methods
+ assert root._methods["loaded"].__name__ == "loaded_method"
+ assert "ignored_method" not in root._methods
+
+
+def test_plot_tree_networkx_returns_whole_tree_with_current_highlight() -> None:
+ root_tensor = _tensor_stub()
+ child_tensor = _tensor_stub()
+ grandchild_tensor = _tensor_stub()
+
+ def child_grow(_: Node) -> dict[str, Node]:
+ grandchild = Node(
+ name="grandchild",
+ _target="grandchild_transform",
+ _data={"grandchild_transform": grandchild_tensor},
+ _parent=None,
+ _leaves={},
+ _methods={},
+ )
+ return {"grandchild_leaf": grandchild}
+
+ def root_grow(_: Node) -> dict[str, Node]:
+ child = Node(
+ name="child",
+ _target="child_transform",
+ _data={"child_transform": child_tensor},
+ _parent=None,
+ _leaves={},
+ _methods={},
+ )
+ return {"child_leaf": child}
+
+ root = Node.new(
+ name="root",
+ target=root_tensor,
+ methods={"root_grow": root_grow, "child_grow": child_grow},
+ )
+ root.grow("root_grow")
+ child = root.leaf("child_leaf").node
+ child.grow("child_grow")
+ grandchild = child.leaf("grandchild_leaf").node
+
+ graph = cast(Any, child.plot("tree", backend="networkx"))
+ root_id = id(root)
+ child_id = id(child)
+ grandchild_id = id(grandchild)
+
+ assert set(graph.nodes) == {root_id, child_id, grandchild_id}
+ assert set(graph.edges) == {(root_id, child_id), (child_id, grandchild_id)}
+ assert graph.graph["root"] is root
+ assert graph.graph["current"] is child
+ assert graph.nodes[root_id]["node"] is root
+ assert graph.nodes[root_id]["is_root"] is True
+ assert graph.nodes[root_id]["is_current"] is False
+ assert graph.nodes[child_id]["is_current"] is True
+ assert graph.edges[root_id, child_id]["target"] == "child_leaf"
+ assert graph.edges[child_id, grandchild_id]["target"] == "grandchild_leaf"
+ assert graph.graph["positions"][root_id][1] == 0.0
+ assert graph.graph["positions"][child_id][1] == -1.0
+
+
+def test_plot_tree_matplotlib_wraps_labels() -> None:
+ root_tensor = _tensor_stub()
+ child_tensor = _tensor_stub()
+
+ def grow_method(_: Node) -> dict[str, Node]:
+ child = Node(
+ name="child node with a very long label",
+ _target="a_very_long_target_key_name",
+ _data={"a_very_long_target_key_name": child_tensor},
+ _parent=None,
+ _leaves={},
+ _methods={},
+ )
+ return {"child_leaf": child}
+
+ root = Node.new(
+ name="root node with a very long label",
+ target=root_tensor,
+ methods={"grow": grow_method},
+ )
+ root.grow("grow")
+
+ fig = cast(Any, root.plot("tree", backend="matplotlib", label_width=12))
+ labels = [text.get_text() for text in fig.axes[0].texts]
+
+ assert any("\n" in label for label in labels)
+
+
+def test_plot_tree_defaults_to_matplotlib_figure() -> None:
+ root = Node.new(name="root", target=_tensor_stub(), methods={})
+
+ fig = cast(Any, root.plot("tree"))
+
+ assert hasattr(fig, "axes")
+
+
+def test_plot_tree_plotly_returns_plotly_figure() -> None:
+ pytest.importorskip("plotly")
+
+ root = Node.new(name="root", target=_tensor_stub(), methods={})
+
+ fig = cast(Any, root.plot("tree", backend="plotly"))
+
+ assert hasattr(fig, "to_plotly_json")
+ assert fig.layout.title.text == "Tree View: root"
+ assert len(fig.layout.annotations) == 1