Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,671 changes: 1,470 additions & 1,201 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pyobjus = [
{ version = "^1.2.1", platform = "darwin" },
{ version = "^1.2.1", platform = "linux" },
]

scikit-learn = "^1.3.0"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should make these optional dependencies and just use numpy in the tests. unless we need to use them in the core app.

[tool.poetry.extras]
dev = ["plotly", "scikit-learn"]

# [tool.poetry.group.dev.dependencies] # Can't install these dev deps with pip, so they're in the main deps
black = "^23.3.0"
flake8 = "^6.0.0"
Expand Down
76 changes: 65 additions & 11 deletions trapdata/api/models/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@

import numpy as np
import torch
import torch.utils.data
from sentry_sdk import start_transaction

from trapdata import logger
from trapdata.common.logs import logger
from trapdata.ml.models.classification import (
GlobalMothSpeciesClassifier,
Expand All @@ -17,6 +20,7 @@
TuringCostaRicaSpeciesClassifier,
UKDenmarkMothSpeciesClassifier2024,
)
from trapdata.ml.utils import StopWatch

from ..datasets import ClassificationImageDataset
from ..schemas import (
Expand Down Expand Up @@ -60,28 +64,43 @@ def get_dataset(self):
batch_size=self.batch_size,
)

def post_process_batch(self, logits: torch.Tensor):
def post_process_batch(
self, logits: torch.Tensor, features: torch.Tensor | None = None
):
"""
Return the labels, softmax/calibrated scores, and the original logits for
each image in the batch.

Almost like the base class method, but we need to return the logits as well.
each image in the batch, along with optional feature vectors.
"""
predictions = torch.nn.functional.softmax(logits, dim=1)
predictions = predictions.cpu().numpy()

features = features.cpu() if features is not None else None
batch_results = []
for pred in predictions:
# Get all class indices and their corresponding scores
for i, pred in enumerate(predictions):
class_indices = np.arange(len(pred))
scores = pred
labels = [self.category_map[i] for i in class_indices]
batch_results.append(list(zip(labels, scores, pred)))
preds = list(zip(labels, scores, pred))

logger.debug(f"Post-processing result batch: {batch_results}")
if features is not None:
batch_results.append((preds, features[i].tolist()))
else:
batch_results.append((preds, None))

logger.debug(f"Post-processing result batch with {len(batch_results)} entries.")
return batch_results

def predict_batch(self, batch, return_features: bool = False):
batch_input = batch.to(self.device, non_blocking=True)

if return_features:
features = self.get_features(batch_input)
logits = self.model(batch_input)
return logits, features

logits = self.model(batch_input)
return logits, None

