Skip to content

NikolayNyunin/image-generation-evaluation

Repository files navigation

Оценка качества генерации изображений

Python

Оглавление

Описание

Общая информация

Данный проект представляет собой инструмент для оценки моделей генерации изображений с использованием различных метрик.

Проект оформлен в соответствии с лучшими современными MLOps-практиками и может быть легко расширен для использования новых моделей, датасетов, метрик и систем логирования и мониторинга.

Данные

  • AFHQv2 (Animal Faces High Quality v2). В качестве изначального набора данных для обучения моделей был выбран датасет AFHQv2. Он содержит около 15.8 тысяч изображений мордочек котов, собак и диких животных с разрешением 512x512x3. Из-за однородности изображений данный датасет отлично подходит для обучения генеративных моделей небольшого размера.

Данные сохранены с использованием DVC S3 remote, скачивание их происходит при необходимости в начале обучения, так что вручную ничего загружать не нужно.

Из-за модульной структуры проекта в дальнейшем с минимальными усилиями можно будет добавить и другие датасеты, если это потребуется.

Модели

  • WGAN-GP (Wasserstein Generative Adversarial Network with Gradient Penalty). Это бейзлайн-модель, состоящая из очень маленьких сетей генератора и критика. Модель простая и легковесная, стабильно обучается, но всё же способна генерировать достаточно неплохие изображения, хоть и низкого разрешения.

В ближайшем будущем планируется добавление и других моделей для обучения и тестирования.

Метрики

На данный момент вычисляются и логируются только метрики, получаемые в процессе обучения модели:

  • Ошибка генератора (loss/generator)
  • Ошибка критика (loss/critic)
  • Средняя оценка реальных изображений (critic/real)
  • Средняя оценка синтетических изображений (critic/fake)
  • Средний размер градиентного штрафа критика (critic/gp)

По этим метрикам можно оценить ход обучения, но они напрямую не связаны с качеством генерируемых изображений, поэтому в ближайшем будущем планируется добавить такие метрики, как IS и FID, для оценки именно финальных изображений, не привязываясь к модели, генерирующей их.

Помимо метрик логируются также все основные гиперпараметры для данных и модели, id текущего коммита репозитория (в виде тега в MLflow) и примеры изображений после каждой эпохи обучения в виде артефактов.

Используемые инструменты

  • uv - Управление окружением и зависимостями
  • pre-commit - Запуск форматтеров и линтеров для поддержания качества кода (настройки можно посмотреть в файлах .pre-commit-config.yaml и pyproject.toml)
  • dvc - Хранение и версионирование данных
  • hydra - Управление конфигами с гиперпараметрами
  • fire - CLI проекта
  • lightning - Обучение и применение моделей
  • mlflow - Внешняя система логирования экспериментов (напрямую не запускается; ожидается, что уже запущена до вызова проекта)

Установка (Setup)

Так как в проекте используется uv, чтобы установить основные зависимости проекта, достаточно выполнить команду:

uv sync

Часть зависимостей, таких как mlflow и pre-commit, вынесены в отдельную группу зависимостей dev, так как не требуются для минимального сценария запуска проекта.

Установить все зависимости, включая группу dev, можно командой:

uv sync --dev

Примечание: torch и torchvision по возможности ставятся с поддержкой CUDA 12.1.

В результате выполнения одной из данных команд будет создано виртуальное окружение .venv в корне проекта, далее его необходимо активировать:

source .venv/bin/activate  # Linux/MacOS
source .venv/Scripts/activate  # Windows (Git Bash)
.venv\Scripts\Activate.ps1  # Windows (PowerShell)
.venv\Scripts\activate.bat  # Windows (CMD)

Установить git-хуки и запустить проверку кода проекта с использованием pre-commit при установленных dev-зависимостях можно при помощи команд:

pre-commit install
pre-commit run -a

Запуск обучения (Train)

Всё взаимодействие с проектом происходит через CLI, реализованный при помощи связки Fire + Hydra Compose API.

Корневой файл для использования CLI - commands.py, в нём собраны все доступные команды.

Вызывать команды необходимо, находясь в корне проекта.

Запустить обучение модели в базовом варианте можно командой:

python commands.py train

Данный способ запуска поддерживает перезапись гиперпараметров "на лету" в формате Hydra CLI. Например, для изменения числа эпох обучения и размера батча можно использовать команду:

python commands.py train train.max_epochs=5 data.batch_size=512

Полный список доступных гиперпараметров можно посмотреть в конфигах в папке configs.

❗ Обратите внимание, что для успешного запуска обучения НЕОБХОДИМО наличие запущенного MLflow-сервера по адресу http://127.0.0.1:8080

Запустить MLflow локально при установленных dev-зависимостях можно при помощи команды:

mlflow ui -p 8080

Запуск генерации

*Work in progress*

Image Generation Quality Evaluation

*Work in progress*

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages