Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions CodeEntropy/levels/axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,64 @@ def get_residue_axes(self, data_container, index: int, residue=None):

return trans_axes, rot_axes, center, moment_of_inertia

def get_residue_axes_from_topology(
self,
*,
u,
mol,
residue_atoms,
topology,
box: np.ndarray | None,
):
"""Compute residue axes using cached static topology.

This is the cached-index equivalent of ``get_residue_axes``. It keeps
all frame-dependent numerical work frame-local, but avoids repeated
MDAnalysis selections for residue heavy atoms, UA masses, and neighbour
bond discovery.

Args:
u: Current-frame universe used to resolve cached atom indices.
mol: Current-frame molecule fragment.
residue_atoms: AtomGroup for the residue in the current frame.
topology: Cached ``ResidueAxesTopology`` for this residue.
box: Current periodic box lengths. If omitted, ``u.dimensions`` is used.

Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
- trans_axes: Translational axes, shape ``(3, 3)``.
- rot_axes: Rotational axes, shape ``(3, 3)``.
- center: Residue centre, shape ``(3,)``.
- moment_of_inertia: Principal moments, shape ``(3,)``.
"""
dimensions = (
np.asarray(box, dtype=float)
if box is not None
else np.asarray(u.dimensions[:3], dtype=float)
)

center = residue_atoms.center_of_mass(unwrap=True)

if not topology.has_neighbor_bonds:
heavy_atoms = u.atoms[topology.residue_heavy_indices]
moment_of_inertia_tensor = self.get_moment_of_inertia_tensor(
center_of_mass=center,
positions=heavy_atoms.positions,
masses=topology.residue_ua_masses,
dimensions=dimensions,
)
rot_axes, moment_of_inertia = self.get_custom_principal_axes(
moment_of_inertia_tensor
)
trans_axes = rot_axes
else:
make_whole(mol.atoms)
trans_axes = mol.atoms.principal_axes()
rot_axes, moment_of_inertia = self.get_vanilla_axes(residue_atoms)
center = residue_atoms.center_of_mass(unwrap=True)

return trans_axes, rot_axes, center, moment_of_inertia

def get_UA_axes(self, data_container, index: int):
"""Compute united-atom-level translational and rotational axes.

Expand Down
119 changes: 106 additions & 13 deletions CodeEntropy/levels/nodes/axes_topology.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Build static axes-topology metadata for frame covariance calculations.

This module caches topology-only atom-index relationships needed by customised
united-atom axes calculations. The cache avoids repeated MDAnalysis selection
parsing inside the frame-local covariance loop while preserving frame-dependent
positions, forces, centres, axes, torques, and moments of inertia.
axes calculations. The cache avoids repeated MDAnalysis selection parsing inside
the frame-local covariance loop while preserving frame-dependent positions,
forces, centres, axes, torques, and moments of inertia.
"""

from __future__ import annotations
Expand All @@ -17,6 +17,7 @@
logger = logging.getLogger(__name__)

UAKey = tuple[int, int, int]
ResidueKey = tuple[int, int]


@dataclass(frozen=True)
Expand Down Expand Up @@ -44,16 +45,35 @@ class UAAxesTopology:
residue_ua_masses: np.ndarray


@dataclass(frozen=True)
class ResidueAxesTopology:
"""Static topology required to compute customised residue axes.

Attributes:
residue_heavy_indices: Heavy atom indices in the residue.
residue_ua_masses: UA masses for heavy atoms in the residue.
has_neighbor_bonds: Whether the residue is bonded to a neighbouring
residue according to the original customised residue-axis selection.
"""

residue_heavy_indices: np.ndarray
residue_ua_masses: np.ndarray
has_neighbor_bonds: bool


@dataclass(frozen=True)
class AxesTopology:
"""Cached axes topology for frame covariance calculations.

Attributes:
ua: Mapping from ``(mol_id, local_residue_id, ua_id)`` to cached
united-atom axes topology.
residue: Mapping from ``(mol_id, local_residue_id)`` to cached
residue axes topology.
"""

