Skip to content

Minor fixes to enable training on Apple Silicon (MPS)#97

Open
cmkobel wants to merge 3 commits into
snap-stanford:mainfrom
cmkobel:mps-speedup
Open

Minor fixes to enable training on Apple Silicon (MPS)#97
cmkobel wants to merge 3 commits into
snap-stanford:mainfrom
cmkobel:mps-speedup

Conversation

@cmkobel

@cmkobel cmkobel commented Jun 3, 2026

Copy link
Copy Markdown

Hi! Thanks for SATURN. We've been using it to align scRNA-seq across several plant species and it's been great.

While getting it running on an Apple Silicon laptop we hit a couple of small issues. None of these change training behaviour; they just remove some friction. Three small commits, one file (train-saturn.py):

  1. Fall back to MPS/CPU when CUDA is unavailable. Upstream calls torch.cuda.set_device() at startup unconditionally, which crashes on hosts without a CUDA build of PyTorch. Gated on cuda.is_available() with a fall-through to MPS, then CPU.

  2. Remove always-on torch.autograd.set_detect_anomaly(True) in the metric-learning loop. Anomaly mode forces a device sync on every backward op. It's intended for debugging NaN gradients, but it's currently on by default.

  3. Vectorise the indices_counts bookkeeping. The per-triplet loop indexes into three device tensors element-by-element to build a dict key. __format__ on a 0-d device tensor calls .item(), forcing one device->host sync per triplet. Moving each tensor to host once via .cpu().tolist() and iterating in pure Python produces the same dict, just without the syncs.

Validation

We ran the full pipeline twice with --seed 0, before and after the patches, at two scales on two devices:

device scale baseline patched speedup
NVIDIA H200 5k cells/species, 10 types, 3 epochs 66.8 s 26.2 s 2.55x
NVIDIA H200 15k cells/species, 15 types, 5 epochs 91.3 s 64.2 s 1.42x
Apple M5 (MPS) 5k cells/species, 10 types, 3 epochs 232.6 s 173.8 s 1.34x
Apple M5 (MPS) 15k cells/species, 15 types, 5 epochs 999.7 s 760.6 s 1.31x

In every run, every pretrain rank-loss, every metric loss, and every mined-triplet count match bit-for-bit between baseline and patched. The patches do not change training math.

The integration speedup is more modest than the patches' per-pattern impact suggests, because forward/backward compute dominates wall time at these scales. A microbenchmark of the indices_counts loop in isolation shows ~50x on H200 and ~3000x on MPS, with fidelity controls (an explicit .item() loop matches the upstream f-string loop within 11%) and a scale sweep (per-triplet cost flat across 10k to 1M) confirming the bottleneck is the per-element device sync. So on workloads where the metric loop runs millions of triplets per epoch, or under parameter combinations that maximise triplet density relative to compute, the wall-clock impact will be larger than our table shows. We've observed real-data runs go from "doesn't finish in 90 minutes" to "completes in 2 minutes" on MPS at certain configurations, though we couldn't reproduce that regime cleanly with synthetic data.

Notes

Happy to split the commits, adjust the comments, or drop any of the three if you'd prefer to take them separately. The MPS device fallback (commit 1) is the only one that changes runtime control flow on non-CUDA hosts; the other two are runtime-only cleanups.

Reproducer (synthetic data + microbench + integration test) is available on request. Happy to share the validation harness if it would help review.

cmkobel added 3 commits June 3, 2026 09:55
Upstream unconditionally calls `torch.cuda.set_device(args.device_num)`
at startup, which crashes on any host without a CUDA build of PyTorch
(Apple Silicon, CPU-only Linux). Gate the call on `cuda.is_available()`
and fall through to MPS, then CPU.

The plain `--device cpu` / `--device mps` argparse path still works;
this only fixes the startup crash, not the in-loop device selection.
… loop

`torch.autograd.set_detect_anomaly(True)` was left enabled at the top of
the metric-learning `train()` loop. Anomaly mode forces a device sync
after every backward op so it can attribute NaNs to the producing
operation. It is meant for debugging NaN gradients, not steady-state
training.

The cost is dramatic on MPS — where each sync round-trips through the
unified-memory boundary — and non-trivial on CUDA. Removing it does not
change the training math; verified by matching pretrain and metric
losses bit-for-bit across small runs.

Users who actively need anomaly detection can re-enable it locally.
… triplet)

The `indices_counts` accumulator in the metric-learning loop indexes
into three device tensors (`indices_mapped[0..2]`) element-by-element
inside a Python loop. Each indexing operation forces a device->host
sync, since the value is interpolated into an f-string key. For a
typical mid-sized run this is ~1.25M syncs per epoch.

Move each of the three tensors to host memory once via `.cpu().tolist()`
and iterate in pure Python. The dictionary contents are bitwise
identical; the data feeds a diagnostic triplet-counts CSV, not the
training math.

Impact on MPS: a microbenchmark of 200k synthetic triplets goes from
125 s to 37 ms — a 3000x+ speedup of this loop in isolation. Real
SATURN runs that previously hung indefinitely on the first metric
epoch now complete in minutes. CUDA also sees a smaller but real
speedup.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant