Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions red/equilibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from pathlib import Path as _Path
from typing import Callable as _Callable
from typing import Literal as _Literal
from typing import Optional as _Optional
from typing import Tuple as _Tuple
from typing import Union as _Union
Expand Down Expand Up @@ -200,6 +201,7 @@ def detect_equilibration_window(
time_units: str = "ns",
data_y_label: str = r"$\Delta G$ / kcal mol$^{-1}$",
plot_window_size: bool = True,
backend: _Literal["numba", "numpy"] = "numba",
) -> _Tuple[_Union[float, int], float, float]:
r"""
Detect the equilibration time of a time series by finding the minimum
Expand Down Expand Up @@ -258,6 +260,10 @@ def detect_equilibration_window(
plot_window_size : bool, optional, default=True
Whether to plot the window size used to estimate the variance.

backend : str, optional, default="numba"
The backend to use for computation. Can be "numba" (faster, requires numba)
or "numpy" (pure numpy, no numba dependency).

Returns
-------
equil_time: float | int
Expand Down Expand Up @@ -292,6 +298,7 @@ def detect_equilibration_window(
window_size_fn=window_size_fn,
window_size=window_size,
frac_padding=frac_padding,
backend=backend,
)

# Get the corresponding times (or indices).
Expand Down
2 changes: 1 addition & 1 deletion red/ess.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def get_ess_series_window(
kernel: _Callable[[int], _npt.NDArray[_np.float64]] = _np.bartlett, # type: ignore
window_size_fn: _Optional[_Callable[[int], int]] = lambda x: round(x**0.5),
window_size: _Optional[int] = None,
) -> _Tuple[_npt.NDArray[_np.float64], _npt.NDArray[_np.float64]]:
) -> _Tuple[_npt.NDArray[_np.float64], _npt.NDArray[_np.int64]]:
"""
Compute a series of effective sample sizes for a time series as data
is discarded from the beginning of the time series. The squared standard
Expand Down
4 changes: 2 additions & 2 deletions red/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def plot_sse(
ax: _Axes,
sse: _npt.NDArray[_np.float64],
max_lags: _Optional[_npt.NDArray[_np.float64]],
window_sizes: _Optional[_npt.NDArray[_np.float64]],
window_sizes: _Optional[_npt.NDArray[_np.int64]],
times: _npt.NDArray[_Union[_np.int64, _np.float64]],
time_units: str = "ns",
variance_y_label: str = r"$\frac{1}{\sigma^2(\Delta G)}$ / kcal$^{-2}$ mol$^2$",
Expand Down Expand Up @@ -412,7 +412,7 @@ def plot_equilibration_min_sse(
data_times: _npt.NDArray[_Union[_np.int64, _np.float64]],
sse_times: _npt.NDArray[_Union[_np.int64, _np.float64]],
max_lag_series: _Optional[_npt.NDArray[_np.float64]] = None,
window_size_series: _Optional[_npt.NDArray[_np.float64]] = None,
window_size_series: _Optional[_npt.NDArray[_np.int64]] = None,
time_units: str = "ns",
data_y_label: str = r"$\Delta G$ / kcal mol$^{-1}$",
variance_y_label: str = r"$\frac{1}{\sigma^2(\Delta G)}$ / kcal$^{-2}$ mol$^2$",
Expand Down
9 changes: 8 additions & 1 deletion red/sse.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Functions to calculate the squared standard error series."""

from typing import Callable as _Callable
from typing import Literal as _Literal
from typing import Optional as _Optional
from typing import Tuple as _Tuple

Expand Down Expand Up @@ -90,7 +91,8 @@ def get_sse_series_window(
window_size_fn: _Optional[_Callable[[int], int]] = lambda x: round(x**0.5),
window_size: _Optional[int] = None,
frac_padding: float = 0.1,
) -> _Tuple[_npt.NDArray[_np.float64], _npt.NDArray[_np.float64]]:
backend: _Literal["numba", "numpy"] = "numba",
) -> _Tuple[_npt.NDArray[_np.float64], _npt.NDArray[_np.int64]]:
"""
Compute a series of squared standard errors for a time series as data
is discarded from the beginning of the time series. The squared standard
Expand Down Expand Up @@ -118,6 +120,10 @@ def get_sse_series_window(
for the first 90% of the time series. This helps to avoid noise in the
variance when there are few data points.

backend : str, optional, default="numba"
The backend to use for computation. Can be "numba" (faster, requires numba)
or "numpy" (pure numpy, no numba dependency).

Returns
-------
np.ndarray
Expand All @@ -137,6 +143,7 @@ def get_sse_series_window(
window_size_fn=window_size_fn,
window_size=window_size,
frac_padding=frac_padding,
backend=backend,
)

# Compute the squared standard error series by dividing the variance series by
Expand Down
31 changes: 31 additions & 0 deletions red/tests/test_equilibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,37 @@ def test_detect_equilibration_window(example_timeseries, example_times, tmpdir):
assert tmp_output.with_suffix(".png").exists()


def test_detect_equilibration_window_backends(example_timeseries, example_times):
"""
Test that numba and numpy backends yield the same results.
"""
# Use the mean time to make this faster.
example_timeseries = example_timeseries.mean(axis=0)

# Compute with numba backend
equil_idx_numba, equil_g_numba, equil_ess_numba = detect_equilibration_window(
data=example_timeseries,
times=example_times,
method="min_sse",
plot=False,
backend="numba",
)

# Compute with numpy backend
equil_idx_numpy, equil_g_numpy, equil_ess_numpy = detect_equilibration_window(
data=example_timeseries,
times=example_times,
method="min_sse",
plot=False,
backend="numpy",
)

# Check that results match
assert equil_idx_numba == equil_idx_numpy
assert equil_g_numba == pytest.approx(equil_g_numpy, rel=1e-10)
assert equil_ess_numba == pytest.approx(equil_ess_numpy, rel=1e-10)


@pytest.mark.parametrize(
"equil_fn, equil_fn_args",
[
Expand Down
Loading
Loading