Skip to content
155 changes: 155 additions & 0 deletions CodeEntropy/levels/axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,161 @@ def get_UA_axes(self, data_container, index: int):

return trans_axes, rot_axes, center, moment_of_inertia

def get_UA_axes_from_topology(
self,
*,
u,
residue_atoms,
topology,
box: np.ndarray | None,
):
"""Compute UA axes using cached static topology.

This is the cached-index equivalent of ``get_UA_axes``. It preserves the
frame-dependent numerical calculations, but avoids repeated MDAnalysis
selection strings for heavy atoms, bonded atoms, and UA masses.

Args:
u: Current-frame universe.
residue_atoms: AtomGroup for the parent residue in the current frame.
topology: Cached ``UAAxesTopology`` for this UA bead.
box: Current periodic box lengths. If omitted, ``u.dimensions`` is used.

Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
- trans_axes: Translational axes, shape ``(3, 3)``.
- rot_axes: Rotational axes, shape ``(3, 3)``.
- center: Rotation centre, shape ``(3,)``.
- moment_of_inertia: Principal moments, shape ``(3,)``.

Raises:
ValueError: If cached bonded-axis construction fails.
"""
dimensions = (
np.asarray(box, dtype=float)
if box is not None
else np.asarray(u.dimensions[:3], dtype=float)
)

heavy_atoms = u.atoms[topology.residue_heavy_indices]
heavy_atom = u.atoms[int(topology.heavy_atom_index)]

if len(heavy_atoms) > 1:
center = residue_atoms.center_of_mass(unwrap=True)
moment_of_inertia_tensor = self.get_moment_of_inertia_tensor(
center_of_mass=center,
positions=heavy_atoms.positions,
masses=topology.residue_ua_masses,
dimensions=dimensions,
)
trans_axes, _moment_of_inertia = self.get_custom_principal_axes(
moment_of_inertia_tensor
)
else:
make_whole(residue_atoms)
trans_axes = residue_atoms.principal_axes()

center = heavy_atom.position
rot_axes, moment_of_inertia = self.get_bonded_axes_from_topology(
u=u,
heavy_atom=heavy_atom,
topology=topology,
dimensions=dimensions,
)
if rot_axes is None or moment_of_inertia is None:
raise ValueError("Unable to compute bonded axes for cached UA bead.")

logger.debug("Translational Axes: %s", trans_axes)
logger.debug("Rotational Axes: %s", rot_axes)
logger.debug("Center: %s", center)
logger.debug("Moment of Inertia: %s", moment_of_inertia)

return trans_axes, rot_axes, center, moment_of_inertia

def get_bonded_axes_from_topology(
self,
*,
u,
heavy_atom,
topology,
dimensions: np.ndarray,
):
"""Compute UA bonded axes using cached bonded atom indices.

This mirrors ``get_bonded_axes`` but receives precomputed bonded atom
memberships from ``UAAxesTopology`` instead of rediscovering them with
MDAnalysis selection strings inside the frame loop.

Args:
u: Current-frame universe.
heavy_atom: Current-frame heavy atom for the UA bead.
topology: Cached ``UAAxesTopology`` for the UA bead.
dimensions: Simulation box lengths, shape ``(3,)``.

Returns:
Tuple[np.ndarray | None, np.ndarray | None]:
- custom_axes: Custom rotation axes, shape ``(3, 3)``, or ``None``.
- custom_moment_of_inertia: Principal moments, shape ``(3,)``, or
``None``.
"""
if not heavy_atom.mass > 1.1:
return None, None

custom_moment_of_inertia = None
custom_axes = None

heavy_bonded = u.atoms[topology.bonded_heavy_indices]
light_bonded = u.atoms[topology.bonded_light_indices]
ua = u.atoms[topology.ua_atom_indices]
ua_all = u.atoms[topology.ua_all_atom_indices]

if len(heavy_bonded) == 0:
custom_axes, custom_moment_of_inertia = self.get_vanilla_axes(ua_all)

if len(heavy_bonded) == 1 and len(light_bonded) == 0:
custom_axes = self.get_custom_axes(
a=heavy_atom.position,
b_list=[heavy_bonded[0].position],
c=np.zeros(3),
dimensions=dimensions,
)

