This repository implements a spatial Graph Neural Network (GAT) to predict cell spatial coordinates using only molecular profiles (RNA + protein) from 10x Xenium outputs. The graph is constructed from expression similarity (K-NN) rather than physical proximity.
- Local subgraphs: one subgraph per cell (central + k neighbors). The model predicts the central cell’s coordinates while learning over its local neighborhood. Batching across many subgraphs enables scalable training and DDP.
- Xenium data ingestion via
spatialdata_ioand AnnData processing. - Modality-aware Joint Encoder that splits RNA and protein features, then fuses representations before GAT.
- GAT stack with residual connections, dropout, to produce 2D spatial coordinates.
- Local subgraph construction per cell: k neighbors by expression similarity, bidirectional star edges to the central node.
- Distributed Data Parallel (DDP) training using PyTorch with
torch.multiprocessing.spawnand NCCL backend. - Comprehensive evaluation: MSE/MAE, R² per axis, Euclidean error distribution, prediction vs truth plots, and spatial error heatmaps.
- Caching of preprocessed subgraphs, scalers, and splits to accelerate reproducible runs.
main.py: Entry point for local subgraph pipeline with DDP training and evaluation.data_preprocessing.py: AnnData filtering and normalization; optional global K-NN graph creation.data_preprocessing_subgraph.py: Local subgraph builder per cell and train/val/test split creation.model_joint_encoder.py: Joint Encoder architectures (base/large) with optional cross-modal attention; GAT stacks with residuals and pooling.model.py: Simpler GAT baselines (base/large) without joint modality separation.train_subgraph.py: Trainer for batched local subgraphs; supports single-GPU and DDP; smoothing regularizer.evaluate.py: Metrics and plotting utilities for training curves, prediction scatter, spatial visualization, and error analysis.requirements.txt: Python dependencies.data/: Datasets (Xenium and other sources as laid out below).cache_*: Cached subgraphs, scalers, splits, metadata.results_*andresults/: Saved checkpoints, metrics, and plots.
The default dataset in main.py expects Xenium outputs:
data/Xenium_V1_Human_Kidney_FFPE_Protein_updated_outs/
Other folders present (not required for the default run):
data/Xenium_Prime_Human_Ovary_FF_outs/data/LiverDataReleaseTileDB/(TileDB layout)data/breast.zarr/(Zarr layout)data/GSE158013_RAW/(scATAC-related inputs)
The pipeline uses spatialdata_io.xenium(...) with options to include gene and protein modalities and to represent cells as circles. The main AnnData table is available at sdata.tables["table"]. Features are filtered to keep only var["feature_types"] in {"Gene Expression", "Protein Expression"}.
Implemented in data_preprocessing.py::preprocess_adata:
- Gene features: Scanpy
normalize_total(target_sum=1e4)followed bylog1p. - Protein features: StandardScaler z-score per feature.
- Sparse matrices: converted to dense up front for simplicity.
Local subgraph construction in data_preprocessing_subgraph.py::build_local_subgraphs:
- Similarity: scikit-learn
NearestNeighbors(n_neighbors=k+1, metric=cosine|euclidean, n_jobs=-1)fits on normalized features. - For each cell i, build nodes: [central=i, neighbors=indices[i, 1:k+1]].
- Edges: bidirectional star from central (0) to each neighbor (1..k).
- Targets: normalized spatial coordinates for all nodes in the subgraph (supervision over all nodes), with
y_centralkept for compatibility. - Normalization: StandardScaler fit on global spatial coordinates; inverse transform used at evaluation if requested.
- Returns:
List[torch_geometric.data.Data], each withx,edge_index,y,y_central,pos_original, andcentral_idx.
Train/val/test split creation in create_subgraph_splits: random permutation with ratios (default 70/15/15) and fixed seed.
Joint Encoder models (model_joint_encoder.py):
- Split input features into RNA (n_genes) and protein (n_proteins).
- Independent encoders with LayerNorm, ReLU, Dropout; configurable dimensions.
- Fusion MLP to produce a joint representation fed to GAT layers.
- GAT stack with residual projections to stabilize deep attention layers.
Baseline GAT models (model.py): simpler two- or three-layer GAT with a final MLP; no modality separation.
Losses and regularizers (train_subgraph.py):
- Main loss: MSE on predicted vs normalized coordinates for all nodes in the batch.
- MAE tracked for reporting; ReduceLROnPlateau scheduler on validation loss.
The main entry main.py orchestrates:
- Resource setup: increase file descriptor limits and set PyTorch’s sharing strategy to
file_system. - Load Xenium dataset via
spatialdata_io.xenium(...); get AnnData table. - Preprocess features with gene/protein normalization.
- Cache handling: if cached
subgraphs.pt,scaler.pkl, andmetadata.jsonexist, reuse them; else build and cache. - Derive
n_genesandn_proteinsfromadata.var["feature_types"]counts. - Create or load train/val/test splits; save to JSON for reproducibility.
- Determine
in_channelsand select architecture; default uses Joint Encoder with parameters increate_joint_encoder_model(...). - Spawn DDP processes via
torch.multiprocessing.spawnwithworld_size=4. - In each rank:
- Initialize
torch.distributedwith NCCL backend and bind rank tocuda:rank. - Load cached subgraphs and scaler locally.
- Build
SubgraphTrainerwithDistributedSamplers and DDP-wrapped model. - Train with early stopping and scheduler.
- Initialize
- On rank 0 only:
- Save checkpoint to
results/spatial_gat_model.pt. - Plot training history.
- Predict on test set, denormalize if requested.
- Compute metrics; save
results/test_metrics.csv. - Save plots: predictions vs truth, spatial predictions, error distribution; print extreme errors summary.
- Save checkpoint to
On successful training (rank 0):
results/spatial_gat_model.pt: checkpoint with model state, optimizer state, and training history.results/training_history.png: loss and MAE curves for train/val.results/predictions_vs_true.png: scatter plots for X and Y predictions.results/spatial_predictions.png: spatial scatter of true vs predicted and error heatmap.results/error_distribution.png: histogram and boxplot of Euclidean errors.results/test_metrics.csv: MSE, MAE, R² per axis, and Euclidean distance statistics.
- Normalization: coordinate scaler is fit globally; predictions are trained in normalized space and denormalized for reporting.
- Reproducibility: splits and metadata are persisted to cache; seeds applied in split creation.