Modular JAX optimizers + multi-GPU sharding#304
Conversation
7fd55d3 to
9e83fb6
Compare
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## dev-backend #304 +/- ##
===============================================
- Coverage 47.43% 46.90% -0.53%
===============================================
Files 55 58 +3
Lines 26979 27362 +383
Branches 4599 4647 +48
===============================================
+ Hits 12798 12835 +37
- Misses 12689 13031 +342
- Partials 1492 1496 +4 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
|
@achael This PR is ready for review! |
achael
left a comment
There was a problem hiding this comment.
looks great overall, but I think my general feedback so far is just to do a human pass on the docstring and comments to make the flow here clearer for non-experts. Full docstrings with explicit argument/output specs for the new functions would help too.
Some other minor comments as well.
| # Caller (compute_init_state) decides when random-pol init applies. | ||
| # Here we just honor the flag: True means "use random pol initialization | ||
| # regardless of init image content"; False means "use init image's pol". | ||
| pol_rng = np.random.default_rng(0) # seeded so the random pol init is reproducible |
There was a problem hiding this comment.
add a seed argument to the function?
| def _prepare_jax_loss(initvec, config, which_solve, data_tuples, logfreqratio_list, | ||
| n_obs, dat_term, reg_term, priorvec, norm_reg, reg_params, | ||
| embed_mask, device=None): | ||
| """Build the shared jax loss(x) closure for the imaging objective. |
There was a problem hiding this comment.
useful to add use case in docstring -- this is in both make_objective_jax and make_value_and_grad_jax
| def make_objective_jax(initvec, config, which_solve, data_tuples, logfreqratio_list, | ||
| n_obs, dat_term, reg_term, priorvec, norm_reg, reg_params, | ||
| embed_mask, device=None): | ||
| """Return a scipy fun(x) -> (value, grad) for the imaging objective, via jax. |
There was a problem hiding this comment.
as a jax novice it would be useful in the docstring here and in make_value_and_grad_jax to to make the use cases for the different objective functions a little clearer
|
|
||
| # an optax / device optimizer needs the on-device jax objective | ||
| if optimizer is not None and not use_jax and classify_optimizer(optimizer) == 'optax': | ||
| print("using the jax objective (required by the optax optimizer)") |
There was a problem hiding this comment.
I think it's better to raise here than automatically changing the user flags
| # optax optimizer; default to optax-lbfgs when none was given. | ||
| if shard and (optimizer is None or classify_optimizer(optimizer) != 'optax'): | ||
| optimizer = 'optax-lbfgs' | ||
| print("sharding: using optax-lbfgs (the sharded objective runs on-device)") |
There was a problem hiding this comment.
I think it's better to raise here than automatically changing the optimizer
|
|
||
|
|
||
| def run_optimizer(optimizer, *, x0, optdict, callback=None, | ||
| build_scipy=None, build_device_vg=None, device=None, mesh=None): |
There was a problem hiding this comment.
since build_scipy and build_device_vg are never used together, can this be one argument? That is "build_loss = build_scipy" or "build_loss = build_device_vg", and then have a check in each block that the passed in functions are appropriate for the specified lane?
| return float(value), np.asarray(grad, dtype=np.float64) | ||
|
|
||
| return fun | ||
|
|
There was a problem hiding this comment.
do we need both? it sounds like from the docstring that make_value_and_grad_jax is the better choice for gpus -- is make_objective_jax better for cpus? If not, can we get rid of it and only use make_value_and_grad_jax?
| use_jax = True | ||
|
|
||
| def build_scipy(): | ||
| # host objective/gradient handles for the scipy and callable lanes |
There was a problem hiding this comment.
let's use clear if-else with comments:
if use_jax: #jitted jax objective and gradient
...
elif grads: # default scipy with analytic gradients
...
else: # default scipy without gradies
| from ehtim.imaging.sharding import build_mesh, make_sharded_value_and_grad | ||
| m = mesh if mesh is not None else build_mesh() | ||
| return make_sharded_value_and_grad(*backend_args, mesh=m, shard_axis=shard_axis) | ||
| vg, loss_fn, to_device = make_value_and_grad_jax(*backend_args, device=dev) |
There was a problem hiding this comment.
again add "else" (just easier for me to read this way -- if there's a legit reason not to or its not correct python style i'm happy to be corrected....)
| return loss, put | ||
|
|
||
|
|
||
| def make_objective_jax(initvec, config, which_solve, data_tuples, logfreqratio_list, |
There was a problem hiding this comment.
i think it would help clarify things if make_objective_jax, make_value_and_grad_jax, and make_survey_value_and_grad had more parallel names to better illustrate their use case. One suggestion from my naive read would be "make_objective_jax_jit" "make_objective_jax_device" "make_objective_jax_survey" --> but you have a better picture of the code and so don't adopt these unless they make sense
Modular JAX optimizers + multi-GPU sharding for the imaging objective, on top of the JAX objective (#295). Both are opt-in; the default
make_imagepath is unchanged (scipy L-BFGS-B, bit-for-bit).Optimizer layer
make_image(optimizer=...)now accepts:None(default) — scipy L-BFGS-B, exactly as before'optax-lbfgs'/'optax-lbfgs-bt'(backtracking line search) /'adam'/'adamw'/'sgd'/'rmsprop', or anyoptax.GradientTransformation— an on-devicelax.while_loop(value_and_grad stays on device, no per-step host sync)New
ehtim/imaging/optimizers.py.optaxis an optional, lazily-imported dependency (import ehtimstays optax-free).Multi-GPU sharding
make_image(shard=True, shard_axis='baseline'|'frequency', mesh=None)splits the data-fidelity (chi²) term across the local GPUs; the image and regularizers stay replicated. Off by default (single device).The sharded objective is bit-for-bit the single-device one: the data axis is padded to a mesh multiple with finite-sample, infinite-sigma rows (exact-zero contribution, no closure NaN), and each chi² is corrected by
pad_len/true_len. Validatedsharded == single-device == numpy(gradient rel ~1e-16) for Stokes I (vis/amp/cphase/logcamp), polarimetric (IP), and multifrequency, on both transforms.nfft caveat:
jax_finufft'snufft2forward shards cleanly, but its registered transpose is wrong undershard_map(minimal repro filed upstream). We wrap the grid→samples map in acustom_vjpwhose backward is an explicit forwardnufft1+psum— the exact type-1 adjoint, correct on one or many devices.GPU parameter surveys
ehtim/imaging/survey_gpu.py—run_survey_gpu(imgr, weight_grid=, regparam_grid=, prior_fwhm=, sys_noise=, ...)reconstructs a whole hyperparameter grid in one vmapped on-device optimization (the GPU counterpart to the CPUparamsurveytool inehtim/survey.py, which is untouched). Two axis kinds:jax.vmap'd optimization;Returns the image, objective, per-row grid record, and per-term reduced chi² for each reconstruction (for ranking the Top Set).
Backtracking line search (
optax-lbfgs-bt). Profiling found the survey bottleneck is not FP64 (only ~2.5× slower than FP32 on these GPUs) nor the forward model (value_and_gradvmap-parallelizes ~44× from batch 1→625). It is the optax zoom line-search undervmap: its bracket+zoomwhile_loopruns to the batch worst-case element every iteration (~20 value_and_grad/iter across a 625-config batch). A backtracking (Armijo) line search — a few value evals per step, keeping L-BFGS curvature — is ~7× faster at matched convergence, and is the survey default.M87 demonstration (real EHT 2019-D01-01 data, FP32): the Paper IV survey grid (5⁴ reg weights × 3 prior-FWHM × 4 systematic-noise = 7,500 reconstructions) runs in ~17 min on 4 GPUs (~7.4 reconstructions/s), recovering the asymmetric ring across the Top Set; best fit χ²(cphase / logcamp) = 2.46 / 0.95, flux 0.60 Jy. A denser 8⁴-reg run — 49,152 reconstructions, exceeding Paper IV's full 37,500-combination eht-imaging survey — completes in 113 min on 4 GPUs; a single CPU reconstruction (numpy/scipy, measured via SLURM) is ~16 s, so the same survey is ~9 days sequential on CPU — a ~115× speedup.
Benchmarks (4× RTX PRO 6000 Blackwell, x64)
Optimizer layer — Gaussian reconstruction (direct, 1600 px, 200 iters):
optax-lbfgs reaches the same minimum as scipy at 2.7× the speed (on-device loop, no per-step host sync). adam drops in as an arbitrary optax optimizer (first-order — fast but needs more iterations to converge chi²).
Sharding — single value_and_grad:
At these (small, compile-bound) sizes baseline visibility-axis sharding is communication-bound — multi-GPU adds overhead, so the value is per-device memory and production-scale compute: nfft is matrix-free (no dense
A, so it scales to large images without the O(Nvis·Npix) matrix), and direct shards the denseA~1/k across devices. Frequency sharding (independent channels) already scales. Larger sizes viaBENCH_NPIX/BENCH_TADV/BENCH_NCHAN; the scripts are compile-bound (the nfft sharded graph especially).ALMA-scale (the regime that matters). At Npix=16384 the dense
A(2.71 GB) is compute-bound and baseline sharding goes near-linear: direct value_and_grad 7.89 ms (1 GPU) → 4.49 (2 GPU, 1.76×) → 2.45 (4 GPU, 3.23×); GPU vs CPU on the dense path ≈ 200×. The memory wall isA(Nvis×Npix): at 1024² it is ~174 GB — exceeds a single 96 GB GPU, so baseline sharding across ≥4 GPUs is what makes the dense problem fit, while nfft (matrix-free) does 1024² in ~hundreds of MB on one GPU. This confirms image-axis sharding (a deferred follow-up) is redundant: the wall isA, already split by the visibility axis, and the image vector itself is ~8 MB at 1024². Full tables inMEMORY_BASELINES.md.Tests
pytest -m jax— 143 tests, incl. the new survey tests (objective parity vs the fixed-weight objective, batch-matches-single, prior-FWHM / sys-noise outer axes, per-term chi²). Sharding parity tests (sharded == single-device == numpy, direct/nfft × vis/closures, plus frequency) are gated on ≥2 local GPUs.Notes
default_rng(0)) — reproducible run-to-run; existing pol runs get a different (deterministic) initial guess.optaxandjax-finufftare added to thedev/gpuextras (optional);import ehtimstays free of both.Large + phased, stacks on #295 (the JAX objective). Based on
feature/jax-objectivefor a clean diff (just the optimizer + sharding work); retargets todev-backendonce #295 merges.