ua: dict[UAKey, UAAxesTopology] = field(default_factory=dict)
residue: dict[ResidueKey, ResidueAxesTopology] = field(default_factory=dict)


class BuildAxesTopologyNode:
Expand Down Expand Up @@ -86,24 +106,76 @@ def run(self, shared_data: dict[str, Any]) -> dict[str, Any]:
beads = shared_data["beads"]

ua_topology: dict[UAKey, UAAxesTopology] = {}
residue_topology: dict[ResidueKey, ResidueAxesTopology] = {}
fragments = u.atoms.fragments

for mol_id, level_list in enumerate(levels):
if "united_atom" not in level_list:
continue
mol = fragments[mol_id]

if "residue" in level_list:
self._add_residue_topology(
mol=mol,
mol_id=mol_id,
beads=beads,
out=residue_topology,
)

self._add_ua_topology(
u=u,
mol=fragments[mol_id],
mol_id=mol_id,
beads=beads,
out=ua_topology,
)
if "united_atom" in level_list:
self._add_ua_topology(
u=u,
mol=mol,
mol_id=mol_id,
beads=beads,
out=ua_topology,
)

topology = AxesTopology(ua=ua_topology)
topology = AxesTopology(ua=ua_topology, residue=residue_topology)
shared_data["axes_topology"] = topology
return {"axes_topology": topology}

def _add_residue_topology(
self,
*,
mol: Any,
mol_id: int,
beads: dict[Any, list[np.ndarray]],
out: dict[ResidueKey, ResidueAxesTopology],
) -> None:
"""Cache static residue axes topology for one molecule.

Args:
mol: Molecule AtomGroup.
mol_id: Molecule index.
beads: Bead-index mapping produced by ``BuildBeadsNode``.
out: Output residue topology mapping mutated in place.
"""
bead_key = (mol_id, "residue")
bead_idx_list = beads.get(bead_key, [])
if not bead_idx_list:
return

for local_res_i, residue in enumerate(mol.residues):
if local_res_i >= len(bead_idx_list):
continue

residue_atoms = residue.atoms
residue_heavy = residue_atoms.select_atoms("mass 2 to 999")
residue_heavy_indices = residue_heavy.indices.astype(int, copy=True)
residue_ua_masses = np.asarray(
self._get_ua_masses_from_topology(residue_atoms),
dtype=float,
)
has_neighbor_bonds = self._has_neighbor_bonds(
mol=mol,
local_res_i=local_res_i,
)

out[(mol_id, local_res_i)] = ResidueAxesTopology(
residue_heavy_indices=residue_heavy_indices,
residue_ua_masses=residue_ua_masses,
has_neighbor_bonds=has_neighbor_bonds,
)

