Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions CodeEntropy/levels/axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,161 @@ def get_UA_axes(self, data_container, index: int):

return trans_axes, rot_axes, center, moment_of_inertia

def get_UA_axes_from_topology(
self,
*,
u,
residue_atoms,
topology,
box: np.ndarray | None,
):
"""Compute UA axes using cached static topology.

This is the cached-index equivalent of ``get_UA_axes``. It preserves the
frame-dependent numerical calculations, but avoids repeated MDAnalysis
selection strings for heavy atoms, bonded atoms, and UA masses.

Args:
u: Current-frame universe.
residue_atoms: AtomGroup for the parent residue in the current frame.
topology: Cached ``UAAxesTopology`` for this UA bead.
box: Current periodic box lengths. If omitted, ``u.dimensions`` is used.

Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
- trans_axes: Translational axes, shape ``(3, 3)``.
- rot_axes: Rotational axes, shape ``(3, 3)``.
- center: Rotation centre, shape ``(3,)``.
- moment_of_inertia: Principal moments, shape ``(3,)``.

Raises:
ValueError: If cached bonded-axis construction fails.
"""
dimensions = (
np.asarray(box, dtype=float)
if box is not None
else np.asarray(u.dimensions[:3], dtype=float)
)

heavy_atoms = u.atoms[topology.residue_heavy_indices]
heavy_atom = u.atoms[int(topology.heavy_atom_index)]

if len(heavy_atoms) > 1:
center = residue_atoms.center_of_mass(unwrap=True)
moment_of_inertia_tensor = self.get_moment_of_inertia_tensor(
center_of_mass=center,
positions=heavy_atoms.positions,
masses=topology.residue_ua_masses,
dimensions=dimensions,
)
trans_axes, _moment_of_inertia = self.get_custom_principal_axes(
moment_of_inertia_tensor
)
else:
make_whole(residue_atoms)
trans_axes = residue_atoms.principal_axes()

center = heavy_atom.position
rot_axes, moment_of_inertia = self.get_bonded_axes_from_topology(
u=u,
heavy_atom=heavy_atom,
topology=topology,
dimensions=dimensions,
)
if rot_axes is None or moment_of_inertia is None:
raise ValueError("Unable to compute bonded axes for cached UA bead.")

logger.debug("Translational Axes: %s", trans_axes)
logger.debug("Rotational Axes: %s", rot_axes)
logger.debug("Center: %s", center)
logger.debug("Moment of Inertia: %s", moment_of_inertia)

return trans_axes, rot_axes, center, moment_of_inertia

def get_bonded_axes_from_topology(
self,
*,
u,
heavy_atom,
topology,
dimensions: np.ndarray,
):
"""Compute UA bonded axes using cached bonded atom indices.

This mirrors ``get_bonded_axes`` but receives precomputed bonded atom
memberships from ``UAAxesTopology`` instead of rediscovering them with
MDAnalysis selection strings inside the frame loop.

Args:
u: Current-frame universe.
heavy_atom: Current-frame heavy atom for the UA bead.
topology: Cached ``UAAxesTopology`` for the UA bead.
dimensions: Simulation box lengths, shape ``(3,)``.

Returns:
Tuple[np.ndarray | None, np.ndarray | None]:
- custom_axes: Custom rotation axes, shape ``(3, 3)``, or ``None``.
- custom_moment_of_inertia: Principal moments, shape ``(3,)``, or
``None``.
"""
if not heavy_atom.mass > 1.1:
return None, None

custom_moment_of_inertia = None
custom_axes = None

heavy_bonded = u.atoms[topology.bonded_heavy_indices]
light_bonded = u.atoms[topology.bonded_light_indices]
ua = u.atoms[topology.ua_atom_indices]
ua_all = u.atoms[topology.ua_all_atom_indices]

if len(heavy_bonded) == 0:
custom_axes, custom_moment_of_inertia = self.get_vanilla_axes(ua_all)

if len(heavy_bonded) == 1 and len(light_bonded) == 0:
custom_axes = self.get_custom_axes(
a=heavy_atom.position,
b_list=[heavy_bonded[0].position],
c=np.zeros(3),
dimensions=dimensions,
)

if len(heavy_bonded) == 1 and len(light_bonded) >= 1:
custom_axes = self.get_custom_axes(
a=heavy_atom.position,
b_list=[heavy_bonded[0].position],
c=light_bonded[0].position,
dimensions=dimensions,
)

if len(heavy_bonded) >= 2:
custom_axes = self.get_custom_axes(
a=heavy_atom.position,
b_list=heavy_bonded.positions,
c=heavy_bonded[1].position,
dimensions=dimensions,
)

if custom_axes is None:
return None, None

if custom_moment_of_inertia is None:
custom_moment_of_inertia = self.get_custom_moment_of_inertia(
UA=ua,
custom_rotation_axes=custom_axes,
center_of_mass=heavy_atom.position,
dimensions=dimensions,
)

custom_axes = self.get_flipped_axes(
ua,
custom_axes,
heavy_atom.position,
dimensions,
)

return custom_axes, custom_moment_of_inertia

def get_bonded_axes(self, system, atom, dimensions: np.ndarray):
r"""Compute UA rotational axes from bonded topology around a heavy atom.

Expand Down
6 changes: 6 additions & 0 deletions CodeEntropy/levels/level_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from CodeEntropy.levels.frame_dag import FrameGraph
from CodeEntropy.levels.neighbors import Neighbors
from CodeEntropy.levels.nodes.accumulators import InitCovarianceAccumulatorsNode
from CodeEntropy.levels.nodes.axes_topology import BuildAxesTopologyNode
from CodeEntropy.levels.nodes.beads import BuildBeadsNode
from CodeEntropy.levels.nodes.detect_levels import DetectLevelsNode
from CodeEntropy.levels.nodes.detect_molecules import DetectMoleculesNode
Expand Down Expand Up @@ -49,6 +50,11 @@ def build(self) -> LevelDAG:
self._add_static("detect_molecules", DetectMoleculesNode())
self._add_static("detect_levels", DetectLevelsNode(), deps=["detect_molecules"])
self._add_static("build_beads", BuildBeadsNode(), deps=["detect_levels"])
self._add_static(
"build_axes_topology",
BuildAxesTopologyNode(),
deps=["build_beads"],
)
self._add_static(
"init_covariance_accumulators",
InitCovarianceAccumulatorsNode(),
Expand Down
Loading