From b5e4cd2f7caeb2c0a393871109b7b22363eb1aa6 Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Sun, 22 Feb 2026 10:42:12 +0000 Subject: [PATCH] Fix OR vs AND logic error in get_model and guard both-None case --- cellfinder/core/classify/tools.py | 33 ++++++++++++------- .../test_unit/test_classify/test_tools.py | 30 +++++++++++++++++ 2 files changed, 51 insertions(+), 12 deletions(-) diff --git a/cellfinder/core/classify/tools.py b/cellfinder/core/classify/tools.py index 7ad13546..4a6951df 100644 --- a/cellfinder/core/classify/tools.py +++ b/cellfinder/core/classify/tools.py @@ -18,24 +18,33 @@ def get_model( inference: bool = False, continue_training: bool = False, ) -> Model: - """Returns the correct model based on the arguments passed - :param existing_model: An existing, trained model. This is returned if it - exists - :param model_weights: This file is used to set the model weights if it - exists - :param network_depth: This defines the type of model to be created if - necessary - :param learning_rate: For creating a new model + """Returns the correct model based on the arguments passed. + + If ``existing_model`` is provided it is loaded and returned directly, + regardless of ``network_depth``. Otherwise a new model is built using + ``network_depth``. + + :param existing_model: Path to an existing, trained model. Takes + precedence over ``network_depth`` when both are supplied. + :param model_weights: Path used to set the model weights when + ``inference`` or ``continue_training`` is True. + :param network_depth: Defines the architecture of a new model to build + when no ``existing_model`` is given. + :param learning_rate: Learning rate for a newly built model. :param inference: If True, will ensure that a trained model exists. E.g. - by using the default one + by using the default one. :param continue_training: If True, will ensure that a trained model - exists. E.g. by using the default one - :return: A keras model + exists. E.g. by using the default one. + :return: A keras model. """ - if existing_model is not None or network_depth is None: + if existing_model is not None: logger.debug(f"Loading model: {existing_model}") return keras.models.load_model(existing_model) + elif network_depth is None: + raise ValueError( + "Either `existing_model` or `network_depth` must be provided." + ) else: logger.debug(f"Creating a new instance of model: {network_depth}") model = build_model( diff --git a/tests/core/test_unit/test_classify/test_tools.py b/tests/core/test_unit/test_classify/test_tools.py index 87a4e17a..6648ae51 100644 --- a/tests/core/test_unit/test_classify/test_tools.py +++ b/tests/core/test_unit/test_classify/test_tools.py @@ -38,3 +38,33 @@ def test_incorrect_weights(mock_build_model): inference=True, model_weights="incorrect_weights.h5", ) + + +@patch("cellfinder.core.classify.tools.build_model") +@patch("cellfinder.core.classify.tools.keras.models.load_model") +def test_get_model_existing_takes_precedence( + mock_load_model, mock_build_model +): + """Test that existing_model takes precedence over network_depth.""" + tools.get_model(existing_model="/some/path", network_depth="18-layer") + mock_load_model.assert_called_once_with("/some/path") + mock_build_model.assert_not_called() + + +@patch("cellfinder.core.classify.tools.build_model") +def test_get_model_builds_with_depth_only(mock_build_model): + """network_depth alone should build and return a new model.""" + tools.get_model(network_depth="18-layer") + mock_build_model.assert_called_once_with( + network_depth="18-layer", + learning_rate=0.0001, + ) + + +def test_get_model_no_model_no_depth(): + """Both existing_model and network_depth as None should raise.""" + with pytest.raises( + ValueError, + match="Either `existing_model` or `network_depth` must be provided.", + ): + tools.get_model(existing_model=None, network_depth=None)