diff --git a/openadmet/models/active_learning/committee.py b/openadmet/models/active_learning/committee.py index 5402a70e..d07d7dca 100644 --- a/openadmet/models/active_learning/committee.py +++ b/openadmet/models/active_learning/committee.py @@ -353,7 +353,7 @@ def query(self, X, query_strategy: str = None, **kwargs): return _ACQUISITION_FUNCTIONS[query_strategy](mean, std, **kwargs) - def _predict(self, X, return_std=False, **kwargs): + def _predict(self, X, return_std=False, return_all=False, **kwargs): """ Make predictions using the committee model. @@ -363,35 +363,46 @@ def _predict(self, X, return_std=False, **kwargs): The input samples to predict. return_std : bool, optional Whether to return the standard deviation of the predictions. + Mutually exclusive with ``return_all``. + return_all : bool, optional + Whether to return the raw per-member predictions of shape + (n_samples, n_tasks, n_members) instead of the mean (and std). + Mutually exclusive with ``return_std``. **kwargs : dict Additional keyword arguments to pass to the committee's predict method. Returns ------- - array-like - Predicted values or probabilities, depending on the committee's implementation. + array-like or tuple + mean, or (mean, std), or ndarray of shape (n_samples, n_tasks, n_members) + depending on the values of return_std and return_all. """ - # Make predictions + if return_std and return_all: + raise ValueError( + "return_std and return_all are mutually exclusive. " + "When return_all=True, compute mean and std from the returned array as needed." + ) + + # Make predictions: (n_samples, n_tasks, n_members) preds = np.stack([model.predict(X, **kwargs) for model in self.models], axis=-1) + if return_all: + return preds + # Compute mean mean = np.mean(preds, axis=-1) - # Skip std if not requested - if return_std is False: + if not return_std: return mean # Compute standard deviation, guard against zero std std = np.maximum(np.std(preds, axis=-1), 1e-8) - - # Calibrate std if calibration model is available if self.calibrated: std = self._get_calibration_function()(std) - return mean, std - def predict(self, X, return_std=False, **kwargs): + def predict(self, X, return_std=False, return_all=False, **kwargs): """ Make predictions using the committee model. @@ -401,13 +412,19 @@ def predict(self, X, return_std=False, **kwargs): The input samples to predict. return_std : bool, optional Whether to return the standard deviation of the predictions. + Mutually exclusive with ``return_all``. + return_all : bool, optional + Whether to return the raw per-member predictions of shape + (n_samples, n_tasks, n_members) instead of the mean (and std). + Mutually exclusive with ``return_std``. **kwargs : dict Additional keyword arguments to pass to the committee's predict method. Returns ------- - array-like - Predicted values or probabilities, depending on the committee's implementation. + array-like or tuple + mean, or (mean, std), or ndarray of shape (n_samples, n_tasks, n_members) + depending on the values of return_std and return_all. """ if return_std is True and not self.calibrated: @@ -415,7 +432,7 @@ def predict(self, X, return_std=False, **kwargs): "Standard deviation not calibrated: consider calling `calibrate_uncertainty`." ) - return self._predict(X, return_std=return_std, **kwargs) + return self._predict(X, return_std=return_std, return_all=return_all, **kwargs) def _save_calibration_model(self, path: PathLike = "calibration_model.pkl"): # Save calibration model diff --git a/openadmet/models/tests/unit/active_learning/test_active_learning.py b/openadmet/models/tests/unit/active_learning/test_active_learning.py index f560c967..a6593204 100644 --- a/openadmet/models/tests/unit/active_learning/test_active_learning.py +++ b/openadmet/models/tests/unit/active_learning/test_active_learning.py @@ -36,6 +36,31 @@ def dummy_models(): return models +def test_return_all(toy_data): + """Test that return_all exposes raw per-member predictions with correct shape.""" + X_train, _, X_test, y_train, _, _ = toy_data + n_members = 4 + n_tasks = 1 + + committee = CommitteeRegressor.train( + X_train, + y_train, + mod_class=DummyRegressorModel, + n_models=n_members, + ) + + # return_all returns only the raw array + preds = committee.predict(X_test, return_all=True) + assert preds.shape == (X_test.shape[0], n_tasks, n_members) + # mean and std are derivable from the returned array + assert np.mean(preds, axis=-1).shape == (X_test.shape[0], n_tasks) + assert np.std(preds, axis=-1).shape == (X_test.shape[0], n_tasks) + + # return_all and return_std are mutually exclusive + with pytest.raises(ValueError, match="mutually exclusive"): + committee.predict(X_test, return_std=True, return_all=True) + + @pytest.fixture def trained_committee(dummy_models, toy_data): """