From 18d53a6fc4f05a1723c666c88ae590238bdd516a Mon Sep 17 00:00:00 2001 From: abishop1990 Date: Sat, 21 Mar 2026 09:43:36 -0700 Subject: [PATCH 1/2] feat: Add pre-commit hooks configuration for code quality Implements automated code quality checks including: - black: Code formatting (PEP 8 compliant) - flake8: Linting (catches style issues, unused imports) - isort: Import sorting and organization - trailing-whitespace, end-of-file-fixer: Cleanup hooks - check-yaml, check-merge-conflict: Safety hooks Also adds DEVELOPMENT.md with comprehensive setup and usage guide for developers. Fixes #600 --- .pre-commit-config.yaml | 51 ++++++++++ DEVELOPMENT.md | 211 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 262 insertions(+) create mode 100644 .pre-commit-config.yaml create mode 100644 DEVELOPMENT.md diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..67de664d4 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,51 @@ +# Pre-commit hooks configuration for PyTorch Wildlife +# Installation: pip install pre-commit +# Setup: pre-commit install +# Run: pre-commit run --all-files + +repos: + # Code formatting + - repo: https://github.com/psf/black + rev: 23.12.1 + hooks: + - id: black + language_version: python3 + args: ['--line-length=100'] + + # Import sorting + - repo: https://github.com/PyCQA/isort + rev: 5.13.2 + hooks: + - id: isort + args: ['--profile=black', '--line-length=100'] + + # Linting + - repo: https://github.com/PyCQA/flake8 + rev: 6.1.0 + hooks: + - id: flake8 + args: ['--max-line-length=100', '--extend-ignore=E203,W503'] + additional_dependencies: ['flake8-docstrings', 'flake8-bugbear'] + + # Built-in pre-commit hooks + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: trailing-whitespace + args: ['--markdown-template'] + - id: end-of-file-fixer + - id: check-yaml + args: ['--unsafe'] + - id: check-added-large-files + args: ['--maxkb=1000'] + - id: check-merge-conflict + - id: debug-statements + + # Type checking (optional but recommended) + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.7.1 + hooks: + - id: mypy + additional_dependencies: ['types-all'] + args: ['--ignore-missing-imports', '--no-warn-unused-ignores'] + stages: [manual] diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md new file mode 100644 index 000000000..da04c597a --- /dev/null +++ b/DEVELOPMENT.md @@ -0,0 +1,211 @@ +# Development Guide for PyTorch Wildlife + +This guide covers setting up your development environment and following our code quality standards. + +## Setup + +### Prerequisites +- Python >= 3.8 +- Git + +### Virtual Environment + +We recommend using a virtual environment: + +```bash +# Create virtual environment +python -m venv venv + +# Activate it +source venv/bin/activate # On Windows: venv\Scripts\activate +``` + +### Install Dependencies + +```bash +# Install package in development mode with all dependencies +pip install -e ".[dev]" + +# Or install core dependencies + development tools +pip install -r requirements.txt +pip install pre-commit black flake8 isort +``` + +## Code Quality + +We use automated tools to maintain consistent code quality across the project. + +### Pre-commit Hooks + +Pre-commit hooks automatically check your code before each commit, catching common issues early. + +**Setup (one-time):** + +```bash +pip install pre-commit +pre-commit install +``` + +Now every time you commit, hooks will run automatically. If any issues are found, the commit will be blocked until you fix them. + +**Running Manually:** + +```bash +# Check all files +pre-commit run --all-files + +# Check specific file +pre-commit run --files path/to/file.py + +# Skip a specific hook +SKIP=flake8 pre-commit run --all-files + +# Update hook versions +pre-commit autoupdate +``` + +### What Each Hook Does + +1. **black** - Automatic code formatting (PEP 8 compliant) + - Enforces consistent style across the codebase + - Max line length: 100 characters + +2. **isort** - Import organization + - Sorts imports alphabetically + - Groups: stdlib → third-party → local + - Compatible with black + +3. **flake8** - Linting + - Checks for common style issues + - Detects unused variables/imports + - Catches potential bugs + +4. **trailing-whitespace** - Removes trailing spaces + +5. **end-of-file-fixer** - Ensures files end with a newline + +6. **check-yaml** - Validates YAML syntax + +7. **check-merge-conflict** - Prevents committed merge conflict markers + +### Fixing Code Issues + +If pre-commit finds issues, you can fix them automatically: + +```bash +# Black and isort will auto-fix formatting/imports +black . +isort . + +# Flake8 requires manual fixes (shows issues, doesn't auto-fix) +flake8 . +``` + +After fixing, commit again: + +```bash +git add . +git commit -m "feat: Your message here" +``` + +### Type Checking (Optional) + +For advanced type checking (optional, not enforced on commit): + +```bash +pre-commit run mypy --all-files +``` + +## Testing + +Run the test suite before submitting PRs: + +```bash +pytest tests/ +``` + +## Committing Code + +Follow these guidelines: + +1. **Create a feature branch:** + ```bash + git checkout -b issue-600-add-precommit-hooks + ``` + +2. **Make changes and test:** + ```bash + # Edit files... + pre-commit run --all-files # Run checks + pytest tests/ # Run tests + ``` + +3. **Commit with clear messages:** + ```bash + git commit -m "feat: Add pre-commit hooks for code quality" + ``` + +4. **Push and create PR:** + ```bash + git push origin issue-600-add-precommit-hooks + ``` + +## Commit Message Format + +We follow conventional commits for clear, organized history: + +``` +type: short description + +Optional longer description explaining the change. + +Fixes #123 +``` + +### Types +- `feat:` - New feature +- `fix:` - Bug fix +- `docs:` - Documentation +- `style:` - Code style (formatting, etc.) +- `refactor:` - Code refactoring without functionality change +- `test:` - Test additions/changes +- `chore:` - Maintenance tasks + +Example: +``` +feat: Add pre-commit hooks for code quality + +Implements black, flake8, and isort for consistent code formatting. +Developers can now use pre-commit hooks to catch issues before commit. + +Fixes #600 +``` + +## Troubleshooting + +**Pre-commit is slow:** +- Large repositories can be slow on first run. Subsequent runs are faster. +- You can run specific hooks: `pre-commit run black --all-files` + +**Hook keeps failing on the same issue:** +- Some hooks (black, isort) auto-fix. Run them to fix automatically. +- For flake8: read the output and fix manually. + +**Need to bypass hooks temporarily:** +- For emergencies only: `git commit --no-verify` +- But please fix issues before pushing! + +**Confused about a hook error:** +- Run the tool directly: `black path/to/file.py` +- Check the hook documentation in `.pre-commit-config.yaml` + +## Resources + +- [Pre-commit Documentation](https://pre-commit.com/) +- [Black Documentation](https://black.readthedocs.io/) +- [Flake8 Documentation](https://flake8.pycqa.org/) +- [isort Documentation](https://pycqa.github.io/isort/) + +## Questions? + +If you have questions about the development setup, open an issue or reach out on Discord: https://discord.gg/TeEVxzaYtm From fe9b8306bb06a7f29a8ba16f987cf41b6d77e7ea Mon Sep 17 00:00:00 2001 From: abishop1990 Date: Sat, 21 Mar 2026 09:43:39 -0700 Subject: [PATCH 2/2] style: Format codebase with pre-commit hooks (black, isort, etc.) Runs the full pre-commit hook suite on the entire codebase to establish a consistent code style baseline. This includes: - black formatting (line length 100) - isort import sorting - flake8 linting - trailing whitespace and EOF fixes - YAML validation This is the baseline formatting - future PRs will use these same hooks to maintain consistency. --- .github/ISSUE_TEMPLATE/bug-report.yml | 2 +- .github/ISSUE_TEMPLATE/feature-request.yml | 2 +- .github/ISSUE_TEMPLATE/question.yml | 2 +- Dockerfile | 1 - MANIFEST.in | 2 +- PW_FT_classification/__init__.py | 2 +- PW_FT_classification/configs/config.yaml | 1 - PW_FT_classification/main.py | 173 +++--- PW_FT_classification/requirements.txt | 2 +- .../src/algorithms/__init__.py | 1 - PW_FT_classification/src/algorithms/plain.py | 203 ++++--- PW_FT_classification/src/algorithms/utils.py | 1 + PW_FT_classification/src/datasets/__init__.py | 2 +- PW_FT_classification/src/datasets/custom.py | 167 ++++-- PW_FT_classification/src/models/__init__.py | 2 +- .../src/models/plain_resnet.py | 47 +- .../src/utils/batch_detection_cropping.py | 17 +- .../src/utils/data_splitting.py | 113 ++-- PW_FT_classification/src/utils/utils.py | 53 +- PW_FT_detection/config.yaml | 4 - PW_FT_detection/main.py | 74 +-- PW_FT_detection/requirements.txt | 2 +- PW_FT_detection/utils.py | 17 +- PytorchWildlife/__init__.py | 3 +- PytorchWildlife/data/__init__.py | 2 +- PytorchWildlife/data/datasets.py | 69 ++- PytorchWildlife/data/transforms.py | 60 ++- PytorchWildlife/models/__init__.py | 2 +- .../models/classification/__init__.py | 2 +- .../models/classification/base_classifier.py | 3 +- .../classification/resnet_base/__init__.py | 4 +- .../classification/resnet_base/amazon.py | 104 ++-- .../resnet_base/base_classifier.py | 59 ++- .../resnet_base/custom_weights.py | 21 +- .../classification/resnet_base/opossum.py | 27 +- .../classification/resnet_base/serengeti.py | 44 +- .../models/classification/timm_base/DFNE.py | 70 +-- .../classification/timm_base/Deepfaune.py | 189 ++++++- .../classification/timm_base/__init__.py | 2 +- .../timm_base/base_classifier.py | 147 +++--- PytorchWildlife/models/detection/__init__.py | 4 +- .../models/detection/base_detector.py | 40 +- .../models/detection/herdnet/__init__.py | 2 +- .../detection/herdnet/animaloc/__init__.py | 3 +- .../herdnet/animaloc/data/__init__.py | 3 +- .../herdnet/animaloc/data/patches.py | 122 +++-- .../detection/herdnet/animaloc/data/types.py | 126 ++--- .../herdnet/animaloc/eval/__init__.py | 5 +- .../detection/herdnet/animaloc/eval/lmds.py | 113 ++-- .../herdnet/animaloc/eval/stitchers.py | 159 +++--- .../models/detection/herdnet/dla.py | 454 ++++++++++------ .../models/detection/herdnet/herdnet.py | 309 ++++++----- .../models/detection/herdnet/model.py | 109 ++-- .../detection/rtdetr_apache/__init__.py | 2 +- .../rtdetr_apache/megadetectorv6_apache.py | 28 +- .../rtdetr_apache/rtdetr_apache_base.py | 144 ++--- .../dataset/megadetector_detection.yml | 2 +- .../rtdetrv2/include/rtdetrv2_r50vd.yml | 1 - .../rtdetrv2_r18vd_120e_megadetector.yml | 2 +- .../rtdetrv2_pytorch/src/__init__.py | 3 +- .../rtdetrv2_pytorch/src/backbone/__init__.py | 2 +- .../rtdetrv2_pytorch/src/backbone/common.py | 52 +- .../rtdetrv2_pytorch/src/backbone/presnet.py | 133 ++--- .../rtdetrv2_pytorch/src/core/__init__.py | 4 +- .../rtdetrv2_pytorch/src/core/_config.py | 107 ++-- .../rtdetrv2_pytorch/src/core/workspace.py | 129 +++-- .../rtdetrv2_pytorch/src/core/yaml_config.py | 36 +- .../rtdetrv2_pytorch/src/core/yaml_utils.py | 45 +- .../rtdetrv2_pytorch/src/rtdetr/__init__.py | 4 +- .../rtdetrv2_pytorch/src/rtdetr/box_ops.py | 8 +- .../rtdetrv2_pytorch/src/rtdetr/denoising.py | 52 +- .../src/rtdetr/hybrid_encoder.py | 211 ++++---- .../rtdetrv2_pytorch/src/rtdetr/rtdetr.py | 48 +- .../src/rtdetr/rtdetr_postprocessor.py | 68 ++- .../src/rtdetr/rtdetrv2_decoder.py | 498 +++++++++++------- .../rtdetrv2_pytorch/src/rtdetr/utils.py | 101 ++-- .../detection/ultralytics_based/Deepfaune.py | 19 +- .../detection/ultralytics_based/__init__.py | 6 +- .../ultralytics_based/megadetectorv5.py | 29 +- .../ultralytics_based/megadetectorv6.py | 42 +- .../megadetectorv6_distributed.py | 42 +- .../ultralytics_based/yolov5_base.py | 97 ++-- .../ultralytics_based/yolov8_base.py | 132 +++-- .../ultralytics_based/yolov8_distributed.py | 151 +++--- .../models/detection/yolo_mit/__init__.py | 2 +- .../detection/yolo_mit/megadetectorv6_mit.py | 32 +- .../detection/yolo_mit/yolo/__init__.py | 7 +- .../models/detection/yolo_mit/yolo/config.py | 3 +- .../detection/yolo_mit/yolo/model/module.py | 38 +- .../detection/yolo_mit/yolo/model/yolo.py | 29 +- .../yolo_mit/yolo/tools/data_augmentation.py | 6 +- .../yolo_mit/yolo/tools/data_loader.py | 80 +-- .../yolo/tools/dataset_preparation.py | 2 - .../yolo_mit/yolo/utils/bounding_box_utils.py | 32 +- .../yolo_mit/yolo/utils/dataset_utils.py | 16 +- .../yolo_mit/yolo/utils/model_utils.py | 7 +- .../detection/yolo_mit/yolo_mit_base.py | 143 ++--- PytorchWildlife/utils/__init__.py | 2 +- PytorchWildlife/utils/misc.py | 31 +- PytorchWildlife/utils/post_process.py | 213 +++++--- README.md | 1 - .../detection_classification_pipeline_demo.py | 97 ++-- demo/gradio_demo.py | 328 ++++++++---- demo/image_demo.py | 62 ++- demo/image_demo_herdnet.py | 49 +- demo/image_separation_demo.py | 56 +- demo/video_demo.py | 49 +- docs/base/data/datasets.md | 2 +- docs/base/data/transforms.md | 2 +- .../models/classification/base_classifier.md | 2 +- .../classification/resnet_base/amazon.md | 2 +- .../resnet_base/base_classifier.md | 2 +- .../resnet_base/custom_weights.md | 2 +- .../classification/resnet_base/opossum.md | 2 +- .../classification/resnet_base/serengeti.md | 2 +- .../models/classification/timm_base/DFNE.md | 2 +- .../classification/timm_base/Deepfaune.md | 2 +- .../timm_base/base_classifier.md | 2 +- docs/base/models/detection/base_detector.md | 2 +- docs/base/models/detection/herdnet.md | 2 +- .../herdnet/animaloc/data/patches.md | 2 +- .../detection/herdnet/animaloc/data/types.md | 2 +- .../detection/herdnet/animaloc/eval/lmds.md | 2 +- .../herdnet/animaloc/eval/stitchers.md | 2 +- docs/base/models/detection/herdnet/dla.md | 2 +- docs/base/models/detection/herdnet/model.md | 2 +- .../detection/ultralytics_based/Deepfaune.md | 2 +- .../ultralytics_based/megadetectorv5.md | 2 +- .../ultralytics_based/megadetectorv6.md | 2 +- .../megadetectorv6_distributed.md | 2 +- .../ultralytics_based/yolov5_base.md | 2 +- .../ultralytics_based/yolov8_base.md | 2 +- .../ultralytics_based/yolov8_distributed.md | 2 +- docs/base/overview.md | 2 +- docs/base/utils/misc.md | 2 +- docs/base/utils/post_process.md | 2 +- docs/build_mkdocs.md | 1 - docs/cite.md | 2 +- docs/contribute.md | 2 +- docs/contributors.md | 2 +- docs/demo_and_ui/demo_data.md | 2 +- docs/demo_and_ui/ecoassist.md | 2 +- docs/demo_and_ui/gradio.md | 2 +- docs/demo_and_ui/notebook.md | 2 +- docs/demo_and_ui/timelapse.md | 2 +- .../fine_tuning_modules/detection/overview.md | 2 +- docs/fine_tuning_modules/overview.md | 2 +- docs/index.md | 1 - docs/license.md | 2 +- docs/megadetector.md | 2 +- docs/model_zoo/classifiers.md | 2 +- docs/model_zoo/other_detectors.md | 2 +- docs/releases/past_releases.md | 2 +- docs/releases/release_notes.md | 1 - mkdocs.yml | 1 - setup.py | 62 +-- 156 files changed, 3889 insertions(+), 2895 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index ca823661f..db3919d59 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -61,4 +61,4 @@ body: description: > (Optional) We encourage you to submit a [Pull Request](https://github.com/microsoft/CameraTraps/pulls) (PR) to help contribute Pytorch-Wildlife for everyone, especially if you have a good understanding of how to implement a fix or feature. options: - - label: Yes I'd like to help by submitting a PR! \ No newline at end of file + - label: Yes I'd like to help by submitting a PR! diff --git a/.github/ISSUE_TEMPLATE/feature-request.yml b/.github/ISSUE_TEMPLATE/feature-request.yml index f26991c4c..576da754b 100644 --- a/.github/ISSUE_TEMPLATE/feature-request.yml +++ b/.github/ISSUE_TEMPLATE/feature-request.yml @@ -46,4 +46,4 @@ body: description: > (Optional) We encourage you to submit a [Pull Request](https://github.com/microsoft/CameraTraps/pulls) (PR) to help contribute Pytorch-Wildlife for everyone, especially if you have a good understanding of how to implement a fix or feature. options: - - label: Yes I'd like to help by submitting a PR! \ No newline at end of file + - label: Yes I'd like to help by submitting a PR! diff --git a/.github/ISSUE_TEMPLATE/question.yml b/.github/ISSUE_TEMPLATE/question.yml index bd601f8f9..040de4107 100644 --- a/.github/ISSUE_TEMPLATE/question.yml +++ b/.github/ISSUE_TEMPLATE/question.yml @@ -30,4 +30,4 @@ body: - type: textarea attributes: label: Additional - description: Anything else you would like to share? \ No newline at end of file + description: Anything else you would like to share? diff --git a/Dockerfile b/Dockerfile index b99a41062..5b06bafcd 100644 --- a/Dockerfile +++ b/Dockerfile @@ -23,4 +23,3 @@ RUN rm -rf /tmp/* RUN pip install --no-cache-dir PytorchWildlife EXPOSE 80 - diff --git a/MANIFEST.in b/MANIFEST.in index 349ed67ad..f11cb4320 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1 @@ -global-include *.yml \ No newline at end of file +global-include *.yml diff --git a/PW_FT_classification/__init__.py b/PW_FT_classification/__init__.py index c18128e6d..4f9811159 100644 --- a/PW_FT_classification/__init__.py +++ b/PW_FT_classification/__init__.py @@ -1 +1 @@ -from batch_detection_cropping import * \ No newline at end of file +from batch_detection_cropping import * diff --git a/PW_FT_classification/configs/config.yaml b/PW_FT_classification/configs/config.yaml index e0c4bb19b..cea9ce95f 100644 --- a/PW_FT_classification/configs/config.yaml +++ b/PW_FT_classification/configs/config.yaml @@ -38,4 +38,3 @@ weight_decay_classifier: 0.0005 ## lr_scheduler step_size: 10 gamma: 0.1 - diff --git a/PW_FT_classification/main.py b/PW_FT_classification/main.py index ec66333a1..c72b4dcfb 100644 --- a/PW_FT_classification/main.py +++ b/PW_FT_classification/main.py @@ -1,40 +1,43 @@ # %% # Importing libraries import os -import yaml -import typer -from munch import Munch + +import pytorch_lightning as pl + # %% import torch -import pytorch_lightning as pl -from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.callbacks import LearningRateMonitor -from pytorch_lightning.loggers import CSVLogger, CometLogger, TensorBoardLogger, WandbLogger +import typer +import yaml +from munch import Munch +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint +from pytorch_lightning.loggers import CometLogger, CSVLogger, TensorBoardLogger, WandbLogger + # %% -from src import algorithms -from src import datasets +from src import algorithms, datasets + # %% -from src.utils import batch_detection_cropping -from src.utils import data_splitting +from src.utils import batch_detection_cropping, data_splitting app = typer.Typer(pretty_exceptions_short=True, pretty_exceptions_show_locals=False) + + # %% @app.command() def main( - config:str='./configs/config.yaml', - project:str='Custom-classification', - gpus:str='0', - logger_type:str='csv', - evaluate:str=None, - np_threads:str='32', - session:int=0, - seed:int=0, - dev:bool=False, - val:bool=False, - test:bool=False, - predict:bool=False, - predict_root:str="" - ): + config: str = "./configs/config.yaml", + project: str = "Custom-classification", + gpus: str = "0", + logger_type: str = "csv", + evaluate: str = None, + np_threads: str = "32", + session: int = 0, + seed: int = 0, + dev: bool = False, + val: bool = False, + test: bool = False, + predict: bool = False, + predict_root: str = "", +): """ Main function for training or evaluating a ResNet model (50 or 18) using PyTorch Lightning. It loads configurations, initializes the model, logger, and other components based on provided arguments. @@ -56,7 +59,7 @@ def main( # GPU configuration: set up GPUs based on availability and user specification gpus = gpus if torch.cuda.is_available() else None - gpus = [int(i) for i in gpus.split(',')] + gpus = [int(i) for i in gpus.split(",")] # Environment variable setup for numpy multi-threading. It is important to avoid cpu and ram issues. os.environ["OMP_NUM_THREADS"] = str(np_threads) @@ -81,87 +84,113 @@ def main( # Replace annotation dir from config with the directory containing the split files conf.annotation_dir = os.path.dirname(conf.split_path) # Split the data according to the split type - if conf.split_type == 'location': - data_splitting.split_by_location(conf.split_path, conf.annotation_dir, conf.test_size, conf.val_size) - elif conf.split_type == 'sequence': - data_splitting.split_by_seq(conf.split_path, conf.annotation_dir, conf.test_size, conf.val_size) - elif conf.split_type == 'random': - data_splitting.create_splits(conf.split_path, conf.annotation_dir, conf.test_size, conf.val_size) + if conf.split_type == "location": + data_splitting.split_by_location( + conf.split_path, conf.annotation_dir, conf.test_size, conf.val_size + ) + elif conf.split_type == "sequence": + data_splitting.split_by_seq( + conf.split_path, conf.annotation_dir, conf.test_size, conf.val_size + ) + elif conf.split_type == "random": + data_splitting.create_splits( + conf.split_path, conf.annotation_dir, conf.test_size, conf.val_size + ) else: - raise ValueError('Invalid split type: {}. Available options: random, location, sequence.'.format(conf.split_type)) - + raise ValueError( + "Invalid split type: {}. Available options: random, location, sequence.".format( + conf.split_type + ) + ) + if not conf.predict: # Get the path to the annotation files, and we only want to do this if we are not predicting if conf.test: - test_annotations = os.path.join(conf.dataset_root, 'test_annotations.csv') + test_annotations = os.path.join(conf.dataset_root, "test_annotations.csv") # Crop test data (most likely we don't need this) - batch_detection_cropping.batch_detection_cropping(conf.dataset_root, os.path.join(conf.dataset_root, "cropped_resized"), test_annotations) + batch_detection_cropping.batch_detection_cropping( + conf.dataset_root, + os.path.join(conf.dataset_root, "cropped_resized"), + test_annotations, + ) else: - train_annotations = os.path.join(conf.dataset_root, 'train_annotations.csv') - val_annotations = os.path.join(conf.dataset_root, 'val_annotations.csv') + train_annotations = os.path.join(conf.dataset_root, "train_annotations.csv") + val_annotations = os.path.join(conf.dataset_root, "val_annotations.csv") # Crop training data - batch_detection_cropping.batch_detection_cropping(conf.dataset_root, os.path.join(conf.dataset_root, "cropped_resized"), train_annotations) + batch_detection_cropping.batch_detection_cropping( + conf.dataset_root, + os.path.join(conf.dataset_root, "cropped_resized"), + train_annotations, + ) # Crop validation data - batch_detection_cropping.batch_detection_cropping(conf.dataset_root, os.path.join(conf.dataset_root, "cropped_resized"), val_annotations) + batch_detection_cropping.batch_detection_cropping( + conf.dataset_root, + os.path.join(conf.dataset_root, "cropped_resized"), + val_annotations, + ) # Dataset and algorithm loading based on the configuration dataset = datasets.__dict__[conf.dataset_name](conf=conf) - learner = algorithms.__dict__[conf.algorithm](conf=conf, - train_class_counts=dataset.train_class_counts, - id_to_labels=dataset.id_to_labels) + learner = algorithms.__dict__[conf.algorithm]( + conf=conf, train_class_counts=dataset.train_class_counts, id_to_labels=dataset.id_to_labels + ) # Logger setup based on the specified logger type - log_folder = 'log_dev' if dev else 'log' + log_folder = "log_dev" if dev else "log" logger = None - if logger_type == 'csv': + if logger_type == "csv": logger = CSVLogger( - save_dir='./{}/{}/{}'.format(log_folder, conf.log_dir, conf.algorithm), + save_dir="./{}/{}/{}".format(log_folder, conf.log_dir, conf.algorithm), prefix=project, - name='{}_{}'.format(conf.algorithm, conf.conf_id), - version=session + name="{}_{}".format(conf.algorithm, conf.conf_id), + version=session, ) - elif logger_type == 'tensorboard': + elif logger_type == "tensorboard": logger = TensorBoardLogger( - save_dir='./{}/{}/{}'.format(log_folder, conf.log_dir, conf.algorithm), + save_dir="./{}/{}/{}".format(log_folder, conf.log_dir, conf.algorithm), prefix=project, - name='{}_{}'.format(conf.algorithm, conf.conf_id), - version=session + name="{}_{}".format(conf.algorithm, conf.conf_id), + version=session, ) - elif logger_type == 'comet': + elif logger_type == "comet": logger = CometLogger( api_key=os.environ.get("COMET_API_KEY"), - save_dir='./{}/{}/{}'.format(log_folder, conf.log_dir, conf.algorithm), - project_name=project, - experiment_name='{}_{}_{}'.format(conf.algorithm, conf.conf_id, session), + save_dir="./{}/{}/{}".format(log_folder, conf.log_dir, conf.algorithm), + project_name=project, + experiment_name="{}_{}_{}".format(conf.algorithm, conf.conf_id, session), ) - elif logger_type == 'wandb': + elif logger_type == "wandb": logger = WandbLogger( - save_dir='./{}/{}/{}'.format(log_folder, conf.log_dir, conf.algorithm), - project=project, - name='{}_{}_{}'.format(conf.algorithm, conf.conf_id, session), + save_dir="./{}/{}/{}".format(log_folder, conf.log_dir, conf.algorithm), + project=project, + name="{}_{}_{}".format(conf.algorithm, conf.conf_id, session), ) # Callbacks for model checkpointing and learning rate monitoring - weights_folder = 'weights_dev' if dev else 'weights' + weights_folder = "weights_dev" if dev else "weights" checkpoint_callback = ModelCheckpoint( - monitor='valid_mac_acc', mode='max', dirpath='./{}/{}/{}'.format(weights_folder, conf.log_dir, conf.algorithm), - save_top_k=1, filename='{}-{}'.format(conf.conf_id, session) + '-{epoch:02d}-{valid_mac_acc:.2f}', verbose=True + monitor="valid_mac_acc", + mode="max", + dirpath="./{}/{}/{}".format(weights_folder, conf.log_dir, conf.algorithm), + save_top_k=1, + filename="{}-{}".format(conf.conf_id, session) + "-{epoch:02d}-{valid_mac_acc:.2f}", + verbose=True, ) - lr_monitor = LearningRateMonitor(logging_interval='step') + lr_monitor = LearningRateMonitor(logging_interval="step") # Trainer configuration in PyTorch Lightning trainer = pl.Trainer( max_epochs=conf.num_epochs, - check_val_every_n_epoch=1, - log_every_n_steps = conf.log_interval, - accelerator='gpu', + check_val_every_n_epoch=1, + log_every_n_steps=conf.log_interval, + accelerator="gpu", devices=gpus, logger=None if evaluate is not None else logger, callbacks=[lr_monitor, checkpoint_callback], - strategy='auto', + strategy="auto", num_sanity_val_steps=0, - profiler=None + profiler=None, ) # Training, validation, or evaluation execution based on the mode if evaluate is not None: @@ -172,11 +201,13 @@ def main( elif test: trainer.test(learner, dataloaders=[dataset.test_dataloader()], ckpt_path=evaluate) else: - print('Invalid mode for evaluation.') + print("Invalid mode for evaluation.") else: trainer.fit(learner, datamodule=dataset) + + # %% -if __name__ == '__main__': +if __name__ == "__main__": app() # %% diff --git a/PW_FT_classification/requirements.txt b/PW_FT_classification/requirements.txt index c20d78a7c..0dbc206c2 100644 --- a/PW_FT_classification/requirements.txt +++ b/PW_FT_classification/requirements.txt @@ -2,4 +2,4 @@ PytorchWildlife scikit_learn lightning munch -typer \ No newline at end of file +typer diff --git a/PW_FT_classification/src/algorithms/__init__.py b/PW_FT_classification/src/algorithms/__init__.py index fa84a4a0c..b79897b14 100644 --- a/PW_FT_classification/src/algorithms/__init__.py +++ b/PW_FT_classification/src/algorithms/__init__.py @@ -1,3 +1,2 @@ from . import utils from .plain import * - diff --git a/PW_FT_classification/src/algorithms/plain.py b/PW_FT_classification/src/algorithms/plain.py index be3212ae0..4c604eaaa 100644 --- a/PW_FT_classification/src/algorithms/plain.py +++ b/PW_FT_classification/src/algorithms/plain.py @@ -1,21 +1,19 @@ -import os -import numpy as np import json -from datetime import datetime -from tqdm import tqdm +import os import random +from datetime import datetime +import numpy as np +import pytorch_lightning as pl import torch import torch.optim as optim -import pytorch_lightning as pl +from src import models +from tqdm import tqdm from .utils import acc -from src import models +__all__ = ["Plain"] -__all__ = [ - 'Plain' -] class Plain(pl.LightningModule): """ @@ -25,7 +23,7 @@ class Plain(pl.LightningModule): and training/validation/testing steps for the training process. """ - name = 'Plain' + name = "Plain" def __init__(self, conf, train_class_counts, id_to_labels, **kwargs): """ @@ -39,11 +37,12 @@ def __init__(self, conf, train_class_counts, id_to_labels, **kwargs): """ super().__init__() self.hparams.update(conf.__dict__) - self.save_hyperparameters(ignore=['conf', 'train_class_counts']) + self.save_hyperparameters(ignore=["conf", "train_class_counts"]) self.train_class_counts = train_class_counts self.id_to_labels = id_to_labels - self.net = models.__dict__[self.hparams.model_name](num_cls=self.hparams.num_classes, - num_layers=self.hparams.num_layers) + self.net = models.__dict__[self.hparams.model_name]( + num_cls=self.hparams.num_classes, num_layers=self.hparams.num_layers + ) def configure_optimizers(self): """ @@ -55,19 +54,25 @@ def configure_optimizers(self): # Define parameters for the optimizer net_optim_params_list = [ # Optimizer parameters for feature extraction - {'params': self.net.feature.parameters(), - 'lr': self.hparams.lr_feature, - 'momentum': self.hparams.momentum_feature, - 'weight_decay': self.hparams.weight_decay_feature}, + { + "params": self.net.feature.parameters(), + "lr": self.hparams.lr_feature, + "momentum": self.hparams.momentum_feature, + "weight_decay": self.hparams.weight_decay_feature, + }, # Optimizer parameters for the classifier - {'params': self.net.classifier.parameters(), - 'lr': self.hparams.lr_classifier, - 'momentum': self.hparams.momentum_classifier, - 'weight_decay': self.hparams.weight_decay_classifier} + { + "params": self.net.classifier.parameters(), + "lr": self.hparams.lr_classifier, + "momentum": self.hparams.momentum_classifier, + "weight_decay": self.hparams.weight_decay_classifier, + }, ] # Setup optimizer and optimizer scheduler optimizer = torch.optim.SGD(net_optim_params_list) - scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=self.hparams.step_size, gamma=self.hparams.gamma) + scheduler = optim.lr_scheduler.StepLR( + optimizer, step_size=self.hparams.step_size, gamma=self.hparams.gamma + ) return [optimizer], [scheduler] def on_train_start(self): @@ -90,14 +95,14 @@ def training_step(self, batch, batch_idx): Tensor: The loss for the current training step. """ data, label_ids = batch[0], batch[1] - + # Forward pass feats = self.net.feature(data) logits = self.net.classifier(feats) # Calculate loss loss = self.net.criterion_cls(logits, label_ids) self.log("train_loss", loss) - + return loss def on_validation_start(self): @@ -119,9 +124,8 @@ def validation_step(self, batch, batch_idx): feats = self.net.feature(data) logits = self.net.classifier(feats) preds = logits.argmax(dim=1) - - self.val_st_outs.append((preds.detach().cpu().numpy(), - label_ids.detach().cpu().numpy())) + + self.val_st_outs.append((preds.detach().cpu().numpy(), label_ids.detach().cpu().numpy())) def on_validation_epoch_end(self): """ @@ -150,14 +154,17 @@ def test_step(self, batch, batch_idx): feats = self.net.feature(data) logits = self.net.classifier(feats) preds = logits.argmax(dim=1) - - self.te_st_outs.append((preds.detach().cpu().numpy(), - label_ids.detach().cpu().numpy(), - feats.detach().cpu().numpy(), - logits.detach().cpu().numpy(), - labels, file_ids - )) - + + self.te_st_outs.append( + ( + preds.detach().cpu().numpy(), + label_ids.detach().cpu().numpy(), + feats.detach().cpu().numpy(), + logits.detach().cpu().numpy(), + labels, + file_ids, + ) + ) def on_test_epoch_end(self): """ @@ -172,14 +179,23 @@ def on_test_epoch_end(self): total_file_ids = np.concatenate([x[5] for x in self.te_st_outs], axis=0) # Calculate the metrics and save the output - self.eval_logging(total_preds[total_label_ids != -1], - total_label_ids[total_label_ids != -1], - print_class_acc=False) - - output_path = self.hparams.evaluate.replace('.ckpt', 'eval.npz') - np.savez(output_path, preds=total_preds, label_ids=total_label_ids, feats=total_feats, - logits=total_logits, labels=total_labels, file_ids=total_file_ids) - print('Test output saved to {}.'.format(output_path)) + self.eval_logging( + total_preds[total_label_ids != -1], + total_label_ids[total_label_ids != -1], + print_class_acc=False, + ) + + output_path = self.hparams.evaluate.replace(".ckpt", "eval.npz") + np.savez( + output_path, + preds=total_preds, + label_ids=total_label_ids, + feats=total_feats, + logits=total_logits, + labels=total_labels, + file_ids=total_file_ids, + ) + print("Test output saved to {}.".format(output_path)) def on_predict_start(self): """ @@ -201,14 +217,16 @@ def predict_step(self, batch, batch_idx): logits = self.net.classifier(feats) preds = logits.argmax(dim=1) probs = torch.softmax(logits, dim=1).max(dim=1)[0] - - self.pr_st_outs.append((preds.detach().cpu().numpy(), - feats.detach().cpu().numpy(), - logits.detach().cpu().numpy(), - probs.detach().cpu().numpy(), - file_ids - )) - + + self.pr_st_outs.append( + ( + preds.detach().cpu().numpy(), + feats.detach().cpu().numpy(), + logits.detach().cpu().numpy(), + probs.detach().cpu().numpy(), + file_ids, + ) + ) def on_predict_epoch_end(self): """ @@ -223,25 +241,31 @@ def on_predict_epoch_end(self): json_output = [] for i in range(len(total_preds)): - json_output.append({ - "marker_id": "", - "survey_pic_id": total_file_ids[i], - "marker_confidence": float(total_probs[i]), - "marker_gear_type": "ghostnet" if total_preds[i] == 1 else "neg", - "marker_bounding_polygon": "", - "marker_status": "unverified", - "marker_ai_model": "" - }) - - output_path_full = self.hparams.evaluate.replace('.ckpt', '_predict.npz') - np.savez(output_path_full, preds=total_preds, feats=total_feats, - logits=total_logits, file_ids=total_file_ids) - print('Predict output saved to {}.'.format(output_path_full)) - - output_path_json = self.hparams.evaluate.replace('.ckpt', '_predict.json') - json.dump(json_output, open(output_path_json, 'w')) - print('Predict output json saved to {}.'.format(output_path_json)) - + json_output.append( + { + "marker_id": "", + "survey_pic_id": total_file_ids[i], + "marker_confidence": float(total_probs[i]), + "marker_gear_type": "ghostnet" if total_preds[i] == 1 else "neg", + "marker_bounding_polygon": "", + "marker_status": "unverified", + "marker_ai_model": "", + } + ) + + output_path_full = self.hparams.evaluate.replace(".ckpt", "_predict.npz") + np.savez( + output_path_full, + preds=total_preds, + feats=total_feats, + logits=total_logits, + file_ids=total_file_ids, + ) + print("Predict output saved to {}.".format(output_path_full)) + + output_path_json = self.hparams.evaluate.replace(".ckpt", "_predict.json") + json.dump(json_output, open(output_path_json, "w")) + print("Predict output json saved to {}.".format(output_path_json)) def eval_logging(self, preds, labels, print_class_acc=False): """ @@ -259,27 +283,32 @@ def eval_logging(self, preds, labels, print_class_acc=False): self.log("valid_mic_acc", mic_acc * 100) if print_class_acc: - if self.train_class_counts: - acc_list = [(class_acc[i], unique_eval_labels[i], - self.id_to_labels[unique_eval_labels[i]], - self.train_class_counts[unique_eval_labels[i]]) - for i in range(len(class_acc))] - - print('\n') + acc_list = [ + ( + class_acc[i], + unique_eval_labels[i], + self.id_to_labels[unique_eval_labels[i]], + self.train_class_counts[unique_eval_labels[i]], + ) + for i in range(len(class_acc)) + ] + + print("\n") for i in range(len(class_acc)): - info = '{:>20} ({:<3}, tr {:>3}) Acc: '.format(acc_list[i][2], - acc_list[i][1], - acc_list[i][3]) - info += '{:.2f}'.format(acc_list[i][0] * 100) + info = "{:>20} ({:<3}, tr {:>3}) Acc: ".format( + acc_list[i][2], acc_list[i][1], acc_list[i][3] + ) + info += "{:.2f}".format(acc_list[i][0] * 100) print(info) else: - acc_list = [(class_acc[i], unique_eval_labels[i], - self.id_to_labels[unique_eval_labels[i]]) - for i in range(len(class_acc))] + acc_list = [ + (class_acc[i], unique_eval_labels[i], self.id_to_labels[unique_eval_labels[i]]) + for i in range(len(class_acc)) + ] - print('\n') + print("\n") for i in range(len(class_acc)): - info = '{:>20} ({:<3}) Acc: '.format(acc_list[i][2], acc_list[i][1]) - info += '{:.2f}'.format(acc_list[i][0] * 100) + info = "{:>20} ({:<3}) Acc: ".format(acc_list[i][2], acc_list[i][1]) + info += "{:.2f}".format(acc_list[i][0] * 100) print(info) diff --git a/PW_FT_classification/src/algorithms/utils.py b/PW_FT_classification/src/algorithms/utils.py index 11a14da33..60804fb26 100644 --- a/PW_FT_classification/src/algorithms/utils.py +++ b/PW_FT_classification/src/algorithms/utils.py @@ -1,5 +1,6 @@ from sklearn.metrics import confusion_matrix + def acc(preds, labels): """ Calculate the accuracy metrics based on predictions and true labels. diff --git a/PW_FT_classification/src/datasets/__init__.py b/PW_FT_classification/src/datasets/__init__.py index a04dcce14..8ef4d4c1a 100644 --- a/PW_FT_classification/src/datasets/__init__.py +++ b/PW_FT_classification/src/datasets/__init__.py @@ -1 +1 @@ -from .custom import * \ No newline at end of file +from .custom import * diff --git a/PW_FT_classification/src/datasets/custom.py b/PW_FT_classification/src/datasets/custom.py index cbac698a2..38152e741 100644 --- a/PW_FT_classification/src/datasets/custom.py +++ b/PW_FT_classification/src/datasets/custom.py @@ -1,29 +1,33 @@ # Import necessary libraries import os from glob import glob + import numpy as np import pandas as pd +import pytorch_lightning as pl import torch from PIL import Image +from torch.utils.data import DataLoader, Dataset from torchvision import transforms -from torch.utils.data import Dataset, DataLoader -import pytorch_lightning as pl # Exportable class names for external use -__all__ = [ - 'Custom_Crop' -] - -# Define the allowed image extensions -IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") - -def has_file_allowed_extension(filename: str, extensions: tuple) -> bool: - """Checks if a file is an allowed extension.""" - return filename.lower().endswith(extensions if isinstance(extensions, str) else tuple(extensions)) - -def is_image_file(filename: str) -> bool: - """Checks if a file is an allowed image extension.""" - return has_file_allowed_extension(filename, IMG_EXTENSIONS) +__all__ = ["Custom_Crop"] + +# Define the allowed image extensions +IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") + + +def has_file_allowed_extension(filename: str, extensions: tuple) -> bool: + """Checks if a file is an allowed extension.""" + return filename.lower().endswith( + extensions if isinstance(extensions, str) else tuple(extensions) + ) + + +def is_image_file(filename: str) -> bool: + """Checks if a file is an allowed image extension.""" + return has_file_allowed_extension(filename, IMG_EXTENSIONS) + # Define normalization mean and standard deviation for image preprocessing mean = [0.485, 0.456, 0.406] @@ -31,21 +35,22 @@ def is_image_file(filename: str) -> bool: # Define data transformations for training and validation datasets data_transforms = { - 'train': transforms.Compose([ - transforms.RandomResizedCrop((224, 224), scale=(0.7, 1.0), ratio=(0.8, 1.2)), - transforms.RandomHorizontalFlip(p=0.5), - transforms.RandomVerticalFlip(p=0.5), - transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2), - transforms.ToTensor(), - transforms.Normalize(mean, std) - ]), - 'val': transforms.Compose([ - transforms.Resize((224, 224)), - transforms.ToTensor(), - transforms.Normalize(mean, std) - ]), + "train": transforms.Compose( + [ + transforms.RandomResizedCrop((224, 224), scale=(0.7, 1.0), ratio=(0.8, 1.2)), + transforms.RandomHorizontalFlip(p=0.5), + transforms.RandomVerticalFlip(p=0.5), + transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2), + transforms.ToTensor(), + transforms.Normalize(mean, std), + ] + ), + "val": transforms.Compose( + [transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean, std)] + ), } + class Custom_Base_DS(Dataset): """ Base dataset class for handling custom datasets. @@ -80,13 +85,18 @@ def load_data(self): if self.predict: # Load data for prediction # self.data = glob(os.path.join(self.img_root,"*.{}".format(self.extension))) - self.data = [os.path.join(dp, f) for dp, dn, filenames in os.walk(self.img_root) for f in filenames if is_image_file(f)] # dp: directory path, dn: directory name, f: filename + self.data = [ + os.path.join(dp, f) + for dp, dn, filenames in os.walk(self.img_root) + for f in filenames + if is_image_file(f) + ] # dp: directory path, dn: directory name, f: filename else: # Load data for training/validation - self.data = list(self.ann['path']) - self.label_ids = list(self.ann['classification']) - self.labels = list(self.ann['label']) - print('Number of images loaded: ', len(self.data)) + self.data = list(self.ann["path"]) + self.label_ids = list(self.ann["classification"]) + self.labels = list(self.ann["label"]) + print("Number of images loaded: ", len(self.data)) def class_counts_cal(self): """ @@ -120,8 +130,8 @@ def __getitem__(self, index): file_id = self.data[index] file_dir = os.path.join(self.img_root, file_id) if not self.predict else file_id - with open(file_dir, 'rb') as f: - sample = Image.open(f).convert('RGB') + with open(file_dir, "rb") as f: + sample = Image.open(f).convert("RGB") if self.transform is not None: sample = self.transform(sample) @@ -142,7 +152,7 @@ class Custom_Crop_DS(Custom_Base_DS): Inherits from Custom_Base_DS and includes specific handling for cropped data. """ - def __init__(self, rootdir, dset='train', transform=None): + def __init__(self, rootdir, dset="train", transform=None): """ Initialize the Custom_Crop_DS with the dataset directory, type, and transformations. @@ -151,12 +161,17 @@ def __init__(self, rootdir, dset='train', transform=None): dset (str): Type of dataset (train, val, test, predict). transform (callable, optional): Transformations to be applied to each data sample. """ - self.predict = dset == 'predict' + self.predict = dset == "predict" super().__init__(rootdir=rootdir, transform=transform, predict=self.predict) - self.img_root = rootdir if self.predict else os.path.join(self.rootdir, 'cropped_resized') + self.img_root = rootdir if self.predict else os.path.join(self.rootdir, "cropped_resized") if not self.predict: - self.ann = pd.read_csv(os.path.join(self.rootdir, 'cropped_resized', '{}_annotations_cropped.csv' - .format('test' if dset == 'test' else dset))) + self.ann = pd.read_csv( + os.path.join( + self.rootdir, + "cropped_resized", + "{}_annotations_cropped.csv".format("test" if dset == "test" else dset), + ) + ) self.load_data() @@ -178,27 +193,41 @@ def __init__(self, conf): """ super().__init__() self._log_hyperparams = True - self.id_to_labels = None # We don't need this for evaluations. We should save this in model weights in the future + self.id_to_labels = None # We don't need this for evaluations. We should save this in model weights in the future self.train_class_counts = None self.conf = conf - print('Loading datasets...') + print("Loading datasets...") # Load datasets for different modes (training, validation, testing, prediction) if self.conf.predict: - self.dset_pr = self.ds(rootdir=self.conf.predict_root, dset='predict', transform=data_transforms['val']) + self.dset_pr = self.ds( + rootdir=self.conf.predict_root, dset="predict", transform=data_transforms["val"] + ) elif self.conf.test: - self.dset_te = self.ds(rootdir=self.conf.dataset_root, dset='test', transform=data_transforms['val']) - self.id_to_labels = {i: l for i, l in np.unique(pd.Series(zip(self.dset_te.label_ids, self.dset_te.labels)))} + self.dset_te = self.ds( + rootdir=self.conf.dataset_root, dset="test", transform=data_transforms["val"] + ) + self.id_to_labels = { + i: l + for i, l in np.unique(pd.Series(zip(self.dset_te.label_ids, self.dset_te.labels))) + } else: - self.dset_tr = self.ds(rootdir=self.conf.dataset_root, dset='train', transform=data_transforms['train']) - self.dset_val = self.ds(rootdir=self.conf.dataset_root, dset='val', transform=data_transforms['val']) - - self.id_to_labels = {i: l for i, l in np.unique(pd.Series(zip(self.dset_tr.label_ids, self.dset_tr.labels)))} + self.dset_tr = self.ds( + rootdir=self.conf.dataset_root, dset="train", transform=data_transforms["train"] + ) + self.dset_val = self.ds( + rootdir=self.conf.dataset_root, dset="val", transform=data_transforms["val"] + ) + + self.id_to_labels = { + i: l + for i, l in np.unique(pd.Series(zip(self.dset_tr.label_ids, self.dset_tr.labels))) + } # Calculate class counts and label mappings self.unique_label_ids, self.train_class_counts = self.dset_tr.class_counts_cal() - print('Datasets loaded.') + print("Datasets loaded.") def train_dataloader(self): """ @@ -207,7 +236,14 @@ def train_dataloader(self): Returns: DataLoader: DataLoader for the training dataset. """ - return DataLoader(self.dset_tr, batch_size=self.conf.batch_size, shuffle=True, pin_memory=True, num_workers=self.conf.num_workers, drop_last=False) + return DataLoader( + self.dset_tr, + batch_size=self.conf.batch_size, + shuffle=True, + pin_memory=True, + num_workers=self.conf.num_workers, + drop_last=False, + ) def val_dataloader(self): """ @@ -216,7 +252,14 @@ def val_dataloader(self): Returns: DataLoader: DataLoader for the validation dataset. """ - return DataLoader(self.dset_val, batch_size=self.conf.batch_size, shuffle=False, pin_memory=True, num_workers=self.conf.num_workers, drop_last=False) + return DataLoader( + self.dset_val, + batch_size=self.conf.batch_size, + shuffle=False, + pin_memory=True, + num_workers=self.conf.num_workers, + drop_last=False, + ) def test_dataloader(self): """ @@ -225,7 +268,14 @@ def test_dataloader(self): Returns: DataLoader: DataLoader for the testing dataset. """ - return DataLoader(self.dset_te, batch_size=256, shuffle=False, pin_memory=True, num_workers=self.conf.num_workers, drop_last=False) + return DataLoader( + self.dset_te, + batch_size=256, + shuffle=False, + pin_memory=True, + num_workers=self.conf.num_workers, + drop_last=False, + ) def predict_dataloader(self): """ @@ -234,7 +284,14 @@ def predict_dataloader(self): Returns: DataLoader: DataLoader for the prediction dataset. """ - return DataLoader(self.dset_pr, batch_size=64, shuffle=False, pin_memory=True, num_workers=self.conf.num_workers, drop_last=False) + return DataLoader( + self.dset_pr, + batch_size=64, + shuffle=False, + pin_memory=True, + num_workers=self.conf.num_workers, + drop_last=False, + ) class Custom_Crop(Custom_Base): diff --git a/PW_FT_classification/src/models/__init__.py b/PW_FT_classification/src/models/__init__.py index 70be09a93..9e8b093ac 100644 --- a/PW_FT_classification/src/models/__init__.py +++ b/PW_FT_classification/src/models/__init__.py @@ -1 +1 @@ -from .plain_resnet import * \ No newline at end of file +from .plain_resnet import * diff --git a/PW_FT_classification/src/models/plain_resnet.py b/PW_FT_classification/src/models/plain_resnet.py index d664c91b7..e8fd61113 100644 --- a/PW_FT_classification/src/models/plain_resnet.py +++ b/PW_FT_classification/src/models/plain_resnet.py @@ -1,26 +1,25 @@ -import os import copy +import os from collections import OrderedDict + import torch import torch.nn as nn -from torchvision.models.resnet import BasicBlock, Bottleneck from torchvision.models.resnet import * - +from torchvision.models.resnet import BasicBlock, Bottleneck try: from torch.hub import load_state_dict_from_url except ImportError: from torch.utils.model_zoo import load_state_dict_from_url # Exportable class names for external use -__all__ = [ - 'PlainResNetClassifier' -] +__all__ = ["PlainResNetClassifier"] model_urls = { - 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth', - 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth' + "resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth", + "resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth", } + class ResNetBackbone(ResNet): """ Custom ResNet backbone class for feature extraction. @@ -93,7 +92,7 @@ class PlainResNetClassifier(nn.Module): Extends nn.Module and provides a complete ResNet-based classifier, including feature extraction and classification layers. """ - name = 'PlainResNetClassifier' + name = "PlainResNetClassifier" def __init__(self, num_cls=10, num_layers=18): """ @@ -123,17 +122,19 @@ def setup_net(self): if self.num_layers == 18: block = BasicBlock layers = [2, 2, 2, 2] - #self.pretrained_weights = ResNet18_Weights.IMAGENET1K_V1 - self.pretrained_weights = state_dict = load_state_dict_from_url(model_urls['resnet18'], - progress=True) + # self.pretrained_weights = ResNet18_Weights.IMAGENET1K_V1 + self.pretrained_weights = state_dict = load_state_dict_from_url( + model_urls["resnet18"], progress=True + ) elif self.num_layers == 50: block = Bottleneck layers = [3, 4, 6, 3] - #self.pretrained_weights = ResNet50_Weights.IMAGENET1K_V1 - self.pretrained_weights = state_dict = load_state_dict_from_url(model_urls['resnet50'], - progress=True) + # self.pretrained_weights = ResNet50_Weights.IMAGENET1K_V1 + self.pretrained_weights = state_dict = load_state_dict_from_url( + model_urls["resnet50"], progress=True + ) else: - raise Exception('ResNet Type not supported.') + raise Exception("ResNet Type not supported.") # Constructing the feature extractor and classifier self.feature = ResNetBackbone(block, layers, **kwargs) @@ -151,10 +152,14 @@ def feat_init(self): Initialize the feature extractor with pre-trained weights. """ # Load pre-trained weights and adjust for the current model - #init_weights = self.pretrained_weights.get_state_dict(progress=True) + # init_weights = self.pretrained_weights.get_state_dict(progress=True) init_weights = self.pretrained_weights - init_weights = OrderedDict({k.replace('module.', '').replace('feature.', ''): init_weights[k] - for k in init_weights}) + init_weights = OrderedDict( + { + k.replace("module.", "").replace("feature.", ""): init_weights[k] + for k in init_weights + } + ) # Load the weights into the feature extractor self.feature.load_state_dict(init_weights, strict=False) @@ -165,5 +170,5 @@ def feat_init(self): missing_keys = self_keys - load_keys unused_keys = load_keys - self_keys - print('missing keys: {}'.format(sorted(list(missing_keys)))) - print('unused_keys: {}'.format(sorted(list(unused_keys)))) + print("missing keys: {}".format(sorted(list(missing_keys)))) + print("unused_keys: {}".format(sorted(list(unused_keys)))) diff --git a/PW_FT_classification/src/utils/batch_detection_cropping.py b/PW_FT_classification/src/utils/batch_detection_cropping.py index a89301486..6332a3942 100644 --- a/PW_FT_classification/src/utils/batch_detection_cropping.py +++ b/PW_FT_classification/src/utils/batch_detection_cropping.py @@ -3,16 +3,20 @@ """ Demo for batch detection, cropping and resizing""" -#%% -# PyTorch imports +# %% +# PyTorch imports import torch -# Importing the model, dataset, transformations and utility functions from PytorchWildlife -from PytorchWildlife.models import detection as pw_detection -from PytorchWildlife.data import transforms as pw_trans -from PytorchWildlife.data import datasets as pw_data + # Importing the utility function for saving cropped images from src.utils import utils +from PytorchWildlife.data import datasets as pw_data +from PytorchWildlife.data import transforms as pw_trans + +# Importing the model, dataset, transformations and utility functions from PytorchWildlife +from PytorchWildlife.models import detection as pw_detection + + def batch_detection_cropping(folder_path, output_path, annotation_file): # Setting the device to use for computations ('cuda' indicates GPU) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" @@ -29,5 +33,4 @@ def batch_detection_cropping(folder_path, output_path, annotation_file): return crop_annotation_path - # %% diff --git a/PW_FT_classification/src/utils/data_splitting.py b/PW_FT_classification/src/utils/data_splitting.py index f667c6247..f5b37a855 100644 --- a/PW_FT_classification/src/utils/data_splitting.py +++ b/PW_FT_classification/src/utils/data_splitting.py @@ -1,20 +1,22 @@ ## DATA SPLITTING +import os + import pandas as pd from sklearn.model_selection import train_test_split -import os from tqdm import tqdm + def create_splits(csv_path, output_folder, test_size=0.2, val_size=0.1): """ Create stratified training, validation, and testing splits. - + Args: - csv_path (str): Path to the csv containing the annotations. - output_folder (str): Destination directory to save the annotation split csv files. - test_size (float): Proportion of the dataset to include in the test split. - val_size (float): Proportion of the training dataset to include in the validation split. - + Returns: - A tuple of DataFrames: (train_set, val_set, test_set) - Saves the splits into separate csv files in the output_folder. @@ -22,41 +24,46 @@ def create_splits(csv_path, output_folder, test_size=0.2, val_size=0.1): # Load the data from the csv file data = pd.read_csv(csv_path) # Separate the features and the targets - X = data[['path','label']] - y = data['classification'] - + X = data[["path", "label"]] + y = data["classification"] + # First split to separate out the test set - X_temp, X_test, y_temp, y_test = train_test_split(X, y, test_size=test_size, stratify=y, random_state=42) - + X_temp, X_test, y_temp, y_test = train_test_split( + X, y, test_size=test_size, stratify=y, random_state=42 + ) + # Adjust val_size to account for the initial split val_size_adjusted = val_size / (1 - test_size) - + # Second split to separate out the validation set - X_train, X_val, y_train, y_val = train_test_split(X_temp, y_temp, test_size=val_size_adjusted, stratify=y_temp, random_state=42) - + X_train, X_val, y_train, y_val = train_test_split( + X_temp, y_temp, test_size=val_size_adjusted, stratify=y_temp, random_state=42 + ) + # Combine features, labels, and classification back into dataframes train_set = pd.concat([X_train.reset_index(drop=True), y_train.reset_index(drop=True)], axis=1) val_set = pd.concat([X_val.reset_index(drop=True), y_val.reset_index(drop=True)], axis=1) test_set = pd.concat([X_test.reset_index(drop=True), y_test.reset_index(drop=True)], axis=1) - + # Create the output directory in case that it does not exist os.makedirs(output_folder, exist_ok=True) # Save the splits to new CSV files - train_set.to_csv(os.path.join(output_folder,'train_annotations.csv'), index=False) - val_set.to_csv(os.path.join(output_folder,'val_annotations.csv'), index=False) - test_set.to_csv(os.path.join(output_folder,'test_annotations.csv'), index=False) + train_set.to_csv(os.path.join(output_folder, "train_annotations.csv"), index=False) + val_set.to_csv(os.path.join(output_folder, "val_annotations.csv"), index=False) + test_set.to_csv(os.path.join(output_folder, "test_annotations.csv"), index=False) # Return the dataframes return train_set, val_set, test_set + def split_by_location(csv_path, output_folder, val_size=0.15, test_size=0.15, random_state=None): """ Splits the dataset into train, validation, and test sets based on location, ensuring that: 1. All images from the same location are in the same split. 2. The split is random among the locations. 3. Saves the split datasets into CSV files. - + Parameters: - csv_path: Path to the csv containing the annotations. - train_size, val_size, test_size: float, proportions of the dataset to include in the train, validation, and test splits. @@ -67,27 +74,31 @@ def split_by_location(csv_path, output_folder, val_size=0.15, test_size=0.15, ra # Calculate train size based on val and test size train_size = 1.0 - val_size - test_size - + # Get unique locations - unique_locations = data['Location'].unique() + unique_locations = data["Location"].unique() # Split locations into train and temp (temporary holding for val and test) - train_locs, temp_locs = train_test_split(unique_locations, train_size=train_size, random_state=random_state) - + train_locs, temp_locs = train_test_split( + unique_locations, train_size=train_size, random_state=random_state + ) + # Adjust the proportions for val and test based on the remaining locations temp_size = val_size / (val_size + test_size) - val_locs, test_locs = train_test_split(temp_locs, train_size=temp_size, random_state=random_state) - + val_locs, test_locs = train_test_split( + temp_locs, train_size=temp_size, random_state=random_state + ) + # Allocate images to train, validation, and test sets based on their location - train_data = data[data['Location'].isin(train_locs)] - val_data = data[data['Location'].isin(val_locs)] - test_data = data[data['Location'].isin(test_locs)] - + train_data = data[data["Location"].isin(train_locs)] + val_data = data[data["Location"].isin(val_locs)] + test_data = data[data["Location"].isin(test_locs)] + # Save the datasets to CSV files - train_data.to_csv(os.path.join(output_folder,'train_annotations.csv'), index=False) - val_data.to_csv(os.path.join(output_folder,'val_annotations.csv'), index=False) - test_data.to_csv(os.path.join(output_folder,'test_annotations.csv'), index=False) - + train_data.to_csv(os.path.join(output_folder, "train_annotations.csv"), index=False) + val_data.to_csv(os.path.join(output_folder, "val_annotations.csv"), index=False) + test_data.to_csv(os.path.join(output_folder, "test_annotations.csv"), index=False) + # Return the split datasets return train_data, val_data, test_data @@ -98,7 +109,7 @@ def split_by_seq(csv_path, output_folder, val_size=0.15, test_size=0.15, random_ 1. All images from the same sequence are in the same split. 2. The split is random among the sequences. 3. Saves the split datasets into CSV files. - + Parameters: - csv_path: Path to the csv containing the annotations. - train_size, val_size, test_size: float, proportions of the dataset to include in the train, validation, and test splits. @@ -108,40 +119,44 @@ def split_by_seq(csv_path, output_folder, val_size=0.15, test_size=0.15, random_ data = pd.read_csv(csv_path) # Convert 'Photo_Time' from string to datetime - data['Photo_Time'] = pd.to_datetime(data['Photo_Time']) + data["Photo_Time"] = pd.to_datetime(data["Photo_Time"]) # Calculate train size based on val and test size train_size = 1 - val_size - test_size - + # Sort by 'Photo_Time' to ensure chronological order - data = data.sort_values(by=['Photo_Time']).reset_index(drop=True) + data = data.sort_values(by=["Photo_Time"]).reset_index(drop=True) # Group photos into sequences based on a 30-second interval - time_groups = data.groupby(pd.Grouper(key='Photo_Time', freq='30S')) + time_groups = data.groupby(pd.Grouper(key="Photo_Time", freq="30S")) # Assign unique sequence IDs to each group for s, i in tqdm(enumerate(time_groups.indices.values())): - data.loc[i, 'Seq_ID'] = int(s) + data.loc[i, "Seq_ID"] = int(s) # Get unique sequence IDs - unique_seq_ids = data['Seq_ID'].unique() - + unique_seq_ids = data["Seq_ID"].unique() + # Split sequence IDs into train and temp (temporary holding for val and test) - train_seq_ids, temp_seq_ids = train_test_split(unique_seq_ids, train_size=train_size, random_state=random_state) - + train_seq_ids, temp_seq_ids = train_test_split( + unique_seq_ids, train_size=train_size, random_state=random_state + ) + # Adjust the proportions for val and test based on the remaining sequences temp_size = val_size / (val_size + test_size) - val_seq_ids, test_seq_ids = train_test_split(temp_seq_ids, train_size=temp_size, random_state=random_state) - + val_seq_ids, test_seq_ids = train_test_split( + temp_seq_ids, train_size=temp_size, random_state=random_state + ) + # Allocate images to train, validation, and test sets based on their sequence ID - train_data = data[data['Seq_ID'].isin(train_seq_ids)] - val_data = data[data['Seq_ID'].isin(val_seq_ids)] - test_data = data[data['Seq_ID'].isin(test_seq_ids)] - + train_data = data[data["Seq_ID"].isin(train_seq_ids)] + val_data = data[data["Seq_ID"].isin(val_seq_ids)] + test_data = data[data["Seq_ID"].isin(test_seq_ids)] + # Save the datasets to CSV files - train_data.to_csv(os.path.join(output_folder,'train_annotations.csv'), index=False) - val_data.to_csv(os.path.join(output_folder,'val_annotations.csv'), index=False) - test_data.to_csv(os.path.join(output_folder,'test_annotations.csv'), index=False) + train_data.to_csv(os.path.join(output_folder, "train_annotations.csv"), index=False) + val_data.to_csv(os.path.join(output_folder, "val_annotations.csv"), index=False) + test_data.to_csv(os.path.join(output_folder, "test_annotations.csv"), index=False) # Return the split datasets return train_data, val_data, test_data diff --git a/PW_FT_classification/src/utils/utils.py b/PW_FT_classification/src/utils/utils.py index 0b359a541..7707dc9b4 100644 --- a/PW_FT_classification/src/utils/utils.py +++ b/PW_FT_classification/src/utils/utils.py @@ -1,9 +1,11 @@ import os -import pandas as pd + import cv2 -import supervision as sv -from PIL import Image import numpy as np +import pandas as pd +import supervision as sv +from PIL import Image + def save_crop_images(results, output_dir, original_csv_path, overwrite=False): """ @@ -29,44 +31,51 @@ def save_crop_images(results, output_dir, original_csv_path, overwrite=False): # Prepare a list to store new records for the new CSV new_records = [] - + os.makedirs(output_dir, exist_ok=True) with sv.ImageSink(target_dir_path=output_dir, overwrite=overwrite) as sink: for entry in results: # Process the data if the name of the file is in the dataframe - if os.path.basename(entry["img_id"]) in original_df['path'].values: - for i, (xyxy, cat) in enumerate(zip(entry["detections"].xyxy, entry["detections"].class_id)): + if os.path.basename(entry["img_id"]) in original_df["path"].values: + for i, (xyxy, cat) in enumerate( + zip(entry["detections"].xyxy, entry["detections"].class_id) + ): cropped_img = sv.crop_image( image=np.array(Image.open(entry["img_id"]).convert("RGB")), xyxy=xyxy ) new_img_name = "{}_{}_{}".format( - int(cat), i, entry["img_id"].rsplit(os.sep, 1)[1]) + int(cat), i, entry["img_id"].rsplit(os.sep, 1)[1] + ) sink.save_image( - image=cv2.cvtColor(cropped_img, cv2.COLOR_RGB2BGR), - image_name=new_img_name - ), - + image=cv2.cvtColor(cropped_img, cv2.COLOR_RGB2BGR), image_name=new_img_name + ), + # Save the crop into a new csv - image_name = entry['img_id'] - - classification_id = original_df[original_df['path'].str.endswith(image_name.split(os.sep)[-1])]['classification'].values[0] - classification_name = original_df[original_df['path'].str.endswith(image_name.split(os.sep)[-1])]['label'].values[0] + image_name = entry["img_id"] + + classification_id = original_df[ + original_df["path"].str.endswith(image_name.split(os.sep)[-1]) + ]["classification"].values[0] + classification_name = original_df[ + original_df["path"].str.endswith(image_name.split(os.sep)[-1]) + ]["label"].values[0] # Add record to the new CSV data - new_records.append({ - 'path': new_img_name, - 'classification': classification_id, - 'label': classification_name - }) + new_records.append( + { + "path": new_img_name, + "classification": classification_id, + "label": classification_name, + } + ) # Create a DataFrame from the new records new_df = pd.DataFrame(new_records) # Define the path for the new CSV file - new_file_name = "{}_cropped.csv".format(original_csv_path.split(os.sep)[-1].split('.')[0]) + new_file_name = "{}_cropped.csv".format(original_csv_path.split(os.sep)[-1].split(".")[0]) new_csv_path = os.path.join(output_dir, new_file_name) # Save the new DataFrame to CSV new_df.to_csv(new_csv_path, index=False) return new_csv_path - diff --git a/PW_FT_detection/config.yaml b/PW_FT_detection/config.yaml index d4c9f5485..05357a0ad 100644 --- a/PW_FT_detection/config.yaml +++ b/PW_FT_detection/config.yaml @@ -25,7 +25,3 @@ save_json: True plot: True device_val: 0 batch_size_val: 12 - - - - diff --git a/PW_FT_detection/main.py b/PW_FT_detection/main.py index d3f7cd140..d6afac3bb 100644 --- a/PW_FT_detection/main.py +++ b/PW_FT_detection/main.py @@ -1,79 +1,79 @@ -from ultralytics import YOLO, RTDETR -from utils import get_model_path -from munch import Munch -import yaml import os -def main(config:str='./config.yaml'): +import yaml +from munch import Munch +from ultralytics import RTDETR, YOLO +from utils import get_model_path + +def main(config: str = "./config.yaml"): # Load and set configurations from the YAML file with open(config) as f: cfg = Munch(yaml.load(f, Loader=yaml.FullLoader)) - if cfg.resume: - model_path = cfg.weights - else: - model_path = get_model_path(cfg.model_name) - - if cfg.model == "YOLO": - model = YOLO(model_path) - elif cfg.model == "RTDETR": - model = RTDETR(model_path) - else: - raise ValueError("Model not supported") + if cfg.resume: + model_path = cfg.weights + else: + model_path = get_model_path(cfg.model_name) + + if cfg.model == "YOLO": + model = YOLO(model_path) + elif cfg.model == "RTDETR": + model = RTDETR(model_path) + else: + raise ValueError("Model not supported") with open(cfg.data) as f: data = yaml.safe_load(f) if not os.path.isabs(data["path"]): data["path"] = os.path.abspath(data["path"]) - with open(cfg.data, 'w') as f: + with open(cfg.data, "w") as f: yaml.dump(data, f) - + model.info() if cfg.task == "train": - results = model.train( data=cfg.data, - epochs=cfg.epochs, + epochs=cfg.epochs, imgsz=cfg.imgsz, - device=cfg.device_train, - save_period=cfg.save_period, - workers=cfg.workers, - batch=cfg.batch_size_train, + device=cfg.device_train, + save_period=cfg.save_period, + workers=cfg.workers, + batch=cfg.batch_size_train, val=cfg.val, project=f"runs/train_{cfg.exp_name}", name="exp", patience=cfg.patience, - resume=cfg.resume - ) + resume=cfg.resume, + ) if cfg.task == "validation": - metrics = model.val( data=cfg.data, - save_json=cfg.save_json, - plots=cfg.plots, - device=cfg.device_val, - project=f'runs/val_{cfg.exp_name}', + save_json=cfg.save_json, + plots=cfg.plots, + device=cfg.device_val, + project=f"runs/val_{cfg.exp_name}", name="exp", - batch=cfg.batch_size_val) + batch=cfg.batch_size_val, + ) - metrics.box.map # map50-95 + metrics.box.map # map50-95 metrics.box.map50 # map50 metrics.box.map75 # map75 - metrics.box.maps # a list contains map50-95 of each category + metrics.box.maps # a list contains map50-95 of each category if cfg.task == "inference": - results = model(cfg.test_data) save_path = os.path.join("inference_results", cfg.exp_name) os.makedirs(save_path, exist_ok=True) for i in range(len(results)): results[i][0].boxes - results[i].save(filename=os.path.join(save_path,f"inference_{i}.jpg")) + results[i].save(filename=os.path.join(save_path, f"inference_{i}.jpg")) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/PW_FT_detection/requirements.txt b/PW_FT_detection/requirements.txt index f98f2933b..bc470a5c0 100644 --- a/PW_FT_detection/requirements.txt +++ b/PW_FT_detection/requirements.txt @@ -1,4 +1,4 @@ PytorchWildlife ultralytics munch -wget \ No newline at end of file +wget diff --git a/PW_FT_detection/utils.py b/PW_FT_detection/utils.py index 9cc6b262f..d20cd00eb 100644 --- a/PW_FT_detection/utils.py +++ b/PW_FT_detection/utils.py @@ -1,11 +1,12 @@ import os -import wget + import torch +import wget -def get_model_path(model): - if model == "MDV6-yolov9-c": - url = "https://zenodo.org/records/14567879/files/MDV6b-yolov9c.pt?download=1" +def get_model_path(model): + if model == "MDV6-yolov9-c": + url = "https://zenodo.org/records/14567879/files/MDV6b-yolov9c.pt?download=1" model_name = "MDV6b-yolov9c.pt" elif model == "MDV6-yolov9-e": url = "https://zenodo.org/records/14567879/files/MDV6-yolov9e.pt?download=1" @@ -20,12 +21,14 @@ def get_model_path(model): url = "https://zenodo.org/records/14567879/files/MDV6b-rtdetrl.pt?download=1" model_name = "MDV6b-rtdetrl.pt" else: - raise ValueError('Select a valid model version: MDV6-yolov9-c, MDV6-yolov9-e, MDV6-yolov10-c, MDV6-yolov10-e or MDV6-rtdetr-c') + raise ValueError( + "Select a valid model version: MDV6-yolov9-c, MDV6-yolov9-e, MDV6-yolov10-c, MDV6-yolov10-e or MDV6-rtdetr-c" + ) if not os.path.exists(os.path.join(torch.hub.get_dir(), "checkpoints", model_name)): os.makedirs(os.path.join(torch.hub.get_dir(), "checkpoints"), exist_ok=True) model_path = wget.download(url, out=os.path.join(torch.hub.get_dir(), "checkpoints")) else: model_path = os.path.join(torch.hub.get_dir(), "checkpoints", model_name) - - return model_path \ No newline at end of file + + return model_path diff --git a/PytorchWildlife/__init__.py b/PytorchWildlife/__init__.py index a59275f7a..1c4f1c9ed 100644 --- a/PytorchWildlife/__init__.py +++ b/PytorchWildlife/__init__.py @@ -1,4 +1,5 @@ import importlib.metadata as importlib_metadata + try: # This will read version from pyproject.toml __version__ = importlib_metadata.version(__package__ or __name__) @@ -7,4 +8,4 @@ from .data import * from .models import * -from .utils import * \ No newline at end of file +from .utils import * diff --git a/PytorchWildlife/data/__init__.py b/PytorchWildlife/data/__init__.py index 82ef5ae2d..cfa56d0b9 100644 --- a/PytorchWildlife/data/__init__.py +++ b/PytorchWildlife/data/__init__.py @@ -1,2 +1,2 @@ from .datasets import * -from .transforms import * \ No newline at end of file +from .transforms import * diff --git a/PytorchWildlife/data/datasets.py b/PytorchWildlife/data/datasets.py index ce280f02f..de0629103 100644 --- a/PytorchWildlife/data/datasets.py +++ b/PytorchWildlife/data/datasets.py @@ -3,10 +3,11 @@ import os from glob import glob -from PIL import Image, ImageFile + import numpy as np import supervision as sv import torch +from PIL import Image, ImageFile from torch.utils.data import Dataset # To handle truncated images during loading @@ -15,23 +16,28 @@ # Making the DetectionImageFolder class available for import from this module __all__ = [ "DetectionImageFolder", - ] - -# Define the allowed image extensions -IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") - -def has_file_allowed_extension(filename: str, extensions: tuple) -> bool: - """Checks if a file is an allowed extension.""" - return filename.lower().endswith(extensions if isinstance(extensions, str) else tuple(extensions)) - -def is_image_file(filename: str) -> bool: - """Checks if a file is an allowed image extension.""" - return has_file_allowed_extension(filename, IMG_EXTENSIONS) +] + +# Define the allowed image extensions +IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") + + +def has_file_allowed_extension(filename: str, extensions: tuple) -> bool: + """Checks if a file is an allowed extension.""" + return filename.lower().endswith( + extensions if isinstance(extensions, str) else tuple(extensions) + ) + + +def is_image_file(filename: str) -> bool: + """Checks if a file is an allowed image extension.""" + return has_file_allowed_extension(filename, IMG_EXTENSIONS) + class ImageFolder(Dataset): """ A PyTorch Dataset for loading images from a specified directory. - Each item in the dataset is a tuple containing the image data, + Each item in the dataset is a tuple containing the image data, the image's path, and the original size of the image. """ @@ -46,7 +52,12 @@ def __init__(self, image_dir, transform=None): super(ImageFolder, self).__init__() self.image_dir = image_dir self.transform = transform - self.images = [os.path.join(dp, f) for dp, dn, filenames in os.walk(image_dir) for f in filenames if is_image_file(f)] # dp: directory path, dn: directory name, f: filename + self.images = [ + os.path.join(dp, f) + for dp, dn, filenames in os.walk(image_dir) + for f in filenames + if is_image_file(f) + ] # dp: directory path, dn: directory name, f: filename def __getitem__(self, idx) -> tuple: """ @@ -59,7 +70,7 @@ def __getitem__(self, idx) -> tuple: tuple: Contains the image data, the image's path, and its original size. """ pass - + def __len__(self) -> int: """ Returns the total number of images in the dataset. @@ -69,10 +80,11 @@ def __len__(self) -> int: """ return len(self.images) + class ClassificationImageFolder(ImageFolder): """ A PyTorch Dataset for loading images from a specified directory. - Each item in the dataset is a tuple containing the image data, + Each item in the dataset is a tuple containing the image data, the image's path, and the original size of the image. """ @@ -101,7 +113,7 @@ def __getitem__(self, idx) -> tuple: # Load and convert image to RGB img = Image.open(img_path).convert("RGB") - + # Apply transformation if specified if self.transform: img = self.transform(img) @@ -112,7 +124,7 @@ def __getitem__(self, idx) -> tuple: class DetectionImageFolder(ImageFolder): """ A PyTorch Dataset for loading images from a specified directory. - Each item in the dataset is a tuple containing the image data, + Each item in the dataset is a tuple containing the image data, the image's path, and the original size of the image. """ @@ -142,23 +154,23 @@ def __getitem__(self, idx) -> tuple: # Load and convert image to RGB img = Image.open(img_path).convert("RGB") img_size_ori = img.size[::-1] - + # Apply transformation if specified if self.transform: img = self.transform(img) return img, img_path, torch.tensor(img_size_ori) - + # TODO: Under development for efficiency improvement class DetectionCrops(Dataset): - def __init__(self, detection_results, transform=None, path_head=None, animal_cls_id=0): - self.detection_results = detection_results self.transform = transform self.path_head = path_head - self.animal_cls_id = animal_cls_id # This determines which detection class id represents animals. + self.animal_cls_id = ( + animal_cls_id # This determines which detection class id represents animals. + ) self.img_ids = [] self.xyxys = [] @@ -188,11 +200,10 @@ def __getitem__(self, idx) -> tuple: xyxy = self.xyxys[idx] img_path = os.path.join(self.path_head, img_id) if self.path_head else img_id - + # Load and crop image with supervision - img = sv.crop_image(np.array(Image.open(img_path).convert("RGB")), - xyxy=xyxy) - + img = sv.crop_image(np.array(Image.open(img_path).convert("RGB")), xyxy=xyxy) + # Apply transformation if specified if self.transform: img = self.transform(Image.fromarray(img)) @@ -200,4 +211,4 @@ def __getitem__(self, idx) -> tuple: return img, img_path def __len__(self) -> int: - return len(self.img_ids) \ No newline at end of file + return len(self.img_ids) diff --git a/PytorchWildlife/data/transforms.py b/PytorchWildlife/data/transforms.py index c9bceab28..ba1e27e11 100644 --- a/PytorchWildlife/data/transforms.py +++ b/PytorchWildlife/data/transforms.py @@ -3,23 +3,28 @@ import numpy as np import torch -from torchvision import transforms -import torchvision.transforms as T import torch.nn.functional as F +import torchvision.transforms as T from PIL import Image +from torchvision import transforms # Making the provided classes available for import from this module -__all__ = [ - "MegaDetector_v5_Transform", - "Classification_Inference_Transform" -] - - -def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=False, scaleFill=False, scaleup=True, stride=32) -> torch.Tensor: +__all__ = ["MegaDetector_v5_Transform", "Classification_Inference_Transform"] + + +def letterbox( + im, + new_shape=(640, 640), + color=(114, 114, 114), + auto=False, + scaleFill=False, + scaleup=True, + stride=32, +) -> torch.Tensor: """ Resize and pad an image to a desired shape while keeping the aspect ratio unchanged. - This function is commonly used in object detection tasks to prepare images for models like YOLOv5. + This function is commonly used in object detection tasks to prepare images for models like YOLOv5. It resizes the image to fit into the new shape with the correct aspect ratio and then pads the rest. Args: @@ -64,19 +69,26 @@ def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=False, scale dw /= 2 dh /= 2 - + # Resize image if shape[::-1] != new_unpad: - resize_transform = T.Resize(new_unpad[::-1], interpolation=T.InterpolationMode.BILINEAR, - antialias=False) + resize_transform = T.Resize( + new_unpad[::-1], interpolation=T.InterpolationMode.BILINEAR, antialias=False + ) im = resize_transform(im) # Pad image - padding = (int(round(dw - 0.1)), int(round(dw + 0.1)), int(round(dh + 0.1)), int(round(dh - 0.1))) - im = F.pad(im*255.0, padding, value=114)/255.0 + padding = ( + int(round(dw - 0.1)), + int(round(dw + 0.1)), + int(round(dh + 0.1)), + int(round(dh - 0.1)), + ) + im = F.pad(im * 255.0, padding, value=114) / 255.0 return im + class MegaDetector_v5_Transform: """ A transformation class to preprocess images for the MegaDetector v5 model. @@ -113,16 +125,18 @@ def __call__(self, np_img) -> torch.Tensor: np_img = torch.from_numpy(np_img).float() np_img /= 255.0 - # Resize and pad the image using a customized letterbox function. + # Resize and pad the image using a customized letterbox function. img = letterbox(np_img, new_shape=self.target_size, stride=self.stride, auto=False) return img + class Classification_Inference_Transform: """ A transformation class to preprocess images for classification inference. This includes resizing, normalization, and conversion to a tensor. """ + # Normalization constants mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] @@ -135,12 +149,14 @@ def __init__(self, target_size=224, **kwargs): target_size (int): Desired size for the height and width after resizing. """ # Define the sequence of transformations - self.trans = transforms.Compose([ - # transforms.Resize((target_size, target_size)), - transforms.Resize((target_size, target_size), **kwargs), - transforms.ToTensor(), - transforms.Normalize(self.mean, self.std) - ]) + self.trans = transforms.Compose( + [ + # transforms.Resize((target_size, target_size)), + transforms.Resize((target_size, target_size), **kwargs), + transforms.ToTensor(), + transforms.Normalize(self.mean, self.std), + ] + ) def __call__(self, img) -> torch.Tensor: """ diff --git a/PytorchWildlife/models/__init__.py b/PytorchWildlife/models/__init__.py index 89e334c07..b134b2934 100644 --- a/PytorchWildlife/models/__init__.py +++ b/PytorchWildlife/models/__init__.py @@ -1,2 +1,2 @@ from .classification import * -from .detection import * \ No newline at end of file +from .detection import * diff --git a/PytorchWildlife/models/classification/__init__.py b/PytorchWildlife/models/classification/__init__.py index e450ab78b..2963d6ec7 100644 --- a/PytorchWildlife/models/classification/__init__.py +++ b/PytorchWildlife/models/classification/__init__.py @@ -1,3 +1,3 @@ +from .base_classifier import * from .resnet_base import * from .timm_base import * -from .base_classifier import * \ No newline at end of file diff --git a/PytorchWildlife/models/classification/base_classifier.py b/PytorchWildlife/models/classification/base_classifier.py index c3a60ed36..5dc57f0cf 100644 --- a/PytorchWildlife/models/classification/base_classifier.py +++ b/PytorchWildlife/models/classification/base_classifier.py @@ -1,4 +1,3 @@ - # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. @@ -7,10 +6,12 @@ # Making the PlainResNetInference class available for import from this module __all__ = ["BaseClassifierInference"] + class BaseClassifierInference(nn.Module): """ Inference module for the PlainResNet Classifier. """ + def __init__(self): super(BaseClassifierInference, self).__init__() pass diff --git a/PytorchWildlife/models/classification/resnet_base/__init__.py b/PytorchWildlife/models/classification/resnet_base/__init__.py index 10c6d0ba2..bb2ae189b 100644 --- a/PytorchWildlife/models/classification/resnet_base/__init__.py +++ b/PytorchWildlife/models/classification/resnet_base/__init__.py @@ -1,5 +1,5 @@ +from .amazon import * from .base_classifier import * +from .custom_weights import * from .opossum import * -from .amazon import * from .serengeti import * -from .custom_weights import * \ No newline at end of file diff --git a/PytorchWildlife/models/classification/resnet_base/amazon.py b/PytorchWildlife/models/classification/resnet_base/amazon.py index 8b228b04e..92d237322 100644 --- a/PytorchWildlife/models/classification/resnet_base/amazon.py +++ b/PytorchWildlife/models/classification/resnet_base/amazon.py @@ -2,11 +2,10 @@ # Licensed under the MIT License. import torch + from .base_classifier import PlainResNetInference -__all__ = [ - "AI4GAmazonRainforest" -] +__all__ = ["AI4GAmazonRainforest"] class AI4GAmazonRainforest(PlainResNetInference): @@ -14,48 +13,48 @@ class AI4GAmazonRainforest(PlainResNetInference): Amazon Ranforest Animal Classifier that inherits from PlainResNetInference. This classifier is specialized for recognizing 36 different animals in the Amazon Rainforest. """ - + # Image size for the Opossum classifier IMAGE_SIZE = 224 - + # Class names for prediction CLASS_NAMES = { - 0: 'Dasyprocta', - 1: 'Bos', - 2: 'Pecari', - 3: 'Mazama', - 4: 'Cuniculus', - 5: 'Leptotila', - 6: 'Human', - 7: 'Aramides', - 8: 'Tinamus', - 9: 'Eira', - 10: 'Crax', - 11: 'Procyon', - 12: 'Capra', - 13: 'Dasypus', - 14: 'Sciurus', - 15: 'Crypturellus', - 16: 'Tamandua', - 17: 'Proechimys', - 18: 'Leopardus', - 19: 'Equus', - 20: 'Columbina', - 21: 'Nyctidromus', - 22: 'Ortalis', - 23: 'Emballonura', - 24: 'Odontophorus', - 25: 'Geotrygon', - 26: 'Metachirus', - 27: 'Catharus', - 28: 'Cerdocyon', - 29: 'Momotus', - 30: 'Tapirus', - 31: 'Canis', - 32: 'Furnarius', - 33: 'Didelphis', - 34: 'Sylvilagus', - 35: 'Unknown' + 0: "Dasyprocta", + 1: "Bos", + 2: "Pecari", + 3: "Mazama", + 4: "Cuniculus", + 5: "Leptotila", + 6: "Human", + 7: "Aramides", + 8: "Tinamus", + 9: "Eira", + 10: "Crax", + 11: "Procyon", + 12: "Capra", + 13: "Dasypus", + 14: "Sciurus", + 15: "Crypturellus", + 16: "Tamandua", + 17: "Proechimys", + 18: "Leopardus", + 19: "Equus", + 20: "Columbina", + 21: "Nyctidromus", + 22: "Ortalis", + 23: "Emballonura", + 24: "Odontophorus", + 25: "Geotrygon", + 26: "Metachirus", + 27: "Catharus", + 28: "Cerdocyon", + 29: "Momotus", + 30: "Tapirus", + 31: "Canis", + 32: "Furnarius", + 33: "Didelphis", + 34: "Sylvilagus", + 35: "Unknown", } def __init__(self, weights=None, device="cpu", pretrained=True, version="v2"): @@ -71,29 +70,34 @@ def __init__(self, weights=None, device="cpu", pretrained=True, version="v2"): # If pretrained, use the provided URL to fetch the weights if pretrained: - if version == 'v1': + if version == "v1": url = "https://zenodo.org/records/10042023/files/AI4GAmazonClassification_v0.0.0.ckpt?download=1" - elif version == 'v2': - url = "https://zenodo.org/records/14252214/files/AI4GAmazonDeforestationv2?download=1" + elif version == "v2": + url = ( + "https://zenodo.org/records/14252214/files/AI4GAmazonDeforestationv2?download=1" + ) else: url = None - super(AI4GAmazonRainforest, self).__init__(weights=weights, device=device, - num_cls=36, num_layers=50, url=url) + super(AI4GAmazonRainforest, self).__init__( + weights=weights, device=device, num_cls=36, num_layers=50, url=url + ) - def results_generation(self, logits: torch.Tensor, img_ids: list[str], id_strip: str = None) -> list[dict]: + def results_generation( + self, logits: torch.Tensor, img_ids: list[str], id_strip: str = None + ) -> list[dict]: """ Generate results for classification. Args: logits (torch.Tensor): Output tensor from the model. img_ids (str): Image identifier. - id_strip (str): stiping string for better image id saving. + id_strip (str): stiping string for better image id saving. Returns: dict: Dictionary containing image ID, prediction, and confidence score. """ - + probs = torch.softmax(logits, dim=1) preds = probs.argmax(dim=1) confs = probs.max(dim=1)[0] @@ -108,5 +112,5 @@ def results_generation(self, logits: torch.Tensor, img_ids: list[str], id_strip: r["confidence"] = conf.item() r["all_confidences"] = result results.append(r) - + return results diff --git a/PytorchWildlife/models/classification/resnet_base/base_classifier.py b/PytorchWildlife/models/classification/resnet_base/base_classifier.py index 965ff4389..fc18dfbe4 100644 --- a/PytorchWildlife/models/classification/resnet_base/base_classifier.py +++ b/PytorchWildlife/models/classification/resnet_base/base_classifier.py @@ -1,20 +1,20 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import numpy as np -from PIL import Image -from tqdm import tqdm from collections import OrderedDict +import numpy as np import torch import torch.nn as nn -from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet +from PIL import Image from torch.hub import load_state_dict_from_url from torch.utils.data import DataLoader +from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet +from tqdm import tqdm -from ..base_classifier import BaseClassifierInference +from ....data import datasets as pw_data from ....data import transforms as pw_trans -from ....data import datasets as pw_data +from ..base_classifier import BaseClassifierInference # Making the PlainResNetInference class available for import from this module __all__ = ["PlainResNetInference"] @@ -24,6 +24,7 @@ class ResNetBackbone(ResNet): """ Custom ResNet Backbone that extracts features from input images. """ + def _forward_impl(self, x): # Following the ResNet structure to extract features x = self.conv1(x) @@ -45,6 +46,7 @@ class PlainResNetClassifier(nn.Module): """ Basic ResNet Classifier that uses a custom ResNet backbone. """ + name = "PlainResNetClassifier" def __init__(self, num_cls=1, num_layers=50): @@ -88,8 +90,12 @@ def feat_init(self): Initialize the features using pretrained weights. """ init_weights = self.pretrained_weights.get_state_dict(progress=True) - init_weights = OrderedDict({k.replace("module.", "").replace("feature.", ""): init_weights[k] - for k in init_weights}) + init_weights = OrderedDict( + { + k.replace("module.", "").replace("feature.", ""): init_weights[k] + for k in init_weights + } + ) self.feature.load_state_dict(init_weights, strict=False) # Print missing and unused keys for debugging purposes load_keys = set(init_weights.keys()) @@ -104,8 +110,12 @@ class PlainResNetInference(BaseClassifierInference): """ Inference module for the PlainResNet Classifier. """ + IMAGE_SIZE = None - def __init__(self, num_cls=36, num_layers=50, weights=None, device="cpu", url=None, transform=None): + + def __init__( + self, num_cls=36, num_layers=50, weights=None, device="cpu", url=None, transform=None + ): super(PlainResNetInference, self).__init__() self.device = device self.net = PlainResNetClassifier(num_cls=num_cls, num_layers=num_layers) @@ -122,16 +132,20 @@ def __init__(self, num_cls=36, num_layers=50, weights=None, device="cpu", url=No if transform: self.transform = transform else: - self.transform = pw_trans.Classification_Inference_Transform(target_size=self.IMAGE_SIZE) + self.transform = pw_trans.Classification_Inference_Transform( + target_size=self.IMAGE_SIZE + ) - def results_generation(self, logits: torch.Tensor, img_ids: list[str], id_strip: str = None) -> list[dict]: + def results_generation( + self, logits: torch.Tensor, img_ids: list[str], id_strip: str = None + ) -> list[dict]: """ - Process logits to produce final results. + Process logits to produce final results. Args: logits (torch.Tensor): Logits from the network. img_ids (list[str]): List of image paths. - id_strip (str): Stripping string for better image ID saving. + id_strip (str): Stripping string for better image ID saving. Returns: list[dict]: List of dictionaries containing the results. @@ -158,26 +172,19 @@ def batch_image_classification(self, data_path=None, det_results=None, id_strip= """ if data_path: - dataset = pw_data.ImageFolder( - data_path, - transform=self.transform, - path_head='.' - ) + dataset = pw_data.ImageFolder(data_path, transform=self.transform, path_head=".") elif det_results: - dataset = pw_data.DetectionCrops( - det_results, - transform=self.transform, - path_head='.' - ) + dataset = pw_data.DetectionCrops(det_results, transform=self.transform, path_head=".") else: raise Exception("Need data for inference.") - dataloader = DataLoader(dataset, batch_size=32, shuffle=False, - pin_memory=True, num_workers=4, drop_last=False) + dataloader = DataLoader( + dataset, batch_size=32, shuffle=False, pin_memory=True, num_workers=4, drop_last=False + ) total_logits = [] total_paths = [] - with tqdm(total=len(dataloader)) as pbar: + with tqdm(total=len(dataloader)) as pbar: for batch in dataloader: imgs, paths = batch imgs = imgs.to(self.device) diff --git a/PytorchWildlife/models/classification/resnet_base/custom_weights.py b/PytorchWildlife/models/classification/resnet_base/custom_weights.py index 7247abcf5..8027e2b70 100644 --- a/PytorchWildlife/models/classification/resnet_base/custom_weights.py +++ b/PytorchWildlife/models/classification/resnet_base/custom_weights.py @@ -2,11 +2,10 @@ # Licensed under the MIT License. import torch + from .base_classifier import PlainResNetInference -__all__ = [ - "CustomWeights" -] +__all__ = ["CustomWeights"] class CustomWeights(PlainResNetInference): @@ -14,11 +13,10 @@ class CustomWeights(PlainResNetInference): Custom Weight Classifier that inherits from PlainResNetInference. This classifier can load any model that was based on the PytorchWildlife finetuning tool. """ - + # Image size for the classifier IMAGE_SIZE = 224 - def __init__(self, weights=None, class_names=None, device="cpu"): """ Initialize the CustomWeights Classifier. @@ -30,10 +28,13 @@ def __init__(self, weights=None, class_names=None, device="cpu"): """ self.CLASS_NAMES = class_names self.num_cls = len(self.CLASS_NAMES) - super(CustomWeights, self).__init__(weights=weights, device=device, - num_cls=self.num_cls, num_layers=50, url=None) + super(CustomWeights, self).__init__( + weights=weights, device=device, num_cls=self.num_cls, num_layers=50, url=None + ) - def results_generation(self, logits: torch.Tensor, img_ids: list[str], id_strip: str = None) -> list[dict]: + def results_generation( + self, logits: torch.Tensor, img_ids: list[str], id_strip: str = None + ) -> list[dict]: """ Generate results for classification. @@ -45,7 +46,7 @@ def results_generation(self, logits: torch.Tensor, img_ids: list[str], id_strip: Returns: list[dict]: List of dictionaries containing image ID, prediction, and confidence score. """ - + probs = torch.softmax(logits, dim=1) preds = probs.argmax(dim=1) confs = probs.max(dim=1)[0] @@ -60,5 +61,5 @@ def results_generation(self, logits: torch.Tensor, img_ids: list[str], id_strip: r["confidence"] = conf.item() r["all_confidences"] = result results.append(r) - + return results diff --git a/PytorchWildlife/models/classification/resnet_base/opossum.py b/PytorchWildlife/models/classification/resnet_base/opossum.py index 453533394..5217e18aa 100644 --- a/PytorchWildlife/models/classification/resnet_base/opossum.py +++ b/PytorchWildlife/models/classification/resnet_base/opossum.py @@ -2,11 +2,10 @@ # Licensed under the MIT License. import torch + from .base_classifier import PlainResNetInference -__all__ = [ - "AI4GOpossum" -] +__all__ = ["AI4GOpossum"] class AI4GOpossum(PlainResNetInference): @@ -14,15 +13,12 @@ class AI4GOpossum(PlainResNetInference): Opossum Classifier that inherits from PlainResNetInference. This classifier is specialized for distinguishing between Opossums and Non-opossums. """ - + # Image size for the Opossum classifier IMAGE_SIZE = 224 - + # Class names for prediction - CLASS_NAMES = { - 0: "Non-opossum", - 1: "Opossum" - } + CLASS_NAMES = {0: "Non-opossum", 1: "Opossum"} def __init__(self, weights=None, device="cpu", pretrained=True): """ @@ -40,17 +36,20 @@ def __init__(self, weights=None, device="cpu", pretrained=True): else: url = None - super(AI4GOpossum, self).__init__(weights=weights, device=device, - num_cls=1, num_layers=50, url=url) + super(AI4GOpossum, self).__init__( + weights=weights, device=device, num_cls=1, num_layers=50, url=url + ) - def results_generation(self, logits: torch.Tensor, img_ids: list[str], id_strip: str = None) -> list[dict]: + def results_generation( + self, logits: torch.Tensor, img_ids: list[str], id_strip: str = None + ) -> list[dict]: """ Generate results for classification. Args: logits (torch.Tensor): Output tensor from the model. img_ids (list): List of image identifier. - id_strip (str): stiping string for better image id saving. + id_strip (str): stiping string for better image id saving. Returns: dict: Dictionary containing image ID, prediction, and confidence score. @@ -66,5 +65,5 @@ def results_generation(self, logits: torch.Tensor, img_ids: list[str], id_strip: r["class_id"] = pred r["confidence"] = prob.item() if pred == 1 else (1 - prob.item()) results.append(r) - + return results diff --git a/PytorchWildlife/models/classification/resnet_base/serengeti.py b/PytorchWildlife/models/classification/resnet_base/serengeti.py index e0820a02f..63d0ee697 100644 --- a/PytorchWildlife/models/classification/resnet_base/serengeti.py +++ b/PytorchWildlife/models/classification/resnet_base/serengeti.py @@ -2,11 +2,10 @@ # Licensed under the MIT License. import torch + from .base_classifier import PlainResNetInference -__all__ = [ - "AI4GSnapshotSerengeti" -] +__all__ = ["AI4GSnapshotSerengeti"] class AI4GSnapshotSerengeti(PlainResNetInference): @@ -14,22 +13,22 @@ class AI4GSnapshotSerengeti(PlainResNetInference): Snapshot Serengeti Animal Classifier that inherits from PlainResNetInference. This classifier is specialized for recognizing 9 different animals and has 1 'other' class. """ - + # Image size for the Opossum classifier IMAGE_SIZE = 224 - + # Class names for prediction CLASS_NAMES = { - 0: 'wildebeest', - 1: 'guineafowl', - 2: 'zebra', - 3: 'buffalo', - 4: 'gazellethomsons', - 5: 'gazellegrants', - 6: 'warthog', - 7: 'impala', - 8: 'hyenaspotted', - 9: 'other' + 0: "wildebeest", + 1: "guineafowl", + 2: "zebra", + 3: "buffalo", + 4: "gazellethomsons", + 5: "gazellegrants", + 6: "warthog", + 7: "impala", + 8: "hyenaspotted", + 9: "other", } def __init__(self, weights=None, device="cpu", pretrained=True): @@ -48,22 +47,25 @@ def __init__(self, weights=None, device="cpu", pretrained=True): else: url = None - super(AI4GSnapshotSerengeti, self).__init__(weights=weights, device=device, - num_cls=10, num_layers=18, url=url) + super(AI4GSnapshotSerengeti, self).__init__( + weights=weights, device=device, num_cls=10, num_layers=18, url=url + ) - def results_generation(self, logits: torch.Tensor, img_ids: list[str], id_strip: str = None) -> list[dict]: + def results_generation( + self, logits: torch.Tensor, img_ids: list[str], id_strip: str = None + ) -> list[dict]: """ Generate results for classification. Args: logits (torch.Tensor): Output tensor from the model. img_ids (str): Image identifier. - id_strip (str): stiping string for better image id saving. + id_strip (str): stiping string for better image id saving. Returns: dict: Dictionary containing image ID, prediction, and confidence score. """ - + probs = torch.softmax(logits, dim=1) preds = probs.argmax(dim=1) confs = probs.max(dim=1)[0] @@ -78,5 +80,5 @@ def results_generation(self, logits: torch.Tensor, img_ids: list[str], id_strip: r["confidence"] = conf.item() r["all_confidences"] = result results.append(r) - + return results diff --git a/PytorchWildlife/models/classification/timm_base/DFNE.py b/PytorchWildlife/models/classification/timm_base/DFNE.py index 1cb6d71b0..40674d6c1 100644 --- a/PytorchWildlife/models/classification/timm_base/DFNE.py +++ b/PytorchWildlife/models/classification/timm_base/DFNE.py @@ -10,47 +10,53 @@ from .base_classifier import TIMM_BaseClassifierInference -__all__ = [ - "DFNE" -] +__all__ = ["DFNE"] + class DFNE(TIMM_BaseClassifierInference): """ Base detector class for dinov2 classifier. This class provides utility methods - for loading the model, performing single and batch image classifications, and - formatting results. Make sure the appropriate file for the model weights has been + for loading the model, performing single and batch image classifications, and + formatting results. Make sure the appropriate file for the model weights has been downloaded to the "models" folder before running DFNE. """ + BACKBONE = "vit_large_patch14_dinov2.lvd142m" MODEL_NAME = "dfne_weights_v1_0.pth" IMAGE_SIZE = 182 CLASS_NAMES = { - 0: "American Marten", - 1: "Bird sp.", - 2: "Black Bear", - 3: "Bobcat", - 4: "Coyote", - 5: "Domestic Cat", - 6: "Domestic Cow", - 7: "Domestic Dog", - 8: "Fisher", - 9: "Gray Fox", - 10: "Gray Squirrel", - 11: "Human", - 12: "Moose", - 13: "Mouse sp.", - 14: "Opossum", - 15: "Raccoon", - 16: "Red Fox", - 17: "Red Squirrel", - 18: "Skunk", - 19: "Snowshoe Hare", - 20: "White-tailed Deer", - 21: "Wild Boar", - 22: "Wild Turkey", - 23: "no-species" - } + 0: "American Marten", + 1: "Bird sp.", + 2: "Black Bear", + 3: "Bobcat", + 4: "Coyote", + 5: "Domestic Cat", + 6: "Domestic Cow", + 7: "Domestic Dog", + 8: "Fisher", + 9: "Gray Fox", + 10: "Gray Squirrel", + 11: "Human", + 12: "Moose", + 13: "Mouse sp.", + 14: "Opossum", + 15: "Raccoon", + 16: "Red Fox", + 17: "Red Squirrel", + 18: "Skunk", + 19: "Snowshoe Hare", + 20: "White-tailed Deer", + 21: "Wild Boar", + 22: "Wild Turkey", + 23: "no-species", + } def __init__(self, weights=None, device="cpu", transform=None): - url = 'https://prod-is-usgs-sb-prod-publish.s3.amazonaws.com/67ae17fcd34e3f09c0e0f002/dfne_weights_v1_0.pth' - super(DFNE, self).__init__(weights=weights, device=device, url=url, transform=transform, weights_key='model_state_dict') \ No newline at end of file + url = "https://prod-is-usgs-sb-prod-publish.s3.amazonaws.com/67ae17fcd34e3f09c0e0f002/dfne_weights_v1_0.pth" + super(DFNE, self).__init__( + weights=weights, + device=device, + url=url, + transform=transform, + weights_key="model_state_dict", + ) diff --git a/PytorchWildlife/models/classification/timm_base/Deepfaune.py b/PytorchWildlife/models/classification/timm_base/Deepfaune.py index a024729b7..8c931838a 100644 --- a/PytorchWildlife/models/classification/timm_base/Deepfaune.py +++ b/PytorchWildlife/models/classification/timm_base/Deepfaune.py @@ -9,39 +9,186 @@ # Import libraries from torchvision.transforms.functional import InterpolationMode -from .base_classifier import TIMM_BaseClassifierInference + from ....data import transforms as pw_trans +from .base_classifier import TIMM_BaseClassifierInference + +__all__ = ["DeepfauneClassifier"] -__all__ = [ - "DeepfauneClassifier" -] class DeepfauneClassifier(TIMM_BaseClassifierInference): """ Base detector class for dinov2 classifier. This class provides utility methods - for loading the model, performing single and batch image classifications, and - formatting results. Make sure the appropriate file for the model weights has been + for loading the model, performing single and batch image classifications, and + formatting results. Make sure the appropriate file for the model weights has been downloaded to the "models" folder before running DFNE. """ + BACKBONE = "vit_large_patch14_dinov2.lvd142m" MODEL_NAME = "deepfaune-vit_large_patch14_dinov2.lvd142m.v3.pt" IMAGE_SIZE = 182 - CLASS_NAMES={ - 'fr': ['bison', 'blaireau', 'bouquetin', 'castor', 'cerf', 'chamois', 'chat', 'chevre', 'chevreuil', 'chien', 'daim', 'ecureuil', 'elan', 'equide', 'genette', 'glouton', 'herisson', 'lagomorphe', 'loup', 'loutre', 'lynx', 'marmotte', 'micromammifere', 'mouflon', 'mouton', 'mustelide', 'oiseau', 'ours', 'ragondin', 'raton laveur', 'renard', 'renne', 'sanglier', 'vache'], - 'en': ['bison', 'badger', 'ibex', 'beaver', 'red deer', 'chamois', 'cat', 'goat', 'roe deer', 'dog', 'fallow deer', 'squirrel', 'moose', 'equid', 'genet', 'wolverine', 'hedgehog', 'lagomorph', 'wolf', 'otter', 'lynx', 'marmot', 'micromammal', 'mouflon', 'sheep', 'mustelid', 'bird', 'bear', 'nutria', 'raccoon', 'fox', 'reindeer', 'wild boar', 'cow'], - 'it': ['bisonte', 'tasso', 'stambecco', 'castoro', 'cervo', 'camoscio', 'gatto', 'capra', 'capriolo', 'cane', 'daino', 'scoiattolo', 'alce', 'equide', 'genetta', 'ghiottone', 'riccio', 'lagomorfo', 'lupo', 'lontra', 'lince', 'marmotta', 'micromammifero', 'muflone', 'pecora', 'mustelide', 'uccello', 'orso', 'nutria', 'procione', 'volpe', 'renna', 'cinghiale', 'mucca'], - 'de': ['Bison', 'Dachs', 'Steinbock', 'Biber', 'Rothirsch', 'Gämse', 'Katze', 'Ziege', 'Rehwild', 'Hund', 'Damwild', 'Eichhörnchen', 'Elch', 'Equide', 'Ginsterkatze', 'Vielfraß', 'Igel', 'Lagomorpha', 'Wolf', 'Otter', 'Luchs', 'Murmeltier', 'Kleinsäuger', 'Mufflon', 'Schaf', 'Marder', 'Vogel', 'Bär', 'Nutria', 'Waschbär', 'Fuchs', 'Rentier', 'Wildschwein', 'Kuh'], + CLASS_NAMES = { + "fr": [ + "bison", + "blaireau", + "bouquetin", + "castor", + "cerf", + "chamois", + "chat", + "chevre", + "chevreuil", + "chien", + "daim", + "ecureuil", + "elan", + "equide", + "genette", + "glouton", + "herisson", + "lagomorphe", + "loup", + "loutre", + "lynx", + "marmotte", + "micromammifere", + "mouflon", + "mouton", + "mustelide", + "oiseau", + "ours", + "ragondin", + "raton laveur", + "renard", + "renne", + "sanglier", + "vache", + ], + "en": [ + "bison", + "badger", + "ibex", + "beaver", + "red deer", + "chamois", + "cat", + "goat", + "roe deer", + "dog", + "fallow deer", + "squirrel", + "moose", + "equid", + "genet", + "wolverine", + "hedgehog", + "lagomorph", + "wolf", + "otter", + "lynx", + "marmot", + "micromammal", + "mouflon", + "sheep", + "mustelid", + "bird", + "bear", + "nutria", + "raccoon", + "fox", + "reindeer", + "wild boar", + "cow", + ], + "it": [ + "bisonte", + "tasso", + "stambecco", + "castoro", + "cervo", + "camoscio", + "gatto", + "capra", + "capriolo", + "cane", + "daino", + "scoiattolo", + "alce", + "equide", + "genetta", + "ghiottone", + "riccio", + "lagomorfo", + "lupo", + "lontra", + "lince", + "marmotta", + "micromammifero", + "muflone", + "pecora", + "mustelide", + "uccello", + "orso", + "nutria", + "procione", + "volpe", + "renna", + "cinghiale", + "mucca", + ], + "de": [ + "Bison", + "Dachs", + "Steinbock", + "Biber", + "Rothirsch", + "Gämse", + "Katze", + "Ziege", + "Rehwild", + "Hund", + "Damwild", + "Eichhörnchen", + "Elch", + "Equide", + "Ginsterkatze", + "Vielfraß", + "Igel", + "Lagomorpha", + "Wolf", + "Otter", + "Luchs", + "Murmeltier", + "Kleinsäuger", + "Mufflon", + "Schaf", + "Marder", + "Vogel", + "Bär", + "Nutria", + "Waschbär", + "Fuchs", + "Rentier", + "Wildschwein", + "Kuh", + ], } - - def __init__(self, weights=None, device="cpu", transform=None, class_name_lang='en'): - url = 'https://pbil.univ-lyon1.fr/software/download/deepfaune/v1.3/deepfaune-vit_large_patch14_dinov2.lvd142m.v3.pt' + def __init__(self, weights=None, device="cpu", transform=None, class_name_lang="en"): + url = "https://pbil.univ-lyon1.fr/software/download/deepfaune/v1.3/deepfaune-vit_large_patch14_dinov2.lvd142m.v3.pt" self.CLASS_NAMES = {i: c for i, c in enumerate(self.CLASS_NAMES[class_name_lang])} if transform is None: - transform = pw_trans.Classification_Inference_Transform(target_size=self.IMAGE_SIZE, - interpolation=InterpolationMode.BICUBIC, - max_size=None, - antialias=None) - super(DeepfauneClassifier, self).__init__(weights=weights, device=device, url=url, transform=transform, - weights_key='state_dict', weights_prefix='base_model.') - \ No newline at end of file + transform = pw_trans.Classification_Inference_Transform( + target_size=self.IMAGE_SIZE, + interpolation=InterpolationMode.BICUBIC, + max_size=None, + antialias=None, + ) + super(DeepfauneClassifier, self).__init__( + weights=weights, + device=device, + url=url, + transform=transform, + weights_key="state_dict", + weights_prefix="base_model.", + ) diff --git a/PytorchWildlife/models/classification/timm_base/__init__.py b/PytorchWildlife/models/classification/timm_base/__init__.py index 1938f933a..806a29c9f 100644 --- a/PytorchWildlife/models/classification/timm_base/__init__.py +++ b/PytorchWildlife/models/classification/timm_base/__init__.py @@ -1,3 +1,3 @@ from .base_classifier import * from .Deepfaune import * -from .DFNE import * \ No newline at end of file +from .DFNE import * diff --git a/PytorchWildlife/models/classification/timm_base/base_classifier.py b/PytorchWildlife/models/classification/timm_base/base_classifier.py index 682333b17..452faea39 100644 --- a/PytorchWildlife/models/classification/timm_base/base_classifier.py +++ b/PytorchWildlife/models/classification/timm_base/base_classifier.py @@ -3,28 +3,27 @@ # Import libraries import os -import wget -import numpy as np -import pandas as pd -from tqdm import tqdm -from PIL import Image from collections import OrderedDict +import numpy as np +import pandas as pd +import timm import torch +import wget +from PIL import Image from torch.utils.data import DataLoader +from tqdm import tqdm -import timm - -from ..base_classifier import BaseClassifierInference +from ....data import datasets as pw_data from ....data import transforms as pw_trans -from ....data import datasets as pw_data +from ..base_classifier import BaseClassifierInference class TIMM_BaseClassifierInference(BaseClassifierInference): """ Base detector class for dinov2 classifier. This class provides utility methods - for loading the model, performing single and batch image classifications, and - formatting results. Make sure the appropriate file for the model weights has been + for loading the model, performing single and batch image classifications, and + formatting results. Make sure the appropriate file for the model weights has been downloaded to the "models" folder before running DFNE. """ @@ -32,21 +31,28 @@ class TIMM_BaseClassifierInference(BaseClassifierInference): MODEL_NAME = None IMAGE_SIZE = None - def __init__(self, weights=None, device="cpu", url=None, transform=None, - weights_key='model_state_dict', weights_prefix=''): + def __init__( + self, + weights=None, + device="cpu", + url=None, + transform=None, + weights_key="model_state_dict", + weights_prefix="", + ): """ Initialize the model. - + Args: - weights (str, optional): + weights (str, optional): Path to the model weights. Defaults to None. - device (str, optional): + device (str, optional): Device for model inference. Defaults to "cpu". - url (str, optional): + url (str, optional): URL to fetch the model weights. Defaults to None. - weights_key (str, optional): + weights_key (str, optional): Key to fetch the model weights. Defaults to None. - weights_prefix (str, optional): + weights_prefix (str, optional): prefix of model weight keys. Defaults to None. """ super(TIMM_BaseClassifierInference, self).__init__() @@ -55,30 +61,36 @@ def __init__(self, weights=None, device="cpu", url=None, transform=None, if transform: self.transform = transform else: - self.transform = pw_trans.Classification_Inference_Transform(target_size=self.IMAGE_SIZE) + self.transform = pw_trans.Classification_Inference_Transform( + target_size=self.IMAGE_SIZE + ) self._load_model(weights, url, weights_key, weights_prefix) - def _load_model(self, weights=None, url=None, weights_key='model_state_dict', weights_prefix=''): + def _load_model( + self, weights=None, url=None, weights_key="model_state_dict", weights_prefix="" + ): """ Load TIMM based model weights - + Args: - weights (str, optional): + weights (str, optional): Path to the model weights. (defaults to None) - url (str, optional): + url (str, optional): url to the model weights. (defaults to None) """ self.predictor = timm.create_model( - self.BACKBONE, - pretrained = False, - num_classes = len(self.CLASS_NAMES), - dynamic_img_size = True + self.BACKBONE, + pretrained=False, + num_classes=len(self.CLASS_NAMES), + dynamic_img_size=True, ) if url: - if not os.path.exists(os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME)): + if not os.path.exists( + os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME) + ): os.makedirs(os.path.join(torch.hub.get_dir(), "checkpoints"), exist_ok=True) weights = wget.download(url, out=os.path.join(torch.hub.get_dir(), "checkpoints")) else: @@ -86,22 +98,27 @@ def _load_model(self, weights=None, url=None, weights_key='model_state_dict', we elif weights is None: raise Exception("Need weights for inference.") - checkpoint = torch.load( - f = weights, - map_location = self.device, - weights_only = False - )[weights_key] + checkpoint = torch.load(f=weights, map_location=self.device, weights_only=False)[ + weights_key + ] - checkpoint = OrderedDict({k.replace("{}".format(weights_prefix), ""): checkpoint[k] - for k in checkpoint}) + checkpoint = OrderedDict( + {k.replace("{}".format(weights_prefix), ""): checkpoint[k] for k in checkpoint} + ) self.predictor.load_state_dict(checkpoint) - print("Model loaded from {}".format(os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME))) + print( + "Model loaded from {}".format( + os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME) + ) + ) self.predictor.to(self.device) self.eval() - def results_generation(self, logits: torch.Tensor, img_ids: list[str], id_strip: str = None) -> list[dict]: + def results_generation( + self, logits: torch.Tensor, img_ids: list[str], id_strip: str = None + ) -> list[dict]: """ Generate results for classification. @@ -113,7 +130,7 @@ def results_generation(self, logits: torch.Tensor, img_ids: list[str], id_strip: Returns: list[dict]: List of dictionaries containing image ID, prediction, and confidence score. """ - + probs = torch.softmax(logits, dim=1) preds = probs.argmax(dim=1) confs = probs.max(dim=1)[0] @@ -128,17 +145,17 @@ def results_generation(self, logits: torch.Tensor, img_ids: list[str], id_strip: r["confidence"] = conf.item() r["all_confidences"] = result results.append(r) - + return results def single_image_classification(self, img, img_id=None, id_strip=None): """ Perform classification on a single image. - + Args: - img (str or ndarray): + img (str or ndarray): Image path or ndarray of images. - img_id (str, optional): + img_id (str, optional): Image path or identifier. id_strip (str, optional): Whether to strip stings in id. Defaults to None. @@ -155,14 +172,21 @@ def single_image_classification(self, img, img_id=None, id_strip=None): logits = self.predictor(img.unsqueeze(0).to(self.device)) return self.results_generation(logits.cpu(), [img_id], id_strip=id_strip)[0] - def batch_image_classification(self, data_path=None, det_results=None, id_strip=None, - batch_size=32, num_workers=0, **kwargs): + def batch_image_classification( + self, + data_path=None, + det_results=None, + id_strip=None, + batch_size=32, + num_workers=0, + **kwargs + ): """ Perform classification on a batch of images. - + Args: - data_path (str): - Path containing all images for inference. Defaults to None. + data_path (str): + Path containing all images for inference. Defaults to None. det_results (dict): Dirct outputs from detectors. Defaults to None. id_strip (str, optional): @@ -177,27 +201,26 @@ def batch_image_classification(self, data_path=None, det_results=None, id_strip= """ if data_path: - dataset = pw_data.ImageFolder( - data_path, - transform=self.transform, - path_head='.' - ) + dataset = pw_data.ImageFolder(data_path, transform=self.transform, path_head=".") elif det_results: - dataset = pw_data.DetectionCrops( - det_results, - transform=self.transform, - path_head='.' - ) + dataset = pw_data.DetectionCrops(det_results, transform=self.transform, path_head=".") else: raise Exception("Need data for inference.") - dataloader = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=num_workers, - shuffle=False, pin_memory=True, drop_last=False, **kwargs) - + dataloader = DataLoader( + dataset=dataset, + batch_size=batch_size, + num_workers=num_workers, + shuffle=False, + pin_memory=True, + drop_last=False, + **kwargs + ) + total_logits = [] total_paths = [] - with tqdm(total=len(dataloader)) as pbar: + with tqdm(total=len(dataloader)) as pbar: for batch in dataloader: imgs, paths = batch imgs = imgs.to(self.device) diff --git a/PytorchWildlife/models/detection/__init__.py b/PytorchWildlife/models/detection/__init__.py index 7d42c184d..63b2c1840 100644 --- a/PytorchWildlife/models/detection/__init__.py +++ b/PytorchWildlife/models/detection/__init__.py @@ -1,4 +1,4 @@ -from .ultralytics_based import * from .herdnet import * +from .rtdetr_apache import * +from .ultralytics_based import * from .yolo_mit import * -from .rtdetr_apache import * \ No newline at end of file diff --git a/PytorchWildlife/models/detection/base_detector.py b/PytorchWildlife/models/detection/base_detector.py index 00cbfc5cc..63300e1c3 100644 --- a/PytorchWildlife/models/detection/base_detector.py +++ b/PytorchWildlife/models/detection/base_detector.py @@ -6,12 +6,13 @@ # Importing basic libraries from torch import nn + class BaseDetector(nn.Module): """ Base detector class. This class provides utility methods for loading the model, generating results, and performing single and batch image detections. """ - + # Placeholder class-level attributes to be defined in derived classes IMAGE_SIZE = None STRIDE = None @@ -21,29 +22,28 @@ class BaseDetector(nn.Module): def __init__(self, weights=None, device="cpu", url=None): """ Initialize the base detector. - + Args: - weights (str, optional): + weights (str, optional): Path to the model weights. Defaults to None. - device (str, optional): + device (str, optional): Device for model inference. Defaults to "cpu". - url (str, optional): + url (str, optional): URL to fetch the model weights. Defaults to None. """ super(BaseDetector, self).__init__() self.device = device - def _load_model(self, weights=None, device="cpu", url=None): """ Load model weights. - + Args: - weights (str, optional): + weights (str, optional): Path to the model weights. Defaults to None. - device (str, optional): + device (str, optional): Device for model inference. Defaults to "cpu". - url (str, optional): + url (str, optional): URL to fetch the model weights. Defaults to None. Raises: Exception: If weights are not provided. @@ -64,20 +64,22 @@ def results_generation(self, preds, img_id: str, id_strip: str = None) -> dict: """ pass - def single_image_detection(self, img, img_size=None, img_path=None, conf_thres=0.2, id_strip=None) -> dict: + def single_image_detection( + self, img, img_size=None, img_path=None, conf_thres=0.2, id_strip=None + ) -> dict: """ Perform detection on a single image. - + Args: - img (str or ndarray): + img (str or ndarray): Image path or ndarray of images. - img_size (tuple): + img_size (tuple): Original image size. - img_path (str): + img_path (str): Image path or identifier. - conf_thres (float, optional): + conf_thres (float, optional): Confidence threshold for predictions. Defaults to 0.2. - id_strip (str, optional): + id_strip (str, optional): Characters to strip from img_id. Defaults to None. Returns: @@ -85,7 +87,9 @@ def single_image_detection(self, img, img_size=None, img_path=None, conf_thres=0 """ pass - def batch_image_detection(self, dataloader, conf_thres: float = 0.2, id_strip: str = None) -> list[dict]: + def batch_image_detection( + self, dataloader, conf_thres: float = 0.2, id_strip: str = None + ) -> list[dict]: """ Perform detection on a batch of images. diff --git a/PytorchWildlife/models/detection/herdnet/__init__.py b/PytorchWildlife/models/detection/herdnet/__init__.py index b98d2f565..bf695c132 100644 --- a/PytorchWildlife/models/detection/herdnet/__init__.py +++ b/PytorchWildlife/models/detection/herdnet/__init__.py @@ -1 +1 @@ -from .herdnet import * \ No newline at end of file +from .herdnet import * diff --git a/PytorchWildlife/models/detection/herdnet/animaloc/__init__.py b/PytorchWildlife/models/detection/herdnet/animaloc/__init__.py index 74c0b9052..34d34cd52 100644 --- a/PytorchWildlife/models/detection/herdnet/animaloc/__init__.py +++ b/PytorchWildlife/models/detection/herdnet/animaloc/__init__.py @@ -1,5 +1,4 @@ -__copyright__ = \ - """ +__copyright__ = """ Copyright (C) 2024 University of Liège, Gembloux Agro-Bio Tech, Forest Is Life All rights reserved. diff --git a/PytorchWildlife/models/detection/herdnet/animaloc/data/__init__.py b/PytorchWildlife/models/detection/herdnet/animaloc/data/__init__.py index 649204b5d..e041cb081 100644 --- a/PytorchWildlife/models/detection/herdnet/animaloc/data/__init__.py +++ b/PytorchWildlife/models/detection/herdnet/animaloc/data/__init__.py @@ -1,5 +1,4 @@ -__copyright__ = \ - """ +__copyright__ = """ Copyright (C) 2024 University of Liège, Gembloux Agro-Bio Tech, Forest Is Life All rights reserved. diff --git a/PytorchWildlife/models/detection/herdnet/animaloc/data/patches.py b/PytorchWildlife/models/detection/herdnet/animaloc/data/patches.py index 2eb05ba12..5c41bd96c 100644 --- a/PytorchWildlife/models/detection/herdnet/animaloc/data/patches.py +++ b/PytorchWildlife/models/detection/herdnet/animaloc/data/patches.py @@ -1,5 +1,4 @@ -__copyright__ = \ - """ +__copyright__ = """ Copyright (C) 2024 University of Liège, Gembloux Agro-Bio Tech, Forest Is Life All rights reserved. @@ -15,41 +14,39 @@ import os +from typing import Tuple, Union + +import matplotlib.pyplot as plt +import numpy +import pandas import PIL import torch -import pandas -import numpy -import matplotlib.pyplot as plt import torchvision from torchvision.utils import make_grid, save_image - -from typing import Union, Tuple - from tqdm import tqdm from .types import BoundingBox -__all__ = ['ImageToPatches'] +__all__ = ["ImageToPatches"] + class ImageToPatches: - ''' Class to make patches from a tensor image ''' + """Class to make patches from a tensor image""" def __init__( - self, - image: Union[PIL.Image.Image, torch.Tensor], - size: Tuple[int,int], - overlap: int = 0 - ) -> None: - ''' + self, image: Union[PIL.Image.Image, torch.Tensor], size: Tuple[int, int], overlap: int = 0 + ) -> None: + """ Args: image (PIL.Image.Image or torch.Tensor): image, if tensor: (C,H,W) size (tuple): patches size (height, width), in pixels - overlap (int, optional): overlap between patches, in pixels. + overlap (int, optional): overlap between patches, in pixels. Defaults to 0. - ''' + """ - assert isinstance(image, (PIL.Image.Image, torch.Tensor)), \ - 'image must be a PIL.Image.Image or a torch.Tensor instance' + assert isinstance( + image, (PIL.Image.Image, torch.Tensor) + ), "image must be a PIL.Image.Image or a torch.Tensor instance" self.image = image if isinstance(self.image, PIL.Image.Image): @@ -57,32 +54,34 @@ def __init__( self.size = size self.overlap = overlap - + def make_patches(self) -> torch.Tensor: - ''' Make patches from the image + """Make patches from the image - When the image division is not perfect, a zero-padding is performed + When the image division is not perfect, a zero-padding is performed so that the patches have the same size. Returns: torch.Tensor: patches of shape (B,C,H,W) - ''' + """ # patches' height & width - height = min(self.image.size(1),self.size[0]) - width = min(self.image.size(2),self.size[1]) + height = min(self.image.size(1), self.size[0]) + width = min(self.image.size(2), self.size[1]) - # unfold on height + # unfold on height height_fold = self.image.unfold(1, height, height - self.overlap) # if non-perfect division on height residual = self._img_residual(self.image.size(1), height, self.overlap) if residual != 0: # get the residual patch and add it to the fold - remaining_height = torch.zeros(3, 1, self.image.size(2), height) # padding - remaining_height[:,:,:,:residual] = self.image[:,-residual:,:].permute(0,2,1).unsqueeze(1) + remaining_height = torch.zeros(3, 1, self.image.size(2), height) # padding + remaining_height[:, :, :, :residual] = ( + self.image[:, -residual:, :].permute(0, 2, 1).unsqueeze(1) + ) - height_fold = torch.cat((height_fold,remaining_height),dim=1) + height_fold = torch.cat((height_fold, remaining_height), dim=1) # unfold on width fold = height_fold.unfold(2, width, width - self.overlap) @@ -90,20 +89,22 @@ def make_patches(self) -> torch.Tensor: # if non-perfect division on width, the same residual = self._img_residual(self.image.size(2), width, self.overlap) if residual != 0: - remaining_width = torch.zeros(3, fold.shape[1], 1, height, width) # padding - remaining_width[:,:,:,:,:residual] = height_fold[:,:,-residual:,:].permute(0,1,3,2).unsqueeze(2) + remaining_width = torch.zeros(3, fold.shape[1], 1, height, width) # padding + remaining_width[:, :, :, :, :residual] = ( + height_fold[:, :, -residual:, :].permute(0, 1, 3, 2).unsqueeze(2) + ) - fold = torch.cat((fold,remaining_width),dim=2) + fold = torch.cat((fold, remaining_width), dim=2) - self._nrow , self._ncol = fold.shape[2] , fold.shape[1] + self._nrow, self._ncol = fold.shape[2], fold.shape[1] # reshaping - patches = fold.permute(1,2,0,3,4).reshape(-1,self.image.size(0),height,width) + patches = fold.permute(1, 2, 0, 3, 4).reshape(-1, self.image.size(0), height, width) return patches - + def get_limits(self) -> dict: - ''' Get patches limits within the image frame + """Get patches limits within the image frame When the image division is not perfect, the zero-padding is not considered here. Hence, the limits are the true limits of patches @@ -112,69 +113,64 @@ def get_limits(self) -> dict: Returns: dict: a dict containing int as key and BoundingBox as value - ''' + """ # patches' height & width - height = min(self.image.size(1),self.size[0]) - width = min(self.image.size(2),self.size[1]) + height = min(self.image.size(1), self.size[0]) + width = min(self.image.size(2), self.size[1]) # lists of pixels numbers - y_pixels = torch.tensor(list(range(0,self.image.size(1)+1))) - x_pixels = torch.tensor(list(range(0,self.image.size(2)+1))) + y_pixels = torch.tensor(list(range(0, self.image.size(1) + 1))) + x_pixels = torch.tensor(list(range(0, self.image.size(2) + 1))) # cut into patches to get limits - y_pixels_fold = y_pixels.unfold(0, height+1, height-self.overlap) + y_pixels_fold = y_pixels.unfold(0, height + 1, height - self.overlap) y_mina = [int(patch[0]) for patch in y_pixels_fold] y_maxa = [int(patch[-1]) for patch in y_pixels_fold] - x_pixels_fold = x_pixels.unfold(0, width+1, width-self.overlap) + x_pixels_fold = x_pixels.unfold(0, width + 1, width - self.overlap) x_mina = [int(patch[0]) for patch in x_pixels_fold] x_maxa = [int(patch[-1]) for patch in x_pixels_fold] # if non-perfect division on height residual = self._img_residual(self.image.size(1), height, self.overlap) if residual != 0: - remaining_y = y_pixels[-residual-1:].unsqueeze(0)[0] + remaining_y = y_pixels[-residual - 1 :].unsqueeze(0)[0] y_mina.append(int(remaining_y[0])) y_maxa.append(int(remaining_y[-1])) - # if non-perfect division on width + # if non-perfect division on width residual = self._img_residual(self.image.size(2), width, self.overlap) if residual != 0: - remaining_x = x_pixels[-residual-1:].unsqueeze(0)[0] + remaining_x = x_pixels[-residual - 1 :].unsqueeze(0)[0] x_mina.append(int(remaining_x[0])) x_maxa.append(int(remaining_x[-1])) - + i = 0 patches_limits = {} - for y_min , y_max in zip(y_mina,y_maxa): - for x_min , x_max in zip(x_mina,x_maxa): - patches_limits[i] = BoundingBox(x_min,y_min,x_max,y_max) + for y_min, y_max in zip(y_mina, y_maxa): + for x_min, x_max in zip(x_mina, x_maxa): + patches_limits[i] = BoundingBox(x_min, y_min, x_max, y_max) i += 1 - + return patches_limits - + def show(self) -> None: - ''' Show the grid of patches ''' + """Show the grid of patches""" - grid = make_grid( - self.make_patches(), - padding=50, - nrow=self._nrow - ).permute(1,2,0).numpy() + grid = make_grid(self.make_patches(), padding=50, nrow=self._nrow).permute(1, 2, 0).numpy() plt.imshow(grid) plt.show() return grid - - def _img_residual(self, ims: int, ks: int, overlap: int) -> int: + def _img_residual(self, ims: int, ks: int, overlap: int) -> int: ims, stride = int(ims), int(ks - overlap) n = ims // stride end = n * stride + overlap - + residual = ims % stride if end > ims: @@ -182,6 +178,6 @@ def _img_residual(self, ims: int, ks: int, overlap: int) -> int: residual = ims - (n * stride) return residual - + def __len__(self) -> int: return len(self.get_limits()) diff --git a/PytorchWildlife/models/detection/herdnet/animaloc/data/types.py b/PytorchWildlife/models/detection/herdnet/animaloc/data/types.py index bd650ac7b..a5e0ad96d 100644 --- a/PytorchWildlife/models/detection/herdnet/animaloc/data/types.py +++ b/PytorchWildlife/models/detection/herdnet/animaloc/data/types.py @@ -1,5 +1,4 @@ -__copyright__ = \ - """ +__copyright__ = """ Copyright (C) 2024 University of Liège, Gembloux Agro-Bio Tech, Forest Is Life All rights reserved. @@ -13,116 +12,117 @@ __license__ = "MIT License" __version__ = "0.2.1" -from typing import Union, Tuple +from typing import Tuple, Union + +__all__ = ["Point", "BoundingBox"] -__all__ = ['Point', 'BoundingBox'] class Point: - ''' Class to define a Point object in a 2D Cartesian + """Class to define a Point object in a 2D Cartesian coordinate system. - ''' + """ - def __init__(self, x: Union[int,float], y: Union[int,float]) -> None: - ''' + def __init__(self, x: Union[int, float], y: Union[int, float]) -> None: + """ Args: x (int, float): x coordinate y (int, float): y coordinate - ''' + """ - assert x >= 0 and y >= 0, f'Coordinates must be positives, got x={x} and y={y}' + assert x >= 0 and y >= 0, f"Coordinates must be positives, got x={x} and y={y}" self.x = x self.y = y - self.area = 1 # always 1 pixel - + self.area = 1 # always 1 pixel + # @property # def area(self) -> int: # ''' To get area ''' # return 1 # always 1 pixel - + @property - def get_tuple(self) -> Tuple[Union[int,float],Union[int,float]]: - ''' To get point's coordinates in tuple ''' - return (self.x,self.y) + def get_tuple(self) -> Tuple[Union[int, float], Union[int, float]]: + """To get point's coordinates in tuple""" + return (self.x, self.y) @property def atype(self) -> str: - ''' To get annotation type string ''' - return 'Point' - + """To get annotation type string""" + return "Point" + def __repr__(self) -> str: - return f'Point(x: {self.x}, y: {self.y})' - + return f"Point(x: {self.x}, y: {self.y})" + def __eq__(self, other) -> bool: - return all([ - self.x == other.x, - self.y == other.y - ]) + return all([self.x == other.x, self.y == other.y]) + class BoundingBox: - ''' Class to define a BoundingBox object in a 2D Cartesian + """Class to define a BoundingBox object in a 2D Cartesian coordinate system. - ''' + """ def __init__( - self, - x_min: Union[int,float], - y_min: Union[int,float], - x_max: Union[int,float], - y_max: Union[int,float] - ) -> None: - ''' + self, + x_min: Union[int, float], + y_min: Union[int, float], + x_max: Union[int, float], + y_max: Union[int, float], + ) -> None: + """ Args: x_min (int, float): x bbox top-left coordinate y_min (int, float): y bbox top-left coordinate x_max (int, float): x bbox bottom-right coordinate y_max (int, float): y bbox bottom-right coordinate - ''' + """ - assert all([c >= 0 for c in [x_min,y_min,x_max,y_max]]), \ - f'Coordinates must be positives, got x_min={x_min}, y_min={y_min}, ' \ - f'x_max={x_max} and y_max={y_max}' + assert all([c >= 0 for c in [x_min, y_min, x_max, y_max]]), ( + f"Coordinates must be positives, got x_min={x_min}, y_min={y_min}, " + f"x_max={x_max} and y_max={y_max}" + ) - assert x_max >= x_min and y_max >= y_min, \ - 'Wrong bounding box coordinates.' + assert x_max >= x_min and y_max >= y_min, "Wrong bounding box coordinates." self.x_min = x_min self.y_min = y_min self.x_max = x_max self.y_max = y_max - + @property - def area(self) -> Union[int,float]: - ''' To get bbox area ''' + def area(self) -> Union[int, float]: + """To get bbox area""" return max(0, self.width) * max(0, self.height) - + @property - def width(self) -> Union[int,float]: - ''' To get bbox width ''' + def width(self) -> Union[int, float]: + """To get bbox width""" return max(0, self.x_max - self.x_min) - + @property - def height(self) -> Union[int,float]: - ''' To get bbox height ''' + def height(self) -> Union[int, float]: + """To get bbox height""" return max(0, self.y_max - self.y_min) - + @property - def get_tuple(self) -> Tuple[Union[int,float],...]: - ''' To get bbox coordinates in tuple type ''' - return (self.x_min,self.y_min,self.x_max,self.y_max) - + def get_tuple(self) -> Tuple[Union[int, float], ...]: + """To get bbox coordinates in tuple type""" + return (self.x_min, self.y_min, self.x_max, self.y_max) + @property def atype(self) -> str: - ''' To get annotation type string ''' - return 'BoundingBox' + """To get annotation type string""" + return "BoundingBox" def __repr__(self) -> str: - return f'BoundingBox(x_min: {self.x_min}, y_min: {self.y_min}, x_max: {self.x_max}, y_max: {self.y_max})' + return f"BoundingBox(x_min: {self.x_min}, y_min: {self.y_min}, x_max: {self.x_max}, y_max: {self.y_max})" def __eq__(self, other) -> bool: - return all([ - self.x_min == other.x_min, - self.y_min == other.y_min, - self.x_max == other.x_max, - self.y_max == other.y_max - ]) \ No newline at end of file + return all( + [ + self.x_min == other.x_min, + self.y_min == other.y_min, + self.x_max == other.x_max, + self.y_max == other.y_max, + ] + ) diff --git a/PytorchWildlife/models/detection/herdnet/animaloc/eval/__init__.py b/PytorchWildlife/models/detection/herdnet/animaloc/eval/__init__.py index cb7dc2dc4..bd5ac7932 100644 --- a/PytorchWildlife/models/detection/herdnet/animaloc/eval/__init__.py +++ b/PytorchWildlife/models/detection/herdnet/animaloc/eval/__init__.py @@ -1,5 +1,4 @@ -__copyright__ = \ - """ +__copyright__ = """ Copyright (C) 2024 University of Liège, Gembloux Agro-Bio Tech, Forest Is Life All rights reserved. @@ -13,5 +12,5 @@ __license__ = "MIT License" __version__ = "0.2.1" -from .stitchers import * from .lmds import * +from .stitchers import * diff --git a/PytorchWildlife/models/detection/herdnet/animaloc/eval/lmds.py b/PytorchWildlife/models/detection/herdnet/animaloc/eval/lmds.py index cda1fe9fd..1e3d9d458 100644 --- a/PytorchWildlife/models/detection/herdnet/animaloc/eval/lmds.py +++ b/PytorchWildlife/models/detection/herdnet/animaloc/eval/lmds.py @@ -1,5 +1,4 @@ -__copyright__ = \ - """ +__copyright__ = """ Copyright (C) 2024 University of Liège, Gembloux Agro-Bio Tech, Forest Is Life All rights reserved. @@ -13,56 +12,52 @@ __license__ = "MIT License" __version__ = "0.2.1" -import torch -import numpy +from typing import List, Tuple +import numpy +import torch import torch.nn.functional as F -from typing import Tuple, List - -__all__ = ['LMDS', 'HerdNetLMDS'] +__all__ = ["LMDS", "HerdNetLMDS"] class LMDS: - ''' Local Maxima Detection Strategy + """Local Maxima Detection Strategy Adapted and enhanced from https://github.com/dk-liang/FIDTM (author: dklinag) - available under the MIT license ''' + available under the MIT license""" def __init__( - self, - kernel_size: tuple = (3,3), - adapt_ts: float = 100.0/255.0, - neg_ts: float = 0.1 - ) -> None: - ''' + self, kernel_size: tuple = (3, 3), adapt_ts: float = 100.0 / 255.0, neg_ts: float = 0.1 + ) -> None: + """ Args: kernel_size (tuple, optional): size of the kernel used to select local maxima. Defaults to (3,3) (as in the paper). adapt_ts (float, optional): adaptive threshold to select final points from candidates. Defaults to 100.0/255.0 (as in the paper). - neg_ts (float, optional): negative sample threshold used to define if + neg_ts (float, optional): negative sample threshold used to define if an image is a negative sample or not. Defaults to 0.1 (as in the paper). - ''' + """ - assert kernel_size[0] == kernel_size[1], \ - f'The kernel shape must be a square, got {kernel_size[0]}x{kernel_size[1]}' - assert not kernel_size[0] % 2 == 0, \ - f'The kernel size must be odd, got {kernel_size[0]}' + assert ( + kernel_size[0] == kernel_size[1] + ), f"The kernel shape must be a square, got {kernel_size[0]}x{kernel_size[1]}" + assert not kernel_size[0] % 2 == 0, f"The kernel size must be odd, got {kernel_size[0]}" self.kernel_size = tuple(kernel_size) self.adapt_ts = adapt_ts self.neg_ts = neg_ts - def __call__(self, est_map: torch.Tensor) -> Tuple[list,list,list,list]: - ''' + def __call__(self, est_map: torch.Tensor) -> Tuple[list, list, list, list]: + """ Args: est_map (torch.Tensor): the estimated FIDT map - + Returns: Tuple[list,list,list,list] counts, labels, scores and locations per batch - ''' + """ batch_size, classes = est_map.shape[:2] b_counts, b_labels, b_scores, b_locs = [], [], [], [] @@ -72,7 +67,7 @@ def __call__(self, est_map: torch.Tensor) -> Tuple[list,list,list,list]: for c in range(classes): count, loc, score = self._lmds(est_map[b][c]) counts.append(count) - labels = [*labels, *[c+1]*count] + labels = [*labels, *[c + 1] * count] scores = [*scores, *score] locs = [*locs, *loc] @@ -82,36 +77,36 @@ def __call__(self, est_map: torch.Tensor) -> Tuple[list,list,list,list]: b_locs.append(locs) return b_counts, b_locs, b_labels, b_scores - + def _local_max(self, est_map: torch.Tensor) -> torch.Tensor: - ''' Shape: est_map = [B,C,H,W] ''' + """Shape: est_map = [B,C,H,W]""" pad = int(self.kernel_size[0] / 2) - keep = torch.nn.functional.max_pool2d(est_map, kernel_size=self.kernel_size, stride=1, padding=pad) + keep = torch.nn.functional.max_pool2d( + est_map, kernel_size=self.kernel_size, stride=1, padding=pad + ) keep = (keep == est_map).float() est_map = keep * est_map return est_map - + def _get_locs_and_scores( - self, - locs_map: torch.Tensor, - scores_map: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - ''' Shapes: locs_map = [H,W] and scores_map = [H,W] ''' + self, locs_map: torch.Tensor, scores_map: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Shapes: locs_map = [H,W] and scores_map = [H,W]""" locs_map = locs_map.data.cpu().numpy() scores_map = scores_map.data.cpu().numpy() locs = [] scores = [] - for i, j in numpy.argwhere(locs_map == 1): - locs.append((i,j)) + for i, j in numpy.argwhere(locs_map == 1): + locs.append((i, j)) scores.append(scores_map[i][j]) - + return torch.Tensor(locs), torch.Tensor(scores) - + def _lmds(self, est_map: torch.Tensor) -> Tuple[int, list, list]: - ''' Shape: est_map = [H,W] ''' + """Shape: est_map = [H,W]""" est_map_max = torch.max(est_map).item() @@ -132,22 +127,21 @@ def _lmds(self, est_map: torch.Tensor) -> Tuple[int, list, list]: # locations and scores locs, scores = self._get_locs_and_scores( - est_map.squeeze(0).squeeze(0), - scores_map.squeeze(0).squeeze(0) - ) + est_map.squeeze(0).squeeze(0), scores_map.squeeze(0).squeeze(0) + ) return count, locs.tolist(), scores.tolist() -class HerdNetLMDS(LMDS): +class HerdNetLMDS(LMDS): def __init__( - self, - up: bool = True, - kernel_size: tuple = (3,3), - adapt_ts: float = 0.3, - neg_ts: float = 0.1 - ) -> None: - ''' + self, + up: bool = True, + kernel_size: tuple = (3, 3), + adapt_ts: float = 0.3, + neg_ts: float = 0.1, + ) -> None: + """ Args: up (bool, optional): set to False to disable class maps upsampling. Defaults to True. @@ -155,14 +149,14 @@ def __init__( maxima. Defaults to (3,3) (as in the paper). adapt_ts (float, optional): adaptive threshold to select final points from candidates. Defaults to 0.3. - neg_ts (float, optional): negative sample threshold used to define if + neg_ts (float, optional): negative sample threshold used to define if an image is a negative sample or not. Defaults to 0.1 (as in the paper). - ''' + """ super().__init__(kernel_size=kernel_size, adapt_ts=adapt_ts, neg_ts=neg_ts) self.up = up - + def __call__(self, outputs: List[torch.Tensor]) -> Tuple[list, list, list, list, list]: """ Args: @@ -176,14 +170,14 @@ def __call__(self, outputs: List[torch.Tensor]) -> Tuple[list, list, list, list, """ heatmap, clsmap = outputs - + # upsample class map if self.up: scale_factor = 16 - clsmap = F.interpolate(clsmap, scale_factor=scale_factor, mode='nearest') + clsmap = F.interpolate(clsmap, scale_factor=scale_factor, mode="nearest") # softmax - cls_scores = torch.softmax(clsmap, dim=1)[:,1:,:,:] + cls_scores = torch.softmax(clsmap, dim=1)[:, 1:, :, :] # cat to heatmap outmaps = torch.cat([heatmap, cls_scores], dim=1) @@ -193,10 +187,9 @@ def __call__(self, outputs: List[torch.Tensor]) -> Tuple[list, list, list, list, b_counts, b_labels, b_scores, b_locs, b_dscores = [], [], [], [], [] for b in range(batch_size): - _, locs, _ = self._lmds(heatmap[b][0]) - cls_idx = torch.argmax(clsmap[b,1:,:,:], dim=0) + cls_idx = torch.argmax(clsmap[b, 1:, :, :], dim=0) classes = torch.add(cls_idx, 1) h_idx = torch.Tensor([l[0] for l in locs]).long() @@ -216,4 +209,4 @@ def __call__(self, outputs: List[torch.Tensor]) -> Tuple[list, list, list, list, b_counts.append(counts) b_dscores.append(dscores) - return b_counts, b_locs, b_labels, b_scores, b_dscores \ No newline at end of file + return b_counts, b_locs, b_labels, b_scores, b_dscores diff --git a/PytorchWildlife/models/detection/herdnet/animaloc/eval/stitchers.py b/PytorchWildlife/models/detection/herdnet/animaloc/eval/stitchers.py index f22ddcd17..bda49b8c1 100644 --- a/PytorchWildlife/models/detection/herdnet/animaloc/eval/stitchers.py +++ b/PytorchWildlife/models/detection/herdnet/animaloc/eval/stitchers.py @@ -1,5 +1,4 @@ -__copyright__ = \ - """ +__copyright__ = """ Copyright (C) 2024 University of Liège, Gembloux Agro-Bio Tech, Forest Is Life All rights reserved. @@ -14,20 +13,20 @@ __version__ = "0.2.1" -import torch -import torchvision +from typing import List, Tuple -import torch.nn.functional as F import numpy as np - -from typing import List, Tuple -from torch.utils.data import TensorDataset, DataLoader, SequentialSampler +import torch +import torch.nn.functional as F +import torchvision +from torch.utils.data import DataLoader, SequentialSampler, TensorDataset from ..data import ImageToPatches + class Stitcher(ImageToPatches): - ''' Class to stitch detections of patches into original image - coordinates system + """Class to stitch detections of patches into original image + coordinates system This algorithm works as follow: 1) Cut original image into patches @@ -35,43 +34,44 @@ class Stitcher(ImageToPatches): 3) Patch the detections maps into the coordinate system of the original image Optional: 4) Upsample the patched detection map - ''' + """ def __init__( self, - model: torch.nn.Module, - size: Tuple[int,int], + model: torch.nn.Module, + size: Tuple[int, int], overlap: int = 100, batch_size: int = 1, down_ratio: int = 1, up: bool = False, - reduction: str = 'sum', - device_name: str = 'cuda', - ) -> None: - ''' + reduction: str = "sum", + device_name: str = "cuda", + ) -> None: + """ Args: model (torch.nn.Module): CNN detection model, that takes as inputs image and returns output and dict (i.e. wrapped by LossWrapper) size (tuple): patches size (height, width), in pixels - overlap (int, optional): overlap between patches, in pixels. - Defaults to 100. - batch_size (int, optional): batch size used for inference over patches. + overlap (int, optional): overlap between patches, in pixels. + Defaults to 100. + batch_size (int, optional): batch size used for inference over patches. Defaults to 1. - down_ratio (int, optional): downsample ratio. Set to 1 to get output of the same + down_ratio (int, optional): downsample ratio. Set to 1 to get output of the same size as input (i.e. no downsample). Defaults to 1. up (bool, optional): set to True to upsample the patched map. Defaults to False. reduction (str, optional): specifies the reduction to apply on overlapping areas. Possible values are 'sum', 'mean', 'max'. Defaults to 'sum'. - device_name (str, optional): the device name on which tensors will be allocated + device_name (str, optional): the device name on which tensors will be allocated ('cpu' or 'cuda'). Defaults to 'cuda'. - ''' + """ - assert isinstance(model, torch.nn.Module), \ - 'model argument must be an instance of nn.Module()' - - assert reduction in ['sum', 'mean', 'max'], \ - 'reduction argument possible values are \'sum\', \'mean\' and \'max\' ' \ - f'got \'{reduction}\'' + assert isinstance( + model, torch.nn.Module + ), "model argument must be an instance of nn.Module()" + + assert reduction in ["sum", "mean", "max"], ( + "reduction argument possible values are 'sum', 'mean' and 'max' " f"got '{reduction}'" + ) self.model = model self.size = size @@ -84,23 +84,20 @@ def __init__( self.model.to(self.device) - def __call__( - self, - image: torch.Tensor - ) -> torch.Tensor: - ''' Apply the stitching algorithm to the image + def __call__(self, image: torch.Tensor) -> torch.Tensor: + """Apply the stitching algorithm to the image Args: image (torch.Tensor): image of shape [C,H,W] - + Returns: torch.Tensor the detections into the coordinate system of the original image - ''' + """ super(Stitcher, self).__init__(image, self.size, self.overlap) - - self.image = image.to(torch.device('cpu')) + + self.image = image.to(torch.device("cpu")) # step 1 - get patches and limits patches = self.make_patches() @@ -114,23 +111,20 @@ def __call__( # (step 4 - upsample) if self.up: - patched_map = F.interpolate(patched_map, scale_factor=self.down_ratio, - mode='bilinear', align_corners=True) + patched_map = F.interpolate( + patched_map, scale_factor=self.down_ratio, mode="bilinear", align_corners=True + ) return patched_map - @torch.no_grad() def _inference(self, patches: torch.Tensor) -> List[torch.Tensor]: - self.model.eval() dataset = TensorDataset(patches) dataloader = DataLoader( - dataset, - batch_size=self.batch_size, - sampler=SequentialSampler(dataset) - ) + dataset, batch_size=self.batch_size, sampler=SequentialSampler(dataset) + ) maps = [] for patch in dataloader: @@ -141,89 +135,84 @@ def _inference(self, patches: torch.Tensor) -> List[torch.Tensor]: return maps def _patch_maps(self, maps: List[torch.Tensor]) -> torch.Tensor: - _, h, w = self.image.shape dh, dw = h // self.down_ratio, w // self.down_ratio kernel_size = np.array(self.size) // self.down_ratio stride = kernel_size - self.overlap // self.down_ratio output_size = ( - self._ncol * kernel_size[0] - ((self._ncol-1) * self.overlap // self.down_ratio), - self._nrow * kernel_size[1] - ((self._nrow-1) * self.overlap // self.down_ratio) - ) + self._ncol * kernel_size[0] - ((self._ncol - 1) * self.overlap // self.down_ratio), + self._nrow * kernel_size[1] - ((self._nrow - 1) * self.overlap // self.down_ratio), + ) maps = torch.cat(maps, dim=0) - if self.reduction == 'max': - out_map = self._max_fold(maps, output_size=output_size, - kernel_size=tuple(kernel_size), stride=tuple(stride)) + if self.reduction == "max": + out_map = self._max_fold( + maps, output_size=output_size, kernel_size=tuple(kernel_size), stride=tuple(stride) + ) else: n_patches = maps.shape[0] - maps = maps.permute(1,2,3,0).contiguous().view(1, -1, n_patches) - out_map = F.fold(maps, output_size=output_size, - kernel_size=tuple(kernel_size), stride=tuple(stride)) + maps = maps.permute(1, 2, 3, 0).contiguous().view(1, -1, n_patches) + out_map = F.fold( + maps, output_size=output_size, kernel_size=tuple(kernel_size), stride=tuple(stride) + ) - out_map = out_map[:,:, 0:dh, 0:dw] + out_map = out_map[:, :, 0:dh, 0:dw] return out_map - - def _reduce(self, map: torch.Tensor) -> torch.Tensor: + def _reduce(self, map: torch.Tensor) -> torch.Tensor: dh = self.image.shape[1] // self.down_ratio dw = self.image.shape[2] // self.down_ratio - ones = torch.ones(self.image.shape[0],dh,dw) + ones = torch.ones(self.image.shape[0], dh, dw) - if self.reduction == 'mean': - ones_patches = ImageToPatches(ones, - np.array(self.size)//self.down_ratio, - self.overlap//self.down_ratio - ).make_patches() + if self.reduction == "mean": + ones_patches = ImageToPatches( + ones, np.array(self.size) // self.down_ratio, self.overlap // self.down_ratio + ).make_patches() - ones_patches = [p.unsqueeze(0).unsqueeze(0) for p in ones_patches[:,1,:,:]] + ones_patches = [p.unsqueeze(0).unsqueeze(0) for p in ones_patches[:, 1, :, :]] norm_map = self._patch_maps(ones_patches) - + else: - norm_map = ones[1,:,:] - + norm_map = ones[1, :, :] + return torch.div(map.to(self.device), norm_map.to(self.device)) - - def _max_fold(self, maps: torch.Tensor, output_size: tuple, - kernel_size: tuple, stride: tuple - ) -> torch.Tensor: - + + def _max_fold( + self, maps: torch.Tensor, output_size: tuple, kernel_size: tuple, stride: tuple + ) -> torch.Tensor: output = torch.zeros((1, maps.shape[1], *output_size)) - fn = lambda x: [[i, i+kernel_size[x]] for i in range(0, output_size[x], stride[x])][:-1] + fn = lambda x: [[i, i + kernel_size[x]] for i in range(0, output_size[x], stride[x])][:-1] locs = [[*h, *w] for h in fn(0) for w in fn(1)] for loc, m in zip(locs, maps): patch = torch.zeros(output.shape) - patch[:,:, loc[0]:loc[1], loc[2]:loc[3]] = m + patch[:, :, loc[0] : loc[1], loc[2] : loc[3]] = m output = torch.max(output, patch) return output -class HerdNetStitcher(Stitcher): +class HerdNetStitcher(Stitcher): @torch.no_grad() def _inference(self, patches: torch.Tensor) -> List[torch.Tensor]: - self.model.eval() dataset = TensorDataset(patches) dataloader = DataLoader( - dataset, - batch_size=self.batch_size, - sampler=SequentialSampler(dataset) - ) + dataset, batch_size=self.batch_size, sampler=SequentialSampler(dataset) + ) maps = [] for patch in dataloader: patch = patch[0].to(self.device) - #outputs = self.model(patch)[0] - outputs = self.model(patch) # LossWrapper is not used + # outputs = self.model(patch)[0] + outputs = self.model(patch) # LossWrapper is not used heatmap = outputs[0] scale_factor = 16 - clsmap = F.interpolate(outputs[1], scale_factor=scale_factor, mode='nearest') + clsmap = F.interpolate(outputs[1], scale_factor=scale_factor, mode="nearest") # cat outmaps = torch.cat([heatmap, clsmap], dim=1) maps = [*maps, *outmaps.unsqueeze(0)] diff --git a/PytorchWildlife/models/detection/herdnet/dla.py b/PytorchWildlife/models/detection/herdnet/dla.py index dd87cccff..8b08559ce 100644 --- a/PytorchWildlife/models/detection/herdnet/dla.py +++ b/PytorchWildlife/models/detection/herdnet/dla.py @@ -1,5 +1,4 @@ -__copyright__ = \ - """ +__copyright__ = """ MIT License Copyright (c) 2019 Xingyi Zhou @@ -23,35 +22,40 @@ from os.path import join from posixpath import basename +import numpy as np import torch -from torch import nn import torch.utils.model_zoo as model_zoo - -import numpy as np +from torch import nn BatchNorm = nn.BatchNorm2d -def get_model_url(data='imagenet', name='dla34', hash='ba72cf86'): - return join('http://dl.yf.io/dla/models', data, '{}-{}.pth'.format(name, hash)) + +def get_model_url(data="imagenet", name="dla34", hash="ba72cf86"): + return join("http://dl.yf.io/dla/models", data, "{}-{}.pth".format(name, hash)) def conv3x3(in_planes, out_planes, stride=1): "3x3 convolution with padding" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=1, bias=False) + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) class BasicBlock(nn.Module): def __init__(self, inplanes, planes, stride=1, dilation=1): super(BasicBlock, self).__init__() - self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, - stride=stride, padding=dilation, - bias=False, dilation=dilation) + self.conv1 = nn.Conv2d( + inplanes, + planes, + kernel_size=3, + stride=stride, + padding=dilation, + bias=False, + dilation=dilation, + ) self.bn1 = BatchNorm(planes) self.relu = nn.ReLU(inplace=True) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, - stride=1, padding=dilation, - bias=False, dilation=dilation) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, stride=1, padding=dilation, bias=False, dilation=dilation + ) self.bn2 = BatchNorm(planes) self.stride = stride @@ -79,15 +83,19 @@ def __init__(self, inplanes, planes, stride=1, dilation=1): super(Bottleneck, self).__init__() expansion = Bottleneck.expansion bottle_planes = planes // expansion - self.conv1 = nn.Conv2d(inplanes, bottle_planes, - kernel_size=1, bias=False) + self.conv1 = nn.Conv2d(inplanes, bottle_planes, kernel_size=1, bias=False) self.bn1 = BatchNorm(bottle_planes) - self.conv2 = nn.Conv2d(bottle_planes, bottle_planes, kernel_size=3, - stride=stride, padding=dilation, - bias=False, dilation=dilation) + self.conv2 = nn.Conv2d( + bottle_planes, + bottle_planes, + kernel_size=3, + stride=stride, + padding=dilation, + bias=False, + dilation=dilation, + ) self.bn2 = BatchNorm(bottle_planes) - self.conv3 = nn.Conv2d(bottle_planes, planes, - kernel_size=1, bias=False) + self.conv3 = nn.Conv2d(bottle_planes, planes, kernel_size=1, bias=False) self.bn3 = BatchNorm(planes) self.relu = nn.ReLU(inplace=True) self.stride = stride @@ -121,15 +129,20 @@ def __init__(self, inplanes, planes, stride=1, dilation=1): super(BottleneckX, self).__init__() cardinality = BottleneckX.cardinality bottle_planes = planes * cardinality // 32 - self.conv1 = nn.Conv2d(inplanes, bottle_planes, - kernel_size=1, bias=False) + self.conv1 = nn.Conv2d(inplanes, bottle_planes, kernel_size=1, bias=False) self.bn1 = BatchNorm(bottle_planes) - self.conv2 = nn.Conv2d(bottle_planes, bottle_planes, kernel_size=3, - stride=stride, padding=dilation, bias=False, - dilation=dilation, groups=cardinality) + self.conv2 = nn.Conv2d( + bottle_planes, + bottle_planes, + kernel_size=3, + stride=stride, + padding=dilation, + bias=False, + dilation=dilation, + groups=cardinality, + ) self.bn2 = BatchNorm(bottle_planes) - self.conv3 = nn.Conv2d(bottle_planes, planes, - kernel_size=1, bias=False) + self.conv3 = nn.Conv2d(bottle_planes, planes, kernel_size=1, bias=False) self.bn3 = BatchNorm(planes) self.relu = nn.ReLU(inplace=True) self.stride = stride @@ -159,8 +172,8 @@ class Root(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, residual): super(Root, self).__init__() self.conv = nn.Conv2d( - in_channels, out_channels, 1, - stride=1, bias=False, padding=(kernel_size - 1) // 2) + in_channels, out_channels, 1, stride=1, bias=False, padding=(kernel_size - 1) // 2 + ) self.bn = BatchNorm(out_channels) self.relu = nn.ReLU(inplace=True) self.residual = residual @@ -177,31 +190,51 @@ def forward(self, *x): class Tree(nn.Module): - def __init__(self, levels, block, in_channels, out_channels, stride=1, - level_root=False, root_dim=0, root_kernel_size=1, - dilation=1, root_residual=False): + def __init__( + self, + levels, + block, + in_channels, + out_channels, + stride=1, + level_root=False, + root_dim=0, + root_kernel_size=1, + dilation=1, + root_residual=False, + ): super(Tree, self).__init__() if root_dim == 0: root_dim = 2 * out_channels if level_root: root_dim += in_channels if levels == 1: - self.tree1 = block(in_channels, out_channels, stride, - dilation=dilation) - self.tree2 = block(out_channels, out_channels, 1, - dilation=dilation) + self.tree1 = block(in_channels, out_channels, stride, dilation=dilation) + self.tree2 = block(out_channels, out_channels, 1, dilation=dilation) else: - self.tree1 = Tree(levels - 1, block, in_channels, out_channels, - stride, root_dim=0, - root_kernel_size=root_kernel_size, - dilation=dilation, root_residual=root_residual) - self.tree2 = Tree(levels - 1, block, out_channels, out_channels, - root_dim=root_dim + out_channels, - root_kernel_size=root_kernel_size, - dilation=dilation, root_residual=root_residual) + self.tree1 = Tree( + levels - 1, + block, + in_channels, + out_channels, + stride, + root_dim=0, + root_kernel_size=root_kernel_size, + dilation=dilation, + root_residual=root_residual, + ) + self.tree2 = Tree( + levels - 1, + block, + out_channels, + out_channels, + root_dim=root_dim + out_channels, + root_kernel_size=root_kernel_size, + dilation=dilation, + root_residual=root_residual, + ) if levels == 1: - self.root = Root(root_dim, out_channels, root_kernel_size, - root_residual) + self.root = Root(root_dim, out_channels, root_kernel_size, root_residual) self.level_root = level_root self.root_dim = root_dim self.downsample = None @@ -211,9 +244,8 @@ def __init__(self, levels, block, in_channels, out_channels, stride=1, self.downsample = nn.MaxPool2d(stride, stride=stride) if in_channels != out_channels: self.project = nn.Sequential( - nn.Conv2d(in_channels, out_channels, - kernel_size=1, stride=1, bias=False), - BatchNorm(out_channels) + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False), + BatchNorm(out_channels), ) def forward(self, x, residual=None, children=None): @@ -233,40 +265,74 @@ def forward(self, x, residual=None, children=None): class DLA(nn.Module): - def __init__(self, levels, channels, num_classes=1000, - block=BasicBlock, residual_root=False, return_levels=False, - pool_size=7, linear_root=False): + def __init__( + self, + levels, + channels, + num_classes=1000, + block=BasicBlock, + residual_root=False, + return_levels=False, + pool_size=7, + linear_root=False, + ): super(DLA, self).__init__() self.channels = channels self.return_levels = return_levels self.num_classes = num_classes self.base_layer = nn.Sequential( - nn.Conv2d(3, channels[0], kernel_size=7, stride=1, - padding=3, bias=False), + nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3, bias=False), BatchNorm(channels[0]), - nn.ReLU(inplace=True)) - self.level0 = self._make_conv_level( - channels[0], channels[0], levels[0]) - self.level1 = self._make_conv_level( - channels[0], channels[1], levels[1], stride=2) - self.level2 = Tree(levels[2], block, channels[1], channels[2], 2, - level_root=False, - root_residual=residual_root) - self.level3 = Tree(levels[3], block, channels[2], channels[3], 2, - level_root=True, root_residual=residual_root) - self.level4 = Tree(levels[4], block, channels[3], channels[4], 2, - level_root=True, root_residual=residual_root) - self.level5 = Tree(levels[5], block, channels[4], channels[5], 2, - level_root=True, root_residual=residual_root) + nn.ReLU(inplace=True), + ) + self.level0 = self._make_conv_level(channels[0], channels[0], levels[0]) + self.level1 = self._make_conv_level(channels[0], channels[1], levels[1], stride=2) + self.level2 = Tree( + levels[2], + block, + channels[1], + channels[2], + 2, + level_root=False, + root_residual=residual_root, + ) + self.level3 = Tree( + levels[3], + block, + channels[2], + channels[3], + 2, + level_root=True, + root_residual=residual_root, + ) + self.level4 = Tree( + levels[4], + block, + channels[3], + channels[4], + 2, + level_root=True, + root_residual=residual_root, + ) + self.level5 = Tree( + levels[5], + block, + channels[4], + channels[5], + 2, + level_root=True, + root_residual=residual_root, + ) self.avgpool = nn.AvgPool2d(pool_size) - self.fc = nn.Conv2d(channels[-1], num_classes, kernel_size=1, - stride=1, padding=0, bias=True) + self.fc = nn.Conv2d( + channels[-1], num_classes, kernel_size=1, stride=1, padding=0, bias=True + ) for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2. / n)) + m.weight.data.normal_(0, math.sqrt(2.0 / n)) elif isinstance(m, BatchNorm): m.weight.data.fill_(1) m.bias.data.zero_() @@ -276,8 +342,7 @@ def _make_level(self, block, inplanes, planes, blocks, stride=1): if stride != 1 or inplanes != planes: downsample = nn.Sequential( nn.MaxPool2d(stride, stride=stride), - nn.Conv2d(inplanes, planes, - kernel_size=1, stride=1, bias=False), + nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False), BatchNorm(planes), ) @@ -291,12 +356,21 @@ def _make_level(self, block, inplanes, planes, blocks, stride=1): def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1): modules = [] for i in range(convs): - modules.extend([ - nn.Conv2d(inplanes, planes, kernel_size=3, - stride=stride if i == 0 else 1, - padding=dilation, bias=False, dilation=dilation), - BatchNorm(planes), - nn.ReLU(inplace=True)]) + modules.extend( + [ + nn.Conv2d( + inplanes, + planes, + kernel_size=3, + stride=stride if i == 0 else 1, + padding=dilation, + bias=False, + dilation=dilation, + ), + BatchNorm(planes), + nn.ReLU(inplace=True), + ] + ) inplanes = planes return nn.Sequential(*modules) @@ -304,7 +378,7 @@ def forward(self, x): y = [] x = self.base_layer(x) for i in range(6): - x = getattr(self, 'level{}'.format(i))(x) + x = getattr(self, "level{}".format(i))(x) y.append(x) if self.return_levels: return y @@ -315,113 +389,121 @@ def forward(self, x): return x - def load_pretrained_model(self, data='imagenet', name='dla34', hash='ba72cf86'): + def load_pretrained_model(self, data="imagenet", name="dla34", hash="ba72cf86"): fc = self.fc - if name.endswith('.pth'): + if name.endswith(".pth"): model_weights = torch.load(data + name) else: model_url = get_model_url(data, name, hash) model_weights = model_zoo.load_url(model_url) num_classes = len(model_weights[list(model_weights.keys())[-1]]) self.fc = nn.Conv2d( - self.channels[-1], num_classes, - kernel_size=1, stride=1, padding=0, bias=True) + self.channels[-1], num_classes, kernel_size=1, stride=1, padding=0, bias=True + ) self.load_state_dict(model_weights) self.fc = fc def dla34(pretrained, **kwargs): # DLA-34 - model = DLA([1, 1, 1, 2, 2, 1], - [16, 32, 64, 128, 256, 512], - block=BasicBlock, **kwargs) + model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512], block=BasicBlock, **kwargs) if pretrained: - model.load_pretrained_model(data='imagenet', name='dla34', hash='ba72cf86') + model.load_pretrained_model(data="imagenet", name="dla34", hash="ba72cf86") return model def dla46_c(pretrained=None, **kwargs): # DLA-46-C Bottleneck.expansion = 2 - model = DLA([1, 1, 1, 2, 2, 1], - [16, 32, 64, 64, 128, 256], - block=Bottleneck, **kwargs) + model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 64, 128, 256], block=Bottleneck, **kwargs) if pretrained is not None: - model.load_pretrained_model(pretrained, 'dla46_c') + model.load_pretrained_model(pretrained, "dla46_c") return model def dla46x_c(pretrained=None, **kwargs): # DLA-X-46-C BottleneckX.expansion = 2 - model = DLA([1, 1, 1, 2, 2, 1], - [16, 32, 64, 64, 128, 256], - block=BottleneckX, **kwargs) + model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 64, 128, 256], block=BottleneckX, **kwargs) if pretrained is not None: - model.load_pretrained_model(pretrained, 'dla46x_c') + model.load_pretrained_model(pretrained, "dla46x_c") return model def dla60x_c(pretrained, **kwargs): # DLA-X-60-C BottleneckX.expansion = 2 - model = DLA([1, 1, 1, 2, 3, 1], - [16, 32, 64, 64, 128, 256], - block=BottleneckX, **kwargs) + model = DLA([1, 1, 1, 2, 3, 1], [16, 32, 64, 64, 128, 256], block=BottleneckX, **kwargs) if pretrained: - model.load_pretrained_model(data='imagenet', name='dla60x_c', hash='b870c45c') + model.load_pretrained_model(data="imagenet", name="dla60x_c", hash="b870c45c") return model def dla60(pretrained=None, **kwargs): # DLA-60 Bottleneck.expansion = 2 - model = DLA([1, 1, 1, 2, 3, 1], - [16, 32, 128, 256, 512, 1024], - block=Bottleneck, **kwargs) + model = DLA([1, 1, 1, 2, 3, 1], [16, 32, 128, 256, 512, 1024], block=Bottleneck, **kwargs) if pretrained is not None: - model.load_pretrained_model(pretrained, 'dla60') + model.load_pretrained_model(pretrained, "dla60") return model def dla60x(pretrained=None, **kwargs): # DLA-X-60 BottleneckX.expansion = 2 - model = DLA([1, 1, 1, 2, 3, 1], - [16, 32, 128, 256, 512, 1024], - block=BottleneckX, **kwargs) + model = DLA([1, 1, 1, 2, 3, 1], [16, 32, 128, 256, 512, 1024], block=BottleneckX, **kwargs) if pretrained is not None: - model.load_pretrained_model(pretrained, 'dla60x') + model.load_pretrained_model(pretrained, "dla60x") return model def dla102(pretrained=None, **kwargs): # DLA-102 Bottleneck.expansion = 2 - model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024], - block=Bottleneck, residual_root=True, **kwargs) + model = DLA( + [1, 1, 1, 3, 4, 1], + [16, 32, 128, 256, 512, 1024], + block=Bottleneck, + residual_root=True, + **kwargs + ) if pretrained is not None: - model.load_pretrained_model(pretrained, 'dla102') + model.load_pretrained_model(pretrained, "dla102") return model def dla102x(pretrained=None, **kwargs): # DLA-X-102 BottleneckX.expansion = 2 - model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024], - block=BottleneckX, residual_root=True, **kwargs) + model = DLA( + [1, 1, 1, 3, 4, 1], + [16, 32, 128, 256, 512, 1024], + block=BottleneckX, + residual_root=True, + **kwargs + ) if pretrained is not None: - model.load_pretrained_model(pretrained, 'dla102x') + model.load_pretrained_model(pretrained, "dla102x") return model def dla102x2(pretrained=None, **kwargs): # DLA-X-102 64 BottleneckX.cardinality = 64 - model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024], - block=BottleneckX, residual_root=True, **kwargs) + model = DLA( + [1, 1, 1, 3, 4, 1], + [16, 32, 128, 256, 512, 1024], + block=BottleneckX, + residual_root=True, + **kwargs + ) if pretrained is not None: - model.load_pretrained_model(pretrained, 'dla102x2') + model.load_pretrained_model(pretrained, "dla102x2") return model def dla169(pretrained=None, **kwargs): # DLA-169 Bottleneck.expansion = 2 - model = DLA([1, 1, 2, 3, 5, 1], [16, 32, 128, 256, 512, 1024], - block=Bottleneck, residual_root=True, **kwargs) + model = DLA( + [1, 1, 2, 3, 5, 1], + [16, 32, 128, 256, 512, 1024], + block=Bottleneck, + residual_root=True, + **kwargs + ) if pretrained is not None: - model.load_pretrained_model(pretrained, 'dla169') + model.load_pretrained_model(pretrained, "dla169") return model @@ -436,11 +518,10 @@ def forward(self, x): def fill_up_weights(up): w = up.weight.data f = math.ceil(w.size(2) / 2) - c = (2 * f - 1 - f % 2) / (2. * f) + c = (2 * f - 1 - f % 2) / (2.0 * f) for i in range(w.size(2)): for j in range(w.size(3)): - w[0, 0, i, j] = \ - (1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c)) + w[0, 0, i, j] = (1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c)) for c in range(1, w.size(0)): w[c, 0, :, :] = w[0, 0, :, :] @@ -455,50 +536,64 @@ def __init__(self, node_kernel, out_dim, channels, up_factors): proj = Identity() else: proj = nn.Sequential( - nn.Conv2d(c, out_dim, - kernel_size=1, stride=1, bias=False), + nn.Conv2d(c, out_dim, kernel_size=1, stride=1, bias=False), BatchNorm(out_dim), - nn.ReLU(inplace=True)) + nn.ReLU(inplace=True), + ) f = int(up_factors[i]) if f == 1: up = Identity() else: up = nn.ConvTranspose2d( - out_dim, out_dim, f * 2, stride=f, padding=f // 2, - output_padding=0, groups=out_dim, bias=False) + out_dim, + out_dim, + f * 2, + stride=f, + padding=f // 2, + output_padding=0, + groups=out_dim, + bias=False, + ) fill_up_weights(up) - setattr(self, 'proj_' + str(i), proj) - setattr(self, 'up_' + str(i), up) + setattr(self, "proj_" + str(i), proj) + setattr(self, "up_" + str(i), up) for i in range(1, len(channels)): node = nn.Sequential( - nn.Conv2d(out_dim * 2, out_dim, - kernel_size=node_kernel, stride=1, - padding=node_kernel // 2, bias=False), + nn.Conv2d( + out_dim * 2, + out_dim, + kernel_size=node_kernel, + stride=1, + padding=node_kernel // 2, + bias=False, + ), BatchNorm(out_dim), - nn.ReLU(inplace=True)) - setattr(self, 'node_' + str(i), node) + nn.ReLU(inplace=True), + ) + setattr(self, "node_" + str(i), node) for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2. / n)) + m.weight.data.normal_(0, math.sqrt(2.0 / n)) elif isinstance(m, BatchNorm): m.weight.data.fill_(1) m.bias.data.zero_() def forward(self, layers): - assert len(self.channels) == len(layers), \ - '{} vs {} layers'.format(len(self.channels), len(layers)) + assert len(self.channels) == len(layers), "{} vs {} layers".format( + len(self.channels), len(layers) + ) layers = list(layers) for i, l in enumerate(layers): - upsample = getattr(self, 'up_' + str(i)) - project = getattr(self, 'proj_' + str(i)) + upsample = getattr(self, "up_" + str(i)) + project = getattr(self, "proj_" + str(i)) layers[i] = upsample(project(l)) x = layers[0] y = [] for i in range(1, len(layers)): - node = getattr(self, 'node_' + str(i)) + node = getattr(self, "node_" + str(i)) x = node(torch.cat([x, layers[i]], 1)) y.append(x) return x, y @@ -514,21 +609,24 @@ def __init__(self, channels, scales=(1, 2, 4, 8, 16), in_channels=None): scales = np.array(scales, dtype=int) for i in range(len(channels) - 1): j = -i - 2 - setattr(self, 'ida_{}'.format(i), - IDAUp(3, channels[j], in_channels[j:], - scales[j:] // scales[j])) - scales[j + 1:] = scales[j] - in_channels[j + 1:] = [channels[j] for _ in channels[j + 1:]] + setattr( + self, + "ida_{}".format(i), + IDAUp(3, channels[j], in_channels[j:], scales[j:] // scales[j]), + ) + scales[j + 1 :] = scales[j] + in_channels[j + 1 :] = [channels[j] for _ in channels[j + 1 :]] def forward(self, layers): layers = list(layers) assert len(layers) > 1 for i in range(len(layers) - 1): - ida = getattr(self, 'ida_{}'.format(i)) - x, y = ida(layers[-i - 2:]) - layers[-i - 1:] = y + ida = getattr(self, "ida_{}".format(i)) + x, y = ida(layers[-i - 2 :]) + layers[-i - 1 :] = y return x + def fill_fc_weights(layers): for m in layers.modules(): if isinstance(m, nn.Conv2d): @@ -536,37 +634,41 @@ def fill_fc_weights(layers): if m.bias is not None: nn.init.constant_(m.bias, 0) + class DLASeg(nn.Module): - def __init__(self, base_name, heads, - pretrained=True, down_ratio=4, head_conv=256): + def __init__(self, base_name, heads, pretrained=True, down_ratio=4, head_conv=256): super(DLASeg, self).__init__() self.heads = heads self.first_level = int(np.log2(down_ratio)) - self.base = globals()[base_name]( - pretrained=pretrained, return_levels=True) + self.base = globals()[base_name](pretrained=pretrained, return_levels=True) channels = self.base.channels - scales = [2 ** i for i in range(len(channels[self.first_level:]))] - self.dla_up = DLAUp(channels[self.first_level:], scales=scales) + scales = [2**i for i in range(len(channels[self.first_level :]))] + self.dla_up = DLAUp(channels[self.first_level :], scales=scales) for head in self.heads: classes = self.heads[head] if head_conv > 0: fc = nn.Sequential( - nn.Conv2d(channels[self.first_level], head_conv, - kernel_size=3, padding=1, bias=True), - nn.ReLU(inplace=True), - nn.Conv2d(head_conv, classes, - kernel_size=1, stride=1, - padding=0, bias=True)) - if 'hm' in head: + nn.Conv2d( + channels[self.first_level], head_conv, kernel_size=3, padding=1, bias=True + ), + nn.ReLU(inplace=True), + nn.Conv2d(head_conv, classes, kernel_size=1, stride=1, padding=0, bias=True), + ) + if "hm" in head: fc[-1].bias.data.fill_(-2.19) else: fill_fc_weights(fc) else: - fc = nn.Conv2d(channels[self.first_level], classes, - kernel_size=1, stride=1, - padding=0, bias=True) - if 'hm' in head: + fc = nn.Conv2d( + channels[self.first_level], + classes, + kernel_size=1, + stride=1, + padding=0, + bias=True, + ) + if "hm" in head: fc.bias.data.fill_(-2.19) else: fill_fc_weights(fc) @@ -574,7 +676,7 @@ def __init__(self, base_name, heads, def forward(self, x): x = self.base(x) - x = self.dla_up(x[self.first_level:]) + x = self.dla_up(x[self.first_level :]) # x = self.fc(x) # y = self.softmax(self.up(x)) ret = {} @@ -582,9 +684,13 @@ def forward(self, x): ret[head] = self.__getattr__(head)(x) return [ret] + def get_pose_net(num_layers, heads, head_conv=256, down_ratio=4): - model = DLASeg('dla{}'.format(num_layers), heads, - pretrained=True, - down_ratio=down_ratio, - head_conv=head_conv) - return model \ No newline at end of file + model = DLASeg( + "dla{}".format(num_layers), + heads, + pretrained=True, + down_ratio=down_ratio, + head_conv=head_conv, + ) + return model diff --git a/PytorchWildlife/models/detection/herdnet/herdnet.py b/PytorchWildlife/models/detection/herdnet/herdnet.py index 4e2a41963..43edd8382 100644 --- a/PytorchWildlife/models/detection/herdnet/herdnet.py +++ b/PytorchWildlife/models/detection/herdnet/herdnet.py @@ -1,56 +1,65 @@ -from ..base_detector import BaseDetector -from ..herdnet.animaloc.eval import HerdNetStitcher, HerdNetLMDS -from ....data import datasets as pw_data -from .model import HerdNet as HerdNetArch +import os +import cv2 +import numpy as np +import supervision as sv import torch -from torch.hub import load_state_dict_from_url -from torch.utils.data import DataLoader import torchvision.transforms as transforms - -import numpy as np +import wget from PIL import Image +from torch.hub import load_state_dict_from_url +from torch.utils.data import DataLoader from tqdm import tqdm -import supervision as sv -import os -import wget -import cv2 - -class ResizeIfSmaller: - def __init__(self, min_size, interpolation=Image.BILINEAR): - self.min_size = min_size - self.interpolation = interpolation - + +from ....data import datasets as pw_data +from ..base_detector import BaseDetector +from ..herdnet.animaloc.eval import HerdNetLMDS, HerdNetStitcher +from .model import HerdNet as HerdNetArch + + +class ResizeIfSmaller: + def __init__(self, min_size, interpolation=Image.BILINEAR): + self.min_size = min_size + self.interpolation = interpolation + def __call__(self, img): if isinstance(img, np.ndarray): img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) - assert isinstance(img, Image.Image), "Image should be a PIL Image" + assert isinstance(img, Image.Image), "Image should be a PIL Image" width, height = img.size - if height < self.min_size or width < self.min_size: - ratio = max(self.min_size / height, self.min_size / width) - new_height = int(height * ratio) - new_width = int(width * ratio) - img = img.resize((new_width, new_height), self.interpolation) - return img + if height < self.min_size or width < self.min_size: + ratio = max(self.min_size / height, self.min_size / width) + new_height = int(height * ratio) + new_width = int(width * ratio) + img = img.resize((new_width, new_height), self.interpolation) + return img + class HerdNet(BaseDetector): """ HerdNet detector class. This class provides utility methods for loading the model, generating results, and performing single and batch image detections. """ - - def __init__(self, weights=None, device="cpu", version='general' ,url="https://zenodo.org/records/13899852/files/20220413_HerdNet_General_dataset_2022.pth?download=1", transform=None): + + def __init__( + self, + weights=None, + device="cpu", + version="general", + url="https://zenodo.org/records/13899852/files/20220413_HerdNet_General_dataset_2022.pth?download=1", + transform=None, + ): """ Initialize the HerdNet detector. - + Args: - weights (str, optional): + weights (str, optional): Path to the model weights. Defaults to None. - device (str, optional): + device (str, optional): Device for model inference. Defaults to "cpu". version (str, optional): Version name based on what dataset the model is trained on. It should be either 'general' or 'ennedi'. Defaults to 'general'. - url (str, optional): + url (str, optional): URL to fetch the model weights. Defaults to None. transform (torchvision.transforms.Compose, optional): Image transformation for inference. Defaults to None. @@ -58,43 +67,45 @@ def __init__(self, weights=None, device="cpu", version='general' ,url="https://z super(HerdNet, self).__init__(weights=weights, device=device, url=url) # Assert that the dataset is either 'general' or 'ennedi' version = version.lower() - assert version in ['general', 'ennedi'], "Dataset should be either 'general' or 'ennedi'" - if version == 'ennedi': + assert version in ["general", "ennedi"], "Dataset should be either 'general' or 'ennedi'" + if version == "ennedi": url = "https://zenodo.org/records/13914287/files/20220329_HerdNet_Ennedi_dataset_2023.pth?download=1" self._load_model(weights, device, url) - self.stitcher = HerdNetStitcher( # This module enables patch-based inference - model = self.model, - size = (512,512), - overlap = 160, - down_ratio = 2, - up = True, - reduction = 'mean', - device_name = device - ) - - self.lmds_kwargs: dict = {'kernel_size': (3, 3), 'adapt_ts': 0.2, 'neg_ts': 0.1} - self.lmds = HerdNetLMDS(up=False, **self.lmds_kwargs) # Local Maxima Detection Strategy + self.stitcher = HerdNetStitcher( # This module enables patch-based inference + model=self.model, + size=(512, 512), + overlap=160, + down_ratio=2, + up=True, + reduction="mean", + device_name=device, + ) + + self.lmds_kwargs: dict = {"kernel_size": (3, 3), "adapt_ts": 0.2, "neg_ts": 0.1} + self.lmds = HerdNetLMDS(up=False, **self.lmds_kwargs) # Local Maxima Detection Strategy if not transform: - self.transforms = transforms.Compose([ - ResizeIfSmaller(512), - transforms.ToTensor(), - transforms.Normalize(mean=self.img_mean, std=self.img_std) - ]) + self.transforms = transforms.Compose( + [ + ResizeIfSmaller(512), + transforms.ToTensor(), + transforms.Normalize(mean=self.img_mean, std=self.img_std), + ] + ) else: self.transforms = transform def _load_model(self, weights=None, device="cpu", url=None): """ Load the HerdNet model weights. - + Args: - weights (str, optional): + weights (str, optional): Path to the model weights. Defaults to None. - device (str, optional): + device (str, optional): Device for model inference. Defaults to "cpu". - url (str, optional): + url (str, optional): URL to fetch the model weights. Defaults to None. Raises: Exception: If weights are not provided. @@ -102,7 +113,9 @@ def _load_model(self, weights=None, device="cpu", url=None): if weights: checkpoint = torch.load(weights, map_location=torch.device(device)) elif url: - filename = url.split('/')[-1][:-11] # Splitting the URL to get the filename and removing the '?download=1' part + filename = url.split("/")[-1][ + :-11 + ] # Splitting the URL to get the filename and removing the '?download=1' part if not os.path.exists(os.path.join(torch.hub.get_dir(), "checkpoints", filename)): os.makedirs(os.path.join(torch.hub.get_dir(), "checkpoints"), exist_ok=True) weights = wget.download(url, out=os.path.join(torch.hub.get_dir(), "checkpoints")) @@ -111,26 +124,30 @@ def _load_model(self, weights=None, device="cpu", url=None): checkpoint = torch.load(weights, map_location=torch.device(device)) else: raise Exception("Need weights for inference.") - + # Load the class names and other metadata from the checkpoint self.CLASS_NAMES = checkpoint["classes"] self.num_classes = len(self.CLASS_NAMES) + 1 - self.img_mean = checkpoint['mean'] - self.img_std = checkpoint['std'] + self.img_mean = checkpoint["mean"] + self.img_std = checkpoint["std"] # Load the model architecture self.model = HerdNetArch(num_classes=self.num_classes, pretrained=False) # Load checkpoint into model - state_dict = checkpoint['model_state_dict'] + state_dict = checkpoint["model_state_dict"] # Remove 'model.' prefix from the state_dict keys if the key starts with 'model.' - new_state_dict = {k.replace('model.', ''): v for k, v in state_dict.items() if k.startswith('model.')} - # Load the new state_dict + new_state_dict = { + k.replace("model.", ""): v for k, v in state_dict.items() if k.startswith("model.") + } + # Load the new state_dict self.model.load_state_dict(new_state_dict, strict=True) print(f"Model loaded from {weights}") - def results_generation(self, preds: np.ndarray, img: np.ndarray = None, img_id: str = None, id_strip: str = None) -> dict: + def results_generation( + self, preds: np.ndarray, img: np.ndarray = None, img_id: str = None, id_strip: str = None + ) -> dict: """ Generate results for detection based on model predictions. @@ -151,52 +168,63 @@ def results_generation(self, preds: np.ndarray, img: np.ndarray = None, img_id: results = {"img": img} results["detections"] = sv.Detections( - xyxy=preds[:, :4], - confidence=preds[:, 4], - class_id=preds[:, 5].astype(int) + xyxy=preds[:, :4], confidence=preds[:, 4], class_id=preds[:, 5].astype(int) ) results["labels"] = [ f"{self.CLASS_NAMES[class_id]} {confidence:0.2f}" - for confidence, class_id in zip(results["detections"].confidence, results["detections"].class_id) + for confidence, class_id in zip( + results["detections"].confidence, results["detections"].class_id + ) ] return results - - def single_image_detection(self, img, img_path=None, det_conf_thres=0.2, clf_conf_thres=0.2, id_strip=None) -> dict: + + def single_image_detection( + self, img, img_path=None, det_conf_thres=0.2, clf_conf_thres=0.2, id_strip=None + ) -> dict: """ Perform detection on a single image. Args: - img (str or np.ndarray): + img (str or np.ndarray): Image for inference. - img_path (str, optional): + img_path (str, optional): Path to the image. Defaults to None. det_conf_thres (float, optional): Confidence threshold for detections. Defaults to 0.2. clf_conf_thres (float, optional): Confidence threshold for classification. Defaults to 0.2. - id_strip (str, optional): + id_strip (str, optional): Characters to strip from img_id. Defaults to None. Returns: dict: Detection results for the image. """ - if isinstance(img, str): - img_path = img_path or img - img = np.array(Image.open(img_path).convert("RGB")) - if self.transforms: + if isinstance(img, str): + img_path = img_path or img + img = np.array(Image.open(img_path).convert("RGB")) + if self.transforms: img_tensor = self.transforms(img) - preds = self.stitcher(img_tensor) - heatmap, clsmap = preds[:,:1,:,:], preds[:,1:,:,:] + preds = self.stitcher(img_tensor) + heatmap, clsmap = preds[:, :1, :, :], preds[:, 1:, :, :] counts, locs, labels, scores, dscores = self.lmds((heatmap, clsmap)) - preds_array = self.process_lmds_results(counts, locs, labels, scores, dscores, det_conf_thres, clf_conf_thres) + preds_array = self.process_lmds_results( + counts, locs, labels, scores, dscores, det_conf_thres, clf_conf_thres + ) if img_path: results_dict = self.results_generation(preds_array, img_id=img_path, id_strip=id_strip) else: results_dict = self.results_generation(preds_array, img=img) return results_dict - def batch_image_detection(self, data_path: str, det_conf_thres: float = 0.2, clf_conf_thres: float = 0.2, batch_size: int = 1, id_strip: str = None) -> list[dict]: + def batch_image_detection( + self, + data_path: str, + det_conf_thres: float = 0.2, + clf_conf_thres: float = 0.2, + batch_size: int = 1, + id_strip: str = None, + ) -> list[dict]: """ Perform detection on a batch of images. @@ -210,32 +238,51 @@ def batch_image_detection(self, data_path: str, det_conf_thres: float = 0.2, clf Returns: list[dict]: List of detection results for all images. """ - dataset = pw_data.DetectionImageFolder( - data_path, - transform=self.transforms - ) + dataset = pw_data.DetectionImageFolder(data_path, transform=self.transforms) # Creating a Dataloader for batching and parallel processing of the images - loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, - pin_memory=True, num_workers=0, drop_last=False) # TODO: discuss. why is num_workers 0? - + loader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + pin_memory=True, + num_workers=0, + drop_last=False, + ) # TODO: discuss. why is num_workers 0? + results = [] with tqdm(total=len(loader)) as pbar: for batch_index, (imgs, paths, sizes) in enumerate(loader): imgs = imgs.to(self.device) predictions = self.stitcher(imgs[0]).detach().cpu() - heatmap, clsmap = predictions[:,:1,:,:], predictions[:,1:,:,:] + heatmap, clsmap = predictions[:, :1, :, :], predictions[:, 1:, :, :] counts, locs, labels, scores, dscores = self.lmds((heatmap, clsmap)) - preds_array = self.process_lmds_results(counts, locs, labels, scores, dscores, det_conf_thres, clf_conf_thres) - results_dict = self.results_generation(preds_array, img_id=paths[0], id_strip=id_strip) + preds_array = self.process_lmds_results( + counts, locs, labels, scores, dscores, det_conf_thres, clf_conf_thres + ) + results_dict = self.results_generation( + preds_array, img_id=paths[0], id_strip=id_strip + ) pbar.update(1) sizes = sizes.numpy() - normalized_coords = [[x1 / sizes[0][0], y1 / sizes[0][1], x2 / sizes[0][0], y2 / sizes[0][1]] for x1, y1, x2, y2 in preds_array[:, :4]] # TODO: Check if this is correct due to xy swapping - results_dict['normalized_coords'] = normalized_coords + normalized_coords = [ + [x1 / sizes[0][0], y1 / sizes[0][1], x2 / sizes[0][0], y2 / sizes[0][1]] + for x1, y1, x2, y2 in preds_array[:, :4] + ] # TODO: Check if this is correct due to xy swapping + results_dict["normalized_coords"] = normalized_coords results.append(results_dict) return results - def process_lmds_results(self, counts: list, locs: list, labels: list, scores: list, dscores: list, det_conf_thres: float = 0.2, clf_conf_thres: float = 0.2) -> np.ndarray: + def process_lmds_results( + self, + counts: list, + locs: list, + labels: list, + scores: list, + dscores: list, + det_conf_thres: float = 0.2, + clf_conf_thres: float = 0.2, + ) -> np.ndarray: """ Process the results from the Local Maxima Detection Strategy. @@ -251,29 +298,31 @@ def process_lmds_results(self, counts: list, locs: list, labels: list, scores: l Returns: numpy.ndarray: Processed detection results. """ - # Flatten the lists since we know its a single image - counts = counts[0] - locs = locs[0] - labels = labels[0] + # Flatten the lists since we know its a single image + counts = counts[0] + locs = locs[0] + labels = labels[0] scores = scores[0] - dscores = dscores[0] - - # Calculate the total number of detections - total_detections = sum(counts) - - # Pre-allocate based on total possible detections - preds_array = np.empty((total_detections, 6)) #xyxy, confidence, class_id format + dscores = dscores[0] + + # Calculate the total number of detections + total_detections = sum(counts) + + # Pre-allocate based on total possible detections + preds_array = np.empty((total_detections, 6)) # xyxy, confidence, class_id format detection_idx = 0 - valid_detections_idx = 0 # Index for valid detections after applying the confidence threshold - # Loop through each species - for specie_idx in range(len(counts)): - count = counts[specie_idx] - if count == 0: - continue - - # Get the detections for this species + valid_detections_idx = ( + 0 # Index for valid detections after applying the confidence threshold + ) + # Loop through each species + for specie_idx in range(len(counts)): + count = counts[specie_idx] + if count == 0: + continue + + # Get the detections for this species species_locs = np.array(locs[detection_idx : detection_idx + count]) - species_locs[:, [0, 1]] = species_locs[:, [1, 0]] # Swap x and y in species_locs + species_locs[:, [0, 1]] = species_locs[:, [1, 0]] # Swap x and y in species_locs species_scores = np.array(scores[detection_idx : detection_idx + count]) species_dscores = np.array(dscores[detection_idx : detection_idx + count]) species_labels = np.array(labels[detection_idx : detection_idx + count]) @@ -281,30 +330,40 @@ def process_lmds_results(self, counts: list, locs: list, labels: list, scores: l # Apply the confidence threshold valid_detections_by_clf_score = species_scores > clf_conf_thres valid_detections_by_det_score = species_dscores > det_conf_thres - valid_detections = np.logical_and(valid_detections_by_clf_score, valid_detections_by_det_score) + valid_detections = np.logical_and( + valid_detections_by_clf_score, valid_detections_by_det_score + ) valid_detections_count = np.sum(valid_detections) valid_detections_idx += valid_detections_count # Fill the preds_array with the valid detections if valid_detections_count > 0: - preds_array[valid_detections_idx - valid_detections_count : valid_detections_idx, :2] = species_locs[valid_detections] - 1 - preds_array[valid_detections_idx - valid_detections_count : valid_detections_idx, 2:4] = species_locs[valid_detections] + 1 - preds_array[valid_detections_idx - valid_detections_count : valid_detections_idx, 4] = species_scores[valid_detections] - preds_array[valid_detections_idx - valid_detections_count : valid_detections_idx, 5] = species_labels[valid_detections] - - detection_idx += count # Move to the next species - - preds_array = preds_array[:valid_detections_idx] # Remove the empty rows - + preds_array[ + valid_detections_idx - valid_detections_count : valid_detections_idx, :2 + ] = (species_locs[valid_detections] - 1) + preds_array[ + valid_detections_idx - valid_detections_count : valid_detections_idx, 2:4 + ] = (species_locs[valid_detections] + 1) + preds_array[ + valid_detections_idx - valid_detections_count : valid_detections_idx, 4 + ] = species_scores[valid_detections] + preds_array[ + valid_detections_idx - valid_detections_count : valid_detections_idx, 5 + ] = species_labels[valid_detections] + + detection_idx += count # Move to the next species + + preds_array = preds_array[:valid_detections_idx] # Remove the empty rows + return preds_array def forward(self, input: torch.Tensor) -> torch.Tensor: """ Forward pass of the model. - + Args: - input (torch.Tensor): + input (torch.Tensor): Input tensor for the model. - + Returns: torch.Tensor: Model output. """ diff --git a/PytorchWildlife/models/detection/herdnet/model.py b/PytorchWildlife/models/detection/herdnet/model.py index f26bb6ecd..fa31ce425 100644 --- a/PytorchWildlife/models/detection/herdnet/model.py +++ b/PytorchWildlife/models/detection/herdnet/model.py @@ -1,5 +1,4 @@ -__copyright__ = \ - """ +__copyright__ = """ Copyright (C) 2024 University of Liège, Gembloux Agro-Bio Tech, Forest Is Life All rights reserved. @@ -14,47 +13,52 @@ __version__ = "0.2.1" -import torch +from typing import Optional -import torch.nn as nn import numpy as np +import torch +import torch.nn as nn import torchvision.transforms as T -from typing import Optional - from . import dla as dla_modules + class HerdNet(nn.Module): - ''' HerdNet architecture ''' + """HerdNet architecture""" def __init__( self, num_layers: int = 34, num_classes: int = 2, - pretrained: bool = True, - down_ratio: Optional[int] = 2, - head_conv: int = 64 - ): - ''' + pretrained: bool = True, + down_ratio: Optional[int] = 2, + head_conv: int = 64, + ): + """ Args: num_layers (int, optional): number of layers of DLA. Defaults to 34. - num_classes (int, optional): number of output classes, background included. + num_classes (int, optional): number of output classes, background included. Defaults to 2. pretrained (bool, optional): set False to disable pretrained DLA encoder parameters from ImageNet. Defaults to True. - down_ratio (int, optional): downsample ratio. Possible values are 1, 2, 4, 8, or 16. + down_ratio (int, optional): downsample ratio. Possible values are 1, 2, 4, 8, or 16. Set to 1 to get output of the same size as input (i.e. no downsample). Defaults to 2. - head_conv (int, optional): number of supplementary convolutional layers at the end + head_conv (int, optional): number of supplementary convolutional layers at the end of decoder. Defaults to 64. - ''' + """ super(HerdNet, self).__init__() - assert down_ratio in [1, 2, 4, 8, 16], \ - f'Downsample ratio possible values are 1, 2, 4, 8 or 16, got {down_ratio}' - - base_name = 'dla{}'.format(num_layers) + assert down_ratio in [ + 1, + 2, + 4, + 8, + 16, + ], f"Downsample ratio possible values are 1, 2, 4, 8 or 16, got {down_ratio}" + + base_name = "dla{}".format(num_layers) self.down_ratio = down_ratio self.num_classes = num_classes @@ -64,58 +68,45 @@ def __init__( # backbone base = dla_modules.__dict__[base_name](pretrained=pretrained, return_levels=True) - setattr(self, 'base_0', base) - setattr(self, 'channels_0', base.channels) + setattr(self, "base_0", base) + setattr(self, "channels_0", base.channels) channels = self.channels_0 - scales = [2 ** i for i in range(len(channels[self.first_level:]))] - self.dla_up = dla_modules.DLAUp(channels[self.first_level:], scales=scales) + scales = [2**i for i in range(len(channels[self.first_level :]))] + self.dla_up = dla_modules.DLAUp(channels[self.first_level :], scales=scales) # self.cls_dla_up = dla_modules.DLAUp(channels[-3:], scales=scales[:3]) # bottleneck conv self.bottleneck_conv = nn.Conv2d( - channels[-1], channels[-1], - kernel_size=1, stride=1, - padding=0, bias=True + channels[-1], channels[-1], kernel_size=1, stride=1, padding=0, bias=True ) # localization head self.loc_head = nn.Sequential( - nn.Conv2d(channels[self.first_level], head_conv, - kernel_size=3, padding=1, bias=True), + nn.Conv2d(channels[self.first_level], head_conv, kernel_size=3, padding=1, bias=True), nn.ReLU(inplace=True), - nn.Conv2d( - head_conv, 1, - kernel_size=1, stride=1, - padding=0, bias=True - ), - nn.Sigmoid() - ) + nn.Conv2d(head_conv, 1, kernel_size=1, stride=1, padding=0, bias=True), + nn.Sigmoid(), + ) self.loc_head[-2].bias.data.fill_(0.00) # classification head self.cls_head = nn.Sequential( - nn.Conv2d(channels[-1], head_conv, - kernel_size=3, padding=1, bias=True), + nn.Conv2d(channels[-1], head_conv, kernel_size=3, padding=1, bias=True), nn.ReLU(inplace=True), - nn.Conv2d( - head_conv, self.num_classes, - kernel_size=1, stride=1, - padding=0, bias=True - ) - ) + nn.Conv2d(head_conv, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True), + ) self.cls_head[-1].bias.data.fill_(0.00) - - def forward(self, input: torch.Tensor): - encode = self.base_0(input) + def forward(self, input: torch.Tensor): + encode = self.base_0(input) bottleneck = self.bottleneck_conv(encode[-1]) encode[-1] = bottleneck - decode_hm = self.dla_up(encode[self.first_level:]) + decode_hm = self.dla_up(encode[self.first_level :]) # decode_cls = self.cls_dla_up(encode[-3:]) heatmap = self.loc_head(decode_hm) @@ -123,29 +114,27 @@ def forward(self, input: torch.Tensor): # clsmap = self.cls_head(decode_cls) return heatmap, clsmap - + def freeze(self, layers: list) -> None: - ''' Freeze all layers mentioned in the input list ''' + """Freeze all layers mentioned in the input list""" for layer in layers: self._freeze_layer(layer) - + def _freeze_layer(self, layer_name: str) -> None: for param in getattr(self, layer_name).parameters(): param.requires_grad = False - + def reshape_classes(self, num_classes: int) -> None: - ''' Reshape architecture according to a new number of classes. + """Reshape architecture according to a new number of classes. Arg: num_classes (int): new number of classes - ''' - + """ + self.cls_head[-1] = nn.Conv2d( - self.head_conv, num_classes, - kernel_size=1, stride=1, - padding=0, bias=True - ) + self.head_conv, num_classes, kernel_size=1, stride=1, padding=0, bias=True + ) self.cls_head[-1].bias.data.fill_(0.00) - self.num_classes = num_classes \ No newline at end of file + self.num_classes = num_classes diff --git a/PytorchWildlife/models/detection/rtdetr_apache/__init__.py b/PytorchWildlife/models/detection/rtdetr_apache/__init__.py index 77db993c5..11866034d 100644 --- a/PytorchWildlife/models/detection/rtdetr_apache/__init__.py +++ b/PytorchWildlife/models/detection/rtdetr_apache/__init__.py @@ -1,2 +1,2 @@ +from .megadetectorv6_apache import * from .rtdetr_apache_base import * -from .megadetectorv6_apache import * \ No newline at end of file diff --git a/PytorchWildlife/models/detection/rtdetr_apache/megadetectorv6_apache.py b/PytorchWildlife/models/detection/rtdetr_apache/megadetectorv6_apache.py index 2e2cf10dd..2741165c5 100644 --- a/PytorchWildlife/models/detection/rtdetr_apache/megadetectorv6_apache.py +++ b/PytorchWildlife/models/detection/rtdetr_apache/megadetectorv6_apache.py @@ -1,29 +1,23 @@ - from .rtdetr_apache_base import RTDETRApacheBase -__all__ = [ - 'MegaDetectorV6Apache' -] +__all__ = ["MegaDetectorV6Apache"] + class MegaDetectorV6Apache(RTDETRApacheBase): """ - MegaDetectorV6 is a specialized class derived from the RTDETRApacheBase class + MegaDetectorV6 is a specialized class derived from the RTDETRApacheBase class that is specifically designed for detecting animals, persons, and vehicles. - + Attributes: CLASS_NAMES (dict): Mapping of class IDs to their respective names. """ - - CLASS_NAMES = { - 0: "animal", - 1: "person", - 2: "vehicle" - } - - def __init__(self, weights=None, device="cpu", pretrained=True, version='MDV6-rtdetr-x-apache'): + + CLASS_NAMES = {0: "animal", 1: "person", 2: "vehicle"} + + def __init__(self, weights=None, device="cpu", pretrained=True, version="MDV6-rtdetr-x-apache"): """ Initializes the MegaDetectorV6 model with the option to load pretrained weights. - + Args: weights (str, optional): Path to the weights file. device (str, optional): Device to load the model on (e.g., "cpu" or "cuda"). Default is "cpu". @@ -39,6 +33,6 @@ def __init__(self, weights=None, device="cpu", pretrained=True, version='MDV6-rt url = "https://zenodo.org/records/15398270/files/MDV6-apa-rtdetr-e.pth?download=1" self.MODEL_NAME = "MDV6-apa-rtdetr-e.pth" else: - raise ValueError('Select a valid model version: MDV6-apa-rtdetr-c or MDV6-apa-rtdetr-e') + raise ValueError("Select a valid model version: MDV6-apa-rtdetr-c or MDV6-apa-rtdetr-e") - super(MegaDetectorV6Apache, self).__init__(weights=weights, device=device, url=url) \ No newline at end of file + super(MegaDetectorV6Apache, self).__init__(weights=weights, device=device, url=url) diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetr_apache_base.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetr_apache_base.py index 4e5004e38..14d44a082 100644 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetr_apache_base.py +++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetr_apache_base.py @@ -5,59 +5,66 @@ # Importing basic libraries import os +import sys +from pathlib import Path + import supervision as sv -import wget import torch -import torch.nn as nn +import torch.nn as nn import torchvision.transforms as T +import wget +from PIL import Image -from ..base_detector import BaseDetector from ....data import datasets as pw_data -from PIL import Image +from ..base_detector import BaseDetector -import sys -from pathlib import Path project_root = Path(__file__).resolve().parent sys.path.append(str(project_root)) from rtdetrv2_pytorch.src.core import YAMLConfig + class RTDETRApacheBase(BaseDetector): """ Base detector class for RTDETRApacheBase framework. This class provides utility methods for loading the model, generating results, and performing single and batch image detections. """ + def __init__(self, weights=None, device="cpu", url=None): """ Initialize the RT-DETR apache detector. - + Args: - weights (str, optional): + weights (str, optional): Path to the model weights. Defaults to None. - device (str, optional): + device (str, optional): Device for model inference. Defaults to "cpu". - url (str, optional): + url (str, optional): URL to fetch the model weights. Defaults to None. """ - self.transform = T.Compose([ - T.Resize((640, 640)), - T.ToTensor(), - ]) + self.transform = T.Compose( + [ + T.Resize((640, 640)), + T.ToTensor(), + ] + ) self.weights = weights self.device = device self.url = url - super(RTDETRApacheBase, self).__init__(weights=self.weights, device=self.device, url=self.url) + super(RTDETRApacheBase, self).__init__( + weights=self.weights, device=self.device, url=self.url + ) self._load_model(weights=self.weights, device=self.device, url=self.url) def _load_model(self, weights=None, device="cpu", url=None): """ Load the RT-DETR apache model weights. - + Args: - weights (str, optional): + weights (str, optional): Path to the model weights. Defaults to None. - device (str, optional): + device (str, optional): Device for model inference. Defaults to "cpu". - url (str, optional): + url (str, optional): URL to fetch the model weights. Defaults to None. Raises: Exception: If weights are not provided. @@ -65,7 +72,9 @@ def _load_model(self, weights=None, device="cpu", url=None): if weights: resume = weights elif url: - if not os.path.exists(os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME)): + if not os.path.exists( + os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME) + ): os.makedirs(os.path.join(torch.hub.get_dir(), "checkpoints"), exist_ok=True) resume = wget.download(url, out=os.path.join(torch.hub.get_dir(), "checkpoints")) else: @@ -74,45 +83,59 @@ def _load_model(self, weights=None, device="cpu", url=None): raise Exception("Need weights for inference.") if self.MODEL_NAME == "MDV6-apa-rtdetr-c.pth": - config = os.path.join(project_root, "rtdetrv2_pytorch", "configs", "rtdetrv2", "rtdetrv2_r18vd_120e_megadetector.yml") + config = os.path.join( + project_root, + "rtdetrv2_pytorch", + "configs", + "rtdetrv2", + "rtdetrv2_r18vd_120e_megadetector.yml", + ) elif self.MODEL_NAME == "MDV6-apa-rtdetr-e.pth": - config = os.path.join(project_root, "rtdetrv2_pytorch", "configs", "rtdetrv2", "rtdetrv2_r101vd_6x_megadetector.yml") + config = os.path.join( + project_root, + "rtdetrv2_pytorch", + "configs", + "rtdetrv2", + "rtdetrv2_r101vd_6x_megadetector.yml", + ) else: - raise ValueError('Select a valid model version: MDV6-apa-rtdetr-c or MDV6-apa-rtdetr-e') - + raise ValueError("Select a valid model version: MDV6-apa-rtdetr-c or MDV6-apa-rtdetr-e") + cfg = YAMLConfig(config, resume=resume) - - checkpoint = torch.load(resume, map_location='cpu') - if 'ema' in checkpoint: - state = checkpoint['ema']['module'] + + checkpoint = torch.load(resume, map_location="cpu") + if "ema" in checkpoint: + state = checkpoint["ema"]["module"] else: - state = checkpoint['model'] + state = checkpoint["model"] cfg.model.load_state_dict(state) class Model(nn.Module): - def __init__(self, ) -> None: + def __init__( + self, + ) -> None: super().__init__() self.model = cfg.model.deploy() self.postprocessor = cfg.postprocessor.deploy() - + def forward(self, images, orig_target_sizes): outputs = self.model(images) outputs = self.postprocessor(outputs, orig_target_sizes) return outputs - + self.model = Model().to(self.device) def results_generation(self, preds, img_id, id_strip=None): """ Generate results for detection based on model predictions. - + Args: - preds (List[torch.Tensor]): + preds (List[torch.Tensor]): Model predictions. - img_id (str): + img_id (str): Image identifier. - id_strip (str, optional): + id_strip (str, optional): Strip specific characters from img_id. Defaults to None. Returns: @@ -123,32 +146,27 @@ def results_generation(self, preds, img_id, id_strip=None): confidence = preds[2].detach().cpu().numpy() results = {"img_id": str(img_id).strip(id_strip)} - results["detections"] = sv.Detections( - xyxy=xyxy, - confidence=confidence, - class_id=class_id - ) + results["detections"] = sv.Detections(xyxy=xyxy, confidence=confidence, class_id=class_id) results["labels"] = [ - f"{self.CLASS_NAMES[class_id]} {confidence:0.2f}" - for _, _, confidence, class_id, _, _ in results["detections"] + f"{self.CLASS_NAMES[class_id]} {confidence:0.2f}" + for _, _, confidence, class_id, _, _ in results["detections"] ] - + return results - def single_image_detection(self, img, img_path=None, det_conf_thres=0.2, id_strip=None): """ Perform detection on a single image. - + Args: - img (str or ndarray): + img (str or ndarray): Image path or ndarray of images. - img_path (str, optional): + img_path (str, optional): Image path or identifier. - det_conf_thres (float, optional): + det_conf_thres (float, optional): Confidence threshold for predictions. Defaults to 0.2. - id_strip (str, optional): + id_strip (str, optional): Characters to strip from img_id. Defaults to None. Returns: @@ -157,7 +175,7 @@ def single_image_detection(self, img, img_path=None, det_conf_thres=0.2, id_stri if type(img) == str: if img_path is None: img_path = img - im_pil = Image.open(img_path).convert('RGB') + im_pil = Image.open(img_path).convert("RGB") else: im_pil = Image.fromarray(img) @@ -170,21 +188,21 @@ def single_image_detection(self, img, img_path=None, det_conf_thres=0.2, id_stri lab = labels[0][scr > det_conf_thres] box = boxes[0][scr > det_conf_thres] scrs = scores[0][scr > det_conf_thres] - + return self.results_generation([lab, box, scrs], img_path, id_strip) def batch_image_detection(self, data_path, batch_size=16, det_conf_thres=0.2, id_strip=None): """ Perform detection on a batch of images. - + Args: - data_path (str): + data_path (str): Path containing all images for inference. batch_size (int, optional): Batch size for inference. Defaults to 16. - det_conf_thres (float, optional): + det_conf_thres (float, optional): Confidence threshold for predictions. Defaults to 0.2. - id_strip (str, optional): + id_strip (str, optional): Characters to strip from img_id. Defaults to None. extension (str, optional): Image extension to search for. Defaults to "JPG" @@ -196,10 +214,10 @@ def batch_image_detection(self, data_path, batch_size=16, det_conf_thres=0.2, id data_path, transform=self.transform, ) - + results = [] for i in range(len(dataset)): - im_pil = Image.open(dataset.images[i]).convert('RGB') + im_pil = Image.open(dataset.images[i]).convert("RGB") w, h = im_pil.size orig_size = torch.tensor([w, h])[None].to(self.device) im_data = self.transform(im_pil)[None].to(self.device) @@ -210,16 +228,16 @@ def batch_image_detection(self, data_path, batch_size=16, det_conf_thres=0.2, id lab = labels[0][scr > det_conf_thres] box = boxes[0][scr > det_conf_thres] scrs = scores[0][scr > det_conf_thres] - + res = self.results_generation([lab, box, scrs], dataset.images[i], id_strip) # Normalize the coordinates for timelapse compatibility size = orig_size[0].cpu().numpy() - normalized_coords = [[x1 / size[1], y1 / size[0], x2 / size[1], y2 / size[0]] for x1, y1, x2, y2 in res["detections"].xyxy] + normalized_coords = [ + [x1 / size[1], y1 / size[0], x2 / size[1], y2 / size[0]] + for x1, y1, x2, y2 in res["detections"].xyxy + ] res["normalized_coords"] = normalized_coords results.append(res) return results - - - diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/dataset/megadetector_detection.yml b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/dataset/megadetector_detection.yml index 0e4f046a6..4c4a6f420 100644 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/dataset/megadetector_detection.yml +++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/dataset/megadetector_detection.yml @@ -1,3 +1,3 @@ task: detection num_classes: 3 -remap_mscoco_category: False \ No newline at end of file +remap_mscoco_category: False diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/rtdetrv2/include/rtdetrv2_r50vd.yml b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/rtdetrv2/include/rtdetrv2_r50vd.yml index a5c14909b..d7936f6eb 100644 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/rtdetrv2/include/rtdetrv2_r50vd.yml +++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/rtdetrv2/include/rtdetrv2_r50vd.yml @@ -80,4 +80,3 @@ RTDETRCriterionv2: weight_dict: {cost_class: 2, cost_bbox: 5, cost_giou: 2} alpha: 0.25 gamma: 2.0 - diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/rtdetrv2/rtdetrv2_r18vd_120e_megadetector.yml b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/rtdetrv2/rtdetrv2_r18vd_120e_megadetector.yml index c703545f0..0ad1881dc 100644 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/rtdetrv2/rtdetrv2_r18vd_120e_megadetector.yml +++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/rtdetrv2/rtdetrv2_r18vd_120e_megadetector.yml @@ -18,4 +18,4 @@ HybridEncoder: RTDETRTransformerv2: - num_layers: 3 \ No newline at end of file + num_layers: 3 diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/__init__.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/__init__.py index 8b850c549..4f00ea85f 100644 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/__init__.py +++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/__init__.py @@ -2,5 +2,4 @@ """ # for register purpose -from . import backbone -from . import rtdetr \ No newline at end of file +from . import backbone, rtdetr diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/__init__.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/__init__.py index 53ab01265..7c4cdb60c 100644 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/__init__.py +++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/__init__.py @@ -1,4 +1,4 @@ """Copyright(c) 2023 lyuwenyu. All Rights Reserved. """ -from .presnet import PResNet \ No newline at end of file +from .presnet import PResNet diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/common.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/common.py index e3f54ea2e..238cb9b91 100644 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/common.py +++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/common.py @@ -1,7 +1,7 @@ """Copyright(c) 2023 lyuwenyu. All Rights Reserved. """ -import torch +import torch import torch.nn as nn @@ -12,6 +12,7 @@ class FrozenBatchNorm2d(nn.Module): without which any other models than torchvision.models.resnet[18,34,50,101] produce nans. """ + def __init__(self, num_features, eps=1e-5): super(FrozenBatchNorm2d, self).__init__() n = num_features @@ -20,17 +21,18 @@ def __init__(self, num_features, eps=1e-5): self.register_buffer("running_mean", torch.zeros(n)) self.register_buffer("running_var", torch.ones(n)) self.eps = eps - self.num_features = n + self.num_features = n - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs): - num_batches_tracked_key = prefix + 'num_batches_tracked' + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + num_batches_tracked_key = prefix + "num_batches_tracked" if num_batches_tracked_key in state_dict: del state_dict[num_batches_tracked_key] super(FrozenBatchNorm2d, self)._load_from_state_dict( - state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs) + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) def forward(self, x): # move reshapes to the beginning @@ -44,43 +46,41 @@ def forward(self, x): return x * scale + bias def extra_repr(self): - return ( - "{num_features}, eps={eps}".format(**self.__dict__) - ) + return "{num_features}, eps={eps}".format(**self.__dict__) -def get_activation(act: str, inplace: bool=True): - """get activation - """ + +def get_activation(act: str, inplace: bool = True): + """get activation""" if act is None: return nn.Identity() elif isinstance(act, nn.Module): - return act + return act act = act.lower() - - if act == 'silu' or act == 'swish': + + if act == "silu" or act == "swish": m = nn.SiLU() - elif act == 'relu': + elif act == "relu": m = nn.ReLU() - elif act == 'leaky_relu': + elif act == "leaky_relu": m = nn.LeakyReLU() - elif act == 'silu': + elif act == "silu": m = nn.SiLU() - - elif act == 'gelu': + + elif act == "gelu": m = nn.GELU() - elif act == 'hardsigmoid': + elif act == "hardsigmoid": m = nn.Hardsigmoid() else: - raise RuntimeError('') + raise RuntimeError("") - if hasattr(m, 'inplace'): + if hasattr(m, "inplace"): m.inplace = inplace - - return m + + return m diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/presnet.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/presnet.py index 7401c0ef6..a586630db 100644 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/presnet.py +++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/presnet.py @@ -1,17 +1,15 @@ """Copyright(c) 2023 lyuwenyu. All Rights Reserved. """ -import torch -import torch.nn as nn -import torch.nn.functional as F - from collections import OrderedDict -from .common import get_activation, FrozenBatchNorm2d +import torch +import torch.nn as nn +import torch.nn.functional as F from ..core import register +from .common import FrozenBatchNorm2d, get_activation - -__all__ = ['PResNet'] +__all__ = ["PResNet"] ResNet_cfg = { @@ -23,10 +21,10 @@ donwload_url = { - 18: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet18_vd_pretrained_from_paddle.pth', - 34: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet34_vd_pretrained_from_paddle.pth', - 50: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet50_vd_ssld_v2_pretrained_from_paddle.pth', - 101: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet101_vd_ssld_pretrained_from_paddle.pth', + 18: "https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet18_vd_pretrained_from_paddle.pth", + 34: "https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet34_vd_pretrained_from_paddle.pth", + 50: "https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet50_vd_ssld_v2_pretrained_from_paddle.pth", + 101: "https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet101_vd_ssld_pretrained_from_paddle.pth", } @@ -34,14 +32,15 @@ class ConvNormLayer(nn.Module): def __init__(self, ch_in, ch_out, kernel_size, stride, padding=None, bias=False, act=None): super().__init__() self.conv = nn.Conv2d( - ch_in, - ch_out, - kernel_size, - stride, - padding=(kernel_size-1)//2 if padding is None else padding, - bias=bias) + ch_in, + ch_out, + kernel_size, + stride, + padding=(kernel_size - 1) // 2 if padding is None else padding, + bias=bias, + ) self.norm = nn.BatchNorm2d(ch_out) - self.act = get_activation(act) + self.act = get_activation(act) def forward(self, x): return self.act(self.norm(self.conv(x))) @@ -50,24 +49,27 @@ def forward(self, x): class BasicBlock(nn.Module): expansion = 1 - def __init__(self, ch_in, ch_out, stride, shortcut, act='relu', variant='b'): + def __init__(self, ch_in, ch_out, stride, shortcut, act="relu", variant="b"): super().__init__() self.shortcut = shortcut if not shortcut: - if variant == 'd' and stride == 2: - self.short = nn.Sequential(OrderedDict([ - ('pool', nn.AvgPool2d(2, 2, 0, ceil_mode=True)), - ('conv', ConvNormLayer(ch_in, ch_out, 1, 1)) - ])) + if variant == "d" and stride == 2: + self.short = nn.Sequential( + OrderedDict( + [ + ("pool", nn.AvgPool2d(2, 2, 0, ceil_mode=True)), + ("conv", ConvNormLayer(ch_in, ch_out, 1, 1)), + ] + ) + ) else: self.short = ConvNormLayer(ch_in, ch_out, 1, stride) self.branch2a = ConvNormLayer(ch_in, ch_out, 3, stride, act=act) self.branch2b = ConvNormLayer(ch_out, ch_out, 3, 1, act=None) - self.act = nn.Identity() if act is None else get_activation(act) - + self.act = nn.Identity() if act is None else get_activation(act) def forward(self, x): out = self.branch2a(x) @@ -76,7 +78,7 @@ def forward(self, x): short = x else: short = self.short(x) - + out = out + short out = self.act(out) @@ -86,15 +88,15 @@ def forward(self, x): class BottleNeck(nn.Module): expansion = 4 - def __init__(self, ch_in, ch_out, stride, shortcut, act='relu', variant='b'): + def __init__(self, ch_in, ch_out, stride, shortcut, act="relu", variant="b"): super().__init__() - if variant == 'a': + if variant == "a": stride1, stride2 = stride, 1 else: stride1, stride2 = 1, stride - width = ch_out + width = ch_out self.branch2a = ConvNormLayer(ch_in, width, 1, stride1, act=act) self.branch2b = ConvNormLayer(width, width, 3, stride2, act=act) @@ -102,15 +104,19 @@ def __init__(self, ch_in, ch_out, stride, shortcut, act='relu', variant='b'): self.shortcut = shortcut if not shortcut: - if variant == 'd' and stride == 2: - self.short = nn.Sequential(OrderedDict([ - ('pool', nn.AvgPool2d(2, 2, 0, ceil_mode=True)), - ('conv', ConvNormLayer(ch_in, ch_out * self.expansion, 1, 1)) - ])) + if variant == "d" and stride == 2: + self.short = nn.Sequential( + OrderedDict( + [ + ("pool", nn.AvgPool2d(2, 2, 0, ceil_mode=True)), + ("conv", ConvNormLayer(ch_in, ch_out * self.expansion, 1, 1)), + ] + ) + ) else: self.short = ConvNormLayer(ch_in, ch_out * self.expansion, 1, stride) - self.act = nn.Identity() if act is None else get_activation(act) + self.act = nn.Identity() if act is None else get_activation(act) def forward(self, x): out = self.branch2a(x) @@ -129,19 +135,20 @@ def forward(self, x): class Blocks(nn.Module): - def __init__(self, block, ch_in, ch_out, count, stage_num, act='relu', variant='b'): + def __init__(self, block, ch_in, ch_out, count, stage_num, act="relu", variant="b"): super().__init__() self.blocks = nn.ModuleList() for i in range(count): self.blocks.append( block( - ch_in, + ch_in, ch_out, - stride=2 if i == 0 and stage_num != 2 else 1, + stride=2 if i == 0 and stage_num != 2 else 1, shortcut=False if i == 0 else True, variant=variant, - act=act) + act=act, + ) ) if i == 0: @@ -157,20 +164,21 @@ def forward(self, x): @register() class PResNet(nn.Module): def __init__( - self, - depth, - variant='d', - num_stages=4, - return_idx=[0, 1, 2, 3], - act='relu', - freeze_at=-1, - freeze_norm=True, - pretrained=False): + self, + depth, + variant="d", + num_stages=4, + return_idx=[0, 1, 2, 3], + act="relu", + freeze_at=-1, + freeze_norm=True, + pretrained=False, + ): super().__init__() block_nums = ResNet_cfg[depth] ch_in = 64 - if variant in ['c', 'd']: + if variant in ["c", "d"]: conv_def = [ [3, ch_in // 2, 3, 2, "conv1_1"], [ch_in // 2, ch_in // 2, 3, 1, "conv1_2"], @@ -179,9 +187,14 @@ def __init__( else: conv_def = [[3, ch_in, 7, 2, "conv1_1"]] - self.conv1 = nn.Sequential(OrderedDict([ - (name, ConvNormLayer(cin, cout, k, s, act=act)) for cin, cout, k, s, name in conv_def - ])) + self.conv1 = nn.Sequential( + OrderedDict( + [ + (name, ConvNormLayer(cin, cout, k, s, act=act)) + for cin, cout, k, s, name in conv_def + ] + ) + ) ch_out_list = [64, 128, 256, 512] block = BottleNeck if depth >= 50 else BasicBlock @@ -193,7 +206,9 @@ def __init__( for i in range(num_stages): stage_num = i + 2 self.res_layers.append( - Blocks(block, ch_in, ch_out_list[i], block_nums[i], stage_num, act=act, variant=variant) + Blocks( + block, ch_in, ch_out_list[i], block_nums[i], stage_num, act=act, variant=variant + ) ) ch_in = _out_channels[i] @@ -210,12 +225,12 @@ def __init__( self._freeze_norm(self) if pretrained: - if isinstance(pretrained, bool) or 'http' in pretrained: - state = torch.hub.load_state_dict_from_url(donwload_url[depth], map_location='cpu') + if isinstance(pretrained, bool) or "http" in pretrained: + state = torch.hub.load_state_dict_from_url(donwload_url[depth], map_location="cpu") else: - state = torch.load(pretrained, map_location='cpu') + state = torch.load(pretrained, map_location="cpu") self.load_state_dict(state) - print(f'Load PResNet{depth} state_dict') + print(f"Load PResNet{depth} state_dict") def _freeze_parameters(self, m: nn.Module): for p in m.parameters(): @@ -240,5 +255,3 @@ def forward(self, x): if idx in self.return_idx: outs.append(x) return outs - - diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/__init__.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/__init__.py index e1078b225..fe463ca0a 100644 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/__init__.py +++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/__init__.py @@ -1,7 +1,7 @@ """Copyright(c) 2023 lyuwenyu. All Rights Reserved. """ -from .workspace import * -from .yaml_utils import * from ._config import BaseConfig +from .workspace import * from .yaml_config import YAMLConfig +from .yaml_utils import * diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/_config.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/_config.py index 91d720ca2..33f6461a9 100644 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/_config.py +++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/_config.py @@ -1,76 +1,81 @@ """Copyright(c) 2023 lyuwenyu. All Rights Reserved. """ -import torch.nn as nn -from torch.utils.data import Dataset, DataLoader +from typing import Callable + +import torch.nn as nn +from torch.cuda.amp.grad_scaler import GradScaler from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler -from torch.cuda.amp.grad_scaler import GradScaler +from torch.utils.data import DataLoader, Dataset from torch.utils.tensorboard import SummaryWriter -from typing import Callable - -__all__ = ['BaseConfig', ] +__all__ = [ + "BaseConfig", +] class BaseConfig(object): - def __init__(self) -> None: super().__init__() - self.task :str = None + self.task: str = None - # instance / function - self._model :nn.Module = None - self._postprocessor :nn.Module = None - self._criterion :nn.Module = None - self._optimizer :Optimizer = None - self._lr_scheduler :LRScheduler = None - self._lr_warmup_scheduler: LRScheduler = None - self._train_dataloader :DataLoader = None - self._val_dataloader :DataLoader = None - self._ema :nn.Module = None - self._scaler :GradScaler = None - self._train_dataset :Dataset = None - self._val_dataset :Dataset = None - self._collate_fn :Callable = None - self._evaluator :Callable[[nn.Module, DataLoader, str], ] = None + # instance / function + self._model: nn.Module = None + self._postprocessor: nn.Module = None + self._criterion: nn.Module = None + self._optimizer: Optimizer = None + self._lr_scheduler: LRScheduler = None + self._lr_warmup_scheduler: LRScheduler = None + self._train_dataloader: DataLoader = None + self._val_dataloader: DataLoader = None + self._ema: nn.Module = None + self._scaler: GradScaler = None + self._train_dataset: Dataset = None + self._val_dataset: Dataset = None + self._collate_fn: Callable = None + self._evaluator: Callable[[nn.Module, DataLoader, str],] = None self._writer: SummaryWriter = None - - # dataset - self.num_workers :int = 0 - self.batch_size :int = None - self._train_batch_size :int = None - self._val_batch_size :int = None - self._train_shuffle: bool = None - self._val_shuffle: bool = None + + # dataset + self.num_workers: int = 0 + self.batch_size: int = None + self._train_batch_size: int = None + self._val_batch_size: int = None + self._train_shuffle: bool = None + self._val_shuffle: bool = None # runtime - self.resume :str = None - self.tuning :str = None + self.resume: str = None + self.tuning: str = None - self.epoches :int = None - self.last_epoch :int = -1 + self.epoches: int = None + self.last_epoch: int = -1 - self.use_amp :bool = False - self.use_ema :bool = False - self.ema_decay :float = 0.9999 + self.use_amp: bool = False + self.use_ema: bool = False + self.ema_decay: float = 0.9999 self.ema_warmups: int = 2000 - self.sync_bn :bool = False - self.clip_max_norm : float = 0. - self.find_unused_parameters :bool = None + self.sync_bn: bool = False + self.clip_max_norm: float = 0.0 + self.find_unused_parameters: bool = None - self.seed :int = None - self.print_freq :int = None - self.checkpoint_freq :int = 1 - self.output_dir :str = None - self.summary_dir :str = None - self.device : str = '' + self.seed: int = None + self.print_freq: int = None + self.checkpoint_freq: int = 1 + self.output_dir: str = None + self.summary_dir: str = None + self.device: str = "" @property - def model(self, ) -> nn.Module: - return self._model + def model( + self, + ) -> nn.Module: + return self._model @property - def postprocessor(self, ) -> nn.Module: - return self._postprocessor \ No newline at end of file + def postprocessor( + self, + ) -> nn.Module: + return self._postprocessor diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/workspace.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/workspace.py index ea0e41af6..b96a83381 100644 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/workspace.py +++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/workspace.py @@ -1,171 +1,168 @@ """"Copyright(c) 2023 lyuwenyu. All Rights Reserved. """ -import inspect -import importlib import functools +import importlib import inspect from collections import defaultdict -from typing import Any, Dict, Optional, List - +from typing import Any, Dict, List, Optional GLOBAL_CONFIG = defaultdict(dict) -def register(dct :Any=GLOBAL_CONFIG, name=None, force=False): +def register(dct: Any = GLOBAL_CONFIG, name=None, force=False): """ - dct: - if dct is Dict, register foo into dct as key-value pair - if dct is Clas, register as modules attibute - force - whether force register. + dct: + if dct is Dict, register foo into dct as key-value pair + if dct is Clas, register as modules attibute + force + whether force register. """ + def decorator(foo): register_name = foo.__name__ if name is None else name if not force: if inspect.isclass(dct): - assert not hasattr(dct, foo.__name__), \ - f'module {dct.__name__} has {foo.__name__}' + assert not hasattr(dct, foo.__name__), f"module {dct.__name__} has {foo.__name__}" else: - assert foo.__name__ not in dct, \ - f'{foo.__name__} has been already registered' + assert foo.__name__ not in dct, f"{foo.__name__} has been already registered" if inspect.isfunction(foo): + @functools.wraps(foo) def wrap_func(*args, **kwargs): return foo(*args, **kwargs) + if isinstance(dct, dict): dct[foo.__name__] = wrap_func elif inspect.isclass(dct): setattr(dct, foo.__name__, wrap_func) else: - raise AttributeError('') + raise AttributeError("") return wrap_func elif inspect.isclass(foo): - dct[register_name] = extract_schema(foo) + dct[register_name] = extract_schema(foo) else: - raise ValueError(f'Do not support {type(foo)} register') + raise ValueError(f"Do not support {type(foo)} register") return foo return decorator - def extract_schema(module: type): """ Args: module (type), Return: - Dict, + Dict, """ argspec = inspect.getfullargspec(module.__init__) - arg_names = [arg for arg in argspec.args if arg != 'self'] + arg_names = [arg for arg in argspec.args if arg != "self"] num_defualts = len(argspec.defaults) if argspec.defaults is not None else 0 num_requires = len(arg_names) - num_defualts schame = dict() - schame['_name'] = module.__name__ - schame['_pymodule'] = importlib.import_module(module.__module__) - schame['_inject'] = getattr(module, '__inject__', []) - schame['_share'] = getattr(module, '__share__', []) - schame['_kwargs'] = {} + schame["_name"] = module.__name__ + schame["_pymodule"] = importlib.import_module(module.__module__) + schame["_inject"] = getattr(module, "__inject__", []) + schame["_share"] = getattr(module, "__share__", []) + schame["_kwargs"] = {} for i, name in enumerate(arg_names): - if name in schame['_share']: - assert i >= num_requires, 'share config must have default value.' + if name in schame["_share"]: + assert i >= num_requires, "share config must have default value." value = argspec.defaults[i - num_requires] - + elif i >= num_requires: value = argspec.defaults[i - num_requires] else: - value = None + value = None schame[name] = value - schame['_kwargs'][name] = value - + schame["_kwargs"][name] = value + return schame def create(type_or_name, global_cfg=GLOBAL_CONFIG, **kwargs): - """ - """ - assert type(type_or_name) in (type, str), 'create should be modules or name.' + """ """ + assert type(type_or_name) in (type, str), "create should be modules or name." name = type_or_name if isinstance(type_or_name, str) else type_or_name.__name__ if name in global_cfg: - if hasattr(global_cfg[name], '__dict__'): + if hasattr(global_cfg[name], "__dict__"): return global_cfg[name] else: - raise ValueError('The module {} is not registered'.format(name)) + raise ValueError("The module {} is not registered".format(name)) cfg = global_cfg[name] - if isinstance(cfg, dict) and 'type' in cfg: - _cfg: dict = global_cfg[cfg['type']] + if isinstance(cfg, dict) and "type" in cfg: + _cfg: dict = global_cfg[cfg["type"]] # clean args - _keys = [k for k in _cfg.keys() if not k.startswith('_')] + _keys = [k for k in _cfg.keys() if not k.startswith("_")] for _arg in _keys: del _cfg[_arg] - _cfg.update(_cfg['_kwargs']) # restore default args - _cfg.update(cfg) # load config args - _cfg.update(kwargs) # TODO recive extra kwargs - name = _cfg.pop('type') # pop extra key `type` (from cfg) - + _cfg.update(_cfg["_kwargs"]) # restore default args + _cfg.update(cfg) # load config args + _cfg.update(kwargs) # TODO recive extra kwargs + name = _cfg.pop("type") # pop extra key `type` (from cfg) + return create(name, global_cfg) - - module = getattr(cfg['_pymodule'], name) + + module = getattr(cfg["_pymodule"], name) module_kwargs = {} module_kwargs.update(cfg) - + # shared var - for k in cfg['_share']: + for k in cfg["_share"]: if k in global_cfg: module_kwargs[k] = global_cfg[k] else: module_kwargs[k] = cfg[k] # inject - for k in cfg['_inject']: + for k in cfg["_inject"]: _k = cfg[k] if _k is None: continue - if isinstance(_k, str): + if isinstance(_k, str): if _k not in global_cfg: - raise ValueError(f'Missing inject config of {_k}.') + raise ValueError(f"Missing inject config of {_k}.") _cfg = global_cfg[_k] - + if isinstance(_cfg, dict): - module_kwargs[k] = create(_cfg['_name'], global_cfg) + module_kwargs[k] = create(_cfg["_name"], global_cfg) else: - module_kwargs[k] = _cfg + module_kwargs[k] = _cfg elif isinstance(_k, dict): - if 'type' not in _k.keys(): - raise ValueError(f'Missing inject for `type` style.') + if "type" not in _k.keys(): + raise ValueError(f"Missing inject for `type` style.") - _type = str(_k['type']) + _type = str(_k["type"]) if _type not in global_cfg: - raise ValueError(f'Missing {_type} in inspect stage.') + raise ValueError(f"Missing {_type} in inspect stage.") _cfg: dict = global_cfg[_type] - _keys = [k for k in _cfg.keys() if not k.startswith('_')] + _keys = [k for k in _cfg.keys() if not k.startswith("_")] for _arg in _keys: del _cfg[_arg] - _cfg.update(_cfg['_kwargs']) # restore default values - _cfg.update(_k) # load config args - name = _cfg.pop('type') # pop extra key (`type` from _k) + _cfg.update(_cfg["_kwargs"]) # restore default values + _cfg.update(_k) # load config args + name = _cfg.pop("type") # pop extra key (`type` from _k) module_kwargs[k] = create(name, global_cfg) else: - raise ValueError(f'Inject does not support {_k}') - - module_kwargs = {k: v for k, v in module_kwargs.items() if not k.startswith('_')} + raise ValueError(f"Inject does not support {_k}") + + module_kwargs = {k: v for k, v in module_kwargs.items() if not k.startswith("_")} - return module(**module_kwargs) \ No newline at end of file + return module(**module_kwargs) diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/yaml_config.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/yaml_config.py index 8e0e02fd4..3e7129227 100644 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/yaml_config.py +++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/yaml_config.py @@ -1,13 +1,15 @@ """Copyright(c) 2023 lyuwenyu. All Rights Reserved. """ -import torch import copy +import torch + from ._config import BaseConfig from .workspace import create from .yaml_utils import load_config, merge_config, merge_dict + class YAMLConfig(BaseConfig): def __init__(self, cfg_path: str, **kwargs) -> None: super().__init__() @@ -15,24 +17,30 @@ def __init__(self, cfg_path: str, **kwargs) -> None: cfg = load_config(cfg_path) cfg = merge_dict(cfg, kwargs) - self.yaml_cfg = copy.deepcopy(cfg) - + self.yaml_cfg = copy.deepcopy(cfg) + for k in super().__dict__: - if not k.startswith('_') and k in cfg: + if not k.startswith("_") and k in cfg: self.__dict__[k] = cfg[k] @property - def global_cfg(self, ): + def global_cfg( + self, + ): return merge_config(self.yaml_cfg, inplace=False, overwrite=False) - + @property - def model(self, ) -> torch.nn.Module: - if self._model is None and 'model' in self.yaml_cfg: - self._model = create(self.yaml_cfg['model'], self.global_cfg) - return super().model + def model( + self, + ) -> torch.nn.Module: + if self._model is None and "model" in self.yaml_cfg: + self._model = create(self.yaml_cfg["model"], self.global_cfg) + return super().model @property - def postprocessor(self, ) -> torch.nn.Module: - if self._postprocessor is None and 'postprocessor' in self.yaml_cfg: - self._postprocessor = create(self.yaml_cfg['postprocessor'], self.global_cfg) - return super().postprocessor \ No newline at end of file + def postprocessor( + self, + ) -> torch.nn.Module: + if self._postprocessor is None and "postprocessor" in self.yaml_cfg: + self._postprocessor = create(self.yaml_cfg["postprocessor"], self.global_cfg) + return super().postprocessor diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/yaml_utils.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/yaml_utils.py index 1033bcf21..9a47d35e0 100644 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/yaml_utils.py +++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/yaml_utils.py @@ -1,28 +1,28 @@ """"Copyright(c) 2023 lyuwenyu. All Rights Reserved. """ -import os import copy -import yaml -from typing import Any, Dict, Optional, List +import os +from typing import Any, Dict, List, Optional + +import yaml from .workspace import GLOBAL_CONFIG __all__ = [ - 'load_config', - 'merge_config', - 'merge_dict', + "load_config", + "merge_config", + "merge_dict", ] -INCLUDE_KEY = '__include__' +INCLUDE_KEY = "__include__" def load_config(file_path, cfg=dict()): - """load config - """ + """load config""" _, ext = os.path.splitext(file_path) - assert ext in ['.yml', '.yaml'], "only support yaml files" + assert ext in [".yml", ".yaml"], "only support yaml files" with open(file_path) as f: file_cfg = yaml.load(f, Loader=yaml.Loader) @@ -32,10 +32,10 @@ def load_config(file_path, cfg=dict()): if INCLUDE_KEY in file_cfg: base_yamls = list(file_cfg[INCLUDE_KEY]) for base_yaml in base_yamls: - if base_yaml.startswith('~'): + if base_yaml.startswith("~"): base_yaml = os.path.expanduser(base_yaml) - if not base_yaml.startswith('/'): + if not base_yaml.startswith("/"): base_yaml = os.path.join(os.path.dirname(file_path), base_yaml) with open(base_yaml) as f: @@ -46,24 +46,24 @@ def load_config(file_path, cfg=dict()): def merge_dict(dct, another_dct, inplace=True) -> Dict: - """merge another_dct into dct - """ + """merge another_dct into dct""" + def _merge(dct, another) -> Dict: for k in another: - if (k in dct and isinstance(dct[k], dict) and isinstance(another[k], dict)): + if k in dct and isinstance(dct[k], dict) and isinstance(another[k], dict): _merge(dct[k], another[k]) else: dct[k] = another[k] return dct - + if not inplace: dct = copy.deepcopy(dct) - + return _merge(dct, another_dct) -def merge_config(cfg, another_cfg=GLOBAL_CONFIG, inplace: bool=False, overwrite: bool=False): +def merge_config(cfg, another_cfg=GLOBAL_CONFIG, inplace: bool = False, overwrite: bool = False): """ Merge another_cfg into cfg, return the merged config @@ -78,19 +78,20 @@ def merge_config(cfg, another_cfg=GLOBAL_CONFIG, inplace: bool=False, overwrite: model1 = create(cfg1['model'], cfg1) model2 = create(cfg2['model'], cfg2) """ + def _merge(dct, another): for k in another: if k not in dct: dct[k] = another[k] - + elif isinstance(dct[k], dict) and isinstance(another[k], dict): - _merge(dct[k], another[k]) - + _merge(dct[k], another[k]) + elif overwrite: dct[k] = another[k] return cfg - + if not inplace: cfg = copy.deepcopy(cfg) diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/__init__.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/__init__.py index 1df1f9669..f11cde740 100644 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/__init__.py +++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/__init__.py @@ -2,7 +2,7 @@ """ -from .rtdetr import RTDETR from .hybrid_encoder import HybridEncoder +from .rtdetr import RTDETR from .rtdetr_postprocessor import RTDETRPostProcessor -from .rtdetrv2_decoder import RTDETRTransformerv2 \ No newline at end of file +from .rtdetrv2_decoder import RTDETRTransformerv2 diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/box_ops.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/box_ops.py index c45752a6c..f910fddf5 100644 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/box_ops.py +++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/box_ops.py @@ -10,13 +10,11 @@ def box_cxcywh_to_xyxy(x: Tensor) -> Tensor: x_c, y_c, w, h = x.unbind(-1) - b = [(x_c - 0.5 * w), (y_c - 0.5 * h), - (x_c + 0.5 * w), (y_c + 0.5 * h)] + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] return torch.stack(b, dim=-1) def box_xyxy_to_cxcywh(x: Tensor) -> Tensor: x0, y0, x1, y1 = x.unbind(-1) - b = [(x0 + x1) / 2, (y0 + y1) / 2, - (x1 - x0), (y1 - y0)] - return torch.stack(b, dim=-1) \ No newline at end of file + b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)] + return torch.stack(b, dim=-1) diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/denoising.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/denoising.py index c50f214c6..e63002e94 100644 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/denoising.py +++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/denoising.py @@ -1,26 +1,28 @@ """Copyright(c) 2023 lyuwenyu. All Rights Reserved. """ -import torch +import torch -from .utils import inverse_sigmoid from .box_ops import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh +from .utils import inverse_sigmoid -def get_contrastive_denoising_training_group(targets, - num_classes, - num_queries, - class_embed, - num_denoising=100, - label_noise_ratio=0.5, - box_noise_scale=1.0,): +def get_contrastive_denoising_training_group( + targets, + num_classes, + num_queries, + class_embed, + num_denoising=100, + label_noise_ratio=0.5, + box_noise_scale=1.0, +): """cnd""" if num_denoising <= 0: return None, None, None, None - num_gts = [len(t['labels']) for t in targets] - device = targets[0]['labels'].device - + num_gts = [len(t["labels"]) for t in targets] + device = targets[0]["labels"].device + max_gt_num = max(num_gts) if max_gt_num == 0: return None, None, None, None @@ -37,8 +39,8 @@ def get_contrastive_denoising_training_group(targets, for i in range(bs): num_gt = num_gts[i] if num_gt > 0: - input_query_class[i, :num_gt] = targets[i]['labels'] - input_query_bbox[i, :num_gt] = targets[i]['boxes'] + input_query_class[i, :num_gt] = targets[i]["labels"] + input_query_bbox[i, :num_gt] = targets[i]["boxes"] pad_gt_mask[i, :num_gt] = 1 # each group has positive and negative queries. input_query_class = input_query_class.tile([1, 2 * num_group]) @@ -68,7 +70,7 @@ def get_contrastive_denoising_training_group(targets, rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0 rand_part = torch.rand_like(input_query_bbox) rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask) - known_bbox += (rand_sign * rand_part * diff) + known_bbox += rand_sign * rand_part * diff known_bbox = torch.clip(known_bbox, min=0.0, max=1.0) input_query_bbox = box_xyxy_to_cxcywh(known_bbox) input_query_bbox_unact = inverse_sigmoid(input_query_bbox) @@ -79,21 +81,27 @@ def get_contrastive_denoising_training_group(targets, attn_mask = torch.full([tgt_size, tgt_size], False, dtype=torch.bool, device=device) # match query cannot see the reconstruction attn_mask[num_denoising:, :num_denoising] = True - + # reconstruct cannot see each other for i in range(num_group): if i == 0: - attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), max_gt_num * 2 * (i + 1): num_denoising] = True + attn_mask[ + max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), + max_gt_num * 2 * (i + 1) : num_denoising, + ] = True if i == num_group - 1: - attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), :max_gt_num * i * 2] = True + attn_mask[max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), : max_gt_num * i * 2] = True else: - attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), max_gt_num * 2 * (i + 1): num_denoising] = True - attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), :max_gt_num * 2 * i] = True - + attn_mask[ + max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), + max_gt_num * 2 * (i + 1) : num_denoising, + ] = True + attn_mask[max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), : max_gt_num * 2 * i] = True + dn_meta = { "dn_positive_idx": dn_positive_idx, "dn_num_group": num_group, - "dn_num_split": [num_denoising, num_queries] + "dn_num_split": [num_denoising, num_queries], } return input_query_logits, input_query_bbox_unact, attn_mask, dn_meta diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/hybrid_encoder.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/hybrid_encoder.py index 15e5acf7e..a4446a988 100644 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/hybrid_encoder.py +++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/hybrid_encoder.py @@ -4,47 +4,45 @@ import copy from collections import OrderedDict -import torch -import torch.nn as nn -import torch.nn.functional as F - -from .utils import get_activation +import torch +import torch.nn as nn +import torch.nn.functional as F from ..core import register +from .utils import get_activation - -__all__ = ['HybridEncoder'] - +__all__ = ["HybridEncoder"] class ConvNormLayer(nn.Module): def __init__(self, ch_in, ch_out, kernel_size, stride, padding=None, bias=False, act=None): super().__init__() self.conv = nn.Conv2d( - ch_in, - ch_out, - kernel_size, - stride, - padding=(kernel_size-1)//2 if padding is None else padding, - bias=bias) + ch_in, + ch_out, + kernel_size, + stride, + padding=(kernel_size - 1) // 2 if padding is None else padding, + bias=bias, + ) self.norm = nn.BatchNorm2d(ch_out) - self.act = nn.Identity() if act is None else get_activation(act) + self.act = nn.Identity() if act is None else get_activation(act) def forward(self, x): return self.act(self.norm(self.conv(x))) class RepVggBlock(nn.Module): - def __init__(self, ch_in, ch_out, act='relu'): + def __init__(self, ch_in, ch_out, act="relu"): super().__init__() self.ch_in = ch_in self.ch_out = ch_out self.conv1 = ConvNormLayer(ch_in, ch_out, 3, 1, padding=1, act=None) self.conv2 = ConvNormLayer(ch_in, ch_out, 1, 1, padding=0, act=None) - self.act = nn.Identity() if act is None else get_activation(act) + self.act = nn.Identity() if act is None else get_activation(act) def forward(self, x): - if hasattr(self, 'conv'): + if hasattr(self, "conv"): y = self.conv(x) else: y = self.conv1(x) + self.conv2(x) @@ -52,17 +50,17 @@ def forward(self, x): return self.act(y) def convert_to_deploy(self): - if not hasattr(self, 'conv'): + if not hasattr(self, "conv"): self.conv = nn.Conv2d(self.ch_in, self.ch_out, 3, 1, padding=1) kernel, bias = self.get_equivalent_kernel_bias() self.conv.weight.data = kernel - self.conv.bias.data = bias + self.conv.bias.data = bias def get_equivalent_kernel_bias(self): kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1) kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2) - + return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1), bias3x3 + bias1x1 def _pad_1x1_to_3x3_tensor(self, kernel1x1): @@ -86,20 +84,16 @@ def _fuse_bn_tensor(self, branch: ConvNormLayer): class CSPRepLayer(nn.Module): - def __init__(self, - in_channels, - out_channels, - num_blocks=3, - expansion=1.0, - bias=None, - act="silu"): + def __init__( + self, in_channels, out_channels, num_blocks=3, expansion=1.0, bias=None, act="silu" + ): super(CSPRepLayer, self).__init__() hidden_channels = int(out_channels * expansion) self.conv1 = ConvNormLayer(in_channels, hidden_channels, 1, 1, bias=bias, act=act) self.conv2 = ConvNormLayer(in_channels, hidden_channels, 1, 1, bias=bias, act=act) - self.bottlenecks = nn.Sequential(*[ - RepVggBlock(hidden_channels, hidden_channels, act=act) for _ in range(num_blocks) - ]) + self.bottlenecks = nn.Sequential( + *[RepVggBlock(hidden_channels, hidden_channels, act=act) for _ in range(num_blocks)] + ) if hidden_channels != out_channels: self.conv3 = ConvNormLayer(hidden_channels, out_channels, 1, 1, bias=bias, act=act) else: @@ -114,13 +108,15 @@ def forward(self, x): # transformer class TransformerEncoderLayer(nn.Module): - def __init__(self, - d_model, - nhead, - dim_feedforward=2048, - dropout=0.1, - activation="relu", - normalize_before=False): + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + ): super().__init__() self.normalize_before = normalize_before @@ -134,7 +130,7 @@ def __init__(self, self.norm2 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) - self.activation = get_activation(activation) + self.activation = get_activation(activation) @staticmethod def with_pos_embed(tensor, pos_embed): @@ -181,24 +177,28 @@ def forward(self, src, src_mask=None, pos_embed=None) -> torch.Tensor: @register() class HybridEncoder(nn.Module): - __share__ = ['eval_spatial_size', ] - - def __init__(self, - in_channels=[512, 1024, 2048], - feat_strides=[8, 16, 32], - hidden_dim=256, - nhead=8, - dim_feedforward = 1024, - dropout=0.0, - enc_act='gelu', - use_encoder_idx=[2], - num_encoder_layers=1, - pe_temperature=10000, - expansion=1.0, - depth_mult=1.0, - act='silu', - eval_spatial_size=None, - version='v2'): + __share__ = [ + "eval_spatial_size", + ] + + def __init__( + self, + in_channels=[512, 1024, 2048], + feat_strides=[8, 16, 32], + hidden_dim=256, + nhead=8, + dim_feedforward=1024, + dropout=0.0, + enc_act="gelu", + use_encoder_idx=[2], + num_encoder_layers=1, + pe_temperature=10000, + expansion=1.0, + depth_mult=1.0, + act="silu", + eval_spatial_size=None, + version="v2", + ): super().__init__() self.in_channels = in_channels self.feat_strides = feat_strides @@ -206,38 +206,47 @@ def __init__(self, self.use_encoder_idx = use_encoder_idx self.num_encoder_layers = num_encoder_layers self.pe_temperature = pe_temperature - self.eval_spatial_size = eval_spatial_size + self.eval_spatial_size = eval_spatial_size self.out_channels = [hidden_dim for _ in range(len(in_channels))] self.out_strides = feat_strides - + # channel projection self.input_proj = nn.ModuleList() for in_channel in in_channels: - if version == 'v1': + if version == "v1": proj = nn.Sequential( nn.Conv2d(in_channel, hidden_dim, kernel_size=1, bias=False), - nn.BatchNorm2d(hidden_dim)) - elif version == 'v2': - proj = nn.Sequential(OrderedDict([ - ('conv', nn.Conv2d(in_channel, hidden_dim, kernel_size=1, bias=False)), - ('norm', nn.BatchNorm2d(hidden_dim)) - ])) + nn.BatchNorm2d(hidden_dim), + ) + elif version == "v2": + proj = nn.Sequential( + OrderedDict( + [ + ("conv", nn.Conv2d(in_channel, hidden_dim, kernel_size=1, bias=False)), + ("norm", nn.BatchNorm2d(hidden_dim)), + ] + ) + ) else: raise AttributeError() - + self.input_proj.append(proj) # encoder transformer encoder_layer = TransformerEncoderLayer( - hidden_dim, + hidden_dim, nhead=nhead, - dim_feedforward=dim_feedforward, + dim_feedforward=dim_feedforward, dropout=dropout, - activation=enc_act) + activation=enc_act, + ) - self.encoder = nn.ModuleList([ - TransformerEncoder(copy.deepcopy(encoder_layer), num_encoder_layers) for _ in range(len(use_encoder_idx)) - ]) + self.encoder = nn.ModuleList( + [ + TransformerEncoder(copy.deepcopy(encoder_layer), num_encoder_layers) + for _ in range(len(use_encoder_idx)) + ] + ) # top-down fpn self.lateral_convs = nn.ModuleList() @@ -245,18 +254,20 @@ def __init__(self, for _ in range(len(in_channels) - 1, 0, -1): self.lateral_convs.append(ConvNormLayer(hidden_dim, hidden_dim, 1, 1, act=act)) self.fpn_blocks.append( - CSPRepLayer(hidden_dim * 2, hidden_dim, round(3 * depth_mult), act=act, expansion=expansion) + CSPRepLayer( + hidden_dim * 2, hidden_dim, round(3 * depth_mult), act=act, expansion=expansion + ) ) # bottom-up pan self.downsample_convs = nn.ModuleList() self.pan_blocks = nn.ModuleList() for _ in range(len(in_channels) - 1): - self.downsample_convs.append( - ConvNormLayer(hidden_dim, hidden_dim, 3, 2, act=act) - ) + self.downsample_convs.append(ConvNormLayer(hidden_dim, hidden_dim, 3, 2, act=act)) self.pan_blocks.append( - CSPRepLayer(hidden_dim * 2, hidden_dim, round(3 * depth_mult), act=act, expansion=expansion) + CSPRepLayer( + hidden_dim * 2, hidden_dim, round(3 * depth_mult), act=act, expansion=expansion + ) ) self._reset_parameters() @@ -266,23 +277,26 @@ def _reset_parameters(self): for idx in self.use_encoder_idx: stride = self.feat_strides[idx] pos_embed = self.build_2d_sincos_position_embedding( - self.eval_spatial_size[1] // stride, self.eval_spatial_size[0] // stride, - self.hidden_dim, self.pe_temperature) - setattr(self, f'pos_embed{idx}', pos_embed) - #self.register_buffer(f'pos_embed{idx}', pos_embed) + self.eval_spatial_size[1] // stride, + self.eval_spatial_size[0] // stride, + self.hidden_dim, + self.pe_temperature, + ) + setattr(self, f"pos_embed{idx}", pos_embed) + # self.register_buffer(f'pos_embed{idx}', pos_embed) @staticmethod - def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.): - """ - """ + def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.0): + """ """ grid_w = torch.arange(int(w), dtype=torch.float32) grid_h = torch.arange(int(h), dtype=torch.float32) - grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing='ij') - assert embed_dim % 4 == 0, \ - 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding' + grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij") + assert ( + embed_dim % 4 == 0 + ), "Embed dimension must be divisible by 4 for 2D sin-cos position embedding" pos_dim = embed_dim // 4 omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim - omega = 1. / (temperature ** omega) + omega = 1.0 / (temperature**omega) out_w = grid_w.flatten()[..., None] @ omega[None] out_h = grid_h.flatten()[..., None] @ omega[None] @@ -292,7 +306,7 @@ def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.): def forward(self, feats): assert len(feats) == len(self.in_channels) proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)] - + # encoder if self.num_encoder_layers > 0: for i, enc_ind in enumerate(self.use_encoder_idx): @@ -301,12 +315,15 @@ def forward(self, feats): src_flatten = proj_feats[enc_ind].flatten(2).permute(0, 2, 1) if self.training or self.eval_spatial_size is None: pos_embed = self.build_2d_sincos_position_embedding( - w, h, self.hidden_dim, self.pe_temperature).to(src_flatten.device) + w, h, self.hidden_dim, self.pe_temperature + ).to(src_flatten.device) else: - pos_embed = getattr(self, f'pos_embed{enc_ind}', None).to(src_flatten.device) + pos_embed = getattr(self, f"pos_embed{enc_ind}", None).to(src_flatten.device) - memory :torch.Tensor = self.encoder[i](src_flatten, pos_embed=pos_embed) - proj_feats[enc_ind] = memory.permute(0, 2, 1).reshape(-1, self.hidden_dim, h, w).contiguous() + memory: torch.Tensor = self.encoder[i](src_flatten, pos_embed=pos_embed) + proj_feats[enc_ind] = ( + memory.permute(0, 2, 1).reshape(-1, self.hidden_dim, h, w).contiguous() + ) # broadcasting and fusion inner_outs = [proj_feats[-1]] @@ -315,8 +332,10 @@ def forward(self, feats): feat_low = proj_feats[idx - 1] feat_heigh = self.lateral_convs[len(self.in_channels) - 1 - idx](feat_heigh) inner_outs[0] = feat_heigh - upsample_feat = F.interpolate(feat_heigh, scale_factor=2., mode='nearest') - inner_out = self.fpn_blocks[len(self.in_channels)-1-idx](torch.concat([upsample_feat, feat_low], dim=1)) + upsample_feat = F.interpolate(feat_heigh, scale_factor=2.0, mode="nearest") + inner_out = self.fpn_blocks[len(self.in_channels) - 1 - idx]( + torch.concat([upsample_feat, feat_low], dim=1) + ) inner_outs.insert(0, inner_out) outs = [inner_outs[0]] diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetr.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetr.py index 77a23be5c..1bf28c22f 100644 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetr.py +++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetr.py @@ -1,44 +1,52 @@ """Copyright(c) 2023 lyuwenyu. All Rights Reserved. """ -import torch -import torch.nn as nn -import torch.nn.functional as F +import random +from typing import List -import random -import numpy as np -from typing import List +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F from ..core import register - -__all__ = ['RTDETR', ] +__all__ = [ + "RTDETR", +] @register() class RTDETR(nn.Module): - __inject__ = ['backbone', 'encoder', 'decoder', ] - - def __init__(self, \ - backbone: nn.Module, - encoder: nn.Module, - decoder: nn.Module, + __inject__ = [ + "backbone", + "encoder", + "decoder", + ] + + def __init__( + self, + backbone: nn.Module, + encoder: nn.Module, + decoder: nn.Module, ): super().__init__() self.backbone = backbone self.decoder = decoder self.encoder = encoder - + def forward(self, x, targets=None): x = self.backbone(x) - x = self.encoder(x) + x = self.encoder(x) x = self.decoder(x, targets) return x - - def deploy(self, ): + + def deploy( + self, + ): self.eval() for m in self.modules(): - if hasattr(m, 'convert_to_deploy'): + if hasattr(m, "convert_to_deploy"): m.convert_to_deploy() - return self + return self diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetr_postprocessor.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetr_postprocessor.py index 6d024fa45..1afc7c87c 100644 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetr_postprocessor.py +++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetr_postprocessor.py @@ -1,16 +1,14 @@ """Copyright(c) 2023 lyuwenyu. All Rights Reserved. """ -import torch -import torch.nn as nn -import torch.nn.functional as F - +import torch +import torch.nn as nn +import torch.nn.functional as F import torchvision from ..core import register - -__all__ = ['RTDETRPostProcessor'] +__all__ = ["RTDETRPostProcessor"] def mod(a, b): @@ -20,36 +18,27 @@ def mod(a, b): @register() class RTDETRPostProcessor(nn.Module): - __share__ = [ - 'num_classes', - 'use_focal_loss', - 'num_top_queries', - 'remap_mscoco_category' - ] - + __share__ = ["num_classes", "use_focal_loss", "num_top_queries", "remap_mscoco_category"] + def __init__( - self, - num_classes=80, - use_focal_loss=True, - num_top_queries=300, - remap_mscoco_category=False + self, num_classes=80, use_focal_loss=True, num_top_queries=300, remap_mscoco_category=False ) -> None: super().__init__() self.use_focal_loss = use_focal_loss self.num_top_queries = num_top_queries self.num_classes = int(num_classes) - self.remap_mscoco_category = remap_mscoco_category - self.deploy_mode = False + self.remap_mscoco_category = remap_mscoco_category + self.deploy_mode = False def extra_repr(self) -> str: - return f'use_focal_loss={self.use_focal_loss}, num_classes={self.num_classes}, num_top_queries={self.num_top_queries}' - + return f"use_focal_loss={self.use_focal_loss}, num_classes={self.num_classes}, num_top_queries={self.num_top_queries}" + # def forward(self, outputs, orig_target_sizes): def forward(self, outputs, orig_target_sizes: torch.Tensor): - logits, boxes = outputs['pred_logits'], outputs['pred_boxes'] - # orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) + logits, boxes = outputs["pred_logits"], outputs["pred_boxes"] + # orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) - bbox_pred = torchvision.ops.box_convert(boxes, in_fmt='cxcywh', out_fmt='xyxy') + bbox_pred = torchvision.ops.box_convert(boxes, in_fmt="cxcywh", out_fmt="xyxy") bbox_pred *= orig_target_sizes.repeat(1, 2).unsqueeze(1) if self.use_focal_loss: @@ -59,16 +48,20 @@ def forward(self, outputs, orig_target_sizes: torch.Tensor): # labels = index % self.num_classes labels = mod(index, self.num_classes) index = index // self.num_classes - boxes = bbox_pred.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, bbox_pred.shape[-1])) - + boxes = bbox_pred.gather( + dim=1, index=index.unsqueeze(-1).repeat(1, 1, bbox_pred.shape[-1]) + ) + else: scores = F.softmax(logits)[:, :, :-1] scores, labels = scores.max(dim=-1) if scores.shape[1] > self.num_top_queries: scores, index = torch.topk(scores, self.num_top_queries, dim=-1) labels = torch.gather(labels, dim=1, index=index) - boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1])) - + boxes = torch.gather( + boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1]) + ) + # TODO for onnx export if self.deploy_mode: return labels, boxes, scores @@ -76,18 +69,23 @@ def forward(self, outputs, orig_target_sizes: torch.Tensor): # TODO if self.remap_mscoco_category: from ...data.dataset import mscoco_label2category - labels = torch.tensor([mscoco_label2category[int(x.item())] for x in labels.flatten()])\ - .to(boxes.device).reshape(labels.shape) + + labels = ( + torch.tensor([mscoco_label2category[int(x.item())] for x in labels.flatten()]) + .to(boxes.device) + .reshape(labels.shape) + ) results = [] for lab, box, sco in zip(labels, boxes, scores): result = dict(labels=lab, boxes=box, scores=sco) results.append(result) - + return results - - def deploy(self, ): + def deploy( + self, + ): self.eval() self.deploy_mode = True - return self + return self diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetrv2_decoder.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetrv2_decoder.py index 945fbfd49..87e6a6e44 100644 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetrv2_decoder.py +++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetrv2_decoder.py @@ -1,31 +1,37 @@ """Copyright(c) 2023 lyuwenyu. All Rights Reserved. """ -import math -import copy +import copy import functools +import math from collections import OrderedDict - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.nn.init as init from typing import List -from .denoising import get_contrastive_denoising_training_group -from .utils import deformable_attention_core_func_v2, get_activation, inverse_sigmoid, bias_init_with_prob +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init from ..core import register +from .denoising import get_contrastive_denoising_training_group +from .utils import ( + bias_init_with_prob, + deformable_attention_core_func_v2, + get_activation, + inverse_sigmoid, +) -__all__ = ['RTDETRTransformerv2'] +__all__ = ["RTDETRTransformerv2"] class MLP(nn.Module): - def __init__(self, input_dim, hidden_dim, output_dim, num_layers, act='relu'): + def __init__(self, input_dim, hidden_dim, output_dim, num_layers, act="relu"): super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) - self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) self.act = get_activation(act) def forward(self, x): @@ -36,16 +42,15 @@ def forward(self, x): class MSDeformableAttention(nn.Module): def __init__( - self, - embed_dim=256, - num_heads=8, - num_levels=4, - num_points=4, - method='default', + self, + embed_dim=256, + num_heads=8, + num_levels=4, + num_points=4, + method="default", offset_scale=0.5, ): - """Multi-Scale Deformable Attention - """ + """Multi-Scale Deformable Attention""" super(MSDeformableAttention, self).__init__() self.embed_dim = embed_dim self.num_heads = num_heads @@ -53,43 +58,53 @@ def __init__( self.offset_scale = offset_scale if isinstance(num_points, list): - assert len(num_points) == num_levels, '' + assert len(num_points) == num_levels, "" num_points_list = num_points else: num_points_list = [num_points for _ in range(num_levels)] self.num_points_list = num_points_list - - num_points_scale = [1/n for n in num_points_list for _ in range(n)] - self.register_buffer('num_points_scale', torch.tensor(num_points_scale, dtype=torch.float32)) + + num_points_scale = [1 / n for n in num_points_list for _ in range(n)] + self.register_buffer( + "num_points_scale", torch.tensor(num_points_scale, dtype=torch.float32) + ) self.total_points = num_heads * sum(num_points_list) self.method = method self.head_dim = embed_dim // num_heads - assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" self.sampling_offsets = nn.Linear(embed_dim, self.total_points * 2) self.attention_weights = nn.Linear(embed_dim, self.total_points) self.value_proj = nn.Linear(embed_dim, embed_dim) self.output_proj = nn.Linear(embed_dim, embed_dim) - self.ms_deformable_attn_core = functools.partial(deformable_attention_core_func_v2, method=self.method) + self.ms_deformable_attn_core = functools.partial( + deformable_attention_core_func_v2, method=self.method + ) self._reset_parameters() - if method == 'discrete': + if method == "discrete": for p in self.sampling_offsets.parameters(): p.requires_grad = False def _reset_parameters(self): # sampling_offsets init.constant_(self.sampling_offsets.weight, 0) - thetas = torch.arange(self.num_heads, dtype=torch.float32) * (2.0 * math.pi / self.num_heads) + thetas = torch.arange(self.num_heads, dtype=torch.float32) * ( + 2.0 * math.pi / self.num_heads + ) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = grid_init / grid_init.abs().max(-1, keepdim=True).values grid_init = grid_init.reshape(self.num_heads, 1, 2).tile([1, sum(self.num_points_list), 1]) - scaling = torch.concat([torch.arange(1, n + 1) for n in self.num_points_list]).reshape(1, -1, 1) + scaling = torch.concat([torch.arange(1, n + 1) for n in self.num_points_list]).reshape( + 1, -1, 1 + ) grid_init *= scaling self.sampling_offsets.bias.data[...] = grid_init.flatten() @@ -103,13 +118,14 @@ def _reset_parameters(self): init.xavier_uniform_(self.output_proj.weight) init.constant_(self.output_proj.bias, 0) - - def forward(self, - query: torch.Tensor, - reference_points: torch.Tensor, - value: torch.Tensor, - value_spatial_shapes: List[int], - value_mask: torch.Tensor=None): + def forward( + self, + query: torch.Tensor, + reference_points: torch.Tensor, + value: torch.Tensor, + value_spatial_shapes: List[int], + value_mask: torch.Tensor = None, + ): """ Args: query (Tensor): [bs, query_length, C] @@ -132,27 +148,45 @@ def forward(self, value = value.reshape(bs, Len_v, self.num_heads, self.head_dim) sampling_offsets: torch.Tensor = self.sampling_offsets(query) - sampling_offsets = sampling_offsets.reshape(bs, Len_q, self.num_heads, sum(self.num_points_list), 2) + sampling_offsets = sampling_offsets.reshape( + bs, Len_q, self.num_heads, sum(self.num_points_list), 2 + ) - attention_weights = self.attention_weights(query).reshape(bs, Len_q, self.num_heads, sum(self.num_points_list)) - attention_weights = F.softmax(attention_weights, dim=-1).reshape(bs, Len_q, self.num_heads, sum(self.num_points_list)) + attention_weights = self.attention_weights(query).reshape( + bs, Len_q, self.num_heads, sum(self.num_points_list) + ) + attention_weights = F.softmax(attention_weights, dim=-1).reshape( + bs, Len_q, self.num_heads, sum(self.num_points_list) + ) if reference_points.shape[-1] == 2: offset_normalizer = torch.tensor(value_spatial_shapes) offset_normalizer = offset_normalizer.flip([1]).reshape(1, 1, 1, self.num_levels, 1, 2) - sampling_locations = reference_points.reshape(bs, Len_q, 1, self.num_levels, 1, 2) + sampling_offsets / offset_normalizer + sampling_locations = ( + reference_points.reshape(bs, Len_q, 1, self.num_levels, 1, 2) + + sampling_offsets / offset_normalizer + ) elif reference_points.shape[-1] == 4: # reference_points [8, 480, None, 1, 4] # sampling_offsets [8, 480, 8, 12, 2] num_points_scale = self.num_points_scale.to(dtype=query.dtype).unsqueeze(-1) - offset = sampling_offsets * num_points_scale * reference_points[:, :, None, :, 2:] * self.offset_scale + offset = ( + sampling_offsets + * num_points_scale + * reference_points[:, :, None, :, 2:] + * self.offset_scale + ) sampling_locations = reference_points[:, :, None, :, :2] + offset else: raise ValueError( - "Last dim of reference_points must be 2 or 4, but get {} instead.". - format(reference_points.shape[-1])) + "Last dim of reference_points must be 2 or 4, but get {} instead.".format( + reference_points.shape[-1] + ) + ) - output = self.ms_deformable_attn_core(value, value_spatial_shapes, sampling_locations, attention_weights, self.num_points_list) + output = self.ms_deformable_attn_core( + value, value_spatial_shapes, sampling_locations, attention_weights, self.num_points_list + ) output = self.output_proj(output) @@ -160,15 +194,17 @@ def forward(self, class TransformerDecoderLayer(nn.Module): - def __init__(self, - d_model=256, - n_head=8, - dim_feedforward=1024, - dropout=0., - activation='relu', - n_levels=4, - n_points=4, - cross_attn_method='default'): + def __init__( + self, + d_model=256, + n_head=8, + dim_feedforward=1024, + dropout=0.0, + activation="relu", + n_levels=4, + n_points=4, + cross_attn_method="default", + ): super(TransformerDecoderLayer, self).__init__() # self attention @@ -177,7 +213,9 @@ def __init__(self, self.norm1 = nn.LayerNorm(d_model) # cross attention - self.cross_attn = MSDeformableAttention(d_model, n_head, n_levels, n_points, method=cross_attn_method) + self.cross_attn = MSDeformableAttention( + d_model, n_head, n_levels, n_points, method=cross_attn_method + ) self.dropout2 = nn.Dropout(dropout) self.norm2 = nn.LayerNorm(d_model) @@ -188,7 +226,7 @@ def __init__(self, self.linear2 = nn.Linear(dim_feedforward, d_model) self.dropout4 = nn.Dropout(dropout) self.norm3 = nn.LayerNorm(d_model) - + self._reset_parameters() def _reset_parameters(self): @@ -201,14 +239,16 @@ def with_pos_embed(self, tensor, pos): def forward_ffn(self, tgt): return self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) - def forward(self, - target, - reference_points, - memory, - memory_spatial_shapes, - attn_mask=None, - memory_mask=None, - query_pos_embed=None): + def forward( + self, + target, + reference_points, + memory, + memory_spatial_shapes, + attn_mask=None, + memory_mask=None, + query_pos_embed=None, + ): # self attention q = k = self.with_pos_embed(target, query_pos_embed) @@ -217,12 +257,13 @@ def forward(self, target = self.norm1(target) # cross attention - target2 = self.cross_attn(\ - self.with_pos_embed(target, query_pos_embed), - reference_points, - memory, - memory_spatial_shapes, - memory_mask) + target2 = self.cross_attn( + self.with_pos_embed(target, query_pos_embed), + reference_points, + memory, + memory_spatial_shapes, + memory_mask, + ) target = target + self.dropout2(target2) target = self.norm2(target) @@ -242,16 +283,18 @@ def __init__(self, hidden_dim, decoder_layer, num_layers, eval_idx=-1): self.num_layers = num_layers self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx - def forward(self, - target, - ref_points_unact, - memory, - memory_spatial_shapes, - bbox_head, - score_head, - query_pos_head, - attn_mask=None, - memory_mask=None): + def forward( + self, + target, + ref_points_unact, + memory, + memory_spatial_shapes, + bbox_head, + score_head, + query_pos_head, + attn_mask=None, + memory_mask=None, + ): dec_out_bboxes = [] dec_out_logits = [] ref_points_detach = F.sigmoid(ref_points_unact) @@ -261,7 +304,15 @@ def forward(self, ref_points_input = ref_points_detach.unsqueeze(2) query_pos_embed = query_pos_head(ref_points_detach) - output = layer(output, ref_points_input, memory, memory_spatial_shapes, attn_mask, memory_mask, query_pos_embed) + output = layer( + output, + ref_points_input, + memory, + memory_spatial_shapes, + attn_mask, + memory_mask, + query_pos_embed, + ) inter_ref_bbox = F.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points_detach)) @@ -270,7 +321,9 @@ def forward(self, if i == 0: dec_out_bboxes.append(inter_ref_bbox) else: - dec_out_bboxes.append(F.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points))) + dec_out_bboxes.append( + F.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points)) + ) elif i == self.eval_idx: dec_out_logits.append(score_head[i](output)) @@ -285,35 +338,37 @@ def forward(self, @register() class RTDETRTransformerv2(nn.Module): - __share__ = ['num_classes', 'eval_spatial_size'] - - def __init__(self, - num_classes=80, - hidden_dim=256, - num_queries=300, - feat_channels=[512, 1024, 2048], - feat_strides=[8, 16, 32], - num_levels=3, - num_points=4, - nhead=8, - num_layers=6, - dim_feedforward=1024, - dropout=0., - activation="relu", - num_denoising=100, - label_noise_ratio=0.5, - box_noise_scale=1.0, - learn_query_content=False, - eval_spatial_size=None, - eval_idx=-1, - eps=1e-2, - aux_loss=True, - cross_attn_method='default', - query_select_method='default'): + __share__ = ["num_classes", "eval_spatial_size"] + + def __init__( + self, + num_classes=80, + hidden_dim=256, + num_queries=300, + feat_channels=[512, 1024, 2048], + feat_strides=[8, 16, 32], + num_levels=3, + num_points=4, + nhead=8, + num_layers=6, + dim_feedforward=1024, + dropout=0.0, + activation="relu", + num_denoising=100, + label_noise_ratio=0.5, + box_noise_scale=1.0, + learn_query_content=False, + eval_spatial_size=None, + eval_idx=-1, + eps=1e-2, + aux_loss=True, + cross_attn_method="default", + query_select_method="default", + ): super().__init__() assert len(feat_channels) <= num_levels assert len(feat_strides) == len(feat_channels) - + for _ in range(num_levels - len(feat_strides)): feat_strides.append(feat_strides[-1] * 2) @@ -328,8 +383,8 @@ def __init__(self, self.eval_spatial_size = eval_spatial_size self.aux_loss = aux_loss - assert query_select_method in ('default', 'one2many', 'agnostic'), '' - assert cross_attn_method in ('default', 'discrete'), '' + assert query_select_method in ("default", "one2many", "agnostic"), "" + assert cross_attn_method in ("default", "discrete"), "" self.cross_attn_method = cross_attn_method self.query_select_method = query_select_method @@ -337,16 +392,26 @@ def __init__(self, self._build_input_proj_layer(feat_channels) # Transformer module - decoder_layer = TransformerDecoderLayer(hidden_dim, nhead, dim_feedforward, dropout, \ - activation, num_levels, num_points, cross_attn_method=cross_attn_method) + decoder_layer = TransformerDecoderLayer( + hidden_dim, + nhead, + dim_feedforward, + dropout, + activation, + num_levels, + num_points, + cross_attn_method=cross_attn_method, + ) self.decoder = TransformerDecoder(hidden_dim, decoder_layer, num_layers, eval_idx) # denoising self.num_denoising = num_denoising self.label_noise_ratio = label_noise_ratio self.box_noise_scale = box_noise_scale - if num_denoising > 0: - self.denoising_class_embed = nn.Embedding(num_classes+1, hidden_dim, padding_idx=num_classes) + if num_denoising > 0: + self.denoising_class_embed = nn.Embedding( + num_classes + 1, hidden_dim, padding_idx=num_classes + ) init.normal_(self.denoising_class_embed.weight[:-1]) # decoder embedding @@ -359,12 +424,21 @@ def __init__(self, # layer = TransformerEncoderLayer(hidden_dim, nhead, dim_feedforward, activation='gelu') # self.encoder = TransformerEncoder(layer, 1) - self.enc_output = nn.Sequential(OrderedDict([ - ('proj', nn.Linear(hidden_dim, hidden_dim)), - ('norm', nn.LayerNorm(hidden_dim,)), - ])) + self.enc_output = nn.Sequential( + OrderedDict( + [ + ("proj", nn.Linear(hidden_dim, hidden_dim)), + ( + "norm", + nn.LayerNorm( + hidden_dim, + ), + ), + ] + ) + ) - if query_select_method == 'agnostic': + if query_select_method == "agnostic": self.enc_score_head = nn.Linear(hidden_dim, 1) else: self.enc_score_head = nn.Linear(hidden_dim, num_classes) @@ -372,21 +446,21 @@ def __init__(self, self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3) # decoder head - self.dec_score_head = nn.ModuleList([ - nn.Linear(hidden_dim, num_classes) for _ in range(num_layers) - ]) - self.dec_bbox_head = nn.ModuleList([ - MLP(hidden_dim, hidden_dim, 4, 3) for _ in range(num_layers) - ]) + self.dec_score_head = nn.ModuleList( + [nn.Linear(hidden_dim, num_classes) for _ in range(num_layers)] + ) + self.dec_bbox_head = nn.ModuleList( + [MLP(hidden_dim, hidden_dim, 4, 3) for _ in range(num_layers)] + ) # init encoder output anchors and valid_mask if self.eval_spatial_size: anchors, valid_mask = self._generate_anchors() - self.register_buffer('anchors', anchors) - self.register_buffer('valid_mask', valid_mask) + self.register_buffer("anchors", anchors) + self.register_buffer("valid_mask", valid_mask) self._reset_parameters() - + def _reset_parameters(self): bias = bias_init_with_prob(0.01) init.constant_(self.enc_score_head.bias, bias) @@ -397,7 +471,7 @@ def _reset_parameters(self): init.constant_(_cls.bias, bias) init.constant_(_reg.layers[-1].weight, 0) init.constant_(_reg.layers[-1].bias, 0) - + init.xavier_uniform_(self.enc_output[0].weight) if self.learn_query_content: init.xavier_uniform_(self.tgt_embed.weight) @@ -410,9 +484,18 @@ def _build_input_proj_layer(self, feat_channels): self.input_proj = nn.ModuleList() for in_channels in feat_channels: self.input_proj.append( - nn.Sequential(OrderedDict([ - ('conv', nn.Conv2d(in_channels, self.hidden_dim, 1, bias=False)), - ('norm', nn.BatchNorm2d(self.hidden_dim,))]) + nn.Sequential( + OrderedDict( + [ + ("conv", nn.Conv2d(in_channels, self.hidden_dim, 1, bias=False)), + ( + "norm", + nn.BatchNorm2d( + self.hidden_dim, + ), + ), + ] + ) ) ) @@ -420,9 +503,18 @@ def _build_input_proj_layer(self, feat_channels): for _ in range(self.num_levels - len(feat_channels)): self.input_proj.append( - nn.Sequential(OrderedDict([ - ('conv', nn.Conv2d(in_channels, self.hidden_dim, 3, 2, padding=1, bias=False)), - ('norm', nn.BatchNorm2d(self.hidden_dim))]) + nn.Sequential( + OrderedDict( + [ + ( + "conv", + nn.Conv2d( + in_channels, self.hidden_dim, 3, 2, padding=1, bias=False + ), + ), + ("norm", nn.BatchNorm2d(self.hidden_dim)), + ] + ) ) ) in_channels = self.hidden_dim @@ -451,11 +543,9 @@ def _get_encoder_input(self, feats: List[torch.Tensor]): feat_flatten = torch.concat(feat_flatten, 1) return feat_flatten, spatial_shapes - def _generate_anchors(self, - spatial_shapes=None, - grid_size=0.05, - dtype=torch.float32, - device='cpu'): + def _generate_anchors( + self, spatial_shapes=None, grid_size=0.05, dtype=torch.float32, device="cpu" + ): if spatial_shapes is None: spatial_shapes = [] eval_h, eval_w = self.eval_spatial_size @@ -464,10 +554,10 @@ def _generate_anchors(self, anchors = [] for lvl, (h, w) in enumerate(spatial_shapes): - grid_y, grid_x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij') + grid_y, grid_x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") grid_xy = torch.stack([grid_x, grid_y], dim=-1) grid_xy = (grid_xy.unsqueeze(0) + 0.5) / torch.tensor([w, h], dtype=dtype) - wh = torch.ones_like(grid_xy) * grid_size * (2.0 ** lvl) + wh = torch.ones_like(grid_xy) * grid_size * (2.0**lvl) lvl_anchors = torch.concat([grid_xy, wh], dim=-1).reshape(-1, h * w, 4) anchors.append(lvl_anchors) @@ -478,13 +568,9 @@ def _generate_anchors(self, return anchors, valid_mask - - def _get_decoder_input(self, - memory: torch.Tensor, - spatial_shapes, - denoising_logits=None, - denoising_bbox_unact=None): - + def _get_decoder_input( + self, memory: torch.Tensor, spatial_shapes, denoising_logits=None, denoising_bbox_unact=None + ): # prepare input for decoder if self.training or self.eval_spatial_size is None: anchors, valid_mask = self._generate_anchors(spatial_shapes, device=memory.device) @@ -493,82 +579,101 @@ def _get_decoder_input(self, valid_mask = self.valid_mask # memory = torch.where(valid_mask, memory, 0) - # TODO fix type error for onnx export - memory = valid_mask.to(memory.dtype) * memory + # TODO fix type error for onnx export + memory = valid_mask.to(memory.dtype) * memory - output_memory :torch.Tensor = self.enc_output(memory) - enc_outputs_logits :torch.Tensor = self.enc_score_head(output_memory) - enc_outputs_coord_unact :torch.Tensor = self.enc_bbox_head(output_memory) + anchors + output_memory: torch.Tensor = self.enc_output(memory) + enc_outputs_logits: torch.Tensor = self.enc_score_head(output_memory) + enc_outputs_coord_unact: torch.Tensor = self.enc_bbox_head(output_memory) + anchors enc_topk_bboxes_list, enc_topk_logits_list = [], [] - enc_topk_memory, enc_topk_logits, enc_topk_bbox_unact = \ - self._select_topk(output_memory, enc_outputs_logits, enc_outputs_coord_unact, self.num_queries) - + enc_topk_memory, enc_topk_logits, enc_topk_bbox_unact = self._select_topk( + output_memory, enc_outputs_logits, enc_outputs_coord_unact, self.num_queries + ) + if self.training: enc_topk_bboxes = F.sigmoid(enc_topk_bbox_unact) enc_topk_bboxes_list.append(enc_topk_bboxes) enc_topk_logits_list.append(enc_topk_logits) - # if self.num_select_queries != self.num_queries: + # if self.num_select_queries != self.num_queries: # raise NotImplementedError('') if self.learn_query_content: content = self.tgt_embed.weight.unsqueeze(0).tile([memory.shape[0], 1, 1]) else: content = enc_topk_memory.detach() - + enc_topk_bbox_unact = enc_topk_bbox_unact.detach() - + if denoising_bbox_unact is not None: enc_topk_bbox_unact = torch.concat([denoising_bbox_unact, enc_topk_bbox_unact], dim=1) content = torch.concat([denoising_logits, content], dim=1) - + return content, enc_topk_bbox_unact, enc_topk_bboxes_list, enc_topk_logits_list - def _select_topk(self, memory: torch.Tensor, outputs_logits: torch.Tensor, outputs_coords_unact: torch.Tensor, topk: int): - if self.query_select_method == 'default': + def _select_topk( + self, + memory: torch.Tensor, + outputs_logits: torch.Tensor, + outputs_coords_unact: torch.Tensor, + topk: int, + ): + if self.query_select_method == "default": _, topk_ind = torch.topk(outputs_logits.max(-1).values, topk, dim=-1) - elif self.query_select_method == 'one2many': + elif self.query_select_method == "one2many": _, topk_ind = torch.topk(outputs_logits.flatten(1), topk, dim=-1) topk_ind = topk_ind // self.num_classes - elif self.query_select_method == 'agnostic': + elif self.query_select_method == "agnostic": _, topk_ind = torch.topk(outputs_logits.squeeze(-1), topk, dim=-1) - + topk_ind: torch.Tensor - topk_coords = outputs_coords_unact.gather(dim=1, \ - index=topk_ind.unsqueeze(-1).repeat(1, 1, outputs_coords_unact.shape[-1])) - - topk_logits = outputs_logits.gather(dim=1, \ - index=topk_ind.unsqueeze(-1).repeat(1, 1, outputs_logits.shape[-1])) - - topk_memory = memory.gather(dim=1, \ - index=topk_ind.unsqueeze(-1).repeat(1, 1, memory.shape[-1])) + topk_coords = outputs_coords_unact.gather( + dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, outputs_coords_unact.shape[-1]) + ) - return topk_memory, topk_logits, topk_coords + topk_logits = outputs_logits.gather( + dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, outputs_logits.shape[-1]) + ) + + topk_memory = memory.gather( + dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, memory.shape[-1]) + ) + return topk_memory, topk_logits, topk_coords def forward(self, feats, targets=None): # input projection and embedding memory, spatial_shapes = self._get_encoder_input(feats) - + # prepare denoising training if self.training and self.num_denoising > 0: - denoising_logits, denoising_bbox_unact, attn_mask, dn_meta = \ - get_contrastive_denoising_training_group(targets, \ - self.num_classes, - self.num_queries, - self.denoising_class_embed, - num_denoising=self.num_denoising, - label_noise_ratio=self.label_noise_ratio, - box_noise_scale=self.box_noise_scale, ) + ( + denoising_logits, + denoising_bbox_unact, + attn_mask, + dn_meta, + ) = get_contrastive_denoising_training_group( + targets, + self.num_classes, + self.num_queries, + self.denoising_class_embed, + num_denoising=self.num_denoising, + label_noise_ratio=self.label_noise_ratio, + box_noise_scale=self.box_noise_scale, + ) else: denoising_logits, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None - init_ref_contents, init_ref_points_unact, enc_topk_bboxes_list, enc_topk_logits_list = \ - self._get_decoder_input(memory, spatial_shapes, denoising_logits, denoising_bbox_unact) + ( + init_ref_contents, + init_ref_points_unact, + enc_topk_bboxes_list, + enc_topk_logits_list, + ) = self._get_decoder_input(memory, spatial_shapes, denoising_logits, denoising_bbox_unact) # decoder out_bboxes, out_logits = self.decoder( @@ -579,30 +684,29 @@ def forward(self, feats, targets=None): self.dec_bbox_head, self.dec_score_head, self.query_pos_head, - attn_mask=attn_mask) + attn_mask=attn_mask, + ) if self.training and dn_meta is not None: - dn_out_bboxes, out_bboxes = torch.split(out_bboxes, dn_meta['dn_num_split'], dim=2) - dn_out_logits, out_logits = torch.split(out_logits, dn_meta['dn_num_split'], dim=2) + dn_out_bboxes, out_bboxes = torch.split(out_bboxes, dn_meta["dn_num_split"], dim=2) + dn_out_logits, out_logits = torch.split(out_logits, dn_meta["dn_num_split"], dim=2) - out = {'pred_logits': out_logits[-1], 'pred_boxes': out_bboxes[-1]} + out = {"pred_logits": out_logits[-1], "pred_boxes": out_bboxes[-1]} if self.training and self.aux_loss: - out['aux_outputs'] = self._set_aux_loss(out_logits[:-1], out_bboxes[:-1]) - out['enc_aux_outputs'] = self._set_aux_loss(enc_topk_logits_list, enc_topk_bboxes_list) - out['enc_meta'] = {'class_agnostic': self.query_select_method == 'agnostic'} + out["aux_outputs"] = self._set_aux_loss(out_logits[:-1], out_bboxes[:-1]) + out["enc_aux_outputs"] = self._set_aux_loss(enc_topk_logits_list, enc_topk_bboxes_list) + out["enc_meta"] = {"class_agnostic": self.query_select_method == "agnostic"} if dn_meta is not None: - out['dn_aux_outputs'] = self._set_aux_loss(dn_out_logits, dn_out_bboxes) - out['dn_meta'] = dn_meta + out["dn_aux_outputs"] = self._set_aux_loss(dn_out_logits, dn_out_bboxes) + out["dn_meta"] = dn_meta return out - @torch.jit.unused def _set_aux_loss(self, outputs_class, outputs_coord): # this is a workaround to make torchscript happy, as torchscript # doesn't support dictionary with non-homogeneous values, such # as a dict having both a Tensor and a list. - return [{'pred_logits': a, 'pred_boxes': b} - for a, b in zip(outputs_class, outputs_coord)] + return [{"pred_logits": a, "pred_boxes": b} for a, b in zip(outputs_class, outputs_coord)] diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/utils.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/utils.py index 7653bf9e6..5d25119b7 100644 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/utils.py +++ b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/utils.py @@ -4,13 +4,13 @@ import math from typing import List -import torch +import torch import torch.nn as nn -import torch.nn.functional as F +import torch.nn.functional as F -def inverse_sigmoid(x: torch.Tensor, eps: float=1e-5) -> torch.Tensor: - x = x.clip(min=0., max=1.) +def inverse_sigmoid(x: torch.Tensor, eps: float = 1e-5) -> torch.Tensor: + x = x.clip(min=0.0, max=1.0) return torch.log(x.clip(min=eps) / (1 - x).clip(min=eps)) @@ -20,13 +20,14 @@ def bias_init_with_prob(prior_prob=0.01): return bias_init -def deformable_attention_core_func_v2(\ - value: torch.Tensor, +def deformable_attention_core_func_v2( + value: torch.Tensor, value_spatial_shapes, - sampling_locations: torch.Tensor, - attention_weights: torch.Tensor, - num_points_list: List[int], - method='default'): + sampling_locations: torch.Tensor, + attention_weights: torch.Tensor, + num_points_list: List[int], + method="default", +): """ Args: value (Tensor): [bs, value_length, n_head, c] @@ -40,15 +41,15 @@ def deformable_attention_core_func_v2(\ """ bs, _, n_head, c = value.shape _, Len_q, _, _, _ = sampling_locations.shape - + split_shape = [h * w for h, w in value_spatial_shapes] value_list = value.permute(0, 2, 3, 1).flatten(0, 1).split(split_shape, dim=-1) # sampling_offsets [8, 480, 8, 12, 2] - if method == 'default': + if method == "default": sampling_grids = 2 * sampling_locations - 1 - elif method == 'discrete': + elif method == "discrete": sampling_grids = sampling_locations sampling_grids = sampling_grids.permute(0, 2, 1, 3, 4).flatten(0, 1) @@ -59,69 +60,77 @@ def deformable_attention_core_func_v2(\ value_l = value_list[level].reshape(bs * n_head, c, h, w) sampling_grid_l: torch.Tensor = sampling_locations_list[level] - if method == 'default': + if method == "default": sampling_value_l = F.grid_sample( - value_l, - sampling_grid_l, - mode='bilinear', - padding_mode='zeros', - align_corners=False) - - elif method == 'discrete': + value_l, sampling_grid_l, mode="bilinear", padding_mode="zeros", align_corners=False + ) + + elif method == "discrete": # n * m, seq, n, 2 - sampling_coord = (sampling_grid_l * torch.tensor([[w, h]], device=value.device) + 0.5).to(torch.int64) + sampling_coord = ( + sampling_grid_l * torch.tensor([[w, h]], device=value.device) + 0.5 + ).to(torch.int64) # FIX ME? for rectangle input - sampling_coord = sampling_coord.clamp(0, h - 1) - sampling_coord = sampling_coord.reshape(bs * n_head, Len_q * num_points_list[level], 2) + sampling_coord = sampling_coord.clamp(0, h - 1) + sampling_coord = sampling_coord.reshape(bs * n_head, Len_q * num_points_list[level], 2) + + s_idx = ( + torch.arange(sampling_coord.shape[0], device=value.device) + .unsqueeze(-1) + .repeat(1, sampling_coord.shape[1]) + ) + sampling_value_l: torch.Tensor = value_l[ + s_idx, :, sampling_coord[..., 1], sampling_coord[..., 0] + ] # n l c + + sampling_value_l = sampling_value_l.permute(0, 2, 1).reshape( + bs * n_head, c, Len_q, num_points_list[level] + ) - s_idx = torch.arange(sampling_coord.shape[0], device=value.device).unsqueeze(-1).repeat(1, sampling_coord.shape[1]) - sampling_value_l: torch.Tensor = value_l[s_idx, :, sampling_coord[..., 1], sampling_coord[..., 0]] # n l c - - sampling_value_l = sampling_value_l.permute(0, 2, 1).reshape(bs * n_head, c, Len_q, num_points_list[level]) - sampling_value_list.append(sampling_value_l) - attn_weights = attention_weights.permute(0, 2, 1, 3).reshape(bs * n_head, 1, Len_q, sum(num_points_list)) + attn_weights = attention_weights.permute(0, 2, 1, 3).reshape( + bs * n_head, 1, Len_q, sum(num_points_list) + ) weighted_sample_locs = torch.concat(sampling_value_list, dim=-1) * attn_weights output = weighted_sample_locs.sum(-1).reshape(bs, n_head * c, Len_q) return output.permute(0, 2, 1) -def get_activation(act: str, inpace: bool=True): - """get activation - """ +def get_activation(act: str, inpace: bool = True): + """get activation""" if act is None: return nn.Identity() elif isinstance(act, nn.Module): - return act + return act act = act.lower() - - if act == 'silu' or act == 'swish': + + if act == "silu" or act == "swish": m = nn.SiLU() - elif act == 'relu': + elif act == "relu": m = nn.ReLU() - elif act == 'leaky_relu': + elif act == "leaky_relu": m = nn.LeakyReLU() - elif act == 'silu': + elif act == "silu": m = nn.SiLU() - - elif act == 'gelu': + + elif act == "gelu": m = nn.GELU() - elif act == 'hardsigmoid': + elif act == "hardsigmoid": m = nn.Hardsigmoid() else: - raise RuntimeError('') + raise RuntimeError("") - if hasattr(m, 'inplace'): + if hasattr(m, "inplace"): m.inplace = inpace - - return m + + return m diff --git a/PytorchWildlife/models/detection/ultralytics_based/Deepfaune.py b/PytorchWildlife/models/detection/ultralytics_based/Deepfaune.py index 36c84592f..0d45404ea 100644 --- a/PytorchWildlife/models/detection/ultralytics_based/Deepfaune.py +++ b/PytorchWildlife/models/detection/ultralytics_based/Deepfaune.py @@ -9,23 +9,20 @@ from .yolov8_base import YOLOV8Base __all__ = [ - 'DeepfauneDetector', + "DeepfauneDetector", ] + class DeepfauneDetector(YOLOV8Base): """ - MegaDetectorV6 is a specialized class derived from the YOLOV8Base class + MegaDetectorV6 is a specialized class derived from the YOLOV8Base class that is specifically designed for detecting animals, persons, and vehicles. - + Attributes: CLASS_NAMES (dict): Mapping of class IDs to their respective names. """ - - CLASS_NAMES = { - 0: "animal", - 1: "person", - 2: "vehicle" - } + + CLASS_NAMES = {0: "animal", 1: "person", 2: "vehicle"} def __init__(self, weights=None, device="cpu"): """ @@ -37,7 +34,7 @@ def __init__(self, weights=None, device="cpu"): """ self.IMAGE_SIZE = 960 - url = "https://pbil.univ-lyon1.fr/software/download/deepfaune/v1.3/deepfaune-yolov8s_960.pt" + url = "https://pbil.univ-lyon1.fr/software/download/deepfaune/v1.3/deepfaune-yolov8s_960.pt" self.MODEL_NAME = "deepfaune-yolov8s_960.pt" - super(DeepfauneDetector, self).__init__(weights=weights, device=device, url=url) \ No newline at end of file + super(DeepfauneDetector, self).__init__(weights=weights, device=device, url=url) diff --git a/PytorchWildlife/models/detection/ultralytics_based/__init__.py b/PytorchWildlife/models/detection/ultralytics_based/__init__.py index cb9fd8b1e..5906a19e7 100644 --- a/PytorchWildlife/models/detection/ultralytics_based/__init__.py +++ b/PytorchWildlife/models/detection/ultralytics_based/__init__.py @@ -1,6 +1,6 @@ -from .yolov5_base import * -from .yolov8_base import * +from .Deepfaune import * from .megadetectorv5 import * from .megadetectorv6 import * from .megadetectorv6_distributed import * -from .Deepfaune import * +from .yolov5_base import * +from .yolov8_base import * diff --git a/PytorchWildlife/models/detection/ultralytics_based/megadetectorv5.py b/PytorchWildlife/models/detection/ultralytics_based/megadetectorv5.py index cbbbb6671..40effd318 100644 --- a/PytorchWildlife/models/detection/ultralytics_based/megadetectorv5.py +++ b/PytorchWildlife/models/detection/ultralytics_based/megadetectorv5.py @@ -3,40 +3,35 @@ from .yolov5_base import YOLOV5Base -__all__ = [ - 'MegaDetectorV5' -] +__all__ = ["MegaDetectorV5"] + class MegaDetectorV5(YOLOV5Base): """ - MegaDetectorV5 is a specialized class derived from the YOLOV5Base class + MegaDetectorV5 is a specialized class derived from the YOLOV5Base class that is specifically designed for detecting animals, persons, and vehicles. - + Attributes: IMAGE_SIZE (int): The standard image size used during training. STRIDE (int): Stride value used in the detector. CLASS_NAMES (dict): Mapping of class IDs to their respective names. """ - + IMAGE_SIZE = 1280 # image size used in training STRIDE = 64 - CLASS_NAMES = { - 0: "animal", - 1: "person", - 2: "vehicle" - } + CLASS_NAMES = {0: "animal", 1: "person", 2: "vehicle"} def __init__(self, weights=None, device="cpu", pretrained=True, version="a"): """ Initializes the MegaDetectorV5 model with the option to load pretrained weights. - + Args: weights (str, optional): Path to the weights file. device (str, optional): Device to load the model on (e.g., "cpu" or "cuda"). Default is "cpu". pretrained (bool, optional): Whether to load the pretrained model. Default is True. version (str, optional): Version of the MegaDetectorV5 model to load. Default is "a". """ - + if pretrained: if version == "a": url = "https://zenodo.org/records/13357337/files/md_v5a.0.0.pt?download=1" @@ -45,14 +40,12 @@ def __init__(self, weights=None, device="cpu", pretrained=True, version="a"): else: url = None - import site + import site import sys - sys.path.insert(0, site.getsitepackages()[0]+'/yolov5') + + sys.path.insert(0, site.getsitepackages()[0] + "/yolov5") super(MegaDetectorV5, self).__init__(weights=weights, device=device, url=url) - - - # %% diff --git a/PytorchWildlife/models/detection/ultralytics_based/megadetectorv6.py b/PytorchWildlife/models/detection/ultralytics_based/megadetectorv6.py index 0d0ee6584..389ab4010 100644 --- a/PytorchWildlife/models/detection/ultralytics_based/megadetectorv6.py +++ b/PytorchWildlife/models/detection/ultralytics_based/megadetectorv6.py @@ -1,29 +1,23 @@ - from .yolov8_base import YOLOV8Base -__all__ = [ - 'MegaDetectorV6' -] +__all__ = ["MegaDetectorV6"] + class MegaDetectorV6(YOLOV8Base): """ - MegaDetectorV6 is a specialized class derived from the YOLOV8Base class + MegaDetectorV6 is a specialized class derived from the YOLOV8Base class that is specifically designed for detecting animals, persons, and vehicles. - + Attributes: CLASS_NAMES (dict): Mapping of class IDs to their respective names. """ - - CLASS_NAMES = { - 0: "animal", - 1: "person", - 2: "vehicle" - } - - def __init__(self, weights=None, device="cpu", pretrained=True, version='yolov9c'): + + CLASS_NAMES = {0: "animal", 1: "person", 2: "vehicle"} + + def __init__(self, weights=None, device="cpu", pretrained=True, version="yolov9c"): """ Initializes the MegaDetectorV5 model with the option to load pretrained weights. - + Args: weights (str, optional): Path to the weights file. device (str, optional): Device to load the model on (e.g., "cpu" or "cuda"). Default is "cpu". @@ -32,22 +26,24 @@ def __init__(self, weights=None, device="cpu", pretrained=True, version='yolov9c """ self.IMAGE_SIZE = 1280 - if version == 'MDV6-yolov9-c': - url = "https://zenodo.org/records/15398270/files/MDV6-yolov9-c.pt?download=1" + if version == "MDV6-yolov9-c": + url = "https://zenodo.org/records/15398270/files/MDV6-yolov9-c.pt?download=1" self.MODEL_NAME = "MDV6b-yolov9-c.pt" - elif version == 'MDV6-yolov9-e': + elif version == "MDV6-yolov9-e": url = "https://zenodo.org/records/15398270/files/MDV6-yolov9-e-1280.pt?download=1" self.MODEL_NAME = "MDV6-yolov9-e-1280.pt" - elif version == 'MDV6-yolov10-c': + elif version == "MDV6-yolov10-c": url = "https://zenodo.org/records/15398270/files/MDV6-yolov10-c.pt?download=1" self.MODEL_NAME = "MDV6-yolov10-c.pt" - elif version == 'MDV6-yolov10-e': + elif version == "MDV6-yolov10-e": url = "https://zenodo.org/records/15398270/files/MDV6-yolov10-e-1280.pt?download=1" self.MODEL_NAME = "MDV6-yolov10-e-1280.pt" - elif version == 'MDV6-rtdetr-c': + elif version == "MDV6-rtdetr-c": url = "https://zenodo.org/records/15398270/files/MDV6-rtdetr-c.pt?download=1" self.MODEL_NAME = "MDV6b-rtdetr-c.pt" else: - raise ValueError('Select a valid model version: MDV6-yolov9-c, MDV6-yolov9-e, MDV6-yolov10-c, MDV6-yolov10-e, MDV6-rtdetr-c') + raise ValueError( + "Select a valid model version: MDV6-yolov9-c, MDV6-yolov9-e, MDV6-yolov10-c, MDV6-yolov10-e, MDV6-rtdetr-c" + ) - super(MegaDetectorV6, self).__init__(weights=weights, device=device, url=url) \ No newline at end of file + super(MegaDetectorV6, self).__init__(weights=weights, device=device, url=url) diff --git a/PytorchWildlife/models/detection/ultralytics_based/megadetectorv6_distributed.py b/PytorchWildlife/models/detection/ultralytics_based/megadetectorv6_distributed.py index ba1cf0256..a8303acd6 100644 --- a/PytorchWildlife/models/detection/ultralytics_based/megadetectorv6_distributed.py +++ b/PytorchWildlife/models/detection/ultralytics_based/megadetectorv6_distributed.py @@ -1,29 +1,23 @@ - from .yolov8_distributed import YOLOV8_Distributed -__all__ = [ - 'MegaDetectorV6_Distributed' -] +__all__ = ["MegaDetectorV6_Distributed"] + class MegaDetectorV6_Distributed(YOLOV8_Distributed): """ - MegaDetectorV6 is a specialized class derived from the YOLOV8Base class + MegaDetectorV6 is a specialized class derived from the YOLOV8Base class that is specifically designed for detecting animals, persons, and vehicles. - + Attributes: CLASS_NAMES (dict): Mapping of class IDs to their respective names. """ - - CLASS_NAMES = { - 0: "animal", - 1: "person", - 2: "vehicle" - } - - def __init__(self, weights=None, device="cpu", pretrained=True, version='yolov9c'): + + CLASS_NAMES = {0: "animal", 1: "person", 2: "vehicle"} + + def __init__(self, weights=None, device="cpu", pretrained=True, version="yolov9c"): """ Initializes the MegaDetectorV5 model with the option to load pretrained weights. - + Args: weights (str, optional): Path to the weights file. device (str, optional): Device to load the model on (e.g., "cpu" or "cuda"). Default is "cpu". @@ -32,22 +26,24 @@ def __init__(self, weights=None, device="cpu", pretrained=True, version='yolov9c """ self.IMAGE_SIZE = 1280 - if version == 'MDV6-yolov9-c': - url = "https://zenodo.org/records/15398270/files/MDV6-yolov9-c.pt?download=1" + if version == "MDV6-yolov9-c": + url = "https://zenodo.org/records/15398270/files/MDV6-yolov9-c.pt?download=1" self.MODEL_NAME = "MDV6b-yolov9-c.pt" - elif version == 'MDV6-yolov9-e': + elif version == "MDV6-yolov9-e": url = "https://zenodo.org/records/15398270/files/MDV6-yolov9-e-1280.pt?download=1" self.MODEL_NAME = "MDV6-yolov9-e-1280.pt" - elif version == 'MDV6-yolov10-c': + elif version == "MDV6-yolov10-c": url = "https://zenodo.org/records/15398270/files/MDV6-yolov10-c.pt?download=1" self.MODEL_NAME = "MDV6-yolov10-c.pt" - elif version == 'MDV6-yolov10-e': + elif version == "MDV6-yolov10-e": url = "https://zenodo.org/records/15398270/files/MDV6-yolov10-e-1280.pt?download=1" self.MODEL_NAME = "MDV6-yolov10-e-1280.pt" - elif version == 'MDV6-rtdetr-c': + elif version == "MDV6-rtdetr-c": url = "https://zenodo.org/records/15398270/files/MDV6-rtdetr-c.pt?download=1" self.MODEL_NAME = "MDV6b-rtdetr-c.pt" else: - raise ValueError('Select a valid model version: MDV6-yolov9-c, MDV6-yolov9-e, MDV6-yolov10-c, MDV6-yolov10-e, MDV6-rtdetr-c') + raise ValueError( + "Select a valid model version: MDV6-yolov9-c, MDV6-yolov9-e, MDV6-yolov10-c, MDV6-yolov10-e, MDV6-rtdetr-c" + ) - super(MegaDetectorV6_Distributed, self).__init__(weights=weights, device=device, url=url) \ No newline at end of file + super(MegaDetectorV6_Distributed, self).__init__(weights=weights, device=device, url=url) diff --git a/PytorchWildlife/models/detection/ultralytics_based/yolov5_base.py b/PytorchWildlife/models/detection/ultralytics_based/yolov5_base.py index 917b6d93c..ce00bfe09 100644 --- a/PytorchWildlife/models/detection/ultralytics_based/yolov5_base.py +++ b/PytorchWildlife/models/detection/ultralytics_based/yolov5_base.py @@ -6,18 +6,17 @@ # Importing basic libraries import numpy as np -from tqdm import tqdm -from PIL import Image import supervision as sv - import torch -from torch.utils.data import DataLoader +from PIL import Image from torch.hub import load_state_dict_from_url - +from torch.utils.data import DataLoader +from tqdm import tqdm from yolov5.utils.general import non_max_suppression, scale_boxes -from ..base_detector import BaseDetector -from ....data import transforms as pw_trans + from ....data import datasets as pw_data +from ....data import transforms as pw_trans +from ..base_detector import BaseDetector class YOLOV5Base(BaseDetector): @@ -25,16 +24,17 @@ class YOLOV5Base(BaseDetector): Base detector class for YOLO V5. This class provides utility methods for loading the model, generating results, and performing single and batch image detections. """ + def __init__(self, weights=None, device="cpu", url=None, transform=None): """ Initialize the YOLO V5 detector. - + Args: - weights (str, optional): + weights (str, optional): Path to the model weights. Defaults to None. - device (str, optional): + device (str, optional): Device for model inference. Defaults to "cpu". - url (str, optional): + url (str, optional): URL to fetch the model weights. Defaults to None. transform (callable, optional): Optional transform to be applied on the image. Defaults to None. @@ -46,13 +46,13 @@ def __init__(self, weights=None, device="cpu", url=None, transform=None): def _load_model(self, weights=None, device="cpu", url=None): """ Load the YOLO V5 model weights. - + Args: - weights (str, optional): + weights (str, optional): Path to the model weights. Defaults to None. - device (str, optional): + device (str, optional): Device for model inference. Defaults to "cpu". - url (str, optional): + url (str, optional): URL to fetch the model weights. Defaults to None. Raises: Exception: If weights are not provided. @@ -64,21 +64,22 @@ def _load_model(self, weights=None, device="cpu", url=None): else: raise Exception("Need weights for inference.") self.model = checkpoint["model"].float().fuse().eval().to(self.device) - + if not self.transform: - self.transform = pw_trans.MegaDetector_v5_Transform(target_size=self.IMAGE_SIZE, - stride=self.STRIDE) + self.transform = pw_trans.MegaDetector_v5_Transform( + target_size=self.IMAGE_SIZE, stride=self.STRIDE + ) def results_generation(self, preds, img_id, id_strip=None) -> dict: """ Generate results for detection based on model predictions. - + Args: - preds (numpy.ndarray): + preds (numpy.ndarray): Model predictions. - img_id (str): + img_id (str): Image identifier. - id_strip (str, optional): + id_strip (str, optional): Strip specific characters from img_id. Defaults to None. Returns: @@ -86,28 +87,28 @@ def results_generation(self, preds, img_id, id_strip=None) -> dict: """ results = {"img_id": str(img_id).strip(id_strip)} results["detections"] = sv.Detections( - xyxy=preds[:, :4], - confidence=preds[:, 4], - class_id=preds[:, 5].astype(int) + xyxy=preds[:, :4], confidence=preds[:, 4], class_id=preds[:, 5].astype(int) ) results["labels"] = [ f"{self.CLASS_NAMES[class_id]} {confidence:0.2f}" - for confidence, class_id in zip(results["detections"].confidence, results["detections"].class_id) + for confidence, class_id in zip( + results["detections"].confidence, results["detections"].class_id + ) ] return results def single_image_detection(self, img, img_path=None, det_conf_thres=0.2, id_strip=None) -> dict: """ Perform detection on a single image. - + Args: - img (str or ndarray): + img (str or ndarray): Image path or ndarray of images. - img_path (str, optional): + img_path (str, optional): Image path or identifier. - det_conf_thres (float, optional): + det_conf_thres (float, optional): Confidence threshold for predictions. Defaults to 0.2. - id_strip (str, optional): + id_strip (str, optional): Characters to strip from img_id. Defaults to None. Returns: @@ -121,19 +122,28 @@ def single_image_detection(self, img, img_path=None, det_conf_thres=0.2, id_stri img = self.transform(img) if img_size is None: - img_size = img.permute((1, 2, 0)).shape # We need hwc instead of chw for coord scaling + img_size = img.permute((1, 2, 0)).shape # We need hwc instead of chw for coord scaling preds = self.model(img.unsqueeze(0).to(self.device))[0] - preds = torch.cat(non_max_suppression(prediction=preds, conf_thres=det_conf_thres), axis=0).cpu().numpy() + preds = ( + torch.cat(non_max_suppression(prediction=preds, conf_thres=det_conf_thres), axis=0) + .cpu() + .numpy() + ) # preds[:, :4] = scale_coords([self.IMAGE_SIZE] * 2, preds[:, :4], img_size).round() preds[:, :4] = scale_boxes([self.IMAGE_SIZE] * 2, preds[:, :4], img_size).round() res = self.results_generation(preds, img_path, id_strip) - normalized_coords = [[x1 / img_size[1], y1 / img_size[0], x2 / img_size[1], y2 / img_size[0]] for x1, y1, x2, y2 in preds[:, :4]] + normalized_coords = [ + [x1 / img_size[1], y1 / img_size[0], x2 / img_size[1], y2 / img_size[0]] + for x1, y1, x2, y2 in preds[:, :4] + ] res["normalized_coords"] = normalized_coords return res - def batch_image_detection(self, data_path, batch_size: int = 16, det_conf_thres: float = 0.2, id_strip: str = None) -> list[dict]: + def batch_image_detection( + self, data_path, batch_size: int = 16, det_conf_thres: float = 0.2, id_strip: str = None + ) -> list[dict]: """ Perform detection on a batch of images. @@ -153,8 +163,14 @@ def batch_image_detection(self, data_path, batch_size: int = 16, det_conf_thres: ) # Creating a DataLoader for batching and parallel processing of the images - loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, - pin_memory=True, num_workers=0, drop_last=False) + loader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + pin_memory=True, + num_workers=0, + drop_last=False, + ) results = [] with tqdm(total=len(loader)) as pbar: @@ -165,7 +181,7 @@ def batch_image_detection(self, data_path, batch_size: int = 16, det_conf_thres: batch_results = [] for i, pred in enumerate(predictions): - if pred.size(0) == 0: + if pred.size(0) == 0: continue pred = pred.numpy() size = sizes[i].numpy() @@ -174,7 +190,10 @@ def batch_image_detection(self, data_path, batch_size: int = 16, det_conf_thres: # pred[:, :4] = scale_coords([self.IMAGE_SIZE] * 2, pred[:, :4], size).round() pred[:, :4] = scale_boxes([self.IMAGE_SIZE] * 2, pred[:, :4], size).round() # Normalize the coordinates for timelapse compatibility - normalized_coords = [[x1 / size[1], y1 / size[0], x2 / size[1], y2 / size[0]] for x1, y1, x2, y2 in pred[:, :4]] + normalized_coords = [ + [x1 / size[1], y1 / size[0], x2 / size[1], y2 / size[0]] + for x1, y1, x2, y2 in pred[:, :4] + ] res = self.results_generation(pred, path, id_strip) res["normalized_coords"] = normalized_coords batch_results.append(res) diff --git a/PytorchWildlife/models/detection/ultralytics_based/yolov8_base.py b/PytorchWildlife/models/detection/ultralytics_based/yolov8_base.py index 50d6cfbbb..e061d5ea1 100644 --- a/PytorchWildlife/models/detection/ultralytics_based/yolov8_base.py +++ b/PytorchWildlife/models/detection/ultralytics_based/yolov8_base.py @@ -6,39 +6,39 @@ # Importing basic libraries import os -import wget + import numpy as np -from tqdm import tqdm -from PIL import Image import supervision as sv - import torch +import wget +from PIL import Image from torch.utils.data import DataLoader +from tqdm import tqdm +from ultralytics.models import rtdetr, yolo -from ultralytics.models import yolo, rtdetr - -from ..base_detector import BaseDetector -from ....data import transforms as pw_trans from ....data import datasets as pw_data +from ....data import transforms as pw_trans +from ..base_detector import BaseDetector class YOLOV8Base(BaseDetector): """ Base detector class for the new ultralytics YOLOV8 framework. This class provides utility methods for loading the model, generating results, and performing single and batch image detections. - This base detector class is also compatible with all the new ultralytics models including YOLOV9, + This base detector class is also compatible with all the new ultralytics models including YOLOV9, RTDetr, and more. """ + def __init__(self, weights=None, device="cpu", url=None, transform=None): """ Initialize the YOLOV8 detector. - + Args: - weights (str, optional): + weights (str, optional): Path to the model weights. Defaults to None. - device (str, optional): + device (str, optional): Device for model inference. Defaults to "cpu". - url (str, optional): + url (str, optional): URL to fetch the model weights. Defaults to None. """ super(YOLOV8Base, self).__init__(weights=weights, device=device, url=url) @@ -48,30 +48,34 @@ def __init__(self, weights=None, device="cpu", url=None, transform=None): def _load_model(self, weights=None, device="cpu", url=None): """ Load the YOLOV8 model weights. - + Args: - weights (str, optional): + weights (str, optional): Path to the model weights. Defaults to None. - device (str, optional): + device (str, optional): Device for model inference. Defaults to "cpu". - url (str, optional): + url (str, optional): URL to fetch the model weights. Defaults to None. Raises: Exception: If weights are not provided. """ - if self.MODEL_NAME == 'MDV6b-rtdetrl.pt': + if self.MODEL_NAME == "MDV6b-rtdetrl.pt": self.predictor = rtdetr.RTDETRPredictor() else: self.predictor = yolo.detect.DetectionPredictor() # self.predictor.args.device = device # Will uncomment later self.predictor.args.imgsz = self.IMAGE_SIZE - self.predictor.args.save = False # Will see if we want to use ultralytics native inference saving functions. + self.predictor.args.save = ( + False # Will see if we want to use ultralytics native inference saving functions. + ) if weights: self.predictor.setup_model(weights) elif url: - if not os.path.exists(os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME)): + if not os.path.exists( + os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME) + ): os.makedirs(os.path.join(torch.hub.get_dir(), "checkpoints"), exist_ok=True) weights = wget.download(url, out=os.path.join(torch.hub.get_dir(), "checkpoints")) else: @@ -79,21 +83,22 @@ def _load_model(self, weights=None, device="cpu", url=None): self.predictor.setup_model(weights) else: raise Exception("Need weights for inference.") - + if not self.transform: - self.transform = pw_trans.MegaDetector_v5_Transform(target_size=self.IMAGE_SIZE, - stride=self.STRIDE) + self.transform = pw_trans.MegaDetector_v5_Transform( + target_size=self.IMAGE_SIZE, stride=self.STRIDE + ) def results_generation(self, preds, img_id, id_strip=None) -> dict: """ Generate results for detection based on model predictions. - + Args: - preds (ultralytics.engine.results.Results): + preds (ultralytics.engine.results.Results): Model predictions. - img_id (str): + img_id (str): Image identifier. - id_strip (str, optional): + id_strip (str, optional): Strip specific characters from img_id. Defaults to None. Returns: @@ -104,32 +109,27 @@ def results_generation(self, preds, img_id, id_strip=None) -> dict: class_id = preds.boxes.cls.cpu().numpy().astype(int) results = {"img_id": str(img_id).strip(id_strip)} - results["detections"] = sv.Detections( - xyxy=xyxy, - confidence=confidence, - class_id=class_id - ) + results["detections"] = sv.Detections(xyxy=xyxy, confidence=confidence, class_id=class_id) results["labels"] = [ - f"{self.CLASS_NAMES[class_id]} {confidence:0.2f}" - for _, _, confidence, class_id, _, _ in results["detections"] + f"{self.CLASS_NAMES[class_id]} {confidence:0.2f}" + for _, _, confidence, class_id, _, _ in results["detections"] ] - + return results - def single_image_detection(self, img, img_path=None, det_conf_thres=0.2, id_strip=None) -> dict: """ Perform detection on a single image. - + Args: - img (str or ndarray): + img (str or ndarray): Image path or ndarray of images. - img_path (str, optional): + img_path (str, optional): Image path or identifier. - det_conf_thres (float, optional): + det_conf_thres (float, optional): Confidence threshold for predictions. Defaults to 0.2. - id_strip (str, optional): + id_strip (str, optional): Characters to strip from img_id. Defaults to None. Returns: @@ -144,18 +144,22 @@ def single_image_detection(self, img, img_path=None, det_conf_thres=0.2, id_stri self.predictor.args.batch = 1 self.predictor.args.conf = det_conf_thres - + det_results = list(self.predictor.stream_inference([img])) res = self.results_generation(det_results[0], img_path, id_strip) - normalized_coords = [[x1 / img_size[1], y1 / img_size[0], x2 / img_size[1], y2 / img_size[0]] - for x1, y1, x2, y2 in res["detections"].xyxy] + normalized_coords = [ + [x1 / img_size[1], y1 / img_size[0], x2 / img_size[1], y2 / img_size[0]] + for x1, y1, x2, y2 in res["detections"].xyxy + ] res["normalized_coords"] = normalized_coords - + return res - def batch_image_detection(self, data_source, batch_size: int = 16, det_conf_thres: float = 0.2, id_strip: str = None) -> list[dict]: + def batch_image_detection( + self, data_source, batch_size: int = 16, det_conf_thres: float = 0.2, id_strip: str = None + ) -> list[dict]: """ Perform detection on a batch of images. @@ -174,24 +178,28 @@ def batch_image_detection(self, data_source, batch_size: int = 16, det_conf_thre # Handle numpy array input if isinstance(data_source, (list, np.ndarray)): results = [] - num_batches = (len(data_source) + batch_size - 1) // batch_size # Calculate total batches - + num_batches = ( + len(data_source) + batch_size - 1 + ) // batch_size # Calculate total batches + with tqdm(total=num_batches) as pbar: for start_idx in range(0, len(data_source), batch_size): - batch_arrays = data_source[start_idx:start_idx + batch_size] + batch_arrays = data_source[start_idx : start_idx + batch_size] det_results = self.predictor.stream_inference(batch_arrays) - + for idx, preds in enumerate(det_results): res = self.results_generation(preds, f"{start_idx + idx}", id_strip) # Get size directly from numpy array img_height, img_width = batch_arrays[idx].shape[:2] - normalized_coords = [[x1/img_width, y1/img_height, x2/img_width, y2/img_height] - for x1, y1, x2, y2 in res["detections"].xyxy] + normalized_coords = [ + [x1 / img_width, y1 / img_height, x2 / img_width, y2 / img_height] + for x1, y1, x2, y2 in res["detections"].xyxy + ] res["normalized_coords"] = normalized_coords results.append(res) pbar.update(1) return results - + # Handle image directory input dataset = pw_data.DetectionImageFolder( data_source, @@ -199,10 +207,15 @@ def batch_image_detection(self, data_source, batch_size: int = 16, det_conf_thre ) # Creating a DataLoader for batching and parallel processing of the images - loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, - pin_memory=True, num_workers=0, drop_last=False - ) - + loader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + pin_memory=True, + num_workers=0, + drop_last=False, + ) + results = [] with tqdm(total=len(loader)) as pbar: for batch_index, (imgs, paths, sizes) in enumerate(loader): @@ -212,7 +225,10 @@ def batch_image_detection(self, data_source, batch_size: int = 16, det_conf_thre res = self.results_generation(preds, paths[idx], id_strip) size = preds.orig_shape # Normalize the coordinates for timelapse compatibility - normalized_coords = [[x1 / size[1], y1 / size[0], x2 / size[1], y2 / size[0]] for x1, y1, x2, y2 in res["detections"].xyxy] + normalized_coords = [ + [x1 / size[1], y1 / size[0], x2 / size[1], y2 / size[0]] + for x1, y1, x2, y2 in res["detections"].xyxy + ] res["normalized_coords"] = normalized_coords results.append(res) pbar.update(1) diff --git a/PytorchWildlife/models/detection/ultralytics_based/yolov8_distributed.py b/PytorchWildlife/models/detection/ultralytics_based/yolov8_distributed.py index 688ea780d..0d071f78f 100644 --- a/PytorchWildlife/models/detection/ultralytics_based/yolov8_distributed.py +++ b/PytorchWildlife/models/detection/ultralytics_based/yolov8_distributed.py @@ -7,20 +7,21 @@ import os import time from glob import glob -import supervision as sv + import numpy as np import pandas as pd -from PIL import Image -import wget +import supervision as sv import torch - -from ultralytics.models import yolo, rtdetr +import wget +from PIL import Image from torch.utils.data import DataLoader from tqdm import tqdm +from ultralytics.models import rtdetr, yolo -from ..base_detector import BaseDetector -from ....data import transforms as pw_trans from ....data import datasets as pw_data +from ....data import transforms as pw_trans +from ..base_detector import BaseDetector + class YOLOV8_Distributed(BaseDetector): """ @@ -32,46 +33,50 @@ class YOLOV8_Distributed(BaseDetector): def __init__(self, weights=None, device="cpu", url=None, transform=None): """ Initialize the YOLOV8 detector. - + Args: - weights (str, optional): + weights (str, optional): Path to the model weights. Defaults to None. - device (str, optional): + device (str, optional): Device for model inference. Defaults to "cpu". - url (str, optional): + url (str, optional): URL to fetch the model weights. Defaults to None. """ self.transform = transform super(YOLOV8_Distributed, self).__init__(weights=weights, device=device, url=url) self._load_model(weights, self.device, url) - + def _load_model(self, weights=None, device="cpu", url=None): """ Load the YOLOV8 model weights. - + Args: - weights (str, optional): + weights (str, optional): Path to the model weights. Defaults to None. - device (str, optional): + device (str, optional): Device for model inference. Defaults to "cpu". - url (str, optional): + url (str, optional): URL to fetch the model weights. Defaults to None. Raises: Exception: If weights are not provided. """ - if self.MODEL_NAME == 'MDV6b-rtdetrl.pt': + if self.MODEL_NAME == "MDV6b-rtdetrl.pt": self.predictor = rtdetr.RTDETRPredictor() else: self.predictor = yolo.detect.DetectionPredictor() # self.predictor.args.device = device # Will uncomment later self.predictor.args.imgsz = self.IMAGE_SIZE - self.predictor.args.save = False # Will see if we want to use ultralytics native inference saving functions. + self.predictor.args.save = ( + False # Will see if we want to use ultralytics native inference saving functions. + ) if weights: self.predictor.setup_model(weights) elif url: - if not os.path.exists(os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME)): + if not os.path.exists( + os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME) + ): os.makedirs(os.path.join(torch.hub.get_dir(), "checkpoints"), exist_ok=True) weights = wget.download(url, out=os.path.join(torch.hub.get_dir(), "checkpoints")) else: @@ -79,21 +84,22 @@ def _load_model(self, weights=None, device="cpu", url=None): self.predictor.setup_model(weights) else: raise Exception("Need weights for inference.") - + if not self.transform: - self.transform = pw_trans.MegaDetector_v5_Transform(target_size=self.IMAGE_SIZE, - stride=self.STRIDE) - + self.transform = pw_trans.MegaDetector_v5_Transform( + target_size=self.IMAGE_SIZE, stride=self.STRIDE + ) + def results_generation(self, preds, img_id, id_strip=None) -> dict: """ Generate results for detection based on model predictions. - + Args: - preds (ultralytics.engine.results.Results): + preds (ultralytics.engine.results.Results): Model predictions. - img_id (str): + img_id (str): Image identifier. - id_strip (str, optional): + id_strip (str, optional): Strip specific characters from img_id. Defaults to None. Returns: @@ -102,7 +108,7 @@ def results_generation(self, preds, img_id, id_strip=None) -> dict: xyxy = preds.boxes.xyxy.cpu().numpy() confidence = preds.boxes.conf.cpu().numpy() class_id = preds.boxes.cls.cpu().numpy().astype(int) - + results = {"img_id": str(img_id).strip(id_strip)} # results["detections"] = sv.Detections( # xyxy=xyxy, @@ -114,38 +120,45 @@ def results_generation(self, preds, img_id, id_strip=None) -> dict: results["detections_class_id"] = class_id # results["labels"] = [ - # f"{self.CLASS_NAMES[class_id]} {confidence:0.2f}" - # for _, _, confidence, class_id, _, _ in results["detections"] + # f"{self.CLASS_NAMES[class_id]} {confidence:0.2f}" + # for _, _, confidence, class_id, _, _ in results["detections"] # ] - + results["labels"] = [ - f"{self.CLASS_NAMES[cls_id]} {conf:0.2f}" - for cls_id, conf in zip(class_id, confidence) + f"{self.CLASS_NAMES[cls_id]} {conf:0.2f}" for cls_id, conf in zip(class_id, confidence) ] - + results["n_animal_detected"] = np.sum(class_id == 0) - + return results - - def batch_image_detection(self, loader, batch_size, global_rank, local_rank, output_dir, det_conf_thres=0.2, checkpoint_frequency = 1000): + def batch_image_detection( + self, + loader, + batch_size, + global_rank, + local_rank, + output_dir, + det_conf_thres=0.2, + checkpoint_frequency=1000, + ): """ Perform batch image detection using the YOLOV8 model. - + Args: - loader (torch.utils.data.DataLoader): + loader (torch.utils.data.DataLoader): DataLoader for input images. batch_size (int): Size of the batch for detection. - global_rank (int): + global_rank (int): Global rank of the process. - local_rank (int): + local_rank (int): Local rank of the process. - output_dir (str): + output_dir (str): Directory to save detection results. - det_conf_thres (float, optional): + det_conf_thres (float, optional): Confidence threshold for detections. Defaults to 0.2. - checkpoint_frequency (int, optional): + checkpoint_frequency (int, optional): Frequency of saving intermediate results. Defaults to 1000. """ os.makedirs(output_dir, exist_ok=True) @@ -153,7 +166,6 @@ def batch_image_detection(self, loader, batch_size, global_rank, local_rank, out self.predictor.args.conf = det_conf_thres self.predictor.args.device = local_rank - # Create checkpoint directory # Track batches and processed items results = { @@ -163,7 +175,7 @@ def batch_image_detection(self, loader, batch_size, global_rank, local_rank, out "detections_class_id": [], "labels": [], "n_animal_detected": [], - "normalized_coords": [] + "normalized_coords": [], } checkpoint_dir = os.path.join(output_dir, f"checkpoints_rank{global_rank}") @@ -178,15 +190,18 @@ def batch_image_detection(self, loader, batch_size, global_rank, local_rank, out # images: tensor of shape [batch_size, 3, H, W] # Assuming images are transformed & Standardized det_results = self.predictor.stream_inference(images) - + for idx, preds in enumerate(det_results): res = self.results_generation(preds, uuids[idx]) - + size = preds.orig_shape - normalized_coords = [[x1 / size[1], y1 / size[0], x2 / size[1], y2 / size[0]] for x1, y1, x2, y2 in res["detections_xyxy"]] + normalized_coords = [ + [x1 / size[1], y1 / size[0], x2 / size[1], y2 / size[0]] + for x1, y1, x2, y2 in res["detections_xyxy"] + ] res["normalized_coords"] = normalized_coords - - #results.append(res) + + # results.append(res) results["img_id"].append(res["img_id"]) results["detections_xyxy"].append(res["detections_xyxy"].tolist()) results["detections_confidence"].append(res["detections_confidence"].tolist()) @@ -194,39 +209,29 @@ def batch_image_detection(self, loader, batch_size, global_rank, local_rank, out results["labels"].append(res["labels"]) results["n_animal_detected"].append(int(res["n_animal_detected"])) results["normalized_coords"].append(res["normalized_coords"]) - + if batch_counter % checkpoint_frequency == 0: elapsed = time.time() - start_time print(f"[Rank {global_rank}] Processed {processed_count} images in {elapsed}") - + # Save intermediate results checkpoint_path = os.path.join( - checkpoint_dir, - f"checkpoint_{batch_counter:06d}.parquet" + checkpoint_dir, f"checkpoint_{batch_counter:06d}.parquet" + ) + + df = pd.DataFrame( + {"img_id": results["img_id"], "n_animal_detected": results["n_animal_detected"]} ) - - df = pd.DataFrame({ - "img_id": results["img_id"], - "n_animal_detected": results["n_animal_detected"] - }) df.to_parquet(checkpoint_path, index=False) print(f"[Rank {global_rank}] Saved checkpoint to {checkpoint_path}") - + # Save results to disk os.makedirs(output_dir, exist_ok=True) - df = pd.DataFrame({ - "img_id": results["img_id"], - "n_animal_detected": results["n_animal_detected"] - }) + df = pd.DataFrame( + {"img_id": results["img_id"], "n_animal_detected": results["n_animal_detected"]} + ) out_path = os.path.join(output_dir, f"predictions_rank{global_rank}.parquet") df.to_parquet(out_path, index=False) print(f"[rank {global_rank}] Saved predictions to {out_path}") - - return results - - - - - - \ No newline at end of file + return results diff --git a/PytorchWildlife/models/detection/yolo_mit/__init__.py b/PytorchWildlife/models/detection/yolo_mit/__init__.py index 6d410c7dd..1d77fb0ab 100644 --- a/PytorchWildlife/models/detection/yolo_mit/__init__.py +++ b/PytorchWildlife/models/detection/yolo_mit/__init__.py @@ -1,2 +1,2 @@ +from .megadetectorv6_mit import * from .yolo_mit_base import * -from .megadetectorv6_mit import * \ No newline at end of file diff --git a/PytorchWildlife/models/detection/yolo_mit/megadetectorv6_mit.py b/PytorchWildlife/models/detection/yolo_mit/megadetectorv6_mit.py index 24a400b7c..bec7e8e99 100644 --- a/PytorchWildlife/models/detection/yolo_mit/megadetectorv6_mit.py +++ b/PytorchWildlife/models/detection/yolo_mit/megadetectorv6_mit.py @@ -1,29 +1,23 @@ - from .yolo_mit_base import YOLOMITBase -__all__ = [ - 'MegaDetectorV6MIT' -] +__all__ = ["MegaDetectorV6MIT"] + class MegaDetectorV6MIT(YOLOMITBase): """ - MegaDetectorV6 is a specialized class derived from the YOLOMITBase class + MegaDetectorV6 is a specialized class derived from the YOLOMITBase class that is specifically designed for detecting animals, persons, and vehicles. - + Attributes: CLASS_NAMES (dict): Mapping of class IDs to their respective names. """ - - CLASS_NAMES = { - 0: "animal", - 1: "person", - 2: "vehicle" - } - - def __init__(self, weights=None, device="cpu", pretrained=True, version='MDV6-yolov9-c-mit'): + + CLASS_NAMES = {0: "animal", 1: "person", 2: "vehicle"} + + def __init__(self, weights=None, device="cpu", pretrained=True, version="MDV6-yolov9-c-mit"): """ Initializes the MegaDetectorV6 model with the option to load pretrained weights. - + Args: weights (str, optional): Path to the weights file. device (str, optional): Device to load the model on (e.g., "cpu" or "cuda"). Default is "cpu". @@ -32,13 +26,13 @@ def __init__(self, weights=None, device="cpu", pretrained=True, version='MDV6-yo """ self.IMAGE_SIZE = 640 - if version == 'MDV6-mit-yolov9-c': + if version == "MDV6-mit-yolov9-c": url = "https://zenodo.org/records/15398270/files/MDV6-mit-yolov9-c.ckpt?download=1" self.MODEL_NAME = "MDV6-mit-yolov9-c.ckpt" - elif version == 'MDV6-mit-yolov9-e': + elif version == "MDV6-mit-yolov9-e": url = "https://zenodo.org/records/15398270/files/MDV6-mit-yolov9-e.ckpt?download=1" self.MODEL_NAME = "MDV6-mit-yolov9-e.ckpt" else: - raise ValueError('Select a valid model version: MDV6-mit-yolov9-c or MDV6-mit-yolov9-e') + raise ValueError("Select a valid model version: MDV6-mit-yolov9-c or MDV6-mit-yolov9-e") - super(MegaDetectorV6MIT, self).__init__(weights=weights, device=device, url=url) \ No newline at end of file + super(MegaDetectorV6MIT, self).__init__(weights=weights, device=device, url=url) diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo/__init__.py b/PytorchWildlife/models/detection/yolo_mit/yolo/__init__.py index 02d8f69c4..5f7b1b7a7 100644 --- a/PytorchWildlife/models/detection/yolo_mit/yolo/__init__.py +++ b/PytorchWildlife/models/detection/yolo_mit/yolo/__init__.py @@ -1,15 +1,14 @@ -from yolo.model.yolo import create_model from yolo.config import Config, NMSConfig +from yolo.model.yolo import create_model from yolo.tools.data_loader import AugmentationComposer, create_dataloader -from yolo.utils.model_utils import PostProcess from yolo.utils.bounding_box_utils import create_converter +from yolo.utils.model_utils import PostProcess all = [ "create_model", "Config", "NMSConfig", - "AugmentationComposer" - "create_dataloader", + "AugmentationComposer" "create_dataloader", "PostProcess", "create_converter", ] diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo/config.py b/PytorchWildlife/models/detection/yolo_mit/yolo/config.py index b8b69f5d8..af12389a0 100644 --- a/PytorchWildlife/models/detection/yolo_mit/yolo/config.py +++ b/PytorchWildlife/models/detection/yolo_mit/yolo/config.py @@ -1,5 +1,6 @@ from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union + from torch import nn @@ -165,4 +166,4 @@ class YOLOLayer(nn.Module): tags: str layer_type: str usable: bool - external: Optional[dict] \ No newline at end of file + external: Optional[dict] diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo/model/module.py b/PytorchWildlife/models/detection/yolo_mit/yolo/model/module.py index 87b211d43..ef1ecc00a 100644 --- a/PytorchWildlife/models/detection/yolo_mit/yolo/model/module.py +++ b/PytorchWildlife/models/detection/yolo_mit/yolo/model/module.py @@ -1,9 +1,11 @@ +import inspect from typing import Any, Dict, List, Optional, Tuple, Union + import torch import torch.nn.functional as F from torch import Tensor, nn from torch.nn.common_types import _size_2_t -import inspect + # ----------- Utils ----------- # def get_layer_map(): @@ -109,7 +111,14 @@ def forward(self, x): class Detection(nn.Module): """A single YOLO Detection head for detection models""" - def __init__(self, in_channels: Tuple[int], num_classes: int, *, reg_max: int = 16, use_group: bool = True): + def __init__( + self, + in_channels: Tuple[int], + num_classes: int, + *, + reg_max: int = 16, + use_group: bool = True, + ): super().__init__() groups = 4 if use_group else 1 @@ -125,7 +134,9 @@ def __init__(self, in_channels: Tuple[int], num_classes: int, *, reg_max: int = nn.Conv2d(anchor_neck, anchor_channels, 1, groups=groups), ) self.class_conv = nn.Sequential( - Conv(in_channels, class_neck, 3), Conv(class_neck, class_neck, 3), nn.Conv2d(class_neck, num_classes, 1) + Conv(in_channels, class_neck, 3), + Conv(class_neck, class_neck, 3), + nn.Conv2d(class_neck, num_classes, 1), ) self.anc2vec = Anchor2Vec(reg_max=reg_max) @@ -151,7 +162,10 @@ def __init__(self, in_channels: List[int], num_classes: int, **head_kwargs): DetectionHead = IDetection self.heads = nn.ModuleList( - [DetectionHead((in_channels[0], in_channel), num_classes, **head_kwargs) for in_channel in in_channels] + [ + DetectionHead((in_channels[0], in_channel), num_classes, **head_kwargs) + for in_channel in in_channels + ] ) def forward(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]: @@ -166,11 +180,11 @@ def __init__(self, reg_max: int = 16) -> None: self.anc2vec.weight = nn.Parameter(reverse_reg, requires_grad=False) def forward(self, anchor_x: Tensor) -> Tensor: - #anchor_x = rearrange(anchor_x, "B (P R) h w -> B R P h w", P=4) + # anchor_x = rearrange(anchor_x, "B (P R) h w -> B R P h w", P=4) B, PR, h, w = anchor_x.shape P = 4 R = PR // P - anchor_x = anchor_x.reshape(B, P, R, h, w).permute(0, 2, 1, 3, 4) + anchor_x = anchor_x.reshape(B, P, R, h, w).permute(0, 2, 1, 3, 4) vector_x = anchor_x.softmax(dim=1) vector_x = self.anc2vec(vector_x)[:, 0] return anchor_x, vector_x @@ -338,6 +352,7 @@ def forward(self, x: Tensor) -> Tensor: x = self.conv(x) return x + class ADown(nn.Module): """Downsampling module combining average and max pooling with convolution for feature reduction.""" @@ -373,6 +388,7 @@ def forward(self, x: Tensor) -> Tuple[Tensor]: x = self.conv(x) return x.split(self.out_channels, dim=1) + class CBFuse(nn.Module): def __init__(self, index: List[int], mode: str = "nearest"): super().__init__() @@ -383,10 +399,14 @@ def forward(self, x_list: List[torch.Tensor]) -> List[Tensor]: target = x_list[-1] target_size = target.shape[2:] # Batch, Channel, H, W - res = [F.interpolate(x[pick_id], size=target_size, mode=self.mode) for pick_id, x in zip(self.idx, x_list)] + res = [ + F.interpolate(x[pick_id], size=target_size, mode=self.mode) + for pick_id, x in zip(self.idx, x_list) + ] out = torch.stack(res + [target]).sum(dim=0) return out - + + class SPPELAN(nn.Module): """SPPELAN module comprising multiple pooling and convolution layers.""" @@ -411,4 +431,4 @@ def __init__(self, **kwargs): self.UpSample = nn.Upsample(**kwargs) def forward(self, x): - return self.UpSample(x) \ No newline at end of file + return self.UpSample(x) diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo/model/yolo.py b/PytorchWildlife/models/detection/yolo_mit/yolo/model/yolo.py index 3834dd43f..eadeebcc8 100644 --- a/PytorchWildlife/models/detection/yolo_mit/yolo/model/yolo.py +++ b/PytorchWildlife/models/detection/yolo_mit/yolo/model/yolo.py @@ -5,10 +5,9 @@ import torch from omegaconf import ListConfig, OmegaConf from torch import nn - from yolo.config import ModelConfig, YOLOLayer -from yolo.tools.dataset_preparation import prepare_weight from yolo.model.module import get_layer_map +from yolo.tools.dataset_preparation import prepare_weight class YOLO(nn.Module): @@ -40,9 +39,15 @@ def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]): source = self.get_source_idx(layer_info.get("source", -1), layer_idx) # Find in channels - if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "AConv", "CBLinear"]): + if any( + module in layer_type + for module in ["Conv", "ELAN", "ADown", "AConv", "CBLinear"] + ): layer_args["in_channels"] = output_dim[source] - if any(module in layer_type for module in ["Detection", "Segmentation", "Classification"]): + if any( + module in layer_type + for module in ["Detection", "Segmentation", "Classification"] + ): if isinstance(source, list): layer_args["in_channels"] = [output_dim[idx] for idx in source] else: @@ -85,7 +90,9 @@ def forward(self, x, external: Optional[Dict] = None, shortcut: Optional[str] = return output return output - def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]): + def get_out_channels( + self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list] + ): if hasattr(layer_args, "out_channels"): return layer_args["out_channels"] if layer_type == "CBFuse": @@ -106,7 +113,9 @@ def get_source_idx(self, source: Union[ListConfig, str, int], layer_idx: int): self.model[source - 1].usable = True return source - def create_layer(self, layer_type: str, source: Union[int, list], layer_info: Dict, **kwargs) -> YOLOLayer: + def create_layer( + self, layer_type: str, source: Union[int, list], layer_info: Dict, **kwargs + ) -> YOLOLayer: if layer_type in self.layer_map: layer = self.layer_map[layer_type](**kwargs) setattr(layer, "layer_type", layer_type) @@ -133,7 +142,7 @@ def save_load_weights(self, weights: Union[Path, OrderedDict]): weights = weights["state_dict"] # Drop the prefix 'model.model.' from the keys - if "model.model." in list(weights.keys())[0]: + if "model.model." in list(weights.keys())[0]: weights = {k.replace("model.model.", ""): v for k, v in weights.items()} model_state_dict = self.model.state_dict() @@ -142,7 +151,7 @@ def save_load_weights(self, weights: Union[Path, OrderedDict]): # TODO2: weight transform if num_class difference error_dict = {"Mismatch": set(), "Not Found": set()} - + for model_key, model_weight in model_state_dict.items(): if model_key not in weights: error_dict["Not Found"].add(tuple(model_key.split(".")[:-2])) @@ -155,7 +164,9 @@ def save_load_weights(self, weights: Union[Path, OrderedDict]): self.model.load_state_dict(model_state_dict) -def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, class_num: int = 80) -> YOLO: +def create_model( + model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, class_num: int = 80 +) -> YOLO: """Constructs and returns a model from a Dictionary configuration file. Args: diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo/tools/data_augmentation.py b/PytorchWildlife/models/detection/yolo_mit/yolo/tools/data_augmentation.py index a8003e135..aea66964b 100644 --- a/PytorchWildlife/models/detection/yolo_mit/yolo/tools/data_augmentation.py +++ b/PytorchWildlife/models/detection/yolo_mit/yolo/tools/data_augmentation.py @@ -45,11 +45,13 @@ def __call__(self, image: Image, boxes): pad_left = (self.target_width - new_width) // 2 pad_top = (self.target_height - new_height) // 2 - padded_image = Image.new("RGB", (self.target_width, self.target_height), self.background_color) + padded_image = Image.new( + "RGB", (self.target_width, self.target_height), self.background_color + ) padded_image.paste(resized_image, (pad_left, pad_top)) boxes[:, [1, 3]] = (boxes[:, [1, 3]] * new_width + pad_left) / self.target_width boxes[:, [2, 4]] = (boxes[:, [2, 4]] * new_height + pad_top) / self.target_height transform_info = torch.tensor([scale, pad_left, pad_top, pad_left, pad_top]) - return padded_image, boxes, transform_info \ No newline at end of file + return padded_image, boxes, transform_info diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo/tools/data_loader.py b/PytorchWildlife/models/detection/yolo_mit/yolo/tools/data_loader.py index 67a91e5e4..cbde2c15f 100644 --- a/PytorchWildlife/models/detection/yolo_mit/yolo/tools/data_loader.py +++ b/PytorchWildlife/models/detection/yolo_mit/yolo/tools/data_loader.py @@ -8,7 +8,6 @@ from rich.progress import track from torch import Tensor from torch.utils.data import DataLoader, Dataset - from yolo.config import DataConfig, DatasetConfig from yolo.tools.data_augmentation import AugmentationComposer from yolo.tools.dataset_preparation import prepare_dataset @@ -32,7 +31,9 @@ def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, phase: str transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()] self.transform = AugmentationComposer(transforms, self.image_size, self.base_size) self.transform.get_more_data = self.get_more_data - self.img_paths, self.bboxes, self.ratios = tensorlize(self.load_data(Path(dataset_cfg.path), phase_name)) + self.img_paths, self.bboxes, self.ratios = tensorlize( + self.load_data(Path(dataset_cfg.path), phase_name) + ) def load_data(self, dataset_path: Path, phase_name: str): """ @@ -94,7 +95,9 @@ def filter_data(self, dataset_path: Path, phase_name: str, sort_image: bool = Fa if not label_path.is_file(): continue with open(label_path, "r") as file: - image_seg_annotations = [list(map(float, line.strip().split())) for line in file] + image_seg_annotations = [ + list(map(float, line.strip().split())) for line in file + ] else: image_seg_annotations = [] @@ -140,43 +143,42 @@ def load_valid_labels(self, label_path: str, seg_data_one_img: list) -> Union[Te else: return torch.zeros((0, 5)) - def adapt_labels(self, bboxes: Tensor) -> Tensor: - """ - Adapt bounding box labels using vectorized operations. - - Args: - bboxes (Tensor): Tensor of bounding boxes in the format [class_id, width, height, x_center, y_center]. - - Returns: - Tensor: Tensor of adapted bounding boxes in the format [class_id, xmin, ymin, xmax, ymax]. - """ - class_ids = bboxes[:, 0] - widths = bboxes[:, 1] - heights = bboxes[:, 2] - x_centers = bboxes[:, 3] - y_centers = bboxes[:, 4] - - xmins = x_centers - widths / 2 - ymins = y_centers - heights / 2 - xmaxs = x_centers + widths / 2 - ymaxs = y_centers + heights / 2 - - adapted_bboxes = torch.stack([class_ids, xmins, ymins, xmaxs, ymaxs], dim=1) - - return adapted_bboxes + def adapt_labels(self, bboxes: Tensor) -> Tensor: + """ + Adapt bounding box labels using vectorized operations. + + Args: + bboxes (Tensor): Tensor of bounding boxes in the format [class_id, width, height, x_center, y_center]. + + Returns: + Tensor: Tensor of adapted bounding boxes in the format [class_id, xmin, ymin, xmax, ymax]. + """ + class_ids = bboxes[:, 0] + widths = bboxes[:, 1] + heights = bboxes[:, 2] + x_centers = bboxes[:, 3] + y_centers = bboxes[:, 4] + + xmins = x_centers - widths / 2 + ymins = y_centers - heights / 2 + xmaxs = x_centers + widths / 2 + ymaxs = y_centers + heights / 2 + + adapted_bboxes = torch.stack([class_ids, xmins, ymins, xmaxs, ymaxs], dim=1) + + return adapted_bboxes def adapt_labels_list(self, points): - - x_center = points[0] - y_center = points[1] - width = points[2] - height = points[3] - - xmin = x_center - width / 2 - ymin = y_center - height / 2 - xmax = x_center + width / 2 - ymax = y_center + height / 2 - + x_center = points[0] + y_center = points[1] + width = points[2] + height = points[3] + + xmin = x_center - width / 2 + ymin = y_center - height / 2 + xmax = x_center + width / 2 + ymax = y_center + height / 2 + return [xmin, ymin, xmax, ymax] def get_data(self, idx): @@ -228,4 +230,4 @@ def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: st num_workers=data_cfg.cpu_num, pin_memory=data_cfg.pin_memory, collate_fn=collate_fn, - ) \ No newline at end of file + ) diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo/tools/dataset_preparation.py b/PytorchWildlife/models/detection/yolo_mit/yolo/tools/dataset_preparation.py index bd41c1ec4..8cb089bd6 100644 --- a/PytorchWildlife/models/detection/yolo_mit/yolo/tools/dataset_preparation.py +++ b/PytorchWildlife/models/detection/yolo_mit/yolo/tools/dataset_preparation.py @@ -2,7 +2,6 @@ from typing import Optional import requests - from yolo.config import DatasetConfig @@ -48,4 +47,3 @@ def prepare_weight(download_link: Optional[str] = None, weight_path: Path = Path download_file(weight_link, weight_path) except requests.exceptions.RequestException as e: raise RuntimeError(f"Failed to download the weight file: {e}") - diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo/utils/bounding_box_utils.py b/PytorchWildlife/models/detection/yolo_mit/yolo/utils/bounding_box_utils.py index b6e1c13e6..34e781f0b 100644 --- a/PytorchWildlife/models/detection/yolo_mit/yolo/utils/bounding_box_utils.py +++ b/PytorchWildlife/models/detection/yolo_mit/yolo/utils/bounding_box_utils.py @@ -3,7 +3,6 @@ import torch from torch import Tensor, tensor from torchvision.ops import batched_nms - from yolo.config import AnchorConfig, NMSConfig from yolo.model.yolo import YOLO @@ -109,7 +108,9 @@ def __init__(self, model: YOLO, anchor_cfg: AnchorConfig, image_size, device): self.head_num = len(anchor_cfg.anchor) self.anchor_grids = self.generate_anchors(image_size) - self.anchor_scale = tensor(anchor_cfg.anchor, device=device).view(self.head_num, 1, -1, 1, 1, 2) + self.anchor_scale = tensor(anchor_cfg.anchor, device=device).view( + self.head_num, 1, -1, 1, 1, 2 + ) self.anchor_num = self.anchor_scale.size(2) self.class_num = model.num_classes @@ -128,7 +129,9 @@ def generate_anchors(self, image_size: List[int]): for stride in self.strides: W, H = image_size[0] // stride, image_size[1] // stride anchor_h, anchor_w = torch.meshgrid([torch.arange(H), torch.arange(W)], indexing="ij") - anchor_grid = torch.stack((anchor_w, anchor_h), 2).view((1, 1, H, W, 2)).float().to(self.device) + anchor_grid = ( + torch.stack((anchor_w, anchor_h), 2).view((1, 1, H, W, 2)).float().to(self.device) + ) anchor_grids.append(anchor_grid) return anchor_grids @@ -147,11 +150,9 @@ def __call__(self, predicts: List[Tensor]): pred_box = pred_box.sigmoid() pred_box[..., 0:2] = ( - (pred_box[..., 0:2] * 2.0 - 0.5 + self.anchor_grids[layer_idx]) * self.strides[layer_idx] - ) - pred_box[..., 2:4] = ( - (pred_box[..., 2:4] * 2) ** 2 * self.anchor_scale[layer_idx] - ) + pred_box[..., 0:2] * 2.0 - 0.5 + self.anchor_grids[layer_idx] + ) * self.strides[layer_idx] + pred_box[..., 2:4] = (pred_box[..., 2:4] * 2) ** 2 * self.anchor_scale[layer_idx] B, L, h, w, A = pred_box.shape preds_box.append(pred_box.reshape(B, L * h * w, A)) @@ -177,20 +178,29 @@ def create_converter(model_version: str = "v9-c", *args, **kwargs) -> Union[Anc2 return converter -def bbox_nms(cls_dist: Tensor, bbox: Tensor, nms_cfg: NMSConfig, confidence: Optional[Tensor] = None): +def bbox_nms( + cls_dist: Tensor, bbox: Tensor, nms_cfg: NMSConfig, confidence: Optional[Tensor] = None +): cls_dist = cls_dist.sigmoid() * (1 if confidence is None else confidence) batch_idx, valid_grid, valid_cls = torch.where(cls_dist > nms_cfg.min_confidence) valid_con = cls_dist[batch_idx, valid_grid, valid_cls] valid_box = bbox[batch_idx, valid_grid] - nms_idx = batched_nms(valid_box, valid_con, batch_idx + valid_cls * bbox.size(0), nms_cfg.min_iou) + nms_idx = batched_nms( + valid_box, valid_con, batch_idx + valid_cls * bbox.size(0), nms_cfg.min_iou + ) predicts_nms = [] for idx in range(cls_dist.size(0)): instance_idx = nms_idx[idx == batch_idx[nms_idx]] predict_nms = torch.cat( - [valid_cls[instance_idx][:, None], valid_box[instance_idx], valid_con[instance_idx][:, None]], dim=-1 + [ + valid_cls[instance_idx][:, None], + valid_box[instance_idx], + valid_con[instance_idx][:, None], + ], + dim=-1, ) predicts_nms.append(predict_nms[: nms_cfg.max_bbox]) diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo/utils/dataset_utils.py b/PytorchWildlife/models/detection/yolo_mit/yolo/utils/dataset_utils.py index da1989d59..6111bdd24 100644 --- a/PytorchWildlife/models/detection/yolo_mit/yolo/utils/dataset_utils.py +++ b/PytorchWildlife/models/detection/yolo_mit/yolo/utils/dataset_utils.py @@ -3,9 +3,11 @@ from itertools import chain from pathlib import Path from typing import Any, Dict, List, Optional, Tuple + import numpy as np import torch + def discretize_categories(categories: List[Dict[str, int]]) -> Dict[int, int]: """ Maps each unique 'id' in the list of category dictionaries to a sequential integer index. @@ -54,8 +56,14 @@ def create_image_metadata(labels_path: str) -> Tuple[Dict[str, List], Dict[str, """ with open(labels_path, "r") as file: labels_data = json.load(file) - id_to_idx = discretize_categories(labels_data.get("categories", [])) if "categories" in labels_data else None - annotations_index = organize_annotations_by_image(labels_data, id_to_idx) # check lookup is a good name? + id_to_idx = ( + discretize_categories(labels_data.get("categories", [])) + if "categories" in labels_data + else None + ) + annotations_index = organize_annotations_by_image( + labels_data, id_to_idx + ) # check lookup is a good name? image_info_dict = {Path(img["file_name"]).stem: img for img in labels_data["images"]} return annotations_index, image_info_dict @@ -89,7 +97,9 @@ def scale_segmentation( scaled_seg_data = ( np.array(seg_list).reshape(-1, 2) / [w, h] ).tolist() # make the list group in x, y pairs and scaled with image width, height - scaled_flat_seg_data = [category_id] + list(chain(*scaled_seg_data)) # flatten the scaled_seg_data list + scaled_flat_seg_data = [category_id] + list( + chain(*scaled_seg_data) + ) # flatten the scaled_seg_data list seg_array_with_cat.append(scaled_flat_seg_data) return seg_array_with_cat diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo/utils/model_utils.py b/PytorchWildlife/models/detection/yolo_mit/yolo/utils/model_utils.py index e03942e8e..6707dbcc3 100644 --- a/PytorchWildlife/models/detection/yolo_mit/yolo/utils/model_utils.py +++ b/PytorchWildlife/models/detection/yolo_mit/yolo/utils/model_utils.py @@ -1,4 +1,5 @@ from typing import List, Optional, Union + from torch import Tensor from yolo.config import NMSConfig from yolo.model.yolo import YOLO @@ -25,5 +26,7 @@ def __call__( pred_conf = prediction[3] if len(prediction) == 4 else None if rev_tensor is not None: pred_bbox = (pred_bbox - rev_tensor[:, None, 1:]) / rev_tensor[:, 0:1, None] - pred_bbox = bbox_nms(pred_class, pred_bbox, self.nms, pred_conf) #pred_box: [cls, x1, y1, x2, y2, conf] - return pred_bbox \ No newline at end of file + pred_bbox = bbox_nms( + pred_class, pred_bbox, self.nms, pred_conf + ) # pred_box: [cls, x1, y1, x2, y2, conf] + return pred_bbox diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo_mit_base.py b/PytorchWildlife/models/detection/yolo_mit/yolo_mit_base.py index 4b02a7dff..ee286f492 100644 --- a/PytorchWildlife/models/detection/yolo_mit/yolo_mit_base.py +++ b/PytorchWildlife/models/detection/yolo_mit/yolo_mit_base.py @@ -6,45 +6,46 @@ # Importing basic libraries import os -import supervision as sv -import numpy as np -from PIL import Image -import wget -import torch - -from ..base_detector import BaseDetector -from ....data import datasets as pw_data - import sys from pathlib import Path -from lightning import Trainer +import numpy as np +import supervision as sv +import torch +import wget import yaml +from lightning import Trainer from omegaconf import OmegaConf +from PIL import Image + +from ....data import datasets as pw_data +from ..base_detector import BaseDetector project_root = Path(__file__).resolve().parent sys.path.append(str(project_root)) -from yolo import create_model, create_converter, PostProcess, AugmentationComposer +from yolo import AugmentationComposer, PostProcess, create_converter, create_model + class YOLOMITBase(BaseDetector): """ Base detector class for YOLO MIT framework. This class provides utility methods for loading the model, generating results, and performing single and batch image detections. """ + def __init__(self, weights=None, device="cpu", url=None): """ Initialize the YOLO MIT detector. - + Args: - weights (str, optional): + weights (str, optional): Path to the model weights. Defaults to None. - device (str, optional): + device (str, optional): Device for model inference. Defaults to "cpu". - url (str, optional): + url (str, optional): URL to fetch the model weights. Defaults to None. """ - + self.cfg = self._load_cfg() self.transform = AugmentationComposer([], self.cfg.image_size, self.cfg.image_size[0]) self.weights = weights @@ -54,21 +55,29 @@ def __init__(self, weights=None, device="cpu", url=None): def _load_cfg(self): if self.MODEL_NAME == "MDV6-mit-yolov9-c.ckpt": - if not os.path.exists(os.path.join(torch.hub.get_dir(), "checkpoints", "config_v9s.yaml")): + if not os.path.exists( + os.path.join(torch.hub.get_dir(), "checkpoints", "config_v9s.yaml") + ): os.makedirs(os.path.join(torch.hub.get_dir(), "checkpoints"), exist_ok=True) url = "https://zenodo.org/records/15178680/files/config_v9s.yaml?download=1" - config_path = wget.download(url, out=os.path.join(torch.hub.get_dir(), "checkpoints")) + config_path = wget.download( + url, out=os.path.join(torch.hub.get_dir(), "checkpoints") + ) else: config_path = os.path.join(torch.hub.get_dir(), "checkpoints", "config_v9s.yaml") elif self.MODEL_NAME == "MDV6-mit-yolov9-e.ckpt": - if not os.path.exists(os.path.join(torch.hub.get_dir(), "checkpoints", "config_v9c.yaml")): + if not os.path.exists( + os.path.join(torch.hub.get_dir(), "checkpoints", "config_v9c.yaml") + ): os.makedirs(os.path.join(torch.hub.get_dir(), "checkpoints"), exist_ok=True) url = "https://zenodo.org/records/15178680/files/config_v9c.yaml?download=1" - config_path = wget.download(url, out=os.path.join(torch.hub.get_dir(), "checkpoints")) + config_path = wget.download( + url, out=os.path.join(torch.hub.get_dir(), "checkpoints") + ) else: config_path = os.path.join(torch.hub.get_dir(), "checkpoints", "config_v9c.yaml") - with open(config_path, 'r') as f: + with open(config_path, "r") as f: cfg_dict = yaml.safe_load(f) return OmegaConf.create(cfg_dict) @@ -76,78 +85,77 @@ def _load_cfg(self): def _load_model(self, weights=None, device="cpu", url=None): """ Load the YOLO MIT model weights. - + Args: - weights (str, optional): + weights (str, optional): Path to the model weights. Defaults to None. - device (str, optional): + device (str, optional): Device for model inference. Defaults to "cpu". - url (str, optional): + url (str, optional): URL to fetch the model weights. Defaults to None. Raises: Exception: If weights are not provided. """ if url: - if not os.path.exists(os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME)): + if not os.path.exists( + os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME) + ): os.makedirs(os.path.join(torch.hub.get_dir(), "checkpoints"), exist_ok=True) weights = wget.download(url, out=os.path.join(torch.hub.get_dir(), "checkpoints")) else: weights = os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME) else: raise Exception("Need weights for inference.") - + self.cfg.image_size = [self.IMAGE_SIZE, self.IMAGE_SIZE] self.model = create_model(self.cfg.model, weight_path=weights, class_num=3).to(self.device) - self.converter = create_converter(self.cfg.model.name, self.model, self.cfg.model.anchor, self.cfg.image_size, self.device) + self.converter = create_converter( + self.cfg.model.name, self.model, self.cfg.model.anchor, self.cfg.image_size, self.device + ) self.post_proccess = PostProcess(self.converter, self.cfg.task.nms) def results_generation(self, preds, img_id, id_strip=None): """ Generate results for detection based on model predictions. - + Args: - preds (List[torch.Tensor]): + preds (List[torch.Tensor]): Model predictions. - img_id (str): + img_id (str): Image identifier. - id_strip (str, optional): + id_strip (str, optional): Strip specific characters from img_id. Defaults to None. Returns: dict: Dictionary containing image ID, detections, and labels. """ - #preds: [cls, x1, y1, x2, y2, conf] - class_id = preds[0][:,0].cpu().numpy().astype(int) - xyxy = preds[0][:,1:5].cpu().numpy() - confidence = preds[0][:,5].cpu().numpy() + # preds: [cls, x1, y1, x2, y2, conf] + class_id = preds[0][:, 0].cpu().numpy().astype(int) + xyxy = preds[0][:, 1:5].cpu().numpy() + confidence = preds[0][:, 5].cpu().numpy() results = {"img_id": str(img_id).strip(id_strip)} - results["detections"] = sv.Detections( - xyxy=xyxy, - confidence=confidence, - class_id=class_id - ) + results["detections"] = sv.Detections(xyxy=xyxy, confidence=confidence, class_id=class_id) results["labels"] = [ - f"{self.CLASS_NAMES[class_id]} {confidence:0.2f}" - for _, _, confidence, class_id, _, _ in results["detections"] + f"{self.CLASS_NAMES[class_id]} {confidence:0.2f}" + for _, _, confidence, class_id, _, _ in results["detections"] ] results return results - def single_image_detection(self, img, img_path=None, det_conf_thres=0.2, id_strip=None): """ Perform detection on a single image. - + Args: - img (str or ndarray): + img (str or ndarray): Image path or ndarray of images. - img_path (str, optional): + img_path (str, optional): Image path or identifier. - det_conf_thres (float, optional): + det_conf_thres (float, optional): Confidence threshold for predictions. Defaults to 0.2. - id_strip (str, optional): + id_strip (str, optional): Characters to strip from img_id. Defaults to None. Returns: @@ -160,32 +168,34 @@ def single_image_detection(self, img, img_path=None, det_conf_thres=0.2, id_stri if type(img) == str: if img_path is None: img_path = img - im_pil = Image.open(img_path).convert('RGB') + im_pil = Image.open(img_path).convert("RGB") else: im_pil = Image.fromarray(img) image, bbox, rev_tensor = self.transform(im_pil) image = image.to(self.device)[None] rev_tensor = rev_tensor.to(self.device)[None] - + with torch.no_grad(): predict = self.model(image) - det_results = self.post_proccess(predict, rev_tensor) #pred_box: [cls, x1, y1, x2, y2, conf] - + det_results = self.post_proccess( + predict, rev_tensor + ) # pred_box: [cls, x1, y1, x2, y2, conf] + return self.results_generation(det_results, img_path, id_strip) def batch_image_detection(self, data_path, batch_size=16, det_conf_thres=0.2, id_strip=None): """ Perform detection on a batch of images. - + Args: - data_path (str): + data_path (str): Path containing all images for inference. batch_size (int, optional): Batch size for inference. Defaults to 16. - det_conf_thres (float, optional): + det_conf_thres (float, optional): Confidence threshold for predictions. Defaults to 0.2. - id_strip (str, optional): + id_strip (str, optional): Characters to strip from img_id. Defaults to None. extension (str, optional): Image extension to search for. Defaults to "JPG" @@ -196,25 +206,30 @@ def batch_image_detection(self, data_path, batch_size=16, det_conf_thres=0.2, id self.cfg.task.data.source = data_path self.cfg.task.nms.min_confidence = det_conf_thres self._load_model(weights=self.weights, device=self.device, url=self.url) - + dataset = pw_data.DetectionImageFolder( data_path, transform=self.transform, ) - + results = [] for i in range(len(dataset.images)): - res = self.single_image_detection(dataset.images[i], img_path=dataset.images[i], det_conf_thres=det_conf_thres, id_strip=id_strip) + res = self.single_image_detection( + dataset.images[i], + img_path=dataset.images[i], + det_conf_thres=det_conf_thres, + id_strip=id_strip, + ) # Upload the original image and get the size in the format (height, width) img = Image.open(dataset.images[i]) img = np.asarray(img) size = img.shape[:2] # Normalize the coordinates for timelapse compatibility - normalized_coords = [[x1 / size[1], y1 / size[0], x2 / size[1], y2 / size[0]] for x1, y1, x2, y2 in res["detections"].xyxy] + normalized_coords = [ + [x1 / size[1], y1 / size[0], x2 / size[1], y2 / size[0]] + for x1, y1, x2, y2 in res["detections"].xyxy + ] res["normalized_coords"] = normalized_coords results.append(res) return results - - - diff --git a/PytorchWildlife/utils/__init__.py b/PytorchWildlife/utils/__init__.py index bca86797f..66efc7844 100644 --- a/PytorchWildlife/utils/__init__.py +++ b/PytorchWildlife/utils/__init__.py @@ -1,2 +1,2 @@ from .misc import * -from .post_process import * \ No newline at end of file +from .post_process import * diff --git a/PytorchWildlife/utils/misc.py b/PytorchWildlife/utils/misc.py index f516a66c2..ae32d4006 100644 --- a/PytorchWildlife/utils/misc.py +++ b/PytorchWildlife/utils/misc.py @@ -3,15 +3,14 @@ """ Miscellaneous functions.""" -import numpy as np -from tqdm import tqdm -import cv2 from typing import Callable + +import cv2 +import numpy as np from supervision import VideoInfo, VideoSink, get_video_frames_generator +from tqdm import tqdm -__all__ = [ - "process_video" -] +__all__ = ["process_video"] def process_video( @@ -19,32 +18,32 @@ def process_video( target_path: str, callback: Callable[[np.ndarray, int], np.ndarray], target_fps: int = 1, - codec: str = "mp4v" + codec: str = "mp4v", ) -> None: """ - Process a video frame-by-frame, applying a callback function to each frame and saving the results + Process a video frame-by-frame, applying a callback function to each frame and saving the results to a new video. This version includes a progress bar and allows codec selection. - + Args: - source_path (str): + source_path (str): Path to the source video file. - target_path (str): + target_path (str): Path to save the processed video. - callback (Callable[[np.ndarray, int], np.ndarray]): + callback (Callable[[np.ndarray, int], np.ndarray]): A function that takes a video frame and its index as input and returns the processed frame. - codec (str, optional): + codec (str, optional): Codec used to encode the processed video. Default is "avc1". """ source_video_info = VideoInfo.from_video_path(video_path=source_path) - + if source_video_info.fps > target_fps: stride = int(source_video_info.fps / target_fps) source_video_info.fps = target_fps else: stride = 1 - + with VideoSink(target_path=target_path, video_info=source_video_info, codec=codec) as sink: - with tqdm(total=int(source_video_info.total_frames / stride)) as pbar: + with tqdm(total=int(source_video_info.total_frames / stride)) as pbar: for index, frame in enumerate( get_video_frames_generator(source_path=source_path, stride=stride) ): diff --git a/PytorchWildlife/utils/post_process.py b/PytorchWildlife/utils/post_process.py index d7152ae4e..955a4b661 100644 --- a/PytorchWildlife/utils/post_process.py +++ b/PytorchWildlife/utils/post_process.py @@ -3,15 +3,16 @@ """ Post-processing functions.""" -import os -import numpy as np import json -import cv2 -from PIL import Image -import supervision as sv +import os import shutil from pathlib import Path +import cv2 +import numpy as np +import supervision as sv +from PIL import Image + __all__ = [ "save_detection_images", "save_detection_images_dots", @@ -21,7 +22,7 @@ "save_detection_classification_json", "save_detection_timelapse_json", "save_detection_classification_timelapse_json", - "detection_folder_separation" + "detection_folder_separation", ] @@ -43,7 +44,7 @@ def save_detection_images(results, output_dir, input_dir=None, overwrite=False): lab_annotator = sv.LabelAnnotator(text_color=sv.Color.BLACK, text_thickness=4, text_scale=2) os.makedirs(output_dir, exist_ok=True) - with sv.ImageSink(target_dir_path=output_dir, overwrite=overwrite) as sink: + with sv.ImageSink(target_dir_path=output_dir, overwrite=overwrite) as sink: if isinstance(results, list): for entry in results: annotated_img = lab_annotator.annotate( @@ -57,8 +58,8 @@ def save_detection_images(results, output_dir, input_dir=None, overwrite=False): if input_dir: relative_path = os.path.relpath(entry["img_id"], input_dir) save_path = os.path.join(output_dir, relative_path) - os.makedirs(os.path.dirname(save_path), exist_ok=True) - image_name = relative_path + os.makedirs(os.path.dirname(save_path), exist_ok=True) + image_name = relative_path else: image_name = os.path.basename(entry["img_id"]) sink.save_image( @@ -75,9 +76,11 @@ def save_detection_images(results, output_dir, input_dir=None, overwrite=False): ) sink.save_image( - image=cv2.cvtColor(annotated_img, cv2.COLOR_RGB2BGR), image_name=os.path.basename(results["img_id"]) + image=cv2.cvtColor(annotated_img, cv2.COLOR_RGB2BGR), + image_name=os.path.basename(results["img_id"]), ) + def save_detection_images_dots(results, output_dir, input_dir=None, overwrite=False): """ Save detected images with bounding boxes and labels annotated. @@ -92,10 +95,10 @@ def save_detection_images_dots(results, output_dir, input_dir=None, overwrite=Fa overwrite (bool): Whether overwriting existing image folders. Default to False. """ - dot_annotator = sv.DotAnnotator(radius=6) - lab_annotator = sv.LabelAnnotator(text_position=sv.Position.BOTTOM_RIGHT) + dot_annotator = sv.DotAnnotator(radius=6) + lab_annotator = sv.LabelAnnotator(text_position=sv.Position.BOTTOM_RIGHT) os.makedirs(output_dir, exist_ok=True) - + with sv.ImageSink(target_dir_path=output_dir, overwrite=overwrite) as sink: if isinstance(results, list): for i, entry in enumerate(results): @@ -104,7 +107,7 @@ def save_detection_images_dots(results, output_dir, input_dir=None, overwrite=Fa image_name = os.path.basename(entry["img_id"]) else: scene = entry["img"] - image_name = f"output_image_{i}.jpg" # default name if no image id is provided + image_name = f"output_image_{i}.jpg" # default name if no image id is provided annotated_img = lab_annotator.annotate( scene=dot_annotator.annotate( @@ -117,7 +120,7 @@ def save_detection_images_dots(results, output_dir, input_dir=None, overwrite=Fa if input_dir: relative_path = os.path.relpath(entry["img_id"], input_dir) save_path = os.path.join(output_dir, relative_path) - os.makedirs(os.path.dirname(save_path), exist_ok=True) + os.makedirs(os.path.dirname(save_path), exist_ok=True) image_name = relative_path sink.save_image( image=cv2.cvtColor(annotated_img, cv2.COLOR_RGB2BGR), image_name=image_name @@ -128,8 +131,8 @@ def save_detection_images_dots(results, output_dir, input_dir=None, overwrite=Fa image_name = os.path.basename(results["img_id"]) else: scene = results["img"] - image_name = "output_image.jpg" # default name if no image id is provided - + image_name = "output_image.jpg" # default name if no image id is provided + annotated_img = lab_annotator.annotate( scene=dot_annotator.annotate( scene=scene, @@ -137,7 +140,7 @@ def save_detection_images_dots(results, output_dir, input_dir=None, overwrite=Fa ), detections=results["detections"], labels=results["labels"], - ) + ) sink.save_image( image=cv2.cvtColor(annotated_img, cv2.COLOR_RGB2BGR), image_name=image_name ) @@ -164,32 +167,48 @@ def save_crop_images(results, output_dir, input_dir=None, overwrite=False): with sv.ImageSink(target_dir_path=output_dir, overwrite=overwrite) as sink: if isinstance(results, list): for entry in results: - for i, (xyxy, cat) in enumerate(zip(entry["detections"].xyxy, entry["detections"].class_id)): + for i, (xyxy, cat) in enumerate( + zip(entry["detections"].xyxy, entry["detections"].class_id) + ): cropped_img = sv.crop_image( image=np.array(Image.open(entry["img_id"]).convert("RGB")), xyxy=xyxy ) if input_dir: relative_path = os.path.relpath(entry["img_id"], input_dir) save_path = os.path.join(output_dir, relative_path) - os.makedirs(os.path.dirname(save_path), exist_ok=True) - image_name = os.path.join(os.path.dirname(relative_path), "{}_{}_{}".format(int(cat), i, os.path.basename(entry["img_id"]))) + os.makedirs(os.path.dirname(save_path), exist_ok=True) + image_name = os.path.join( + os.path.dirname(relative_path), + "{}_{}_{}".format(int(cat), i, os.path.basename(entry["img_id"])), + ) else: - image_name = "{}_{}_{}".format(int(cat), i, os.path.basename(entry["img_id"])) + image_name = "{}_{}_{}".format( + int(cat), i, os.path.basename(entry["img_id"]) + ) sink.save_image( image=cv2.cvtColor(cropped_img, cv2.COLOR_RGB2BGR), image_name=image_name, ) else: - for i, (xyxy, cat) in enumerate(zip(results["detections"].xyxy, results["detections"].class_id)): + for i, (xyxy, cat) in enumerate( + zip(results["detections"].xyxy, results["detections"].class_id) + ): cropped_img = sv.crop_image( image=np.array(Image.open(results["img_id"]).convert("RGB")), xyxy=xyxy ) sink.save_image( image=cv2.cvtColor(cropped_img, cv2.COLOR_RGB2BGR), - image_name="{}_{}_{}".format(int(cat), i, os.path.basename(results["img_id"]), - )) + image_name="{}_{}_{}".format( + int(cat), + i, + os.path.basename(results["img_id"]), + ), + ) + -def save_detection_json(det_results, output_dir, categories=None, exclude_category_ids=[], exclude_file_path=None): +def save_detection_json( + det_results, output_dir, categories=None, exclude_category_ids=[], exclude_file_path=None +): """ Save detection results to a JSON file. @@ -208,19 +227,20 @@ def save_detection_json(det_results, output_dir, categories=None, exclude_catego json_results = {"annotations": [], "categories": categories} for det_r in det_results: - # Category filtering img_id = det_r["img_id"] category = det_r["detections"].class_id bbox = det_r["detections"].xyxy.astype(int)[~np.isin(category, exclude_category_ids)] - confidence = det_r["detections"].confidence[~np.isin(category, exclude_category_ids)] + confidence = det_r["detections"].confidence[~np.isin(category, exclude_category_ids)] category = category[~np.isin(category, exclude_category_ids)] # if not all([x in exclude_category_ids for x in category]): json_results["annotations"].append( { - "img_id": img_id.replace(exclude_file_path + os.sep, '') if exclude_file_path else img_id, + "img_id": img_id.replace(exclude_file_path + os.sep, "") + if exclude_file_path + else img_id, "bbox": bbox.tolist(), "category": category.tolist(), "confidence": confidence.tolist(), @@ -230,7 +250,10 @@ def save_detection_json(det_results, output_dir, categories=None, exclude_catego with open(output_dir, "w") as f: json.dump(json_results, f, indent=4) -def save_detection_json_as_dots(det_results, output_dir, categories=None, exclude_category_ids=[], exclude_file_path=None): + +def save_detection_json_as_dots( + det_results, output_dir, categories=None, exclude_category_ids=[], exclude_file_path=None +): """ Save detection results to a JSON file in dots format. @@ -249,20 +272,21 @@ def save_detection_json_as_dots(det_results, output_dir, categories=None, exclud json_results = {"annotations": [], "categories": categories} for det_r in det_results: - # Category filtering img_id = det_r["img_id"] category = det_r["detections"].class_id bbox = det_r["detections"].xyxy.astype(int)[~np.isin(category, exclude_category_ids)] dot = np.array([[np.mean(row[::2]), np.mean(row[1::2])] for row in bbox]) - confidence = det_r["detections"].confidence[~np.isin(category, exclude_category_ids)] + confidence = det_r["detections"].confidence[~np.isin(category, exclude_category_ids)] category = category[~np.isin(category, exclude_category_ids)] # if not all([x in exclude_category_ids for x in category]): json_results["annotations"].append( { - "img_id": img_id.replace(exclude_file_path + os.sep, '') if exclude_file_path else img_id, + "img_id": img_id.replace(exclude_file_path + os.sep, "") + if exclude_file_path + else img_id, "dot": dot.tolist(), "category": category.tolist(), "confidence": confidence.tolist(), @@ -274,9 +298,13 @@ def save_detection_json_as_dots(det_results, output_dir, categories=None, exclud def save_detection_timelapse_json( - det_results, output_dir, categories=None, - exclude_category_ids=[], exclude_file_path=None, info={"detector": "megadetector_v5"} - ): + det_results, + output_dir, + categories=None, + exclude_category_ids=[], + exclude_file_path=None, + info={"detector": "megadetector_v5"}, +): """ Save detection results to a JSON file. @@ -295,35 +323,41 @@ def save_detection_timelapse_json( Default Timelapse info. Defaults to {"detector": "megadetector_v5}. """ - json_results = { - "info": info, - "detection_categories": categories, - "images": [] - } + json_results = {"info": info, "detection_categories": categories, "images": []} for det_r in det_results: - img_id = det_r["img_id"] category_id_list = det_r["detections"].class_id - bbox_list = det_r["detections"].xyxy.astype(int)[~np.isin(category_id_list, exclude_category_ids)] - confidence_list = det_r["detections"].confidence[~np.isin(category_id_list, exclude_category_ids)] - normalized_bbox_list = np.array(det_r["normalized_coords"])[~np.isin(category_id_list, exclude_category_ids)] + bbox_list = det_r["detections"].xyxy.astype(int)[ + ~np.isin(category_id_list, exclude_category_ids) + ] + confidence_list = det_r["detections"].confidence[ + ~np.isin(category_id_list, exclude_category_ids) + ] + normalized_bbox_list = np.array(det_r["normalized_coords"])[ + ~np.isin(category_id_list, exclude_category_ids) + ] category_id_list = category_id_list[~np.isin(category_id_list, exclude_category_ids)] # if not all([x in exclude_category_ids for x in category_id_list]): image_annotations = { - "file": img_id.replace(exclude_file_path + os.sep, '') if exclude_file_path else img_id, - "max_detection_conf": float(max(confidence_list)) if len(confidence_list) > 0 else '', - "detections": [] + "file": img_id.replace(exclude_file_path + os.sep, "") if exclude_file_path else img_id, + "max_detection_conf": float(max(confidence_list)) if len(confidence_list) > 0 else "", + "detections": [], } for i in range(len(bbox_list)): normalized_bbox = [float(y) for y in normalized_bbox_list[i]] detection = { "category": str(category_id_list[i]), "conf": float(confidence_list[i]), - "bbox": [normalized_bbox[0], normalized_bbox[1], normalized_bbox[2]-normalized_bbox[0], normalized_bbox[3]-normalized_bbox[1]], - "classifications": [] + "bbox": [ + normalized_bbox[0], + normalized_bbox[1], + normalized_bbox[2] - normalized_bbox[0], + normalized_bbox[3] - normalized_bbox[1], + ], + "classifications": [], } image_annotations["detections"].append(detection) @@ -335,7 +369,12 @@ def save_detection_timelapse_json( def save_detection_classification_json( - det_results, clf_results, output_path, det_categories=None, clf_categories=None, exclude_file_path=None + det_results, + clf_results, + output_path, + det_categories=None, + clf_categories=None, + exclude_file_path=None, ): """ Save classification results to a JSON file. @@ -377,17 +416,15 @@ def save_detection_classification_json( json_results["annotations"].append( { - "img_id": str(det_r["img_id"]).replace(exclude_file_path + os.sep, '') if exclude_file_path else str(det_r["img_id"]), + "img_id": str(det_r["img_id"]).replace(exclude_file_path + os.sep, "") + if exclude_file_path + else str(det_r["img_id"]), "bbox": [ [int(x) for x in sublist] for sublist in det_r["detections"].xyxy.astype(int).tolist() ], - "det_category": [ - int(x) for x in det_r["detections"].class_id.tolist() - ], - "det_confidence": [ - float(x) for x in det_r["detections"].confidence.tolist() - ], + "det_category": [int(x) for x in det_r["detections"].class_id.tolist()], + "det_confidence": [float(x) for x in det_r["detections"].confidence.tolist()], "clf_category": [int(x) for x in clf_categories], "clf_confidence": [float(x) for x in clf_confidence], } @@ -396,8 +433,13 @@ def save_detection_classification_json( def save_detection_classification_timelapse_json( - det_results, clf_results, output_path, det_categories=None, clf_categories=None, - exclude_file_path=None, info={"detector": "megadetector_v5"} + det_results, + clf_results, + output_path, + det_categories=None, + clf_categories=None, + exclude_file_path=None, + info={"detector": "megadetector_v5"}, ): """ Save detection and classification results to a JSON file in the specified format. @@ -420,14 +462,18 @@ def save_detection_classification_timelapse_json( "info": info, "detection_categories": det_categories, "classification_categories": clf_categories, - "images": [] + "images": [], } for det_r in det_results: image_annotations = { - "file": str(det_r["img_id"]).replace(exclude_file_path + os.sep, '') if exclude_file_path else str(det_r["img_id"]), - "max_detection_conf": float(max(det_r["detections"].confidence)) if len(det_r["detections"].confidence) > 0 else '', - "detections": [] + "file": str(det_r["img_id"]).replace(exclude_file_path + os.sep, "") + if exclude_file_path + else str(det_r["img_id"]), + "max_detection_conf": float(max(det_r["detections"].confidence)) + if len(det_r["detections"].confidence) > 0 + else "", + "detections": [], } for i in range(len(det_r["detections"])): @@ -436,14 +482,21 @@ def save_detection_classification_timelapse_json( detection = { "category": str(det.class_id[0]), "conf": float(det.confidence[0]), - "bbox": [normalized_bbox[0], normalized_bbox[1], normalized_bbox[2]-normalized_bbox[0], normalized_bbox[3]-normalized_bbox[1]], - "classifications": [] + "bbox": [ + normalized_bbox[0], + normalized_bbox[1], + normalized_bbox[2] - normalized_bbox[0], + normalized_bbox[3] - normalized_bbox[1], + ], + "classifications": [], } # Find classifications for this detection for clf_r in clf_results: if clf_r["img_id"] == det_r["img_id"]: - detection["classifications"].append([str(clf_r["class_id"]), float(clf_r["confidence"])]) + detection["classifications"].append( + [str(clf_r["class_id"]), float(clf_r["confidence"])] + ) image_annotations["detections"].append(detection) @@ -461,7 +514,7 @@ def detection_folder_separation(json_file, img_path, destination_path, confidenc This function reads a JSON formatted file containing annotations of image detections. Each image is checked for detections with category '0' and a confidence level above the specified threshold. If such detections are found, the image is categorized under 'Animal'. Images without - any category '0' detections above the threshold, including those with no detections at all, are + any category '0' detections above the threshold, including those with no detections at all, are categorized under 'No_animal'. Parameters: @@ -485,41 +538,41 @@ def detection_folder_separation(json_file, img_path, destination_path, confidenc """ # Load JSON data from the file - with open(json_file, 'r') as file: + with open(json_file, "r") as file: data = json.load(file) - + # Ensure the destination directories exist os.makedirs(destination_path, exist_ok=True) animal_path = os.path.join(destination_path, "Animal") no_animal_path = os.path.join(destination_path, "No_animal") os.makedirs(animal_path, exist_ok=True) os.makedirs(no_animal_path, exist_ok=True) - + # Process each image detection i = 0 - for item in data['annotations']: - i+=1 - img_id = item['img_id'] - categories = item['category'] - confidences = item['confidence'] - + for item in data["annotations"]: + i += 1 + img_id = item["img_id"] + categories = item["category"] + confidences = item["confidence"] + # Check if there is any category '0' with confidence above the threshold file_targeted_for_animal = False for category, confidence in zip(categories, confidences): if category == 0 and confidence > confidence_threshold: file_targeted_for_animal = True break - + if file_targeted_for_animal: target_folder = animal_path else: target_folder = no_animal_path - + # Construct the source and destination file paths src_file_path = os.path.join(img_path, img_id) dest_file_path = os.path.join(target_folder, os.path.dirname(img_id)) os.makedirs(dest_file_path, exist_ok=True) - + # Copy the file to the appropriate directory shutil.copy(src_file_path, dest_file_path) diff --git a/README.md b/README.md index ce3c0efef..de464849d 100644 --- a/README.md +++ b/README.md @@ -222,4 +222,3 @@ We are also building a list of contributors and will release in future updates! >[!IMPORTANT] >If you would like to be added to this list or have any questions regarding MegaDetector and Pytorch-Wildlife, please [email us](zhongqimiao@microsoft.com) or join us in our Discord channel: [![](https://img.shields.io/badge/any_text-Join_us!-blue?logo=discord&label=PytorchWildife)](https://discord.gg/TeEVxzaYtm) - diff --git a/demo/detection_classification_pipeline_demo.py b/demo/detection_classification_pipeline_demo.py index 58f65ae48..c46e6b629 100644 --- a/demo/detection_classification_pipeline_demo.py +++ b/demo/detection_classification_pipeline_demo.py @@ -3,34 +3,37 @@ """ Demo for image detection""" -#%% +# %% # Importing necessary basic libraries and modules import os + import numpy as np -from PIL import Image import supervision as sv -# PyTorch imports +# PyTorch imports import torch +from PIL import Image from torch.utils.data import DataLoader -#%% -# Importing the model, dataset, transformations and utility functions from PytorchWildlife -from PytorchWildlife.models import detection as pw_detection from PytorchWildlife import utils as pw_utils +from PytorchWildlife.data import datasets as pw_data +from PytorchWildlife.data import transforms as pw_trans +# %% +# Importing the model, dataset, transformations and utility functions from PytorchWildlife from PytorchWildlife.models import classification as pw_classification -from PytorchWildlife.data import transforms as pw_trans -from PytorchWildlife.data import datasets as pw_data +from PytorchWildlife.models import detection as pw_detection -#%% +# %% # Setting the device to use for computations ('cuda' indicates GPU) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" -#%% +# %% # Initializing the MegaDetectorV6 model for image detection # Valid versions are MDV6-yolov9-c, MDV6-yolov9-e, MDV6-yolov10-c, MDV6-yolov10-e or MDV6-rtdetr-c -detection_model = pw_detection.MegaDetectorV6(device=DEVICE, pretrained=True, version="MDV6-yolov10-e") +detection_model = pw_detection.MegaDetectorV6( + device=DEVICE, pretrained=True, version="MDV6-yolov10-e" +) # Uncomment the following line to use MegaDetectorV5 instead of MegaDetectorV6 # detection_model = pw_detection.MegaDetectorV5(device=DEVICE, pretrained=True, version="a") @@ -38,49 +41,57 @@ # %% # Initializing a classification model for image classification # classification_model = pw_classification.DFNE(device=DEVICE) -classification_model = pw_classification.AI4GAmazonRainforest(device=DEVICE, version='v2') +classification_model = pw_classification.AI4GAmazonRainforest(device=DEVICE, version="v2") -#%% Single image detection +# %% Single image detection # Specifying the path to the target image TODO: Allow argparsing -tgt_img_path = os.path.join(".","demo_data","imgs","10050028_0.JPG") +tgt_img_path = os.path.join(".", "demo_data", "imgs", "10050028_0.JPG") # Performing the detection on the single image results = detection_model.single_image_detection(tgt_img_path) clf_conf_thres = 0.8 -input_img = np.array(Image.open(tgt_img_path).convert('RGB')) +input_img = np.array(Image.open(tgt_img_path).convert("RGB")) clf_labels = [] for i, (xyxy, det_id) in enumerate(zip(results["detections"].xyxy, results["detections"].class_id)): # Only run classifier when detection class is animal if det_id == 0: cropped_image = sv.crop_image(image=input_img, xyxy=xyxy) results_clf = classification_model.single_image_classification(cropped_image) - clf_labels.append("{} {:.2f}".format(results_clf["prediction"] if results_clf["confidence"] > clf_conf_thres else "Unknown", - results_clf["confidence"])) + clf_labels.append( + "{} {:.2f}".format( + results_clf["prediction"] + if results_clf["confidence"] > clf_conf_thres + else "Unknown", + results_clf["confidence"], + ) + ) else: clf_labels.append(results["labels"][i]) results["labels"] = clf_labels # %% -# Saving the detection results -pw_utils.save_detection_images(results, os.path.join(".","demo_output"), overwrite=False) +# Saving the detection results +pw_utils.save_detection_images(results, os.path.join(".", "demo_output"), overwrite=False) # %%# Saving the detected objects as cropped images -pw_utils.save_crop_images(results, os.path.join(".","crop_output"), overwrite=False) +pw_utils.save_crop_images(results, os.path.join(".", "crop_output"), overwrite=False) # %% -#%% Batch detection +# %% Batch detection """ Batch-detection demo """ # Specifying the folder path containing multiple images for batch detection -tgt_folder_path = os.path.join(".","demo_data","imgs") +tgt_folder_path = os.path.join(".", "demo_data", "imgs") # Performing batch detection on the images det_results = detection_model.batch_image_detection(tgt_folder_path, batch_size=16) -clf_results = classification_model.batch_image_classification(det_results=det_results, id_strip=tgt_folder_path) +clf_results = classification_model.batch_image_classification( + det_results=det_results, id_strip=tgt_folder_path +) # %% merged_results = det_results.copy() @@ -91,33 +102,43 @@ clf_labels = [] for i, (xyxy, det_id) in enumerate(zip(det["detections"].xyxy, det["detections"].class_id)): if det_id == 0: - clf_labels.append("{} {:.2f}".format(clf_results[clf_counter]["prediction"] if clf_results[clf_counter]["confidence"] > clf_conf_thres else "Unknown", - clf_results[clf_counter]["confidence"])) + clf_labels.append( + "{} {:.2f}".format( + clf_results[clf_counter]["prediction"] + if clf_results[clf_counter]["confidence"] > clf_conf_thres + else "Unknown", + clf_results[clf_counter]["confidence"], + ) + ) else: clf_labels.append(det["labels"][i]) clf_counter += 1 det["labels"] = clf_labels -#%% Output to annotated images +# %% Output to annotated images # Saving the batch detection results as annotated images pw_utils.save_detection_images(merged_results, "batch_output", tgt_folder_path, overwrite=False) -#%% Output to cropped images +# %% Output to cropped images # Saving the detected objects as cropped images pw_utils.save_crop_images(merged_results, "crop_output", tgt_folder_path, overwrite=False) -#%% Output to JSON results +# %% Output to JSON results # Saving the detection results in JSON format -pw_utils.save_detection_classification_json(det_results=det_results, - clf_results=clf_results, - det_categories=detection_model.CLASS_NAMES, - clf_categories=classification_model.CLASS_NAMES, - output_path=os.path.join(".","batch_output_classification.json")) +pw_utils.save_detection_classification_json( + det_results=det_results, + clf_results=clf_results, + det_categories=detection_model.CLASS_NAMES, + clf_categories=classification_model.CLASS_NAMES, + output_path=os.path.join(".", "batch_output_classification.json"), +) # %% # Saving the detection results in timelapse JSON format -pw_utils.save_detection_classification_timelapse_json(det_results=det_results, - clf_results=clf_results, - det_categories=detection_model.CLASS_NAMES, - clf_categories=classification_model.CLASS_NAMES, - output_path=os.path.join(".","batch_output_classification_timelapse.json")) \ No newline at end of file +pw_utils.save_detection_classification_timelapse_json( + det_results=det_results, + clf_results=clf_results, + det_categories=detection_model.CLASS_NAMES, + clf_categories=classification_model.CLASS_NAMES, + output_path=os.path.join(".", "batch_output_classification_timelapse.json"), +) diff --git a/demo/gradio_demo.py b/demo/gradio_demo.py index 58c5d7962..6f5f70ce9 100644 --- a/demo/gradio_demo.py +++ b/demo/gradio_demo.py @@ -3,31 +3,33 @@ """ Gradio Demo for image detection""" +import ast + # Importing necessary basic libraries and modules import os -# PyTorch imports -import torch -from torch.utils.data import DataLoader - -# Importing the model, dataset, transformations and utility functions from PytorchWildlife -from PytorchWildlife.models import detection as pw_detection -from PytorchWildlife import utils as pw_utils - # Importing basic libraries import shutil import time -from PIL import Image -import supervision as sv -import gradio as gr from zipfile import ZipFile + +import gradio as gr import numpy as np -import ast +import supervision as sv + +# PyTorch imports +import torch +from PIL import Image +from torch.utils.data import DataLoader + +from PytorchWildlife import utils as pw_utils +from PytorchWildlife.data import datasets as pw_data +from PytorchWildlife.data import transforms as pw_trans # Importing the models, dataset, transformations, and utility functions from PytorchWildlife +# Importing the model, dataset, transformations and utility functions from PytorchWildlife from PytorchWildlife.models import classification as pw_classification -from PytorchWildlife.data import transforms as pw_trans -from PytorchWildlife.data import datasets as pw_data +from PytorchWildlife.models import detection as pw_detection # Setting the device to use for computations ('cuda' indicates GPU) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" @@ -36,15 +38,15 @@ box_annotator = sv.BoxAnnotator(thickness=4) lab_annotator = sv.LabelAnnotator(text_color=sv.Color.BLACK, text_thickness=4, text_scale=2) # Create a temp folder -os.makedirs(os.path.join("..","temp"), exist_ok=True) # ASK: Why do we need this? +os.makedirs(os.path.join("..", "temp"), exist_ok=True) # ASK: Why do we need this? # Initializing the detection and classification models detection_model = None classification_model = None - + + # Defining functions for different detection scenarios def load_models(det, version, clf, wpath=None, wclass=None): - global detection_model, classification_model if det != "None": if det == "HerdNet General": @@ -53,11 +55,17 @@ def load_models(det, version, clf, wpath=None, wclass=None): detection_model = pw_detection.HerdNet(device=DEVICE, version="ennedi") else: if "mit" in version: - detection_model = pw_detection.MegaDetectorV6MIT(device=DEVICE, pretrained=True, version=version) + detection_model = pw_detection.MegaDetectorV6MIT( + device=DEVICE, pretrained=True, version=version + ) elif "apache" in version: - detection_model = pw_detection.MegaDetectorV6Apache(device=DEVICE, pretrained=True, version=version) + detection_model = pw_detection.MegaDetectorV6Apache( + device=DEVICE, pretrained=True, version=version + ) else: - detection_model = pw_detection.__dict__[det](device=DEVICE, pretrained=True, version=version) + detection_model = pw_detection.__dict__[det]( + device=DEVICE, pretrained=True, version=version + ) else: detection_model = None return "NO MODEL LOADED!!" @@ -65,9 +73,11 @@ def load_models(det, version, clf, wpath=None, wclass=None): if clf != "None": # Create an exception for custom weights if clf == "CustomWeights": - if (wpath is not None) and (wclass is not None): + if (wpath is not None) and (wclass is not None): wclass = ast.literal_eval(wclass) - classification_model = pw_classification.__dict__[clf](weights=wpath, class_names=wclass, device=DEVICE) + classification_model = pw_classification.__dict__[clf]( + weights=wpath, class_names=wclass, device=DEVICE + ) else: classification_model = pw_classification.__dict__[clf](device=DEVICE, pretrained=True) else: @@ -92,25 +102,35 @@ def single_image_detection(input_img, det_conf_thres, clf_conf_thres, img_index= if detection_model.__class__.__name__.__contains__("HerdNet"): annotator = dot_annotator # Herdnet receives both clf and det confidence thresholds - results_det = detection_model.single_image_detection(input_img, - img_path=img_index, - det_conf_thres=det_conf_thres, - clf_conf_thres=clf_conf_thres) + results_det = detection_model.single_image_detection( + input_img, + img_path=img_index, + det_conf_thres=det_conf_thres, + clf_conf_thres=clf_conf_thres, + ) else: annotator = box_annotator - results_det = detection_model.single_image_detection(input_img, - img_path=img_index, - det_conf_thres = det_conf_thres) - + results_det = detection_model.single_image_detection( + input_img, img_path=img_index, det_conf_thres=det_conf_thres + ) + if classification_model is not None: labels = [] - for i, (xyxy, det_id) in enumerate(zip(results_det["detections"].xyxy, results_det["detections"].class_id)): + for i, (xyxy, det_id) in enumerate( + zip(results_det["detections"].xyxy, results_det["detections"].class_id) + ): # Only run classifier when detection class is animal if det_id == 0: cropped_image = sv.crop_image(image=input_img, xyxy=xyxy) results_clf = classification_model.single_image_classification(cropped_image) - labels.append("{} {:.2f}".format(results_clf["prediction"] if results_clf["confidence"] > clf_conf_thres else "Unknown", - results_clf["confidence"])) + labels.append( + "{} {:.2f}".format( + results_clf["prediction"] + if results_clf["confidence"] > clf_conf_thres + else "Unknown", + results_clf["confidence"], + ) + ) else: labels.append(results_det["labels"][i]) else: @@ -126,9 +146,10 @@ def single_image_detection(input_img, det_conf_thres, clf_conf_thres, img_index= ) return annotated_img + def batch_detection(zip_file, timelapse, det_conf_thres): """Perform detection on a batch of images from a zip file and return path to results JSON. - + Args: zip_file (File): Zip file containing images. det_conf_thres (float): Confidence threshold for detection. @@ -139,7 +160,7 @@ def batch_detection(zip_file, timelapse, det_conf_thres): json_save_path (str): Path to the JSON file containing detection results. """ # Clean the temp folder if it contains files - extract_path = os.path.join("..","temp","zip_upload") + extract_path = os.path.join("..", "temp", "zip_upload") if os.path.exists(extract_path): shutil.rmtree(extract_path) os.makedirs(extract_path) @@ -149,53 +170,76 @@ def batch_detection(zip_file, timelapse, det_conf_thres): zfile.extractall(extract_path) # Check the contents of the extracted folder extracted_files = os.listdir(extract_path) - + if len(extracted_files) == 1 and os.path.isdir(os.path.join(extract_path, extracted_files[0])): tgt_folder_path = os.path.join(extract_path, extracted_files[0]) else: tgt_folder_path = extract_path # If the detection model is HerdNet set batch_size to 1 if detection_model.__class__.__name__.__contains__("HerdNet"): - det_results = detection_model.batch_image_detection(tgt_folder_path, batch_size=1, det_conf_thres=det_conf_thres, id_strip=tgt_folder_path) + det_results = detection_model.batch_image_detection( + tgt_folder_path, batch_size=1, det_conf_thres=det_conf_thres, id_strip=tgt_folder_path + ) else: - det_results = detection_model.batch_image_detection(tgt_folder_path, batch_size=16, det_conf_thres=det_conf_thres, id_strip=tgt_folder_path) + det_results = detection_model.batch_image_detection( + tgt_folder_path, batch_size=16, det_conf_thres=det_conf_thres, id_strip=tgt_folder_path + ) if classification_model is not None: clf_dataset = pw_data.DetectionCrops( det_results, transform=pw_trans.Classification_Inference_Transform(target_size=224), - path_head=tgt_folder_path + path_head=tgt_folder_path, + ) + clf_loader = DataLoader( + clf_dataset, + batch_size=32, + shuffle=False, + pin_memory=True, + num_workers=4, + drop_last=False, + ) + clf_results = classification_model.batch_image_classification( + clf_loader, id_strip=tgt_folder_path ) - clf_loader = DataLoader(clf_dataset, batch_size=32, shuffle=False, - pin_memory=True, num_workers=4, drop_last=False) - clf_results = classification_model.batch_image_classification(clf_loader, id_strip=tgt_folder_path) if timelapse: json_save_path = json_save_path.replace(".json", "_timelapse.json") - pw_utils.save_detection_classification_timelapse_json(det_results=det_results, - clf_results=clf_results, - det_categories=detection_model.CLASS_NAMES, - clf_categories=classification_model.CLASS_NAMES, - output_path=json_save_path) + pw_utils.save_detection_classification_timelapse_json( + det_results=det_results, + clf_results=clf_results, + det_categories=detection_model.CLASS_NAMES, + clf_categories=classification_model.CLASS_NAMES, + output_path=json_save_path, + ) else: - pw_utils.save_detection_classification_json(det_results=det_results, - clf_results=clf_results, - det_categories=detection_model.CLASS_NAMES, - clf_categories=classification_model.CLASS_NAMES, - output_path=json_save_path) + pw_utils.save_detection_classification_json( + det_results=det_results, + clf_results=clf_results, + det_categories=detection_model.CLASS_NAMES, + clf_categories=classification_model.CLASS_NAMES, + output_path=json_save_path, + ) else: if timelapse: json_save_path = json_save_path.replace(".json", "_timelapse.json") - pw_utils.save_detection_timelapse_json(det_results, json_save_path, categories=detection_model.CLASS_NAMES) + pw_utils.save_detection_timelapse_json( + det_results, json_save_path, categories=detection_model.CLASS_NAMES + ) elif detection_model.__class__.__name__.__contains__("HerdNet"): - pw_utils.save_detection_json_as_dots(det_results, json_save_path, categories=detection_model.CLASS_NAMES) - else: - pw_utils.save_detection_json(det_results, json_save_path, categories=detection_model.CLASS_NAMES) + pw_utils.save_detection_json_as_dots( + det_results, json_save_path, categories=detection_model.CLASS_NAMES + ) + else: + pw_utils.save_detection_json( + det_results, json_save_path, categories=detection_model.CLASS_NAMES + ) return json_save_path + def batch_path_detection(tgt_folder_path, det_conf_thres): """Perform detection on a batch of images from a zip file and return path to results JSON. - + Args: tgt_folder_path (str): path to the folder containing the images. det_conf_thres (float): Confidence threshold for detection. @@ -204,36 +248,48 @@ def batch_path_detection(tgt_folder_path, det_conf_thres): """ json_save_path = os.path.join(tgt_folder_path, "results.json") - det_results = detection_model.batch_image_detection(tgt_folder_path, det_conf_thres=det_conf_thres, id_strip=tgt_folder_path) + det_results = detection_model.batch_image_detection( + tgt_folder_path, det_conf_thres=det_conf_thres, id_strip=tgt_folder_path + ) if detection_model.__class__.__name__.__contains__("HerdNet"): - pw_utils.save_detection_json_as_dots(det_results, json_save_path, categories=detection_model.CLASS_NAMES) + pw_utils.save_detection_json_as_dots( + det_results, json_save_path, categories=detection_model.CLASS_NAMES + ) else: - pw_utils.save_detection_json(det_results, json_save_path, categories=detection_model.CLASS_NAMES) + pw_utils.save_detection_json( + det_results, json_save_path, categories=detection_model.CLASS_NAMES + ) return json_save_path def video_detection(video, det_conf_thres, clf_conf_thres, target_fps, codec): """Perform detection on a video and return path to processed video. - + Args: video (str): Video source path. det_conf_thres (float): Confidence threshold for detection. clf_conf_thres (float): Confidence threshold for classification. """ + def callback(frame, index): - annotated_frame = single_image_detection(frame, - img_index=index, - det_conf_thres=det_conf_thres, - clf_conf_thres=clf_conf_thres) - return annotated_frame - - target_path = os.path.join("..","temp","video_detection.mp4") - pw_utils.process_video(source_path=video, target_path=target_path, - callback=callback, target_fps=int(target_fps), codec=codec) + annotated_frame = single_image_detection( + frame, img_index=index, det_conf_thres=det_conf_thres, clf_conf_thres=clf_conf_thres + ) + return annotated_frame + + target_path = os.path.join("..", "temp", "video_detection.mp4") + pw_utils.process_video( + source_path=video, + target_path=target_path, + callback=callback, + target_fps=int(target_fps), + codec=codec, + ) return target_path + # Building Gradio UI with gr.Blocks() as demo: @@ -243,37 +299,72 @@ def callback(frame, index): ["None", "MegaDetectorV5", "MegaDetectorV6", "HerdNet General", "HerdNet Ennedi"], label="Detection model", info="Will add more detection models!", - value="None" # Default + value="None", # Default ) - det_version = gr.Dropdown( - ["None"], - label="Model version", + det_version = gr.Dropdown( + ["None"], + label="Model version", info="Select the version of the model", value="None", ) - + with gr.Column(): clf_drop = gr.Dropdown( - ["None", "AI4GOpossum", "AI4GAmazonRainforest", "AI4GSnapshotSerengeti", "CustomWeights"], + [ + "None", + "AI4GOpossum", + "AI4GAmazonRainforest", + "AI4GSnapshotSerengeti", + "CustomWeights", + ], interactive=True, label="Classification model", info="Will add more classification models!", visible=False, - value="None" + value="None", + ) + custom_weights_path = gr.Textbox( + label="Custom Weights Path", + visible=False, + interactive=True, + placeholder="./weights/my_weight.pt", + ) + custom_weights_class = gr.Textbox( + label="Custom Weights Class", + visible=False, + interactive=True, + placeholder="{1:'ocelot', 2:'cow', 3:'bear'}", ) - custom_weights_path = gr.Textbox(label="Custom Weights Path", visible=False, interactive=True, placeholder="./weights/my_weight.pt") - custom_weights_class = gr.Textbox(label="Custom Weights Class", visible=False, interactive=True, placeholder="{1:'ocelot', 2:'cow', 3:'bear'}") load_but = gr.Button("Load Models!") load_out = gr.Text("NO MODEL LOADED!!", label="Loaded models:") - - def update_ui_elements(det_model): - if det_model == "MegaDetectorV6": - return gr.Dropdown(choices=["MDV6-yolov9-c", "MDV6-yolov9-e", "MDV6-yolov10-c", "MDV6-yolov10-e", "MDV6-rtdetr-c", "MDV6-yolov9-c-mit", "MDV6-yolov9-e-mit", "MDV6-rtdetr-c-apache", "MDV6-rtdetr-e-apache"], interactive=True, label="Model version", value="MDV6-yolov9e"), gr.update(visible=True) - elif det_model == "MegaDetectorV5": - return gr.Dropdown(choices=["a", "b"], interactive=True, label="Model version", value="a"), gr.update(visible=True) + + def update_ui_elements(det_model): + if det_model == "MegaDetectorV6": + return gr.Dropdown( + choices=[ + "MDV6-yolov9-c", + "MDV6-yolov9-e", + "MDV6-yolov10-c", + "MDV6-yolov10-e", + "MDV6-rtdetr-c", + "MDV6-yolov9-c-mit", + "MDV6-yolov9-e-mit", + "MDV6-rtdetr-c-apache", + "MDV6-rtdetr-e-apache", + ], + interactive=True, + label="Model version", + value="MDV6-yolov9e", + ), gr.update(visible=True) + elif det_model == "MegaDetectorV5": + return gr.Dropdown( + choices=["a", "b"], interactive=True, label="Model version", value="a" + ), gr.update(visible=True) else: - return gr.Dropdown(choices=["None"], interactive=True, label="Model version", value="None"), gr.update(value="None", visible=False) - + return gr.Dropdown( + choices=["None"], interactive=True, label="Model version", value="None" + ), gr.update(value="None", visible=False) + det_drop.change(update_ui_elements, det_drop, [det_version, clf_drop]) def toggle_textboxes(model): @@ -281,20 +372,18 @@ def toggle_textboxes(model): return gr.update(visible=True), gr.update(visible=True) else: return gr.update(visible=False), gr.update(visible=False) - - clf_drop.change( - toggle_textboxes, - clf_drop, - [custom_weights_path, custom_weights_class] - ) + + clf_drop.change(toggle_textboxes, clf_drop, [custom_weights_path, custom_weights_class]) with gr.Tab("Single Image Process"): with gr.Row(): with gr.Column(): sgl_in = gr.Image(type="pil") sgl_conf_sl_det = gr.Slider(0, 1, label="Detection Confidence Threshold", value=0.2) - sgl_conf_sl_clf = gr.Slider(0, 1, label="Classification Confidence Threshold", value=0.7, visible=True) - sgl_out = gr.Image() + sgl_conf_sl_clf = gr.Slider( + 0, 1, label="Classification Confidence Threshold", value=0.7, visible=True + ) + sgl_out = gr.Image() sgl_but = gr.Button("Detect Animals!") with gr.Tab("Folder Separation"): with gr.Row(): @@ -306,9 +395,18 @@ def toggle_textboxes(model): bth_out2 = gr.File(label="Detection Results JSON.", height=200) with gr.Column(): process_files_button = gr.Button("Separate files") - process_result = gr.Text("Click on 'Separate files' once you see the JSON file", label="Separated files:") - process_btn.click(batch_path_detection, inputs=[inp_path, bth_conf_fs], outputs=bth_out2) - process_files_button.click(pw_utils.detection_folder_separation, inputs=[bth_out2, inp_path, out_path, bth_conf_fs], outputs=process_result) + process_result = gr.Text( + "Click on 'Separate files' once you see the JSON file", + label="Separated files:", + ) + process_btn.click( + batch_path_detection, inputs=[inp_path, bth_conf_fs], outputs=bth_out2 + ) + process_files_button.click( + pw_utils.detection_folder_separation, + inputs=[bth_out2, inp_path, out_path, bth_conf_fs], + outputs=process_result, + ) with gr.Tab("Batch Image Process"): with gr.Row(): with gr.Column(): @@ -323,28 +421,42 @@ def toggle_textboxes(model): with gr.Column(): vid_in = gr.Video(label="Upload a video.") vid_conf_sl_det = gr.Slider(0, 1, label="Detection Confidence Threshold", value=0.2) - vid_conf_sl_clf = gr.Slider(0, 1, label="Classification Confidence Threshold", value=0.7) + vid_conf_sl_clf = gr.Slider( + 0, 1, label="Classification Confidence Threshold", value=0.7 + ) vid_fr = gr.Dropdown([5, 10, 30], label="Output video framerate", value=30) vid_enc = gr.Dropdown( ["mp4v", "avc1"], label="Video encoder", info="mp4v is default, av1c is faster (needs conda install opencv)", - value="mp4v" - ) + value="mp4v", + ) vid_out = gr.Video() vid_but = gr.Button("Detect Animals!") - + # Show timelapsed checkbox only when detection model is not HerdNet det_drop.change( - lambda model: gr.update(visible=True) if "HerdNet" not in model else gr.update(visible=False), + lambda model: gr.update(visible=True) + if "HerdNet" not in model + else gr.update(visible=False), det_drop, - [chck_timelapse] + [chck_timelapse], + ) + + load_but.click( + load_models, + inputs=[det_drop, det_version, clf_drop, custom_weights_path, custom_weights_class], + outputs=load_out, + ) + sgl_but.click( + single_image_detection, inputs=[sgl_in, sgl_conf_sl_det, sgl_conf_sl_clf], outputs=sgl_out ) - - load_but.click(load_models, inputs=[det_drop, det_version, clf_drop, custom_weights_path, custom_weights_class], outputs=load_out) - sgl_but.click(single_image_detection, inputs=[sgl_in, sgl_conf_sl_det, sgl_conf_sl_clf], outputs=sgl_out) bth_but.click(batch_detection, inputs=[bth_in, chck_timelapse, bth_conf_sl], outputs=bth_out) - vid_but.click(video_detection, inputs=[vid_in, vid_conf_sl_det, vid_conf_sl_clf, vid_fr, vid_enc], outputs=vid_out) + vid_but.click( + video_detection, + inputs=[vid_in, vid_conf_sl_det, vid_conf_sl_clf, vid_fr, vid_enc], + outputs=vid_out, + ) if __name__ == "__main__": demo.launch(share=True) diff --git a/demo/image_demo.py b/demo/image_demo.py index 7e65bf5da..a9d347ef7 100644 --- a/demo/image_demo.py +++ b/demo/image_demo.py @@ -3,25 +3,29 @@ """ Demo for image detection""" -#%% +# %% # Importing necessary basic libraries and modules import os -# PyTorch imports + +# PyTorch imports import torch -#%% +from PytorchWildlife import utils as pw_utils + +# %% # Importing the model, dataset, transformations and utility functions from PytorchWildlife from PytorchWildlife.models import detection as pw_detection -from PytorchWildlife import utils as pw_utils -#%% +# %% # Setting the device to use for computations ('cuda' indicates GPU) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" -#%% +# %% # Initializing the MegaDetectorV6 model for image detection # Valid versions are MDV6-yolov9-c, MDV6-yolov9-e, MDV6-yolov10-c, MDV6-yolov10-e or MDV6-rtdetr-c -detection_model = pw_detection.MegaDetectorV6(device=DEVICE, pretrained=True, version="MDV6-yolov10-e") +detection_model = pw_detection.MegaDetectorV6( + device=DEVICE, pretrained=True, version="MDV6-yolov10-e" +) # Uncomment the following line to use MegaDetectorV6 with yolo v9 MIT weights # Valid versions are MDV6-mit-yolov9-c, MDV6-mit-yolov9-e @@ -34,46 +38,52 @@ # Uncomment the following line to use MegaDetectorV5 instead of MegaDetectorV6 # detection_model = pw_detection.MegaDetectorV5(device=DEVICE, pretrained=True, version="a") -#%% Single image detection +# %% Single image detection # Specifying the path to the target image TODO: Allow argparsing -tgt_img_path = os.path.join(".","demo_data","imgs","10050028_0.JPG") +tgt_img_path = os.path.join(".", "demo_data", "imgs", "10050028_0.JPG") # Performing the detection on the single image results = detection_model.single_image_detection(tgt_img_path) -# Saving the detection results -pw_utils.save_detection_images(results, os.path.join(".","demo_output"), overwrite=False) +# Saving the detection results +pw_utils.save_detection_images(results, os.path.join(".", "demo_output"), overwrite=False) # Saving the detected objects as cropped images -pw_utils.save_crop_images(results, os.path.join(".","crop_output"), overwrite=False) +pw_utils.save_crop_images(results, os.path.join(".", "crop_output"), overwrite=False) -#%% Batch detection +# %% Batch detection """ Batch-detection demo """ # Specifying the folder path containing multiple images for batch detection -tgt_folder_path = os.path.join(".","demo_data","imgs") +tgt_folder_path = os.path.join(".", "demo_data", "imgs") # Performing batch detection on the images results = detection_model.batch_image_detection(tgt_folder_path, batch_size=16) -#%% Output to annotated images +# %% Output to annotated images # Saving the batch detection results as annotated images pw_utils.save_detection_images(results, "batch_output", tgt_folder_path, overwrite=False) -#%% Output to cropped images +# %% Output to cropped images # Saving the detected objects as cropped images pw_utils.save_crop_images(results, "crop_output", tgt_folder_path, overwrite=False) -#%% Output to JSON results +# %% Output to JSON results # Saving the detection results in JSON format -pw_utils.save_detection_json(results, os.path.join(".","batch_output.json"), - categories=detection_model.CLASS_NAMES, - exclude_category_ids=[], # Category IDs can be found in the definition of each model. - exclude_file_path=None) +pw_utils.save_detection_json( + results, + os.path.join(".", "batch_output.json"), + categories=detection_model.CLASS_NAMES, + exclude_category_ids=[], # Category IDs can be found in the definition of each model. + exclude_file_path=None, +) # Saving the detection results in timelapse JSON format -pw_utils.save_detection_timelapse_json(results, os.path.join(".","batch_output_timelapse.json"), - categories=detection_model.CLASS_NAMES, - exclude_category_ids=[], # Category IDs can be found in the definition of each model. - exclude_file_path=tgt_folder_path, - info={"detector": "MegaDetectorV6"}) \ No newline at end of file +pw_utils.save_detection_timelapse_json( + results, + os.path.join(".", "batch_output_timelapse.json"), + categories=detection_model.CLASS_NAMES, + exclude_category_ids=[], # Category IDs can be found in the definition of each model. + exclude_file_path=tgt_folder_path, + info={"detector": "MegaDetectorV6"}, +) diff --git a/demo/image_demo_herdnet.py b/demo/image_demo_herdnet.py index d8837caf2..37c0f85eb 100644 --- a/demo/image_demo_herdnet.py +++ b/demo/image_demo_herdnet.py @@ -2,51 +2,62 @@ # Licensed under the MIT License. """ Demo for Herdnet image detection""" -#%% +# %% # Importing necessary basic libraries and modules import os -# PyTorch imports + +# PyTorch imports import torch -#%% +from PytorchWildlife import utils as pw_utils + +# %% # Importing the model, dataset, transformations and utility functions from PytorchWildlife from PytorchWildlife.models import detection as pw_detection -from PytorchWildlife import utils as pw_utils -#%% +# %% # Setting the device to use for computations ('cuda' indicates GPU) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" -#%% +# %% # Initializing the HerdNet model for image detection detection_model = pw_detection.HerdNet(device=DEVICE) # If you want to use ennedi dataset weights, you can use the following line: # detection_model = pw_detection.HerdNet(device=DEVICE, version="ennedi") -#%% Single image detection -img_path = os.path.join(".","demo_data","herdnet_imgs","S_11_05_16_DSC01556.JPG") +# %% Single image detection +img_path = os.path.join(".", "demo_data", "herdnet_imgs", "S_11_05_16_DSC01556.JPG") # Performing the detection on the single image results = detection_model.single_image_detection(img=img_path) -#%% Output to annotated images +# %% Output to annotated images # Saving the batch detection results as annotated images -pw_utils.save_detection_images_dots(results, os.path.join(".","herdnet_demo_output"), overwrite=False) +pw_utils.save_detection_images_dots( + results, os.path.join(".", "herdnet_demo_output"), overwrite=False +) -#%% Batch image detection +# %% Batch image detection """ Batch-detection demo """ # Specifying the folder path containing multiple images for batch detection -folder_path = os.path.join(".","demo_data","herdnet_imgs") +folder_path = os.path.join(".", "demo_data", "herdnet_imgs") # Performing batch detection on the images -results = detection_model.batch_image_detection(folder_path, batch_size=1) # NOTE: Only use batch size 1 because each image is divided into patches and this batch is enough. +results = detection_model.batch_image_detection( + folder_path, batch_size=1 +) # NOTE: Only use batch size 1 because each image is divided into patches and this batch is enough. -#%% Output to annotated images +# %% Output to annotated images # Saving the batch detection results as annotated images -pw_utils.save_detection_images_dots(results, "herdnet_demo_batch_output", folder_path, overwrite=False) +pw_utils.save_detection_images_dots( + results, "herdnet_demo_batch_output", folder_path, overwrite=False +) # Saving the detection results in JSON format -pw_utils.save_detection_json_as_dots(results, os.path.join(".","herdnet_demo_batch_output.json"), - categories=detection_model.CLASS_NAMES, - exclude_category_ids=[], # Category IDs can be found in the definition of each model. - exclude_file_path=None) +pw_utils.save_detection_json_as_dots( + results, + os.path.join(".", "herdnet_demo_batch_output.json"), + categories=detection_model.CLASS_NAMES, + exclude_category_ids=[], # Category IDs can be found in the definition of each model. + exclude_file_path=None, +) diff --git a/demo/image_separation_demo.py b/demo/image_separation_demo.py index 962287601..10aee3e89 100644 --- a/demo/image_separation_demo.py +++ b/demo/image_separation_demo.py @@ -3,31 +3,50 @@ """ Demo for image separation between positive and negative detections""" -#%% +# %% # Importing necessary basic libraries and modules import argparse import os + import torch -# PyTorch imports -from PytorchWildlife.models import detection as pw_detection from PytorchWildlife import utils as pw_utils -#%% Argument parsing +# PyTorch imports +from PytorchWildlife.models import detection as pw_detection + +# %% Argument parsing parser = argparse.ArgumentParser(description="Batch image detection and separation") -parser.add_argument('--image_folder', type=str, default=os.path.join(".","demo_data","imgs"), help='Folder path containing images for detection') -parser.add_argument('--output_path', type=str, default='folder_separation', help='Path where the outputs will be saved') -parser.add_argument('--threshold', type=float, default='0.2', help='Confidence threshold to consider a detection as positive') +parser.add_argument( + "--image_folder", + type=str, + default=os.path.join(".", "demo_data", "imgs"), + help="Folder path containing images for detection", +) +parser.add_argument( + "--output_path", + type=str, + default="folder_separation", + help="Path where the outputs will be saved", +) +parser.add_argument( + "--threshold", + type=float, + default="0.2", + help="Confidence threshold to consider a detection as positive", +) args = parser.parse_args() -#%% +# %% # Setting the device to use for computations ('cuda' indicates GPU) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" -#%% +# %% # Initializing the MegaDetectorV6 model for image detection # Valid versions are MDV6-yolov9-c, MDV6-yolov9-e, MDV6-yolov10-c, MDV6-yolov10-e or MDV6-rtdetr-c -detection_model = pw_detection.MegaDetectorV6(device=DEVICE, pretrained=True, version="MDV6-yolov10-e") +detection_model = pw_detection.MegaDetectorV6( + device=DEVICE, pretrained=True, version="MDV6-yolov10-e" +) # Uncomment the following line to use MegaDetectorV6 with yolo v9 MIT weights # Valid versions are MDV6-mit-yolov9-c, MDV6-mit-yolov9-e @@ -40,20 +59,23 @@ # Uncomment the following line to use MegaDetectorV5 instead of MegaDetectorV6 # detection_model = pw_detection.MegaDetectorV5(device=DEVICE, pretrained=True, version="a") -#%% Batch detection +# %% Batch detection """ Batch-detection demo """ # Performing batch detection on the images results = detection_model.batch_image_detection(args.image_folder, batch_size=16) -#%% Output to JSON results +# %% Output to JSON results # Saving the detection results in JSON format os.makedirs(args.output_path, exist_ok=True) json_file = os.path.join(args.output_path, "detection_results.json") -pw_utils.save_detection_json(results, json_file, - categories=detection_model.CLASS_NAMES, - exclude_category_ids=[], # Category IDs can be found in the definition of each model. - exclude_file_path=args.image_folder) +pw_utils.save_detection_json( + results, + json_file, + categories=detection_model.CLASS_NAMES, + exclude_category_ids=[], # Category IDs can be found in the definition of each model. + exclude_file_path=args.image_folder, +) # Separate the positive and negative detections through file copying: -pw_utils.detection_folder_separation(json_file, args.image_folder, args.output_path, args.threshold) \ No newline at end of file +pw_utils.detection_folder_separation(json_file, args.image_folder, args.output_path, args.threshold) diff --git a/demo/video_demo.py b/demo/video_demo.py index 3def5c9c9..98575522f 100644 --- a/demo/video_demo.py +++ b/demo/video_demo.py @@ -2,31 +2,36 @@ # Licensed under the MIT License. """ Video detection demo """ -#%% +import os + +# %% # Importing necessary basic libraries and modules import numpy as np import supervision as sv -#%% +# %% # PyTorch imports for tensor operations import torch -import os -#%% -# Importing the models, transformations, and utility functions from PytorchWildlife -from PytorchWildlife.models import detection as pw_detection -from PytorchWildlife.models import classification as pw_classification + from PytorchWildlife import utils as pw_utils -#%% +# %% +# Importing the models, transformations, and utility functions from PytorchWildlife +from PytorchWildlife.models import classification as pw_classification +from PytorchWildlife.models import detection as pw_detection + +# %% # Setting the device to use for computations ('cuda' indicates GPU) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" -SOURCE_VIDEO_PATH = os.path.join(".","demo_data","videos","opossum_example.MP4") -TARGET_VIDEO_PATH = os.path.join(".","demo_data","videos","opossum_example_processed.MP4") +SOURCE_VIDEO_PATH = os.path.join(".", "demo_data", "videos", "opossum_example.MP4") +TARGET_VIDEO_PATH = os.path.join(".", "demo_data", "videos", "opossum_example_processed.MP4") -#%% +# %% # Initializing the MegaDetectorV6 model for image detection # Valid versions are MDV6-yolov9-c, MDV6-yolov9-e, MDV6-yolov10-c, MDV6-yolov10-e or MDV6-rtdetr-c -detection_model = pw_detection.MegaDetectorV6(device=DEVICE, pretrained=True, version="MDV6-yolov10-e") +detection_model = pw_detection.MegaDetectorV6( + device=DEVICE, pretrained=True, version="MDV6-yolov10-e" +) # Uncomment the following line to use MegaDetectorV6 with yolo v9 MIT weights # Valid versions are MDV6-mit-yolov9-c, MDV6-mit-yolov9-e @@ -39,27 +44,28 @@ # Uncomment the following line to use MegaDetectorV5 instead of MegaDetectorV6 # detection_model = pw_detection.MegaDetectorV5(device=DEVICE, pretrained=True, version="a") -#%% +# %% # Initializing the model for image classification classification_model = pw_classification.AI4GOpossum(device=DEVICE, pretrained=True) -#%% +# %% # Initializing a box annotator for visualizing detections box_annotator = sv.BoxAnnotator(thickness=4) lab_annotator = sv.LabelAnnotator(text_color=sv.Color.BLACK, text_thickness=4, text_scale=2) + def callback(frame: np.ndarray, index: int) -> np.ndarray: """ Callback function to process each video frame for detection and classification. - + Parameters: - frame (np.ndarray): Video frame as a numpy array. - index (int): Frame index. - + Returns: annotated_frame (np.ndarray): Annotated video frame. """ - + results_det = detection_model.single_image_detection(frame, img_path=index) labels = [] @@ -77,8 +83,11 @@ def callback(frame: np.ndarray, index: int) -> np.ndarray: detections=results_det["detections"], labels=labels, ) - - return annotated_frame + + return annotated_frame + # Processing the video and saving the result with annotated detections and classifications -pw_utils.process_video(source_path=SOURCE_VIDEO_PATH, target_path=TARGET_VIDEO_PATH, callback=callback, target_fps=10) +pw_utils.process_video( + source_path=SOURCE_VIDEO_PATH, target_path=TARGET_VIDEO_PATH, callback=callback, target_fps=10 +) diff --git a/docs/base/data/datasets.md b/docs/base/data/datasets.md index a16a16c8f..74a1ed2d2 100644 --- a/docs/base/data/datasets.md +++ b/docs/base/data/datasets.md @@ -1,3 +1,3 @@ # Datasets Module -::: PytorchWildlife.data.datasets \ No newline at end of file +::: PytorchWildlife.data.datasets diff --git a/docs/base/data/transforms.md b/docs/base/data/transforms.md index b9b264689..7e12c941f 100644 --- a/docs/base/data/transforms.md +++ b/docs/base/data/transforms.md @@ -1,3 +1,3 @@ # Transforms Module -::: PytorchWildlife.data.transforms \ No newline at end of file +::: PytorchWildlife.data.transforms diff --git a/docs/base/models/classification/base_classifier.md b/docs/base/models/classification/base_classifier.md index 6f729c9b1..38abedc78 100644 --- a/docs/base/models/classification/base_classifier.md +++ b/docs/base/models/classification/base_classifier.md @@ -1,3 +1,3 @@ # Base Classifier -::: PytorchWildlife.models.classification.base_classifier \ No newline at end of file +::: PytorchWildlife.models.classification.base_classifier diff --git a/docs/base/models/classification/resnet_base/amazon.md b/docs/base/models/classification/resnet_base/amazon.md index 520bfb778..a299d9376 100644 --- a/docs/base/models/classification/resnet_base/amazon.md +++ b/docs/base/models/classification/resnet_base/amazon.md @@ -1,3 +1,3 @@ # Amazon -::: PytorchWildlife.models.classification.resnet_base.amazon \ No newline at end of file +::: PytorchWildlife.models.classification.resnet_base.amazon diff --git a/docs/base/models/classification/resnet_base/base_classifier.md b/docs/base/models/classification/resnet_base/base_classifier.md index 68c1b339a..1bc9a62d0 100644 --- a/docs/base/models/classification/resnet_base/base_classifier.md +++ b/docs/base/models/classification/resnet_base/base_classifier.md @@ -1,3 +1,3 @@ # ResNet Base -::: PytorchWildlife.models.classification.resnet_base.base_classifier \ No newline at end of file +::: PytorchWildlife.models.classification.resnet_base.base_classifier diff --git a/docs/base/models/classification/resnet_base/custom_weights.md b/docs/base/models/classification/resnet_base/custom_weights.md index e41b4e85e..b594cc188 100644 --- a/docs/base/models/classification/resnet_base/custom_weights.md +++ b/docs/base/models/classification/resnet_base/custom_weights.md @@ -1,3 +1,3 @@ # Custom Weights -::: PytorchWildlife.models.classification.resnet_base.custom_weights \ No newline at end of file +::: PytorchWildlife.models.classification.resnet_base.custom_weights diff --git a/docs/base/models/classification/resnet_base/opossum.md b/docs/base/models/classification/resnet_base/opossum.md index 4f6b11f1c..54c6d0d52 100644 --- a/docs/base/models/classification/resnet_base/opossum.md +++ b/docs/base/models/classification/resnet_base/opossum.md @@ -1,3 +1,3 @@ # Opossum -::: PytorchWildlife.models.classification.resnet_base.opossum \ No newline at end of file +::: PytorchWildlife.models.classification.resnet_base.opossum diff --git a/docs/base/models/classification/resnet_base/serengeti.md b/docs/base/models/classification/resnet_base/serengeti.md index 6f7a3760b..76c659c63 100644 --- a/docs/base/models/classification/resnet_base/serengeti.md +++ b/docs/base/models/classification/resnet_base/serengeti.md @@ -1,3 +1,3 @@ # Serengeti -::: PytorchWildlife.models.classification.resnet_base.serengeti \ No newline at end of file +::: PytorchWildlife.models.classification.resnet_base.serengeti diff --git a/docs/base/models/classification/timm_base/DFNE.md b/docs/base/models/classification/timm_base/DFNE.md index 01f11fff9..17be1fa54 100644 --- a/docs/base/models/classification/timm_base/DFNE.md +++ b/docs/base/models/classification/timm_base/DFNE.md @@ -1,3 +1,3 @@ # DFNE -::: PytorchWildlife.models.classification.timm_base.DFNE \ No newline at end of file +::: PytorchWildlife.models.classification.timm_base.DFNE diff --git a/docs/base/models/classification/timm_base/Deepfaune.md b/docs/base/models/classification/timm_base/Deepfaune.md index a943f5900..fe7ea756a 100644 --- a/docs/base/models/classification/timm_base/Deepfaune.md +++ b/docs/base/models/classification/timm_base/Deepfaune.md @@ -1,3 +1,3 @@ # Deepfaune -::: PytorchWildlife.models.classification.timm_base.Deepfaune \ No newline at end of file +::: PytorchWildlife.models.classification.timm_base.Deepfaune diff --git a/docs/base/models/classification/timm_base/base_classifier.md b/docs/base/models/classification/timm_base/base_classifier.md index ef03c85cb..59905e415 100644 --- a/docs/base/models/classification/timm_base/base_classifier.md +++ b/docs/base/models/classification/timm_base/base_classifier.md @@ -1,3 +1,3 @@ # Timm Base -::: PytorchWildlife.models.classification.timm_base.base_classifier \ No newline at end of file +::: PytorchWildlife.models.classification.timm_base.base_classifier diff --git a/docs/base/models/detection/base_detector.md b/docs/base/models/detection/base_detector.md index 7fbb5f9e0..6c0a74140 100644 --- a/docs/base/models/detection/base_detector.md +++ b/docs/base/models/detection/base_detector.md @@ -1,3 +1,3 @@ # Base Detector -::: PytorchWildlife.models.detection.base_detector \ No newline at end of file +::: PytorchWildlife.models.detection.base_detector diff --git a/docs/base/models/detection/herdnet.md b/docs/base/models/detection/herdnet.md index da27b9618..b77ad508d 100644 --- a/docs/base/models/detection/herdnet.md +++ b/docs/base/models/detection/herdnet.md @@ -1,3 +1,3 @@ # HerdNet -::: PytorchWildlife.models.detection.herdnet \ No newline at end of file +::: PytorchWildlife.models.detection.herdnet diff --git a/docs/base/models/detection/herdnet/animaloc/data/patches.md b/docs/base/models/detection/herdnet/animaloc/data/patches.md index 44ddb7616..1be5d70eb 100644 --- a/docs/base/models/detection/herdnet/animaloc/data/patches.md +++ b/docs/base/models/detection/herdnet/animaloc/data/patches.md @@ -1,3 +1,3 @@ # Patches -::: PytorchWildlife.models.detection.herdnet.animaloc.data.patches \ No newline at end of file +::: PytorchWildlife.models.detection.herdnet.animaloc.data.patches diff --git a/docs/base/models/detection/herdnet/animaloc/data/types.md b/docs/base/models/detection/herdnet/animaloc/data/types.md index f1e42fae0..56dba606b 100644 --- a/docs/base/models/detection/herdnet/animaloc/data/types.md +++ b/docs/base/models/detection/herdnet/animaloc/data/types.md @@ -1,3 +1,3 @@ # Types -::: PytorchWildlife.models.detection.herdnet.animaloc.data.types \ No newline at end of file +::: PytorchWildlife.models.detection.herdnet.animaloc.data.types diff --git a/docs/base/models/detection/herdnet/animaloc/eval/lmds.md b/docs/base/models/detection/herdnet/animaloc/eval/lmds.md index 76682ff3a..663f6daa0 100644 --- a/docs/base/models/detection/herdnet/animaloc/eval/lmds.md +++ b/docs/base/models/detection/herdnet/animaloc/eval/lmds.md @@ -1,3 +1,3 @@ # LMDS -::: PytorchWildlife.models.detection.herdnet.animaloc.eval.lmds \ No newline at end of file +::: PytorchWildlife.models.detection.herdnet.animaloc.eval.lmds diff --git a/docs/base/models/detection/herdnet/animaloc/eval/stitchers.md b/docs/base/models/detection/herdnet/animaloc/eval/stitchers.md index 74939e373..325f5158f 100644 --- a/docs/base/models/detection/herdnet/animaloc/eval/stitchers.md +++ b/docs/base/models/detection/herdnet/animaloc/eval/stitchers.md @@ -1,3 +1,3 @@ # Stitchers -::: PytorchWildlife.models.detection.herdnet.animaloc.eval.stitchers \ No newline at end of file +::: PytorchWildlife.models.detection.herdnet.animaloc.eval.stitchers diff --git a/docs/base/models/detection/herdnet/dla.md b/docs/base/models/detection/herdnet/dla.md index bf718a19f..2422b7e63 100644 --- a/docs/base/models/detection/herdnet/dla.md +++ b/docs/base/models/detection/herdnet/dla.md @@ -1,3 +1,3 @@ # DLA -::: PytorchWildlife.models.detection.herdnet.dla \ No newline at end of file +::: PytorchWildlife.models.detection.herdnet.dla diff --git a/docs/base/models/detection/herdnet/model.md b/docs/base/models/detection/herdnet/model.md index c6ea9ce47..a4a5b685a 100644 --- a/docs/base/models/detection/herdnet/model.md +++ b/docs/base/models/detection/herdnet/model.md @@ -1,3 +1,3 @@ # Model -::: PytorchWildlife.models.detection.herdnet.model \ No newline at end of file +::: PytorchWildlife.models.detection.herdnet.model diff --git a/docs/base/models/detection/ultralytics_based/Deepfaune.md b/docs/base/models/detection/ultralytics_based/Deepfaune.md index f432ea92f..43956a2ad 100644 --- a/docs/base/models/detection/ultralytics_based/Deepfaune.md +++ b/docs/base/models/detection/ultralytics_based/Deepfaune.md @@ -1,3 +1,3 @@ # Deepfaune -::: PytorchWildlife.models.detection.ultralytics_based.Deepfaune \ No newline at end of file +::: PytorchWildlife.models.detection.ultralytics_based.Deepfaune diff --git a/docs/base/models/detection/ultralytics_based/megadetectorv5.md b/docs/base/models/detection/ultralytics_based/megadetectorv5.md index acc45ca84..e91a3d86b 100644 --- a/docs/base/models/detection/ultralytics_based/megadetectorv5.md +++ b/docs/base/models/detection/ultralytics_based/megadetectorv5.md @@ -1,3 +1,3 @@ # MegaDetector v5 -::: PytorchWildlife.models.detection.ultralytics_based.megadetectorv5 \ No newline at end of file +::: PytorchWildlife.models.detection.ultralytics_based.megadetectorv5 diff --git a/docs/base/models/detection/ultralytics_based/megadetectorv6.md b/docs/base/models/detection/ultralytics_based/megadetectorv6.md index 62fc7b329..fd8db0851 100644 --- a/docs/base/models/detection/ultralytics_based/megadetectorv6.md +++ b/docs/base/models/detection/ultralytics_based/megadetectorv6.md @@ -1,3 +1,3 @@ # MegaDetector v6 -::: PytorchWildlife.models.detection.ultralytics_based.megadetectorv6 \ No newline at end of file +::: PytorchWildlife.models.detection.ultralytics_based.megadetectorv6 diff --git a/docs/base/models/detection/ultralytics_based/megadetectorv6_distributed.md b/docs/base/models/detection/ultralytics_based/megadetectorv6_distributed.md index afab79d5c..b67a56fc8 100644 --- a/docs/base/models/detection/ultralytics_based/megadetectorv6_distributed.md +++ b/docs/base/models/detection/ultralytics_based/megadetectorv6_distributed.md @@ -1,3 +1,3 @@ # MegaDetector v6 Distributed -::: PytorchWildlife.models.detection.ultralytics_based.megadetectorv6_distributed \ No newline at end of file +::: PytorchWildlife.models.detection.ultralytics_based.megadetectorv6_distributed diff --git a/docs/base/models/detection/ultralytics_based/yolov5_base.md b/docs/base/models/detection/ultralytics_based/yolov5_base.md index 57366f9b2..ffbf8d2f4 100644 --- a/docs/base/models/detection/ultralytics_based/yolov5_base.md +++ b/docs/base/models/detection/ultralytics_based/yolov5_base.md @@ -1,3 +1,3 @@ # YOLOv5 Base -::: PytorchWildlife.models.detection.ultralytics_based.yolov5_base \ No newline at end of file +::: PytorchWildlife.models.detection.ultralytics_based.yolov5_base diff --git a/docs/base/models/detection/ultralytics_based/yolov8_base.md b/docs/base/models/detection/ultralytics_based/yolov8_base.md index b71b3ac0e..7661d844a 100644 --- a/docs/base/models/detection/ultralytics_based/yolov8_base.md +++ b/docs/base/models/detection/ultralytics_based/yolov8_base.md @@ -1,3 +1,3 @@ # YOLOv8 Base -::: PytorchWildlife.models.detection.ultralytics_based.yolov8_base \ No newline at end of file +::: PytorchWildlife.models.detection.ultralytics_based.yolov8_base diff --git a/docs/base/models/detection/ultralytics_based/yolov8_distributed.md b/docs/base/models/detection/ultralytics_based/yolov8_distributed.md index 12feadc67..ebda79a85 100644 --- a/docs/base/models/detection/ultralytics_based/yolov8_distributed.md +++ b/docs/base/models/detection/ultralytics_based/yolov8_distributed.md @@ -1,3 +1,3 @@ # YOLOv8 Distributed -::: PytorchWildlife.models.detection.ultralytics_based.yolov8_distributed \ No newline at end of file +::: PytorchWildlife.models.detection.ultralytics_based.yolov8_distributed diff --git a/docs/base/overview.md b/docs/base/overview.md index d2833a1ef..648e36c68 100644 --- a/docs/base/overview.md +++ b/docs/base/overview.md @@ -34,4 +34,4 @@ from PytorchWildlife.models import classification, detection from PytorchWildlife.utils import misc, post_process ``` -Refer to the specific submodule documentation for detailed usage instructions. \ No newline at end of file +Refer to the specific submodule documentation for detailed usage instructions. diff --git a/docs/base/utils/misc.md b/docs/base/utils/misc.md index 72139cdf8..aed3f6d5a 100644 --- a/docs/base/utils/misc.md +++ b/docs/base/utils/misc.md @@ -1 +1 @@ -::: PytorchWildlife.utils.misc \ No newline at end of file +::: PytorchWildlife.utils.misc diff --git a/docs/base/utils/post_process.md b/docs/base/utils/post_process.md index 02b513c1b..eea069bcc 100644 --- a/docs/base/utils/post_process.md +++ b/docs/base/utils/post_process.md @@ -1 +1 @@ -::: PytorchWildlife.utils.post_process \ No newline at end of file +::: PytorchWildlife.utils.post_process diff --git a/docs/build_mkdocs.md b/docs/build_mkdocs.md index 5f80d7041..24ab17adb 100644 --- a/docs/build_mkdocs.md +++ b/docs/build_mkdocs.md @@ -33,4 +33,3 @@ To build the MkDocs site locally, follow these steps: 5. **Exclude the `site/` Directory**: The `site/` directory is automatically generated and should not be included in version control. It is already added to `.gitignore`. - diff --git a/docs/cite.md b/docs/cite.md index d75ba03e3..aa832a5a2 100644 --- a/docs/cite.md +++ b/docs/cite.md @@ -21,4 +21,4 @@ Also, don't forget to cite our original paper for MegaDetector: eprint={1907.06772}, archivePrefix={arXiv}, } -``` \ No newline at end of file +``` diff --git a/docs/contribute.md b/docs/contribute.md index 5dae03407..02ce775ca 100644 --- a/docs/contribute.md +++ b/docs/contribute.md @@ -16,4 +16,4 @@ Thanks for your interest in collaborating on Pytorch-Wildlife! Here you can find Thank you for helping us improve PytorchWildlife! -We have adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [us](mailto:zhongqimiao@microsoft.com) with any additional questions or comments. \ No newline at end of file +We have adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [us](mailto:zhongqimiao@microsoft.com) with any additional questions or comments. diff --git a/docs/contributors.md b/docs/contributors.md index 262400844..f929476b8 100644 --- a/docs/contributors.md +++ b/docs/contributors.md @@ -1 +1 @@ -# In construction \ No newline at end of file +# In construction diff --git a/docs/demo_and_ui/demo_data.md b/docs/demo_and_ui/demo_data.md index d0a53015f..41cb1ec9a 100644 --- a/docs/demo_and_ui/demo_data.md +++ b/docs/demo_and_ui/demo_data.md @@ -1,2 +1,2 @@ # Prepare data for demos -Before we run any demos, please download some demo data for our demo notebooks and webapps from this [link](https://zenodo.org/records/15376499/files/demo_data_base.zip?download=1) and decompress the data in the [demo folder](https://github.com/microsoft/CameraTraps/tree/main/demo). \ No newline at end of file +Before we run any demos, please download some demo data for our demo notebooks and webapps from this [link](https://zenodo.org/records/15376499/files/demo_data_base.zip?download=1) and decompress the data in the [demo folder](https://github.com/microsoft/CameraTraps/tree/main/demo). diff --git a/docs/demo_and_ui/ecoassist.md b/docs/demo_and_ui/ecoassist.md index a28fcc04b..ed71206a7 100644 --- a/docs/demo_and_ui/ecoassist.md +++ b/docs/demo_and_ui/ecoassist.md @@ -1,2 +1,2 @@ # Pytorch-Wildlife modelsa are available with AddaxAI (formerly EcoAssist)! -We are thrilled to announce our collaboration with [AddaxAI](https://addaxdatascience.com/addaxai/#spp-models)---a powerful user interface software that enables users to directly load models from the PyTorch-Wildlife model zoo for image analysis on local computers. With AddaxAI, you can now utilize MegaDetectorV5 and the classification models---AI4GAmazonRainforest and AI4GOpossum---for automatic animal detection and identification, alongside a comprehensive suite of pre- and post-processing tools. This partnership aims to enhance the overall user experience with PyTorch-Wildlife models for a general audience. We will work closely to bring more features together for more efficient and effective wildlife analysis in the future. Please refer to their tutorials on how to use Pytorch-Wildlife models with AddaxAI. \ No newline at end of file +We are thrilled to announce our collaboration with [AddaxAI](https://addaxdatascience.com/addaxai/#spp-models)---a powerful user interface software that enables users to directly load models from the PyTorch-Wildlife model zoo for image analysis on local computers. With AddaxAI, you can now utilize MegaDetectorV5 and the classification models---AI4GAmazonRainforest and AI4GOpossum---for automatic animal detection and identification, alongside a comprehensive suite of pre- and post-processing tools. This partnership aims to enhance the overall user experience with PyTorch-Wildlife models for a general audience. We will work closely to bring more features together for more efficient and effective wildlife analysis in the future. Please refer to their tutorials on how to use Pytorch-Wildlife models with AddaxAI. diff --git a/docs/demo_and_ui/gradio.md b/docs/demo_and_ui/gradio.md index ee7f7b955..4fc83b482 100644 --- a/docs/demo_and_ui/gradio.md +++ b/docs/demo_and_ui/gradio.md @@ -24,4 +24,4 @@ pip uninstall opencv-python conda install -c conda-forge opencv ``` -![image](https://zenodo.org/records/15376499/files/gradio_UI.png) \ No newline at end of file +![image](https://zenodo.org/records/15376499/files/gradio_UI.png) diff --git a/docs/demo_and_ui/notebook.md b/docs/demo_and_ui/notebook.md index 30f4f7405..fc1cf6657 100644 --- a/docs/demo_and_ui/notebook.md +++ b/docs/demo_and_ui/notebook.md @@ -4,4 +4,4 @@ To run our demo notebooks and scripts, please go to the `demo` folder and follow the instructions in our [installation page](../installation.md) on how to use jupyter notebooks. ### Online notebooks -We are currently migrating our demo jupyter notebooks to this document page so you do not have to always run jupyter on your local machines. Please keep updates! \ No newline at end of file +We are currently migrating our demo jupyter notebooks to this document page so you do not have to always run jupyter on your local machines. Please keep updates! diff --git a/docs/demo_and_ui/timelapse.md b/docs/demo_and_ui/timelapse.md index 4b64fcbfb..5e6665a78 100644 --- a/docs/demo_and_ui/timelapse.md +++ b/docs/demo_and_ui/timelapse.md @@ -1,2 +1,2 @@ # Pytorch-Wildlife and TimeLapse -Pytorch-Wildlife offers native output formats that are directly compatible with TimeLapse. We will provide more details on this in future releases! We will keep you posted! \ No newline at end of file +Pytorch-Wildlife offers native output formats that are directly compatible with TimeLapse. We will provide more details on this in future releases! We will keep you posted! diff --git a/docs/fine_tuning_modules/detection/overview.md b/docs/fine_tuning_modules/detection/overview.md index d4e4cbb9d..114041389 100644 --- a/docs/fine_tuning_modules/detection/overview.md +++ b/docs/fine_tuning_modules/detection/overview.md @@ -1,2 +1,2 @@ -# In Construction \ No newline at end of file +# In Construction diff --git a/docs/fine_tuning_modules/overview.md b/docs/fine_tuning_modules/overview.md index 262400844..f929476b8 100644 --- a/docs/fine_tuning_modules/overview.md +++ b/docs/fine_tuning_modules/overview.md @@ -1 +1 @@ -# In construction \ No newline at end of file +# In construction diff --git a/docs/index.md b/docs/index.md index eb5d80457..6e93e7f54 100644 --- a/docs/index.md +++ b/docs/index.md @@ -60,4 +60,3 @@ Please refer to our [installation guide](installation.md) for more installation ### Opossum ID with `MegaDetector` and `AI4GOpossum` opossum_det
*Credits to the Agency for Regulation and Control of Biosecurity and Quarantine for Galápagos (ABG), Ecuador.* - diff --git a/docs/license.md b/docs/license.md index c089e7740..915cd8b4e 100644 --- a/docs/license.md +++ b/docs/license.md @@ -13,4 +13,4 @@ In addition, since the **Pytorch-Wildlife** package is under MIT, all the utilit ``` --8<-- "LICENSE.md" -``` \ No newline at end of file +``` diff --git a/docs/megadetector.md b/docs/megadetector.md index 11991fdb0..acdf38f03 100644 --- a/docs/megadetector.md +++ b/docs/megadetector.md @@ -1,3 +1,3 @@ --8<-- "megadetector.md" - \ No newline at end of file + diff --git a/docs/model_zoo/classifiers.md b/docs/model_zoo/classifiers.md index f83de560b..5b78baaab 100644 --- a/docs/model_zoo/classifiers.md +++ b/docs/model_zoo/classifiers.md @@ -10,4 +10,4 @@ |Deepfaune-New-England|v1.0|CC0 1.0 Universal|Released|[Deepfaune-New-England](https://code.usgs.gov/vtcfwru/deepfaune-new-england)| >[!TIP] ->Some models, such as MegaDetectorV6, HerdNet, and AI4G-Amazon, have different versions, and they are loaded by their corresponding version names. Here is an example: `detection_model = pw_detection.MegaDetectorV6(version="MDV6-yolov10-e")`. \ No newline at end of file +>Some models, such as MegaDetectorV6, HerdNet, and AI4G-Amazon, have different versions, and they are loaded by their corresponding version names. Here is an example: `detection_model = pw_detection.MegaDetectorV6(version="MDV6-yolov10-e")`. diff --git a/docs/model_zoo/other_detectors.md b/docs/model_zoo/other_detectors.md index 9223a523d..c3ebdf8af 100644 --- a/docs/model_zoo/other_detectors.md +++ b/docs/model_zoo/other_detectors.md @@ -7,4 +7,4 @@ |HerdNet-ennedi|ennedi|CC BY-NC-SA-4.0|Released|[Alexandre et. al. 2023](https://github.com/Alexandre-Delplanque/HerdNet)| >[!TIP] ->Some models, such as MegaDetectorV6, HerdNet, and AI4G-Amazon, have different versions, and they are loaded by their corresponding version names. Here is an example: `detection_model = pw_detection.MegaDetectorV6(version="MDV6-yolov10-e")`. \ No newline at end of file +>Some models, such as MegaDetectorV6, HerdNet, and AI4G-Amazon, have different versions, and they are loaded by their corresponding version names. Here is an example: `detection_model = pw_detection.MegaDetectorV6(version="MDV6-yolov10-e")`. diff --git a/docs/releases/past_releases.md b/docs/releases/past_releases.md index 8fd371dab..3d6fc9b1f 100644 --- a/docs/releases/past_releases.md +++ b/docs/releases/past_releases.md @@ -6,4 +6,4 @@ - We will also make a new roadmap for 2025 in the next couple of updates. - Special thanks to [José Díaz](https://github.com/jdiaz97) for his great cross-platform app, [BoquilaHUB](https://github.com/boquila/boquilahub), that is even working on ios and android! Please check his repo out! In the future, we will create a project gallery showcasing projects that use or are build upon Pytorch-Wildlife. If you want your projects to be included, please feel free to reach out to us on [![](https://img.shields.io/badge/any_text-Join_us!-blue?logo=discord&label=PytorchWildife)](https://discord.gg/TeEVxzaYtm) -animal_det_1
\ No newline at end of file +animal_det_1
diff --git a/docs/releases/release_notes.md b/docs/releases/release_notes.md index 217637604..f69cc1a85 100644 --- a/docs/releases/release_notes.md +++ b/docs/releases/release_notes.md @@ -18,4 +18,3 @@ classification_model = pw_classification.DeepfauneClassifier(device=DEVICE) #### Deepfaune-New-England in Our Model Zoo Too!! - Besides the original Deepfaune mode, there is another fine-tuned Deepfaune model developed by USGS for the Northeastern NA area called Deepfaune-New-England (DFNE). It can also be loaded with `classification_model = pw_classification.DFNE(device=DEVICE)` - Please take a look at the orignal [DFNE repo](https://code.usgs.gov/vtcfwru/deepfaune-new-england/-/tree/main?ref_type=heads) and give them a star! - diff --git a/mkdocs.yml b/mkdocs.yml index d3243286f..fbb8f5c4e 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -181,4 +181,3 @@ plugins: python: options: docstring_style: google - diff --git a/setup.py b/setup.py index 3402d9614..0ed2440bc 100644 --- a/setup.py +++ b/setup.py @@ -1,42 +1,42 @@ -from setuptools import setup, find_packages +from setuptools import find_packages, setup -with open('README.md', encoding="utf8") as file: - long_description = file.read() +with open("README.md", encoding="utf8") as file: + long_description = file.read() setup( - name='PytorchWildlife', - version='1.2.4.2', + name="PytorchWildlife", + version="1.2.4.2", packages=find_packages(), include_package_data=True, package_data={"": ["*.yml"]}, - url='https://github.com/microsoft/CameraTraps/', - license='MIT', - author='Andres Hernandez, Zhongqi Miao, Daniela Ruiz Lopez, Isai Daniel Chacon Silva', - author_email='v-hernandres@microsoft.com, zhongqimiao@microsoft.com, v-druizlopez@microsoft.com, v-ichaconsil@microsoft.com', - description='a PyTorch Collaborative Deep Learning Framework for Conservation.', + url="https://github.com/microsoft/CameraTraps/", + license="MIT", + author="Andres Hernandez, Zhongqi Miao, Daniela Ruiz Lopez, Isai Daniel Chacon Silva", + author_email="v-hernandres@microsoft.com, zhongqimiao@microsoft.com, v-druizlopez@microsoft.com, v-ichaconsil@microsoft.com", + description="a PyTorch Collaborative Deep Learning Framework for Conservation.", long_description=long_description, - long_description_content_type='text/markdown', + long_description_content_type="text/markdown", install_requires=[ - 'torch', - 'torchvision', - 'torchaudio', - 'tqdm', - 'Pillow', - 'supervision==0.23.0', - 'gradio', - 'ultralytics', - 'chardet', - 'wget', - 'yolov5', - 'setuptools', - 'scikit-learn', - 'timm', + "torch", + "torchvision", + "torchaudio", + "tqdm", + "Pillow", + "supervision==0.23.0", + "gradio", + "ultralytics", + "chardet", + "wget", + "yolov5", + "setuptools", + "scikit-learn", + "timm", ], classifiers=[ - 'Development Status :: 3 - Alpha', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: MIT License', - 'Programming Language :: Python :: 3', + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", ], - keywords='pytorch_wildlife, pytorch, wildlife, megadetector, conservation, animal, detection, classification', - python_requires='>=3.8', + keywords="pytorch_wildlife, pytorch, wildlife, megadetector, conservation, animal, detection, classification", + python_requires=">=3.8", )