diff --git a/CodeEntropy/levels/axes.py b/CodeEntropy/levels/axes.py index 57ef8912..cb0af142 100644 --- a/CodeEntropy/levels/axes.py +++ b/CodeEntropy/levels/axes.py @@ -216,6 +216,161 @@ def get_UA_axes(self, data_container, index: int): return trans_axes, rot_axes, center, moment_of_inertia + def get_UA_axes_from_topology( + self, + *, + u, + residue_atoms, + topology, + box: np.ndarray | None, + ): + """Compute UA axes using cached static topology. + + This is the cached-index equivalent of ``get_UA_axes``. It preserves the + frame-dependent numerical calculations, but avoids repeated MDAnalysis + selection strings for heavy atoms, bonded atoms, and UA masses. + + Args: + u: Current-frame universe. + residue_atoms: AtomGroup for the parent residue in the current frame. + topology: Cached ``UAAxesTopology`` for this UA bead. + box: Current periodic box lengths. If omitted, ``u.dimensions`` is used. + + Returns: + Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + - trans_axes: Translational axes, shape ``(3, 3)``. + - rot_axes: Rotational axes, shape ``(3, 3)``. + - center: Rotation centre, shape ``(3,)``. + - moment_of_inertia: Principal moments, shape ``(3,)``. + + Raises: + ValueError: If cached bonded-axis construction fails. + """ + dimensions = ( + np.asarray(box, dtype=float) + if box is not None + else np.asarray(u.dimensions[:3], dtype=float) + ) + + heavy_atoms = u.atoms[topology.residue_heavy_indices] + heavy_atom = u.atoms[int(topology.heavy_atom_index)] + + if len(heavy_atoms) > 1: + center = residue_atoms.center_of_mass(unwrap=True) + moment_of_inertia_tensor = self.get_moment_of_inertia_tensor( + center_of_mass=center, + positions=heavy_atoms.positions, + masses=topology.residue_ua_masses, + dimensions=dimensions, + ) + trans_axes, _moment_of_inertia = self.get_custom_principal_axes( + moment_of_inertia_tensor + ) + else: + make_whole(residue_atoms) + trans_axes = residue_atoms.principal_axes() + + center = heavy_atom.position + rot_axes, moment_of_inertia = self.get_bonded_axes_from_topology( + u=u, + heavy_atom=heavy_atom, + topology=topology, + dimensions=dimensions, + ) + if rot_axes is None or moment_of_inertia is None: + raise ValueError("Unable to compute bonded axes for cached UA bead.") + + logger.debug("Translational Axes: %s", trans_axes) + logger.debug("Rotational Axes: %s", rot_axes) + logger.debug("Center: %s", center) + logger.debug("Moment of Inertia: %s", moment_of_inertia) + + return trans_axes, rot_axes, center, moment_of_inertia + + def get_bonded_axes_from_topology( + self, + *, + u, + heavy_atom, + topology, + dimensions: np.ndarray, + ): + """Compute UA bonded axes using cached bonded atom indices. + + This mirrors ``get_bonded_axes`` but receives precomputed bonded atom + memberships from ``UAAxesTopology`` instead of rediscovering them with + MDAnalysis selection strings inside the frame loop. + + Args: + u: Current-frame universe. + heavy_atom: Current-frame heavy atom for the UA bead. + topology: Cached ``UAAxesTopology`` for the UA bead. + dimensions: Simulation box lengths, shape ``(3,)``. + + Returns: + Tuple[np.ndarray | None, np.ndarray | None]: + - custom_axes: Custom rotation axes, shape ``(3, 3)``, or ``None``. + - custom_moment_of_inertia: Principal moments, shape ``(3,)``, or + ``None``. + """ + if not heavy_atom.mass > 1.1: + return None, None + + custom_moment_of_inertia = None + custom_axes = None + + heavy_bonded = u.atoms[topology.bonded_heavy_indices] + light_bonded = u.atoms[topology.bonded_light_indices] + ua = u.atoms[topology.ua_atom_indices] + ua_all = u.atoms[topology.ua_all_atom_indices] + + if len(heavy_bonded) == 0: + custom_axes, custom_moment_of_inertia = self.get_vanilla_axes(ua_all) + + if len(heavy_bonded) == 1 and len(light_bonded) == 0: + custom_axes = self.get_custom_axes( + a=heavy_atom.position, + b_list=[heavy_bonded[0].position], + c=np.zeros(3), + dimensions=dimensions, + ) + + if len(heavy_bonded) == 1 and len(light_bonded) >= 1: + custom_axes = self.get_custom_axes( + a=heavy_atom.position, + b_list=[heavy_bonded[0].position], + c=light_bonded[0].position, + dimensions=dimensions, + ) + + if len(heavy_bonded) >= 2: + custom_axes = self.get_custom_axes( + a=heavy_atom.position, + b_list=heavy_bonded.positions, + c=heavy_bonded[1].position, + dimensions=dimensions, + ) + + if custom_axes is None: + return None, None + + if custom_moment_of_inertia is None: + custom_moment_of_inertia = self.get_custom_moment_of_inertia( + UA=ua, + custom_rotation_axes=custom_axes, + center_of_mass=heavy_atom.position, + dimensions=dimensions, + ) + + custom_axes = self.get_flipped_axes( + ua, + custom_axes, + heavy_atom.position, + dimensions, + ) + + return custom_axes, custom_moment_of_inertia + def get_bonded_axes(self, system, atom, dimensions: np.ndarray): r"""Compute UA rotational axes from bonded topology around a heavy atom. diff --git a/CodeEntropy/levels/dihedrals/topology.py b/CodeEntropy/levels/dihedrals/topology.py index 680639a1..31e35ffd 100644 --- a/CodeEntropy/levels/dihedrals/topology.py +++ b/CodeEntropy/levels/dihedrals/topology.py @@ -59,9 +59,7 @@ def _discover_group_dihedral_topology( topologies: list[MoleculeDihedralTopology] = [] for molecule_order, molecule_id in enumerate(molecules): - mol = self._universe_operations.extract_fragment( - data_container, molecule_id - ) + mol = self._extract_topology_fragment(data_container, molecule_id) num_residues = len(mol.residues) ua_dihedrals_by_residue: dict[int, list[Any]] = {} residue_dihedrals: list[Any] = [] @@ -90,6 +88,26 @@ def _discover_group_dihedral_topology( return topologies + def _extract_topology_fragment(self, data_container: Any, molecule_id: Any) -> Any: + """Return a molecule fragment for topology discovery. + + This uses the lightweight AtomGroup extraction helper when available so + static conformational topology discovery does not create a standalone + in-memory universe or copy trajectory frames. The fallback preserves + compatibility with older ``UniverseOperations`` implementations. + + Args: + data_container: Source MDAnalysis universe or universe-like container. + molecule_id: Fragment index identifying the molecule to extract. + + Returns: + MDAnalysis AtomGroup for the selected molecule + """ + return self._universe_operations.extract_fragment_atomgroup( + data_container, + int(molecule_id), + ) + def _select_heavy_residue(self, mol: Any, res_id: int) -> Any: """Select heavy atoms in a residue by residue index. @@ -100,13 +118,15 @@ def _select_heavy_residue(self, mol: Any, res_id: int) -> Any: Returns: AtomGroup containing heavy atoms in the residue selection. """ - selection1 = mol.residues[res_id].atoms.indices[0] - selection2 = mol.residues[res_id].atoms.indices[-1] + residue_atoms = mol.residues[int(res_id)].atoms + selection1 = residue_atoms.indices[0] + selection2 = residue_atoms.indices[-1] - res_container = self._universe_operations.select_atoms( - mol, f"index {selection1}:{selection2}" + res_container = mol.select_atoms( + f"index {selection1}:{selection2}", + updating=False, ) - return self._universe_operations.select_atoms(res_container, "prop mass > 1.1") + return res_container.select_atoms("prop mass > 1.1", updating=False) def _get_dihedrals(self, data_container: Any, level: str) -> list[Any]: """Return dihedral AtomGroups for a container at a given level. @@ -121,26 +141,93 @@ def _get_dihedrals(self, data_container: Any, level: str) -> list[Any]: atom_groups: list[Any] = [] if level == "united_atom": + selected_indices = {int(index) for index in data_container.indices} + for dihedral in data_container.dihedrals: - atom_groups.append(dihedral.atoms) + dihedral_atoms = dihedral.atoms + dihedral_indices = {int(index) for index in dihedral_atoms.indices} + + if len(dihedral_atoms) == 4 and dihedral_indices.issubset( + selected_indices + ): + atom_groups.append(dihedral_atoms) if level == "residue": num_residues = len(data_container.residues) if num_residues >= 4: for residue in range(4, num_residues + 1): - atom1 = data_container.select_atoms( - f"resindex {residue - 4} and bonded resindex {residue - 3}" + residue1 = data_container.residues[residue - 4] + residue2 = data_container.residues[residue - 3] + residue3 = data_container.residues[residue - 2] + residue4 = data_container.residues[residue - 1] + + atom1 = self._atoms_in_source_bonded_to_target( + residue1, + residue2, ) - atom2 = data_container.select_atoms( - f"resindex {residue - 3} and bonded resindex {residue - 4}" + atom2 = self._atoms_in_source_bonded_to_target( + residue2, + residue1, ) - atom3 = data_container.select_atoms( - f"resindex {residue - 2} and bonded resindex {residue - 1}" + atom3 = self._atoms_in_source_bonded_to_target( + residue3, + residue4, ) - atom4 = data_container.select_atoms( - f"resindex {residue - 1} and bonded resindex {residue - 2}" + atom4 = self._atoms_in_source_bonded_to_target( + residue4, + residue3, ) - atom_groups.append(atom1 + atom2 + atom3 + atom4) + + dihedral_atoms = atom1 + atom2 + atom3 + atom4 + + if len(dihedral_atoms) == 4: + atom_groups.append(dihedral_atoms) + else: + logger.debug( + "Skipping residue-level dihedral for local residues " + "%s-%s-%s-%s because it produced %d atoms.", + residue - 4, + residue - 3, + residue - 2, + residue - 1, + len(dihedral_atoms), + ) logger.debug("Level: %s, Dihedrals: %s", level, atom_groups) return atom_groups + + @staticmethod + def _atoms_in_source_bonded_to_target( + source_residue: Any, + target_residue: Any, + ) -> Any: + """Return source-residue atoms bonded to atoms in a target residue. + + This helper is used when constructing residue-level dihedral definitions + from lightweight molecule AtomGroups. It selects atoms from the source + residue that are bonded to any atom in the target residue without using + global ``resindex`` selection strings. + + Args: + source_residue: Residue whose atoms should be tested for bonds. + target_residue: Adjacent residue providing the target bonded atoms. + + Returns: + MDAnalysis AtomGroup containing atoms from ``source_residue`` that are + bonded to at least one atom in ``target_residue``. If no matching + atoms are found, an empty AtomGroup is returned. + """ + source_atoms = source_residue.atoms + target_indices = {int(index) for index in target_residue.atoms.indices} + selected_indices: list[int] = [] + + for atom in source_atoms: + bonded_atoms = getattr(atom, "bonded_atoms", None) + if bonded_atoms is None: + continue + + bonded_indices = {int(index) for index in bonded_atoms.indices} + if bonded_indices.intersection(target_indices): + selected_indices.append(int(atom.index)) + + return source_atoms.universe.atoms[selected_indices] diff --git a/CodeEntropy/levels/level_dag.py b/CodeEntropy/levels/level_dag.py index cc8129a1..d299fd0c 100644 --- a/CodeEntropy/levels/level_dag.py +++ b/CodeEntropy/levels/level_dag.py @@ -19,6 +19,7 @@ from CodeEntropy.levels.frame_dag import FrameGraph from CodeEntropy.levels.neighbors import Neighbors from CodeEntropy.levels.nodes.accumulators import InitCovarianceAccumulatorsNode +from CodeEntropy.levels.nodes.axes_topology import BuildAxesTopologyNode from CodeEntropy.levels.nodes.beads import BuildBeadsNode from CodeEntropy.levels.nodes.detect_levels import DetectLevelsNode from CodeEntropy.levels.nodes.detect_molecules import DetectMoleculesNode @@ -49,6 +50,11 @@ def build(self) -> LevelDAG: self._add_static("detect_molecules", DetectMoleculesNode()) self._add_static("detect_levels", DetectLevelsNode(), deps=["detect_molecules"]) self._add_static("build_beads", BuildBeadsNode(), deps=["detect_levels"]) + self._add_static( + "build_axes_topology", + BuildAxesTopologyNode(), + deps=["build_beads"], + ) self._add_static( "init_covariance_accumulators", InitCovarianceAccumulatorsNode(), diff --git a/CodeEntropy/levels/nodes/axes_topology.py b/CodeEntropy/levels/nodes/axes_topology.py new file mode 100644 index 00000000..abac3739 --- /dev/null +++ b/CodeEntropy/levels/nodes/axes_topology.py @@ -0,0 +1,223 @@ +"""Build static axes-topology metadata for frame covariance calculations. + +This module caches topology-only atom-index relationships needed by customised +united-atom axes calculations. The cache avoids repeated MDAnalysis selection +parsing inside the frame-local covariance loop while preserving frame-dependent +positions, forces, centres, axes, torques, and moments of inertia. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from typing import Any + +import numpy as np + +logger = logging.getLogger(__name__) + +UAKey = tuple[int, int, int] + + +@dataclass(frozen=True) +class UAAxesTopology: + """Static topology required to compute customised united-atom axes. + + Attributes: + heavy_atom_index: Reduced-universe atom index for the UA heavy atom. + ua_atom_indices: Atom indices for the UA heavy atom and its bonded + hydrogens/light atoms. + ua_all_atom_indices: Atom indices for the UA heavy atom, bonded heavy + atoms, and bonded hydrogens/light atoms. + bonded_heavy_indices: Heavy atoms bonded to the UA heavy atom. + bonded_light_indices: Hydrogens/light atoms bonded to the UA heavy atom. + residue_heavy_indices: Heavy atoms in the parent residue. + residue_ua_masses: UA masses for heavy atoms in the parent residue. + """ + + heavy_atom_index: int + ua_atom_indices: np.ndarray + ua_all_atom_indices: np.ndarray + bonded_heavy_indices: np.ndarray + bonded_light_indices: np.ndarray + residue_heavy_indices: np.ndarray + residue_ua_masses: np.ndarray + + +@dataclass(frozen=True) +class AxesTopology: + """Cached axes topology for frame covariance calculations. + + Attributes: + ua: Mapping from ``(mol_id, local_residue_id, ua_id)`` to cached + united-atom axes topology. + """ + + ua: dict[UAKey, UAAxesTopology] = field(default_factory=dict) + + +class BuildAxesTopologyNode: + """Build static customised-axes topology before frame covariance execution.""" + + def run(self, shared_data: dict[str, Any]) -> dict[str, Any]: + """Build cached axes topology and write it into shared data. + + The cache is only populated when ``args.customised_axes`` is true. When + customised axes are disabled, an empty cache is still written so later + stages can read ``shared_data["axes_topology"]`` safely. + + Args: + shared_data: Shared workflow data containing ``args`` and, when + customised axes are enabled, ``reduced_universe``, ``levels``, + and ``beads``. + + Returns: + Dict containing the cached ``axes_topology`` object. + """ + args = shared_data["args"] + topology = AxesTopology() + + if not bool(getattr(args, "customised_axes", False)): + shared_data["axes_topology"] = topology + return {"axes_topology": topology} + + u = shared_data["reduced_universe"] + levels = shared_data["levels"] + beads = shared_data["beads"] + + ua_topology: dict[UAKey, UAAxesTopology] = {} + fragments = u.atoms.fragments + + for mol_id, level_list in enumerate(levels): + if "united_atom" not in level_list: + continue + + self._add_ua_topology( + u=u, + mol=fragments[mol_id], + mol_id=mol_id, + beads=beads, + out=ua_topology, + ) + + topology = AxesTopology(ua=ua_topology) + shared_data["axes_topology"] = topology + return {"axes_topology": topology} + + def _add_ua_topology( + self, + *, + u: Any, + mol: Any, + mol_id: int, + beads: dict[Any, list[np.ndarray]], + out: dict[UAKey, UAAxesTopology], + ) -> None: + """Cache static UA axes topology for one molecule. + + Args: + u: Reduced universe used to resolve bead atom-index arrays. + mol: Molecule AtomGroup. + mol_id: Molecule index. + beads: Bead-index mapping produced by ``BuildBeadsNode``. + out: Output UA topology mapping mutated in place. + """ + for local_res_i, residue in enumerate(mol.residues): + bead_key = (mol_id, "united_atom", local_res_i) + bead_idx_list = beads.get(bead_key, []) + + if not bead_idx_list: + continue + + residue_atoms = residue.atoms + residue_heavy = residue_atoms.select_atoms("prop mass > 1.1") + residue_heavy_indices = residue_heavy.indices.astype(int, copy=True) + residue_ua_masses = np.asarray( + self._get_ua_masses_from_topology(residue_atoms), + dtype=float, + ) + + for ua_i, bead_indices in enumerate(bead_idx_list): + bead = u.atoms[bead_indices] + heavy = bead.select_atoms("prop mass > 1.1") + + if len(heavy) == 0: + logger.warning( + "Skipping UA axes topology with no heavy atom: " + "mol=%s residue=%s ua=%s", + mol_id, + local_res_i, + ua_i, + ) + continue + + heavy_atom = heavy[0] + bonded_heavy, bonded_light = self._split_bonded_atoms(heavy_atom) + + heavy_index = np.asarray([int(heavy_atom.index)], dtype=int) + bonded_heavy_indices = bonded_heavy.indices.astype(int, copy=True) + bonded_light_indices = bonded_light.indices.astype(int, copy=True) + + ua_atom_indices = np.concatenate( + [heavy_index, bonded_light_indices], + axis=0, + ) + ua_all_atom_indices = np.concatenate( + [heavy_index, bonded_heavy_indices, bonded_light_indices], + axis=0, + ) + + out[(mol_id, local_res_i, ua_i)] = UAAxesTopology( + heavy_atom_index=int(heavy_atom.index), + ua_atom_indices=ua_atom_indices, + ua_all_atom_indices=ua_all_atom_indices, + bonded_heavy_indices=bonded_heavy_indices, + bonded_light_indices=bonded_light_indices, + residue_heavy_indices=residue_heavy_indices, + residue_ua_masses=residue_ua_masses, + ) + + @staticmethod + def _split_bonded_atoms(atom: Any) -> tuple[Any, Any]: + """Return bonded heavy and light atoms for one atom. + + Args: + atom: MDAnalysis Atom. + + Returns: + Tuple containing bonded heavy atoms and bonded hydrogens/light atoms. + """ + bonded_atoms = atom.bonded_atoms + bonded_heavy = bonded_atoms.select_atoms("mass 2 to 999") + bonded_light = bonded_atoms.select_atoms("mass 1 to 1.1") + return bonded_heavy, bonded_light + + @staticmethod + def _get_ua_masses_from_topology(atom_group: Any) -> list[float]: + """Return UA masses using static bonded atom relationships. + + Args: + atom_group: AtomGroup containing atoms from one residue. + + Returns: + List of UA masses, one for each heavy atom in ``atom_group``. + """ + ua_masses: list[float] = [] + + for atom in atom_group: + if atom.mass <= 1.1: + continue + + ua_mass = float(atom.mass) + bonded_atoms = getattr(atom, "bonded_atoms", None) + if bonded_atoms is None: + ua_masses.append(ua_mass) + continue + + bonded_h_atoms = bonded_atoms.select_atoms("mass 1 to 1.1") + for hydrogen in bonded_h_atoms: + ua_mass += float(hydrogen.mass) + + ua_masses.append(ua_mass) + + return ua_masses diff --git a/CodeEntropy/levels/nodes/covariance.py b/CodeEntropy/levels/nodes/covariance.py index 4331cc6a..e7f829d5 100644 --- a/CodeEntropy/levels/nodes/covariance.py +++ b/CodeEntropy/levels/nodes/covariance.py @@ -79,6 +79,7 @@ def run(self, ctx: FrameCtx) -> dict[str, Any]: beads = shared["beads"] args = shared["args"] axes_manager = shared.get("axes_manager") + axes_topology = shared.get("axes_topology") fp = float(args.force_partitioning) combined = bool(getattr(args, "combined_forcetorque", False)) @@ -110,6 +111,7 @@ def run(self, ctx: FrameCtx) -> dict[str, Any]: group_id=group_id, beads=beads, axes_manager=axes_manager, + axes_topology=axes_topology, box=box, force_partitioning=fp, customised_axes=customised_axes, @@ -172,6 +174,7 @@ def _process_united_atom( group_id: int, beads: dict[Any, list[Any]], axes_manager: Any, + axes_topology: Any | None, box: np.ndarray | None, force_partitioning: float, customised_axes: bool, @@ -189,6 +192,7 @@ def _process_united_atom( group_id: Molecule-group identifier used for within-frame averaging. beads: Mapping of bead keys to reduced-universe atom-index arrays. axes_manager: Axes helper used to build translation and rotation axes. + axes_topology: Optional cached axes topology generated during static setup. box: Optional periodic box vector. force_partitioning: Force partitioning factor for highest-level vectors. customised_axes: Whether customised UA axes should be used. @@ -208,9 +212,13 @@ def _process_united_atom( continue force_vecs, torque_vecs = self._build_ua_vectors( + u=u, + mol_id=mol_id, + local_res_i=local_res_i, residue_atoms=res.atoms, bead_groups=bead_groups, axes_manager=axes_manager, + axes_topology=axes_topology, box=box, force_partitioning=force_partitioning, customised_axes=customised_axes, @@ -388,9 +396,13 @@ def _process_polymer( def _build_ua_vectors( self, *, + u: Any, + mol_id: int, + local_res_i: int, bead_groups: list[Any], residue_atoms: Any, axes_manager: Any, + axes_topology: Any | None, box: np.ndarray | None, force_partitioning: float, customised_axes: bool, @@ -399,9 +411,13 @@ def _build_ua_vectors( """Build force and torque vectors for united-atom beads. Args: + u: Universe-like object used to resolve cached atom indices. + mol_id: Molecule index used in axes-topology lookup keys. + local_res_i: Local residue index used in axes-topology lookup keys. bead_groups: Atom groups representing UA beads in a residue. residue_atoms: Atom group for the parent residue. axes_manager: Axes helper used to select axes, centres, and moments. + axes_topology: Optional cached axes topology generated during static setup. box: Optional periodic box vector. force_partitioning: Force partitioning factor for highest-level vectors. customised_axes: Whether customised UA axes should be used. @@ -415,9 +431,23 @@ def _build_ua_vectors( for ua_i, bead in enumerate(bead_groups): if customised_axes: - trans_axes, rot_axes, center, moi = axes_manager.get_UA_axes( - residue_atoms, ua_i - ) + ua_topology = None + if axes_topology is not None: + ua_topology = axes_topology.ua.get((mol_id, local_res_i, ua_i)) + + if ua_topology is not None: + trans_axes, rot_axes, center, moi = ( + axes_manager.get_UA_axes_from_topology( + u=u, + residue_atoms=residue_atoms, + topology=ua_topology, + box=box, + ) + ) + else: + trans_axes, rot_axes, center, moi = axes_manager.get_UA_axes( + residue_atoms, ua_i + ) else: make_whole(residue_atoms) make_whole(bead) diff --git a/CodeEntropy/trajectory/mda.py b/CodeEntropy/trajectory/mda.py index 04943a4c..1cdb6e75 100644 --- a/CodeEntropy/trajectory/mda.py +++ b/CodeEntropy/trajectory/mda.py @@ -25,7 +25,8 @@ class UniverseOperations: This helper provides methods to: - Build reduced universes by selecting subsets of frames or atoms. - Extract a single fragment (molecule) into a standalone universe. - - Merge coordinates from one trajectory with forces from another trajectory. + - Merge coordinates from one trajectory with forces sourced from another + trajectory. """ def __init__(self) -> None: @@ -186,6 +187,25 @@ def extract_fragment( selection_string = f"index {frag.indices[0]}:{frag.indices[-1]}" return self.select_atoms(universe, selection_string) + def extract_fragment_atomgroup(self, universe: mda.Universe, molecule_id: int): + """Return a molecule fragment as an AtomGroup. + + This helper mirrors the atom-index range used by ``extract_fragment`` but + avoids building a standalone in-memory universe. It is intended for + topology discovery paths where only the static atom selection is needed + and trajectory coordinates do not need to be copied. + + Args: + universe: Source MDAnalysis universe. + molecule_id: Fragment index in ``universe.atoms.fragments``. + + Returns: + MDAnalysis AtomGroup containing the atoms in the selected fragment. + """ + frag = universe.atoms.fragments[int(molecule_id)] + selection_string = f"index {frag.indices[0]}:{frag.indices[-1]}" + return universe.select_atoms(selection_string, updating=False) + def convert_lammps( self, tprfile: str, diff --git a/tests/unit/CodeEntropy/levels/dihedrals/test_topology.py b/tests/unit/CodeEntropy/levels/dihedrals/test_topology.py index 60affa71..cbc38b19 100644 --- a/tests/unit/CodeEntropy/levels/dihedrals/test_topology.py +++ b/tests/unit/CodeEntropy/levels/dihedrals/test_topology.py @@ -1,146 +1,472 @@ from __future__ import annotations -from unittest.mock import MagicMock +import logging +from dataclasses import dataclass +from types import SimpleNamespace +from typing import Any +from unittest.mock import Mock, call -import numpy as np +import pytest -from CodeEntropy.levels.dihedrals.topology import DihedralTopologyDiscovery +from CodeEntropy.levels.dihedrals.topology import ( + DihedralTopologyDiscovery, + MoleculeDihedralTopology, +) -class _AddableAG: - """Minimal addable AtomGroup test double.""" +@dataclass +class FakeAtom: + """Small test double for an MDAnalysis Atom.""" - def __init__(self, name: str) -> None: - """Initialize the fake AtomGroup. + index: int + bonded_atoms: Any = None - Args: - name: Human-readable identifier used in composed names. - """ - self.name = name - def __add__(self, other: _AddableAG) -> _AddableAG: - """Return a composed fake AtomGroup. +class FakeUniverseAtoms: + """Indexable fake universe atom container.""" - Args: - other: Fake AtomGroup to combine with this object. + def __init__(self, universe: FakeUniverse) -> None: + self._universe = universe - Returns: - New fake AtomGroup containing a composed name. - """ - return _AddableAG(f"({self.name}+{other.name})") + def __getitem__(self, indices: int | list[int] | tuple[int, ...]) -> Any: + if isinstance(indices, int): + return self._universe.atom_by_index[int(indices)] + return FakeAtomGroup( + [self._universe.atom_by_index[int(index)] for index in indices], + universe=self._universe, + ) -class _TopologyDiscovery(DihedralTopologyDiscovery): - """Concrete topology-discovery helper for unit tests.""" - def __init__(self, universe_operations: MagicMock) -> None: - """Initialize the test helper. +class FakeUniverse: + """Small fake universe supporting ``universe.atoms[indices]``.""" - Args: - universe_operations: Mock universe-operation adapter. - """ - self._universe_operations = universe_operations + def __init__(self, atoms: list[FakeAtom]) -> None: + self.atom_by_index = {int(atom.index): atom for atom in atoms} + self.atoms = FakeUniverseAtoms(self) -def test_select_heavy_residue_builds_expected_selections(): - uops = MagicMock() - helper = _TopologyDiscovery(universe_operations=uops) +class FakeAtomGroup: + """Small AtomGroup-like test double for topology discovery tests.""" - mol = MagicMock() - mol.residues = [MagicMock()] - mol.residues[0].atoms.indices = np.array([10, 11, 12], dtype=int) - uops.select_atoms.side_effect = ["residue_atoms", "heavy_atoms"] + def __init__( + self, + atoms: list[FakeAtom], + *, + residues: list[Any] | None = None, + dihedrals: list[Any] | None = None, + select_map: dict[str, Any] | None = None, + universe: FakeUniverse | None = None, + ) -> None: + self._atoms = list(atoms) + self.residues = list(residues or []) + self.dihedrals = list(dihedrals or []) + self._select_map = dict(select_map or {}) + self.universe = universe - out = helper._select_heavy_residue(mol, res_id=0) + @property + def atoms(self) -> FakeAtomGroup: + return self - assert out == "heavy_atoms" - assert uops.select_atoms.call_args_list == [ - ((mol, "index 10:12"),), - (("residue_atoms", "prop mass > 1.1"),), - ] + @property + def indices(self) -> list[int]: + return [int(atom.index) for atom in self._atoms] + + def __iter__(self): + return iter(self._atoms) + def __len__(self) -> int: + return len(self._atoms) -def test_get_dihedrals_united_atom_collects_atoms_from_dihedral_objects(): - helper = _TopologyDiscovery(universe_operations=MagicMock()) + def __add__(self, other: FakeAtomGroup) -> FakeAtomGroup: + return FakeAtomGroup( + self._atoms + other._atoms, + universe=self.universe or other.universe, + ) - d0 = MagicMock() - d0.atoms = "A0" - d1 = MagicMock() - d1.atoms = "A1" + def select_atoms(self, select_string: str, updating: bool = False) -> Any: + if select_string not in self._select_map: + raise AssertionError(f"Unexpected selection: {select_string!r}") + return self._select_map[select_string] - container = MagicMock() - container.dihedrals = [d0, d1] - assert helper._get_dihedrals(container, level="united_atom") == ["A0", "A1"] +@dataclass +class FakeResidue: + """Small residue test double.""" + atoms: FakeAtomGroup -def test_get_dihedrals_residue_returns_empty_when_less_than_four_residues(): - helper = _TopologyDiscovery(universe_operations=MagicMock()) - mol = MagicMock() - mol.residues = [MagicMock(), MagicMock(), MagicMock()] - mol.select_atoms = MagicMock() +@dataclass +class FakeDihedral: + """Small dihedral topology object.""" - assert helper._get_dihedrals(mol, level="residue") == [] - mol.select_atoms.assert_not_called() + atoms: FakeAtomGroup -def test_get_dihedrals_residue_builds_one_dihedral_when_four_residues(): - helper = _TopologyDiscovery(universe_operations=MagicMock()) +def _make_discovery( + universe_operations: Any | None = None, +) -> DihedralTopologyDiscovery: + discovery = DihedralTopologyDiscovery() + discovery._universe_operations = universe_operations or Mock() + return discovery - mol = MagicMock() - mol.residues = [MagicMock(), MagicMock(), MagicMock(), MagicMock()] - mol.select_atoms = MagicMock( - side_effect=[ - _AddableAG("a1"), - _AddableAG("a2"), - _AddableAG("a3"), - _AddableAG("a4"), - ] + +def test_molecule_dihedral_topology_stores_expected_fields() -> None: + topology = MoleculeDihedralTopology( + group_id=1, + molecule_id=2, + molecule_order=3, + num_residues=4, + ua_dihedrals_by_residue={0: ["ua"]}, + residue_dihedrals=["res"], ) - out = helper._get_dihedrals(mol, level="residue") + assert topology.group_id == 1 + assert topology.molecule_id == 2 + assert topology.molecule_order == 3 + assert topology.num_residues == 4 + assert topology.ua_dihedrals_by_residue == {0: ["ua"]} + assert topology.residue_dihedrals == ["res"] - assert len(out) == 1 - assert isinstance(out[0], _AddableAG) - assert mol.select_atoms.call_count == 4 +def test_extract_topology_fragment_uses_lightweight_atomgroup_helper() -> None: + universe_operations = Mock() + universe_operations.extract_fragment_atomgroup.return_value = "fragment_atomgroup" + universe_operations.extract_fragment.return_value = "heavy_fragment" -def test_discover_group_dihedral_topology_builds_one_entry_per_molecule(): - uops = MagicMock() - helper = _TopologyDiscovery(universe_operations=uops) + discovery = _make_discovery(universe_operations) - mol0 = MagicMock() - mol0.residues = [MagicMock(), MagicMock()] - mol1 = MagicMock() - mol1.residues = [MagicMock(), MagicMock()] - uops.extract_fragment.side_effect = [mol0, mol1] + result = discovery._extract_topology_fragment("universe", 5) - helper._select_heavy_residue = MagicMock( - side_effect=["heavy0", "heavy1", "heavy2", "heavy3"] + assert result == "fragment_atomgroup" + universe_operations.extract_fragment_atomgroup.assert_called_once_with( + "universe", + 5, ) - helper._get_dihedrals = MagicMock( - side_effect=[ - ["ua0r0"], - ["ua0r1"], - ["res0"], - ["ua1r0"], - ["ua1r1"], - ["res1"], - ] + universe_operations.extract_fragment.assert_not_called() + + +def test_select_heavy_residue_builds_expected_selections() -> None: + discovery = _make_discovery() + + residue_atoms = Mock() + residue_atoms.indices = [10, 11, 12, 13] + residue = SimpleNamespace(atoms=residue_atoms) + + heavy_atoms = object() + residue_container = Mock() + residue_container.select_atoms.return_value = heavy_atoms + + mol = Mock() + mol.residues = [residue] + mol.select_atoms.return_value = residue_container + + result = discovery._select_heavy_residue(mol, 0) + + assert result is heavy_atoms + mol.select_atoms.assert_called_once_with("index 10:13", updating=False) + residue_container.select_atoms.assert_called_once_with( + "prop mass > 1.1", + updating=False, ) + discovery._universe_operations.select_atoms.assert_not_called() + + +def test_get_dihedrals_united_atom_collects_atoms_from_dihedral_objects() -> None: + discovery = _make_discovery() + + atoms = [FakeAtom(index) for index in range(1, 8)] + universe = FakeUniverse(atoms) + + valid_dihedral_atoms = FakeAtomGroup( + [atoms[0], atoms[1], atoms[2], atoms[3]], + universe=universe, + ) + outside_selection_atoms = FakeAtomGroup( + [atoms[0], atoms[1], atoms[2], atoms[6]], + universe=universe, + ) + wrong_size_atoms = FakeAtomGroup( + [atoms[0], atoms[1], atoms[2]], + universe=universe, + ) + + selected_heavy_atoms = FakeAtomGroup( + [atoms[0], atoms[1], atoms[2], atoms[3], atoms[4]], + dihedrals=[ + FakeDihedral(valid_dihedral_atoms), + FakeDihedral(outside_selection_atoms), + FakeDihedral(wrong_size_atoms), + ], + universe=universe, + ) + + result = discovery._get_dihedrals(selected_heavy_atoms, "united_atom") + + assert result == [valid_dihedral_atoms] + + +def test_get_dihedrals_united_atom_returns_empty_when_no_valid_dihedrals() -> None: + discovery = _make_discovery() + + atoms = [FakeAtom(index) for index in range(1, 5)] + universe = FakeUniverse(atoms) + + invalid_dihedral_atoms = FakeAtomGroup( + [atoms[0], atoms[1], atoms[2]], + universe=universe, + ) + selected_heavy_atoms = FakeAtomGroup( + atoms, + dihedrals=[FakeDihedral(invalid_dihedral_atoms)], + universe=universe, + ) + + assert discovery._get_dihedrals(selected_heavy_atoms, "united_atom") == [] + + +def test_get_dihedrals_residue_builds_one_dihedral_when_four_residues() -> None: + discovery = _make_discovery() + + atom1 = FakeAtom(10) + atom2 = FakeAtom(20) + atom3 = FakeAtom(30) + atom4 = FakeAtom(40) + + universe = FakeUniverse([atom1, atom2, atom3, atom4]) + + group1 = FakeAtomGroup([atom1], universe=universe) + group2 = FakeAtomGroup([atom2], universe=universe) + group3 = FakeAtomGroup([atom3], universe=universe) + group4 = FakeAtomGroup([atom4], universe=universe) + + atom1.bonded_atoms = group2 + atom2.bonded_atoms = group1 + atom3.bonded_atoms = group4 + atom4.bonded_atoms = group3 + + mol = FakeAtomGroup( + [atom1, atom2, atom3, atom4], + residues=[ + FakeResidue(group1), + FakeResidue(group2), + FakeResidue(group3), + FakeResidue(group4), + ], + universe=universe, + ) + + result = discovery._get_dihedrals(mol, "residue") + + assert len(result) == 1 + assert result[0].indices == [10, 20, 30, 40] + + +def test_get_dihedrals_residue_skips_invalid_four_residue_window( + caplog: pytest.LogCaptureFixture, +) -> None: + caplog.set_level(logging.DEBUG) + + discovery = _make_discovery() + + atom1 = FakeAtom(10) + atom2 = FakeAtom(20) + atom3 = FakeAtom(30) + atom4 = FakeAtom(40) + + universe = FakeUniverse([atom1, atom2, atom3, atom4]) + + group1 = FakeAtomGroup([atom1], universe=universe) + group2 = FakeAtomGroup([atom2], universe=universe) + group3 = FakeAtomGroup([atom3], universe=universe) + group4 = FakeAtomGroup([atom4], universe=universe) + + atom1.bonded_atoms = group2 + atom2.bonded_atoms = group1 + atom3.bonded_atoms = FakeAtomGroup([], universe=universe) + atom4.bonded_atoms = FakeAtomGroup([], universe=universe) + + mol = FakeAtomGroup( + [atom1, atom2, atom3, atom4], + residues=[ + FakeResidue(group1), + FakeResidue(group2), + FakeResidue(group3), + FakeResidue(group4), + ], + universe=universe, + ) + + result = discovery._get_dihedrals(mol, "residue") + + assert result == [] + assert "Skipping residue-level dihedral" in caplog.text - topologies = helper._discover_group_dihedral_topology( + +def test_get_dihedrals_residue_returns_empty_when_fewer_than_four_residues() -> None: + discovery = _make_discovery() + + atoms = [FakeAtom(1), FakeAtom(2), FakeAtom(3)] + universe = FakeUniverse(atoms) + groups = [FakeAtomGroup([atom], universe=universe) for atom in atoms] + + mol = FakeAtomGroup( + atoms, + residues=[FakeResidue(group) for group in groups], + universe=universe, + ) + + assert discovery._get_dihedrals(mol, "residue") == [] + + +def test_atoms_in_source_bonded_to_target_returns_matching_source_atoms() -> None: + atom1 = FakeAtom(1) + atom2 = FakeAtom(2) + atom3 = FakeAtom(3) + + universe = FakeUniverse([atom1, atom2, atom3]) + + source_group = FakeAtomGroup([atom1, atom3], universe=universe) + target_group = FakeAtomGroup([atom2], universe=universe) + + atom1.bonded_atoms = target_group + atom3.bonded_atoms = FakeAtomGroup([], universe=universe) + + source_residue = FakeResidue(source_group) + target_residue = FakeResidue(target_group) + + result = DihedralTopologyDiscovery._atoms_in_source_bonded_to_target( + source_residue, + target_residue, + ) + + assert result.indices == [1] + + +def test_atoms_in_source_bonded_to_target_returns_empty_when_atom_has_no_bonds() -> ( + None +): + atom1 = FakeAtom(1) + atom2 = FakeAtom(2) + + universe = FakeUniverse([atom1, atom2]) + + source_residue = FakeResidue(FakeAtomGroup([atom1], universe=universe)) + target_residue = FakeResidue(FakeAtomGroup([atom2], universe=universe)) + + result = DihedralTopologyDiscovery._atoms_in_source_bonded_to_target( + source_residue, + target_residue, + ) + + assert result.indices == [] + + +def test_discover_group_dihedral_topology_builds_one_entry_per_molecule( + monkeypatch: pytest.MonkeyPatch, +) -> None: + universe_operations = Mock() + + molecule0 = SimpleNamespace(label="molecule0", residues=[object(), object()]) + molecule1 = SimpleNamespace(label="molecule1", residues=[object()]) + + universe_operations.extract_fragment_atomgroup.side_effect = [ + molecule0, + molecule1, + ] + + discovery = _make_discovery(universe_operations) + + def fake_select_heavy_residue(mol: Any, res_id: int) -> str: + return f"{mol.label}-heavy-{res_id}" + + def fake_get_dihedrals(data_container: Any, level: str) -> list[str]: + if level == "united_atom": + return [f"ua-{data_container}"] + return [f"res-{data_container.label}"] + + monkeypatch.setattr( + discovery, + "_select_heavy_residue", + fake_select_heavy_residue, + ) + monkeypatch.setattr( + discovery, + "_get_dihedrals", + fake_get_dihedrals, + ) + + topologies = discovery._discover_group_dihedral_topology( data_container="universe", - group_id=3, - molecules=[7, 8], + group_id=9, + molecules=[100, 200], level_list=["united_atom", "residue"], ) - assert [topology.molecule_id for topology in topologies] == [7, 8] - assert [topology.molecule_order for topology in topologies] == [0, 1] - assert topologies[0].group_id == 3 - assert topologies[0].ua_dihedrals_by_residue == {0: ["ua0r0"], 1: ["ua0r1"]} - assert topologies[0].residue_dihedrals == ["res0"] - assert topologies[1].ua_dihedrals_by_residue == {0: ["ua1r0"], 1: ["ua1r1"]} - assert topologies[1].residue_dihedrals == ["res1"] + assert topologies == [ + MoleculeDihedralTopology( + group_id=9, + molecule_id=100, + molecule_order=0, + num_residues=2, + ua_dihedrals_by_residue={ + 0: ["ua-molecule0-heavy-0"], + 1: ["ua-molecule0-heavy-1"], + }, + residue_dihedrals=["res-molecule0"], + ), + MoleculeDihedralTopology( + group_id=9, + molecule_id=200, + molecule_order=1, + num_residues=1, + ua_dihedrals_by_residue={ + 0: ["ua-molecule1-heavy-0"], + }, + residue_dihedrals=["res-molecule1"], + ), + ] + + assert universe_operations.extract_fragment_atomgroup.call_args_list == [ + call("universe", 100), + call("universe", 200), + ] + + +def test_discover_group_dihedral_topology_respects_enabled_levels( + monkeypatch: pytest.MonkeyPatch, +) -> None: + universe_operations = Mock() + + molecule = SimpleNamespace(label="molecule", residues=[object(), object()]) + universe_operations.extract_fragment_atomgroup.return_value = molecule + + discovery = _make_discovery(universe_operations) + + select_heavy_residue = Mock(return_value="heavy-residue") + get_dihedrals = Mock(return_value=["residue-dihedral"]) + + monkeypatch.setattr(discovery, "_select_heavy_residue", select_heavy_residue) + monkeypatch.setattr(discovery, "_get_dihedrals", get_dihedrals) + + topologies = discovery._discover_group_dihedral_topology( + data_container="universe", + group_id=1, + molecules=[0], + level_list=["residue"], + ) + + assert topologies == [ + MoleculeDihedralTopology( + group_id=1, + molecule_id=0, + molecule_order=0, + num_residues=2, + ua_dihedrals_by_residue={}, + residue_dihedrals=["residue-dihedral"], + ) + ] + + select_heavy_residue.assert_not_called() + get_dihedrals.assert_called_once_with(molecule, "residue") diff --git a/tests/unit/CodeEntropy/levels/nodes/test_axes_topology.py b/tests/unit/CodeEntropy/levels/nodes/test_axes_topology.py new file mode 100644 index 00000000..680c439e --- /dev/null +++ b/tests/unit/CodeEntropy/levels/nodes/test_axes_topology.py @@ -0,0 +1,334 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import numpy as np + +from CodeEntropy.levels.nodes.axes_topology import ( + AxesTopology, + BuildAxesTopologyNode, + UAAxesTopology, +) + + +class FakeAtomGroup: + """Small AtomGroup-like object for axes-topology tests.""" + + def __init__(self, atoms=None, *, name="ag"): + self._atoms = list(atoms or []) + self.name = name + self.indices = np.asarray([atom.index for atom in self._atoms], dtype=int) + + def __iter__(self): + return iter(self._atoms) + + def __len__(self): + return len(self._atoms) + + def __getitem__(self, index): + if isinstance(index, (list, tuple, np.ndarray)): + return FakeAtomGroup([self._atoms[int(i)] for i in index]) + return self._atoms[int(index)] + + def select_atoms(self, selection): + """Return atoms matching the small mass selections used by the node.""" + if selection == "prop mass > 1.1": + return FakeAtomGroup([atom for atom in self._atoms if atom.mass > 1.1]) + if selection == "mass 2 to 999": + return FakeAtomGroup( + [atom for atom in self._atoms if 2.0 <= atom.mass <= 999.0] + ) + if selection == "mass 1 to 1.1": + return FakeAtomGroup( + [atom for atom in self._atoms if 1.0 <= atom.mass <= 1.1] + ) + raise AssertionError(f"Unexpected selection: {selection}") + + +class FakeAtom: + """Small Atom-like object with mass, index, and bonded atoms.""" + + def __init__(self, index, mass, bonded_atoms=None): + self.index = index + self.mass = mass + self.bonded_atoms = bonded_atoms or FakeAtomGroup([]) + + +class FakeResidue: + """Small residue-like object.""" + + def __init__(self, atoms): + self.atoms = atoms + + +class FakeMolecule: + """Small molecule-like object with residues.""" + + def __init__(self, residues): + self.residues = residues + + +class FakeAtoms: + """Container supporting fragments and u.atoms[index_array].""" + + def __init__(self, fragments, atom_map): + self.fragments = fragments + self._atom_map = dict(atom_map) + + def __getitem__(self, indices): + if isinstance(indices, np.ndarray): + return FakeAtomGroup([self._atom_map[int(index)] for index in indices]) + if isinstance(indices, (list, tuple)): + return FakeAtomGroup([self._atom_map[int(index)] for index in indices]) + return self._atom_map[int(indices)] + + +class FakeUniverse: + """Small universe-like object.""" + + def __init__(self, fragments, atom_map): + self.atoms = FakeAtoms(fragments=fragments, atom_map=atom_map) + + +def _args(*, customised_axes): + return SimpleNamespace(customised_axes=customised_axes) + + +def _single_molecule_universe(): + """Build a small molecule containing one residue and one UA bead.""" + hydrogen = FakeAtom(index=2, mass=1.0) + bonded_heavy = FakeAtom(index=3, mass=12.0) + heavy = FakeAtom( + index=1, + mass=12.0, + bonded_atoms=FakeAtomGroup([hydrogen, bonded_heavy]), + ) + other_residue_heavy = FakeAtom(index=4, mass=14.0) + + residue_atoms = FakeAtomGroup([heavy, hydrogen, bonded_heavy, other_residue_heavy]) + residue = FakeResidue(residue_atoms) + molecule = FakeMolecule([residue]) + + atom_map = { + heavy.index: heavy, + hydrogen.index: hydrogen, + bonded_heavy.index: bonded_heavy, + other_residue_heavy.index: other_residue_heavy, + } + universe = FakeUniverse([molecule], atom_map) + + return universe, molecule, heavy, hydrogen, bonded_heavy, other_residue_heavy + + +def test_ua_axes_topology_dataclass_preserves_arrays(): + topology = UAAxesTopology( + heavy_atom_index=1, + ua_atom_indices=np.array([1, 2]), + ua_all_atom_indices=np.array([1, 3, 2]), + bonded_heavy_indices=np.array([3]), + bonded_light_indices=np.array([2]), + residue_heavy_indices=np.array([1, 3, 4]), + residue_ua_masses=np.array([13.0, 12.0, 14.0]), + ) + + assert topology.heavy_atom_index == 1 + np.testing.assert_array_equal(topology.ua_atom_indices, np.array([1, 2])) + np.testing.assert_array_equal(topology.ua_all_atom_indices, np.array([1, 3, 2])) + np.testing.assert_array_equal(topology.bonded_heavy_indices, np.array([3])) + np.testing.assert_array_equal(topology.bonded_light_indices, np.array([2])) + np.testing.assert_array_equal(topology.residue_heavy_indices, np.array([1, 3, 4])) + np.testing.assert_allclose(topology.residue_ua_masses, np.array([13.0, 12.0, 14.0])) + + +def test_axes_topology_defaults_to_empty_ua_mapping(): + topology = AxesTopology() + + assert topology.ua == {} + + +def test_run_writes_empty_topology_when_customised_axes_disabled(): + node = BuildAxesTopologyNode() + shared_data = {"args": _args(customised_axes=False)} + + result = node.run(shared_data) + + assert isinstance(result["axes_topology"], AxesTopology) + assert result["axes_topology"].ua == {} + assert shared_data["axes_topology"] is result["axes_topology"] + + +def test_run_builds_ua_topology_for_united_atom_levels(): + node = BuildAxesTopologyNode() + universe, _, heavy, hydrogen, bonded_heavy, other_residue_heavy = ( + _single_molecule_universe() + ) + shared_data = { + "args": _args(customised_axes=True), + "reduced_universe": universe, + "levels": [["united_atom"]], + "beads": {(0, "united_atom", 0): [np.array([1, 2])]}, + } + + result = node.run(shared_data) + + axes_topology = result["axes_topology"] + assert shared_data["axes_topology"] is axes_topology + assert set(axes_topology.ua) == {(0, 0, 0)} + + ua_topology = axes_topology.ua[(0, 0, 0)] + assert ua_topology.heavy_atom_index == heavy.index + np.testing.assert_array_equal(ua_topology.ua_atom_indices, np.array([1, 2])) + np.testing.assert_array_equal(ua_topology.ua_all_atom_indices, np.array([1, 3, 2])) + np.testing.assert_array_equal(ua_topology.bonded_heavy_indices, np.array([3])) + np.testing.assert_array_equal(ua_topology.bonded_light_indices, np.array([2])) + np.testing.assert_array_equal( + ua_topology.residue_heavy_indices, + np.array([heavy.index, bonded_heavy.index, other_residue_heavy.index]), + ) + np.testing.assert_allclose( + ua_topology.residue_ua_masses, + np.array([13.0, 12.0, 14.0]), + ) + + +def test_run_ignores_molecules_without_united_atom_level(): + node = BuildAxesTopologyNode() + universe, _, _, _, _, _ = _single_molecule_universe() + shared_data = { + "args": _args(customised_axes=True), + "reduced_universe": universe, + "levels": [["residue"]], + "beads": {(0, "united_atom", 0): [np.array([1, 2])]}, + } + + result = node.run(shared_data) + + assert result["axes_topology"].ua == {} + assert shared_data["axes_topology"].ua == {} + + +def test_add_ua_topology_skips_residues_without_beads(): + node = BuildAxesTopologyNode() + universe, molecule, _, _, _, _ = _single_molecule_universe() + out = {} + + node._add_ua_topology( + u=universe, + mol=molecule, + mol_id=0, + beads={}, + out=out, + ) + + assert out == {} + + +def test_add_ua_topology_skips_ua_beads_without_heavy_atoms(caplog): + node = BuildAxesTopologyNode() + hydrogen = FakeAtom(index=2, mass=1.0) + residue = FakeResidue(FakeAtomGroup([hydrogen])) + molecule = FakeMolecule([residue]) + universe = FakeUniverse([molecule], {2: hydrogen}) + out = {} + + node._add_ua_topology( + u=universe, + mol=molecule, + mol_id=5, + beads={(5, "united_atom", 0): [np.array([2])]}, + out=out, + ) + + assert out == {} + assert "Skipping UA axes topology with no heavy atom" in caplog.text + + +def test_add_ua_topology_handles_multiple_residues_and_ua_beads(): + node = BuildAxesTopologyNode() + + h0 = FakeAtom(index=10, mass=1.0) + c0 = FakeAtom(index=11, mass=12.0, bonded_atoms=FakeAtomGroup([h0])) + residue0 = FakeResidue(FakeAtomGroup([c0, h0])) + + h1 = FakeAtom(index=20, mass=1.0) + c1 = FakeAtom(index=21, mass=12.0, bonded_atoms=FakeAtomGroup([h1])) + residue1 = FakeResidue(FakeAtomGroup([c1, h1])) + + molecule = FakeMolecule([residue0, residue1]) + universe = FakeUniverse( + [molecule], + { + h0.index: h0, + c0.index: c0, + h1.index: h1, + c1.index: c1, + }, + ) + out = {} + + node._add_ua_topology( + u=universe, + mol=molecule, + mol_id=3, + beads={ + (3, "united_atom", 0): [np.array([11, 10])], + (3, "united_atom", 1): [np.array([21, 20])], + }, + out=out, + ) + + assert set(out) == {(3, 0, 0), (3, 1, 0)} + assert out[(3, 0, 0)].heavy_atom_index == 11 + assert out[(3, 1, 0)].heavy_atom_index == 21 + + +def test_split_bonded_atoms_returns_heavy_and_light_atom_groups(): + hydrogen = FakeAtom(index=2, mass=1.0) + heavy_bonded = FakeAtom(index=3, mass=12.0) + atom = FakeAtom( + index=1, + mass=12.0, + bonded_atoms=FakeAtomGroup([hydrogen, heavy_bonded]), + ) + + bonded_heavy, bonded_light = BuildAxesTopologyNode._split_bonded_atoms(atom) + + np.testing.assert_array_equal(bonded_heavy.indices, np.array([3])) + np.testing.assert_array_equal(bonded_light.indices, np.array([2])) + + +def test_get_ua_masses_from_topology_adds_bonded_hydrogen_masses(): + hydrogen = FakeAtom(index=2, mass=1.0) + heavy = FakeAtom(index=1, mass=12.0, bonded_atoms=FakeAtomGroup([hydrogen])) + other_heavy = FakeAtom(index=3, mass=14.0) + atom_group = FakeAtomGroup([heavy, hydrogen, other_heavy]) + + masses = BuildAxesTopologyNode._get_ua_masses_from_topology(atom_group) + + assert masses == [13.0, 14.0] + + +def test_get_ua_masses_from_topology_handles_atoms_without_bonded_atoms_attribute(): + class AtomWithoutBonds: + """Small atom-like object without bonded_atoms.""" + + def __init__(self, index, mass): + self.index = index + self.mass = mass + + heavy = AtomWithoutBonds(index=1, mass=12.0) + hydrogen = AtomWithoutBonds(index=2, mass=1.0) + atom_group = FakeAtomGroup([heavy, hydrogen]) + + masses = BuildAxesTopologyNode._get_ua_masses_from_topology(atom_group) + + assert masses == [12.0] + + +def test_get_ua_masses_from_topology_returns_empty_list_for_light_atoms_only(): + hydrogen = FakeAtom(index=2, mass=1.0) + atom_group = FakeAtomGroup([hydrogen]) + + masses = BuildAxesTopologyNode._get_ua_masses_from_topology(atom_group) + + assert masses == [] diff --git a/tests/unit/CodeEntropy/levels/nodes/test_covariance_node.py b/tests/unit/CodeEntropy/levels/nodes/test_covariance_node.py index 8398113c..6a3d2ec8 100644 --- a/tests/unit/CodeEntropy/levels/nodes/test_covariance_node.py +++ b/tests/unit/CodeEntropy/levels/nodes/test_covariance_node.py @@ -88,6 +88,7 @@ def test_run_processes_all_levels_and_writes_frame_covariance(): mol = FakeMolecule() universe = FakeUniverse([mol], dimensions=np.array([10.0, 20.0, 30.0, 90.0])) axes_manager = object() + axes_topology = object() ctx = { "shared": { @@ -97,6 +98,7 @@ def test_run_processes_all_levels_and_writes_frame_covariance(): "beads": {}, "args": _args(combined_forcetorque=True, customised_axes=True), "axes_manager": axes_manager, + "axes_topology": axes_topology, } } @@ -115,6 +117,7 @@ def test_run_processes_all_levels_and_writes_frame_covariance(): assert ua_kwargs["mol_id"] == 0 assert ua_kwargs["group_id"] == 7 assert ua_kwargs["axes_manager"] is axes_manager + assert ua_kwargs["axes_topology"] is axes_topology assert ua_kwargs["force_partitioning"] == 0.5 assert ua_kwargs["customised_axes"] is True assert ua_kwargs["is_highest"] is False @@ -165,6 +168,7 @@ def test_process_united_atom_updates_outputs_and_molcount(): group_id=7, beads={(0, "united_atom", 0): [np.array([0])]}, axes_manager="axes", + axes_topology=None, box=None, force_partitioning=0.5, customised_axes=False, @@ -192,6 +196,7 @@ def test_process_united_atom_returns_when_no_beads_or_empty_atom_groups(): group_id=7, beads={}, axes_manager=None, + axes_topology=None, box=None, force_partitioning=0.5, customised_axes=False, @@ -213,6 +218,7 @@ def test_process_united_atom_returns_when_no_beads_or_empty_atom_groups(): group_id=7, beads={(0, "united_atom", 0): [np.array([0])]}, axes_manager=None, + axes_topology=None, box=None, force_partitioning=0.5, customised_axes=False, @@ -414,9 +420,13 @@ def test_build_ua_vectors_uses_customised_axes(): node._ft.get_weighted_torques = MagicMock(return_value=np.array([0.0, 1.0, 0.0])) force_vecs, torque_vecs = node._build_ua_vectors( + u=FakeUniverse([]), + mol_id=0, + local_res_i=0, bead_groups=[FakeAtomGroup("ua")], residue_atoms=FakeAtomGroup("res"), axes_manager=axes_manager, + axes_topology=None, box=None, force_partitioning=0.5, customised_axes=True, @@ -428,6 +438,49 @@ def test_build_ua_vectors_uses_customised_axes(): axes_manager.get_UA_axes.assert_called_once() +def test_build_ua_vectors_uses_cached_axes_topology_when_available(): + node = FrameCovarianceNode() + axes_manager = MagicMock() + + u = FakeUniverse([]) + ua_topology = object() + axes_topology = SimpleNamespace(ua={(3, 4, 0): ua_topology}) + + axes_manager.get_UA_axes_from_topology.return_value = ( + np.eye(3), + 2.0 * np.eye(3), + np.ones(3), + np.array([1.0, 2.0, 3.0]), + ) + node._ft.get_weighted_forces = MagicMock(return_value=np.array([1.0, 0.0, 0.0])) + node._ft.get_weighted_torques = MagicMock(return_value=np.array([0.0, 1.0, 0.0])) + + force_vecs, torque_vecs = node._build_ua_vectors( + u=u, + mol_id=3, + local_res_i=4, + bead_groups=[FakeAtomGroup("ua")], + residue_atoms=FakeAtomGroup("res"), + axes_manager=axes_manager, + axes_topology=axes_topology, + box=None, + force_partitioning=0.5, + customised_axes=True, + is_highest=True, + ) + + assert len(force_vecs) == 1 + assert len(torque_vecs) == 1 + + called_kwargs = axes_manager.get_UA_axes_from_topology.call_args.kwargs + assert called_kwargs["u"] is u + assert called_kwargs["topology"] is ua_topology + assert called_kwargs["box"] is None + assert called_kwargs["residue_atoms"].name == "res" + + axes_manager.get_UA_axes.assert_not_called() + + def test_build_ua_vectors_uses_vanilla_axes_when_not_customised(): node = FrameCovarianceNode() axes_manager = MagicMock() @@ -440,9 +493,13 @@ def test_build_ua_vectors_uses_vanilla_axes_when_not_customised(): with patch("CodeEntropy.levels.nodes.covariance.make_whole") as make_whole: node._build_ua_vectors( + u=FakeUniverse([]), + mol_id=0, + local_res_i=0, bead_groups=[FakeAtomGroup("ua")], residue_atoms=FakeAtomGroup("res"), axes_manager=axes_manager, + axes_topology=None, box=None, force_partitioning=0.5, customised_axes=False, diff --git a/tests/unit/CodeEntropy/levels/test_axes.py b/tests/unit/CodeEntropy/levels/test_axes.py index d6d0093f..68adfb8a 100644 --- a/tests/unit/CodeEntropy/levels/test_axes.py +++ b/tests/unit/CodeEntropy/levels/test_axes.py @@ -4,6 +4,7 @@ import pytest from CodeEntropy.levels.axes import AxesCalculator +from CodeEntropy.levels.nodes.axes_topology import UAAxesTopology class _FakeAtom: @@ -699,3 +700,396 @@ def test_get_bonded_axes_returns_none_none_if_custom_axes_none(monkeypatch): assert custom_axes is None assert moi is None + + +class _FakeIndexedAtoms: + """Container supporting ``u.atoms[index]`` and ``u.atoms[index_array]``.""" + + def __init__(self, atom_map): + self._atom_map = dict(atom_map) + + def __getitem__(self, index): + if isinstance(index, np.ndarray): + return _FakeAtomGroup([self._atom_map[int(i)] for i in index]) + if isinstance(index, (list, tuple)): + return _FakeAtomGroup([self._atom_map[int(i)] for i in index]) + return self._atom_map[int(index)] + + +class _FakeUniverse: + """Small universe-like object with indexed atoms and dimensions.""" + + def __init__(self, atom_map, dimensions=None): + self.atoms = _FakeIndexedAtoms(atom_map) + self.dimensions = np.asarray( + dimensions + if dimensions is not None + else [10.0, 10.0, 10.0, 90.0, 90.0, 90.0], + dtype=float, + ) + + +def _ua_topology( + *, + heavy_atom_index=1, + ua_atom_indices=(1,), + ua_all_atom_indices=(1,), + bonded_heavy_indices=(), + bonded_light_indices=(), + residue_heavy_indices=(1,), + residue_ua_masses=(12.0,), +): + """Build a small cached UA topology fixture.""" + return UAAxesTopology( + heavy_atom_index=int(heavy_atom_index), + ua_atom_indices=np.asarray(ua_atom_indices, dtype=int), + ua_all_atom_indices=np.asarray(ua_all_atom_indices, dtype=int), + bonded_heavy_indices=np.asarray(bonded_heavy_indices, dtype=int), + bonded_light_indices=np.asarray(bonded_light_indices, dtype=int), + residue_heavy_indices=np.asarray(residue_heavy_indices, dtype=int), + residue_ua_masses=np.asarray(residue_ua_masses, dtype=float), + ) + + +def test_get_UA_axes_from_topology_multiple_heavy_uses_cached_indices_and_box( + monkeypatch, +): + ax = AxesCalculator() + + heavy_atom = _FakeAtom(1, 12.0, [1.0, 2.0, 3.0]) + other_heavy = _FakeAtom(3, 14.0, [4.0, 5.0, 6.0]) + universe = _FakeUniverse({1: heavy_atom, 3: other_heavy}) + residue_atoms = MagicMock() + residue_atoms.center_of_mass.return_value = np.array([9.0, 8.0, 7.0]) + + topology = _ua_topology( + heavy_atom_index=1, + residue_heavy_indices=(1, 3), + residue_ua_masses=(13.0, 14.0), + ) + + get_tensor = MagicMock(return_value=np.eye(3)) + get_principal = MagicMock(return_value=(np.eye(3) * 2.0, np.array([3.0, 2.0, 1.0]))) + get_bonded = MagicMock(return_value=(np.eye(3) * 4.0, np.array([1.0, 1.0, 1.0]))) + + monkeypatch.setattr(ax, "get_moment_of_inertia_tensor", get_tensor) + monkeypatch.setattr(ax, "get_custom_principal_axes", get_principal) + monkeypatch.setattr(ax, "get_bonded_axes_from_topology", get_bonded) + + box = np.array([20.0, 30.0, 40.0]) + trans_axes, rot_axes, center, moi = ax.get_UA_axes_from_topology( + u=universe, + residue_atoms=residue_atoms, + topology=topology, + box=box, + ) + + np.testing.assert_allclose(trans_axes, np.eye(3) * 2.0) + np.testing.assert_allclose(rot_axes, np.eye(3) * 4.0) + np.testing.assert_allclose(center, heavy_atom.position) + np.testing.assert_allclose(moi, np.array([1.0, 1.0, 1.0])) + + get_tensor.assert_called_once() + tensor_kwargs = get_tensor.call_args.kwargs + np.testing.assert_allclose( + tensor_kwargs["center_of_mass"], np.array([9.0, 8.0, 7.0]) + ) + np.testing.assert_allclose( + tensor_kwargs["positions"], np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + ) + np.testing.assert_allclose(tensor_kwargs["masses"], np.array([13.0, 14.0])) + np.testing.assert_allclose(tensor_kwargs["dimensions"], box) + + get_principal.assert_called_once() + np.testing.assert_allclose(get_principal.call_args.args[0], np.eye(3)) + get_bonded.assert_called_once_with( + u=universe, + heavy_atom=heavy_atom, + topology=topology, + dimensions=box, + ) + + +def test_get_UA_axes_from_topology_single_heavy_uses_residue_principal_axes( + monkeypatch, +): + ax = AxesCalculator() + + heavy_atom = _FakeAtom(1, 12.0, [1.0, 0.0, 0.0]) + universe = _FakeUniverse( + {1: heavy_atom}, dimensions=[11.0, 12.0, 13.0, 90.0, 90.0, 90.0] + ) + residue_atoms = MagicMock() + residue_atoms.principal_axes.return_value = np.eye(3) * 5.0 + + topology = _ua_topology(heavy_atom_index=1, residue_heavy_indices=(1,)) + + make_whole = MagicMock() + get_bonded = MagicMock(return_value=(np.eye(3) * 6.0, np.array([6.0, 5.0, 4.0]))) + + monkeypatch.setattr("CodeEntropy.levels.axes.make_whole", make_whole) + monkeypatch.setattr(ax, "get_bonded_axes_from_topology", get_bonded) + + trans_axes, rot_axes, center, moi = ax.get_UA_axes_from_topology( + u=universe, + residue_atoms=residue_atoms, + topology=topology, + box=None, + ) + + make_whole.assert_called_once_with(residue_atoms) + residue_atoms.principal_axes.assert_called_once() + np.testing.assert_allclose(trans_axes, np.eye(3) * 5.0) + np.testing.assert_allclose(rot_axes, np.eye(3) * 6.0) + np.testing.assert_allclose(center, heavy_atom.position) + np.testing.assert_allclose(moi, np.array([6.0, 5.0, 4.0])) + + called_kwargs = get_bonded.call_args.kwargs + np.testing.assert_allclose( + called_kwargs["dimensions"], np.array([11.0, 12.0, 13.0]) + ) + + +def test_get_UA_axes_from_topology_raises_when_cached_bonded_axes_fail(monkeypatch): + ax = AxesCalculator() + + heavy_atom = _FakeAtom(1, 12.0, [1.0, 0.0, 0.0]) + universe = _FakeUniverse({1: heavy_atom}) + residue_atoms = MagicMock() + residue_atoms.principal_axes.return_value = np.eye(3) + topology = _ua_topology(heavy_atom_index=1, residue_heavy_indices=(1,)) + + monkeypatch.setattr("CodeEntropy.levels.axes.make_whole", lambda _ag: None) + monkeypatch.setattr( + ax, "get_bonded_axes_from_topology", lambda **kwargs: (None, None) + ) + + with pytest.raises(ValueError, match="cached UA bead"): + ax.get_UA_axes_from_topology( + u=universe, + residue_atoms=residue_atoms, + topology=topology, + box=None, + ) + + +def test_get_bonded_axes_from_topology_non_heavy_returns_none_none(): + ax = AxesCalculator() + light_atom = _FakeAtom(1, 1.0, [0.0, 0.0, 0.0]) + + custom_axes, moi = ax.get_bonded_axes_from_topology( + u=MagicMock(), + heavy_atom=light_atom, + topology=_ua_topology(heavy_atom_index=1), + dimensions=np.array([10.0, 10.0, 10.0]), + ) + + assert custom_axes is None + assert moi is None + + +def test_get_bonded_axes_from_topology_no_bonded_heavy_uses_vanilla_axes( + monkeypatch, +): + ax = AxesCalculator() + + heavy_atom = _FakeAtom(1, 12.0, [0.0, 0.0, 0.0]) + hydrogen = _FakeAtom(2, 1.0, [1.0, 0.0, 0.0]) + universe = _FakeUniverse({1: heavy_atom, 2: hydrogen}) + topology = _ua_topology( + heavy_atom_index=1, + ua_atom_indices=(1, 2), + ua_all_atom_indices=(1, 2), + bonded_heavy_indices=(), + bonded_light_indices=(2,), + ) + + get_vanilla = MagicMock(return_value=(np.eye(3) * 7.0, np.array([7.0, 8.0, 9.0]))) + get_custom_moi = MagicMock() + get_flipped = MagicMock(return_value=np.eye(3) * -7.0) + + monkeypatch.setattr(ax, "get_vanilla_axes", get_vanilla) + monkeypatch.setattr(ax, "get_custom_moment_of_inertia", get_custom_moi) + monkeypatch.setattr(ax, "get_flipped_axes", get_flipped) + + custom_axes, moi = ax.get_bonded_axes_from_topology( + u=universe, + heavy_atom=heavy_atom, + topology=topology, + dimensions=np.array([10.0, 10.0, 10.0]), + ) + + np.testing.assert_allclose(custom_axes, np.eye(3) * -7.0) + np.testing.assert_allclose(moi, np.array([7.0, 8.0, 9.0])) + get_vanilla.assert_called_once() + get_custom_moi.assert_not_called() + get_flipped.assert_called_once() + + +def test_get_bonded_axes_from_topology_one_heavy_no_light_uses_custom_axes( + monkeypatch, +): + ax = AxesCalculator() + + heavy_atom = _FakeAtom(1, 12.0, [0.0, 0.0, 0.0]) + bonded_heavy = _FakeAtom(3, 12.0, [1.0, 0.0, 0.0]) + universe = _FakeUniverse({1: heavy_atom, 3: bonded_heavy}) + topology = _ua_topology( + heavy_atom_index=1, + ua_atom_indices=(1,), + ua_all_atom_indices=(1, 3), + bonded_heavy_indices=(3,), + bonded_light_indices=(), + ) + + get_custom_axes = MagicMock(return_value=np.eye(3) * 2.0) + get_custom_moi = MagicMock(return_value=np.array([2.0, 3.0, 4.0])) + get_flipped = MagicMock(return_value=np.eye(3) * 3.0) + + monkeypatch.setattr(ax, "get_custom_axes", get_custom_axes) + monkeypatch.setattr(ax, "get_custom_moment_of_inertia", get_custom_moi) + monkeypatch.setattr(ax, "get_flipped_axes", get_flipped) + + custom_axes, moi = ax.get_bonded_axes_from_topology( + u=universe, + heavy_atom=heavy_atom, + topology=topology, + dimensions=np.array([10.0, 10.0, 10.0]), + ) + + np.testing.assert_allclose(custom_axes, np.eye(3) * 3.0) + np.testing.assert_allclose(moi, np.array([2.0, 3.0, 4.0])) + + kwargs = get_custom_axes.call_args.kwargs + np.testing.assert_allclose(kwargs["a"], heavy_atom.position) + np.testing.assert_allclose(kwargs["b_list"][0], bonded_heavy.position) + np.testing.assert_allclose(kwargs["c"], np.zeros(3)) + get_custom_moi.assert_called_once() + get_flipped.assert_called_once() + + +def test_get_bonded_axes_from_topology_one_heavy_with_light_uses_light_as_c( + monkeypatch, +): + ax = AxesCalculator() + + heavy_atom = _FakeAtom(1, 12.0, [0.0, 0.0, 0.0]) + bonded_heavy = _FakeAtom(3, 12.0, [1.0, 0.0, 0.0]) + bonded_light = _FakeAtom(2, 1.0, [0.0, 1.0, 0.0]) + universe = _FakeUniverse({1: heavy_atom, 2: bonded_light, 3: bonded_heavy}) + topology = _ua_topology( + heavy_atom_index=1, + ua_atom_indices=(1, 2), + ua_all_atom_indices=(1, 3, 2), + bonded_heavy_indices=(3,), + bonded_light_indices=(2,), + ) + + get_custom_axes = MagicMock(return_value=np.eye(3)) + monkeypatch.setattr(ax, "get_custom_axes", get_custom_axes) + monkeypatch.setattr( + ax, + "get_custom_moment_of_inertia", + lambda **kwargs: np.array([1.0, 2.0, 3.0]), + ) + monkeypatch.setattr(ax, "get_flipped_axes", lambda ua, axes, com, dims: axes) + + custom_axes, moi = ax.get_bonded_axes_from_topology( + u=universe, + heavy_atom=heavy_atom, + topology=topology, + dimensions=np.array([10.0, 10.0, 10.0]), + ) + + np.testing.assert_allclose(custom_axes, np.eye(3)) + np.testing.assert_allclose(moi, np.array([1.0, 2.0, 3.0])) + np.testing.assert_allclose( + get_custom_axes.call_args.kwargs["c"], bonded_light.position + ) + + +def test_get_bonded_axes_from_topology_two_heavy_uses_heavy_positions_as_b_list( + monkeypatch, +): + ax = AxesCalculator() + + heavy_atom = _FakeAtom(1, 12.0, [0.0, 0.0, 0.0]) + bonded_heavy_0 = _FakeAtom(3, 12.0, [1.0, 0.0, 0.0]) + bonded_heavy_1 = _FakeAtom(4, 12.0, [0.0, 1.0, 0.0]) + universe = _FakeUniverse( + { + 1: heavy_atom, + 3: bonded_heavy_0, + 4: bonded_heavy_1, + } + ) + topology = _ua_topology( + heavy_atom_index=1, + ua_atom_indices=(1,), + ua_all_atom_indices=(1, 3, 4), + bonded_heavy_indices=(3, 4), + bonded_light_indices=(), + ) + + get_custom_axes = MagicMock(return_value=np.eye(3) * 4.0) + monkeypatch.setattr(ax, "get_custom_axes", get_custom_axes) + monkeypatch.setattr( + ax, + "get_custom_moment_of_inertia", + lambda **kwargs: np.array([4.0, 5.0, 6.0]), + ) + monkeypatch.setattr(ax, "get_flipped_axes", lambda ua, axes, com, dims: axes) + + custom_axes, moi = ax.get_bonded_axes_from_topology( + u=universe, + heavy_atom=heavy_atom, + topology=topology, + dimensions=np.array([10.0, 10.0, 10.0]), + ) + + np.testing.assert_allclose(custom_axes, np.eye(3) * 4.0) + np.testing.assert_allclose(moi, np.array([4.0, 5.0, 6.0])) + np.testing.assert_allclose( + get_custom_axes.call_args.kwargs["b_list"], + np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]), + ) + np.testing.assert_allclose( + get_custom_axes.call_args.kwargs["c"], + bonded_heavy_1.position, + ) + + +def test_get_bonded_axes_from_topology_returns_none_when_custom_axes_none( + monkeypatch, +): + ax = AxesCalculator() + + heavy_atom = _FakeAtom(1, 12.0, [0.0, 0.0, 0.0]) + bonded_heavy = _FakeAtom(3, 12.0, [1.0, 0.0, 0.0]) + universe = _FakeUniverse({1: heavy_atom, 3: bonded_heavy}) + topology = _ua_topology( + heavy_atom_index=1, + ua_atom_indices=(1,), + ua_all_atom_indices=(1, 3), + bonded_heavy_indices=(3,), + bonded_light_indices=(), + ) + + get_custom_moi = MagicMock() + get_flipped = MagicMock() + + monkeypatch.setattr(ax, "get_custom_axes", lambda **kwargs: None) + monkeypatch.setattr(ax, "get_custom_moment_of_inertia", get_custom_moi) + monkeypatch.setattr(ax, "get_flipped_axes", get_flipped) + + custom_axes, moi = ax.get_bonded_axes_from_topology( + u=universe, + heavy_atom=heavy_atom, + topology=topology, + dimensions=np.array([10.0, 10.0, 10.0]), + ) + + assert custom_axes is None + assert moi is None + get_custom_moi.assert_not_called() + get_flipped.assert_not_called() diff --git a/tests/unit/CodeEntropy/levels/test_level_dag.py b/tests/unit/CodeEntropy/levels/test_level_dag.py index 1b9f4d4f..11e74e9e 100644 --- a/tests/unit/CodeEntropy/levels/test_level_dag.py +++ b/tests/unit/CodeEntropy/levels/test_level_dag.py @@ -12,6 +12,7 @@ def test_build_registers_static_nodes_and_builds_stage_dags(): patch("CodeEntropy.levels.level_dag.DetectMoleculesNode"), patch("CodeEntropy.levels.level_dag.DetectLevelsNode"), patch("CodeEntropy.levels.level_dag.BuildBeadsNode"), + patch("CodeEntropy.levels.level_dag.BuildAxesTopologyNode"), patch("CodeEntropy.levels.level_dag.InitCovarianceAccumulatorsNode"), patch("CodeEntropy.levels.level_dag.ConformationDAG"), ): @@ -27,6 +28,7 @@ def test_build_registers_static_nodes_and_builds_stage_dags(): "detect_molecules", "detect_levels", "build_beads", + "build_axes_topology", "init_covariance_accumulators", } assert "find_neighbors" not in dag._static_nodes @@ -34,6 +36,7 @@ def test_build_registers_static_nodes_and_builds_stage_dags(): assert ("detect_molecules", "detect_levels") in dag._static_graph.edges assert ("detect_levels", "build_beads") in dag._static_graph.edges + assert ("build_beads", "build_axes_topology") in dag._static_graph.edges assert ("detect_levels", "init_covariance_accumulators") in dag._static_graph.edges assert ( "detect_levels", diff --git a/tests/unit/CodeEntropy/levels/test_mda_universe_operations.py b/tests/unit/CodeEntropy/levels/test_mda_universe_operations.py index e8d92329..d298c486 100644 --- a/tests/unit/CodeEntropy/levels/test_mda_universe_operations.py +++ b/tests/unit/CodeEntropy/levels/test_mda_universe_operations.py @@ -1,4 +1,5 @@ -from unittest.mock import MagicMock +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock import numpy as np import pytest @@ -407,3 +408,19 @@ def test_select_frame_indices_raises_when_frame_indices_empty(): ops.select_frame_indices(u, frame_indices=[]) u.select_atoms.assert_not_called() + + +def test_extract_fragment_atomgroup_returns_lightweight_range_selection() -> None: + universe = Mock() + universe.atoms.fragments = [ + SimpleNamespace(indices=[10, 11, 12, 13]), + ] + universe.select_atoms.return_value = "fragment_atomgroup" + + result = UniverseOperations().extract_fragment_atomgroup(universe, 0) + + assert result == "fragment_atomgroup" + universe.select_atoms.assert_called_once_with( + "index 10:13", + updating=False, + )