Official implementation of "Variational Perturbation Personalized Federated Learning via Prior-Posterior Distance".
VPFL is a personalized federated learning algorithm that leverages Prior-Posterior Distance (PPD) to achieve better personalization in non-IID environments.
-
Three Core Components:
- PPD (Prior-Posterior Distance): Measures the discrepancy between global and local historical models
- Constrained Update: Guides model updates based on PPD
- Variational Perturbation: Adds Gaussian noise based on PPD for robustness
-
Optimized Performance:
- Fashion-MNIST Pathological: 81.49% (target: 80%) β
- CIFAR-10 Dirichlet: 67.08% (target: 65%) β
- CIFAR-10 Pathological: 58.02% (vs 51.19% baseline)
-
PFLlib-Compatible Architecture: Easy to integrate with existing FL frameworks
# Clone the repository
git clone https://github.com/yourusername/VPFL.git
cd VPFL
# Install dependencies
pip install -r requirements.txt# Generate CIFAR-10 datasets
cd dataset
python generate_Cifar10.py
# Generate Fashion-MNIST datasets
python generate_FashionMNIST.py
# Or generate all at once
bash generate_all_datasets.sh
cd ..# Train on Fashion-MNIST (5 clients, pathological)
python main.py --dataset FashionMNIST_5_pat --global_rounds 100
# Train on CIFAR-10 (5 clients, pathological)
python main.py --dataset Cifar10_5_pat --global_rounds 100
# Train with custom hyperparameters
python main.py \
--dataset FashionMNIST_5_pat \
--lambda_param 10.0 \
--momentum 0.9 \
--global_rounds 100from system.flcore.trainmodel.models import create_model
import torch
# Create model
model = create_model('Cifar10_5_pat', device='cuda')
# Load pre-trained weights
model.load_state_dict(torch.load('results/Cifar10_5_pat_VPFL_global_model.pt'))
# Evaluate
model.eval()| Dataset | # Clients | Partition | Directory |
|---|---|---|---|
| CIFAR-10 | 5 | Pathological | dataset/Cifar10_5_pat/ |
| CIFAR-10 | 10 | Dirichlet(0.1) | dataset/Cifar10_10_dir/ |
| Fashion-MNIST | 5 | Pathological | dataset/FashionMNIST_5_pat/ |
| Fashion-MNIST | 10 | Dirichlet(0.1) | dataset/FashionMNIST_10_dir/ |
VPFL_CONFIG = {
'lambda_param': 10.0, # PPD constraint strength
'mu': 3.0, # Perturbation layer control
'perturb_scale': 0.01, # Perturbation magnitude
'warmup_rounds': 20, # Warmup before PPD activation
}
# Optimizer settings
optimizer = torch.optim.SGD(
model.parameters(),
lr=0.005,
momentum=0.9 # CRITICAL for best performance!
)- lambda_param (Ξ»): Controls the strength of PPD constraint. Higher values allow more personalization.
- mu (ΞΌ): Controls perturbation scaling across layers.
- perturb_scale: Base magnitude of Gaussian perturbation.
- warmup_rounds: Number of initial rounds without PPD.
- momentum: SGD momentum. 0.9 is crucial for best performance (+5.52% improvement).
VPFL/
βββ README.md # This file
βββ LICENSE # Apache 2.0 License
βββ requirements.txt # Python dependencies
βββ main.py # Main entry point
β
βββ dataset/ # Dataset generation scripts
β βββ generate_Cifar10.py
β βββ generate_FashionMNIST.py
β βββ generate_all_datasets.sh
β
βββ system/ # Core algorithm (PFLlib-style)
β βββ flcore/
β βββ servers/
β β βββ serverbase.py # Base server
β β βββ servervpfl.py # VPFL server
β βββ clients/
β β βββ clientbase.py # Base client
β β βββ clientvpfl.py # VPFL client
β βββ trainmodel/
β βββ models.py # Neural network models
β
βββ utils/ # Utilities
β βββ data_utils.py # Data loading utilities
β βββ result_utils.py # Result saving utilities
β
βββ results/ # Results and saved models
bash scripts/run_experiments.sh# Quick test (20 rounds)
python main.py --dataset FashionMNIST_5_pat --global_rounds 20
# Full training (100 rounds)
python main.py --dataset Cifar10_5_pat --global_rounds 100
# With model saving
python main.py --dataset FashionMNIST_5_pat --global_rounds 100| Method | CIFAR-10 Path | Fashion Path | CIFAR-10 Dirichlet |
|---|---|---|---|
| VPFL (Ours) | 58.02% | 81.49% | 67.08% |
| FedAvg | ~50% | ~75% | ~65% |
| VPFL Baseline | 51.19% | - | - |
- Momentum is Critical: Adding momentum=0.9 improved accuracy by +5.52%
- Ξ»=10.0 is Optimal: Increasing from 5.0 to 10.0 added +0.66%
- Fashion-MNIST is Easier: Achieved 81.49% vs 58.02% on CIFAR-10
- Three Components Work Together: PPD + Constrained Update + Perturbation
# Generate datasets first
cd dataset
python generate_Cifar10.py
python generate_FashionMNIST.py# Reduce batch size
python main.py --dataset Cifar10_5_pat --batch_size 5# Use fewer rounds
python main.py --dataset Cifar10_5_pat --global_rounds 50If you use this code in your research, please cite:
@inproceedings{zhou2025variational,
title={Variational Perturbation Personalized Federated Learning via Prior-Posterior Distance},
author={Zhou, Hefeng and Wang, Yuanbin and Wang, Jun and Lou, Jiong and Bao, Wugedele and Wu, Chentao and Li, Jie},
booktitle={ICASSP 2025-2025 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
pages={1--5},
year={2025},
organization={IEEE}
}This project is licensed under the Apache License 2.0 - see the LICENSE file for details.
- Architecture inspired by PFLlib
- PyTorch team for the excellent deep learning framework