Skip to content
Open
25 changes: 20 additions & 5 deletions ami/main/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,7 +1475,11 @@ def get_queryset(self) -> QuerySet:
qs = self.get_taxa_observed(qs, project, include_unobserved=include_unobserved)
if self.action == "retrieve":
qs = self.get_taxa_observed(
qs, project, include_unobserved=include_unobserved, apply_default_filters=False
qs,
project,
include_unobserved=include_unobserved,
apply_default_score_filter=True,
apply_default_taxa_filter=False,
)
qs = qs.prefetch_related(
Prefetch(
Expand All @@ -1493,7 +1497,12 @@ def get_queryset(self) -> QuerySet:
return qs

def get_taxa_observed(
self, qs: QuerySet, project: Project, include_unobserved=False, apply_default_filters=True
self,
qs: QuerySet,
project: Project,
include_unobserved=False,
apply_default_score_filter=True,
apply_default_taxa_filter=True,
) -> QuerySet:
"""
If a project is passed, only return taxa that have been observed.
Expand All @@ -1511,15 +1520,21 @@ def get_taxa_observed(
# Respects apply_defaults flag: build_occurrence_default_filters_q checks it internally
from ami.main.models_future.filters import build_occurrence_default_filters_q

default_filters_q = build_occurrence_default_filters_q(project, self.request, occurrence_accessor="")
default_filters_q = build_occurrence_default_filters_q(
project,
self.request,
occurrence_accessor="",
apply_default_score_filter=apply_default_score_filter,
apply_default_taxa_filter=apply_default_taxa_filter,
)

# Combine base occurrence filters with default filters
base_filter = models.Q(
occurrence_filters,
determination_id=models.OuterRef("id"),
)
if apply_default_filters:
base_filter = base_filter & default_filters_q

base_filter = base_filter & default_filters_q

# Count occurrences - uses composite index (determination_id, project_id, event_id, determination_score)
occurrences_count_subquery = models.Subquery(
Expand Down
66 changes: 61 additions & 5 deletions ami/main/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,17 @@ def get_or_create_default_collection(project: "Project") -> "SourceImageCollecti
return collection


def get_project_default_filters():
"""
Read default taxa names from Django settings (read from environment variables)
and return corresponding Taxon objects.
"""
include_taxa = list(Taxon.objects.filter(name__in=settings.DEFAULT_INCLUDE_TAXA))
exclude_taxa = list(Taxon.objects.filter(name__in=settings.DEFAULT_EXCLUDE_TAXA))

return {"default_include_taxa": include_taxa, "default_exclude_taxa": exclude_taxa}


def get_or_create_default_project(user: User) -> "Project":
"""
Create a default project for a user.
Expand All @@ -158,8 +169,20 @@ def get_or_create_default_project(user: User) -> "Project":
when the project is saved for the first time.
If the project already exists, it will be returned without modification.
"""
project, _created = Project.objects.get_or_create(name="Scratch Project", owner=user, create_defaults=True)
logger.info(f"Created default project for user {user}")
project, created = Project.objects.get_or_create(name="Scratch Project", owner=user)
if created:
logger.info(f"Created default project for user {user}")
defaults = get_project_default_filters()

if defaults["default_include_taxa"]:
project.default_filters_include_taxa.set(defaults["default_include_taxa"])
logger.info(f"Set {len(defaults['default_include_taxa'])} default include taxa for project {project}")
if defaults["default_exclude_taxa"]:
project.default_filters_exclude_taxa.set(defaults["default_exclude_taxa"])
logger.info(f"Set {len(defaults['default_exclude_taxa'])} default exclude taxa for project {project}")
project.save()
else:
logger.info(f"Loaded existing default project for user {user}")
return project


Expand Down Expand Up @@ -678,6 +701,18 @@ def get_first_and_last_timestamps(self) -> tuple[datetime.datetime, datetime.dat
)
return (first, last)

def get_detections_count(self) -> int | None:
"""Return detections count filtered by project default filters"""

qs = Detection.objects.filter(source_image__deployment=self)
filter_q = build_occurrence_default_filters_q(
project=self.project,
request=None,
occurrence_accessor="occurrence",
)

return qs.filter(filter_q).distinct().count()

def first_date(self) -> datetime.date | None:
return self.first_capture_timestamp.date() if self.first_capture_timestamp else None

Expand Down Expand Up @@ -883,7 +918,7 @@ def update_calculated_fields(self, save=False):

self.events_count = self.events.count()
self.captures_count = self.data_source_total_files or self.captures.count()
self.detections_count = Detection.objects.filter(Q(source_image__deployment=self)).count()
self.detections_count = self.get_detections_count()
occ_qs = self.occurrences.filter(event__isnull=False).apply_default_filters( # type: ignore
project=self.project,
request=None,
Expand Down Expand Up @@ -1048,7 +1083,15 @@ def get_captures_count(self) -> int:
return self.captures.distinct().count()

def get_detections_count(self) -> int | None:
return Detection.objects.filter(Q(source_image__event=self)).count()
"""Return detections count filtered by project default filters"""
qs = Detection.objects.filter(source_image__event=self)
filter_q = build_occurrence_default_filters_q(
project=self.project,
request=None,
occurrence_accessor="occurrence",
)

return qs.filter(filter_q).distinct().count()

def get_occurrences_count(self, classification_threshold: float = 0) -> int:
"""
Expand Down Expand Up @@ -1753,7 +1796,20 @@ def size_display(self) -> str:
return filesizeformat(self.size)

def get_detections_count(self) -> int:
return self.detections.distinct().count()
"""
Return detections count filtered by project default filters.
"""
project = self.project
if not project:
return self.detections.distinct().count()

q = build_occurrence_default_filters_q(
project=project,
request=None,
occurrence_accessor="occurrence",
)

return self.detections.filter(q).distinct().count()

def get_base_url(self) -> str | None:
"""
Expand Down
30 changes: 16 additions & 14 deletions ami/main/models_future/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ def build_occurrence_default_filters_q(
project: "Project | None" = None,
request: "Request | None" = None,
occurrence_accessor: str = "",
apply_default_score_filter: bool = True,
apply_default_taxa_filter: bool = True,
) -> Q:
"""
Build a Q filter that applies default filters (score threshold + taxa) for Occurrence relationships.
Expand Down Expand Up @@ -194,19 +196,19 @@ def build_occurrence_default_filters_q(
return Q()

filter_q = Q()

# Build score threshold filter
score_threshold = get_default_classification_threshold(project, request)
filter_q &= build_occurrence_score_threshold_q(score_threshold, occurrence_accessor)

# Build taxa inclusion/exclusion filter
# For taxa filtering, we need to append "__determination" to the occurrence accessor
prefix = f"{occurrence_accessor}__" if occurrence_accessor else ""
taxon_accessor = f"{prefix}determination"
include_taxa = project.default_filters_include_taxa.all()
exclude_taxa = project.default_filters_exclude_taxa.all()
taxa_q = build_taxa_recursive_filter_q(include_taxa, exclude_taxa, taxon_accessor)
if taxa_q:
filter_q &= taxa_q
if apply_default_score_filter:
# Build score threshold filter
score_threshold = get_default_classification_threshold(project, request)
filter_q &= build_occurrence_score_threshold_q(score_threshold, occurrence_accessor)
if apply_default_taxa_filter:
# Build taxa inclusion/exclusion filter
# For taxa filtering, we need to append "__determination" to the occurrence accessor
prefix = f"{occurrence_accessor}__" if occurrence_accessor else ""
taxon_accessor = f"{prefix}determination"
include_taxa = project.default_filters_include_taxa.all()
exclude_taxa = project.default_filters_exclude_taxa.all()
taxa_q = build_taxa_recursive_filter_q(include_taxa, exclude_taxa, taxon_accessor)
if taxa_q:
filter_q &= taxa_q

return filter_q
81 changes: 80 additions & 1 deletion ami/main/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from django.dispatch import receiver
from guardian.shortcuts import assign_perm

from ami.main.models import Project
from ami.users.roles import BasicMember, ProjectManager, create_roles_for_project

from .models import Project, User
from .models import User

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -110,3 +111,81 @@ def delete_project_groups(sender, instance, **kwargs):
prefix = f"{instance.pk}_"
# Find and delete all groups that start with {project_id}_
Group.objects.filter(name__startswith=prefix).delete()


# ============================================================================
# Project Default Filters Update Signals
# ============================================================================
# These signals handle efficient updates to calculated fields for project-related
# objects (such as Deployments and Events) whenever a project's default filter
# values change.
#
# Specifically, they trigger recalculation of cached counts when:
# - The project's default score threshold is updated
# - The project's default include taxa are modified
# - The project's default exclude taxa are modified
#
# This ensures that cached counts (e.g., occurrences_count, taxa_count) remain
# accurate and consistent with the active filter configuration for each project.
# ============================================================================


def refresh_cached_counts_for_project(project: Project):
"""
Refresh cached counts for Deployments and Events belonging to a project.
"""
logger.info(f"Refreshing cached counts for project {project.pk} ({project.name})")
project.update_related_calculated_fields()


@receiver(pre_save, sender=Project)
def cache_old_threshold(sender, instance, **kwargs):
"""
Cache the previous default score threshold before saving the Project.

We do this because:
- In post_save, the instance already contains the NEW value.
- To detect whether the threshold actually changed, we must read the OLD
value from the database before the update happens.
- This allows us to accurately detect threshold changes and then trigger
recalculation of cached filtered counts (Events, Deployments, etc.).

The cached value is stored on the instance as `_old_threshold` so it can be
safely accessed in the post_save handler.
"""
if instance.pk:
instance._old_threshold = Project.objects.get(pk=instance.pk).default_filters_score_threshold
else:
instance._old_threshold = None


@receiver(post_save, sender=Project)
def threshold_updated(sender, instance, **kwargs):
"""
After saving the Project, compare the previously cached threshold with the new value.
If the default score threshold changed, we refresh all cached counts using the new filters.

This two-step (pre_save + post_save) pattern is required because:
- post_save instances already contain the updated value
- so the old threshold would be lost without caching it in pre_save
"""
old_threshold = instance._old_threshold
new_threshold = instance.default_filters_score_threshold
if old_threshold is not None and old_threshold != new_threshold:
refresh_cached_counts_for_project(instance)


@receiver(m2m_changed, sender=Project.default_filters_include_taxa.through)
def include_taxa_updated(sender, instance: Project, action, **kwargs):
"""Refresh cached counts when include taxa are modified."""
if action in ["post_add", "post_remove", "post_clear"]:
logger.info(f"Include taxa updated for project {instance.pk} (action={action})")
refresh_cached_counts_for_project(instance)


@receiver(m2m_changed, sender=Project.default_filters_exclude_taxa.through)
def exclude_taxa_updated(sender, instance: Project, action, **kwargs):
"""Refresh cached counts when exclude taxa are modified."""
if action in ["post_add", "post_remove", "post_clear"]:
logger.info(f"Exclude taxa updated for project {instance.pk} (action={action})")
refresh_cached_counts_for_project(instance)
3 changes: 3 additions & 0 deletions config/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,3 +448,6 @@
"DEFAULT_PROCESSING_SERVICE_ENDPOINT", default=None # type: ignore[no-untyped-call]
)
DEFAULT_PIPELINES_ENABLED = env.list("DEFAULT_PIPELINES_ENABLED", default=None) # type: ignore[no-untyped-call]
# Default taxa filters
DEFAULT_INCLUDE_TAXA = env.list("DEFAULT_INCLUDE_TAXA", default=[]) # type: ignore[no-untyped-call]
DEFAULT_EXCLUDE_TAXA = env.list("DEFAULT_EXCLUDE_TAXA", default=[]) # type: ignore[no-untyped-call]