From f340d9456accf81237528f618e214c4d148a16e8 Mon Sep 17 00:00:00 2001 From: harryswift01 Date: Wed, 24 Jun 2026 10:14:14 +0100 Subject: [PATCH 1/3] perf(axes): cache united-atom topology for customised axes --- CodeEntropy/levels/axes.py | 155 +++++++++++++++ CodeEntropy/levels/level_dag.py | 6 + CodeEntropy/levels/nodes/axes_topology.py | 223 ++++++++++++++++++++++ CodeEntropy/levels/nodes/covariance.py | 36 +++- 4 files changed, 417 insertions(+), 3 deletions(-) create mode 100644 CodeEntropy/levels/nodes/axes_topology.py diff --git a/CodeEntropy/levels/axes.py b/CodeEntropy/levels/axes.py index 57ef8912..cb0af142 100644 --- a/CodeEntropy/levels/axes.py +++ b/CodeEntropy/levels/axes.py @@ -216,6 +216,161 @@ def get_UA_axes(self, data_container, index: int): return trans_axes, rot_axes, center, moment_of_inertia + def get_UA_axes_from_topology( + self, + *, + u, + residue_atoms, + topology, + box: np.ndarray | None, + ): + """Compute UA axes using cached static topology. + + This is the cached-index equivalent of ``get_UA_axes``. It preserves the + frame-dependent numerical calculations, but avoids repeated MDAnalysis + selection strings for heavy atoms, bonded atoms, and UA masses. + + Args: + u: Current-frame universe. + residue_atoms: AtomGroup for the parent residue in the current frame. + topology: Cached ``UAAxesTopology`` for this UA bead. + box: Current periodic box lengths. If omitted, ``u.dimensions`` is used. + + Returns: + Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + - trans_axes: Translational axes, shape ``(3, 3)``. + - rot_axes: Rotational axes, shape ``(3, 3)``. + - center: Rotation centre, shape ``(3,)``. + - moment_of_inertia: Principal moments, shape ``(3,)``. + + Raises: + ValueError: If cached bonded-axis construction fails. + """ + dimensions = ( + np.asarray(box, dtype=float) + if box is not None + else np.asarray(u.dimensions[:3], dtype=float) + ) + + heavy_atoms = u.atoms[topology.residue_heavy_indices] + heavy_atom = u.atoms[int(topology.heavy_atom_index)] + + if len(heavy_atoms) > 1: + center = residue_atoms.center_of_mass(unwrap=True) + moment_of_inertia_tensor = self.get_moment_of_inertia_tensor( + center_of_mass=center, + positions=heavy_atoms.positions, + masses=topology.residue_ua_masses, + dimensions=dimensions, + ) + trans_axes, _moment_of_inertia = self.get_custom_principal_axes( + moment_of_inertia_tensor + ) + else: + make_whole(residue_atoms) + trans_axes = residue_atoms.principal_axes() + + center = heavy_atom.position + rot_axes, moment_of_inertia = self.get_bonded_axes_from_topology( + u=u, + heavy_atom=heavy_atom, + topology=topology, + dimensions=dimensions, + ) + if rot_axes is None or moment_of_inertia is None: + raise ValueError("Unable to compute bonded axes for cached UA bead.") + + logger.debug("Translational Axes: %s", trans_axes) + logger.debug("Rotational Axes: %s", rot_axes) + logger.debug("Center: %s", center) + logger.debug("Moment of Inertia: %s", moment_of_inertia) + + return trans_axes, rot_axes, center, moment_of_inertia + + def get_bonded_axes_from_topology( + self, + *, + u, + heavy_atom, + topology, + dimensions: np.ndarray, + ): + """Compute UA bonded axes using cached bonded atom indices. + + This mirrors ``get_bonded_axes`` but receives precomputed bonded atom + memberships from ``UAAxesTopology`` instead of rediscovering them with + MDAnalysis selection strings inside the frame loop. + + Args: + u: Current-frame universe. + heavy_atom: Current-frame heavy atom for the UA bead. + topology: Cached ``UAAxesTopology`` for the UA bead. + dimensions: Simulation box lengths, shape ``(3,)``. + + Returns: + Tuple[np.ndarray | None, np.ndarray | None]: + - custom_axes: Custom rotation axes, shape ``(3, 3)``, or ``None``. + - custom_moment_of_inertia: Principal moments, shape ``(3,)``, or + ``None``. + """ + if not heavy_atom.mass > 1.1: + return None, None + + custom_moment_of_inertia = None + custom_axes = None + + heavy_bonded = u.atoms[topology.bonded_heavy_indices] + light_bonded = u.atoms[topology.bonded_light_indices] + ua = u.atoms[topology.ua_atom_indices] + ua_all = u.atoms[topology.ua_all_atom_indices] + + if len(heavy_bonded) == 0: + custom_axes, custom_moment_of_inertia = self.get_vanilla_axes(ua_all) + + if len(heavy_bonded) == 1 and len(light_bonded) == 0: + custom_axes = self.get_custom_axes( + a=heavy_atom.position, + b_list=[heavy_bonded[0].position], + c=np.zeros(3), + dimensions=dimensions, + ) + + if len(heavy_bonded) == 1 and len(light_bonded) >= 1: + custom_axes = self.get_custom_axes( + a=heavy_atom.position, + b_list=[heavy_bonded[0].position], + c=light_bonded[0].position, + dimensions=dimensions, + ) + + if len(heavy_bonded) >= 2: + custom_axes = self.get_custom_axes( + a=heavy_atom.position, + b_list=heavy_bonded.positions, + c=heavy_bonded[1].position, + dimensions=dimensions, + ) + + if custom_axes is None: + return None, None + + if custom_moment_of_inertia is None: + custom_moment_of_inertia = self.get_custom_moment_of_inertia( + UA=ua, + custom_rotation_axes=custom_axes, + center_of_mass=heavy_atom.position, + dimensions=dimensions, + ) + + custom_axes = self.get_flipped_axes( + ua, + custom_axes, + heavy_atom.position, + dimensions, + ) + + return custom_axes, custom_moment_of_inertia + def get_bonded_axes(self, system, atom, dimensions: np.ndarray): r"""Compute UA rotational axes from bonded topology around a heavy atom. diff --git a/CodeEntropy/levels/level_dag.py b/CodeEntropy/levels/level_dag.py index cc8129a1..d299fd0c 100644 --- a/CodeEntropy/levels/level_dag.py +++ b/CodeEntropy/levels/level_dag.py @@ -19,6 +19,7 @@ from CodeEntropy.levels.frame_dag import FrameGraph from CodeEntropy.levels.neighbors import Neighbors from CodeEntropy.levels.nodes.accumulators import InitCovarianceAccumulatorsNode +from CodeEntropy.levels.nodes.axes_topology import BuildAxesTopologyNode from CodeEntropy.levels.nodes.beads import BuildBeadsNode from CodeEntropy.levels.nodes.detect_levels import DetectLevelsNode from CodeEntropy.levels.nodes.detect_molecules import DetectMoleculesNode @@ -49,6 +50,11 @@ def build(self) -> LevelDAG: self._add_static("detect_molecules", DetectMoleculesNode()) self._add_static("detect_levels", DetectLevelsNode(), deps=["detect_molecules"]) self._add_static("build_beads", BuildBeadsNode(), deps=["detect_levels"]) + self._add_static( + "build_axes_topology", + BuildAxesTopologyNode(), + deps=["build_beads"], + ) self._add_static( "init_covariance_accumulators", InitCovarianceAccumulatorsNode(), diff --git a/CodeEntropy/levels/nodes/axes_topology.py b/CodeEntropy/levels/nodes/axes_topology.py new file mode 100644 index 00000000..abac3739 --- /dev/null +++ b/CodeEntropy/levels/nodes/axes_topology.py @@ -0,0 +1,223 @@ +"""Build static axes-topology metadata for frame covariance calculations. + +This module caches topology-only atom-index relationships needed by customised +united-atom axes calculations. The cache avoids repeated MDAnalysis selection +parsing inside the frame-local covariance loop while preserving frame-dependent +positions, forces, centres, axes, torques, and moments of inertia. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from typing import Any + +import numpy as np + +logger = logging.getLogger(__name__) + +UAKey = tuple[int, int, int] + + +@dataclass(frozen=True) +class UAAxesTopology: + """Static topology required to compute customised united-atom axes. + + Attributes: + heavy_atom_index: Reduced-universe atom index for the UA heavy atom. + ua_atom_indices: Atom indices for the UA heavy atom and its bonded + hydrogens/light atoms. + ua_all_atom_indices: Atom indices for the UA heavy atom, bonded heavy + atoms, and bonded hydrogens/light atoms. + bonded_heavy_indices: Heavy atoms bonded to the UA heavy atom. + bonded_light_indices: Hydrogens/light atoms bonded to the UA heavy atom. + residue_heavy_indices: Heavy atoms in the parent residue. + residue_ua_masses: UA masses for heavy atoms in the parent residue. + """ + + heavy_atom_index: int + ua_atom_indices: np.ndarray + ua_all_atom_indices: np.ndarray + bonded_heavy_indices: np.ndarray + bonded_light_indices: np.ndarray + residue_heavy_indices: np.ndarray + residue_ua_masses: np.ndarray + + +@dataclass(frozen=True) +class AxesTopology: + """Cached axes topology for frame covariance calculations. + + Attributes: + ua: Mapping from ``(mol_id, local_residue_id, ua_id)`` to cached + united-atom axes topology. + """ + + ua: dict[UAKey, UAAxesTopology] = field(default_factory=dict) + + +class BuildAxesTopologyNode: + """Build static customised-axes topology before frame covariance execution.""" + + def run(self, shared_data: dict[str, Any]) -> dict[str, Any]: + """Build cached axes topology and write it into shared data. + + The cache is only populated when ``args.customised_axes`` is true. When + customised axes are disabled, an empty cache is still written so later + stages can read ``shared_data["axes_topology"]`` safely. + + Args: + shared_data: Shared workflow data containing ``args`` and, when + customised axes are enabled, ``reduced_universe``, ``levels``, + and ``beads``. + + Returns: + Dict containing the cached ``axes_topology`` object. + """ + args = shared_data["args"] + topology = AxesTopology() + + if not bool(getattr(args, "customised_axes", False)): + shared_data["axes_topology"] = topology + return {"axes_topology": topology} + + u = shared_data["reduced_universe"] + levels = shared_data["levels"] + beads = shared_data["beads"] + + ua_topology: dict[UAKey, UAAxesTopology] = {} + fragments = u.atoms.fragments + + for mol_id, level_list in enumerate(levels): + if "united_atom" not in level_list: + continue + + self._add_ua_topology( + u=u, + mol=fragments[mol_id], + mol_id=mol_id, + beads=beads, + out=ua_topology, + ) + + topology = AxesTopology(ua=ua_topology) + shared_data["axes_topology"] = topology + return {"axes_topology": topology} + + def _add_ua_topology( + self, + *, + u: Any, + mol: Any, + mol_id: int, + beads: dict[Any, list[np.ndarray]], + out: dict[UAKey, UAAxesTopology], + ) -> None: + """Cache static UA axes topology for one molecule. + + Args: + u: Reduced universe used to resolve bead atom-index arrays. + mol: Molecule AtomGroup. + mol_id: Molecule index. + beads: Bead-index mapping produced by ``BuildBeadsNode``. + out: Output UA topology mapping mutated in place. + """ + for local_res_i, residue in enumerate(mol.residues): + bead_key = (mol_id, "united_atom", local_res_i) + bead_idx_list = beads.get(bead_key, []) + + if not bead_idx_list: + continue + + residue_atoms = residue.atoms + residue_heavy = residue_atoms.select_atoms("prop mass > 1.1") + residue_heavy_indices = residue_heavy.indices.astype(int, copy=True) + residue_ua_masses = np.asarray( + self._get_ua_masses_from_topology(residue_atoms), + dtype=float, + ) + + for ua_i, bead_indices in enumerate(bead_idx_list): + bead = u.atoms[bead_indices] + heavy = bead.select_atoms("prop mass > 1.1") + + if len(heavy) == 0: + logger.warning( + "Skipping UA axes topology with no heavy atom: " + "mol=%s residue=%s ua=%s", + mol_id, + local_res_i, + ua_i, + ) + continue + + heavy_atom = heavy[0] + bonded_heavy, bonded_light = self._split_bonded_atoms(heavy_atom) + + heavy_index = np.asarray([int(heavy_atom.index)], dtype=int) + bonded_heavy_indices = bonded_heavy.indices.astype(int, copy=True) + bonded_light_indices = bonded_light.indices.astype(int, copy=True) + + ua_atom_indices = np.concatenate( + [heavy_index, bonded_light_indices], + axis=0, + ) + ua_all_atom_indices = np.concatenate( + [heavy_index, bonded_heavy_indices, bonded_light_indices], + axis=0, + ) + + out[(mol_id, local_res_i, ua_i)] = UAAxesTopology( + heavy_atom_index=int(heavy_atom.index), + ua_atom_indices=ua_atom_indices, + ua_all_atom_indices=ua_all_atom_indices, + bonded_heavy_indices=bonded_heavy_indices, + bonded_light_indices=bonded_light_indices, + residue_heavy_indices=residue_heavy_indices, + residue_ua_masses=residue_ua_masses, + ) + + @staticmethod + def _split_bonded_atoms(atom: Any) -> tuple[Any, Any]: + """Return bonded heavy and light atoms for one atom. + + Args: + atom: MDAnalysis Atom. + + Returns: + Tuple containing bonded heavy atoms and bonded hydrogens/light atoms. + """ + bonded_atoms = atom.bonded_atoms + bonded_heavy = bonded_atoms.select_atoms("mass 2 to 999") + bonded_light = bonded_atoms.select_atoms("mass 1 to 1.1") + return bonded_heavy, bonded_light + + @staticmethod + def _get_ua_masses_from_topology(atom_group: Any) -> list[float]: + """Return UA masses using static bonded atom relationships. + + Args: + atom_group: AtomGroup containing atoms from one residue. + + Returns: + List of UA masses, one for each heavy atom in ``atom_group``. + """ + ua_masses: list[float] = [] + + for atom in atom_group: + if atom.mass <= 1.1: + continue + + ua_mass = float(atom.mass) + bonded_atoms = getattr(atom, "bonded_atoms", None) + if bonded_atoms is None: + ua_masses.append(ua_mass) + continue + + bonded_h_atoms = bonded_atoms.select_atoms("mass 1 to 1.1") + for hydrogen in bonded_h_atoms: + ua_mass += float(hydrogen.mass) + + ua_masses.append(ua_mass) + + return ua_masses diff --git a/CodeEntropy/levels/nodes/covariance.py b/CodeEntropy/levels/nodes/covariance.py index 4331cc6a..e7f829d5 100644 --- a/CodeEntropy/levels/nodes/covariance.py +++ b/CodeEntropy/levels/nodes/covariance.py @@ -79,6 +79,7 @@ def run(self, ctx: FrameCtx) -> dict[str, Any]: beads = shared["beads"] args = shared["args"] axes_manager = shared.get("axes_manager") + axes_topology = shared.get("axes_topology") fp = float(args.force_partitioning) combined = bool(getattr(args, "combined_forcetorque", False)) @@ -110,6 +111,7 @@ def run(self, ctx: FrameCtx) -> dict[str, Any]: group_id=group_id, beads=beads, axes_manager=axes_manager, + axes_topology=axes_topology, box=box, force_partitioning=fp, customised_axes=customised_axes, @@ -172,6 +174,7 @@ def _process_united_atom( group_id: int, beads: dict[Any, list[Any]], axes_manager: Any, + axes_topology: Any | None, box: np.ndarray | None, force_partitioning: float, customised_axes: bool, @@ -189,6 +192,7 @@ def _process_united_atom( group_id: Molecule-group identifier used for within-frame averaging. beads: Mapping of bead keys to reduced-universe atom-index arrays. axes_manager: Axes helper used to build translation and rotation axes. + axes_topology: Optional cached axes topology generated during static setup. box: Optional periodic box vector. force_partitioning: Force partitioning factor for highest-level vectors. customised_axes: Whether customised UA axes should be used. @@ -208,9 +212,13 @@ def _process_united_atom( continue force_vecs, torque_vecs = self._build_ua_vectors( + u=u, + mol_id=mol_id, + local_res_i=local_res_i, residue_atoms=res.atoms, bead_groups=bead_groups, axes_manager=axes_manager, + axes_topology=axes_topology, box=box, force_partitioning=force_partitioning, customised_axes=customised_axes, @@ -388,9 +396,13 @@ def _process_polymer( def _build_ua_vectors( self, *, + u: Any, + mol_id: int, + local_res_i: int, bead_groups: list[Any], residue_atoms: Any, axes_manager: Any, + axes_topology: Any | None, box: np.ndarray | None, force_partitioning: float, customised_axes: bool, @@ -399,9 +411,13 @@ def _build_ua_vectors( """Build force and torque vectors for united-atom beads. Args: + u: Universe-like object used to resolve cached atom indices. + mol_id: Molecule index used in axes-topology lookup keys. + local_res_i: Local residue index used in axes-topology lookup keys. bead_groups: Atom groups representing UA beads in a residue. residue_atoms: Atom group for the parent residue. axes_manager: Axes helper used to select axes, centres, and moments. + axes_topology: Optional cached axes topology generated during static setup. box: Optional periodic box vector. force_partitioning: Force partitioning factor for highest-level vectors. customised_axes: Whether customised UA axes should be used. @@ -415,9 +431,23 @@ def _build_ua_vectors( for ua_i, bead in enumerate(bead_groups): if customised_axes: - trans_axes, rot_axes, center, moi = axes_manager.get_UA_axes( - residue_atoms, ua_i - ) + ua_topology = None + if axes_topology is not None: + ua_topology = axes_topology.ua.get((mol_id, local_res_i, ua_i)) + + if ua_topology is not None: + trans_axes, rot_axes, center, moi = ( + axes_manager.get_UA_axes_from_topology( + u=u, + residue_atoms=residue_atoms, + topology=ua_topology, + box=box, + ) + ) + else: + trans_axes, rot_axes, center, moi = axes_manager.get_UA_axes( + residue_atoms, ua_i + ) else: make_whole(residue_atoms) make_whole(bead) From 61c920d27b7c1862187bf8161a9dc3671e2a2728 Mon Sep 17 00:00:00 2001 From: harryswift01 Date: Wed, 24 Jun 2026 10:14:50 +0100 Subject: [PATCH 2/3] test(unit): cover cached united-atom axes topology --- .../levels/nodes/test_axes_topology.py | 336 +++++++++++++++ .../levels/nodes/test_covariance_node.py | 57 +++ tests/unit/CodeEntropy/levels/test_axes.py | 394 ++++++++++++++++++ .../unit/CodeEntropy/levels/test_level_dag.py | 3 + 4 files changed, 790 insertions(+) create mode 100644 tests/unit/CodeEntropy/levels/nodes/test_axes_topology.py diff --git a/tests/unit/CodeEntropy/levels/nodes/test_axes_topology.py b/tests/unit/CodeEntropy/levels/nodes/test_axes_topology.py new file mode 100644 index 00000000..db31f026 --- /dev/null +++ b/tests/unit/CodeEntropy/levels/nodes/test_axes_topology.py @@ -0,0 +1,336 @@ +"""Atomic unit tests for static axes-topology construction.""" + +from __future__ import annotations + +from types import SimpleNamespace + +import numpy as np + +from CodeEntropy.levels.nodes.axes_topology import ( + AxesTopology, + BuildAxesTopologyNode, + UAAxesTopology, +) + + +class FakeAtomGroup: + """Small AtomGroup-like object for axes-topology tests.""" + + def __init__(self, atoms=None, *, name="ag"): + self._atoms = list(atoms or []) + self.name = name + self.indices = np.asarray([atom.index for atom in self._atoms], dtype=int) + + def __iter__(self): + return iter(self._atoms) + + def __len__(self): + return len(self._atoms) + + def __getitem__(self, index): + if isinstance(index, (list, tuple, np.ndarray)): + return FakeAtomGroup([self._atoms[int(i)] for i in index]) + return self._atoms[int(index)] + + def select_atoms(self, selection): + """Return atoms matching the small mass selections used by the node.""" + if selection == "prop mass > 1.1": + return FakeAtomGroup([atom for atom in self._atoms if atom.mass > 1.1]) + if selection == "mass 2 to 999": + return FakeAtomGroup( + [atom for atom in self._atoms if 2.0 <= atom.mass <= 999.0] + ) + if selection == "mass 1 to 1.1": + return FakeAtomGroup( + [atom for atom in self._atoms if 1.0 <= atom.mass <= 1.1] + ) + raise AssertionError(f"Unexpected selection: {selection}") + + +class FakeAtom: + """Small Atom-like object with mass, index, and bonded atoms.""" + + def __init__(self, index, mass, bonded_atoms=None): + self.index = index + self.mass = mass + self.bonded_atoms = bonded_atoms or FakeAtomGroup([]) + + +class FakeResidue: + """Small residue-like object.""" + + def __init__(self, atoms): + self.atoms = atoms + + +class FakeMolecule: + """Small molecule-like object with residues.""" + + def __init__(self, residues): + self.residues = residues + + +class FakeAtoms: + """Container supporting fragments and u.atoms[index_array].""" + + def __init__(self, fragments, atom_map): + self.fragments = fragments + self._atom_map = dict(atom_map) + + def __getitem__(self, indices): + if isinstance(indices, np.ndarray): + return FakeAtomGroup([self._atom_map[int(index)] for index in indices]) + if isinstance(indices, (list, tuple)): + return FakeAtomGroup([self._atom_map[int(index)] for index in indices]) + return self._atom_map[int(indices)] + + +class FakeUniverse: + """Small universe-like object.""" + + def __init__(self, fragments, atom_map): + self.atoms = FakeAtoms(fragments=fragments, atom_map=atom_map) + + +def _args(*, customised_axes): + return SimpleNamespace(customised_axes=customised_axes) + + +def _single_molecule_universe(): + """Build a small molecule containing one residue and one UA bead.""" + hydrogen = FakeAtom(index=2, mass=1.0) + bonded_heavy = FakeAtom(index=3, mass=12.0) + heavy = FakeAtom( + index=1, + mass=12.0, + bonded_atoms=FakeAtomGroup([hydrogen, bonded_heavy]), + ) + other_residue_heavy = FakeAtom(index=4, mass=14.0) + + residue_atoms = FakeAtomGroup([heavy, hydrogen, bonded_heavy, other_residue_heavy]) + residue = FakeResidue(residue_atoms) + molecule = FakeMolecule([residue]) + + atom_map = { + heavy.index: heavy, + hydrogen.index: hydrogen, + bonded_heavy.index: bonded_heavy, + other_residue_heavy.index: other_residue_heavy, + } + universe = FakeUniverse([molecule], atom_map) + + return universe, molecule, heavy, hydrogen, bonded_heavy, other_residue_heavy + + +def test_ua_axes_topology_dataclass_preserves_arrays(): + topology = UAAxesTopology( + heavy_atom_index=1, + ua_atom_indices=np.array([1, 2]), + ua_all_atom_indices=np.array([1, 3, 2]), + bonded_heavy_indices=np.array([3]), + bonded_light_indices=np.array([2]), + residue_heavy_indices=np.array([1, 3, 4]), + residue_ua_masses=np.array([13.0, 12.0, 14.0]), + ) + + assert topology.heavy_atom_index == 1 + np.testing.assert_array_equal(topology.ua_atom_indices, np.array([1, 2])) + np.testing.assert_array_equal(topology.ua_all_atom_indices, np.array([1, 3, 2])) + np.testing.assert_array_equal(topology.bonded_heavy_indices, np.array([3])) + np.testing.assert_array_equal(topology.bonded_light_indices, np.array([2])) + np.testing.assert_array_equal(topology.residue_heavy_indices, np.array([1, 3, 4])) + np.testing.assert_allclose(topology.residue_ua_masses, np.array([13.0, 12.0, 14.0])) + + +def test_axes_topology_defaults_to_empty_ua_mapping(): + topology = AxesTopology() + + assert topology.ua == {} + + +def test_run_writes_empty_topology_when_customised_axes_disabled(): + node = BuildAxesTopologyNode() + shared_data = {"args": _args(customised_axes=False)} + + result = node.run(shared_data) + + assert isinstance(result["axes_topology"], AxesTopology) + assert result["axes_topology"].ua == {} + assert shared_data["axes_topology"] is result["axes_topology"] + + +def test_run_builds_ua_topology_for_united_atom_levels(): + node = BuildAxesTopologyNode() + universe, _, heavy, hydrogen, bonded_heavy, other_residue_heavy = ( + _single_molecule_universe() + ) + shared_data = { + "args": _args(customised_axes=True), + "reduced_universe": universe, + "levels": [["united_atom"]], + "beads": {(0, "united_atom", 0): [np.array([1, 2])]}, + } + + result = node.run(shared_data) + + axes_topology = result["axes_topology"] + assert shared_data["axes_topology"] is axes_topology + assert set(axes_topology.ua) == {(0, 0, 0)} + + ua_topology = axes_topology.ua[(0, 0, 0)] + assert ua_topology.heavy_atom_index == heavy.index + np.testing.assert_array_equal(ua_topology.ua_atom_indices, np.array([1, 2])) + np.testing.assert_array_equal(ua_topology.ua_all_atom_indices, np.array([1, 3, 2])) + np.testing.assert_array_equal(ua_topology.bonded_heavy_indices, np.array([3])) + np.testing.assert_array_equal(ua_topology.bonded_light_indices, np.array([2])) + np.testing.assert_array_equal( + ua_topology.residue_heavy_indices, + np.array([heavy.index, bonded_heavy.index, other_residue_heavy.index]), + ) + np.testing.assert_allclose( + ua_topology.residue_ua_masses, + np.array([13.0, 12.0, 14.0]), + ) + + +def test_run_ignores_molecules_without_united_atom_level(): + node = BuildAxesTopologyNode() + universe, _, _, _, _, _ = _single_molecule_universe() + shared_data = { + "args": _args(customised_axes=True), + "reduced_universe": universe, + "levels": [["residue"]], + "beads": {(0, "united_atom", 0): [np.array([1, 2])]}, + } + + result = node.run(shared_data) + + assert result["axes_topology"].ua == {} + assert shared_data["axes_topology"].ua == {} + + +def test_add_ua_topology_skips_residues_without_beads(): + node = BuildAxesTopologyNode() + universe, molecule, _, _, _, _ = _single_molecule_universe() + out = {} + + node._add_ua_topology( + u=universe, + mol=molecule, + mol_id=0, + beads={}, + out=out, + ) + + assert out == {} + + +def test_add_ua_topology_skips_ua_beads_without_heavy_atoms(caplog): + node = BuildAxesTopologyNode() + hydrogen = FakeAtom(index=2, mass=1.0) + residue = FakeResidue(FakeAtomGroup([hydrogen])) + molecule = FakeMolecule([residue]) + universe = FakeUniverse([molecule], {2: hydrogen}) + out = {} + + node._add_ua_topology( + u=universe, + mol=molecule, + mol_id=5, + beads={(5, "united_atom", 0): [np.array([2])]}, + out=out, + ) + + assert out == {} + assert "Skipping UA axes topology with no heavy atom" in caplog.text + + +def test_add_ua_topology_handles_multiple_residues_and_ua_beads(): + node = BuildAxesTopologyNode() + + h0 = FakeAtom(index=10, mass=1.0) + c0 = FakeAtom(index=11, mass=12.0, bonded_atoms=FakeAtomGroup([h0])) + residue0 = FakeResidue(FakeAtomGroup([c0, h0])) + + h1 = FakeAtom(index=20, mass=1.0) + c1 = FakeAtom(index=21, mass=12.0, bonded_atoms=FakeAtomGroup([h1])) + residue1 = FakeResidue(FakeAtomGroup([c1, h1])) + + molecule = FakeMolecule([residue0, residue1]) + universe = FakeUniverse( + [molecule], + { + h0.index: h0, + c0.index: c0, + h1.index: h1, + c1.index: c1, + }, + ) + out = {} + + node._add_ua_topology( + u=universe, + mol=molecule, + mol_id=3, + beads={ + (3, "united_atom", 0): [np.array([11, 10])], + (3, "united_atom", 1): [np.array([21, 20])], + }, + out=out, + ) + + assert set(out) == {(3, 0, 0), (3, 1, 0)} + assert out[(3, 0, 0)].heavy_atom_index == 11 + assert out[(3, 1, 0)].heavy_atom_index == 21 + + +def test_split_bonded_atoms_returns_heavy_and_light_atom_groups(): + hydrogen = FakeAtom(index=2, mass=1.0) + heavy_bonded = FakeAtom(index=3, mass=12.0) + atom = FakeAtom( + index=1, + mass=12.0, + bonded_atoms=FakeAtomGroup([hydrogen, heavy_bonded]), + ) + + bonded_heavy, bonded_light = BuildAxesTopologyNode._split_bonded_atoms(atom) + + np.testing.assert_array_equal(bonded_heavy.indices, np.array([3])) + np.testing.assert_array_equal(bonded_light.indices, np.array([2])) + + +def test_get_ua_masses_from_topology_adds_bonded_hydrogen_masses(): + hydrogen = FakeAtom(index=2, mass=1.0) + heavy = FakeAtom(index=1, mass=12.0, bonded_atoms=FakeAtomGroup([hydrogen])) + other_heavy = FakeAtom(index=3, mass=14.0) + atom_group = FakeAtomGroup([heavy, hydrogen, other_heavy]) + + masses = BuildAxesTopologyNode._get_ua_masses_from_topology(atom_group) + + assert masses == [13.0, 14.0] + + +def test_get_ua_masses_from_topology_handles_atoms_without_bonded_atoms_attribute(): + class AtomWithoutBonds: + """Small atom-like object without bonded_atoms.""" + + def __init__(self, index, mass): + self.index = index + self.mass = mass + + heavy = AtomWithoutBonds(index=1, mass=12.0) + hydrogen = AtomWithoutBonds(index=2, mass=1.0) + atom_group = FakeAtomGroup([heavy, hydrogen]) + + masses = BuildAxesTopologyNode._get_ua_masses_from_topology(atom_group) + + assert masses == [12.0] + + +def test_get_ua_masses_from_topology_returns_empty_list_for_light_atoms_only(): + hydrogen = FakeAtom(index=2, mass=1.0) + atom_group = FakeAtomGroup([hydrogen]) + + masses = BuildAxesTopologyNode._get_ua_masses_from_topology(atom_group) + + assert masses == [] diff --git a/tests/unit/CodeEntropy/levels/nodes/test_covariance_node.py b/tests/unit/CodeEntropy/levels/nodes/test_covariance_node.py index 8398113c..6a3d2ec8 100644 --- a/tests/unit/CodeEntropy/levels/nodes/test_covariance_node.py +++ b/tests/unit/CodeEntropy/levels/nodes/test_covariance_node.py @@ -88,6 +88,7 @@ def test_run_processes_all_levels_and_writes_frame_covariance(): mol = FakeMolecule() universe = FakeUniverse([mol], dimensions=np.array([10.0, 20.0, 30.0, 90.0])) axes_manager = object() + axes_topology = object() ctx = { "shared": { @@ -97,6 +98,7 @@ def test_run_processes_all_levels_and_writes_frame_covariance(): "beads": {}, "args": _args(combined_forcetorque=True, customised_axes=True), "axes_manager": axes_manager, + "axes_topology": axes_topology, } } @@ -115,6 +117,7 @@ def test_run_processes_all_levels_and_writes_frame_covariance(): assert ua_kwargs["mol_id"] == 0 assert ua_kwargs["group_id"] == 7 assert ua_kwargs["axes_manager"] is axes_manager + assert ua_kwargs["axes_topology"] is axes_topology assert ua_kwargs["force_partitioning"] == 0.5 assert ua_kwargs["customised_axes"] is True assert ua_kwargs["is_highest"] is False @@ -165,6 +168,7 @@ def test_process_united_atom_updates_outputs_and_molcount(): group_id=7, beads={(0, "united_atom", 0): [np.array([0])]}, axes_manager="axes", + axes_topology=None, box=None, force_partitioning=0.5, customised_axes=False, @@ -192,6 +196,7 @@ def test_process_united_atom_returns_when_no_beads_or_empty_atom_groups(): group_id=7, beads={}, axes_manager=None, + axes_topology=None, box=None, force_partitioning=0.5, customised_axes=False, @@ -213,6 +218,7 @@ def test_process_united_atom_returns_when_no_beads_or_empty_atom_groups(): group_id=7, beads={(0, "united_atom", 0): [np.array([0])]}, axes_manager=None, + axes_topology=None, box=None, force_partitioning=0.5, customised_axes=False, @@ -414,9 +420,13 @@ def test_build_ua_vectors_uses_customised_axes(): node._ft.get_weighted_torques = MagicMock(return_value=np.array([0.0, 1.0, 0.0])) force_vecs, torque_vecs = node._build_ua_vectors( + u=FakeUniverse([]), + mol_id=0, + local_res_i=0, bead_groups=[FakeAtomGroup("ua")], residue_atoms=FakeAtomGroup("res"), axes_manager=axes_manager, + axes_topology=None, box=None, force_partitioning=0.5, customised_axes=True, @@ -428,6 +438,49 @@ def test_build_ua_vectors_uses_customised_axes(): axes_manager.get_UA_axes.assert_called_once() +def test_build_ua_vectors_uses_cached_axes_topology_when_available(): + node = FrameCovarianceNode() + axes_manager = MagicMock() + + u = FakeUniverse([]) + ua_topology = object() + axes_topology = SimpleNamespace(ua={(3, 4, 0): ua_topology}) + + axes_manager.get_UA_axes_from_topology.return_value = ( + np.eye(3), + 2.0 * np.eye(3), + np.ones(3), + np.array([1.0, 2.0, 3.0]), + ) + node._ft.get_weighted_forces = MagicMock(return_value=np.array([1.0, 0.0, 0.0])) + node._ft.get_weighted_torques = MagicMock(return_value=np.array([0.0, 1.0, 0.0])) + + force_vecs, torque_vecs = node._build_ua_vectors( + u=u, + mol_id=3, + local_res_i=4, + bead_groups=[FakeAtomGroup("ua")], + residue_atoms=FakeAtomGroup("res"), + axes_manager=axes_manager, + axes_topology=axes_topology, + box=None, + force_partitioning=0.5, + customised_axes=True, + is_highest=True, + ) + + assert len(force_vecs) == 1 + assert len(torque_vecs) == 1 + + called_kwargs = axes_manager.get_UA_axes_from_topology.call_args.kwargs + assert called_kwargs["u"] is u + assert called_kwargs["topology"] is ua_topology + assert called_kwargs["box"] is None + assert called_kwargs["residue_atoms"].name == "res" + + axes_manager.get_UA_axes.assert_not_called() + + def test_build_ua_vectors_uses_vanilla_axes_when_not_customised(): node = FrameCovarianceNode() axes_manager = MagicMock() @@ -440,9 +493,13 @@ def test_build_ua_vectors_uses_vanilla_axes_when_not_customised(): with patch("CodeEntropy.levels.nodes.covariance.make_whole") as make_whole: node._build_ua_vectors( + u=FakeUniverse([]), + mol_id=0, + local_res_i=0, bead_groups=[FakeAtomGroup("ua")], residue_atoms=FakeAtomGroup("res"), axes_manager=axes_manager, + axes_topology=None, box=None, force_partitioning=0.5, customised_axes=False, diff --git a/tests/unit/CodeEntropy/levels/test_axes.py b/tests/unit/CodeEntropy/levels/test_axes.py index d6d0093f..68adfb8a 100644 --- a/tests/unit/CodeEntropy/levels/test_axes.py +++ b/tests/unit/CodeEntropy/levels/test_axes.py @@ -4,6 +4,7 @@ import pytest from CodeEntropy.levels.axes import AxesCalculator +from CodeEntropy.levels.nodes.axes_topology import UAAxesTopology class _FakeAtom: @@ -699,3 +700,396 @@ def test_get_bonded_axes_returns_none_none_if_custom_axes_none(monkeypatch): assert custom_axes is None assert moi is None + + +class _FakeIndexedAtoms: + """Container supporting ``u.atoms[index]`` and ``u.atoms[index_array]``.""" + + def __init__(self, atom_map): + self._atom_map = dict(atom_map) + + def __getitem__(self, index): + if isinstance(index, np.ndarray): + return _FakeAtomGroup([self._atom_map[int(i)] for i in index]) + if isinstance(index, (list, tuple)): + return _FakeAtomGroup([self._atom_map[int(i)] for i in index]) + return self._atom_map[int(index)] + + +class _FakeUniverse: + """Small universe-like object with indexed atoms and dimensions.""" + + def __init__(self, atom_map, dimensions=None): + self.atoms = _FakeIndexedAtoms(atom_map) + self.dimensions = np.asarray( + dimensions + if dimensions is not None + else [10.0, 10.0, 10.0, 90.0, 90.0, 90.0], + dtype=float, + ) + + +def _ua_topology( + *, + heavy_atom_index=1, + ua_atom_indices=(1,), + ua_all_atom_indices=(1,), + bonded_heavy_indices=(), + bonded_light_indices=(), + residue_heavy_indices=(1,), + residue_ua_masses=(12.0,), +): + """Build a small cached UA topology fixture.""" + return UAAxesTopology( + heavy_atom_index=int(heavy_atom_index), + ua_atom_indices=np.asarray(ua_atom_indices, dtype=int), + ua_all_atom_indices=np.asarray(ua_all_atom_indices, dtype=int), + bonded_heavy_indices=np.asarray(bonded_heavy_indices, dtype=int), + bonded_light_indices=np.asarray(bonded_light_indices, dtype=int), + residue_heavy_indices=np.asarray(residue_heavy_indices, dtype=int), + residue_ua_masses=np.asarray(residue_ua_masses, dtype=float), + ) + + +def test_get_UA_axes_from_topology_multiple_heavy_uses_cached_indices_and_box( + monkeypatch, +): + ax = AxesCalculator() + + heavy_atom = _FakeAtom(1, 12.0, [1.0, 2.0, 3.0]) + other_heavy = _FakeAtom(3, 14.0, [4.0, 5.0, 6.0]) + universe = _FakeUniverse({1: heavy_atom, 3: other_heavy}) + residue_atoms = MagicMock() + residue_atoms.center_of_mass.return_value = np.array([9.0, 8.0, 7.0]) + + topology = _ua_topology( + heavy_atom_index=1, + residue_heavy_indices=(1, 3), + residue_ua_masses=(13.0, 14.0), + ) + + get_tensor = MagicMock(return_value=np.eye(3)) + get_principal = MagicMock(return_value=(np.eye(3) * 2.0, np.array([3.0, 2.0, 1.0]))) + get_bonded = MagicMock(return_value=(np.eye(3) * 4.0, np.array([1.0, 1.0, 1.0]))) + + monkeypatch.setattr(ax, "get_moment_of_inertia_tensor", get_tensor) + monkeypatch.setattr(ax, "get_custom_principal_axes", get_principal) + monkeypatch.setattr(ax, "get_bonded_axes_from_topology", get_bonded) + + box = np.array([20.0, 30.0, 40.0]) + trans_axes, rot_axes, center, moi = ax.get_UA_axes_from_topology( + u=universe, + residue_atoms=residue_atoms, + topology=topology, + box=box, + ) + + np.testing.assert_allclose(trans_axes, np.eye(3) * 2.0) + np.testing.assert_allclose(rot_axes, np.eye(3) * 4.0) + np.testing.assert_allclose(center, heavy_atom.position) + np.testing.assert_allclose(moi, np.array([1.0, 1.0, 1.0])) + + get_tensor.assert_called_once() + tensor_kwargs = get_tensor.call_args.kwargs + np.testing.assert_allclose( + tensor_kwargs["center_of_mass"], np.array([9.0, 8.0, 7.0]) + ) + np.testing.assert_allclose( + tensor_kwargs["positions"], np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + ) + np.testing.assert_allclose(tensor_kwargs["masses"], np.array([13.0, 14.0])) + np.testing.assert_allclose(tensor_kwargs["dimensions"], box) + + get_principal.assert_called_once() + np.testing.assert_allclose(get_principal.call_args.args[0], np.eye(3)) + get_bonded.assert_called_once_with( + u=universe, + heavy_atom=heavy_atom, + topology=topology, + dimensions=box, + ) + + +def test_get_UA_axes_from_topology_single_heavy_uses_residue_principal_axes( + monkeypatch, +): + ax = AxesCalculator() + + heavy_atom = _FakeAtom(1, 12.0, [1.0, 0.0, 0.0]) + universe = _FakeUniverse( + {1: heavy_atom}, dimensions=[11.0, 12.0, 13.0, 90.0, 90.0, 90.0] + ) + residue_atoms = MagicMock() + residue_atoms.principal_axes.return_value = np.eye(3) * 5.0 + + topology = _ua_topology(heavy_atom_index=1, residue_heavy_indices=(1,)) + + make_whole = MagicMock() + get_bonded = MagicMock(return_value=(np.eye(3) * 6.0, np.array([6.0, 5.0, 4.0]))) + + monkeypatch.setattr("CodeEntropy.levels.axes.make_whole", make_whole) + monkeypatch.setattr(ax, "get_bonded_axes_from_topology", get_bonded) + + trans_axes, rot_axes, center, moi = ax.get_UA_axes_from_topology( + u=universe, + residue_atoms=residue_atoms, + topology=topology, + box=None, + ) + + make_whole.assert_called_once_with(residue_atoms) + residue_atoms.principal_axes.assert_called_once() + np.testing.assert_allclose(trans_axes, np.eye(3) * 5.0) + np.testing.assert_allclose(rot_axes, np.eye(3) * 6.0) + np.testing.assert_allclose(center, heavy_atom.position) + np.testing.assert_allclose(moi, np.array([6.0, 5.0, 4.0])) + + called_kwargs = get_bonded.call_args.kwargs + np.testing.assert_allclose( + called_kwargs["dimensions"], np.array([11.0, 12.0, 13.0]) + ) + + +def test_get_UA_axes_from_topology_raises_when_cached_bonded_axes_fail(monkeypatch): + ax = AxesCalculator() + + heavy_atom = _FakeAtom(1, 12.0, [1.0, 0.0, 0.0]) + universe = _FakeUniverse({1: heavy_atom}) + residue_atoms = MagicMock() + residue_atoms.principal_axes.return_value = np.eye(3) + topology = _ua_topology(heavy_atom_index=1, residue_heavy_indices=(1,)) + + monkeypatch.setattr("CodeEntropy.levels.axes.make_whole", lambda _ag: None) + monkeypatch.setattr( + ax, "get_bonded_axes_from_topology", lambda **kwargs: (None, None) + ) + + with pytest.raises(ValueError, match="cached UA bead"): + ax.get_UA_axes_from_topology( + u=universe, + residue_atoms=residue_atoms, + topology=topology, + box=None, + ) + + +def test_get_bonded_axes_from_topology_non_heavy_returns_none_none(): + ax = AxesCalculator() + light_atom = _FakeAtom(1, 1.0, [0.0, 0.0, 0.0]) + + custom_axes, moi = ax.get_bonded_axes_from_topology( + u=MagicMock(), + heavy_atom=light_atom, + topology=_ua_topology(heavy_atom_index=1), + dimensions=np.array([10.0, 10.0, 10.0]), + ) + + assert custom_axes is None + assert moi is None + + +def test_get_bonded_axes_from_topology_no_bonded_heavy_uses_vanilla_axes( + monkeypatch, +): + ax = AxesCalculator() + + heavy_atom = _FakeAtom(1, 12.0, [0.0, 0.0, 0.0]) + hydrogen = _FakeAtom(2, 1.0, [1.0, 0.0, 0.0]) + universe = _FakeUniverse({1: heavy_atom, 2: hydrogen}) + topology = _ua_topology( + heavy_atom_index=1, + ua_atom_indices=(1, 2), + ua_all_atom_indices=(1, 2), + bonded_heavy_indices=(), + bonded_light_indices=(2,), + ) + + get_vanilla = MagicMock(return_value=(np.eye(3) * 7.0, np.array([7.0, 8.0, 9.0]))) + get_custom_moi = MagicMock() + get_flipped = MagicMock(return_value=np.eye(3) * -7.0) + + monkeypatch.setattr(ax, "get_vanilla_axes", get_vanilla) + monkeypatch.setattr(ax, "get_custom_moment_of_inertia", get_custom_moi) + monkeypatch.setattr(ax, "get_flipped_axes", get_flipped) + + custom_axes, moi = ax.get_bonded_axes_from_topology( + u=universe, + heavy_atom=heavy_atom, + topology=topology, + dimensions=np.array([10.0, 10.0, 10.0]), + ) + + np.testing.assert_allclose(custom_axes, np.eye(3) * -7.0) + np.testing.assert_allclose(moi, np.array([7.0, 8.0, 9.0])) + get_vanilla.assert_called_once() + get_custom_moi.assert_not_called() + get_flipped.assert_called_once() + + +def test_get_bonded_axes_from_topology_one_heavy_no_light_uses_custom_axes( + monkeypatch, +): + ax = AxesCalculator() + + heavy_atom = _FakeAtom(1, 12.0, [0.0, 0.0, 0.0]) + bonded_heavy = _FakeAtom(3, 12.0, [1.0, 0.0, 0.0]) + universe = _FakeUniverse({1: heavy_atom, 3: bonded_heavy}) + topology = _ua_topology( + heavy_atom_index=1, + ua_atom_indices=(1,), + ua_all_atom_indices=(1, 3), + bonded_heavy_indices=(3,), + bonded_light_indices=(), + ) + + get_custom_axes = MagicMock(return_value=np.eye(3) * 2.0) + get_custom_moi = MagicMock(return_value=np.array([2.0, 3.0, 4.0])) + get_flipped = MagicMock(return_value=np.eye(3) * 3.0) + + monkeypatch.setattr(ax, "get_custom_axes", get_custom_axes) + monkeypatch.setattr(ax, "get_custom_moment_of_inertia", get_custom_moi) + monkeypatch.setattr(ax, "get_flipped_axes", get_flipped) + + custom_axes, moi = ax.get_bonded_axes_from_topology( + u=universe, + heavy_atom=heavy_atom, + topology=topology, + dimensions=np.array([10.0, 10.0, 10.0]), + ) + + np.testing.assert_allclose(custom_axes, np.eye(3) * 3.0) + np.testing.assert_allclose(moi, np.array([2.0, 3.0, 4.0])) + + kwargs = get_custom_axes.call_args.kwargs + np.testing.assert_allclose(kwargs["a"], heavy_atom.position) + np.testing.assert_allclose(kwargs["b_list"][0], bonded_heavy.position) + np.testing.assert_allclose(kwargs["c"], np.zeros(3)) + get_custom_moi.assert_called_once() + get_flipped.assert_called_once() + + +def test_get_bonded_axes_from_topology_one_heavy_with_light_uses_light_as_c( + monkeypatch, +): + ax = AxesCalculator() + + heavy_atom = _FakeAtom(1, 12.0, [0.0, 0.0, 0.0]) + bonded_heavy = _FakeAtom(3, 12.0, [1.0, 0.0, 0.0]) + bonded_light = _FakeAtom(2, 1.0, [0.0, 1.0, 0.0]) + universe = _FakeUniverse({1: heavy_atom, 2: bonded_light, 3: bonded_heavy}) + topology = _ua_topology( + heavy_atom_index=1, + ua_atom_indices=(1, 2), + ua_all_atom_indices=(1, 3, 2), + bonded_heavy_indices=(3,), + bonded_light_indices=(2,), + ) + + get_custom_axes = MagicMock(return_value=np.eye(3)) + monkeypatch.setattr(ax, "get_custom_axes", get_custom_axes) + monkeypatch.setattr( + ax, + "get_custom_moment_of_inertia", + lambda **kwargs: np.array([1.0, 2.0, 3.0]), + ) + monkeypatch.setattr(ax, "get_flipped_axes", lambda ua, axes, com, dims: axes) + + custom_axes, moi = ax.get_bonded_axes_from_topology( + u=universe, + heavy_atom=heavy_atom, + topology=topology, + dimensions=np.array([10.0, 10.0, 10.0]), + ) + + np.testing.assert_allclose(custom_axes, np.eye(3)) + np.testing.assert_allclose(moi, np.array([1.0, 2.0, 3.0])) + np.testing.assert_allclose( + get_custom_axes.call_args.kwargs["c"], bonded_light.position + ) + + +def test_get_bonded_axes_from_topology_two_heavy_uses_heavy_positions_as_b_list( + monkeypatch, +): + ax = AxesCalculator() + + heavy_atom = _FakeAtom(1, 12.0, [0.0, 0.0, 0.0]) + bonded_heavy_0 = _FakeAtom(3, 12.0, [1.0, 0.0, 0.0]) + bonded_heavy_1 = _FakeAtom(4, 12.0, [0.0, 1.0, 0.0]) + universe = _FakeUniverse( + { + 1: heavy_atom, + 3: bonded_heavy_0, + 4: bonded_heavy_1, + } + ) + topology = _ua_topology( + heavy_atom_index=1, + ua_atom_indices=(1,), + ua_all_atom_indices=(1, 3, 4), + bonded_heavy_indices=(3, 4), + bonded_light_indices=(), + ) + + get_custom_axes = MagicMock(return_value=np.eye(3) * 4.0) + monkeypatch.setattr(ax, "get_custom_axes", get_custom_axes) + monkeypatch.setattr( + ax, + "get_custom_moment_of_inertia", + lambda **kwargs: np.array([4.0, 5.0, 6.0]), + ) + monkeypatch.setattr(ax, "get_flipped_axes", lambda ua, axes, com, dims: axes) + + custom_axes, moi = ax.get_bonded_axes_from_topology( + u=universe, + heavy_atom=heavy_atom, + topology=topology, + dimensions=np.array([10.0, 10.0, 10.0]), + ) + + np.testing.assert_allclose(custom_axes, np.eye(3) * 4.0) + np.testing.assert_allclose(moi, np.array([4.0, 5.0, 6.0])) + np.testing.assert_allclose( + get_custom_axes.call_args.kwargs["b_list"], + np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]), + ) + np.testing.assert_allclose( + get_custom_axes.call_args.kwargs["c"], + bonded_heavy_1.position, + ) + + +def test_get_bonded_axes_from_topology_returns_none_when_custom_axes_none( + monkeypatch, +): + ax = AxesCalculator() + + heavy_atom = _FakeAtom(1, 12.0, [0.0, 0.0, 0.0]) + bonded_heavy = _FakeAtom(3, 12.0, [1.0, 0.0, 0.0]) + universe = _FakeUniverse({1: heavy_atom, 3: bonded_heavy}) + topology = _ua_topology( + heavy_atom_index=1, + ua_atom_indices=(1,), + ua_all_atom_indices=(1, 3), + bonded_heavy_indices=(3,), + bonded_light_indices=(), + ) + + get_custom_moi = MagicMock() + get_flipped = MagicMock() + + monkeypatch.setattr(ax, "get_custom_axes", lambda **kwargs: None) + monkeypatch.setattr(ax, "get_custom_moment_of_inertia", get_custom_moi) + monkeypatch.setattr(ax, "get_flipped_axes", get_flipped) + + custom_axes, moi = ax.get_bonded_axes_from_topology( + u=universe, + heavy_atom=heavy_atom, + topology=topology, + dimensions=np.array([10.0, 10.0, 10.0]), + ) + + assert custom_axes is None + assert moi is None + get_custom_moi.assert_not_called() + get_flipped.assert_not_called() diff --git a/tests/unit/CodeEntropy/levels/test_level_dag.py b/tests/unit/CodeEntropy/levels/test_level_dag.py index 1b9f4d4f..11e74e9e 100644 --- a/tests/unit/CodeEntropy/levels/test_level_dag.py +++ b/tests/unit/CodeEntropy/levels/test_level_dag.py @@ -12,6 +12,7 @@ def test_build_registers_static_nodes_and_builds_stage_dags(): patch("CodeEntropy.levels.level_dag.DetectMoleculesNode"), patch("CodeEntropy.levels.level_dag.DetectLevelsNode"), patch("CodeEntropy.levels.level_dag.BuildBeadsNode"), + patch("CodeEntropy.levels.level_dag.BuildAxesTopologyNode"), patch("CodeEntropy.levels.level_dag.InitCovarianceAccumulatorsNode"), patch("CodeEntropy.levels.level_dag.ConformationDAG"), ): @@ -27,6 +28,7 @@ def test_build_registers_static_nodes_and_builds_stage_dags(): "detect_molecules", "detect_levels", "build_beads", + "build_axes_topology", "init_covariance_accumulators", } assert "find_neighbors" not in dag._static_nodes @@ -34,6 +36,7 @@ def test_build_registers_static_nodes_and_builds_stage_dags(): assert ("detect_molecules", "detect_levels") in dag._static_graph.edges assert ("detect_levels", "build_beads") in dag._static_graph.edges + assert ("build_beads", "build_axes_topology") in dag._static_graph.edges assert ("detect_levels", "init_covariance_accumulators") in dag._static_graph.edges assert ( "detect_levels", From e815cb1e3ca568e710c2c06a5668c869374c365d Mon Sep 17 00:00:00 2001 From: harryswift01 Date: Wed, 24 Jun 2026 10:21:30 +0100 Subject: [PATCH 3/3] tests(unit): tidy up comments within axes test file --- tests/unit/CodeEntropy/levels/nodes/test_axes_topology.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/unit/CodeEntropy/levels/nodes/test_axes_topology.py b/tests/unit/CodeEntropy/levels/nodes/test_axes_topology.py index db31f026..680c439e 100644 --- a/tests/unit/CodeEntropy/levels/nodes/test_axes_topology.py +++ b/tests/unit/CodeEntropy/levels/nodes/test_axes_topology.py @@ -1,5 +1,3 @@ -"""Atomic unit tests for static axes-topology construction.""" - from __future__ import annotations from types import SimpleNamespace