From 8ffc395b8f5e90a970f0f9be80d7244d562373b5 Mon Sep 17 00:00:00 2001 From: shivv23 Date: Fri, 22 May 2026 15:48:50 +0530 Subject: [PATCH 1/2] feat: add model name check and count endpoints --- .../app/routers/deep_learning.py | 23 +++++++ .../app/services/deep_learning.py | 20 +++++++ .../tests/test_deep_learning_service.py | 60 ++++++++++++++++++- 3 files changed, 102 insertions(+), 1 deletion(-) diff --git a/tensormap-backend/app/routers/deep_learning.py b/tensormap-backend/app/routers/deep_learning.py index 5b04dfee..599d5578 100644 --- a/tensormap-backend/app/routers/deep_learning.py +++ b/tensormap-backend/app/routers/deep_learning.py @@ -9,9 +9,11 @@ from app.database import get_db from app.schemas.deep_learning import ModelNameRequest, ModelSaveRequest, ModelValidateRequest, TrainingConfigRequest from app.services.deep_learning import ( + check_model_name_service, delete_model_service, get_available_model_list, get_code_service, + get_model_count_service, get_model_graph_service, get_training_history_service, model_save_service, @@ -124,3 +126,24 @@ def get_training_history( """Return a paginated list of models with enriched metadata for training history view.""" body, status_code = get_training_history_service(db, project_id=project_id, offset=offset, limit=limit) return JSONResponse(status_code=status_code, content=body) + + +@router.get("/model/check-name") +def check_model_name( + model_name: str = Query(), + project_id: uuid_pkg.UUID | None = Query(None), + db: Session = Depends(get_db), +): + """Check if a model name is available.""" + body, status_code = check_model_name_service(db, model_name=model_name, project_id=project_id) + return JSONResponse(status_code=status_code, content=body) + + +@router.get("/model/count") +def get_model_count( + project_id: uuid_pkg.UUID | None = Query(None), + db: Session = Depends(get_db), +): + """Get the total count of saved models.""" + body, status_code = get_model_count_service(db, project_id=project_id) + return JSONResponse(status_code=status_code, content=body) diff --git a/tensormap-backend/app/services/deep_learning.py b/tensormap-backend/app/services/deep_learning.py index 531ad143..1e145ba0 100644 --- a/tensormap-backend/app/services/deep_learning.py +++ b/tensormap-backend/app/services/deep_learning.py @@ -553,6 +553,26 @@ def get_available_model_list( return body, 200 +def check_model_name_service(db: Session, model_name: str, project_id: uuid_pkg.UUID | None = None) -> tuple: + """Check if a model name is available.""" + stmt = select(ModelBasic).where(ModelBasic.model_name == model_name) + if project_id is not None: + stmt = stmt.where(ModelBasic.project_id == project_id) + existing = db.exec(stmt).first() + if existing: + return {"success": False, "message": "Model name already in use", "data": {"available": False}}, 200 + return {"success": True, "message": "Model name is available", "data": {"available": True}}, 200 + + +def get_model_count_service(db: Session, project_id: uuid_pkg.UUID | None = None) -> tuple: + """Get count of saved models.""" + stmt = select(func.count(ModelBasic.id)) + if project_id is not None: + stmt = stmt.where(ModelBasic.project_id == project_id) + total = db.exec(stmt).one() + return {"success": True, "message": "Model count retrieved", "data": {"count": total}}, 200 + + def get_training_history_service( db: Session, project_id: uuid_pkg.UUID | None = None, offset: int = 0, limit: int = 50 ) -> tuple: diff --git a/tensormap-backend/tests/test_deep_learning_service.py b/tensormap-backend/tests/test_deep_learning_service.py index 3176b620..f3586ef0 100644 --- a/tensormap-backend/tests/test_deep_learning_service.py +++ b/tensormap-backend/tests/test_deep_learning_service.py @@ -17,7 +17,7 @@ sys.modules.setdefault("flatten_json", MagicMock()) from app.models.ml import ModelBasic # noqa: E402 -from app.services.deep_learning import delete_model_service # noqa: E402 +from app.services.deep_learning import check_model_name_service, delete_model_service, get_model_count_service # noqa: E402 # --------------------------------------------------------------------------- # Fixtures @@ -113,3 +113,61 @@ def test_get_called_with_correct_id(self, mock_db, sample_model, tmp_path): delete_model_service(mock_db, model_id=42) mock_db.get.assert_called_once_with(ModelBasic, 42) + + +# --------------------------------------------------------------------------- +# check_model_name_service +# --------------------------------------------------------------------------- + + +class TestCheckModelNameService: + def test_returns_available_when_name_not_taken(self, mock_db): + """A model name that doesn't exist should be reported as available.""" + mock_db.exec.return_value.first.return_value = None + body, status_code = check_model_name_service(mock_db, "new_model") + assert status_code == 200 + assert body["success"] is True + assert body["data"]["available"] is True + + def test_returns_unavailable_when_name_taken(self, mock_db): + """A model name that already exists should be reported as unavailable.""" + mock_db.exec.return_value.first.return_value = MagicMock() + body, status_code = check_model_name_service(mock_db, "existing_model") + assert status_code == 200 + assert body["success"] is False + assert body["data"]["available"] is False + + def test_filters_by_project_id(self, mock_db): + """When project_id is provided, the query should include it.""" + mock_db.exec.return_value.first.return_value = None + check_model_name_service(mock_db, "my_model", project_id="proj-1") + call_stmt = mock_db.exec.call_args[0][0] + assert "project_id" in str(call_stmt) + + +# --------------------------------------------------------------------------- +# get_model_count_service +# --------------------------------------------------------------------------- + + +class TestGetModelCountService: + def test_returns_zero_when_no_models(self, mock_db): + """Count should be 0 when no models exist.""" + mock_db.exec.return_value.one.return_value = 0 + body, status_code = get_model_count_service(mock_db) + assert status_code == 200 + assert body["data"]["count"] == 0 + + def test_returns_correct_count(self, mock_db): + """Count should match the number of models in the database.""" + mock_db.exec.return_value.one.return_value = 5 + body, status_code = get_model_count_service(mock_db) + assert status_code == 200 + assert body["data"]["count"] == 5 + + def test_filters_by_project_id(self, mock_db): + """When project_id is provided, the query should include it.""" + mock_db.exec.return_value.one.return_value = 2 + get_model_count_service(mock_db, project_id="proj-1") + call_stmt = mock_db.exec.call_args[0][0] + assert "project_id" in str(call_stmt) From 8e0e007fb0fdc6e20854733c2a90057594b7e1be Mon Sep 17 00:00:00 2001 From: shivv23 Date: Fri, 22 May 2026 16:01:33 +0530 Subject: [PATCH 2/2] fix: split long import line to satisfy ruff I001 --- tensormap-backend/tests/test_deep_learning_service.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tensormap-backend/tests/test_deep_learning_service.py b/tensormap-backend/tests/test_deep_learning_service.py index f3586ef0..aaee8666 100644 --- a/tensormap-backend/tests/test_deep_learning_service.py +++ b/tensormap-backend/tests/test_deep_learning_service.py @@ -17,7 +17,11 @@ sys.modules.setdefault("flatten_json", MagicMock()) from app.models.ml import ModelBasic # noqa: E402 -from app.services.deep_learning import check_model_name_service, delete_model_service, get_model_count_service # noqa: E402 +from app.services.deep_learning import ( # noqa: E402 + check_model_name_service, + delete_model_service, + get_model_count_service, +) # --------------------------------------------------------------------------- # Fixtures