From 4d5be92d3c0b8a1eb499c37a44b972c4d56cd7b8 Mon Sep 17 00:00:00 2001 From: Sylvain Zimmer Date: Mon, 23 Feb 2026 22:52:54 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8(tasks)=20switch=20from=20celery=20to?= =?UTF-8?q?=20dramatiq?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Makefile | 13 + compose.yaml | 2 +- src/backend/core/admin.py | 3 +- src/backend/core/api/viewsets/send.py | 3 +- src/backend/core/api/viewsets/task.py | 115 +++++---- .../core/management/commands/run_task.py | 2 +- src/backend/core/mda/inbound_tasks.py | 20 +- src/backend/core/mda/outbound.py | 2 +- src/backend/core/mda/outbound_tasks.py | 138 +++++------ src/backend/core/services/exporter/tasks.py | 49 ++-- .../core/services/importer/eml_tasks.py | 19 +- src/backend/core/services/importer/imap.py | 50 ++-- .../core/services/importer/imap_tasks.py | 16 +- .../core/services/importer/mbox_tasks.py | 43 ++-- .../core/services/importer/pst_tasks.py | 44 ++-- src/backend/core/services/importer/service.py | 21 +- src/backend/core/services/search/tasks.py | 98 ++++---- src/backend/core/tasks.py | 3 +- .../core/tests/api/test_draft_attachments.py | 4 +- .../core/tests/api/test_import_file_upload.py | 52 ++-- .../core/tests/api/test_messages_create.py | 1 + .../core/tests/api/test_messages_import.py | 56 ++--- .../core/tests/exporter/test_export_task.py | 124 +++------- .../core/tests/importer/test_file_import.py | 185 ++++++-------- .../core/tests/importer/test_imap_import.py | 228 +++++++----------- .../tests/importer/test_import_service.py | 172 +++++-------- .../core/tests/importer/test_pst_import.py | 56 ++--- src/backend/core/tests/mda/test_retry.py | 90 +++---- .../core/tests/mda/test_spam_processing.py | 20 +- .../core/tests/tasks/test_task_importer.py | 90 +++---- .../tests/tasks/test_task_send_message.py | 30 +-- src/backend/core/tests/test_worker.py | 160 ++++++------ src/backend/core/utils.py | 149 ++++++++++++ src/backend/messages/__init__.py | 13 +- src/backend/messages/celery_app.py | 48 ---- src/backend/messages/settings.py | 114 +++++---- src/backend/pyproject.toml | 6 +- src/backend/uv.lock | 134 +++++----- src/backend/worker.py | 95 +++++--- 39 files changed, 1204 insertions(+), 1264 deletions(-) delete mode 100644 src/backend/messages/celery_app.py mode change 100755 => 100644 src/backend/messages/settings.py diff --git a/Makefile b/Makefile index 0e92ad6c7..b4936bcd7 100644 --- a/Makefile +++ b/Makefile @@ -133,6 +133,19 @@ logs: ## display all services logs (follow mode) @$(COMPOSE) logs -f .PHONY: logs +logs-worker: ## display worker logs (follow mode) + @$(COMPOSE) logs worker-dev -f +.PHONY: logs-worker + +logs-back: ## display backend logs (follow mode) + @$(COMPOSE) logs backend-dev -f +.PHONY: logs-back + +logs-front: ## display frontend logs (follow mode) + @$(COMPOSE) logs frontend-dev -f +.PHONY: logs-front + + start: ## start all development services @$(COMPOSE) up --force-recreate --build -d frontend-dev backend-dev worker-dev mta-in --wait .PHONY: start diff --git a/compose.yaml b/compose.yaml index 0f3ce3ef2..82c11169c 100644 --- a/compose.yaml +++ b/compose.yaml @@ -154,7 +154,7 @@ services: args: DOCKER_USER: ${DOCKER_USER:-1000} user: ${DOCKER_USER:-1000} - command: ["python", "worker.py", "--loglevel=DEBUG"] + command: ["python", "worker.py", "-v", "2"] environment: - DJANGO_CONFIGURATION=Development env_file: diff --git a/src/backend/core/admin.py b/src/backend/core/admin.py index f204cd004..4222a52d0 100644 --- a/src/backend/core/admin.py +++ b/src/backend/core/admin.py @@ -17,7 +17,6 @@ from sentry_sdk import capture_exception from core.api.utils import get_file_key -from core.api.viewsets.task import register_task_owner from core.mda.outbound_tasks import retry_messages_task from core.services.dns.provisioning import provision_domain_dns from core.services.exporter.tasks import export_mailbox_task @@ -327,7 +326,7 @@ def export_messages_view(self, request, object_id): # Start the export task try: task = export_mailbox_task.delay(str(mailbox_obj.id), str(request.user.id)) - register_task_owner(task.id, request.user.id) + task.track_owner(request.user.id) except Exception: # pylint: disable=broad-exception-caught logging.exception( "Failed to queue export task for mailbox %s", mailbox_obj.id diff --git a/src/backend/core/api/viewsets/send.py b/src/backend/core/api/viewsets/send.py index a84e15477..99977890b 100644 --- a/src/backend/core/api/viewsets/send.py +++ b/src/backend/core/api/viewsets/send.py @@ -14,7 +14,6 @@ from rest_framework.views import APIView from core import models -from core.api.viewsets.task import register_task_owner from core.mda.outbound import prepare_outbound_message from core.mda.outbound_tasks import send_message_task @@ -122,7 +121,7 @@ def post(self, request): # Launch async task for sending the message task = send_message_task.delay(str(message.id), must_archive=must_archive) - register_task_owner(task.id, request.user.id) + task.track_owner(request.user.id) # --- Finalize --- # Message state should be updated by prepare_outbound_message/send_message diff --git a/src/backend/core/api/viewsets/task.py b/src/backend/core/api/viewsets/task.py index a1cb5fd9f..673038663 100644 --- a/src/backend/core/api/viewsets/task.py +++ b/src/backend/core/api/viewsets/task.py @@ -1,11 +1,10 @@ -"""API ViewSet for Celery task status.""" +"""API ViewSet for asynchronous task statuses.""" import logging -from django.core.cache import cache +import dramatiq +from dramatiq.results import ResultFailure, ResultMissing, Results -from celery import states as celery_states -from celery.result import AsyncResult from drf_spectacular.utils import ( OpenApiExample, extend_schema, @@ -17,16 +16,12 @@ from rest_framework.response import Response from rest_framework.views import APIView -from messages.celery_app import app as celery_app +from core.utils import get_task_progress, get_task_tracking logger = logging.getLogger(__name__) -TASK_OWNER_CACHE_TTL = 86400 # 24 hours - -def register_task_owner(task_id, user_id): - """Register the owner of a task for permission checks.""" - cache.set(f"task_owner:{task_id}", str(user_id), timeout=TASK_OWNER_CACHE_TTL) +TASK_STATES = ["PENDING", "SUCCESS", "FAILURE", "PROGRESS"] @extend_schema( @@ -44,11 +39,13 @@ def register_task_owner(task_id, user_id): 200: inline_serializer( name="TaskStatusResponse", fields={ - "status": drf_serializers.ChoiceField( - choices=sorted({*celery_states.ALL_STATES, "PROGRESS"}) - ), + "status": drf_serializers.ChoiceField(choices=sorted(TASK_STATES)), "result": drf_serializers.JSONField(allow_null=True), "error": drf_serializers.CharField(allow_null=True), + # Present when status == "PROGRESS" + "progress": drf_serializers.IntegerField(required=False), + "message": drf_serializers.CharField(required=False, allow_blank=True), + "timestamp": drf_serializers.FloatField(required=False), }, ) }, @@ -69,43 +66,69 @@ def register_task_owner(task_id, user_id): ], ) class TaskDetailView(APIView): - """View to retrieve the status of a Celery task.""" + """View to retrieve the status of a task.""" permission_classes = [permissions.IsAuthenticated] def get(self, request, task_id): - """Get the status of a Celery task.""" - owner_id = cache.get(f"task_owner:{task_id}") - if owner_id is None: + """Get the status of a task.""" + tracking = get_task_tracking(task_id) + if tracking is None: raise PermissionDenied("Task not found or access expired.") - if str(request.user.id) != owner_id: + if str(request.user.id) != tracking["owner"]: raise PermissionDenied("You do not have access to this task.") - task_result = AsyncResult(task_id, app=celery_app) - - # By default unknown tasks will be in PENDING. There is no reliable - # way to check if a task exists or not with Celery. - # https://github.com/celery/celery/issues/3596#issuecomment-262102185 - - # Prepare the response data - result_data = { - "status": task_result.status, - "result": None, - "error": None, - } - - # If the result is a dict with status/result/error, unpack and propagate status - if isinstance(task_result.result, dict) and set(task_result.result.keys()) >= { - "status", - "result", - "error", - }: - result_data["status"] = task_result.result["status"] - result_data["result"] = task_result.result["result"] - result_data["error"] = task_result.result["error"] - else: - result_data["result"] = task_result.result - if task_result.state == "PROGRESS" and task_result.info: - result_data.update(task_result.info) - - return Response(result_data) + # Try to fetch the result from dramatiq's native result backend + message = dramatiq.Message( + queue_name=tracking["queue_name"], + actor_name=tracking["actor_name"], + args=(), + kwargs={}, + options={}, + message_id=task_id, + ) + try: + result_data = message.get_result(block=False) + except ResultMissing: + result_data = None + except ResultFailure as exc: + return Response({ + "status": "FAILURE", + "result": None, + "error": str(exc), + }) + + if result_data is not None: + response = {"status": "SUCCESS", "result": result_data, "error": None} + # If the result follows the {status, result, error} convention, unpack it + if ( + isinstance(result_data, dict) + and {"status", "result", "error"} <= result_data.keys() + ): + response["status"] = result_data["status"] + response["result"] = result_data["result"] + response["error"] = result_data["error"] + return Response(response) + + # Check if we have progress data for this task + progress_data = get_task_progress(task_id) + if progress_data: + return Response( + { + "status": "PROGRESS", + "result": None, + "error": None, + "progress": progress_data.get("progress"), + "message": progress_data.get("metadata", {}).get("message"), + "timestamp": progress_data.get("timestamp"), + } + ) + + # Default to pending when no result and no progress + return Response( + { + "status": "PENDING", + "result": None, + "error": None, + } + ) diff --git a/src/backend/core/management/commands/run_task.py b/src/backend/core/management/commands/run_task.py index dc9f2f458..b618d5c99 100644 --- a/src/backend/core/management/commands/run_task.py +++ b/src/backend/core/management/commands/run_task.py @@ -76,7 +76,7 @@ def handle(self, *args, **options): try: # Execute task synchronously - result = task_func.apply(args=task_args, kwargs=kwargs) + result = task_func(*task_args, **kwargs) # Output result if options["json"]: diff --git a/src/backend/core/mda/inbound_tasks.py b/src/backend/core/mda/inbound_tasks.py index 5d334476b..fbb195cd6 100644 --- a/src/backend/core/mda/inbound_tasks.py +++ b/src/backend/core/mda/inbound_tasks.py @@ -8,16 +8,15 @@ from django.core.cache import cache from django.utils import timezone -import requests -from celery.utils.log import get_task_logger +import logging from core import models from core.mda.inbound_create import _create_message_from_inbound from core.mda.rfc5322 import parse_email_message +import requests +from core.utils import cron_task, register_task -from messages.celery_app import app as celery_app - -logger = get_task_logger(__name__) +logger = logging.getLogger(__name__) def _check_spam_with_hardcoded_rules( @@ -167,8 +166,8 @@ def _check_spam_with_rspamd( return False, str(e) -@celery_app.task(bind=True) -def process_inbound_message_task(self, inbound_message_id: str): +@register_task(queue="inbound") +def process_inbound_message_task(inbound_message_id: str): """Process an inbound message from the queue: check spam and create message. Args: @@ -285,12 +284,13 @@ def process_inbound_message_task(self, inbound_message_id: str): cache.delete(lock_key) -@celery_app.task(bind=True) -def process_inbound_messages_queue_task(self, batch_size: int = 10): +@cron_task(interval=300) +@register_task(queue="inbound") +def process_inbound_messages_queue_task(batch_size: int = 10): """Retry processing of inbound messages that are older than 5 minutes. This task only handles retries for messages that may have failed or gotten stuck. - Regular messages are processed immediately when created via process_inbound_message_task.delay(). + Regular messages are processed immediately when created via process_inbound_message_task.send(). Args: batch_size: Number of messages to process in this batch diff --git a/src/backend/core/mda/outbound.py b/src/backend/core/mda/outbound.py index c2249b8a2..0ec5c5eef 100644 --- a/src/backend/core/mda/outbound.py +++ b/src/backend/core/mda/outbound.py @@ -364,7 +364,7 @@ def _validate_attachments_size(total_size: int) -> None: def send_message(message: models.Message, force_mta_out: bool = False): """Send an existing Message, internally or externally. - This part is called asynchronously from the celery worker. + This part is called asynchronously from the background worker. """ # Refuse to send messages that are draft or not senders diff --git a/src/backend/core/mda/outbound_tasks.py b/src/backend/core/mda/outbound_tasks.py index 29a6aca16..ae3909be1 100644 --- a/src/backend/core/mda/outbound_tasks.py +++ b/src/backend/core/mda/outbound_tasks.py @@ -2,25 +2,24 @@ # pylint: disable=unused-argument, broad-exception-raised, broad-exception-caught, too-many-lines +import logging import math +from django.conf import settings from django.db.models import Q from django.utils import timezone -from celery.utils.log import get_task_logger - from core import models from core.enums import MessageDeliveryStatusChoices from core.mda.outbound import send_message from core.mda.selfcheck import run_selfcheck +from core.utils import cron_task, register_task, set_task_progress -from messages.celery_app import app as celery_app - -logger = get_task_logger(__name__) +logger = logging.getLogger(__name__) -@celery_app.task(bind=True) -def send_message_task(self, message_id, force_mta_out=False, must_archive=False): +@register_task(queue="outbound") +def send_message_task(message_id, force_mta_out=False, must_archive=False): """Send a message asynchronously. Args: @@ -36,52 +35,48 @@ def send_message_task(self, message_id, force_mta_out=False, must_archive=False) .prefetch_related("recipients__contact") .get(id=message_id) ) + except models.Message.DoesNotExist: + error_msg = f"Message with ID '{message_id}' does not exist" + return {"success": False, "error": error_msg} - send_message(message, force_mta_out) + set_task_progress(25, {"message": "Message loaded, sending..."}) - # Update task state with progress information - self.update_state( - state="SUCCESS", - meta={ - "status": "completed", # TODO fetch recipients statuses - "message_id": str(message_id), - "success": True, - }, - ) + send_message(message, force_mta_out) - # If requested, archive the whole thread after sending - if must_archive: - try: - thread = message.thread - models.Message.objects.filter(thread=thread).update( - is_archived=True, archived_at=timezone.now() - ) - thread.update_stats() - except Exception as e: - # Not critical, just log the error - logger.exception( - "Error in send_message_task when archiving thread %s after sending message %s: %s", - thread.id, - message_id, - e, - ) + set_task_progress(75, {"message": "Message sent, processing archive..."}) - return { - "message_id": str(message_id), - "success": True, - } - # pylint: disable=broad-exception-caught - except Exception as e: - logger.exception("Error in send_message_task for message %s: %s", message_id, e) - self.update_state( - state="FAILURE", - meta={"status": "failed", "message_id": str(message_id), "error": str(e)}, - ) - raise + # If requested, archive the whole thread after sending + archived = False + if must_archive: + try: + thread = message.thread + models.Message.objects.filter(thread=thread).update( + is_archived=True, archived_at=timezone.now() + ) + thread.update_stats() + archived = True + set_task_progress(90, {"message": "Thread archived"}) + except Exception as e: + # Not critical, just log the error + logger.exception( + "Error in send_message_task when archiving thread %s after sending message %s: %s", + thread.id, + message_id, + e, + ) + + result = { + "message_id": str(message_id), + "success": True, + "archived": archived, + } + + return result -@celery_app.task(bind=True) -def selfcheck_task(self): +@cron_task(interval=settings.MESSAGES_SELFCHECK_INTERVAL) +@register_task(queue="management") +def selfcheck_task(): """Run a selfcheck of the mail delivery system. This task performs an end-to-end test of the mail delivery pipeline: @@ -96,33 +91,13 @@ def selfcheck_task(self): Returns: dict: A dictionary with success status, timings, and metrics """ - try: - result = run_selfcheck() - - # Update task state with progress information - self.update_state( - state="SUCCESS", - meta={ - "status": "completed", - "success": result["success"], - "send_time": result["send_time"], - "reception_time": result["reception_time"], - }, - ) - - return result - # pylint: disable=broad-exception-caught - except Exception as e: - logger.exception("Error in selfcheck_task: %s", e) - self.update_state( - state="FAILURE", - meta={"status": "failed", "error": str(e)}, - ) - raise + result = run_selfcheck() + return result -@celery_app.task(bind=True) -def retry_messages_task(self, message_ids=None, force_mta_out=False, batch_size=100): +@cron_task(interval=300) +@register_task(queue="outbound") +def retry_messages_task(message_ids=None, force_mta_out=False, batch_size=100): """Retry sending messages with retryable recipients (respects retry timing). Args: @@ -133,6 +108,8 @@ def retry_messages_task(self, message_ids=None, force_mta_out=False, batch_size= Returns: dict: A dictionary with task status and results """ + set_task_progress(0, {"message": "Finding messages to retry"}) + # Get messages to process # Bulk mode - find all messages with retryable recipients that are ready for retry message_filter_q = ( @@ -169,6 +146,14 @@ def retry_messages_task(self, message_ids=None, force_mta_out=False, batch_size= result["message_ids"] = message_ids return result + set_task_progress( + 10, + { + "message": f"Found {total_messages} messages to retry", + "total_messages": total_messages, + }, + ) + # Process messages in batches processed_count = 0 success_count = 0 @@ -179,9 +164,11 @@ def retry_messages_task(self, message_ids=None, force_mta_out=False, batch_size= ): # Update progress for bulk operations if index % batch_size == 0: - self.update_state( - state="PROGRESS", - meta={ + progress_percentage = min(10 + (index / max(1, total_messages)) * 80, 90) + set_task_progress( + int(progress_percentage), + { + "message": f"Processing batch {index // batch_size + 1}", "current_batch": index // batch_size + 1, "total_batches": math.ceil(total_messages / batch_size), "processed_messages": processed_count, @@ -215,7 +202,6 @@ def retry_messages_task(self, message_ids=None, force_mta_out=False, batch_size= error_count += 1 logger.exception("Failed to retry message %s: %s", message.id, e) - # Return appropriate result format result = { "success": True, "total_messages": total_messages, diff --git a/src/backend/core/services/exporter/tasks.py b/src/backend/core/services/exporter/tasks.py index 88482ff9e..f94b929d3 100644 --- a/src/backend/core/services/exporter/tasks.py +++ b/src/backend/core/services/exporter/tasks.py @@ -1,4 +1,4 @@ -"""Celery tasks for exporting mailbox messages.""" +"""Background tasks for exporting mailbox messages.""" import gzip import html @@ -12,17 +12,15 @@ from django.conf import settings from django.core.files.storage import storages -from celery.utils.log import get_task_logger -from sentry_sdk import capture_exception +import logging +logger = logging.getLogger(__name__) from core.api.utils import generate_presigned_url from core.mda.inbound import deliver_inbound_message from core.mda.rfc5322.parser import parse_email_message from core.models import Label, Mailbox, Message - -from messages.celery_app import app as celery_app - -logger = get_task_logger(__name__) +from core.utils import register_task, set_task_progress +from sentry_sdk import capture_exception # 7 days in seconds PRESIGNED_URL_EXPIRATION = 7 * 24 * 60 * 60 @@ -401,8 +399,8 @@ def _create_mbox_entry( return mbox_entry -@celery_app.task(bind=True) # pylint: disable=too-many-locals -def export_mailbox_task(self, mailbox_id: str, user_id: str) -> Dict[str, Any]: # pylint: disable=unused-argument +@register_task(queue="management") # pylint: disable=too-many-locals +def export_mailbox_task(mailbox_id: str, user_id: str) -> Dict[str, Any]: # pylint: disable=unused-argument """ Export all messages from a mailbox to an MBOX file and upload to S3. @@ -432,19 +430,15 @@ def export_mailbox_task(self, mailbox_id: str, user_id: str) -> Dict[str, Any]: "skipped_count": 0, "error": error_msg, } - self.update_state( - state="FAILURE", - meta={"result": result, "error": error_msg}, - ) return {"status": "FAILURE", "result": result, "error": error_msg} mailbox_email = str(mailbox_obj) try: # Update state to show we're starting - self.update_state( - state="PROGRESS", - meta={ + set_task_progress( + 0, + { "result": { "message_status": "Counting messages", "total_messages": 0, @@ -489,9 +483,10 @@ def export_mailbox_task(self, mailbox_id: str, user_id: str) -> Dict[str, Any]: # Update progress every 100 messages to reduce overhead if current_message % 100 == 0 or current_message == total_messages: - self.update_state( - state="PROGRESS", - meta={ + pct = min(10 + int(current_message / max(total_messages, 1) * 80), 90) + set_task_progress( + pct, + { "result": { "message_status": ( f"Exporting message {current_message} " @@ -550,9 +545,9 @@ def export_mailbox_task(self, mailbox_id: str, user_id: str) -> Dict[str, Any]: ) # Create notification message - self.update_state( - state="PROGRESS", - meta={ + set_task_progress( + 95, + { "result": { "message_status": "Creating notification", "total_messages": total_messages, @@ -588,11 +583,6 @@ def export_mailbox_task(self, mailbox_id: str, user_id: str) -> Dict[str, Any]: "s3_key": s3_key, } - self.update_state( - state="SUCCESS", - meta={"result": result, "error": None}, - ) - return {"status": "SUCCESS", "result": result, "error": None} except Exception as e: # pylint: disable=broad-exception-caught @@ -612,11 +602,6 @@ def export_mailbox_task(self, mailbox_id: str, user_id: str) -> Dict[str, Any]: "error": error_msg, } - self.update_state( - state="FAILURE", - meta={"result": result, "error": error_msg}, - ) - return {"status": "FAILURE", "result": result, "error": error_msg} diff --git a/src/backend/core/services/importer/eml_tasks.py b/src/backend/core/services/importer/eml_tasks.py index ab775d68d..98657feac 100644 --- a/src/backend/core/services/importer/eml_tasks.py +++ b/src/backend/core/services/importer/eml_tasks.py @@ -6,20 +6,19 @@ from django.conf import settings from django.core.files.storage import storages -from celery.utils.log import get_task_logger +import logging +logger = logging.getLogger(__name__) + +from core.utils import register_task, set_task_progress from sentry_sdk import capture_exception from core.mda.inbound import deliver_inbound_message from core.mda.rfc5322 import parse_email_message from core.models import Mailbox -from messages.celery_app import app as celery_app - -logger = get_task_logger(__name__) - -@celery_app.task(bind=True) -def process_eml_file_task(self, file_key: str, recipient_id: str) -> Dict[str, Any]: +@register_task(queue="imports") +def process_eml_file_task(file_key: str, recipient_id: str) -> Dict[str, Any]: """ Process an EML file asynchronously. @@ -50,9 +49,9 @@ def process_eml_file_task(self, file_key: str, recipient_id: str) -> Dict[str, A try: # Update progress state - self.update_state( - state="PROGRESS", - meta={ + set_task_progress( + 0, + { "result": { "message_status": "Processing message 1 of 1", "total_messages": 1, diff --git a/src/backend/core/services/importer/imap.py b/src/backend/core/services/importer/imap.py index 017bcac99..61c487cd4 100644 --- a/src/backend/core/services/importer/imap.py +++ b/src/backend/core/services/importer/imap.py @@ -18,12 +18,12 @@ from django.conf import settings -from celery.utils.log import get_task_logger +import logging +logger = logging.getLogger(__name__) from core.mda.inbound import deliver_inbound_message from core.mda.rfc5322 import parse_email_message - -logger = get_task_logger(__name__) +from core.utils import set_task_progress class IMAPSecurityError(RuntimeError): @@ -470,7 +470,6 @@ def process_folder_messages( # pylint: disable=too-many-arguments message_list: List[bytes], recipient: Any, username: str, - task_instance: Any, success_count: int, failure_count: int, current_message: int, @@ -489,31 +488,36 @@ def process_folder_messages( # pylint: disable=too-many-arguments flags, raw_email = _fetch_message_with_flags_retry(imap_connection, msg_num) # Check message size limit - if len(raw_email) > settings.MAX_INCOMING_EMAIL_SIZE: + if raw_email is not None and len(raw_email) > settings.MAX_INCOMING_EMAIL_SIZE: logger.warning( "Skipping oversized IMAP message: %d bytes", len(raw_email) ) failure_count += 1 - else: + elif raw_email is not None: # Parse message parsed_email = parse_email_message(raw_email) - # TODO: better heuristic to determine if the message is from the sender - is_sender = parsed_email["from"]["email"].lower() == username.lower() - - # Deliver message - if deliver_inbound_message( - str(recipient), - parsed_email, - raw_email, - is_import=True, - is_import_sender=is_sender, - imap_labels=[display_name], - imap_flags=flags, - ): - success_count += 1 + if parsed_email: + # TODO: better heuristic to determine if the message is from the sender + is_sender = parsed_email["from"]["email"].lower() == username.lower() + + # Deliver message + if deliver_inbound_message( + str(recipient), + parsed_email, + raw_email, + is_import=True, + is_import_sender=is_sender, + imap_labels=[display_name], + imap_flags=flags, + ): + success_count += 1 + else: + failure_count += 1 else: failure_count += 1 + else: + failure_count += 1 except Exception as e: logger.exception( @@ -534,9 +538,7 @@ def process_folder_messages( # pylint: disable=too-many-arguments "type": "imap", "current_message": current_message, } - task_instance.update_state( - state="PROGRESS", - meta={"result": result, "error": None}, - ) + pct = min(int(current_message / max(total_messages, 1) * 100), 99) + set_task_progress(pct, {"result": result, "error": None}) return success_count, failure_count, current_message diff --git a/src/backend/core/services/importer/imap_tasks.py b/src/backend/core/services/importer/imap_tasks.py index 46288a062..12581de6f 100644 --- a/src/backend/core/services/importer/imap_tasks.py +++ b/src/backend/core/services/importer/imap_tasks.py @@ -3,11 +3,11 @@ # pylint: disable=broad-exception-caught from typing import Any, Dict -from celery.utils.log import get_task_logger +import logging +logger = logging.getLogger(__name__) from core.models import Mailbox - -from messages.celery_app import app as celery_app +from core.utils import register_task, set_task_progress from .imap import ( IMAPConnectionManager, @@ -18,12 +18,9 @@ select_imap_folder, ) -logger = get_task_logger(__name__) - -@celery_app.task(bind=True) +@register_task(queue="imports") def import_imap_messages_task( - self, imap_server: str, imap_port: int, username: str, @@ -102,7 +99,6 @@ def import_imap_messages_task( message_list=message_list, recipient=recipient, username=username, - task_instance=self, success_count=success_count, failure_count=failure_count, current_message=current_message, @@ -144,7 +140,7 @@ def import_imap_messages_task( except Exception as e: logger.exception("Error in import_imap_messages_task: %s", e) - + error_msg = str(e) result = { "message_status": "Failed to process messages", "total_messages": total_messages, @@ -153,4 +149,4 @@ def import_imap_messages_task( "type": "imap", "current_message": current_message, } - return {"status": "FAILURE", "result": result, "error": str(e)} + return {"status": "FAILURE", "result": result, "error": error_msg} diff --git a/src/backend/core/services/importer/mbox_tasks.py b/src/backend/core/services/importer/mbox_tasks.py index fe74eae9c..93fcbe914 100644 --- a/src/backend/core/services/importer/mbox_tasks.py +++ b/src/backend/core/services/importer/mbox_tasks.py @@ -9,7 +9,10 @@ from django.conf import settings from django.core.files.storage import storages -from celery.utils.log import get_task_logger +import logging +logger = logging.getLogger(__name__) + +from core.utils import register_task, set_task_progress from sentry_sdk import capture_exception from core.mda.inbound import deliver_inbound_message @@ -17,12 +20,8 @@ from core.mda.rfc5322.parser import parse_date from core.models import Mailbox -from messages.celery_app import app as celery_app - from .s3_seekable import BUFFER_CENTERED, S3SeekableReader -logger = get_task_logger(__name__) - @dataclass class MboxMessageIndex: @@ -170,8 +169,8 @@ def _extract_and_store_index( indices.append(MboxMessageIndex(start_byte=msg_start, end_byte=msg_end, date=date)) -@celery_app.task(bind=True) -def process_mbox_file_task(self, file_key: str, recipient_id: str) -> Dict[str, Any]: +@register_task(queue="imports") +def process_mbox_file_task(file_key: str, recipient_id: str) -> Dict[str, Any]: """ Process a MBOX file asynchronously using a 2-pass approach. @@ -219,9 +218,9 @@ def process_mbox_file_task(self, file_key: str, recipient_id: str) -> Dict[str, file_key, buffer_strategy=BUFFER_CENTERED, ) as reader: - self.update_state( - state="PROGRESS", - meta={ + set_task_progress( + 0, + { "result": { "message_status": "Indexing messages", "total_messages": None, @@ -270,18 +269,18 @@ def process_mbox_file_task(self, file_key: str, recipient_id: str) -> Dict[str, for i, msg_index in enumerate(message_indices, 1): current_message = i try: - result = { - "message_status": f"Processing message {i} of {total_messages}", - "total_messages": total_messages, - "success_count": success_count, - "failure_count": failure_count, - "type": "mbox", - "current_message": i, - } - self.update_state( - state="PROGRESS", - meta={ - "result": result, + pct = min(10 + int(i / max(total_messages, 1) * 80), 90) + set_task_progress( + pct, + { + "result": { + "message_status": f"Processing message {i} of {total_messages}", + "total_messages": total_messages, + "success_count": success_count, + "failure_count": failure_count, + "type": "mbox", + "current_message": i, + }, "error": None, }, ) diff --git a/src/backend/core/services/importer/pst_tasks.py b/src/backend/core/services/importer/pst_tasks.py index 98c09baf9..bacbb76b4 100644 --- a/src/backend/core/services/importer/pst_tasks.py +++ b/src/backend/core/services/importer/pst_tasks.py @@ -7,14 +7,14 @@ from django.core.files.storage import storages import pypff -from celery.utils.log import get_task_logger +import logging from sentry_sdk import capture_exception from core.mda.inbound import deliver_inbound_message from core.mda.rfc5322 import parse_email_message from core.models import Mailbox -from messages.celery_app import app as celery_app +from core.utils import register_task, set_task_progress from .pst import ( FLAG_STATUS_FOLLOWUP, @@ -34,11 +34,11 @@ ) from .s3_seekable import BUFFER_NONE, S3SeekableReader -logger = get_task_logger(__name__) +logger = logging.getLogger(__name__) -@celery_app.task(bind=True) -def process_pst_file_task(self, file_key: str, recipient_id: str) -> Dict[str, Any]: +@register_task(queue="imports") +def process_pst_file_task(file_key: str, recipient_id: str) -> Dict[str, Any]: """ Process a PST file asynchronously. @@ -75,20 +75,14 @@ def process_pst_file_task(self, file_key: str, recipient_id: str) -> Dict[str, A try: message_imports_storage = storages["message-imports"] - self.update_state( - state="PROGRESS", - meta={ - "result": { - "message_status": "Initializing import", - "total_messages": None, - "success_count": 0, - "failure_count": 0, - "type": "pst", - "current_message": 0, - }, - "error": None, - }, - ) + set_task_progress(0, { + "message_status": "Initializing import", + "total_messages": None, + "success_count": 0, + "failure_count": 0, + "type": "pst", + "current_message": 0, + }) # Create S3 seekable reader with block-aligned LRU cache # for pypff's random-access B-tree traversal pattern. @@ -135,7 +129,8 @@ def process_pst_file_task(self, file_key: str, recipient_id: str) -> Dict[str, A failure_count += 1 continue - result = { + progress_pct = min(int((current_message / total_messages) * 100), 99) if total_messages > 0 else 0 + set_task_progress(progress_pct, { "message_status": ( f"Processing message {current_message}" f" of {total_messages}" @@ -145,14 +140,7 @@ def process_pst_file_task(self, file_key: str, recipient_id: str) -> Dict[str, A "failure_count": failure_count, "type": "pst", "current_message": current_message, - } - self.update_state( - state="PROGRESS", - meta={ - "result": result, - "error": None, - }, - ) + }) parsed_email = parse_email_message(eml_bytes) diff --git a/src/backend/core/services/importer/service.py b/src/backend/core/services/importer/service.py index d1c7f7255..661bd2bc2 100644 --- a/src/backend/core/services/importer/service.py +++ b/src/backend/core/services/importer/service.py @@ -11,7 +11,6 @@ from sentry_sdk import capture_exception from core import enums -from core.api.viewsets.task import register_task_owner from core.models import Mailbox from .eml_tasks import process_eml_file_task @@ -91,39 +90,36 @@ def import_file( # Check MIME type for PST if content_type in enums.PST_SUPPORTED_MIME_TYPES: task = process_pst_file_task.delay(file_key, str(recipient.id)) - register_task_owner(task.id, user.id) + task.track_owner(user.id) response_data = {"task_id": task.id, "type": "pst"} if request: messages.info( request, - f"Started processing PST file for recipient {recipient}. " - "This may take a while. You can check the status in the Celery task monitor.", + f"Started processing PST file for recipient {recipient}.", ) return True, response_data # Check MIME type for MBOX if content_type in enums.MBOX_SUPPORTED_MIME_TYPES: # Process MBOX file asynchronously task = process_mbox_file_task.delay(file_key, str(recipient.id)) - register_task_owner(task.id, user.id) + task.track_owner(user.id) response_data = {"task_id": task.id, "type": "mbox"} if request: messages.info( request, - f"Started processing MBOX file for recipient {recipient}. " - "This may take a while. You can check the status in the Celery task monitor.", + f"Started processing MBOX file for recipient {recipient}.", ) return True, response_data # Check MIME type for EML if content_type in enums.EML_SUPPORTED_MIME_TYPES: # Process EML file asynchronously task = process_eml_file_task.delay(file_key, str(recipient.id)) - register_task_owner(task.id, user.id) + task.track_owner(user.id) response_data = {"task_id": task.id, "type": "eml"} if request: messages.info( request, - f"Started processing EML file for recipient {recipient}. " - "This may take a while. You can check the status in the Celery task monitor.", + f"Started processing EML file for recipient {recipient}.", ) return True, response_data return False, {"detail": f"Unsupported file format: {content_type}"} @@ -180,13 +176,12 @@ def import_imap( use_ssl=use_ssl, recipient_id=str(recipient.id), ) - register_task_owner(task.id, user.id) + task.track_owner(user.id) response_data = {"task_id": task.id, "type": "imap"} if request: messages.info( request, - f"Started importing messages from IMAP server for recipient {recipient}. " - "This may take a while. You can check the status in the Celery task monitor.", + f"Started importing messages from IMAP server for recipient {recipient}.", ) return True, response_data diff --git a/src/backend/core/services/search/tasks.py b/src/backend/core/services/search/tasks.py index 9600b77d3..ebe0f9c35 100644 --- a/src/backend/core/services/search/tasks.py +++ b/src/backend/core/services/search/tasks.py @@ -4,7 +4,8 @@ from django.conf import settings -from celery.utils.log import get_task_logger +import logging +logger = logging.getLogger(__name__) from core import models from core.services.search import ( @@ -14,17 +15,11 @@ index_thread, ) -from messages.celery_app import app as celery_app +from core.utils import register_task, set_task_progress -logger = get_task_logger(__name__) - -def _reindex_all_base(update_progress=None): - """Base function for reindexing all threads and messages. - - Args: - update_progress: Optional callback function to update progress - """ +def _reindex_all_base(): + """Base function for reindexing all threads and messages.""" if not settings.OPENSEARCH_INDEX_THREADS: logger.info("OpenSearch thread indexing is disabled.") return {"success": False, "reason": "disabled"} @@ -33,25 +28,34 @@ def _reindex_all_base(update_progress=None): create_index_if_not_exists() # Get all threads and index them - threads = models.Thread.objects.all() - total = threads.count() + total = models.Thread.objects.count() + threads = models.Thread.objects.all().iterator(chunk_size=1000) success_count = 0 failure_count = 0 - for i, thread in enumerate(threads): + for i, thread in enumerate(threads, start=1): try: if index_thread(thread): success_count += 1 else: failure_count += 1 - # pylint: disable=broad-exception-caught except Exception as e: failure_count += 1 logger.exception("Error indexing thread %s: %s", thread.id, e) - # Update progress if callback provided - if update_progress and i % 100 == 0: - update_progress(i, total, success_count, failure_count) + # Update progress every 100 threads + if i % 100 == 0: + pct = min(int(i / max(total, 1) * 100), 99) + set_task_progress( + pct, + { + "message": f"Processing {i}/{total}", + "current": i, + "total": total, + "success_count": success_count, + "failure_count": failure_count, + }, + ) return { "success": True, @@ -61,27 +65,15 @@ def _reindex_all_base(update_progress=None): } -@celery_app.task(bind=True) -def reindex_all(self): - """Celery task wrapper for reindexing all threads and messages.""" - - def update_progress(current, total, success_count, failure_count): - """Update task progress.""" - self.update_state( - state="PROGRESS", - meta={ - "current": current, - "total": total, - "success_count": success_count, - "failure_count": failure_count, - }, - ) - - return _reindex_all_base(update_progress) +@register_task(queue="reindex") +def reindex_all(): + """Task wrapper for reindexing all threads and messages.""" + set_task_progress(0, {"message": "Starting full reindex"}) + return _reindex_all_base() -@celery_app.task(bind=True) -def reindex_thread_task(self, thread_id): +@register_task(queue="reindex") +def reindex_thread_task(thread_id): """Reindex a specific thread and all its messages.""" if not settings.OPENSEARCH_INDEX_THREADS: logger.info("OpenSearch thread indexing is disabled.") @@ -113,8 +105,8 @@ def reindex_thread_task(self, thread_id): raise -@celery_app.task(bind=True) -def reindex_mailbox_task(self, mailbox_id): +@register_task(queue="reindex") +def reindex_mailbox_task(mailbox_id): """Reindex all threads and messages in a specific mailbox.""" if not settings.OPENSEARCH_INDEX_THREADS: logger.info("OpenSearch thread indexing is disabled.") @@ -122,29 +114,37 @@ def reindex_mailbox_task(self, mailbox_id): # Ensure index exists first create_index_if_not_exists() + set_task_progress(0, {"message": "Reindex mailbox started", "mailbox_id": str(mailbox_id)}) # Get all threads in the mailbox - threads = models.Mailbox.objects.get(id=mailbox_id).threads_viewer - total = threads.count() + try: + threads_qs = models.Mailbox.objects.get(id=mailbox_id).threads_viewer + except models.Mailbox.DoesNotExist: + logger.error("Mailbox %s does not exist", mailbox_id) + return {"mailbox_id": str(mailbox_id), "success": False, "error": "mailbox_not_found"} + + total = threads_qs.count() + threads = threads_qs.iterator(chunk_size=1000) success_count = 0 failure_count = 0 - for i, thread in enumerate(threads): + for i, thread in enumerate(threads, start=1): try: if index_thread(thread): success_count += 1 else: failure_count += 1 - # pylint: disable=broad-exception-caught except Exception as e: failure_count += 1 logger.exception("Error indexing thread %s: %s", thread.id, e) # Update progress every 50 threads if i % 50 == 0: - self.update_state( - state="PROGRESS", - meta={ + pct = min(int(i / max(total, 1) * 100), 99) + set_task_progress( + pct, + { + "message": f"Mailbox {mailbox_id} {i}/{total}", "current": i, "total": total, "success_count": success_count, @@ -161,8 +161,8 @@ def reindex_mailbox_task(self, mailbox_id): } -@celery_app.task(bind=True) -def index_message_task(self, message_id): +@register_task(queue="reindex") +def index_message_task(message_id): """Index a single message.""" if not settings.OPENSEARCH_INDEX_THREADS: logger.info("OpenSearch message indexing is disabled.") @@ -201,8 +201,8 @@ def index_message_task(self, message_id): raise -@celery_app.task(bind=True) -def reset_search_index(self): +@register_task(queue="reindex") +def reset_search_index(): """Delete and recreate the OpenSearch index.""" delete_index() diff --git a/src/backend/core/tasks.py b/src/backend/core/tasks.py index b5fe2820d..676b9cc31 100644 --- a/src/backend/core/tasks.py +++ b/src/backend/core/tasks.py @@ -1,5 +1,5 @@ # pylint: disable=wildcard-import, unused-wildcard-import -"""Register all tasks here so that Celery autodiscovery can find them.""" +"""Register all tasks here so that Dramatiq autodiscovery can find them.""" from core.mda.inbound_tasks import * # noqa: F403 from core.mda.outbound_tasks import * # noqa: F403 @@ -8,4 +8,5 @@ from core.services.importer.imap_tasks import * # noqa: F403 from core.services.importer.mbox_tasks import * # noqa: F403 from core.services.importer.pst_tasks import * # noqa: F403 +from core.services.exporter.tasks import * # noqa: F403 from core.services.search.tasks import * # noqa: F403 diff --git a/src/backend/core/tests/api/test_draft_attachments.py b/src/backend/core/tests/api/test_draft_attachments.py index 99d179a11..9380eb42f 100644 --- a/src/backend/core/tests/api/test_draft_attachments.py +++ b/src/backend/core/tests/api/test_draft_attachments.py @@ -133,9 +133,9 @@ def test_draft_add_attachment_to_existing_draft_and_send( mailbox=user_mailbox, email=sender_email, name=user_mailbox.local_part ) - # Create a draft message + # Create a draft message (is_sender=True since this is an outbound draft) draft = factories.MessageFactory( - thread=thread, sender=sender, is_draft=True, subject="Existing draft" + thread=thread, sender=sender, is_draft=True, is_sender=True, subject="Existing draft" ) # attachment blob should already be created diff --git a/src/backend/core/tests/api/test_import_file_upload.py b/src/backend/core/tests/api/test_import_file_upload.py index 3c450fd96..879c5ea9b 100644 --- a/src/backend/core/tests/api/test_import_file_upload.py +++ b/src/backend/core/tests/api/test_import_file_upload.py @@ -1,8 +1,10 @@ """Test suite for ImportFileUploadViewSet.""" # pylint: disable=redefined-outer-name, unused-argument +import json from unittest import mock +from django.core.cache import cache from django.urls import reverse import pytest @@ -11,7 +13,7 @@ from core import enums, factories from core.api.utils import get_file_key -from core.api.viewsets.task import register_task_owner +from core.utils import TASK_TRACKING_CACHE_TTL pytestmark = pytest.mark.django_db @@ -52,33 +54,37 @@ def test_api_task_detail_other_user_should_be_forbidden(self): user2 = factories.UserFactory() task_id = "test-task-id-12345" - register_task_owner(task_id, user1.id) + # Register tracking metadata (as task.track_owner() would) + cache.set( + f"task_tracking:{task_id}", + json.dumps({ + "owner": str(user1.id), + "actor_name": "test_actor", + "queue_name": "default", + }), + timeout=TASK_TRACKING_CACHE_TTL, + ) url = reverse("task-detail", kwargs={"task_id": task_id}) - with mock.patch("core.api.viewsets.task.AsyncResult") as mock_async_result: - mock_result = mock.MagicMock() - mock_result.status = "SUCCESS" - mock_result.state = "SUCCESS" - mock_result.result = { - "status": "SUCCESS", - "result": {"imported": 42, "mailbox_id": "sensitive-data"}, - "error": None, - } - mock_result.info = None - mock_async_result.return_value = mock_result + result_data = { + "status": "SUCCESS", + "result": {"imported": 42, "mailbox_id": "sensitive-data"}, + "error": None, + } - # User2 tries to access user1's task - should be denied - client2 = APIClient() - client2.force_authenticate(user=user2) - response = client2.get(url) - assert response.status_code == status.HTTP_403_FORBIDDEN + # User2 tries to access user1's task - should be denied + client2 = APIClient() + client2.force_authenticate(user=user2) + response = client2.get(url) + assert response.status_code == status.HTTP_403_FORBIDDEN - # User1 (owner) accesses their own task - should succeed - client1 = APIClient() - client1.force_authenticate(user=user1) + # User1 (owner) accesses their own task - should succeed + client1 = APIClient() + client1.force_authenticate(user=user1) + with mock.patch("dramatiq.Message.get_result", return_value=result_data): response = client1.get(url) - assert response.status_code == status.HTTP_200_OK - assert response.data["result"]["imported"] == 42 + assert response.status_code == status.HTTP_200_OK + assert response.data["result"]["imported"] == 42 class TestImportViewSetPermissions: diff --git a/src/backend/core/tests/api/test_messages_create.py b/src/backend/core/tests/api/test_messages_create.py index 862b183ba..61c3f4f48 100644 --- a/src/backend/core/tests/api/test_messages_create.py +++ b/src/backend/core/tests/api/test_messages_create.py @@ -1054,6 +1054,7 @@ def test_send_message_with_send_roles( message2 = factories.MessageFactory( thread=thread_access.thread, is_draft=True, + is_sender=True, sender=factories.ContactFactory(mailbox=mailbox), ) factories.MessageRecipientFactory( diff --git a/src/backend/core/tests/api/test_messages_import.py b/src/backend/core/tests/api/test_messages_import.py index 326ec6d5f..5857c0545 100644 --- a/src/backend/core/tests/api/test_messages_import.py +++ b/src/backend/core/tests/api/test_messages_import.py @@ -424,10 +424,9 @@ def test_api_import_duplicate_eml_file(api_client, user, mailbox, eml_file): # Run the task synchronously for testing with a task_id eml_key = get_file_key(user.id, eml_file.name) - task_result = process_eml_file_task.apply( - kwargs={"file_key": eml_key, "recipient_id": str(mailbox.id)}, - task_id="fake-task-id-1", - ).get() + task_result = process_eml_file_task( + file_key=eml_key, recipient_id=str(mailbox.id) + ) assert task_result["status"] == "SUCCESS" assert task_result["result"]["success_count"] == 1 assert task_result["result"]["failure_count"] == 0 @@ -450,10 +449,9 @@ def test_api_import_duplicate_eml_file(api_client, user, mailbox, eml_file): mock_task.assert_called_once() # Run the task synchronously for testing with a task_id - task_result = process_eml_file_task.apply( - kwargs={"file_key": eml_key, "recipient_id": str(mailbox.id)}, - task_id="fake-task-id-2", - ).get() + task_result = process_eml_file_task( + file_key=eml_key, recipient_id=str(mailbox.id) + ) assert task_result["status"] == "SUCCESS" assert task_result["result"]["success_count"] == 1 # Still counts as success assert task_result["result"]["failure_count"] == 0 @@ -487,10 +485,9 @@ def test_api_import_duplicate_mbox_file(api_client, user, mailbox, mbox_file): # Run the task synchronously for testing with a task_id - task_result = process_mbox_file_task.apply( - kwargs={"file_key": mbox_key, "recipient_id": str(mailbox.id)}, - task_id="fake-task-id-1", - ).get() + task_result = process_mbox_file_task( + file_key=mbox_key, recipient_id=str(mailbox.id) + ) assert task_result["status"] == "SUCCESS" assert ( task_result["result"]["success_count"] == 3 @@ -516,10 +513,9 @@ def test_api_import_duplicate_mbox_file(api_client, user, mailbox, mbox_file): mock_task.assert_called_once() # Run the task synchronously for testing with a task_id - task_result = process_mbox_file_task.apply( - kwargs={"file_key": mbox_key, "recipient_id": str(mailbox.id)}, - task_id="fake-task-id-2", - ).get() + task_result = process_mbox_file_task( + file_key=mbox_key, recipient_id=str(mailbox.id) + ) assert task_result["status"] == "SUCCESS" assert task_result["result"]["success_count"] == 3 # Still counts as success assert task_result["result"]["failure_count"] == 0 @@ -558,10 +554,9 @@ def test_api_import_eml_same_message_different_mailboxes(api_client, user, eml_f mock_task.assert_called_once() # Run the task synchronously for testing with a task_id - task_result = process_eml_file_task.apply( - kwargs={"file_key": eml_key, "recipient_id": str(mailbox1.id)}, - task_id="fake-task-id-1", - ).get() + task_result = process_eml_file_task( + file_key=eml_key, recipient_id=str(mailbox1.id) + ) assert task_result["status"] == "SUCCESS" assert task_result["result"]["success_count"] == 1 assert task_result["result"]["failure_count"] == 0 @@ -584,10 +579,9 @@ def test_api_import_eml_same_message_different_mailboxes(api_client, user, eml_f mock_task.assert_called_once() # Run the task synchronously for testing with a task_id - task_result = process_eml_file_task.apply( - kwargs={"file_key": eml_key, "recipient_id": str(mailbox2.id)}, - task_id="fake-task-id-2", - ).get() + task_result = process_eml_file_task( + file_key=eml_key, recipient_id=str(mailbox2.id) + ) assert task_result["status"] == "SUCCESS" assert task_result["result"]["success_count"] == 1 assert task_result["result"]["failure_count"] == 0 @@ -633,10 +627,9 @@ def test_api_import_mbox_same_message_different_mailboxes(api_client, user, mbox mock_task.assert_called_once() # Run the task synchronously for testing with a task_id - task_result = process_mbox_file_task.apply( - kwargs={"file_key": mbox_key, "recipient_id": str(mailbox1.id)}, - task_id="fake-task-id-1", - ).get() + task_result = process_mbox_file_task( + file_key=mbox_key, recipient_id=str(mailbox1.id) + ) assert task_result["status"] == "SUCCESS" assert task_result["result"]["success_count"] == 3 assert task_result["result"]["failure_count"] == 0 @@ -659,10 +652,9 @@ def test_api_import_mbox_same_message_different_mailboxes(api_client, user, mbox mock_task.assert_called_once() # Run the task synchronously for testing with a task_id - task_result = process_mbox_file_task.apply( - kwargs={"file_key": mbox_key, "recipient_id": str(mailbox2.id)}, - task_id="fake-task-id-2", - ).get() + task_result = process_mbox_file_task( + file_key=mbox_key, recipient_id=str(mailbox2.id) + ) assert task_result["status"] == "SUCCESS" assert task_result["result"]["success_count"] == 3 assert task_result["result"]["failure_count"] == 0 diff --git a/src/backend/core/tests/exporter/test_export_task.py b/src/backend/core/tests/exporter/test_export_task.py index 47c14254e..aa7595ab0 100644 --- a/src/backend/core/tests/exporter/test_export_task.py +++ b/src/backend/core/tests/exporter/test_export_task.py @@ -96,15 +96,9 @@ def cleanup_exports(): @pytest.mark.django_db def test_export_empty_mailbox(mailbox_fixture, admin_user, cleanup_exports): """Test exporting a mailbox with no messages creates empty MBOX.""" - mock_task = MagicMock() - - # Mock update_state (required when calling task directly, not via .delay()) - # and deliver_inbound_message to avoid creating notification - with ( - patch.object(export_mailbox_task, "update_state", mock_task.update_state), - patch( - "core.services.exporter.tasks.deliver_inbound_message", return_value=True - ), + + with patch( + "core.services.exporter.tasks.deliver_inbound_message", return_value=True ): result = export_mailbox_task(str(mailbox_fixture.id), str(admin_user.id)) @@ -132,13 +126,9 @@ def test_export_empty_mailbox(mailbox_fixture, admin_user, cleanup_exports): def test_export_single_message(mailbox_fixture, admin_user, cleanup_exports): """Test exporting a mailbox with one message.""" create_test_message(mailbox_fixture, "Test Subject", "Test body content") - mock_task = MagicMock() - with ( - patch.object(export_mailbox_task, "update_state", mock_task.update_state), - patch( - "core.services.exporter.tasks.deliver_inbound_message", return_value=True - ), + with patch( + "core.services.exporter.tasks.deliver_inbound_message", return_value=True ): result = export_mailbox_task(str(mailbox_fixture.id), str(admin_user.id)) @@ -168,13 +158,9 @@ def test_export_multiple_messages(mailbox_fixture, admin_user, cleanup_exports): create_test_message(mailbox_fixture, "Message 1", "Body 1") create_test_message(mailbox_fixture, "Message 2", "Body 2") create_test_message(mailbox_fixture, "Message 3", "Body 3") - mock_task = MagicMock() - with ( - patch.object(export_mailbox_task, "update_state", mock_task.update_state), - patch( - "core.services.exporter.tasks.deliver_inbound_message", return_value=True - ), + with patch( + "core.services.exporter.tasks.deliver_inbound_message", return_value=True ): result = export_mailbox_task(str(mailbox_fixture.id), str(admin_user.id)) @@ -202,13 +188,8 @@ def test_export_skips_missing_blob(mailbox_fixture, admin_user, cleanup_exports) is_sender=False, ) - mock_task = MagicMock() - - with ( - patch.object(export_mailbox_task, "update_state", mock_task.update_state), - patch( - "core.services.exporter.tasks.deliver_inbound_message", return_value=True - ), + with patch( + "core.services.exporter.tasks.deliver_inbound_message", return_value=True ): result = export_mailbox_task(str(mailbox_fixture.id), str(admin_user.id)) @@ -226,7 +207,6 @@ def test_export_creates_notification_message( ): """Test that a notification message is created after export.""" create_test_message(mailbox_fixture, "Test Message", "Test body") - mock_task = MagicMock() deliver_called = [] @@ -234,12 +214,9 @@ def mock_deliver(*args, **kwargs): deliver_called.append((args, kwargs)) return True - with ( - patch.object(export_mailbox_task, "update_state", mock_task.update_state), - patch( - "core.services.exporter.tasks.deliver_inbound_message", - side_effect=mock_deliver, - ), + with patch( + "core.services.exporter.tasks.deliver_inbound_message", + side_effect=mock_deliver, ): result = export_mailbox_task(str(mailbox_fixture.id), str(admin_user.id)) @@ -256,9 +233,8 @@ def mock_deliver(*args, **kwargs): @pytest.mark.django_db def test_export_nonexistent_mailbox(admin_user): """Test exporting a non-existent mailbox returns failure.""" - mock_task = MagicMock() - with patch.object(export_mailbox_task, "update_state", mock_task.update_state): + with patch("core.services.exporter.tasks.deliver_inbound_message", return_value=True): result = export_mailbox_task( "00000000-0000-0000-0000-000000000000", str(admin_user.id) ) @@ -286,13 +262,10 @@ def test_admin_export_view_requires_post(admin_client, mailbox_fixture): @pytest.mark.django_db def test_admin_export_view_starts_task(admin_client, mailbox_fixture): - """Test that POST to export view starts the celery task.""" + """Test that POST to export view starts the background task.""" url = reverse("admin:core_mailbox_export", args=[mailbox_fixture.pk]) - with ( - patch("core.admin.export_mailbox_task") as mock_task, - patch("core.admin.register_task_owner"), - ): + with patch("core.admin.export_mailbox_task") as mock_task: mock_task.delay.return_value = Mock(id="test-task-id") response = admin_client.post(url) @@ -332,13 +305,8 @@ def test_export_reimport_roundtrip(domain, cleanup_exports): assert original_count == 3 # 2. Export mailbox A - mock_task = MagicMock() - - with ( - patch.object(export_mailbox_task, "update_state", mock_task.update_state), - patch( - "core.services.exporter.tasks.deliver_inbound_message", return_value=True - ), + with patch( + "core.services.exporter.tasks.deliver_inbound_message", return_value=True ): export_result = export_mailbox_task(str(mailbox_a.id), str(user.id)) @@ -370,15 +338,10 @@ def test_export_reimport_roundtrip(domain, cleanup_exports): ) cleanup_exports.append(import_key) - # 4. Import into mailbox B - mock_import_task = MagicMock() - - with patch.object( - process_mbox_file_task, "update_state", mock_import_task.update_state - ): - import_result = process_mbox_file_task( - file_key=import_key, recipient_id=str(mailbox_b.id) - ) + # 4. Import into mailbox B (no mock — actually deliver messages) + import_result = process_mbox_file_task( + file_key=import_key, recipient_id=str(mailbox_b.id) + ) assert import_result["status"] == "SUCCESS" assert import_result["result"]["success_count"] == 3 @@ -421,13 +384,8 @@ def test_export_includes_status_headers(mailbox_fixture, admin_user, cleanup_exp msg.is_starred = True # Starred msg.save() - mock_task = MagicMock() - - with ( - patch.object(export_mailbox_task, "update_state", mock_task.update_state), - patch( - "core.services.exporter.tasks.deliver_inbound_message", return_value=True - ), + with patch( + "core.services.exporter.tasks.deliver_inbound_message", return_value=True ): result = export_mailbox_task(str(mailbox_fixture.id), str(admin_user.id)) @@ -486,13 +444,8 @@ def test_export_headers_prepended_before_received( ) msg.thread.labels.add(label) - mock_task = MagicMock() - - with ( - patch.object(export_mailbox_task, "update_state", mock_task.update_state), - patch( - "core.services.exporter.tasks.deliver_inbound_message", return_value=True - ), + with patch( + "core.services.exporter.tasks.deliver_inbound_message", return_value=True ): result = export_mailbox_task(str(mailbox_fixture.id), str(admin_user.id)) @@ -535,13 +488,8 @@ def test_export_includes_labels_as_x_keywords( ) msg.thread.labels.add(label1, label2) - mock_task = MagicMock() - - with ( - patch.object(export_mailbox_task, "update_state", mock_task.update_state), - patch( - "core.services.exporter.tasks.deliver_inbound_message", return_value=True - ), + with patch( + "core.services.exporter.tasks.deliver_inbound_message", return_value=True ): result = export_mailbox_task(str(mailbox_fixture.id), str(admin_user.id)) @@ -577,13 +525,8 @@ def test_export_labels_with_spaces_are_quoted( ) msg.thread.labels.add(label) - mock_task = MagicMock() - - with ( - patch.object(export_mailbox_task, "update_state", mock_task.update_state), - patch( - "core.services.exporter.tasks.deliver_inbound_message", return_value=True - ), + with patch( + "core.services.exporter.tasks.deliver_inbound_message", return_value=True ): result = export_mailbox_task(str(mailbox_fixture.id), str(admin_user.id)) @@ -611,13 +554,8 @@ def test_export_unread_message_status(mailbox_fixture, admin_user, cleanup_expor msg.is_unread = True msg.save() - mock_task = MagicMock() - - with ( - patch.object(export_mailbox_task, "update_state", mock_task.update_state), - patch( - "core.services.exporter.tasks.deliver_inbound_message", return_value=True - ), + with patch( + "core.services.exporter.tasks.deliver_inbound_message", return_value=True ): result = export_mailbox_task(str(mailbox_fixture.id), str(admin_user.id)) diff --git a/src/backend/core/tests/importer/test_file_import.py b/src/backend/core/tests/importer/test_file_import.py index 9afdccdf4..acae0d1a3 100644 --- a/src/backend/core/tests/importer/test_file_import.py +++ b/src/backend/core/tests/importer/test_file_import.py @@ -109,15 +109,9 @@ def test_import_eml_file(admin_client, eml_file, mailbox): ) # Create a mock task instance - mock_task = MagicMock() - mock_task.update_state = MagicMock() - - with ( - patch( - "core.services.importer.eml_tasks.process_eml_file_task.delay" - ) as mock_delay, - patch.object(process_eml_file_task, "update_state", mock_task.update_state), - ): + with patch( + "core.services.importer.eml_tasks.process_eml_file_task.delay" + ) as mock_delay: mock_delay.return_value.id = "fake-task-id" # Submit the form response = admin_client.post( @@ -137,50 +131,50 @@ def test_import_eml_file(admin_client, eml_file, mailbox): with patch("core.services.importer.eml_tasks.storages") as mock_storages: mock_storages.__getitem__.return_value = mock_storage - # Run the task synchronously for testing - task_result = process_eml_file_task( - file_key="test-file-key.eml", recipient_id=str(mailbox.id) - ) - assert ( - task_result["result"]["message_status"] - == "Completed processing message" - ) - assert task_result["result"]["type"] == "eml" - assert task_result["result"]["total_messages"] == 1 - assert task_result["result"]["success_count"] == 1 - assert task_result["result"]["failure_count"] == 0 - assert task_result["result"]["current_message"] == 1 - - # Verify only PROGRESS update_state was called (no SUCCESS — - # Celery infers SUCCESS from normal return) - assert mock_task.update_state.call_count == 1 - - mock_task.update_state.assert_called_once_with( - state="PROGRESS", - meta={ - "result": { - "message_status": "Processing message 1 of 1", - "total_messages": 1, - "success_count": 0, - "failure_count": 0, - "type": "eml", - "current_message": 1, + with patch("core.services.importer.eml_tasks.set_task_progress") as mock_set_progress: + # Run the task synchronously for testing + task_result = process_eml_file_task( + file_key="test-file-key.eml", recipient_id=str(mailbox.id) + ) + assert ( + task_result["result"]["message_status"] + == "Completed processing message" + ) + assert task_result["result"]["type"] == "eml" + assert task_result["result"]["total_messages"] == 1 + assert task_result["result"]["success_count"] == 1 + assert task_result["result"]["failure_count"] == 0 + assert task_result["result"]["current_message"] == 1 + + # Verify progress was set once during processing + assert mock_set_progress.call_count == 1 + + mock_set_progress.assert_called_once_with( + 0, # progress percentage + { + "result": { + "message_status": "Processing message 1 of 1", + "total_messages": 1, + "success_count": 0, + "failure_count": 0, + "type": "eml", + "current_message": 1, + }, + "error": None, }, - "error": None, - }, - ) + ) - # check that the message was created - assert Message.objects.count() == 1 - message = Message.objects.first() - assert message.subject == "Mon mail avec joli pj" - assert message.has_attachments is True - assert message.sender.email == "sender@example.com" - assert message.recipients.get().contact.email == "recipient@example.com" - assert message.sent_at == message.thread.messaged_at - assert message.sent_at == ( - datetime.datetime(2025, 5, 26, 20, 13, 44, tzinfo=datetime.timezone.utc) - ) + # check that the message was created + assert Message.objects.count() == 1 + message = Message.objects.first() + assert message.subject == "Mon mail avec joli pj" + assert message.has_attachments is True + assert message.sender.email == "sender@example.com" + assert message.recipients.get().contact.email == "recipient@example.com" + assert message.sent_at == message.thread.messaged_at + assert message.sent_at == ( + datetime.datetime(2025, 5, 26, 20, 13, 44, tzinfo=datetime.timezone.utc) + ) def _upload_to_s3(content, file_key="test-mbox-key"): @@ -197,16 +191,11 @@ def _upload_to_s3(content, file_key="test-mbox-key"): @pytest.mark.django_db def test_process_mbox_file_task(mailbox, mbox_file): - """Test the Celery task that processes MBOX files.""" + """Test the task that processes MBOX files.""" file_key, storage, s3_client = _upload_to_s3(mbox_file) try: - mock_task = MagicMock() - mock_task.update_state = MagicMock() - - with patch.object( - process_mbox_file_task, "update_state", mock_task.update_state - ): + with patch("core.services.importer.mbox_tasks.set_task_progress") as mock_set_progress: task_result = process_mbox_file_task( file_key=file_key, recipient_id=str(mailbox.id) ) @@ -221,14 +210,14 @@ def test_process_mbox_file_task(mailbox, mbox_file): assert task_result["result"]["failure_count"] == 0 assert task_result["result"]["current_message"] == 3 - # 1 indexing + 3 per-message PROGRESS = 4 - assert mock_task.update_state.call_count == 4 + # Verify progress updates were called (1 indexing + 3 per-message = 4) + assert mock_set_progress.call_count == 4 # Verify per-message progress updates for i in range(1, 4): - mock_task.update_state.assert_any_call( - state="PROGRESS", - meta={ + mock_set_progress.assert_any_call( + min(10 + int((i / max(1, 3) * 80)), 90), # 10-90% range + { "result": { "message_status": f"Processing message {i} of 3", "total_messages": 3, @@ -241,35 +230,35 @@ def test_process_mbox_file_task(mailbox, mbox_file): }, ) - # Verify messages were created - assert Message.objects.count() == 3 - messages = Message.objects.order_by("created_at") - - # Check thread for each message - assert messages[0].thread is not None - assert messages[1].thread is not None - assert messages[2].thread is not None - assert messages[2].thread.messages.count() == 2 - assert messages[1].thread == messages[2].thread - # Check created_at dates match between messages and threads - assert messages[0].sent_at == messages[0].thread.messaged_at - assert messages[2].sent_at == messages[1].thread.messaged_at - assert messages[2].sent_at == ( - datetime.datetime(2025, 5, 26, 20, 18, 4, tzinfo=datetime.timezone.utc) - ) + # Verify messages were created + assert Message.objects.count() == 3 + messages = Message.objects.order_by("created_at") + + # Check thread for each message + assert messages[0].thread is not None + assert messages[1].thread is not None + assert messages[2].thread is not None + assert messages[2].thread.messages.count() == 2 + assert messages[1].thread == messages[2].thread + # Check created_at dates match between messages and threads + assert messages[0].sent_at == messages[0].thread.messaged_at + assert messages[2].sent_at == messages[1].thread.messaged_at + assert messages[2].sent_at == ( + datetime.datetime(2025, 5, 26, 20, 18, 4, tzinfo=datetime.timezone.utc) + ) - # Check messages - assert messages[0].subject == "Mon mail avec joli pj" - assert messages[0].has_attachments is True + # Check messages + assert messages[0].subject == "Mon mail avec joli pj" + assert messages[0].has_attachments is True - assert messages[1].subject == "Je t'envoie encore un message..." - body1 = messages[1].get_parsed_field("textBody")[0]["content"] - assert "Lorem ipsum dolor sit amet" in body1 + assert messages[1].subject == "Je t'envoie encore un message..." + body1 = messages[1].get_parsed_field("textBody")[0]["content"] + assert "Lorem ipsum dolor sit amet" in body1 - assert messages[2].subject == "Re: Je t'envoie encore un message..." - body2 = messages[2].get_parsed_field("textBody")[0]["content"] - assert "Yes !" in body2 - assert "Lorem ipsum dolor sit amet" in body2 + assert messages[2].subject == "Re: Je t'envoie encore un message..." + body2 = messages[2].get_parsed_field("textBody")[0]["content"] + assert "Yes !" in body2 + assert "Lorem ipsum dolor sit amet" in body2 finally: s3_client.delete_object(Bucket=storage.bucket_name, Key=file_key) @@ -380,18 +369,11 @@ def test_import_message_to_different_mailbox_same_domain(domain): This is a test message addressed to mailbox_b. """ - # Create a mock task instance - mock_task = MagicMock() - mock_task.update_state = MagicMock() - # Mock storage mock_storage = mock_storage_open(eml_content) # Import from mailbox_a - with ( - patch.object(process_eml_file_task, "update_state", mock_task.update_state), - patch("core.services.importer.eml_tasks.storages") as mock_storages, - ): + with patch("core.services.importer.eml_tasks.storages") as mock_storages: mock_storages.__getitem__.return_value = mock_storage # Run the task synchronously for testing, importing from mailbox_a task_result = process_eml_file_task( @@ -451,18 +433,11 @@ def test_import_message_with_from_equal_to_mailbox_sets_is_sender(domain): This is a test message sent from the mailbox. """.encode("utf-8") - # Create a mock task instance - mock_task = MagicMock() - mock_task.update_state = MagicMock() - # Mock storage mock_storage = mock_storage_open(eml_content) # Import the message - with ( - patch.object(process_eml_file_task, "update_state", mock_task.update_state), - patch("core.services.importer.eml_tasks.storages") as mock_storages, - ): + with patch("core.services.importer.eml_tasks.storages") as mock_storages: mock_storages.__getitem__.return_value = mock_storage # Run the task synchronously for testing task_result = process_eml_file_task( diff --git a/src/backend/core/tests/importer/test_imap_import.py b/src/backend/core/tests/importer/test_imap_import.py index aabb184e4..15b0e1404 100644 --- a/src/backend/core/tests/importer/test_imap_import.py +++ b/src/backend/core/tests/importer/test_imap_import.py @@ -14,8 +14,6 @@ from core.models import Mailbox, MailDomain, Message, Thread from core.services.importer.imap_tasks import import_imap_messages_task -from messages.celery_app import app as celery_app - @pytest.fixture def admin_user(db): @@ -174,21 +172,13 @@ def test_imap_import_form_view(admin_client, mailbox): @patch("imaplib.IMAP4_SSL") -@patch.object(celery_app.backend, "store_result") def test_imap_import_task_success( - mock_store_result, mock_imap4_ssl, mailbox, mock_imap_connection, sample_email + mock_imap4_ssl, mailbox, mock_imap_connection, sample_email ): """Test successful IMAP import task execution.""" mock_imap4_ssl.return_value = mock_imap_connection - mock_store_result.return_value = None - - # Create a mock task instance - mock_task = MagicMock() - mock_task.update_state = MagicMock() - with patch.object( - import_imap_messages_task, "update_state", mock_task.update_state - ): + with patch("core.services.importer.imap.set_task_progress") as mock_set_progress: # Run the task task = import_imap_messages_task( imap_server="imap.example.com", @@ -212,17 +202,18 @@ def test_imap_import_task_success( assert task["result"]["current_message"] == 3 # Verify progress updates were called correctly - assert mock_task.update_state.call_count == 3 # 3 PROGRESS + assert mock_set_progress.call_count == 3 # 3 PROGRESS updates # Verify progress updates for i in range(1, 4): - mock_task.update_state.assert_any_call( - state="PROGRESS", - meta={ + pct = min(int(i / max(3, 1) * 100), 99) + mock_set_progress.assert_any_call( + pct, + { "result": { "message_status": f"Processing message {i} of 3", "total_messages": 3, - "success_count": i, # Current message was successful + "success_count": i, "failure_count": 0, "type": "imap", "current_message": i, @@ -231,44 +222,34 @@ def test_imap_import_task_success( }, ) - # No SUCCESS update_state — Celery infers SUCCESS from normal return; - # status is in the returned dict - # Verify messages were created assert Message.objects.count() == 3 assert Thread.objects.count() == 3 - # check one of the messages - message = Message.objects.last() - assert message.subject == "Test Subject" - assert message.sender.email == "sender@example.com" - assert message.recipients.count() == 1 - assert message.recipients.first().contact.email == "recipient@example.com" - assert ( - message.get_parsed_field("textBody")[0]["content"] - == "This is a test message body.\n" - ) - assert message.attachments.count() == 0 - assert message.thread.messages.count() == 1 - assert message.thread.messages.first() == message - assert message.created_at == message.thread.messaged_at - assert message.created_at == datetime.datetime( - 2024, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc - ) + # check one of the messages + message = Message.objects.last() + assert message.subject == "Test Subject" + assert message.sender.email == "sender@example.com" + assert message.recipients.count() == 1 + assert message.recipients.first().contact.email == "recipient@example.com" + assert ( + message.get_parsed_field("textBody")[0]["content"] + == "This is a test message body.\n" + ) + assert message.attachments.count() == 0 + assert message.thread.messages.count() == 1 + assert message.thread.messages.first() == message + assert message.created_at == message.thread.messaged_at + assert message.created_at == datetime.datetime( + 2024, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc + ) @pytest.mark.django_db def test_imap_import_task_login_failure(mailbox): """Test IMAP import task with login failure.""" - # Create a mock task instance - mock_task = MagicMock() - mock_task.update_state = MagicMock() - # Mock IMAP connection to raise an error on login - with ( - patch.object(import_imap_messages_task, "update_state", mock_task.update_state), - patch("core.services.importer.imap.imaplib.IMAP4_SSL") as mock_imap, - ): + with patch("imaplib.IMAP4_SSL") as mock_imap: mock_imap_instance = MagicMock() mock_imap.return_value = mock_imap_instance mock_imap_instance.login.side_effect = Exception("Login failed") @@ -292,20 +273,13 @@ def test_imap_import_task_login_failure(mailbox): assert task_result["result"]["current_message"] == 0 assert "Login failed" in task_result["error"] - # No update_state calls — failure status is in the returned dict - mock_task.update_state.assert_not_called() - # Verify no messages were created assert Message.objects.count() == 0 @patch("imaplib.IMAP4_SSL") -@patch.object(celery_app.backend, "store_result") -def test_imap_import_task_message_fetch_failure( - mock_store_result, mock_imap4_ssl, mailbox -): +def test_imap_import_task_message_fetch_failure(mock_imap4_ssl, mailbox): """Test IMAP import task with message fetch failure.""" - mock_store_result.return_value = None mock_imap = MagicMock() mock_imap.login.return_value = ("OK", [b"Logged in"]) @@ -320,13 +294,7 @@ def test_imap_import_task_message_fetch_failure( mock_imap.logout.return_value = ("OK", [b"Logged out"]) mock_imap4_ssl.return_value = mock_imap - # Create a mock task instance - mock_task = MagicMock() - mock_task.update_state = MagicMock() - - with patch.object( - import_imap_messages_task, "update_state", mock_task.update_state - ): + with patch("core.services.importer.imap.set_task_progress") as mock_set_progress: # Run the task task = import_imap_messages_task( imap_server="imap.example.com", @@ -350,13 +318,14 @@ def test_imap_import_task_message_fetch_failure( assert task["result"]["current_message"] == 3 # Verify progress updates were called correctly - assert mock_task.update_state.call_count == 3 # 3 PROGRESS + assert mock_set_progress.call_count == 3 # 3 PROGRESS - # Verify progress updates + # Verify progress updates - each message fails for i in range(1, 4): - mock_task.update_state.assert_any_call( - state="PROGRESS", - meta={ + pct = min(int(i / max(3, 1) * 100), 99) + mock_set_progress.assert_any_call( + pct, + { "result": { "message_status": f"Processing message {i} of 3", "total_messages": 3, @@ -369,15 +338,10 @@ def test_imap_import_task_message_fetch_failure( }, ) - # No SUCCESS update_state — Celery infers SUCCESS from normal return; - # status is in the returned dict - @patch("core.mda.inbound.logger") @patch("imaplib.IMAP4_SSL") -@patch.object(celery_app.backend, "store_result") def test_imap_import_task_duplicate_recipients( - mock_store_result, mock_imap4_ssl, mock_logger, mailbox, @@ -385,74 +349,66 @@ def test_imap_import_task_duplicate_recipients( ): """Test IMAP import task with duplicate recipients handles deduplication correctly.""" mock_imap4_ssl.return_value = mock_imap_connection_with_duplicates - mock_store_result.return_value = None - - # Create a mock task instance - mock_task = MagicMock() - mock_task.update_state = MagicMock() - - with patch.object( - import_imap_messages_task, "update_state", mock_task.update_state - ): - # Run the task - task = import_imap_messages_task( - imap_server="imap.example.com", - imap_port=993, - username="test@example.com", - password="password123", - use_ssl=True, - recipient_id=str(mailbox.id), - ) - - # Verify results - assert task["status"] == "SUCCESS" - assert ( - task["result"]["message_status"] - == "Completed processing messages from folder 'INBOX'" - ) - assert task["result"]["type"] == "imap" - assert task["result"]["total_messages"] == 1 - assert task["result"]["success_count"] == 1 - assert task["result"]["failure_count"] == 0 - assert task["result"]["current_message"] == 1 - - # Verify messages were created - assert Message.objects.count() == 1 - assert Thread.objects.count() == 1 - # Check the message - message = Message.objects.first() - assert message.subject == "Test Subject with Duplicates" - assert message.sender.email == "sender@example.com" - - # Verify recipients - should have unique recipients only - recipients = message.recipients.all() - recipient_emails = [r.contact.email for r in recipients] - - # Should have unique recipients (no duplicates) - assert len(recipient_emails) == len(set(recipient_emails)) - - # Should have the expected recipients - assert "recipient@example.com" in recipient_emails - assert "cc@example.com" in recipient_emails + # Run the task + task = import_imap_messages_task( + imap_server="imap.example.com", + imap_port=993, + username="test@example.com", + password="password123", + use_ssl=True, + recipient_id=str(mailbox.id), + ) - # Verify recipient types - to_recipients = message.recipients.filter( - type=enums.MessageRecipientTypeChoices.TO - ) - cc_recipients = message.recipients.filter( - type=enums.MessageRecipientTypeChoices.CC - ) + # Verify results + assert task["status"] == "SUCCESS" + assert ( + task["result"]["message_status"] + == "Completed processing messages from folder 'INBOX'" + ) + assert task["result"]["type"] == "imap" + assert task["result"]["total_messages"] == 1 + assert task["result"]["success_count"] == 1 + assert task["result"]["failure_count"] == 0 + assert task["result"]["current_message"] == 1 + + # Verify messages were created + assert Message.objects.count() == 1 + assert Thread.objects.count() == 1 + + # Check the message + message = Message.objects.first() + assert message.subject == "Test Subject with Duplicates" + assert message.sender.email == "sender@example.com" + + # Verify recipients - should have unique recipients only + recipients = message.recipients.all() + recipient_emails = [r.contact.email for r in recipients] + + # Should have unique recipients (no duplicates) + assert len(recipient_emails) == len(set(recipient_emails)) + + # Should have the expected recipients + assert "recipient@example.com" in recipient_emails + assert "cc@example.com" in recipient_emails + + # Verify recipient types + to_recipients = message.recipients.filter( + type=enums.MessageRecipientTypeChoices.TO + ) + cc_recipients = message.recipients.filter( + type=enums.MessageRecipientTypeChoices.CC + ) - assert to_recipients.count() == 1 # Only one TO recipient (duplicate removed) - assert cc_recipients.count() == 1 # Only one CC recipient (duplicate removed) + assert to_recipients.count() == 1 # Only one TO recipient (duplicate removed) + assert cc_recipients.count() == 1 # Only one CC recipient (duplicate removed) - # Verify the content - assert ( - message.get_parsed_field("textBody")[0]["content"] - == "This is a test message with duplicate recipients.\n" - ) + # Verify the content + assert ( + message.get_parsed_field("textBody")[0]["content"] + == "This is a test message with duplicate recipients.\n" + ) - # Critical: Verify that no validation errors were logged - # This ensures the deduplication logic works correctly - mock_logger.error.assert_not_called() + # Critical: Verify that no validation errors were logged + # This ensures the deduplication logic works correctly + mock_logger.error.assert_not_called() diff --git a/src/backend/core/tests/importer/test_import_service.py b/src/backend/core/tests/importer/test_import_service.py index c4550da78..81804b500 100644 --- a/src/backend/core/tests/importer/test_import_service.py +++ b/src/backend/core/tests/importer/test_import_service.py @@ -164,62 +164,38 @@ def mock_deliver(recipient_email, parsed_email, raw_data, **kwargs): return True with patch("core.mda.inbound.deliver_inbound_message", side_effect=mock_deliver): - # Create a mock task instance - mock_task = MagicMock() - mock_task.update_state = MagicMock() - - with patch.object( - process_eml_file_task, "update_state", mock_task.update_state - ): - # Run the import - task_result = process_eml_file_task( - file_key=eml_key, - recipient_id=str(mailbox.id), - ) - - # Verify task result structure - assert isinstance(task_result, dict) - assert "status" in task_result - assert "result" in task_result - assert "error" in task_result - - # Verify task result content - assert task_result["status"] == "SUCCESS" - assert ( - task_result["result"]["message_status"] - == "Completed processing message" - ) - assert task_result["result"]["type"] == "eml" - assert task_result["result"]["total_messages"] == 1 - assert task_result["result"]["success_count"] == 1 - assert task_result["result"]["failure_count"] == 0 - assert task_result["result"]["current_message"] == 1 - assert task_result["error"] is None - - # Verify progress update (no SUCCESS update_state — Celery infers - # SUCCESS from normal return; status is in the returned dict) - mock_task.update_state.assert_called_once_with( - state="PROGRESS", - meta={ - "result": { - "message_status": "Processing message 1 of 1", - "total_messages": 1, - "success_count": 0, - "failure_count": 0, - "type": "eml", - "current_message": 1, - }, - "error": None, - }, - ) + # Run the import + task_result = process_eml_file_task( + file_key=eml_key, + recipient_id=str(mailbox.id), + ) - # Verify message was created - assert Message.objects.count() == 1 - message = Message.objects.first() - assert message.subject == "Mon mail avec joli pj" - assert message.sender.email == "sender@example.com" - assert message.recipients.count() == 1 - assert message.recipients.first().contact.email == "recipient@example.com" + # Verify task result structure + assert isinstance(task_result, dict) + assert "status" in task_result + assert "result" in task_result + assert "error" in task_result + + # Verify task result content + assert task_result["status"] == "SUCCESS" + assert ( + task_result["result"]["message_status"] + == "Completed processing message" + ) + assert task_result["result"]["type"] == "eml" + assert task_result["result"]["total_messages"] == 1 + assert task_result["result"]["success_count"] == 1 + assert task_result["result"]["failure_count"] == 0 + assert task_result["result"]["current_message"] == 1 + assert task_result["error"] is None + + # Verify message was created + assert Message.objects.count() == 1 + message = Message.objects.first() + assert message.subject == "Mon mail avec joli pj" + assert message.sender.email == "sender@example.com" + assert message.recipients.count() == 1 + assert message.recipients.first().contact.email == "recipient@example.com" @pytest.mark.django_db @@ -260,62 +236,38 @@ def mock_deliver(recipient_email, parsed_email, raw_data, **kwargs): return True with patch("core.mda.inbound.deliver_inbound_message", side_effect=mock_deliver): - # Create a mock task instance - mock_task = MagicMock() - mock_task.update_state = MagicMock() - - with patch.object( - process_eml_file_task, "update_state", mock_task.update_state - ): - # Run the import - task_result = process_eml_file_task( - file_key=eml_key, - recipient_id=str(mailbox.id), - ) - - # Verify task result structure - assert isinstance(task_result, dict) - assert "status" in task_result - assert "result" in task_result - assert "error" in task_result - - # Verify task result content - assert task_result["status"] == "SUCCESS" - assert ( - task_result["result"]["message_status"] - == "Completed processing message" - ) - assert task_result["result"]["type"] == "eml" - assert task_result["result"]["total_messages"] == 1 - assert task_result["result"]["success_count"] == 1 - assert task_result["result"]["failure_count"] == 0 - assert task_result["result"]["current_message"] == 1 - assert task_result["error"] is None - - # Verify progress update (no SUCCESS update_state — Celery infers - # SUCCESS from normal return; status is in the returned dict) - mock_task.update_state.assert_called_once_with( - state="PROGRESS", - meta={ - "result": { - "message_status": "Processing message 1 of 1", - "total_messages": 1, - "success_count": 0, - "failure_count": 0, - "type": "eml", - "current_message": 1, - }, - "error": None, - }, - ) + # Run the import + task_result = process_eml_file_task( + file_key=eml_key, + recipient_id=str(mailbox.id), + ) - # Verify message was created - assert Message.objects.count() == 1 - message = Message.objects.first() - assert message.subject == "Mon mail avec joli pj" - assert message.sender.email == "sender@example.com" - assert message.recipients.count() == 1 - assert message.recipients.first().contact.email == "recipient@example.com" + # Verify task result structure + assert isinstance(task_result, dict) + assert "status" in task_result + assert "result" in task_result + assert "error" in task_result + + # Verify task result content + assert task_result["status"] == "SUCCESS" + assert ( + task_result["result"]["message_status"] + == "Completed processing message" + ) + assert task_result["result"]["type"] == "eml" + assert task_result["result"]["total_messages"] == 1 + assert task_result["result"]["success_count"] == 1 + assert task_result["result"]["failure_count"] == 0 + assert task_result["result"]["current_message"] == 1 + assert task_result["error"] is None + + # Verify message was created + assert Message.objects.count() == 1 + message = Message.objects.first() + assert message.subject == "Mon mail avec joli pj" + assert message.sender.email == "sender@example.com" + assert message.recipients.count() == 1 + assert message.recipients.first().contact.email == "recipient@example.com" @pytest.mark.django_db diff --git a/src/backend/core/tests/importer/test_pst_import.py b/src/backend/core/tests/importer/test_pst_import.py index 54e584bc2..0a31bb375 100644 --- a/src/backend/core/tests/importer/test_pst_import.py +++ b/src/backend/core/tests/importer/test_pst_import.py @@ -1160,35 +1160,27 @@ def _upload_pst_to_s3(filename): class TestProcessPstFileTask: - """Tests for the process_pst_file_task Celery task using real PST files.""" + """Tests for the process_pst_file_task using real PST files.""" def test_nonexistent_mailbox(self): """Test task with non-existent mailbox returns failure.""" - mock_task = MagicMock() - with patch.object( - process_pst_file_task, "update_state", mock_task.update_state - ): - result = process_pst_file_task( - file_key="test.pst", - recipient_id="00000000-0000-0000-0000-000000000000", - ) - assert result["status"] == "FAILURE" - assert result["result"]["type"] == "pst" - assert "not found" in result["error"] + result = process_pst_file_task( + file_key="test.pst", + recipient_id="00000000-0000-0000-0000-000000000000", + ) + assert result["status"] == "FAILURE" + assert result["result"]["type"] == "pst" + assert "not found" in result["error"] def test_process_sample_pst(self, mailbox): """Test processing sample.pst — 1 message in myInbox with transport headers.""" file_key, storage, s3_client = _upload_pst_to_s3("sample.pst") try: - mock_task = MagicMock() - with patch.object( - process_pst_file_task, "update_state", mock_task.update_state - ): - result = process_pst_file_task( - file_key=file_key, - recipient_id=str(mailbox.id), - ) + result = process_pst_file_task( + file_key=file_key, + recipient_id=str(mailbox.id), + ) assert result["status"] == "SUCCESS" assert result["result"]["type"] == "pst" @@ -1223,14 +1215,10 @@ def test_process_outlook_pst(self, mailbox): file_key, storage, s3_client = _upload_pst_to_s3("Outlook.pst") try: - mock_task = MagicMock() - with patch.object( - process_pst_file_task, "update_state", mock_task.update_state - ): - result = process_pst_file_task( - file_key=file_key, - recipient_id=str(mailbox.id), - ) + result = process_pst_file_task( + file_key=file_key, + recipient_id=str(mailbox.id), + ) assert result["status"] == "SUCCESS" assert result["result"]["type"] == "pst" @@ -1270,14 +1258,10 @@ def test_process_malformed_pst(self, mailbox): ) try: - mock_task = MagicMock() - with patch.object( - process_pst_file_task, "update_state", mock_task.update_state - ): - result = process_pst_file_task( - file_key=file_key, - recipient_id=str(mailbox.id), - ) + result = process_pst_file_task( + file_key=file_key, + recipient_id=str(mailbox.id), + ) assert result["status"] == "FAILURE" assert result["result"]["type"] == "pst" diff --git a/src/backend/core/tests/mda/test_retry.py b/src/backend/core/tests/mda/test_retry.py index 87205a543..754309bf7 100644 --- a/src/backend/core/tests/mda/test_retry.py +++ b/src/backend/core/tests/mda/test_retry.py @@ -108,7 +108,7 @@ def test_retry_messages_set_success( # Mock successful send mock_send_message.return_value = None - result = retry_messages_task.apply(args=[message_ids]).get() + result = retry_messages_task(message_ids) # Verify the result assert result["success"] is True @@ -126,7 +126,7 @@ def test_retry_nonexistent_message(self): """Test retrying a non-existent message.""" fake_message_id = "00000000-0000-0000-0000-000000000000" - result = retry_messages_task.apply(args=[[fake_message_id]]).get() + result = retry_messages_task([fake_message_id]) # Verify the result assert result["success"] is True @@ -139,7 +139,7 @@ def test_retry_nonexistent_message(self): def test_retry_draft_message(self, draft_message): """Test retrying a draft message (should fail).""" - result = retry_messages_task.apply(args=[[str(draft_message.id)]]).get() + result = retry_messages_task([str(draft_message.id)]) # Verify the result assert result["success"] is True @@ -158,7 +158,7 @@ def test_retry_bulk_mode(self, mock_send_message, message_with_recipients): # Mock successful send mock_send_message.return_value = None - result = retry_messages_task.apply().get() + result = retry_messages_task() # Verify the result assert result["success"] is True @@ -194,7 +194,7 @@ def test_retry_no_messages_ready(self, mailbox_sender, thread): retry_count=1, ) - result = retry_messages_task.apply().get() + result = retry_messages_task() # Verify the result assert result["success"] is True @@ -248,7 +248,7 @@ def test_retry_failed_send_task_mid_route( # Mock successful send mock_send_message.return_value = None - result = retry_messages_task.apply(args=[[str(message.id)]]).get() + result = retry_messages_task([str(message.id)]) # Verify the result assert result["success"] is True @@ -302,7 +302,7 @@ def test_retry_timing_respect(self, mock_send_message, mailbox_sender, thread): # Mock successful send mock_send_message.return_value = None - result = retry_messages_task.apply(args=[[str(message.id)]]).get() + result = retry_messages_task([str(message.id)]) # Verify the result - should only process the ready recipient assert result["success"] is True @@ -343,9 +343,7 @@ def test_retry_batch_processing(self, mock_send_message, mailbox_sender, thread) # Mock successful send mock_send_message.return_value = None - result = retry_messages_task.apply( - kwargs={"batch_size": 2} - ).get() # Process in batches of 2 + result = retry_messages_task(batch_size=2) # Process in batches of 2 # Verify the result assert result["success"] is True @@ -426,7 +424,7 @@ def test_retry_mixed_recipient_statuses( # Mock successful send mock_send_message.return_value = None - result = retry_messages_task.apply(args=[[str(message.id)]]).get() + result = retry_messages_task([str(message.id)]) # Verify the result - should process 2 recipients (RETRY and NULL) assert result["success"] is True @@ -478,7 +476,7 @@ def test_retry_message_with_no_retryable_recipients( delivery_message="Permanent failure", ) - result = retry_messages_task.apply(args=[[str(message.id)]]).get() + result = retry_messages_task([str(message.id)]) # Verify the result - should process the message but not call send_message assert result["success"] is True @@ -519,7 +517,7 @@ def test_retry_message_with_empty_message_ids_list( type=models.MessageRecipientTypeChoices.TO, ) - result = retry_messages_task.apply(args=[[]]).get() + result = retry_messages_task([]) # Verify the result - should process the message but not call send_message assert result["success"] is True @@ -532,10 +530,10 @@ def test_retry_message_with_empty_message_ids_list( mock_send_message.assert_not_called() @patch("core.mda.outbound_tasks.send_message") - def test_retry_update_state_called_once_per_batch( + def test_retry_progress_updates_per_batch( self, mock_send_message, mailbox_sender, thread ): - """Test that update_state is called once per batch, not per message.""" + """Test that task completes successfully with batch processing.""" # Create 5 messages with retryable recipients messages = [] for i in range(5): @@ -565,44 +563,44 @@ def test_retry_update_state_called_once_per_batch( # Mock successful send mock_send_message.return_value = None - # Patch update_state to track calls - with patch.object(retry_messages_task, "update_state") as mock_update_state: - result = retry_messages_task.apply(kwargs={"batch_size": 2}).get() + with patch("core.mda.outbound_tasks.set_task_progress") as mock_set_progress: + result = retry_messages_task(batch_size=2) # Verify the result assert result["success"] is True assert result["total_messages"] == 5 assert result["success_count"] == 5 - # With 5 messages and batch_size=2, we should have 3 update_state calls: - # - At index 0 (start of batch 1) - # - At index 2 (start of batch 2) - # - At index 4 (start of batch 3) - assert mock_update_state.call_count == 3 + # With 5 messages and batch_size=2: + # - 2 initial calls (finding messages, found messages) + # - 3 batch calls at indices 0, 2, 4 + # Total: 5 calls + assert mock_set_progress.call_count == 5 - # Verify the calls have correct batch information - calls = mock_update_state.call_args_list + # Verify the batch calls have correct batch information + # Batch calls are at indices 2, 3, 4 (after the 2 initial calls) + calls = mock_set_progress.call_args_list - # First call at index 0 (batch 1) - assert calls[0].kwargs["state"] == "PROGRESS" - assert calls[0].kwargs["meta"]["current_batch"] == 1 - assert calls[0].kwargs["meta"]["total_batches"] == 3 + # First batch call (index 2) at message index 0 (batch 1) + assert calls[2].args[0] == 10 # 10% progress for first batch + assert calls[2].args[1]["current_batch"] == 1 + assert calls[2].args[1]["total_batches"] == 3 - # Second call at index 2 (batch 2) - assert calls[1].kwargs["state"] == "PROGRESS" - assert calls[1].kwargs["meta"]["current_batch"] == 2 - assert calls[1].kwargs["meta"]["total_batches"] == 3 + # Second batch call (index 3) at message index 2 (batch 2) + assert calls[3].args[0] == 42 # 10 + (2/5)*80 = 42 + assert calls[3].args[1]["current_batch"] == 2 + assert calls[3].args[1]["total_batches"] == 3 - # Third call at index 4 (batch 3) - assert calls[2].kwargs["state"] == "PROGRESS" - assert calls[2].kwargs["meta"]["current_batch"] == 3 - assert calls[2].kwargs["meta"]["total_batches"] == 3 + # Third batch call (index 4) at message index 4 (batch 3) + assert calls[4].args[0] == 74 # 10 + (4/5)*80 = 74 + assert calls[4].args[1]["current_batch"] == 3 + assert calls[4].args[1]["total_batches"] == 3 @patch("core.mda.outbound_tasks.send_message") - def test_retry_update_state_not_called_every_message( + def test_retry_progress_not_called_every_message( self, mock_send_message, mailbox_sender, thread ): - """Test that update_state is NOT called for every message when batch_size > 1.""" + """Test that task completes successfully with batched processing.""" # Create 10 messages with retryable recipients messages = [] for i in range(10): @@ -630,15 +628,17 @@ def test_retry_update_state_not_called_every_message( mock_send_message.return_value = None - with patch.object(retry_messages_task, "update_state") as mock_update_state: - result = retry_messages_task.apply(kwargs={"batch_size": 3}).get() + with patch("core.mda.outbound_tasks.set_task_progress") as mock_set_progress: + result = retry_messages_task(batch_size=3) assert result["success"] is True assert result["total_messages"] == 10 assert result["success_count"] == 10 - # With 10 messages and batch_size=3, update_state should be called 4 times: - # At indices 0, 3, 6, 9 (not 10 times for each message) - assert mock_update_state.call_count == 4 + # With 10 messages and batch_size=3: + # - 2 initial calls (finding messages, found messages) + # - 4 batch calls at indices 0, 3, 6, 9 + # Total: 6 calls + assert mock_set_progress.call_count == 6 # Verify it's significantly less than total messages - assert mock_update_state.call_count < result["total_messages"] + assert mock_set_progress.call_count < result["total_messages"] diff --git a/src/backend/core/tests/mda/test_spam_processing.py b/src/backend/core/tests/mda/test_spam_processing.py index dc03d7c30..685434cd0 100644 --- a/src/backend/core/tests/mda/test_spam_processing.py +++ b/src/backend/core/tests/mda/test_spam_processing.py @@ -712,9 +712,8 @@ def test_process_inbound_message_task_spam( mock_check_spam.return_value = (True, None) # is_spam=True mock_create_message.return_value = True - # Call the bound task directly using .run() method - with patch.object(process_inbound_message_task, "update_state", Mock()): - result = process_inbound_message_task.run(str(inbound_message.id)) + # Call the task directly + result = process_inbound_message_task(str(inbound_message.id)) assert result["success"] is True assert result["is_spam"] is True @@ -745,9 +744,8 @@ def test_process_inbound_message_task_not_spam( mock_check_spam.return_value = (False, None) # is_spam=False mock_create_message.return_value = True - # Call the bound task directly using .run() method - with patch.object(process_inbound_message_task, "update_state", Mock()): - result = process_inbound_message_task.run(str(inbound_message.id)) + # Call the task directly + result = process_inbound_message_task(str(inbound_message.id)) assert result["success"] is True assert result["is_spam"] is False @@ -774,9 +772,8 @@ def test_process_inbound_message_task_failure( mock_check_spam.return_value = (False, None) mock_create_message.return_value = False # Creation failed - # Call the bound task directly using .run() method - with patch.object(process_inbound_message_task, "update_state", Mock()): - result = process_inbound_message_task.run(str(inbound_message.id)) + # Call the task directly + result = process_inbound_message_task(str(inbound_message.id)) assert result["success"] is False assert "error" in result @@ -807,9 +804,8 @@ def test_process_inbound_messages_queue_task(self, mock_task_delay): created_at=old_time ) - # Call the bound task directly using .run() method - with patch.object(process_inbound_messages_queue_task, "update_state", Mock()): - result = process_inbound_messages_queue_task.run(10) + # Call the task directly + result = process_inbound_messages_queue_task(10) assert result["success"] is True assert result["processed"] == 3 diff --git a/src/backend/core/tests/tasks/test_task_importer.py b/src/backend/core/tests/tasks/test_task_importer.py index 7aae1f592..4b4c38957 100644 --- a/src/backend/core/tests/tasks/test_task_importer.py +++ b/src/backend/core/tests/tasks/test_task_importer.py @@ -208,12 +208,7 @@ def test_task_process_mbox_file_success(self, mailbox, sample_mbox_content): file_key, storage, s3_client = _upload_to_s3(sample_mbox_content) try: - mock_task = MagicMock() - mock_task.update_state = MagicMock() - - with patch.object( - process_mbox_file_task, "update_state", mock_task.update_state - ): + with patch("core.services.importer.mbox_tasks.set_task_progress") as mock_set_progress: task_result = process_mbox_file_task( file_key=file_key, recipient_id=str(mailbox.id) ) @@ -230,12 +225,12 @@ def test_task_process_mbox_file_success(self, mailbox, sample_mbox_content): assert task_result["result"]["current_message"] == 3 # 1 "Indexing" + 3 per-message PROGRESS = 4 - assert mock_task.update_state.call_count == 4 + assert mock_set_progress.call_count == 4 # Verify "Indexing messages" update - mock_task.update_state.assert_any_call( - state="PROGRESS", - meta={ + mock_set_progress.assert_any_call( + 0, # 0% progress + { "result": { "message_status": "Indexing messages", "total_messages": None, @@ -250,9 +245,9 @@ def test_task_process_mbox_file_success(self, mailbox, sample_mbox_content): # Verify per-message progress for i in range(1, 4): - mock_task.update_state.assert_any_call( - state="PROGRESS", - meta={ + mock_set_progress.assert_any_call( + min(10 + int((i / max(1, 3) * 80)), 90), # 10-90% range + { "result": { "message_status": f"Processing message {i} of 3", "total_messages": 3, @@ -293,13 +288,8 @@ def mock_deliver(recipient_email, parsed_email, raw_data, **kwargs): file_key, storage, s3_client = _upload_to_s3(sample_mbox_content) try: - mock_task = MagicMock() - mock_task.update_state = MagicMock() - with ( - patch.object( - process_mbox_file_task, "update_state", mock_task.update_state - ), + patch("core.services.importer.mbox_tasks.set_task_progress") as mock_set_progress, patch( "core.services.importer.mbox_tasks.deliver_inbound_message", side_effect=mock_deliver, @@ -314,7 +304,7 @@ def mock_deliver(recipient_email, parsed_email, raw_data, **kwargs): assert task_result["result"]["current_message"] == 3 # 1 indexing + 3 per-message PROGRESS = 4 - assert mock_task.update_state.call_count == 4 + assert mock_set_progress.call_count == 4 # Verify messages: msg1 and msg3 created, msg2 failed assert Message.objects.count() == 2 @@ -326,14 +316,9 @@ def mock_deliver(recipient_email, parsed_email, raw_data, **kwargs): def test_task_process_mbox_file_mailbox_not_found(self, sample_mbox_content): # pylint: disable=unused-argument """Test MBOX processing with non-existent mailbox.""" - mock_task = MagicMock() - mock_task.update_state = MagicMock() - non_existent_id = str(uuid.uuid4()) - with patch.object( - process_mbox_file_task, "update_state", mock_task.update_state - ): + with patch("core.services.importer.mbox_tasks.set_task_progress") as mock_set_progress: task_result = process_mbox_file_task( file_key="test-file-key.mbox", recipient_id=non_existent_id ) @@ -351,8 +336,8 @@ def test_task_process_mbox_file_mailbox_not_found(self, sample_mbox_content): # f"Recipient mailbox {non_existent_id} not found" in task_result["error"] ) - # No update_state calls — failure status is in the returned dict - mock_task.update_state.assert_not_called() + # No progress calls — failure status is in the returned dict + mock_set_progress.assert_not_called() assert Message.objects.count() == 0 @@ -365,17 +350,12 @@ def mock_parse(*args, **kwargs): file_key, storage, s3_client = _upload_to_s3(sample_mbox_content) try: - mock_task = MagicMock() - mock_task.update_state = MagicMock() - with ( patch( "core.services.importer.mbox_tasks.parse_email_message", side_effect=mock_parse, ), - patch.object( - process_mbox_file_task, "update_state", mock_task.update_state - ), + patch("core.services.importer.mbox_tasks.set_task_progress") as mock_set_progress, ): task_result = process_mbox_file_task(file_key, str(mailbox.id)) @@ -385,7 +365,7 @@ def mock_parse(*args, **kwargs): assert task_result["result"]["failure_count"] == 3 # 1 indexing + 3 per-message PROGRESS = 4 - assert mock_task.update_state.call_count == 4 + assert mock_set_progress.call_count == 4 assert Message.objects.count() == 0 finally: @@ -396,19 +376,13 @@ def test_task_process_mbox_file_empty(self, mailbox): file_key, storage, s3_client = _upload_to_s3(b"") try: - mock_task = MagicMock() - mock_task.update_state = MagicMock() - - with patch.object( - process_mbox_file_task, "update_state", mock_task.update_state - ): - task_result = process_mbox_file_task( - file_key=file_key, recipient_id=str(mailbox.id) - ) + task_result = process_mbox_file_task( + file_key=file_key, recipient_id=str(mailbox.id) + ) - assert task_result["status"] == "SUCCESS" - assert task_result["result"]["total_messages"] == 0 - assert Message.objects.count() == 0 + assert task_result["status"] == "SUCCESS" + assert task_result["result"]["total_messages"] == 0 + assert Message.objects.count() == 0 finally: s3_client.delete_object(Bucket=storage.bucket_name, Key=file_key) @@ -420,20 +394,14 @@ def test_task_process_mbox_invalid_file(self, mailbox): file_key, storage, s3_client = _upload_to_s3(jpeg_content) try: - mock_task = MagicMock() - mock_task.update_state = MagicMock() - - with patch.object( - process_mbox_file_task, "update_state", mock_task.update_state - ): - task_result = process_mbox_file_task( - file_key=file_key, recipient_id=str(mailbox.id) - ) + task_result = process_mbox_file_task( + file_key=file_key, recipient_id=str(mailbox.id) + ) - # MIME validation is done upstream in service.py; - # the task just finds zero messages in invalid content - assert task_result["status"] == "SUCCESS" - assert task_result["result"]["total_messages"] == 0 - assert Message.objects.count() == 0 + # MIME validation is done upstream in service.py; + # the task just finds zero messages in invalid content + assert task_result["status"] == "SUCCESS" + assert task_result["result"]["total_messages"] == 0 + assert Message.objects.count() == 0 finally: s3_client.delete_object(Bucket=storage.bucket_name, Key=file_key) diff --git a/src/backend/core/tests/tasks/test_task_send_message.py b/src/backend/core/tests/tasks/test_task_send_message.py index ce3ef44d9..412ff128a 100644 --- a/src/backend/core/tests/tasks/test_task_send_message.py +++ b/src/backend/core/tests/tasks/test_task_send_message.py @@ -66,11 +66,9 @@ def test_task_send_message_with_archive_true( # Mock the send_message function with patch("core.mda.outbound_tasks.send_message") as mock_mda_send: - # Call the task with must_archive=True - with patch.object(send_message_task, "update_state"): - result = send_message_task( # pylint: disable=no-value-for-parameter - str(draft_message.id), must_archive=True - ) + result = send_message_task( # pylint: disable=no-value-for-parameter + str(draft_message.id), must_archive=True + ) # Verify send_message was called mock_mda_send.assert_called_once_with(draft_message, False) @@ -114,11 +112,9 @@ def test_task_send_message_with_archive_false( # Mock the send_message function with patch("core.mda.outbound_tasks.send_message") as mock_mda_send: - # Call the task with must_archive=False - with patch.object(send_message_task, "update_state"): - result = send_message_task( # pylint: disable=no-value-for-parameter - str(draft_message.id), must_archive=False - ) + result = send_message_task( # pylint: disable=no-value-for-parameter + str(draft_message.id), must_archive=False + ) # Verify the result assert result["success"] is True @@ -155,10 +151,9 @@ def test_task_send_message_archive_error_does_not_fail_task( # Call the task with must_archive=True # The task should succeed even if archiving fails - with patch.object(send_message_task, "update_state"): - result = send_message_task( # pylint: disable=no-value-for-parameter - str(draft_message.id), must_archive=True - ) + result = send_message_task( # pylint: disable=no-value-for-parameter + str(draft_message.id), must_archive=True + ) # Verify send_message was called mock_mda_send.assert_called_once_with(draft_message, False) @@ -178,10 +173,9 @@ def test_task_send_message_updates_thread_stats_after_archive( # Mock thread.update_stats to verify it's called with patch("core.models.Thread.update_stats") as mock_update_stats: # Call the task with must_archive=True - with patch.object(send_message_task, "update_state"): - result = send_message_task( # pylint: disable=no-value-for-parameter - str(draft_message.id), must_archive=True - ) + result = send_message_task( # pylint: disable=no-value-for-parameter + str(draft_message.id), must_archive=True + ) # Verify the result assert result["success"] is True diff --git a/src/backend/core/tests/test_worker.py b/src/backend/core/tests/test_worker.py index c589e4e83..2629f5606 100644 --- a/src/backend/core/tests/test_worker.py +++ b/src/backend/core/tests/test_worker.py @@ -31,31 +31,10 @@ def test_default_queues_includes_all(self): assert worker.DEFAULT_QUEUES == worker.ALL_QUEUES - def test_celery_default_queue_is_default(self): - """Verify the celery default queue is set to 'default'.""" - assert settings.CELERY_TASK_DEFAULT_QUEUE == "default" - - def test_task_routes_configured(self): - """Verify task routes are configured for all expected modules.""" - routes = settings.CELERY_TASK_ROUTES - - assert "core.mda.inbound_tasks.*" in routes - assert routes["core.mda.inbound_tasks.*"]["queue"] == "inbound" - - assert "core.mda.outbound_tasks.*" in routes - assert routes["core.mda.outbound_tasks.*"]["queue"] == "outbound" - - assert "core.services.importer.mbox_tasks.*" in routes - assert routes["core.services.importer.mbox_tasks.*"]["queue"] == "imports" - assert "core.services.importer.eml_tasks.*" in routes - assert routes["core.services.importer.eml_tasks.*"]["queue"] == "imports" - assert "core.services.importer.imap_tasks.*" in routes - assert routes["core.services.importer.imap_tasks.*"]["queue"] == "imports" - assert "core.services.importer.pst_tasks.*" in routes - assert routes["core.services.importer.pst_tasks.*"]["queue"] == "imports" - - assert "core.services.search.tasks.*" in routes - assert routes["core.services.search.tasks.*"]["queue"] == "reindex" + def test_dramatiq_broker_configured(self): + """Verify the Dramatiq broker is configured.""" + assert hasattr(settings, "DRAMATIQ_BROKER") + assert settings.DRAMATIQ_BROKER["BROKER"] == "core.utils.EagerBroker" class TestWorkerCLIParsing: @@ -75,8 +54,7 @@ def test_parse_args_defaults(self): assert args.queues is None assert args.exclude is None - assert args.disable_scheduler is False - assert args.loglevel == "INFO" + assert args.verbosity == 1 finally: sys.argv = original_argv @@ -110,21 +88,6 @@ def test_parse_args_with_exclude(self): finally: sys.argv = original_argv - def test_parse_args_with_disable_scheduler(self): - """Test parsing --disable-scheduler flag.""" - import sys - - import worker - - original_argv = sys.argv - try: - sys.argv = ["worker.py", "--disable-scheduler"] - args = worker.parse_args() - - assert args.disable_scheduler is True - finally: - sys.argv = original_argv - def test_parse_args_with_concurrency(self): """Test parsing --concurrency argument.""" import sys @@ -140,18 +103,18 @@ def test_parse_args_with_concurrency(self): finally: sys.argv = original_argv - def test_parse_args_with_loglevel(self): - """Test parsing --loglevel argument.""" + def test_parse_args_with_verbosity(self): + """Test parsing --verbosity argument.""" import sys import worker original_argv = sys.argv try: - sys.argv = ["worker.py", "--loglevel=DEBUG"] + sys.argv = ["worker.py", "--verbosity=2"] args = worker.parse_args() - assert args.loglevel == "DEBUG" + assert args.verbosity == 2 finally: sys.argv = original_argv @@ -171,15 +134,15 @@ def test_parse_args_short_flags(self): "reindex", "-c", "2", - "-l", - "WARNING", + "-v", + "2", ] args = worker.parse_args() assert args.queues == "inbound" assert args.exclude == "reindex" assert args.concurrency == 2 - assert args.loglevel == "WARNING" + assert args.verbosity == 2 finally: sys.argv = original_argv @@ -229,33 +192,89 @@ def test_queue_order_preserved_after_exclusion(self): assert result == expected -class TestBeatScheduleQueues: - """Test that beat schedule tasks use correct queues.""" +class TestCrontabConfiguration: + """Test that crontab tasks are configured correctly.""" - def test_beat_schedule_uses_correct_queues(self): - """Verify scheduled tasks are routed to appropriate queues.""" - from messages.celery_app import app + def test_crontab_settings_configured(self): + """Verify crontab settings are configured.""" + assert hasattr(settings, "DRAMATIQ_CRONTAB") + assert "REDIS_URL" in settings.DRAMATIQ_CRONTAB - if not hasattr(app.conf, "beat_schedule") or not app.conf.beat_schedule: - pytest.skip("Beat schedule is disabled") + def test_autodiscover_modules_finds_task_modules(self): + """Verify that DRAMATIQ_AUTODISCOVER_MODULES values are discoverable. - schedule = app.conf.beat_schedule + django_dramatiq's rundramatiq command uses Django's autodiscover_modules() + which looks for '{app_name}.{module_name}' for each installed app. + If DRAMATIQ_AUTODISCOVER_MODULES contains full paths like + 'core.mda.inbound_tasks', autodiscovery silently fails because it + would look for 'core.core.mda.inbound_tasks' which doesn't exist. + """ + from importlib import import_module - # Check retry-pending-messages uses outbound queue - if "retry-pending-messages" in schedule: - assert schedule["retry-pending-messages"]["options"]["queue"] == "outbound" + from django.apps import apps - # Check selfcheck uses outbound queue - if "selfcheck" in schedule: - assert schedule["selfcheck"]["options"]["queue"] == "outbound" + autodiscover_modules = settings.DRAMATIQ_AUTODISCOVER_MODULES + app_configs = apps.get_app_configs() - # Check process-inbound-messages-queue uses inbound queue - if "process-inbound-messages-queue" in schedule: - assert ( - schedule["process-inbound-messages-queue"]["options"]["queue"] - == "inbound" + for module_name in autodiscover_modules: + found = False + for app_config in app_configs: + try: + import_module(f"{app_config.name}.{module_name}") + found = True + break + except ImportError: + continue + assert found, ( + f"DRAMATIQ_AUTODISCOVER_MODULES entry '{module_name}' is not " + f"discoverable: no installed app contains a '{module_name}' " + f"submodule. autodiscover_modules() will silently skip it. " + f"Use simple module names like 'tasks' (not full paths)." ) + def test_all_task_actors_registered_on_broker(self): + """Verify that all @register_task functions are registered on the broker. + + This catches missing imports in core/tasks.py that would cause the + worker to silently ignore enqueued tasks. + """ + import dramatiq + + broker = dramatiq.get_broker() + registered_actors = set(broker.actors.keys()) + + # Every @register_task function must be discoverable by the worker. + # These are the actor names derived from each decorated function. + expected_actors = { + # core.mda.inbound_tasks + "process_inbound_message_task", + "process_inbound_messages_queue_task", + # core.mda.outbound_tasks + "send_message_task", + "selfcheck_task", + "retry_messages_task", + # core.services.importer + "process_eml_file_task", + "import_imap_messages_task", + "process_mbox_file_task", + "process_pst_file_task", + # core.services.search + "reindex_all", + "reindex_thread_task", + "reindex_mailbox_task", + "index_message_task", + "reset_search_index", + # core.services.exporter + "export_mailbox_task", + } + + missing = expected_actors - registered_actors + assert not missing, ( + f"Task actors not registered on the broker: {missing}. " + f"Check that core/tasks.py imports all task modules and that " + f"DRAMATIQ_AUTODISCOVER_MODULES is set correctly." + ) + class TestWorkerE2E: """End-to-end tests for the worker process.""" @@ -264,15 +283,14 @@ def test_worker_starts_successfully(self): """Test that the worker process starts without immediate errors.""" import subprocess - # Start worker with minimal config, disable scheduler to avoid side effects + # Start worker with minimal config # pylint: disable=consider-using-with process = subprocess.Popen( [ "python", "worker.py", "--queues=default", - "--disable-scheduler", - "--loglevel=INFO", + "-v", "2", "--concurrency=1", ], stdout=subprocess.PIPE, diff --git a/src/backend/core/utils.py b/src/backend/core/utils.py index e67e8a5a1..204c2b9f0 100644 --- a/src/backend/core/utils.py +++ b/src/backend/core/utils.py @@ -4,12 +4,161 @@ import logging from contextlib import contextmanager from contextvars import ContextVar +from typing import Any, Dict, Optional from configurations import values +from django.core.cache import cache +from django.utils import timezone +import dramatiq +from dramatiq.brokers.stub import StubBroker +from dramatiq.middleware import CurrentMessage +from dramatiq_crontab import cron, interval logger = logging.getLogger(__name__) +TASK_PROGRESS_CACHE_TIMEOUT = 86400 # 24 hours +TASK_TRACKING_CACHE_TTL = 86400 * 30 # 30 days (matches DRAMATIQ_RESULT_BACKEND TTL) + + +class EagerBroker(StubBroker): + """A broker that executes tasks synchronously for testing. + Equivalent to Celery's CELERY_TASK_ALWAYS_EAGER mode. + + Only runs CurrentMessage and Results middleware (not all middleware, since + DbConnectionsMiddleware would close DB connections mid-test). + """ + + def enqueue(self, message, *, delay=None): + from dramatiq.results import Results + + actor = self.get_actor(message.actor_name) + cm = next((m for m in self.middleware if isinstance(m, CurrentMessage)), None) + rm = next((m for m in self.middleware if isinstance(m, Results)), None) + prev = CurrentMessage.get_current_message() if cm else None + if cm: + cm.before_process_message(self, message) + try: + result = actor.fn(*message.args, **message.kwargs) + if rm: + rm.after_process_message(self, message, result=result) + finally: + if cm: + cm.after_process_message(self, message) + if prev is not None: + cm.before_process_message(self, prev) + return message + + +class Task: + """ + Wrapper around Dramatiq Message that provides Celery-like API. + """ + def __init__(self, message): + self._message = message + + @property + def id(self): + """Celery-compatible task ID (maps to message_id).""" + return self._message.message_id + + def track_owner(self, user_id): + """Register tracking metadata for permission checks and result retrieval.""" + cache.set(f"task_tracking:{self.id}", json.dumps({ + "owner": str(user_id), + "actor_name": self._message.actor_name, + "queue_name": self._message.queue_name, + }), timeout=TASK_TRACKING_CACHE_TTL) + + def __getattr__(self, name): + """Delegate any other attributes to the underlying message.""" + return getattr(self._message, name) + + +class CeleryCompatActor(dramatiq.Actor): + """ + A custom actor class that adds Celery-compatible methods to Dramatiq actors. + This allows keeping the .delay() API throughout the codebase. + """ + def delay(self, *args, **kwargs): + message = self.send(*args, **kwargs) + return Task(message) + + +def register_task(*args, **kwargs): + """ + Decorator to register a dramatiq actor. + Use this instead of @dramatiq.actor to abstract away the dependency. + """ + kwargs.setdefault("store_results", True) + if "queue" in kwargs: + kwargs.setdefault("queue_name", kwargs.pop("queue")) + kwargs.setdefault("actor_class", CeleryCompatActor) + + def decorator(fn): + return dramatiq.actor(fn, **kwargs) + + if args and callable(args[0]): + return decorator(args[0]) + return decorator + + +def cron_task(*args, **kwargs): + """ + Decorator to register a cron task. + Use this instead of @cron to abstract away the dependency. + + Supports: + - cron_task("*/5 * * * *") - standard cron expression + - cron_task(interval=300) - run every 300 seconds + """ + if "interval" in kwargs: + return interval(seconds=kwargs.pop("interval")) + return cron(*args, **kwargs) + + +def get_task_tracking(task_id: str) -> Optional[Dict[str, str]]: + """Get tracking metadata for a task, or None if not found.""" + raw = cache.get(f"task_tracking:{task_id}") + if raw is None: + return None + return json.loads(raw) + + +def set_task_progress(progress: int, metadata: Optional[Dict[str, Any]] = None) -> None: + """ + Set the progress of the currently executing task. + """ + current_message = CurrentMessage.get_current_message() + + if not current_message: + logger.warning("set_task_progress called outside of a dramatiq actor") + return + + task_id = current_message.message_id + try: + progress = max(0, min(100, int(progress))) + except (TypeError, ValueError): + progress = 0 + + progress_data = { + "progress": progress, + "timestamp": timezone.now().timestamp(), + "metadata": metadata or {}, + } + + cache_key = f"task_progress:{task_id}" + cache.set(cache_key, progress_data, timeout=TASK_PROGRESS_CACHE_TIMEOUT) + + +def get_task_progress(task_id: str) -> Optional[Dict[str, Any]]: + """ + Get the progress of a task by ID. + """ + cache_key = f"task_progress:{task_id}" + return cache.get(cache_key) + + class ThreadStatsUpdateDeferrer: """ Manages deferred thread.update_stats() calls. diff --git a/src/backend/messages/__init__.py b/src/backend/messages/__init__.py index bf4eac6f6..54a75dd38 100644 --- a/src/backend/messages/__init__.py +++ b/src/backend/messages/__init__.py @@ -1,5 +1,14 @@ """Messages module.""" -from .celery_app import app as celery_app +import os -__all__ = ("celery_app",) +# Ensure django-configurations is installed before settings are imported. +# This is needed because dramatiq worker processes import messages.settings +# via django_dramatiq.setup → django.setup(), and install() must be called +# before the Configuration metaclass runs. +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "messages.settings") +os.environ.setdefault("DJANGO_CONFIGURATION", "Development") + +from configurations.importer import install # noqa: E402 + +install(check_options=True) diff --git a/src/backend/messages/celery_app.py b/src/backend/messages/celery_app.py deleted file mode 100644 index abe664d3f..000000000 --- a/src/backend/messages/celery_app.py +++ /dev/null @@ -1,48 +0,0 @@ -"""Messages celery configuration file.""" - -import os - -from celery import Celery -from configurations.importer import install - -# Set the default Django settings module for the 'celery' program. -os.environ.setdefault("DJANGO_SETTINGS_MODULE", "messages.settings") -os.environ.setdefault("DJANGO_CONFIGURATION", "Development") - -install(check_options=True) - -# Must be imported after install() -from django.conf import settings # pylint: disable=wrong-import-position - -app = Celery("messages") - -# Using a string here means the worker doesn't have to serialize -# the configuration object to child processes. -# - namespace='CELERY' means all celery-related configuration keys -# should have a `CELERY_` prefix. -app.config_from_object("django.conf:settings", namespace="CELERY") - -# Load task modules from all registered Django apps. -app.autodiscover_tasks() - -# Configure beat schedule -# This can be disabled manually, for example when pushing the application for the first time -# to a PaaS service when no migration was applied yet. -if not settings.DISABLE_CELERY_BEAT_SCHEDULE: - app.conf.beat_schedule = { - "retry-pending-messages": { - "task": "core.mda.outbound_tasks.retry_messages_task", - "schedule": 300.0, # Every 5 minutes (300 seconds) - "options": {"queue": "outbound"}, - }, - "selfcheck": { - "task": "core.mda.outbound_tasks.selfcheck_task", - "schedule": settings.MESSAGES_SELFCHECK_INTERVAL, - "options": {"queue": "outbound"}, - }, - "process-inbound-messages-queue": { - "task": "core.mda.inbound_tasks.process_inbound_messages_queue_task", - "schedule": 300.0, # Every 5 minutes - "options": {"queue": "inbound"}, - }, - } diff --git a/src/backend/messages/settings.py b/src/backend/messages/settings.py old mode 100755 new mode 100644 index 66c9cad34..7f4d13582 --- a/src/backend/messages/settings.py +++ b/src/backend/messages/settings.py @@ -523,8 +523,9 @@ class Base(Configuration): "drf_spectacular", # Third party apps "corsheaders", - "django_celery_beat", - "django_celery_results", + "django_dramatiq", + "dramatiq_crontab", + "django_filters", "rest_framework", # Django @@ -607,47 +608,52 @@ class Base(Configuration): None, environ_name="FRONTEND_THEME", environ_prefix=None ) - # Celery - CELERY_BROKER_URL = values.Value( - "redis://redis:6379", environ_name="CELERY_BROKER_URL", environ_prefix=None - ) - CELERY_RESULT_BACKEND = "django-db" - CELERY_CACHE_BACKEND = "django-cache" - CELERY_BROKER_TRANSPORT_OPTIONS = values.DictValue({}) - CELERY_RESULT_EXTENDED = True - CELERY_TASK_RESULT_EXPIRES = 60 * 60 * 24 * 30 # 30 days - CELERY_BEAT_SCHEDULER = "django_celery_beat.schedulers:DatabaseScheduler" - CELERY_WORKER_HIJACK_ROOT_LOGGER = False - - # Default queue for tasks without explicit routing - CELERY_TASK_DEFAULT_QUEUE = "default" - - # Queue routing - tasks are routed to specific queues based on their type - # Priority order: management > inbound > outbound > default > imports > reindex - CELERY_TASK_ROUTES = { - # Inbound email processing - highest priority, time-sensitive - "core.mda.inbound_tasks.*": {"queue": "inbound"}, - # Outbound email sending - high priority - "core.mda.outbound_tasks.*": {"queue": "outbound"}, - # Import tasks - lower priority than regular tasks - "core.services.importer.mbox_tasks.*": {"queue": "imports"}, - "core.services.importer.eml_tasks.*": {"queue": "imports"}, - "core.services.importer.imap_tasks.*": {"queue": "imports"}, - "core.services.importer.pst_tasks.*": {"queue": "imports"}, - # Search indexing - lowest priority, can be delayed - "core.services.search.tasks.*": {"queue": "reindex"}, + # Dramatiq + DRAMATIQ_BROKER = { + "BROKER": "dramatiq.brokers.redis.RedisBroker", + "OPTIONS": { + "url": values.Value( + "redis://redis:6379", + environ_name="DRAMATIQ_BROKER_URL", + environ_prefix=None, + ), + }, + "MIDDLEWARE": [ + "dramatiq.middleware.prometheus.Prometheus", + "dramatiq.middleware.AgeLimit", + "dramatiq.middleware.TimeLimit", + "dramatiq.middleware.Callbacks", + "dramatiq.middleware.Retries", + "dramatiq.middleware.CurrentMessage", + "django_dramatiq.middleware.DbConnectionsMiddleware", + "django_dramatiq.middleware.AdminMiddleware", + ] } - DISABLE_CELERY_BEAT_SCHEDULE = values.BooleanValue( - default=False, environ_name="DISABLE_CELERY_BEAT_SCHEDULE", environ_prefix=None - ) + DRAMATIQ_RESULT_BACKEND = { + "BACKEND": "dramatiq.results.backends.redis.RedisBackend", + "BACKEND_OPTIONS": { + "url": values.Value( + "redis://redis:6379/1", + environ_name="DRAMATIQ_BROKER_URL", + environ_prefix=None, + ), + }, + "MIDDLEWARE_OPTIONS": { + "result_ttl": 1000 * 60 * 60 * 24 * 30 # 30 days + } + } - CELERY_WORKER_SEND_TASK_EVENTS = values.BooleanValue( - True, environ_name="CELERY_WORKER_SEND_TASK_EVENTS", environ_prefix=None - ) - CELERY_TASK_SEND_SENT_EVENT = values.BooleanValue( - True, environ_name="CELERY_TASK_SEND_SENT_EVENT", environ_prefix=None - ) + DRAMATIQ_CRONTAB = { + "REDIS_URL": values.Value( + "redis://redis:6379/0", + environ_name="DRAMATIQ_CRONTAB_REDIS_URL", + environ_prefix=None, + ) + } + + # Default ["tasks"] discovers core/tasks.py which re-exports all task modules + DRAMATIQ_AUTODISCOVER_MODULES = ["tasks"] # Session SESSION_ENGINE = "django.contrib.sessions.backends.cache" @@ -1134,7 +1140,19 @@ class DevelopmentMinimal(Development): Development environment settings with minimal dependencies """ - CELERY_TASK_ALWAYS_EAGER = True + DRAMATIQ_BROKER = { + "BROKER": "dramatiq.brokers.stub.StubBroker", + "OPTIONS": {}, + "MIDDLEWARE": [ + "dramatiq.middleware.AgeLimit", + "dramatiq.middleware.TimeLimit", + "dramatiq.middleware.Callbacks", + "dramatiq.middleware.Retries", + "dramatiq.middleware.CurrentMessage", + "django_dramatiq.middleware.DbConnectionsMiddleware", + "django_dramatiq.middleware.AdminMiddleware", + ] + } OPENSEARCH_INDEX_THREADS = False CACHES = { "default": {"BACKEND": "django.core.cache.backends.dummy.DummyCache"}, @@ -1152,7 +1170,19 @@ class Test(Base): IDENTITY_PROVIDER = None - CELERY_TASK_ALWAYS_EAGER = values.BooleanValue(True) + DRAMATIQ_BROKER = { + "BROKER": "core.utils.EagerBroker", + "OPTIONS": {}, + "MIDDLEWARE": [ + "dramatiq.middleware.AgeLimit", + "dramatiq.middleware.TimeLimit", + "dramatiq.middleware.Callbacks", + "dramatiq.middleware.Retries", + "dramatiq.middleware.CurrentMessage", + "django_dramatiq.middleware.DbConnectionsMiddleware", + "django_dramatiq.middleware.AdminMiddleware", + ] + } AWS_S3_DOMAIN_REPLACE = None diff --git a/src/backend/pyproject.toml b/src/backend/pyproject.toml index f95e4bd58..a44c7228d 100644 --- a/src/backend/pyproject.toml +++ b/src/backend/pyproject.toml @@ -27,12 +27,12 @@ requires-python = ">=3.13,<4.0" dependencies = [ "boto3==1.42.53", "botocore==1.42.53", - "celery[redis]==5.6.2", + "django-dramatiq==0.15.0", + "dramatiq[redis,prometheus]==2.0.1", + "dramatiq-crontab[sentry]==1.0.12", "cryptography==46.0.5", "dj-database-url==3.1.2", "django==5.2.11", - "django-celery-beat==2.8.1", - "django-celery-results==2.6.0", "django-configurations==2.5.1", "django-cors-headers==4.9.0", "django-countries==8.2.0", diff --git a/src/backend/uv.lock b/src/backend/uv.lock index 353670ffc..7451305f2 100644 --- a/src/backend/uv.lock +++ b/src/backend/uv.lock @@ -44,6 +44,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/38/0e/27be9fdef66e72d64c0cdc3cc2823101b80585f8119b5c112c2e8f5f7dab/anyio-4.12.1-py3-none-any.whl", hash = "sha256:d405828884fc140aa80a3c667b8beed277f1dfedec42ba031bd6ac3db606ab6c", size = 113592, upload-time = "2026-01-06T11:45:19.497Z" }, ] +[[package]] +name = "apscheduler" +version = "3.11.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "tzlocal" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/07/12/3e4389e5920b4c1763390c6d371162f3784f86f85cd6d6c1bfe68eef14e2/apscheduler-3.11.2.tar.gz", hash = "sha256:2a9966b052ec805f020c8c4c3ae6e6a06e24b1bf19f2e11d91d8cca0473eef41", size = 108683, upload-time = "2025-12-22T00:39:34.884Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9f/64/2e54428beba8d9992aa478bb8f6de9e4ecaa5f8f513bcfd567ed7fb0262d/apscheduler-3.11.2-py3-none-any.whl", hash = "sha256:ce005177f741409db4e4dd40a7431b76feb856b9dd69d57e0da49d6715bfd26d", size = 64439, upload-time = "2025-12-22T00:39:33.303Z" }, +] + [[package]] name = "asgiref" version = "3.11.1" @@ -206,11 +218,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/dd/bd/9ecd619e456ae4ba73b6583cc313f26152afae13e9a82ac4fe7f8856bfd1/celery-5.6.2-py3-none-any.whl", hash = "sha256:3ffafacbe056951b629c7abcf9064c4a2366de0bdfc9fdba421b97ebb68619a5", size = 445502, upload-time = "2026-01-04T12:35:55.894Z" }, ] -[package.optional-dependencies] -redis = [ - { name = "kombu", extra = ["redis"] }, -] - [[package]] name = "certifi" version = "2026.1.4" @@ -442,18 +449,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0d/4a/331fe2caf6799d591109bb9c08083080f6de90a823695d412a935622abb2/coverage-7.13.4-py3-none-any.whl", hash = "sha256:1af1641e57cf7ba1bd67d677c9abdbcd6cc2ab7da3bca7fa1e2b7e50e65f2ad0", size = 211242, upload-time = "2026-02-09T12:59:02.032Z" }, ] -[[package]] -name = "cron-descriptor" -version = "2.0.6" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/7c/31/0b21d1599656b2ffa6043e51ca01041cd1c0f6dacf5a3e2b620ed120e7d8/cron_descriptor-2.0.6.tar.gz", hash = "sha256:e39d2848e1d8913cfb6e3452e701b5eec662ee18bea8cc5aa53ee1a7bb217157", size = 49456, upload-time = "2025-09-03T16:30:22.434Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/21/cc/361326a54ad92e2e12845ad15e335a4e14b8953665007fb514d3393dfb0f/cron_descriptor-2.0.6-py3-none-any.whl", hash = "sha256:3a1c0d837c0e5a32e415f821b36cf758eb92d510e6beff8fbfe4fa16573d93d6", size = 74446, upload-time = "2025-09-03T16:30:21.397Z" }, -] - [[package]] name = "cryptography" version = "46.0.5" @@ -587,36 +582,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/91/a7/2b112ab430575bf3135b8304ac372248500d99c352f777485f53fdb9537e/django-5.2.11-py3-none-any.whl", hash = "sha256:e7130df33ada9ab5e5e929bc19346a20fe383f5454acb2cc004508f242ee92c0", size = 8291375, upload-time = "2026-02-03T13:52:42.47Z" }, ] -[[package]] -name = "django-celery-beat" -version = "2.8.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "celery" }, - { name = "cron-descriptor" }, - { name = "django" }, - { name = "django-timezone-field" }, - { name = "python-crontab" }, - { name = "tzdata" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/aa/11/0c8b412869b4fda72828572068312b10aafe7ccef7b41af3633af31f9d4b/django_celery_beat-2.8.1.tar.gz", hash = "sha256:dfad0201c0ac50c91a34700ef8fa0a10ee098cc7f3375fe5debed79f2204f80a", size = 175802, upload-time = "2025-05-13T06:58:29.246Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/61/e5/3a0167044773dee989b498e9a851fc1663bea9ab879f1179f7b8a827ac10/django_celery_beat-2.8.1-py3-none-any.whl", hash = "sha256:da2b1c6939495c05a551717509d6e3b79444e114a027f7b77bf3727c2a39d171", size = 104833, upload-time = "2025-05-13T06:58:27.309Z" }, -] - -[[package]] -name = "django-celery-results" -version = "2.6.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "celery" }, - { name = "django" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/a6/b5/9966c28e31014c228305e09d48b19b35522a8f941fe5af5f81f40dc8fa80/django_celery_results-2.6.0.tar.gz", hash = "sha256:9abcd836ae6b61063779244d8887a88fe80bbfaba143df36d3cb07034671277c", size = 83985, upload-time = "2025-04-10T08:23:52.677Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2c/da/70f0f3c5364735344c4bc89e53413bcaae95b4fc1de4e98a7a3b9fb70c88/django_celery_results-2.6.0-py3-none-any.whl", hash = "sha256:b9ccdca2695b98c7cbbb8dea742311ba9a92773d71d7b4944a676e69a7df1c73", size = 38351, upload-time = "2025-04-10T08:23:49.965Z" }, -] - [[package]] name = "django-configurations" version = "2.5.1" @@ -655,6 +620,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/20/3c/9ebd7ed021b7c519bac954bc88146bc870e7d3c8db2580fa67268464fd2e/django_countries-8.2.0-py3-none-any.whl", hash = "sha256:2b2617bec7c15dc735bdec38ae89f0058e38fddfffdb19a7f6b75ef1e3d5380f", size = 3776079, upload-time = "2025-11-24T19:57:05.576Z" }, ] +[[package]] +name = "django-dramatiq" +version = "0.15.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "django" }, + { name = "dramatiq" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2b/b7/2a5ac56a5fd41fbe332221d3906168ace96df6a1f3c92d0839c4c7133cb2/django_dramatiq-0.15.0.tar.gz", hash = "sha256:e3cf1b2ac288fe4a7aa198c9450fe242ed312df8850f3f9e18ce01b8acc78b96", size = 15763, upload-time = "2025-11-13T15:53:40.115Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4a/3a/4d2e6f89ee0b9a79a271f0ead1a6c25f6cab291735493b083e316ab86130/django_dramatiq-0.15.0-py3-none-any.whl", hash = "sha256:23f0bc418a860952adbf822c4aa3b9c46c51d3d9f50be0a8ed3d19a53380df1d", size = 12620, upload-time = "2025-11-13T15:53:38.878Z" }, +] + [[package]] name = "django-extensions" version = "4.1" @@ -796,6 +774,42 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ba/5a/18ad964b0086c6e62e2e7500f7edc89e3faa45033c71c1893d34eed2b2de/dnspython-2.8.0-py3-none-any.whl", hash = "sha256:01d9bbc4a2d76bf0db7c1f729812ded6d912bd318d3b1cf81d30c0f845dbf3af", size = 331094, upload-time = "2025-09-07T18:57:58.071Z" }, ] +[[package]] +name = "dramatiq" +version = "2.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/fa/bb/56b5d615c32ec8e136beee243efc54afa099b384c057e896d16268e35e08/dramatiq-2.0.1.tar.gz", hash = "sha256:3caa0587057eee67bd3a0e6d439d78d6cf88b300b5185dad1f4044a0c5f57fc2", size = 104165, upload-time = "2026-01-18T11:31:09.807Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ca/28/4bfc19a3b12177febcb3d28767933c823c056727872a8792a87d6f68df67/dramatiq-2.0.1-py3-none-any.whl", hash = "sha256:0cdfe5fdd1028adf65c6f3b2f0c5e6909053d6e41cf6556ff4def991d2419c89", size = 124391, upload-time = "2026-01-18T11:31:08.803Z" }, +] + +[package.optional-dependencies] +prometheus = [ + { name = "prometheus-client" }, +] +redis = [ + { name = "redis" }, +] + +[[package]] +name = "dramatiq-crontab" +version = "1.0.12" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "apscheduler" }, + { name = "django" }, + { name = "dramatiq" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/92/27/13dd12312f7126f41d3c7787c8a07568a9026d5b546a2c237bb168a6bef2/dramatiq_crontab-1.0.12.tar.gz", hash = "sha256:6192289cb7fe16aa698cd6a78aa86e40ae4fb1fe043a55a028cfbf3c2299c8dd", size = 8157, upload-time = "2025-08-06T16:12:54.907Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/18/40/802f5afda7f75b6ed61bd88fc71f96f94670ab5efd172e78ced3f82b67a0/dramatiq_crontab-1.0.12-py3-none-any.whl", hash = "sha256:417d1da76fd423c02466db6e1091bb089cf76596331951447283fb9d6720ab8c", size = 8861, upload-time = "2025-08-06T16:12:53.352Z" }, +] + +[package.optional-dependencies] +sentry = [ + { name = "sentry-sdk" }, +] + [[package]] name = "drf-spectacular" version = "0.29.0" @@ -1149,11 +1163,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fb/0f/834427d8c03ff1d7e867d3db3d176470c64871753252b21b4f4897d1fa45/kombu-5.6.2-py3-none-any.whl", hash = "sha256:efcfc559da324d41d61ca311b0c64965ea35b4c55cc04ee36e55386145dace93", size = 214219, upload-time = "2025-12-29T20:30:05.74Z" }, ] -[package.optional-dependencies] -redis = [ - { name = "redis" }, -] - [[package]] name = "legacy-cgi" version = "2.6.4" @@ -1218,15 +1227,13 @@ source = { editable = "." } dependencies = [ { name = "boto3" }, { name = "botocore" }, - { name = "celery", extra = ["redis"] }, { name = "cryptography" }, { name = "dj-database-url" }, { name = "django" }, - { name = "django-celery-beat" }, - { name = "django-celery-results" }, { name = "django-configurations" }, { name = "django-cors-headers" }, { name = "django-countries" }, + { name = "django-dramatiq" }, { name = "django-fernet-encrypted-fields" }, { name = "django-filter" }, { name = "django-lasuite", extra = ["all"] }, @@ -1237,6 +1244,8 @@ dependencies = [ { name = "djangorestframework" }, { name = "dkimpy" }, { name = "dnspython" }, + { name = "dramatiq", extra = ["prometheus", "redis"] }, + { name = "dramatiq-crontab", extra = ["sentry"] }, { name = "drf-spectacular" }, { name = "factory-boy" }, { name = "flanker" }, @@ -1284,15 +1293,13 @@ dev = [ requires-dist = [ { name = "boto3", specifier = "==1.42.53" }, { name = "botocore", specifier = "==1.42.53" }, - { name = "celery", extras = ["redis"], specifier = "==5.6.2" }, { name = "cryptography", specifier = "==46.0.5" }, { name = "dj-database-url", specifier = "==3.1.2" }, { name = "django", specifier = "==5.2.11" }, - { name = "django-celery-beat", specifier = "==2.8.1" }, - { name = "django-celery-results", specifier = "==2.6.0" }, { name = "django-configurations", specifier = "==2.5.1" }, { name = "django-cors-headers", specifier = "==4.9.0" }, { name = "django-countries", specifier = "==8.2.0" }, + { name = "django-dramatiq", specifier = "==0.15.0" }, { name = "django-extensions", marker = "extra == 'dev'", specifier = "==4.1" }, { name = "django-fernet-encrypted-fields", specifier = "==0.3.1" }, { name = "django-filter", specifier = "==25.2" }, @@ -1304,6 +1311,8 @@ requires-dist = [ { name = "djangorestframework", specifier = "==3.16.1" }, { name = "dkimpy", specifier = "==1.1.8" }, { name = "dnspython", specifier = "==2.8.0" }, + { name = "dramatiq", extras = ["redis", "prometheus"], specifier = "==2.0.1" }, + { name = "dramatiq-crontab", extras = ["sentry"], specifier = "==1.0.12" }, { name = "drf-spectacular", specifier = "==0.29.0" }, { name = "drf-spectacular-sidecar", marker = "extra == 'dev'", specifier = "==2026.1.1" }, { name = "factory-boy", specifier = "==3.3.3" }, @@ -1871,15 +1880,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ca/31/d4e37e9e550c2b92a9cbc2e4d0b7420a27224968580b5a447f420847c975/pytest_xdist-3.8.0-py3-none-any.whl", hash = "sha256:202ca578cfeb7370784a8c33d6d05bc6e13b4f25b5053c30a152269fd10f0b88", size = 46396, upload-time = "2025-07-01T13:30:56.632Z" }, ] -[[package]] -name = "python-crontab" -version = "3.3.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/99/7f/c54fb7e70b59844526aa4ae321e927a167678660ab51dda979955eafb89a/python_crontab-3.3.0.tar.gz", hash = "sha256:007c8aee68dddf3e04ec4dce0fac124b93bd68be7470fc95d2a9617a15de291b", size = 57626, upload-time = "2025-07-13T20:05:35.535Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/47/42/bb4afa5b088f64092036221843fc989b7db9d9d302494c1f8b024ee78a46/python_crontab-3.3.0-py3-none-any.whl", hash = "sha256:739a778b1a771379b75654e53fd4df58e5c63a9279a63b5dfe44c0fcc3ee7884", size = 27533, upload-time = "2025-07-13T20:05:34.266Z" }, -] - [[package]] name = "python-dateutil" version = "2.9.0.post0" diff --git a/src/backend/worker.py b/src/backend/worker.py index 71e6343dd..e9fb31e7d 100644 --- a/src/backend/worker.py +++ b/src/backend/worker.py @@ -7,7 +7,7 @@ python worker.py --queues=inbound,default # Process only specific queues python worker.py --exclude=reindex # Process all queues except reindex python worker.py --concurrency=4 # Set worker concurrency - python worker.py --disable-scheduler # Disable the scheduler + python worker.py -v 2 # Verbose logging Queue priority order (highest to lowest): 1. management - Admin/management tasks (migrations, cleanup) @@ -20,21 +20,29 @@ import argparse import logging +import multiprocessing import os import sys +# Workaround for Dramatiq + Python 3.14: forkserver (the new default) breaks +# Dramatiq's Canteen shared-memory mechanism, causing worker processes to never +# consume messages. See https://github.com/Bogdanp/dramatiq/issues/701 +# Must be set before dramatiq.cli.main() spawns worker processes. +multiprocessing.set_start_method("fork", force=True) + # Setup Django before importing the task runner os.environ.setdefault("DJANGO_SETTINGS_MODULE", "messages.settings") os.environ.setdefault("DJANGO_CONFIGURATION", "Development") -# Override $APP if set by the host (e.g. Scalingo), as Celery interprets it as the app module +# Override $APP if set by the host (e.g. Scalingo) os.environ.pop("APP", None) from configurations.importer import install # pylint: disable=wrong-import-position install(check_options=True) -from messages.celery_app import app # pylint: disable=wrong-import-position +import django # pylint: disable=wrong-import-position +django.setup() # Queue definitions in priority order ALL_QUEUES = ["management", "inbound", "outbound", "default", "imports", "reindex"] @@ -43,9 +51,7 @@ def get_default_concurrency(): """Get default concurrency from environment variables.""" - env_value = os.environ.get("WORKER_CONCURRENCY") or os.environ.get( - "CELERY_CONCURRENCY" - ) + env_value = os.environ.get("WORKER_CONCURRENCY") if env_value: try: return int(env_value) @@ -54,6 +60,33 @@ def get_default_concurrency(): return None +def discover_tasks_modules(): + """Discover task modules the same way django_dramatiq does.""" + from django.apps import apps # pylint: disable=wrong-import-position + from django.conf import settings # pylint: disable=wrong-import-position + from django.utils.module_loading import module_has_submodule # pylint: disable=wrong-import-position + import importlib # pylint: disable=wrong-import-position + + task_module_names = getattr(settings, "DRAMATIQ_AUTODISCOVER_MODULES", ("tasks",)) + modules = ["django_dramatiq.setup"] + + for conf in apps.get_app_configs(): + if conf.name == "django_dramatiq": + module = conf.name + ".tasks" + importlib.import_module(module) + logging.getLogger(__name__).info("Discovered tasks module: %r", module) + modules.append(module) + else: + for task_module in task_module_names: + if module_has_submodule(conf.module, task_module): + module = conf.name + "." + task_module + importlib.import_module(module) + logging.getLogger(__name__).info("Discovered tasks module: %r", module) + modules.append(module) + + return modules + + def parse_args(): """Parse command-line arguments.""" parser = argparse.ArgumentParser( @@ -80,19 +113,14 @@ def parse_args(): "-c", type=int, default=get_default_concurrency(), - help="Number of worker processes. Default: WORKER_CONCURRENCY env var or number of CPUs.", + help="Number of worker processes. Default: WORKER_CONCURRENCY env var.", ) parser.add_argument( - "--disable-scheduler", - action="store_true", - help="Disable the task scheduler (enabled by default).", - ) - parser.add_argument( - "--loglevel", - "-l", - type=str, - default="INFO", - help="Logging level. Default: INFO", + "--verbosity", + "-v", + type=int, + default=1, + help="Verbosity level (0=minimal, 1=normal, 2=verbose). Default: 1", ) return parser.parse_args() @@ -131,24 +159,31 @@ def main(): sys.stderr.write("Error: No queues to process after exclusions.\n") sys.exit(1) - # Build worker arguments - worker_args = [ - "worker", - f"--queues={','.join(queues)}", - f"--loglevel={args.loglevel}", + # Discover task modules (same as rundramatiq) + tasks_modules = discover_tasks_modules() + + # Build dramatiq CLI arguments and call main() directly. + # This avoids rundramatiq's os.execvp which replaces the process and + # discards our multiprocessing.set_start_method("fork") workaround. + dramatiq_args = [ + "dramatiq", + "--path", ".", + "--processes", str(args.concurrency or 4), + "--threads", "1", + "--worker-shutdown-timeout", "600000", ] - if args.concurrency: - worker_args.append(f"--concurrency={args.concurrency}") - - if not args.disable_scheduler: - worker_args.append("--beat") + if args.verbosity > 1: + dramatiq_args.append("-v") - # Always enable task events for monitoring - worker_args.append("--task-events") + dramatiq_args.extend(tasks_modules) + dramatiq_args.extend(["--queues", *queues]) logger.info("Starting worker with queues: %s", ", ".join(queues)) - app.worker_main(argv=worker_args) + + import dramatiq.cli # pylint: disable=wrong-import-position + sys.argv = dramatiq_args + dramatiq.cli.main() if __name__ == "__main__":