diff --git a/ami/jobs/models.py b/ami/jobs/models.py index ac0078d76..636085145 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -3,6 +3,7 @@ import random import time import typing +from contextlib import contextmanager from dataclasses import dataclass import pydantic @@ -24,6 +25,73 @@ logger = logging.getLogger(__name__) +# ============================================================================== +# CONCURRENCY PROTECTION +# ============================================================================== +# This module implements row-level locking to prevent concurrent updates from +# multiple Celery workers from overwriting each other's changes. +# +# Key components: +# - atomic_job_update(): Context manager that locks a job row for safe updates +# - JobLogHandler.emit(): Uses locking when writing logs +# - Job.save(): Automatically uses locking and validates logs to prevent overwrites +# +# When multiple workers (1-10) are processing the same job, they all may need +# to update logs, progress, and status fields. Without locking, last-write-wins +# can cause lost updates. The locking approach ensures all updates are preserved. +# ============================================================================== + + +@contextmanager +def atomic_job_update(job_id: int, timeout: int | None = None): + """ + Context manager for safely updating job fields with row-level locking. + + This ensures that concurrent updates to the same job (from multiple + Celery workers or tasks) don't overwrite each other's changes. The job + is locked for the duration of the context, and automatically saved when + the context exits if it was modified. + + Args: + job_id: The ID of the job to lock and update + timeout: Optional timeout in seconds to wait for the lock. + If None (default), waits indefinitely. + + Yields: + Job: The locked job instance, safe to modify + + Example: + with atomic_job_update(job.pk) as locked_job: + locked_job.logs.stdout.insert(0, "New log message") + locked_job.progress.update_stage("process", progress=0.5) + # Job is automatically saved on context exit + + Raises: + Job.DoesNotExist: If the job doesn't exist + DatabaseError: If the lock cannot be acquired within timeout + """ + # Import here to avoid circular import + from ami.jobs.models import Job + + with transaction.atomic(): + # Use select_for_update to lock the row + # nowait=False means we'll wait for the lock (don't lose data) + # skip_locked=False means we'll wait, not skip + query = Job.objects.select_for_update(nowait=False, skip_locked=False) + + if timeout is not None: + # Set statement timeout for this transaction + from django.db import connection + + with connection.cursor() as cursor: + cursor.execute(f"SET LOCAL statement_timeout = {timeout * 1000}") + + job = query.get(pk=job_id) + yield job + # Job will be saved automatically when exiting if changed + # due to Django's behavior with select_for_update + + class JobState(str, OrderedEnum): """ These come from Celery, except for CREATED, which is a custom state. @@ -255,9 +323,12 @@ class JobLogs(pydantic.BaseModel): class JobLogHandler(logging.Handler): """ Class for handling logs from a job and writing them to the job instance. + + Uses row-level locking to prevent concurrent log writes from overwriting + each other when multiple Celery workers are updating the same job. """ - max_log_length = 1000 + max_log_length = 10000 # Allow ~100 messages per batch × hundreds of batches def __init__(self, job: "Job", *args, **kwargs): self.job = job @@ -267,23 +338,29 @@ def emit(self, record: logging.LogRecord): # Log to the current app logger logger.log(record.levelno, self.format(record)) - # Write to the logs field on the job instance + # Write to the logs field on the job instance with atomic locking timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") msg = f"[{timestamp}] {record.levelname} {self.format(record)}" - if msg not in self.job.logs.stdout: - self.job.logs.stdout.insert(0, msg) - # Write a simpler copy of any errors to the errors field - if record.levelno >= logging.ERROR: - if record.message not in self.job.logs.stderr: - self.job.logs.stderr.insert(0, record.message) + try: + # Use atomic update to prevent race conditions + with atomic_job_update(self.job.pk) as job: + if msg not in job.logs.stdout: + job.logs.stdout.insert(0, msg) - if len(self.job.logs.stdout) > self.max_log_length: - self.job.logs.stdout = self.job.logs.stdout[: self.max_log_length] + # Write a simpler copy of any errors to the errors field + if record.levelno >= logging.ERROR: + if record.message not in job.logs.stderr: + job.logs.stderr.insert(0, record.message) - # @TODO consider saving logs to the database periodically rather than on every log - try: - self.job.save(update_fields=["logs"], update_progress=False) + if len(job.logs.stdout) > self.max_log_length: + job.logs.stdout = job.logs.stdout[: self.max_log_length] + + if len(job.logs.stderr) > self.max_log_length: + job.logs.stderr = job.logs.stderr[: self.max_log_length] + + # Save with only the logs field to minimize lock time + job.save(update_fields=["logs"]) except Exception as e: logger.error(f"Failed to save logs for job #{self.job.pk}: {e}") pass @@ -325,7 +402,7 @@ def run(cls, job: "Job"): job.update_status(JobState.STARTED) job.started_at = datetime.datetime.now() job.finished_at = None - job.save() + job.save(update_fields=["status", "progress", "started_at", "finished_at"]) # Keep track of sub-tasks for saving results, pair with batch number save_tasks: list[tuple[int, AsyncResult]] = [] @@ -345,7 +422,8 @@ def run(cls, job: "Job"): progress=i / job.delay, mood="😵‍💫", ) - job.save() + # Only save progress to avoid overwriting logs + job.save(update_fields=["progress"]) last_update = time.time() job.progress.update_stage( @@ -354,7 +432,8 @@ def run(cls, job: "Job"): progress=1, mood="🥳", ) - job.save() + # Only save progress to avoid overwriting logs + job.save(update_fields=["progress"]) if not job.pipeline: raise ValueError("No pipeline specified to process images in ML job") @@ -398,8 +477,8 @@ def run(cls, job: "Job"): progress=1, ) - # End image collection stage - job.save() + # End image collection stage - only save progress to avoid overwriting logs + job.save(update_fields=["progress"]) total_captures = 0 total_detections = 0 @@ -461,7 +540,8 @@ def run(cls, job: "Job"): detections=total_detections, classifications=total_classifications, ) - job.save() + # Only save progress field to avoid overwriting logs from JobLogHandler + job.save(update_fields=["progress"]) # Stop processing if any save tasks have failed # Otherwise, calculate the percent of images that have failed to save @@ -495,7 +575,8 @@ def run(cls, job: "Job"): FAILURE_THRESHOLD = 0.5 if image_count and (percent_successful < FAILURE_THRESHOLD): job.progress.update_stage("process", status=JobState.FAILURE) - job.save() + # Only save progress to avoid overwriting logs + job.save(update_fields=["progress"]) raise Exception(f"Failed to process more than {int(FAILURE_THRESHOLD * 100)}% of images") job.progress.update_stage( @@ -510,7 +591,8 @@ def run(cls, job: "Job"): ) job.update_status(JobState.SUCCESS, save=False) job.finished_at = datetime.datetime.now() - job.save() + # Save all final fields at once, excluding logs + job.save(update_fields=["status", "progress", "finished_at"]) class DataStorageSyncJob(JobType): @@ -530,7 +612,8 @@ def run(cls, job: "Job"): job.update_status(JobState.STARTED) job.started_at = datetime.datetime.now() job.finished_at = None - job.save() + # Only save specific fields to avoid overwriting logs + job.save(update_fields=["progress", "status", "started_at", "finished_at"]) if not job.deployment: raise ValueError("No deployment provided for data storage sync job") @@ -542,7 +625,8 @@ def run(cls, job: "Job"): progress=0, total_files=0, ) - job.save() + # Only save progress to avoid overwriting logs + job.save(update_fields=["progress"]) job.deployment.sync_captures(job=job) @@ -553,10 +637,12 @@ def run(cls, job: "Job"): progress=1, ) job.update_status(JobState.SUCCESS) - job.save() + # Save status and progress to avoid overwriting logs + job.save(update_fields=["status", "progress"]) job.finished_at = datetime.datetime.now() - job.save() + # Only save finished_at to avoid overwriting logs + job.save(update_fields=["finished_at"]) class SourceImageCollectionPopulateJob(JobType): @@ -575,7 +661,8 @@ def run(cls, job: "Job"): job.update_status(JobState.STARTED) job.started_at = datetime.datetime.now() job.finished_at = None - job.save() + # Only save specific fields to avoid overwriting logs + job.save(update_fields=["progress", "status", "started_at", "finished_at"]) if not job.source_image_collection: raise ValueError("No source image collection provided") @@ -590,11 +677,13 @@ def run(cls, job: "Job"): progress=0.10, captures_added=0, ) - job.save() + # Only save progress to avoid overwriting logs + job.save(update_fields=["progress", "status", "started_at", "finished_at"]) job.source_image_collection.populate_sample(job=job) job.logger.info(f"Finished populating source image collection {job.source_image_collection}") - job.save() + # Only save progress to avoid overwriting logs + job.save(update_fields=["progress"]) captures_added = job.source_image_collection.images.count() job.logger.info(f"Added {captures_added} captures to source image collection {job.source_image_collection}") @@ -607,7 +696,8 @@ def run(cls, job: "Job"): ) job.finished_at = datetime.datetime.now() job.update_status(JobState.SUCCESS, save=False) - job.save() + # Save final fields to avoid overwriting logs + job.save(update_fields=["progress", "status", "finished_at"]) class DataExportJob(JobType): @@ -630,7 +720,8 @@ def run(cls, job: "Job"): job.update_status(JobState.STARTED) job.started_at = datetime.datetime.now() job.finished_at = None - job.save() + # Only save specific fields to avoid overwriting logs + job.save(update_fields=["progress", "status", "started_at", "finished_at"]) job.logger.info(f"Starting export for project {job.project}") @@ -643,7 +734,9 @@ def run(cls, job: "Job"): job.progress.add_stage_param(stage.key, "File URL", f"{file_url}") job.progress.update_stage(stage.key, status=JobState.SUCCESS, progress=1) job.finished_at = datetime.datetime.now() - job.update_status(JobState.SUCCESS, save=True) + job.update_status(JobState.SUCCESS, save=False) + # Save final fields to avoid overwriting logs + job.save(update_fields=["progress", "status", "finished_at"]) class PostProcessingJob(JobType): @@ -929,22 +1022,129 @@ def update_progress(self, save=True): self.progress.summary.progress = total_progress if save: - self.save(update_progress=False) + # Only update progress field to avoid concurrency issues + self.save(update_fields=["progress"]) + + def _validate_log_lengths(self, update_fields: list[str] | None) -> None: + """ + Validate that logs aren't getting shorter due to stale in-memory data. + + This is a safety check to catch bugs where concurrent updates might + overwrite logs. If logs would get shorter, automatically refreshes + them from the database to prevent data loss. + + Args: + update_fields: List of fields being updated, or None for all fields + + Note: + This can be easily disabled by commenting out the call in save() + if the validation overhead becomes an issue. However, the check + is very cheap (just a length comparison) and provides valuable + protection against data loss. + """ + if self.pk is None or not update_fields or "logs" in update_fields: + # Skip validation for new jobs, when updating logs explicitly, + # or when not using update_fields + return + + try: + # Get current log lengths from database + current_job = Job.objects.only("logs").get(pk=self.pk) + current_stdout_len = len(current_job.logs.stdout) + current_stderr_len = len(current_job.logs.stderr) + new_stdout_len = len(self.logs.stdout) + new_stderr_len = len(self.logs.stderr) + + # If logs would get shorter, it means we have stale in-memory data + if new_stdout_len < current_stdout_len or new_stderr_len < current_stderr_len: + logger.error( + f"CRITICAL: Job #{self.pk} attempted to save with stale logs! " + f"stdout: {current_stdout_len} -> {new_stdout_len}, " + f"stderr: {current_stderr_len} -> {new_stderr_len}. " + f"update_fields={update_fields}. This would lose log data!" + ) + # Refresh logs from database to prevent data loss + self.logs = current_job.logs + logger.warning(f"Refreshed logs for job #{self.pk} from database to prevent data loss") + except Job.DoesNotExist: + # Job might have been deleted, let it fail naturally + pass + except Exception as e: + # Don't let validation break the save, but log it + logger.warning(f"Failed to validate log lengths for job #{self.pk}: {e}") def duration(self) -> datetime.timedelta | None: if self.started_at and self.finished_at: return self.finished_at - self.started_at return None - def save(self, update_progress=True, *args, **kwargs): + def save(self, update_progress=True, use_locking=True, *args, **kwargs): """ Create the job stages if they don't exist. + + This method automatically uses row-level locking when updating existing jobs + to prevent concurrent workers from overwriting each other's changes. + + Args: + update_progress: Whether to recalculate progress summary (default: True) + use_locking: Whether to use SELECT FOR UPDATE locking for existing jobs (default: True) + *args, **kwargs: Additional arguments passed to Django's save() + + Special handling: + - If 'update_fields' is specified, only those fields will be saved + - If 'update_fields' includes 'logs', locking is automatically disabled to avoid + conflicts with JobLogHandler which manages its own locks + - For new jobs (pk=None), locking is skipped """ + is_new = self.pk is None + update_fields = kwargs.get("update_fields") + + # Don't use locking if explicitly disabled or if this is a new job + # or if we're only updating logs (JobLogHandler manages its own locks) + should_lock = use_locking and not is_new and not (update_fields and update_fields == ["logs"]) + + # Update progress/setup before saving if self.pk and self.progress.stages and update_progress: self.update_progress(save=False) else: self.setup(save=False) - super().save(*args, **kwargs) + + # Safety check: Ensure logs never get shorter (unless explicitly updating logs) + # This can be disabled by commenting out this line if needed + self._validate_log_lengths(update_fields) + + if should_lock: + # Refresh from database with row-level lock to prevent concurrent overwrites + # This ensures we have the latest data before saving our changes + try: + with atomic_job_update(self.pk) as locked_job: + # Copy our in-memory changes to the locked instance + if update_fields: + # Only update specified fields + for field_name in update_fields: + setattr(locked_job, field_name, getattr(self, field_name)) + else: + # Update all non-log fields to preserve concurrent log writes + for field in self._meta.fields: + if field.name != "logs": # Never overwrite logs unless explicitly specified + setattr(locked_job, field.name, getattr(self, field.name)) + + # Save the locked instance with our changes + super(Job, locked_job).save(*args, **kwargs) + + # Update our instance with the saved state + self.pk = locked_job.pk + for field in self._meta.fields: + setattr(self, field.name, getattr(locked_job, field.name)) + except Exception as e: + logger.error(f"Failed to save job #{self.pk} with locking: {e}") + # Fall back to normal save - better to save without lock than to fail completely + logger.warning(f"Falling back to unlocked save for job #{self.pk}") + super().save(*args, **kwargs) + else: + # New job or locking disabled - use normal save + super().save(*args, **kwargs) + logger.debug(f"Saved job {self}") if self.progress.summary.status != self.status: logger.warning(f"Job {self} status mismatches progress: {self.progress.summary.status} != {self.status}") diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index b12271178..238d7e9a4 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -50,7 +50,7 @@ def update_job_status(sender, task_id, task, *args, **kwargs): task = AsyncResult(task_id) # I'm not sure if this is reliable job.update_status(task.status, save=False) - job.save() + job.save(update_fields=["status", "progress"]) @task_failure.connect(sender=run_job, retry=False) @@ -62,4 +62,93 @@ def update_job_failure(sender, task_id, exception, *args, **kwargs): job.logger.error(f'Job #{job.pk} "{job.name}" failed: {exception}') - job.save() + job.save(update_fields=["status", "progress"]) + + +@celery_app.task(soft_time_limit=300, time_limit=360) +def check_unfinished_jobs(): + """ + Periodic task to check the status of all unfinished jobs. + + This task prevents duplicate execution using cache-based locking and + checks jobs that haven't been verified recently to ensure their Celery + tasks are still active and their statuses are accurate. + """ + import datetime + + from ami.jobs.models import Job, JobState + + # Configuration thresholds (TODO: make these configurable via settings) + LOCK_TIMEOUT_SECONDS = 300 # 5 minutes - how long the lock is held + MAX_JOBS_PER_RUN = 100 # Maximum number of jobs to check in one run + MIN_CHECK_INTERVAL_MINUTES = 2 # Minimum time between checks for the same job + + # Use cache-based locking to prevent duplicate checks + lock_id = "check_unfinished_jobs_lock" + + # Try to acquire lock + if not cache.add(lock_id, "locked", LOCK_TIMEOUT_SECONDS): + task_logger.info("check_unfinished_jobs is already running, skipping this execution") + return {"status": "skipped", "reason": "already_running"} + + try: + task_logger.info("Starting check_unfinished_jobs task") + + # Get all jobs that are not in final states + unfinished_jobs = Job.objects.filter(status__in=JobState.running_states()).order_by("scheduled_at") + + total_jobs = unfinished_jobs.count() + task_logger.info(f"Found {total_jobs} unfinished jobs to check") + + if total_jobs == 0: + return {"status": "success", "checked": 0, "updated": 0} + + # Avoid checking too many jobs at once + if total_jobs > MAX_JOBS_PER_RUN: + task_logger.warning(f"Limiting check to {MAX_JOBS_PER_RUN} jobs (out of {total_jobs})") + unfinished_jobs = unfinished_jobs[:MAX_JOBS_PER_RUN] + + # Only check jobs that haven't been checked recently + now = datetime.datetime.now() + min_check_interval = datetime.timedelta(minutes=MIN_CHECK_INTERVAL_MINUTES) + + jobs_to_check = [] + for job in unfinished_jobs: + if job.last_checked_at is None: + jobs_to_check.append(job) + else: + time_since_check = now - job.last_checked_at + if time_since_check >= min_check_interval: + jobs_to_check.append(job) + + task_logger.info(f"Checking {len(jobs_to_check)} jobs that need status verification") + + checked_count = 0 + updated_count = 0 + error_count = 0 + + for job in jobs_to_check: + try: + task_logger.debug(f"Checking job {job.pk}: {job.name} (status: {job.status})") + status_changed = job.check_status(force=False, save=True) + checked_count += 1 + if status_changed: + updated_count += 1 + task_logger.info(f"Updated job {job.pk} status to {job.status}") + except Exception as e: + error_count += 1 + task_logger.error(f"Error checking job {job.pk}: {e}", exc_info=True) + + result = { + "status": "success", + "total_unfinished": total_jobs, + "checked": checked_count, + "updated": updated_count, + "errors": error_count, + } + task_logger.info(f"Completed check_unfinished_jobs: {result}") + return result + + finally: + # Always release the lock + cache.delete(lock_id) diff --git a/ami/jobs/tests.py b/ami/jobs/tests.py index 64fbf23a2..85aee8ebb 100644 --- a/ami/jobs/tests.py +++ b/ami/jobs/tests.py @@ -1,7 +1,10 @@ # from rich import print +import datetime import logging +from unittest.mock import MagicMock, patch from django.test import TestCase +from django.utils import timezone from guardian.shortcuts import assign_perm from rest_framework.test import APIRequestFactory, APITestCase @@ -198,3 +201,529 @@ def test_cancel_job(self): # This cannot be tested until we have a way to cancel jobs # and a way to run async tasks in tests. pass + + +class TestJobStatusChecking(TestCase): + """ + Test the job status checking functionality. + """ + + def setUp(self): + self.project = Project.objects.create(name="Status Check Test Project") + self.pipeline = Pipeline.objects.create( + name="Test ML pipeline", + description="Test ML pipeline", + ) + self.pipeline.projects.add(self.project) + self.source_image_collection = SourceImageCollection.objects.create( + name="Test collection", + project=self.project, + ) + + def test_check_status_no_task_id_recently_scheduled(self): + """Test that recently scheduled jobs without task_id are not marked as failed.""" + job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Test job - no task_id", + scheduled_at=timezone.now(), + ) + + status_changed = job.check_status() + + self.assertFalse(status_changed) + self.assertEqual(job.status, JobState.CREATED.value) + self.assertIsNotNone(job.last_checked_at) + + def test_check_status_no_task_id_old_scheduled(self): + """Test that old scheduled jobs without task_id are marked as failed.""" + job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Test job - stale no task_id", + scheduled_at=timezone.now() - datetime.timedelta(minutes=10), + ) + + status_changed = job.check_status() + + self.assertTrue(status_changed) + self.assertEqual(job.status, JobState.FAILURE.value) + self.assertIsNotNone(job.finished_at) + self.assertIsNotNone(job.last_checked_at) + + @patch("ami.jobs.models.AsyncResult") + def test_check_status_with_matching_status(self, mock_async_result): + """Test that jobs with matching Celery status are not changed.""" + mock_task = MagicMock() + mock_task.status = JobState.STARTED.value + mock_async_result.return_value = mock_task + + job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Test job - matching status", + task_id="test-task-id-123", + status=JobState.STARTED.value, + started_at=timezone.now(), + ) + + status_changed = job.check_status() + + self.assertFalse(status_changed) + self.assertEqual(job.status, JobState.STARTED.value) + self.assertIsNotNone(job.last_checked_at) + + @patch("ami.jobs.models.AsyncResult") + def test_check_status_with_mismatched_status(self, mock_async_result): + """Test that jobs with mismatched Celery status are updated.""" + mock_task = MagicMock() + mock_task.status = JobState.FAILURE.value + mock_async_result.return_value = mock_task + + job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Test job - mismatched status", + task_id="test-task-id-456", + status=JobState.STARTED.value, + started_at=timezone.now(), + ) + + status_changed = job.check_status() + + self.assertTrue(status_changed) + self.assertEqual(job.status, JobState.FAILURE.value) + self.assertIsNotNone(job.finished_at) + self.assertIsNotNone(job.last_checked_at) + + @patch("ami.jobs.models.AsyncResult") + def test_check_status_stale_running_job(self, mock_async_result): + """Test that jobs running for too long are marked as failed.""" + mock_task = MagicMock() + mock_task.status = JobState.STARTED.value + mock_async_result.return_value = mock_task + + # Create job that started longer than MAX_JOB_RUNTIME_SECONDS ago + stale_time = datetime.timedelta(seconds=Job.MAX_JOB_RUNTIME_SECONDS + 3600) # 1 hour past limit + job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Test job - stale running", + task_id="test-task-id-789", + status=JobState.STARTED.value, + started_at=timezone.now() - stale_time, + ) + + status_changed = job.check_status() + + self.assertTrue(status_changed) + self.assertEqual(job.status, JobState.FAILURE.value) + self.assertIsNotNone(job.finished_at) + # Verify task was attempted to be revoked + mock_task.revoke.assert_called_once_with(terminate=True) + + @patch("ami.jobs.models.AsyncResult") + def test_check_status_stuck_pending(self, mock_async_result): + """Test that jobs stuck in PENDING for too long are marked as failed.""" + mock_task = MagicMock() + mock_task.status = JobState.PENDING.value + mock_async_result.return_value = mock_task + + job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Test job - stuck pending", + task_id="test-task-id-pending", + status=JobState.PENDING.value, + scheduled_at=timezone.now() - datetime.timedelta(minutes=15), + ) + + status_changed = job.check_status() + + self.assertTrue(status_changed) + self.assertEqual(job.status, JobState.FAILURE.value) + self.assertIsNotNone(job.finished_at) + + def test_check_status_does_not_check_completed_jobs(self): + """Test that completed jobs are not checked unless forced.""" + job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Test job - completed", + task_id="test-task-id-completed", + status=JobState.SUCCESS.value, + finished_at=timezone.now(), + ) + + status_changed = job.check_status(force=False) + + self.assertFalse(status_changed) + self.assertEqual(job.status, JobState.SUCCESS.value) + self.assertIsNotNone(job.last_checked_at) + + @patch("ami.jobs.models.AsyncResult") + def test_check_status_forces_check_on_completed_jobs(self, mock_async_result): + """Test that force=True checks even completed jobs.""" + mock_task = MagicMock() + mock_task.status = JobState.SUCCESS.value + mock_async_result.return_value = mock_task + + job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Test job - force check", + task_id="test-task-id-force", + status=JobState.SUCCESS.value, + finished_at=timezone.now(), + ) + + status_changed = job.check_status(force=True) + + # Status shouldn't change since Celery status matches + self.assertFalse(status_changed) + self.assertIsNotNone(job.last_checked_at) + + @patch("ami.jobs.tasks.cache") + @patch("ami.jobs.models.Job.objects") + def test_check_unfinished_jobs_with_lock(self, mock_job_objects, mock_cache): + """Test that check_unfinished_jobs uses locking to prevent duplicates.""" + from ami.jobs.tasks import check_unfinished_jobs + + # Simulate lock already acquired + mock_cache.add.return_value = False + + result = check_unfinished_jobs() + + self.assertEqual(result["status"], "skipped") + self.assertEqual(result["reason"], "already_running") + mock_cache.add.assert_called_once() + mock_cache.delete.assert_not_called() + + @patch("ami.jobs.tasks.cache") + def test_check_unfinished_jobs_processes_jobs(self, mock_cache): + """Test that check_unfinished_jobs processes unfinished jobs.""" + from ami.jobs.tasks import check_unfinished_jobs + + # Allow lock to be acquired + mock_cache.add.return_value = True + + # Create some unfinished jobs + job1 = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Unfinished job 1", + status=JobState.STARTED.value, + task_id="test-task-1", + started_at=timezone.now(), + last_checked_at=timezone.now() - datetime.timedelta(minutes=5), + ) + + job2 = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Unfinished job 2", + status=JobState.PENDING.value, + task_id="test-task-2", + scheduled_at=timezone.now() - datetime.timedelta(minutes=3), + ) + + # Create a completed job (should not be checked) + Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Completed job", + status=JobState.SUCCESS.value, + finished_at=timezone.now(), + ) + + with patch("ami.jobs.models.AsyncResult") as mock_async_result: + mock_task = MagicMock() + mock_task.status = JobState.STARTED.value + mock_async_result.return_value = mock_task + + result = check_unfinished_jobs() + + self.assertEqual(result["status"], "success") + self.assertEqual(result["total_unfinished"], 2) + self.assertGreaterEqual(result["checked"], 1) + + # Verify lock was released + mock_cache.delete.assert_called_once() + + # Verify jobs were checked + job1.refresh_from_db() + job2.refresh_from_db() + self.assertIsNotNone(job1.last_checked_at) + self.assertIsNotNone(job2.last_checked_at) + + @patch("ami.jobs.models.AsyncResult") + def test_check_status_task_disappeared_with_retry(self, mock_async_result): + """Test that jobs with disappeared tasks are retried if they just started.""" + mock_task = MagicMock() + mock_task.status = None # Task not found + mock_async_result.return_value = mock_task + + job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Test job - task disappeared", + task_id="test-task-disappeared", + status=JobState.STARTED.value, + started_at=timezone.now() - datetime.timedelta(minutes=2), # Started 2 mins ago + pipeline=self.pipeline, + source_image_collection=self.source_image_collection, + ) + + # Mock the retry method + with patch.object(job, "retry") as mock_retry: + status_changed = job.check_status(auto_retry=True) + + # Should attempt retry + mock_retry.assert_called_once_with(async_task=True) + self.assertTrue(status_changed) + + @patch("ami.jobs.models.AsyncResult") + def test_check_status_task_disappeared_no_retry_old_job(self, mock_async_result): + """Test that old jobs with disappeared tasks are marked failed, not retried.""" + mock_task = MagicMock() + mock_task.status = None # Task not found + mock_async_result.return_value = mock_task + + job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Test job - old disappeared task", + task_id="test-task-old-disappeared", + status=JobState.STARTED.value, + started_at=timezone.now() - datetime.timedelta(minutes=10), # Started 10 mins ago + ) + + status_changed = job.check_status(auto_retry=True) + + # Should not retry, just mark as failed + self.assertTrue(status_changed) + self.assertEqual(job.status, JobState.FAILURE.value) + self.assertIsNotNone(job.finished_at) + + @patch("ami.jobs.models.AsyncResult") + def test_check_status_task_disappeared_auto_retry_disabled(self, mock_async_result): + """Test that auto_retry=False prevents automatic retry.""" + mock_task = MagicMock() + mock_task.status = None # Task not found + mock_async_result.return_value = mock_task + + job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Test job - no auto retry", + task_id="test-task-no-retry", + status=JobState.STARTED.value, + started_at=timezone.now() - datetime.timedelta(minutes=2), + ) + + status_changed = job.check_status(auto_retry=False) + + # Should not retry, just mark as failed + self.assertTrue(status_changed) + self.assertEqual(job.status, JobState.FAILURE.value) + self.assertIsNotNone(job.finished_at) + + @patch("ami.jobs.models.AsyncResult") + def test_check_status_task_pending_but_job_running(self, mock_async_result): + """Test that PENDING status from Celery when job thinks it's running indicates disappeared task.""" + mock_task = MagicMock() + mock_task.status = "PENDING" # Celery returns PENDING for unknown tasks + mock_async_result.return_value = mock_task + + job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Test job - pending but should be running", + task_id="test-task-fake-pending", + status=JobState.STARTED.value, + started_at=timezone.now() - datetime.timedelta(minutes=2), + ) + + status_changed = job.check_status(auto_retry=False) + + # Should detect this as a disappeared task + self.assertTrue(status_changed) + self.assertEqual(job.status, JobState.FAILURE.value) + + +class TestJobConcurrency(TestCase): + """Test concurrent updates to jobs from multiple workers.""" + + def setUp(self): + self.project = Project.objects.create(name="Test project") + self.pipeline = Pipeline.objects.create( + name="Test ML pipeline", + description="Test ML pipeline", + ) + self.pipeline.projects.add(self.project) + + def test_atomic_job_update_context_manager(self): + """Test that atomic_job_update locks and updates the job.""" + from ami.jobs.models import atomic_job_update + + job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Test job - atomic update", + pipeline=self.pipeline, + ) + + # Use the context manager to update the job + with atomic_job_update(job.pk) as locked_job: + locked_job.logs.stdout.insert(0, "Test log message") + locked_job.save(update_fields=["logs"], update_progress=False) + + # Refresh from DB and verify the update persisted + job.refresh_from_db() + self.assertIn("Test log message", job.logs.stdout) + + def test_concurrent_log_writes(self): + """Test that concurrent log writes don't overwrite each other.""" + job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Test job - concurrent logs", + pipeline=self.pipeline, + ) + + # Simulate multiple workers adding logs + messages = [f"Log message {i}" for i in range(5)] + + for msg in messages: + # Use the logger which uses JobLogHandler with atomic updates + job.logger.info(msg) + + # Refresh from DB + job.refresh_from_db() + + # All messages should be present (no overwrites) + for msg in messages: + # Messages are formatted with timestamps and log levels + self.assertTrue(any(msg in log for log in job.logs.stdout), f"Message '{msg}' not found in logs") + + def test_log_handler_with_atomic_update(self): + """Test that JobLogHandler properly uses atomic updates.""" + job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Test job - log handler", + pipeline=self.pipeline, + ) + + # Get the logger (which adds JobLogHandler) + job_logger = job.logger + + # Add multiple log messages + job_logger.info("Info message") + job_logger.warning("Warning message") + job_logger.error("Error message") + + # Refresh from DB + job.refresh_from_db() + + # Verify all logs are present + self.assertTrue(any("Info message" in log for log in job.logs.stdout)) + self.assertTrue(any("Warning message" in log for log in job.logs.stdout)) + self.assertTrue(any("Error message" in log for log in job.logs.stdout)) + + # Verify error also appears in stderr + self.assertTrue(any("Error message" in err for err in job.logs.stderr)) + + def test_max_log_length_enforcement(self): + """Test that log length limits are enforced with atomic updates.""" + import logging + + from ami.jobs.models import JobLogHandler + + job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Test job - max logs", + pipeline=self.pipeline, + ) + + # Temporarily suppress log output to avoid spamming test results + job_logger = job.logger + original_level = job_logger.level + job_logger.setLevel(logging.CRITICAL) + + try: + # Add more logs than the max + max_logs = JobLogHandler.max_log_length + for i in range(max_logs + 10): + job.logger.info(f"Message {i}") + + # Refresh from DB + job.refresh_from_db() + + # Should not exceed max length + self.assertLessEqual(len(job.logs.stdout), max_logs) + self.assertLessEqual(len(job.logs.stderr), max_logs) + finally: + # Restore original log level + job_logger.setLevel(original_level) + + def test_log_length_never_decreases(self): + """Test that the save method prevents logs from getting shorter.""" + job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Test job - log safety", + pipeline=self.pipeline, + ) + + # Add some logs + job.logger.info("Log message 1") + job.logger.info("Log message 2") + job.logger.info("Log message 3") + + job.refresh_from_db() + initial_log_count = len(job.logs.stdout) + self.assertGreaterEqual(initial_log_count, 3) + + # Simulate stale in-memory job with fewer logs (like what happens with concurrent workers) + stale_job = Job.objects.get(pk=job.pk) + stale_job.logs.stdout = stale_job.logs.stdout[:1] # Artificially reduce logs to just 1 + + # Try to save with update_fields that doesn't include logs + # The safety check should prevent logs from being overwritten + stale_job.status = JobState.STARTED + stale_job.save(update_fields=["status", "progress"]) + + # Verify logs weren't reduced + stale_job.refresh_from_db() + final_log_count = len(stale_job.logs.stdout) + self.assertEqual( + final_log_count, + initial_log_count, + "Logs should never decrease in length when not explicitly updating logs", + ) + + def test_log_can_be_explicitly_updated(self): + """Test that logs CAN be updated when explicitly included in update_fields.""" + job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Test job - explicit log update", + pipeline=self.pipeline, + ) + + # Add initial logs + job.logger.info("Log message 1") + job.logger.info("Log message 2") + + job.refresh_from_db() + + # Explicitly update logs (like JobLogHandler does) + job.logs.stdout = ["New log only"] + job.save(update_fields=["logs"]) + + # Verify logs were updated as requested + job.refresh_from_db() + self.assertEqual(len(job.logs.stdout), 1) + self.assertEqual(job.logs.stdout[0], "New log only")