Skip to content

Modular JAX optimizers + multi-GPU sharding#304

Open
rohandahale wants to merge 2 commits into
dev-backendfrom
feature/jax-optimizers-sharding
Open

Modular JAX optimizers + multi-GPU sharding#304
rohandahale wants to merge 2 commits into
dev-backendfrom
feature/jax-optimizers-sharding

Conversation

@rohandahale

@rohandahale rohandahale commented Jun 13, 2026

Copy link
Copy Markdown
Collaborator

Modular JAX optimizers + multi-GPU sharding for the imaging objective, on top of the JAX objective (#295). Both are opt-in; the default make_image path is unchanged (scipy L-BFGS-B, bit-for-bit).

Optimizer layer

make_image(optimizer=...) now accepts:

  • None (default) — scipy L-BFGS-B, exactly as before
  • 'optax-lbfgs' / 'optax-lbfgs-bt' (backtracking line search) / 'adam' / 'adamw' / 'sgd' / 'rmsprop', or any optax.GradientTransformation — an on-device lax.while_loop (value_and_grad stays on device, no per-step host sync)
  • any callable — escape hatch, given the host value_and_grad

New ehtim/imaging/optimizers.py. optax is an optional, lazily-imported dependency (import ehtim stays optax-free).

Multi-GPU sharding

make_image(shard=True, shard_axis='baseline'|'frequency', mesh=None) splits the data-fidelity (chi²) term across the local GPUs; the image and regularizers stay replicated. Off by default (single device).

  • baseline — shard the visibility/baseline axis (direct + nfft)
  • frequency — shard the channel axis (multifrequency)

The sharded objective is bit-for-bit the single-device one: the data axis is padded to a mesh multiple with finite-sample, infinite-sigma rows (exact-zero contribution, no closure NaN), and each chi² is corrected by pad_len/true_len. Validated sharded == single-device == numpy (gradient rel ~1e-16) for Stokes I (vis/amp/cphase/logcamp), polarimetric (IP), and multifrequency, on both transforms.

nfft caveat: jax_finufft's nufft2 forward shards cleanly, but its registered transpose is wrong under shard_map (minimal repro filed upstream). We wrap the grid→samples map in a custom_vjp whose backward is an explicit forward nufft1 + psum — the exact type-1 adjoint, correct on one or many devices.

GPU parameter surveys

ehtim/imaging/survey_gpu.pyrun_survey_gpu(imgr, weight_grid=, regparam_grid=, prior_fwhm=, sys_noise=, ...) reconstructs a whole hyperparameter grid in one vmapped on-device optimization (the GPU counterpart to the CPU paramsurvey tool in ehtim/survey.py, which is untouched). Two axis kinds:

  • scalar axes (data/reg-term weights, RegParams scalars) — trace through the objective, so the whole sub-grid runs as one jax.vmap'd optimization;
  • outer axes (prior FWHM, fractional systematic noise) — rebuild the prior / re-derive the data sigmas on the host, looped and split across GPUs.

Returns the image, objective, per-row grid record, and per-term reduced chi² for each reconstruction (for ranking the Top Set).

Backtracking line search (optax-lbfgs-bt). Profiling found the survey bottleneck is not FP64 (only ~2.5× slower than FP32 on these GPUs) nor the forward model (value_and_grad vmap-parallelizes ~44× from batch 1→625). It is the optax zoom line-search under vmap: its bracket+zoom while_loop runs to the batch worst-case element every iteration (~20 value_and_grad/iter across a 625-config batch). A backtracking (Armijo) line search — a few value evals per step, keeping L-BFGS curvature — is ~7× faster at matched convergence, and is the survey default.

M87 demonstration (real EHT 2019-D01-01 data, FP32): the Paper IV survey grid (5⁴ reg weights × 3 prior-FWHM × 4 systematic-noise = 7,500 reconstructions) runs in ~17 min on 4 GPUs (~7.4 reconstructions/s), recovering the asymmetric ring across the Top Set; best fit χ²(cphase / logcamp) = 2.46 / 0.95, flux 0.60 Jy. A denser 8⁴-reg run — 49,152 reconstructions, exceeding Paper IV's full 37,500-combination eht-imaging survey — completes in 113 min on 4 GPUs; a single CPU reconstruction (numpy/scipy, measured via SLURM) is ~16 s, so the same survey is ~9 days sequential on CPU — a ~115× speedup.

Benchmarks (4× RTX PRO 6000 Blackwell, x64)

Optimizer layer — Gaussian reconstruction (direct, 1600 px, 200 iters):

  optimizer        wall s   nxcorr   chi2_vis
  scipy-lbfgsb        7.3    0.985    1.332
  optax-lbfgs         2.7    0.982    1.405
  adam                1.1    0.902  226.020

optax-lbfgs reaches the same minimum as scipy at 2.7× the speed (on-device loop, no per-step host sync). adam drops in as an arbitrary optax optimizer (first-order — fast but needs more iterations to converge chi²).

Sharding — single value_and_grad:

  baseline,  direct, Nvis=2592, Npix=1600:  CPU 9.45 ms  | 1 GPU 0.18 ms (52×  CPU) | 2-4 GPU 0.3-0.4 ms
  baseline,  nfft,   Nvis=2070, Npix=1024:  CPU 37.8 ms  | 1 GPU 4.4  ms (8.6× CPU) | 4 GPU 5.6 ms | matrix-free
  frequency, 8 channels, Npix=2048:         CPU 182  ms  | 1 GPU 1.44 ms (127× CPU) | 4 GPU 1.04 ms (1.40×)

At these (small, compile-bound) sizes baseline visibility-axis sharding is communication-bound — multi-GPU adds overhead, so the value is per-device memory and production-scale compute: nfft is matrix-free (no dense A, so it scales to large images without the O(Nvis·Npix) matrix), and direct shards the dense A ~1/k across devices. Frequency sharding (independent channels) already scales. Larger sizes via BENCH_NPIX / BENCH_TADV / BENCH_NCHAN; the scripts are compile-bound (the nfft sharded graph especially).

ALMA-scale (the regime that matters). At Npix=16384 the dense A (2.71 GB) is compute-bound and baseline sharding goes near-linear: direct value_and_grad 7.89 ms (1 GPU) → 4.49 (2 GPU, 1.76×) → 2.45 (4 GPU, 3.23×); GPU vs CPU on the dense path ≈ 200×. The memory wall is A (Nvis×Npix): at 1024² it is ~174 GB — exceeds a single 96 GB GPU, so baseline sharding across ≥4 GPUs is what makes the dense problem fit, while nfft (matrix-free) does 1024² in ~hundreds of MB on one GPU. This confirms image-axis sharding (a deferred follow-up) is redundant: the wall is A, already split by the visibility axis, and the image vector itself is ~8 MB at 1024². Full tables in MEMORY_BASELINES.md.

Tests

pytest -m jax — 143 tests, incl. the new survey tests (objective parity vs the fixed-weight objective, batch-matches-single, prior-FWHM / sys-noise outer axes, per-term chi²). Sharding parity tests (sharded == single-device == numpy, direct/nfft × vis/closures, plus frequency) are gated on ≥2 local GPUs.

Notes

  • Pol imaging's random init is now seeded (default_rng(0)) — reproducible run-to-run; existing pol runs get a different (deterministic) initial guess.
  • optax and jax-finufft are added to the dev / gpu extras (optional); import ehtim stays free of both.

Large + phased, stacks on #295 (the JAX objective). Based on feature/jax-objective for a clean diff (just the optimizer + sharding work); retargets to dev-backend once #295 merges.

@rohandahale rohandahale marked this pull request as ready for review June 13, 2026 19:57
@rohandahale rohandahale requested a review from achael June 13, 2026 19:57
@rohandahale rohandahale self-assigned this Jun 13, 2026
@rohandahale rohandahale added this to the 2.0 milestone Jun 13, 2026
@rohandahale rohandahale changed the base branch from feature/jax-objective to dev-backend June 13, 2026 20:09
@rohandahale rohandahale marked this pull request as draft June 13, 2026 20:09
@rohandahale rohandahale marked this pull request as ready for review June 13, 2026 20:09
@rohandahale rohandahale changed the base branch from dev-backend to feature/jax-objective June 13, 2026 20:17
@rohandahale rohandahale marked this pull request as draft June 13, 2026 20:17
@rohandahale rohandahale marked this pull request as ready for review June 15, 2026 04:42
@rohandahale rohandahale force-pushed the feature/jax-optimizers-sharding branch from 7fd55d3 to 9e83fb6 Compare June 19, 2026 03:06
@rohandahale rohandahale changed the base branch from feature/jax-objective to dev-backend June 19, 2026 03:06
@rohandahale rohandahale reopened this Jun 22, 2026
@codecov

codecov Bot commented Jun 22, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 13.79310% with 350 lines in your changes missing coverage. Please review.
✅ Project coverage is 46.90%. Comparing base (20e604e) to head (26a75e4).

Files with missing lines Patch % Lines
ehtim/imaging/sharding.py 0.00% 162 Missing ⚠️
ehtim/imaging/optimizers.py 15.46% 80 Missing and 2 partials ⚠️
ehtim/imaging/survey_gpu.py 0.00% 69 Missing ⚠️
ehtim/imaging/imager_backend.py 40.00% 18 Missing ⚠️
ehtim/imager.py 51.51% 13 Missing and 3 partials ⚠️
ehtim/observing/obs_helpers.py 0.00% 3 Missing ⚠️
Additional details and impacted files
@@               Coverage Diff               @@
##           dev-backend     #304      +/-   ##
===============================================
- Coverage        47.43%   46.90%   -0.53%     
===============================================
  Files               55       58       +3     
  Lines            26979    27362     +383     
  Branches          4599     4647      +48     
===============================================
+ Hits             12798    12835      +37     
- Misses           12689    13031     +342     
- Partials          1492     1496       +4     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@rohandahale

Copy link
Copy Markdown
Collaborator Author

@achael This PR is ready for review!

@achael achael left a comment

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great overall, but I think my general feedback so far is just to do a human pass on the docstring and comments to make the flow here clearer for non-experts. Full docstrings with explicit argument/output specs for the new functions would help too.
Some other minor comments as well.

# Caller (compute_init_state) decides when random-pol init applies.
# Here we just honor the flag: True means "use random pol initialization
# regardless of init image content"; False means "use init image's pol".
pol_rng = np.random.default_rng(0) # seeded so the random pol init is reproducible

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a seed argument to the function?

def _prepare_jax_loss(initvec, config, which_solve, data_tuples, logfreqratio_list,
n_obs, dat_term, reg_term, priorvec, norm_reg, reg_params,
embed_mask, device=None):
"""Build the shared jax loss(x) closure for the imaging objective.

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

useful to add use case in docstring -- this is in both make_objective_jax and make_value_and_grad_jax

def make_objective_jax(initvec, config, which_solve, data_tuples, logfreqratio_list,
n_obs, dat_term, reg_term, priorvec, norm_reg, reg_params,
embed_mask, device=None):
"""Return a scipy fun(x) -> (value, grad) for the imaging objective, via jax.

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as a jax novice it would be useful in the docstring here and in make_value_and_grad_jax to to make the use cases for the different objective functions a little clearer

Comment thread ehtim/imager.py

# an optax / device optimizer needs the on-device jax objective
if optimizer is not None and not use_jax and classify_optimizer(optimizer) == 'optax':
print("using the jax objective (required by the optax optimizer)")

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to raise here than automatically changing the user flags

Comment thread ehtim/imager.py
# optax optimizer; default to optax-lbfgs when none was given.
if shard and (optimizer is None or classify_optimizer(optimizer) != 'optax'):
optimizer = 'optax-lbfgs'
print("sharding: using optax-lbfgs (the sharded objective runs on-device)")

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to raise here than automatically changing the optimizer



def run_optimizer(optimizer, *, x0, optdict, callback=None,
build_scipy=None, build_device_vg=None, device=None, mesh=None):

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since build_scipy and build_device_vg are never used together, can this be one argument? That is "build_loss = build_scipy" or "build_loss = build_device_vg", and then have a check in each block that the passed in functions are appropriate for the specified lane?

return float(value), np.asarray(grad, dtype=np.float64)

return fun

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need both? it sounds like from the docstring that make_value_and_grad_jax is the better choice for gpus -- is make_objective_jax better for cpus? If not, can we get rid of it and only use make_value_and_grad_jax?

Comment thread ehtim/imager.py
use_jax = True

def build_scipy():
# host objective/gradient handles for the scipy and callable lanes

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use clear if-else with comments:
if use_jax: #jitted jax objective and gradient
...
elif grads: # default scipy with analytic gradients
...
else: # default scipy without gradies

Comment thread ehtim/imager.py
from ehtim.imaging.sharding import build_mesh, make_sharded_value_and_grad
m = mesh if mesh is not None else build_mesh()
return make_sharded_value_and_grad(*backend_args, mesh=m, shard_axis=shard_axis)
vg, loss_fn, to_device = make_value_and_grad_jax(*backend_args, device=dev)

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again add "else" (just easier for me to read this way -- if there's a legit reason not to or its not correct python style i'm happy to be corrected....)

return loss, put


def make_objective_jax(initvec, config, which_solve, data_tuples, logfreqratio_list,

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it would help clarify things if make_objective_jax, make_value_and_grad_jax, and make_survey_value_and_grad had more parallel names to better illustrate their use case. One suggestion from my naive read would be "make_objective_jax_jit" "make_objective_jax_device" "make_objective_jax_survey" --> but you have a better picture of the code and so don't adopt these unless they make sense

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants