The nn_handler repository provides a comprehensive and flexible Python framework designed to streamline the development, training, evaluation, and management of PyTorch neural network models. It aims to abstract away boilerplate code, allowing researchers and developers to focus on model architecture and experimentation.
NNHandler offers a unified interface supporting:
- Standard training and validation loops.
- Advanced features like Automatic Mixed Precision (AMP), gradient accumulation, and Exponential Moving Average (EMA).
- Seamless integration with Distributed Data Parallel (DDP) for multi-GPU and multi-node training.
- A rich, extensible callback system for monitoring, checkpointing, visualization, and custom logic.
- Built-in support for generative models, including score-based models (SDEs) and custom samplers.
- Comprehensive model saving and loading, including full training state resumption.
- Integrated logging and metric tracking with plotting capabilities.
- Support for
torch.compilefor potential performance boosts.
pip install --upgrade git+https://github.com/rouzib/NNHandler.git-
Clone the repository:
git clone https://github.com/rouzib/NNHandler.git cd NNHandler # Or your repository's root directory
-
Install dependencies: Ensure you have Python 3.10+ and PyTorch 1.10+ installed. It's recommended to use a virtual environment.
pip install -r requirements.txt
Note: For specific features like AMP, EMA, or plotting, ensure the corresponding libraries (
torch_ema,matplotlib,tqdm) are installed.
For a quick introduction and basic usage examples, please refer to the Getting Started Guide.
The framework revolves around the central NNHandler class and its supporting modules:
NNHandlerClass (nn_handler/nn_handler.md): The main orchestrator for model training, evaluation, and management. Implemented innn_handler_distributed.py.- Distributed Training (nn_handler/distributed.md): Details on how
NNHandlerintegrates with PyTorch DDP. - AutoSaver (nn_handler/autosaver.md): Functionality for automatic checkpoint saving during training (integrated within
NNHandler). - Sampler (nn_handler/sampler.md): Support for custom sampling algorithms via the
Samplerbase class. - Callbacks (nn_handler/callbacks/README.md): A powerful system for customizing the training loop with various hooks.
- Utilities (nn_handler/utils/README.md): A collection of utility functions and classes, including DDP utilities for working with PyTorch's Distributed Data Parallel functionality.
- Getting Started: Quick-start tutorial.
- NNHandler Module: Overview of the core module.
- NNHandler Class: Detailed API reference for
NNHandler. - Distributed Training (DDP): Guide to using DDP features.
- SLURM Job Templates: Templates for running distributed training on HPC clusters (single_node, multi-node/single-GPU, multi-node/multi-GPU).
- AutoSaver Feature: Auto-saving configuration.
- Sampler Integration: Using custom samplers.
- Callbacks System: Introduction to callbacks and available implementations.
- Utilities: Documentation for utility functions and classes, including DDP utilities.
- NNHandler Class: Detailed API reference for
import torch
from src.nn_handler import NNHandler # Assuming src layout
# from your_model_file import YourModel, your_loss_fn
# from your_dataset_file import your_train_dataset, your_val_dataset
# Dummy components for illustration
class YourModel(torch.nn.Module):
def __init__(self): super().__init__(); self.linear = torch.nn.Linear(10, 1)
def forward(self, x): return self.linear(x)
def your_loss_fn(pred, target): return torch.nn.functional.mse_loss(pred, target)
your_train_dataset = torch.utils.data.TensorDataset(torch.randn(100, 10), torch.randn(100, 1))
your_val_dataset = torch.utils.data.TensorDataset(torch.randn(50, 10), torch.randn(50, 1))
# --- Main NNHandler Workflow ---
# 1. Initialize NNHandler (DDP auto-detected if environment is set)
handler = NNHandler(
model_class=YourModel,
device="cuda" if torch.cuda.is_available() else "cpu", # DDP will assign specific cuda device
logger_mode=NNHandler.LoggingMode.CONSOLE, # Log to console (Rank 0 only)
model_type=NNHandler.ModelType.REGRESSION # Specify model type
# Add model_kwargs if needed: hidden_units=128
)
# 2. Configure Components
handler.set_optimizer(torch.optim.Adam, lr=1e-3)
handler.set_loss_fn(your_loss_fn)
handler.set_train_loader(your_train_dataset, batch_size=16)
handler.set_val_loader(your_val_dataset, batch_size=16)
# Optional: Add Metrics & Callbacks
# def your_metric(pred, target): return torch.abs(pred - target).mean().item()
# handler.add_metric("mae", your_metric)
# from src.nn_handler.callbacks import ModelCheckpoint
# handler.add_callback(ModelCheckpoint(filepath="models/best_model.pth", monitor="val_loss"))
# 3. Train
handler.train(
epochs=10,
validate_every=1,
use_amp=True, # Enable AMP if available and on CUDA
gradient_accumulation_steps=2, # Example: Accumulate gradients
ema_decay=0.99 # Example: Use EMA
)
# 4. Save Final State (Rank 0 saves)
handler.save("models/final_handler_state.pth")
# 5. Load and Predict
# To run prediction/loading, ensure the environment matches (e.g., DDP or single process)
# loaded_handler = NNHandler.load("models/final_handler_state.pth")
# predictions = loaded_handler.predict(some_data_loader) # Predict gathers on Rank 0(Remember to run DDP examples using torchrun)
For high-performance computing (HPC) clusters using SLURM, NNHandler provides ready-to-use job templates:
-
Single Node, Multiple GPUs (single_node_multiple_gpu_slurm_job.sh):
# Copy the template and modify as needed cp doc/single_node_multiple_gpu_slurm_job.sh ./my_training_job.sh # Edit the script to update account, time, and Python file sbatch my_training_job.sh
-
Multiple Nodes, Single GPU per Node (multiple_nodes_single_gpu_slurm_job.sh):
# For distributed training across multiple nodes, each with a single GPU sbatch multiple_nodes_single_gpu_slurm_job.sh -
Multiple Nodes, Multiple GPUs per Node (multiple_nodes_multiple_gpu_slurm_job.sh):
# For large-scale distributed training with multiple GPUs per node sbatch multiple_nodes_multiple_gpu_slurm_job.sh
These templates handle all the necessary environment setup for distributed training, including proper initialization of NCCL for inter-node communication. See the Distributed Training documentation for detailed information.
This project is licensed under the MIT License - see the LICENSE file for details.