diff --git a/ami/main/api/views.py b/ami/main/api/views.py index ac6df634a..4d68dcba3 100644 --- a/ami/main/api/views.py +++ b/ami/main/api/views.py @@ -1475,7 +1475,11 @@ def get_queryset(self) -> QuerySet: qs = self.get_taxa_observed(qs, project, include_unobserved=include_unobserved) if self.action == "retrieve": qs = self.get_taxa_observed( - qs, project, include_unobserved=include_unobserved, apply_default_filters=False + qs, + project, + include_unobserved=include_unobserved, + apply_default_score_filter=True, + apply_default_taxa_filter=False, ) qs = qs.prefetch_related( Prefetch( @@ -1493,7 +1497,12 @@ def get_queryset(self) -> QuerySet: return qs def get_taxa_observed( - self, qs: QuerySet, project: Project, include_unobserved=False, apply_default_filters=True + self, + qs: QuerySet, + project: Project, + include_unobserved=False, + apply_default_score_filter=True, + apply_default_taxa_filter=True, ) -> QuerySet: """ If a project is passed, only return taxa that have been observed. @@ -1511,15 +1520,21 @@ def get_taxa_observed( # Respects apply_defaults flag: build_occurrence_default_filters_q checks it internally from ami.main.models_future.filters import build_occurrence_default_filters_q - default_filters_q = build_occurrence_default_filters_q(project, self.request, occurrence_accessor="") + default_filters_q = build_occurrence_default_filters_q( + project, + self.request, + occurrence_accessor="", + apply_default_score_filter=apply_default_score_filter, + apply_default_taxa_filter=apply_default_taxa_filter, + ) # Combine base occurrence filters with default filters base_filter = models.Q( occurrence_filters, determination_id=models.OuterRef("id"), ) - if apply_default_filters: - base_filter = base_filter & default_filters_q + + base_filter = base_filter & default_filters_q # Count occurrences - uses composite index (determination_id, project_id, event_id, determination_score) occurrences_count_subquery = models.Subquery( diff --git a/ami/main/models.py b/ami/main/models.py index f672c2832..ad4644eaf 100644 --- a/ami/main/models.py +++ b/ami/main/models.py @@ -150,6 +150,17 @@ def get_or_create_default_collection(project: "Project") -> "SourceImageCollecti return collection +def get_project_default_filters(): + """ + Read default taxa names from Django settings (read from environment variables) + and return corresponding Taxon objects. + """ + include_taxa = list(Taxon.objects.filter(name__in=settings.DEFAULT_INCLUDE_TAXA)) + exclude_taxa = list(Taxon.objects.filter(name__in=settings.DEFAULT_EXCLUDE_TAXA)) + + return {"default_include_taxa": include_taxa, "default_exclude_taxa": exclude_taxa} + + def get_or_create_default_project(user: User) -> "Project": """ Create a default project for a user. @@ -158,8 +169,20 @@ def get_or_create_default_project(user: User) -> "Project": when the project is saved for the first time. If the project already exists, it will be returned without modification. """ - project, _created = Project.objects.get_or_create(name="Scratch Project", owner=user, create_defaults=True) - logger.info(f"Created default project for user {user}") + project, created = Project.objects.get_or_create(name="Scratch Project", owner=user) + if created: + logger.info(f"Created default project for user {user}") + defaults = get_project_default_filters() + + if defaults["default_include_taxa"]: + project.default_filters_include_taxa.set(defaults["default_include_taxa"]) + logger.info(f"Set {len(defaults['default_include_taxa'])} default include taxa for project {project}") + if defaults["default_exclude_taxa"]: + project.default_filters_exclude_taxa.set(defaults["default_exclude_taxa"]) + logger.info(f"Set {len(defaults['default_exclude_taxa'])} default exclude taxa for project {project}") + project.save() + else: + logger.info(f"Loaded existing default project for user {user}") return project @@ -678,6 +701,18 @@ def get_first_and_last_timestamps(self) -> tuple[datetime.datetime, datetime.dat ) return (first, last) + def get_detections_count(self) -> int | None: + """Return detections count filtered by project default filters""" + + qs = Detection.objects.filter(source_image__deployment=self) + filter_q = build_occurrence_default_filters_q( + project=self.project, + request=None, + occurrence_accessor="occurrence", + ) + + return qs.filter(filter_q).distinct().count() + def first_date(self) -> datetime.date | None: return self.first_capture_timestamp.date() if self.first_capture_timestamp else None @@ -883,7 +918,7 @@ def update_calculated_fields(self, save=False): self.events_count = self.events.count() self.captures_count = self.data_source_total_files or self.captures.count() - self.detections_count = Detection.objects.filter(Q(source_image__deployment=self)).count() + self.detections_count = self.get_detections_count() occ_qs = self.occurrences.filter(event__isnull=False).apply_default_filters( # type: ignore project=self.project, request=None, @@ -1048,7 +1083,15 @@ def get_captures_count(self) -> int: return self.captures.distinct().count() def get_detections_count(self) -> int | None: - return Detection.objects.filter(Q(source_image__event=self)).count() + """Return detections count filtered by project default filters""" + qs = Detection.objects.filter(source_image__event=self) + filter_q = build_occurrence_default_filters_q( + project=self.project, + request=None, + occurrence_accessor="occurrence", + ) + + return qs.filter(filter_q).distinct().count() def get_occurrences_count(self, classification_threshold: float = 0) -> int: """ @@ -1753,7 +1796,20 @@ def size_display(self) -> str: return filesizeformat(self.size) def get_detections_count(self) -> int: - return self.detections.distinct().count() + """ + Return detections count filtered by project default filters. + """ + project = self.project + if not project: + return self.detections.distinct().count() + + q = build_occurrence_default_filters_q( + project=project, + request=None, + occurrence_accessor="occurrence", + ) + + return self.detections.filter(q).distinct().count() def get_base_url(self) -> str | None: """ diff --git a/ami/main/models_future/filters.py b/ami/main/models_future/filters.py index 6689065c2..8b8782dce 100644 --- a/ami/main/models_future/filters.py +++ b/ami/main/models_future/filters.py @@ -129,6 +129,8 @@ def build_occurrence_default_filters_q( project: "Project | None" = None, request: "Request | None" = None, occurrence_accessor: str = "", + apply_default_score_filter: bool = True, + apply_default_taxa_filter: bool = True, ) -> Q: """ Build a Q filter that applies default filters (score threshold + taxa) for Occurrence relationships. @@ -194,19 +196,19 @@ def build_occurrence_default_filters_q( return Q() filter_q = Q() - - # Build score threshold filter - score_threshold = get_default_classification_threshold(project, request) - filter_q &= build_occurrence_score_threshold_q(score_threshold, occurrence_accessor) - - # Build taxa inclusion/exclusion filter - # For taxa filtering, we need to append "__determination" to the occurrence accessor - prefix = f"{occurrence_accessor}__" if occurrence_accessor else "" - taxon_accessor = f"{prefix}determination" - include_taxa = project.default_filters_include_taxa.all() - exclude_taxa = project.default_filters_exclude_taxa.all() - taxa_q = build_taxa_recursive_filter_q(include_taxa, exclude_taxa, taxon_accessor) - if taxa_q: - filter_q &= taxa_q + if apply_default_score_filter: + # Build score threshold filter + score_threshold = get_default_classification_threshold(project, request) + filter_q &= build_occurrence_score_threshold_q(score_threshold, occurrence_accessor) + if apply_default_taxa_filter: + # Build taxa inclusion/exclusion filter + # For taxa filtering, we need to append "__determination" to the occurrence accessor + prefix = f"{occurrence_accessor}__" if occurrence_accessor else "" + taxon_accessor = f"{prefix}determination" + include_taxa = project.default_filters_include_taxa.all() + exclude_taxa = project.default_filters_exclude_taxa.all() + taxa_q = build_taxa_recursive_filter_q(include_taxa, exclude_taxa, taxon_accessor) + if taxa_q: + filter_q &= taxa_q return filter_q diff --git a/ami/main/signals.py b/ami/main/signals.py index a81ee13b0..f99620cf7 100644 --- a/ami/main/signals.py +++ b/ami/main/signals.py @@ -5,9 +5,10 @@ from django.dispatch import receiver from guardian.shortcuts import assign_perm +from ami.main.models import Project from ami.users.roles import BasicMember, ProjectManager, create_roles_for_project -from .models import Project, User +from .models import User logger = logging.getLogger(__name__) @@ -110,3 +111,81 @@ def delete_project_groups(sender, instance, **kwargs): prefix = f"{instance.pk}_" # Find and delete all groups that start with {project_id}_ Group.objects.filter(name__startswith=prefix).delete() + + +# ============================================================================ +# Project Default Filters Update Signals +# ============================================================================ +# These signals handle efficient updates to calculated fields for project-related +# objects (such as Deployments and Events) whenever a project's default filter +# values change. +# +# Specifically, they trigger recalculation of cached counts when: +# - The project's default score threshold is updated +# - The project's default include taxa are modified +# - The project's default exclude taxa are modified +# +# This ensures that cached counts (e.g., occurrences_count, taxa_count) remain +# accurate and consistent with the active filter configuration for each project. +# ============================================================================ + + +def refresh_cached_counts_for_project(project: Project): + """ + Refresh cached counts for Deployments and Events belonging to a project. + """ + logger.info(f"Refreshing cached counts for project {project.pk} ({project.name})") + project.update_related_calculated_fields() + + +@receiver(pre_save, sender=Project) +def cache_old_threshold(sender, instance, **kwargs): + """ + Cache the previous default score threshold before saving the Project. + + We do this because: + - In post_save, the instance already contains the NEW value. + - To detect whether the threshold actually changed, we must read the OLD + value from the database before the update happens. + - This allows us to accurately detect threshold changes and then trigger + recalculation of cached filtered counts (Events, Deployments, etc.). + + The cached value is stored on the instance as `_old_threshold` so it can be + safely accessed in the post_save handler. + """ + if instance.pk: + instance._old_threshold = Project.objects.get(pk=instance.pk).default_filters_score_threshold + else: + instance._old_threshold = None + + +@receiver(post_save, sender=Project) +def threshold_updated(sender, instance, **kwargs): + """ + After saving the Project, compare the previously cached threshold with the new value. + If the default score threshold changed, we refresh all cached counts using the new filters. + + This two-step (pre_save + post_save) pattern is required because: + - post_save instances already contain the updated value + - so the old threshold would be lost without caching it in pre_save + """ + old_threshold = instance._old_threshold + new_threshold = instance.default_filters_score_threshold + if old_threshold is not None and old_threshold != new_threshold: + refresh_cached_counts_for_project(instance) + + +@receiver(m2m_changed, sender=Project.default_filters_include_taxa.through) +def include_taxa_updated(sender, instance: Project, action, **kwargs): + """Refresh cached counts when include taxa are modified.""" + if action in ["post_add", "post_remove", "post_clear"]: + logger.info(f"Include taxa updated for project {instance.pk} (action={action})") + refresh_cached_counts_for_project(instance) + + +@receiver(m2m_changed, sender=Project.default_filters_exclude_taxa.through) +def exclude_taxa_updated(sender, instance: Project, action, **kwargs): + """Refresh cached counts when exclude taxa are modified.""" + if action in ["post_add", "post_remove", "post_clear"]: + logger.info(f"Exclude taxa updated for project {instance.pk} (action={action})") + refresh_cached_counts_for_project(instance) diff --git a/config/settings/base.py b/config/settings/base.py index 03124d41a..4740ca7e7 100644 --- a/config/settings/base.py +++ b/config/settings/base.py @@ -448,3 +448,6 @@ "DEFAULT_PROCESSING_SERVICE_ENDPOINT", default=None # type: ignore[no-untyped-call] ) DEFAULT_PIPELINES_ENABLED = env.list("DEFAULT_PIPELINES_ENABLED", default=None) # type: ignore[no-untyped-call] +# Default taxa filters +DEFAULT_INCLUDE_TAXA = env.list("DEFAULT_INCLUDE_TAXA", default=[]) # type: ignore[no-untyped-call] +DEFAULT_EXCLUDE_TAXA = env.list("DEFAULT_EXCLUDE_TAXA", default=[]) # type: ignore[no-untyped-call]