Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions tensormap-backend/app/routers/deep_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from app.database import get_db
from app.schemas.deep_learning import ModelNameRequest, ModelSaveRequest, ModelValidateRequest, TrainingConfigRequest
from app.services.deep_learning import (
check_model_name_service,
delete_model_service,
get_available_model_list,
get_code_service,
get_model_count_service,
get_model_graph_service,
get_training_history_service,
model_save_service,
Expand Down Expand Up @@ -124,3 +126,24 @@ def get_training_history(
"""Return a paginated list of models with enriched metadata for training history view."""
body, status_code = get_training_history_service(db, project_id=project_id, offset=offset, limit=limit)
return JSONResponse(status_code=status_code, content=body)


@router.get("/model/check-name")
def check_model_name(
model_name: str = Query(),
project_id: uuid_pkg.UUID | None = Query(None),
db: Session = Depends(get_db),
):
"""Check if a model name is available."""
body, status_code = check_model_name_service(db, model_name=model_name, project_id=project_id)
return JSONResponse(status_code=status_code, content=body)


@router.get("/model/count")
def get_model_count(
project_id: uuid_pkg.UUID | None = Query(None),
db: Session = Depends(get_db),
):
"""Get the total count of saved models."""
body, status_code = get_model_count_service(db, project_id=project_id)
return JSONResponse(status_code=status_code, content=body)
20 changes: 20 additions & 0 deletions tensormap-backend/app/services/deep_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,26 @@ def get_available_model_list(
return body, 200


def check_model_name_service(db: Session, model_name: str, project_id: uuid_pkg.UUID | None = None) -> tuple:
"""Check if a model name is available."""
stmt = select(ModelBasic).where(ModelBasic.model_name == model_name)
if project_id is not None:
stmt = stmt.where(ModelBasic.project_id == project_id)
existing = db.exec(stmt).first()
if existing:
return {"success": False, "message": "Model name already in use", "data": {"available": False}}, 200
return {"success": True, "message": "Model name is available", "data": {"available": True}}, 200


def get_model_count_service(db: Session, project_id: uuid_pkg.UUID | None = None) -> tuple:
"""Get count of saved models."""
stmt = select(func.count(ModelBasic.id))
if project_id is not None:
stmt = stmt.where(ModelBasic.project_id == project_id)
total = db.exec(stmt).one()
return {"success": True, "message": "Model count retrieved", "data": {"count": total}}, 200


def get_training_history_service(
db: Session, project_id: uuid_pkg.UUID | None = None, offset: int = 0, limit: int = 50
) -> tuple:
Expand Down
64 changes: 63 additions & 1 deletion tensormap-backend/tests/test_deep_learning_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
sys.modules.setdefault("flatten_json", MagicMock())

from app.models.ml import ModelBasic # noqa: E402
from app.services.deep_learning import delete_model_service # noqa: E402
from app.services.deep_learning import ( # noqa: E402
check_model_name_service,
delete_model_service,
get_model_count_service,
)

# ---------------------------------------------------------------------------
# Fixtures
Expand Down Expand Up @@ -113,3 +117,61 @@ def test_get_called_with_correct_id(self, mock_db, sample_model, tmp_path):
delete_model_service(mock_db, model_id=42)

mock_db.get.assert_called_once_with(ModelBasic, 42)


# ---------------------------------------------------------------------------
# check_model_name_service
# ---------------------------------------------------------------------------


class TestCheckModelNameService:
def test_returns_available_when_name_not_taken(self, mock_db):
"""A model name that doesn't exist should be reported as available."""
mock_db.exec.return_value.first.return_value = None
body, status_code = check_model_name_service(mock_db, "new_model")
assert status_code == 200
assert body["success"] is True
assert body["data"]["available"] is True

def test_returns_unavailable_when_name_taken(self, mock_db):
"""A model name that already exists should be reported as unavailable."""
mock_db.exec.return_value.first.return_value = MagicMock()
body, status_code = check_model_name_service(mock_db, "existing_model")
assert status_code == 200
assert body["success"] is False
assert body["data"]["available"] is False

def test_filters_by_project_id(self, mock_db):
"""When project_id is provided, the query should include it."""
mock_db.exec.return_value.first.return_value = None
check_model_name_service(mock_db, "my_model", project_id="proj-1")
call_stmt = mock_db.exec.call_args[0][0]
assert "project_id" in str(call_stmt)


# ---------------------------------------------------------------------------
# get_model_count_service
# ---------------------------------------------------------------------------


class TestGetModelCountService:
def test_returns_zero_when_no_models(self, mock_db):
"""Count should be 0 when no models exist."""
mock_db.exec.return_value.one.return_value = 0
body, status_code = get_model_count_service(mock_db)
assert status_code == 200
assert body["data"]["count"] == 0

def test_returns_correct_count(self, mock_db):
"""Count should match the number of models in the database."""
mock_db.exec.return_value.one.return_value = 5
body, status_code = get_model_count_service(mock_db)
assert status_code == 200
assert body["data"]["count"] == 5

def test_filters_by_project_id(self, mock_db):
"""When project_id is provided, the query should include it."""
mock_db.exec.return_value.one.return_value = 2
get_model_count_service(mock_db, project_id="proj-1")
call_stmt = mock_db.exec.call_args[0][0]
assert "project_id" in str(call_stmt)
Loading