diff --git a/docs/generate_animations.py b/docs/generate_animations.py new file mode 100644 index 00000000..5dabdc89 --- /dev/null +++ b/docs/generate_animations.py @@ -0,0 +1,63 @@ +""" + Regenerate the classification animations embedded in the documentation. + + Run from the repository root:: + + python docs/generate_animations.py + + For each example this writes three assets into ``docs/source/media``: + + - a ``.gif`` for a quick non-interactive preview, + - a ``.html`` interactive player (matplotlib ``to_jshtml``) with play / pause / step / loop + controls, which is what the docs embed, + + and once, a ``classification_legend.png`` colour legend shared by both examples. +""" +import os + +from paulie import get_pauli_string as p +from paulie import animation_anti_commutation_graph +from paulie.helpers.drawing import save_role_legend + +MEDIA_DIR = os.path.join(os.path.dirname(__file__), "source", "media") + +EXAMPLES = { + # A-type canonical graph -> 4*so(5) + "classification_a_type": { + "generators": ["IYZI", "IIXX", "IIYZ", "IXXI", "XXII", "YZII"], + "n": None, + }, + # B-type canonical graph (a_9) -> sp(4) + "classification_b_type": { + "generators": ["XY", "XZ"], + "n": 4, + }, +} + + +def main() -> None: + """Generate the legend and every documentation animation.""" + legend_path = os.path.join(MEDIA_DIR, "classification_legend.png") + save_role_legend(legend_path) + print(f"wrote {legend_path}") + + for name, spec in EXAMPLES.items(): + generators = (p(spec["generators"], n=spec["n"]) if spec["n"] is not None + else p(spec["generators"])) + print(f"Rendering {name} (algebra = {generators.get_algebra()}) ...") + + # Build the animation once, then export it as both a gif and an interactive player. + ani = animation_anti_commutation_graph(generators, interval=1200, show=False) + + gif_path = os.path.join(MEDIA_DIR, f"{name}.gif") + ani.save(filename=gif_path, writer="pillow") + print(f" wrote {gif_path}") + + html_path = os.path.join(MEDIA_DIR, f"{name}.html") + with open(html_path, "w", encoding="utf-8") as handle: + handle.write(ani.to_jshtml(default_mode="once")) + print(f" wrote {html_path}") + + +if __name__ == "__main__": + main() diff --git a/docs/source/media/classification_a_type.gif b/docs/source/media/classification_a_type.gif new file mode 100644 index 00000000..4159288e Binary files /dev/null and b/docs/source/media/classification_a_type.gif differ diff --git a/docs/source/media/classification_a_type.html b/docs/source/media/classification_a_type.html new file mode 100644 index 00000000..0deb7e59 --- /dev/null +++ b/docs/source/media/classification_a_type.html @@ -0,0 +1,2640 @@ + + + + + + +
+ +
+ +
+ + + + + + + + + +
+
+ + + + + + +
+
+
+ + + diff --git a/docs/source/media/classification_b_type.gif b/docs/source/media/classification_b_type.gif new file mode 100644 index 00000000..f44577e5 Binary files /dev/null and b/docs/source/media/classification_b_type.gif differ diff --git a/docs/source/media/classification_b_type.html b/docs/source/media/classification_b_type.html new file mode 100644 index 00000000..fb7e3a2b --- /dev/null +++ b/docs/source/media/classification_b_type.html @@ -0,0 +1,3507 @@ + + + + + + +
+ +
+ +
+ + + + + + + + + +
+
+ + + + + + +
+
+
+ + + diff --git a/docs/source/media/classification_legend.png b/docs/source/media/classification_legend.png new file mode 100644 index 00000000..49504fe2 Binary files /dev/null and b/docs/source/media/classification_legend.png differ diff --git a/docs/source/user/classification.rst b/docs/source/user/classification.rst index 1ff29b52..dd6dd392 100644 --- a/docs/source/user/classification.rst +++ b/docs/source/user/classification.rst @@ -76,6 +76,32 @@ For any generator set consisting of Pauli strings, the anticommutation graph can +----------------+------------------------------------------+-------------------------------------------------------------+ +Visualizing the transformation +------------------------------ +The step-by-step transformation of the anticommutation graph into its canonical form can be +animated with :code:`animation_anti_commutation_graph`. It drives the classification algorithm +through a recording wrapper that observes every step (without changing the result) and renders the +captured frames into an animation: + +.. code-block:: python + + from paulie import get_pauli_string as p, animation_anti_commutation_graph + + generators = p(["XY", "XZ"], n=4) + animation_anti_commutation_graph( + generators, + storage={"filename": "classification.gif", "writer": "pillow"}, + ) + +Each frame highlights the role a vertex currently plays, following this colour legend: + +.. image:: ../media/classification_legend.png + :alt: Colour legend for the node roles tracked while building the canonical graph + :align: center + +The two worked examples below are illustrated with interactive players. Use the controls to play, +pause, or step through the construction frame by frame. + Classification of A-type canonical graph ---------------------------------------- Let's try to classify a generator set that corresponds to an A-type canonical graph. This algebra is generated by :math:`\mathcal{P}=\{IYZI,IIXX,IIYZ,IXXI,XXII,YZII\}`. @@ -92,7 +118,13 @@ outputs algebra = 4*so(5) -.. Here we should add some definitions like lit, dependent etc. +The player below follows the construction of one of the four canonical components. Starting from +the input anticommutation graph, each generator is added in turn (red), attached to the centre or to +a leg (green), and length-one legs in different lit states are merged using the ``p`` (blue) and +``q`` (pink) vertices, until the A-type star graph is obtained: + +.. raw:: html + :file: ../media/classification_a_type.html According to the table, the resultant graph corresponds to :math:`\mathfrak{so}(5)\oplus \mathfrak{so}(5)\oplus \mathfrak{so}(5)\oplus \mathfrak{so}(5)`. But it is worth noting that it also corresponds to :math:`\mathfrak{sp}(2)\oplus \mathfrak{sp}(2)\oplus \mathfrak{sp}(2)\oplus \mathfrak{sp}(2)`. This shows that there is an exceptional isomorphism between :math:`\mathfrak{so}(5)` and :math:`\mathfrak{sp}(2)`. @@ -113,6 +145,12 @@ outputs algebra = sp(4) +Here the contractions reduce the graph to a single long leg attached to the centre, the canonical +form of type B1 corresponding to :math:`\mathfrak{sp}(4)`: + +.. raw:: html + :file: ../media/classification_b_type.html + The Lie algebra plays a pivotal role in quantum control theory to understand the reachability of states. Also measures of operator spread complexity rely on this concept. Furthermore, determining moments of circuits can be significantly simplified when the Lie algebra is known. diff --git a/pyproject.toml b/pyproject.toml index a8e78121..581eeb41 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "bitarray>=3.0.0", "memory-profiler>=0.61.0", "matplotlib>=3.10.0", + "pillow>=10.0.0", "six>=1.17.0", "tqdm>=4.67.1", "scipy>=1.17.0", diff --git a/src/paulie/__init__.py b/src/paulie/__init__.py index 292dd53d..29ebe401 100644 --- a/src/paulie/__init__.py +++ b/src/paulie/__init__.py @@ -24,6 +24,7 @@ from .application.get_optimal_su2_n import get_optimal_universal_generators from .application.average_graph_complexity import average_graph_complexity from .application.plot import plot_anti_commutation_graph +from .application.animation import animation_anti_commutation_graph from .application.average_pauli_weight import ( quantum_fourier_entropy, average_pauli_weight, @@ -79,6 +80,7 @@ "get_optimal_universal_generators", "average_graph_complexity", "plot_anti_commutation_graph", + "animation_anti_commutation_graph", "quantum_fourier_entropy", "average_pauli_weight", "get_pauli_weights", diff --git a/src/paulie/application/animation.py b/src/paulie/application/animation.py new file mode 100644 index 00000000..1e410230 --- /dev/null +++ b/src/paulie/application/animation.py @@ -0,0 +1,45 @@ +""" + Module for animating the transformation of the anti-commutation graph into a canonical form. +""" +from matplotlib.animation import Animation +from paulie.helpers._recording import RecordGraph +from paulie.helpers.drawing import _animation_graph +from paulie.common.pauli_string_collection import PauliStringCollection + + +def animation_anti_commutation_graph( + generators: PauliStringCollection, + storage: dict[str, str] | None = None, + interval: int = 1000, + show: bool = False, +) -> Animation: + """ + Generates an animation showing the transformation of the anti-commutation + graph into canonical form. + + The animation is driven by a + :class:`~paulie.classifier.recording_canonicalizer.RecordingCanonicalizer`, which observes the + classification algorithm and records each step. Use the colour legend in + :data:`paulie.helpers.drawing.NODE_ROLE_COLORS` to interpret node roles. + + Args: + generators (PauliStringCollection): Collection of Pauli strings. + storage (dict[str, str], optional): Location and format to save the + animation to. Expected keys are ``"filename"`` and ``"writer"``. + interval (int, optional): Interval between recording frames in + milliseconds. + show (bool, optional): Whether to display the animation window. + + Returns: + matplotlib.animation.Animation + """ + record = RecordGraph() + generators.set_record(record) + generators.classify() + generators.set_record(None) + return _animation_graph( + record, + interval=interval, + storage=storage, + show=show, + ) diff --git a/src/paulie/classifier/canonicalizer.py b/src/paulie/classifier/canonicalizer.py index 2838e845..ba2775fd 100644 --- a/src/paulie/classifier/canonicalizer.py +++ b/src/paulie/classifier/canonicalizer.py @@ -3,12 +3,18 @@ """ from paulie.common.pauli_string_bitarray import PauliString from paulie.classifier.classification import Morph +from paulie.classifier.observer import EventManager class Canonicalizer: """ - Class of canonicalizer of generators + Class of canonicalizer of generators. + + The canonicalizer is the *publisher* of the observer pattern: it owns an + :class:`~paulie.classifier.observer.EventManager` and emits events at every relevant step of + the algorithm through :meth:`_notify`. When no observer is subscribed, the notifications are + cheap no-ops, so the behaviour of the algorithm is unchanged. """ - def __init__(self): + def __init__(self) -> None: """ Initialize a Canonicalizer """ @@ -16,6 +22,37 @@ def __init__(self): self.central_vertex = None self.legs = [] self.vertex_stack = [] + self.events = EventManager() + + def _current_collection(self) -> list[PauliString]: + """ + Snapshot the vertices of the canonical graph being built. + + The central vertex is placed first, followed by the vertices of each leg. Observers use + this to reconstruct the anticommutation graph of the current state. + + Returns: + list[PauliString]: Current vertices, central vertex first. + """ + collection: list[PauliString] = [] + if self.central_vertex is not None: + collection.append(self.central_vertex) + for leg in self.legs: + collection.extend(leg) + return collection + + def _notify(self, event: str, **data) -> None: + """ + Emit an event to subscribed observers. + + Args: + event (str): Name of the event (used as the frame title). + **data: Contextual data describing the event (node roles, explicit collection, ...). + Returns: + None + """ + if self.events.has_subscribers(): + self.events.notify(event, self, data) def _set_vertex_stack(self, vertex_stack: list[PauliString]) -> None: """ @@ -71,6 +108,7 @@ def _build_core(self, v: PauliString) -> PauliString: v = self._tracked_multiply(v, self._representative(self.legs[0][0])) v = self._tracked_multiply(v, self._representative(self.central_vertex)) self.legs.append([v]) + self._notify("Attach to centre", appending=v) return v def _convert_to_single_lit_state(self, p_index: int, q_index: int, @@ -83,21 +121,31 @@ def _convert_to_single_lit_state(self, p_index: int, q_index: int, vertex_stack (list[PauliString]): Generator stack v (PauliString): Append Pauli string """ + self._notify("Single legs in different states (p lit, q unlit)", + lighting=v, p=self.legs[p_index][0], q=self.legs[q_index][0]) pq = (self._representative(self.legs[p_index][0]) @ self._representative(self.legs[q_index][0])) + replaced: list[PauliString] = [] if self._is_lit(v, self.central_vertex): self.central_vertex = self._tracked_multiply(self.central_vertex, pq) + replaced.append(self.central_vertex) for i in range(len(self.legs)): for j in range(len(self.legs[i])): if i != p_index and self._is_lit(v, self.legs[i][j]): self.legs[i][j] = self._tracked_multiply(self.legs[i][j], pq) + replaced.append(self.legs[i][j]) self.legs[p_index].append(v) + self._notify("Append to lit leg", appending=v, replacing=replaced) # Truncate longest leg if necessary, this happens at most once big_leg_cnt = sum(1 for leg in self.legs if len(leg) >= 2) if self.type == 'A' and big_leg_cnt >= 2: self.type = 'B' + removed: list[PauliString] = [] while len(self.legs[-1]) > 4: - vertex_stack.append(self.legs[-1].pop()) + removed.append(self.legs[-1].pop()) + vertex_stack.append(removed[-1]) + if removed: + self._notify("Type becomes B, trim longest leg", removing=removed) def _transfer_lightning(self, lit_2_leg_index: int, v: PauliString) -> PauliString: """ @@ -154,6 +202,7 @@ def _transfer_lightning(self, lit_2_leg_index: int, v: PauliString) -> PauliStri v = self._tracked_multiply(v, self._representative(self.legs[i][1]) @ self._representative( self.legs[i][0]) @ self._representative(self.legs[0][0])) + self._notify("Transfer lightning to long leg", lighting=v) return v def _reduce_lightning(self, vertex_stack: list[PauliString], @@ -177,6 +226,7 @@ def _reduce_lightning(self, vertex_stack: list[PauliString], break if is_append_to_end: self.legs[-1].append(v) + self._notify("Append to leg end", appending=v) return v m = None for i in range(len(self.legs[-1])): @@ -198,9 +248,12 @@ def _reduce_lightning(self, vertex_stack: list[PauliString], # Exit if no element of longest leg is lit if f is None and s is None: self.legs.append([v]) + self._notify("Attach as new leg", appending=v) return v # Otherwise naively reduce until one element is left if s is not None: + self._notify("Reduce lightning on long leg", lighting=v, + contracting=self.legs[-1][f]) # Compute prefix products on the leg to perform operations in O(1) and O(n) overall pref = self._representative(self.legs[-1][f]) for i in range(f, s): @@ -222,10 +275,12 @@ def _reduce_lightning(self, vertex_stack: list[PauliString], self._representative(self.legs[-1][1]) @ self._representative( self.legs[-1][3])) self.legs.append([v]) + self._notify("First vertex lit, type B2 split", appending=v) else: for w in self.legs[-1]: v = self._tracked_multiply(v, self._representative(w)) self.legs[-1].append(v) + self._notify("First vertex lit, extend long leg", appending=v) else: # Now we have to do careful case handling based on the type of the graph # Here f is either the middle or last vertex @@ -237,6 +292,8 @@ def _reduce_lightning(self, vertex_stack: list[PauliString], subleg = self.legs[-1][f:] self.legs[-1] = self.legs[-1][:f] self.legs.append([v] + subleg) + self._notify("Break long leg at lit vertex", appending=v, + replacing=[self.legs[-2][f - 1]]) # Now if the graph was type B, then nothing else has to be done # Let's order the legs in increasing size if len(self.legs[-1]) < len(self.legs[-2]): @@ -247,10 +304,15 @@ def _reduce_lightning(self, vertex_stack: list[PauliString], # This happens at most once if self.type == 'A' and len(self.legs[-1]) >= 2 and len(self.legs[-2]) >= 2: self.type = 'B' + removed: list[PauliString] = [] while len(self.legs[-1]) > 4: - vertex_stack.append(self.legs[-1].pop()) + removed.append(self.legs[-1].pop()) + vertex_stack.append(removed[-1]) while len(self.legs[-2]) > 2: - vertex_stack.append(self.legs[-2].pop()) + removed.append(self.legs[-2].pop()) + vertex_stack.append(removed[-1]) + if removed: + self._notify("Type becomes B, trim legs", removing=removed) return v def _dependency_check(self, length_1_legs: list[list[PauliString]]) -> None: @@ -286,13 +348,22 @@ def _connected_canonical_graph(self, vertex_stack: list[PauliString]) -> None: while vertex_stack: confirmed_legs = [leg for leg in self.legs if len(leg) != 1] length_1_legs = [leg for leg in self.legs if len(leg) == 1] - confirmed_legs.extend(self._dependency_check(length_1_legs)) + independent_legs = self._dependency_check(length_1_legs) + if self.events.has_subscribers(): + independent_set = {leg[0] for leg in independent_legs} + dropped = [leg[0] for leg in length_1_legs if leg[0] not in independent_set] + if dropped: + self._notify("Remove dependent vertices", + dependent=dropped[0], removing=dropped) + confirmed_legs.extend(independent_legs) self.legs = confirmed_legs self.legs.sort(key=len) v = vertex_stack.pop() if self.central_vertex is None: self.central_vertex = v + self._notify("Central vertex") continue + self._notify("Lighting vertex", lighting=v) # Build the core if len(self.legs) < 2: v = self._build_core(v) @@ -333,6 +404,7 @@ def _connected_canonical_graph(self, vertex_stack: list[PauliString]) -> None: break if not any_lit_leg: self.legs.append([v]) + self._notify("Connect to centre only", appending=v) continue if self.type == 'B': @@ -352,7 +424,14 @@ def _connected_canonical_graph(self, vertex_stack: list[PauliString]) -> None: confirmed_legs = [leg for leg in self.legs if len(leg) != 1] length_1_legs = [leg for leg in self.legs if len(leg) == 1] - confirmed_legs.extend(self._dependency_check(length_1_legs)) + independent_legs = self._dependency_check(length_1_legs) + if self.events.has_subscribers(): + independent_set = {leg[0] for leg in independent_legs} + dropped = [leg[0] for leg in length_1_legs if leg[0] not in independent_set] + if dropped: + self._notify("Remove dependent vertices", + dependent=dropped[0], removing=dropped) + confirmed_legs.extend(independent_legs) self.legs = confirmed_legs self.legs.sort(key=len) @@ -376,5 +455,13 @@ def build_canonical_graph(self, vertex_stack: list[PauliString]) -> Morph: vertex_stack (list[PauliString]): Generator stack """ self._set_vertex_stack(vertex_stack) + self._notify("Anticommutation graph", collection=vertex_stack.copy(), init=True) self._connected_canonical_graph(vertex_stack) - return self._get_morph() + morph = self._get_morph() + if self.events.has_subscribers(): + try: + self._notify( + f"Canonical graph of type {morph.get_type().name}: {morph.get_algebra()}") + except Exception: # pylint: disable=broad-except + self._notify("Canonical graph") + return morph diff --git a/src/paulie/classifier/classification.py b/src/paulie/classifier/classification.py index edfdfa51..102f3b42 100644 --- a/src/paulie/classifier/classification.py +++ b/src/paulie/classifier/classification.py @@ -219,6 +219,22 @@ def get_algebra_properties(self) -> tuple[TypeAlgebra,int,int]: return TypeAlgebra.SU, one_legs, 2**(two_legs + 2) return None, None, None + def get_algebra(self) -> str: + """ + Get the Lie algebra of this single canonical graph as a string. + + Returns: + str: The algebra, e.g. ``"sp(4)"`` or ``"4*so(5)"``. + """ + type_algebra, nc, size = self.get_algebra_properties() + name = {TypeAlgebra.U: "u", TypeAlgebra.SU: "su", + TypeAlgebra.SP: "sp", TypeAlgebra.SO: "so"}.get(type_algebra) + if name is None: + return "" + algebra = f"{name}({size})" + multiplicity = nc if nc == 1 else 2**(nc - 1) + return algebra if multiplicity == 1 else f"{multiplicity}*{algebra}" + def gen_independent_pair(self )-> Generator[list[list[PauliString]], None, None]: """ diff --git a/src/paulie/classifier/observer.py b/src/paulie/classifier/observer.py new file mode 100644 index 00000000..d7a9cf27 --- /dev/null +++ b/src/paulie/classifier/observer.py @@ -0,0 +1,106 @@ +""" + Observer pattern for the canonicalizer. + + The :class:`~paulie.classifier.canonicalizer.Canonicalizer` plays the role of the + *publisher*. It owns an :class:`EventManager` and emits events at every relevant step of + the classification algorithm. Objects that want to track these events (for example, the + frame recorder used to animate the transformation) implement :class:`CanonicalizerObserver` + and subscribe to the event manager. Subscribers only observe; they never change the + behaviour of the algorithm. +""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from paulie.classifier.canonicalizer import Canonicalizer + + +class CanonicalizerObserver: + """ + Subscriber interface for canonicalizer events. + + Concrete subscribers override :meth:`update` to react to events emitted by a + :class:`~paulie.classifier.canonicalizer.Canonicalizer`. + """ + + def update(self, event: str, source: "Canonicalizer", data: dict[str, Any]) -> None: + """ + React to an event emitted by the publisher. + + Args: + event (str): Name of the event (used as the frame title). + source (Canonicalizer): The canonicalizer that emitted the event. Subscribers may + read its current state (``central_vertex``, ``legs``) directly. + data (dict[str, Any]): Contextual data describing the event, with keys matching the + node roles tracked by the recorder (``lighting``, ``appending``, ``contracting``, + ``p``, ``q``, ``removing``, ``replacing``, ``dependent``, ``collection``, ``init``). + Returns: + None + """ + + +class EventManager: + """ + Subscription manager shared by composition. + + Holds the list of subscribers and forwards events to them. The canonicalizer keeps an + instance of this class instead of inheriting the subscription logic, so the recording + machinery can be patched in without touching the class hierarchy. + """ + + def __init__(self) -> None: + """ + Initialize an empty subscriber list. + + Returns: + None + """ + self._subscribers: list[CanonicalizerObserver] = [] + + def subscribe(self, observer: CanonicalizerObserver) -> None: + """ + Register a subscriber. + + Args: + observer (CanonicalizerObserver): Subscriber to register. + Returns: + None + """ + if observer not in self._subscribers: + self._subscribers.append(observer) + + def unsubscribe(self, observer: CanonicalizerObserver) -> None: + """ + Remove a subscriber. + + Args: + observer (CanonicalizerObserver): Subscriber to remove. + Returns: + None + """ + if observer in self._subscribers: + self._subscribers.remove(observer) + + def has_subscribers(self) -> bool: + """ + Check whether any subscriber is registered. + + Returns: + bool: True if at least one subscriber is registered. + """ + return bool(self._subscribers) + + def notify(self, event: str, source: "Canonicalizer", data: dict[str, Any]) -> None: + """ + Notify all subscribers of an event. + + Args: + event (str): Name of the event. + source (Canonicalizer): The canonicalizer emitting the event. + data (dict[str, Any]): Contextual data describing the event. + Returns: + None + """ + for observer in self._subscribers: + observer.update(event, source, data) diff --git a/src/paulie/classifier/recording_canonicalizer.py b/src/paulie/classifier/recording_canonicalizer.py new file mode 100644 index 00000000..87ab637f --- /dev/null +++ b/src/paulie/classifier/recording_canonicalizer.py @@ -0,0 +1,127 @@ +""" + Recording wrapper around the canonicalizer. + + :class:`FrameRecorder` is the *concrete subscriber*: it reacts to events emitted by a + :class:`~paulie.classifier.canonicalizer.Canonicalizer` and turns each of them into a frame in + a :class:`~paulie.helpers._recording.RecordGraph`. + + :class:`RecordingCanonicalizer` is the *client*: it creates a canonicalizer and a recorder and + wires them together. It does not change the algorithm, it only observes it, so it produces a + :class:`~paulie.helpers._recording.RecordGraph` for any input the plain canonicalizer accepts. +""" +from __future__ import annotations + +from typing import Any + +from paulie.common.pauli_string_bitarray import PauliString +from paulie.classifier.canonicalizer import Canonicalizer +from paulie.classifier.observer import CanonicalizerObserver +from paulie.helpers._recording import RecordGraph, recording_graph + + +class FrameRecorder(CanonicalizerObserver): + """ + Subscriber that records canonicalizer events as frames. + """ + + def __init__(self, record: RecordGraph) -> None: + """ + Initialize the recorder. + + Args: + record (RecordGraph): Record to append frames to. + Returns: + None + """ + self.record = record + + def update(self, event: str, source: Canonicalizer, data: dict[str, Any]) -> None: + """ + Translate an event into a frame. + + The current anticommutation graph is rebuilt from the canonicalizer state (or from an + explicit ``collection`` for the initial frame). When a ``lighting`` vertex is present, the + set of lit vertices is derived from the canonicalizer, so the recorder never needs the + algorithm to compute them. + + Args: + event (str): Name of the event, used as the frame title. + source (Canonicalizer): The canonicalizer emitting the event. + data (dict[str, Any]): Contextual data describing the event. + Returns: + None + """ + collection = data.get("collection") + if collection is None: + collection = source._current_collection() # pylint: disable=protected-access + + lighting = data.get("lighting") + lits = None + if lighting is not None: + lits = [w for w in collection + if w != lighting and source._is_lit(lighting, w)] # pylint: disable=protected-access + + recording_graph( + self.record, + collection=collection, + lighting=lighting, + appending=data.get("appending"), + contracting=data.get("contracting"), + lits=lits, + p=data.get("p"), + q=data.get("q"), + removing_vertices=self._as_list(data.get("removing")), + replacing_vertices=self._as_list(data.get("replacing")), + dependent=data.get("dependent"), + title=event, + init=data.get("init", False), + ) + + @staticmethod + def _as_list(value: Any) -> list[PauliString] | None: + """ + Normalize a role payload to a list of Pauli strings. + + Args: + value (Any): A single Pauli string, a list of them, or None. + Returns: + list[PauliString] | None: List form of the payload, or None. + """ + if value is None: + return None + if isinstance(value, list): + return value + return [value] + + +class RecordingCanonicalizer(Canonicalizer): + """ + Canonicalizer that records its own transformation steps. + + Drop-in replacement for :class:`~paulie.classifier.canonicalizer.Canonicalizer`: it accepts the + same input and returns the same :class:`~paulie.classifier.classification.Morph`, while emitting + a frame for every step into a :class:`~paulie.helpers._recording.RecordGraph`. + """ + + def __init__(self, record: RecordGraph | None = None) -> None: + """ + Initialize the recording canonicalizer. + + Args: + record (RecordGraph, optional): Record to append frames to. Defaults to a fresh record. + Returns: + None + """ + super().__init__() + self.record = record if record is not None else RecordGraph() + self.recorder = FrameRecorder(self.record) + self.events.subscribe(self.recorder) + + def get_record(self) -> RecordGraph: + """ + Get the recording produced by this canonicalizer. + + Returns: + RecordGraph: Recording of the transformation steps. + """ + return self.record diff --git a/src/paulie/common/pauli_string_collection.py b/src/paulie/common/pauli_string_collection.py index 504992a5..5ab09b26 100644 --- a/src/paulie/common/pauli_string_collection.py +++ b/src/paulie/common/pauli_string_collection.py @@ -15,6 +15,7 @@ from paulie.classifier.classification import Classification from paulie.classifier.tracked_canonicalizer import TrackedCanonicalizer from paulie.classifier.canonicalizer import Canonicalizer +from paulie.classifier.recording_canonicalizer import RecordingCanonicalizer from paulie.exceptions import PauliStringCollectionException class PauliStringCollection: @@ -34,6 +35,7 @@ def __init__(self, generators: list[PauliString] | Self | None = None) -> None: self.generators: list[PauliString] = [] self.classification: Classification = None self._tracked: bool = False + self._record = None if not generators: return @@ -523,6 +525,22 @@ def get_subgraphs(self) -> list[PauliStringCollection]: return [self._convert(subgraph) for subgraph in sorted(nx.connected_components(g), key=len, reverse=True)] + def set_record(self, record) -> None: + """ + Attach a recording of the canonical graph construction. + + When a record is set, :meth:`classify` drives a + :class:`~paulie.classifier.recording_canonicalizer.RecordingCanonicalizer`, which observes + the algorithm and appends a frame for every transformation step. Pass ``None`` to stop + recording. + + Args: + record (RecordGraph | None): Record to append frames to, or None to disable recording. + Returns: + None + """ + self._record = record + def classify(self) -> Classification: """ Build the canonical graph of the generators and therefore classify its Lie algebra. @@ -541,7 +559,9 @@ def classify(self) -> Classification: vertex_stack = [self.create_instance(pauli_str=s) for s in nx.dfs_preorder_nodes(g.subgraph(cc))] vertex_stack.reverse() - if self._tracked: + if self._record is not None: + conn_canon = RecordingCanonicalizer(self._record) + elif self._tracked: conn_canon = TrackedCanonicalizer() else: conn_canon = Canonicalizer() diff --git a/src/paulie/helpers/_recording.py b/src/paulie/helpers/_recording.py new file mode 100644 index 00000000..0126771e --- /dev/null +++ b/src/paulie/helpers/_recording.py @@ -0,0 +1,449 @@ +""" + Module that contains classes for creating a recording of the anticommutation graph + transformation process. + + These are plain data containers. They are populated by + :class:`~paulie.classifier.recording_canonicalizer.FrameRecorder` (a subscriber to the + canonicalizer events) and consumed by :func:`paulie.helpers.drawing._animation_graph` to + render an animation. +""" +from typing import cast + +import numpy as np +from paulie.common.get_graph import get_graph, GraphWithLabels +from paulie.common.pauli_string_bitarray import PauliString + +class FrameGraph: + """ + Stores the graph data in a frame. + """ + def __init__(self, vertices:list[str], edges:list[tuple[str,str]], + edge_labels:dict[tuple[str,str],str] | None = None + ) -> None: + """ + Initialize the frame with graph data (vertices, edges, edge labels). + + Args: + vertices (list[str]): List of vertices. + edges (list[tuple[str,str]]): List of edges. + edge_labels (dict[tuple[str,str],str]): List of edge labels. + Returns: + None + """ + self.vertices = vertices + self.edges = edges + self.edge_labels = edge_labels + + def get_graph(self + ) -> tuple[list[str], list[tuple[str, str]], dict[tuple[str, str], str] | None]: + """ + Get the graph data in the frame. + + Returns: + Vertices, edges, and labels of edges. + """ + return self.vertices.copy(), self.edges.copy(), self.edge_labels + +class FrameRecord: + """ + A single frame of a recording. + """ + def __init__(self, graph:FrameGraph | None=None, lighting:PauliString | None=None, + appending:PauliString | None=None, contracting:PauliString | None=None, + lits:list[PauliString] | None=None, p:PauliString | None=None, + q:PauliString | None=None, removing_vertices:list[PauliString] | None=None, + replacing_vertices:list[PauliString] | None=None, + dependent:PauliString | None=None, + title:str | None=None, init:bool=False) -> None: + """ + Initialize the frame with graph data and the current state of the algorithm. + + Args: + graph (FrameGraph, optional): Graph data. + lighting (PauliString, optional): Vertex to be added which is generating the lighting. + appending (PauliString, optional): Vertex to which the new vertex will be attached. + contracting (PauliString, optional): Vertex which is being contracted. + lits (list[PauliString], optional): List of lit vertices. + p (PauliString, optional): Lit vertex in a leg of length 1. + q (PauliString, optional): Unlit vertex in a leg of length 1. + removing_vertices (list[PauliString], optional): List of vertices to be removed. + replacing_vertices (list[PauliString], optional): List of vertices to be replaced. + dependent (PauliString, optional): Dependent vertex. + title (str, optional): Title of frame. + init (bool, optional): Whether this is the initial frame. Defaults to `False`. + Returns: + None + """ + self.graph = graph + self.lighting = str(lighting) if lighting else None + self.lits = [str(v) for v in lits] if lits else [] + self.p = str(p) if p else None + self.q = str(q) if q else None + self.removing_vertices = [str(v) for v in removing_vertices] if removing_vertices else [] + self.replacing_vertices = ([str(v) for v in replacing_vertices] + if replacing_vertices else []) + self.dependent = str(dependent) if dependent else None + self.title = title + self.contracting = str(contracting) if contracting else None + self.appending = str(appending) if appending else None + self.init = init + + def get_graph(self + ) -> tuple[list[str], list[tuple[str, str]], dict[tuple[str, str], str] | None] | None: + """ + Get graph data stored in the frame. + + Returns: + Vertices, edges, and labels of edges. + """ + if not self.graph: + return None + return self.graph.get_graph() + + def get_lighting(self) -> str: + """ + Get the vertex to be added. + + Returns: + str: The vertex to be added. + """ + return self.lighting + + def get_title(self) -> str: + """ + Get the title of the frame. + + Returns: + str: The title of the frame. + """ + return self.title + + def is_appending(self) -> bool: + """ + Check if no vertex is being appended to the graph. + + Returns: + bool: True if not appending. + """ + return not self.appending + + def get_is_appending(self, vertex: str) -> bool: + """ + Check if the vertex is being appended to. + + Args: + vertex (str): Vertex for checking. + Returns: + bool: True if vertex is being appended to. + """ + if not self.appending: + return False + return self.appending == vertex + + def get_is_contracting(self, vertex:str) -> bool: + """ + Check if the vertex is being contracted. + + Args: + vertex (str): Vertex for checking. + Returns: + bool: True if vertex is being contracted. + """ + if not self.contracting: + return False + return self.contracting == vertex + + def get_is_p(self, vertex:str) -> bool: + """ + Check if vertex is p. + + Args: + vertex (str): Vertex for checking. + Returns: + bool: True if vertex is p. + """ + if not self.p: + return False + return self.p == vertex + + def get_is_q(self, vertex:str) -> bool: + """ + Check if vertex is q. + + Args: + vertex (str): Vertex for checking. + Returns: + bool: True if vertex is q. + """ + if not self.q: + return False + return self.q == vertex + + def get_is_dependent(self, vertex:str) -> bool: + """ + Check if vertex is dependent. + + Args: + vertex (str): Vertex for checking. + Returns: + bool: True if vertex is dependent. + """ + if not self.dependent: + return False + return self.dependent == vertex + + def get_is_lits(self, vertex:str) -> bool: + """ + Check if vertex is lit. + + Args: + vertex (str): Vertex for checking. + Returns: + bool: True if vertex is lit. + """ + return vertex in self.lits + + def is_removing(self) -> bool: + """ + Check if no vertices are being removed. + + Returns: + bool: True if no vertices are being removed. + """ + return not self.removing_vertices + + def get_is_removing(self, vertex:str) -> bool: + """ + Check if vertex is being removed. + + Args: + vertex (str): Vertex for checking. + Returns: + bool: True if vertex is being removed. + """ + return vertex in self.removing_vertices + + def get_is_replacing(self, vertex:str) -> bool: + """ + Check if vertex is being replaced. + + Args: + vertex (str): Vertex for checking. + Returns: + bool: True if vertex is being replaced. + """ + return vertex in self.replacing_vertices + + def get_init(self) -> bool: + """ + Check if this is the initial frame. + + Returns: + bool: True if this is the initial frame. + """ + return self.init + +class RecordGraph: + """ + Recording of the canonical graph transformation process. + """ + def __init__(self) -> None: + """ + Initialize an empty record. + + Returns: + None + """ + self.frames: list[FrameRecord] = [] + self.positions = {} + self.x_position_lighting = 0 + + def append_frame(self, frame: FrameRecord) -> None: + """ + Append a frame to the record. + + Args: + frame (FrameRecord): Frame to be added to the record. + Returns: + None + """ + self.frames.append(frame) + + def append(self, graph: FrameGraph | None=None, + lighting:PauliString | None=None, appending:PauliString | None=None, + contracting:PauliString | None=None, lits:list[PauliString] | None=None, + p:PauliString | None=None, q:PauliString | None=None, + removing_vertices:list[PauliString] | None=None, + replacing_vertices:list[PauliString] | None=None, + dependent:PauliString | None=None, title:str | None=None, + init:bool=False) -> None: + """ + Make a frame and append it to the record. + + Args: + graph (FrameGraph, optional): Graph data. + lighting (PauliString, optional): Vertex to be added which is generating the lighting. + appending (PauliString, optional): Vertex to which the new vertex will be attached. + contracting (PauliString, optional): Vertex which is being contracted. + lits (list[PauliString], optional): List of lit vertices. + p (PauliString, optional): Lit vertex in a leg of length 1. + q (PauliString, optional): Unlit vertex in a leg of length 1. + removing_vertices (list[PauliString], optional): List of vertices to be removed. + replacing_vertices (list[PauliString], optional): List of vertices to be replaced. + dependent (PauliString, optional): Dependent vertex. + title (str, optional): Title of frame. + init (bool, optional): Whether this is the initial frame. Defaults to `False`. + Returns: + None + """ + self.append_frame(FrameRecord(graph, lighting=lighting, appending=appending, + contracting=contracting, lits=lits, p=p, q=q, + removing_vertices=removing_vertices, + replacing_vertices=replacing_vertices, + dependent=dependent, title=title, init=init) + ) + + def get_frame(self, index:int) -> FrameRecord: + """ + Get a frame by its index. + + Args: + index (int): Index of frame. + Returns: + FrameRecord: Frame of record. + Raises: + ValueError: If is index is greater than the number of frames. + """ + if index > len(self.frames) - 1: + raise ValueError("Out of index") + return self.frames[index] + + def clear(self) -> None: + """ + Clear the record. + + Returns: + None + """ + self.frames = [] + + def get_size(self) -> int: + """ + Get the number of frames in the record. + + Returns: + int: Number of frames. + """ + return len(self.frames) + + def get_graph(self, index:int + ) -> tuple[list[str], list[tuple[str, str]], dict[tuple[str, str], str] | None]|None: + """ + Get graph data by frame index. + + If the frame at ``index`` has no graph data, walk back to the most recent frame that does. + + Args: + index (int): Index of frame. + Returns: + Vertices, edges, and labels of edges. + """ + while index > -1: + frame = self.get_frame(index) + graph = frame.get_graph() + if not graph: + index -= 1 + continue + return graph + return None + + def get_is_prev(self, index:int) -> bool: + """ + Check if a frame has no graph data. + + Args: + index (int): Index of frame. + Returns: + bool: True if the frame has no graph data. + """ + frame = self.get_frame(index) + return frame.get_graph() is None + + def set_positions(self, positions:dict[str,np.array]) -> None: + """ + Set the positions of the vertices. + + Args: + positions (dict[str,numpy.array]): Positions of vertices. + Returns: + None + """ + self.positions = positions + + def get_positions(self) -> dict[str,np.array]: + """ + Get the positions of the vertices. + + Returns: + dict[str,numpy.array]: Positions of vertices. + """ + return self.positions + + def set_x_position_lighting(self, x_position_lighting:int) -> None: + """ + Set x position of the vertex to be added. + + Args: + x_position_lighting (int): x position of the vertex to be added. + Returns: + None + """ + self.x_position_lighting = x_position_lighting + + def get_x_position_lighting(self) -> int: + """ + Get x position of the vertex to be added. + + Returns: + x position of the vertex to be added. + """ + return self.x_position_lighting + +def recording_graph(record:RecordGraph, collection:list[PauliString] | None=None, + lighting:PauliString | None=None, appending:PauliString | None=None, + contracting:PauliString | None=None, lits:list[PauliString] | None=None, + p:PauliString | None=None, q:PauliString | None = None, + removing_vertices:list[PauliString] | None=None, + replacing_vertices:list[PauliString] | None=None, + dependent:PauliString | None=None, title:str | None=None, + init:bool=False) -> None: + """ + Append a frame with the given data to the record. + + Args: + record (RecordGraph): Record of graph building. + collection (list[PauliString], optional): List of vertices. When provided, the + anticommutation graph of these vertices is stored in the frame. + lighting (PauliString, optional): Vertex to be added which is generating the lighting. + appending (PauliString, optional): Vertex to which the new vertex will be attached. + contracting (PauliString, optional): Vertex which is being contracted. + lits (list[PauliString], optional): List of lit vertices. + p (PauliString, optional): Lit vertex in a leg of length 1. + q (PauliString, optional): Unlit vertex in a leg of length 1. + removing_vertices (list[PauliString], optional): List of vertices to be removed. + replacing_vertices (list[PauliString], optional): List of vertices to be replaced. + dependent (PauliString, optional): Dependent vertex. + title (str, optional): Title of frame. + init (bool, optional): Whether this is the initial frame. Defaults to `False`. + Returns: + None + """ + graph = None + if collection is not None: + vertices, edges, edge_labels = cast(GraphWithLabels, get_graph(collection)) + graph = FrameGraph(vertices, edges, edge_labels) + record.append(graph, lighting=lighting, appending=appending, + contracting=contracting, lits=lits, p=p, q=q, + removing_vertices=removing_vertices, + replacing_vertices=replacing_vertices, + dependent=dependent, title=title, init=init) diff --git a/src/paulie/helpers/drawing.py b/src/paulie/helpers/drawing.py index da0a6811..8b4e0302 100644 --- a/src/paulie/helpers/drawing.py +++ b/src/paulie/helpers/drawing.py @@ -1,10 +1,45 @@ """ Module with graph drawing utilities. """ +import math import networkx as nx +import matplotlib.pyplot as plt +import matplotlib.animation +import numpy as np from paulie.common.pauli_string_collection import PauliStringCollection from paulie.common.pauli_string_bitarray import PauliString from paulie.common.get_graph import get_graph +from paulie.helpers._recording import RecordGraph + +#: Colour convention for the node roles tracked by the recorder. +#: The palette is a qualitative set chosen so the roles stay distinguishable from one another +#: (no two near-identical blues or magentas), which the docs legend renders as colour swatches. +NODE_ROLE_COLORS = { + "lighting": "#e6194b", # vertex currently being added (red) + "dependent": "#9a6324", # dependent vertex (brown) + "contracting": "#f58231", # vertex currently being contracted (orange) + "appending": "#3cb44b", # attachment target / appended vertex (green) + "removing": "#000000", # vertex being temporarily removed (black) + "replacing": "#911eb4", # vertex being replaced (purple) + "p": "#4363d8", # lit vertex in a leg of length one (blue) + "q": "#f032e6", # unlit vertex in a leg of length one (pink) + "lit": "#42d4f4", # lit vertex (cyan) + "unlit": "#d3d3d3", # unlit vertex (light gray) +} + +#: Human readable label for each node role, used to build the legend. +NODE_ROLE_LABELS = { + "lighting": "vertex being added (lighting)", + "dependent": "dependent vertex", + "lit": "lit vertex", + "unlit": "unlit vertex", + "contracting": "vertex being contracted", + "appending": "attachment target (appending)", + "removing": "vertex being temporarily removed", + "replacing": "vertex being replaced", + "p": "lit vertex in a length-one leg (p)", + "q": "unlit vertex in a length-one leg (q)", +} def plot_graph(vertices:list[str], edges:list[tuple[str,str]], @@ -46,3 +81,454 @@ def plot_graph_by_nodes(nodes:PauliStringCollection, commutators = [] vertices, edges, edge_labels = get_graph(nodes, commutators) return plot_graph(vertices, edges, edge_labels) + +def save_role_legend(filename: str) -> None: + """ + Render the node-role colour legend as a standalone image. + + Each role is drawn as a filled colour swatch next to its description, so a reader can map the + colours in an animation to their meaning. Colours come from :data:`NODE_ROLE_COLORS` and labels + from :data:`NODE_ROLE_LABELS`. + + Args: + filename (str): Path to write the legend image to. + Returns: + None + """ + roles = list(NODE_ROLE_LABELS) + fig, ax = plt.subplots(figsize=(5, 0.4 * len(roles) + 0.3)) + ax.axis("off") + # One row per role, top to bottom: a square swatch on the left, its label to the right. + for i, role in enumerate(roles): + y = len(roles) - i + ax.add_patch(plt.Rectangle((0, y - 0.4), 0.6, 0.6, + facecolor=NODE_ROLE_COLORS[role], edgecolor="#666666")) + ax.text(0.9, y - 0.1, NODE_ROLE_LABELS[role], va="center", fontsize=10) + ax.set_xlim(0, 6) + ax.set_ylim(0, len(roles) + 1) + fig.tight_layout() + fig.savefig(filename, dpi=150, bbox_inches="tight") + plt.close(fig) + +def _node_color(frame, node: str, lighting: str | None) -> str: + """ + Resolve the colour of a node for a frame from its role. + + Roles are checked in priority order so that the most specific role wins. + + Args: + frame (FrameRecord): Frame being rendered. + node (str): Vertex to colour. + lighting (str | None): The lighting vertex of the frame, if any. + Returns: + str: A matplotlib colour. + """ + if lighting is not None and node == lighting: + if frame.get_is_dependent(node): + return NODE_ROLE_COLORS["dependent"] + return NODE_ROLE_COLORS["lighting"] + if frame.get_is_removing(node): + return NODE_ROLE_COLORS["removing"] + if frame.get_is_dependent(node): + return NODE_ROLE_COLORS["dependent"] + if frame.get_is_replacing(node): + return NODE_ROLE_COLORS["replacing"] + if frame.get_is_appending(node): + return NODE_ROLE_COLORS["appending"] + if frame.get_is_contracting(node): + return NODE_ROLE_COLORS["contracting"] + if frame.get_is_p(node): + return NODE_ROLE_COLORS["p"] + if frame.get_is_q(node): + return NODE_ROLE_COLORS["q"] + if frame.get_is_lits(node): + return NODE_ROLE_COLORS["lit"] + return NODE_ROLE_COLORS["unlit"] + +def _staggered_label_positions(positions: dict) -> dict: + """ + Offset node labels alternately above and below their nodes. + + Within a horizontal row, neighbouring labels overlap once the Pauli strings are about as + wide as the node spacing; alternating the side doubles the horizontal room per label and + keeps the labels off the coloured nodes. + + Args: + positions (dict): Node coordinates. + Returns: + dict: Coordinates to draw each node label at. + """ + rows: dict[float, list] = {} + for node, xy in positions.items(): + rows.setdefault(round(float(xy[1]), 2), []).append(node) + label_positions = {} + for row_nodes in rows.values(): + row_nodes.sort(key=lambda node: float(positions[node][0])) + for i, node in enumerate(row_nodes): + dy = 0.06 if i % 2 == 0 else -0.06 + label_positions[node] = positions[node] + np.array([0.0, dy]) + return label_positions + +def _animation_graph( + record: RecordGraph, + interval: int = 1000, + repeat: bool = False, + storage: dict | None = None, + show: bool = True, +) -> matplotlib.animation.Animation: + """ + Animate the canonical graph construction from a recording. + + Args: + record (RecordGraph): A recording of the canonical graph construction. + interval (int, optional): Interval between frames in milliseconds. + repeat (bool, optional): Whether to loop the animation. + storage (dict, optional): Output location and format. Expected keys: + - "filename": path to the output file + - "writer": matplotlib writer name or writer instance + show (bool, optional): Whether to display the animation window. + + Returns: + matplotlib.animation.Animation + """ + fig, ax = plt.subplots(figsize=(6, 4)) + + # Largest number of legs seen in any frame; used to fix one spacing for all frames. + max_legs_seen = [0] + + def build_positions( + edges: list[tuple[str, str]], + center: str, + dist: float | None = None, + ) -> tuple[dict[str, np.ndarray], float]: + """ + Lay out a star-like canonical graph as 2D coordinates. + + The canonical graph is a central vertex with several legs (chains) hanging off it. The two + longest legs are drawn as a single horizontal backbone through the centre, and the shorter + legs fan out radially. This keeps the long leg readable even when it wraps onto several + rows. + + Args: + edges (list[tuple[str, str]]): Edges of the committed graph for this frame. + center (str): The central vertex (legs are reconstructed by walking out from it). + dist (float, optional): Spacing between adjacent vertices. Defaults to a value + derived from this frame's leg count; pass an explicit value to keep the + spacing identical across all frames of an animation. + Returns: + tuple[dict[str, np.ndarray], float]: Vertex coordinates and the x coordinate at which + the incoming "lighting" vertex should be placed above the graph. + """ + # Reconstruct the legs from the edge list. Each neighbour of the centre starts a leg... + legs = [] + positions: dict[str, np.ndarray] = {} + + for edge in edges: + if center in edge: + v = edge[1] if center == edge[0] else edge[0] + legs.append([v]) + + # ...and we extend every leg outwards, hopping to the next unused neighbour until the + # chain ends. This turns the edge set back into ordered chains of vertices. + for leg in legs: + current = leg[0] + while True: + is_found = False + for edge in edges: + v = edge[1] if current == edge[0] else edge[0] + if current in edge and v not in leg and v != center: + leg.append(v) + current = v + is_found = True + break + if not is_found: + break + + # Order legs shortest to longest so the two longest end up at the back of the list. + legs.sort(key=len) + n_legs_total = len(legs) + max_legs_seen[0] = max(max_legs_seen[0], n_legs_total) + + max_line = 7 # vertices per row before the long leg wraps to a new row + y_dist = 0.25 # vertical gap between wrapped rows + pos_y = 0.0 # baseline y of the backbone + y = pos_y + center_x = 0.0 + x_position_lighting = 0.0 + x_first = 0.0 + x_last = 0.0 + + # Horizontal spacing between adjacent vertices; tighter when there are many legs. + # Must be defined for all branches. + if dist is None: + dist = 2.0 / (8 if n_legs_total > 7 else max(1, n_legs_total)) + + if n_legs_total > 1: + # Backbone: lay the two longest legs in one horizontal line through the centre. + n = 0 + x = 1.0 + dist / 2.0 + + # Second-longest leg fills the right half, placed right-to-left up to the centre. + for v in reversed(legs[-2]): + x -= dist + if x_first == 0: + x_first = x + positions[v] = np.array([x, y]) + n += 1 + + # The centre sits between the two backbone legs. + x -= dist + positions[center] = np.array([x, y]) + center_x = x + n += 1 + direction = -1 + + # Longest leg continues to the left. When a row gets too long it wraps: drop down + # one row and reverse horizontal direction so it snakes back (boustrophedon). + for v in legs[-1]: + if n > max_line: + if x_last == 0: + x_last = x + x += direction * dist + y -= y_dist + n = 0 + max_line = 5 + direction *= -1 + if x_position_lighting == 0: + x_position_lighting = (x + 1 - dist / 2) / 2 + + x += direction * dist + positions[v] = np.array([x, y]) + n += 1 + if x_last == 0: + x_last = x + + # Park the incoming lighting vertex above the middle of the backbone. + if x_position_lighting == 0: + x_position_lighting = (x_first + x_last) / 2 + + # The two longest legs are placed; the rest are handled radially below. + legs = legs[:-2] + + if n_legs_total == 0: + # Lone centre, nothing else to place. + positions[center] = np.array([0.0, pos_y]) + x_position_lighting = 0.0 + + elif n_legs_total == 1: + # Single leg: just the centre and its one neighbour side by side. + positions[legs[0][0]] = np.array([0.0, pos_y]) + positions[center] = np.array([0.25, pos_y]) + x_position_lighting = 0.125 + + elif len(legs) > 0: + # Remaining short legs fan out from the centre at evenly spaced angles, sweeping + # back and forth so they spread above and below the backbone instead of overlapping. + direction = 1 + n = len(legs) + ang = 3 * math.pi / (2 * n) + if ang >= math.pi / 2: + ang = ang / 2 + c_ang = ang + + for leg in legs: + # First vertex of the leg, one step out from the centre along the current angle. + v = leg[0] + y = pos_y + dist * math.sin(c_ang) + x = center_x + dist * math.cos(c_ang) + positions[v] = np.array([x, y]) + + # A length-2 leg gets its second vertex one more step out along the same ray. + if len(leg) > 1: + v = leg[1] + y = pos_y + 2 * dist * math.sin(c_ang) + x = center_x + 2 * dist * math.cos(c_ang) + positions[v] = np.array([x, y]) + + # Advance the angle; once it passes the top, flip the sweep to the other side. + c_ang += direction * ang + if c_ang > 3 * math.pi / 4: + direction *= -1 + c_ang = direction * ang + + return positions, x_position_lighting + + def compute_frame(num: int, dist: float | None = None) -> dict: + """ + Build the graph and its layout for one frame. + + Args: + num (int): Index of the frame. + dist (float, optional): Fixed vertex spacing shared by all frames. + Returns: + dict: Frame, graph, positions, edge labels, lighting vertex, label flag, and + whether a spring layout was used (init frame). + """ + frame = record.get_frame(num) + + vertices, edges, edge_labels = record.get_graph(num) + vertices = list(vertices) + edges = list(edges) + + center = None + with_labels = True + + if len(vertices) > 0: + center = vertices[0] + # Long Pauli strings: the mid-edge labels collide with the node labels on the + # same baseline, so drop them first; past 10 qubits node labels no longer fit. + if len(center) > 6: + edge_labels = None + if len(center) > 10: + with_labels = False + + lighting = frame.get_lighting() + if lighting: + if len(lighting) > 6: + edge_labels = None + if len(lighting) > 10: + with_labels = False + + if lighting not in vertices: + vertices.append(lighting) + + for v in vertices: + if frame.get_is_lits(v): + edges.append((lighting, v)) + + if frame.get_is_dependent(lighting): + dependent = lighting + lighting = f"dependent {lighting}" + edges.append((dependent, lighting)) + + graph = nx.Graph() + graph.add_nodes_from(vertices) + graph.add_edges_from(edges) + + spring = frame.get_init() or center is None + if spring: + positions = nx.spring_layout(graph) + else: + positions, x_position_lighting = build_positions(edges, center, dist) + if lighting: + positions[lighting] = np.array([x_position_lighting, 1.0]) + # Centre each frame's content horizontally. build_positions anchors the backbone + # at an arbitrary x and the long leg grows sideways, so without this the graph + # drifts towards one edge of the canvas as it is built up. + xs = [xy[0] for xy in positions.values()] + offset = np.array([(min(xs) + max(xs)) / 2, 0.0]) + for node in positions: + positions[node] = positions[node] - offset + + # Robustness: make sure every drawn node has a position. + missing = [node for node in graph if node not in positions] + if missing: + fallback = nx.spring_layout(graph, pos=positions, fixed=list(positions) or None) + for node in missing: + positions[node] = fallback[node] + + return {"frame": frame, "graph": graph, "positions": positions, + "edge_labels": edge_labels, "lighting": lighting, + "with_labels": with_labels, "spring": spring} + + # Precompute every frame so the whole animation can share one geometry and one viewport. + # Without this, matplotlib autoscales each frame to its own data and the per-frame vertex + # spacing varies with the leg count, which makes the graph and the gaps between wrapped + # rows jump from frame to frame. + # First pass discovers the largest leg count of any frame; the second pass lays every + # frame out again with the spacing that count dictates, so coordinates are comparable. + for num in range(record.get_size()): + compute_frame(num) + shared_dist = 2.0 / (8 if max_legs_seen[0] > 7 else max(1, max_legs_seen[0])) + frames_data = [compute_frame(num, shared_dist) for num in range(record.get_size())] + + # Fix the axis limits from the canonical-layout frames only. The x range is symmetric + # around the hub (anchored at x = 0 above) so the graph stays horizontally centred. + all_x = [xy[0] for fd in frames_data if not fd["spring"] for xy in fd["positions"].values()] + all_y = [xy[1] for fd in frames_data if not fd["spring"] for xy in fd["positions"].values()] + if not all_x: + all_x = [xy[0] for fd in frames_data for xy in fd["positions"].values()] + all_y = [xy[1] for fd in frames_data for xy in fd["positions"].values()] + margin = 0.1 + if all_x: + half_width = max(abs(x) for x in all_x) + margin + xlim = (-half_width, half_width) + ylim = (min(all_y) - margin, max(all_y) + margin) + else: + xlim = (-1.0, 1.0) + ylim = (-1.0, 1.0) + + # Spring-layout frames (the initial anticommutation graph) use their own coordinate + # system; rescale them into the shared viewport so they do not distort the limits. + for fd in frames_data: + if fd["spring"] and fd["graph"].number_of_nodes() > 0: + fd["positions"] = nx.spring_layout( + fd["graph"], + center=((xlim[0] + xlim[1]) / 2, (ylim[0] + ylim[1]) / 2), + scale=0.45 * min(xlim[1] - xlim[0], ylim[1] - ylim[0]), + ) + + def update(num: int): + ax.clear() + fd = frames_data[num] + frame, graph, positions = fd["frame"], fd["graph"], fd["positions"] + edge_labels, lighting, with_labels = fd["edge_labels"], fd["lighting"], fd["with_labels"] + ax.set_title(f"{frame.get_title()}") + + color_map = [_node_color(frame, node, lighting) for node in graph] + + ax.axis("off") + + if edge_labels is not None: + nx.draw_networkx_edge_labels( + graph, + pos=positions, + edge_labels=edge_labels, + font_color="red", + hide_ticks=True, + node_size=60, + font_size=6, + ax=ax, + ) + + result = nx.draw_networkx( + graph, + pos=positions, + node_color=color_map, + hide_ticks=True, + node_size=60, + font_size=6, + with_labels=False, + edge_color="#aaaaaa", + ax=ax, + ) + + if with_labels: + nx.draw_networkx_labels( + graph, + pos=_staggered_label_positions(positions), + font_size=6, + hide_ticks=True, + ax=ax, + ) + + # Same viewport for every frame so the layout does not jump during the animation. + ax.set_xlim(*xlim) + ax.set_ylim(*ylim) + return result + + ani = matplotlib.animation.FuncAnimation( + fig, + update, + frames=record.get_size(), + interval=interval, + repeat=repeat, + ) + + if storage is not None: + ani.save(filename=storage["filename"], writer=storage["writer"]) + + if show: + plt.show() + + return ani diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..c3023832 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,5 @@ +"""Shared pytest configuration.""" +import matplotlib + +# Use a non-interactive backend so the plotting and animation tests run headless. +matplotlib.use("Agg") diff --git a/tests/test_public_api.py b/tests/test_public_api.py index 62abdadb..7f5fc623 100644 --- a/tests/test_public_api.py +++ b/tests/test_public_api.py @@ -46,6 +46,7 @@ "get_optimal_universal_generators", "average_graph_complexity", "plot_anti_commutation_graph", + "animation_anti_commutation_graph", "quantum_fourier_entropy", "average_pauli_weight", "get_pauli_weights", diff --git a/tests/test_recording.py b/tests/test_recording.py new file mode 100644 index 00000000..a3a4319e --- /dev/null +++ b/tests/test_recording.py @@ -0,0 +1,157 @@ +"""Tests for the recording / animation machinery wrapping the canonicalizer.""" +import matplotlib.pyplot as plt +import pytest + +from paulie import get_pauli_string as p +from paulie.helpers._recording import RecordGraph +from paulie.helpers.drawing import save_role_legend, NODE_ROLE_COLORS, NODE_ROLE_LABELS +from paulie.classifier.recording_canonicalizer import RecordingCanonicalizer +from paulie.application.animation import animation_anti_commutation_graph + +# The two worked examples from docs/source/user/classification.rst. +A_TYPE = (["IYZI", "IIXX", "IIYZ", "IXXI", "XXII", "YZII"], None, "4*so(5)") +B_TYPE = (["XY", "XZ"], 4, "sp(4)") + +SMALL_CASES = [ + (["XY"], 3, "so(3)"), + (["XY", "XZ"], 3, None), + A_TYPE, + B_TYPE, +] + + +def _collection(generators, n): + return p(generators, n=n) if n is not None else p(generators) + + +def _frames_via_collection(generators, n): + """Run a recorded classification and return the populated RecordGraph.""" + record = RecordGraph() + collection = _collection(generators, n) + collection.set_record(record) + collection.classify() + return record, collection + + +def test_recording_does_not_change_classification(): + """Recording must only observe: the classified algebra is unchanged.""" + for generators, n, _ in SMALL_CASES: + plain = _collection(generators, n).get_algebra() + record = RecordGraph() + recorded_collection = _collection(generators, n) + recorded_collection.set_record(record) + recorded = recorded_collection.classify().get_algebra() + assert plain == recorded + + +def test_record_is_non_empty(): + """A recorded run produces at least an initial and a final frame.""" + for generators, n, _ in SMALL_CASES: + record, _ = _frames_via_collection(generators, n) + assert record.get_size() >= 2 + + +def test_initial_frame_is_input_graph(): + """The first frame is flagged as the init frame and holds the input graph.""" + generators, n, _ = B_TYPE + record, _ = _frames_via_collection(generators, n) + first = record.get_frame(0) + assert first.get_init() is True + vertices, edges, _labels = first.get_graph() + # The initial frame is the input anticommutation graph of the first connected component. + assert len(vertices) == len(set(vertices)) + assert len(vertices) >= 2 + vertex_set = set(vertices) + for a, b in edges: + assert a in vertex_set and b in vertex_set + + +def test_terminal_frame_is_canonical_type(): + """The final frame title announces a canonical type and carries no lighting vertex.""" + generators, n, _ = B_TYPE + record, _ = _frames_via_collection(generators, n) + last = record.get_frame(record.get_size() - 1) + assert last.get_title().startswith("Canonical graph") + assert last.get_lighting() is None + + +def test_terminal_frame_title_has_algebra_name(): + """The final frame title includes the algebra of the component (gksmail review).""" + generators, n, expected = B_TYPE + record, _ = _frames_via_collection(generators, n) + last = record.get_frame(record.get_size() - 1) + assert expected in last.get_title() # e.g. "Canonical graph of type B1: sp(4)" + + +def test_role_legend_renders(tmp_path): + """The colour legend renders to an image file.""" + # Every labelled role must have a colour, so the legend is complete. + assert set(NODE_ROLE_LABELS) <= set(NODE_ROLE_COLORS) + out = tmp_path / "legend.png" + save_role_legend(str(out)) + assert out.exists() and out.stat().st_size > 0 + + +def test_every_frame_graph_is_well_formed(): + """Each frame resolves to a graph whose edges only reference existing vertices.""" + for generators, n, _ in SMALL_CASES: + record, _ = _frames_via_collection(generators, n) + for i in range(record.get_size()): + graph = record.get_graph(i) + assert graph is not None + vertices, edges, _labels = graph + assert len(vertices) == len(set(vertices)) + vertex_set = set(vertices) + for a, b in edges: + assert a in vertex_set + assert b in vertex_set + + +def test_recording_canonicalizer_returns_morph(): + """RecordingCanonicalizer is a drop-in returning the same kind of Morph.""" + generators, n, _ = B_TYPE + collection = _collection(generators, n) + record = RecordGraph() + canon = RecordingCanonicalizer(record) + # Build the same vertex stack classify() would, for a single connected component. + subgraphs = collection.get_subgraphs() + assert len(subgraphs) == 1 + stack = subgraphs[0].get() + morph = canon.build_canonical_graph(stack) + assert morph.get_type().name in {"A", "B1", "B2", "B3", "NONE"} + assert canon.get_record().get_size() >= 2 + + +def test_lit_vertices_have_edges_to_lighting(): + """When a frame has a lighting vertex, every lit vertex is recorded for it.""" + generators, n, _ = A_TYPE + record, _ = _frames_via_collection(generators, n) + saw_lighting = False + for i in range(record.get_size()): + frame = record.get_frame(i) + lighting = frame.get_lighting() + if lighting is None: + continue + saw_lighting = True + vertices, _edges, _labels = record.get_graph(i) + # The lit set is a subset of the current vertices (excluding lighting itself). + for v in vertices: + if frame.get_is_lits(v): + assert v != lighting + assert saw_lighting + + +@pytest.mark.parametrize("generators,n", [(A_TYPE[0], A_TYPE[1]), (B_TYPE[0], B_TYPE[1])]) +def test_animation_renders_to_gif(generators, n, tmp_path): + """The renderer turns a run into a GIF, exercising every frame's draw callback.""" + collection = _collection(generators, n) + out = tmp_path / "anim.gif" + animation_anti_commutation_graph( + collection, + storage={"filename": str(out), "writer": "pillow"}, + interval=200, + show=False, + ) + plt.close("all") + assert out.exists() + assert out.stat().st_size > 0 diff --git a/uv.lock b/uv.lock index eb30035c..c5e60b5e 100644 --- a/uv.lock +++ b/uv.lock @@ -753,14 +753,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ef/3c/2c197d226f9ea224a9ab8d197933f9da0ae0aac5b6e0f884e2b8d9c8e9f7/pathspec-1.0.4-py3-none-any.whl", hash = "sha256:fb6ae2fd4e7c921a165808a552060e722767cfa526f99ca5156ed2ce45a5c723", size = 55206, upload-time = "2026-01-27T03:59:45.137Z" }, ] -[[package]] -name = "pauliarray" -version = "0.0.1" -source = { git = "https://github.com/algolab-quantique/pauliarray.git#e0e245aa8c9470324f3454c9624e3412018fb106" } -dependencies = [ - { name = "numpy" }, -] - [[package]] name = "paulie" version = "0.0.1" @@ -772,7 +764,7 @@ dependencies = [ { name = "memory-profiler" }, { name = "networkx" }, { name = "numpy" }, - { name = "pauliarray" }, + { name = "pillow" }, { name = "scipy" }, { name = "six" }, { name = "tqdm" }, @@ -800,7 +792,7 @@ requires-dist = [ { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.19.0" }, { name = "networkx", specifier = ">=3.3" }, { name = "numpy", specifier = ">=2.2.2" }, - { name = "pauliarray", git = "https://github.com/algolab-quantique/pauliarray.git" }, + { name = "pillow", specifier = ">=10.0.0" }, { name = "pydata-sphinx-theme", marker = "extra == 'dev'", specifier = ">=0.16.1" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.3.5" }, { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.9.5" },