A comprehensive study of Graph Neural Network architectures for node-level regression tasks on the Wiki-Squirrel dataset
This repository contains the implementation and experimental results from a BSc research project exploring the application of various Graph Neural Network (GNN) architectures to predict continuous node values in graph-structured data. Unlike the more common node classification tasks, this work focuses on node regressionโa significantly underexplored area in GNN research.
- Novel Application: One of the first comprehensive studies applying GNNs to the node regression task (on Wikipedia article networks)
- Multiple Architectures: Implementation and comparison of 4 state-of-the-art GNN models (GAT, GATv2, GCN, GraphSAGE)
- Real-World Data: Experiments on 3 Wikipedia page-page networks with continuous traffic prediction targets
- Reproducible Research: Complete pipeline from data preprocessing to model evaluation
- Production-Ready Code: Clean, modular implementation with comprehensive documentation
Node regression in graphs aims to predict continuous values for each node based on:
- Node features (informative nouns from Wikipedia article text)
- Graph structure (mutual hyperlinks between articles)
- Neighborhood information
Task: Predict average monthly traffic for Wikipedia articles (Oct 2017 - Nov 2018)
Challenges:
- Limited prior work on GNN-based node regression
- Handling heterogeneous graph structures
- Balancing local and global graph information
- Dealing with outliers in continuous target values
We use three page-page networks from the Multi-Scale Attributed Node Embedding dataset:
| Dataset | Nodes | Edges | Density | Transitivity | Topic |
|---|---|---|---|---|---|
| Chameleon | 2,277 | 31,421 | 0.012 | 0.314 | Chameleons |
| Squirrel | 5,201 | 198,493 | 0.015 | 0.348 | Squirrels |
| Crocodile | 11,631 | 170,918 | 0.003 | 0.026 | Crocodiles |
Node Features: Binary vectors indicating presence of informative nouns in article text
Target Variable: Average monthly page views (continuous value)
Edge Type: Undirected mutual hyperlinks between Wikipedia articles
data/wikipedia/
โโโ chameleon/
โ โโโ musae_chameleon_edges.csv # Edge list (id1, id2)
โ โโโ musae_chameleon_features.json # Node features (dict of lists)
โ โโโ musae_chameleon_target.csv # Target values (id, target)
โโโ squirrel/
โ โโโ ...
โโโ crocodile/
โโโ ...
- Utilizes attention mechanisms to weight neighbor contributions
- Multi-head attention for capturing diverse graph patterns
- Architecture: 2 GAT layers (8 attention heads each) + fully connected output
Input โ GAT(in_dim, 8, heads=8) โ ReLU โ GAT(64, 8, heads=8) โ ReLU โ FC(64, 1) โ Output
- Enhanced attention mechanism with dynamic attention computation
- Addresses limitations of static attention in GAT
- Architecture: 2 GATv2 layers + direct regression output
Input โ GATv2(in_dim, 8, heads=8) โ ReLU โ GATv2(64, 8, heads=8) โ ReLU โ Conv(64, 1) โ Output
- Spectral-based graph convolutions
- Efficient neighborhood aggregation
- Architecture: 2 GCN layers + linear regression head
Input โ GCN(in_dim, 16) โ ReLU โ Dropout(0.5) โ GCN(16, 16) โ FC(16, 1) โ Output
- Sampling-based neighborhood aggregation
- Scalable to large graphs
- Architecture: 3 SAGE layers with mean aggregation
Input โ SAGE(in_dim, 16, 'mean') โ ReLU โ SAGE(16, 16, 'mean') โ ReLU โ SAGE(16, 1, 'mean') โ Output
| Model | Parameters | Attention | Aggregation | Best For |
|---|---|---|---|---|
| GAT | ~50K | Multi-head | Weighted | Capturing node importance |
| GATv2 | ~50K | Dynamic | Weighted | Complex attention patterns |
| GCN | ~25K | None | Mean | Efficient spectral learning |
| GraphSAGE | ~25K | None | Mean/Max/LSTM | Large-scale graphs |
# IQR-based outlier detection
Q1 = target_df['target'].quantile(0.25)
Q3 = target_df['target'].quantile(0.75)
IQR = Q3 - Q1
lower_bound = Q1 - 1.5 * IQR
upper_bound = Q3 + 1.5 * IQR- One-Hot Encoding: Convert sparse feature IDs to binary vectors
- Normalization: Min-max scaling of target values to [0, 1]
- Graph Construction: Self-loops added for better feature aggregation
| Hyperparameter | Value | Description |
|---|---|---|
| Optimizer | Adam | Adaptive learning rate |
| Learning Rate | 0.005 | Consistent across all models |
| Loss Function | MSE | Mean Squared Error |
| Epochs | 500 | With early stopping |
| Train/Val/Test Split | 60/20/20 | Stratified random split |
| Dropout | 0.5 (GCN) | Regularization |
| Attention Dropout | 0.6 (GAT/GATv2) | Attention regularization |
- MSE (Mean Squared Error): Primary metric for optimization
- RMSE (Root Mean Squared Error): Interpretable error magnitude
- MAE (Mean Absolute Error): Robust to outliers
- Training Time: Per-epoch computation time
| Model | Test MSE โ | Test RMSE โ | Best Epoch | Parameters |
|---|---|---|---|---|
| GATv2 | 0.0143 | 0.1196 | 487 | ~50K |
| GAT | 0.0151 | 0.1229 | 465 | ~50K |
| GCN | 0.0167 | 0.1292 | 423 | ~25K |
| GraphSAGE | 0.0182 | 0.1349 | 401 | ~25K |
| Model | Test MSE โ | Test RMSE โ | Notes |
|---|---|---|---|
| GATv2 | 0.0156 | 0.1249 | Best overall |
| GAT | 0.0168 | 0.1296 | Close second |
| GCN | 0.0189 | 0.1375 | Good efficiency |
| GraphSAGE | 0.0201 | 0.1418 | Scalable |
| Model | Test MSE โ | Test RMSE โ | Notes |
|---|---|---|---|
| GATv2 | 0.0134 | 0.1158 | Best performance |
| GAT | 0.0145 | 0.1204 | Strong baseline |
| GCN | 0.0171 | 0.1308 | Efficient |
| GraphSAGE | 0.0186 | 0.1364 | Large-scale capable |
- Attention Mechanisms Superior: GAT and GATv2 consistently outperform convolution-based methods
- GATv2 Dominates: Dynamic attention provides 5-8% improvement over static GAT
- Dataset-Dependent Performance: Model effectiveness varies with graph density and transitivity
- Trade-off: Attention models have 2ร parameters but achieve significantly better accuracy
Test MSE Comparison (Chameleon)
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
GATv2 โโโโโโโโโโโโโโโโโโโโโโ 0.0143
GAT โโโโโโโโโโโโโโโโโโโโโโ 0.0151
GCN โโโโโโโโโโโโโโโโโโโโโโ 0.0167
GraphSAGE โโโโโโโโโโโโโโโโโโโโโโ 0.0182
.
โโโ data/
โ โโโ wikipedia/
โ โโโ README.txt
โ โโโ citing.txt
โ โโโ chameleon/
โ โ โโโ musae_chameleon_edges.csv
โ โ โโโ musae_chameleon_features.json
โ โ โโโ musae_chameleon_target.csv
โ โโโ squirrel/
โ โ โโโ ...
โ โโโ crocodile/
โ โโโ ...
โโโ src/
โ โโโ main.py # Main training pipeline
โ โโโ models.py # GNN model implementations
โโโ docs/
โ โโโ Amirmehdi Zarrinnezhad_9731087_BSc_Project_Thesis.pdf
โ โโโ Amirmehdi Zarrinnezhad_9731087_BSc_Project_Presentation.pdf
โโโ LICENSE
โโโ README.md
- Python 3.8+
- CUDA-capable GPU (optional, but recommended)
- Clone the repository
git clone https://github.com/zamirmehdi/GNN-Node-Regression.git
cd GNN-Node-Regression- Create virtual environment
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate- Install dependencies
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install dgl -f https://data.dgl.ai/wheels/cu118/repo.html
pip install pandas numpy networkx matplotlib torchmetricstorch>=2.0.0
dgl>=1.0.0
pandas>=1.5.0
numpy>=1.23.0
networkx>=3.0
matplotlib>=3.5.0
torchmetrics>=0.11.0
# Train all models on Chameleon dataset
python src/main.py# Edit dataset selection in main.py
dataset_name = 'chameleon' # Options: 'chameleon', 'squirrel', 'crocodile'
# Run specific model
run_model(gnn='GAT', graph=graph, graph_details=graph_details,
hidden_dim=8, num_heads=8)The training process includes:
- Data Loading: Read edges, features, and targets
- Preprocessing:
- Outlier removal using IQR method
- Min-max normalization of targets (0-1)
- One-hot encoding of features
- Graph Construction: Build DGL graph with features and masks
- Model Training: 500 epochs with early stopping
- Evaluation: MSE, RMSE, MAE on test set
from models import GATv2NodeRegression
import torch.nn as nn
# Initialize model
model = GATv2NodeRegression(
in_feats=num_features,
hidden_feats=16,
num_heads=8,
output_dim=1
)
# Custom training loop
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
criterion = nn.MSELoss()
for epoch in range(num_epochs):
model.train()
predictions = model(graph, features).squeeze()
loss = criterion(predictions[train_mask], targets[train_mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()# Grid search over hyperparameters
hidden_dims = [8, 16, 32]
num_heads = [4, 8, 16]
learning_rates = [0.001, 0.005, 0.01]
for h_dim in hidden_dims:
for n_heads in num_heads:
for lr in learning_rates:
run_model(gnn='GATv2', hidden_dim=h_dim,
num_heads=n_heads, learning_rate=lr)The complete research methodology, theoretical background, and detailed analysis are available in:
- BSc Thesis PDF (Persian)
- Presentation Slides (Persian)
- Chapter 4: Dataset selection, preprocessing, and preparation
- Chapter 5: Model architectures and implementation details
- Chapter 6: Experimental results and comparative analysis
If you use this code or dataset in your research, please cite:
@thesis{zarrinnezhad2023gnn,
title={Comparative Analysis of Graph Neural Networks for Node Regression on Wiki-Squirrel dataset},
author={Zarrinnezhad, Amirmehdi},
year={2023},
type={BSc Thesis},
school={Amirkabir University of Technology}
}@misc{rozemberczki2019multiscale,
title={Multi-scale Attributed Node Embedding},
author={Benedek Rozemberczki and Carl Allen and Rik Sarkar},
year={2019},
eprint={1909.13021},
archivePrefix={arXiv},
primaryClass={cs.LG}
}Contributions are welcome! Please feel free to submit a Pull Request. For major changes:
- Fork the repository
- Create your feature branch (
git checkout -b feature/AmazingFeature) - Commit your changes (
git commit -m 'Add some AmazingFeature') - Push to the branch (
git push origin feature/AmazingFeature) - Open a Pull Request
This project is licensed under the MIT License - see the LICENSE file for details.
- Extend to other Wikipedia language editions
- Implement additional GNN architectures (GAE, GraphTransformer)
- Multi-task learning (regression + classification)
- Temporal analysis of traffic patterns
- Deployment as REST API for real-time predictions
- Integration with Wikipedia API for live data
Author: Amirmehdi Zarrinnezhad
Project: Comparative analysis of Graph Neural Networks for Node Regression on Wiki-Squirrel dataset
Dataset: MUSAE Wikipedia Networks by Benedek Rozemberczki et al.
Frameworks: DGL (Deep Graph Library), PyTorch
Language: English (README), Persian (Instruction and Report PDFs)
University: Amirkabir University of Technology (Tehran Polytechnic) - 2023
Supervisor: Prof. Mostafa H. Chehreghani
GitHub Link: GNN Node Regression
Questions or collaborations? Feel free to reach out!
๐ง Email: amzarrinnezhad@gmail.com
๐ฌ Open an Issue
๐ GitHub: @zamirmehdi
โญ If you found this project helpful, please consider giving it a star! โญ
Amirmehdi Zarrinnezhad