Данный проект представляет собой инструмент для оценки моделей генерации изображений с использованием различных метрик.
Проект оформлен в соответствии с лучшими современными MLOps-практиками и может быть легко расширен для использования новых моделей, датасетов, метрик и систем логирования и мониторинга.
- AFHQv2 (Animal Faces High Quality v2). В качестве изначального набора данных для обучения моделей был выбран датасет AFHQv2. Он содержит около 15.8 тысяч изображений мордочек котов, собак и диких животных с разрешением 512x512x3. Из-за однородности изображений данный датасет отлично подходит для обучения генеративных моделей небольшого размера.
Данные сохранены с использованием DVC S3 remote, скачивание их происходит при необходимости в начале обучения, так что вручную ничего загружать не нужно.
Из-за модульной структуры проекта в дальнейшем с минимальными усилиями можно будет добавить и другие датасеты, если это потребуется.
- WGAN-GP (Wasserstein Generative Adversarial Network with Gradient Penalty). Это бейзлайн-модель, состоящая из очень маленьких сетей генератора и критика. Модель простая и легковесная, стабильно обучается, но всё же способна генерировать достаточно неплохие изображения, хоть и низкого разрешения.
В ближайшем будущем планируется добавление и других моделей для обучения и тестирования.
На данный момент вычисляются и логируются только метрики, получаемые в процессе обучения модели:
- Ошибка генератора (
loss/generator) - Ошибка критика (
loss/critic) - Средняя оценка реальных изображений (
critic/real) - Средняя оценка синтетических изображений (
critic/fake) - Средний размер градиентного штрафа критика (
critic/gp)
По этим метрикам можно оценить ход обучения, но они напрямую не связаны с
качеством генерируемых изображений, поэтому в ближайшем будущем планируется
добавить такие метрики, как IS и FID, для оценки именно финальных
изображений, не привязываясь к модели, генерирующей их.
Помимо метрик логируются также все основные гиперпараметры для данных и модели, id текущего коммита репозитория (в виде тега в MLflow) и примеры изображений после каждой эпохи обучения в виде артефактов.
uv- Управление окружением и зависимостямиpre-commit- Запуск форматтеров и линтеров для поддержания качества кода (настройки можно посмотреть в файлах.pre-commit-config.yamlиpyproject.toml)dvc- Хранение и версионирование данныхhydra- Управление конфигами с гиперпараметрамиfire- CLI проектаlightning- Обучение и применение моделейmlflow- Внешняя система логирования экспериментов (напрямую не запускается; ожидается, что уже запущена до вызова проекта)
Так как в проекте используется uv, чтобы установить основные зависимости
проекта, достаточно выполнить команду:
uv syncЧасть зависимостей, таких как mlflow и pre-commit, вынесены в отдельную
группу зависимостей dev, так как не требуются для минимального сценария
запуска проекта.
Установить все зависимости, включая группу dev, можно командой:
uv sync --devПримечание: torch и torchvision по возможности ставятся с поддержкой CUDA
12.1.
В результате выполнения одной из данных команд будет создано виртуальное
окружение .venv в корне проекта, далее его необходимо активировать:
source .venv/bin/activate # Linux/MacOS
source .venv/Scripts/activate # Windows (Git Bash)
.venv\Scripts\Activate.ps1 # Windows (PowerShell)
.venv\Scripts\activate.bat # Windows (CMD)Установить git-хуки и запустить проверку кода проекта с использованием
pre-commit при установленных dev-зависимостях можно при помощи команд:
pre-commit install
pre-commit run -aВсё взаимодействие с проектом происходит через CLI, реализованный при помощи связки Fire + Hydra Compose API.
Корневой файл для использования CLI - commands.py, в нём собраны все доступные
команды.
Вызывать команды необходимо, находясь в корне проекта.
Запустить обучение модели в базовом варианте можно командой:
python commands.py trainДанный способ запуска поддерживает перезапись гиперпараметров "на лету" в формате Hydra CLI. Например, для изменения числа эпох обучения и размера батча можно использовать команду:
python commands.py train train.max_epochs=5 data.batch_size=512Полный список доступных гиперпараметров можно посмотреть в конфигах в папке
configs.
❗ Обратите внимание, что для успешного запуска обучения НЕОБХОДИМО наличие запущенного MLflow-сервера по адресу http://127.0.0.1:8080 ❗
Запустить MLflow локально при установленных dev-зависимостях можно при помощи
команды:
mlflow ui -p 8080*Work in progress*
*Work in progress*