diff --git a/ami/main/api/views.py b/ami/main/api/views.py index 9a2770ac8..ba2a88077 100644 --- a/ami/main/api/views.py +++ b/ami/main/api/views.py @@ -13,6 +13,7 @@ from django_filters.rest_framework import DjangoFilterBackend from drf_spectacular.types import OpenApiTypes from drf_spectacular.utils import OpenApiParameter, extend_schema +from pydantic import ValidationError from rest_framework import exceptions as api_exceptions from rest_framework import filters, serializers, status, viewsets from rest_framework.decorators import action @@ -31,6 +32,8 @@ from ami.base.views import ProjectMixin from ami.main.api.schemas import project_id_doc_param from ami.main.api.serializers import TagSerializer +from ami.ml.models.processing_service import ProcessingService +from ami.ml.schemas import PipelineRegistrationResponse from ami.utils.requests import get_default_classification_threshold from ami.utils.storages import ConnectionTestResult @@ -206,6 +209,63 @@ def charts(self, request, pk=None): project = self.get_object() return Response({"summary_data": project.summary_data()}) + @action(detail=True, methods=["post"], url_path="pipelines") + def pipelines(self, request, pk=None): + """ + Receive pipeline registrations for a project. This endpoint is called by the + V2 ML processing services to register available pipelines for a project. + + Expected payload: PipelineRegistrationResponse (pydantic schema) containing a + list of PipelineConfigResponse objects under the `pipelines` key. + + Behavior: + - If the project has no associated ProcessingService, create a dummy one and + associate it with the project. + - Call ProcessingService.create_pipelines() with the provided pipeline configs + and limit the operation to this project. + + Returns the PipelineRegistrationResponse returned by create_pipelines(). + """ + # Parse the incoming payload using the pydantic schema so we convert dicts to + # the expected PipelineConfigResponse models + try: + parsed: PipelineRegistrationResponse = PipelineRegistrationResponse.parse_obj(request.data) + except ValidationError as err: + logger.debug(f"Invalid pipeline registration payload: {err}") + return Response({"detail": str(err)}, status=status.HTTP_400_BAD_REQUEST) + + project: Project = self.get_object() + + # TODO: Discuss the right approach for associating pipelines with projects in V2. + # For now, we create a dummy processing service if none exists (hack). + + # Find an existing processing service for this project + processing_service = ProcessingService.objects.filter(projects=project).first() + + if not processing_service: + # Create a dummy processing service and associate it with the project + processing_service = ProcessingService.objects.create( + name=f"Dummy Processing Service for project {project.pk}", + endpoint_url=f"http://dummy.local/projects/{project.pk}/processing-service", + ) + processing_service.projects.add(project) + processing_service.save() + logger.info(f"Created dummy processing service {processing_service} for project {project.pk}") + + pipeline_configs = parsed.pipelines if parsed and parsed.pipelines else None + + # Call create_pipelines limited to this project + response = processing_service.create_pipelines( + pipeline_configs=pipeline_configs, + projects=Project.objects.filter(pk=project.pk), + ) + + # Save any changes to the processing service + processing_service.save() + + # response is a pydantic model; return its dict representation + return Response(response.dict()) + @extend_schema( parameters=[ OpenApiParameter(