diff --git a/red/equilibration.py b/red/equilibration.py index f71f7c3..d6c90bf 100644 --- a/red/equilibration.py +++ b/red/equilibration.py @@ -2,6 +2,7 @@ from pathlib import Path as _Path from typing import Callable as _Callable +from typing import Literal as _Literal from typing import Optional as _Optional from typing import Tuple as _Tuple from typing import Union as _Union @@ -200,6 +201,7 @@ def detect_equilibration_window( time_units: str = "ns", data_y_label: str = r"$\Delta G$ / kcal mol$^{-1}$", plot_window_size: bool = True, + backend: _Literal["numba", "numpy"] = "numba", ) -> _Tuple[_Union[float, int], float, float]: r""" Detect the equilibration time of a time series by finding the minimum @@ -258,6 +260,10 @@ def detect_equilibration_window( plot_window_size : bool, optional, default=True Whether to plot the window size used to estimate the variance. + backend : str, optional, default="numba" + The backend to use for computation. Can be "numba" (faster, requires numba) + or "numpy" (pure numpy, no numba dependency). + Returns ------- equil_time: float | int @@ -292,6 +298,7 @@ def detect_equilibration_window( window_size_fn=window_size_fn, window_size=window_size, frac_padding=frac_padding, + backend=backend, ) # Get the corresponding times (or indices). diff --git a/red/ess.py b/red/ess.py index 35150fb..1d83ae8 100644 --- a/red/ess.py +++ b/red/ess.py @@ -115,7 +115,7 @@ def get_ess_series_window( kernel: _Callable[[int], _npt.NDArray[_np.float64]] = _np.bartlett, # type: ignore window_size_fn: _Optional[_Callable[[int], int]] = lambda x: round(x**0.5), window_size: _Optional[int] = None, -) -> _Tuple[_npt.NDArray[_np.float64], _npt.NDArray[_np.float64]]: +) -> _Tuple[_npt.NDArray[_np.float64], _npt.NDArray[_np.int64]]: """ Compute a series of effective sample sizes for a time series as data is discarded from the beginning of the time series. The squared standard diff --git a/red/plot.py b/red/plot.py index feb68cf..e0b6c7c 100644 --- a/red/plot.py +++ b/red/plot.py @@ -202,7 +202,7 @@ def plot_sse( ax: _Axes, sse: _npt.NDArray[_np.float64], max_lags: _Optional[_npt.NDArray[_np.float64]], - window_sizes: _Optional[_npt.NDArray[_np.float64]], + window_sizes: _Optional[_npt.NDArray[_np.int64]], times: _npt.NDArray[_Union[_np.int64, _np.float64]], time_units: str = "ns", variance_y_label: str = r"$\frac{1}{\sigma^2(\Delta G)}$ / kcal$^{-2}$ mol$^2$", @@ -412,7 +412,7 @@ def plot_equilibration_min_sse( data_times: _npt.NDArray[_Union[_np.int64, _np.float64]], sse_times: _npt.NDArray[_Union[_np.int64, _np.float64]], max_lag_series: _Optional[_npt.NDArray[_np.float64]] = None, - window_size_series: _Optional[_npt.NDArray[_np.float64]] = None, + window_size_series: _Optional[_npt.NDArray[_np.int64]] = None, time_units: str = "ns", data_y_label: str = r"$\Delta G$ / kcal mol$^{-1}$", variance_y_label: str = r"$\frac{1}{\sigma^2(\Delta G)}$ / kcal$^{-2}$ mol$^2$", diff --git a/red/sse.py b/red/sse.py index 19a1eda..636e81a 100644 --- a/red/sse.py +++ b/red/sse.py @@ -1,6 +1,7 @@ """Functions to calculate the squared standard error series.""" from typing import Callable as _Callable +from typing import Literal as _Literal from typing import Optional as _Optional from typing import Tuple as _Tuple @@ -90,7 +91,8 @@ def get_sse_series_window( window_size_fn: _Optional[_Callable[[int], int]] = lambda x: round(x**0.5), window_size: _Optional[int] = None, frac_padding: float = 0.1, -) -> _Tuple[_npt.NDArray[_np.float64], _npt.NDArray[_np.float64]]: + backend: _Literal["numba", "numpy"] = "numba", +) -> _Tuple[_npt.NDArray[_np.float64], _npt.NDArray[_np.int64]]: """ Compute a series of squared standard errors for a time series as data is discarded from the beginning of the time series. The squared standard @@ -118,6 +120,10 @@ def get_sse_series_window( for the first 90% of the time series. This helps to avoid noise in the variance when there are few data points. + backend : str, optional, default="numba" + The backend to use for computation. Can be "numba" (faster, requires numba) + or "numpy" (pure numpy, no numba dependency). + Returns ------- np.ndarray @@ -137,6 +143,7 @@ def get_sse_series_window( window_size_fn=window_size_fn, window_size=window_size, frac_padding=frac_padding, + backend=backend, ) # Compute the squared standard error series by dividing the variance series by diff --git a/red/tests/test_equilibration.py b/red/tests/test_equilibration.py index eab8d34..d1e18a7 100644 --- a/red/tests/test_equilibration.py +++ b/red/tests/test_equilibration.py @@ -121,6 +121,37 @@ def test_detect_equilibration_window(example_timeseries, example_times, tmpdir): assert tmp_output.with_suffix(".png").exists() +def test_detect_equilibration_window_backends(example_timeseries, example_times): + """ + Test that numba and numpy backends yield the same results. + """ + # Use the mean time to make this faster. + example_timeseries = example_timeseries.mean(axis=0) + + # Compute with numba backend + equil_idx_numba, equil_g_numba, equil_ess_numba = detect_equilibration_window( + data=example_timeseries, + times=example_times, + method="min_sse", + plot=False, + backend="numba", + ) + + # Compute with numpy backend + equil_idx_numpy, equil_g_numpy, equil_ess_numpy = detect_equilibration_window( + data=example_timeseries, + times=example_times, + method="min_sse", + plot=False, + backend="numpy", + ) + + # Check that results match + assert equil_idx_numba == equil_idx_numpy + assert equil_g_numba == pytest.approx(equil_g_numpy, rel=1e-10) + assert equil_ess_numba == pytest.approx(equil_ess_numpy, rel=1e-10) + + @pytest.mark.parametrize( "equil_fn, equil_fn_args", [ diff --git a/red/variance.py b/red/variance.py index b1ab069..becc81b 100644 --- a/red/variance.py +++ b/red/variance.py @@ -12,14 +12,16 @@ """ +from typing import Any as _Any from typing import Callable as _Callable +from typing import Literal as _Literal from typing import Optional as _Optional from typing import Tuple as _Tuple +from typing import TypeVar as _TypeVar from typing import Union as _Union from typing import cast as _cast from warnings import warn as _warn -import numba as _numba import numpy as _np import numpy.typing as _npt from statsmodels.tsa.stattools import acovf as _acovf @@ -27,11 +29,34 @@ from ._exceptions import AnalysisError, InvalidInputError from ._validation import check_data as _check_data +# Optional numba import +try: + import numba as _numba + + _NUMBA_AVAILABLE = True + _prange = _numba.prange +except ImportError: + _numba = None # type: ignore + _NUMBA_AVAILABLE = False + _prange = range # type: ignore + + +_F = _TypeVar("_F", bound=_Callable[..., _Any]) + + +def _optional_njit(*args: _Any, **kwargs: _Any) -> _Callable[[_F], _F]: + """Decorator that applies numba.njit if available, otherwise returns the function unchanged.""" + def decorator(func: _F) -> _F: + if _NUMBA_AVAILABLE: + return _numba.njit(*args, **kwargs)(func) # type: ignore[no-any-return] + return func + return decorator + ####### Private functions ####### # No need to thoroughly validate input as this is done in the public functions. -@_numba.njit(cache=True) # type: ignore +@_optional_njit(cache=True) # type: ignore def _compute_autocovariance_no_fft( data: _npt.NDArray[_np.float64], max_lag: int ) -> _npt.NDArray[_np.float64]: @@ -153,7 +178,7 @@ def _get_autocovariance( return compute_autocov_fn(data, max_lag) -@_numba.njit(cache=True) # type: ignore +@_optional_njit(cache=True) # type: ignore def _get_gamma_cap( autocov_series: _npt.NDArray[_np.float64], ) -> _npt.NDArray[_np.float64]: @@ -190,7 +215,7 @@ def _get_gamma_cap( return gamma -@_numba.njit(cache=True) # type: ignore +@_optional_njit(cache=True) # type: ignore def _get_initial_positive_sequence( gamma_cap: _npt.NDArray[_np.float64], min_max_lag_time: int = 3, @@ -226,7 +251,7 @@ def _get_initial_positive_sequence( return gamma_cap -@_numba.njit(cache=True) # type: ignore +@_optional_njit(cache=True) # type: ignore def _get_initial_monotone_sequence( gamma_cap: _npt.NDArray[_np.float64], min_max_lag_time: int = 3, @@ -262,7 +287,7 @@ def _get_initial_monotone_sequence( return gamma_cap -@_numba.njit(cache=True) # type: ignore +@_optional_njit(cache=True) # type: ignore def _get_initial_convex_sequence( gamma_cap: _npt.NDArray[_np.float64], min_max_lag_time: int = 3, @@ -338,7 +363,7 @@ def _get_initial_convex_sequence( gamma_con[j] = gamma_con[j - 1] + mean_pooled_value j += 1 - return _cast(_npt.NDArray[_np.float64], gamma_con) + return gamma_con def _get_autocovariance_window( @@ -795,13 +820,180 @@ def get_variance_window( return float(max(corr_var, var)) +@_optional_njit(cache=True, parallel=True) # type: ignore +def _compute_variance_series_window_indexed( + data: _npt.NDArray[_np.float64], + indices: _npt.NDArray[_np.int64], + window_sizes: _npt.NDArray[_np.int64], + windows_flat: _npt.NDArray[_np.float64], + window_offsets: _npt.NDArray[_np.int64], + window_size_to_idx: _npt.NDArray[_np.int64], + max_window_size: int, +) -> _npt.NDArray[_np.float64]: + """ + Fast numba-accelerated computation of variance series using window estimators. + Parallelized over the specified indices with minimal overhead. + + Parameters + ---------- + data : numpy.ndarray + A time series of data with shape (n_runs, n_samples). + + indices : numpy.ndarray + Array of indices to compute variance for. + + window_sizes : numpy.ndarray + Window size to use at each position in indices array. + + windows_flat : numpy.ndarray + Flattened array containing all precomputed windows. + + window_offsets : numpy.ndarray + Start offset for each unique window size in windows_flat. + + window_size_to_idx : numpy.ndarray + Mapping from window size to index in window_offsets. + + max_window_size : int + Maximum window size (for array sizing). + + Returns + ------- + numpy.ndarray + The variance series (one value per index in indices). + """ + n_runs, n_samples = data.shape + n_indices = len(indices) + variance_series = _np.zeros(n_indices) + + for idx_pos in _prange(n_indices): + index = indices[idx_pos] + curr_window_size = window_sizes[idx_pos] + remaining_samples = n_samples - index + + # Compute the mean for this truncated series + curr_sum = 0.0 + for run in range(n_runs): + for i in range(index, n_samples): + curr_sum += data[run, i] + curr_mean = curr_sum / (n_runs * remaining_samples) + + # Get window offset + ws_idx = window_size_to_idx[curr_window_size] + window_start = window_offsets[ws_idx] + + # Compute windowed autocovariance sum directly (combine autocov computation and windowing) + windowed_sum = 0.0 + for lag in range(curr_window_size + 1): + # Get window value + window_val = windows_flat[window_start + lag] + + # Compute autocovariance for this lag across all runs + acov_sum = 0.0 + for run in range(n_runs): + run_sum = 0.0 + for i in range(index, n_samples - lag): + run_sum += (data[run, i] - curr_mean) * (data[run, i + lag] - curr_mean) + acov_sum += run_sum + acov_sum /= n_runs * remaining_samples + + windowed_sum += acov_sum * window_val + + # Compute uncorrelated variance + var_sum = 0.0 + for run in range(n_runs): + for i in range(index, n_samples): + diff = data[run, i] - curr_mean + var_sum += diff * diff + var_uncorr = var_sum / (n_runs * remaining_samples) + + # Account for correlation in both directions + corr_var = windowed_sum * 2 - var_uncorr + + # Ensure variance is at least the uncorrelated value + if corr_var > var_uncorr: + variance_series[idx_pos] = corr_var + else: + variance_series[idx_pos] = var_uncorr + + return variance_series + + +def _compute_variance_series_window_numpy( + data: _npt.NDArray[_np.float64], + indices: _npt.NDArray[_np.int64], + window_sizes: _npt.NDArray[_np.int64], + kernel: _Callable[[int], _npt.NDArray[_np.float64]], +) -> _npt.NDArray[_np.float64]: + """ + Pure numpy implementation of variance series computation using window estimators. + This is used when numba is not available. + + Parameters + ---------- + data : numpy.ndarray + A time series of data with shape (n_runs, n_samples). + + indices : numpy.ndarray + Array of indices to compute variance for. + + window_sizes : numpy.ndarray + Window size to use at each position in indices array. + + kernel : callable + A function that takes a window size and returns a window function. + + Returns + ------- + numpy.ndarray + The variance series (one value per index in indices). + """ + n_runs, _ = data.shape + n_indices = len(indices) + variance_series = _np.zeros(n_indices) + + # Precompute unique windows + unique_window_sizes = _np.unique(window_sizes) + windows_cache = {ws: kernel(2 * ws + 1)[ws:] for ws in unique_window_sizes} + + for idx_pos, index in enumerate(indices): + curr_window_size = window_sizes[idx_pos] + truncated_data = data[:, index:] + + # Get the shared mean across all runs + curr_mean = truncated_data.mean() + + # Compute autocovariance using direct computation + autocov = _np.zeros(curr_window_size + 1) + for run in range(n_runs): + run_data = truncated_data[run] - curr_mean + autocov += _compute_autocovariance_no_fft(run_data, curr_window_size) + autocov /= n_runs + + # Get the precomputed window and apply it + window = windows_cache[curr_window_size] + windowed_autocov = autocov * window + + # Compute uncorrelated variance + var_uncorr = truncated_data.var() + + # Account for correlation in both directions + corr_var = windowed_autocov.sum() * 2 - var_uncorr + + # Ensure variance is at least the uncorrelated value + variance_series[idx_pos] = max(corr_var, var_uncorr) + + return variance_series + + def get_variance_series_window( data: _npt.NDArray[_np.float64], kernel: _Callable[[int], _npt.NDArray[_np.float64]] = _np.bartlett, # type: ignore window_size_fn: _Optional[_Callable[[int], int]] = lambda x: round(x**0.5), window_size: _Optional[int] = None, frac_padding: float = 0.1, -) -> _Tuple[_npt.NDArray[_np.float64], _npt.NDArray[_np.float64]]: + backend: _Literal["numba", "numpy"] = "numba", +) -> _Tuple[_npt.NDArray[_np.float64], _npt.NDArray[_np.int64]]: """ Repeatedly calculate the variance of a time series while discarding increasing numbers of samples from the start of the time series. The variance is calculated @@ -829,6 +1021,10 @@ def get_variance_series_window( for the first 90% of the time series. This helps to avoid noise in the variance when there are few data points. + backend : str, optional, default="numba" + The backend to use for computation. Can be "numba" (faster, requires numba) + or "numpy" (pure numpy, no numba dependency). + Returns ------- numpy.ndarray @@ -839,7 +1035,7 @@ def get_variance_series_window( """ # Check that the data is valid. data = _check_data(data, one_dim_allowed=True) - n_samples = data.shape[1] + n_runs, n_samples = data.shape # Check that only one of window_size_fn and window_size is not None. if window_size_fn is not None and window_size is not None: @@ -871,18 +1067,68 @@ def get_variance_series_window( frac_padding_max_index = round(n_samples * (1 - frac_padding)) max_index = min(max_index, frac_padding_max_index) - # Calculate the variance at each index and store the window size used. - variance_series = _np.zeros(max_index + 1) - window_size_series = _np.zeros(max_index + 1, dtype=int) + # Generate the indices to evaluate (all indices from 0 to max_index) + indices = _np.arange(0, max_index + 1, dtype=_np.int64) + n_indices = len(indices) - for index in range(max_index + 1): - window_size = window_size_fn(n_samples - index) if window_size_fn else window_size - variance_series[index] = get_variance_window( - data[:, index:], - kernel=kernel, - window_size=window_size, # type: ignore + # Precompute all window sizes + window_size_series = _np.zeros(n_indices, dtype=_np.int64) + if window_size_fn is not None: + for i in range(n_indices): + window_size_series[i] = window_size_fn(n_samples - i) + else: + window_size_series[:] = window_size # type: ignore + + # Ensure data is contiguous and float64 + data = _np.ascontiguousarray(data, dtype=_np.float64) + + # Choose backend + use_numba = backend == "numba" and _NUMBA_AVAILABLE + if backend == "numba" and not _NUMBA_AVAILABLE: + _warn( + "Numba backend requested but numba is not installed. Falling back to numpy backend.", + stacklevel=2, + ) + + if use_numba: + # Precompute unique windows and create lookup structures for numba + unique_window_sizes = _np.unique(window_size_series) + max_window_size = int(unique_window_sizes.max()) + + # Create flattened windows array and offset lookup + total_size = sum(ws + 1 for ws in unique_window_sizes) + windows_flat = _np.zeros(total_size, dtype=_np.float64) + window_offsets = _np.zeros(len(unique_window_sizes), dtype=_np.int64) + + # Create mapping from window size to index + window_size_to_idx = _np.zeros(max_window_size + 1, dtype=_np.int64) + + offset = 0 + for i, ws in enumerate(unique_window_sizes): + window_offsets[i] = offset + window_size_to_idx[ws] = i + window = kernel(2 * ws + 1)[ws:] + windows_flat[offset : offset + ws + 1] = window + offset += ws + 1 + + # Use fast numba implementation + variance_series = _compute_variance_series_window_indexed( + data, + indices, + window_size_series, + windows_flat, + window_offsets, + window_size_to_idx, + max_window_size, + ) + else: + # Use pure numpy implementation + variance_series = _compute_variance_series_window_numpy( + data, + indices, + window_size_series, + kernel, ) - window_size_series[index] = window_size return variance_series, window_size_series