def get_best_label(self, predictions):
"""
Convenience method to get the best label from the predictions, which are a list of tuples
Expand All @@ -105,16 +124,18 @@ def save_results(
) -> list[DetectionResponse]:
image_ids = metadata[0]
detection_idxes = metadata[1]
for image_id, detection_idx, predictions in zip(
for image_id, detection_idx, (predictions, features_vec) in zip(
image_ids, detection_idxes, batch_output
):
detection = self.detections[detection_idx]
assert detection.source_image_id == image_id
_labels, scores, logits = zip(*predictions)

classification = ClassificationResponse(
classification=self.get_best_label(predictions),
scores=scores,
logits=logits,
features=features_vec,
inference_time=seconds_per_item,
algorithm=AlgorithmReference(name=self.name, key=self.get_key()),
timestamp=datetime.datetime.now(),
Expand All @@ -140,12 +161,45 @@ def update_classification(
f"Total classifications: {len(detection.classifications)}"
)

def run(self) -> list[DetectionResponse]:
@torch.no_grad()
def run(self):
logger.info(
f"Starting {self.__class__.__name__} run with {len(self.results)} "
"detections"
)
super().run()
torch.cuda.empty_cache()

for i, batch in enumerate(self.dataloader):
if not batch:
logger.info(f"Batch {i+1} is empty, skipping")
continue

item_ids, batch_input = batch

logger.info(
f"Processing batch {i+1}, about {len(self.dataloader)} remaining"
)

with StopWatch() as batch_time:
with start_transaction(op="inference_batch", name=self.name):
logits, features = self.predict_batch(
batch_input, return_features=True
)

seconds_per_item = batch_time.duration / len(logits)

batch_output = list(self.post_process_batch(logits, features=features))
if isinstance(item_ids, (np.ndarray, torch.Tensor)):
item_ids = item_ids.tolist()

logger.info(f"Saving results from {len(item_ids)} items")
self.save_results(
item_ids,
batch_output,
seconds_per_item=seconds_per_item,
)
logger.info(f"{self.name} Batch -- Done")

logger.info(
f"Finished {self.__class__.__name__} run. "
f"Processed {len(self.results)} detections"
Expand Down
8 changes: 8 additions & 0 deletions trapdata/api/schemas.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Can these be imported from the OpenAPI spec yaml?
import datetime
import pathlib
from typing import Optional

import PIL.Image
import pydantic
Expand Down Expand Up @@ -100,6 +101,13 @@ class ClassificationResponse(pydantic.BaseModel):
),
repr=False, # Too long to display in the repr
)
features: Optional[list[float]] = pydantic.Field(
default=None,
description=(
"Intermediate features extracted from the model before the classification head"
),
repr=False,
)
inference_time: float | None = None
algorithm: AlgorithmReference
terminal: bool = True
Expand Down
187 changes: 187 additions & 0 deletions trapdata/api/tests/test_features_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import os
import pathlib
import unittest
from unittest import TestCase

import numpy as np
from fastapi.testclient import TestClient
from PIL import Image
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA

from trapdata.api.api import PipelineChoice, PipelineRequest, PipelineResponse, app
from trapdata.api.schemas import SourceImageRequest
from trapdata.api.tests.image_server import StaticFileTestServer
from trapdata.ml.models.tracking import cosine_similarity
from trapdata.tests import TEST_IMAGES_BASE_PATH


class TestFeatureExtractionAPI(TestCase):
@classmethod
def setUpClass(cls):
cls.test_images_dir = pathlib.Path(TEST_IMAGES_BASE_PATH)
cls.file_server = StaticFileTestServer(cls.test_images_dir)
cls.client = TestClient(app)

@classmethod
def tearDownClass(cls):
cls.file_server.stop()

def get_local_test_images(self, num=1):
image_paths = [
"panama/01-20231110214539-snapshot.jpg",
"panama/01-20231111032659-snapshot.jpg",
"panama/01-20231111015309-snapshot.jpg",
]
return [
SourceImageRequest(id="0", url=self.file_server.get_url(image_path))
for image_path in image_paths[:num]
]

def get_pipeline_response(self, pipeline_slug="global_moths_2024", num_images=1):
"""
Utility method to send a pipeline request and return the parsed response.
"""
test_images = self.get_local_test_images(num=num_images)
pipeline_request = PipelineRequest(
pipeline=PipelineChoice[pipeline_slug],
source_images=test_images,
)

with self.file_server:
response = self.client.post("/process", json=pipeline_request.model_dump())
assert response.status_code == 200
return PipelineResponse(**response.json())

def test_feature_extraction_from_pipeline(self):
"""
Run a local image through the pipeline and validate extracted features.
"""
pipeline_response = self.get_pipeline_response()

self.assertTrue(pipeline_response.detections, "No detections returned")
for detection in pipeline_response.detections:
for classification in detection.classifications:
if classification.terminal:
features = classification.features
self.assertIsNotNone(features, "Features should not be None")
assert features # This is for type checking
self.assertIsInstance(features, list, "Features should be a list")
self.assertTrue(
all(isinstance(x, float) for x in features),
"All features should be floats",
)
self.assertEqual(
len(features), 2048, "Feature vector should be 2048 dims"
)

def test_cosine_similarity_of_extracted_features(self):
"""
Run the pipeline and compare features using cosine similarity to validate
output.
"""
pipeline_response = self.get_pipeline_response(num_images=1)

# Extract all terminal classification features
feature_vectors = []
for detection in pipeline_response.detections:
for classification in detection.classifications:
if classification.terminal and classification.features:
feature_vectors.append(classification.features)

self.assertGreater(
len(feature_vectors), 1, "Need at least two features to compare"
)

for _i, vec1 in enumerate(feature_vectors):
sims = []
for _j, vec2 in enumerate(feature_vectors):
sim = cosine_similarity(vec1, vec2)
sims.append(round(sim, 4))

# Confirm that similarity with itself is 1.0
for i, vec in enumerate(feature_vectors):
self_sim = cosine_similarity(vec, vec)
self.assertAlmostEqual(
self_sim, 1.0, places=5, msg=f"Self similarity at index {i} not 1.0"
)
# Confirm that a feature is most similar to itself

for ref_index, ref_vec in enumerate(feature_vectors):
similarities = [
(i, cosine_similarity(ref_vec, other_vec))
for i, other_vec in enumerate(feature_vectors)
]
similarities.sort(key=lambda x: x[1], reverse=True)
most_similar_index = similarities[0][0]
self.assertEqual(
most_similar_index,
ref_index,
f"Expected most similar vector to be at index {ref_index}, "
"got {most_similar_index}",
)

def get_detection_crop(self, local_image_path: str, bbox) -> Image.Image | None:
"""
Given a local image path and a bounding box, return a cropped and resized image.
"""

try:
if not os.path.exists(local_image_path):
print(f"File not found: {local_image_path}")
return None

img = Image.open(local_image_path).convert("RGB")
x1, y1, x2, y2 = map(int, [bbox.x1, bbox.y1, bbox.x2, bbox.y2])
crop = img.crop((x1, y1, x2, y2)).resize((64, 64))
return crop
except Exception as e:
print(f"Failed to load or crop image: {e}")
return None

@unittest.skip("Skipping visualization test")
def test_feature_clustering_visualization(self):

source_images = self.get_local_test_images(num=3)
pipeline_response = self.get_pipeline_response(num_images=len(source_images))
image_id_to_url = {img.id: img.url for img in source_images}

features = []
labels = []

for detection in pipeline_response.detections:
source_url = image_id_to_url.get(detection.source_image_id)
if not source_url or not detection.bbox:
continue

for classification in detection.classifications:
if classification.features:
features.append(classification.features)
print(f"Classification: {classification.classification}")

labels.append(classification.classification)

if len(features) < 2:
print("Not enough data for clustering.")
return

# Reduce to 3D using PCA
features_np = np.array(features)
reduced = PCA(n_components=3).fit_transform(features_np)
cluster_labels = KMeans(
n_clusters=min(8, len(features)), random_state=42
).fit_predict(features_np)

import plotly.express as px # type: ignore[import]

fig = px.scatter_3d(
x=reduced[:, 0],
y=reduced[:, 1],
z=reduced[:, 2],
color=cluster_labels.astype(str),
hover_name=labels,
title="3D Clustering of Classification Feature Vectors (K-Means + PCA)",
)

fig.update_traces(marker={"size": 6})
fig.write_html("feature_clustering_3d_pca.html")
1 change: 0 additions & 1 deletion trapdata/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
NEGATIVE_BINARY_LABEL = "nonmoth"
NULL_DETECTION_LABELS = [NEGATIVE_BINARY_LABEL]
TRACKING_COST_THRESHOLD = 1.0

POSITIVE_COLOR = [0, 100 / 255, 1, 1] # Blue
# POSITIVE_COLOR = [1, 0, 162 / 255, 1] # Pink
# NEUTRAL_COLOR = [1, 1, 1, 0.5] # White
Expand Down
9 changes: 9 additions & 0 deletions trapdata/ml/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,15 @@ def get_model(self) -> torch.nn.Module:
"""
raise NotImplementedError

def get_features(
self, batch_input: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Default get_features method for models that don't implement feature extraction.
"""

return None

def get_transforms(self) -> torchvision.transforms.Compose:
"""
This method must be implemented by a subclass.
Expand Down
10 changes: 10 additions & 0 deletions trapdata/ml/models/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,16 @@ def get_model(self):
model.eval()
return model

def get_features(self, batch_input: torch.Tensor) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work on this method of extracting features! It seems more flexible than our current feature extractor. Perhaps we should add a comment in both feature extractors that the other one exists. And eventually update the old one to use this code.

logger.debug(
f"[{self.name}] get_features called with input shape: {batch_input.shape}"
)
features = self.model.forward_features(batch_input) # [B, 2048, 4, 4]
# Flatten the features vector
features = torch.nn.functional.adaptive_avg_pool2d(features, output_size=(1, 1))
features = features.view(features.size(0), -1)
return features


class BinaryClassifier(Resnet50ClassifierLowRes):
stage = 2
Expand Down