From f349c09c00fff8dd2c30ae2d2ba335d7689414c5 Mon Sep 17 00:00:00 2001 From: harryswift01 Date: Wed, 24 Jun 2026 11:53:43 +0100 Subject: [PATCH 1/2] perf(axes): cache residue topology for customised axes --- CodeEntropy/levels/axes.py | 58 +++++++++++ CodeEntropy/levels/nodes/axes_topology.py | 119 +++++++++++++++++++--- CodeEntropy/levels/nodes/covariance.py | 37 +++++++ 3 files changed, 201 insertions(+), 13 deletions(-) diff --git a/CodeEntropy/levels/axes.py b/CodeEntropy/levels/axes.py index cb0af14..4d4cf89 100644 --- a/CodeEntropy/levels/axes.py +++ b/CodeEntropy/levels/axes.py @@ -135,6 +135,64 @@ def get_residue_axes(self, data_container, index: int, residue=None): return trans_axes, rot_axes, center, moment_of_inertia + def get_residue_axes_from_topology( + self, + *, + u, + mol, + residue_atoms, + topology, + box: np.ndarray | None, + ): + """Compute residue axes using cached static topology. + + This is the cached-index equivalent of ``get_residue_axes``. It keeps + all frame-dependent numerical work frame-local, but avoids repeated + MDAnalysis selections for residue heavy atoms, UA masses, and neighbour + bond discovery. + + Args: + u: Current-frame universe used to resolve cached atom indices. + mol: Current-frame molecule fragment. + residue_atoms: AtomGroup for the residue in the current frame. + topology: Cached ``ResidueAxesTopology`` for this residue. + box: Current periodic box lengths. If omitted, ``u.dimensions`` is used. + + Returns: + Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + - trans_axes: Translational axes, shape ``(3, 3)``. + - rot_axes: Rotational axes, shape ``(3, 3)``. + - center: Residue centre, shape ``(3,)``. + - moment_of_inertia: Principal moments, shape ``(3,)``. + """ + dimensions = ( + np.asarray(box, dtype=float) + if box is not None + else np.asarray(u.dimensions[:3], dtype=float) + ) + + center = residue_atoms.center_of_mass(unwrap=True) + + if not topology.has_neighbor_bonds: + heavy_atoms = u.atoms[topology.residue_heavy_indices] + moment_of_inertia_tensor = self.get_moment_of_inertia_tensor( + center_of_mass=center, + positions=heavy_atoms.positions, + masses=topology.residue_ua_masses, + dimensions=dimensions, + ) + rot_axes, moment_of_inertia = self.get_custom_principal_axes( + moment_of_inertia_tensor + ) + trans_axes = rot_axes + else: + make_whole(mol.atoms) + trans_axes = mol.atoms.principal_axes() + rot_axes, moment_of_inertia = self.get_vanilla_axes(residue_atoms) + center = residue_atoms.center_of_mass(unwrap=True) + + return trans_axes, rot_axes, center, moment_of_inertia + def get_UA_axes(self, data_container, index: int): """Compute united-atom-level translational and rotational axes. diff --git a/CodeEntropy/levels/nodes/axes_topology.py b/CodeEntropy/levels/nodes/axes_topology.py index abac373..5197d75 100644 --- a/CodeEntropy/levels/nodes/axes_topology.py +++ b/CodeEntropy/levels/nodes/axes_topology.py @@ -1,9 +1,9 @@ """Build static axes-topology metadata for frame covariance calculations. This module caches topology-only atom-index relationships needed by customised -united-atom axes calculations. The cache avoids repeated MDAnalysis selection -parsing inside the frame-local covariance loop while preserving frame-dependent -positions, forces, centres, axes, torques, and moments of inertia. +axes calculations. The cache avoids repeated MDAnalysis selection parsing inside +the frame-local covariance loop while preserving frame-dependent positions, +forces, centres, axes, torques, and moments of inertia. """ from __future__ import annotations @@ -17,6 +17,7 @@ logger = logging.getLogger(__name__) UAKey = tuple[int, int, int] +ResidueKey = tuple[int, int] @dataclass(frozen=True) @@ -44,6 +45,22 @@ class UAAxesTopology: residue_ua_masses: np.ndarray +@dataclass(frozen=True) +class ResidueAxesTopology: + """Static topology required to compute customised residue axes. + + Attributes: + residue_heavy_indices: Heavy atom indices in the residue. + residue_ua_masses: UA masses for heavy atoms in the residue. + has_neighbor_bonds: Whether the residue is bonded to a neighbouring + residue according to the original customised residue-axis selection. + """ + + residue_heavy_indices: np.ndarray + residue_ua_masses: np.ndarray + has_neighbor_bonds: bool + + @dataclass(frozen=True) class AxesTopology: """Cached axes topology for frame covariance calculations. @@ -51,9 +68,12 @@ class AxesTopology: Attributes: ua: Mapping from ``(mol_id, local_residue_id, ua_id)`` to cached united-atom axes topology. + residue: Mapping from ``(mol_id, local_residue_id)`` to cached + residue axes topology. """ ua: dict[UAKey, UAAxesTopology] = field(default_factory=dict) + residue: dict[ResidueKey, ResidueAxesTopology] = field(default_factory=dict) class BuildAxesTopologyNode: @@ -86,24 +106,76 @@ def run(self, shared_data: dict[str, Any]) -> dict[str, Any]: beads = shared_data["beads"] ua_topology: dict[UAKey, UAAxesTopology] = {} + residue_topology: dict[ResidueKey, ResidueAxesTopology] = {} fragments = u.atoms.fragments for mol_id, level_list in enumerate(levels): - if "united_atom" not in level_list: - continue + mol = fragments[mol_id] + + if "residue" in level_list: + self._add_residue_topology( + mol=mol, + mol_id=mol_id, + beads=beads, + out=residue_topology, + ) - self._add_ua_topology( - u=u, - mol=fragments[mol_id], - mol_id=mol_id, - beads=beads, - out=ua_topology, - ) + if "united_atom" in level_list: + self._add_ua_topology( + u=u, + mol=mol, + mol_id=mol_id, + beads=beads, + out=ua_topology, + ) - topology = AxesTopology(ua=ua_topology) + topology = AxesTopology(ua=ua_topology, residue=residue_topology) shared_data["axes_topology"] = topology return {"axes_topology": topology} + def _add_residue_topology( + self, + *, + mol: Any, + mol_id: int, + beads: dict[Any, list[np.ndarray]], + out: dict[ResidueKey, ResidueAxesTopology], + ) -> None: + """Cache static residue axes topology for one molecule. + + Args: + mol: Molecule AtomGroup. + mol_id: Molecule index. + beads: Bead-index mapping produced by ``BuildBeadsNode``. + out: Output residue topology mapping mutated in place. + """ + bead_key = (mol_id, "residue") + bead_idx_list = beads.get(bead_key, []) + if not bead_idx_list: + return + + for local_res_i, residue in enumerate(mol.residues): + if local_res_i >= len(bead_idx_list): + continue + + residue_atoms = residue.atoms + residue_heavy = residue_atoms.select_atoms("mass 2 to 999") + residue_heavy_indices = residue_heavy.indices.astype(int, copy=True) + residue_ua_masses = np.asarray( + self._get_ua_masses_from_topology(residue_atoms), + dtype=float, + ) + has_neighbor_bonds = self._has_neighbor_bonds( + mol=mol, + local_res_i=local_res_i, + ) + + out[(mol_id, local_res_i)] = ResidueAxesTopology( + residue_heavy_indices=residue_heavy_indices, + residue_ua_masses=residue_ua_masses, + has_neighbor_bonds=has_neighbor_bonds, + ) + def _add_ua_topology( self, *, @@ -177,6 +249,27 @@ def _add_ua_topology( residue_ua_masses=residue_ua_masses, ) + @staticmethod + def _has_neighbor_bonds(*, mol: Any, local_res_i: int) -> bool: + """Return whether a residue is bonded to neighbouring residues. + + Args: + mol: Molecule AtomGroup used for the original bonded-neighbour + selection. + local_res_i: Residue index local to ``mol``. + + Returns: + True when the residue has bonded atoms in the previous or next + residue according to the original customised residue-axis query. + """ + index_prev = local_res_i - 1 + index_next = local_res_i + 1 + atom_set = mol.select_atoms( + f"(resindex {index_prev} or resindex {index_next}) " + f"and bonded resid {local_res_i}" + ) + return len(atom_set) > 0 + @staticmethod def _split_bonded_atoms(atom: Any) -> tuple[Any, Any]: """Return bonded heavy and light atoms for one atom. diff --git a/CodeEntropy/levels/nodes/covariance.py b/CodeEntropy/levels/nodes/covariance.py index e7f829d..39ff61a 100644 --- a/CodeEntropy/levels/nodes/covariance.py +++ b/CodeEntropy/levels/nodes/covariance.py @@ -129,6 +129,7 @@ def run(self, ctx: FrameCtx) -> dict[str, Any]: group_id=group_id, beads=beads, axes_manager=axes_manager, + axes_topology=axes_topology, box=box, customised_axes=customised_axes, force_partitioning=fp, @@ -242,6 +243,7 @@ def _process_residue( group_id: int, beads: dict[Any, list[Any]], axes_manager: Any, + axes_topology: Any | None, box: np.ndarray | None, customised_axes: bool, force_partitioning: float, @@ -261,6 +263,7 @@ def _process_residue( group_id: Molecule-group identifier used for within-frame averaging. beads: Mapping of bead keys to reduced-universe atom-index arrays. axes_manager: Axes helper used to build translation and rotation axes. + axes_topology: Optional cached axes topology generated during static setup. box: Optional periodic box vector. customised_axes: Whether customised residue axes should be used. force_partitioning: Force partitioning factor for highest-level vectors. @@ -281,9 +284,12 @@ def _process_residue( return force_vecs, torque_vecs = self._build_residue_vectors( + u=u, mol=mol, + mol_id=mol_id, bead_groups=bead_groups, axes_manager=axes_manager, + axes_topology=axes_topology, box=box, customised_axes=customised_axes, force_partitioning=force_partitioning, @@ -481,9 +487,12 @@ def _build_ua_vectors( def _build_residue_vectors( self, *, + u: Any, mol: Any, + mol_id: int, bead_groups: list[Any], axes_manager: Any, + axes_topology: Any | None, box: np.ndarray | None, customised_axes: bool, force_partitioning: float, @@ -492,9 +501,12 @@ def _build_residue_vectors( """Build force and torque vectors for residue beads. Args: + u: Universe-like object used to resolve cached atom indices. mol: Molecule fragment containing residues and atoms. + mol_id: Molecule index used in axes-topology lookup keys. bead_groups: Atom groups representing residue beads. axes_manager: Axes helper used to select axes, centres, and moments. + axes_topology: Optional cached axes topology generated during static setup. box: Optional periodic box vector. customised_axes: Whether customised residue axes should be used. force_partitioning: Force partitioning factor for highest-level vectors. @@ -508,10 +520,14 @@ def _build_residue_vectors( for local_res_i, bead in enumerate(bead_groups): trans_axes, rot_axes, center, moi = self._get_residue_axes( + u=u, mol=mol, + mol_id=mol_id, bead=bead, local_res_i=local_res_i, axes_manager=axes_manager, + axes_topology=axes_topology, + box=box, customised_axes=customised_axes, ) @@ -540,19 +556,27 @@ def _build_residue_vectors( def _get_residue_axes( self, *, + u: Any, mol: Any, + mol_id: int, bead: Any, local_res_i: int, axes_manager: Any, + axes_topology: Any | None, + box: np.ndarray | None, customised_axes: bool, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Return axes, centre, and inertia data for a residue bead. Args: + u: Universe-like object used to resolve cached atom indices. mol: Molecule fragment containing residues and atoms. + mol_id: Molecule index used in axes-topology lookup keys. bead: Atom group representing the residue bead. local_res_i: Residue index local to ``mol``. axes_manager: Axes helper used to select axes, centres, and moments. + axes_topology: Optional cached axes topology generated during static setup. + box: Optional periodic box vector. customised_axes: Whether customised residue axes should be used. Returns: @@ -560,6 +584,19 @@ def _get_residue_axes( """ if customised_axes: res = mol.residues[local_res_i] + residue_topology = None + if axes_topology is not None: + residue_topology = axes_topology.residue.get((mol_id, local_res_i)) + + if residue_topology is not None: + return axes_manager.get_residue_axes_from_topology( + u=u, + mol=mol, + residue_atoms=res.atoms, + topology=residue_topology, + box=box, + ) + return axes_manager.get_residue_axes(mol, local_res_i, residue=res.atoms) make_whole(mol.atoms) From 09bc892872fffde51646ddd4088562e5a34598c9 Mon Sep 17 00:00:00 2001 From: harryswift01 Date: Wed, 24 Jun 2026 11:55:56 +0100 Subject: [PATCH 2/2] tests(unit): cover cached residue axes topology --- .../levels/nodes/test_axes_topology.py | 87 ++++++++++++++- .../levels/nodes/test_covariance_node.py | 47 +++++++- tests/unit/CodeEntropy/levels/test_axes.py | 105 +++++++++++++++++- 3 files changed, 233 insertions(+), 6 deletions(-) diff --git a/tests/unit/CodeEntropy/levels/nodes/test_axes_topology.py b/tests/unit/CodeEntropy/levels/nodes/test_axes_topology.py index 680c439..2c31998 100644 --- a/tests/unit/CodeEntropy/levels/nodes/test_axes_topology.py +++ b/tests/unit/CodeEntropy/levels/nodes/test_axes_topology.py @@ -7,6 +7,7 @@ from CodeEntropy.levels.nodes.axes_topology import ( AxesTopology, BuildAxesTopologyNode, + ResidueAxesTopology, UAAxesTopology, ) @@ -62,10 +63,14 @@ def __init__(self, atoms): class FakeMolecule: - """Small molecule-like object with residues.""" + """Small molecule-like object with residues and selections.""" - def __init__(self, residues): + def __init__(self, residues, select_map=None): self.residues = residues + self._select_map = dict(select_map or {}) + + def select_atoms(self, selection): + return self._select_map.get(selection, FakeAtomGroup([])) class FakeAtoms: @@ -140,10 +145,23 @@ def test_ua_axes_topology_dataclass_preserves_arrays(): np.testing.assert_allclose(topology.residue_ua_masses, np.array([13.0, 12.0, 14.0])) -def test_axes_topology_defaults_to_empty_ua_mapping(): +def test_residue_axes_topology_dataclass_preserves_arrays(): + topology = ResidueAxesTopology( + residue_heavy_indices=np.array([1, 3, 4]), + residue_ua_masses=np.array([13.0, 12.0, 14.0]), + has_neighbor_bonds=True, + ) + + np.testing.assert_array_equal(topology.residue_heavy_indices, np.array([1, 3, 4])) + np.testing.assert_allclose(topology.residue_ua_masses, np.array([13.0, 12.0, 14.0])) + assert topology.has_neighbor_bonds is True + + +def test_axes_topology_defaults_to_empty_mappings(): topology = AxesTopology() assert topology.ua == {} + assert topology.residue == {} def test_run_writes_empty_topology_when_customised_axes_disabled(): @@ -154,9 +172,72 @@ def test_run_writes_empty_topology_when_customised_axes_disabled(): assert isinstance(result["axes_topology"], AxesTopology) assert result["axes_topology"].ua == {} + assert result["axes_topology"].residue == {} assert shared_data["axes_topology"] is result["axes_topology"] +def test_run_builds_residue_topology_for_residue_levels(): + node = BuildAxesTopologyNode() + universe, molecule, heavy, hydrogen, bonded_heavy, other_residue_heavy = ( + _single_molecule_universe() + ) + neighbor_query = "(resindex -1 or resindex 1) and bonded resid 0" + molecule._select_map[neighbor_query] = FakeAtomGroup([bonded_heavy]) + shared_data = { + "args": _args(customised_axes=True), + "reduced_universe": universe, + "levels": [["residue"]], + "beads": {(0, "residue"): [np.array([1, 2, 3, 4])]}, + } + + result = node.run(shared_data) + + axes_topology = result["axes_topology"] + assert axes_topology.ua == {} + assert set(axes_topology.residue) == {(0, 0)} + + residue_topology = axes_topology.residue[(0, 0)] + np.testing.assert_array_equal( + residue_topology.residue_heavy_indices, + np.array([heavy.index, bonded_heavy.index, other_residue_heavy.index]), + ) + np.testing.assert_allclose( + residue_topology.residue_ua_masses, + np.array([13.0, 12.0, 14.0]), + ) + assert residue_topology.has_neighbor_bonds is True + + +def test_add_residue_topology_skips_when_no_residue_beads(): + node = BuildAxesTopologyNode() + _, molecule, _, _, _, _ = _single_molecule_universe() + out = {} + + node._add_residue_topology( + mol=molecule, + mol_id=0, + beads={}, + out=out, + ) + + assert out == {} + + +def test_has_neighbor_bonds_uses_original_residue_selection(): + node = BuildAxesTopologyNode() + query = "(resindex 0 or resindex 2) and bonded resid 1" + molecule = FakeMolecule([], select_map={query: FakeAtomGroup([FakeAtom(1, 12.0)])}) + + assert node._has_neighbor_bonds(mol=molecule, local_res_i=1) is True + + +def test_has_neighbor_bonds_returns_false_for_empty_selection(): + node = BuildAxesTopologyNode() + molecule = FakeMolecule([]) + + assert node._has_neighbor_bonds(mol=molecule, local_res_i=1) is False + + def test_run_builds_ua_topology_for_united_atom_levels(): node = BuildAxesTopologyNode() universe, _, heavy, hydrogen, bonded_heavy, other_residue_heavy = ( diff --git a/tests/unit/CodeEntropy/levels/nodes/test_covariance_node.py b/tests/unit/CodeEntropy/levels/nodes/test_covariance_node.py index 6a3d2ec..9d66e88 100644 --- a/tests/unit/CodeEntropy/levels/nodes/test_covariance_node.py +++ b/tests/unit/CodeEntropy/levels/nodes/test_covariance_node.py @@ -1,5 +1,3 @@ -"""Atomic unit tests for frame-local covariance construction.""" - from __future__ import annotations from types import SimpleNamespace @@ -122,6 +120,9 @@ def test_run_processes_all_levels_and_writes_frame_covariance(): assert ua_kwargs["customised_axes"] is True assert ua_kwargs["is_highest"] is False + res_kwargs = node._process_residue.call_args.kwargs + assert res_kwargs["axes_topology"] is axes_topology + def test_run_omits_forcetorque_when_combined_is_false(): node = FrameCovarianceNode() @@ -255,6 +256,7 @@ def test_process_residue_updates_outputs_and_combined_ft(): group_id=7, beads={(0, "residue"): [np.array([0])]}, axes_manager="axes", + axes_topology=None, box=None, customised_axes=True, force_partitioning=0.5, @@ -284,6 +286,7 @@ def test_process_residue_returns_when_no_beads_or_empty_groups(): group_id=7, beads={}, axes_manager=None, + axes_topology=None, box=None, customised_axes=False, force_partitioning=0.5, @@ -305,6 +308,7 @@ def test_process_residue_returns_when_no_beads_or_empty_groups(): group_id=7, beads={(0, "residue"): [np.array([0])]}, axes_manager=None, + axes_topology=None, box=None, customised_axes=False, force_partitioning=0.5, @@ -522,9 +526,12 @@ def test_build_residue_vectors_uses_residue_axes(): node._ft.get_weighted_torques = MagicMock(return_value=np.array([0.0, 1.0, 0.0])) force_vecs, torque_vecs = node._build_residue_vectors( + u=FakeUniverse([mol]), mol=mol, + mol_id=0, bead_groups=[FakeAtomGroup("res")], axes_manager=axes_manager, + axes_topology=None, box=None, customised_axes=True, force_partitioning=0.5, @@ -536,6 +543,34 @@ def test_build_residue_vectors_uses_residue_axes(): node._get_residue_axes.assert_called_once() +def test_get_residue_axes_customised_uses_cached_topology_when_available(): + node = FrameCovarianceNode() + mol = FakeMolecule(n_residues=1) + axes_manager = MagicMock() + expected = (np.eye(3), np.eye(3) * 2.0, np.zeros(3), np.ones(3)) + residue_topology = object() + axes_topology = SimpleNamespace(residue={(3, 0): residue_topology}) + axes_manager.get_residue_axes_from_topology.return_value = expected + + result = node._get_residue_axes( + u=FakeUniverse([mol]), + mol=mol, + mol_id=3, + bead=FakeAtomGroup("res"), + local_res_i=0, + axes_manager=axes_manager, + axes_topology=axes_topology, + box=None, + customised_axes=True, + ) + + assert result == expected + called_kwargs = axes_manager.get_residue_axes_from_topology.call_args.kwargs + assert called_kwargs["topology"] is residue_topology + assert called_kwargs["residue_atoms"] is mol.residues[0].atoms + axes_manager.get_residue_axes.assert_not_called() + + def test_get_residue_axes_customised_delegates_to_axes_manager(): node = FrameCovarianceNode() mol = FakeMolecule(n_residues=1) @@ -545,10 +580,14 @@ def test_get_residue_axes_customised_delegates_to_axes_manager(): assert ( node._get_residue_axes( + u=FakeUniverse([mol]), mol=mol, + mol_id=0, bead=FakeAtomGroup("res"), local_res_i=0, axes_manager=axes_manager, + axes_topology=None, + box=None, customised_axes=True, ) == expected @@ -573,10 +612,14 @@ def test_get_residue_axes_vanilla_uses_make_whole_and_vanilla_axes(): with patch("CodeEntropy.levels.nodes.covariance.make_whole") as make_whole: trans_axes, rot_axes, center, moi = node._get_residue_axes( + u=FakeUniverse([mol]), mol=mol, + mol_id=0, bead=bead, local_res_i=0, axes_manager=axes_manager, + axes_topology=None, + box=None, customised_axes=False, ) diff --git a/tests/unit/CodeEntropy/levels/test_axes.py b/tests/unit/CodeEntropy/levels/test_axes.py index 68adfb8..e68f4d1 100644 --- a/tests/unit/CodeEntropy/levels/test_axes.py +++ b/tests/unit/CodeEntropy/levels/test_axes.py @@ -4,7 +4,7 @@ import pytest from CodeEntropy.levels.axes import AxesCalculator -from CodeEntropy.levels.nodes.axes_topology import UAAxesTopology +from CodeEntropy.levels.nodes.axes_topology import ResidueAxesTopology, UAAxesTopology class _FakeAtom: @@ -751,6 +751,109 @@ def _ua_topology( ) +def _residue_topology( + *, + residue_heavy_indices=(1,), + residue_ua_masses=(12.0,), + has_neighbor_bonds=False, +): + """Build a small cached residue topology fixture.""" + return ResidueAxesTopology( + residue_heavy_indices=np.asarray(residue_heavy_indices, dtype=int), + residue_ua_masses=np.asarray(residue_ua_masses, dtype=float), + has_neighbor_bonds=bool(has_neighbor_bonds), + ) + + +def test_get_residue_axes_from_topology_no_neighbor_bonds_uses_cached_indices( + monkeypatch, +): + ax = AxesCalculator() + + heavy_atom = _FakeAtom(1, 12.0, [1.0, 2.0, 3.0]) + other_heavy = _FakeAtom(3, 14.0, [4.0, 5.0, 6.0]) + universe = _FakeUniverse({1: heavy_atom, 3: other_heavy}) + mol = MagicMock() + residue_atoms = MagicMock() + residue_atoms.center_of_mass.return_value = np.array([9.0, 8.0, 7.0]) + topology = _residue_topology( + residue_heavy_indices=(1, 3), + residue_ua_masses=(13.0, 14.0), + has_neighbor_bonds=False, + ) + + get_tensor = MagicMock(return_value=np.eye(3)) + get_principal = MagicMock(return_value=(np.eye(3) * 2.0, np.array([3.0, 2.0, 1.0]))) + + monkeypatch.setattr(ax, "get_moment_of_inertia_tensor", get_tensor) + monkeypatch.setattr(ax, "get_custom_principal_axes", get_principal) + + box = np.array([20.0, 30.0, 40.0]) + trans_axes, rot_axes, center, moi = ax.get_residue_axes_from_topology( + u=universe, + mol=mol, + residue_atoms=residue_atoms, + topology=topology, + box=box, + ) + + np.testing.assert_allclose(trans_axes, np.eye(3) * 2.0) + np.testing.assert_allclose(rot_axes, np.eye(3) * 2.0) + np.testing.assert_allclose(center, np.array([9.0, 8.0, 7.0])) + np.testing.assert_allclose(moi, np.array([3.0, 2.0, 1.0])) + + tensor_kwargs = get_tensor.call_args.kwargs + np.testing.assert_allclose( + tensor_kwargs["center_of_mass"], + np.array([9.0, 8.0, 7.0]), + ) + np.testing.assert_allclose( + tensor_kwargs["positions"], + np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), + ) + np.testing.assert_allclose(tensor_kwargs["masses"], np.array([13.0, 14.0])) + np.testing.assert_allclose(tensor_kwargs["dimensions"], box) + get_principal.assert_called_once() + + +def test_get_residue_axes_from_topology_neighbor_bonds_uses_vanilla_axes( + monkeypatch, +): + ax = AxesCalculator() + + universe = _FakeUniverse( + {}, + dimensions=[11.0, 12.0, 13.0, 90.0, 90.0, 90.0], + ) + mol = MagicMock() + mol.atoms.principal_axes.return_value = np.eye(3) * 5.0 + residue_atoms = MagicMock() + residue_atoms.center_of_mass.return_value = np.array([1.0, 2.0, 3.0]) + topology = _residue_topology(has_neighbor_bonds=True) + + make_whole = MagicMock() + get_vanilla = MagicMock(return_value=(np.eye(3) * 6.0, np.array([6.0, 5.0, 4.0]))) + + monkeypatch.setattr("CodeEntropy.levels.axes.make_whole", make_whole) + monkeypatch.setattr(ax, "get_vanilla_axes", get_vanilla) + + trans_axes, rot_axes, center, moi = ax.get_residue_axes_from_topology( + u=universe, + mol=mol, + residue_atoms=residue_atoms, + topology=topology, + box=None, + ) + + make_whole.assert_called_once_with(mol.atoms) + mol.atoms.principal_axes.assert_called_once() + get_vanilla.assert_called_once_with(residue_atoms) + np.testing.assert_allclose(trans_axes, np.eye(3) * 5.0) + np.testing.assert_allclose(rot_axes, np.eye(3) * 6.0) + np.testing.assert_allclose(center, np.array([1.0, 2.0, 3.0])) + np.testing.assert_allclose(moi, np.array([6.0, 5.0, 4.0])) + + def test_get_UA_axes_from_topology_multiple_heavy_uses_cached_indices_and_box( monkeypatch, ):