Minor fixes to enable training on Apple Silicon (MPS)#97
Open
cmkobel wants to merge 3 commits into
Open
Conversation
Upstream unconditionally calls `torch.cuda.set_device(args.device_num)` at startup, which crashes on any host without a CUDA build of PyTorch (Apple Silicon, CPU-only Linux). Gate the call on `cuda.is_available()` and fall through to MPS, then CPU. The plain `--device cpu` / `--device mps` argparse path still works; this only fixes the startup crash, not the in-loop device selection.
… loop `torch.autograd.set_detect_anomaly(True)` was left enabled at the top of the metric-learning `train()` loop. Anomaly mode forces a device sync after every backward op so it can attribute NaNs to the producing operation. It is meant for debugging NaN gradients, not steady-state training. The cost is dramatic on MPS — where each sync round-trips through the unified-memory boundary — and non-trivial on CUDA. Removing it does not change the training math; verified by matching pretrain and metric losses bit-for-bit across small runs. Users who actively need anomaly detection can re-enable it locally.
… triplet) The `indices_counts` accumulator in the metric-learning loop indexes into three device tensors (`indices_mapped[0..2]`) element-by-element inside a Python loop. Each indexing operation forces a device->host sync, since the value is interpolated into an f-string key. For a typical mid-sized run this is ~1.25M syncs per epoch. Move each of the three tensors to host memory once via `.cpu().tolist()` and iterate in pure Python. The dictionary contents are bitwise identical; the data feeds a diagnostic triplet-counts CSV, not the training math. Impact on MPS: a microbenchmark of 200k synthetic triplets goes from 125 s to 37 ms — a 3000x+ speedup of this loop in isolation. Real SATURN runs that previously hung indefinitely on the first metric epoch now complete in minutes. CUDA also sees a smaller but real speedup.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Hi! Thanks for SATURN. We've been using it to align scRNA-seq across several plant species and it's been great.
While getting it running on an Apple Silicon laptop we hit a couple of small issues. None of these change training behaviour; they just remove some friction. Three small commits, one file (
train-saturn.py):Fall back to MPS/CPU when CUDA is unavailable. Upstream calls
torch.cuda.set_device()at startup unconditionally, which crashes on hosts without a CUDA build of PyTorch. Gated oncuda.is_available()with a fall-through to MPS, then CPU.Remove always-on
torch.autograd.set_detect_anomaly(True)in the metric-learning loop. Anomaly mode forces a device sync on every backward op. It's intended for debugging NaN gradients, but it's currently on by default.Vectorise the
indices_countsbookkeeping. The per-triplet loop indexes into three device tensors element-by-element to build a dict key.__format__on a 0-d device tensor calls.item(), forcing one device->host sync per triplet. Moving each tensor to host once via.cpu().tolist()and iterating in pure Python produces the same dict, just without the syncs.Validation
We ran the full pipeline twice with
--seed 0, before and after the patches, at two scales on two devices:In every run, every pretrain rank-loss, every metric loss, and every mined-triplet count match bit-for-bit between baseline and patched. The patches do not change training math.
The integration speedup is more modest than the patches' per-pattern impact suggests, because forward/backward compute dominates wall time at these scales. A microbenchmark of the
indices_countsloop in isolation shows ~50x on H200 and ~3000x on MPS, with fidelity controls (an explicit.item()loop matches the upstream f-string loop within 11%) and a scale sweep (per-triplet cost flat across 10k to 1M) confirming the bottleneck is the per-element device sync. So on workloads where the metric loop runs millions of triplets per epoch, or under parameter combinations that maximise triplet density relative to compute, the wall-clock impact will be larger than our table shows. We've observed real-data runs go from "doesn't finish in 90 minutes" to "completes in 2 minutes" on MPS at certain configurations, though we couldn't reproduce that regime cleanly with synthetic data.Notes
Happy to split the commits, adjust the comments, or drop any of the three if you'd prefer to take them separately. The MPS device fallback (commit 1) is the only one that changes runtime control flow on non-CUDA hosts; the other two are runtime-only cleanups.
Reproducer (synthetic data + microbench + integration test) is available on request. Happy to share the validation harness if it would help review.