if len(heavy_bonded) == 1 and len(light_bonded) >= 1:
custom_axes = self.get_custom_axes(
a=heavy_atom.position,
b_list=[heavy_bonded[0].position],
c=light_bonded[0].position,
dimensions=dimensions,
)

if len(heavy_bonded) >= 2:
custom_axes = self.get_custom_axes(
a=heavy_atom.position,
b_list=heavy_bonded.positions,
c=heavy_bonded[1].position,
dimensions=dimensions,
)

if custom_axes is None:
return None, None

if custom_moment_of_inertia is None:
custom_moment_of_inertia = self.get_custom_moment_of_inertia(
UA=ua,
custom_rotation_axes=custom_axes,
center_of_mass=heavy_atom.position,
dimensions=dimensions,
)

custom_axes = self.get_flipped_axes(
ua,
custom_axes,
heavy_atom.position,
dimensions,
)

return custom_axes, custom_moment_of_inertia

def get_bonded_axes(self, system, atom, dimensions: np.ndarray):
r"""Compute UA rotational axes from bonded topology around a heavy atom.

Expand Down
123 changes: 105 additions & 18 deletions CodeEntropy/levels/dihedrals/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ def _discover_group_dihedral_topology(
topologies: list[MoleculeDihedralTopology] = []

for molecule_order, molecule_id in enumerate(molecules):
mol = self._universe_operations.extract_fragment(
data_container, molecule_id
)
mol = self._extract_topology_fragment(data_container, molecule_id)
num_residues = len(mol.residues)
ua_dihedrals_by_residue: dict[int, list[Any]] = {}
residue_dihedrals: list[Any] = []
Expand Down Expand Up @@ -90,6 +88,26 @@ def _discover_group_dihedral_topology(

return topologies

def _extract_topology_fragment(self, data_container: Any, molecule_id: Any) -> Any:
"""Return a molecule fragment for topology discovery.

This uses the lightweight AtomGroup extraction helper when available so
static conformational topology discovery does not create a standalone
in-memory universe or copy trajectory frames. The fallback preserves
compatibility with older ``UniverseOperations`` implementations.

Args:
data_container: Source MDAnalysis universe or universe-like container.
molecule_id: Fragment index identifying the molecule to extract.

Returns:
MDAnalysis AtomGroup for the selected molecule
"""
return self._universe_operations.extract_fragment_atomgroup(
data_container,
int(molecule_id),
)

def _select_heavy_residue(self, mol: Any, res_id: int) -> Any:
"""Select heavy atoms in a residue by residue index.

Expand All @@ -100,13 +118,15 @@ def _select_heavy_residue(self, mol: Any, res_id: int) -> Any:
Returns:
AtomGroup containing heavy atoms in the residue selection.
"""
selection1 = mol.residues[res_id].atoms.indices[0]
selection2 = mol.residues[res_id].atoms.indices[-1]
residue_atoms = mol.residues[int(res_id)].atoms
selection1 = residue_atoms.indices[0]
selection2 = residue_atoms.indices[-1]

res_container = self._universe_operations.select_atoms(
mol, f"index {selection1}:{selection2}"
res_container = mol.select_atoms(
f"index {selection1}:{selection2}",
updating=False,
)
return self._universe_operations.select_atoms(res_container, "prop mass > 1.1")
return res_container.select_atoms("prop mass > 1.1", updating=False)

def _get_dihedrals(self, data_container: Any, level: str) -> list[Any]:
"""Return dihedral AtomGroups for a container at a given level.
Expand All @@ -121,26 +141,93 @@ def _get_dihedrals(self, data_container: Any, level: str) -> list[Any]:
atom_groups: list[Any] = []

if level == "united_atom":
selected_indices = {int(index) for index in data_container.indices}

for dihedral in data_container.dihedrals:
atom_groups.append(dihedral.atoms)
dihedral_atoms = dihedral.atoms
dihedral_indices = {int(index) for index in dihedral_atoms.indices}

if len(dihedral_atoms) == 4 and dihedral_indices.issubset(
selected_indices
):
atom_groups.append(dihedral_atoms)

