diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml
index ca823661f..db3919d59 100644
--- a/.github/ISSUE_TEMPLATE/bug-report.yml
+++ b/.github/ISSUE_TEMPLATE/bug-report.yml
@@ -61,4 +61,4 @@ body:
description: >
(Optional) We encourage you to submit a [Pull Request](https://github.com/microsoft/CameraTraps/pulls) (PR) to help contribute Pytorch-Wildlife for everyone, especially if you have a good understanding of how to implement a fix or feature.
options:
- - label: Yes I'd like to help by submitting a PR!
\ No newline at end of file
+ - label: Yes I'd like to help by submitting a PR!
diff --git a/.github/ISSUE_TEMPLATE/feature-request.yml b/.github/ISSUE_TEMPLATE/feature-request.yml
index f26991c4c..576da754b 100644
--- a/.github/ISSUE_TEMPLATE/feature-request.yml
+++ b/.github/ISSUE_TEMPLATE/feature-request.yml
@@ -46,4 +46,4 @@ body:
description: >
(Optional) We encourage you to submit a [Pull Request](https://github.com/microsoft/CameraTraps/pulls) (PR) to help contribute Pytorch-Wildlife for everyone, especially if you have a good understanding of how to implement a fix or feature.
options:
- - label: Yes I'd like to help by submitting a PR!
\ No newline at end of file
+ - label: Yes I'd like to help by submitting a PR!
diff --git a/.github/ISSUE_TEMPLATE/question.yml b/.github/ISSUE_TEMPLATE/question.yml
index bd601f8f9..040de4107 100644
--- a/.github/ISSUE_TEMPLATE/question.yml
+++ b/.github/ISSUE_TEMPLATE/question.yml
@@ -30,4 +30,4 @@ body:
- type: textarea
attributes:
label: Additional
- description: Anything else you would like to share?
\ No newline at end of file
+ description: Anything else you would like to share?
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 000000000..67de664d4
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,51 @@
+# Pre-commit hooks configuration for PyTorch Wildlife
+# Installation: pip install pre-commit
+# Setup: pre-commit install
+# Run: pre-commit run --all-files
+
+repos:
+ # Code formatting
+ - repo: https://github.com/psf/black
+ rev: 23.12.1
+ hooks:
+ - id: black
+ language_version: python3
+ args: ['--line-length=100']
+
+ # Import sorting
+ - repo: https://github.com/PyCQA/isort
+ rev: 5.13.2
+ hooks:
+ - id: isort
+ args: ['--profile=black', '--line-length=100']
+
+ # Linting
+ - repo: https://github.com/PyCQA/flake8
+ rev: 6.1.0
+ hooks:
+ - id: flake8
+ args: ['--max-line-length=100', '--extend-ignore=E203,W503']
+ additional_dependencies: ['flake8-docstrings', 'flake8-bugbear']
+
+ # Built-in pre-commit hooks
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v4.5.0
+ hooks:
+ - id: trailing-whitespace
+ args: ['--markdown-template']
+ - id: end-of-file-fixer
+ - id: check-yaml
+ args: ['--unsafe']
+ - id: check-added-large-files
+ args: ['--maxkb=1000']
+ - id: check-merge-conflict
+ - id: debug-statements
+
+ # Type checking (optional but recommended)
+ - repo: https://github.com/pre-commit/mirrors-mypy
+ rev: v1.7.1
+ hooks:
+ - id: mypy
+ additional_dependencies: ['types-all']
+ args: ['--ignore-missing-imports', '--no-warn-unused-ignores']
+ stages: [manual]
diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md
new file mode 100644
index 000000000..da04c597a
--- /dev/null
+++ b/DEVELOPMENT.md
@@ -0,0 +1,211 @@
+# Development Guide for PyTorch Wildlife
+
+This guide covers setting up your development environment and following our code quality standards.
+
+## Setup
+
+### Prerequisites
+- Python >= 3.8
+- Git
+
+### Virtual Environment
+
+We recommend using a virtual environment:
+
+```bash
+# Create virtual environment
+python -m venv venv
+
+# Activate it
+source venv/bin/activate # On Windows: venv\Scripts\activate
+```
+
+### Install Dependencies
+
+```bash
+# Install package in development mode with all dependencies
+pip install -e ".[dev]"
+
+# Or install core dependencies + development tools
+pip install -r requirements.txt
+pip install pre-commit black flake8 isort
+```
+
+## Code Quality
+
+We use automated tools to maintain consistent code quality across the project.
+
+### Pre-commit Hooks
+
+Pre-commit hooks automatically check your code before each commit, catching common issues early.
+
+**Setup (one-time):**
+
+```bash
+pip install pre-commit
+pre-commit install
+```
+
+Now every time you commit, hooks will run automatically. If any issues are found, the commit will be blocked until you fix them.
+
+**Running Manually:**
+
+```bash
+# Check all files
+pre-commit run --all-files
+
+# Check specific file
+pre-commit run --files path/to/file.py
+
+# Skip a specific hook
+SKIP=flake8 pre-commit run --all-files
+
+# Update hook versions
+pre-commit autoupdate
+```
+
+### What Each Hook Does
+
+1. **black** - Automatic code formatting (PEP 8 compliant)
+ - Enforces consistent style across the codebase
+ - Max line length: 100 characters
+
+2. **isort** - Import organization
+ - Sorts imports alphabetically
+ - Groups: stdlib → third-party → local
+ - Compatible with black
+
+3. **flake8** - Linting
+ - Checks for common style issues
+ - Detects unused variables/imports
+ - Catches potential bugs
+
+4. **trailing-whitespace** - Removes trailing spaces
+
+5. **end-of-file-fixer** - Ensures files end with a newline
+
+6. **check-yaml** - Validates YAML syntax
+
+7. **check-merge-conflict** - Prevents committed merge conflict markers
+
+### Fixing Code Issues
+
+If pre-commit finds issues, you can fix them automatically:
+
+```bash
+# Black and isort will auto-fix formatting/imports
+black .
+isort .
+
+# Flake8 requires manual fixes (shows issues, doesn't auto-fix)
+flake8 .
+```
+
+After fixing, commit again:
+
+```bash
+git add .
+git commit -m "feat: Your message here"
+```
+
+### Type Checking (Optional)
+
+For advanced type checking (optional, not enforced on commit):
+
+```bash
+pre-commit run mypy --all-files
+```
+
+## Testing
+
+Run the test suite before submitting PRs:
+
+```bash
+pytest tests/
+```
+
+## Committing Code
+
+Follow these guidelines:
+
+1. **Create a feature branch:**
+ ```bash
+ git checkout -b issue-600-add-precommit-hooks
+ ```
+
+2. **Make changes and test:**
+ ```bash
+ # Edit files...
+ pre-commit run --all-files # Run checks
+ pytest tests/ # Run tests
+ ```
+
+3. **Commit with clear messages:**
+ ```bash
+ git commit -m "feat: Add pre-commit hooks for code quality"
+ ```
+
+4. **Push and create PR:**
+ ```bash
+ git push origin issue-600-add-precommit-hooks
+ ```
+
+## Commit Message Format
+
+We follow conventional commits for clear, organized history:
+
+```
+type: short description
+
+Optional longer description explaining the change.
+
+Fixes #123
+```
+
+### Types
+- `feat:` - New feature
+- `fix:` - Bug fix
+- `docs:` - Documentation
+- `style:` - Code style (formatting, etc.)
+- `refactor:` - Code refactoring without functionality change
+- `test:` - Test additions/changes
+- `chore:` - Maintenance tasks
+
+Example:
+```
+feat: Add pre-commit hooks for code quality
+
+Implements black, flake8, and isort for consistent code formatting.
+Developers can now use pre-commit hooks to catch issues before commit.
+
+Fixes #600
+```
+
+## Troubleshooting
+
+**Pre-commit is slow:**
+- Large repositories can be slow on first run. Subsequent runs are faster.
+- You can run specific hooks: `pre-commit run black --all-files`
+
+**Hook keeps failing on the same issue:**
+- Some hooks (black, isort) auto-fix. Run them to fix automatically.
+- For flake8: read the output and fix manually.
+
+**Need to bypass hooks temporarily:**
+- For emergencies only: `git commit --no-verify`
+- But please fix issues before pushing!
+
+**Confused about a hook error:**
+- Run the tool directly: `black path/to/file.py`
+- Check the hook documentation in `.pre-commit-config.yaml`
+
+## Resources
+
+- [Pre-commit Documentation](https://pre-commit.com/)
+- [Black Documentation](https://black.readthedocs.io/)
+- [Flake8 Documentation](https://flake8.pycqa.org/)
+- [isort Documentation](https://pycqa.github.io/isort/)
+
+## Questions?
+
+If you have questions about the development setup, open an issue or reach out on Discord: https://discord.gg/TeEVxzaYtm
diff --git a/Dockerfile b/Dockerfile
index b99a41062..5b06bafcd 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -23,4 +23,3 @@ RUN rm -rf /tmp/*
RUN pip install --no-cache-dir PytorchWildlife
EXPOSE 80
-
diff --git a/MANIFEST.in b/MANIFEST.in
index 349ed67ad..f11cb4320 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -1 +1 @@
-global-include *.yml
\ No newline at end of file
+global-include *.yml
diff --git a/PW_FT_classification/__init__.py b/PW_FT_classification/__init__.py
index c18128e6d..4f9811159 100644
--- a/PW_FT_classification/__init__.py
+++ b/PW_FT_classification/__init__.py
@@ -1 +1 @@
-from batch_detection_cropping import *
\ No newline at end of file
+from batch_detection_cropping import *
diff --git a/PW_FT_classification/configs/config.yaml b/PW_FT_classification/configs/config.yaml
index e0c4bb19b..cea9ce95f 100644
--- a/PW_FT_classification/configs/config.yaml
+++ b/PW_FT_classification/configs/config.yaml
@@ -38,4 +38,3 @@ weight_decay_classifier: 0.0005
## lr_scheduler
step_size: 10
gamma: 0.1
-
diff --git a/PW_FT_classification/main.py b/PW_FT_classification/main.py
index ec66333a1..c72b4dcfb 100644
--- a/PW_FT_classification/main.py
+++ b/PW_FT_classification/main.py
@@ -1,40 +1,43 @@
# %%
# Importing libraries
import os
-import yaml
-import typer
-from munch import Munch
+
+import pytorch_lightning as pl
+
# %%
import torch
-import pytorch_lightning as pl
-from pytorch_lightning.callbacks import ModelCheckpoint
-from pytorch_lightning.callbacks import LearningRateMonitor
-from pytorch_lightning.loggers import CSVLogger, CometLogger, TensorBoardLogger, WandbLogger
+import typer
+import yaml
+from munch import Munch
+from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
+from pytorch_lightning.loggers import CometLogger, CSVLogger, TensorBoardLogger, WandbLogger
+
# %%
-from src import algorithms
-from src import datasets
+from src import algorithms, datasets
+
# %%
-from src.utils import batch_detection_cropping
-from src.utils import data_splitting
+from src.utils import batch_detection_cropping, data_splitting
app = typer.Typer(pretty_exceptions_short=True, pretty_exceptions_show_locals=False)
+
+
# %%
@app.command()
def main(
- config:str='./configs/config.yaml',
- project:str='Custom-classification',
- gpus:str='0',
- logger_type:str='csv',
- evaluate:str=None,
- np_threads:str='32',
- session:int=0,
- seed:int=0,
- dev:bool=False,
- val:bool=False,
- test:bool=False,
- predict:bool=False,
- predict_root:str=""
- ):
+ config: str = "./configs/config.yaml",
+ project: str = "Custom-classification",
+ gpus: str = "0",
+ logger_type: str = "csv",
+ evaluate: str = None,
+ np_threads: str = "32",
+ session: int = 0,
+ seed: int = 0,
+ dev: bool = False,
+ val: bool = False,
+ test: bool = False,
+ predict: bool = False,
+ predict_root: str = "",
+):
"""
Main function for training or evaluating a ResNet model (50 or 18) using PyTorch Lightning.
It loads configurations, initializes the model, logger, and other components based on provided arguments.
@@ -56,7 +59,7 @@ def main(
# GPU configuration: set up GPUs based on availability and user specification
gpus = gpus if torch.cuda.is_available() else None
- gpus = [int(i) for i in gpus.split(',')]
+ gpus = [int(i) for i in gpus.split(",")]
# Environment variable setup for numpy multi-threading. It is important to avoid cpu and ram issues.
os.environ["OMP_NUM_THREADS"] = str(np_threads)
@@ -81,87 +84,113 @@ def main(
# Replace annotation dir from config with the directory containing the split files
conf.annotation_dir = os.path.dirname(conf.split_path)
# Split the data according to the split type
- if conf.split_type == 'location':
- data_splitting.split_by_location(conf.split_path, conf.annotation_dir, conf.test_size, conf.val_size)
- elif conf.split_type == 'sequence':
- data_splitting.split_by_seq(conf.split_path, conf.annotation_dir, conf.test_size, conf.val_size)
- elif conf.split_type == 'random':
- data_splitting.create_splits(conf.split_path, conf.annotation_dir, conf.test_size, conf.val_size)
+ if conf.split_type == "location":
+ data_splitting.split_by_location(
+ conf.split_path, conf.annotation_dir, conf.test_size, conf.val_size
+ )
+ elif conf.split_type == "sequence":
+ data_splitting.split_by_seq(
+ conf.split_path, conf.annotation_dir, conf.test_size, conf.val_size
+ )
+ elif conf.split_type == "random":
+ data_splitting.create_splits(
+ conf.split_path, conf.annotation_dir, conf.test_size, conf.val_size
+ )
else:
- raise ValueError('Invalid split type: {}. Available options: random, location, sequence.'.format(conf.split_type))
-
+ raise ValueError(
+ "Invalid split type: {}. Available options: random, location, sequence.".format(
+ conf.split_type
+ )
+ )
+
if not conf.predict:
# Get the path to the annotation files, and we only want to do this if we are not predicting
if conf.test:
- test_annotations = os.path.join(conf.dataset_root, 'test_annotations.csv')
+ test_annotations = os.path.join(conf.dataset_root, "test_annotations.csv")
# Crop test data (most likely we don't need this)
- batch_detection_cropping.batch_detection_cropping(conf.dataset_root, os.path.join(conf.dataset_root, "cropped_resized"), test_annotations)
+ batch_detection_cropping.batch_detection_cropping(
+ conf.dataset_root,
+ os.path.join(conf.dataset_root, "cropped_resized"),
+ test_annotations,
+ )
else:
- train_annotations = os.path.join(conf.dataset_root, 'train_annotations.csv')
- val_annotations = os.path.join(conf.dataset_root, 'val_annotations.csv')
+ train_annotations = os.path.join(conf.dataset_root, "train_annotations.csv")
+ val_annotations = os.path.join(conf.dataset_root, "val_annotations.csv")
# Crop training data
- batch_detection_cropping.batch_detection_cropping(conf.dataset_root, os.path.join(conf.dataset_root, "cropped_resized"), train_annotations)
+ batch_detection_cropping.batch_detection_cropping(
+ conf.dataset_root,
+ os.path.join(conf.dataset_root, "cropped_resized"),
+ train_annotations,
+ )
# Crop validation data
- batch_detection_cropping.batch_detection_cropping(conf.dataset_root, os.path.join(conf.dataset_root, "cropped_resized"), val_annotations)
+ batch_detection_cropping.batch_detection_cropping(
+ conf.dataset_root,
+ os.path.join(conf.dataset_root, "cropped_resized"),
+ val_annotations,
+ )
# Dataset and algorithm loading based on the configuration
dataset = datasets.__dict__[conf.dataset_name](conf=conf)
- learner = algorithms.__dict__[conf.algorithm](conf=conf,
- train_class_counts=dataset.train_class_counts,
- id_to_labels=dataset.id_to_labels)
+ learner = algorithms.__dict__[conf.algorithm](
+ conf=conf, train_class_counts=dataset.train_class_counts, id_to_labels=dataset.id_to_labels
+ )
# Logger setup based on the specified logger type
- log_folder = 'log_dev' if dev else 'log'
+ log_folder = "log_dev" if dev else "log"
logger = None
- if logger_type == 'csv':
+ if logger_type == "csv":
logger = CSVLogger(
- save_dir='./{}/{}/{}'.format(log_folder, conf.log_dir, conf.algorithm),
+ save_dir="./{}/{}/{}".format(log_folder, conf.log_dir, conf.algorithm),
prefix=project,
- name='{}_{}'.format(conf.algorithm, conf.conf_id),
- version=session
+ name="{}_{}".format(conf.algorithm, conf.conf_id),
+ version=session,
)
- elif logger_type == 'tensorboard':
+ elif logger_type == "tensorboard":
logger = TensorBoardLogger(
- save_dir='./{}/{}/{}'.format(log_folder, conf.log_dir, conf.algorithm),
+ save_dir="./{}/{}/{}".format(log_folder, conf.log_dir, conf.algorithm),
prefix=project,
- name='{}_{}'.format(conf.algorithm, conf.conf_id),
- version=session
+ name="{}_{}".format(conf.algorithm, conf.conf_id),
+ version=session,
)
- elif logger_type == 'comet':
+ elif logger_type == "comet":
logger = CometLogger(
api_key=os.environ.get("COMET_API_KEY"),
- save_dir='./{}/{}/{}'.format(log_folder, conf.log_dir, conf.algorithm),
- project_name=project,
- experiment_name='{}_{}_{}'.format(conf.algorithm, conf.conf_id, session),
+ save_dir="./{}/{}/{}".format(log_folder, conf.log_dir, conf.algorithm),
+ project_name=project,
+ experiment_name="{}_{}_{}".format(conf.algorithm, conf.conf_id, session),
)
- elif logger_type == 'wandb':
+ elif logger_type == "wandb":
logger = WandbLogger(
- save_dir='./{}/{}/{}'.format(log_folder, conf.log_dir, conf.algorithm),
- project=project,
- name='{}_{}_{}'.format(conf.algorithm, conf.conf_id, session),
+ save_dir="./{}/{}/{}".format(log_folder, conf.log_dir, conf.algorithm),
+ project=project,
+ name="{}_{}_{}".format(conf.algorithm, conf.conf_id, session),
)
# Callbacks for model checkpointing and learning rate monitoring
- weights_folder = 'weights_dev' if dev else 'weights'
+ weights_folder = "weights_dev" if dev else "weights"
checkpoint_callback = ModelCheckpoint(
- monitor='valid_mac_acc', mode='max', dirpath='./{}/{}/{}'.format(weights_folder, conf.log_dir, conf.algorithm),
- save_top_k=1, filename='{}-{}'.format(conf.conf_id, session) + '-{epoch:02d}-{valid_mac_acc:.2f}', verbose=True
+ monitor="valid_mac_acc",
+ mode="max",
+ dirpath="./{}/{}/{}".format(weights_folder, conf.log_dir, conf.algorithm),
+ save_top_k=1,
+ filename="{}-{}".format(conf.conf_id, session) + "-{epoch:02d}-{valid_mac_acc:.2f}",
+ verbose=True,
)
- lr_monitor = LearningRateMonitor(logging_interval='step')
+ lr_monitor = LearningRateMonitor(logging_interval="step")
# Trainer configuration in PyTorch Lightning
trainer = pl.Trainer(
max_epochs=conf.num_epochs,
- check_val_every_n_epoch=1,
- log_every_n_steps = conf.log_interval,
- accelerator='gpu',
+ check_val_every_n_epoch=1,
+ log_every_n_steps=conf.log_interval,
+ accelerator="gpu",
devices=gpus,
logger=None if evaluate is not None else logger,
callbacks=[lr_monitor, checkpoint_callback],
- strategy='auto',
+ strategy="auto",
num_sanity_val_steps=0,
- profiler=None
+ profiler=None,
)
# Training, validation, or evaluation execution based on the mode
if evaluate is not None:
@@ -172,11 +201,13 @@ def main(
elif test:
trainer.test(learner, dataloaders=[dataset.test_dataloader()], ckpt_path=evaluate)
else:
- print('Invalid mode for evaluation.')
+ print("Invalid mode for evaluation.")
else:
trainer.fit(learner, datamodule=dataset)
+
+
# %%
-if __name__ == '__main__':
+if __name__ == "__main__":
app()
# %%
diff --git a/PW_FT_classification/requirements.txt b/PW_FT_classification/requirements.txt
index c20d78a7c..0dbc206c2 100644
--- a/PW_FT_classification/requirements.txt
+++ b/PW_FT_classification/requirements.txt
@@ -2,4 +2,4 @@ PytorchWildlife
scikit_learn
lightning
munch
-typer
\ No newline at end of file
+typer
diff --git a/PW_FT_classification/src/algorithms/__init__.py b/PW_FT_classification/src/algorithms/__init__.py
index fa84a4a0c..b79897b14 100644
--- a/PW_FT_classification/src/algorithms/__init__.py
+++ b/PW_FT_classification/src/algorithms/__init__.py
@@ -1,3 +1,2 @@
from . import utils
from .plain import *
-
diff --git a/PW_FT_classification/src/algorithms/plain.py b/PW_FT_classification/src/algorithms/plain.py
index be3212ae0..4c604eaaa 100644
--- a/PW_FT_classification/src/algorithms/plain.py
+++ b/PW_FT_classification/src/algorithms/plain.py
@@ -1,21 +1,19 @@
-import os
-import numpy as np
import json
-from datetime import datetime
-from tqdm import tqdm
+import os
import random
+from datetime import datetime
+import numpy as np
+import pytorch_lightning as pl
import torch
import torch.optim as optim
-import pytorch_lightning as pl
+from src import models
+from tqdm import tqdm
from .utils import acc
-from src import models
+__all__ = ["Plain"]
-__all__ = [
- 'Plain'
-]
class Plain(pl.LightningModule):
"""
@@ -25,7 +23,7 @@ class Plain(pl.LightningModule):
and training/validation/testing steps for the training process.
"""
- name = 'Plain'
+ name = "Plain"
def __init__(self, conf, train_class_counts, id_to_labels, **kwargs):
"""
@@ -39,11 +37,12 @@ def __init__(self, conf, train_class_counts, id_to_labels, **kwargs):
"""
super().__init__()
self.hparams.update(conf.__dict__)
- self.save_hyperparameters(ignore=['conf', 'train_class_counts'])
+ self.save_hyperparameters(ignore=["conf", "train_class_counts"])
self.train_class_counts = train_class_counts
self.id_to_labels = id_to_labels
- self.net = models.__dict__[self.hparams.model_name](num_cls=self.hparams.num_classes,
- num_layers=self.hparams.num_layers)
+ self.net = models.__dict__[self.hparams.model_name](
+ num_cls=self.hparams.num_classes, num_layers=self.hparams.num_layers
+ )
def configure_optimizers(self):
"""
@@ -55,19 +54,25 @@ def configure_optimizers(self):
# Define parameters for the optimizer
net_optim_params_list = [
# Optimizer parameters for feature extraction
- {'params': self.net.feature.parameters(),
- 'lr': self.hparams.lr_feature,
- 'momentum': self.hparams.momentum_feature,
- 'weight_decay': self.hparams.weight_decay_feature},
+ {
+ "params": self.net.feature.parameters(),
+ "lr": self.hparams.lr_feature,
+ "momentum": self.hparams.momentum_feature,
+ "weight_decay": self.hparams.weight_decay_feature,
+ },
# Optimizer parameters for the classifier
- {'params': self.net.classifier.parameters(),
- 'lr': self.hparams.lr_classifier,
- 'momentum': self.hparams.momentum_classifier,
- 'weight_decay': self.hparams.weight_decay_classifier}
+ {
+ "params": self.net.classifier.parameters(),
+ "lr": self.hparams.lr_classifier,
+ "momentum": self.hparams.momentum_classifier,
+ "weight_decay": self.hparams.weight_decay_classifier,
+ },
]
# Setup optimizer and optimizer scheduler
optimizer = torch.optim.SGD(net_optim_params_list)
- scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=self.hparams.step_size, gamma=self.hparams.gamma)
+ scheduler = optim.lr_scheduler.StepLR(
+ optimizer, step_size=self.hparams.step_size, gamma=self.hparams.gamma
+ )
return [optimizer], [scheduler]
def on_train_start(self):
@@ -90,14 +95,14 @@ def training_step(self, batch, batch_idx):
Tensor: The loss for the current training step.
"""
data, label_ids = batch[0], batch[1]
-
+
# Forward pass
feats = self.net.feature(data)
logits = self.net.classifier(feats)
# Calculate loss
loss = self.net.criterion_cls(logits, label_ids)
self.log("train_loss", loss)
-
+
return loss
def on_validation_start(self):
@@ -119,9 +124,8 @@ def validation_step(self, batch, batch_idx):
feats = self.net.feature(data)
logits = self.net.classifier(feats)
preds = logits.argmax(dim=1)
-
- self.val_st_outs.append((preds.detach().cpu().numpy(),
- label_ids.detach().cpu().numpy()))
+
+ self.val_st_outs.append((preds.detach().cpu().numpy(), label_ids.detach().cpu().numpy()))
def on_validation_epoch_end(self):
"""
@@ -150,14 +154,17 @@ def test_step(self, batch, batch_idx):
feats = self.net.feature(data)
logits = self.net.classifier(feats)
preds = logits.argmax(dim=1)
-
- self.te_st_outs.append((preds.detach().cpu().numpy(),
- label_ids.detach().cpu().numpy(),
- feats.detach().cpu().numpy(),
- logits.detach().cpu().numpy(),
- labels, file_ids
- ))
-
+
+ self.te_st_outs.append(
+ (
+ preds.detach().cpu().numpy(),
+ label_ids.detach().cpu().numpy(),
+ feats.detach().cpu().numpy(),
+ logits.detach().cpu().numpy(),
+ labels,
+ file_ids,
+ )
+ )
def on_test_epoch_end(self):
"""
@@ -172,14 +179,23 @@ def on_test_epoch_end(self):
total_file_ids = np.concatenate([x[5] for x in self.te_st_outs], axis=0)
# Calculate the metrics and save the output
- self.eval_logging(total_preds[total_label_ids != -1],
- total_label_ids[total_label_ids != -1],
- print_class_acc=False)
-
- output_path = self.hparams.evaluate.replace('.ckpt', 'eval.npz')
- np.savez(output_path, preds=total_preds, label_ids=total_label_ids, feats=total_feats,
- logits=total_logits, labels=total_labels, file_ids=total_file_ids)
- print('Test output saved to {}.'.format(output_path))
+ self.eval_logging(
+ total_preds[total_label_ids != -1],
+ total_label_ids[total_label_ids != -1],
+ print_class_acc=False,
+ )
+
+ output_path = self.hparams.evaluate.replace(".ckpt", "eval.npz")
+ np.savez(
+ output_path,
+ preds=total_preds,
+ label_ids=total_label_ids,
+ feats=total_feats,
+ logits=total_logits,
+ labels=total_labels,
+ file_ids=total_file_ids,
+ )
+ print("Test output saved to {}.".format(output_path))
def on_predict_start(self):
"""
@@ -201,14 +217,16 @@ def predict_step(self, batch, batch_idx):
logits = self.net.classifier(feats)
preds = logits.argmax(dim=1)
probs = torch.softmax(logits, dim=1).max(dim=1)[0]
-
- self.pr_st_outs.append((preds.detach().cpu().numpy(),
- feats.detach().cpu().numpy(),
- logits.detach().cpu().numpy(),
- probs.detach().cpu().numpy(),
- file_ids
- ))
-
+
+ self.pr_st_outs.append(
+ (
+ preds.detach().cpu().numpy(),
+ feats.detach().cpu().numpy(),
+ logits.detach().cpu().numpy(),
+ probs.detach().cpu().numpy(),
+ file_ids,
+ )
+ )
def on_predict_epoch_end(self):
"""
@@ -223,25 +241,31 @@ def on_predict_epoch_end(self):
json_output = []
for i in range(len(total_preds)):
- json_output.append({
- "marker_id": "",
- "survey_pic_id": total_file_ids[i],
- "marker_confidence": float(total_probs[i]),
- "marker_gear_type": "ghostnet" if total_preds[i] == 1 else "neg",
- "marker_bounding_polygon": "",
- "marker_status": "unverified",
- "marker_ai_model": ""
- })
-
- output_path_full = self.hparams.evaluate.replace('.ckpt', '_predict.npz')
- np.savez(output_path_full, preds=total_preds, feats=total_feats,
- logits=total_logits, file_ids=total_file_ids)
- print('Predict output saved to {}.'.format(output_path_full))
-
- output_path_json = self.hparams.evaluate.replace('.ckpt', '_predict.json')
- json.dump(json_output, open(output_path_json, 'w'))
- print('Predict output json saved to {}.'.format(output_path_json))
-
+ json_output.append(
+ {
+ "marker_id": "",
+ "survey_pic_id": total_file_ids[i],
+ "marker_confidence": float(total_probs[i]),
+ "marker_gear_type": "ghostnet" if total_preds[i] == 1 else "neg",
+ "marker_bounding_polygon": "",
+ "marker_status": "unverified",
+ "marker_ai_model": "",
+ }
+ )
+
+ output_path_full = self.hparams.evaluate.replace(".ckpt", "_predict.npz")
+ np.savez(
+ output_path_full,
+ preds=total_preds,
+ feats=total_feats,
+ logits=total_logits,
+ file_ids=total_file_ids,
+ )
+ print("Predict output saved to {}.".format(output_path_full))
+
+ output_path_json = self.hparams.evaluate.replace(".ckpt", "_predict.json")
+ json.dump(json_output, open(output_path_json, "w"))
+ print("Predict output json saved to {}.".format(output_path_json))
def eval_logging(self, preds, labels, print_class_acc=False):
"""
@@ -259,27 +283,32 @@ def eval_logging(self, preds, labels, print_class_acc=False):
self.log("valid_mic_acc", mic_acc * 100)
if print_class_acc:
-
if self.train_class_counts:
- acc_list = [(class_acc[i], unique_eval_labels[i],
- self.id_to_labels[unique_eval_labels[i]],
- self.train_class_counts[unique_eval_labels[i]])
- for i in range(len(class_acc))]
-
- print('\n')
+ acc_list = [
+ (
+ class_acc[i],
+ unique_eval_labels[i],
+ self.id_to_labels[unique_eval_labels[i]],
+ self.train_class_counts[unique_eval_labels[i]],
+ )
+ for i in range(len(class_acc))
+ ]
+
+ print("\n")
for i in range(len(class_acc)):
- info = '{:>20} ({:<3}, tr {:>3}) Acc: '.format(acc_list[i][2],
- acc_list[i][1],
- acc_list[i][3])
- info += '{:.2f}'.format(acc_list[i][0] * 100)
+ info = "{:>20} ({:<3}, tr {:>3}) Acc: ".format(
+ acc_list[i][2], acc_list[i][1], acc_list[i][3]
+ )
+ info += "{:.2f}".format(acc_list[i][0] * 100)
print(info)
else:
- acc_list = [(class_acc[i], unique_eval_labels[i],
- self.id_to_labels[unique_eval_labels[i]])
- for i in range(len(class_acc))]
+ acc_list = [
+ (class_acc[i], unique_eval_labels[i], self.id_to_labels[unique_eval_labels[i]])
+ for i in range(len(class_acc))
+ ]
- print('\n')
+ print("\n")
for i in range(len(class_acc)):
- info = '{:>20} ({:<3}) Acc: '.format(acc_list[i][2], acc_list[i][1])
- info += '{:.2f}'.format(acc_list[i][0] * 100)
+ info = "{:>20} ({:<3}) Acc: ".format(acc_list[i][2], acc_list[i][1])
+ info += "{:.2f}".format(acc_list[i][0] * 100)
print(info)
diff --git a/PW_FT_classification/src/algorithms/utils.py b/PW_FT_classification/src/algorithms/utils.py
index 11a14da33..60804fb26 100644
--- a/PW_FT_classification/src/algorithms/utils.py
+++ b/PW_FT_classification/src/algorithms/utils.py
@@ -1,5 +1,6 @@
from sklearn.metrics import confusion_matrix
+
def acc(preds, labels):
"""
Calculate the accuracy metrics based on predictions and true labels.
diff --git a/PW_FT_classification/src/datasets/__init__.py b/PW_FT_classification/src/datasets/__init__.py
index a04dcce14..8ef4d4c1a 100644
--- a/PW_FT_classification/src/datasets/__init__.py
+++ b/PW_FT_classification/src/datasets/__init__.py
@@ -1 +1 @@
-from .custom import *
\ No newline at end of file
+from .custom import *
diff --git a/PW_FT_classification/src/datasets/custom.py b/PW_FT_classification/src/datasets/custom.py
index cbac698a2..38152e741 100644
--- a/PW_FT_classification/src/datasets/custom.py
+++ b/PW_FT_classification/src/datasets/custom.py
@@ -1,29 +1,33 @@
# Import necessary libraries
import os
from glob import glob
+
import numpy as np
import pandas as pd
+import pytorch_lightning as pl
import torch
from PIL import Image
+from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
-from torch.utils.data import Dataset, DataLoader
-import pytorch_lightning as pl
# Exportable class names for external use
-__all__ = [
- 'Custom_Crop'
-]
-
-# Define the allowed image extensions
-IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
-
-def has_file_allowed_extension(filename: str, extensions: tuple) -> bool:
- """Checks if a file is an allowed extension."""
- return filename.lower().endswith(extensions if isinstance(extensions, str) else tuple(extensions))
-
-def is_image_file(filename: str) -> bool:
- """Checks if a file is an allowed image extension."""
- return has_file_allowed_extension(filename, IMG_EXTENSIONS)
+__all__ = ["Custom_Crop"]
+
+# Define the allowed image extensions
+IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
+
+
+def has_file_allowed_extension(filename: str, extensions: tuple) -> bool:
+ """Checks if a file is an allowed extension."""
+ return filename.lower().endswith(
+ extensions if isinstance(extensions, str) else tuple(extensions)
+ )
+
+
+def is_image_file(filename: str) -> bool:
+ """Checks if a file is an allowed image extension."""
+ return has_file_allowed_extension(filename, IMG_EXTENSIONS)
+
# Define normalization mean and standard deviation for image preprocessing
mean = [0.485, 0.456, 0.406]
@@ -31,21 +35,22 @@ def is_image_file(filename: str) -> bool:
# Define data transformations for training and validation datasets
data_transforms = {
- 'train': transforms.Compose([
- transforms.RandomResizedCrop((224, 224), scale=(0.7, 1.0), ratio=(0.8, 1.2)),
- transforms.RandomHorizontalFlip(p=0.5),
- transforms.RandomVerticalFlip(p=0.5),
- transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
- transforms.ToTensor(),
- transforms.Normalize(mean, std)
- ]),
- 'val': transforms.Compose([
- transforms.Resize((224, 224)),
- transforms.ToTensor(),
- transforms.Normalize(mean, std)
- ]),
+ "train": transforms.Compose(
+ [
+ transforms.RandomResizedCrop((224, 224), scale=(0.7, 1.0), ratio=(0.8, 1.2)),
+ transforms.RandomHorizontalFlip(p=0.5),
+ transforms.RandomVerticalFlip(p=0.5),
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
+ transforms.ToTensor(),
+ transforms.Normalize(mean, std),
+ ]
+ ),
+ "val": transforms.Compose(
+ [transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean, std)]
+ ),
}
+
class Custom_Base_DS(Dataset):
"""
Base dataset class for handling custom datasets.
@@ -80,13 +85,18 @@ def load_data(self):
if self.predict:
# Load data for prediction
# self.data = glob(os.path.join(self.img_root,"*.{}".format(self.extension)))
- self.data = [os.path.join(dp, f) for dp, dn, filenames in os.walk(self.img_root) for f in filenames if is_image_file(f)] # dp: directory path, dn: directory name, f: filename
+ self.data = [
+ os.path.join(dp, f)
+ for dp, dn, filenames in os.walk(self.img_root)
+ for f in filenames
+ if is_image_file(f)
+ ] # dp: directory path, dn: directory name, f: filename
else:
# Load data for training/validation
- self.data = list(self.ann['path'])
- self.label_ids = list(self.ann['classification'])
- self.labels = list(self.ann['label'])
- print('Number of images loaded: ', len(self.data))
+ self.data = list(self.ann["path"])
+ self.label_ids = list(self.ann["classification"])
+ self.labels = list(self.ann["label"])
+ print("Number of images loaded: ", len(self.data))
def class_counts_cal(self):
"""
@@ -120,8 +130,8 @@ def __getitem__(self, index):
file_id = self.data[index]
file_dir = os.path.join(self.img_root, file_id) if not self.predict else file_id
- with open(file_dir, 'rb') as f:
- sample = Image.open(f).convert('RGB')
+ with open(file_dir, "rb") as f:
+ sample = Image.open(f).convert("RGB")
if self.transform is not None:
sample = self.transform(sample)
@@ -142,7 +152,7 @@ class Custom_Crop_DS(Custom_Base_DS):
Inherits from Custom_Base_DS and includes specific handling for cropped data.
"""
- def __init__(self, rootdir, dset='train', transform=None):
+ def __init__(self, rootdir, dset="train", transform=None):
"""
Initialize the Custom_Crop_DS with the dataset directory, type, and transformations.
@@ -151,12 +161,17 @@ def __init__(self, rootdir, dset='train', transform=None):
dset (str): Type of dataset (train, val, test, predict).
transform (callable, optional): Transformations to be applied to each data sample.
"""
- self.predict = dset == 'predict'
+ self.predict = dset == "predict"
super().__init__(rootdir=rootdir, transform=transform, predict=self.predict)
- self.img_root = rootdir if self.predict else os.path.join(self.rootdir, 'cropped_resized')
+ self.img_root = rootdir if self.predict else os.path.join(self.rootdir, "cropped_resized")
if not self.predict:
- self.ann = pd.read_csv(os.path.join(self.rootdir, 'cropped_resized', '{}_annotations_cropped.csv'
- .format('test' if dset == 'test' else dset)))
+ self.ann = pd.read_csv(
+ os.path.join(
+ self.rootdir,
+ "cropped_resized",
+ "{}_annotations_cropped.csv".format("test" if dset == "test" else dset),
+ )
+ )
self.load_data()
@@ -178,27 +193,41 @@ def __init__(self, conf):
"""
super().__init__()
self._log_hyperparams = True
- self.id_to_labels = None # We don't need this for evaluations. We should save this in model weights in the future
+ self.id_to_labels = None # We don't need this for evaluations. We should save this in model weights in the future
self.train_class_counts = None
self.conf = conf
- print('Loading datasets...')
+ print("Loading datasets...")
# Load datasets for different modes (training, validation, testing, prediction)
if self.conf.predict:
- self.dset_pr = self.ds(rootdir=self.conf.predict_root, dset='predict', transform=data_transforms['val'])
+ self.dset_pr = self.ds(
+ rootdir=self.conf.predict_root, dset="predict", transform=data_transforms["val"]
+ )
elif self.conf.test:
- self.dset_te = self.ds(rootdir=self.conf.dataset_root, dset='test', transform=data_transforms['val'])
- self.id_to_labels = {i: l for i, l in np.unique(pd.Series(zip(self.dset_te.label_ids, self.dset_te.labels)))}
+ self.dset_te = self.ds(
+ rootdir=self.conf.dataset_root, dset="test", transform=data_transforms["val"]
+ )
+ self.id_to_labels = {
+ i: l
+ for i, l in np.unique(pd.Series(zip(self.dset_te.label_ids, self.dset_te.labels)))
+ }
else:
- self.dset_tr = self.ds(rootdir=self.conf.dataset_root, dset='train', transform=data_transforms['train'])
- self.dset_val = self.ds(rootdir=self.conf.dataset_root, dset='val', transform=data_transforms['val'])
-
- self.id_to_labels = {i: l for i, l in np.unique(pd.Series(zip(self.dset_tr.label_ids, self.dset_tr.labels)))}
+ self.dset_tr = self.ds(
+ rootdir=self.conf.dataset_root, dset="train", transform=data_transforms["train"]
+ )
+ self.dset_val = self.ds(
+ rootdir=self.conf.dataset_root, dset="val", transform=data_transforms["val"]
+ )
+
+ self.id_to_labels = {
+ i: l
+ for i, l in np.unique(pd.Series(zip(self.dset_tr.label_ids, self.dset_tr.labels)))
+ }
# Calculate class counts and label mappings
self.unique_label_ids, self.train_class_counts = self.dset_tr.class_counts_cal()
- print('Datasets loaded.')
+ print("Datasets loaded.")
def train_dataloader(self):
"""
@@ -207,7 +236,14 @@ def train_dataloader(self):
Returns:
DataLoader: DataLoader for the training dataset.
"""
- return DataLoader(self.dset_tr, batch_size=self.conf.batch_size, shuffle=True, pin_memory=True, num_workers=self.conf.num_workers, drop_last=False)
+ return DataLoader(
+ self.dset_tr,
+ batch_size=self.conf.batch_size,
+ shuffle=True,
+ pin_memory=True,
+ num_workers=self.conf.num_workers,
+ drop_last=False,
+ )
def val_dataloader(self):
"""
@@ -216,7 +252,14 @@ def val_dataloader(self):
Returns:
DataLoader: DataLoader for the validation dataset.
"""
- return DataLoader(self.dset_val, batch_size=self.conf.batch_size, shuffle=False, pin_memory=True, num_workers=self.conf.num_workers, drop_last=False)
+ return DataLoader(
+ self.dset_val,
+ batch_size=self.conf.batch_size,
+ shuffle=False,
+ pin_memory=True,
+ num_workers=self.conf.num_workers,
+ drop_last=False,
+ )
def test_dataloader(self):
"""
@@ -225,7 +268,14 @@ def test_dataloader(self):
Returns:
DataLoader: DataLoader for the testing dataset.
"""
- return DataLoader(self.dset_te, batch_size=256, shuffle=False, pin_memory=True, num_workers=self.conf.num_workers, drop_last=False)
+ return DataLoader(
+ self.dset_te,
+ batch_size=256,
+ shuffle=False,
+ pin_memory=True,
+ num_workers=self.conf.num_workers,
+ drop_last=False,
+ )
def predict_dataloader(self):
"""
@@ -234,7 +284,14 @@ def predict_dataloader(self):
Returns:
DataLoader: DataLoader for the prediction dataset.
"""
- return DataLoader(self.dset_pr, batch_size=64, shuffle=False, pin_memory=True, num_workers=self.conf.num_workers, drop_last=False)
+ return DataLoader(
+ self.dset_pr,
+ batch_size=64,
+ shuffle=False,
+ pin_memory=True,
+ num_workers=self.conf.num_workers,
+ drop_last=False,
+ )
class Custom_Crop(Custom_Base):
diff --git a/PW_FT_classification/src/models/__init__.py b/PW_FT_classification/src/models/__init__.py
index 70be09a93..9e8b093ac 100644
--- a/PW_FT_classification/src/models/__init__.py
+++ b/PW_FT_classification/src/models/__init__.py
@@ -1 +1 @@
-from .plain_resnet import *
\ No newline at end of file
+from .plain_resnet import *
diff --git a/PW_FT_classification/src/models/plain_resnet.py b/PW_FT_classification/src/models/plain_resnet.py
index d664c91b7..e8fd61113 100644
--- a/PW_FT_classification/src/models/plain_resnet.py
+++ b/PW_FT_classification/src/models/plain_resnet.py
@@ -1,26 +1,25 @@
-import os
import copy
+import os
from collections import OrderedDict
+
import torch
import torch.nn as nn
-from torchvision.models.resnet import BasicBlock, Bottleneck
from torchvision.models.resnet import *
-
+from torchvision.models.resnet import BasicBlock, Bottleneck
try:
from torch.hub import load_state_dict_from_url
except ImportError:
from torch.utils.model_zoo import load_state_dict_from_url
# Exportable class names for external use
-__all__ = [
- 'PlainResNetClassifier'
-]
+__all__ = ["PlainResNetClassifier"]
model_urls = {
- 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth',
- 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth'
+ "resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth",
+ "resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth",
}
+
class ResNetBackbone(ResNet):
"""
Custom ResNet backbone class for feature extraction.
@@ -93,7 +92,7 @@ class PlainResNetClassifier(nn.Module):
Extends nn.Module and provides a complete ResNet-based classifier, including feature extraction and classification layers.
"""
- name = 'PlainResNetClassifier'
+ name = "PlainResNetClassifier"
def __init__(self, num_cls=10, num_layers=18):
"""
@@ -123,17 +122,19 @@ def setup_net(self):
if self.num_layers == 18:
block = BasicBlock
layers = [2, 2, 2, 2]
- #self.pretrained_weights = ResNet18_Weights.IMAGENET1K_V1
- self.pretrained_weights = state_dict = load_state_dict_from_url(model_urls['resnet18'],
- progress=True)
+ # self.pretrained_weights = ResNet18_Weights.IMAGENET1K_V1
+ self.pretrained_weights = state_dict = load_state_dict_from_url(
+ model_urls["resnet18"], progress=True
+ )
elif self.num_layers == 50:
block = Bottleneck
layers = [3, 4, 6, 3]
- #self.pretrained_weights = ResNet50_Weights.IMAGENET1K_V1
- self.pretrained_weights = state_dict = load_state_dict_from_url(model_urls['resnet50'],
- progress=True)
+ # self.pretrained_weights = ResNet50_Weights.IMAGENET1K_V1
+ self.pretrained_weights = state_dict = load_state_dict_from_url(
+ model_urls["resnet50"], progress=True
+ )
else:
- raise Exception('ResNet Type not supported.')
+ raise Exception("ResNet Type not supported.")
# Constructing the feature extractor and classifier
self.feature = ResNetBackbone(block, layers, **kwargs)
@@ -151,10 +152,14 @@ def feat_init(self):
Initialize the feature extractor with pre-trained weights.
"""
# Load pre-trained weights and adjust for the current model
- #init_weights = self.pretrained_weights.get_state_dict(progress=True)
+ # init_weights = self.pretrained_weights.get_state_dict(progress=True)
init_weights = self.pretrained_weights
- init_weights = OrderedDict({k.replace('module.', '').replace('feature.', ''): init_weights[k]
- for k in init_weights})
+ init_weights = OrderedDict(
+ {
+ k.replace("module.", "").replace("feature.", ""): init_weights[k]
+ for k in init_weights
+ }
+ )
# Load the weights into the feature extractor
self.feature.load_state_dict(init_weights, strict=False)
@@ -165,5 +170,5 @@ def feat_init(self):
missing_keys = self_keys - load_keys
unused_keys = load_keys - self_keys
- print('missing keys: {}'.format(sorted(list(missing_keys))))
- print('unused_keys: {}'.format(sorted(list(unused_keys))))
+ print("missing keys: {}".format(sorted(list(missing_keys))))
+ print("unused_keys: {}".format(sorted(list(unused_keys))))
diff --git a/PW_FT_classification/src/utils/batch_detection_cropping.py b/PW_FT_classification/src/utils/batch_detection_cropping.py
index a89301486..6332a3942 100644
--- a/PW_FT_classification/src/utils/batch_detection_cropping.py
+++ b/PW_FT_classification/src/utils/batch_detection_cropping.py
@@ -3,16 +3,20 @@
""" Demo for batch detection, cropping and resizing"""
-#%%
-# PyTorch imports
+# %%
+# PyTorch imports
import torch
-# Importing the model, dataset, transformations and utility functions from PytorchWildlife
-from PytorchWildlife.models import detection as pw_detection
-from PytorchWildlife.data import transforms as pw_trans
-from PytorchWildlife.data import datasets as pw_data
+
# Importing the utility function for saving cropped images
from src.utils import utils
+from PytorchWildlife.data import datasets as pw_data
+from PytorchWildlife.data import transforms as pw_trans
+
+# Importing the model, dataset, transformations and utility functions from PytorchWildlife
+from PytorchWildlife.models import detection as pw_detection
+
+
def batch_detection_cropping(folder_path, output_path, annotation_file):
# Setting the device to use for computations ('cuda' indicates GPU)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@@ -29,5 +33,4 @@ def batch_detection_cropping(folder_path, output_path, annotation_file):
return crop_annotation_path
-
# %%
diff --git a/PW_FT_classification/src/utils/data_splitting.py b/PW_FT_classification/src/utils/data_splitting.py
index f667c6247..f5b37a855 100644
--- a/PW_FT_classification/src/utils/data_splitting.py
+++ b/PW_FT_classification/src/utils/data_splitting.py
@@ -1,20 +1,22 @@
## DATA SPLITTING
+import os
+
import pandas as pd
from sklearn.model_selection import train_test_split
-import os
from tqdm import tqdm
+
def create_splits(csv_path, output_folder, test_size=0.2, val_size=0.1):
"""
Create stratified training, validation, and testing splits.
-
+
Args:
- csv_path (str): Path to the csv containing the annotations.
- output_folder (str): Destination directory to save the annotation split csv files.
- test_size (float): Proportion of the dataset to include in the test split.
- val_size (float): Proportion of the training dataset to include in the validation split.
-
+
Returns:
- A tuple of DataFrames: (train_set, val_set, test_set)
- Saves the splits into separate csv files in the output_folder.
@@ -22,41 +24,46 @@ def create_splits(csv_path, output_folder, test_size=0.2, val_size=0.1):
# Load the data from the csv file
data = pd.read_csv(csv_path)
# Separate the features and the targets
- X = data[['path','label']]
- y = data['classification']
-
+ X = data[["path", "label"]]
+ y = data["classification"]
+
# First split to separate out the test set
- X_temp, X_test, y_temp, y_test = train_test_split(X, y, test_size=test_size, stratify=y, random_state=42)
-
+ X_temp, X_test, y_temp, y_test = train_test_split(
+ X, y, test_size=test_size, stratify=y, random_state=42
+ )
+
# Adjust val_size to account for the initial split
val_size_adjusted = val_size / (1 - test_size)
-
+
# Second split to separate out the validation set
- X_train, X_val, y_train, y_val = train_test_split(X_temp, y_temp, test_size=val_size_adjusted, stratify=y_temp, random_state=42)
-
+ X_train, X_val, y_train, y_val = train_test_split(
+ X_temp, y_temp, test_size=val_size_adjusted, stratify=y_temp, random_state=42
+ )
+
# Combine features, labels, and classification back into dataframes
train_set = pd.concat([X_train.reset_index(drop=True), y_train.reset_index(drop=True)], axis=1)
val_set = pd.concat([X_val.reset_index(drop=True), y_val.reset_index(drop=True)], axis=1)
test_set = pd.concat([X_test.reset_index(drop=True), y_test.reset_index(drop=True)], axis=1)
-
+
# Create the output directory in case that it does not exist
os.makedirs(output_folder, exist_ok=True)
# Save the splits to new CSV files
- train_set.to_csv(os.path.join(output_folder,'train_annotations.csv'), index=False)
- val_set.to_csv(os.path.join(output_folder,'val_annotations.csv'), index=False)
- test_set.to_csv(os.path.join(output_folder,'test_annotations.csv'), index=False)
+ train_set.to_csv(os.path.join(output_folder, "train_annotations.csv"), index=False)
+ val_set.to_csv(os.path.join(output_folder, "val_annotations.csv"), index=False)
+ test_set.to_csv(os.path.join(output_folder, "test_annotations.csv"), index=False)
# Return the dataframes
return train_set, val_set, test_set
+
def split_by_location(csv_path, output_folder, val_size=0.15, test_size=0.15, random_state=None):
"""
Splits the dataset into train, validation, and test sets based on location, ensuring that:
1. All images from the same location are in the same split.
2. The split is random among the locations.
3. Saves the split datasets into CSV files.
-
+
Parameters:
- csv_path: Path to the csv containing the annotations.
- train_size, val_size, test_size: float, proportions of the dataset to include in the train, validation, and test splits.
@@ -67,27 +74,31 @@ def split_by_location(csv_path, output_folder, val_size=0.15, test_size=0.15, ra
# Calculate train size based on val and test size
train_size = 1.0 - val_size - test_size
-
+
# Get unique locations
- unique_locations = data['Location'].unique()
+ unique_locations = data["Location"].unique()
# Split locations into train and temp (temporary holding for val and test)
- train_locs, temp_locs = train_test_split(unique_locations, train_size=train_size, random_state=random_state)
-
+ train_locs, temp_locs = train_test_split(
+ unique_locations, train_size=train_size, random_state=random_state
+ )
+
# Adjust the proportions for val and test based on the remaining locations
temp_size = val_size / (val_size + test_size)
- val_locs, test_locs = train_test_split(temp_locs, train_size=temp_size, random_state=random_state)
-
+ val_locs, test_locs = train_test_split(
+ temp_locs, train_size=temp_size, random_state=random_state
+ )
+
# Allocate images to train, validation, and test sets based on their location
- train_data = data[data['Location'].isin(train_locs)]
- val_data = data[data['Location'].isin(val_locs)]
- test_data = data[data['Location'].isin(test_locs)]
-
+ train_data = data[data["Location"].isin(train_locs)]
+ val_data = data[data["Location"].isin(val_locs)]
+ test_data = data[data["Location"].isin(test_locs)]
+
# Save the datasets to CSV files
- train_data.to_csv(os.path.join(output_folder,'train_annotations.csv'), index=False)
- val_data.to_csv(os.path.join(output_folder,'val_annotations.csv'), index=False)
- test_data.to_csv(os.path.join(output_folder,'test_annotations.csv'), index=False)
-
+ train_data.to_csv(os.path.join(output_folder, "train_annotations.csv"), index=False)
+ val_data.to_csv(os.path.join(output_folder, "val_annotations.csv"), index=False)
+ test_data.to_csv(os.path.join(output_folder, "test_annotations.csv"), index=False)
+
# Return the split datasets
return train_data, val_data, test_data
@@ -98,7 +109,7 @@ def split_by_seq(csv_path, output_folder, val_size=0.15, test_size=0.15, random_
1. All images from the same sequence are in the same split.
2. The split is random among the sequences.
3. Saves the split datasets into CSV files.
-
+
Parameters:
- csv_path: Path to the csv containing the annotations.
- train_size, val_size, test_size: float, proportions of the dataset to include in the train, validation, and test splits.
@@ -108,40 +119,44 @@ def split_by_seq(csv_path, output_folder, val_size=0.15, test_size=0.15, random_
data = pd.read_csv(csv_path)
# Convert 'Photo_Time' from string to datetime
- data['Photo_Time'] = pd.to_datetime(data['Photo_Time'])
+ data["Photo_Time"] = pd.to_datetime(data["Photo_Time"])
# Calculate train size based on val and test size
train_size = 1 - val_size - test_size
-
+
# Sort by 'Photo_Time' to ensure chronological order
- data = data.sort_values(by=['Photo_Time']).reset_index(drop=True)
+ data = data.sort_values(by=["Photo_Time"]).reset_index(drop=True)
# Group photos into sequences based on a 30-second interval
- time_groups = data.groupby(pd.Grouper(key='Photo_Time', freq='30S'))
+ time_groups = data.groupby(pd.Grouper(key="Photo_Time", freq="30S"))
# Assign unique sequence IDs to each group
for s, i in tqdm(enumerate(time_groups.indices.values())):
- data.loc[i, 'Seq_ID'] = int(s)
+ data.loc[i, "Seq_ID"] = int(s)
# Get unique sequence IDs
- unique_seq_ids = data['Seq_ID'].unique()
-
+ unique_seq_ids = data["Seq_ID"].unique()
+
# Split sequence IDs into train and temp (temporary holding for val and test)
- train_seq_ids, temp_seq_ids = train_test_split(unique_seq_ids, train_size=train_size, random_state=random_state)
-
+ train_seq_ids, temp_seq_ids = train_test_split(
+ unique_seq_ids, train_size=train_size, random_state=random_state
+ )
+
# Adjust the proportions for val and test based on the remaining sequences
temp_size = val_size / (val_size + test_size)
- val_seq_ids, test_seq_ids = train_test_split(temp_seq_ids, train_size=temp_size, random_state=random_state)
-
+ val_seq_ids, test_seq_ids = train_test_split(
+ temp_seq_ids, train_size=temp_size, random_state=random_state
+ )
+
# Allocate images to train, validation, and test sets based on their sequence ID
- train_data = data[data['Seq_ID'].isin(train_seq_ids)]
- val_data = data[data['Seq_ID'].isin(val_seq_ids)]
- test_data = data[data['Seq_ID'].isin(test_seq_ids)]
-
+ train_data = data[data["Seq_ID"].isin(train_seq_ids)]
+ val_data = data[data["Seq_ID"].isin(val_seq_ids)]
+ test_data = data[data["Seq_ID"].isin(test_seq_ids)]
+
# Save the datasets to CSV files
- train_data.to_csv(os.path.join(output_folder,'train_annotations.csv'), index=False)
- val_data.to_csv(os.path.join(output_folder,'val_annotations.csv'), index=False)
- test_data.to_csv(os.path.join(output_folder,'test_annotations.csv'), index=False)
+ train_data.to_csv(os.path.join(output_folder, "train_annotations.csv"), index=False)
+ val_data.to_csv(os.path.join(output_folder, "val_annotations.csv"), index=False)
+ test_data.to_csv(os.path.join(output_folder, "test_annotations.csv"), index=False)
# Return the split datasets
return train_data, val_data, test_data
diff --git a/PW_FT_classification/src/utils/utils.py b/PW_FT_classification/src/utils/utils.py
index 0b359a541..7707dc9b4 100644
--- a/PW_FT_classification/src/utils/utils.py
+++ b/PW_FT_classification/src/utils/utils.py
@@ -1,9 +1,11 @@
import os
-import pandas as pd
+
import cv2
-import supervision as sv
-from PIL import Image
import numpy as np
+import pandas as pd
+import supervision as sv
+from PIL import Image
+
def save_crop_images(results, output_dir, original_csv_path, overwrite=False):
"""
@@ -29,44 +31,51 @@ def save_crop_images(results, output_dir, original_csv_path, overwrite=False):
# Prepare a list to store new records for the new CSV
new_records = []
-
+
os.makedirs(output_dir, exist_ok=True)
with sv.ImageSink(target_dir_path=output_dir, overwrite=overwrite) as sink:
for entry in results:
# Process the data if the name of the file is in the dataframe
- if os.path.basename(entry["img_id"]) in original_df['path'].values:
- for i, (xyxy, cat) in enumerate(zip(entry["detections"].xyxy, entry["detections"].class_id)):
+ if os.path.basename(entry["img_id"]) in original_df["path"].values:
+ for i, (xyxy, cat) in enumerate(
+ zip(entry["detections"].xyxy, entry["detections"].class_id)
+ ):
cropped_img = sv.crop_image(
image=np.array(Image.open(entry["img_id"]).convert("RGB")), xyxy=xyxy
)
new_img_name = "{}_{}_{}".format(
- int(cat), i, entry["img_id"].rsplit(os.sep, 1)[1])
+ int(cat), i, entry["img_id"].rsplit(os.sep, 1)[1]
+ )
sink.save_image(
- image=cv2.cvtColor(cropped_img, cv2.COLOR_RGB2BGR),
- image_name=new_img_name
- ),
-
+ image=cv2.cvtColor(cropped_img, cv2.COLOR_RGB2BGR), image_name=new_img_name
+ ),
+
# Save the crop into a new csv
- image_name = entry['img_id']
-
- classification_id = original_df[original_df['path'].str.endswith(image_name.split(os.sep)[-1])]['classification'].values[0]
- classification_name = original_df[original_df['path'].str.endswith(image_name.split(os.sep)[-1])]['label'].values[0]
+ image_name = entry["img_id"]
+
+ classification_id = original_df[
+ original_df["path"].str.endswith(image_name.split(os.sep)[-1])
+ ]["classification"].values[0]
+ classification_name = original_df[
+ original_df["path"].str.endswith(image_name.split(os.sep)[-1])
+ ]["label"].values[0]
# Add record to the new CSV data
- new_records.append({
- 'path': new_img_name,
- 'classification': classification_id,
- 'label': classification_name
- })
+ new_records.append(
+ {
+ "path": new_img_name,
+ "classification": classification_id,
+ "label": classification_name,
+ }
+ )
# Create a DataFrame from the new records
new_df = pd.DataFrame(new_records)
# Define the path for the new CSV file
- new_file_name = "{}_cropped.csv".format(original_csv_path.split(os.sep)[-1].split('.')[0])
+ new_file_name = "{}_cropped.csv".format(original_csv_path.split(os.sep)[-1].split(".")[0])
new_csv_path = os.path.join(output_dir, new_file_name)
# Save the new DataFrame to CSV
new_df.to_csv(new_csv_path, index=False)
return new_csv_path
-
diff --git a/PW_FT_detection/config.yaml b/PW_FT_detection/config.yaml
index d4c9f5485..05357a0ad 100644
--- a/PW_FT_detection/config.yaml
+++ b/PW_FT_detection/config.yaml
@@ -25,7 +25,3 @@ save_json: True
plot: True
device_val: 0
batch_size_val: 12
-
-
-
-
diff --git a/PW_FT_detection/main.py b/PW_FT_detection/main.py
index d3f7cd140..d6afac3bb 100644
--- a/PW_FT_detection/main.py
+++ b/PW_FT_detection/main.py
@@ -1,79 +1,79 @@
-from ultralytics import YOLO, RTDETR
-from utils import get_model_path
-from munch import Munch
-import yaml
import os
-def main(config:str='./config.yaml'):
+import yaml
+from munch import Munch
+from ultralytics import RTDETR, YOLO
+from utils import get_model_path
+
+def main(config: str = "./config.yaml"):
# Load and set configurations from the YAML file
with open(config) as f:
cfg = Munch(yaml.load(f, Loader=yaml.FullLoader))
- if cfg.resume:
- model_path = cfg.weights
- else:
- model_path = get_model_path(cfg.model_name)
-
- if cfg.model == "YOLO":
- model = YOLO(model_path)
- elif cfg.model == "RTDETR":
- model = RTDETR(model_path)
- else:
- raise ValueError("Model not supported")
+ if cfg.resume:
+ model_path = cfg.weights
+ else:
+ model_path = get_model_path(cfg.model_name)
+
+ if cfg.model == "YOLO":
+ model = YOLO(model_path)
+ elif cfg.model == "RTDETR":
+ model = RTDETR(model_path)
+ else:
+ raise ValueError("Model not supported")
with open(cfg.data) as f:
data = yaml.safe_load(f)
if not os.path.isabs(data["path"]):
data["path"] = os.path.abspath(data["path"])
- with open(cfg.data, 'w') as f:
+ with open(cfg.data, "w") as f:
yaml.dump(data, f)
-
+
model.info()
if cfg.task == "train":
-
results = model.train(
data=cfg.data,
- epochs=cfg.epochs,
+ epochs=cfg.epochs,
imgsz=cfg.imgsz,
- device=cfg.device_train,
- save_period=cfg.save_period,
- workers=cfg.workers,
- batch=cfg.batch_size_train,
+ device=cfg.device_train,
+ save_period=cfg.save_period,
+ workers=cfg.workers,
+ batch=cfg.batch_size_train,
val=cfg.val,
project=f"runs/train_{cfg.exp_name}",
name="exp",
patience=cfg.patience,
- resume=cfg.resume
- )
+ resume=cfg.resume,
+ )
if cfg.task == "validation":
-
metrics = model.val(
data=cfg.data,
- save_json=cfg.save_json,
- plots=cfg.plots,
- device=cfg.device_val,
- project=f'runs/val_{cfg.exp_name}',
+ save_json=cfg.save_json,
+ plots=cfg.plots,
+ device=cfg.device_val,
+ project=f"runs/val_{cfg.exp_name}",
name="exp",
- batch=cfg.batch_size_val)
+ batch=cfg.batch_size_val,
+ )
- metrics.box.map # map50-95
+ metrics.box.map # map50-95
metrics.box.map50 # map50
metrics.box.map75 # map75
- metrics.box.maps # a list contains map50-95 of each category
+ metrics.box.maps # a list contains map50-95 of each category
if cfg.task == "inference":
-
results = model(cfg.test_data)
save_path = os.path.join("inference_results", cfg.exp_name)
os.makedirs(save_path, exist_ok=True)
for i in range(len(results)):
results[i][0].boxes
- results[i].save(filename=os.path.join(save_path,f"inference_{i}.jpg"))
+ results[i].save(filename=os.path.join(save_path, f"inference_{i}.jpg"))
+
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/PW_FT_detection/requirements.txt b/PW_FT_detection/requirements.txt
index f98f2933b..bc470a5c0 100644
--- a/PW_FT_detection/requirements.txt
+++ b/PW_FT_detection/requirements.txt
@@ -1,4 +1,4 @@
PytorchWildlife
ultralytics
munch
-wget
\ No newline at end of file
+wget
diff --git a/PW_FT_detection/utils.py b/PW_FT_detection/utils.py
index 9cc6b262f..d20cd00eb 100644
--- a/PW_FT_detection/utils.py
+++ b/PW_FT_detection/utils.py
@@ -1,11 +1,12 @@
import os
-import wget
+
import torch
+import wget
-def get_model_path(model):
- if model == "MDV6-yolov9-c":
- url = "https://zenodo.org/records/14567879/files/MDV6b-yolov9c.pt?download=1"
+def get_model_path(model):
+ if model == "MDV6-yolov9-c":
+ url = "https://zenodo.org/records/14567879/files/MDV6b-yolov9c.pt?download=1"
model_name = "MDV6b-yolov9c.pt"
elif model == "MDV6-yolov9-e":
url = "https://zenodo.org/records/14567879/files/MDV6-yolov9e.pt?download=1"
@@ -20,12 +21,14 @@ def get_model_path(model):
url = "https://zenodo.org/records/14567879/files/MDV6b-rtdetrl.pt?download=1"
model_name = "MDV6b-rtdetrl.pt"
else:
- raise ValueError('Select a valid model version: MDV6-yolov9-c, MDV6-yolov9-e, MDV6-yolov10-c, MDV6-yolov10-e or MDV6-rtdetr-c')
+ raise ValueError(
+ "Select a valid model version: MDV6-yolov9-c, MDV6-yolov9-e, MDV6-yolov10-c, MDV6-yolov10-e or MDV6-rtdetr-c"
+ )
if not os.path.exists(os.path.join(torch.hub.get_dir(), "checkpoints", model_name)):
os.makedirs(os.path.join(torch.hub.get_dir(), "checkpoints"), exist_ok=True)
model_path = wget.download(url, out=os.path.join(torch.hub.get_dir(), "checkpoints"))
else:
model_path = os.path.join(torch.hub.get_dir(), "checkpoints", model_name)
-
- return model_path
\ No newline at end of file
+
+ return model_path
diff --git a/PytorchWildlife/__init__.py b/PytorchWildlife/__init__.py
index a59275f7a..1c4f1c9ed 100644
--- a/PytorchWildlife/__init__.py
+++ b/PytorchWildlife/__init__.py
@@ -1,4 +1,5 @@
import importlib.metadata as importlib_metadata
+
try:
# This will read version from pyproject.toml
__version__ = importlib_metadata.version(__package__ or __name__)
@@ -7,4 +8,4 @@
from .data import *
from .models import *
-from .utils import *
\ No newline at end of file
+from .utils import *
diff --git a/PytorchWildlife/data/__init__.py b/PytorchWildlife/data/__init__.py
index 82ef5ae2d..cfa56d0b9 100644
--- a/PytorchWildlife/data/__init__.py
+++ b/PytorchWildlife/data/__init__.py
@@ -1,2 +1,2 @@
from .datasets import *
-from .transforms import *
\ No newline at end of file
+from .transforms import *
diff --git a/PytorchWildlife/data/datasets.py b/PytorchWildlife/data/datasets.py
index ce280f02f..de0629103 100644
--- a/PytorchWildlife/data/datasets.py
+++ b/PytorchWildlife/data/datasets.py
@@ -3,10 +3,11 @@
import os
from glob import glob
-from PIL import Image, ImageFile
+
import numpy as np
import supervision as sv
import torch
+from PIL import Image, ImageFile
from torch.utils.data import Dataset
# To handle truncated images during loading
@@ -15,23 +16,28 @@
# Making the DetectionImageFolder class available for import from this module
__all__ = [
"DetectionImageFolder",
- ]
-
-# Define the allowed image extensions
-IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
-
-def has_file_allowed_extension(filename: str, extensions: tuple) -> bool:
- """Checks if a file is an allowed extension."""
- return filename.lower().endswith(extensions if isinstance(extensions, str) else tuple(extensions))
-
-def is_image_file(filename: str) -> bool:
- """Checks if a file is an allowed image extension."""
- return has_file_allowed_extension(filename, IMG_EXTENSIONS)
+]
+
+# Define the allowed image extensions
+IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
+
+
+def has_file_allowed_extension(filename: str, extensions: tuple) -> bool:
+ """Checks if a file is an allowed extension."""
+ return filename.lower().endswith(
+ extensions if isinstance(extensions, str) else tuple(extensions)
+ )
+
+
+def is_image_file(filename: str) -> bool:
+ """Checks if a file is an allowed image extension."""
+ return has_file_allowed_extension(filename, IMG_EXTENSIONS)
+
class ImageFolder(Dataset):
"""
A PyTorch Dataset for loading images from a specified directory.
- Each item in the dataset is a tuple containing the image data,
+ Each item in the dataset is a tuple containing the image data,
the image's path, and the original size of the image.
"""
@@ -46,7 +52,12 @@ def __init__(self, image_dir, transform=None):
super(ImageFolder, self).__init__()
self.image_dir = image_dir
self.transform = transform
- self.images = [os.path.join(dp, f) for dp, dn, filenames in os.walk(image_dir) for f in filenames if is_image_file(f)] # dp: directory path, dn: directory name, f: filename
+ self.images = [
+ os.path.join(dp, f)
+ for dp, dn, filenames in os.walk(image_dir)
+ for f in filenames
+ if is_image_file(f)
+ ] # dp: directory path, dn: directory name, f: filename
def __getitem__(self, idx) -> tuple:
"""
@@ -59,7 +70,7 @@ def __getitem__(self, idx) -> tuple:
tuple: Contains the image data, the image's path, and its original size.
"""
pass
-
+
def __len__(self) -> int:
"""
Returns the total number of images in the dataset.
@@ -69,10 +80,11 @@ def __len__(self) -> int:
"""
return len(self.images)
+
class ClassificationImageFolder(ImageFolder):
"""
A PyTorch Dataset for loading images from a specified directory.
- Each item in the dataset is a tuple containing the image data,
+ Each item in the dataset is a tuple containing the image data,
the image's path, and the original size of the image.
"""
@@ -101,7 +113,7 @@ def __getitem__(self, idx) -> tuple:
# Load and convert image to RGB
img = Image.open(img_path).convert("RGB")
-
+
# Apply transformation if specified
if self.transform:
img = self.transform(img)
@@ -112,7 +124,7 @@ def __getitem__(self, idx) -> tuple:
class DetectionImageFolder(ImageFolder):
"""
A PyTorch Dataset for loading images from a specified directory.
- Each item in the dataset is a tuple containing the image data,
+ Each item in the dataset is a tuple containing the image data,
the image's path, and the original size of the image.
"""
@@ -142,23 +154,23 @@ def __getitem__(self, idx) -> tuple:
# Load and convert image to RGB
img = Image.open(img_path).convert("RGB")
img_size_ori = img.size[::-1]
-
+
# Apply transformation if specified
if self.transform:
img = self.transform(img)
return img, img_path, torch.tensor(img_size_ori)
-
+
# TODO: Under development for efficiency improvement
class DetectionCrops(Dataset):
-
def __init__(self, detection_results, transform=None, path_head=None, animal_cls_id=0):
-
self.detection_results = detection_results
self.transform = transform
self.path_head = path_head
- self.animal_cls_id = animal_cls_id # This determines which detection class id represents animals.
+ self.animal_cls_id = (
+ animal_cls_id # This determines which detection class id represents animals.
+ )
self.img_ids = []
self.xyxys = []
@@ -188,11 +200,10 @@ def __getitem__(self, idx) -> tuple:
xyxy = self.xyxys[idx]
img_path = os.path.join(self.path_head, img_id) if self.path_head else img_id
-
+
# Load and crop image with supervision
- img = sv.crop_image(np.array(Image.open(img_path).convert("RGB")),
- xyxy=xyxy)
-
+ img = sv.crop_image(np.array(Image.open(img_path).convert("RGB")), xyxy=xyxy)
+
# Apply transformation if specified
if self.transform:
img = self.transform(Image.fromarray(img))
@@ -200,4 +211,4 @@ def __getitem__(self, idx) -> tuple:
return img, img_path
def __len__(self) -> int:
- return len(self.img_ids)
\ No newline at end of file
+ return len(self.img_ids)
diff --git a/PytorchWildlife/data/transforms.py b/PytorchWildlife/data/transforms.py
index c9bceab28..ba1e27e11 100644
--- a/PytorchWildlife/data/transforms.py
+++ b/PytorchWildlife/data/transforms.py
@@ -3,23 +3,28 @@
import numpy as np
import torch
-from torchvision import transforms
-import torchvision.transforms as T
import torch.nn.functional as F
+import torchvision.transforms as T
from PIL import Image
+from torchvision import transforms
# Making the provided classes available for import from this module
-__all__ = [
- "MegaDetector_v5_Transform",
- "Classification_Inference_Transform"
-]
-
-
-def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=False, scaleFill=False, scaleup=True, stride=32) -> torch.Tensor:
+__all__ = ["MegaDetector_v5_Transform", "Classification_Inference_Transform"]
+
+
+def letterbox(
+ im,
+ new_shape=(640, 640),
+ color=(114, 114, 114),
+ auto=False,
+ scaleFill=False,
+ scaleup=True,
+ stride=32,
+) -> torch.Tensor:
"""
Resize and pad an image to a desired shape while keeping the aspect ratio unchanged.
- This function is commonly used in object detection tasks to prepare images for models like YOLOv5.
+ This function is commonly used in object detection tasks to prepare images for models like YOLOv5.
It resizes the image to fit into the new shape with the correct aspect ratio and then pads the rest.
Args:
@@ -64,19 +69,26 @@ def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=False, scale
dw /= 2
dh /= 2
-
+
# Resize image
if shape[::-1] != new_unpad:
- resize_transform = T.Resize(new_unpad[::-1], interpolation=T.InterpolationMode.BILINEAR,
- antialias=False)
+ resize_transform = T.Resize(
+ new_unpad[::-1], interpolation=T.InterpolationMode.BILINEAR, antialias=False
+ )
im = resize_transform(im)
# Pad image
- padding = (int(round(dw - 0.1)), int(round(dw + 0.1)), int(round(dh + 0.1)), int(round(dh - 0.1)))
- im = F.pad(im*255.0, padding, value=114)/255.0
+ padding = (
+ int(round(dw - 0.1)),
+ int(round(dw + 0.1)),
+ int(round(dh + 0.1)),
+ int(round(dh - 0.1)),
+ )
+ im = F.pad(im * 255.0, padding, value=114) / 255.0
return im
+
class MegaDetector_v5_Transform:
"""
A transformation class to preprocess images for the MegaDetector v5 model.
@@ -113,16 +125,18 @@ def __call__(self, np_img) -> torch.Tensor:
np_img = torch.from_numpy(np_img).float()
np_img /= 255.0
- # Resize and pad the image using a customized letterbox function.
+ # Resize and pad the image using a customized letterbox function.
img = letterbox(np_img, new_shape=self.target_size, stride=self.stride, auto=False)
return img
+
class Classification_Inference_Transform:
"""
A transformation class to preprocess images for classification inference.
This includes resizing, normalization, and conversion to a tensor.
"""
+
# Normalization constants
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
@@ -135,12 +149,14 @@ def __init__(self, target_size=224, **kwargs):
target_size (int): Desired size for the height and width after resizing.
"""
# Define the sequence of transformations
- self.trans = transforms.Compose([
- # transforms.Resize((target_size, target_size)),
- transforms.Resize((target_size, target_size), **kwargs),
- transforms.ToTensor(),
- transforms.Normalize(self.mean, self.std)
- ])
+ self.trans = transforms.Compose(
+ [
+ # transforms.Resize((target_size, target_size)),
+ transforms.Resize((target_size, target_size), **kwargs),
+ transforms.ToTensor(),
+ transforms.Normalize(self.mean, self.std),
+ ]
+ )
def __call__(self, img) -> torch.Tensor:
"""
diff --git a/PytorchWildlife/models/__init__.py b/PytorchWildlife/models/__init__.py
index 89e334c07..b134b2934 100644
--- a/PytorchWildlife/models/__init__.py
+++ b/PytorchWildlife/models/__init__.py
@@ -1,2 +1,2 @@
from .classification import *
-from .detection import *
\ No newline at end of file
+from .detection import *
diff --git a/PytorchWildlife/models/classification/__init__.py b/PytorchWildlife/models/classification/__init__.py
index e450ab78b..2963d6ec7 100644
--- a/PytorchWildlife/models/classification/__init__.py
+++ b/PytorchWildlife/models/classification/__init__.py
@@ -1,3 +1,3 @@
+from .base_classifier import *
from .resnet_base import *
from .timm_base import *
-from .base_classifier import *
\ No newline at end of file
diff --git a/PytorchWildlife/models/classification/base_classifier.py b/PytorchWildlife/models/classification/base_classifier.py
index c3a60ed36..5dc57f0cf 100644
--- a/PytorchWildlife/models/classification/base_classifier.py
+++ b/PytorchWildlife/models/classification/base_classifier.py
@@ -1,4 +1,3 @@
-
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
@@ -7,10 +6,12 @@
# Making the PlainResNetInference class available for import from this module
__all__ = ["BaseClassifierInference"]
+
class BaseClassifierInference(nn.Module):
"""
Inference module for the PlainResNet Classifier.
"""
+
def __init__(self):
super(BaseClassifierInference, self).__init__()
pass
diff --git a/PytorchWildlife/models/classification/resnet_base/__init__.py b/PytorchWildlife/models/classification/resnet_base/__init__.py
index 10c6d0ba2..bb2ae189b 100644
--- a/PytorchWildlife/models/classification/resnet_base/__init__.py
+++ b/PytorchWildlife/models/classification/resnet_base/__init__.py
@@ -1,5 +1,5 @@
+from .amazon import *
from .base_classifier import *
+from .custom_weights import *
from .opossum import *
-from .amazon import *
from .serengeti import *
-from .custom_weights import *
\ No newline at end of file
diff --git a/PytorchWildlife/models/classification/resnet_base/amazon.py b/PytorchWildlife/models/classification/resnet_base/amazon.py
index 8b228b04e..92d237322 100644
--- a/PytorchWildlife/models/classification/resnet_base/amazon.py
+++ b/PytorchWildlife/models/classification/resnet_base/amazon.py
@@ -2,11 +2,10 @@
# Licensed under the MIT License.
import torch
+
from .base_classifier import PlainResNetInference
-__all__ = [
- "AI4GAmazonRainforest"
-]
+__all__ = ["AI4GAmazonRainforest"]
class AI4GAmazonRainforest(PlainResNetInference):
@@ -14,48 +13,48 @@ class AI4GAmazonRainforest(PlainResNetInference):
Amazon Ranforest Animal Classifier that inherits from PlainResNetInference.
This classifier is specialized for recognizing 36 different animals in the Amazon Rainforest.
"""
-
+
# Image size for the Opossum classifier
IMAGE_SIZE = 224
-
+
# Class names for prediction
CLASS_NAMES = {
- 0: 'Dasyprocta',
- 1: 'Bos',
- 2: 'Pecari',
- 3: 'Mazama',
- 4: 'Cuniculus',
- 5: 'Leptotila',
- 6: 'Human',
- 7: 'Aramides',
- 8: 'Tinamus',
- 9: 'Eira',
- 10: 'Crax',
- 11: 'Procyon',
- 12: 'Capra',
- 13: 'Dasypus',
- 14: 'Sciurus',
- 15: 'Crypturellus',
- 16: 'Tamandua',
- 17: 'Proechimys',
- 18: 'Leopardus',
- 19: 'Equus',
- 20: 'Columbina',
- 21: 'Nyctidromus',
- 22: 'Ortalis',
- 23: 'Emballonura',
- 24: 'Odontophorus',
- 25: 'Geotrygon',
- 26: 'Metachirus',
- 27: 'Catharus',
- 28: 'Cerdocyon',
- 29: 'Momotus',
- 30: 'Tapirus',
- 31: 'Canis',
- 32: 'Furnarius',
- 33: 'Didelphis',
- 34: 'Sylvilagus',
- 35: 'Unknown'
+ 0: "Dasyprocta",
+ 1: "Bos",
+ 2: "Pecari",
+ 3: "Mazama",
+ 4: "Cuniculus",
+ 5: "Leptotila",
+ 6: "Human",
+ 7: "Aramides",
+ 8: "Tinamus",
+ 9: "Eira",
+ 10: "Crax",
+ 11: "Procyon",
+ 12: "Capra",
+ 13: "Dasypus",
+ 14: "Sciurus",
+ 15: "Crypturellus",
+ 16: "Tamandua",
+ 17: "Proechimys",
+ 18: "Leopardus",
+ 19: "Equus",
+ 20: "Columbina",
+ 21: "Nyctidromus",
+ 22: "Ortalis",
+ 23: "Emballonura",
+ 24: "Odontophorus",
+ 25: "Geotrygon",
+ 26: "Metachirus",
+ 27: "Catharus",
+ 28: "Cerdocyon",
+ 29: "Momotus",
+ 30: "Tapirus",
+ 31: "Canis",
+ 32: "Furnarius",
+ 33: "Didelphis",
+ 34: "Sylvilagus",
+ 35: "Unknown",
}
def __init__(self, weights=None, device="cpu", pretrained=True, version="v2"):
@@ -71,29 +70,34 @@ def __init__(self, weights=None, device="cpu", pretrained=True, version="v2"):
# If pretrained, use the provided URL to fetch the weights
if pretrained:
- if version == 'v1':
+ if version == "v1":
url = "https://zenodo.org/records/10042023/files/AI4GAmazonClassification_v0.0.0.ckpt?download=1"
- elif version == 'v2':
- url = "https://zenodo.org/records/14252214/files/AI4GAmazonDeforestationv2?download=1"
+ elif version == "v2":
+ url = (
+ "https://zenodo.org/records/14252214/files/AI4GAmazonDeforestationv2?download=1"
+ )
else:
url = None
- super(AI4GAmazonRainforest, self).__init__(weights=weights, device=device,
- num_cls=36, num_layers=50, url=url)
+ super(AI4GAmazonRainforest, self).__init__(
+ weights=weights, device=device, num_cls=36, num_layers=50, url=url
+ )
- def results_generation(self, logits: torch.Tensor, img_ids: list[str], id_strip: str = None) -> list[dict]:
+ def results_generation(
+ self, logits: torch.Tensor, img_ids: list[str], id_strip: str = None
+ ) -> list[dict]:
"""
Generate results for classification.
Args:
logits (torch.Tensor): Output tensor from the model.
img_ids (str): Image identifier.
- id_strip (str): stiping string for better image id saving.
+ id_strip (str): stiping string for better image id saving.
Returns:
dict: Dictionary containing image ID, prediction, and confidence score.
"""
-
+
probs = torch.softmax(logits, dim=1)
preds = probs.argmax(dim=1)
confs = probs.max(dim=1)[0]
@@ -108,5 +112,5 @@ def results_generation(self, logits: torch.Tensor, img_ids: list[str], id_strip:
r["confidence"] = conf.item()
r["all_confidences"] = result
results.append(r)
-
+
return results
diff --git a/PytorchWildlife/models/classification/resnet_base/base_classifier.py b/PytorchWildlife/models/classification/resnet_base/base_classifier.py
index 965ff4389..fc18dfbe4 100644
--- a/PytorchWildlife/models/classification/resnet_base/base_classifier.py
+++ b/PytorchWildlife/models/classification/resnet_base/base_classifier.py
@@ -1,20 +1,20 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
-import numpy as np
-from PIL import Image
-from tqdm import tqdm
from collections import OrderedDict
+import numpy as np
import torch
import torch.nn as nn
-from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet
+from PIL import Image
from torch.hub import load_state_dict_from_url
from torch.utils.data import DataLoader
+from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet
+from tqdm import tqdm
-from ..base_classifier import BaseClassifierInference
+from ....data import datasets as pw_data
from ....data import transforms as pw_trans
-from ....data import datasets as pw_data
+from ..base_classifier import BaseClassifierInference
# Making the PlainResNetInference class available for import from this module
__all__ = ["PlainResNetInference"]
@@ -24,6 +24,7 @@ class ResNetBackbone(ResNet):
"""
Custom ResNet Backbone that extracts features from input images.
"""
+
def _forward_impl(self, x):
# Following the ResNet structure to extract features
x = self.conv1(x)
@@ -45,6 +46,7 @@ class PlainResNetClassifier(nn.Module):
"""
Basic ResNet Classifier that uses a custom ResNet backbone.
"""
+
name = "PlainResNetClassifier"
def __init__(self, num_cls=1, num_layers=50):
@@ -88,8 +90,12 @@ def feat_init(self):
Initialize the features using pretrained weights.
"""
init_weights = self.pretrained_weights.get_state_dict(progress=True)
- init_weights = OrderedDict({k.replace("module.", "").replace("feature.", ""): init_weights[k]
- for k in init_weights})
+ init_weights = OrderedDict(
+ {
+ k.replace("module.", "").replace("feature.", ""): init_weights[k]
+ for k in init_weights
+ }
+ )
self.feature.load_state_dict(init_weights, strict=False)
# Print missing and unused keys for debugging purposes
load_keys = set(init_weights.keys())
@@ -104,8 +110,12 @@ class PlainResNetInference(BaseClassifierInference):
"""
Inference module for the PlainResNet Classifier.
"""
+
IMAGE_SIZE = None
- def __init__(self, num_cls=36, num_layers=50, weights=None, device="cpu", url=None, transform=None):
+
+ def __init__(
+ self, num_cls=36, num_layers=50, weights=None, device="cpu", url=None, transform=None
+ ):
super(PlainResNetInference, self).__init__()
self.device = device
self.net = PlainResNetClassifier(num_cls=num_cls, num_layers=num_layers)
@@ -122,16 +132,20 @@ def __init__(self, num_cls=36, num_layers=50, weights=None, device="cpu", url=No
if transform:
self.transform = transform
else:
- self.transform = pw_trans.Classification_Inference_Transform(target_size=self.IMAGE_SIZE)
+ self.transform = pw_trans.Classification_Inference_Transform(
+ target_size=self.IMAGE_SIZE
+ )
- def results_generation(self, logits: torch.Tensor, img_ids: list[str], id_strip: str = None) -> list[dict]:
+ def results_generation(
+ self, logits: torch.Tensor, img_ids: list[str], id_strip: str = None
+ ) -> list[dict]:
"""
- Process logits to produce final results.
+ Process logits to produce final results.
Args:
logits (torch.Tensor): Logits from the network.
img_ids (list[str]): List of image paths.
- id_strip (str): Stripping string for better image ID saving.
+ id_strip (str): Stripping string for better image ID saving.
Returns:
list[dict]: List of dictionaries containing the results.
@@ -158,26 +172,19 @@ def batch_image_classification(self, data_path=None, det_results=None, id_strip=
"""
if data_path:
- dataset = pw_data.ImageFolder(
- data_path,
- transform=self.transform,
- path_head='.'
- )
+ dataset = pw_data.ImageFolder(data_path, transform=self.transform, path_head=".")
elif det_results:
- dataset = pw_data.DetectionCrops(
- det_results,
- transform=self.transform,
- path_head='.'
- )
+ dataset = pw_data.DetectionCrops(det_results, transform=self.transform, path_head=".")
else:
raise Exception("Need data for inference.")
- dataloader = DataLoader(dataset, batch_size=32, shuffle=False,
- pin_memory=True, num_workers=4, drop_last=False)
+ dataloader = DataLoader(
+ dataset, batch_size=32, shuffle=False, pin_memory=True, num_workers=4, drop_last=False
+ )
total_logits = []
total_paths = []
- with tqdm(total=len(dataloader)) as pbar:
+ with tqdm(total=len(dataloader)) as pbar:
for batch in dataloader:
imgs, paths = batch
imgs = imgs.to(self.device)
diff --git a/PytorchWildlife/models/classification/resnet_base/custom_weights.py b/PytorchWildlife/models/classification/resnet_base/custom_weights.py
index 7247abcf5..8027e2b70 100644
--- a/PytorchWildlife/models/classification/resnet_base/custom_weights.py
+++ b/PytorchWildlife/models/classification/resnet_base/custom_weights.py
@@ -2,11 +2,10 @@
# Licensed under the MIT License.
import torch
+
from .base_classifier import PlainResNetInference
-__all__ = [
- "CustomWeights"
-]
+__all__ = ["CustomWeights"]
class CustomWeights(PlainResNetInference):
@@ -14,11 +13,10 @@ class CustomWeights(PlainResNetInference):
Custom Weight Classifier that inherits from PlainResNetInference.
This classifier can load any model that was based on the PytorchWildlife finetuning tool.
"""
-
+
# Image size for the classifier
IMAGE_SIZE = 224
-
def __init__(self, weights=None, class_names=None, device="cpu"):
"""
Initialize the CustomWeights Classifier.
@@ -30,10 +28,13 @@ def __init__(self, weights=None, class_names=None, device="cpu"):
"""
self.CLASS_NAMES = class_names
self.num_cls = len(self.CLASS_NAMES)
- super(CustomWeights, self).__init__(weights=weights, device=device,
- num_cls=self.num_cls, num_layers=50, url=None)
+ super(CustomWeights, self).__init__(
+ weights=weights, device=device, num_cls=self.num_cls, num_layers=50, url=None
+ )
- def results_generation(self, logits: torch.Tensor, img_ids: list[str], id_strip: str = None) -> list[dict]:
+ def results_generation(
+ self, logits: torch.Tensor, img_ids: list[str], id_strip: str = None
+ ) -> list[dict]:
"""
Generate results for classification.
@@ -45,7 +46,7 @@ def results_generation(self, logits: torch.Tensor, img_ids: list[str], id_strip:
Returns:
list[dict]: List of dictionaries containing image ID, prediction, and confidence score.
"""
-
+
probs = torch.softmax(logits, dim=1)
preds = probs.argmax(dim=1)
confs = probs.max(dim=1)[0]
@@ -60,5 +61,5 @@ def results_generation(self, logits: torch.Tensor, img_ids: list[str], id_strip:
r["confidence"] = conf.item()
r["all_confidences"] = result
results.append(r)
-
+
return results
diff --git a/PytorchWildlife/models/classification/resnet_base/opossum.py b/PytorchWildlife/models/classification/resnet_base/opossum.py
index 453533394..5217e18aa 100644
--- a/PytorchWildlife/models/classification/resnet_base/opossum.py
+++ b/PytorchWildlife/models/classification/resnet_base/opossum.py
@@ -2,11 +2,10 @@
# Licensed under the MIT License.
import torch
+
from .base_classifier import PlainResNetInference
-__all__ = [
- "AI4GOpossum"
-]
+__all__ = ["AI4GOpossum"]
class AI4GOpossum(PlainResNetInference):
@@ -14,15 +13,12 @@ class AI4GOpossum(PlainResNetInference):
Opossum Classifier that inherits from PlainResNetInference.
This classifier is specialized for distinguishing between Opossums and Non-opossums.
"""
-
+
# Image size for the Opossum classifier
IMAGE_SIZE = 224
-
+
# Class names for prediction
- CLASS_NAMES = {
- 0: "Non-opossum",
- 1: "Opossum"
- }
+ CLASS_NAMES = {0: "Non-opossum", 1: "Opossum"}
def __init__(self, weights=None, device="cpu", pretrained=True):
"""
@@ -40,17 +36,20 @@ def __init__(self, weights=None, device="cpu", pretrained=True):
else:
url = None
- super(AI4GOpossum, self).__init__(weights=weights, device=device,
- num_cls=1, num_layers=50, url=url)
+ super(AI4GOpossum, self).__init__(
+ weights=weights, device=device, num_cls=1, num_layers=50, url=url
+ )
- def results_generation(self, logits: torch.Tensor, img_ids: list[str], id_strip: str = None) -> list[dict]:
+ def results_generation(
+ self, logits: torch.Tensor, img_ids: list[str], id_strip: str = None
+ ) -> list[dict]:
"""
Generate results for classification.
Args:
logits (torch.Tensor): Output tensor from the model.
img_ids (list): List of image identifier.
- id_strip (str): stiping string for better image id saving.
+ id_strip (str): stiping string for better image id saving.
Returns:
dict: Dictionary containing image ID, prediction, and confidence score.
@@ -66,5 +65,5 @@ def results_generation(self, logits: torch.Tensor, img_ids: list[str], id_strip:
r["class_id"] = pred
r["confidence"] = prob.item() if pred == 1 else (1 - prob.item())
results.append(r)
-
+
return results
diff --git a/PytorchWildlife/models/classification/resnet_base/serengeti.py b/PytorchWildlife/models/classification/resnet_base/serengeti.py
index e0820a02f..63d0ee697 100644
--- a/PytorchWildlife/models/classification/resnet_base/serengeti.py
+++ b/PytorchWildlife/models/classification/resnet_base/serengeti.py
@@ -2,11 +2,10 @@
# Licensed under the MIT License.
import torch
+
from .base_classifier import PlainResNetInference
-__all__ = [
- "AI4GSnapshotSerengeti"
-]
+__all__ = ["AI4GSnapshotSerengeti"]
class AI4GSnapshotSerengeti(PlainResNetInference):
@@ -14,22 +13,22 @@ class AI4GSnapshotSerengeti(PlainResNetInference):
Snapshot Serengeti Animal Classifier that inherits from PlainResNetInference.
This classifier is specialized for recognizing 9 different animals and has 1 'other' class.
"""
-
+
# Image size for the Opossum classifier
IMAGE_SIZE = 224
-
+
# Class names for prediction
CLASS_NAMES = {
- 0: 'wildebeest',
- 1: 'guineafowl',
- 2: 'zebra',
- 3: 'buffalo',
- 4: 'gazellethomsons',
- 5: 'gazellegrants',
- 6: 'warthog',
- 7: 'impala',
- 8: 'hyenaspotted',
- 9: 'other'
+ 0: "wildebeest",
+ 1: "guineafowl",
+ 2: "zebra",
+ 3: "buffalo",
+ 4: "gazellethomsons",
+ 5: "gazellegrants",
+ 6: "warthog",
+ 7: "impala",
+ 8: "hyenaspotted",
+ 9: "other",
}
def __init__(self, weights=None, device="cpu", pretrained=True):
@@ -48,22 +47,25 @@ def __init__(self, weights=None, device="cpu", pretrained=True):
else:
url = None
- super(AI4GSnapshotSerengeti, self).__init__(weights=weights, device=device,
- num_cls=10, num_layers=18, url=url)
+ super(AI4GSnapshotSerengeti, self).__init__(
+ weights=weights, device=device, num_cls=10, num_layers=18, url=url
+ )
- def results_generation(self, logits: torch.Tensor, img_ids: list[str], id_strip: str = None) -> list[dict]:
+ def results_generation(
+ self, logits: torch.Tensor, img_ids: list[str], id_strip: str = None
+ ) -> list[dict]:
"""
Generate results for classification.
Args:
logits (torch.Tensor): Output tensor from the model.
img_ids (str): Image identifier.
- id_strip (str): stiping string for better image id saving.
+ id_strip (str): stiping string for better image id saving.
Returns:
dict: Dictionary containing image ID, prediction, and confidence score.
"""
-
+
probs = torch.softmax(logits, dim=1)
preds = probs.argmax(dim=1)
confs = probs.max(dim=1)[0]
@@ -78,5 +80,5 @@ def results_generation(self, logits: torch.Tensor, img_ids: list[str], id_strip:
r["confidence"] = conf.item()
r["all_confidences"] = result
results.append(r)
-
+
return results
diff --git a/PytorchWildlife/models/classification/timm_base/DFNE.py b/PytorchWildlife/models/classification/timm_base/DFNE.py
index 1cb6d71b0..40674d6c1 100644
--- a/PytorchWildlife/models/classification/timm_base/DFNE.py
+++ b/PytorchWildlife/models/classification/timm_base/DFNE.py
@@ -10,47 +10,53 @@
from .base_classifier import TIMM_BaseClassifierInference
-__all__ = [
- "DFNE"
-]
+__all__ = ["DFNE"]
+
class DFNE(TIMM_BaseClassifierInference):
"""
Base detector class for dinov2 classifier. This class provides utility methods
- for loading the model, performing single and batch image classifications, and
- formatting results. Make sure the appropriate file for the model weights has been
+ for loading the model, performing single and batch image classifications, and
+ formatting results. Make sure the appropriate file for the model weights has been
downloaded to the "models" folder before running DFNE.
"""
+
BACKBONE = "vit_large_patch14_dinov2.lvd142m"
MODEL_NAME = "dfne_weights_v1_0.pth"
IMAGE_SIZE = 182
CLASS_NAMES = {
- 0: "American Marten",
- 1: "Bird sp.",
- 2: "Black Bear",
- 3: "Bobcat",
- 4: "Coyote",
- 5: "Domestic Cat",
- 6: "Domestic Cow",
- 7: "Domestic Dog",
- 8: "Fisher",
- 9: "Gray Fox",
- 10: "Gray Squirrel",
- 11: "Human",
- 12: "Moose",
- 13: "Mouse sp.",
- 14: "Opossum",
- 15: "Raccoon",
- 16: "Red Fox",
- 17: "Red Squirrel",
- 18: "Skunk",
- 19: "Snowshoe Hare",
- 20: "White-tailed Deer",
- 21: "Wild Boar",
- 22: "Wild Turkey",
- 23: "no-species"
- }
+ 0: "American Marten",
+ 1: "Bird sp.",
+ 2: "Black Bear",
+ 3: "Bobcat",
+ 4: "Coyote",
+ 5: "Domestic Cat",
+ 6: "Domestic Cow",
+ 7: "Domestic Dog",
+ 8: "Fisher",
+ 9: "Gray Fox",
+ 10: "Gray Squirrel",
+ 11: "Human",
+ 12: "Moose",
+ 13: "Mouse sp.",
+ 14: "Opossum",
+ 15: "Raccoon",
+ 16: "Red Fox",
+ 17: "Red Squirrel",
+ 18: "Skunk",
+ 19: "Snowshoe Hare",
+ 20: "White-tailed Deer",
+ 21: "Wild Boar",
+ 22: "Wild Turkey",
+ 23: "no-species",
+ }
def __init__(self, weights=None, device="cpu", transform=None):
- url = 'https://prod-is-usgs-sb-prod-publish.s3.amazonaws.com/67ae17fcd34e3f09c0e0f002/dfne_weights_v1_0.pth'
- super(DFNE, self).__init__(weights=weights, device=device, url=url, transform=transform, weights_key='model_state_dict')
\ No newline at end of file
+ url = "https://prod-is-usgs-sb-prod-publish.s3.amazonaws.com/67ae17fcd34e3f09c0e0f002/dfne_weights_v1_0.pth"
+ super(DFNE, self).__init__(
+ weights=weights,
+ device=device,
+ url=url,
+ transform=transform,
+ weights_key="model_state_dict",
+ )
diff --git a/PytorchWildlife/models/classification/timm_base/Deepfaune.py b/PytorchWildlife/models/classification/timm_base/Deepfaune.py
index a024729b7..8c931838a 100644
--- a/PytorchWildlife/models/classification/timm_base/Deepfaune.py
+++ b/PytorchWildlife/models/classification/timm_base/Deepfaune.py
@@ -9,39 +9,186 @@
# Import libraries
from torchvision.transforms.functional import InterpolationMode
-from .base_classifier import TIMM_BaseClassifierInference
+
from ....data import transforms as pw_trans
+from .base_classifier import TIMM_BaseClassifierInference
+
+__all__ = ["DeepfauneClassifier"]
-__all__ = [
- "DeepfauneClassifier"
-]
class DeepfauneClassifier(TIMM_BaseClassifierInference):
"""
Base detector class for dinov2 classifier. This class provides utility methods
- for loading the model, performing single and batch image classifications, and
- formatting results. Make sure the appropriate file for the model weights has been
+ for loading the model, performing single and batch image classifications, and
+ formatting results. Make sure the appropriate file for the model weights has been
downloaded to the "models" folder before running DFNE.
"""
+
BACKBONE = "vit_large_patch14_dinov2.lvd142m"
MODEL_NAME = "deepfaune-vit_large_patch14_dinov2.lvd142m.v3.pt"
IMAGE_SIZE = 182
- CLASS_NAMES={
- 'fr': ['bison', 'blaireau', 'bouquetin', 'castor', 'cerf', 'chamois', 'chat', 'chevre', 'chevreuil', 'chien', 'daim', 'ecureuil', 'elan', 'equide', 'genette', 'glouton', 'herisson', 'lagomorphe', 'loup', 'loutre', 'lynx', 'marmotte', 'micromammifere', 'mouflon', 'mouton', 'mustelide', 'oiseau', 'ours', 'ragondin', 'raton laveur', 'renard', 'renne', 'sanglier', 'vache'],
- 'en': ['bison', 'badger', 'ibex', 'beaver', 'red deer', 'chamois', 'cat', 'goat', 'roe deer', 'dog', 'fallow deer', 'squirrel', 'moose', 'equid', 'genet', 'wolverine', 'hedgehog', 'lagomorph', 'wolf', 'otter', 'lynx', 'marmot', 'micromammal', 'mouflon', 'sheep', 'mustelid', 'bird', 'bear', 'nutria', 'raccoon', 'fox', 'reindeer', 'wild boar', 'cow'],
- 'it': ['bisonte', 'tasso', 'stambecco', 'castoro', 'cervo', 'camoscio', 'gatto', 'capra', 'capriolo', 'cane', 'daino', 'scoiattolo', 'alce', 'equide', 'genetta', 'ghiottone', 'riccio', 'lagomorfo', 'lupo', 'lontra', 'lince', 'marmotta', 'micromammifero', 'muflone', 'pecora', 'mustelide', 'uccello', 'orso', 'nutria', 'procione', 'volpe', 'renna', 'cinghiale', 'mucca'],
- 'de': ['Bison', 'Dachs', 'Steinbock', 'Biber', 'Rothirsch', 'Gämse', 'Katze', 'Ziege', 'Rehwild', 'Hund', 'Damwild', 'Eichhörnchen', 'Elch', 'Equide', 'Ginsterkatze', 'Vielfraß', 'Igel', 'Lagomorpha', 'Wolf', 'Otter', 'Luchs', 'Murmeltier', 'Kleinsäuger', 'Mufflon', 'Schaf', 'Marder', 'Vogel', 'Bär', 'Nutria', 'Waschbär', 'Fuchs', 'Rentier', 'Wildschwein', 'Kuh'],
+ CLASS_NAMES = {
+ "fr": [
+ "bison",
+ "blaireau",
+ "bouquetin",
+ "castor",
+ "cerf",
+ "chamois",
+ "chat",
+ "chevre",
+ "chevreuil",
+ "chien",
+ "daim",
+ "ecureuil",
+ "elan",
+ "equide",
+ "genette",
+ "glouton",
+ "herisson",
+ "lagomorphe",
+ "loup",
+ "loutre",
+ "lynx",
+ "marmotte",
+ "micromammifere",
+ "mouflon",
+ "mouton",
+ "mustelide",
+ "oiseau",
+ "ours",
+ "ragondin",
+ "raton laveur",
+ "renard",
+ "renne",
+ "sanglier",
+ "vache",
+ ],
+ "en": [
+ "bison",
+ "badger",
+ "ibex",
+ "beaver",
+ "red deer",
+ "chamois",
+ "cat",
+ "goat",
+ "roe deer",
+ "dog",
+ "fallow deer",
+ "squirrel",
+ "moose",
+ "equid",
+ "genet",
+ "wolverine",
+ "hedgehog",
+ "lagomorph",
+ "wolf",
+ "otter",
+ "lynx",
+ "marmot",
+ "micromammal",
+ "mouflon",
+ "sheep",
+ "mustelid",
+ "bird",
+ "bear",
+ "nutria",
+ "raccoon",
+ "fox",
+ "reindeer",
+ "wild boar",
+ "cow",
+ ],
+ "it": [
+ "bisonte",
+ "tasso",
+ "stambecco",
+ "castoro",
+ "cervo",
+ "camoscio",
+ "gatto",
+ "capra",
+ "capriolo",
+ "cane",
+ "daino",
+ "scoiattolo",
+ "alce",
+ "equide",
+ "genetta",
+ "ghiottone",
+ "riccio",
+ "lagomorfo",
+ "lupo",
+ "lontra",
+ "lince",
+ "marmotta",
+ "micromammifero",
+ "muflone",
+ "pecora",
+ "mustelide",
+ "uccello",
+ "orso",
+ "nutria",
+ "procione",
+ "volpe",
+ "renna",
+ "cinghiale",
+ "mucca",
+ ],
+ "de": [
+ "Bison",
+ "Dachs",
+ "Steinbock",
+ "Biber",
+ "Rothirsch",
+ "Gämse",
+ "Katze",
+ "Ziege",
+ "Rehwild",
+ "Hund",
+ "Damwild",
+ "Eichhörnchen",
+ "Elch",
+ "Equide",
+ "Ginsterkatze",
+ "Vielfraß",
+ "Igel",
+ "Lagomorpha",
+ "Wolf",
+ "Otter",
+ "Luchs",
+ "Murmeltier",
+ "Kleinsäuger",
+ "Mufflon",
+ "Schaf",
+ "Marder",
+ "Vogel",
+ "Bär",
+ "Nutria",
+ "Waschbär",
+ "Fuchs",
+ "Rentier",
+ "Wildschwein",
+ "Kuh",
+ ],
}
-
- def __init__(self, weights=None, device="cpu", transform=None, class_name_lang='en'):
- url = 'https://pbil.univ-lyon1.fr/software/download/deepfaune/v1.3/deepfaune-vit_large_patch14_dinov2.lvd142m.v3.pt'
+ def __init__(self, weights=None, device="cpu", transform=None, class_name_lang="en"):
+ url = "https://pbil.univ-lyon1.fr/software/download/deepfaune/v1.3/deepfaune-vit_large_patch14_dinov2.lvd142m.v3.pt"
self.CLASS_NAMES = {i: c for i, c in enumerate(self.CLASS_NAMES[class_name_lang])}
if transform is None:
- transform = pw_trans.Classification_Inference_Transform(target_size=self.IMAGE_SIZE,
- interpolation=InterpolationMode.BICUBIC,
- max_size=None,
- antialias=None)
- super(DeepfauneClassifier, self).__init__(weights=weights, device=device, url=url, transform=transform,
- weights_key='state_dict', weights_prefix='base_model.')
-
\ No newline at end of file
+ transform = pw_trans.Classification_Inference_Transform(
+ target_size=self.IMAGE_SIZE,
+ interpolation=InterpolationMode.BICUBIC,
+ max_size=None,
+ antialias=None,
+ )
+ super(DeepfauneClassifier, self).__init__(
+ weights=weights,
+ device=device,
+ url=url,
+ transform=transform,
+ weights_key="state_dict",
+ weights_prefix="base_model.",
+ )
diff --git a/PytorchWildlife/models/classification/timm_base/__init__.py b/PytorchWildlife/models/classification/timm_base/__init__.py
index 1938f933a..806a29c9f 100644
--- a/PytorchWildlife/models/classification/timm_base/__init__.py
+++ b/PytorchWildlife/models/classification/timm_base/__init__.py
@@ -1,3 +1,3 @@
from .base_classifier import *
from .Deepfaune import *
-from .DFNE import *
\ No newline at end of file
+from .DFNE import *
diff --git a/PytorchWildlife/models/classification/timm_base/base_classifier.py b/PytorchWildlife/models/classification/timm_base/base_classifier.py
index 682333b17..452faea39 100644
--- a/PytorchWildlife/models/classification/timm_base/base_classifier.py
+++ b/PytorchWildlife/models/classification/timm_base/base_classifier.py
@@ -3,28 +3,27 @@
# Import libraries
import os
-import wget
-import numpy as np
-import pandas as pd
-from tqdm import tqdm
-from PIL import Image
from collections import OrderedDict
+import numpy as np
+import pandas as pd
+import timm
import torch
+import wget
+from PIL import Image
from torch.utils.data import DataLoader
+from tqdm import tqdm
-import timm
-
-from ..base_classifier import BaseClassifierInference
+from ....data import datasets as pw_data
from ....data import transforms as pw_trans
-from ....data import datasets as pw_data
+from ..base_classifier import BaseClassifierInference
class TIMM_BaseClassifierInference(BaseClassifierInference):
"""
Base detector class for dinov2 classifier. This class provides utility methods
- for loading the model, performing single and batch image classifications, and
- formatting results. Make sure the appropriate file for the model weights has been
+ for loading the model, performing single and batch image classifications, and
+ formatting results. Make sure the appropriate file for the model weights has been
downloaded to the "models" folder before running DFNE.
"""
@@ -32,21 +31,28 @@ class TIMM_BaseClassifierInference(BaseClassifierInference):
MODEL_NAME = None
IMAGE_SIZE = None
- def __init__(self, weights=None, device="cpu", url=None, transform=None,
- weights_key='model_state_dict', weights_prefix=''):
+ def __init__(
+ self,
+ weights=None,
+ device="cpu",
+ url=None,
+ transform=None,
+ weights_key="model_state_dict",
+ weights_prefix="",
+ ):
"""
Initialize the model.
-
+
Args:
- weights (str, optional):
+ weights (str, optional):
Path to the model weights. Defaults to None.
- device (str, optional):
+ device (str, optional):
Device for model inference. Defaults to "cpu".
- url (str, optional):
+ url (str, optional):
URL to fetch the model weights. Defaults to None.
- weights_key (str, optional):
+ weights_key (str, optional):
Key to fetch the model weights. Defaults to None.
- weights_prefix (str, optional):
+ weights_prefix (str, optional):
prefix of model weight keys. Defaults to None.
"""
super(TIMM_BaseClassifierInference, self).__init__()
@@ -55,30 +61,36 @@ def __init__(self, weights=None, device="cpu", url=None, transform=None,
if transform:
self.transform = transform
else:
- self.transform = pw_trans.Classification_Inference_Transform(target_size=self.IMAGE_SIZE)
+ self.transform = pw_trans.Classification_Inference_Transform(
+ target_size=self.IMAGE_SIZE
+ )
self._load_model(weights, url, weights_key, weights_prefix)
- def _load_model(self, weights=None, url=None, weights_key='model_state_dict', weights_prefix=''):
+ def _load_model(
+ self, weights=None, url=None, weights_key="model_state_dict", weights_prefix=""
+ ):
"""
Load TIMM based model weights
-
+
Args:
- weights (str, optional):
+ weights (str, optional):
Path to the model weights. (defaults to None)
- url (str, optional):
+ url (str, optional):
url to the model weights. (defaults to None)
"""
self.predictor = timm.create_model(
- self.BACKBONE,
- pretrained = False,
- num_classes = len(self.CLASS_NAMES),
- dynamic_img_size = True
+ self.BACKBONE,
+ pretrained=False,
+ num_classes=len(self.CLASS_NAMES),
+ dynamic_img_size=True,
)
if url:
- if not os.path.exists(os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME)):
+ if not os.path.exists(
+ os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME)
+ ):
os.makedirs(os.path.join(torch.hub.get_dir(), "checkpoints"), exist_ok=True)
weights = wget.download(url, out=os.path.join(torch.hub.get_dir(), "checkpoints"))
else:
@@ -86,22 +98,27 @@ def _load_model(self, weights=None, url=None, weights_key='model_state_dict', we
elif weights is None:
raise Exception("Need weights for inference.")
- checkpoint = torch.load(
- f = weights,
- map_location = self.device,
- weights_only = False
- )[weights_key]
+ checkpoint = torch.load(f=weights, map_location=self.device, weights_only=False)[
+ weights_key
+ ]
- checkpoint = OrderedDict({k.replace("{}".format(weights_prefix), ""): checkpoint[k]
- for k in checkpoint})
+ checkpoint = OrderedDict(
+ {k.replace("{}".format(weights_prefix), ""): checkpoint[k] for k in checkpoint}
+ )
self.predictor.load_state_dict(checkpoint)
- print("Model loaded from {}".format(os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME)))
+ print(
+ "Model loaded from {}".format(
+ os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME)
+ )
+ )
self.predictor.to(self.device)
self.eval()
- def results_generation(self, logits: torch.Tensor, img_ids: list[str], id_strip: str = None) -> list[dict]:
+ def results_generation(
+ self, logits: torch.Tensor, img_ids: list[str], id_strip: str = None
+ ) -> list[dict]:
"""
Generate results for classification.
@@ -113,7 +130,7 @@ def results_generation(self, logits: torch.Tensor, img_ids: list[str], id_strip:
Returns:
list[dict]: List of dictionaries containing image ID, prediction, and confidence score.
"""
-
+
probs = torch.softmax(logits, dim=1)
preds = probs.argmax(dim=1)
confs = probs.max(dim=1)[0]
@@ -128,17 +145,17 @@ def results_generation(self, logits: torch.Tensor, img_ids: list[str], id_strip:
r["confidence"] = conf.item()
r["all_confidences"] = result
results.append(r)
-
+
return results
def single_image_classification(self, img, img_id=None, id_strip=None):
"""
Perform classification on a single image.
-
+
Args:
- img (str or ndarray):
+ img (str or ndarray):
Image path or ndarray of images.
- img_id (str, optional):
+ img_id (str, optional):
Image path or identifier.
id_strip (str, optional):
Whether to strip stings in id. Defaults to None.
@@ -155,14 +172,21 @@ def single_image_classification(self, img, img_id=None, id_strip=None):
logits = self.predictor(img.unsqueeze(0).to(self.device))
return self.results_generation(logits.cpu(), [img_id], id_strip=id_strip)[0]
- def batch_image_classification(self, data_path=None, det_results=None, id_strip=None,
- batch_size=32, num_workers=0, **kwargs):
+ def batch_image_classification(
+ self,
+ data_path=None,
+ det_results=None,
+ id_strip=None,
+ batch_size=32,
+ num_workers=0,
+ **kwargs
+ ):
"""
Perform classification on a batch of images.
-
+
Args:
- data_path (str):
- Path containing all images for inference. Defaults to None.
+ data_path (str):
+ Path containing all images for inference. Defaults to None.
det_results (dict):
Dirct outputs from detectors. Defaults to None.
id_strip (str, optional):
@@ -177,27 +201,26 @@ def batch_image_classification(self, data_path=None, det_results=None, id_strip=
"""
if data_path:
- dataset = pw_data.ImageFolder(
- data_path,
- transform=self.transform,
- path_head='.'
- )
+ dataset = pw_data.ImageFolder(data_path, transform=self.transform, path_head=".")
elif det_results:
- dataset = pw_data.DetectionCrops(
- det_results,
- transform=self.transform,
- path_head='.'
- )
+ dataset = pw_data.DetectionCrops(det_results, transform=self.transform, path_head=".")
else:
raise Exception("Need data for inference.")
- dataloader = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=num_workers,
- shuffle=False, pin_memory=True, drop_last=False, **kwargs)
-
+ dataloader = DataLoader(
+ dataset=dataset,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ shuffle=False,
+ pin_memory=True,
+ drop_last=False,
+ **kwargs
+ )
+
total_logits = []
total_paths = []
- with tqdm(total=len(dataloader)) as pbar:
+ with tqdm(total=len(dataloader)) as pbar:
for batch in dataloader:
imgs, paths = batch
imgs = imgs.to(self.device)
diff --git a/PytorchWildlife/models/detection/__init__.py b/PytorchWildlife/models/detection/__init__.py
index 7d42c184d..63b2c1840 100644
--- a/PytorchWildlife/models/detection/__init__.py
+++ b/PytorchWildlife/models/detection/__init__.py
@@ -1,4 +1,4 @@
-from .ultralytics_based import *
from .herdnet import *
+from .rtdetr_apache import *
+from .ultralytics_based import *
from .yolo_mit import *
-from .rtdetr_apache import *
\ No newline at end of file
diff --git a/PytorchWildlife/models/detection/base_detector.py b/PytorchWildlife/models/detection/base_detector.py
index 00cbfc5cc..63300e1c3 100644
--- a/PytorchWildlife/models/detection/base_detector.py
+++ b/PytorchWildlife/models/detection/base_detector.py
@@ -6,12 +6,13 @@
# Importing basic libraries
from torch import nn
+
class BaseDetector(nn.Module):
"""
Base detector class. This class provides utility methods for
loading the model, generating results, and performing single and batch image detections.
"""
-
+
# Placeholder class-level attributes to be defined in derived classes
IMAGE_SIZE = None
STRIDE = None
@@ -21,29 +22,28 @@ class BaseDetector(nn.Module):
def __init__(self, weights=None, device="cpu", url=None):
"""
Initialize the base detector.
-
+
Args:
- weights (str, optional):
+ weights (str, optional):
Path to the model weights. Defaults to None.
- device (str, optional):
+ device (str, optional):
Device for model inference. Defaults to "cpu".
- url (str, optional):
+ url (str, optional):
URL to fetch the model weights. Defaults to None.
"""
super(BaseDetector, self).__init__()
self.device = device
-
def _load_model(self, weights=None, device="cpu", url=None):
"""
Load model weights.
-
+
Args:
- weights (str, optional):
+ weights (str, optional):
Path to the model weights. Defaults to None.
- device (str, optional):
+ device (str, optional):
Device for model inference. Defaults to "cpu".
- url (str, optional):
+ url (str, optional):
URL to fetch the model weights. Defaults to None.
Raises:
Exception: If weights are not provided.
@@ -64,20 +64,22 @@ def results_generation(self, preds, img_id: str, id_strip: str = None) -> dict:
"""
pass
- def single_image_detection(self, img, img_size=None, img_path=None, conf_thres=0.2, id_strip=None) -> dict:
+ def single_image_detection(
+ self, img, img_size=None, img_path=None, conf_thres=0.2, id_strip=None
+ ) -> dict:
"""
Perform detection on a single image.
-
+
Args:
- img (str or ndarray):
+ img (str or ndarray):
Image path or ndarray of images.
- img_size (tuple):
+ img_size (tuple):
Original image size.
- img_path (str):
+ img_path (str):
Image path or identifier.
- conf_thres (float, optional):
+ conf_thres (float, optional):
Confidence threshold for predictions. Defaults to 0.2.
- id_strip (str, optional):
+ id_strip (str, optional):
Characters to strip from img_id. Defaults to None.
Returns:
@@ -85,7 +87,9 @@ def single_image_detection(self, img, img_size=None, img_path=None, conf_thres=0
"""
pass
- def batch_image_detection(self, dataloader, conf_thres: float = 0.2, id_strip: str = None) -> list[dict]:
+ def batch_image_detection(
+ self, dataloader, conf_thres: float = 0.2, id_strip: str = None
+ ) -> list[dict]:
"""
Perform detection on a batch of images.
diff --git a/PytorchWildlife/models/detection/herdnet/__init__.py b/PytorchWildlife/models/detection/herdnet/__init__.py
index b98d2f565..bf695c132 100644
--- a/PytorchWildlife/models/detection/herdnet/__init__.py
+++ b/PytorchWildlife/models/detection/herdnet/__init__.py
@@ -1 +1 @@
-from .herdnet import *
\ No newline at end of file
+from .herdnet import *
diff --git a/PytorchWildlife/models/detection/herdnet/animaloc/__init__.py b/PytorchWildlife/models/detection/herdnet/animaloc/__init__.py
index 74c0b9052..34d34cd52 100644
--- a/PytorchWildlife/models/detection/herdnet/animaloc/__init__.py
+++ b/PytorchWildlife/models/detection/herdnet/animaloc/__init__.py
@@ -1,5 +1,4 @@
-__copyright__ = \
- """
+__copyright__ = """
Copyright (C) 2024 University of Liège, Gembloux Agro-Bio Tech, Forest Is Life
All rights reserved.
diff --git a/PytorchWildlife/models/detection/herdnet/animaloc/data/__init__.py b/PytorchWildlife/models/detection/herdnet/animaloc/data/__init__.py
index 649204b5d..e041cb081 100644
--- a/PytorchWildlife/models/detection/herdnet/animaloc/data/__init__.py
+++ b/PytorchWildlife/models/detection/herdnet/animaloc/data/__init__.py
@@ -1,5 +1,4 @@
-__copyright__ = \
- """
+__copyright__ = """
Copyright (C) 2024 University of Liège, Gembloux Agro-Bio Tech, Forest Is Life
All rights reserved.
diff --git a/PytorchWildlife/models/detection/herdnet/animaloc/data/patches.py b/PytorchWildlife/models/detection/herdnet/animaloc/data/patches.py
index 2eb05ba12..5c41bd96c 100644
--- a/PytorchWildlife/models/detection/herdnet/animaloc/data/patches.py
+++ b/PytorchWildlife/models/detection/herdnet/animaloc/data/patches.py
@@ -1,5 +1,4 @@
-__copyright__ = \
- """
+__copyright__ = """
Copyright (C) 2024 University of Liège, Gembloux Agro-Bio Tech, Forest Is Life
All rights reserved.
@@ -15,41 +14,39 @@
import os
+from typing import Tuple, Union
+
+import matplotlib.pyplot as plt
+import numpy
+import pandas
import PIL
import torch
-import pandas
-import numpy
-import matplotlib.pyplot as plt
import torchvision
from torchvision.utils import make_grid, save_image
-
-from typing import Union, Tuple
-
from tqdm import tqdm
from .types import BoundingBox
-__all__ = ['ImageToPatches']
+__all__ = ["ImageToPatches"]
+
class ImageToPatches:
- ''' Class to make patches from a tensor image '''
+ """Class to make patches from a tensor image"""
def __init__(
- self,
- image: Union[PIL.Image.Image, torch.Tensor],
- size: Tuple[int,int],
- overlap: int = 0
- ) -> None:
- '''
+ self, image: Union[PIL.Image.Image, torch.Tensor], size: Tuple[int, int], overlap: int = 0
+ ) -> None:
+ """
Args:
image (PIL.Image.Image or torch.Tensor): image, if tensor: (C,H,W)
size (tuple): patches size (height, width), in pixels
- overlap (int, optional): overlap between patches, in pixels.
+ overlap (int, optional): overlap between patches, in pixels.
Defaults to 0.
- '''
+ """
- assert isinstance(image, (PIL.Image.Image, torch.Tensor)), \
- 'image must be a PIL.Image.Image or a torch.Tensor instance'
+ assert isinstance(
+ image, (PIL.Image.Image, torch.Tensor)
+ ), "image must be a PIL.Image.Image or a torch.Tensor instance"
self.image = image
if isinstance(self.image, PIL.Image.Image):
@@ -57,32 +54,34 @@ def __init__(
self.size = size
self.overlap = overlap
-
+
def make_patches(self) -> torch.Tensor:
- ''' Make patches from the image
+ """Make patches from the image
- When the image division is not perfect, a zero-padding is performed
+ When the image division is not perfect, a zero-padding is performed
so that the patches have the same size.
Returns:
torch.Tensor:
patches of shape (B,C,H,W)
- '''
+ """
# patches' height & width
- height = min(self.image.size(1),self.size[0])
- width = min(self.image.size(2),self.size[1])
+ height = min(self.image.size(1), self.size[0])
+ width = min(self.image.size(2), self.size[1])
- # unfold on height
+ # unfold on height
height_fold = self.image.unfold(1, height, height - self.overlap)
# if non-perfect division on height
residual = self._img_residual(self.image.size(1), height, self.overlap)
if residual != 0:
# get the residual patch and add it to the fold
- remaining_height = torch.zeros(3, 1, self.image.size(2), height) # padding
- remaining_height[:,:,:,:residual] = self.image[:,-residual:,:].permute(0,2,1).unsqueeze(1)
+ remaining_height = torch.zeros(3, 1, self.image.size(2), height) # padding
+ remaining_height[:, :, :, :residual] = (
+ self.image[:, -residual:, :].permute(0, 2, 1).unsqueeze(1)
+ )
- height_fold = torch.cat((height_fold,remaining_height),dim=1)
+ height_fold = torch.cat((height_fold, remaining_height), dim=1)
# unfold on width
fold = height_fold.unfold(2, width, width - self.overlap)
@@ -90,20 +89,22 @@ def make_patches(self) -> torch.Tensor:
# if non-perfect division on width, the same
residual = self._img_residual(self.image.size(2), width, self.overlap)
if residual != 0:
- remaining_width = torch.zeros(3, fold.shape[1], 1, height, width) # padding
- remaining_width[:,:,:,:,:residual] = height_fold[:,:,-residual:,:].permute(0,1,3,2).unsqueeze(2)
+ remaining_width = torch.zeros(3, fold.shape[1], 1, height, width) # padding
+ remaining_width[:, :, :, :, :residual] = (
+ height_fold[:, :, -residual:, :].permute(0, 1, 3, 2).unsqueeze(2)
+ )
- fold = torch.cat((fold,remaining_width),dim=2)
+ fold = torch.cat((fold, remaining_width), dim=2)
- self._nrow , self._ncol = fold.shape[2] , fold.shape[1]
+ self._nrow, self._ncol = fold.shape[2], fold.shape[1]
# reshaping
- patches = fold.permute(1,2,0,3,4).reshape(-1,self.image.size(0),height,width)
+ patches = fold.permute(1, 2, 0, 3, 4).reshape(-1, self.image.size(0), height, width)
return patches
-
+
def get_limits(self) -> dict:
- ''' Get patches limits within the image frame
+ """Get patches limits within the image frame
When the image division is not perfect, the zero-padding is not
considered here. Hence, the limits are the true limits of patches
@@ -112,69 +113,64 @@ def get_limits(self) -> dict:
Returns:
dict:
a dict containing int as key and BoundingBox as value
- '''
+ """
# patches' height & width
- height = min(self.image.size(1),self.size[0])
- width = min(self.image.size(2),self.size[1])
+ height = min(self.image.size(1), self.size[0])
+ width = min(self.image.size(2), self.size[1])
# lists of pixels numbers
- y_pixels = torch.tensor(list(range(0,self.image.size(1)+1)))
- x_pixels = torch.tensor(list(range(0,self.image.size(2)+1)))
+ y_pixels = torch.tensor(list(range(0, self.image.size(1) + 1)))
+ x_pixels = torch.tensor(list(range(0, self.image.size(2) + 1)))
# cut into patches to get limits
- y_pixels_fold = y_pixels.unfold(0, height+1, height-self.overlap)
+ y_pixels_fold = y_pixels.unfold(0, height + 1, height - self.overlap)
y_mina = [int(patch[0]) for patch in y_pixels_fold]
y_maxa = [int(patch[-1]) for patch in y_pixels_fold]
- x_pixels_fold = x_pixels.unfold(0, width+1, width-self.overlap)
+ x_pixels_fold = x_pixels.unfold(0, width + 1, width - self.overlap)
x_mina = [int(patch[0]) for patch in x_pixels_fold]
x_maxa = [int(patch[-1]) for patch in x_pixels_fold]
# if non-perfect division on height
residual = self._img_residual(self.image.size(1), height, self.overlap)
if residual != 0:
- remaining_y = y_pixels[-residual-1:].unsqueeze(0)[0]
+ remaining_y = y_pixels[-residual - 1 :].unsqueeze(0)[0]
y_mina.append(int(remaining_y[0]))
y_maxa.append(int(remaining_y[-1]))
- # if non-perfect division on width
+ # if non-perfect division on width
residual = self._img_residual(self.image.size(2), width, self.overlap)
if residual != 0:
- remaining_x = x_pixels[-residual-1:].unsqueeze(0)[0]
+ remaining_x = x_pixels[-residual - 1 :].unsqueeze(0)[0]
x_mina.append(int(remaining_x[0]))
x_maxa.append(int(remaining_x[-1]))
-
+
i = 0
patches_limits = {}
- for y_min , y_max in zip(y_mina,y_maxa):
- for x_min , x_max in zip(x_mina,x_maxa):
- patches_limits[i] = BoundingBox(x_min,y_min,x_max,y_max)
+ for y_min, y_max in zip(y_mina, y_maxa):
+ for x_min, x_max in zip(x_mina, x_maxa):
+ patches_limits[i] = BoundingBox(x_min, y_min, x_max, y_max)
i += 1
-
+
return patches_limits
-
+
def show(self) -> None:
- ''' Show the grid of patches '''
+ """Show the grid of patches"""
- grid = make_grid(
- self.make_patches(),
- padding=50,
- nrow=self._nrow
- ).permute(1,2,0).numpy()
+ grid = make_grid(self.make_patches(), padding=50, nrow=self._nrow).permute(1, 2, 0).numpy()
plt.imshow(grid)
plt.show()
return grid
-
- def _img_residual(self, ims: int, ks: int, overlap: int) -> int:
+ def _img_residual(self, ims: int, ks: int, overlap: int) -> int:
ims, stride = int(ims), int(ks - overlap)
n = ims // stride
end = n * stride + overlap
-
+
residual = ims % stride
if end > ims:
@@ -182,6 +178,6 @@ def _img_residual(self, ims: int, ks: int, overlap: int) -> int:
residual = ims - (n * stride)
return residual
-
+
def __len__(self) -> int:
return len(self.get_limits())
diff --git a/PytorchWildlife/models/detection/herdnet/animaloc/data/types.py b/PytorchWildlife/models/detection/herdnet/animaloc/data/types.py
index bd650ac7b..a5e0ad96d 100644
--- a/PytorchWildlife/models/detection/herdnet/animaloc/data/types.py
+++ b/PytorchWildlife/models/detection/herdnet/animaloc/data/types.py
@@ -1,5 +1,4 @@
-__copyright__ = \
- """
+__copyright__ = """
Copyright (C) 2024 University of Liège, Gembloux Agro-Bio Tech, Forest Is Life
All rights reserved.
@@ -13,116 +12,117 @@
__license__ = "MIT License"
__version__ = "0.2.1"
-from typing import Union, Tuple
+from typing import Tuple, Union
+
+__all__ = ["Point", "BoundingBox"]
-__all__ = ['Point', 'BoundingBox']
class Point:
- ''' Class to define a Point object in a 2D Cartesian
+ """Class to define a Point object in a 2D Cartesian
coordinate system.
- '''
+ """
- def __init__(self, x: Union[int,float], y: Union[int,float]) -> None:
- '''
+ def __init__(self, x: Union[int, float], y: Union[int, float]) -> None:
+ """
Args:
x (int, float): x coordinate
y (int, float): y coordinate
- '''
+ """
- assert x >= 0 and y >= 0, f'Coordinates must be positives, got x={x} and y={y}'
+ assert x >= 0 and y >= 0, f"Coordinates must be positives, got x={x} and y={y}"
self.x = x
self.y = y
- self.area = 1 # always 1 pixel
-
+ self.area = 1 # always 1 pixel
+
# @property
# def area(self) -> int:
# ''' To get area '''
# return 1 # always 1 pixel
-
+
@property
- def get_tuple(self) -> Tuple[Union[int,float],Union[int,float]]:
- ''' To get point's coordinates in tuple '''
- return (self.x,self.y)
+ def get_tuple(self) -> Tuple[Union[int, float], Union[int, float]]:
+ """To get point's coordinates in tuple"""
+ return (self.x, self.y)
@property
def atype(self) -> str:
- ''' To get annotation type string '''
- return 'Point'
-
+ """To get annotation type string"""
+ return "Point"
+
def __repr__(self) -> str:
- return f'Point(x: {self.x}, y: {self.y})'
-
+ return f"Point(x: {self.x}, y: {self.y})"
+
def __eq__(self, other) -> bool:
- return all([
- self.x == other.x,
- self.y == other.y
- ])
+ return all([self.x == other.x, self.y == other.y])
+
class BoundingBox:
- ''' Class to define a BoundingBox object in a 2D Cartesian
+ """Class to define a BoundingBox object in a 2D Cartesian
coordinate system.
- '''
+ """
def __init__(
- self,
- x_min: Union[int,float],
- y_min: Union[int,float],
- x_max: Union[int,float],
- y_max: Union[int,float]
- ) -> None:
- '''
+ self,
+ x_min: Union[int, float],
+ y_min: Union[int, float],
+ x_max: Union[int, float],
+ y_max: Union[int, float],
+ ) -> None:
+ """
Args:
x_min (int, float): x bbox top-left coordinate
y_min (int, float): y bbox top-left coordinate
x_max (int, float): x bbox bottom-right coordinate
y_max (int, float): y bbox bottom-right coordinate
- '''
+ """
- assert all([c >= 0 for c in [x_min,y_min,x_max,y_max]]), \
- f'Coordinates must be positives, got x_min={x_min}, y_min={y_min}, ' \
- f'x_max={x_max} and y_max={y_max}'
+ assert all([c >= 0 for c in [x_min, y_min, x_max, y_max]]), (
+ f"Coordinates must be positives, got x_min={x_min}, y_min={y_min}, "
+ f"x_max={x_max} and y_max={y_max}"
+ )
- assert x_max >= x_min and y_max >= y_min, \
- 'Wrong bounding box coordinates.'
+ assert x_max >= x_min and y_max >= y_min, "Wrong bounding box coordinates."
self.x_min = x_min
self.y_min = y_min
self.x_max = x_max
self.y_max = y_max
-
+
@property
- def area(self) -> Union[int,float]:
- ''' To get bbox area '''
+ def area(self) -> Union[int, float]:
+ """To get bbox area"""
return max(0, self.width) * max(0, self.height)
-
+
@property
- def width(self) -> Union[int,float]:
- ''' To get bbox width '''
+ def width(self) -> Union[int, float]:
+ """To get bbox width"""
return max(0, self.x_max - self.x_min)
-
+
@property
- def height(self) -> Union[int,float]:
- ''' To get bbox height '''
+ def height(self) -> Union[int, float]:
+ """To get bbox height"""
return max(0, self.y_max - self.y_min)
-
+
@property
- def get_tuple(self) -> Tuple[Union[int,float],...]:
- ''' To get bbox coordinates in tuple type '''
- return (self.x_min,self.y_min,self.x_max,self.y_max)
-
+ def get_tuple(self) -> Tuple[Union[int, float], ...]:
+ """To get bbox coordinates in tuple type"""
+ return (self.x_min, self.y_min, self.x_max, self.y_max)
+
@property
def atype(self) -> str:
- ''' To get annotation type string '''
- return 'BoundingBox'
+ """To get annotation type string"""
+ return "BoundingBox"
def __repr__(self) -> str:
- return f'BoundingBox(x_min: {self.x_min}, y_min: {self.y_min}, x_max: {self.x_max}, y_max: {self.y_max})'
+ return f"BoundingBox(x_min: {self.x_min}, y_min: {self.y_min}, x_max: {self.x_max}, y_max: {self.y_max})"
def __eq__(self, other) -> bool:
- return all([
- self.x_min == other.x_min,
- self.y_min == other.y_min,
- self.x_max == other.x_max,
- self.y_max == other.y_max
- ])
\ No newline at end of file
+ return all(
+ [
+ self.x_min == other.x_min,
+ self.y_min == other.y_min,
+ self.x_max == other.x_max,
+ self.y_max == other.y_max,
+ ]
+ )
diff --git a/PytorchWildlife/models/detection/herdnet/animaloc/eval/__init__.py b/PytorchWildlife/models/detection/herdnet/animaloc/eval/__init__.py
index cb7dc2dc4..bd5ac7932 100644
--- a/PytorchWildlife/models/detection/herdnet/animaloc/eval/__init__.py
+++ b/PytorchWildlife/models/detection/herdnet/animaloc/eval/__init__.py
@@ -1,5 +1,4 @@
-__copyright__ = \
- """
+__copyright__ = """
Copyright (C) 2024 University of Liège, Gembloux Agro-Bio Tech, Forest Is Life
All rights reserved.
@@ -13,5 +12,5 @@
__license__ = "MIT License"
__version__ = "0.2.1"
-from .stitchers import *
from .lmds import *
+from .stitchers import *
diff --git a/PytorchWildlife/models/detection/herdnet/animaloc/eval/lmds.py b/PytorchWildlife/models/detection/herdnet/animaloc/eval/lmds.py
index cda1fe9fd..1e3d9d458 100644
--- a/PytorchWildlife/models/detection/herdnet/animaloc/eval/lmds.py
+++ b/PytorchWildlife/models/detection/herdnet/animaloc/eval/lmds.py
@@ -1,5 +1,4 @@
-__copyright__ = \
- """
+__copyright__ = """
Copyright (C) 2024 University of Liège, Gembloux Agro-Bio Tech, Forest Is Life
All rights reserved.
@@ -13,56 +12,52 @@
__license__ = "MIT License"
__version__ = "0.2.1"
-import torch
-import numpy
+from typing import List, Tuple
+import numpy
+import torch
import torch.nn.functional as F
-from typing import Tuple, List
-
-__all__ = ['LMDS', 'HerdNetLMDS']
+__all__ = ["LMDS", "HerdNetLMDS"]
class LMDS:
- ''' Local Maxima Detection Strategy
+ """Local Maxima Detection Strategy
Adapted and enhanced from https://github.com/dk-liang/FIDTM (author: dklinag)
- available under the MIT license '''
+ available under the MIT license"""
def __init__(
- self,
- kernel_size: tuple = (3,3),
- adapt_ts: float = 100.0/255.0,
- neg_ts: float = 0.1
- ) -> None:
- '''
+ self, kernel_size: tuple = (3, 3), adapt_ts: float = 100.0 / 255.0, neg_ts: float = 0.1
+ ) -> None:
+ """
Args:
kernel_size (tuple, optional): size of the kernel used to select local
maxima. Defaults to (3,3) (as in the paper).
adapt_ts (float, optional): adaptive threshold to select final points
from candidates. Defaults to 100.0/255.0 (as in the paper).
- neg_ts (float, optional): negative sample threshold used to define if
+ neg_ts (float, optional): negative sample threshold used to define if
an image is a negative sample or not. Defaults to 0.1 (as in the paper).
- '''
+ """
- assert kernel_size[0] == kernel_size[1], \
- f'The kernel shape must be a square, got {kernel_size[0]}x{kernel_size[1]}'
- assert not kernel_size[0] % 2 == 0, \
- f'The kernel size must be odd, got {kernel_size[0]}'
+ assert (
+ kernel_size[0] == kernel_size[1]
+ ), f"The kernel shape must be a square, got {kernel_size[0]}x{kernel_size[1]}"
+ assert not kernel_size[0] % 2 == 0, f"The kernel size must be odd, got {kernel_size[0]}"
self.kernel_size = tuple(kernel_size)
self.adapt_ts = adapt_ts
self.neg_ts = neg_ts
- def __call__(self, est_map: torch.Tensor) -> Tuple[list,list,list,list]:
- '''
+ def __call__(self, est_map: torch.Tensor) -> Tuple[list, list, list, list]:
+ """
Args:
est_map (torch.Tensor): the estimated FIDT map
-
+
Returns:
Tuple[list,list,list,list]
counts, labels, scores and locations per batch
- '''
+ """
batch_size, classes = est_map.shape[:2]
b_counts, b_labels, b_scores, b_locs = [], [], [], []
@@ -72,7 +67,7 @@ def __call__(self, est_map: torch.Tensor) -> Tuple[list,list,list,list]:
for c in range(classes):
count, loc, score = self._lmds(est_map[b][c])
counts.append(count)
- labels = [*labels, *[c+1]*count]
+ labels = [*labels, *[c + 1] * count]
scores = [*scores, *score]
locs = [*locs, *loc]
@@ -82,36 +77,36 @@ def __call__(self, est_map: torch.Tensor) -> Tuple[list,list,list,list]:
b_locs.append(locs)
return b_counts, b_locs, b_labels, b_scores
-
+
def _local_max(self, est_map: torch.Tensor) -> torch.Tensor:
- ''' Shape: est_map = [B,C,H,W] '''
+ """Shape: est_map = [B,C,H,W]"""
pad = int(self.kernel_size[0] / 2)
- keep = torch.nn.functional.max_pool2d(est_map, kernel_size=self.kernel_size, stride=1, padding=pad)
+ keep = torch.nn.functional.max_pool2d(
+ est_map, kernel_size=self.kernel_size, stride=1, padding=pad
+ )
keep = (keep == est_map).float()
est_map = keep * est_map
return est_map
-
+
def _get_locs_and_scores(
- self,
- locs_map: torch.Tensor,
- scores_map: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- ''' Shapes: locs_map = [H,W] and scores_map = [H,W] '''
+ self, locs_map: torch.Tensor, scores_map: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Shapes: locs_map = [H,W] and scores_map = [H,W]"""
locs_map = locs_map.data.cpu().numpy()
scores_map = scores_map.data.cpu().numpy()
locs = []
scores = []
- for i, j in numpy.argwhere(locs_map == 1):
- locs.append((i,j))
+ for i, j in numpy.argwhere(locs_map == 1):
+ locs.append((i, j))
scores.append(scores_map[i][j])
-
+
return torch.Tensor(locs), torch.Tensor(scores)
-
+
def _lmds(self, est_map: torch.Tensor) -> Tuple[int, list, list]:
- ''' Shape: est_map = [H,W] '''
+ """Shape: est_map = [H,W]"""
est_map_max = torch.max(est_map).item()
@@ -132,22 +127,21 @@ def _lmds(self, est_map: torch.Tensor) -> Tuple[int, list, list]:
# locations and scores
locs, scores = self._get_locs_and_scores(
- est_map.squeeze(0).squeeze(0),
- scores_map.squeeze(0).squeeze(0)
- )
+ est_map.squeeze(0).squeeze(0), scores_map.squeeze(0).squeeze(0)
+ )
return count, locs.tolist(), scores.tolist()
-class HerdNetLMDS(LMDS):
+class HerdNetLMDS(LMDS):
def __init__(
- self,
- up: bool = True,
- kernel_size: tuple = (3,3),
- adapt_ts: float = 0.3,
- neg_ts: float = 0.1
- ) -> None:
- '''
+ self,
+ up: bool = True,
+ kernel_size: tuple = (3, 3),
+ adapt_ts: float = 0.3,
+ neg_ts: float = 0.1,
+ ) -> None:
+ """
Args:
up (bool, optional): set to False to disable class maps upsampling.
Defaults to True.
@@ -155,14 +149,14 @@ def __init__(
maxima. Defaults to (3,3) (as in the paper).
adapt_ts (float, optional): adaptive threshold to select final points
from candidates. Defaults to 0.3.
- neg_ts (float, optional): negative sample threshold used to define if
+ neg_ts (float, optional): negative sample threshold used to define if
an image is a negative sample or not. Defaults to 0.1 (as in the paper).
- '''
+ """
super().__init__(kernel_size=kernel_size, adapt_ts=adapt_ts, neg_ts=neg_ts)
self.up = up
-
+
def __call__(self, outputs: List[torch.Tensor]) -> Tuple[list, list, list, list, list]:
"""
Args:
@@ -176,14 +170,14 @@ def __call__(self, outputs: List[torch.Tensor]) -> Tuple[list, list, list, list,
"""
heatmap, clsmap = outputs
-
+
# upsample class map
if self.up:
scale_factor = 16
- clsmap = F.interpolate(clsmap, scale_factor=scale_factor, mode='nearest')
+ clsmap = F.interpolate(clsmap, scale_factor=scale_factor, mode="nearest")
# softmax
- cls_scores = torch.softmax(clsmap, dim=1)[:,1:,:,:]
+ cls_scores = torch.softmax(clsmap, dim=1)[:, 1:, :, :]
# cat to heatmap
outmaps = torch.cat([heatmap, cls_scores], dim=1)
@@ -193,10 +187,9 @@ def __call__(self, outputs: List[torch.Tensor]) -> Tuple[list, list, list, list,
b_counts, b_labels, b_scores, b_locs, b_dscores = [], [], [], [], []
for b in range(batch_size):
-
_, locs, _ = self._lmds(heatmap[b][0])
- cls_idx = torch.argmax(clsmap[b,1:,:,:], dim=0)
+ cls_idx = torch.argmax(clsmap[b, 1:, :, :], dim=0)
classes = torch.add(cls_idx, 1)
h_idx = torch.Tensor([l[0] for l in locs]).long()
@@ -216,4 +209,4 @@ def __call__(self, outputs: List[torch.Tensor]) -> Tuple[list, list, list, list,
b_counts.append(counts)
b_dscores.append(dscores)
- return b_counts, b_locs, b_labels, b_scores, b_dscores
\ No newline at end of file
+ return b_counts, b_locs, b_labels, b_scores, b_dscores
diff --git a/PytorchWildlife/models/detection/herdnet/animaloc/eval/stitchers.py b/PytorchWildlife/models/detection/herdnet/animaloc/eval/stitchers.py
index f22ddcd17..bda49b8c1 100644
--- a/PytorchWildlife/models/detection/herdnet/animaloc/eval/stitchers.py
+++ b/PytorchWildlife/models/detection/herdnet/animaloc/eval/stitchers.py
@@ -1,5 +1,4 @@
-__copyright__ = \
- """
+__copyright__ = """
Copyright (C) 2024 University of Liège, Gembloux Agro-Bio Tech, Forest Is Life
All rights reserved.
@@ -14,20 +13,20 @@
__version__ = "0.2.1"
-import torch
-import torchvision
+from typing import List, Tuple
-import torch.nn.functional as F
import numpy as np
-
-from typing import List, Tuple
-from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
+import torch
+import torch.nn.functional as F
+import torchvision
+from torch.utils.data import DataLoader, SequentialSampler, TensorDataset
from ..data import ImageToPatches
+
class Stitcher(ImageToPatches):
- ''' Class to stitch detections of patches into original image
- coordinates system
+ """Class to stitch detections of patches into original image
+ coordinates system
This algorithm works as follow:
1) Cut original image into patches
@@ -35,43 +34,44 @@ class Stitcher(ImageToPatches):
3) Patch the detections maps into the coordinate system of the original image
Optional:
4) Upsample the patched detection map
- '''
+ """
def __init__(
self,
- model: torch.nn.Module,
- size: Tuple[int,int],
+ model: torch.nn.Module,
+ size: Tuple[int, int],
overlap: int = 100,
batch_size: int = 1,
down_ratio: int = 1,
up: bool = False,
- reduction: str = 'sum',
- device_name: str = 'cuda',
- ) -> None:
- '''
+ reduction: str = "sum",
+ device_name: str = "cuda",
+ ) -> None:
+ """
Args:
model (torch.nn.Module): CNN detection model, that takes as inputs image and returns
output and dict (i.e. wrapped by LossWrapper)
size (tuple): patches size (height, width), in pixels
- overlap (int, optional): overlap between patches, in pixels.
- Defaults to 100.
- batch_size (int, optional): batch size used for inference over patches.
+ overlap (int, optional): overlap between patches, in pixels.
+ Defaults to 100.
+ batch_size (int, optional): batch size used for inference over patches.
Defaults to 1.
- down_ratio (int, optional): downsample ratio. Set to 1 to get output of the same
+ down_ratio (int, optional): downsample ratio. Set to 1 to get output of the same
size as input (i.e. no downsample). Defaults to 1.
up (bool, optional): set to True to upsample the patched map. Defaults to False.
reduction (str, optional): specifies the reduction to apply on overlapping areas.
Possible values are 'sum', 'mean', 'max'. Defaults to 'sum'.
- device_name (str, optional): the device name on which tensors will be allocated
+ device_name (str, optional): the device name on which tensors will be allocated
('cpu' or 'cuda'). Defaults to 'cuda'.
- '''
+ """
- assert isinstance(model, torch.nn.Module), \
- 'model argument must be an instance of nn.Module()'
-
- assert reduction in ['sum', 'mean', 'max'], \
- 'reduction argument possible values are \'sum\', \'mean\' and \'max\' ' \
- f'got \'{reduction}\''
+ assert isinstance(
+ model, torch.nn.Module
+ ), "model argument must be an instance of nn.Module()"
+
+ assert reduction in ["sum", "mean", "max"], (
+ "reduction argument possible values are 'sum', 'mean' and 'max' " f"got '{reduction}'"
+ )
self.model = model
self.size = size
@@ -84,23 +84,20 @@ def __init__(
self.model.to(self.device)
- def __call__(
- self,
- image: torch.Tensor
- ) -> torch.Tensor:
- ''' Apply the stitching algorithm to the image
+ def __call__(self, image: torch.Tensor) -> torch.Tensor:
+ """Apply the stitching algorithm to the image
Args:
image (torch.Tensor): image of shape [C,H,W]
-
+
Returns:
torch.Tensor
the detections into the coordinate system of the original image
- '''
+ """
super(Stitcher, self).__init__(image, self.size, self.overlap)
-
- self.image = image.to(torch.device('cpu'))
+
+ self.image = image.to(torch.device("cpu"))
# step 1 - get patches and limits
patches = self.make_patches()
@@ -114,23 +111,20 @@ def __call__(
# (step 4 - upsample)
if self.up:
- patched_map = F.interpolate(patched_map, scale_factor=self.down_ratio,
- mode='bilinear', align_corners=True)
+ patched_map = F.interpolate(
+ patched_map, scale_factor=self.down_ratio, mode="bilinear", align_corners=True
+ )
return patched_map
-
@torch.no_grad()
def _inference(self, patches: torch.Tensor) -> List[torch.Tensor]:
-
self.model.eval()
dataset = TensorDataset(patches)
dataloader = DataLoader(
- dataset,
- batch_size=self.batch_size,
- sampler=SequentialSampler(dataset)
- )
+ dataset, batch_size=self.batch_size, sampler=SequentialSampler(dataset)
+ )
maps = []
for patch in dataloader:
@@ -141,89 +135,84 @@ def _inference(self, patches: torch.Tensor) -> List[torch.Tensor]:
return maps
def _patch_maps(self, maps: List[torch.Tensor]) -> torch.Tensor:
-
_, h, w = self.image.shape
dh, dw = h // self.down_ratio, w // self.down_ratio
kernel_size = np.array(self.size) // self.down_ratio
stride = kernel_size - self.overlap // self.down_ratio
output_size = (
- self._ncol * kernel_size[0] - ((self._ncol-1) * self.overlap // self.down_ratio),
- self._nrow * kernel_size[1] - ((self._nrow-1) * self.overlap // self.down_ratio)
- )
+ self._ncol * kernel_size[0] - ((self._ncol - 1) * self.overlap // self.down_ratio),
+ self._nrow * kernel_size[1] - ((self._nrow - 1) * self.overlap // self.down_ratio),
+ )
maps = torch.cat(maps, dim=0)
- if self.reduction == 'max':
- out_map = self._max_fold(maps, output_size=output_size,
- kernel_size=tuple(kernel_size), stride=tuple(stride))
+ if self.reduction == "max":
+ out_map = self._max_fold(
+ maps, output_size=output_size, kernel_size=tuple(kernel_size), stride=tuple(stride)
+ )
else:
n_patches = maps.shape[0]
- maps = maps.permute(1,2,3,0).contiguous().view(1, -1, n_patches)
- out_map = F.fold(maps, output_size=output_size,
- kernel_size=tuple(kernel_size), stride=tuple(stride))
+ maps = maps.permute(1, 2, 3, 0).contiguous().view(1, -1, n_patches)
+ out_map = F.fold(
+ maps, output_size=output_size, kernel_size=tuple(kernel_size), stride=tuple(stride)
+ )
- out_map = out_map[:,:, 0:dh, 0:dw]
+ out_map = out_map[:, :, 0:dh, 0:dw]
return out_map
-
- def _reduce(self, map: torch.Tensor) -> torch.Tensor:
+ def _reduce(self, map: torch.Tensor) -> torch.Tensor:
dh = self.image.shape[1] // self.down_ratio
dw = self.image.shape[2] // self.down_ratio
- ones = torch.ones(self.image.shape[0],dh,dw)
+ ones = torch.ones(self.image.shape[0], dh, dw)
- if self.reduction == 'mean':
- ones_patches = ImageToPatches(ones,
- np.array(self.size)//self.down_ratio,
- self.overlap//self.down_ratio
- ).make_patches()
+ if self.reduction == "mean":
+ ones_patches = ImageToPatches(
+ ones, np.array(self.size) // self.down_ratio, self.overlap // self.down_ratio
+ ).make_patches()
- ones_patches = [p.unsqueeze(0).unsqueeze(0) for p in ones_patches[:,1,:,:]]
+ ones_patches = [p.unsqueeze(0).unsqueeze(0) for p in ones_patches[:, 1, :, :]]
norm_map = self._patch_maps(ones_patches)
-
+
else:
- norm_map = ones[1,:,:]
-
+ norm_map = ones[1, :, :]
+
return torch.div(map.to(self.device), norm_map.to(self.device))
-
- def _max_fold(self, maps: torch.Tensor, output_size: tuple,
- kernel_size: tuple, stride: tuple
- ) -> torch.Tensor:
-
+
+ def _max_fold(
+ self, maps: torch.Tensor, output_size: tuple, kernel_size: tuple, stride: tuple
+ ) -> torch.Tensor:
output = torch.zeros((1, maps.shape[1], *output_size))
- fn = lambda x: [[i, i+kernel_size[x]] for i in range(0, output_size[x], stride[x])][:-1]
+ fn = lambda x: [[i, i + kernel_size[x]] for i in range(0, output_size[x], stride[x])][:-1]
locs = [[*h, *w] for h in fn(0) for w in fn(1)]
for loc, m in zip(locs, maps):
patch = torch.zeros(output.shape)
- patch[:,:, loc[0]:loc[1], loc[2]:loc[3]] = m
+ patch[:, :, loc[0] : loc[1], loc[2] : loc[3]] = m
output = torch.max(output, patch)
return output
-class HerdNetStitcher(Stitcher):
+class HerdNetStitcher(Stitcher):
@torch.no_grad()
def _inference(self, patches: torch.Tensor) -> List[torch.Tensor]:
-
self.model.eval()
dataset = TensorDataset(patches)
dataloader = DataLoader(
- dataset,
- batch_size=self.batch_size,
- sampler=SequentialSampler(dataset)
- )
+ dataset, batch_size=self.batch_size, sampler=SequentialSampler(dataset)
+ )
maps = []
for patch in dataloader:
patch = patch[0].to(self.device)
- #outputs = self.model(patch)[0]
- outputs = self.model(patch) # LossWrapper is not used
+ # outputs = self.model(patch)[0]
+ outputs = self.model(patch) # LossWrapper is not used
heatmap = outputs[0]
scale_factor = 16
- clsmap = F.interpolate(outputs[1], scale_factor=scale_factor, mode='nearest')
+ clsmap = F.interpolate(outputs[1], scale_factor=scale_factor, mode="nearest")
# cat
outmaps = torch.cat([heatmap, clsmap], dim=1)
maps = [*maps, *outmaps.unsqueeze(0)]
diff --git a/PytorchWildlife/models/detection/herdnet/dla.py b/PytorchWildlife/models/detection/herdnet/dla.py
index dd87cccff..8b08559ce 100644
--- a/PytorchWildlife/models/detection/herdnet/dla.py
+++ b/PytorchWildlife/models/detection/herdnet/dla.py
@@ -1,5 +1,4 @@
-__copyright__ = \
- """
+__copyright__ = """
MIT License
Copyright (c) 2019 Xingyi Zhou
@@ -23,35 +22,40 @@
from os.path import join
from posixpath import basename
+import numpy as np
import torch
-from torch import nn
import torch.utils.model_zoo as model_zoo
-
-import numpy as np
+from torch import nn
BatchNorm = nn.BatchNorm2d
-def get_model_url(data='imagenet', name='dla34', hash='ba72cf86'):
- return join('http://dl.yf.io/dla/models', data, '{}-{}.pth'.format(name, hash))
+
+def get_model_url(data="imagenet", name="dla34", hash="ba72cf86"):
+ return join("http://dl.yf.io/dla/models", data, "{}-{}.pth".format(name, hash))
def conv3x3(in_planes, out_planes, stride=1):
"3x3 convolution with padding"
- return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
- padding=1, bias=False)
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
class BasicBlock(nn.Module):
def __init__(self, inplanes, planes, stride=1, dilation=1):
super(BasicBlock, self).__init__()
- self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3,
- stride=stride, padding=dilation,
- bias=False, dilation=dilation)
+ self.conv1 = nn.Conv2d(
+ inplanes,
+ planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ bias=False,
+ dilation=dilation,
+ )
self.bn1 = BatchNorm(planes)
self.relu = nn.ReLU(inplace=True)
- self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
- stride=1, padding=dilation,
- bias=False, dilation=dilation)
+ self.conv2 = nn.Conv2d(
+ planes, planes, kernel_size=3, stride=1, padding=dilation, bias=False, dilation=dilation
+ )
self.bn2 = BatchNorm(planes)
self.stride = stride
@@ -79,15 +83,19 @@ def __init__(self, inplanes, planes, stride=1, dilation=1):
super(Bottleneck, self).__init__()
expansion = Bottleneck.expansion
bottle_planes = planes // expansion
- self.conv1 = nn.Conv2d(inplanes, bottle_planes,
- kernel_size=1, bias=False)
+ self.conv1 = nn.Conv2d(inplanes, bottle_planes, kernel_size=1, bias=False)
self.bn1 = BatchNorm(bottle_planes)
- self.conv2 = nn.Conv2d(bottle_planes, bottle_planes, kernel_size=3,
- stride=stride, padding=dilation,
- bias=False, dilation=dilation)
+ self.conv2 = nn.Conv2d(
+ bottle_planes,
+ bottle_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ bias=False,
+ dilation=dilation,
+ )
self.bn2 = BatchNorm(bottle_planes)
- self.conv3 = nn.Conv2d(bottle_planes, planes,
- kernel_size=1, bias=False)
+ self.conv3 = nn.Conv2d(bottle_planes, planes, kernel_size=1, bias=False)
self.bn3 = BatchNorm(planes)
self.relu = nn.ReLU(inplace=True)
self.stride = stride
@@ -121,15 +129,20 @@ def __init__(self, inplanes, planes, stride=1, dilation=1):
super(BottleneckX, self).__init__()
cardinality = BottleneckX.cardinality
bottle_planes = planes * cardinality // 32
- self.conv1 = nn.Conv2d(inplanes, bottle_planes,
- kernel_size=1, bias=False)
+ self.conv1 = nn.Conv2d(inplanes, bottle_planes, kernel_size=1, bias=False)
self.bn1 = BatchNorm(bottle_planes)
- self.conv2 = nn.Conv2d(bottle_planes, bottle_planes, kernel_size=3,
- stride=stride, padding=dilation, bias=False,
- dilation=dilation, groups=cardinality)
+ self.conv2 = nn.Conv2d(
+ bottle_planes,
+ bottle_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ bias=False,
+ dilation=dilation,
+ groups=cardinality,
+ )
self.bn2 = BatchNorm(bottle_planes)
- self.conv3 = nn.Conv2d(bottle_planes, planes,
- kernel_size=1, bias=False)
+ self.conv3 = nn.Conv2d(bottle_planes, planes, kernel_size=1, bias=False)
self.bn3 = BatchNorm(planes)
self.relu = nn.ReLU(inplace=True)
self.stride = stride
@@ -159,8 +172,8 @@ class Root(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, residual):
super(Root, self).__init__()
self.conv = nn.Conv2d(
- in_channels, out_channels, 1,
- stride=1, bias=False, padding=(kernel_size - 1) // 2)
+ in_channels, out_channels, 1, stride=1, bias=False, padding=(kernel_size - 1) // 2
+ )
self.bn = BatchNorm(out_channels)
self.relu = nn.ReLU(inplace=True)
self.residual = residual
@@ -177,31 +190,51 @@ def forward(self, *x):
class Tree(nn.Module):
- def __init__(self, levels, block, in_channels, out_channels, stride=1,
- level_root=False, root_dim=0, root_kernel_size=1,
- dilation=1, root_residual=False):
+ def __init__(
+ self,
+ levels,
+ block,
+ in_channels,
+ out_channels,
+ stride=1,
+ level_root=False,
+ root_dim=0,
+ root_kernel_size=1,
+ dilation=1,
+ root_residual=False,
+ ):
super(Tree, self).__init__()
if root_dim == 0:
root_dim = 2 * out_channels
if level_root:
root_dim += in_channels
if levels == 1:
- self.tree1 = block(in_channels, out_channels, stride,
- dilation=dilation)
- self.tree2 = block(out_channels, out_channels, 1,
- dilation=dilation)
+ self.tree1 = block(in_channels, out_channels, stride, dilation=dilation)
+ self.tree2 = block(out_channels, out_channels, 1, dilation=dilation)
else:
- self.tree1 = Tree(levels - 1, block, in_channels, out_channels,
- stride, root_dim=0,
- root_kernel_size=root_kernel_size,
- dilation=dilation, root_residual=root_residual)
- self.tree2 = Tree(levels - 1, block, out_channels, out_channels,
- root_dim=root_dim + out_channels,
- root_kernel_size=root_kernel_size,
- dilation=dilation, root_residual=root_residual)
+ self.tree1 = Tree(
+ levels - 1,
+ block,
+ in_channels,
+ out_channels,
+ stride,
+ root_dim=0,
+ root_kernel_size=root_kernel_size,
+ dilation=dilation,
+ root_residual=root_residual,
+ )
+ self.tree2 = Tree(
+ levels - 1,
+ block,
+ out_channels,
+ out_channels,
+ root_dim=root_dim + out_channels,
+ root_kernel_size=root_kernel_size,
+ dilation=dilation,
+ root_residual=root_residual,
+ )
if levels == 1:
- self.root = Root(root_dim, out_channels, root_kernel_size,
- root_residual)
+ self.root = Root(root_dim, out_channels, root_kernel_size, root_residual)
self.level_root = level_root
self.root_dim = root_dim
self.downsample = None
@@ -211,9 +244,8 @@ def __init__(self, levels, block, in_channels, out_channels, stride=1,
self.downsample = nn.MaxPool2d(stride, stride=stride)
if in_channels != out_channels:
self.project = nn.Sequential(
- nn.Conv2d(in_channels, out_channels,
- kernel_size=1, stride=1, bias=False),
- BatchNorm(out_channels)
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
+ BatchNorm(out_channels),
)
def forward(self, x, residual=None, children=None):
@@ -233,40 +265,74 @@ def forward(self, x, residual=None, children=None):
class DLA(nn.Module):
- def __init__(self, levels, channels, num_classes=1000,
- block=BasicBlock, residual_root=False, return_levels=False,
- pool_size=7, linear_root=False):
+ def __init__(
+ self,
+ levels,
+ channels,
+ num_classes=1000,
+ block=BasicBlock,
+ residual_root=False,
+ return_levels=False,
+ pool_size=7,
+ linear_root=False,
+ ):
super(DLA, self).__init__()
self.channels = channels
self.return_levels = return_levels
self.num_classes = num_classes
self.base_layer = nn.Sequential(
- nn.Conv2d(3, channels[0], kernel_size=7, stride=1,
- padding=3, bias=False),
+ nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3, bias=False),
BatchNorm(channels[0]),
- nn.ReLU(inplace=True))
- self.level0 = self._make_conv_level(
- channels[0], channels[0], levels[0])
- self.level1 = self._make_conv_level(
- channels[0], channels[1], levels[1], stride=2)
- self.level2 = Tree(levels[2], block, channels[1], channels[2], 2,
- level_root=False,
- root_residual=residual_root)
- self.level3 = Tree(levels[3], block, channels[2], channels[3], 2,
- level_root=True, root_residual=residual_root)
- self.level4 = Tree(levels[4], block, channels[3], channels[4], 2,
- level_root=True, root_residual=residual_root)
- self.level5 = Tree(levels[5], block, channels[4], channels[5], 2,
- level_root=True, root_residual=residual_root)
+ nn.ReLU(inplace=True),
+ )
+ self.level0 = self._make_conv_level(channels[0], channels[0], levels[0])
+ self.level1 = self._make_conv_level(channels[0], channels[1], levels[1], stride=2)
+ self.level2 = Tree(
+ levels[2],
+ block,
+ channels[1],
+ channels[2],
+ 2,
+ level_root=False,
+ root_residual=residual_root,
+ )
+ self.level3 = Tree(
+ levels[3],
+ block,
+ channels[2],
+ channels[3],
+ 2,
+ level_root=True,
+ root_residual=residual_root,
+ )
+ self.level4 = Tree(
+ levels[4],
+ block,
+ channels[3],
+ channels[4],
+ 2,
+ level_root=True,
+ root_residual=residual_root,
+ )
+ self.level5 = Tree(
+ levels[5],
+ block,
+ channels[4],
+ channels[5],
+ 2,
+ level_root=True,
+ root_residual=residual_root,
+ )
self.avgpool = nn.AvgPool2d(pool_size)
- self.fc = nn.Conv2d(channels[-1], num_classes, kernel_size=1,
- stride=1, padding=0, bias=True)
+ self.fc = nn.Conv2d(
+ channels[-1], num_classes, kernel_size=1, stride=1, padding=0, bias=True
+ )
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
- m.weight.data.normal_(0, math.sqrt(2. / n))
+ m.weight.data.normal_(0, math.sqrt(2.0 / n))
elif isinstance(m, BatchNorm):
m.weight.data.fill_(1)
m.bias.data.zero_()
@@ -276,8 +342,7 @@ def _make_level(self, block, inplanes, planes, blocks, stride=1):
if stride != 1 or inplanes != planes:
downsample = nn.Sequential(
nn.MaxPool2d(stride, stride=stride),
- nn.Conv2d(inplanes, planes,
- kernel_size=1, stride=1, bias=False),
+ nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False),
BatchNorm(planes),
)
@@ -291,12 +356,21 @@ def _make_level(self, block, inplanes, planes, blocks, stride=1):
def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1):
modules = []
for i in range(convs):
- modules.extend([
- nn.Conv2d(inplanes, planes, kernel_size=3,
- stride=stride if i == 0 else 1,
- padding=dilation, bias=False, dilation=dilation),
- BatchNorm(planes),
- nn.ReLU(inplace=True)])
+ modules.extend(
+ [
+ nn.Conv2d(
+ inplanes,
+ planes,
+ kernel_size=3,
+ stride=stride if i == 0 else 1,
+ padding=dilation,
+ bias=False,
+ dilation=dilation,
+ ),
+ BatchNorm(planes),
+ nn.ReLU(inplace=True),
+ ]
+ )
inplanes = planes
return nn.Sequential(*modules)
@@ -304,7 +378,7 @@ def forward(self, x):
y = []
x = self.base_layer(x)
for i in range(6):
- x = getattr(self, 'level{}'.format(i))(x)
+ x = getattr(self, "level{}".format(i))(x)
y.append(x)
if self.return_levels:
return y
@@ -315,113 +389,121 @@ def forward(self, x):
return x
- def load_pretrained_model(self, data='imagenet', name='dla34', hash='ba72cf86'):
+ def load_pretrained_model(self, data="imagenet", name="dla34", hash="ba72cf86"):
fc = self.fc
- if name.endswith('.pth'):
+ if name.endswith(".pth"):
model_weights = torch.load(data + name)
else:
model_url = get_model_url(data, name, hash)
model_weights = model_zoo.load_url(model_url)
num_classes = len(model_weights[list(model_weights.keys())[-1]])
self.fc = nn.Conv2d(
- self.channels[-1], num_classes,
- kernel_size=1, stride=1, padding=0, bias=True)
+ self.channels[-1], num_classes, kernel_size=1, stride=1, padding=0, bias=True
+ )
self.load_state_dict(model_weights)
self.fc = fc
def dla34(pretrained, **kwargs): # DLA-34
- model = DLA([1, 1, 1, 2, 2, 1],
- [16, 32, 64, 128, 256, 512],
- block=BasicBlock, **kwargs)
+ model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512], block=BasicBlock, **kwargs)
if pretrained:
- model.load_pretrained_model(data='imagenet', name='dla34', hash='ba72cf86')
+ model.load_pretrained_model(data="imagenet", name="dla34", hash="ba72cf86")
return model
def dla46_c(pretrained=None, **kwargs): # DLA-46-C
Bottleneck.expansion = 2
- model = DLA([1, 1, 1, 2, 2, 1],
- [16, 32, 64, 64, 128, 256],
- block=Bottleneck, **kwargs)
+ model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 64, 128, 256], block=Bottleneck, **kwargs)
if pretrained is not None:
- model.load_pretrained_model(pretrained, 'dla46_c')
+ model.load_pretrained_model(pretrained, "dla46_c")
return model
def dla46x_c(pretrained=None, **kwargs): # DLA-X-46-C
BottleneckX.expansion = 2
- model = DLA([1, 1, 1, 2, 2, 1],
- [16, 32, 64, 64, 128, 256],
- block=BottleneckX, **kwargs)
+ model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 64, 128, 256], block=BottleneckX, **kwargs)
if pretrained is not None:
- model.load_pretrained_model(pretrained, 'dla46x_c')
+ model.load_pretrained_model(pretrained, "dla46x_c")
return model
def dla60x_c(pretrained, **kwargs): # DLA-X-60-C
BottleneckX.expansion = 2
- model = DLA([1, 1, 1, 2, 3, 1],
- [16, 32, 64, 64, 128, 256],
- block=BottleneckX, **kwargs)
+ model = DLA([1, 1, 1, 2, 3, 1], [16, 32, 64, 64, 128, 256], block=BottleneckX, **kwargs)
if pretrained:
- model.load_pretrained_model(data='imagenet', name='dla60x_c', hash='b870c45c')
+ model.load_pretrained_model(data="imagenet", name="dla60x_c", hash="b870c45c")
return model
def dla60(pretrained=None, **kwargs): # DLA-60
Bottleneck.expansion = 2
- model = DLA([1, 1, 1, 2, 3, 1],
- [16, 32, 128, 256, 512, 1024],
- block=Bottleneck, **kwargs)
+ model = DLA([1, 1, 1, 2, 3, 1], [16, 32, 128, 256, 512, 1024], block=Bottleneck, **kwargs)
if pretrained is not None:
- model.load_pretrained_model(pretrained, 'dla60')
+ model.load_pretrained_model(pretrained, "dla60")
return model
def dla60x(pretrained=None, **kwargs): # DLA-X-60
BottleneckX.expansion = 2
- model = DLA([1, 1, 1, 2, 3, 1],
- [16, 32, 128, 256, 512, 1024],
- block=BottleneckX, **kwargs)
+ model = DLA([1, 1, 1, 2, 3, 1], [16, 32, 128, 256, 512, 1024], block=BottleneckX, **kwargs)
if pretrained is not None:
- model.load_pretrained_model(pretrained, 'dla60x')
+ model.load_pretrained_model(pretrained, "dla60x")
return model
def dla102(pretrained=None, **kwargs): # DLA-102
Bottleneck.expansion = 2
- model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024],
- block=Bottleneck, residual_root=True, **kwargs)
+ model = DLA(
+ [1, 1, 1, 3, 4, 1],
+ [16, 32, 128, 256, 512, 1024],
+ block=Bottleneck,
+ residual_root=True,
+ **kwargs
+ )
if pretrained is not None:
- model.load_pretrained_model(pretrained, 'dla102')
+ model.load_pretrained_model(pretrained, "dla102")
return model
def dla102x(pretrained=None, **kwargs): # DLA-X-102
BottleneckX.expansion = 2
- model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024],
- block=BottleneckX, residual_root=True, **kwargs)
+ model = DLA(
+ [1, 1, 1, 3, 4, 1],
+ [16, 32, 128, 256, 512, 1024],
+ block=BottleneckX,
+ residual_root=True,
+ **kwargs
+ )
if pretrained is not None:
- model.load_pretrained_model(pretrained, 'dla102x')
+ model.load_pretrained_model(pretrained, "dla102x")
return model
def dla102x2(pretrained=None, **kwargs): # DLA-X-102 64
BottleneckX.cardinality = 64
- model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024],
- block=BottleneckX, residual_root=True, **kwargs)
+ model = DLA(
+ [1, 1, 1, 3, 4, 1],
+ [16, 32, 128, 256, 512, 1024],
+ block=BottleneckX,
+ residual_root=True,
+ **kwargs
+ )
if pretrained is not None:
- model.load_pretrained_model(pretrained, 'dla102x2')
+ model.load_pretrained_model(pretrained, "dla102x2")
return model
def dla169(pretrained=None, **kwargs): # DLA-169
Bottleneck.expansion = 2
- model = DLA([1, 1, 2, 3, 5, 1], [16, 32, 128, 256, 512, 1024],
- block=Bottleneck, residual_root=True, **kwargs)
+ model = DLA(
+ [1, 1, 2, 3, 5, 1],
+ [16, 32, 128, 256, 512, 1024],
+ block=Bottleneck,
+ residual_root=True,
+ **kwargs
+ )
if pretrained is not None:
- model.load_pretrained_model(pretrained, 'dla169')
+ model.load_pretrained_model(pretrained, "dla169")
return model
@@ -436,11 +518,10 @@ def forward(self, x):
def fill_up_weights(up):
w = up.weight.data
f = math.ceil(w.size(2) / 2)
- c = (2 * f - 1 - f % 2) / (2. * f)
+ c = (2 * f - 1 - f % 2) / (2.0 * f)
for i in range(w.size(2)):
for j in range(w.size(3)):
- w[0, 0, i, j] = \
- (1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c))
+ w[0, 0, i, j] = (1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c))
for c in range(1, w.size(0)):
w[c, 0, :, :] = w[0, 0, :, :]
@@ -455,50 +536,64 @@ def __init__(self, node_kernel, out_dim, channels, up_factors):
proj = Identity()
else:
proj = nn.Sequential(
- nn.Conv2d(c, out_dim,
- kernel_size=1, stride=1, bias=False),
+ nn.Conv2d(c, out_dim, kernel_size=1, stride=1, bias=False),
BatchNorm(out_dim),
- nn.ReLU(inplace=True))
+ nn.ReLU(inplace=True),
+ )
f = int(up_factors[i])
if f == 1:
up = Identity()
else:
up = nn.ConvTranspose2d(
- out_dim, out_dim, f * 2, stride=f, padding=f // 2,
- output_padding=0, groups=out_dim, bias=False)
+ out_dim,
+ out_dim,
+ f * 2,
+ stride=f,
+ padding=f // 2,
+ output_padding=0,
+ groups=out_dim,
+ bias=False,
+ )
fill_up_weights(up)
- setattr(self, 'proj_' + str(i), proj)
- setattr(self, 'up_' + str(i), up)
+ setattr(self, "proj_" + str(i), proj)
+ setattr(self, "up_" + str(i), up)
for i in range(1, len(channels)):
node = nn.Sequential(
- nn.Conv2d(out_dim * 2, out_dim,
- kernel_size=node_kernel, stride=1,
- padding=node_kernel // 2, bias=False),
+ nn.Conv2d(
+ out_dim * 2,
+ out_dim,
+ kernel_size=node_kernel,
+ stride=1,
+ padding=node_kernel // 2,
+ bias=False,
+ ),
BatchNorm(out_dim),
- nn.ReLU(inplace=True))
- setattr(self, 'node_' + str(i), node)
+ nn.ReLU(inplace=True),
+ )
+ setattr(self, "node_" + str(i), node)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
- m.weight.data.normal_(0, math.sqrt(2. / n))
+ m.weight.data.normal_(0, math.sqrt(2.0 / n))
elif isinstance(m, BatchNorm):
m.weight.data.fill_(1)
m.bias.data.zero_()
def forward(self, layers):
- assert len(self.channels) == len(layers), \
- '{} vs {} layers'.format(len(self.channels), len(layers))
+ assert len(self.channels) == len(layers), "{} vs {} layers".format(
+ len(self.channels), len(layers)
+ )
layers = list(layers)
for i, l in enumerate(layers):
- upsample = getattr(self, 'up_' + str(i))
- project = getattr(self, 'proj_' + str(i))
+ upsample = getattr(self, "up_" + str(i))
+ project = getattr(self, "proj_" + str(i))
layers[i] = upsample(project(l))
x = layers[0]
y = []
for i in range(1, len(layers)):
- node = getattr(self, 'node_' + str(i))
+ node = getattr(self, "node_" + str(i))
x = node(torch.cat([x, layers[i]], 1))
y.append(x)
return x, y
@@ -514,21 +609,24 @@ def __init__(self, channels, scales=(1, 2, 4, 8, 16), in_channels=None):
scales = np.array(scales, dtype=int)
for i in range(len(channels) - 1):
j = -i - 2
- setattr(self, 'ida_{}'.format(i),
- IDAUp(3, channels[j], in_channels[j:],
- scales[j:] // scales[j]))
- scales[j + 1:] = scales[j]
- in_channels[j + 1:] = [channels[j] for _ in channels[j + 1:]]
+ setattr(
+ self,
+ "ida_{}".format(i),
+ IDAUp(3, channels[j], in_channels[j:], scales[j:] // scales[j]),
+ )
+ scales[j + 1 :] = scales[j]
+ in_channels[j + 1 :] = [channels[j] for _ in channels[j + 1 :]]
def forward(self, layers):
layers = list(layers)
assert len(layers) > 1
for i in range(len(layers) - 1):
- ida = getattr(self, 'ida_{}'.format(i))
- x, y = ida(layers[-i - 2:])
- layers[-i - 1:] = y
+ ida = getattr(self, "ida_{}".format(i))
+ x, y = ida(layers[-i - 2 :])
+ layers[-i - 1 :] = y
return x
+
def fill_fc_weights(layers):
for m in layers.modules():
if isinstance(m, nn.Conv2d):
@@ -536,37 +634,41 @@ def fill_fc_weights(layers):
if m.bias is not None:
nn.init.constant_(m.bias, 0)
+
class DLASeg(nn.Module):
- def __init__(self, base_name, heads,
- pretrained=True, down_ratio=4, head_conv=256):
+ def __init__(self, base_name, heads, pretrained=True, down_ratio=4, head_conv=256):
super(DLASeg, self).__init__()
self.heads = heads
self.first_level = int(np.log2(down_ratio))
- self.base = globals()[base_name](
- pretrained=pretrained, return_levels=True)
+ self.base = globals()[base_name](pretrained=pretrained, return_levels=True)
channels = self.base.channels
- scales = [2 ** i for i in range(len(channels[self.first_level:]))]
- self.dla_up = DLAUp(channels[self.first_level:], scales=scales)
+ scales = [2**i for i in range(len(channels[self.first_level :]))]
+ self.dla_up = DLAUp(channels[self.first_level :], scales=scales)
for head in self.heads:
classes = self.heads[head]
if head_conv > 0:
fc = nn.Sequential(
- nn.Conv2d(channels[self.first_level], head_conv,
- kernel_size=3, padding=1, bias=True),
- nn.ReLU(inplace=True),
- nn.Conv2d(head_conv, classes,
- kernel_size=1, stride=1,
- padding=0, bias=True))
- if 'hm' in head:
+ nn.Conv2d(
+ channels[self.first_level], head_conv, kernel_size=3, padding=1, bias=True
+ ),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(head_conv, classes, kernel_size=1, stride=1, padding=0, bias=True),
+ )
+ if "hm" in head:
fc[-1].bias.data.fill_(-2.19)
else:
fill_fc_weights(fc)
else:
- fc = nn.Conv2d(channels[self.first_level], classes,
- kernel_size=1, stride=1,
- padding=0, bias=True)
- if 'hm' in head:
+ fc = nn.Conv2d(
+ channels[self.first_level],
+ classes,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=True,
+ )
+ if "hm" in head:
fc.bias.data.fill_(-2.19)
else:
fill_fc_weights(fc)
@@ -574,7 +676,7 @@ def __init__(self, base_name, heads,
def forward(self, x):
x = self.base(x)
- x = self.dla_up(x[self.first_level:])
+ x = self.dla_up(x[self.first_level :])
# x = self.fc(x)
# y = self.softmax(self.up(x))
ret = {}
@@ -582,9 +684,13 @@ def forward(self, x):
ret[head] = self.__getattr__(head)(x)
return [ret]
+
def get_pose_net(num_layers, heads, head_conv=256, down_ratio=4):
- model = DLASeg('dla{}'.format(num_layers), heads,
- pretrained=True,
- down_ratio=down_ratio,
- head_conv=head_conv)
- return model
\ No newline at end of file
+ model = DLASeg(
+ "dla{}".format(num_layers),
+ heads,
+ pretrained=True,
+ down_ratio=down_ratio,
+ head_conv=head_conv,
+ )
+ return model
diff --git a/PytorchWildlife/models/detection/herdnet/herdnet.py b/PytorchWildlife/models/detection/herdnet/herdnet.py
index 4e2a41963..43edd8382 100644
--- a/PytorchWildlife/models/detection/herdnet/herdnet.py
+++ b/PytorchWildlife/models/detection/herdnet/herdnet.py
@@ -1,56 +1,65 @@
-from ..base_detector import BaseDetector
-from ..herdnet.animaloc.eval import HerdNetStitcher, HerdNetLMDS
-from ....data import datasets as pw_data
-from .model import HerdNet as HerdNetArch
+import os
+import cv2
+import numpy as np
+import supervision as sv
import torch
-from torch.hub import load_state_dict_from_url
-from torch.utils.data import DataLoader
import torchvision.transforms as transforms
-
-import numpy as np
+import wget
from PIL import Image
+from torch.hub import load_state_dict_from_url
+from torch.utils.data import DataLoader
from tqdm import tqdm
-import supervision as sv
-import os
-import wget
-import cv2
-
-class ResizeIfSmaller:
- def __init__(self, min_size, interpolation=Image.BILINEAR):
- self.min_size = min_size
- self.interpolation = interpolation
-
+
+from ....data import datasets as pw_data
+from ..base_detector import BaseDetector
+from ..herdnet.animaloc.eval import HerdNetLMDS, HerdNetStitcher
+from .model import HerdNet as HerdNetArch
+
+
+class ResizeIfSmaller:
+ def __init__(self, min_size, interpolation=Image.BILINEAR):
+ self.min_size = min_size
+ self.interpolation = interpolation
+
def __call__(self, img):
if isinstance(img, np.ndarray):
img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
- assert isinstance(img, Image.Image), "Image should be a PIL Image"
+ assert isinstance(img, Image.Image), "Image should be a PIL Image"
width, height = img.size
- if height < self.min_size or width < self.min_size:
- ratio = max(self.min_size / height, self.min_size / width)
- new_height = int(height * ratio)
- new_width = int(width * ratio)
- img = img.resize((new_width, new_height), self.interpolation)
- return img
+ if height < self.min_size or width < self.min_size:
+ ratio = max(self.min_size / height, self.min_size / width)
+ new_height = int(height * ratio)
+ new_width = int(width * ratio)
+ img = img.resize((new_width, new_height), self.interpolation)
+ return img
+
class HerdNet(BaseDetector):
"""
HerdNet detector class. This class provides utility methods for
loading the model, generating results, and performing single and batch image detections.
"""
-
- def __init__(self, weights=None, device="cpu", version='general' ,url="https://zenodo.org/records/13899852/files/20220413_HerdNet_General_dataset_2022.pth?download=1", transform=None):
+
+ def __init__(
+ self,
+ weights=None,
+ device="cpu",
+ version="general",
+ url="https://zenodo.org/records/13899852/files/20220413_HerdNet_General_dataset_2022.pth?download=1",
+ transform=None,
+ ):
"""
Initialize the HerdNet detector.
-
+
Args:
- weights (str, optional):
+ weights (str, optional):
Path to the model weights. Defaults to None.
- device (str, optional):
+ device (str, optional):
Device for model inference. Defaults to "cpu".
version (str, optional):
Version name based on what dataset the model is trained on. It should be either 'general' or 'ennedi'. Defaults to 'general'.
- url (str, optional):
+ url (str, optional):
URL to fetch the model weights. Defaults to None.
transform (torchvision.transforms.Compose, optional):
Image transformation for inference. Defaults to None.
@@ -58,43 +67,45 @@ def __init__(self, weights=None, device="cpu", version='general' ,url="https://z
super(HerdNet, self).__init__(weights=weights, device=device, url=url)
# Assert that the dataset is either 'general' or 'ennedi'
version = version.lower()
- assert version in ['general', 'ennedi'], "Dataset should be either 'general' or 'ennedi'"
- if version == 'ennedi':
+ assert version in ["general", "ennedi"], "Dataset should be either 'general' or 'ennedi'"
+ if version == "ennedi":
url = "https://zenodo.org/records/13914287/files/20220329_HerdNet_Ennedi_dataset_2023.pth?download=1"
self._load_model(weights, device, url)
- self.stitcher = HerdNetStitcher( # This module enables patch-based inference
- model = self.model,
- size = (512,512),
- overlap = 160,
- down_ratio = 2,
- up = True,
- reduction = 'mean',
- device_name = device
- )
-
- self.lmds_kwargs: dict = {'kernel_size': (3, 3), 'adapt_ts': 0.2, 'neg_ts': 0.1}
- self.lmds = HerdNetLMDS(up=False, **self.lmds_kwargs) # Local Maxima Detection Strategy
+ self.stitcher = HerdNetStitcher( # This module enables patch-based inference
+ model=self.model,
+ size=(512, 512),
+ overlap=160,
+ down_ratio=2,
+ up=True,
+ reduction="mean",
+ device_name=device,
+ )
+
+ self.lmds_kwargs: dict = {"kernel_size": (3, 3), "adapt_ts": 0.2, "neg_ts": 0.1}
+ self.lmds = HerdNetLMDS(up=False, **self.lmds_kwargs) # Local Maxima Detection Strategy
if not transform:
- self.transforms = transforms.Compose([
- ResizeIfSmaller(512),
- transforms.ToTensor(),
- transforms.Normalize(mean=self.img_mean, std=self.img_std)
- ])
+ self.transforms = transforms.Compose(
+ [
+ ResizeIfSmaller(512),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=self.img_mean, std=self.img_std),
+ ]
+ )
else:
self.transforms = transform
def _load_model(self, weights=None, device="cpu", url=None):
"""
Load the HerdNet model weights.
-
+
Args:
- weights (str, optional):
+ weights (str, optional):
Path to the model weights. Defaults to None.
- device (str, optional):
+ device (str, optional):
Device for model inference. Defaults to "cpu".
- url (str, optional):
+ url (str, optional):
URL to fetch the model weights. Defaults to None.
Raises:
Exception: If weights are not provided.
@@ -102,7 +113,9 @@ def _load_model(self, weights=None, device="cpu", url=None):
if weights:
checkpoint = torch.load(weights, map_location=torch.device(device))
elif url:
- filename = url.split('/')[-1][:-11] # Splitting the URL to get the filename and removing the '?download=1' part
+ filename = url.split("/")[-1][
+ :-11
+ ] # Splitting the URL to get the filename and removing the '?download=1' part
if not os.path.exists(os.path.join(torch.hub.get_dir(), "checkpoints", filename)):
os.makedirs(os.path.join(torch.hub.get_dir(), "checkpoints"), exist_ok=True)
weights = wget.download(url, out=os.path.join(torch.hub.get_dir(), "checkpoints"))
@@ -111,26 +124,30 @@ def _load_model(self, weights=None, device="cpu", url=None):
checkpoint = torch.load(weights, map_location=torch.device(device))
else:
raise Exception("Need weights for inference.")
-
+
# Load the class names and other metadata from the checkpoint
self.CLASS_NAMES = checkpoint["classes"]
self.num_classes = len(self.CLASS_NAMES) + 1
- self.img_mean = checkpoint['mean']
- self.img_std = checkpoint['std']
+ self.img_mean = checkpoint["mean"]
+ self.img_std = checkpoint["std"]
# Load the model architecture
self.model = HerdNetArch(num_classes=self.num_classes, pretrained=False)
# Load checkpoint into model
- state_dict = checkpoint['model_state_dict']
+ state_dict = checkpoint["model_state_dict"]
# Remove 'model.' prefix from the state_dict keys if the key starts with 'model.'
- new_state_dict = {k.replace('model.', ''): v for k, v in state_dict.items() if k.startswith('model.')}
- # Load the new state_dict
+ new_state_dict = {
+ k.replace("model.", ""): v for k, v in state_dict.items() if k.startswith("model.")
+ }
+ # Load the new state_dict
self.model.load_state_dict(new_state_dict, strict=True)
print(f"Model loaded from {weights}")
- def results_generation(self, preds: np.ndarray, img: np.ndarray = None, img_id: str = None, id_strip: str = None) -> dict:
+ def results_generation(
+ self, preds: np.ndarray, img: np.ndarray = None, img_id: str = None, id_strip: str = None
+ ) -> dict:
"""
Generate results for detection based on model predictions.
@@ -151,52 +168,63 @@ def results_generation(self, preds: np.ndarray, img: np.ndarray = None, img_id:
results = {"img": img}
results["detections"] = sv.Detections(
- xyxy=preds[:, :4],
- confidence=preds[:, 4],
- class_id=preds[:, 5].astype(int)
+ xyxy=preds[:, :4], confidence=preds[:, 4], class_id=preds[:, 5].astype(int)
)
results["labels"] = [
f"{self.CLASS_NAMES[class_id]} {confidence:0.2f}"
- for confidence, class_id in zip(results["detections"].confidence, results["detections"].class_id)
+ for confidence, class_id in zip(
+ results["detections"].confidence, results["detections"].class_id
+ )
]
return results
-
- def single_image_detection(self, img, img_path=None, det_conf_thres=0.2, clf_conf_thres=0.2, id_strip=None) -> dict:
+
+ def single_image_detection(
+ self, img, img_path=None, det_conf_thres=0.2, clf_conf_thres=0.2, id_strip=None
+ ) -> dict:
"""
Perform detection on a single image.
Args:
- img (str or np.ndarray):
+ img (str or np.ndarray):
Image for inference.
- img_path (str, optional):
+ img_path (str, optional):
Path to the image. Defaults to None.
det_conf_thres (float, optional):
Confidence threshold for detections. Defaults to 0.2.
clf_conf_thres (float, optional):
Confidence threshold for classification. Defaults to 0.2.
- id_strip (str, optional):
+ id_strip (str, optional):
Characters to strip from img_id. Defaults to None.
Returns:
dict: Detection results for the image.
"""
- if isinstance(img, str):
- img_path = img_path or img
- img = np.array(Image.open(img_path).convert("RGB"))
- if self.transforms:
+ if isinstance(img, str):
+ img_path = img_path or img
+ img = np.array(Image.open(img_path).convert("RGB"))
+ if self.transforms:
img_tensor = self.transforms(img)
- preds = self.stitcher(img_tensor)
- heatmap, clsmap = preds[:,:1,:,:], preds[:,1:,:,:]
+ preds = self.stitcher(img_tensor)
+ heatmap, clsmap = preds[:, :1, :, :], preds[:, 1:, :, :]
counts, locs, labels, scores, dscores = self.lmds((heatmap, clsmap))
- preds_array = self.process_lmds_results(counts, locs, labels, scores, dscores, det_conf_thres, clf_conf_thres)
+ preds_array = self.process_lmds_results(
+ counts, locs, labels, scores, dscores, det_conf_thres, clf_conf_thres
+ )
if img_path:
results_dict = self.results_generation(preds_array, img_id=img_path, id_strip=id_strip)
else:
results_dict = self.results_generation(preds_array, img=img)
return results_dict
- def batch_image_detection(self, data_path: str, det_conf_thres: float = 0.2, clf_conf_thres: float = 0.2, batch_size: int = 1, id_strip: str = None) -> list[dict]:
+ def batch_image_detection(
+ self,
+ data_path: str,
+ det_conf_thres: float = 0.2,
+ clf_conf_thres: float = 0.2,
+ batch_size: int = 1,
+ id_strip: str = None,
+ ) -> list[dict]:
"""
Perform detection on a batch of images.
@@ -210,32 +238,51 @@ def batch_image_detection(self, data_path: str, det_conf_thres: float = 0.2, clf
Returns:
list[dict]: List of detection results for all images.
"""
- dataset = pw_data.DetectionImageFolder(
- data_path,
- transform=self.transforms
- )
+ dataset = pw_data.DetectionImageFolder(data_path, transform=self.transforms)
# Creating a Dataloader for batching and parallel processing of the images
- loader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
- pin_memory=True, num_workers=0, drop_last=False) # TODO: discuss. why is num_workers 0?
-
+ loader = DataLoader(
+ dataset,
+ batch_size=batch_size,
+ shuffle=False,
+ pin_memory=True,
+ num_workers=0,
+ drop_last=False,
+ ) # TODO: discuss. why is num_workers 0?
+
results = []
with tqdm(total=len(loader)) as pbar:
for batch_index, (imgs, paths, sizes) in enumerate(loader):
imgs = imgs.to(self.device)
predictions = self.stitcher(imgs[0]).detach().cpu()
- heatmap, clsmap = predictions[:,:1,:,:], predictions[:,1:,:,:]
+ heatmap, clsmap = predictions[:, :1, :, :], predictions[:, 1:, :, :]
counts, locs, labels, scores, dscores = self.lmds((heatmap, clsmap))
- preds_array = self.process_lmds_results(counts, locs, labels, scores, dscores, det_conf_thres, clf_conf_thres)
- results_dict = self.results_generation(preds_array, img_id=paths[0], id_strip=id_strip)
+ preds_array = self.process_lmds_results(
+ counts, locs, labels, scores, dscores, det_conf_thres, clf_conf_thres
+ )
+ results_dict = self.results_generation(
+ preds_array, img_id=paths[0], id_strip=id_strip
+ )
pbar.update(1)
sizes = sizes.numpy()
- normalized_coords = [[x1 / sizes[0][0], y1 / sizes[0][1], x2 / sizes[0][0], y2 / sizes[0][1]] for x1, y1, x2, y2 in preds_array[:, :4]] # TODO: Check if this is correct due to xy swapping
- results_dict['normalized_coords'] = normalized_coords
+ normalized_coords = [
+ [x1 / sizes[0][0], y1 / sizes[0][1], x2 / sizes[0][0], y2 / sizes[0][1]]
+ for x1, y1, x2, y2 in preds_array[:, :4]
+ ] # TODO: Check if this is correct due to xy swapping
+ results_dict["normalized_coords"] = normalized_coords
results.append(results_dict)
return results
- def process_lmds_results(self, counts: list, locs: list, labels: list, scores: list, dscores: list, det_conf_thres: float = 0.2, clf_conf_thres: float = 0.2) -> np.ndarray:
+ def process_lmds_results(
+ self,
+ counts: list,
+ locs: list,
+ labels: list,
+ scores: list,
+ dscores: list,
+ det_conf_thres: float = 0.2,
+ clf_conf_thres: float = 0.2,
+ ) -> np.ndarray:
"""
Process the results from the Local Maxima Detection Strategy.
@@ -251,29 +298,31 @@ def process_lmds_results(self, counts: list, locs: list, labels: list, scores: l
Returns:
numpy.ndarray: Processed detection results.
"""
- # Flatten the lists since we know its a single image
- counts = counts[0]
- locs = locs[0]
- labels = labels[0]
+ # Flatten the lists since we know its a single image
+ counts = counts[0]
+ locs = locs[0]
+ labels = labels[0]
scores = scores[0]
- dscores = dscores[0]
-
- # Calculate the total number of detections
- total_detections = sum(counts)
-
- # Pre-allocate based on total possible detections
- preds_array = np.empty((total_detections, 6)) #xyxy, confidence, class_id format
+ dscores = dscores[0]
+
+ # Calculate the total number of detections
+ total_detections = sum(counts)
+
+ # Pre-allocate based on total possible detections
+ preds_array = np.empty((total_detections, 6)) # xyxy, confidence, class_id format
detection_idx = 0
- valid_detections_idx = 0 # Index for valid detections after applying the confidence threshold
- # Loop through each species
- for specie_idx in range(len(counts)):
- count = counts[specie_idx]
- if count == 0:
- continue
-
- # Get the detections for this species
+ valid_detections_idx = (
+ 0 # Index for valid detections after applying the confidence threshold
+ )
+ # Loop through each species
+ for specie_idx in range(len(counts)):
+ count = counts[specie_idx]
+ if count == 0:
+ continue
+
+ # Get the detections for this species
species_locs = np.array(locs[detection_idx : detection_idx + count])
- species_locs[:, [0, 1]] = species_locs[:, [1, 0]] # Swap x and y in species_locs
+ species_locs[:, [0, 1]] = species_locs[:, [1, 0]] # Swap x and y in species_locs
species_scores = np.array(scores[detection_idx : detection_idx + count])
species_dscores = np.array(dscores[detection_idx : detection_idx + count])
species_labels = np.array(labels[detection_idx : detection_idx + count])
@@ -281,30 +330,40 @@ def process_lmds_results(self, counts: list, locs: list, labels: list, scores: l
# Apply the confidence threshold
valid_detections_by_clf_score = species_scores > clf_conf_thres
valid_detections_by_det_score = species_dscores > det_conf_thres
- valid_detections = np.logical_and(valid_detections_by_clf_score, valid_detections_by_det_score)
+ valid_detections = np.logical_and(
+ valid_detections_by_clf_score, valid_detections_by_det_score
+ )
valid_detections_count = np.sum(valid_detections)
valid_detections_idx += valid_detections_count
# Fill the preds_array with the valid detections
if valid_detections_count > 0:
- preds_array[valid_detections_idx - valid_detections_count : valid_detections_idx, :2] = species_locs[valid_detections] - 1
- preds_array[valid_detections_idx - valid_detections_count : valid_detections_idx, 2:4] = species_locs[valid_detections] + 1
- preds_array[valid_detections_idx - valid_detections_count : valid_detections_idx, 4] = species_scores[valid_detections]
- preds_array[valid_detections_idx - valid_detections_count : valid_detections_idx, 5] = species_labels[valid_detections]
-
- detection_idx += count # Move to the next species
-
- preds_array = preds_array[:valid_detections_idx] # Remove the empty rows
-
+ preds_array[
+ valid_detections_idx - valid_detections_count : valid_detections_idx, :2
+ ] = (species_locs[valid_detections] - 1)
+ preds_array[
+ valid_detections_idx - valid_detections_count : valid_detections_idx, 2:4
+ ] = (species_locs[valid_detections] + 1)
+ preds_array[
+ valid_detections_idx - valid_detections_count : valid_detections_idx, 4
+ ] = species_scores[valid_detections]
+ preds_array[
+ valid_detections_idx - valid_detections_count : valid_detections_idx, 5
+ ] = species_labels[valid_detections]
+
+ detection_idx += count # Move to the next species
+
+ preds_array = preds_array[:valid_detections_idx] # Remove the empty rows
+
return preds_array
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the model.
-
+
Args:
- input (torch.Tensor):
+ input (torch.Tensor):
Input tensor for the model.
-
+
Returns:
torch.Tensor: Model output.
"""
diff --git a/PytorchWildlife/models/detection/herdnet/model.py b/PytorchWildlife/models/detection/herdnet/model.py
index f26bb6ecd..fa31ce425 100644
--- a/PytorchWildlife/models/detection/herdnet/model.py
+++ b/PytorchWildlife/models/detection/herdnet/model.py
@@ -1,5 +1,4 @@
-__copyright__ = \
- """
+__copyright__ = """
Copyright (C) 2024 University of Liège, Gembloux Agro-Bio Tech, Forest Is Life
All rights reserved.
@@ -14,47 +13,52 @@
__version__ = "0.2.1"
-import torch
+from typing import Optional
-import torch.nn as nn
import numpy as np
+import torch
+import torch.nn as nn
import torchvision.transforms as T
-from typing import Optional
-
from . import dla as dla_modules
+
class HerdNet(nn.Module):
- ''' HerdNet architecture '''
+ """HerdNet architecture"""
def __init__(
self,
num_layers: int = 34,
num_classes: int = 2,
- pretrained: bool = True,
- down_ratio: Optional[int] = 2,
- head_conv: int = 64
- ):
- '''
+ pretrained: bool = True,
+ down_ratio: Optional[int] = 2,
+ head_conv: int = 64,
+ ):
+ """
Args:
num_layers (int, optional): number of layers of DLA. Defaults to 34.
- num_classes (int, optional): number of output classes, background included.
+ num_classes (int, optional): number of output classes, background included.
Defaults to 2.
pretrained (bool, optional): set False to disable pretrained DLA encoder parameters
from ImageNet. Defaults to True.
- down_ratio (int, optional): downsample ratio. Possible values are 1, 2, 4, 8, or 16.
+ down_ratio (int, optional): downsample ratio. Possible values are 1, 2, 4, 8, or 16.
Set to 1 to get output of the same size as input (i.e. no downsample).
Defaults to 2.
- head_conv (int, optional): number of supplementary convolutional layers at the end
+ head_conv (int, optional): number of supplementary convolutional layers at the end
of decoder. Defaults to 64.
- '''
+ """
super(HerdNet, self).__init__()
- assert down_ratio in [1, 2, 4, 8, 16], \
- f'Downsample ratio possible values are 1, 2, 4, 8 or 16, got {down_ratio}'
-
- base_name = 'dla{}'.format(num_layers)
+ assert down_ratio in [
+ 1,
+ 2,
+ 4,
+ 8,
+ 16,
+ ], f"Downsample ratio possible values are 1, 2, 4, 8 or 16, got {down_ratio}"
+
+ base_name = "dla{}".format(num_layers)
self.down_ratio = down_ratio
self.num_classes = num_classes
@@ -64,58 +68,45 @@ def __init__(
# backbone
base = dla_modules.__dict__[base_name](pretrained=pretrained, return_levels=True)
- setattr(self, 'base_0', base)
- setattr(self, 'channels_0', base.channels)
+ setattr(self, "base_0", base)
+ setattr(self, "channels_0", base.channels)
channels = self.channels_0
- scales = [2 ** i for i in range(len(channels[self.first_level:]))]
- self.dla_up = dla_modules.DLAUp(channels[self.first_level:], scales=scales)
+ scales = [2**i for i in range(len(channels[self.first_level :]))]
+ self.dla_up = dla_modules.DLAUp(channels[self.first_level :], scales=scales)
# self.cls_dla_up = dla_modules.DLAUp(channels[-3:], scales=scales[:3])
# bottleneck conv
self.bottleneck_conv = nn.Conv2d(
- channels[-1], channels[-1],
- kernel_size=1, stride=1,
- padding=0, bias=True
+ channels[-1], channels[-1], kernel_size=1, stride=1, padding=0, bias=True
)
# localization head
self.loc_head = nn.Sequential(
- nn.Conv2d(channels[self.first_level], head_conv,
- kernel_size=3, padding=1, bias=True),
+ nn.Conv2d(channels[self.first_level], head_conv, kernel_size=3, padding=1, bias=True),
nn.ReLU(inplace=True),
- nn.Conv2d(
- head_conv, 1,
- kernel_size=1, stride=1,
- padding=0, bias=True
- ),
- nn.Sigmoid()
- )
+ nn.Conv2d(head_conv, 1, kernel_size=1, stride=1, padding=0, bias=True),
+ nn.Sigmoid(),
+ )
self.loc_head[-2].bias.data.fill_(0.00)
# classification head
self.cls_head = nn.Sequential(
- nn.Conv2d(channels[-1], head_conv,
- kernel_size=3, padding=1, bias=True),
+ nn.Conv2d(channels[-1], head_conv, kernel_size=3, padding=1, bias=True),
nn.ReLU(inplace=True),
- nn.Conv2d(
- head_conv, self.num_classes,
- kernel_size=1, stride=1,
- padding=0, bias=True
- )
- )
+ nn.Conv2d(head_conv, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True),
+ )
self.cls_head[-1].bias.data.fill_(0.00)
-
- def forward(self, input: torch.Tensor):
- encode = self.base_0(input)
+ def forward(self, input: torch.Tensor):
+ encode = self.base_0(input)
bottleneck = self.bottleneck_conv(encode[-1])
encode[-1] = bottleneck
- decode_hm = self.dla_up(encode[self.first_level:])
+ decode_hm = self.dla_up(encode[self.first_level :])
# decode_cls = self.cls_dla_up(encode[-3:])
heatmap = self.loc_head(decode_hm)
@@ -123,29 +114,27 @@ def forward(self, input: torch.Tensor):
# clsmap = self.cls_head(decode_cls)
return heatmap, clsmap
-
+
def freeze(self, layers: list) -> None:
- ''' Freeze all layers mentioned in the input list '''
+ """Freeze all layers mentioned in the input list"""
for layer in layers:
self._freeze_layer(layer)
-
+
def _freeze_layer(self, layer_name: str) -> None:
for param in getattr(self, layer_name).parameters():
param.requires_grad = False
-
+
def reshape_classes(self, num_classes: int) -> None:
- ''' Reshape architecture according to a new number of classes.
+ """Reshape architecture according to a new number of classes.
Arg:
num_classes (int): new number of classes
- '''
-
+ """
+
self.cls_head[-1] = nn.Conv2d(
- self.head_conv, num_classes,
- kernel_size=1, stride=1,
- padding=0, bias=True
- )
+ self.head_conv, num_classes, kernel_size=1, stride=1, padding=0, bias=True
+ )
self.cls_head[-1].bias.data.fill_(0.00)
- self.num_classes = num_classes
\ No newline at end of file
+ self.num_classes = num_classes
diff --git a/PytorchWildlife/models/detection/rtdetr_apache/__init__.py b/PytorchWildlife/models/detection/rtdetr_apache/__init__.py
index 77db993c5..11866034d 100644
--- a/PytorchWildlife/models/detection/rtdetr_apache/__init__.py
+++ b/PytorchWildlife/models/detection/rtdetr_apache/__init__.py
@@ -1,2 +1,2 @@
+from .megadetectorv6_apache import *
from .rtdetr_apache_base import *
-from .megadetectorv6_apache import *
\ No newline at end of file
diff --git a/PytorchWildlife/models/detection/rtdetr_apache/megadetectorv6_apache.py b/PytorchWildlife/models/detection/rtdetr_apache/megadetectorv6_apache.py
index 2e2cf10dd..2741165c5 100644
--- a/PytorchWildlife/models/detection/rtdetr_apache/megadetectorv6_apache.py
+++ b/PytorchWildlife/models/detection/rtdetr_apache/megadetectorv6_apache.py
@@ -1,29 +1,23 @@
-
from .rtdetr_apache_base import RTDETRApacheBase
-__all__ = [
- 'MegaDetectorV6Apache'
-]
+__all__ = ["MegaDetectorV6Apache"]
+
class MegaDetectorV6Apache(RTDETRApacheBase):
"""
- MegaDetectorV6 is a specialized class derived from the RTDETRApacheBase class
+ MegaDetectorV6 is a specialized class derived from the RTDETRApacheBase class
that is specifically designed for detecting animals, persons, and vehicles.
-
+
Attributes:
CLASS_NAMES (dict): Mapping of class IDs to their respective names.
"""
-
- CLASS_NAMES = {
- 0: "animal",
- 1: "person",
- 2: "vehicle"
- }
-
- def __init__(self, weights=None, device="cpu", pretrained=True, version='MDV6-rtdetr-x-apache'):
+
+ CLASS_NAMES = {0: "animal", 1: "person", 2: "vehicle"}
+
+ def __init__(self, weights=None, device="cpu", pretrained=True, version="MDV6-rtdetr-x-apache"):
"""
Initializes the MegaDetectorV6 model with the option to load pretrained weights.
-
+
Args:
weights (str, optional): Path to the weights file.
device (str, optional): Device to load the model on (e.g., "cpu" or "cuda"). Default is "cpu".
@@ -39,6 +33,6 @@ def __init__(self, weights=None, device="cpu", pretrained=True, version='MDV6-rt
url = "https://zenodo.org/records/15398270/files/MDV6-apa-rtdetr-e.pth?download=1"
self.MODEL_NAME = "MDV6-apa-rtdetr-e.pth"
else:
- raise ValueError('Select a valid model version: MDV6-apa-rtdetr-c or MDV6-apa-rtdetr-e')
+ raise ValueError("Select a valid model version: MDV6-apa-rtdetr-c or MDV6-apa-rtdetr-e")
- super(MegaDetectorV6Apache, self).__init__(weights=weights, device=device, url=url)
\ No newline at end of file
+ super(MegaDetectorV6Apache, self).__init__(weights=weights, device=device, url=url)
diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetr_apache_base.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetr_apache_base.py
index 4e5004e38..14d44a082 100644
--- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetr_apache_base.py
+++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetr_apache_base.py
@@ -5,59 +5,66 @@
# Importing basic libraries
import os
+import sys
+from pathlib import Path
+
import supervision as sv
-import wget
import torch
-import torch.nn as nn
+import torch.nn as nn
import torchvision.transforms as T
+import wget
+from PIL import Image
-from ..base_detector import BaseDetector
from ....data import datasets as pw_data
-from PIL import Image
+from ..base_detector import BaseDetector
-import sys
-from pathlib import Path
project_root = Path(__file__).resolve().parent
sys.path.append(str(project_root))
from rtdetrv2_pytorch.src.core import YAMLConfig
+
class RTDETRApacheBase(BaseDetector):
"""
Base detector class for RTDETRApacheBase framework. This class provides utility methods for
loading the model, generating results, and performing single and batch image detections.
"""
+
def __init__(self, weights=None, device="cpu", url=None):
"""
Initialize the RT-DETR apache detector.
-
+
Args:
- weights (str, optional):
+ weights (str, optional):
Path to the model weights. Defaults to None.
- device (str, optional):
+ device (str, optional):
Device for model inference. Defaults to "cpu".
- url (str, optional):
+ url (str, optional):
URL to fetch the model weights. Defaults to None.
"""
- self.transform = T.Compose([
- T.Resize((640, 640)),
- T.ToTensor(),
- ])
+ self.transform = T.Compose(
+ [
+ T.Resize((640, 640)),
+ T.ToTensor(),
+ ]
+ )
self.weights = weights
self.device = device
self.url = url
- super(RTDETRApacheBase, self).__init__(weights=self.weights, device=self.device, url=self.url)
+ super(RTDETRApacheBase, self).__init__(
+ weights=self.weights, device=self.device, url=self.url
+ )
self._load_model(weights=self.weights, device=self.device, url=self.url)
def _load_model(self, weights=None, device="cpu", url=None):
"""
Load the RT-DETR apache model weights.
-
+
Args:
- weights (str, optional):
+ weights (str, optional):
Path to the model weights. Defaults to None.
- device (str, optional):
+ device (str, optional):
Device for model inference. Defaults to "cpu".
- url (str, optional):
+ url (str, optional):
URL to fetch the model weights. Defaults to None.
Raises:
Exception: If weights are not provided.
@@ -65,7 +72,9 @@ def _load_model(self, weights=None, device="cpu", url=None):
if weights:
resume = weights
elif url:
- if not os.path.exists(os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME)):
+ if not os.path.exists(
+ os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME)
+ ):
os.makedirs(os.path.join(torch.hub.get_dir(), "checkpoints"), exist_ok=True)
resume = wget.download(url, out=os.path.join(torch.hub.get_dir(), "checkpoints"))
else:
@@ -74,45 +83,59 @@ def _load_model(self, weights=None, device="cpu", url=None):
raise Exception("Need weights for inference.")
if self.MODEL_NAME == "MDV6-apa-rtdetr-c.pth":
- config = os.path.join(project_root, "rtdetrv2_pytorch", "configs", "rtdetrv2", "rtdetrv2_r18vd_120e_megadetector.yml")
+ config = os.path.join(
+ project_root,
+ "rtdetrv2_pytorch",
+ "configs",
+ "rtdetrv2",
+ "rtdetrv2_r18vd_120e_megadetector.yml",
+ )
elif self.MODEL_NAME == "MDV6-apa-rtdetr-e.pth":
- config = os.path.join(project_root, "rtdetrv2_pytorch", "configs", "rtdetrv2", "rtdetrv2_r101vd_6x_megadetector.yml")
+ config = os.path.join(
+ project_root,
+ "rtdetrv2_pytorch",
+ "configs",
+ "rtdetrv2",
+ "rtdetrv2_r101vd_6x_megadetector.yml",
+ )
else:
- raise ValueError('Select a valid model version: MDV6-apa-rtdetr-c or MDV6-apa-rtdetr-e')
-
+ raise ValueError("Select a valid model version: MDV6-apa-rtdetr-c or MDV6-apa-rtdetr-e")
+
cfg = YAMLConfig(config, resume=resume)
-
- checkpoint = torch.load(resume, map_location='cpu')
- if 'ema' in checkpoint:
- state = checkpoint['ema']['module']
+
+ checkpoint = torch.load(resume, map_location="cpu")
+ if "ema" in checkpoint:
+ state = checkpoint["ema"]["module"]
else:
- state = checkpoint['model']
+ state = checkpoint["model"]
cfg.model.load_state_dict(state)
class Model(nn.Module):
- def __init__(self, ) -> None:
+ def __init__(
+ self,
+ ) -> None:
super().__init__()
self.model = cfg.model.deploy()
self.postprocessor = cfg.postprocessor.deploy()
-
+
def forward(self, images, orig_target_sizes):
outputs = self.model(images)
outputs = self.postprocessor(outputs, orig_target_sizes)
return outputs
-
+
self.model = Model().to(self.device)
def results_generation(self, preds, img_id, id_strip=None):
"""
Generate results for detection based on model predictions.
-
+
Args:
- preds (List[torch.Tensor]):
+ preds (List[torch.Tensor]):
Model predictions.
- img_id (str):
+ img_id (str):
Image identifier.
- id_strip (str, optional):
+ id_strip (str, optional):
Strip specific characters from img_id. Defaults to None.
Returns:
@@ -123,32 +146,27 @@ def results_generation(self, preds, img_id, id_strip=None):
confidence = preds[2].detach().cpu().numpy()
results = {"img_id": str(img_id).strip(id_strip)}
- results["detections"] = sv.Detections(
- xyxy=xyxy,
- confidence=confidence,
- class_id=class_id
- )
+ results["detections"] = sv.Detections(xyxy=xyxy, confidence=confidence, class_id=class_id)
results["labels"] = [
- f"{self.CLASS_NAMES[class_id]} {confidence:0.2f}"
- for _, _, confidence, class_id, _, _ in results["detections"]
+ f"{self.CLASS_NAMES[class_id]} {confidence:0.2f}"
+ for _, _, confidence, class_id, _, _ in results["detections"]
]
-
+
return results
-
def single_image_detection(self, img, img_path=None, det_conf_thres=0.2, id_strip=None):
"""
Perform detection on a single image.
-
+
Args:
- img (str or ndarray):
+ img (str or ndarray):
Image path or ndarray of images.
- img_path (str, optional):
+ img_path (str, optional):
Image path or identifier.
- det_conf_thres (float, optional):
+ det_conf_thres (float, optional):
Confidence threshold for predictions. Defaults to 0.2.
- id_strip (str, optional):
+ id_strip (str, optional):
Characters to strip from img_id. Defaults to None.
Returns:
@@ -157,7 +175,7 @@ def single_image_detection(self, img, img_path=None, det_conf_thres=0.2, id_stri
if type(img) == str:
if img_path is None:
img_path = img
- im_pil = Image.open(img_path).convert('RGB')
+ im_pil = Image.open(img_path).convert("RGB")
else:
im_pil = Image.fromarray(img)
@@ -170,21 +188,21 @@ def single_image_detection(self, img, img_path=None, det_conf_thres=0.2, id_stri
lab = labels[0][scr > det_conf_thres]
box = boxes[0][scr > det_conf_thres]
scrs = scores[0][scr > det_conf_thres]
-
+
return self.results_generation([lab, box, scrs], img_path, id_strip)
def batch_image_detection(self, data_path, batch_size=16, det_conf_thres=0.2, id_strip=None):
"""
Perform detection on a batch of images.
-
+
Args:
- data_path (str):
+ data_path (str):
Path containing all images for inference.
batch_size (int, optional):
Batch size for inference. Defaults to 16.
- det_conf_thres (float, optional):
+ det_conf_thres (float, optional):
Confidence threshold for predictions. Defaults to 0.2.
- id_strip (str, optional):
+ id_strip (str, optional):
Characters to strip from img_id. Defaults to None.
extension (str, optional):
Image extension to search for. Defaults to "JPG"
@@ -196,10 +214,10 @@ def batch_image_detection(self, data_path, batch_size=16, det_conf_thres=0.2, id
data_path,
transform=self.transform,
)
-
+
results = []
for i in range(len(dataset)):
- im_pil = Image.open(dataset.images[i]).convert('RGB')
+ im_pil = Image.open(dataset.images[i]).convert("RGB")
w, h = im_pil.size
orig_size = torch.tensor([w, h])[None].to(self.device)
im_data = self.transform(im_pil)[None].to(self.device)
@@ -210,16 +228,16 @@ def batch_image_detection(self, data_path, batch_size=16, det_conf_thres=0.2, id
lab = labels[0][scr > det_conf_thres]
box = boxes[0][scr > det_conf_thres]
scrs = scores[0][scr > det_conf_thres]
-
+
res = self.results_generation([lab, box, scrs], dataset.images[i], id_strip)
# Normalize the coordinates for timelapse compatibility
size = orig_size[0].cpu().numpy()
- normalized_coords = [[x1 / size[1], y1 / size[0], x2 / size[1], y2 / size[0]] for x1, y1, x2, y2 in res["detections"].xyxy]
+ normalized_coords = [
+ [x1 / size[1], y1 / size[0], x2 / size[1], y2 / size[0]]
+ for x1, y1, x2, y2 in res["detections"].xyxy
+ ]
res["normalized_coords"] = normalized_coords
results.append(res)
return results
-
-
-
diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/dataset/megadetector_detection.yml b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/dataset/megadetector_detection.yml
index 0e4f046a6..4c4a6f420 100644
--- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/dataset/megadetector_detection.yml
+++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/dataset/megadetector_detection.yml
@@ -1,3 +1,3 @@
task: detection
num_classes: 3
-remap_mscoco_category: False
\ No newline at end of file
+remap_mscoco_category: False
diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/rtdetrv2/include/rtdetrv2_r50vd.yml b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/rtdetrv2/include/rtdetrv2_r50vd.yml
index a5c14909b..d7936f6eb 100644
--- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/rtdetrv2/include/rtdetrv2_r50vd.yml
+++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/rtdetrv2/include/rtdetrv2_r50vd.yml
@@ -80,4 +80,3 @@ RTDETRCriterionv2:
weight_dict: {cost_class: 2, cost_bbox: 5, cost_giou: 2}
alpha: 0.25
gamma: 2.0
-
diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/rtdetrv2/rtdetrv2_r18vd_120e_megadetector.yml b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/rtdetrv2/rtdetrv2_r18vd_120e_megadetector.yml
index c703545f0..0ad1881dc 100644
--- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/rtdetrv2/rtdetrv2_r18vd_120e_megadetector.yml
+++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/rtdetrv2/rtdetrv2_r18vd_120e_megadetector.yml
@@ -18,4 +18,4 @@ HybridEncoder:
RTDETRTransformerv2:
- num_layers: 3
\ No newline at end of file
+ num_layers: 3
diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/__init__.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/__init__.py
index 8b850c549..4f00ea85f 100644
--- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/__init__.py
+++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/__init__.py
@@ -2,5 +2,4 @@
"""
# for register purpose
-from . import backbone
-from . import rtdetr
\ No newline at end of file
+from . import backbone, rtdetr
diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/__init__.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/__init__.py
index 53ab01265..7c4cdb60c 100644
--- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/__init__.py
+++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/__init__.py
@@ -1,4 +1,4 @@
"""Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""
-from .presnet import PResNet
\ No newline at end of file
+from .presnet import PResNet
diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/common.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/common.py
index e3f54ea2e..238cb9b91 100644
--- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/common.py
+++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/common.py
@@ -1,7 +1,7 @@
"""Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""
-import torch
+import torch
import torch.nn as nn
@@ -12,6 +12,7 @@ class FrozenBatchNorm2d(nn.Module):
without which any other models than torchvision.models.resnet[18,34,50,101]
produce nans.
"""
+
def __init__(self, num_features, eps=1e-5):
super(FrozenBatchNorm2d, self).__init__()
n = num_features
@@ -20,17 +21,18 @@ def __init__(self, num_features, eps=1e-5):
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
self.eps = eps
- self.num_features = n
+ self.num_features = n
- def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
- missing_keys, unexpected_keys, error_msgs):
- num_batches_tracked_key = prefix + 'num_batches_tracked'
+ def _load_from_state_dict(
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+ ):
+ num_batches_tracked_key = prefix + "num_batches_tracked"
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super(FrozenBatchNorm2d, self)._load_from_state_dict(
- state_dict, prefix, local_metadata, strict,
- missing_keys, unexpected_keys, error_msgs)
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+ )
def forward(self, x):
# move reshapes to the beginning
@@ -44,43 +46,41 @@ def forward(self, x):
return x * scale + bias
def extra_repr(self):
- return (
- "{num_features}, eps={eps}".format(**self.__dict__)
- )
+ return "{num_features}, eps={eps}".format(**self.__dict__)
-def get_activation(act: str, inplace: bool=True):
- """get activation
- """
+
+def get_activation(act: str, inplace: bool = True):
+ """get activation"""
if act is None:
return nn.Identity()
elif isinstance(act, nn.Module):
- return act
+ return act
act = act.lower()
-
- if act == 'silu' or act == 'swish':
+
+ if act == "silu" or act == "swish":
m = nn.SiLU()
- elif act == 'relu':
+ elif act == "relu":
m = nn.ReLU()
- elif act == 'leaky_relu':
+ elif act == "leaky_relu":
m = nn.LeakyReLU()
- elif act == 'silu':
+ elif act == "silu":
m = nn.SiLU()
-
- elif act == 'gelu':
+
+ elif act == "gelu":
m = nn.GELU()
- elif act == 'hardsigmoid':
+ elif act == "hardsigmoid":
m = nn.Hardsigmoid()
else:
- raise RuntimeError('')
+ raise RuntimeError("")
- if hasattr(m, 'inplace'):
+ if hasattr(m, "inplace"):
m.inplace = inplace
-
- return m
+
+ return m
diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/presnet.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/presnet.py
index 7401c0ef6..a586630db 100644
--- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/presnet.py
+++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/presnet.py
@@ -1,17 +1,15 @@
"""Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
from collections import OrderedDict
-from .common import get_activation, FrozenBatchNorm2d
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
from ..core import register
+from .common import FrozenBatchNorm2d, get_activation
-
-__all__ = ['PResNet']
+__all__ = ["PResNet"]
ResNet_cfg = {
@@ -23,10 +21,10 @@
donwload_url = {
- 18: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet18_vd_pretrained_from_paddle.pth',
- 34: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet34_vd_pretrained_from_paddle.pth',
- 50: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet50_vd_ssld_v2_pretrained_from_paddle.pth',
- 101: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet101_vd_ssld_pretrained_from_paddle.pth',
+ 18: "https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet18_vd_pretrained_from_paddle.pth",
+ 34: "https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet34_vd_pretrained_from_paddle.pth",
+ 50: "https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet50_vd_ssld_v2_pretrained_from_paddle.pth",
+ 101: "https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet101_vd_ssld_pretrained_from_paddle.pth",
}
@@ -34,14 +32,15 @@ class ConvNormLayer(nn.Module):
def __init__(self, ch_in, ch_out, kernel_size, stride, padding=None, bias=False, act=None):
super().__init__()
self.conv = nn.Conv2d(
- ch_in,
- ch_out,
- kernel_size,
- stride,
- padding=(kernel_size-1)//2 if padding is None else padding,
- bias=bias)
+ ch_in,
+ ch_out,
+ kernel_size,
+ stride,
+ padding=(kernel_size - 1) // 2 if padding is None else padding,
+ bias=bias,
+ )
self.norm = nn.BatchNorm2d(ch_out)
- self.act = get_activation(act)
+ self.act = get_activation(act)
def forward(self, x):
return self.act(self.norm(self.conv(x)))
@@ -50,24 +49,27 @@ def forward(self, x):
class BasicBlock(nn.Module):
expansion = 1
- def __init__(self, ch_in, ch_out, stride, shortcut, act='relu', variant='b'):
+ def __init__(self, ch_in, ch_out, stride, shortcut, act="relu", variant="b"):
super().__init__()
self.shortcut = shortcut
if not shortcut:
- if variant == 'd' and stride == 2:
- self.short = nn.Sequential(OrderedDict([
- ('pool', nn.AvgPool2d(2, 2, 0, ceil_mode=True)),
- ('conv', ConvNormLayer(ch_in, ch_out, 1, 1))
- ]))
+ if variant == "d" and stride == 2:
+ self.short = nn.Sequential(
+ OrderedDict(
+ [
+ ("pool", nn.AvgPool2d(2, 2, 0, ceil_mode=True)),
+ ("conv", ConvNormLayer(ch_in, ch_out, 1, 1)),
+ ]
+ )
+ )
else:
self.short = ConvNormLayer(ch_in, ch_out, 1, stride)
self.branch2a = ConvNormLayer(ch_in, ch_out, 3, stride, act=act)
self.branch2b = ConvNormLayer(ch_out, ch_out, 3, 1, act=None)
- self.act = nn.Identity() if act is None else get_activation(act)
-
+ self.act = nn.Identity() if act is None else get_activation(act)
def forward(self, x):
out = self.branch2a(x)
@@ -76,7 +78,7 @@ def forward(self, x):
short = x
else:
short = self.short(x)
-
+
out = out + short
out = self.act(out)
@@ -86,15 +88,15 @@ def forward(self, x):
class BottleNeck(nn.Module):
expansion = 4
- def __init__(self, ch_in, ch_out, stride, shortcut, act='relu', variant='b'):
+ def __init__(self, ch_in, ch_out, stride, shortcut, act="relu", variant="b"):
super().__init__()
- if variant == 'a':
+ if variant == "a":
stride1, stride2 = stride, 1
else:
stride1, stride2 = 1, stride
- width = ch_out
+ width = ch_out
self.branch2a = ConvNormLayer(ch_in, width, 1, stride1, act=act)
self.branch2b = ConvNormLayer(width, width, 3, stride2, act=act)
@@ -102,15 +104,19 @@ def __init__(self, ch_in, ch_out, stride, shortcut, act='relu', variant='b'):
self.shortcut = shortcut
if not shortcut:
- if variant == 'd' and stride == 2:
- self.short = nn.Sequential(OrderedDict([
- ('pool', nn.AvgPool2d(2, 2, 0, ceil_mode=True)),
- ('conv', ConvNormLayer(ch_in, ch_out * self.expansion, 1, 1))
- ]))
+ if variant == "d" and stride == 2:
+ self.short = nn.Sequential(
+ OrderedDict(
+ [
+ ("pool", nn.AvgPool2d(2, 2, 0, ceil_mode=True)),
+ ("conv", ConvNormLayer(ch_in, ch_out * self.expansion, 1, 1)),
+ ]
+ )
+ )
else:
self.short = ConvNormLayer(ch_in, ch_out * self.expansion, 1, stride)
- self.act = nn.Identity() if act is None else get_activation(act)
+ self.act = nn.Identity() if act is None else get_activation(act)
def forward(self, x):
out = self.branch2a(x)
@@ -129,19 +135,20 @@ def forward(self, x):
class Blocks(nn.Module):
- def __init__(self, block, ch_in, ch_out, count, stage_num, act='relu', variant='b'):
+ def __init__(self, block, ch_in, ch_out, count, stage_num, act="relu", variant="b"):
super().__init__()
self.blocks = nn.ModuleList()
for i in range(count):
self.blocks.append(
block(
- ch_in,
+ ch_in,
ch_out,
- stride=2 if i == 0 and stage_num != 2 else 1,
+ stride=2 if i == 0 and stage_num != 2 else 1,
shortcut=False if i == 0 else True,
variant=variant,
- act=act)
+ act=act,
+ )
)
if i == 0:
@@ -157,20 +164,21 @@ def forward(self, x):
@register()
class PResNet(nn.Module):
def __init__(
- self,
- depth,
- variant='d',
- num_stages=4,
- return_idx=[0, 1, 2, 3],
- act='relu',
- freeze_at=-1,
- freeze_norm=True,
- pretrained=False):
+ self,
+ depth,
+ variant="d",
+ num_stages=4,
+ return_idx=[0, 1, 2, 3],
+ act="relu",
+ freeze_at=-1,
+ freeze_norm=True,
+ pretrained=False,
+ ):
super().__init__()
block_nums = ResNet_cfg[depth]
ch_in = 64
- if variant in ['c', 'd']:
+ if variant in ["c", "d"]:
conv_def = [
[3, ch_in // 2, 3, 2, "conv1_1"],
[ch_in // 2, ch_in // 2, 3, 1, "conv1_2"],
@@ -179,9 +187,14 @@ def __init__(
else:
conv_def = [[3, ch_in, 7, 2, "conv1_1"]]
- self.conv1 = nn.Sequential(OrderedDict([
- (name, ConvNormLayer(cin, cout, k, s, act=act)) for cin, cout, k, s, name in conv_def
- ]))
+ self.conv1 = nn.Sequential(
+ OrderedDict(
+ [
+ (name, ConvNormLayer(cin, cout, k, s, act=act))
+ for cin, cout, k, s, name in conv_def
+ ]
+ )
+ )
ch_out_list = [64, 128, 256, 512]
block = BottleNeck if depth >= 50 else BasicBlock
@@ -193,7 +206,9 @@ def __init__(
for i in range(num_stages):
stage_num = i + 2
self.res_layers.append(
- Blocks(block, ch_in, ch_out_list[i], block_nums[i], stage_num, act=act, variant=variant)
+ Blocks(
+ block, ch_in, ch_out_list[i], block_nums[i], stage_num, act=act, variant=variant
+ )
)
ch_in = _out_channels[i]
@@ -210,12 +225,12 @@ def __init__(
self._freeze_norm(self)
if pretrained:
- if isinstance(pretrained, bool) or 'http' in pretrained:
- state = torch.hub.load_state_dict_from_url(donwload_url[depth], map_location='cpu')
+ if isinstance(pretrained, bool) or "http" in pretrained:
+ state = torch.hub.load_state_dict_from_url(donwload_url[depth], map_location="cpu")
else:
- state = torch.load(pretrained, map_location='cpu')
+ state = torch.load(pretrained, map_location="cpu")
self.load_state_dict(state)
- print(f'Load PResNet{depth} state_dict')
+ print(f"Load PResNet{depth} state_dict")
def _freeze_parameters(self, m: nn.Module):
for p in m.parameters():
@@ -240,5 +255,3 @@ def forward(self, x):
if idx in self.return_idx:
outs.append(x)
return outs
-
-
diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/__init__.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/__init__.py
index e1078b225..fe463ca0a 100644
--- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/__init__.py
+++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/__init__.py
@@ -1,7 +1,7 @@
"""Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""
-from .workspace import *
-from .yaml_utils import *
from ._config import BaseConfig
+from .workspace import *
from .yaml_config import YAMLConfig
+from .yaml_utils import *
diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/_config.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/_config.py
index 91d720ca2..33f6461a9 100644
--- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/_config.py
+++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/_config.py
@@ -1,76 +1,81 @@
"""Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""
-import torch.nn as nn
-from torch.utils.data import Dataset, DataLoader
+from typing import Callable
+
+import torch.nn as nn
+from torch.cuda.amp.grad_scaler import GradScaler
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
-from torch.cuda.amp.grad_scaler import GradScaler
+from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
-from typing import Callable
-
-__all__ = ['BaseConfig', ]
+__all__ = [
+ "BaseConfig",
+]
class BaseConfig(object):
-
def __init__(self) -> None:
super().__init__()
- self.task :str = None
+ self.task: str = None
- # instance / function
- self._model :nn.Module = None
- self._postprocessor :nn.Module = None
- self._criterion :nn.Module = None
- self._optimizer :Optimizer = None
- self._lr_scheduler :LRScheduler = None
- self._lr_warmup_scheduler: LRScheduler = None
- self._train_dataloader :DataLoader = None
- self._val_dataloader :DataLoader = None
- self._ema :nn.Module = None
- self._scaler :GradScaler = None
- self._train_dataset :Dataset = None
- self._val_dataset :Dataset = None
- self._collate_fn :Callable = None
- self._evaluator :Callable[[nn.Module, DataLoader, str], ] = None
+ # instance / function
+ self._model: nn.Module = None
+ self._postprocessor: nn.Module = None
+ self._criterion: nn.Module = None
+ self._optimizer: Optimizer = None
+ self._lr_scheduler: LRScheduler = None
+ self._lr_warmup_scheduler: LRScheduler = None
+ self._train_dataloader: DataLoader = None
+ self._val_dataloader: DataLoader = None
+ self._ema: nn.Module = None
+ self._scaler: GradScaler = None
+ self._train_dataset: Dataset = None
+ self._val_dataset: Dataset = None
+ self._collate_fn: Callable = None
+ self._evaluator: Callable[[nn.Module, DataLoader, str],] = None
self._writer: SummaryWriter = None
-
- # dataset
- self.num_workers :int = 0
- self.batch_size :int = None
- self._train_batch_size :int = None
- self._val_batch_size :int = None
- self._train_shuffle: bool = None
- self._val_shuffle: bool = None
+
+ # dataset
+ self.num_workers: int = 0
+ self.batch_size: int = None
+ self._train_batch_size: int = None
+ self._val_batch_size: int = None
+ self._train_shuffle: bool = None
+ self._val_shuffle: bool = None
# runtime
- self.resume :str = None
- self.tuning :str = None
+ self.resume: str = None
+ self.tuning: str = None
- self.epoches :int = None
- self.last_epoch :int = -1
+ self.epoches: int = None
+ self.last_epoch: int = -1
- self.use_amp :bool = False
- self.use_ema :bool = False
- self.ema_decay :float = 0.9999
+ self.use_amp: bool = False
+ self.use_ema: bool = False
+ self.ema_decay: float = 0.9999
self.ema_warmups: int = 2000
- self.sync_bn :bool = False
- self.clip_max_norm : float = 0.
- self.find_unused_parameters :bool = None
+ self.sync_bn: bool = False
+ self.clip_max_norm: float = 0.0
+ self.find_unused_parameters: bool = None
- self.seed :int = None
- self.print_freq :int = None
- self.checkpoint_freq :int = 1
- self.output_dir :str = None
- self.summary_dir :str = None
- self.device : str = ''
+ self.seed: int = None
+ self.print_freq: int = None
+ self.checkpoint_freq: int = 1
+ self.output_dir: str = None
+ self.summary_dir: str = None
+ self.device: str = ""
@property
- def model(self, ) -> nn.Module:
- return self._model
+ def model(
+ self,
+ ) -> nn.Module:
+ return self._model
@property
- def postprocessor(self, ) -> nn.Module:
- return self._postprocessor
\ No newline at end of file
+ def postprocessor(
+ self,
+ ) -> nn.Module:
+ return self._postprocessor
diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/workspace.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/workspace.py
index ea0e41af6..b96a83381 100644
--- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/workspace.py
+++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/workspace.py
@@ -1,171 +1,168 @@
""""Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""
-import inspect
-import importlib
import functools
+import importlib
import inspect
from collections import defaultdict
-from typing import Any, Dict, Optional, List
-
+from typing import Any, Dict, List, Optional
GLOBAL_CONFIG = defaultdict(dict)
-def register(dct :Any=GLOBAL_CONFIG, name=None, force=False):
+def register(dct: Any = GLOBAL_CONFIG, name=None, force=False):
"""
- dct:
- if dct is Dict, register foo into dct as key-value pair
- if dct is Clas, register as modules attibute
- force
- whether force register.
+ dct:
+ if dct is Dict, register foo into dct as key-value pair
+ if dct is Clas, register as modules attibute
+ force
+ whether force register.
"""
+
def decorator(foo):
register_name = foo.__name__ if name is None else name
if not force:
if inspect.isclass(dct):
- assert not hasattr(dct, foo.__name__), \
- f'module {dct.__name__} has {foo.__name__}'
+ assert not hasattr(dct, foo.__name__), f"module {dct.__name__} has {foo.__name__}"
else:
- assert foo.__name__ not in dct, \
- f'{foo.__name__} has been already registered'
+ assert foo.__name__ not in dct, f"{foo.__name__} has been already registered"
if inspect.isfunction(foo):
+
@functools.wraps(foo)
def wrap_func(*args, **kwargs):
return foo(*args, **kwargs)
+
if isinstance(dct, dict):
dct[foo.__name__] = wrap_func
elif inspect.isclass(dct):
setattr(dct, foo.__name__, wrap_func)
else:
- raise AttributeError('')
+ raise AttributeError("")
return wrap_func
elif inspect.isclass(foo):
- dct[register_name] = extract_schema(foo)
+ dct[register_name] = extract_schema(foo)
else:
- raise ValueError(f'Do not support {type(foo)} register')
+ raise ValueError(f"Do not support {type(foo)} register")
return foo
return decorator
-
def extract_schema(module: type):
"""
Args:
module (type),
Return:
- Dict,
+ Dict,
"""
argspec = inspect.getfullargspec(module.__init__)
- arg_names = [arg for arg in argspec.args if arg != 'self']
+ arg_names = [arg for arg in argspec.args if arg != "self"]
num_defualts = len(argspec.defaults) if argspec.defaults is not None else 0
num_requires = len(arg_names) - num_defualts
schame = dict()
- schame['_name'] = module.__name__
- schame['_pymodule'] = importlib.import_module(module.__module__)
- schame['_inject'] = getattr(module, '__inject__', [])
- schame['_share'] = getattr(module, '__share__', [])
- schame['_kwargs'] = {}
+ schame["_name"] = module.__name__
+ schame["_pymodule"] = importlib.import_module(module.__module__)
+ schame["_inject"] = getattr(module, "__inject__", [])
+ schame["_share"] = getattr(module, "__share__", [])
+ schame["_kwargs"] = {}
for i, name in enumerate(arg_names):
- if name in schame['_share']:
- assert i >= num_requires, 'share config must have default value.'
+ if name in schame["_share"]:
+ assert i >= num_requires, "share config must have default value."
value = argspec.defaults[i - num_requires]
-
+
elif i >= num_requires:
value = argspec.defaults[i - num_requires]
else:
- value = None
+ value = None
schame[name] = value
- schame['_kwargs'][name] = value
-
+ schame["_kwargs"][name] = value
+
return schame
def create(type_or_name, global_cfg=GLOBAL_CONFIG, **kwargs):
- """
- """
- assert type(type_or_name) in (type, str), 'create should be modules or name.'
+ """ """
+ assert type(type_or_name) in (type, str), "create should be modules or name."
name = type_or_name if isinstance(type_or_name, str) else type_or_name.__name__
if name in global_cfg:
- if hasattr(global_cfg[name], '__dict__'):
+ if hasattr(global_cfg[name], "__dict__"):
return global_cfg[name]
else:
- raise ValueError('The module {} is not registered'.format(name))
+ raise ValueError("The module {} is not registered".format(name))
cfg = global_cfg[name]
- if isinstance(cfg, dict) and 'type' in cfg:
- _cfg: dict = global_cfg[cfg['type']]
+ if isinstance(cfg, dict) and "type" in cfg:
+ _cfg: dict = global_cfg[cfg["type"]]
# clean args
- _keys = [k for k in _cfg.keys() if not k.startswith('_')]
+ _keys = [k for k in _cfg.keys() if not k.startswith("_")]
for _arg in _keys:
del _cfg[_arg]
- _cfg.update(_cfg['_kwargs']) # restore default args
- _cfg.update(cfg) # load config args
- _cfg.update(kwargs) # TODO recive extra kwargs
- name = _cfg.pop('type') # pop extra key `type` (from cfg)
-
+ _cfg.update(_cfg["_kwargs"]) # restore default args
+ _cfg.update(cfg) # load config args
+ _cfg.update(kwargs) # TODO recive extra kwargs
+ name = _cfg.pop("type") # pop extra key `type` (from cfg)
+
return create(name, global_cfg)
-
- module = getattr(cfg['_pymodule'], name)
+
+ module = getattr(cfg["_pymodule"], name)
module_kwargs = {}
module_kwargs.update(cfg)
-
+
# shared var
- for k in cfg['_share']:
+ for k in cfg["_share"]:
if k in global_cfg:
module_kwargs[k] = global_cfg[k]
else:
module_kwargs[k] = cfg[k]
# inject
- for k in cfg['_inject']:
+ for k in cfg["_inject"]:
_k = cfg[k]
if _k is None:
continue
- if isinstance(_k, str):
+ if isinstance(_k, str):
if _k not in global_cfg:
- raise ValueError(f'Missing inject config of {_k}.')
+ raise ValueError(f"Missing inject config of {_k}.")
_cfg = global_cfg[_k]
-
+
if isinstance(_cfg, dict):
- module_kwargs[k] = create(_cfg['_name'], global_cfg)
+ module_kwargs[k] = create(_cfg["_name"], global_cfg)
else:
- module_kwargs[k] = _cfg
+ module_kwargs[k] = _cfg
elif isinstance(_k, dict):
- if 'type' not in _k.keys():
- raise ValueError(f'Missing inject for `type` style.')
+ if "type" not in _k.keys():
+ raise ValueError(f"Missing inject for `type` style.")
- _type = str(_k['type'])
+ _type = str(_k["type"])
if _type not in global_cfg:
- raise ValueError(f'Missing {_type} in inspect stage.')
+ raise ValueError(f"Missing {_type} in inspect stage.")
_cfg: dict = global_cfg[_type]
- _keys = [k for k in _cfg.keys() if not k.startswith('_')]
+ _keys = [k for k in _cfg.keys() if not k.startswith("_")]
for _arg in _keys:
del _cfg[_arg]
- _cfg.update(_cfg['_kwargs']) # restore default values
- _cfg.update(_k) # load config args
- name = _cfg.pop('type') # pop extra key (`type` from _k)
+ _cfg.update(_cfg["_kwargs"]) # restore default values
+ _cfg.update(_k) # load config args
+ name = _cfg.pop("type") # pop extra key (`type` from _k)
module_kwargs[k] = create(name, global_cfg)
else:
- raise ValueError(f'Inject does not support {_k}')
-
- module_kwargs = {k: v for k, v in module_kwargs.items() if not k.startswith('_')}
+ raise ValueError(f"Inject does not support {_k}")
+
+ module_kwargs = {k: v for k, v in module_kwargs.items() if not k.startswith("_")}
- return module(**module_kwargs)
\ No newline at end of file
+ return module(**module_kwargs)
diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/yaml_config.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/yaml_config.py
index 8e0e02fd4..3e7129227 100644
--- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/yaml_config.py
+++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/yaml_config.py
@@ -1,13 +1,15 @@
"""Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""
-import torch
import copy
+import torch
+
from ._config import BaseConfig
from .workspace import create
from .yaml_utils import load_config, merge_config, merge_dict
+
class YAMLConfig(BaseConfig):
def __init__(self, cfg_path: str, **kwargs) -> None:
super().__init__()
@@ -15,24 +17,30 @@ def __init__(self, cfg_path: str, **kwargs) -> None:
cfg = load_config(cfg_path)
cfg = merge_dict(cfg, kwargs)
- self.yaml_cfg = copy.deepcopy(cfg)
-
+ self.yaml_cfg = copy.deepcopy(cfg)
+
for k in super().__dict__:
- if not k.startswith('_') and k in cfg:
+ if not k.startswith("_") and k in cfg:
self.__dict__[k] = cfg[k]
@property
- def global_cfg(self, ):
+ def global_cfg(
+ self,
+ ):
return merge_config(self.yaml_cfg, inplace=False, overwrite=False)
-
+
@property
- def model(self, ) -> torch.nn.Module:
- if self._model is None and 'model' in self.yaml_cfg:
- self._model = create(self.yaml_cfg['model'], self.global_cfg)
- return super().model
+ def model(
+ self,
+ ) -> torch.nn.Module:
+ if self._model is None and "model" in self.yaml_cfg:
+ self._model = create(self.yaml_cfg["model"], self.global_cfg)
+ return super().model
@property
- def postprocessor(self, ) -> torch.nn.Module:
- if self._postprocessor is None and 'postprocessor' in self.yaml_cfg:
- self._postprocessor = create(self.yaml_cfg['postprocessor'], self.global_cfg)
- return super().postprocessor
\ No newline at end of file
+ def postprocessor(
+ self,
+ ) -> torch.nn.Module:
+ if self._postprocessor is None and "postprocessor" in self.yaml_cfg:
+ self._postprocessor = create(self.yaml_cfg["postprocessor"], self.global_cfg)
+ return super().postprocessor
diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/yaml_utils.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/yaml_utils.py
index 1033bcf21..9a47d35e0 100644
--- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/yaml_utils.py
+++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/yaml_utils.py
@@ -1,28 +1,28 @@
""""Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""
-import os
import copy
-import yaml
-from typing import Any, Dict, Optional, List
+import os
+from typing import Any, Dict, List, Optional
+
+import yaml
from .workspace import GLOBAL_CONFIG
__all__ = [
- 'load_config',
- 'merge_config',
- 'merge_dict',
+ "load_config",
+ "merge_config",
+ "merge_dict",
]
-INCLUDE_KEY = '__include__'
+INCLUDE_KEY = "__include__"
def load_config(file_path, cfg=dict()):
- """load config
- """
+ """load config"""
_, ext = os.path.splitext(file_path)
- assert ext in ['.yml', '.yaml'], "only support yaml files"
+ assert ext in [".yml", ".yaml"], "only support yaml files"
with open(file_path) as f:
file_cfg = yaml.load(f, Loader=yaml.Loader)
@@ -32,10 +32,10 @@ def load_config(file_path, cfg=dict()):
if INCLUDE_KEY in file_cfg:
base_yamls = list(file_cfg[INCLUDE_KEY])
for base_yaml in base_yamls:
- if base_yaml.startswith('~'):
+ if base_yaml.startswith("~"):
base_yaml = os.path.expanduser(base_yaml)
- if not base_yaml.startswith('/'):
+ if not base_yaml.startswith("/"):
base_yaml = os.path.join(os.path.dirname(file_path), base_yaml)
with open(base_yaml) as f:
@@ -46,24 +46,24 @@ def load_config(file_path, cfg=dict()):
def merge_dict(dct, another_dct, inplace=True) -> Dict:
- """merge another_dct into dct
- """
+ """merge another_dct into dct"""
+
def _merge(dct, another) -> Dict:
for k in another:
- if (k in dct and isinstance(dct[k], dict) and isinstance(another[k], dict)):
+ if k in dct and isinstance(dct[k], dict) and isinstance(another[k], dict):
_merge(dct[k], another[k])
else:
dct[k] = another[k]
return dct
-
+
if not inplace:
dct = copy.deepcopy(dct)
-
+
return _merge(dct, another_dct)
-def merge_config(cfg, another_cfg=GLOBAL_CONFIG, inplace: bool=False, overwrite: bool=False):
+def merge_config(cfg, another_cfg=GLOBAL_CONFIG, inplace: bool = False, overwrite: bool = False):
"""
Merge another_cfg into cfg, return the merged config
@@ -78,19 +78,20 @@ def merge_config(cfg, another_cfg=GLOBAL_CONFIG, inplace: bool=False, overwrite:
model1 = create(cfg1['model'], cfg1)
model2 = create(cfg2['model'], cfg2)
"""
+
def _merge(dct, another):
for k in another:
if k not in dct:
dct[k] = another[k]
-
+
elif isinstance(dct[k], dict) and isinstance(another[k], dict):
- _merge(dct[k], another[k])
-
+ _merge(dct[k], another[k])
+
elif overwrite:
dct[k] = another[k]
return cfg
-
+
if not inplace:
cfg = copy.deepcopy(cfg)
diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/__init__.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/__init__.py
index 1df1f9669..f11cde740 100644
--- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/__init__.py
+++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/__init__.py
@@ -2,7 +2,7 @@
"""
-from .rtdetr import RTDETR
from .hybrid_encoder import HybridEncoder
+from .rtdetr import RTDETR
from .rtdetr_postprocessor import RTDETRPostProcessor
-from .rtdetrv2_decoder import RTDETRTransformerv2
\ No newline at end of file
+from .rtdetrv2_decoder import RTDETRTransformerv2
diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/box_ops.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/box_ops.py
index c45752a6c..f910fddf5 100644
--- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/box_ops.py
+++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/box_ops.py
@@ -10,13 +10,11 @@
def box_cxcywh_to_xyxy(x: Tensor) -> Tensor:
x_c, y_c, w, h = x.unbind(-1)
- b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
- (x_c + 0.5 * w), (y_c + 0.5 * h)]
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=-1)
def box_xyxy_to_cxcywh(x: Tensor) -> Tensor:
x0, y0, x1, y1 = x.unbind(-1)
- b = [(x0 + x1) / 2, (y0 + y1) / 2,
- (x1 - x0), (y1 - y0)]
- return torch.stack(b, dim=-1)
\ No newline at end of file
+ b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
+ return torch.stack(b, dim=-1)
diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/denoising.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/denoising.py
index c50f214c6..e63002e94 100644
--- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/denoising.py
+++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/denoising.py
@@ -1,26 +1,28 @@
"""Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""
-import torch
+import torch
-from .utils import inverse_sigmoid
from .box_ops import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh
+from .utils import inverse_sigmoid
-def get_contrastive_denoising_training_group(targets,
- num_classes,
- num_queries,
- class_embed,
- num_denoising=100,
- label_noise_ratio=0.5,
- box_noise_scale=1.0,):
+def get_contrastive_denoising_training_group(
+ targets,
+ num_classes,
+ num_queries,
+ class_embed,
+ num_denoising=100,
+ label_noise_ratio=0.5,
+ box_noise_scale=1.0,
+):
"""cnd"""
if num_denoising <= 0:
return None, None, None, None
- num_gts = [len(t['labels']) for t in targets]
- device = targets[0]['labels'].device
-
+ num_gts = [len(t["labels"]) for t in targets]
+ device = targets[0]["labels"].device
+
max_gt_num = max(num_gts)
if max_gt_num == 0:
return None, None, None, None
@@ -37,8 +39,8 @@ def get_contrastive_denoising_training_group(targets,
for i in range(bs):
num_gt = num_gts[i]
if num_gt > 0:
- input_query_class[i, :num_gt] = targets[i]['labels']
- input_query_bbox[i, :num_gt] = targets[i]['boxes']
+ input_query_class[i, :num_gt] = targets[i]["labels"]
+ input_query_bbox[i, :num_gt] = targets[i]["boxes"]
pad_gt_mask[i, :num_gt] = 1
# each group has positive and negative queries.
input_query_class = input_query_class.tile([1, 2 * num_group])
@@ -68,7 +70,7 @@ def get_contrastive_denoising_training_group(targets,
rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0
rand_part = torch.rand_like(input_query_bbox)
rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask)
- known_bbox += (rand_sign * rand_part * diff)
+ known_bbox += rand_sign * rand_part * diff
known_bbox = torch.clip(known_bbox, min=0.0, max=1.0)
input_query_bbox = box_xyxy_to_cxcywh(known_bbox)
input_query_bbox_unact = inverse_sigmoid(input_query_bbox)
@@ -79,21 +81,27 @@ def get_contrastive_denoising_training_group(targets,
attn_mask = torch.full([tgt_size, tgt_size], False, dtype=torch.bool, device=device)
# match query cannot see the reconstruction
attn_mask[num_denoising:, :num_denoising] = True
-
+
# reconstruct cannot see each other
for i in range(num_group):
if i == 0:
- attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), max_gt_num * 2 * (i + 1): num_denoising] = True
+ attn_mask[
+ max_gt_num * 2 * i : max_gt_num * 2 * (i + 1),
+ max_gt_num * 2 * (i + 1) : num_denoising,
+ ] = True
if i == num_group - 1:
- attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), :max_gt_num * i * 2] = True
+ attn_mask[max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), : max_gt_num * i * 2] = True
else:
- attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), max_gt_num * 2 * (i + 1): num_denoising] = True
- attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), :max_gt_num * 2 * i] = True
-
+ attn_mask[
+ max_gt_num * 2 * i : max_gt_num * 2 * (i + 1),
+ max_gt_num * 2 * (i + 1) : num_denoising,
+ ] = True
+ attn_mask[max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), : max_gt_num * 2 * i] = True
+
dn_meta = {
"dn_positive_idx": dn_positive_idx,
"dn_num_group": num_group,
- "dn_num_split": [num_denoising, num_queries]
+ "dn_num_split": [num_denoising, num_queries],
}
return input_query_logits, input_query_bbox_unact, attn_mask, dn_meta
diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/hybrid_encoder.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/hybrid_encoder.py
index 15e5acf7e..a4446a988 100644
--- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/hybrid_encoder.py
+++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/hybrid_encoder.py
@@ -4,47 +4,45 @@
import copy
from collections import OrderedDict
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from .utils import get_activation
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
from ..core import register
+from .utils import get_activation
-
-__all__ = ['HybridEncoder']
-
+__all__ = ["HybridEncoder"]
class ConvNormLayer(nn.Module):
def __init__(self, ch_in, ch_out, kernel_size, stride, padding=None, bias=False, act=None):
super().__init__()
self.conv = nn.Conv2d(
- ch_in,
- ch_out,
- kernel_size,
- stride,
- padding=(kernel_size-1)//2 if padding is None else padding,
- bias=bias)
+ ch_in,
+ ch_out,
+ kernel_size,
+ stride,
+ padding=(kernel_size - 1) // 2 if padding is None else padding,
+ bias=bias,
+ )
self.norm = nn.BatchNorm2d(ch_out)
- self.act = nn.Identity() if act is None else get_activation(act)
+ self.act = nn.Identity() if act is None else get_activation(act)
def forward(self, x):
return self.act(self.norm(self.conv(x)))
class RepVggBlock(nn.Module):
- def __init__(self, ch_in, ch_out, act='relu'):
+ def __init__(self, ch_in, ch_out, act="relu"):
super().__init__()
self.ch_in = ch_in
self.ch_out = ch_out
self.conv1 = ConvNormLayer(ch_in, ch_out, 3, 1, padding=1, act=None)
self.conv2 = ConvNormLayer(ch_in, ch_out, 1, 1, padding=0, act=None)
- self.act = nn.Identity() if act is None else get_activation(act)
+ self.act = nn.Identity() if act is None else get_activation(act)
def forward(self, x):
- if hasattr(self, 'conv'):
+ if hasattr(self, "conv"):
y = self.conv(x)
else:
y = self.conv1(x) + self.conv2(x)
@@ -52,17 +50,17 @@ def forward(self, x):
return self.act(y)
def convert_to_deploy(self):
- if not hasattr(self, 'conv'):
+ if not hasattr(self, "conv"):
self.conv = nn.Conv2d(self.ch_in, self.ch_out, 3, 1, padding=1)
kernel, bias = self.get_equivalent_kernel_bias()
self.conv.weight.data = kernel
- self.conv.bias.data = bias
+ self.conv.bias.data = bias
def get_equivalent_kernel_bias(self):
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
-
+
return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1), bias3x3 + bias1x1
def _pad_1x1_to_3x3_tensor(self, kernel1x1):
@@ -86,20 +84,16 @@ def _fuse_bn_tensor(self, branch: ConvNormLayer):
class CSPRepLayer(nn.Module):
- def __init__(self,
- in_channels,
- out_channels,
- num_blocks=3,
- expansion=1.0,
- bias=None,
- act="silu"):
+ def __init__(
+ self, in_channels, out_channels, num_blocks=3, expansion=1.0, bias=None, act="silu"
+ ):
super(CSPRepLayer, self).__init__()
hidden_channels = int(out_channels * expansion)
self.conv1 = ConvNormLayer(in_channels, hidden_channels, 1, 1, bias=bias, act=act)
self.conv2 = ConvNormLayer(in_channels, hidden_channels, 1, 1, bias=bias, act=act)
- self.bottlenecks = nn.Sequential(*[
- RepVggBlock(hidden_channels, hidden_channels, act=act) for _ in range(num_blocks)
- ])
+ self.bottlenecks = nn.Sequential(
+ *[RepVggBlock(hidden_channels, hidden_channels, act=act) for _ in range(num_blocks)]
+ )
if hidden_channels != out_channels:
self.conv3 = ConvNormLayer(hidden_channels, out_channels, 1, 1, bias=bias, act=act)
else:
@@ -114,13 +108,15 @@ def forward(self, x):
# transformer
class TransformerEncoderLayer(nn.Module):
- def __init__(self,
- d_model,
- nhead,
- dim_feedforward=2048,
- dropout=0.1,
- activation="relu",
- normalize_before=False):
+ def __init__(
+ self,
+ d_model,
+ nhead,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False,
+ ):
super().__init__()
self.normalize_before = normalize_before
@@ -134,7 +130,7 @@ def __init__(self,
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
- self.activation = get_activation(activation)
+ self.activation = get_activation(activation)
@staticmethod
def with_pos_embed(tensor, pos_embed):
@@ -181,24 +177,28 @@ def forward(self, src, src_mask=None, pos_embed=None) -> torch.Tensor:
@register()
class HybridEncoder(nn.Module):
- __share__ = ['eval_spatial_size', ]
-
- def __init__(self,
- in_channels=[512, 1024, 2048],
- feat_strides=[8, 16, 32],
- hidden_dim=256,
- nhead=8,
- dim_feedforward = 1024,
- dropout=0.0,
- enc_act='gelu',
- use_encoder_idx=[2],
- num_encoder_layers=1,
- pe_temperature=10000,
- expansion=1.0,
- depth_mult=1.0,
- act='silu',
- eval_spatial_size=None,
- version='v2'):
+ __share__ = [
+ "eval_spatial_size",
+ ]
+
+ def __init__(
+ self,
+ in_channels=[512, 1024, 2048],
+ feat_strides=[8, 16, 32],
+ hidden_dim=256,
+ nhead=8,
+ dim_feedforward=1024,
+ dropout=0.0,
+ enc_act="gelu",
+ use_encoder_idx=[2],
+ num_encoder_layers=1,
+ pe_temperature=10000,
+ expansion=1.0,
+ depth_mult=1.0,
+ act="silu",
+ eval_spatial_size=None,
+ version="v2",
+ ):
super().__init__()
self.in_channels = in_channels
self.feat_strides = feat_strides
@@ -206,38 +206,47 @@ def __init__(self,
self.use_encoder_idx = use_encoder_idx
self.num_encoder_layers = num_encoder_layers
self.pe_temperature = pe_temperature
- self.eval_spatial_size = eval_spatial_size
+ self.eval_spatial_size = eval_spatial_size
self.out_channels = [hidden_dim for _ in range(len(in_channels))]
self.out_strides = feat_strides
-
+
# channel projection
self.input_proj = nn.ModuleList()
for in_channel in in_channels:
- if version == 'v1':
+ if version == "v1":
proj = nn.Sequential(
nn.Conv2d(in_channel, hidden_dim, kernel_size=1, bias=False),
- nn.BatchNorm2d(hidden_dim))
- elif version == 'v2':
- proj = nn.Sequential(OrderedDict([
- ('conv', nn.Conv2d(in_channel, hidden_dim, kernel_size=1, bias=False)),
- ('norm', nn.BatchNorm2d(hidden_dim))
- ]))
+ nn.BatchNorm2d(hidden_dim),
+ )
+ elif version == "v2":
+ proj = nn.Sequential(
+ OrderedDict(
+ [
+ ("conv", nn.Conv2d(in_channel, hidden_dim, kernel_size=1, bias=False)),
+ ("norm", nn.BatchNorm2d(hidden_dim)),
+ ]
+ )
+ )
else:
raise AttributeError()
-
+
self.input_proj.append(proj)
# encoder transformer
encoder_layer = TransformerEncoderLayer(
- hidden_dim,
+ hidden_dim,
nhead=nhead,
- dim_feedforward=dim_feedforward,
+ dim_feedforward=dim_feedforward,
dropout=dropout,
- activation=enc_act)
+ activation=enc_act,
+ )
- self.encoder = nn.ModuleList([
- TransformerEncoder(copy.deepcopy(encoder_layer), num_encoder_layers) for _ in range(len(use_encoder_idx))
- ])
+ self.encoder = nn.ModuleList(
+ [
+ TransformerEncoder(copy.deepcopy(encoder_layer), num_encoder_layers)
+ for _ in range(len(use_encoder_idx))
+ ]
+ )
# top-down fpn
self.lateral_convs = nn.ModuleList()
@@ -245,18 +254,20 @@ def __init__(self,
for _ in range(len(in_channels) - 1, 0, -1):
self.lateral_convs.append(ConvNormLayer(hidden_dim, hidden_dim, 1, 1, act=act))
self.fpn_blocks.append(
- CSPRepLayer(hidden_dim * 2, hidden_dim, round(3 * depth_mult), act=act, expansion=expansion)
+ CSPRepLayer(
+ hidden_dim * 2, hidden_dim, round(3 * depth_mult), act=act, expansion=expansion
+ )
)
# bottom-up pan
self.downsample_convs = nn.ModuleList()
self.pan_blocks = nn.ModuleList()
for _ in range(len(in_channels) - 1):
- self.downsample_convs.append(
- ConvNormLayer(hidden_dim, hidden_dim, 3, 2, act=act)
- )
+ self.downsample_convs.append(ConvNormLayer(hidden_dim, hidden_dim, 3, 2, act=act))
self.pan_blocks.append(
- CSPRepLayer(hidden_dim * 2, hidden_dim, round(3 * depth_mult), act=act, expansion=expansion)
+ CSPRepLayer(
+ hidden_dim * 2, hidden_dim, round(3 * depth_mult), act=act, expansion=expansion
+ )
)
self._reset_parameters()
@@ -266,23 +277,26 @@ def _reset_parameters(self):
for idx in self.use_encoder_idx:
stride = self.feat_strides[idx]
pos_embed = self.build_2d_sincos_position_embedding(
- self.eval_spatial_size[1] // stride, self.eval_spatial_size[0] // stride,
- self.hidden_dim, self.pe_temperature)
- setattr(self, f'pos_embed{idx}', pos_embed)
- #self.register_buffer(f'pos_embed{idx}', pos_embed)
+ self.eval_spatial_size[1] // stride,
+ self.eval_spatial_size[0] // stride,
+ self.hidden_dim,
+ self.pe_temperature,
+ )
+ setattr(self, f"pos_embed{idx}", pos_embed)
+ # self.register_buffer(f'pos_embed{idx}', pos_embed)
@staticmethod
- def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.):
- """
- """
+ def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.0):
+ """ """
grid_w = torch.arange(int(w), dtype=torch.float32)
grid_h = torch.arange(int(h), dtype=torch.float32)
- grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing='ij')
- assert embed_dim % 4 == 0, \
- 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
+ grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij")
+ assert (
+ embed_dim % 4 == 0
+ ), "Embed dimension must be divisible by 4 for 2D sin-cos position embedding"
pos_dim = embed_dim // 4
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
- omega = 1. / (temperature ** omega)
+ omega = 1.0 / (temperature**omega)
out_w = grid_w.flatten()[..., None] @ omega[None]
out_h = grid_h.flatten()[..., None] @ omega[None]
@@ -292,7 +306,7 @@ def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.):
def forward(self, feats):
assert len(feats) == len(self.in_channels)
proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)]
-
+
# encoder
if self.num_encoder_layers > 0:
for i, enc_ind in enumerate(self.use_encoder_idx):
@@ -301,12 +315,15 @@ def forward(self, feats):
src_flatten = proj_feats[enc_ind].flatten(2).permute(0, 2, 1)
if self.training or self.eval_spatial_size is None:
pos_embed = self.build_2d_sincos_position_embedding(
- w, h, self.hidden_dim, self.pe_temperature).to(src_flatten.device)
+ w, h, self.hidden_dim, self.pe_temperature
+ ).to(src_flatten.device)
else:
- pos_embed = getattr(self, f'pos_embed{enc_ind}', None).to(src_flatten.device)
+ pos_embed = getattr(self, f"pos_embed{enc_ind}", None).to(src_flatten.device)
- memory :torch.Tensor = self.encoder[i](src_flatten, pos_embed=pos_embed)
- proj_feats[enc_ind] = memory.permute(0, 2, 1).reshape(-1, self.hidden_dim, h, w).contiguous()
+ memory: torch.Tensor = self.encoder[i](src_flatten, pos_embed=pos_embed)
+ proj_feats[enc_ind] = (
+ memory.permute(0, 2, 1).reshape(-1, self.hidden_dim, h, w).contiguous()
+ )
# broadcasting and fusion
inner_outs = [proj_feats[-1]]
@@ -315,8 +332,10 @@ def forward(self, feats):
feat_low = proj_feats[idx - 1]
feat_heigh = self.lateral_convs[len(self.in_channels) - 1 - idx](feat_heigh)
inner_outs[0] = feat_heigh
- upsample_feat = F.interpolate(feat_heigh, scale_factor=2., mode='nearest')
- inner_out = self.fpn_blocks[len(self.in_channels)-1-idx](torch.concat([upsample_feat, feat_low], dim=1))
+ upsample_feat = F.interpolate(feat_heigh, scale_factor=2.0, mode="nearest")
+ inner_out = self.fpn_blocks[len(self.in_channels) - 1 - idx](
+ torch.concat([upsample_feat, feat_low], dim=1)
+ )
inner_outs.insert(0, inner_out)
outs = [inner_outs[0]]
diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetr.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetr.py
index 77a23be5c..1bf28c22f 100644
--- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetr.py
+++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetr.py
@@ -1,44 +1,52 @@
"""Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
+import random
+from typing import List
-import random
-import numpy as np
-from typing import List
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
from ..core import register
-
-__all__ = ['RTDETR', ]
+__all__ = [
+ "RTDETR",
+]
@register()
class RTDETR(nn.Module):
- __inject__ = ['backbone', 'encoder', 'decoder', ]
-
- def __init__(self, \
- backbone: nn.Module,
- encoder: nn.Module,
- decoder: nn.Module,
+ __inject__ = [
+ "backbone",
+ "encoder",
+ "decoder",
+ ]
+
+ def __init__(
+ self,
+ backbone: nn.Module,
+ encoder: nn.Module,
+ decoder: nn.Module,
):
super().__init__()
self.backbone = backbone
self.decoder = decoder
self.encoder = encoder
-
+
def forward(self, x, targets=None):
x = self.backbone(x)
- x = self.encoder(x)
+ x = self.encoder(x)
x = self.decoder(x, targets)
return x
-
- def deploy(self, ):
+
+ def deploy(
+ self,
+ ):
self.eval()
for m in self.modules():
- if hasattr(m, 'convert_to_deploy'):
+ if hasattr(m, "convert_to_deploy"):
m.convert_to_deploy()
- return self
+ return self
diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetr_postprocessor.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetr_postprocessor.py
index 6d024fa45..1afc7c87c 100644
--- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetr_postprocessor.py
+++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetr_postprocessor.py
@@ -1,16 +1,14 @@
"""Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
import torchvision
from ..core import register
-
-__all__ = ['RTDETRPostProcessor']
+__all__ = ["RTDETRPostProcessor"]
def mod(a, b):
@@ -20,36 +18,27 @@ def mod(a, b):
@register()
class RTDETRPostProcessor(nn.Module):
- __share__ = [
- 'num_classes',
- 'use_focal_loss',
- 'num_top_queries',
- 'remap_mscoco_category'
- ]
-
+ __share__ = ["num_classes", "use_focal_loss", "num_top_queries", "remap_mscoco_category"]
+
def __init__(
- self,
- num_classes=80,
- use_focal_loss=True,
- num_top_queries=300,
- remap_mscoco_category=False
+ self, num_classes=80, use_focal_loss=True, num_top_queries=300, remap_mscoco_category=False
) -> None:
super().__init__()
self.use_focal_loss = use_focal_loss
self.num_top_queries = num_top_queries
self.num_classes = int(num_classes)
- self.remap_mscoco_category = remap_mscoco_category
- self.deploy_mode = False
+ self.remap_mscoco_category = remap_mscoco_category
+ self.deploy_mode = False
def extra_repr(self) -> str:
- return f'use_focal_loss={self.use_focal_loss}, num_classes={self.num_classes}, num_top_queries={self.num_top_queries}'
-
+ return f"use_focal_loss={self.use_focal_loss}, num_classes={self.num_classes}, num_top_queries={self.num_top_queries}"
+
# def forward(self, outputs, orig_target_sizes):
def forward(self, outputs, orig_target_sizes: torch.Tensor):
- logits, boxes = outputs['pred_logits'], outputs['pred_boxes']
- # orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
+ logits, boxes = outputs["pred_logits"], outputs["pred_boxes"]
+ # orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
- bbox_pred = torchvision.ops.box_convert(boxes, in_fmt='cxcywh', out_fmt='xyxy')
+ bbox_pred = torchvision.ops.box_convert(boxes, in_fmt="cxcywh", out_fmt="xyxy")
bbox_pred *= orig_target_sizes.repeat(1, 2).unsqueeze(1)
if self.use_focal_loss:
@@ -59,16 +48,20 @@ def forward(self, outputs, orig_target_sizes: torch.Tensor):
# labels = index % self.num_classes
labels = mod(index, self.num_classes)
index = index // self.num_classes
- boxes = bbox_pred.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, bbox_pred.shape[-1]))
-
+ boxes = bbox_pred.gather(
+ dim=1, index=index.unsqueeze(-1).repeat(1, 1, bbox_pred.shape[-1])
+ )
+
else:
scores = F.softmax(logits)[:, :, :-1]
scores, labels = scores.max(dim=-1)
if scores.shape[1] > self.num_top_queries:
scores, index = torch.topk(scores, self.num_top_queries, dim=-1)
labels = torch.gather(labels, dim=1, index=index)
- boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1]))
-
+ boxes = torch.gather(
+ boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1])
+ )
+
# TODO for onnx export
if self.deploy_mode:
return labels, boxes, scores
@@ -76,18 +69,23 @@ def forward(self, outputs, orig_target_sizes: torch.Tensor):
# TODO
if self.remap_mscoco_category:
from ...data.dataset import mscoco_label2category
- labels = torch.tensor([mscoco_label2category[int(x.item())] for x in labels.flatten()])\
- .to(boxes.device).reshape(labels.shape)
+
+ labels = (
+ torch.tensor([mscoco_label2category[int(x.item())] for x in labels.flatten()])
+ .to(boxes.device)
+ .reshape(labels.shape)
+ )
results = []
for lab, box, sco in zip(labels, boxes, scores):
result = dict(labels=lab, boxes=box, scores=sco)
results.append(result)
-
+
return results
-
- def deploy(self, ):
+ def deploy(
+ self,
+ ):
self.eval()
self.deploy_mode = True
- return self
+ return self
diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetrv2_decoder.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetrv2_decoder.py
index 945fbfd49..87e6a6e44 100644
--- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetrv2_decoder.py
+++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetrv2_decoder.py
@@ -1,31 +1,37 @@
"""Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""
-import math
-import copy
+import copy
import functools
+import math
from collections import OrderedDict
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.nn.init as init
from typing import List
-from .denoising import get_contrastive_denoising_training_group
-from .utils import deformable_attention_core_func_v2, get_activation, inverse_sigmoid, bias_init_with_prob
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.nn.init as init
from ..core import register
+from .denoising import get_contrastive_denoising_training_group
+from .utils import (
+ bias_init_with_prob,
+ deformable_attention_core_func_v2,
+ get_activation,
+ inverse_sigmoid,
+)
-__all__ = ['RTDETRTransformerv2']
+__all__ = ["RTDETRTransformerv2"]
class MLP(nn.Module):
- def __init__(self, input_dim, hidden_dim, output_dim, num_layers, act='relu'):
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers, act="relu"):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
- self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+ self.layers = nn.ModuleList(
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
+ )
self.act = get_activation(act)
def forward(self, x):
@@ -36,16 +42,15 @@ def forward(self, x):
class MSDeformableAttention(nn.Module):
def __init__(
- self,
- embed_dim=256,
- num_heads=8,
- num_levels=4,
- num_points=4,
- method='default',
+ self,
+ embed_dim=256,
+ num_heads=8,
+ num_levels=4,
+ num_points=4,
+ method="default",
offset_scale=0.5,
):
- """Multi-Scale Deformable Attention
- """
+ """Multi-Scale Deformable Attention"""
super(MSDeformableAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
@@ -53,43 +58,53 @@ def __init__(
self.offset_scale = offset_scale
if isinstance(num_points, list):
- assert len(num_points) == num_levels, ''
+ assert len(num_points) == num_levels, ""
num_points_list = num_points
else:
num_points_list = [num_points for _ in range(num_levels)]
self.num_points_list = num_points_list
-
- num_points_scale = [1/n for n in num_points_list for _ in range(n)]
- self.register_buffer('num_points_scale', torch.tensor(num_points_scale, dtype=torch.float32))
+
+ num_points_scale = [1 / n for n in num_points_list for _ in range(n)]
+ self.register_buffer(
+ "num_points_scale", torch.tensor(num_points_scale, dtype=torch.float32)
+ )
self.total_points = num_heads * sum(num_points_list)
self.method = method
self.head_dim = embed_dim // num_heads
- assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
+ assert (
+ self.head_dim * num_heads == self.embed_dim
+ ), "embed_dim must be divisible by num_heads"
self.sampling_offsets = nn.Linear(embed_dim, self.total_points * 2)
self.attention_weights = nn.Linear(embed_dim, self.total_points)
self.value_proj = nn.Linear(embed_dim, embed_dim)
self.output_proj = nn.Linear(embed_dim, embed_dim)
- self.ms_deformable_attn_core = functools.partial(deformable_attention_core_func_v2, method=self.method)
+ self.ms_deformable_attn_core = functools.partial(
+ deformable_attention_core_func_v2, method=self.method
+ )
self._reset_parameters()
- if method == 'discrete':
+ if method == "discrete":
for p in self.sampling_offsets.parameters():
p.requires_grad = False
def _reset_parameters(self):
# sampling_offsets
init.constant_(self.sampling_offsets.weight, 0)
- thetas = torch.arange(self.num_heads, dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
+ thetas = torch.arange(self.num_heads, dtype=torch.float32) * (
+ 2.0 * math.pi / self.num_heads
+ )
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = grid_init / grid_init.abs().max(-1, keepdim=True).values
grid_init = grid_init.reshape(self.num_heads, 1, 2).tile([1, sum(self.num_points_list), 1])
- scaling = torch.concat([torch.arange(1, n + 1) for n in self.num_points_list]).reshape(1, -1, 1)
+ scaling = torch.concat([torch.arange(1, n + 1) for n in self.num_points_list]).reshape(
+ 1, -1, 1
+ )
grid_init *= scaling
self.sampling_offsets.bias.data[...] = grid_init.flatten()
@@ -103,13 +118,14 @@ def _reset_parameters(self):
init.xavier_uniform_(self.output_proj.weight)
init.constant_(self.output_proj.bias, 0)
-
- def forward(self,
- query: torch.Tensor,
- reference_points: torch.Tensor,
- value: torch.Tensor,
- value_spatial_shapes: List[int],
- value_mask: torch.Tensor=None):
+ def forward(
+ self,
+ query: torch.Tensor,
+ reference_points: torch.Tensor,
+ value: torch.Tensor,
+ value_spatial_shapes: List[int],
+ value_mask: torch.Tensor = None,
+ ):
"""
Args:
query (Tensor): [bs, query_length, C]
@@ -132,27 +148,45 @@ def forward(self,
value = value.reshape(bs, Len_v, self.num_heads, self.head_dim)
sampling_offsets: torch.Tensor = self.sampling_offsets(query)
- sampling_offsets = sampling_offsets.reshape(bs, Len_q, self.num_heads, sum(self.num_points_list), 2)
+ sampling_offsets = sampling_offsets.reshape(
+ bs, Len_q, self.num_heads, sum(self.num_points_list), 2
+ )
- attention_weights = self.attention_weights(query).reshape(bs, Len_q, self.num_heads, sum(self.num_points_list))
- attention_weights = F.softmax(attention_weights, dim=-1).reshape(bs, Len_q, self.num_heads, sum(self.num_points_list))
+ attention_weights = self.attention_weights(query).reshape(
+ bs, Len_q, self.num_heads, sum(self.num_points_list)
+ )
+ attention_weights = F.softmax(attention_weights, dim=-1).reshape(
+ bs, Len_q, self.num_heads, sum(self.num_points_list)
+ )
if reference_points.shape[-1] == 2:
offset_normalizer = torch.tensor(value_spatial_shapes)
offset_normalizer = offset_normalizer.flip([1]).reshape(1, 1, 1, self.num_levels, 1, 2)
- sampling_locations = reference_points.reshape(bs, Len_q, 1, self.num_levels, 1, 2) + sampling_offsets / offset_normalizer
+ sampling_locations = (
+ reference_points.reshape(bs, Len_q, 1, self.num_levels, 1, 2)
+ + sampling_offsets / offset_normalizer
+ )
elif reference_points.shape[-1] == 4:
# reference_points [8, 480, None, 1, 4]
# sampling_offsets [8, 480, 8, 12, 2]
num_points_scale = self.num_points_scale.to(dtype=query.dtype).unsqueeze(-1)
- offset = sampling_offsets * num_points_scale * reference_points[:, :, None, :, 2:] * self.offset_scale
+ offset = (
+ sampling_offsets
+ * num_points_scale
+ * reference_points[:, :, None, :, 2:]
+ * self.offset_scale
+ )
sampling_locations = reference_points[:, :, None, :, :2] + offset
else:
raise ValueError(
- "Last dim of reference_points must be 2 or 4, but get {} instead.".
- format(reference_points.shape[-1]))
+ "Last dim of reference_points must be 2 or 4, but get {} instead.".format(
+ reference_points.shape[-1]
+ )
+ )
- output = self.ms_deformable_attn_core(value, value_spatial_shapes, sampling_locations, attention_weights, self.num_points_list)
+ output = self.ms_deformable_attn_core(
+ value, value_spatial_shapes, sampling_locations, attention_weights, self.num_points_list
+ )
output = self.output_proj(output)
@@ -160,15 +194,17 @@ def forward(self,
class TransformerDecoderLayer(nn.Module):
- def __init__(self,
- d_model=256,
- n_head=8,
- dim_feedforward=1024,
- dropout=0.,
- activation='relu',
- n_levels=4,
- n_points=4,
- cross_attn_method='default'):
+ def __init__(
+ self,
+ d_model=256,
+ n_head=8,
+ dim_feedforward=1024,
+ dropout=0.0,
+ activation="relu",
+ n_levels=4,
+ n_points=4,
+ cross_attn_method="default",
+ ):
super(TransformerDecoderLayer, self).__init__()
# self attention
@@ -177,7 +213,9 @@ def __init__(self,
self.norm1 = nn.LayerNorm(d_model)
# cross attention
- self.cross_attn = MSDeformableAttention(d_model, n_head, n_levels, n_points, method=cross_attn_method)
+ self.cross_attn = MSDeformableAttention(
+ d_model, n_head, n_levels, n_points, method=cross_attn_method
+ )
self.dropout2 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(d_model)
@@ -188,7 +226,7 @@ def __init__(self,
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.dropout4 = nn.Dropout(dropout)
self.norm3 = nn.LayerNorm(d_model)
-
+
self._reset_parameters()
def _reset_parameters(self):
@@ -201,14 +239,16 @@ def with_pos_embed(self, tensor, pos):
def forward_ffn(self, tgt):
return self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
- def forward(self,
- target,
- reference_points,
- memory,
- memory_spatial_shapes,
- attn_mask=None,
- memory_mask=None,
- query_pos_embed=None):
+ def forward(
+ self,
+ target,
+ reference_points,
+ memory,
+ memory_spatial_shapes,
+ attn_mask=None,
+ memory_mask=None,
+ query_pos_embed=None,
+ ):
# self attention
q = k = self.with_pos_embed(target, query_pos_embed)
@@ -217,12 +257,13 @@ def forward(self,
target = self.norm1(target)
# cross attention
- target2 = self.cross_attn(\
- self.with_pos_embed(target, query_pos_embed),
- reference_points,
- memory,
- memory_spatial_shapes,
- memory_mask)
+ target2 = self.cross_attn(
+ self.with_pos_embed(target, query_pos_embed),
+ reference_points,
+ memory,
+ memory_spatial_shapes,
+ memory_mask,
+ )
target = target + self.dropout2(target2)
target = self.norm2(target)
@@ -242,16 +283,18 @@ def __init__(self, hidden_dim, decoder_layer, num_layers, eval_idx=-1):
self.num_layers = num_layers
self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx
- def forward(self,
- target,
- ref_points_unact,
- memory,
- memory_spatial_shapes,
- bbox_head,
- score_head,
- query_pos_head,
- attn_mask=None,
- memory_mask=None):
+ def forward(
+ self,
+ target,
+ ref_points_unact,
+ memory,
+ memory_spatial_shapes,
+ bbox_head,
+ score_head,
+ query_pos_head,
+ attn_mask=None,
+ memory_mask=None,
+ ):
dec_out_bboxes = []
dec_out_logits = []
ref_points_detach = F.sigmoid(ref_points_unact)
@@ -261,7 +304,15 @@ def forward(self,
ref_points_input = ref_points_detach.unsqueeze(2)
query_pos_embed = query_pos_head(ref_points_detach)
- output = layer(output, ref_points_input, memory, memory_spatial_shapes, attn_mask, memory_mask, query_pos_embed)
+ output = layer(
+ output,
+ ref_points_input,
+ memory,
+ memory_spatial_shapes,
+ attn_mask,
+ memory_mask,
+ query_pos_embed,
+ )
inter_ref_bbox = F.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points_detach))
@@ -270,7 +321,9 @@ def forward(self,
if i == 0:
dec_out_bboxes.append(inter_ref_bbox)
else:
- dec_out_bboxes.append(F.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points)))
+ dec_out_bboxes.append(
+ F.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points))
+ )
elif i == self.eval_idx:
dec_out_logits.append(score_head[i](output))
@@ -285,35 +338,37 @@ def forward(self,
@register()
class RTDETRTransformerv2(nn.Module):
- __share__ = ['num_classes', 'eval_spatial_size']
-
- def __init__(self,
- num_classes=80,
- hidden_dim=256,
- num_queries=300,
- feat_channels=[512, 1024, 2048],
- feat_strides=[8, 16, 32],
- num_levels=3,
- num_points=4,
- nhead=8,
- num_layers=6,
- dim_feedforward=1024,
- dropout=0.,
- activation="relu",
- num_denoising=100,
- label_noise_ratio=0.5,
- box_noise_scale=1.0,
- learn_query_content=False,
- eval_spatial_size=None,
- eval_idx=-1,
- eps=1e-2,
- aux_loss=True,
- cross_attn_method='default',
- query_select_method='default'):
+ __share__ = ["num_classes", "eval_spatial_size"]
+
+ def __init__(
+ self,
+ num_classes=80,
+ hidden_dim=256,
+ num_queries=300,
+ feat_channels=[512, 1024, 2048],
+ feat_strides=[8, 16, 32],
+ num_levels=3,
+ num_points=4,
+ nhead=8,
+ num_layers=6,
+ dim_feedforward=1024,
+ dropout=0.0,
+ activation="relu",
+ num_denoising=100,
+ label_noise_ratio=0.5,
+ box_noise_scale=1.0,
+ learn_query_content=False,
+ eval_spatial_size=None,
+ eval_idx=-1,
+ eps=1e-2,
+ aux_loss=True,
+ cross_attn_method="default",
+ query_select_method="default",
+ ):
super().__init__()
assert len(feat_channels) <= num_levels
assert len(feat_strides) == len(feat_channels)
-
+
for _ in range(num_levels - len(feat_strides)):
feat_strides.append(feat_strides[-1] * 2)
@@ -328,8 +383,8 @@ def __init__(self,
self.eval_spatial_size = eval_spatial_size
self.aux_loss = aux_loss
- assert query_select_method in ('default', 'one2many', 'agnostic'), ''
- assert cross_attn_method in ('default', 'discrete'), ''
+ assert query_select_method in ("default", "one2many", "agnostic"), ""
+ assert cross_attn_method in ("default", "discrete"), ""
self.cross_attn_method = cross_attn_method
self.query_select_method = query_select_method
@@ -337,16 +392,26 @@ def __init__(self,
self._build_input_proj_layer(feat_channels)
# Transformer module
- decoder_layer = TransformerDecoderLayer(hidden_dim, nhead, dim_feedforward, dropout, \
- activation, num_levels, num_points, cross_attn_method=cross_attn_method)
+ decoder_layer = TransformerDecoderLayer(
+ hidden_dim,
+ nhead,
+ dim_feedforward,
+ dropout,
+ activation,
+ num_levels,
+ num_points,
+ cross_attn_method=cross_attn_method,
+ )
self.decoder = TransformerDecoder(hidden_dim, decoder_layer, num_layers, eval_idx)
# denoising
self.num_denoising = num_denoising
self.label_noise_ratio = label_noise_ratio
self.box_noise_scale = box_noise_scale
- if num_denoising > 0:
- self.denoising_class_embed = nn.Embedding(num_classes+1, hidden_dim, padding_idx=num_classes)
+ if num_denoising > 0:
+ self.denoising_class_embed = nn.Embedding(
+ num_classes + 1, hidden_dim, padding_idx=num_classes
+ )
init.normal_(self.denoising_class_embed.weight[:-1])
# decoder embedding
@@ -359,12 +424,21 @@ def __init__(self,
# layer = TransformerEncoderLayer(hidden_dim, nhead, dim_feedforward, activation='gelu')
# self.encoder = TransformerEncoder(layer, 1)
- self.enc_output = nn.Sequential(OrderedDict([
- ('proj', nn.Linear(hidden_dim, hidden_dim)),
- ('norm', nn.LayerNorm(hidden_dim,)),
- ]))
+ self.enc_output = nn.Sequential(
+ OrderedDict(
+ [
+ ("proj", nn.Linear(hidden_dim, hidden_dim)),
+ (
+ "norm",
+ nn.LayerNorm(
+ hidden_dim,
+ ),
+ ),
+ ]
+ )
+ )
- if query_select_method == 'agnostic':
+ if query_select_method == "agnostic":
self.enc_score_head = nn.Linear(hidden_dim, 1)
else:
self.enc_score_head = nn.Linear(hidden_dim, num_classes)
@@ -372,21 +446,21 @@ def __init__(self,
self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3)
# decoder head
- self.dec_score_head = nn.ModuleList([
- nn.Linear(hidden_dim, num_classes) for _ in range(num_layers)
- ])
- self.dec_bbox_head = nn.ModuleList([
- MLP(hidden_dim, hidden_dim, 4, 3) for _ in range(num_layers)
- ])
+ self.dec_score_head = nn.ModuleList(
+ [nn.Linear(hidden_dim, num_classes) for _ in range(num_layers)]
+ )
+ self.dec_bbox_head = nn.ModuleList(
+ [MLP(hidden_dim, hidden_dim, 4, 3) for _ in range(num_layers)]
+ )
# init encoder output anchors and valid_mask
if self.eval_spatial_size:
anchors, valid_mask = self._generate_anchors()
- self.register_buffer('anchors', anchors)
- self.register_buffer('valid_mask', valid_mask)
+ self.register_buffer("anchors", anchors)
+ self.register_buffer("valid_mask", valid_mask)
self._reset_parameters()
-
+
def _reset_parameters(self):
bias = bias_init_with_prob(0.01)
init.constant_(self.enc_score_head.bias, bias)
@@ -397,7 +471,7 @@ def _reset_parameters(self):
init.constant_(_cls.bias, bias)
init.constant_(_reg.layers[-1].weight, 0)
init.constant_(_reg.layers[-1].bias, 0)
-
+
init.xavier_uniform_(self.enc_output[0].weight)
if self.learn_query_content:
init.xavier_uniform_(self.tgt_embed.weight)
@@ -410,9 +484,18 @@ def _build_input_proj_layer(self, feat_channels):
self.input_proj = nn.ModuleList()
for in_channels in feat_channels:
self.input_proj.append(
- nn.Sequential(OrderedDict([
- ('conv', nn.Conv2d(in_channels, self.hidden_dim, 1, bias=False)),
- ('norm', nn.BatchNorm2d(self.hidden_dim,))])
+ nn.Sequential(
+ OrderedDict(
+ [
+ ("conv", nn.Conv2d(in_channels, self.hidden_dim, 1, bias=False)),
+ (
+ "norm",
+ nn.BatchNorm2d(
+ self.hidden_dim,
+ ),
+ ),
+ ]
+ )
)
)
@@ -420,9 +503,18 @@ def _build_input_proj_layer(self, feat_channels):
for _ in range(self.num_levels - len(feat_channels)):
self.input_proj.append(
- nn.Sequential(OrderedDict([
- ('conv', nn.Conv2d(in_channels, self.hidden_dim, 3, 2, padding=1, bias=False)),
- ('norm', nn.BatchNorm2d(self.hidden_dim))])
+ nn.Sequential(
+ OrderedDict(
+ [
+ (
+ "conv",
+ nn.Conv2d(
+ in_channels, self.hidden_dim, 3, 2, padding=1, bias=False
+ ),
+ ),
+ ("norm", nn.BatchNorm2d(self.hidden_dim)),
+ ]
+ )
)
)
in_channels = self.hidden_dim
@@ -451,11 +543,9 @@ def _get_encoder_input(self, feats: List[torch.Tensor]):
feat_flatten = torch.concat(feat_flatten, 1)
return feat_flatten, spatial_shapes
- def _generate_anchors(self,
- spatial_shapes=None,
- grid_size=0.05,
- dtype=torch.float32,
- device='cpu'):
+ def _generate_anchors(
+ self, spatial_shapes=None, grid_size=0.05, dtype=torch.float32, device="cpu"
+ ):
if spatial_shapes is None:
spatial_shapes = []
eval_h, eval_w = self.eval_spatial_size
@@ -464,10 +554,10 @@ def _generate_anchors(self,
anchors = []
for lvl, (h, w) in enumerate(spatial_shapes):
- grid_y, grid_x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij')
+ grid_y, grid_x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
grid_xy = torch.stack([grid_x, grid_y], dim=-1)
grid_xy = (grid_xy.unsqueeze(0) + 0.5) / torch.tensor([w, h], dtype=dtype)
- wh = torch.ones_like(grid_xy) * grid_size * (2.0 ** lvl)
+ wh = torch.ones_like(grid_xy) * grid_size * (2.0**lvl)
lvl_anchors = torch.concat([grid_xy, wh], dim=-1).reshape(-1, h * w, 4)
anchors.append(lvl_anchors)
@@ -478,13 +568,9 @@ def _generate_anchors(self,
return anchors, valid_mask
-
- def _get_decoder_input(self,
- memory: torch.Tensor,
- spatial_shapes,
- denoising_logits=None,
- denoising_bbox_unact=None):
-
+ def _get_decoder_input(
+ self, memory: torch.Tensor, spatial_shapes, denoising_logits=None, denoising_bbox_unact=None
+ ):
# prepare input for decoder
if self.training or self.eval_spatial_size is None:
anchors, valid_mask = self._generate_anchors(spatial_shapes, device=memory.device)
@@ -493,82 +579,101 @@ def _get_decoder_input(self,
valid_mask = self.valid_mask
# memory = torch.where(valid_mask, memory, 0)
- # TODO fix type error for onnx export
- memory = valid_mask.to(memory.dtype) * memory
+ # TODO fix type error for onnx export
+ memory = valid_mask.to(memory.dtype) * memory
- output_memory :torch.Tensor = self.enc_output(memory)
- enc_outputs_logits :torch.Tensor = self.enc_score_head(output_memory)
- enc_outputs_coord_unact :torch.Tensor = self.enc_bbox_head(output_memory) + anchors
+ output_memory: torch.Tensor = self.enc_output(memory)
+ enc_outputs_logits: torch.Tensor = self.enc_score_head(output_memory)
+ enc_outputs_coord_unact: torch.Tensor = self.enc_bbox_head(output_memory) + anchors
enc_topk_bboxes_list, enc_topk_logits_list = [], []
- enc_topk_memory, enc_topk_logits, enc_topk_bbox_unact = \
- self._select_topk(output_memory, enc_outputs_logits, enc_outputs_coord_unact, self.num_queries)
-
+ enc_topk_memory, enc_topk_logits, enc_topk_bbox_unact = self._select_topk(
+ output_memory, enc_outputs_logits, enc_outputs_coord_unact, self.num_queries
+ )
+
if self.training:
enc_topk_bboxes = F.sigmoid(enc_topk_bbox_unact)
enc_topk_bboxes_list.append(enc_topk_bboxes)
enc_topk_logits_list.append(enc_topk_logits)
- # if self.num_select_queries != self.num_queries:
+ # if self.num_select_queries != self.num_queries:
# raise NotImplementedError('')
if self.learn_query_content:
content = self.tgt_embed.weight.unsqueeze(0).tile([memory.shape[0], 1, 1])
else:
content = enc_topk_memory.detach()
-
+
enc_topk_bbox_unact = enc_topk_bbox_unact.detach()
-
+
if denoising_bbox_unact is not None:
enc_topk_bbox_unact = torch.concat([denoising_bbox_unact, enc_topk_bbox_unact], dim=1)
content = torch.concat([denoising_logits, content], dim=1)
-
+
return content, enc_topk_bbox_unact, enc_topk_bboxes_list, enc_topk_logits_list
- def _select_topk(self, memory: torch.Tensor, outputs_logits: torch.Tensor, outputs_coords_unact: torch.Tensor, topk: int):
- if self.query_select_method == 'default':
+ def _select_topk(
+ self,
+ memory: torch.Tensor,
+ outputs_logits: torch.Tensor,
+ outputs_coords_unact: torch.Tensor,
+ topk: int,
+ ):
+ if self.query_select_method == "default":
_, topk_ind = torch.topk(outputs_logits.max(-1).values, topk, dim=-1)
- elif self.query_select_method == 'one2many':
+ elif self.query_select_method == "one2many":
_, topk_ind = torch.topk(outputs_logits.flatten(1), topk, dim=-1)
topk_ind = topk_ind // self.num_classes
- elif self.query_select_method == 'agnostic':
+ elif self.query_select_method == "agnostic":
_, topk_ind = torch.topk(outputs_logits.squeeze(-1), topk, dim=-1)
-
+
topk_ind: torch.Tensor
- topk_coords = outputs_coords_unact.gather(dim=1, \
- index=topk_ind.unsqueeze(-1).repeat(1, 1, outputs_coords_unact.shape[-1]))
-
- topk_logits = outputs_logits.gather(dim=1, \
- index=topk_ind.unsqueeze(-1).repeat(1, 1, outputs_logits.shape[-1]))
-
- topk_memory = memory.gather(dim=1, \
- index=topk_ind.unsqueeze(-1).repeat(1, 1, memory.shape[-1]))
+ topk_coords = outputs_coords_unact.gather(
+ dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, outputs_coords_unact.shape[-1])
+ )
- return topk_memory, topk_logits, topk_coords
+ topk_logits = outputs_logits.gather(
+ dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, outputs_logits.shape[-1])
+ )
+
+ topk_memory = memory.gather(
+ dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, memory.shape[-1])
+ )
+ return topk_memory, topk_logits, topk_coords
def forward(self, feats, targets=None):
# input projection and embedding
memory, spatial_shapes = self._get_encoder_input(feats)
-
+
# prepare denoising training
if self.training and self.num_denoising > 0:
- denoising_logits, denoising_bbox_unact, attn_mask, dn_meta = \
- get_contrastive_denoising_training_group(targets, \
- self.num_classes,
- self.num_queries,
- self.denoising_class_embed,
- num_denoising=self.num_denoising,
- label_noise_ratio=self.label_noise_ratio,
- box_noise_scale=self.box_noise_scale, )
+ (
+ denoising_logits,
+ denoising_bbox_unact,
+ attn_mask,
+ dn_meta,
+ ) = get_contrastive_denoising_training_group(
+ targets,
+ self.num_classes,
+ self.num_queries,
+ self.denoising_class_embed,
+ num_denoising=self.num_denoising,
+ label_noise_ratio=self.label_noise_ratio,
+ box_noise_scale=self.box_noise_scale,
+ )
else:
denoising_logits, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None
- init_ref_contents, init_ref_points_unact, enc_topk_bboxes_list, enc_topk_logits_list = \
- self._get_decoder_input(memory, spatial_shapes, denoising_logits, denoising_bbox_unact)
+ (
+ init_ref_contents,
+ init_ref_points_unact,
+ enc_topk_bboxes_list,
+ enc_topk_logits_list,
+ ) = self._get_decoder_input(memory, spatial_shapes, denoising_logits, denoising_bbox_unact)
# decoder
out_bboxes, out_logits = self.decoder(
@@ -579,30 +684,29 @@ def forward(self, feats, targets=None):
self.dec_bbox_head,
self.dec_score_head,
self.query_pos_head,
- attn_mask=attn_mask)
+ attn_mask=attn_mask,
+ )
if self.training and dn_meta is not None:
- dn_out_bboxes, out_bboxes = torch.split(out_bboxes, dn_meta['dn_num_split'], dim=2)
- dn_out_logits, out_logits = torch.split(out_logits, dn_meta['dn_num_split'], dim=2)
+ dn_out_bboxes, out_bboxes = torch.split(out_bboxes, dn_meta["dn_num_split"], dim=2)
+ dn_out_logits, out_logits = torch.split(out_logits, dn_meta["dn_num_split"], dim=2)
- out = {'pred_logits': out_logits[-1], 'pred_boxes': out_bboxes[-1]}
+ out = {"pred_logits": out_logits[-1], "pred_boxes": out_bboxes[-1]}
if self.training and self.aux_loss:
- out['aux_outputs'] = self._set_aux_loss(out_logits[:-1], out_bboxes[:-1])
- out['enc_aux_outputs'] = self._set_aux_loss(enc_topk_logits_list, enc_topk_bboxes_list)
- out['enc_meta'] = {'class_agnostic': self.query_select_method == 'agnostic'}
+ out["aux_outputs"] = self._set_aux_loss(out_logits[:-1], out_bboxes[:-1])
+ out["enc_aux_outputs"] = self._set_aux_loss(enc_topk_logits_list, enc_topk_bboxes_list)
+ out["enc_meta"] = {"class_agnostic": self.query_select_method == "agnostic"}
if dn_meta is not None:
- out['dn_aux_outputs'] = self._set_aux_loss(dn_out_logits, dn_out_bboxes)
- out['dn_meta'] = dn_meta
+ out["dn_aux_outputs"] = self._set_aux_loss(dn_out_logits, dn_out_bboxes)
+ out["dn_meta"] = dn_meta
return out
-
@torch.jit.unused
def _set_aux_loss(self, outputs_class, outputs_coord):
# this is a workaround to make torchscript happy, as torchscript
# doesn't support dictionary with non-homogeneous values, such
# as a dict having both a Tensor and a list.
- return [{'pred_logits': a, 'pred_boxes': b}
- for a, b in zip(outputs_class, outputs_coord)]
+ return [{"pred_logits": a, "pred_boxes": b} for a, b in zip(outputs_class, outputs_coord)]
diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/utils.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/utils.py
index 7653bf9e6..5d25119b7 100644
--- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/utils.py
+++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/utils.py
@@ -4,13 +4,13 @@
import math
from typing import List
-import torch
+import torch
import torch.nn as nn
-import torch.nn.functional as F
+import torch.nn.functional as F
-def inverse_sigmoid(x: torch.Tensor, eps: float=1e-5) -> torch.Tensor:
- x = x.clip(min=0., max=1.)
+def inverse_sigmoid(x: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
+ x = x.clip(min=0.0, max=1.0)
return torch.log(x.clip(min=eps) / (1 - x).clip(min=eps))
@@ -20,13 +20,14 @@ def bias_init_with_prob(prior_prob=0.01):
return bias_init
-def deformable_attention_core_func_v2(\
- value: torch.Tensor,
+def deformable_attention_core_func_v2(
+ value: torch.Tensor,
value_spatial_shapes,
- sampling_locations: torch.Tensor,
- attention_weights: torch.Tensor,
- num_points_list: List[int],
- method='default'):
+ sampling_locations: torch.Tensor,
+ attention_weights: torch.Tensor,
+ num_points_list: List[int],
+ method="default",
+):
"""
Args:
value (Tensor): [bs, value_length, n_head, c]
@@ -40,15 +41,15 @@ def deformable_attention_core_func_v2(\
"""
bs, _, n_head, c = value.shape
_, Len_q, _, _, _ = sampling_locations.shape
-
+
split_shape = [h * w for h, w in value_spatial_shapes]
value_list = value.permute(0, 2, 3, 1).flatten(0, 1).split(split_shape, dim=-1)
# sampling_offsets [8, 480, 8, 12, 2]
- if method == 'default':
+ if method == "default":
sampling_grids = 2 * sampling_locations - 1
- elif method == 'discrete':
+ elif method == "discrete":
sampling_grids = sampling_locations
sampling_grids = sampling_grids.permute(0, 2, 1, 3, 4).flatten(0, 1)
@@ -59,69 +60,77 @@ def deformable_attention_core_func_v2(\
value_l = value_list[level].reshape(bs * n_head, c, h, w)
sampling_grid_l: torch.Tensor = sampling_locations_list[level]
- if method == 'default':
+ if method == "default":
sampling_value_l = F.grid_sample(
- value_l,
- sampling_grid_l,
- mode='bilinear',
- padding_mode='zeros',
- align_corners=False)
-
- elif method == 'discrete':
+ value_l, sampling_grid_l, mode="bilinear", padding_mode="zeros", align_corners=False
+ )
+
+ elif method == "discrete":
# n * m, seq, n, 2
- sampling_coord = (sampling_grid_l * torch.tensor([[w, h]], device=value.device) + 0.5).to(torch.int64)
+ sampling_coord = (
+ sampling_grid_l * torch.tensor([[w, h]], device=value.device) + 0.5
+ ).to(torch.int64)
# FIX ME? for rectangle input
- sampling_coord = sampling_coord.clamp(0, h - 1)
- sampling_coord = sampling_coord.reshape(bs * n_head, Len_q * num_points_list[level], 2)
+ sampling_coord = sampling_coord.clamp(0, h - 1)
+ sampling_coord = sampling_coord.reshape(bs * n_head, Len_q * num_points_list[level], 2)
+
+ s_idx = (
+ torch.arange(sampling_coord.shape[0], device=value.device)
+ .unsqueeze(-1)
+ .repeat(1, sampling_coord.shape[1])
+ )
+ sampling_value_l: torch.Tensor = value_l[
+ s_idx, :, sampling_coord[..., 1], sampling_coord[..., 0]
+ ] # n l c
+
+ sampling_value_l = sampling_value_l.permute(0, 2, 1).reshape(
+ bs * n_head, c, Len_q, num_points_list[level]
+ )
- s_idx = torch.arange(sampling_coord.shape[0], device=value.device).unsqueeze(-1).repeat(1, sampling_coord.shape[1])
- sampling_value_l: torch.Tensor = value_l[s_idx, :, sampling_coord[..., 1], sampling_coord[..., 0]] # n l c
-
- sampling_value_l = sampling_value_l.permute(0, 2, 1).reshape(bs * n_head, c, Len_q, num_points_list[level])
-
sampling_value_list.append(sampling_value_l)
- attn_weights = attention_weights.permute(0, 2, 1, 3).reshape(bs * n_head, 1, Len_q, sum(num_points_list))
+ attn_weights = attention_weights.permute(0, 2, 1, 3).reshape(
+ bs * n_head, 1, Len_q, sum(num_points_list)
+ )
weighted_sample_locs = torch.concat(sampling_value_list, dim=-1) * attn_weights
output = weighted_sample_locs.sum(-1).reshape(bs, n_head * c, Len_q)
return output.permute(0, 2, 1)
-def get_activation(act: str, inpace: bool=True):
- """get activation
- """
+def get_activation(act: str, inpace: bool = True):
+ """get activation"""
if act is None:
return nn.Identity()
elif isinstance(act, nn.Module):
- return act
+ return act
act = act.lower()
-
- if act == 'silu' or act == 'swish':
+
+ if act == "silu" or act == "swish":
m = nn.SiLU()
- elif act == 'relu':
+ elif act == "relu":
m = nn.ReLU()
- elif act == 'leaky_relu':
+ elif act == "leaky_relu":
m = nn.LeakyReLU()
- elif act == 'silu':
+ elif act == "silu":
m = nn.SiLU()
-
- elif act == 'gelu':
+
+ elif act == "gelu":
m = nn.GELU()
- elif act == 'hardsigmoid':
+ elif act == "hardsigmoid":
m = nn.Hardsigmoid()
else:
- raise RuntimeError('')
+ raise RuntimeError("")
- if hasattr(m, 'inplace'):
+ if hasattr(m, "inplace"):
m.inplace = inpace
-
- return m
+
+ return m
diff --git a/PytorchWildlife/models/detection/ultralytics_based/Deepfaune.py b/PytorchWildlife/models/detection/ultralytics_based/Deepfaune.py
index 36c84592f..0d45404ea 100644
--- a/PytorchWildlife/models/detection/ultralytics_based/Deepfaune.py
+++ b/PytorchWildlife/models/detection/ultralytics_based/Deepfaune.py
@@ -9,23 +9,20 @@
from .yolov8_base import YOLOV8Base
__all__ = [
- 'DeepfauneDetector',
+ "DeepfauneDetector",
]
+
class DeepfauneDetector(YOLOV8Base):
"""
- MegaDetectorV6 is a specialized class derived from the YOLOV8Base class
+ MegaDetectorV6 is a specialized class derived from the YOLOV8Base class
that is specifically designed for detecting animals, persons, and vehicles.
-
+
Attributes:
CLASS_NAMES (dict): Mapping of class IDs to their respective names.
"""
-
- CLASS_NAMES = {
- 0: "animal",
- 1: "person",
- 2: "vehicle"
- }
+
+ CLASS_NAMES = {0: "animal", 1: "person", 2: "vehicle"}
def __init__(self, weights=None, device="cpu"):
"""
@@ -37,7 +34,7 @@ def __init__(self, weights=None, device="cpu"):
"""
self.IMAGE_SIZE = 960
- url = "https://pbil.univ-lyon1.fr/software/download/deepfaune/v1.3/deepfaune-yolov8s_960.pt"
+ url = "https://pbil.univ-lyon1.fr/software/download/deepfaune/v1.3/deepfaune-yolov8s_960.pt"
self.MODEL_NAME = "deepfaune-yolov8s_960.pt"
- super(DeepfauneDetector, self).__init__(weights=weights, device=device, url=url)
\ No newline at end of file
+ super(DeepfauneDetector, self).__init__(weights=weights, device=device, url=url)
diff --git a/PytorchWildlife/models/detection/ultralytics_based/__init__.py b/PytorchWildlife/models/detection/ultralytics_based/__init__.py
index cb9fd8b1e..5906a19e7 100644
--- a/PytorchWildlife/models/detection/ultralytics_based/__init__.py
+++ b/PytorchWildlife/models/detection/ultralytics_based/__init__.py
@@ -1,6 +1,6 @@
-from .yolov5_base import *
-from .yolov8_base import *
+from .Deepfaune import *
from .megadetectorv5 import *
from .megadetectorv6 import *
from .megadetectorv6_distributed import *
-from .Deepfaune import *
+from .yolov5_base import *
+from .yolov8_base import *
diff --git a/PytorchWildlife/models/detection/ultralytics_based/megadetectorv5.py b/PytorchWildlife/models/detection/ultralytics_based/megadetectorv5.py
index cbbbb6671..40effd318 100644
--- a/PytorchWildlife/models/detection/ultralytics_based/megadetectorv5.py
+++ b/PytorchWildlife/models/detection/ultralytics_based/megadetectorv5.py
@@ -3,40 +3,35 @@
from .yolov5_base import YOLOV5Base
-__all__ = [
- 'MegaDetectorV5'
-]
+__all__ = ["MegaDetectorV5"]
+
class MegaDetectorV5(YOLOV5Base):
"""
- MegaDetectorV5 is a specialized class derived from the YOLOV5Base class
+ MegaDetectorV5 is a specialized class derived from the YOLOV5Base class
that is specifically designed for detecting animals, persons, and vehicles.
-
+
Attributes:
IMAGE_SIZE (int): The standard image size used during training.
STRIDE (int): Stride value used in the detector.
CLASS_NAMES (dict): Mapping of class IDs to their respective names.
"""
-
+
IMAGE_SIZE = 1280 # image size used in training
STRIDE = 64
- CLASS_NAMES = {
- 0: "animal",
- 1: "person",
- 2: "vehicle"
- }
+ CLASS_NAMES = {0: "animal", 1: "person", 2: "vehicle"}
def __init__(self, weights=None, device="cpu", pretrained=True, version="a"):
"""
Initializes the MegaDetectorV5 model with the option to load pretrained weights.
-
+
Args:
weights (str, optional): Path to the weights file.
device (str, optional): Device to load the model on (e.g., "cpu" or "cuda"). Default is "cpu".
pretrained (bool, optional): Whether to load the pretrained model. Default is True.
version (str, optional): Version of the MegaDetectorV5 model to load. Default is "a".
"""
-
+
if pretrained:
if version == "a":
url = "https://zenodo.org/records/13357337/files/md_v5a.0.0.pt?download=1"
@@ -45,14 +40,12 @@ def __init__(self, weights=None, device="cpu", pretrained=True, version="a"):
else:
url = None
- import site
+ import site
import sys
- sys.path.insert(0, site.getsitepackages()[0]+'/yolov5')
+
+ sys.path.insert(0, site.getsitepackages()[0] + "/yolov5")
super(MegaDetectorV5, self).__init__(weights=weights, device=device, url=url)
-
-
-
# %%
diff --git a/PytorchWildlife/models/detection/ultralytics_based/megadetectorv6.py b/PytorchWildlife/models/detection/ultralytics_based/megadetectorv6.py
index 0d0ee6584..389ab4010 100644
--- a/PytorchWildlife/models/detection/ultralytics_based/megadetectorv6.py
+++ b/PytorchWildlife/models/detection/ultralytics_based/megadetectorv6.py
@@ -1,29 +1,23 @@
-
from .yolov8_base import YOLOV8Base
-__all__ = [
- 'MegaDetectorV6'
-]
+__all__ = ["MegaDetectorV6"]
+
class MegaDetectorV6(YOLOV8Base):
"""
- MegaDetectorV6 is a specialized class derived from the YOLOV8Base class
+ MegaDetectorV6 is a specialized class derived from the YOLOV8Base class
that is specifically designed for detecting animals, persons, and vehicles.
-
+
Attributes:
CLASS_NAMES (dict): Mapping of class IDs to their respective names.
"""
-
- CLASS_NAMES = {
- 0: "animal",
- 1: "person",
- 2: "vehicle"
- }
-
- def __init__(self, weights=None, device="cpu", pretrained=True, version='yolov9c'):
+
+ CLASS_NAMES = {0: "animal", 1: "person", 2: "vehicle"}
+
+ def __init__(self, weights=None, device="cpu", pretrained=True, version="yolov9c"):
"""
Initializes the MegaDetectorV5 model with the option to load pretrained weights.
-
+
Args:
weights (str, optional): Path to the weights file.
device (str, optional): Device to load the model on (e.g., "cpu" or "cuda"). Default is "cpu".
@@ -32,22 +26,24 @@ def __init__(self, weights=None, device="cpu", pretrained=True, version='yolov9c
"""
self.IMAGE_SIZE = 1280
- if version == 'MDV6-yolov9-c':
- url = "https://zenodo.org/records/15398270/files/MDV6-yolov9-c.pt?download=1"
+ if version == "MDV6-yolov9-c":
+ url = "https://zenodo.org/records/15398270/files/MDV6-yolov9-c.pt?download=1"
self.MODEL_NAME = "MDV6b-yolov9-c.pt"
- elif version == 'MDV6-yolov9-e':
+ elif version == "MDV6-yolov9-e":
url = "https://zenodo.org/records/15398270/files/MDV6-yolov9-e-1280.pt?download=1"
self.MODEL_NAME = "MDV6-yolov9-e-1280.pt"
- elif version == 'MDV6-yolov10-c':
+ elif version == "MDV6-yolov10-c":
url = "https://zenodo.org/records/15398270/files/MDV6-yolov10-c.pt?download=1"
self.MODEL_NAME = "MDV6-yolov10-c.pt"
- elif version == 'MDV6-yolov10-e':
+ elif version == "MDV6-yolov10-e":
url = "https://zenodo.org/records/15398270/files/MDV6-yolov10-e-1280.pt?download=1"
self.MODEL_NAME = "MDV6-yolov10-e-1280.pt"
- elif version == 'MDV6-rtdetr-c':
+ elif version == "MDV6-rtdetr-c":
url = "https://zenodo.org/records/15398270/files/MDV6-rtdetr-c.pt?download=1"
self.MODEL_NAME = "MDV6b-rtdetr-c.pt"
else:
- raise ValueError('Select a valid model version: MDV6-yolov9-c, MDV6-yolov9-e, MDV6-yolov10-c, MDV6-yolov10-e, MDV6-rtdetr-c')
+ raise ValueError(
+ "Select a valid model version: MDV6-yolov9-c, MDV6-yolov9-e, MDV6-yolov10-c, MDV6-yolov10-e, MDV6-rtdetr-c"
+ )
- super(MegaDetectorV6, self).__init__(weights=weights, device=device, url=url)
\ No newline at end of file
+ super(MegaDetectorV6, self).__init__(weights=weights, device=device, url=url)
diff --git a/PytorchWildlife/models/detection/ultralytics_based/megadetectorv6_distributed.py b/PytorchWildlife/models/detection/ultralytics_based/megadetectorv6_distributed.py
index ba1cf0256..a8303acd6 100644
--- a/PytorchWildlife/models/detection/ultralytics_based/megadetectorv6_distributed.py
+++ b/PytorchWildlife/models/detection/ultralytics_based/megadetectorv6_distributed.py
@@ -1,29 +1,23 @@
-
from .yolov8_distributed import YOLOV8_Distributed
-__all__ = [
- 'MegaDetectorV6_Distributed'
-]
+__all__ = ["MegaDetectorV6_Distributed"]
+
class MegaDetectorV6_Distributed(YOLOV8_Distributed):
"""
- MegaDetectorV6 is a specialized class derived from the YOLOV8Base class
+ MegaDetectorV6 is a specialized class derived from the YOLOV8Base class
that is specifically designed for detecting animals, persons, and vehicles.
-
+
Attributes:
CLASS_NAMES (dict): Mapping of class IDs to their respective names.
"""
-
- CLASS_NAMES = {
- 0: "animal",
- 1: "person",
- 2: "vehicle"
- }
-
- def __init__(self, weights=None, device="cpu", pretrained=True, version='yolov9c'):
+
+ CLASS_NAMES = {0: "animal", 1: "person", 2: "vehicle"}
+
+ def __init__(self, weights=None, device="cpu", pretrained=True, version="yolov9c"):
"""
Initializes the MegaDetectorV5 model with the option to load pretrained weights.
-
+
Args:
weights (str, optional): Path to the weights file.
device (str, optional): Device to load the model on (e.g., "cpu" or "cuda"). Default is "cpu".
@@ -32,22 +26,24 @@ def __init__(self, weights=None, device="cpu", pretrained=True, version='yolov9c
"""
self.IMAGE_SIZE = 1280
- if version == 'MDV6-yolov9-c':
- url = "https://zenodo.org/records/15398270/files/MDV6-yolov9-c.pt?download=1"
+ if version == "MDV6-yolov9-c":
+ url = "https://zenodo.org/records/15398270/files/MDV6-yolov9-c.pt?download=1"
self.MODEL_NAME = "MDV6b-yolov9-c.pt"
- elif version == 'MDV6-yolov9-e':
+ elif version == "MDV6-yolov9-e":
url = "https://zenodo.org/records/15398270/files/MDV6-yolov9-e-1280.pt?download=1"
self.MODEL_NAME = "MDV6-yolov9-e-1280.pt"
- elif version == 'MDV6-yolov10-c':
+ elif version == "MDV6-yolov10-c":
url = "https://zenodo.org/records/15398270/files/MDV6-yolov10-c.pt?download=1"
self.MODEL_NAME = "MDV6-yolov10-c.pt"
- elif version == 'MDV6-yolov10-e':
+ elif version == "MDV6-yolov10-e":
url = "https://zenodo.org/records/15398270/files/MDV6-yolov10-e-1280.pt?download=1"
self.MODEL_NAME = "MDV6-yolov10-e-1280.pt"
- elif version == 'MDV6-rtdetr-c':
+ elif version == "MDV6-rtdetr-c":
url = "https://zenodo.org/records/15398270/files/MDV6-rtdetr-c.pt?download=1"
self.MODEL_NAME = "MDV6b-rtdetr-c.pt"
else:
- raise ValueError('Select a valid model version: MDV6-yolov9-c, MDV6-yolov9-e, MDV6-yolov10-c, MDV6-yolov10-e, MDV6-rtdetr-c')
+ raise ValueError(
+ "Select a valid model version: MDV6-yolov9-c, MDV6-yolov9-e, MDV6-yolov10-c, MDV6-yolov10-e, MDV6-rtdetr-c"
+ )
- super(MegaDetectorV6_Distributed, self).__init__(weights=weights, device=device, url=url)
\ No newline at end of file
+ super(MegaDetectorV6_Distributed, self).__init__(weights=weights, device=device, url=url)
diff --git a/PytorchWildlife/models/detection/ultralytics_based/yolov5_base.py b/PytorchWildlife/models/detection/ultralytics_based/yolov5_base.py
index 917b6d93c..ce00bfe09 100644
--- a/PytorchWildlife/models/detection/ultralytics_based/yolov5_base.py
+++ b/PytorchWildlife/models/detection/ultralytics_based/yolov5_base.py
@@ -6,18 +6,17 @@
# Importing basic libraries
import numpy as np
-from tqdm import tqdm
-from PIL import Image
import supervision as sv
-
import torch
-from torch.utils.data import DataLoader
+from PIL import Image
from torch.hub import load_state_dict_from_url
-
+from torch.utils.data import DataLoader
+from tqdm import tqdm
from yolov5.utils.general import non_max_suppression, scale_boxes
-from ..base_detector import BaseDetector
-from ....data import transforms as pw_trans
+
from ....data import datasets as pw_data
+from ....data import transforms as pw_trans
+from ..base_detector import BaseDetector
class YOLOV5Base(BaseDetector):
@@ -25,16 +24,17 @@ class YOLOV5Base(BaseDetector):
Base detector class for YOLO V5. This class provides utility methods for
loading the model, generating results, and performing single and batch image detections.
"""
+
def __init__(self, weights=None, device="cpu", url=None, transform=None):
"""
Initialize the YOLO V5 detector.
-
+
Args:
- weights (str, optional):
+ weights (str, optional):
Path to the model weights. Defaults to None.
- device (str, optional):
+ device (str, optional):
Device for model inference. Defaults to "cpu".
- url (str, optional):
+ url (str, optional):
URL to fetch the model weights. Defaults to None.
transform (callable, optional):
Optional transform to be applied on the image. Defaults to None.
@@ -46,13 +46,13 @@ def __init__(self, weights=None, device="cpu", url=None, transform=None):
def _load_model(self, weights=None, device="cpu", url=None):
"""
Load the YOLO V5 model weights.
-
+
Args:
- weights (str, optional):
+ weights (str, optional):
Path to the model weights. Defaults to None.
- device (str, optional):
+ device (str, optional):
Device for model inference. Defaults to "cpu".
- url (str, optional):
+ url (str, optional):
URL to fetch the model weights. Defaults to None.
Raises:
Exception: If weights are not provided.
@@ -64,21 +64,22 @@ def _load_model(self, weights=None, device="cpu", url=None):
else:
raise Exception("Need weights for inference.")
self.model = checkpoint["model"].float().fuse().eval().to(self.device)
-
+
if not self.transform:
- self.transform = pw_trans.MegaDetector_v5_Transform(target_size=self.IMAGE_SIZE,
- stride=self.STRIDE)
+ self.transform = pw_trans.MegaDetector_v5_Transform(
+ target_size=self.IMAGE_SIZE, stride=self.STRIDE
+ )
def results_generation(self, preds, img_id, id_strip=None) -> dict:
"""
Generate results for detection based on model predictions.
-
+
Args:
- preds (numpy.ndarray):
+ preds (numpy.ndarray):
Model predictions.
- img_id (str):
+ img_id (str):
Image identifier.
- id_strip (str, optional):
+ id_strip (str, optional):
Strip specific characters from img_id. Defaults to None.
Returns:
@@ -86,28 +87,28 @@ def results_generation(self, preds, img_id, id_strip=None) -> dict:
"""
results = {"img_id": str(img_id).strip(id_strip)}
results["detections"] = sv.Detections(
- xyxy=preds[:, :4],
- confidence=preds[:, 4],
- class_id=preds[:, 5].astype(int)
+ xyxy=preds[:, :4], confidence=preds[:, 4], class_id=preds[:, 5].astype(int)
)
results["labels"] = [
f"{self.CLASS_NAMES[class_id]} {confidence:0.2f}"
- for confidence, class_id in zip(results["detections"].confidence, results["detections"].class_id)
+ for confidence, class_id in zip(
+ results["detections"].confidence, results["detections"].class_id
+ )
]
return results
def single_image_detection(self, img, img_path=None, det_conf_thres=0.2, id_strip=None) -> dict:
"""
Perform detection on a single image.
-
+
Args:
- img (str or ndarray):
+ img (str or ndarray):
Image path or ndarray of images.
- img_path (str, optional):
+ img_path (str, optional):
Image path or identifier.
- det_conf_thres (float, optional):
+ det_conf_thres (float, optional):
Confidence threshold for predictions. Defaults to 0.2.
- id_strip (str, optional):
+ id_strip (str, optional):
Characters to strip from img_id. Defaults to None.
Returns:
@@ -121,19 +122,28 @@ def single_image_detection(self, img, img_path=None, det_conf_thres=0.2, id_stri
img = self.transform(img)
if img_size is None:
- img_size = img.permute((1, 2, 0)).shape # We need hwc instead of chw for coord scaling
+ img_size = img.permute((1, 2, 0)).shape # We need hwc instead of chw for coord scaling
preds = self.model(img.unsqueeze(0).to(self.device))[0]
- preds = torch.cat(non_max_suppression(prediction=preds, conf_thres=det_conf_thres), axis=0).cpu().numpy()
+ preds = (
+ torch.cat(non_max_suppression(prediction=preds, conf_thres=det_conf_thres), axis=0)
+ .cpu()
+ .numpy()
+ )
# preds[:, :4] = scale_coords([self.IMAGE_SIZE] * 2, preds[:, :4], img_size).round()
preds[:, :4] = scale_boxes([self.IMAGE_SIZE] * 2, preds[:, :4], img_size).round()
res = self.results_generation(preds, img_path, id_strip)
- normalized_coords = [[x1 / img_size[1], y1 / img_size[0], x2 / img_size[1], y2 / img_size[0]] for x1, y1, x2, y2 in preds[:, :4]]
+ normalized_coords = [
+ [x1 / img_size[1], y1 / img_size[0], x2 / img_size[1], y2 / img_size[0]]
+ for x1, y1, x2, y2 in preds[:, :4]
+ ]
res["normalized_coords"] = normalized_coords
return res
- def batch_image_detection(self, data_path, batch_size: int = 16, det_conf_thres: float = 0.2, id_strip: str = None) -> list[dict]:
+ def batch_image_detection(
+ self, data_path, batch_size: int = 16, det_conf_thres: float = 0.2, id_strip: str = None
+ ) -> list[dict]:
"""
Perform detection on a batch of images.
@@ -153,8 +163,14 @@ def batch_image_detection(self, data_path, batch_size: int = 16, det_conf_thres:
)
# Creating a DataLoader for batching and parallel processing of the images
- loader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
- pin_memory=True, num_workers=0, drop_last=False)
+ loader = DataLoader(
+ dataset,
+ batch_size=batch_size,
+ shuffle=False,
+ pin_memory=True,
+ num_workers=0,
+ drop_last=False,
+ )
results = []
with tqdm(total=len(loader)) as pbar:
@@ -165,7 +181,7 @@ def batch_image_detection(self, data_path, batch_size: int = 16, det_conf_thres:
batch_results = []
for i, pred in enumerate(predictions):
- if pred.size(0) == 0:
+ if pred.size(0) == 0:
continue
pred = pred.numpy()
size = sizes[i].numpy()
@@ -174,7 +190,10 @@ def batch_image_detection(self, data_path, batch_size: int = 16, det_conf_thres:
# pred[:, :4] = scale_coords([self.IMAGE_SIZE] * 2, pred[:, :4], size).round()
pred[:, :4] = scale_boxes([self.IMAGE_SIZE] * 2, pred[:, :4], size).round()
# Normalize the coordinates for timelapse compatibility
- normalized_coords = [[x1 / size[1], y1 / size[0], x2 / size[1], y2 / size[0]] for x1, y1, x2, y2 in pred[:, :4]]
+ normalized_coords = [
+ [x1 / size[1], y1 / size[0], x2 / size[1], y2 / size[0]]
+ for x1, y1, x2, y2 in pred[:, :4]
+ ]
res = self.results_generation(pred, path, id_strip)
res["normalized_coords"] = normalized_coords
batch_results.append(res)
diff --git a/PytorchWildlife/models/detection/ultralytics_based/yolov8_base.py b/PytorchWildlife/models/detection/ultralytics_based/yolov8_base.py
index 50d6cfbbb..e061d5ea1 100644
--- a/PytorchWildlife/models/detection/ultralytics_based/yolov8_base.py
+++ b/PytorchWildlife/models/detection/ultralytics_based/yolov8_base.py
@@ -6,39 +6,39 @@
# Importing basic libraries
import os
-import wget
+
import numpy as np
-from tqdm import tqdm
-from PIL import Image
import supervision as sv
-
import torch
+import wget
+from PIL import Image
from torch.utils.data import DataLoader
+from tqdm import tqdm
+from ultralytics.models import rtdetr, yolo
-from ultralytics.models import yolo, rtdetr
-
-from ..base_detector import BaseDetector
-from ....data import transforms as pw_trans
from ....data import datasets as pw_data
+from ....data import transforms as pw_trans
+from ..base_detector import BaseDetector
class YOLOV8Base(BaseDetector):
"""
Base detector class for the new ultralytics YOLOV8 framework. This class provides utility methods for
loading the model, generating results, and performing single and batch image detections.
- This base detector class is also compatible with all the new ultralytics models including YOLOV9,
+ This base detector class is also compatible with all the new ultralytics models including YOLOV9,
RTDetr, and more.
"""
+
def __init__(self, weights=None, device="cpu", url=None, transform=None):
"""
Initialize the YOLOV8 detector.
-
+
Args:
- weights (str, optional):
+ weights (str, optional):
Path to the model weights. Defaults to None.
- device (str, optional):
+ device (str, optional):
Device for model inference. Defaults to "cpu".
- url (str, optional):
+ url (str, optional):
URL to fetch the model weights. Defaults to None.
"""
super(YOLOV8Base, self).__init__(weights=weights, device=device, url=url)
@@ -48,30 +48,34 @@ def __init__(self, weights=None, device="cpu", url=None, transform=None):
def _load_model(self, weights=None, device="cpu", url=None):
"""
Load the YOLOV8 model weights.
-
+
Args:
- weights (str, optional):
+ weights (str, optional):
Path to the model weights. Defaults to None.
- device (str, optional):
+ device (str, optional):
Device for model inference. Defaults to "cpu".
- url (str, optional):
+ url (str, optional):
URL to fetch the model weights. Defaults to None.
Raises:
Exception: If weights are not provided.
"""
- if self.MODEL_NAME == 'MDV6b-rtdetrl.pt':
+ if self.MODEL_NAME == "MDV6b-rtdetrl.pt":
self.predictor = rtdetr.RTDETRPredictor()
else:
self.predictor = yolo.detect.DetectionPredictor()
# self.predictor.args.device = device # Will uncomment later
self.predictor.args.imgsz = self.IMAGE_SIZE
- self.predictor.args.save = False # Will see if we want to use ultralytics native inference saving functions.
+ self.predictor.args.save = (
+ False # Will see if we want to use ultralytics native inference saving functions.
+ )
if weights:
self.predictor.setup_model(weights)
elif url:
- if not os.path.exists(os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME)):
+ if not os.path.exists(
+ os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME)
+ ):
os.makedirs(os.path.join(torch.hub.get_dir(), "checkpoints"), exist_ok=True)
weights = wget.download(url, out=os.path.join(torch.hub.get_dir(), "checkpoints"))
else:
@@ -79,21 +83,22 @@ def _load_model(self, weights=None, device="cpu", url=None):
self.predictor.setup_model(weights)
else:
raise Exception("Need weights for inference.")
-
+
if not self.transform:
- self.transform = pw_trans.MegaDetector_v5_Transform(target_size=self.IMAGE_SIZE,
- stride=self.STRIDE)
+ self.transform = pw_trans.MegaDetector_v5_Transform(
+ target_size=self.IMAGE_SIZE, stride=self.STRIDE
+ )
def results_generation(self, preds, img_id, id_strip=None) -> dict:
"""
Generate results for detection based on model predictions.
-
+
Args:
- preds (ultralytics.engine.results.Results):
+ preds (ultralytics.engine.results.Results):
Model predictions.
- img_id (str):
+ img_id (str):
Image identifier.
- id_strip (str, optional):
+ id_strip (str, optional):
Strip specific characters from img_id. Defaults to None.
Returns:
@@ -104,32 +109,27 @@ def results_generation(self, preds, img_id, id_strip=None) -> dict:
class_id = preds.boxes.cls.cpu().numpy().astype(int)
results = {"img_id": str(img_id).strip(id_strip)}
- results["detections"] = sv.Detections(
- xyxy=xyxy,
- confidence=confidence,
- class_id=class_id
- )
+ results["detections"] = sv.Detections(xyxy=xyxy, confidence=confidence, class_id=class_id)
results["labels"] = [
- f"{self.CLASS_NAMES[class_id]} {confidence:0.2f}"
- for _, _, confidence, class_id, _, _ in results["detections"]
+ f"{self.CLASS_NAMES[class_id]} {confidence:0.2f}"
+ for _, _, confidence, class_id, _, _ in results["detections"]
]
-
+
return results
-
def single_image_detection(self, img, img_path=None, det_conf_thres=0.2, id_strip=None) -> dict:
"""
Perform detection on a single image.
-
+
Args:
- img (str or ndarray):
+ img (str or ndarray):
Image path or ndarray of images.
- img_path (str, optional):
+ img_path (str, optional):
Image path or identifier.
- det_conf_thres (float, optional):
+ det_conf_thres (float, optional):
Confidence threshold for predictions. Defaults to 0.2.
- id_strip (str, optional):
+ id_strip (str, optional):
Characters to strip from img_id. Defaults to None.
Returns:
@@ -144,18 +144,22 @@ def single_image_detection(self, img, img_path=None, det_conf_thres=0.2, id_stri
self.predictor.args.batch = 1
self.predictor.args.conf = det_conf_thres
-
+
det_results = list(self.predictor.stream_inference([img]))
res = self.results_generation(det_results[0], img_path, id_strip)
- normalized_coords = [[x1 / img_size[1], y1 / img_size[0], x2 / img_size[1], y2 / img_size[0]]
- for x1, y1, x2, y2 in res["detections"].xyxy]
+ normalized_coords = [
+ [x1 / img_size[1], y1 / img_size[0], x2 / img_size[1], y2 / img_size[0]]
+ for x1, y1, x2, y2 in res["detections"].xyxy
+ ]
res["normalized_coords"] = normalized_coords
-
+
return res
- def batch_image_detection(self, data_source, batch_size: int = 16, det_conf_thres: float = 0.2, id_strip: str = None) -> list[dict]:
+ def batch_image_detection(
+ self, data_source, batch_size: int = 16, det_conf_thres: float = 0.2, id_strip: str = None
+ ) -> list[dict]:
"""
Perform detection on a batch of images.
@@ -174,24 +178,28 @@ def batch_image_detection(self, data_source, batch_size: int = 16, det_conf_thre
# Handle numpy array input
if isinstance(data_source, (list, np.ndarray)):
results = []
- num_batches = (len(data_source) + batch_size - 1) // batch_size # Calculate total batches
-
+ num_batches = (
+ len(data_source) + batch_size - 1
+ ) // batch_size # Calculate total batches
+
with tqdm(total=num_batches) as pbar:
for start_idx in range(0, len(data_source), batch_size):
- batch_arrays = data_source[start_idx:start_idx + batch_size]
+ batch_arrays = data_source[start_idx : start_idx + batch_size]
det_results = self.predictor.stream_inference(batch_arrays)
-
+
for idx, preds in enumerate(det_results):
res = self.results_generation(preds, f"{start_idx + idx}", id_strip)
# Get size directly from numpy array
img_height, img_width = batch_arrays[idx].shape[:2]
- normalized_coords = [[x1/img_width, y1/img_height, x2/img_width, y2/img_height]
- for x1, y1, x2, y2 in res["detections"].xyxy]
+ normalized_coords = [
+ [x1 / img_width, y1 / img_height, x2 / img_width, y2 / img_height]
+ for x1, y1, x2, y2 in res["detections"].xyxy
+ ]
res["normalized_coords"] = normalized_coords
results.append(res)
pbar.update(1)
return results
-
+
# Handle image directory input
dataset = pw_data.DetectionImageFolder(
data_source,
@@ -199,10 +207,15 @@ def batch_image_detection(self, data_source, batch_size: int = 16, det_conf_thre
)
# Creating a DataLoader for batching and parallel processing of the images
- loader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
- pin_memory=True, num_workers=0, drop_last=False
- )
-
+ loader = DataLoader(
+ dataset,
+ batch_size=batch_size,
+ shuffle=False,
+ pin_memory=True,
+ num_workers=0,
+ drop_last=False,
+ )
+
results = []
with tqdm(total=len(loader)) as pbar:
for batch_index, (imgs, paths, sizes) in enumerate(loader):
@@ -212,7 +225,10 @@ def batch_image_detection(self, data_source, batch_size: int = 16, det_conf_thre
res = self.results_generation(preds, paths[idx], id_strip)
size = preds.orig_shape
# Normalize the coordinates for timelapse compatibility
- normalized_coords = [[x1 / size[1], y1 / size[0], x2 / size[1], y2 / size[0]] for x1, y1, x2, y2 in res["detections"].xyxy]
+ normalized_coords = [
+ [x1 / size[1], y1 / size[0], x2 / size[1], y2 / size[0]]
+ for x1, y1, x2, y2 in res["detections"].xyxy
+ ]
res["normalized_coords"] = normalized_coords
results.append(res)
pbar.update(1)
diff --git a/PytorchWildlife/models/detection/ultralytics_based/yolov8_distributed.py b/PytorchWildlife/models/detection/ultralytics_based/yolov8_distributed.py
index 688ea780d..0d071f78f 100644
--- a/PytorchWildlife/models/detection/ultralytics_based/yolov8_distributed.py
+++ b/PytorchWildlife/models/detection/ultralytics_based/yolov8_distributed.py
@@ -7,20 +7,21 @@
import os
import time
from glob import glob
-import supervision as sv
+
import numpy as np
import pandas as pd
-from PIL import Image
-import wget
+import supervision as sv
import torch
-
-from ultralytics.models import yolo, rtdetr
+import wget
+from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm
+from ultralytics.models import rtdetr, yolo
-from ..base_detector import BaseDetector
-from ....data import transforms as pw_trans
from ....data import datasets as pw_data
+from ....data import transforms as pw_trans
+from ..base_detector import BaseDetector
+
class YOLOV8_Distributed(BaseDetector):
"""
@@ -32,46 +33,50 @@ class YOLOV8_Distributed(BaseDetector):
def __init__(self, weights=None, device="cpu", url=None, transform=None):
"""
Initialize the YOLOV8 detector.
-
+
Args:
- weights (str, optional):
+ weights (str, optional):
Path to the model weights. Defaults to None.
- device (str, optional):
+ device (str, optional):
Device for model inference. Defaults to "cpu".
- url (str, optional):
+ url (str, optional):
URL to fetch the model weights. Defaults to None.
"""
self.transform = transform
super(YOLOV8_Distributed, self).__init__(weights=weights, device=device, url=url)
self._load_model(weights, self.device, url)
-
+
def _load_model(self, weights=None, device="cpu", url=None):
"""
Load the YOLOV8 model weights.
-
+
Args:
- weights (str, optional):
+ weights (str, optional):
Path to the model weights. Defaults to None.
- device (str, optional):
+ device (str, optional):
Device for model inference. Defaults to "cpu".
- url (str, optional):
+ url (str, optional):
URL to fetch the model weights. Defaults to None.
Raises:
Exception: If weights are not provided.
"""
- if self.MODEL_NAME == 'MDV6b-rtdetrl.pt':
+ if self.MODEL_NAME == "MDV6b-rtdetrl.pt":
self.predictor = rtdetr.RTDETRPredictor()
else:
self.predictor = yolo.detect.DetectionPredictor()
# self.predictor.args.device = device # Will uncomment later
self.predictor.args.imgsz = self.IMAGE_SIZE
- self.predictor.args.save = False # Will see if we want to use ultralytics native inference saving functions.
+ self.predictor.args.save = (
+ False # Will see if we want to use ultralytics native inference saving functions.
+ )
if weights:
self.predictor.setup_model(weights)
elif url:
- if not os.path.exists(os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME)):
+ if not os.path.exists(
+ os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME)
+ ):
os.makedirs(os.path.join(torch.hub.get_dir(), "checkpoints"), exist_ok=True)
weights = wget.download(url, out=os.path.join(torch.hub.get_dir(), "checkpoints"))
else:
@@ -79,21 +84,22 @@ def _load_model(self, weights=None, device="cpu", url=None):
self.predictor.setup_model(weights)
else:
raise Exception("Need weights for inference.")
-
+
if not self.transform:
- self.transform = pw_trans.MegaDetector_v5_Transform(target_size=self.IMAGE_SIZE,
- stride=self.STRIDE)
-
+ self.transform = pw_trans.MegaDetector_v5_Transform(
+ target_size=self.IMAGE_SIZE, stride=self.STRIDE
+ )
+
def results_generation(self, preds, img_id, id_strip=None) -> dict:
"""
Generate results for detection based on model predictions.
-
+
Args:
- preds (ultralytics.engine.results.Results):
+ preds (ultralytics.engine.results.Results):
Model predictions.
- img_id (str):
+ img_id (str):
Image identifier.
- id_strip (str, optional):
+ id_strip (str, optional):
Strip specific characters from img_id. Defaults to None.
Returns:
@@ -102,7 +108,7 @@ def results_generation(self, preds, img_id, id_strip=None) -> dict:
xyxy = preds.boxes.xyxy.cpu().numpy()
confidence = preds.boxes.conf.cpu().numpy()
class_id = preds.boxes.cls.cpu().numpy().astype(int)
-
+
results = {"img_id": str(img_id).strip(id_strip)}
# results["detections"] = sv.Detections(
# xyxy=xyxy,
@@ -114,38 +120,45 @@ def results_generation(self, preds, img_id, id_strip=None) -> dict:
results["detections_class_id"] = class_id
# results["labels"] = [
- # f"{self.CLASS_NAMES[class_id]} {confidence:0.2f}"
- # for _, _, confidence, class_id, _, _ in results["detections"]
+ # f"{self.CLASS_NAMES[class_id]} {confidence:0.2f}"
+ # for _, _, confidence, class_id, _, _ in results["detections"]
# ]
-
+
results["labels"] = [
- f"{self.CLASS_NAMES[cls_id]} {conf:0.2f}"
- for cls_id, conf in zip(class_id, confidence)
+ f"{self.CLASS_NAMES[cls_id]} {conf:0.2f}" for cls_id, conf in zip(class_id, confidence)
]
-
+
results["n_animal_detected"] = np.sum(class_id == 0)
-
+
return results
-
- def batch_image_detection(self, loader, batch_size, global_rank, local_rank, output_dir, det_conf_thres=0.2, checkpoint_frequency = 1000):
+ def batch_image_detection(
+ self,
+ loader,
+ batch_size,
+ global_rank,
+ local_rank,
+ output_dir,
+ det_conf_thres=0.2,
+ checkpoint_frequency=1000,
+ ):
"""
Perform batch image detection using the YOLOV8 model.
-
+
Args:
- loader (torch.utils.data.DataLoader):
+ loader (torch.utils.data.DataLoader):
DataLoader for input images.
batch_size (int):
Size of the batch for detection.
- global_rank (int):
+ global_rank (int):
Global rank of the process.
- local_rank (int):
+ local_rank (int):
Local rank of the process.
- output_dir (str):
+ output_dir (str):
Directory to save detection results.
- det_conf_thres (float, optional):
+ det_conf_thres (float, optional):
Confidence threshold for detections. Defaults to 0.2.
- checkpoint_frequency (int, optional):
+ checkpoint_frequency (int, optional):
Frequency of saving intermediate results. Defaults to 1000.
"""
os.makedirs(output_dir, exist_ok=True)
@@ -153,7 +166,6 @@ def batch_image_detection(self, loader, batch_size, global_rank, local_rank, out
self.predictor.args.conf = det_conf_thres
self.predictor.args.device = local_rank
-
# Create checkpoint directory
# Track batches and processed items
results = {
@@ -163,7 +175,7 @@ def batch_image_detection(self, loader, batch_size, global_rank, local_rank, out
"detections_class_id": [],
"labels": [],
"n_animal_detected": [],
- "normalized_coords": []
+ "normalized_coords": [],
}
checkpoint_dir = os.path.join(output_dir, f"checkpoints_rank{global_rank}")
@@ -178,15 +190,18 @@ def batch_image_detection(self, loader, batch_size, global_rank, local_rank, out
# images: tensor of shape [batch_size, 3, H, W]
# Assuming images are transformed & Standardized
det_results = self.predictor.stream_inference(images)
-
+
for idx, preds in enumerate(det_results):
res = self.results_generation(preds, uuids[idx])
-
+
size = preds.orig_shape
- normalized_coords = [[x1 / size[1], y1 / size[0], x2 / size[1], y2 / size[0]] for x1, y1, x2, y2 in res["detections_xyxy"]]
+ normalized_coords = [
+ [x1 / size[1], y1 / size[0], x2 / size[1], y2 / size[0]]
+ for x1, y1, x2, y2 in res["detections_xyxy"]
+ ]
res["normalized_coords"] = normalized_coords
-
- #results.append(res)
+
+ # results.append(res)
results["img_id"].append(res["img_id"])
results["detections_xyxy"].append(res["detections_xyxy"].tolist())
results["detections_confidence"].append(res["detections_confidence"].tolist())
@@ -194,39 +209,29 @@ def batch_image_detection(self, loader, batch_size, global_rank, local_rank, out
results["labels"].append(res["labels"])
results["n_animal_detected"].append(int(res["n_animal_detected"]))
results["normalized_coords"].append(res["normalized_coords"])
-
+
if batch_counter % checkpoint_frequency == 0:
elapsed = time.time() - start_time
print(f"[Rank {global_rank}] Processed {processed_count} images in {elapsed}")
-
+
# Save intermediate results
checkpoint_path = os.path.join(
- checkpoint_dir,
- f"checkpoint_{batch_counter:06d}.parquet"
+ checkpoint_dir, f"checkpoint_{batch_counter:06d}.parquet"
+ )
+
+ df = pd.DataFrame(
+ {"img_id": results["img_id"], "n_animal_detected": results["n_animal_detected"]}
)
-
- df = pd.DataFrame({
- "img_id": results["img_id"],
- "n_animal_detected": results["n_animal_detected"]
- })
df.to_parquet(checkpoint_path, index=False)
print(f"[Rank {global_rank}] Saved checkpoint to {checkpoint_path}")
-
+
# Save results to disk
os.makedirs(output_dir, exist_ok=True)
- df = pd.DataFrame({
- "img_id": results["img_id"],
- "n_animal_detected": results["n_animal_detected"]
- })
+ df = pd.DataFrame(
+ {"img_id": results["img_id"], "n_animal_detected": results["n_animal_detected"]}
+ )
out_path = os.path.join(output_dir, f"predictions_rank{global_rank}.parquet")
df.to_parquet(out_path, index=False)
print(f"[rank {global_rank}] Saved predictions to {out_path}")
-
- return results
-
-
-
-
-
-
\ No newline at end of file
+ return results
diff --git a/PytorchWildlife/models/detection/yolo_mit/__init__.py b/PytorchWildlife/models/detection/yolo_mit/__init__.py
index 6d410c7dd..1d77fb0ab 100644
--- a/PytorchWildlife/models/detection/yolo_mit/__init__.py
+++ b/PytorchWildlife/models/detection/yolo_mit/__init__.py
@@ -1,2 +1,2 @@
+from .megadetectorv6_mit import *
from .yolo_mit_base import *
-from .megadetectorv6_mit import *
\ No newline at end of file
diff --git a/PytorchWildlife/models/detection/yolo_mit/megadetectorv6_mit.py b/PytorchWildlife/models/detection/yolo_mit/megadetectorv6_mit.py
index 24a400b7c..bec7e8e99 100644
--- a/PytorchWildlife/models/detection/yolo_mit/megadetectorv6_mit.py
+++ b/PytorchWildlife/models/detection/yolo_mit/megadetectorv6_mit.py
@@ -1,29 +1,23 @@
-
from .yolo_mit_base import YOLOMITBase
-__all__ = [
- 'MegaDetectorV6MIT'
-]
+__all__ = ["MegaDetectorV6MIT"]
+
class MegaDetectorV6MIT(YOLOMITBase):
"""
- MegaDetectorV6 is a specialized class derived from the YOLOMITBase class
+ MegaDetectorV6 is a specialized class derived from the YOLOMITBase class
that is specifically designed for detecting animals, persons, and vehicles.
-
+
Attributes:
CLASS_NAMES (dict): Mapping of class IDs to their respective names.
"""
-
- CLASS_NAMES = {
- 0: "animal",
- 1: "person",
- 2: "vehicle"
- }
-
- def __init__(self, weights=None, device="cpu", pretrained=True, version='MDV6-yolov9-c-mit'):
+
+ CLASS_NAMES = {0: "animal", 1: "person", 2: "vehicle"}
+
+ def __init__(self, weights=None, device="cpu", pretrained=True, version="MDV6-yolov9-c-mit"):
"""
Initializes the MegaDetectorV6 model with the option to load pretrained weights.
-
+
Args:
weights (str, optional): Path to the weights file.
device (str, optional): Device to load the model on (e.g., "cpu" or "cuda"). Default is "cpu".
@@ -32,13 +26,13 @@ def __init__(self, weights=None, device="cpu", pretrained=True, version='MDV6-yo
"""
self.IMAGE_SIZE = 640
- if version == 'MDV6-mit-yolov9-c':
+ if version == "MDV6-mit-yolov9-c":
url = "https://zenodo.org/records/15398270/files/MDV6-mit-yolov9-c.ckpt?download=1"
self.MODEL_NAME = "MDV6-mit-yolov9-c.ckpt"
- elif version == 'MDV6-mit-yolov9-e':
+ elif version == "MDV6-mit-yolov9-e":
url = "https://zenodo.org/records/15398270/files/MDV6-mit-yolov9-e.ckpt?download=1"
self.MODEL_NAME = "MDV6-mit-yolov9-e.ckpt"
else:
- raise ValueError('Select a valid model version: MDV6-mit-yolov9-c or MDV6-mit-yolov9-e')
+ raise ValueError("Select a valid model version: MDV6-mit-yolov9-c or MDV6-mit-yolov9-e")
- super(MegaDetectorV6MIT, self).__init__(weights=weights, device=device, url=url)
\ No newline at end of file
+ super(MegaDetectorV6MIT, self).__init__(weights=weights, device=device, url=url)
diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo/__init__.py b/PytorchWildlife/models/detection/yolo_mit/yolo/__init__.py
index 02d8f69c4..5f7b1b7a7 100644
--- a/PytorchWildlife/models/detection/yolo_mit/yolo/__init__.py
+++ b/PytorchWildlife/models/detection/yolo_mit/yolo/__init__.py
@@ -1,15 +1,14 @@
-from yolo.model.yolo import create_model
from yolo.config import Config, NMSConfig
+from yolo.model.yolo import create_model
from yolo.tools.data_loader import AugmentationComposer, create_dataloader
-from yolo.utils.model_utils import PostProcess
from yolo.utils.bounding_box_utils import create_converter
+from yolo.utils.model_utils import PostProcess
all = [
"create_model",
"Config",
"NMSConfig",
- "AugmentationComposer"
- "create_dataloader",
+ "AugmentationComposer" "create_dataloader",
"PostProcess",
"create_converter",
]
diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo/config.py b/PytorchWildlife/models/detection/yolo_mit/yolo/config.py
index b8b69f5d8..af12389a0 100644
--- a/PytorchWildlife/models/detection/yolo_mit/yolo/config.py
+++ b/PytorchWildlife/models/detection/yolo_mit/yolo/config.py
@@ -1,5 +1,6 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
+
from torch import nn
@@ -165,4 +166,4 @@ class YOLOLayer(nn.Module):
tags: str
layer_type: str
usable: bool
- external: Optional[dict]
\ No newline at end of file
+ external: Optional[dict]
diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo/model/module.py b/PytorchWildlife/models/detection/yolo_mit/yolo/model/module.py
index 87b211d43..ef1ecc00a 100644
--- a/PytorchWildlife/models/detection/yolo_mit/yolo/model/module.py
+++ b/PytorchWildlife/models/detection/yolo_mit/yolo/model/module.py
@@ -1,9 +1,11 @@
+import inspect
from typing import Any, Dict, List, Optional, Tuple, Union
+
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn.common_types import _size_2_t
-import inspect
+
# ----------- Utils ----------- #
def get_layer_map():
@@ -109,7 +111,14 @@ def forward(self, x):
class Detection(nn.Module):
"""A single YOLO Detection head for detection models"""
- def __init__(self, in_channels: Tuple[int], num_classes: int, *, reg_max: int = 16, use_group: bool = True):
+ def __init__(
+ self,
+ in_channels: Tuple[int],
+ num_classes: int,
+ *,
+ reg_max: int = 16,
+ use_group: bool = True,
+ ):
super().__init__()
groups = 4 if use_group else 1
@@ -125,7 +134,9 @@ def __init__(self, in_channels: Tuple[int], num_classes: int, *, reg_max: int =
nn.Conv2d(anchor_neck, anchor_channels, 1, groups=groups),
)
self.class_conv = nn.Sequential(
- Conv(in_channels, class_neck, 3), Conv(class_neck, class_neck, 3), nn.Conv2d(class_neck, num_classes, 1)
+ Conv(in_channels, class_neck, 3),
+ Conv(class_neck, class_neck, 3),
+ nn.Conv2d(class_neck, num_classes, 1),
)
self.anc2vec = Anchor2Vec(reg_max=reg_max)
@@ -151,7 +162,10 @@ def __init__(self, in_channels: List[int], num_classes: int, **head_kwargs):
DetectionHead = IDetection
self.heads = nn.ModuleList(
- [DetectionHead((in_channels[0], in_channel), num_classes, **head_kwargs) for in_channel in in_channels]
+ [
+ DetectionHead((in_channels[0], in_channel), num_classes, **head_kwargs)
+ for in_channel in in_channels
+ ]
)
def forward(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
@@ -166,11 +180,11 @@ def __init__(self, reg_max: int = 16) -> None:
self.anc2vec.weight = nn.Parameter(reverse_reg, requires_grad=False)
def forward(self, anchor_x: Tensor) -> Tensor:
- #anchor_x = rearrange(anchor_x, "B (P R) h w -> B R P h w", P=4)
+ # anchor_x = rearrange(anchor_x, "B (P R) h w -> B R P h w", P=4)
B, PR, h, w = anchor_x.shape
P = 4
R = PR // P
- anchor_x = anchor_x.reshape(B, P, R, h, w).permute(0, 2, 1, 3, 4)
+ anchor_x = anchor_x.reshape(B, P, R, h, w).permute(0, 2, 1, 3, 4)
vector_x = anchor_x.softmax(dim=1)
vector_x = self.anc2vec(vector_x)[:, 0]
return anchor_x, vector_x
@@ -338,6 +352,7 @@ def forward(self, x: Tensor) -> Tensor:
x = self.conv(x)
return x
+
class ADown(nn.Module):
"""Downsampling module combining average and max pooling with convolution for feature reduction."""
@@ -373,6 +388,7 @@ def forward(self, x: Tensor) -> Tuple[Tensor]:
x = self.conv(x)
return x.split(self.out_channels, dim=1)
+
class CBFuse(nn.Module):
def __init__(self, index: List[int], mode: str = "nearest"):
super().__init__()
@@ -383,10 +399,14 @@ def forward(self, x_list: List[torch.Tensor]) -> List[Tensor]:
target = x_list[-1]
target_size = target.shape[2:] # Batch, Channel, H, W
- res = [F.interpolate(x[pick_id], size=target_size, mode=self.mode) for pick_id, x in zip(self.idx, x_list)]
+ res = [
+ F.interpolate(x[pick_id], size=target_size, mode=self.mode)
+ for pick_id, x in zip(self.idx, x_list)
+ ]
out = torch.stack(res + [target]).sum(dim=0)
return out
-
+
+
class SPPELAN(nn.Module):
"""SPPELAN module comprising multiple pooling and convolution layers."""
@@ -411,4 +431,4 @@ def __init__(self, **kwargs):
self.UpSample = nn.Upsample(**kwargs)
def forward(self, x):
- return self.UpSample(x)
\ No newline at end of file
+ return self.UpSample(x)
diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo/model/yolo.py b/PytorchWildlife/models/detection/yolo_mit/yolo/model/yolo.py
index 3834dd43f..eadeebcc8 100644
--- a/PytorchWildlife/models/detection/yolo_mit/yolo/model/yolo.py
+++ b/PytorchWildlife/models/detection/yolo_mit/yolo/model/yolo.py
@@ -5,10 +5,9 @@
import torch
from omegaconf import ListConfig, OmegaConf
from torch import nn
-
from yolo.config import ModelConfig, YOLOLayer
-from yolo.tools.dataset_preparation import prepare_weight
from yolo.model.module import get_layer_map
+from yolo.tools.dataset_preparation import prepare_weight
class YOLO(nn.Module):
@@ -40,9 +39,15 @@ def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
source = self.get_source_idx(layer_info.get("source", -1), layer_idx)
# Find in channels
- if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "AConv", "CBLinear"]):
+ if any(
+ module in layer_type
+ for module in ["Conv", "ELAN", "ADown", "AConv", "CBLinear"]
+ ):
layer_args["in_channels"] = output_dim[source]
- if any(module in layer_type for module in ["Detection", "Segmentation", "Classification"]):
+ if any(
+ module in layer_type
+ for module in ["Detection", "Segmentation", "Classification"]
+ ):
if isinstance(source, list):
layer_args["in_channels"] = [output_dim[idx] for idx in source]
else:
@@ -85,7 +90,9 @@ def forward(self, x, external: Optional[Dict] = None, shortcut: Optional[str] =
return output
return output
- def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
+ def get_out_channels(
+ self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]
+ ):
if hasattr(layer_args, "out_channels"):
return layer_args["out_channels"]
if layer_type == "CBFuse":
@@ -106,7 +113,9 @@ def get_source_idx(self, source: Union[ListConfig, str, int], layer_idx: int):
self.model[source - 1].usable = True
return source
- def create_layer(self, layer_type: str, source: Union[int, list], layer_info: Dict, **kwargs) -> YOLOLayer:
+ def create_layer(
+ self, layer_type: str, source: Union[int, list], layer_info: Dict, **kwargs
+ ) -> YOLOLayer:
if layer_type in self.layer_map:
layer = self.layer_map[layer_type](**kwargs)
setattr(layer, "layer_type", layer_type)
@@ -133,7 +142,7 @@ def save_load_weights(self, weights: Union[Path, OrderedDict]):
weights = weights["state_dict"]
# Drop the prefix 'model.model.' from the keys
- if "model.model." in list(weights.keys())[0]:
+ if "model.model." in list(weights.keys())[0]:
weights = {k.replace("model.model.", ""): v for k, v in weights.items()}
model_state_dict = self.model.state_dict()
@@ -142,7 +151,7 @@ def save_load_weights(self, weights: Union[Path, OrderedDict]):
# TODO2: weight transform if num_class difference
error_dict = {"Mismatch": set(), "Not Found": set()}
-
+
for model_key, model_weight in model_state_dict.items():
if model_key not in weights:
error_dict["Not Found"].add(tuple(model_key.split(".")[:-2]))
@@ -155,7 +164,9 @@ def save_load_weights(self, weights: Union[Path, OrderedDict]):
self.model.load_state_dict(model_state_dict)
-def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, class_num: int = 80) -> YOLO:
+def create_model(
+ model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, class_num: int = 80
+) -> YOLO:
"""Constructs and returns a model from a Dictionary configuration file.
Args:
diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo/tools/data_augmentation.py b/PytorchWildlife/models/detection/yolo_mit/yolo/tools/data_augmentation.py
index a8003e135..aea66964b 100644
--- a/PytorchWildlife/models/detection/yolo_mit/yolo/tools/data_augmentation.py
+++ b/PytorchWildlife/models/detection/yolo_mit/yolo/tools/data_augmentation.py
@@ -45,11 +45,13 @@ def __call__(self, image: Image, boxes):
pad_left = (self.target_width - new_width) // 2
pad_top = (self.target_height - new_height) // 2
- padded_image = Image.new("RGB", (self.target_width, self.target_height), self.background_color)
+ padded_image = Image.new(
+ "RGB", (self.target_width, self.target_height), self.background_color
+ )
padded_image.paste(resized_image, (pad_left, pad_top))
boxes[:, [1, 3]] = (boxes[:, [1, 3]] * new_width + pad_left) / self.target_width
boxes[:, [2, 4]] = (boxes[:, [2, 4]] * new_height + pad_top) / self.target_height
transform_info = torch.tensor([scale, pad_left, pad_top, pad_left, pad_top])
- return padded_image, boxes, transform_info
\ No newline at end of file
+ return padded_image, boxes, transform_info
diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo/tools/data_loader.py b/PytorchWildlife/models/detection/yolo_mit/yolo/tools/data_loader.py
index 67a91e5e4..cbde2c15f 100644
--- a/PytorchWildlife/models/detection/yolo_mit/yolo/tools/data_loader.py
+++ b/PytorchWildlife/models/detection/yolo_mit/yolo/tools/data_loader.py
@@ -8,7 +8,6 @@
from rich.progress import track
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
-
from yolo.config import DataConfig, DatasetConfig
from yolo.tools.data_augmentation import AugmentationComposer
from yolo.tools.dataset_preparation import prepare_dataset
@@ -32,7 +31,9 @@ def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, phase: str
transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
self.transform = AugmentationComposer(transforms, self.image_size, self.base_size)
self.transform.get_more_data = self.get_more_data
- self.img_paths, self.bboxes, self.ratios = tensorlize(self.load_data(Path(dataset_cfg.path), phase_name))
+ self.img_paths, self.bboxes, self.ratios = tensorlize(
+ self.load_data(Path(dataset_cfg.path), phase_name)
+ )
def load_data(self, dataset_path: Path, phase_name: str):
"""
@@ -94,7 +95,9 @@ def filter_data(self, dataset_path: Path, phase_name: str, sort_image: bool = Fa
if not label_path.is_file():
continue
with open(label_path, "r") as file:
- image_seg_annotations = [list(map(float, line.strip().split())) for line in file]
+ image_seg_annotations = [
+ list(map(float, line.strip().split())) for line in file
+ ]
else:
image_seg_annotations = []
@@ -140,43 +143,42 @@ def load_valid_labels(self, label_path: str, seg_data_one_img: list) -> Union[Te
else:
return torch.zeros((0, 5))
- def adapt_labels(self, bboxes: Tensor) -> Tensor:
- """
- Adapt bounding box labels using vectorized operations.
-
- Args:
- bboxes (Tensor): Tensor of bounding boxes in the format [class_id, width, height, x_center, y_center].
-
- Returns:
- Tensor: Tensor of adapted bounding boxes in the format [class_id, xmin, ymin, xmax, ymax].
- """
- class_ids = bboxes[:, 0]
- widths = bboxes[:, 1]
- heights = bboxes[:, 2]
- x_centers = bboxes[:, 3]
- y_centers = bboxes[:, 4]
-
- xmins = x_centers - widths / 2
- ymins = y_centers - heights / 2
- xmaxs = x_centers + widths / 2
- ymaxs = y_centers + heights / 2
-
- adapted_bboxes = torch.stack([class_ids, xmins, ymins, xmaxs, ymaxs], dim=1)
-
- return adapted_bboxes
+ def adapt_labels(self, bboxes: Tensor) -> Tensor:
+ """
+ Adapt bounding box labels using vectorized operations.
+
+ Args:
+ bboxes (Tensor): Tensor of bounding boxes in the format [class_id, width, height, x_center, y_center].
+
+ Returns:
+ Tensor: Tensor of adapted bounding boxes in the format [class_id, xmin, ymin, xmax, ymax].
+ """
+ class_ids = bboxes[:, 0]
+ widths = bboxes[:, 1]
+ heights = bboxes[:, 2]
+ x_centers = bboxes[:, 3]
+ y_centers = bboxes[:, 4]
+
+ xmins = x_centers - widths / 2
+ ymins = y_centers - heights / 2
+ xmaxs = x_centers + widths / 2
+ ymaxs = y_centers + heights / 2
+
+ adapted_bboxes = torch.stack([class_ids, xmins, ymins, xmaxs, ymaxs], dim=1)
+
+ return adapted_bboxes
def adapt_labels_list(self, points):
-
- x_center = points[0]
- y_center = points[1]
- width = points[2]
- height = points[3]
-
- xmin = x_center - width / 2
- ymin = y_center - height / 2
- xmax = x_center + width / 2
- ymax = y_center + height / 2
-
+ x_center = points[0]
+ y_center = points[1]
+ width = points[2]
+ height = points[3]
+
+ xmin = x_center - width / 2
+ ymin = y_center - height / 2
+ xmax = x_center + width / 2
+ ymax = y_center + height / 2
+
return [xmin, ymin, xmax, ymax]
def get_data(self, idx):
@@ -228,4 +230,4 @@ def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: st
num_workers=data_cfg.cpu_num,
pin_memory=data_cfg.pin_memory,
collate_fn=collate_fn,
- )
\ No newline at end of file
+ )
diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo/tools/dataset_preparation.py b/PytorchWildlife/models/detection/yolo_mit/yolo/tools/dataset_preparation.py
index bd41c1ec4..8cb089bd6 100644
--- a/PytorchWildlife/models/detection/yolo_mit/yolo/tools/dataset_preparation.py
+++ b/PytorchWildlife/models/detection/yolo_mit/yolo/tools/dataset_preparation.py
@@ -2,7 +2,6 @@
from typing import Optional
import requests
-
from yolo.config import DatasetConfig
@@ -48,4 +47,3 @@ def prepare_weight(download_link: Optional[str] = None, weight_path: Path = Path
download_file(weight_link, weight_path)
except requests.exceptions.RequestException as e:
raise RuntimeError(f"Failed to download the weight file: {e}")
-
diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo/utils/bounding_box_utils.py b/PytorchWildlife/models/detection/yolo_mit/yolo/utils/bounding_box_utils.py
index b6e1c13e6..34e781f0b 100644
--- a/PytorchWildlife/models/detection/yolo_mit/yolo/utils/bounding_box_utils.py
+++ b/PytorchWildlife/models/detection/yolo_mit/yolo/utils/bounding_box_utils.py
@@ -3,7 +3,6 @@
import torch
from torch import Tensor, tensor
from torchvision.ops import batched_nms
-
from yolo.config import AnchorConfig, NMSConfig
from yolo.model.yolo import YOLO
@@ -109,7 +108,9 @@ def __init__(self, model: YOLO, anchor_cfg: AnchorConfig, image_size, device):
self.head_num = len(anchor_cfg.anchor)
self.anchor_grids = self.generate_anchors(image_size)
- self.anchor_scale = tensor(anchor_cfg.anchor, device=device).view(self.head_num, 1, -1, 1, 1, 2)
+ self.anchor_scale = tensor(anchor_cfg.anchor, device=device).view(
+ self.head_num, 1, -1, 1, 1, 2
+ )
self.anchor_num = self.anchor_scale.size(2)
self.class_num = model.num_classes
@@ -128,7 +129,9 @@ def generate_anchors(self, image_size: List[int]):
for stride in self.strides:
W, H = image_size[0] // stride, image_size[1] // stride
anchor_h, anchor_w = torch.meshgrid([torch.arange(H), torch.arange(W)], indexing="ij")
- anchor_grid = torch.stack((anchor_w, anchor_h), 2).view((1, 1, H, W, 2)).float().to(self.device)
+ anchor_grid = (
+ torch.stack((anchor_w, anchor_h), 2).view((1, 1, H, W, 2)).float().to(self.device)
+ )
anchor_grids.append(anchor_grid)
return anchor_grids
@@ -147,11 +150,9 @@ def __call__(self, predicts: List[Tensor]):
pred_box = pred_box.sigmoid()
pred_box[..., 0:2] = (
- (pred_box[..., 0:2] * 2.0 - 0.5 + self.anchor_grids[layer_idx]) * self.strides[layer_idx]
- )
- pred_box[..., 2:4] = (
- (pred_box[..., 2:4] * 2) ** 2 * self.anchor_scale[layer_idx]
- )
+ pred_box[..., 0:2] * 2.0 - 0.5 + self.anchor_grids[layer_idx]
+ ) * self.strides[layer_idx]
+ pred_box[..., 2:4] = (pred_box[..., 2:4] * 2) ** 2 * self.anchor_scale[layer_idx]
B, L, h, w, A = pred_box.shape
preds_box.append(pred_box.reshape(B, L * h * w, A))
@@ -177,20 +178,29 @@ def create_converter(model_version: str = "v9-c", *args, **kwargs) -> Union[Anc2
return converter
-def bbox_nms(cls_dist: Tensor, bbox: Tensor, nms_cfg: NMSConfig, confidence: Optional[Tensor] = None):
+def bbox_nms(
+ cls_dist: Tensor, bbox: Tensor, nms_cfg: NMSConfig, confidence: Optional[Tensor] = None
+):
cls_dist = cls_dist.sigmoid() * (1 if confidence is None else confidence)
batch_idx, valid_grid, valid_cls = torch.where(cls_dist > nms_cfg.min_confidence)
valid_con = cls_dist[batch_idx, valid_grid, valid_cls]
valid_box = bbox[batch_idx, valid_grid]
- nms_idx = batched_nms(valid_box, valid_con, batch_idx + valid_cls * bbox.size(0), nms_cfg.min_iou)
+ nms_idx = batched_nms(
+ valid_box, valid_con, batch_idx + valid_cls * bbox.size(0), nms_cfg.min_iou
+ )
predicts_nms = []
for idx in range(cls_dist.size(0)):
instance_idx = nms_idx[idx == batch_idx[nms_idx]]
predict_nms = torch.cat(
- [valid_cls[instance_idx][:, None], valid_box[instance_idx], valid_con[instance_idx][:, None]], dim=-1
+ [
+ valid_cls[instance_idx][:, None],
+ valid_box[instance_idx],
+ valid_con[instance_idx][:, None],
+ ],
+ dim=-1,
)
predicts_nms.append(predict_nms[: nms_cfg.max_bbox])
diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo/utils/dataset_utils.py b/PytorchWildlife/models/detection/yolo_mit/yolo/utils/dataset_utils.py
index da1989d59..6111bdd24 100644
--- a/PytorchWildlife/models/detection/yolo_mit/yolo/utils/dataset_utils.py
+++ b/PytorchWildlife/models/detection/yolo_mit/yolo/utils/dataset_utils.py
@@ -3,9 +3,11 @@
from itertools import chain
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
+
import numpy as np
import torch
+
def discretize_categories(categories: List[Dict[str, int]]) -> Dict[int, int]:
"""
Maps each unique 'id' in the list of category dictionaries to a sequential integer index.
@@ -54,8 +56,14 @@ def create_image_metadata(labels_path: str) -> Tuple[Dict[str, List], Dict[str,
"""
with open(labels_path, "r") as file:
labels_data = json.load(file)
- id_to_idx = discretize_categories(labels_data.get("categories", [])) if "categories" in labels_data else None
- annotations_index = organize_annotations_by_image(labels_data, id_to_idx) # check lookup is a good name?
+ id_to_idx = (
+ discretize_categories(labels_data.get("categories", []))
+ if "categories" in labels_data
+ else None
+ )
+ annotations_index = organize_annotations_by_image(
+ labels_data, id_to_idx
+ ) # check lookup is a good name?
image_info_dict = {Path(img["file_name"]).stem: img for img in labels_data["images"]}
return annotations_index, image_info_dict
@@ -89,7 +97,9 @@ def scale_segmentation(
scaled_seg_data = (
np.array(seg_list).reshape(-1, 2) / [w, h]
).tolist() # make the list group in x, y pairs and scaled with image width, height
- scaled_flat_seg_data = [category_id] + list(chain(*scaled_seg_data)) # flatten the scaled_seg_data list
+ scaled_flat_seg_data = [category_id] + list(
+ chain(*scaled_seg_data)
+ ) # flatten the scaled_seg_data list
seg_array_with_cat.append(scaled_flat_seg_data)
return seg_array_with_cat
diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo/utils/model_utils.py b/PytorchWildlife/models/detection/yolo_mit/yolo/utils/model_utils.py
index e03942e8e..6707dbcc3 100644
--- a/PytorchWildlife/models/detection/yolo_mit/yolo/utils/model_utils.py
+++ b/PytorchWildlife/models/detection/yolo_mit/yolo/utils/model_utils.py
@@ -1,4 +1,5 @@
from typing import List, Optional, Union
+
from torch import Tensor
from yolo.config import NMSConfig
from yolo.model.yolo import YOLO
@@ -25,5 +26,7 @@ def __call__(
pred_conf = prediction[3] if len(prediction) == 4 else None
if rev_tensor is not None:
pred_bbox = (pred_bbox - rev_tensor[:, None, 1:]) / rev_tensor[:, 0:1, None]
- pred_bbox = bbox_nms(pred_class, pred_bbox, self.nms, pred_conf) #pred_box: [cls, x1, y1, x2, y2, conf]
- return pred_bbox
\ No newline at end of file
+ pred_bbox = bbox_nms(
+ pred_class, pred_bbox, self.nms, pred_conf
+ ) # pred_box: [cls, x1, y1, x2, y2, conf]
+ return pred_bbox
diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo_mit_base.py b/PytorchWildlife/models/detection/yolo_mit/yolo_mit_base.py
index 4b02a7dff..ee286f492 100644
--- a/PytorchWildlife/models/detection/yolo_mit/yolo_mit_base.py
+++ b/PytorchWildlife/models/detection/yolo_mit/yolo_mit_base.py
@@ -6,45 +6,46 @@
# Importing basic libraries
import os
-import supervision as sv
-import numpy as np
-from PIL import Image
-import wget
-import torch
-
-from ..base_detector import BaseDetector
-from ....data import datasets as pw_data
-
import sys
from pathlib import Path
-from lightning import Trainer
+import numpy as np
+import supervision as sv
+import torch
+import wget
import yaml
+from lightning import Trainer
from omegaconf import OmegaConf
+from PIL import Image
+
+from ....data import datasets as pw_data
+from ..base_detector import BaseDetector
project_root = Path(__file__).resolve().parent
sys.path.append(str(project_root))
-from yolo import create_model, create_converter, PostProcess, AugmentationComposer
+from yolo import AugmentationComposer, PostProcess, create_converter, create_model
+
class YOLOMITBase(BaseDetector):
"""
Base detector class for YOLO MIT framework. This class provides utility methods for
loading the model, generating results, and performing single and batch image detections.
"""
+
def __init__(self, weights=None, device="cpu", url=None):
"""
Initialize the YOLO MIT detector.
-
+
Args:
- weights (str, optional):
+ weights (str, optional):
Path to the model weights. Defaults to None.
- device (str, optional):
+ device (str, optional):
Device for model inference. Defaults to "cpu".
- url (str, optional):
+ url (str, optional):
URL to fetch the model weights. Defaults to None.
"""
-
+
self.cfg = self._load_cfg()
self.transform = AugmentationComposer([], self.cfg.image_size, self.cfg.image_size[0])
self.weights = weights
@@ -54,21 +55,29 @@ def __init__(self, weights=None, device="cpu", url=None):
def _load_cfg(self):
if self.MODEL_NAME == "MDV6-mit-yolov9-c.ckpt":
- if not os.path.exists(os.path.join(torch.hub.get_dir(), "checkpoints", "config_v9s.yaml")):
+ if not os.path.exists(
+ os.path.join(torch.hub.get_dir(), "checkpoints", "config_v9s.yaml")
+ ):
os.makedirs(os.path.join(torch.hub.get_dir(), "checkpoints"), exist_ok=True)
url = "https://zenodo.org/records/15178680/files/config_v9s.yaml?download=1"
- config_path = wget.download(url, out=os.path.join(torch.hub.get_dir(), "checkpoints"))
+ config_path = wget.download(
+ url, out=os.path.join(torch.hub.get_dir(), "checkpoints")
+ )
else:
config_path = os.path.join(torch.hub.get_dir(), "checkpoints", "config_v9s.yaml")
elif self.MODEL_NAME == "MDV6-mit-yolov9-e.ckpt":
- if not os.path.exists(os.path.join(torch.hub.get_dir(), "checkpoints", "config_v9c.yaml")):
+ if not os.path.exists(
+ os.path.join(torch.hub.get_dir(), "checkpoints", "config_v9c.yaml")
+ ):
os.makedirs(os.path.join(torch.hub.get_dir(), "checkpoints"), exist_ok=True)
url = "https://zenodo.org/records/15178680/files/config_v9c.yaml?download=1"
- config_path = wget.download(url, out=os.path.join(torch.hub.get_dir(), "checkpoints"))
+ config_path = wget.download(
+ url, out=os.path.join(torch.hub.get_dir(), "checkpoints")
+ )
else:
config_path = os.path.join(torch.hub.get_dir(), "checkpoints", "config_v9c.yaml")
- with open(config_path, 'r') as f:
+ with open(config_path, "r") as f:
cfg_dict = yaml.safe_load(f)
return OmegaConf.create(cfg_dict)
@@ -76,78 +85,77 @@ def _load_cfg(self):
def _load_model(self, weights=None, device="cpu", url=None):
"""
Load the YOLO MIT model weights.
-
+
Args:
- weights (str, optional):
+ weights (str, optional):
Path to the model weights. Defaults to None.
- device (str, optional):
+ device (str, optional):
Device for model inference. Defaults to "cpu".
- url (str, optional):
+ url (str, optional):
URL to fetch the model weights. Defaults to None.
Raises:
Exception: If weights are not provided.
"""
if url:
- if not os.path.exists(os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME)):
+ if not os.path.exists(
+ os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME)
+ ):
os.makedirs(os.path.join(torch.hub.get_dir(), "checkpoints"), exist_ok=True)
weights = wget.download(url, out=os.path.join(torch.hub.get_dir(), "checkpoints"))
else:
weights = os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME)
else:
raise Exception("Need weights for inference.")
-
+
self.cfg.image_size = [self.IMAGE_SIZE, self.IMAGE_SIZE]
self.model = create_model(self.cfg.model, weight_path=weights, class_num=3).to(self.device)
- self.converter = create_converter(self.cfg.model.name, self.model, self.cfg.model.anchor, self.cfg.image_size, self.device)
+ self.converter = create_converter(
+ self.cfg.model.name, self.model, self.cfg.model.anchor, self.cfg.image_size, self.device
+ )
self.post_proccess = PostProcess(self.converter, self.cfg.task.nms)
def results_generation(self, preds, img_id, id_strip=None):
"""
Generate results for detection based on model predictions.
-
+
Args:
- preds (List[torch.Tensor]):
+ preds (List[torch.Tensor]):
Model predictions.
- img_id (str):
+ img_id (str):
Image identifier.
- id_strip (str, optional):
+ id_strip (str, optional):
Strip specific characters from img_id. Defaults to None.
Returns:
dict: Dictionary containing image ID, detections, and labels.
"""
- #preds: [cls, x1, y1, x2, y2, conf]
- class_id = preds[0][:,0].cpu().numpy().astype(int)
- xyxy = preds[0][:,1:5].cpu().numpy()
- confidence = preds[0][:,5].cpu().numpy()
+ # preds: [cls, x1, y1, x2, y2, conf]
+ class_id = preds[0][:, 0].cpu().numpy().astype(int)
+ xyxy = preds[0][:, 1:5].cpu().numpy()
+ confidence = preds[0][:, 5].cpu().numpy()
results = {"img_id": str(img_id).strip(id_strip)}
- results["detections"] = sv.Detections(
- xyxy=xyxy,
- confidence=confidence,
- class_id=class_id
- )
+ results["detections"] = sv.Detections(xyxy=xyxy, confidence=confidence, class_id=class_id)
results["labels"] = [
- f"{self.CLASS_NAMES[class_id]} {confidence:0.2f}"
- for _, _, confidence, class_id, _, _ in results["detections"]
+ f"{self.CLASS_NAMES[class_id]} {confidence:0.2f}"
+ for _, _, confidence, class_id, _, _ in results["detections"]
]
results
return results
-
def single_image_detection(self, img, img_path=None, det_conf_thres=0.2, id_strip=None):
"""
Perform detection on a single image.
-
+
Args:
- img (str or ndarray):
+ img (str or ndarray):
Image path or ndarray of images.
- img_path (str, optional):
+ img_path (str, optional):
Image path or identifier.
- det_conf_thres (float, optional):
+ det_conf_thres (float, optional):
Confidence threshold for predictions. Defaults to 0.2.
- id_strip (str, optional):
+ id_strip (str, optional):
Characters to strip from img_id. Defaults to None.
Returns:
@@ -160,32 +168,34 @@ def single_image_detection(self, img, img_path=None, det_conf_thres=0.2, id_stri
if type(img) == str:
if img_path is None:
img_path = img
- im_pil = Image.open(img_path).convert('RGB')
+ im_pil = Image.open(img_path).convert("RGB")
else:
im_pil = Image.fromarray(img)
image, bbox, rev_tensor = self.transform(im_pil)
image = image.to(self.device)[None]
rev_tensor = rev_tensor.to(self.device)[None]
-
+
with torch.no_grad():
predict = self.model(image)
- det_results = self.post_proccess(predict, rev_tensor) #pred_box: [cls, x1, y1, x2, y2, conf]
-
+ det_results = self.post_proccess(
+ predict, rev_tensor
+ ) # pred_box: [cls, x1, y1, x2, y2, conf]
+
return self.results_generation(det_results, img_path, id_strip)
def batch_image_detection(self, data_path, batch_size=16, det_conf_thres=0.2, id_strip=None):
"""
Perform detection on a batch of images.
-
+
Args:
- data_path (str):
+ data_path (str):
Path containing all images for inference.
batch_size (int, optional):
Batch size for inference. Defaults to 16.
- det_conf_thres (float, optional):
+ det_conf_thres (float, optional):
Confidence threshold for predictions. Defaults to 0.2.
- id_strip (str, optional):
+ id_strip (str, optional):
Characters to strip from img_id. Defaults to None.
extension (str, optional):
Image extension to search for. Defaults to "JPG"
@@ -196,25 +206,30 @@ def batch_image_detection(self, data_path, batch_size=16, det_conf_thres=0.2, id
self.cfg.task.data.source = data_path
self.cfg.task.nms.min_confidence = det_conf_thres
self._load_model(weights=self.weights, device=self.device, url=self.url)
-
+
dataset = pw_data.DetectionImageFolder(
data_path,
transform=self.transform,
)
-
+
results = []
for i in range(len(dataset.images)):
- res = self.single_image_detection(dataset.images[i], img_path=dataset.images[i], det_conf_thres=det_conf_thres, id_strip=id_strip)
+ res = self.single_image_detection(
+ dataset.images[i],
+ img_path=dataset.images[i],
+ det_conf_thres=det_conf_thres,
+ id_strip=id_strip,
+ )
# Upload the original image and get the size in the format (height, width)
img = Image.open(dataset.images[i])
img = np.asarray(img)
size = img.shape[:2]
# Normalize the coordinates for timelapse compatibility
- normalized_coords = [[x1 / size[1], y1 / size[0], x2 / size[1], y2 / size[0]] for x1, y1, x2, y2 in res["detections"].xyxy]
+ normalized_coords = [
+ [x1 / size[1], y1 / size[0], x2 / size[1], y2 / size[0]]
+ for x1, y1, x2, y2 in res["detections"].xyxy
+ ]
res["normalized_coords"] = normalized_coords
results.append(res)
return results
-
-
-
diff --git a/PytorchWildlife/utils/__init__.py b/PytorchWildlife/utils/__init__.py
index bca86797f..66efc7844 100644
--- a/PytorchWildlife/utils/__init__.py
+++ b/PytorchWildlife/utils/__init__.py
@@ -1,2 +1,2 @@
from .misc import *
-from .post_process import *
\ No newline at end of file
+from .post_process import *
diff --git a/PytorchWildlife/utils/misc.py b/PytorchWildlife/utils/misc.py
index f516a66c2..ae32d4006 100644
--- a/PytorchWildlife/utils/misc.py
+++ b/PytorchWildlife/utils/misc.py
@@ -3,15 +3,14 @@
""" Miscellaneous functions."""
-import numpy as np
-from tqdm import tqdm
-import cv2
from typing import Callable
+
+import cv2
+import numpy as np
from supervision import VideoInfo, VideoSink, get_video_frames_generator
+from tqdm import tqdm
-__all__ = [
- "process_video"
-]
+__all__ = ["process_video"]
def process_video(
@@ -19,32 +18,32 @@ def process_video(
target_path: str,
callback: Callable[[np.ndarray, int], np.ndarray],
target_fps: int = 1,
- codec: str = "mp4v"
+ codec: str = "mp4v",
) -> None:
"""
- Process a video frame-by-frame, applying a callback function to each frame and saving the results
+ Process a video frame-by-frame, applying a callback function to each frame and saving the results
to a new video. This version includes a progress bar and allows codec selection.
-
+
Args:
- source_path (str):
+ source_path (str):
Path to the source video file.
- target_path (str):
+ target_path (str):
Path to save the processed video.
- callback (Callable[[np.ndarray, int], np.ndarray]):
+ callback (Callable[[np.ndarray, int], np.ndarray]):
A function that takes a video frame and its index as input and returns the processed frame.
- codec (str, optional):
+ codec (str, optional):
Codec used to encode the processed video. Default is "avc1".
"""
source_video_info = VideoInfo.from_video_path(video_path=source_path)
-
+
if source_video_info.fps > target_fps:
stride = int(source_video_info.fps / target_fps)
source_video_info.fps = target_fps
else:
stride = 1
-
+
with VideoSink(target_path=target_path, video_info=source_video_info, codec=codec) as sink:
- with tqdm(total=int(source_video_info.total_frames / stride)) as pbar:
+ with tqdm(total=int(source_video_info.total_frames / stride)) as pbar:
for index, frame in enumerate(
get_video_frames_generator(source_path=source_path, stride=stride)
):
diff --git a/PytorchWildlife/utils/post_process.py b/PytorchWildlife/utils/post_process.py
index d7152ae4e..955a4b661 100644
--- a/PytorchWildlife/utils/post_process.py
+++ b/PytorchWildlife/utils/post_process.py
@@ -3,15 +3,16 @@
""" Post-processing functions."""
-import os
-import numpy as np
import json
-import cv2
-from PIL import Image
-import supervision as sv
+import os
import shutil
from pathlib import Path
+import cv2
+import numpy as np
+import supervision as sv
+from PIL import Image
+
__all__ = [
"save_detection_images",
"save_detection_images_dots",
@@ -21,7 +22,7 @@
"save_detection_classification_json",
"save_detection_timelapse_json",
"save_detection_classification_timelapse_json",
- "detection_folder_separation"
+ "detection_folder_separation",
]
@@ -43,7 +44,7 @@ def save_detection_images(results, output_dir, input_dir=None, overwrite=False):
lab_annotator = sv.LabelAnnotator(text_color=sv.Color.BLACK, text_thickness=4, text_scale=2)
os.makedirs(output_dir, exist_ok=True)
- with sv.ImageSink(target_dir_path=output_dir, overwrite=overwrite) as sink:
+ with sv.ImageSink(target_dir_path=output_dir, overwrite=overwrite) as sink:
if isinstance(results, list):
for entry in results:
annotated_img = lab_annotator.annotate(
@@ -57,8 +58,8 @@ def save_detection_images(results, output_dir, input_dir=None, overwrite=False):
if input_dir:
relative_path = os.path.relpath(entry["img_id"], input_dir)
save_path = os.path.join(output_dir, relative_path)
- os.makedirs(os.path.dirname(save_path), exist_ok=True)
- image_name = relative_path
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
+ image_name = relative_path
else:
image_name = os.path.basename(entry["img_id"])
sink.save_image(
@@ -75,9 +76,11 @@ def save_detection_images(results, output_dir, input_dir=None, overwrite=False):
)
sink.save_image(
- image=cv2.cvtColor(annotated_img, cv2.COLOR_RGB2BGR), image_name=os.path.basename(results["img_id"])
+ image=cv2.cvtColor(annotated_img, cv2.COLOR_RGB2BGR),
+ image_name=os.path.basename(results["img_id"]),
)
+
def save_detection_images_dots(results, output_dir, input_dir=None, overwrite=False):
"""
Save detected images with bounding boxes and labels annotated.
@@ -92,10 +95,10 @@ def save_detection_images_dots(results, output_dir, input_dir=None, overwrite=Fa
overwrite (bool):
Whether overwriting existing image folders. Default to False.
"""
- dot_annotator = sv.DotAnnotator(radius=6)
- lab_annotator = sv.LabelAnnotator(text_position=sv.Position.BOTTOM_RIGHT)
+ dot_annotator = sv.DotAnnotator(radius=6)
+ lab_annotator = sv.LabelAnnotator(text_position=sv.Position.BOTTOM_RIGHT)
os.makedirs(output_dir, exist_ok=True)
-
+
with sv.ImageSink(target_dir_path=output_dir, overwrite=overwrite) as sink:
if isinstance(results, list):
for i, entry in enumerate(results):
@@ -104,7 +107,7 @@ def save_detection_images_dots(results, output_dir, input_dir=None, overwrite=Fa
image_name = os.path.basename(entry["img_id"])
else:
scene = entry["img"]
- image_name = f"output_image_{i}.jpg" # default name if no image id is provided
+ image_name = f"output_image_{i}.jpg" # default name if no image id is provided
annotated_img = lab_annotator.annotate(
scene=dot_annotator.annotate(
@@ -117,7 +120,7 @@ def save_detection_images_dots(results, output_dir, input_dir=None, overwrite=Fa
if input_dir:
relative_path = os.path.relpath(entry["img_id"], input_dir)
save_path = os.path.join(output_dir, relative_path)
- os.makedirs(os.path.dirname(save_path), exist_ok=True)
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
image_name = relative_path
sink.save_image(
image=cv2.cvtColor(annotated_img, cv2.COLOR_RGB2BGR), image_name=image_name
@@ -128,8 +131,8 @@ def save_detection_images_dots(results, output_dir, input_dir=None, overwrite=Fa
image_name = os.path.basename(results["img_id"])
else:
scene = results["img"]
- image_name = "output_image.jpg" # default name if no image id is provided
-
+ image_name = "output_image.jpg" # default name if no image id is provided
+
annotated_img = lab_annotator.annotate(
scene=dot_annotator.annotate(
scene=scene,
@@ -137,7 +140,7 @@ def save_detection_images_dots(results, output_dir, input_dir=None, overwrite=Fa
),
detections=results["detections"],
labels=results["labels"],
- )
+ )
sink.save_image(
image=cv2.cvtColor(annotated_img, cv2.COLOR_RGB2BGR), image_name=image_name
)
@@ -164,32 +167,48 @@ def save_crop_images(results, output_dir, input_dir=None, overwrite=False):
with sv.ImageSink(target_dir_path=output_dir, overwrite=overwrite) as sink:
if isinstance(results, list):
for entry in results:
- for i, (xyxy, cat) in enumerate(zip(entry["detections"].xyxy, entry["detections"].class_id)):
+ for i, (xyxy, cat) in enumerate(
+ zip(entry["detections"].xyxy, entry["detections"].class_id)
+ ):
cropped_img = sv.crop_image(
image=np.array(Image.open(entry["img_id"]).convert("RGB")), xyxy=xyxy
)
if input_dir:
relative_path = os.path.relpath(entry["img_id"], input_dir)
save_path = os.path.join(output_dir, relative_path)
- os.makedirs(os.path.dirname(save_path), exist_ok=True)
- image_name = os.path.join(os.path.dirname(relative_path), "{}_{}_{}".format(int(cat), i, os.path.basename(entry["img_id"])))
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
+ image_name = os.path.join(
+ os.path.dirname(relative_path),
+ "{}_{}_{}".format(int(cat), i, os.path.basename(entry["img_id"])),
+ )
else:
- image_name = "{}_{}_{}".format(int(cat), i, os.path.basename(entry["img_id"]))
+ image_name = "{}_{}_{}".format(
+ int(cat), i, os.path.basename(entry["img_id"])
+ )
sink.save_image(
image=cv2.cvtColor(cropped_img, cv2.COLOR_RGB2BGR),
image_name=image_name,
)
else:
- for i, (xyxy, cat) in enumerate(zip(results["detections"].xyxy, results["detections"].class_id)):
+ for i, (xyxy, cat) in enumerate(
+ zip(results["detections"].xyxy, results["detections"].class_id)
+ ):
cropped_img = sv.crop_image(
image=np.array(Image.open(results["img_id"]).convert("RGB")), xyxy=xyxy
)
sink.save_image(
image=cv2.cvtColor(cropped_img, cv2.COLOR_RGB2BGR),
- image_name="{}_{}_{}".format(int(cat), i, os.path.basename(results["img_id"]),
- ))
+ image_name="{}_{}_{}".format(
+ int(cat),
+ i,
+ os.path.basename(results["img_id"]),
+ ),
+ )
+
-def save_detection_json(det_results, output_dir, categories=None, exclude_category_ids=[], exclude_file_path=None):
+def save_detection_json(
+ det_results, output_dir, categories=None, exclude_category_ids=[], exclude_file_path=None
+):
"""
Save detection results to a JSON file.
@@ -208,19 +227,20 @@ def save_detection_json(det_results, output_dir, categories=None, exclude_catego
json_results = {"annotations": [], "categories": categories}
for det_r in det_results:
-
# Category filtering
img_id = det_r["img_id"]
category = det_r["detections"].class_id
bbox = det_r["detections"].xyxy.astype(int)[~np.isin(category, exclude_category_ids)]
- confidence = det_r["detections"].confidence[~np.isin(category, exclude_category_ids)]
+ confidence = det_r["detections"].confidence[~np.isin(category, exclude_category_ids)]
category = category[~np.isin(category, exclude_category_ids)]
# if not all([x in exclude_category_ids for x in category]):
json_results["annotations"].append(
{
- "img_id": img_id.replace(exclude_file_path + os.sep, '') if exclude_file_path else img_id,
+ "img_id": img_id.replace(exclude_file_path + os.sep, "")
+ if exclude_file_path
+ else img_id,
"bbox": bbox.tolist(),
"category": category.tolist(),
"confidence": confidence.tolist(),
@@ -230,7 +250,10 @@ def save_detection_json(det_results, output_dir, categories=None, exclude_catego
with open(output_dir, "w") as f:
json.dump(json_results, f, indent=4)
-def save_detection_json_as_dots(det_results, output_dir, categories=None, exclude_category_ids=[], exclude_file_path=None):
+
+def save_detection_json_as_dots(
+ det_results, output_dir, categories=None, exclude_category_ids=[], exclude_file_path=None
+):
"""
Save detection results to a JSON file in dots format.
@@ -249,20 +272,21 @@ def save_detection_json_as_dots(det_results, output_dir, categories=None, exclud
json_results = {"annotations": [], "categories": categories}
for det_r in det_results:
-
# Category filtering
img_id = det_r["img_id"]
category = det_r["detections"].class_id
bbox = det_r["detections"].xyxy.astype(int)[~np.isin(category, exclude_category_ids)]
dot = np.array([[np.mean(row[::2]), np.mean(row[1::2])] for row in bbox])
- confidence = det_r["detections"].confidence[~np.isin(category, exclude_category_ids)]
+ confidence = det_r["detections"].confidence[~np.isin(category, exclude_category_ids)]
category = category[~np.isin(category, exclude_category_ids)]
# if not all([x in exclude_category_ids for x in category]):
json_results["annotations"].append(
{
- "img_id": img_id.replace(exclude_file_path + os.sep, '') if exclude_file_path else img_id,
+ "img_id": img_id.replace(exclude_file_path + os.sep, "")
+ if exclude_file_path
+ else img_id,
"dot": dot.tolist(),
"category": category.tolist(),
"confidence": confidence.tolist(),
@@ -274,9 +298,13 @@ def save_detection_json_as_dots(det_results, output_dir, categories=None, exclud
def save_detection_timelapse_json(
- det_results, output_dir, categories=None,
- exclude_category_ids=[], exclude_file_path=None, info={"detector": "megadetector_v5"}
- ):
+ det_results,
+ output_dir,
+ categories=None,
+ exclude_category_ids=[],
+ exclude_file_path=None,
+ info={"detector": "megadetector_v5"},
+):
"""
Save detection results to a JSON file.
@@ -295,35 +323,41 @@ def save_detection_timelapse_json(
Default Timelapse info. Defaults to {"detector": "megadetector_v5}.
"""
- json_results = {
- "info": info,
- "detection_categories": categories,
- "images": []
- }
+ json_results = {"info": info, "detection_categories": categories, "images": []}
for det_r in det_results:
-
img_id = det_r["img_id"]
category_id_list = det_r["detections"].class_id
- bbox_list = det_r["detections"].xyxy.astype(int)[~np.isin(category_id_list, exclude_category_ids)]
- confidence_list = det_r["detections"].confidence[~np.isin(category_id_list, exclude_category_ids)]
- normalized_bbox_list = np.array(det_r["normalized_coords"])[~np.isin(category_id_list, exclude_category_ids)]
+ bbox_list = det_r["detections"].xyxy.astype(int)[
+ ~np.isin(category_id_list, exclude_category_ids)
+ ]
+ confidence_list = det_r["detections"].confidence[
+ ~np.isin(category_id_list, exclude_category_ids)
+ ]
+ normalized_bbox_list = np.array(det_r["normalized_coords"])[
+ ~np.isin(category_id_list, exclude_category_ids)
+ ]
category_id_list = category_id_list[~np.isin(category_id_list, exclude_category_ids)]
# if not all([x in exclude_category_ids for x in category_id_list]):
image_annotations = {
- "file": img_id.replace(exclude_file_path + os.sep, '') if exclude_file_path else img_id,
- "max_detection_conf": float(max(confidence_list)) if len(confidence_list) > 0 else '',
- "detections": []
+ "file": img_id.replace(exclude_file_path + os.sep, "") if exclude_file_path else img_id,
+ "max_detection_conf": float(max(confidence_list)) if len(confidence_list) > 0 else "",
+ "detections": [],
}
for i in range(len(bbox_list)):
normalized_bbox = [float(y) for y in normalized_bbox_list[i]]
detection = {
"category": str(category_id_list[i]),
"conf": float(confidence_list[i]),
- "bbox": [normalized_bbox[0], normalized_bbox[1], normalized_bbox[2]-normalized_bbox[0], normalized_bbox[3]-normalized_bbox[1]],
- "classifications": []
+ "bbox": [
+ normalized_bbox[0],
+ normalized_bbox[1],
+ normalized_bbox[2] - normalized_bbox[0],
+ normalized_bbox[3] - normalized_bbox[1],
+ ],
+ "classifications": [],
}
image_annotations["detections"].append(detection)
@@ -335,7 +369,12 @@ def save_detection_timelapse_json(
def save_detection_classification_json(
- det_results, clf_results, output_path, det_categories=None, clf_categories=None, exclude_file_path=None
+ det_results,
+ clf_results,
+ output_path,
+ det_categories=None,
+ clf_categories=None,
+ exclude_file_path=None,
):
"""
Save classification results to a JSON file.
@@ -377,17 +416,15 @@ def save_detection_classification_json(
json_results["annotations"].append(
{
- "img_id": str(det_r["img_id"]).replace(exclude_file_path + os.sep, '') if exclude_file_path else str(det_r["img_id"]),
+ "img_id": str(det_r["img_id"]).replace(exclude_file_path + os.sep, "")
+ if exclude_file_path
+ else str(det_r["img_id"]),
"bbox": [
[int(x) for x in sublist]
for sublist in det_r["detections"].xyxy.astype(int).tolist()
],
- "det_category": [
- int(x) for x in det_r["detections"].class_id.tolist()
- ],
- "det_confidence": [
- float(x) for x in det_r["detections"].confidence.tolist()
- ],
+ "det_category": [int(x) for x in det_r["detections"].class_id.tolist()],
+ "det_confidence": [float(x) for x in det_r["detections"].confidence.tolist()],
"clf_category": [int(x) for x in clf_categories],
"clf_confidence": [float(x) for x in clf_confidence],
}
@@ -396,8 +433,13 @@ def save_detection_classification_json(
def save_detection_classification_timelapse_json(
- det_results, clf_results, output_path, det_categories=None, clf_categories=None,
- exclude_file_path=None, info={"detector": "megadetector_v5"}
+ det_results,
+ clf_results,
+ output_path,
+ det_categories=None,
+ clf_categories=None,
+ exclude_file_path=None,
+ info={"detector": "megadetector_v5"},
):
"""
Save detection and classification results to a JSON file in the specified format.
@@ -420,14 +462,18 @@ def save_detection_classification_timelapse_json(
"info": info,
"detection_categories": det_categories,
"classification_categories": clf_categories,
- "images": []
+ "images": [],
}
for det_r in det_results:
image_annotations = {
- "file": str(det_r["img_id"]).replace(exclude_file_path + os.sep, '') if exclude_file_path else str(det_r["img_id"]),
- "max_detection_conf": float(max(det_r["detections"].confidence)) if len(det_r["detections"].confidence) > 0 else '',
- "detections": []
+ "file": str(det_r["img_id"]).replace(exclude_file_path + os.sep, "")
+ if exclude_file_path
+ else str(det_r["img_id"]),
+ "max_detection_conf": float(max(det_r["detections"].confidence))
+ if len(det_r["detections"].confidence) > 0
+ else "",
+ "detections": [],
}
for i in range(len(det_r["detections"])):
@@ -436,14 +482,21 @@ def save_detection_classification_timelapse_json(
detection = {
"category": str(det.class_id[0]),
"conf": float(det.confidence[0]),
- "bbox": [normalized_bbox[0], normalized_bbox[1], normalized_bbox[2]-normalized_bbox[0], normalized_bbox[3]-normalized_bbox[1]],
- "classifications": []
+ "bbox": [
+ normalized_bbox[0],
+ normalized_bbox[1],
+ normalized_bbox[2] - normalized_bbox[0],
+ normalized_bbox[3] - normalized_bbox[1],
+ ],
+ "classifications": [],
}
# Find classifications for this detection
for clf_r in clf_results:
if clf_r["img_id"] == det_r["img_id"]:
- detection["classifications"].append([str(clf_r["class_id"]), float(clf_r["confidence"])])
+ detection["classifications"].append(
+ [str(clf_r["class_id"]), float(clf_r["confidence"])]
+ )
image_annotations["detections"].append(detection)
@@ -461,7 +514,7 @@ def detection_folder_separation(json_file, img_path, destination_path, confidenc
This function reads a JSON formatted file containing annotations of image detections.
Each image is checked for detections with category '0' and a confidence level above the specified
threshold. If such detections are found, the image is categorized under 'Animal'. Images without
- any category '0' detections above the threshold, including those with no detections at all, are
+ any category '0' detections above the threshold, including those with no detections at all, are
categorized under 'No_animal'.
Parameters:
@@ -485,41 +538,41 @@ def detection_folder_separation(json_file, img_path, destination_path, confidenc
"""
# Load JSON data from the file
- with open(json_file, 'r') as file:
+ with open(json_file, "r") as file:
data = json.load(file)
-
+
# Ensure the destination directories exist
os.makedirs(destination_path, exist_ok=True)
animal_path = os.path.join(destination_path, "Animal")
no_animal_path = os.path.join(destination_path, "No_animal")
os.makedirs(animal_path, exist_ok=True)
os.makedirs(no_animal_path, exist_ok=True)
-
+
# Process each image detection
i = 0
- for item in data['annotations']:
- i+=1
- img_id = item['img_id']
- categories = item['category']
- confidences = item['confidence']
-
+ for item in data["annotations"]:
+ i += 1
+ img_id = item["img_id"]
+ categories = item["category"]
+ confidences = item["confidence"]
+
# Check if there is any category '0' with confidence above the threshold
file_targeted_for_animal = False
for category, confidence in zip(categories, confidences):
if category == 0 and confidence > confidence_threshold:
file_targeted_for_animal = True
break
-
+
if file_targeted_for_animal:
target_folder = animal_path
else:
target_folder = no_animal_path
-
+
# Construct the source and destination file paths
src_file_path = os.path.join(img_path, img_id)
dest_file_path = os.path.join(target_folder, os.path.dirname(img_id))
os.makedirs(dest_file_path, exist_ok=True)
-
+
# Copy the file to the appropriate directory
shutil.copy(src_file_path, dest_file_path)
diff --git a/README.md b/README.md
index ce3c0efef..de464849d 100644
--- a/README.md
+++ b/README.md
@@ -222,4 +222,3 @@ We are also building a list of contributors and will release in future updates!
>[!IMPORTANT]
>If you would like to be added to this list or have any questions regarding MegaDetector and Pytorch-Wildlife, please [email us](zhongqimiao@microsoft.com) or join us in our Discord channel: [](https://discord.gg/TeEVxzaYtm)
-
diff --git a/demo/detection_classification_pipeline_demo.py b/demo/detection_classification_pipeline_demo.py
index 58f65ae48..c46e6b629 100644
--- a/demo/detection_classification_pipeline_demo.py
+++ b/demo/detection_classification_pipeline_demo.py
@@ -3,34 +3,37 @@
""" Demo for image detection"""
-#%%
+# %%
# Importing necessary basic libraries and modules
import os
+
import numpy as np
-from PIL import Image
import supervision as sv
-# PyTorch imports
+# PyTorch imports
import torch
+from PIL import Image
from torch.utils.data import DataLoader
-#%%
-# Importing the model, dataset, transformations and utility functions from PytorchWildlife
-from PytorchWildlife.models import detection as pw_detection
from PytorchWildlife import utils as pw_utils
+from PytorchWildlife.data import datasets as pw_data
+from PytorchWildlife.data import transforms as pw_trans
+# %%
+# Importing the model, dataset, transformations and utility functions from PytorchWildlife
from PytorchWildlife.models import classification as pw_classification
-from PytorchWildlife.data import transforms as pw_trans
-from PytorchWildlife.data import datasets as pw_data
+from PytorchWildlife.models import detection as pw_detection
-#%%
+# %%
# Setting the device to use for computations ('cuda' indicates GPU)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
-#%%
+# %%
# Initializing the MegaDetectorV6 model for image detection
# Valid versions are MDV6-yolov9-c, MDV6-yolov9-e, MDV6-yolov10-c, MDV6-yolov10-e or MDV6-rtdetr-c
-detection_model = pw_detection.MegaDetectorV6(device=DEVICE, pretrained=True, version="MDV6-yolov10-e")
+detection_model = pw_detection.MegaDetectorV6(
+ device=DEVICE, pretrained=True, version="MDV6-yolov10-e"
+)
# Uncomment the following line to use MegaDetectorV5 instead of MegaDetectorV6
# detection_model = pw_detection.MegaDetectorV5(device=DEVICE, pretrained=True, version="a")
@@ -38,49 +41,57 @@
# %%
# Initializing a classification model for image classification
# classification_model = pw_classification.DFNE(device=DEVICE)
-classification_model = pw_classification.AI4GAmazonRainforest(device=DEVICE, version='v2')
+classification_model = pw_classification.AI4GAmazonRainforest(device=DEVICE, version="v2")
-#%% Single image detection
+# %% Single image detection
# Specifying the path to the target image TODO: Allow argparsing
-tgt_img_path = os.path.join(".","demo_data","imgs","10050028_0.JPG")
+tgt_img_path = os.path.join(".", "demo_data", "imgs", "10050028_0.JPG")
# Performing the detection on the single image
results = detection_model.single_image_detection(tgt_img_path)
clf_conf_thres = 0.8
-input_img = np.array(Image.open(tgt_img_path).convert('RGB'))
+input_img = np.array(Image.open(tgt_img_path).convert("RGB"))
clf_labels = []
for i, (xyxy, det_id) in enumerate(zip(results["detections"].xyxy, results["detections"].class_id)):
# Only run classifier when detection class is animal
if det_id == 0:
cropped_image = sv.crop_image(image=input_img, xyxy=xyxy)
results_clf = classification_model.single_image_classification(cropped_image)
- clf_labels.append("{} {:.2f}".format(results_clf["prediction"] if results_clf["confidence"] > clf_conf_thres else "Unknown",
- results_clf["confidence"]))
+ clf_labels.append(
+ "{} {:.2f}".format(
+ results_clf["prediction"]
+ if results_clf["confidence"] > clf_conf_thres
+ else "Unknown",
+ results_clf["confidence"],
+ )
+ )
else:
clf_labels.append(results["labels"][i])
results["labels"] = clf_labels
# %%
-# Saving the detection results
-pw_utils.save_detection_images(results, os.path.join(".","demo_output"), overwrite=False)
+# Saving the detection results
+pw_utils.save_detection_images(results, os.path.join(".", "demo_output"), overwrite=False)
# %%# Saving the detected objects as cropped images
-pw_utils.save_crop_images(results, os.path.join(".","crop_output"), overwrite=False)
+pw_utils.save_crop_images(results, os.path.join(".", "crop_output"), overwrite=False)
# %%
-#%% Batch detection
+# %% Batch detection
""" Batch-detection demo """
# Specifying the folder path containing multiple images for batch detection
-tgt_folder_path = os.path.join(".","demo_data","imgs")
+tgt_folder_path = os.path.join(".", "demo_data", "imgs")
# Performing batch detection on the images
det_results = detection_model.batch_image_detection(tgt_folder_path, batch_size=16)
-clf_results = classification_model.batch_image_classification(det_results=det_results, id_strip=tgt_folder_path)
+clf_results = classification_model.batch_image_classification(
+ det_results=det_results, id_strip=tgt_folder_path
+)
# %%
merged_results = det_results.copy()
@@ -91,33 +102,43 @@
clf_labels = []
for i, (xyxy, det_id) in enumerate(zip(det["detections"].xyxy, det["detections"].class_id)):
if det_id == 0:
- clf_labels.append("{} {:.2f}".format(clf_results[clf_counter]["prediction"] if clf_results[clf_counter]["confidence"] > clf_conf_thres else "Unknown",
- clf_results[clf_counter]["confidence"]))
+ clf_labels.append(
+ "{} {:.2f}".format(
+ clf_results[clf_counter]["prediction"]
+ if clf_results[clf_counter]["confidence"] > clf_conf_thres
+ else "Unknown",
+ clf_results[clf_counter]["confidence"],
+ )
+ )
else:
clf_labels.append(det["labels"][i])
clf_counter += 1
det["labels"] = clf_labels
-#%% Output to annotated images
+# %% Output to annotated images
# Saving the batch detection results as annotated images
pw_utils.save_detection_images(merged_results, "batch_output", tgt_folder_path, overwrite=False)
-#%% Output to cropped images
+# %% Output to cropped images
# Saving the detected objects as cropped images
pw_utils.save_crop_images(merged_results, "crop_output", tgt_folder_path, overwrite=False)
-#%% Output to JSON results
+# %% Output to JSON results
# Saving the detection results in JSON format
-pw_utils.save_detection_classification_json(det_results=det_results,
- clf_results=clf_results,
- det_categories=detection_model.CLASS_NAMES,
- clf_categories=classification_model.CLASS_NAMES,
- output_path=os.path.join(".","batch_output_classification.json"))
+pw_utils.save_detection_classification_json(
+ det_results=det_results,
+ clf_results=clf_results,
+ det_categories=detection_model.CLASS_NAMES,
+ clf_categories=classification_model.CLASS_NAMES,
+ output_path=os.path.join(".", "batch_output_classification.json"),
+)
# %%
# Saving the detection results in timelapse JSON format
-pw_utils.save_detection_classification_timelapse_json(det_results=det_results,
- clf_results=clf_results,
- det_categories=detection_model.CLASS_NAMES,
- clf_categories=classification_model.CLASS_NAMES,
- output_path=os.path.join(".","batch_output_classification_timelapse.json"))
\ No newline at end of file
+pw_utils.save_detection_classification_timelapse_json(
+ det_results=det_results,
+ clf_results=clf_results,
+ det_categories=detection_model.CLASS_NAMES,
+ clf_categories=classification_model.CLASS_NAMES,
+ output_path=os.path.join(".", "batch_output_classification_timelapse.json"),
+)
diff --git a/demo/gradio_demo.py b/demo/gradio_demo.py
index 58c5d7962..6f5f70ce9 100644
--- a/demo/gradio_demo.py
+++ b/demo/gradio_demo.py
@@ -3,31 +3,33 @@
""" Gradio Demo for image detection"""
+import ast
+
# Importing necessary basic libraries and modules
import os
-# PyTorch imports
-import torch
-from torch.utils.data import DataLoader
-
-# Importing the model, dataset, transformations and utility functions from PytorchWildlife
-from PytorchWildlife.models import detection as pw_detection
-from PytorchWildlife import utils as pw_utils
-
# Importing basic libraries
import shutil
import time
-from PIL import Image
-import supervision as sv
-import gradio as gr
from zipfile import ZipFile
+
+import gradio as gr
import numpy as np
-import ast
+import supervision as sv
+
+# PyTorch imports
+import torch
+from PIL import Image
+from torch.utils.data import DataLoader
+
+from PytorchWildlife import utils as pw_utils
+from PytorchWildlife.data import datasets as pw_data
+from PytorchWildlife.data import transforms as pw_trans
# Importing the models, dataset, transformations, and utility functions from PytorchWildlife
+# Importing the model, dataset, transformations and utility functions from PytorchWildlife
from PytorchWildlife.models import classification as pw_classification
-from PytorchWildlife.data import transforms as pw_trans
-from PytorchWildlife.data import datasets as pw_data
+from PytorchWildlife.models import detection as pw_detection
# Setting the device to use for computations ('cuda' indicates GPU)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@@ -36,15 +38,15 @@
box_annotator = sv.BoxAnnotator(thickness=4)
lab_annotator = sv.LabelAnnotator(text_color=sv.Color.BLACK, text_thickness=4, text_scale=2)
# Create a temp folder
-os.makedirs(os.path.join("..","temp"), exist_ok=True) # ASK: Why do we need this?
+os.makedirs(os.path.join("..", "temp"), exist_ok=True) # ASK: Why do we need this?
# Initializing the detection and classification models
detection_model = None
classification_model = None
-
+
+
# Defining functions for different detection scenarios
def load_models(det, version, clf, wpath=None, wclass=None):
-
global detection_model, classification_model
if det != "None":
if det == "HerdNet General":
@@ -53,11 +55,17 @@ def load_models(det, version, clf, wpath=None, wclass=None):
detection_model = pw_detection.HerdNet(device=DEVICE, version="ennedi")
else:
if "mit" in version:
- detection_model = pw_detection.MegaDetectorV6MIT(device=DEVICE, pretrained=True, version=version)
+ detection_model = pw_detection.MegaDetectorV6MIT(
+ device=DEVICE, pretrained=True, version=version
+ )
elif "apache" in version:
- detection_model = pw_detection.MegaDetectorV6Apache(device=DEVICE, pretrained=True, version=version)
+ detection_model = pw_detection.MegaDetectorV6Apache(
+ device=DEVICE, pretrained=True, version=version
+ )
else:
- detection_model = pw_detection.__dict__[det](device=DEVICE, pretrained=True, version=version)
+ detection_model = pw_detection.__dict__[det](
+ device=DEVICE, pretrained=True, version=version
+ )
else:
detection_model = None
return "NO MODEL LOADED!!"
@@ -65,9 +73,11 @@ def load_models(det, version, clf, wpath=None, wclass=None):
if clf != "None":
# Create an exception for custom weights
if clf == "CustomWeights":
- if (wpath is not None) and (wclass is not None):
+ if (wpath is not None) and (wclass is not None):
wclass = ast.literal_eval(wclass)
- classification_model = pw_classification.__dict__[clf](weights=wpath, class_names=wclass, device=DEVICE)
+ classification_model = pw_classification.__dict__[clf](
+ weights=wpath, class_names=wclass, device=DEVICE
+ )
else:
classification_model = pw_classification.__dict__[clf](device=DEVICE, pretrained=True)
else:
@@ -92,25 +102,35 @@ def single_image_detection(input_img, det_conf_thres, clf_conf_thres, img_index=
if detection_model.__class__.__name__.__contains__("HerdNet"):
annotator = dot_annotator
# Herdnet receives both clf and det confidence thresholds
- results_det = detection_model.single_image_detection(input_img,
- img_path=img_index,
- det_conf_thres=det_conf_thres,
- clf_conf_thres=clf_conf_thres)
+ results_det = detection_model.single_image_detection(
+ input_img,
+ img_path=img_index,
+ det_conf_thres=det_conf_thres,
+ clf_conf_thres=clf_conf_thres,
+ )
else:
annotator = box_annotator
- results_det = detection_model.single_image_detection(input_img,
- img_path=img_index,
- det_conf_thres = det_conf_thres)
-
+ results_det = detection_model.single_image_detection(
+ input_img, img_path=img_index, det_conf_thres=det_conf_thres
+ )
+
if classification_model is not None:
labels = []
- for i, (xyxy, det_id) in enumerate(zip(results_det["detections"].xyxy, results_det["detections"].class_id)):
+ for i, (xyxy, det_id) in enumerate(
+ zip(results_det["detections"].xyxy, results_det["detections"].class_id)
+ ):
# Only run classifier when detection class is animal
if det_id == 0:
cropped_image = sv.crop_image(image=input_img, xyxy=xyxy)
results_clf = classification_model.single_image_classification(cropped_image)
- labels.append("{} {:.2f}".format(results_clf["prediction"] if results_clf["confidence"] > clf_conf_thres else "Unknown",
- results_clf["confidence"]))
+ labels.append(
+ "{} {:.2f}".format(
+ results_clf["prediction"]
+ if results_clf["confidence"] > clf_conf_thres
+ else "Unknown",
+ results_clf["confidence"],
+ )
+ )
else:
labels.append(results_det["labels"][i])
else:
@@ -126,9 +146,10 @@ def single_image_detection(input_img, det_conf_thres, clf_conf_thres, img_index=
)
return annotated_img
+
def batch_detection(zip_file, timelapse, det_conf_thres):
"""Perform detection on a batch of images from a zip file and return path to results JSON.
-
+
Args:
zip_file (File): Zip file containing images.
det_conf_thres (float): Confidence threshold for detection.
@@ -139,7 +160,7 @@ def batch_detection(zip_file, timelapse, det_conf_thres):
json_save_path (str): Path to the JSON file containing detection results.
"""
# Clean the temp folder if it contains files
- extract_path = os.path.join("..","temp","zip_upload")
+ extract_path = os.path.join("..", "temp", "zip_upload")
if os.path.exists(extract_path):
shutil.rmtree(extract_path)
os.makedirs(extract_path)
@@ -149,53 +170,76 @@ def batch_detection(zip_file, timelapse, det_conf_thres):
zfile.extractall(extract_path)
# Check the contents of the extracted folder
extracted_files = os.listdir(extract_path)
-
+
if len(extracted_files) == 1 and os.path.isdir(os.path.join(extract_path, extracted_files[0])):
tgt_folder_path = os.path.join(extract_path, extracted_files[0])
else:
tgt_folder_path = extract_path
# If the detection model is HerdNet set batch_size to 1
if detection_model.__class__.__name__.__contains__("HerdNet"):
- det_results = detection_model.batch_image_detection(tgt_folder_path, batch_size=1, det_conf_thres=det_conf_thres, id_strip=tgt_folder_path)
+ det_results = detection_model.batch_image_detection(
+ tgt_folder_path, batch_size=1, det_conf_thres=det_conf_thres, id_strip=tgt_folder_path
+ )
else:
- det_results = detection_model.batch_image_detection(tgt_folder_path, batch_size=16, det_conf_thres=det_conf_thres, id_strip=tgt_folder_path)
+ det_results = detection_model.batch_image_detection(
+ tgt_folder_path, batch_size=16, det_conf_thres=det_conf_thres, id_strip=tgt_folder_path
+ )
if classification_model is not None:
clf_dataset = pw_data.DetectionCrops(
det_results,
transform=pw_trans.Classification_Inference_Transform(target_size=224),
- path_head=tgt_folder_path
+ path_head=tgt_folder_path,
+ )
+ clf_loader = DataLoader(
+ clf_dataset,
+ batch_size=32,
+ shuffle=False,
+ pin_memory=True,
+ num_workers=4,
+ drop_last=False,
+ )
+ clf_results = classification_model.batch_image_classification(
+ clf_loader, id_strip=tgt_folder_path
)
- clf_loader = DataLoader(clf_dataset, batch_size=32, shuffle=False,
- pin_memory=True, num_workers=4, drop_last=False)
- clf_results = classification_model.batch_image_classification(clf_loader, id_strip=tgt_folder_path)
if timelapse:
json_save_path = json_save_path.replace(".json", "_timelapse.json")
- pw_utils.save_detection_classification_timelapse_json(det_results=det_results,
- clf_results=clf_results,
- det_categories=detection_model.CLASS_NAMES,
- clf_categories=classification_model.CLASS_NAMES,
- output_path=json_save_path)
+ pw_utils.save_detection_classification_timelapse_json(
+ det_results=det_results,
+ clf_results=clf_results,
+ det_categories=detection_model.CLASS_NAMES,
+ clf_categories=classification_model.CLASS_NAMES,
+ output_path=json_save_path,
+ )
else:
- pw_utils.save_detection_classification_json(det_results=det_results,
- clf_results=clf_results,
- det_categories=detection_model.CLASS_NAMES,
- clf_categories=classification_model.CLASS_NAMES,
- output_path=json_save_path)
+ pw_utils.save_detection_classification_json(
+ det_results=det_results,
+ clf_results=clf_results,
+ det_categories=detection_model.CLASS_NAMES,
+ clf_categories=classification_model.CLASS_NAMES,
+ output_path=json_save_path,
+ )
else:
if timelapse:
json_save_path = json_save_path.replace(".json", "_timelapse.json")
- pw_utils.save_detection_timelapse_json(det_results, json_save_path, categories=detection_model.CLASS_NAMES)
+ pw_utils.save_detection_timelapse_json(
+ det_results, json_save_path, categories=detection_model.CLASS_NAMES
+ )
elif detection_model.__class__.__name__.__contains__("HerdNet"):
- pw_utils.save_detection_json_as_dots(det_results, json_save_path, categories=detection_model.CLASS_NAMES)
- else:
- pw_utils.save_detection_json(det_results, json_save_path, categories=detection_model.CLASS_NAMES)
+ pw_utils.save_detection_json_as_dots(
+ det_results, json_save_path, categories=detection_model.CLASS_NAMES
+ )
+ else:
+ pw_utils.save_detection_json(
+ det_results, json_save_path, categories=detection_model.CLASS_NAMES
+ )
return json_save_path
+
def batch_path_detection(tgt_folder_path, det_conf_thres):
"""Perform detection on a batch of images from a zip file and return path to results JSON.
-
+
Args:
tgt_folder_path (str): path to the folder containing the images.
det_conf_thres (float): Confidence threshold for detection.
@@ -204,36 +248,48 @@ def batch_path_detection(tgt_folder_path, det_conf_thres):
"""
json_save_path = os.path.join(tgt_folder_path, "results.json")
- det_results = detection_model.batch_image_detection(tgt_folder_path, det_conf_thres=det_conf_thres, id_strip=tgt_folder_path)
+ det_results = detection_model.batch_image_detection(
+ tgt_folder_path, det_conf_thres=det_conf_thres, id_strip=tgt_folder_path
+ )
if detection_model.__class__.__name__.__contains__("HerdNet"):
- pw_utils.save_detection_json_as_dots(det_results, json_save_path, categories=detection_model.CLASS_NAMES)
+ pw_utils.save_detection_json_as_dots(
+ det_results, json_save_path, categories=detection_model.CLASS_NAMES
+ )
else:
- pw_utils.save_detection_json(det_results, json_save_path, categories=detection_model.CLASS_NAMES)
+ pw_utils.save_detection_json(
+ det_results, json_save_path, categories=detection_model.CLASS_NAMES
+ )
return json_save_path
def video_detection(video, det_conf_thres, clf_conf_thres, target_fps, codec):
"""Perform detection on a video and return path to processed video.
-
+
Args:
video (str): Video source path.
det_conf_thres (float): Confidence threshold for detection.
clf_conf_thres (float): Confidence threshold for classification.
"""
+
def callback(frame, index):
- annotated_frame = single_image_detection(frame,
- img_index=index,
- det_conf_thres=det_conf_thres,
- clf_conf_thres=clf_conf_thres)
- return annotated_frame
-
- target_path = os.path.join("..","temp","video_detection.mp4")
- pw_utils.process_video(source_path=video, target_path=target_path,
- callback=callback, target_fps=int(target_fps), codec=codec)
+ annotated_frame = single_image_detection(
+ frame, img_index=index, det_conf_thres=det_conf_thres, clf_conf_thres=clf_conf_thres
+ )
+ return annotated_frame
+
+ target_path = os.path.join("..", "temp", "video_detection.mp4")
+ pw_utils.process_video(
+ source_path=video,
+ target_path=target_path,
+ callback=callback,
+ target_fps=int(target_fps),
+ codec=codec,
+ )
return target_path
+
# Building Gradio UI
with gr.Blocks() as demo:
@@ -243,37 +299,72 @@ def callback(frame, index):
["None", "MegaDetectorV5", "MegaDetectorV6", "HerdNet General", "HerdNet Ennedi"],
label="Detection model",
info="Will add more detection models!",
- value="None" # Default
+ value="None", # Default
)
- det_version = gr.Dropdown(
- ["None"],
- label="Model version",
+ det_version = gr.Dropdown(
+ ["None"],
+ label="Model version",
info="Select the version of the model",
value="None",
)
-
+
with gr.Column():
clf_drop = gr.Dropdown(
- ["None", "AI4GOpossum", "AI4GAmazonRainforest", "AI4GSnapshotSerengeti", "CustomWeights"],
+ [
+ "None",
+ "AI4GOpossum",
+ "AI4GAmazonRainforest",
+ "AI4GSnapshotSerengeti",
+ "CustomWeights",
+ ],
interactive=True,
label="Classification model",
info="Will add more classification models!",
visible=False,
- value="None"
+ value="None",
+ )
+ custom_weights_path = gr.Textbox(
+ label="Custom Weights Path",
+ visible=False,
+ interactive=True,
+ placeholder="./weights/my_weight.pt",
+ )
+ custom_weights_class = gr.Textbox(
+ label="Custom Weights Class",
+ visible=False,
+ interactive=True,
+ placeholder="{1:'ocelot', 2:'cow', 3:'bear'}",
)
- custom_weights_path = gr.Textbox(label="Custom Weights Path", visible=False, interactive=True, placeholder="./weights/my_weight.pt")
- custom_weights_class = gr.Textbox(label="Custom Weights Class", visible=False, interactive=True, placeholder="{1:'ocelot', 2:'cow', 3:'bear'}")
load_but = gr.Button("Load Models!")
load_out = gr.Text("NO MODEL LOADED!!", label="Loaded models:")
-
- def update_ui_elements(det_model):
- if det_model == "MegaDetectorV6":
- return gr.Dropdown(choices=["MDV6-yolov9-c", "MDV6-yolov9-e", "MDV6-yolov10-c", "MDV6-yolov10-e", "MDV6-rtdetr-c", "MDV6-yolov9-c-mit", "MDV6-yolov9-e-mit", "MDV6-rtdetr-c-apache", "MDV6-rtdetr-e-apache"], interactive=True, label="Model version", value="MDV6-yolov9e"), gr.update(visible=True)
- elif det_model == "MegaDetectorV5":
- return gr.Dropdown(choices=["a", "b"], interactive=True, label="Model version", value="a"), gr.update(visible=True)
+
+ def update_ui_elements(det_model):
+ if det_model == "MegaDetectorV6":
+ return gr.Dropdown(
+ choices=[
+ "MDV6-yolov9-c",
+ "MDV6-yolov9-e",
+ "MDV6-yolov10-c",
+ "MDV6-yolov10-e",
+ "MDV6-rtdetr-c",
+ "MDV6-yolov9-c-mit",
+ "MDV6-yolov9-e-mit",
+ "MDV6-rtdetr-c-apache",
+ "MDV6-rtdetr-e-apache",
+ ],
+ interactive=True,
+ label="Model version",
+ value="MDV6-yolov9e",
+ ), gr.update(visible=True)
+ elif det_model == "MegaDetectorV5":
+ return gr.Dropdown(
+ choices=["a", "b"], interactive=True, label="Model version", value="a"
+ ), gr.update(visible=True)
else:
- return gr.Dropdown(choices=["None"], interactive=True, label="Model version", value="None"), gr.update(value="None", visible=False)
-
+ return gr.Dropdown(
+ choices=["None"], interactive=True, label="Model version", value="None"
+ ), gr.update(value="None", visible=False)
+
det_drop.change(update_ui_elements, det_drop, [det_version, clf_drop])
def toggle_textboxes(model):
@@ -281,20 +372,18 @@ def toggle_textboxes(model):
return gr.update(visible=True), gr.update(visible=True)
else:
return gr.update(visible=False), gr.update(visible=False)
-
- clf_drop.change(
- toggle_textboxes,
- clf_drop,
- [custom_weights_path, custom_weights_class]
- )
+
+ clf_drop.change(toggle_textboxes, clf_drop, [custom_weights_path, custom_weights_class])
with gr.Tab("Single Image Process"):
with gr.Row():
with gr.Column():
sgl_in = gr.Image(type="pil")
sgl_conf_sl_det = gr.Slider(0, 1, label="Detection Confidence Threshold", value=0.2)
- sgl_conf_sl_clf = gr.Slider(0, 1, label="Classification Confidence Threshold", value=0.7, visible=True)
- sgl_out = gr.Image()
+ sgl_conf_sl_clf = gr.Slider(
+ 0, 1, label="Classification Confidence Threshold", value=0.7, visible=True
+ )
+ sgl_out = gr.Image()
sgl_but = gr.Button("Detect Animals!")
with gr.Tab("Folder Separation"):
with gr.Row():
@@ -306,9 +395,18 @@ def toggle_textboxes(model):
bth_out2 = gr.File(label="Detection Results JSON.", height=200)
with gr.Column():
process_files_button = gr.Button("Separate files")
- process_result = gr.Text("Click on 'Separate files' once you see the JSON file", label="Separated files:")
- process_btn.click(batch_path_detection, inputs=[inp_path, bth_conf_fs], outputs=bth_out2)
- process_files_button.click(pw_utils.detection_folder_separation, inputs=[bth_out2, inp_path, out_path, bth_conf_fs], outputs=process_result)
+ process_result = gr.Text(
+ "Click on 'Separate files' once you see the JSON file",
+ label="Separated files:",
+ )
+ process_btn.click(
+ batch_path_detection, inputs=[inp_path, bth_conf_fs], outputs=bth_out2
+ )
+ process_files_button.click(
+ pw_utils.detection_folder_separation,
+ inputs=[bth_out2, inp_path, out_path, bth_conf_fs],
+ outputs=process_result,
+ )
with gr.Tab("Batch Image Process"):
with gr.Row():
with gr.Column():
@@ -323,28 +421,42 @@ def toggle_textboxes(model):
with gr.Column():
vid_in = gr.Video(label="Upload a video.")
vid_conf_sl_det = gr.Slider(0, 1, label="Detection Confidence Threshold", value=0.2)
- vid_conf_sl_clf = gr.Slider(0, 1, label="Classification Confidence Threshold", value=0.7)
+ vid_conf_sl_clf = gr.Slider(
+ 0, 1, label="Classification Confidence Threshold", value=0.7
+ )
vid_fr = gr.Dropdown([5, 10, 30], label="Output video framerate", value=30)
vid_enc = gr.Dropdown(
["mp4v", "avc1"],
label="Video encoder",
info="mp4v is default, av1c is faster (needs conda install opencv)",
- value="mp4v"
- )
+ value="mp4v",
+ )
vid_out = gr.Video()
vid_but = gr.Button("Detect Animals!")
-
+
# Show timelapsed checkbox only when detection model is not HerdNet
det_drop.change(
- lambda model: gr.update(visible=True) if "HerdNet" not in model else gr.update(visible=False),
+ lambda model: gr.update(visible=True)
+ if "HerdNet" not in model
+ else gr.update(visible=False),
det_drop,
- [chck_timelapse]
+ [chck_timelapse],
+ )
+
+ load_but.click(
+ load_models,
+ inputs=[det_drop, det_version, clf_drop, custom_weights_path, custom_weights_class],
+ outputs=load_out,
+ )
+ sgl_but.click(
+ single_image_detection, inputs=[sgl_in, sgl_conf_sl_det, sgl_conf_sl_clf], outputs=sgl_out
)
-
- load_but.click(load_models, inputs=[det_drop, det_version, clf_drop, custom_weights_path, custom_weights_class], outputs=load_out)
- sgl_but.click(single_image_detection, inputs=[sgl_in, sgl_conf_sl_det, sgl_conf_sl_clf], outputs=sgl_out)
bth_but.click(batch_detection, inputs=[bth_in, chck_timelapse, bth_conf_sl], outputs=bth_out)
- vid_but.click(video_detection, inputs=[vid_in, vid_conf_sl_det, vid_conf_sl_clf, vid_fr, vid_enc], outputs=vid_out)
+ vid_but.click(
+ video_detection,
+ inputs=[vid_in, vid_conf_sl_det, vid_conf_sl_clf, vid_fr, vid_enc],
+ outputs=vid_out,
+ )
if __name__ == "__main__":
demo.launch(share=True)
diff --git a/demo/image_demo.py b/demo/image_demo.py
index 7e65bf5da..a9d347ef7 100644
--- a/demo/image_demo.py
+++ b/demo/image_demo.py
@@ -3,25 +3,29 @@
""" Demo for image detection"""
-#%%
+# %%
# Importing necessary basic libraries and modules
import os
-# PyTorch imports
+
+# PyTorch imports
import torch
-#%%
+from PytorchWildlife import utils as pw_utils
+
+# %%
# Importing the model, dataset, transformations and utility functions from PytorchWildlife
from PytorchWildlife.models import detection as pw_detection
-from PytorchWildlife import utils as pw_utils
-#%%
+# %%
# Setting the device to use for computations ('cuda' indicates GPU)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
-#%%
+# %%
# Initializing the MegaDetectorV6 model for image detection
# Valid versions are MDV6-yolov9-c, MDV6-yolov9-e, MDV6-yolov10-c, MDV6-yolov10-e or MDV6-rtdetr-c
-detection_model = pw_detection.MegaDetectorV6(device=DEVICE, pretrained=True, version="MDV6-yolov10-e")
+detection_model = pw_detection.MegaDetectorV6(
+ device=DEVICE, pretrained=True, version="MDV6-yolov10-e"
+)
# Uncomment the following line to use MegaDetectorV6 with yolo v9 MIT weights
# Valid versions are MDV6-mit-yolov9-c, MDV6-mit-yolov9-e
@@ -34,46 +38,52 @@
# Uncomment the following line to use MegaDetectorV5 instead of MegaDetectorV6
# detection_model = pw_detection.MegaDetectorV5(device=DEVICE, pretrained=True, version="a")
-#%% Single image detection
+# %% Single image detection
# Specifying the path to the target image TODO: Allow argparsing
-tgt_img_path = os.path.join(".","demo_data","imgs","10050028_0.JPG")
+tgt_img_path = os.path.join(".", "demo_data", "imgs", "10050028_0.JPG")
# Performing the detection on the single image
results = detection_model.single_image_detection(tgt_img_path)
-# Saving the detection results
-pw_utils.save_detection_images(results, os.path.join(".","demo_output"), overwrite=False)
+# Saving the detection results
+pw_utils.save_detection_images(results, os.path.join(".", "demo_output"), overwrite=False)
# Saving the detected objects as cropped images
-pw_utils.save_crop_images(results, os.path.join(".","crop_output"), overwrite=False)
+pw_utils.save_crop_images(results, os.path.join(".", "crop_output"), overwrite=False)
-#%% Batch detection
+# %% Batch detection
""" Batch-detection demo """
# Specifying the folder path containing multiple images for batch detection
-tgt_folder_path = os.path.join(".","demo_data","imgs")
+tgt_folder_path = os.path.join(".", "demo_data", "imgs")
# Performing batch detection on the images
results = detection_model.batch_image_detection(tgt_folder_path, batch_size=16)
-#%% Output to annotated images
+# %% Output to annotated images
# Saving the batch detection results as annotated images
pw_utils.save_detection_images(results, "batch_output", tgt_folder_path, overwrite=False)
-#%% Output to cropped images
+# %% Output to cropped images
# Saving the detected objects as cropped images
pw_utils.save_crop_images(results, "crop_output", tgt_folder_path, overwrite=False)
-#%% Output to JSON results
+# %% Output to JSON results
# Saving the detection results in JSON format
-pw_utils.save_detection_json(results, os.path.join(".","batch_output.json"),
- categories=detection_model.CLASS_NAMES,
- exclude_category_ids=[], # Category IDs can be found in the definition of each model.
- exclude_file_path=None)
+pw_utils.save_detection_json(
+ results,
+ os.path.join(".", "batch_output.json"),
+ categories=detection_model.CLASS_NAMES,
+ exclude_category_ids=[], # Category IDs can be found in the definition of each model.
+ exclude_file_path=None,
+)
# Saving the detection results in timelapse JSON format
-pw_utils.save_detection_timelapse_json(results, os.path.join(".","batch_output_timelapse.json"),
- categories=detection_model.CLASS_NAMES,
- exclude_category_ids=[], # Category IDs can be found in the definition of each model.
- exclude_file_path=tgt_folder_path,
- info={"detector": "MegaDetectorV6"})
\ No newline at end of file
+pw_utils.save_detection_timelapse_json(
+ results,
+ os.path.join(".", "batch_output_timelapse.json"),
+ categories=detection_model.CLASS_NAMES,
+ exclude_category_ids=[], # Category IDs can be found in the definition of each model.
+ exclude_file_path=tgt_folder_path,
+ info={"detector": "MegaDetectorV6"},
+)
diff --git a/demo/image_demo_herdnet.py b/demo/image_demo_herdnet.py
index d8837caf2..37c0f85eb 100644
--- a/demo/image_demo_herdnet.py
+++ b/demo/image_demo_herdnet.py
@@ -2,51 +2,62 @@
# Licensed under the MIT License.
""" Demo for Herdnet image detection"""
-#%%
+# %%
# Importing necessary basic libraries and modules
import os
-# PyTorch imports
+
+# PyTorch imports
import torch
-#%%
+from PytorchWildlife import utils as pw_utils
+
+# %%
# Importing the model, dataset, transformations and utility functions from PytorchWildlife
from PytorchWildlife.models import detection as pw_detection
-from PytorchWildlife import utils as pw_utils
-#%%
+# %%
# Setting the device to use for computations ('cuda' indicates GPU)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
-#%%
+# %%
# Initializing the HerdNet model for image detection
detection_model = pw_detection.HerdNet(device=DEVICE)
# If you want to use ennedi dataset weights, you can use the following line:
# detection_model = pw_detection.HerdNet(device=DEVICE, version="ennedi")
-#%% Single image detection
-img_path = os.path.join(".","demo_data","herdnet_imgs","S_11_05_16_DSC01556.JPG")
+# %% Single image detection
+img_path = os.path.join(".", "demo_data", "herdnet_imgs", "S_11_05_16_DSC01556.JPG")
# Performing the detection on the single image
results = detection_model.single_image_detection(img=img_path)
-#%% Output to annotated images
+# %% Output to annotated images
# Saving the batch detection results as annotated images
-pw_utils.save_detection_images_dots(results, os.path.join(".","herdnet_demo_output"), overwrite=False)
+pw_utils.save_detection_images_dots(
+ results, os.path.join(".", "herdnet_demo_output"), overwrite=False
+)
-#%% Batch image detection
+# %% Batch image detection
""" Batch-detection demo """
# Specifying the folder path containing multiple images for batch detection
-folder_path = os.path.join(".","demo_data","herdnet_imgs")
+folder_path = os.path.join(".", "demo_data", "herdnet_imgs")
# Performing batch detection on the images
-results = detection_model.batch_image_detection(folder_path, batch_size=1) # NOTE: Only use batch size 1 because each image is divided into patches and this batch is enough.
+results = detection_model.batch_image_detection(
+ folder_path, batch_size=1
+) # NOTE: Only use batch size 1 because each image is divided into patches and this batch is enough.
-#%% Output to annotated images
+# %% Output to annotated images
# Saving the batch detection results as annotated images
-pw_utils.save_detection_images_dots(results, "herdnet_demo_batch_output", folder_path, overwrite=False)
+pw_utils.save_detection_images_dots(
+ results, "herdnet_demo_batch_output", folder_path, overwrite=False
+)
# Saving the detection results in JSON format
-pw_utils.save_detection_json_as_dots(results, os.path.join(".","herdnet_demo_batch_output.json"),
- categories=detection_model.CLASS_NAMES,
- exclude_category_ids=[], # Category IDs can be found in the definition of each model.
- exclude_file_path=None)
+pw_utils.save_detection_json_as_dots(
+ results,
+ os.path.join(".", "herdnet_demo_batch_output.json"),
+ categories=detection_model.CLASS_NAMES,
+ exclude_category_ids=[], # Category IDs can be found in the definition of each model.
+ exclude_file_path=None,
+)
diff --git a/demo/image_separation_demo.py b/demo/image_separation_demo.py
index 962287601..10aee3e89 100644
--- a/demo/image_separation_demo.py
+++ b/demo/image_separation_demo.py
@@ -3,31 +3,50 @@
""" Demo for image separation between positive and negative detections"""
-#%%
+# %%
# Importing necessary basic libraries and modules
import argparse
import os
+
import torch
-# PyTorch imports
-from PytorchWildlife.models import detection as pw_detection
from PytorchWildlife import utils as pw_utils
-#%% Argument parsing
+# PyTorch imports
+from PytorchWildlife.models import detection as pw_detection
+
+# %% Argument parsing
parser = argparse.ArgumentParser(description="Batch image detection and separation")
-parser.add_argument('--image_folder', type=str, default=os.path.join(".","demo_data","imgs"), help='Folder path containing images for detection')
-parser.add_argument('--output_path', type=str, default='folder_separation', help='Path where the outputs will be saved')
-parser.add_argument('--threshold', type=float, default='0.2', help='Confidence threshold to consider a detection as positive')
+parser.add_argument(
+ "--image_folder",
+ type=str,
+ default=os.path.join(".", "demo_data", "imgs"),
+ help="Folder path containing images for detection",
+)
+parser.add_argument(
+ "--output_path",
+ type=str,
+ default="folder_separation",
+ help="Path where the outputs will be saved",
+)
+parser.add_argument(
+ "--threshold",
+ type=float,
+ default="0.2",
+ help="Confidence threshold to consider a detection as positive",
+)
args = parser.parse_args()
-#%%
+# %%
# Setting the device to use for computations ('cuda' indicates GPU)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
-#%%
+# %%
# Initializing the MegaDetectorV6 model for image detection
# Valid versions are MDV6-yolov9-c, MDV6-yolov9-e, MDV6-yolov10-c, MDV6-yolov10-e or MDV6-rtdetr-c
-detection_model = pw_detection.MegaDetectorV6(device=DEVICE, pretrained=True, version="MDV6-yolov10-e")
+detection_model = pw_detection.MegaDetectorV6(
+ device=DEVICE, pretrained=True, version="MDV6-yolov10-e"
+)
# Uncomment the following line to use MegaDetectorV6 with yolo v9 MIT weights
# Valid versions are MDV6-mit-yolov9-c, MDV6-mit-yolov9-e
@@ -40,20 +59,23 @@
# Uncomment the following line to use MegaDetectorV5 instead of MegaDetectorV6
# detection_model = pw_detection.MegaDetectorV5(device=DEVICE, pretrained=True, version="a")
-#%% Batch detection
+# %% Batch detection
""" Batch-detection demo """
# Performing batch detection on the images
results = detection_model.batch_image_detection(args.image_folder, batch_size=16)
-#%% Output to JSON results
+# %% Output to JSON results
# Saving the detection results in JSON format
os.makedirs(args.output_path, exist_ok=True)
json_file = os.path.join(args.output_path, "detection_results.json")
-pw_utils.save_detection_json(results, json_file,
- categories=detection_model.CLASS_NAMES,
- exclude_category_ids=[], # Category IDs can be found in the definition of each model.
- exclude_file_path=args.image_folder)
+pw_utils.save_detection_json(
+ results,
+ json_file,
+ categories=detection_model.CLASS_NAMES,
+ exclude_category_ids=[], # Category IDs can be found in the definition of each model.
+ exclude_file_path=args.image_folder,
+)
# Separate the positive and negative detections through file copying:
-pw_utils.detection_folder_separation(json_file, args.image_folder, args.output_path, args.threshold)
\ No newline at end of file
+pw_utils.detection_folder_separation(json_file, args.image_folder, args.output_path, args.threshold)
diff --git a/demo/video_demo.py b/demo/video_demo.py
index 3def5c9c9..98575522f 100644
--- a/demo/video_demo.py
+++ b/demo/video_demo.py
@@ -2,31 +2,36 @@
# Licensed under the MIT License.
""" Video detection demo """
-#%%
+import os
+
+# %%
# Importing necessary basic libraries and modules
import numpy as np
import supervision as sv
-#%%
+# %%
# PyTorch imports for tensor operations
import torch
-import os
-#%%
-# Importing the models, transformations, and utility functions from PytorchWildlife
-from PytorchWildlife.models import detection as pw_detection
-from PytorchWildlife.models import classification as pw_classification
+
from PytorchWildlife import utils as pw_utils
-#%%
+# %%
+# Importing the models, transformations, and utility functions from PytorchWildlife
+from PytorchWildlife.models import classification as pw_classification
+from PytorchWildlife.models import detection as pw_detection
+
+# %%
# Setting the device to use for computations ('cuda' indicates GPU)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
-SOURCE_VIDEO_PATH = os.path.join(".","demo_data","videos","opossum_example.MP4")
-TARGET_VIDEO_PATH = os.path.join(".","demo_data","videos","opossum_example_processed.MP4")
+SOURCE_VIDEO_PATH = os.path.join(".", "demo_data", "videos", "opossum_example.MP4")
+TARGET_VIDEO_PATH = os.path.join(".", "demo_data", "videos", "opossum_example_processed.MP4")
-#%%
+# %%
# Initializing the MegaDetectorV6 model for image detection
# Valid versions are MDV6-yolov9-c, MDV6-yolov9-e, MDV6-yolov10-c, MDV6-yolov10-e or MDV6-rtdetr-c
-detection_model = pw_detection.MegaDetectorV6(device=DEVICE, pretrained=True, version="MDV6-yolov10-e")
+detection_model = pw_detection.MegaDetectorV6(
+ device=DEVICE, pretrained=True, version="MDV6-yolov10-e"
+)
# Uncomment the following line to use MegaDetectorV6 with yolo v9 MIT weights
# Valid versions are MDV6-mit-yolov9-c, MDV6-mit-yolov9-e
@@ -39,27 +44,28 @@
# Uncomment the following line to use MegaDetectorV5 instead of MegaDetectorV6
# detection_model = pw_detection.MegaDetectorV5(device=DEVICE, pretrained=True, version="a")
-#%%
+# %%
# Initializing the model for image classification
classification_model = pw_classification.AI4GOpossum(device=DEVICE, pretrained=True)
-#%%
+# %%
# Initializing a box annotator for visualizing detections
box_annotator = sv.BoxAnnotator(thickness=4)
lab_annotator = sv.LabelAnnotator(text_color=sv.Color.BLACK, text_thickness=4, text_scale=2)
+
def callback(frame: np.ndarray, index: int) -> np.ndarray:
"""
Callback function to process each video frame for detection and classification.
-
+
Parameters:
- frame (np.ndarray): Video frame as a numpy array.
- index (int): Frame index.
-
+
Returns:
annotated_frame (np.ndarray): Annotated video frame.
"""
-
+
results_det = detection_model.single_image_detection(frame, img_path=index)
labels = []
@@ -77,8 +83,11 @@ def callback(frame: np.ndarray, index: int) -> np.ndarray:
detections=results_det["detections"],
labels=labels,
)
-
- return annotated_frame
+
+ return annotated_frame
+
# Processing the video and saving the result with annotated detections and classifications
-pw_utils.process_video(source_path=SOURCE_VIDEO_PATH, target_path=TARGET_VIDEO_PATH, callback=callback, target_fps=10)
+pw_utils.process_video(
+ source_path=SOURCE_VIDEO_PATH, target_path=TARGET_VIDEO_PATH, callback=callback, target_fps=10
+)
diff --git a/docs/base/data/datasets.md b/docs/base/data/datasets.md
index a16a16c8f..74a1ed2d2 100644
--- a/docs/base/data/datasets.md
+++ b/docs/base/data/datasets.md
@@ -1,3 +1,3 @@
# Datasets Module
-::: PytorchWildlife.data.datasets
\ No newline at end of file
+::: PytorchWildlife.data.datasets
diff --git a/docs/base/data/transforms.md b/docs/base/data/transforms.md
index b9b264689..7e12c941f 100644
--- a/docs/base/data/transforms.md
+++ b/docs/base/data/transforms.md
@@ -1,3 +1,3 @@
# Transforms Module
-::: PytorchWildlife.data.transforms
\ No newline at end of file
+::: PytorchWildlife.data.transforms
diff --git a/docs/base/models/classification/base_classifier.md b/docs/base/models/classification/base_classifier.md
index 6f729c9b1..38abedc78 100644
--- a/docs/base/models/classification/base_classifier.md
+++ b/docs/base/models/classification/base_classifier.md
@@ -1,3 +1,3 @@
# Base Classifier
-::: PytorchWildlife.models.classification.base_classifier
\ No newline at end of file
+::: PytorchWildlife.models.classification.base_classifier
diff --git a/docs/base/models/classification/resnet_base/amazon.md b/docs/base/models/classification/resnet_base/amazon.md
index 520bfb778..a299d9376 100644
--- a/docs/base/models/classification/resnet_base/amazon.md
+++ b/docs/base/models/classification/resnet_base/amazon.md
@@ -1,3 +1,3 @@
# Amazon
-::: PytorchWildlife.models.classification.resnet_base.amazon
\ No newline at end of file
+::: PytorchWildlife.models.classification.resnet_base.amazon
diff --git a/docs/base/models/classification/resnet_base/base_classifier.md b/docs/base/models/classification/resnet_base/base_classifier.md
index 68c1b339a..1bc9a62d0 100644
--- a/docs/base/models/classification/resnet_base/base_classifier.md
+++ b/docs/base/models/classification/resnet_base/base_classifier.md
@@ -1,3 +1,3 @@
# ResNet Base
-::: PytorchWildlife.models.classification.resnet_base.base_classifier
\ No newline at end of file
+::: PytorchWildlife.models.classification.resnet_base.base_classifier
diff --git a/docs/base/models/classification/resnet_base/custom_weights.md b/docs/base/models/classification/resnet_base/custom_weights.md
index e41b4e85e..b594cc188 100644
--- a/docs/base/models/classification/resnet_base/custom_weights.md
+++ b/docs/base/models/classification/resnet_base/custom_weights.md
@@ -1,3 +1,3 @@
# Custom Weights
-::: PytorchWildlife.models.classification.resnet_base.custom_weights
\ No newline at end of file
+::: PytorchWildlife.models.classification.resnet_base.custom_weights
diff --git a/docs/base/models/classification/resnet_base/opossum.md b/docs/base/models/classification/resnet_base/opossum.md
index 4f6b11f1c..54c6d0d52 100644
--- a/docs/base/models/classification/resnet_base/opossum.md
+++ b/docs/base/models/classification/resnet_base/opossum.md
@@ -1,3 +1,3 @@
# Opossum
-::: PytorchWildlife.models.classification.resnet_base.opossum
\ No newline at end of file
+::: PytorchWildlife.models.classification.resnet_base.opossum
diff --git a/docs/base/models/classification/resnet_base/serengeti.md b/docs/base/models/classification/resnet_base/serengeti.md
index 6f7a3760b..76c659c63 100644
--- a/docs/base/models/classification/resnet_base/serengeti.md
+++ b/docs/base/models/classification/resnet_base/serengeti.md
@@ -1,3 +1,3 @@
# Serengeti
-::: PytorchWildlife.models.classification.resnet_base.serengeti
\ No newline at end of file
+::: PytorchWildlife.models.classification.resnet_base.serengeti
diff --git a/docs/base/models/classification/timm_base/DFNE.md b/docs/base/models/classification/timm_base/DFNE.md
index 01f11fff9..17be1fa54 100644
--- a/docs/base/models/classification/timm_base/DFNE.md
+++ b/docs/base/models/classification/timm_base/DFNE.md
@@ -1,3 +1,3 @@
# DFNE
-::: PytorchWildlife.models.classification.timm_base.DFNE
\ No newline at end of file
+::: PytorchWildlife.models.classification.timm_base.DFNE
diff --git a/docs/base/models/classification/timm_base/Deepfaune.md b/docs/base/models/classification/timm_base/Deepfaune.md
index a943f5900..fe7ea756a 100644
--- a/docs/base/models/classification/timm_base/Deepfaune.md
+++ b/docs/base/models/classification/timm_base/Deepfaune.md
@@ -1,3 +1,3 @@
# Deepfaune
-::: PytorchWildlife.models.classification.timm_base.Deepfaune
\ No newline at end of file
+::: PytorchWildlife.models.classification.timm_base.Deepfaune
diff --git a/docs/base/models/classification/timm_base/base_classifier.md b/docs/base/models/classification/timm_base/base_classifier.md
index ef03c85cb..59905e415 100644
--- a/docs/base/models/classification/timm_base/base_classifier.md
+++ b/docs/base/models/classification/timm_base/base_classifier.md
@@ -1,3 +1,3 @@
# Timm Base
-::: PytorchWildlife.models.classification.timm_base.base_classifier
\ No newline at end of file
+::: PytorchWildlife.models.classification.timm_base.base_classifier
diff --git a/docs/base/models/detection/base_detector.md b/docs/base/models/detection/base_detector.md
index 7fbb5f9e0..6c0a74140 100644
--- a/docs/base/models/detection/base_detector.md
+++ b/docs/base/models/detection/base_detector.md
@@ -1,3 +1,3 @@
# Base Detector
-::: PytorchWildlife.models.detection.base_detector
\ No newline at end of file
+::: PytorchWildlife.models.detection.base_detector
diff --git a/docs/base/models/detection/herdnet.md b/docs/base/models/detection/herdnet.md
index da27b9618..b77ad508d 100644
--- a/docs/base/models/detection/herdnet.md
+++ b/docs/base/models/detection/herdnet.md
@@ -1,3 +1,3 @@
# HerdNet
-::: PytorchWildlife.models.detection.herdnet
\ No newline at end of file
+::: PytorchWildlife.models.detection.herdnet
diff --git a/docs/base/models/detection/herdnet/animaloc/data/patches.md b/docs/base/models/detection/herdnet/animaloc/data/patches.md
index 44ddb7616..1be5d70eb 100644
--- a/docs/base/models/detection/herdnet/animaloc/data/patches.md
+++ b/docs/base/models/detection/herdnet/animaloc/data/patches.md
@@ -1,3 +1,3 @@
# Patches
-::: PytorchWildlife.models.detection.herdnet.animaloc.data.patches
\ No newline at end of file
+::: PytorchWildlife.models.detection.herdnet.animaloc.data.patches
diff --git a/docs/base/models/detection/herdnet/animaloc/data/types.md b/docs/base/models/detection/herdnet/animaloc/data/types.md
index f1e42fae0..56dba606b 100644
--- a/docs/base/models/detection/herdnet/animaloc/data/types.md
+++ b/docs/base/models/detection/herdnet/animaloc/data/types.md
@@ -1,3 +1,3 @@
# Types
-::: PytorchWildlife.models.detection.herdnet.animaloc.data.types
\ No newline at end of file
+::: PytorchWildlife.models.detection.herdnet.animaloc.data.types
diff --git a/docs/base/models/detection/herdnet/animaloc/eval/lmds.md b/docs/base/models/detection/herdnet/animaloc/eval/lmds.md
index 76682ff3a..663f6daa0 100644
--- a/docs/base/models/detection/herdnet/animaloc/eval/lmds.md
+++ b/docs/base/models/detection/herdnet/animaloc/eval/lmds.md
@@ -1,3 +1,3 @@
# LMDS
-::: PytorchWildlife.models.detection.herdnet.animaloc.eval.lmds
\ No newline at end of file
+::: PytorchWildlife.models.detection.herdnet.animaloc.eval.lmds
diff --git a/docs/base/models/detection/herdnet/animaloc/eval/stitchers.md b/docs/base/models/detection/herdnet/animaloc/eval/stitchers.md
index 74939e373..325f5158f 100644
--- a/docs/base/models/detection/herdnet/animaloc/eval/stitchers.md
+++ b/docs/base/models/detection/herdnet/animaloc/eval/stitchers.md
@@ -1,3 +1,3 @@
# Stitchers
-::: PytorchWildlife.models.detection.herdnet.animaloc.eval.stitchers
\ No newline at end of file
+::: PytorchWildlife.models.detection.herdnet.animaloc.eval.stitchers
diff --git a/docs/base/models/detection/herdnet/dla.md b/docs/base/models/detection/herdnet/dla.md
index bf718a19f..2422b7e63 100644
--- a/docs/base/models/detection/herdnet/dla.md
+++ b/docs/base/models/detection/herdnet/dla.md
@@ -1,3 +1,3 @@
# DLA
-::: PytorchWildlife.models.detection.herdnet.dla
\ No newline at end of file
+::: PytorchWildlife.models.detection.herdnet.dla
diff --git a/docs/base/models/detection/herdnet/model.md b/docs/base/models/detection/herdnet/model.md
index c6ea9ce47..a4a5b685a 100644
--- a/docs/base/models/detection/herdnet/model.md
+++ b/docs/base/models/detection/herdnet/model.md
@@ -1,3 +1,3 @@
# Model
-::: PytorchWildlife.models.detection.herdnet.model
\ No newline at end of file
+::: PytorchWildlife.models.detection.herdnet.model
diff --git a/docs/base/models/detection/ultralytics_based/Deepfaune.md b/docs/base/models/detection/ultralytics_based/Deepfaune.md
index f432ea92f..43956a2ad 100644
--- a/docs/base/models/detection/ultralytics_based/Deepfaune.md
+++ b/docs/base/models/detection/ultralytics_based/Deepfaune.md
@@ -1,3 +1,3 @@
# Deepfaune
-::: PytorchWildlife.models.detection.ultralytics_based.Deepfaune
\ No newline at end of file
+::: PytorchWildlife.models.detection.ultralytics_based.Deepfaune
diff --git a/docs/base/models/detection/ultralytics_based/megadetectorv5.md b/docs/base/models/detection/ultralytics_based/megadetectorv5.md
index acc45ca84..e91a3d86b 100644
--- a/docs/base/models/detection/ultralytics_based/megadetectorv5.md
+++ b/docs/base/models/detection/ultralytics_based/megadetectorv5.md
@@ -1,3 +1,3 @@
# MegaDetector v5
-::: PytorchWildlife.models.detection.ultralytics_based.megadetectorv5
\ No newline at end of file
+::: PytorchWildlife.models.detection.ultralytics_based.megadetectorv5
diff --git a/docs/base/models/detection/ultralytics_based/megadetectorv6.md b/docs/base/models/detection/ultralytics_based/megadetectorv6.md
index 62fc7b329..fd8db0851 100644
--- a/docs/base/models/detection/ultralytics_based/megadetectorv6.md
+++ b/docs/base/models/detection/ultralytics_based/megadetectorv6.md
@@ -1,3 +1,3 @@
# MegaDetector v6
-::: PytorchWildlife.models.detection.ultralytics_based.megadetectorv6
\ No newline at end of file
+::: PytorchWildlife.models.detection.ultralytics_based.megadetectorv6
diff --git a/docs/base/models/detection/ultralytics_based/megadetectorv6_distributed.md b/docs/base/models/detection/ultralytics_based/megadetectorv6_distributed.md
index afab79d5c..b67a56fc8 100644
--- a/docs/base/models/detection/ultralytics_based/megadetectorv6_distributed.md
+++ b/docs/base/models/detection/ultralytics_based/megadetectorv6_distributed.md
@@ -1,3 +1,3 @@
# MegaDetector v6 Distributed
-::: PytorchWildlife.models.detection.ultralytics_based.megadetectorv6_distributed
\ No newline at end of file
+::: PytorchWildlife.models.detection.ultralytics_based.megadetectorv6_distributed
diff --git a/docs/base/models/detection/ultralytics_based/yolov5_base.md b/docs/base/models/detection/ultralytics_based/yolov5_base.md
index 57366f9b2..ffbf8d2f4 100644
--- a/docs/base/models/detection/ultralytics_based/yolov5_base.md
+++ b/docs/base/models/detection/ultralytics_based/yolov5_base.md
@@ -1,3 +1,3 @@
# YOLOv5 Base
-::: PytorchWildlife.models.detection.ultralytics_based.yolov5_base
\ No newline at end of file
+::: PytorchWildlife.models.detection.ultralytics_based.yolov5_base
diff --git a/docs/base/models/detection/ultralytics_based/yolov8_base.md b/docs/base/models/detection/ultralytics_based/yolov8_base.md
index b71b3ac0e..7661d844a 100644
--- a/docs/base/models/detection/ultralytics_based/yolov8_base.md
+++ b/docs/base/models/detection/ultralytics_based/yolov8_base.md
@@ -1,3 +1,3 @@
# YOLOv8 Base
-::: PytorchWildlife.models.detection.ultralytics_based.yolov8_base
\ No newline at end of file
+::: PytorchWildlife.models.detection.ultralytics_based.yolov8_base
diff --git a/docs/base/models/detection/ultralytics_based/yolov8_distributed.md b/docs/base/models/detection/ultralytics_based/yolov8_distributed.md
index 12feadc67..ebda79a85 100644
--- a/docs/base/models/detection/ultralytics_based/yolov8_distributed.md
+++ b/docs/base/models/detection/ultralytics_based/yolov8_distributed.md
@@ -1,3 +1,3 @@
# YOLOv8 Distributed
-::: PytorchWildlife.models.detection.ultralytics_based.yolov8_distributed
\ No newline at end of file
+::: PytorchWildlife.models.detection.ultralytics_based.yolov8_distributed
diff --git a/docs/base/overview.md b/docs/base/overview.md
index d2833a1ef..648e36c68 100644
--- a/docs/base/overview.md
+++ b/docs/base/overview.md
@@ -34,4 +34,4 @@ from PytorchWildlife.models import classification, detection
from PytorchWildlife.utils import misc, post_process
```
-Refer to the specific submodule documentation for detailed usage instructions.
\ No newline at end of file
+Refer to the specific submodule documentation for detailed usage instructions.
diff --git a/docs/base/utils/misc.md b/docs/base/utils/misc.md
index 72139cdf8..aed3f6d5a 100644
--- a/docs/base/utils/misc.md
+++ b/docs/base/utils/misc.md
@@ -1 +1 @@
-::: PytorchWildlife.utils.misc
\ No newline at end of file
+::: PytorchWildlife.utils.misc
diff --git a/docs/base/utils/post_process.md b/docs/base/utils/post_process.md
index 02b513c1b..eea069bcc 100644
--- a/docs/base/utils/post_process.md
+++ b/docs/base/utils/post_process.md
@@ -1 +1 @@
-::: PytorchWildlife.utils.post_process
\ No newline at end of file
+::: PytorchWildlife.utils.post_process
diff --git a/docs/build_mkdocs.md b/docs/build_mkdocs.md
index 5f80d7041..24ab17adb 100644
--- a/docs/build_mkdocs.md
+++ b/docs/build_mkdocs.md
@@ -33,4 +33,3 @@ To build the MkDocs site locally, follow these steps:
5. **Exclude the `site/` Directory**:
The `site/` directory is automatically generated and should not be included in version control. It is already added to `.gitignore`.
-
diff --git a/docs/cite.md b/docs/cite.md
index d75ba03e3..aa832a5a2 100644
--- a/docs/cite.md
+++ b/docs/cite.md
@@ -21,4 +21,4 @@ Also, don't forget to cite our original paper for MegaDetector:
eprint={1907.06772},
archivePrefix={arXiv},
}
-```
\ No newline at end of file
+```
diff --git a/docs/contribute.md b/docs/contribute.md
index 5dae03407..02ce775ca 100644
--- a/docs/contribute.md
+++ b/docs/contribute.md
@@ -16,4 +16,4 @@ Thanks for your interest in collaborating on Pytorch-Wildlife! Here you can find
Thank you for helping us improve PytorchWildlife!
-We have adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [us](mailto:zhongqimiao@microsoft.com) with any additional questions or comments.
\ No newline at end of file
+We have adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [us](mailto:zhongqimiao@microsoft.com) with any additional questions or comments.
diff --git a/docs/contributors.md b/docs/contributors.md
index 262400844..f929476b8 100644
--- a/docs/contributors.md
+++ b/docs/contributors.md
@@ -1 +1 @@
-# In construction
\ No newline at end of file
+# In construction
diff --git a/docs/demo_and_ui/demo_data.md b/docs/demo_and_ui/demo_data.md
index d0a53015f..41cb1ec9a 100644
--- a/docs/demo_and_ui/demo_data.md
+++ b/docs/demo_and_ui/demo_data.md
@@ -1,2 +1,2 @@
# Prepare data for demos
-Before we run any demos, please download some demo data for our demo notebooks and webapps from this [link](https://zenodo.org/records/15376499/files/demo_data_base.zip?download=1) and decompress the data in the [demo folder](https://github.com/microsoft/CameraTraps/tree/main/demo).
\ No newline at end of file
+Before we run any demos, please download some demo data for our demo notebooks and webapps from this [link](https://zenodo.org/records/15376499/files/demo_data_base.zip?download=1) and decompress the data in the [demo folder](https://github.com/microsoft/CameraTraps/tree/main/demo).
diff --git a/docs/demo_and_ui/ecoassist.md b/docs/demo_and_ui/ecoassist.md
index a28fcc04b..ed71206a7 100644
--- a/docs/demo_and_ui/ecoassist.md
+++ b/docs/demo_and_ui/ecoassist.md
@@ -1,2 +1,2 @@
# Pytorch-Wildlife modelsa are available with AddaxAI (formerly EcoAssist)!
-We are thrilled to announce our collaboration with [AddaxAI](https://addaxdatascience.com/addaxai/#spp-models)---a powerful user interface software that enables users to directly load models from the PyTorch-Wildlife model zoo for image analysis on local computers. With AddaxAI, you can now utilize MegaDetectorV5 and the classification models---AI4GAmazonRainforest and AI4GOpossum---for automatic animal detection and identification, alongside a comprehensive suite of pre- and post-processing tools. This partnership aims to enhance the overall user experience with PyTorch-Wildlife models for a general audience. We will work closely to bring more features together for more efficient and effective wildlife analysis in the future. Please refer to their tutorials on how to use Pytorch-Wildlife models with AddaxAI.
\ No newline at end of file
+We are thrilled to announce our collaboration with [AddaxAI](https://addaxdatascience.com/addaxai/#spp-models)---a powerful user interface software that enables users to directly load models from the PyTorch-Wildlife model zoo for image analysis on local computers. With AddaxAI, you can now utilize MegaDetectorV5 and the classification models---AI4GAmazonRainforest and AI4GOpossum---for automatic animal detection and identification, alongside a comprehensive suite of pre- and post-processing tools. This partnership aims to enhance the overall user experience with PyTorch-Wildlife models for a general audience. We will work closely to bring more features together for more efficient and effective wildlife analysis in the future. Please refer to their tutorials on how to use Pytorch-Wildlife models with AddaxAI.
diff --git a/docs/demo_and_ui/gradio.md b/docs/demo_and_ui/gradio.md
index ee7f7b955..4fc83b482 100644
--- a/docs/demo_and_ui/gradio.md
+++ b/docs/demo_and_ui/gradio.md
@@ -24,4 +24,4 @@ pip uninstall opencv-python
conda install -c conda-forge opencv
```
-
\ No newline at end of file
+
diff --git a/docs/demo_and_ui/notebook.md b/docs/demo_and_ui/notebook.md
index 30f4f7405..fc1cf6657 100644
--- a/docs/demo_and_ui/notebook.md
+++ b/docs/demo_and_ui/notebook.md
@@ -4,4 +4,4 @@
To run our demo notebooks and scripts, please go to the `demo` folder and follow the instructions in our [installation page](../installation.md) on how to use jupyter notebooks.
### Online notebooks
-We are currently migrating our demo jupyter notebooks to this document page so you do not have to always run jupyter on your local machines. Please keep updates!
\ No newline at end of file
+We are currently migrating our demo jupyter notebooks to this document page so you do not have to always run jupyter on your local machines. Please keep updates!
diff --git a/docs/demo_and_ui/timelapse.md b/docs/demo_and_ui/timelapse.md
index 4b64fcbfb..5e6665a78 100644
--- a/docs/demo_and_ui/timelapse.md
+++ b/docs/demo_and_ui/timelapse.md
@@ -1,2 +1,2 @@
# Pytorch-Wildlife and TimeLapse
-Pytorch-Wildlife offers native output formats that are directly compatible with TimeLapse. We will provide more details on this in future releases! We will keep you posted!
\ No newline at end of file
+Pytorch-Wildlife offers native output formats that are directly compatible with TimeLapse. We will provide more details on this in future releases! We will keep you posted!
diff --git a/docs/fine_tuning_modules/detection/overview.md b/docs/fine_tuning_modules/detection/overview.md
index d4e4cbb9d..114041389 100644
--- a/docs/fine_tuning_modules/detection/overview.md
+++ b/docs/fine_tuning_modules/detection/overview.md
@@ -1,2 +1,2 @@
-# In Construction
\ No newline at end of file
+# In Construction
diff --git a/docs/fine_tuning_modules/overview.md b/docs/fine_tuning_modules/overview.md
index 262400844..f929476b8 100644
--- a/docs/fine_tuning_modules/overview.md
+++ b/docs/fine_tuning_modules/overview.md
@@ -1 +1 @@
-# In construction
\ No newline at end of file
+# In construction
diff --git a/docs/index.md b/docs/index.md
index eb5d80457..6e93e7f54 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -60,4 +60,3 @@ Please refer to our [installation guide](installation.md) for more installation
### Opossum ID with `MegaDetector` and `AI4GOpossum`

*Credits to the Agency for Regulation and Control of Biosecurity and Quarantine for Galápagos (ABG), Ecuador.*
-
diff --git a/docs/license.md b/docs/license.md
index c089e7740..915cd8b4e 100644
--- a/docs/license.md
+++ b/docs/license.md
@@ -13,4 +13,4 @@ In addition, since the **Pytorch-Wildlife** package is under MIT, all the utilit
```
--8<-- "LICENSE.md"
-```
\ No newline at end of file
+```
diff --git a/docs/megadetector.md b/docs/megadetector.md
index 11991fdb0..acdf38f03 100644
--- a/docs/megadetector.md
+++ b/docs/megadetector.md
@@ -1,3 +1,3 @@
--8<-- "megadetector.md"
-
\ No newline at end of file
+
diff --git a/docs/model_zoo/classifiers.md b/docs/model_zoo/classifiers.md
index f83de560b..5b78baaab 100644
--- a/docs/model_zoo/classifiers.md
+++ b/docs/model_zoo/classifiers.md
@@ -10,4 +10,4 @@
|Deepfaune-New-England|v1.0|CC0 1.0 Universal|Released|[Deepfaune-New-England](https://code.usgs.gov/vtcfwru/deepfaune-new-england)|
>[!TIP]
->Some models, such as MegaDetectorV6, HerdNet, and AI4G-Amazon, have different versions, and they are loaded by their corresponding version names. Here is an example: `detection_model = pw_detection.MegaDetectorV6(version="MDV6-yolov10-e")`.
\ No newline at end of file
+>Some models, such as MegaDetectorV6, HerdNet, and AI4G-Amazon, have different versions, and they are loaded by their corresponding version names. Here is an example: `detection_model = pw_detection.MegaDetectorV6(version="MDV6-yolov10-e")`.
diff --git a/docs/model_zoo/other_detectors.md b/docs/model_zoo/other_detectors.md
index 9223a523d..c3ebdf8af 100644
--- a/docs/model_zoo/other_detectors.md
+++ b/docs/model_zoo/other_detectors.md
@@ -7,4 +7,4 @@
|HerdNet-ennedi|ennedi|CC BY-NC-SA-4.0|Released|[Alexandre et. al. 2023](https://github.com/Alexandre-Delplanque/HerdNet)|
>[!TIP]
->Some models, such as MegaDetectorV6, HerdNet, and AI4G-Amazon, have different versions, and they are loaded by their corresponding version names. Here is an example: `detection_model = pw_detection.MegaDetectorV6(version="MDV6-yolov10-e")`.
\ No newline at end of file
+>Some models, such as MegaDetectorV6, HerdNet, and AI4G-Amazon, have different versions, and they are loaded by their corresponding version names. Here is an example: `detection_model = pw_detection.MegaDetectorV6(version="MDV6-yolov10-e")`.
diff --git a/docs/releases/past_releases.md b/docs/releases/past_releases.md
index 8fd371dab..3d6fc9b1f 100644
--- a/docs/releases/past_releases.md
+++ b/docs/releases/past_releases.md
@@ -6,4 +6,4 @@
- We will also make a new roadmap for 2025 in the next couple of updates.
- Special thanks to [José Díaz](https://github.com/jdiaz97) for his great cross-platform app, [BoquilaHUB](https://github.com/boquila/boquilahub), that is even working on ios and android! Please check his repo out! In the future, we will create a project gallery showcasing projects that use or are build upon Pytorch-Wildlife. If you want your projects to be included, please feel free to reach out to us on [](https://discord.gg/TeEVxzaYtm)
-
\ No newline at end of file
+
diff --git a/docs/releases/release_notes.md b/docs/releases/release_notes.md
index 217637604..f69cc1a85 100644
--- a/docs/releases/release_notes.md
+++ b/docs/releases/release_notes.md
@@ -18,4 +18,3 @@ classification_model = pw_classification.DeepfauneClassifier(device=DEVICE)
#### Deepfaune-New-England in Our Model Zoo Too!!
- Besides the original Deepfaune mode, there is another fine-tuned Deepfaune model developed by USGS for the Northeastern NA area called Deepfaune-New-England (DFNE). It can also be loaded with `classification_model = pw_classification.DFNE(device=DEVICE)`
- Please take a look at the orignal [DFNE repo](https://code.usgs.gov/vtcfwru/deepfaune-new-england/-/tree/main?ref_type=heads) and give them a star!
-
diff --git a/mkdocs.yml b/mkdocs.yml
index d3243286f..fbb8f5c4e 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -181,4 +181,3 @@ plugins:
python:
options:
docstring_style: google
-
diff --git a/setup.py b/setup.py
index 3402d9614..0ed2440bc 100644
--- a/setup.py
+++ b/setup.py
@@ -1,42 +1,42 @@
-from setuptools import setup, find_packages
+from setuptools import find_packages, setup
-with open('README.md', encoding="utf8") as file:
- long_description = file.read()
+with open("README.md", encoding="utf8") as file:
+ long_description = file.read()
setup(
- name='PytorchWildlife',
- version='1.2.4.2',
+ name="PytorchWildlife",
+ version="1.2.4.2",
packages=find_packages(),
include_package_data=True,
package_data={"": ["*.yml"]},
- url='https://github.com/microsoft/CameraTraps/',
- license='MIT',
- author='Andres Hernandez, Zhongqi Miao, Daniela Ruiz Lopez, Isai Daniel Chacon Silva',
- author_email='v-hernandres@microsoft.com, zhongqimiao@microsoft.com, v-druizlopez@microsoft.com, v-ichaconsil@microsoft.com',
- description='a PyTorch Collaborative Deep Learning Framework for Conservation.',
+ url="https://github.com/microsoft/CameraTraps/",
+ license="MIT",
+ author="Andres Hernandez, Zhongqi Miao, Daniela Ruiz Lopez, Isai Daniel Chacon Silva",
+ author_email="v-hernandres@microsoft.com, zhongqimiao@microsoft.com, v-druizlopez@microsoft.com, v-ichaconsil@microsoft.com",
+ description="a PyTorch Collaborative Deep Learning Framework for Conservation.",
long_description=long_description,
- long_description_content_type='text/markdown',
+ long_description_content_type="text/markdown",
install_requires=[
- 'torch',
- 'torchvision',
- 'torchaudio',
- 'tqdm',
- 'Pillow',
- 'supervision==0.23.0',
- 'gradio',
- 'ultralytics',
- 'chardet',
- 'wget',
- 'yolov5',
- 'setuptools',
- 'scikit-learn',
- 'timm',
+ "torch",
+ "torchvision",
+ "torchaudio",
+ "tqdm",
+ "Pillow",
+ "supervision==0.23.0",
+ "gradio",
+ "ultralytics",
+ "chardet",
+ "wget",
+ "yolov5",
+ "setuptools",
+ "scikit-learn",
+ "timm",
],
classifiers=[
- 'Development Status :: 3 - Alpha',
- 'Intended Audience :: Developers',
- 'License :: OSI Approved :: MIT License',
- 'Programming Language :: Python :: 3',
+ "Development Status :: 3 - Alpha",
+ "Intended Audience :: Developers",
+ "License :: OSI Approved :: MIT License",
+ "Programming Language :: Python :: 3",
],
- keywords='pytorch_wildlife, pytorch, wildlife, megadetector, conservation, animal, detection, classification',
- python_requires='>=3.8',
+ keywords="pytorch_wildlife, pytorch, wildlife, megadetector, conservation, animal, detection, classification",
+ python_requires=">=3.8",
)