Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions cellfinder/core/classify/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,33 @@ def get_model(
inference: bool = False,
continue_training: bool = False,
) -> Model:
"""Returns the correct model based on the arguments passed
:param existing_model: An existing, trained model. This is returned if it
exists
:param model_weights: This file is used to set the model weights if it
exists
:param network_depth: This defines the type of model to be created if
necessary
:param learning_rate: For creating a new model
"""Returns the correct model based on the arguments passed.

If ``existing_model`` is provided it is loaded and returned directly,
regardless of ``network_depth``. Otherwise a new model is built using
``network_depth``.

:param existing_model: Path to an existing, trained model. Takes
precedence over ``network_depth`` when both are supplied.
:param model_weights: Path used to set the model weights when
``inference`` or ``continue_training`` is True.
:param network_depth: Defines the architecture of a new model to build
when no ``existing_model`` is given.
:param learning_rate: Learning rate for a newly built model.
:param inference: If True, will ensure that a trained model exists. E.g.
by using the default one
by using the default one.
:param continue_training: If True, will ensure that a trained model
exists. E.g. by using the default one
:return: A keras model
exists. E.g. by using the default one.
:return: A keras model.

"""
if existing_model is not None or network_depth is None:
if existing_model is not None:
logger.debug(f"Loading model: {existing_model}")
return keras.models.load_model(existing_model)
elif network_depth is None:
raise ValueError(
"Either `existing_model` or `network_depth` must be provided."
)
else:
logger.debug(f"Creating a new instance of model: {network_depth}")
model = build_model(
Expand Down
30 changes: 30 additions & 0 deletions tests/core/test_unit/test_classify/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,33 @@ def test_incorrect_weights(mock_build_model):
inference=True,
model_weights="incorrect_weights.h5",
)


@patch("cellfinder.core.classify.tools.build_model")
@patch("cellfinder.core.classify.tools.keras.models.load_model")
def test_get_model_existing_takes_precedence(
mock_load_model, mock_build_model
):
"""Test that existing_model takes precedence over network_depth."""
tools.get_model(existing_model="/some/path", network_depth="18-layer")
mock_load_model.assert_called_once_with("/some/path")
mock_build_model.assert_not_called()


@patch("cellfinder.core.classify.tools.build_model")
def test_get_model_builds_with_depth_only(mock_build_model):
"""network_depth alone should build and return a new model."""
tools.get_model(network_depth="18-layer")
mock_build_model.assert_called_once_with(
network_depth="18-layer",
learning_rate=0.0001,
)


def test_get_model_no_model_no_depth():
"""Both existing_model and network_depth as None should raise."""
with pytest.raises(
ValueError,
match="Either `existing_model` or `network_depth` must be provided.",
):
tools.get_model(existing_model=None, network_depth=None)