-
Notifications
You must be signed in to change notification settings - Fork 6
Open
Description
The logits returned with classifications are a copy of the softmax scores! This was fixed in the ood branch but never brought into main. Noticed by @rhine3
The bug can be seen here:
https://github.com/RolnickLab/ami-data-companion/blob/main/trapdata/api/models/classification.py#L78
One fix can be found here:
ami-data-companion/trapdata/api/models/classification.py
Lines 73 to 128 in 6a1b16f
| def post_process_batch( | |
| self, logits: torch.Tensor, features: torch.Tensor | None = None | |
| ) -> list[ClassifierResult]: | |
| """ | |
| Return the labels, softmax/calibrated scores, and the original logits for | |
| each image in the batch. | |
| Almost like the base class method, but we need to return the logits as well. | |
| each image in the batch, along with optional feature vectors. | |
| """ | |
| predictions = torch.nn.functional.softmax(logits, dim=1) | |
| predictions = predictions.cpu().numpy() | |
| if self.class_prior is None: | |
| ood_scores = np.max(predictions, axis=-1) | |
| else: | |
| ood_scores = np.max(predictions - self.class_prior, axis=-1) | |
| # Ensure higher scores indicate more likelihood that it is OOD. | |
| ood_scores = 1 - ood_scores | |
| features = features.cpu() if features is not None else None | |
| logits = logits.cpu() | |
| batch_results = [] | |
| for i, pred in enumerate(predictions): | |
| class_indices = np.arange(len(pred)) | |
| labels_single = [self.category_map[i] for i in class_indices] | |
| ood_score = ood_scores[i] | |
| logits_single = logits[i].float().tolist() | |
| feature_vector = ( | |
| features[i].float().tolist() if features is not None else None | |
| ) | |
| result = ClassifierResult( | |
| features=feature_vector, | |
| labels=labels_single, | |
| logits=logits_single, | |
| scores=pred, | |
| ood_score=ood_score, | |
| ) | |
| batch_results.append(result) | |
| logger.debug(f"Post-processing result batch with {len(batch_results)} entries.") | |
| return batch_results | |
| def predict_batch(self, batch, return_features: bool = False): | |
| batch_input = batch.to(self.device, non_blocking=True) | |
| if return_features: | |
| features = self.get_features(batch_input) | |
| logits = self.model(batch_input) | |
| return logits, features | |
| logits = self.model(batch_input) | |
| return logits, None |
Metadata
Metadata
Assignees
Labels
No labels