Skip to content

compute_fake_perturbation_tests memory grows ~15 GB per iteration -> OOM at iter ~16/50 on real datasets #3

@adamklie

Description

@adamklie

compute_fake_perturbation_tests memory grows ~15 GB per iteration → OOM at iteration ~16/50 on real datasets

Summary

Even with the args.reference_targets fix (filed separately, see #N), compute_fake_perturbation_tests exhausts memory on real-world datasets after ~16 iterations of 50, hitting SLURM-allocated 256 GB.

Reproduction

Run the U-test calibration on a Huangfu HUES8 cNMF h5mu (~270k cells × ~36k genes sparse, ~14k guides, 600 NT) with the standard params:

--number_run 50 --number_guide 6 \
--components 30 50 60 80 100 200 250 300 \
--sel_thresh 2.0 \
--compute_fake_perturbation_tests

Both runs (ESC=10555425, DE=10555432) crashed at iteration 16/50 of K=30, exit code 9. SLURM MaxRSS:

ESC: 261,868,524K ≈ 250 GB
DE:  264,975,876K ≈ 252 GB

Root cause hypothesis

In compute_fake_perturbation_tests (lines 121–164):

for k in args.components:
    mdata = mu.read(...)                # ~10 GB sparse for our dataset
    _assign_guide(mdata, mdata_guide)
    for i in range(args.number_run):    # 50 iterations
        _mdata = mdata.copy()           # ← deep-copies the full mdata each iter
        _mdata[args.prog_key].obsm[args.guide_assignment_key] = mdata[...][:, non_targeting_idx]
        ...
        for samp in unique:
            mdata_samp = _mdata[mask]
            test_stats_df = compute_perturbation_association(mdata_samp, ...)
            ...

Per-iteration full mdata.copy() retains references that Python's GC can't free fast enough between iterations. ~15 GB residual per iteration × 16 iterations ≈ 240 GB at OOM time.

Workaround we used

Dropped --number_run from 50 → 10. Peaks ~150 GB, fits in 256 GB allocation. Less-resolved null distribution but still meaningful for QC.

Possible fixes

  1. Avoid the deep copy. The fake-test only mutates _mdata[prog_key].obsm[guide_assignment_key] and two uns arrays — never the rna modality, which is what makes the copy expensive (~10 GB sparse). A surgical fix would mutate mdata[prog_key] in place each iteration and restore at the end, skipping mdata.copy() entirely.

  2. Force release between iterations: add del _mdata, mdata_samp and gc.collect() at end of each iteration. See oom_remedy.diff for a small bandaid patch. Doesn't fix the underlying redundant deep-copy but should keep peak memory bounded.

  3. Process-level isolation: invoke each iteration in a sub-process; OS reclaims memory cleanly on exit. Heavier but bullet-proof.

(1) is the right structural fix. (2) is a small bandaid that would let the current 50-iteration default work in 256 GB.

Why this wasn't caught earlier

Same answer as the related args.reference_targets issue: PerturbNMF is a publishable rewrite of Stanford's older cNMF_benchmarking tool. Pre-existing fake-test outputs at the Engreitz lab come from the older tool, not this new code path. We seem to be the first to drive the new PerturbNMF U-test code end-to-end on a real-sized dataset, so this leak is surfacing now rather than during the rewrite.

Environment

  • PerturbNMF main @ 8f7c9dd (with the line 160 args.reference_targets fix applied locally on a branch)
  • Python 3.10, mudata 0.4.x
  • Carter HPC (UCSD), 256 GB SLURM allocation per job
  • Real dataset: Huangfu HUES8 endoderm differentiation, 8 K values × sel_thresh=2.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions