A template for simple deep learning projects using Lightning
English | 中文
PyTorch Lightning is to deep learning project development as MVC frameworks (such as Spring, Django, etc.) are to website development. While it is possible to implement everything from scratch and achieve maximum flexibility (especially since PyTorch and its ecosystem are already quite straightforward), using a framework can help you quickly implement prototypes with guidance from "best practices" (personal opinion) to save a lot of boilerplate code through re-usability, and focus on scientific innovation rather than engineering challenges. This template is built using the full Lightning suite, follows the principle of Occam's razor, and is friendly to researchers. It also includes a simple handwritten digit recognition task using the MNIST dataset. The repository also contains some Tips, for reference.
Using Pytorch Lightning as a deep learning framework:
Most of the deep learning code can be divided into the following three parts(Reference [Chinese]):
-
Research code: This part pertains to the model and generally deals with customizations of the model's structure and training. In
Lightning, this code is abstracted as thepl.LightningModuleclass. While dataset definition can also be included in this part, it is not recommended as it is not relevant to the experiment and should be included inpl.LightningDataModuleinstead. -
Engineering code: This part of the code is essential for its high repeatability, such as setting early stopping, 16-bit precision, and GPU distributed training. In
Lightning, this code is abstracted as thepl.Trainerclass. -
Non-essential code: This code is helpful in conducting experiments but is not directly related to the experiment itself, and can even be omitted. For example, gradient checking and outputting logs to
TensorBoard. In Lightning, this code is abstracted as theCallbacksclass, which is registered topl.Trainer.
The advantages of using Lightning:
-
Custom training processes and learning rate adjustment strategies can be implemented through various hook functions in
pl.LightningModule. -
The model and data no longer need to be explicitly designated for devices (
tensor.to,tensor.cuda, etc.).pl.Trainerhandles this automatically, thereby supporting various acceleration devices such as CPU, GPU, and TPU. -
pl.Trainerimplements various training strategies, such as automatic mixed precision training, multi-GPU training, and distributed training. -
pl.Trainerimplements multiple callbacks such as automatic model saving, automatic config saving, and automatic visualization result saving.
Using Pytorch Lightning CLI as a command-line tool:
-
Using
lightning_clias the program entry point, model, data, and training parameters can be set through configuration files or command-line parameters, thereby achieving quick switching between multiple experiments. -
pl.LightningModule.save_hyperparameters()saves the model's hyperparameters and automatically generates a command-line parameter table, eliminating the need for tools such asargparseorhydra.
Using Torchmetrics as a metric computation tool:
-
Torchmetricsprovides multiple metric calculation methods such asAccuracy,Precision, andRecall. -
It is integrated with
Lightningand is compatible with parallel training strategies. Data is automatically aggregated to the main process for metric computation.
[Optional] Using WanDB to track experiments
graph TD;
A[LightningCLI]---B[LightningModule]
A---C[LightningDataModule]
B---D[models]
B---E[metrics]
B---F[...]
C---G[dataloaders]
G---H[datasets]
├── configs # Configuration files
│ ├── data # Dataset configuration
│ │ └── mnist.yaml # Example configuration for MNIST dataset
│ ├── model # Model configuration
│ │ └── simplenet.yaml # Example configuration for SimpleNet model
│ └── default.yaml # Default configuration
├── data # Dataset directory
├── logs # Log directory
├── notebooks # Jupyter Notebook directory
├── scripts # Script directory
│ └── clear_wandb_cache.py # Example script to clear wandb cache
├── src # Source code directory
│ ├── callbacks # Callbacks directory
│ │ └── __init__.py
│ ├── data_modules # Data module directory
│ │ ├── __init__.py
│ │ └── mnist.py # Example data module for MNIST dataset
│ ├── metrics # Metrics directory
│ │ └── __init__.py
│ ├── models # Model directory
│ │ ├── __init__.py
│ │ └── simplenet.py # Example SimpleNet model
│ ├── modules # Module directory
│ │ ├── __init__.py
│ │ └── mnist_module.py # Example MNIST module
│ ├── utils # Utility directory
│ │ ├── __init__.py
│ │ └── cli.py # CLI tool
│ ├── __init__.py
│ └── main.py # Main program entry point
├── .env.example # Example environment variable file
├── .gitignore # Ignore files for git
├── .project-root # Project root indicator file for pyrootutils
├── LICENSE # Open source license
├── pyproject.toml # Configuration file for Black and Ruff
├── README.md # Project documentation
├── README_PROJECT.md # Project documentation template
├── README_ZH.md # Project documentation in Chinese
└── requirements.txt # Dependency list
# Clone project
git clone https://github.com/DavidZhang73/pytorch-lightning-template <project_name>
cd <project_name>
# Install uv, https://docs.astral.sh/uv/getting-started/installation/
curl -LsSf https://astral.sh/uv/install.sh | sh
# Install dependencies
uv sync- Define dataset by inheriting
pl.LightningDataModuleinsrc/data_module. - Define dataset configuration file in
configs/dataas parameters for the custompl.LightningDataModule. - Define the model by inheriting
nn.Moduleinsrc/models. - Define metrics by inheriting
torchmetrics.Metricinsrc/metrics. - Define training module by inheriting
pl.LightningModuleinsrc/modules. - Define the configuration file for the training module in
configs/modelas parameters for the custompl.LightningModule. - Configure
pl.trainer, logs and other parameters inconfigs/default.yaml.
Fit
python src/main.py fit -c configs/data/mnist.yaml -c configs/model/simplenet.yaml --trainer.logger.name exp1Validate
python src/main.py validate -c configs/data/mnist.yaml -c configs/model/simplenet.yaml --trainer.logger.name exp1Test
python src/main.py test -c configs/data/mnist.yaml -c configs/model/simplenet.yaml --trainer.logger.name exp1Inference
python src/main.py predict -c configs/data/mnist.yaml -c configs/model/simplenet.yaml --trainer.logger.name exp1Debug
python src/main.py fit -c configs/data/mnist.yaml -c configs/model/simplenet.yaml --trainer.fast_dev_run trueResume
python src/main.py fit -c configs/data/mnist.yaml -c configs/model/simplenet.yaml --ckpt_path <ckpt_path> --trainer.logger.id exp1_idUsing the print_config functionality of jsonargparse, you can obtain the parsed arguments and generate default yaml files. However, it is necessary to first configure the yaml files for data and model.
python src/main.py fit -c configs/data/mnist.yaml -c configs/model/simplenet.yaml --print_configPrepare a config file for the CLI
This template implements a custom CLI (CustomLightningCLI) to achieve the following functions,
- When starting the program, the configuration file is automatically saved to the corresponding log directory, for
WandbLoggeronly. - When starting the program, save configurations for optimizer and scheduler to loggers.
- When starting the program, the default configuration file is automatically loaded.
- After the test is completed, the
checkpoint_pathused for testing is printed. - Add some command line parameters:
--ignore_warnings(default:False): Ignore all warnings.--test_after_fit(default:False): Automatically test after each training.--git_commit_before_fit(default:False):git commitbefore each training, the commit message is{logger.name}_{logger.version}, forWandbLoggeronly.
CONFIGURE HYPERPARAMETERS FROM THE CLI (EXPERT)
When running on a server, especially when the CPU has a lot of cores (>=24), you may encounter the problem of too many numpy processes, which may cause the experiment to inexplicably hang. You can limit the number of numpy processes by setting environment variables (in the .env file).
OMP_NUM_THREADS=8
MKL_NUM_THREADS=8
GOTO_NUM_THREADS=8
NUMEXPR_NUM_THREADS=8
OPENBLAS_NUM_THREADS=8
MKL_DOMAIN_NUM_THREADS=8
VECLIB_MAXIMUM_THREADS=8
.envfile is automatically loaded to environment bypyrootutilsviapython-dotenv.
Stack Overflow: Limit number of threads in numpy
When you delete an experiment from the wandb web page, the cache of the experiment still exists in the local wandb directory, you can use the scripts/clear_wandb_cache.py script to clear the cache.
Inspired by,