diff --git a/climatecritters/model_critters/bistable_melcher.py b/climatecritters/model_critters/bistable_melcher.py index 35e69c5..c4984b6 100644 --- a/climatecritters/model_critters/bistable_melcher.py +++ b/climatecritters/model_critters/bistable_melcher.py @@ -283,23 +283,20 @@ def populate_diagnostics_from_history(self, time, history): alpha_raw = self.param_values.get('alpha', 0.0) if callable(alpha_raw) or hasattr(alpha_raw, 'get_forcing'): - alpha_vals = np.array([ + alpha_for_thresh = np.array([ float(self._resolve_param(alpha_raw, t, history[i])) for i, t in enumerate(time) ]) - alpha_for_thresh = float(np.mean(alpha_vals)) - elif hasattr(alpha_raw, '__len__'): - alpha_for_thresh = float(np.mean(alpha_raw)) else: - alpha_for_thresh = float(alpha_raw) + alpha_for_thresh = alpha_raw stadial, interstadial = self.compute_stability_thresholds(alpha_for_thresh) self.stadial_threshold = stadial self.interstadial_threshold = interstadial - self.diagnostic_variables['states'] = list( - _classify_states(db_vals, stadial, interstadial) - ) + self.diagnostic_variables = { + 'states': _classify_states(db_vals, stadial, interstadial) + } def compute_stability_thresholds(self, alpha): """Compute stadial and interstadial thresholds via Jacobian stability analysis. diff --git a/climatecritters/tests/test_signal_models_bistable_melcher.py b/climatecritters/tests/test_signal_models_bistable_melcher.py new file mode 100644 index 0000000..613e2e9 --- /dev/null +++ b/climatecritters/tests/test_signal_models_bistable_melcher.py @@ -0,0 +1,140 @@ +"""Tests for climatecritters.model_critters.bistable_melcher.""" + +import numpy as np +import pytest + +from climatecritters.model_critters.bistable_melcher import ( + BistableMelcherModel, classify_bistable_states, +) + + +class TestSignalModelsBistableMelcherIntegrate: + @pytest.mark.parametrize('y0', [[1.0, 0.0], [0.6, 0.0]]) + @pytest.mark.parametrize('method', ['heun_maruyama', 'euler_maruyama', 'milstein']) + def test_integrate_t0(self, y0, method): + model = BistableMelcherModel(sigma=0.2, gamma=1.5, alpha=-0.4) + output = model.integrate( + t_span=(0, 12), y0=y0, method=method, dt=0.012, + kwargs={'random_seed': 0, 'si': 0.12}, + ) + assert model.state_variables.dtype.names == ('db', 'B') + assert 'states' in output.diagnostic_variables + assert np.all(np.isfinite(output.state_variables['db'])) + assert np.all(np.isfinite(output.state_variables['B'])) + assert set(np.unique(output.diagnostic_variables['states'])) <= {0.0, 1.0} + + def test_integrate_with_deterministic_method_t1(self): + """uses_post_history=True models should also work with non-SDE methods.""" + model = BistableMelcherModel(alpha=-0.4) + output = model.integrate(t_span=(0, 1.2), y0=[1.0, 0.0], method='euler', dt=0.012) + assert np.all(np.isfinite(output.state_variables['db'])) + assert model.stadial_threshold is not None + + +class TestSignalModelsBistableMelcherThresholds: + def test_thresholds_set_after_integrate_t0(self): + model = BistableMelcherModel(alpha=-0.4) + model.integrate( + t_span=(0, 12), y0=[1.0, 0.0], method='heun_maruyama', dt=0.012, + kwargs={'random_seed': 1, 'si': 0.12}, + ) + assert model.stadial_threshold is not None + assert model.interstadial_threshold is not None + assert model.stadial_threshold < model.interstadial_threshold + + def test_thresholds_match_compute_stability_thresholds_t1(self): + model = BistableMelcherModel(alpha=-0.4) + model.integrate( + t_span=(0, 12), y0=[1.0, 0.0], method='heun_maruyama', dt=0.012, + kwargs={'random_seed': 1, 'si': 0.12}, + ) + expected_stadial, expected_interstadial = model.compute_stability_thresholds(-0.4) + assert model.stadial_threshold == expected_stadial + assert model.interstadial_threshold == expected_interstadial + + def test_callable_alpha_thresholds_match_scalar_t2(self): + """Regression test: populate_diagnostics_from_history resolves a + constant-valued callable alpha to the same thresholds as passing the + same constant directly (exercises the branch reworked alongside + compute_stability_thresholds's own mean/float reduction).""" + const_model = BistableMelcherModel(alpha=-0.4) + tv_model = BistableMelcherModel(alpha=lambda t: -0.4) + + t_span, y0, dt = (0, 12), [1.0, 0.0], 0.012 + + const_model.integrate(t_span=t_span, y0=y0, method='heun_maruyama', dt=dt, + kwargs={'random_seed': 7, 'si': 0.12}) + tv_model.integrate(t_span=t_span, y0=y0, method='heun_maruyama', dt=dt, + kwargs={'random_seed': 7, 'si': 0.12}) + + assert tv_model.stadial_threshold == const_model.stadial_threshold + assert tv_model.interstadial_threshold == const_model.interstadial_threshold + + def test_array_alpha_thresholds_use_mean_t3(self): + """compute_stability_thresholds should reduce an array-like alpha to its mean.""" + model = BistableMelcherModel() + alpha_arr = np.array([-0.2, -0.6]) + stadial, interstadial = model.compute_stability_thresholds(alpha_arr) + expected_stadial, expected_interstadial = model.compute_stability_thresholds(np.mean(alpha_arr)) + assert stadial == expected_stadial + assert interstadial == expected_interstadial + + +class TestSignalModelsBistableMelcherClassifyStandalone: + def test_classify_bistable_states_matches_model_t0(self): + model = BistableMelcherModel(alpha=-0.4) + output = model.integrate( + t_span=(0, 12), y0=[1.0, 0.0], method='heun_maruyama', dt=0.012, + kwargs={'random_seed': 3, 'si': 0.12}, + ) + db = output.state_variables['db'] + states_from_model = output.diagnostic_variables['states'] + states_standalone = classify_bistable_states(db, alpha=-0.4) + assert np.array_equal(states_from_model, states_standalone) + + def test_classify_bistable_states_hysteresis_t1(self): + """A signal that dips below the stadial threshold and rises above the + interstadial threshold should flip states with hysteresis (no chatter + for values between the two thresholds).""" + model = BistableMelcherModel(alpha=-0.4) + stadial, interstadial = model.compute_stability_thresholds(-0.4) + mid = 0.5 * (stadial + interstadial) + signal = np.array([interstadial + 0.1, mid, stadial - 0.1, mid, interstadial + 0.1]) + states = classify_bistable_states(signal, alpha=-0.4) + assert list(states) == [0, 0, 1, 1, 0] + + +class TestSignalModelsBistableMelcherSDENoise: + def test_zero_sigma_is_deterministic_t0(self): + """sigma=0 should make euler_maruyama reduce to the deterministic drift, + independent of the random seed.""" + model_a = BistableMelcherModel(sigma=0.0, alpha=-0.4) + model_b = BistableMelcherModel(sigma=0.0, alpha=-0.4) + t_span, y0, dt = (0, 12), [1.0, 0.0], 0.012 + out_a = model_a.integrate(t_span=t_span, y0=y0, method='euler_maruyama', dt=dt, + kwargs={'random_seed': 1, 'si': 0.12}) + out_b = model_b.integrate(t_span=t_span, y0=y0, method='euler_maruyama', dt=dt, + kwargs={'random_seed': 999, 'si': 0.12}) + assert np.allclose(out_a.state_variables['db'], out_b.state_variables['db']) + + def test_sde_noise_shape_and_scale_t1(self): + model = BistableMelcherModel(sigma=0.3) + diffusion = model.sde_noise(0.0, [1.0, 0.0]) + assert diffusion.shape == (2,) + assert np.allclose(diffusion, 0.3) + + +class TestSignalModelsBistableMelcherTimeVaryingParams: + def test_time_varying_params_match_constants_t0(self): + model_const = BistableMelcherModel(gamma=1.5, alpha=-0.4) + model_tv = BistableMelcherModel( + gamma=lambda t: 1.5, + alpha=lambda t, x: -0.4, + ) + t_span, y0, dt = (0, 1.2), [1.0, 0.0], 0.012 + model_const.integrate(t_span=t_span, y0=y0, method='euler', dt=dt) + model_tv.integrate(t_span=t_span, y0=y0, method='euler', dt=dt) + + const_last = np.array([model_const.state_variables['db'][-1], model_const.state_variables['B'][-1]]) + tv_last = np.array([model_tv.state_variables['db'][-1], model_tv.state_variables['B'][-1]]) + assert np.allclose(const_last, tv_last, rtol=1e-8, atol=1e-10)