def _add_ua_topology(
self,
*,
Expand Down Expand Up @@ -177,6 +249,27 @@ def _add_ua_topology(
residue_ua_masses=residue_ua_masses,
)

@staticmethod
def _has_neighbor_bonds(*, mol: Any, local_res_i: int) -> bool:
"""Return whether a residue is bonded to neighbouring residues.

Args:
mol: Molecule AtomGroup used for the original bonded-neighbour
selection.
local_res_i: Residue index local to ``mol``.

Returns:
True when the residue has bonded atoms in the previous or next
residue according to the original customised residue-axis query.
"""
index_prev = local_res_i - 1
index_next = local_res_i + 1
atom_set = mol.select_atoms(
f"(resindex {index_prev} or resindex {index_next}) "
f"and bonded resid {local_res_i}"
)
return len(atom_set) > 0

@staticmethod
def _split_bonded_atoms(atom: Any) -> tuple[Any, Any]:
"""Return bonded heavy and light atoms for one atom.
Expand Down
37 changes: 37 additions & 0 deletions CodeEntropy/levels/nodes/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def run(self, ctx: FrameCtx) -> dict[str, Any]:
group_id=group_id,
beads=beads,
axes_manager=axes_manager,
axes_topology=axes_topology,
box=box,
customised_axes=customised_axes,
force_partitioning=fp,
Expand Down Expand Up @@ -242,6 +243,7 @@ def _process_residue(
group_id: int,
beads: dict[Any, list[Any]],
axes_manager: Any,
axes_topology: Any | None,
box: np.ndarray | None,
customised_axes: bool,
force_partitioning: float,
Expand All @@ -261,6 +263,7 @@ def _process_residue(
group_id: Molecule-group identifier used for within-frame averaging.
beads: Mapping of bead keys to reduced-universe atom-index arrays.
axes_manager: Axes helper used to build translation and rotation axes.
axes_topology: Optional cached axes topology generated during static setup.
box: Optional periodic box vector.
customised_axes: Whether customised residue axes should be used.
force_partitioning: Force partitioning factor for highest-level vectors.
Expand All @@ -281,9 +284,12 @@ def _process_residue(
return

force_vecs, torque_vecs = self._build_residue_vectors(
u=u,
mol=mol,
mol_id=mol_id,
bead_groups=bead_groups,
axes_manager=axes_manager,
axes_topology=axes_topology,
box=box,
customised_axes=customised_axes,
force_partitioning=force_partitioning,
Expand Down Expand Up @@ -481,9 +487,12 @@ def _build_ua_vectors(
def _build_residue_vectors(
self,
*,
u: Any,
mol: Any,
mol_id: int,
bead_groups: list[Any],
axes_manager: Any,
axes_topology: Any | None,
box: np.ndarray | None,
customised_axes: bool,
force_partitioning: float,
Expand All @@ -492,9 +501,12 @@ def _build_residue_vectors(
"""Build force and torque vectors for residue beads.

Args:
u: Universe-like object used to resolve cached atom indices.
mol: Molecule fragment containing residues and atoms.
mol_id: Molecule index used in axes-topology lookup keys.
bead_groups: Atom groups representing residue beads.
axes_manager: Axes helper used to select axes, centres, and moments.
axes_topology: Optional cached axes topology generated during static setup.
box: Optional periodic box vector.
customised_axes: Whether customised residue axes should be used.
force_partitioning: Force partitioning factor for highest-level vectors.
Expand All @@ -508,10 +520,14 @@ def _build_residue_vectors(

for local_res_i, bead in enumerate(bead_groups):
trans_axes, rot_axes, center, moi = self._get_residue_axes(
u=u,
mol=mol,
mol_id=mol_id,
bead=bead,
local_res_i=local_res_i,
axes_manager=axes_manager,
axes_topology=axes_topology,
box=box,
customised_axes=customised_axes,
)

Expand Down Expand Up @@ -540,26 +556,47 @@ def _build_residue_vectors(
def _get_residue_axes(
self,
*,
u: Any,
mol: Any,
mol_id: int,
bead: Any,
local_res_i: int,
axes_manager: Any,
axes_topology: Any | None,
box: np.ndarray | None,
customised_axes: bool,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Return axes, centre, and inertia data for a residue bead.

Args:
u: Universe-like object used to resolve cached atom indices.
mol: Molecule fragment containing residues and atoms.
mol_id: Molecule index used in axes-topology lookup keys.
bead: Atom group representing the residue bead.
local_res_i: Residue index local to ``mol``.
axes_manager: Axes helper used to select axes, centres, and moments.
axes_topology: Optional cached axes topology generated during static setup.
box: Optional periodic box vector.
customised_axes: Whether customised residue axes should be used.

Returns:
A tuple of translation axes, rotation axes, centre, and moments of inertia.
"""
if customised_axes:
res = mol.residues[local_res_i]
residue_topology = None
if axes_topology is not None:
residue_topology = axes_topology.residue.get((mol_id, local_res_i))

if residue_topology is not None:
return axes_manager.get_residue_axes_from_topology(
u=u,
mol=mol,
residue_atoms=res.atoms,
topology=residue_topology,
box=box,
)

return axes_manager.get_residue_axes(mol, local_res_i, residue=res.atoms)

make_whole(mol.atoms)
Expand Down
Loading