if level == "residue":
num_residues = len(data_container.residues)
if num_residues >= 4:
for residue in range(4, num_residues + 1):
atom1 = data_container.select_atoms(
f"resindex {residue - 4} and bonded resindex {residue - 3}"
residue1 = data_container.residues[residue - 4]
residue2 = data_container.residues[residue - 3]
residue3 = data_container.residues[residue - 2]
residue4 = data_container.residues[residue - 1]

atom1 = self._atoms_in_source_bonded_to_target(
residue1,
residue2,
)
atom2 = data_container.select_atoms(
f"resindex {residue - 3} and bonded resindex {residue - 4}"
atom2 = self._atoms_in_source_bonded_to_target(
residue2,
residue1,
)
atom3 = data_container.select_atoms(
f"resindex {residue - 2} and bonded resindex {residue - 1}"
atom3 = self._atoms_in_source_bonded_to_target(
residue3,
residue4,
)
atom4 = data_container.select_atoms(
f"resindex {residue - 1} and bonded resindex {residue - 2}"
atom4 = self._atoms_in_source_bonded_to_target(
residue4,
residue3,
)
atom_groups.append(atom1 + atom2 + atom3 + atom4)

dihedral_atoms = atom1 + atom2 + atom3 + atom4

if len(dihedral_atoms) == 4:
atom_groups.append(dihedral_atoms)
else:
logger.debug(
"Skipping residue-level dihedral for local residues "
"%s-%s-%s-%s because it produced %d atoms.",
residue - 4,
residue - 3,
residue - 2,
residue - 1,
len(dihedral_atoms),
)

logger.debug("Level: %s, Dihedrals: %s", level, atom_groups)
return atom_groups

@staticmethod
def _atoms_in_source_bonded_to_target(
source_residue: Any,
target_residue: Any,
) -> Any:
"""Return source-residue atoms bonded to atoms in a target residue.

This helper is used when constructing residue-level dihedral definitions
from lightweight molecule AtomGroups. It selects atoms from the source
residue that are bonded to any atom in the target residue without using
global ``resindex`` selection strings.

Args:
source_residue: Residue whose atoms should be tested for bonds.
target_residue: Adjacent residue providing the target bonded atoms.

Returns:
MDAnalysis AtomGroup containing atoms from ``source_residue`` that are
bonded to at least one atom in ``target_residue``. If no matching
atoms are found, an empty AtomGroup is returned.
"""
source_atoms = source_residue.atoms
target_indices = {int(index) for index in target_residue.atoms.indices}
selected_indices: list[int] = []

for atom in source_atoms:
bonded_atoms = getattr(atom, "bonded_atoms", None)
if bonded_atoms is None:
continue

bonded_indices = {int(index) for index in bonded_atoms.indices}
if bonded_indices.intersection(target_indices):
selected_indices.append(int(atom.index))

return source_atoms.universe.atoms[selected_indices]
6 changes: 6 additions & 0 deletions CodeEntropy/levels/level_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from CodeEntropy.levels.frame_dag import FrameGraph
from CodeEntropy.levels.neighbors import Neighbors
from CodeEntropy.levels.nodes.accumulators import InitCovarianceAccumulatorsNode
from CodeEntropy.levels.nodes.axes_topology import BuildAxesTopologyNode
from CodeEntropy.levels.nodes.beads import BuildBeadsNode
from CodeEntropy.levels.nodes.detect_levels import DetectLevelsNode
from CodeEntropy.levels.nodes.detect_molecules import DetectMoleculesNode
Expand Down Expand Up @@ -49,6 +50,11 @@ def build(self) -> LevelDAG:
self._add_static("detect_molecules", DetectMoleculesNode())
self._add_static("detect_levels", DetectLevelsNode(), deps=["detect_molecules"])
self._add_static("build_beads", BuildBeadsNode(), deps=["detect_levels"])
self._add_static(
"build_axes_topology",
BuildAxesTopologyNode(),
deps=["build_beads"],
)
self._add_static(
"init_covariance_accumulators",
InitCovarianceAccumulatorsNode(),
Expand Down
Loading