JAX-differentiable imaging objective on GPU: direct + NFFT, Stokes-I + pol + mf#295
Conversation
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## dev-backend #295 +/- ##
===============================================
- Coverage 47.54% 47.43% -0.11%
===============================================
Files 55 55
Lines 26977 26978 +1
Branches 4595 4599 +4
===============================================
- Hits 12825 12797 -28
- Misses 12663 12689 +26
- Partials 1489 1492 +3 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
achael
left a comment
There was a problem hiding this comment.
awesome stuff, I'm amazed this has been so simple so far! lots of comments but they boil down to 3 points.
- a few places where there seem to be calls to np. where i'm wondering if you want xp. for the flexible backend. But not being fluent in jax i'm not sure if it's necessary
- coverage of jaxifying the chisqs and gradient functions -- in pol_imager_utils.py in particular, it doesn't seem like all chisq functions are converted, while on the flip side some gradient functions are converted, when gradients aren't touched
- reworking the comments to be clear (some things deserve more comments so those not familiar with jax can follow) but a bit more terse (in bug fixes the comments don't need to signpost that there used to be a bug there).
| nfft_info.plan.f_hat = f_hat | ||
| nfft_info.plan.trafo() | ||
| return nfft_info.plan.f.copy() | ||
| from jax_finufft import nufft2 |
There was a problem hiding this comment.
minor: i'd prefer an explicit else statement here
| d3 = np.sqrt(np.abs(im_r2 - im)**2 + np.abs(im_l1r2 - im_r2)**2 + epsilon) | ||
| # Numerators below use cos/sin of the single-angle difference between | ||
| # neighbors, from d|P_l1 - P|^2/d|P| = 2|P| - 2|P_l1|*cos(angle(P_l1) - angle(P)). | ||
| # The back-neighbor magnitude numerators m2/m3 keep a |P| term that does not |
There was a problem hiding this comment.
make this a one-line comment (e.g. # mask the first row column gradient terms that don't exist)
| works for any pol mode / ttype the backend supports. jax is imported lazily so | ||
| `import ehtim` stays jax-free. | ||
|
|
||
| The only traced argument is x; everything else is captured in the closure, so |
There was a problem hiding this comment.
this last part of the docstring is a bit jargony and i don't think necessary -- i would remove or rephrase
There was a problem hiding this comment.
and i'd add some more one-line explanatory comments throughout for non jax experts to be able to follow what's going on here
There was a problem hiding this comment.
Cleaned up, dropped the dense paragraphs, reworded the confusing NFFTInfo line, and added short comments through the body
| a = jnp.asarray(a) | ||
| return jax.device_put(a, device) if device is not None else a | ||
|
|
||
| # A non-array A (the nfft NFFTInfo) is left as a host object for the kernel. |
There was a problem hiding this comment.
this comment in particular is a bit confusing
…, cleanups - Route the three pol nfft value chisqs (chisq_p/m/vvis_nfft) through nufft2_backend so pol+nfft is traceable under jax (the gap Andrew flagged); add obs_pol_nfft fixture + test_pol_nfft_*[IP,IV] parity tests. Value chisqs are now uniformly backend-agnostic; gradient functions stay numpy. - Move nufft2_backend to obs_helpers.py next to NFFTInfo/FINUFFTPlan; note plan is numpy-only. Explicit else in nufft2_backend + embed; unify embed/embed_imarr on/off index syntax. - Delete the dead mfrac_max>1 / vfrac_max>1 guards in mcv/vcv. - Trim comments: drop make_p_image NOTE, collapse reggrad_ptv boundary comment, de-jargon make_objective_jax docstring; add jax_device docstring examples. - Warn when jax_enable_x64 is off in make_objective_jax (float32 degrades grad).
5cd669b to
0482c7c
Compare
|
Thanks @achael this was really useful! 1. np. vs xp. Every bare 2. chisq vs grad conversion. Good catch, it really was inconsistent. The rule I'm following is: the value chisqs are backend-agnostic (jax autodiffs them), and the gradient functions stay pure numpy (under jax, autodiff replaces them entirely, so they're never traced). What you spotted was the pol nfft values — 3. Trimmed the comments, cut the jargon out of the |
achael
left a comment
There was a problem hiding this comment.
looks great! I misread some of the changes in the git diff before which lead to the wrong comments about function coverage.
My only remaining comments are on test_objective_jax now. I don't think the current version covers all Stokes I chisq/regularizer functions, and the structure is not parallel between stokes I and polarization. It might be useful just to have Claude rewrite this whole module now that the jax functions are fixed, since probably some of the current non-parallel structure came together as the code was being developed.
Here is what I think the full jax test module should contain:
- coverage of ['direct','nfft'] transforms -- already done
- coverage of different pairs of data_terms and reg_terms that, taken together, exercise all of the regularizers and chisq functions
- coverage of different imaging modes - ['I','IP','P','V','IV','IPV'] that together exercise the different transform functions ['polcv','mcv','vcv']
- uniform comparison to analytic and finite diff gradients in all cases. We should xfail the one fixed by the pending bugfix PR, and use epsilon_tv != uniformly between pol and stokes I gradients to avoid singularities
- similar coverage for standard Stokes I multifrequency chisqs, regularizers -- with a note that we need to add real multifreq synthetic data to the test and that multifreq polarization coverage needs to be added.
Finally, I think there should also be a simple test on equality between the numpy/jax branches of nufft2_backend (this is already exercised by the chisqs, but I think that its useful to have a dedicated test).
| out = xp.stack((imarr[0], rho, imarr[2], psi)) | ||
| return out | ||
|
|
||
| def vcv_r(imarr): |
There was a problem hiding this comment.
i see, i was misreading what function these lines were in in the github diff. Makes sense; we should just keep in mind this asymmetric jaxification of the _r functions in case they ever did end up being used in an optimization. .
| NFFT_EPS = 1e-12 | ||
| NFFT_GRAD_RTOL = 1e-7 | ||
| GRAD_ATOL = 1e-12 | ||
| # Pol grad tests validate jax.grad against central FD (the ground truth), NOT the |
There was a problem hiding this comment.
i think it would be good to add in the tests against the analytic gradients and xfail the failing vvis one until we patch in the recent fix.
Isn't the gradient singularity for vtv/ptv also present in normal tv when epsilon_tv=0, as it is by default? I don't see why that should prevent us from testing with finite epsilon.
| # Stokes-I reg only here; pol regs are exercised below. | ||
| imgr = eh.imager.Imager( | ||
| obs_pol_direct, gauss_prior, prior_im=gauss_prior, flux=gauss_im_pol.total_flux(), | ||
| data_term={"amp": 100, "pvis": 100, "m": 50}, |
There was a problem hiding this comment.
pull data_term and reg_term out to top as in stokes I imager
There was a problem hiding this comment.
and make sure the test exercises all of the pol regularizers and chisqs
| assert np.all(np.isfinite(np.asarray(g))) | ||
|
|
||
|
|
||
| def test_ptv_epsilon_tv_removes_singularity(): |
There was a problem hiding this comment.
i dont think this should be in the jax-objective test; probably in test-regularizers instead?
| RNG_SEED = 4 | ||
| PERTURB = 0.10 | ||
|
|
||
| DATA_TERM = {"amp": 100, "cphase": 100, "logcamp": 50} |
There was a problem hiding this comment.
this doesn't exercise all of the data terms and regularizer terms -- it would be good to loop over different data_term and reg_term combos that in the union cover all of the individual data_term and reg_term functions.
|
|
||
| # nfft pol: routes chisq_{p,m,vvis}_nfft through nufft2_backend (jax_finufft). | ||
| # IP exercises chisq_p_nfft (pvis) + chisq_m_nfft (m); IV exercises chisq_vvis_nfft. | ||
| _POL_NFFT_CASES = [ |
There was a problem hiding this comment.
add and IPV test to cover full-stokes imaging and the polcv transform.
| f"{rtype} max fractional gradient diff = {max_frac:.6f} (tol={max_tol})" | ||
| ) | ||
|
|
||
| def test_reggrad_ptv_matches_fd_on_boundary(self): |
There was a problem hiding this comment.
adapt this test to run over all tv functions; tv, tv2, poltv, vtv?
|
I merged the IV gradient fix to dev-backend in #299 so in the new tests of the jax vs analytic pol gradients we shouldn't need to xfail anything |
|
Thanks @achael! I rewrote Every case now checks the three things uniformly, value parity (numpy vs jax), analytic-gradient parity (jax.grad vs the hand-written objgrad), and finite differences. The parametrized cases together exercise:
Two things made it actually catch bugs:
The bugs it found (all same as #296):
The fix in each is the mirror of #296: compute the needed term whenever the coupled variable is solved (
test_regularizers: I generalized the ptv boundary FD test over all the TV funcs (tv/tv2/ptv/vtv/vtv2), and added a parametrized epsilon_tv-removes-singularity test (tv/ptv/vtv). The |
* Add physical_grad_slots helper Maps the Stokes DOF mask to the physical gradout slots the chisq/reg kernels must fill, centralizing the mcv/vcv cross-coupling that mirrors transform_gradients' Jacobian sparsity. Not yet wired in. + unit tests. * Wire physical_grad_slots into chisq and reg gradient dicts Feed the cross-coupling-aware mask to the pol gradient kernels in both compute_chisqgrad_dict and compute_reggrad_dict. Behavior-identical for now (kernels still carry the or-patches). Guard physical_grad_slots against sub-4-wide single-pol masks (Stokes-I carries 'mcv' inertly). + regression test. * Revert vvis kernels to diagonal pol_solve gating The mcv/vcv cross-coupling now lives in physical_grad_slots, so drop the 'or pol_solve[3]' patches (#296) in chisqgrad_vvis / chisqgrad_vvis_nfft; each physical slot keys on its own bit again. Note in each pol chisqgrad docstring that pol_solve flags required physical gradients, not DOFs. * Fix reggrad_ptv first-row/col boundary masking + epsilon_tv Zero the back-neighbor (m2/m3) terms on the first row/column in reggrad_ptv slots 0/1/3 (the back-neighbor is the zero pad), matching reggrad_vtv/reggrad_tv. Pre-fix the whole first row+col of those slots was wrong (corner ~4x off vs FD). Add epsilon_tv to reg_ptv/reggrad_ptv denominators (default 0, byte-identical) for #295 parity. Note pol_solve = physical-gradient slots in the 8 pol reggrad docstrings. Add full-grid boundary FD regression tests for ptv, vtv, and Stokes-I tv. * Note pol_solve semantics in polchisqgrad docstring polchisqgrad is a legacy shim (parity tests only); document that its pol_solve is a physical-gradient mask, not a raw DOF mask. * Drive pol regularizer FD with all four physical slots _pol_solve_for now returns [1,1,1,1] so the previously-blind cross- coupling slots are FD-checked: reggrad_ptv psi (3), reggrad_vflux/l1v/ l2v/vtv rho (1), and slot 0 for every pol reg. Proves the reg-grad slots are individually correct against finite differences. * Add pol chisq FD + cross-ttype tests in test_chisquared.py New pol coverage in its final-home file: TestPolChisqGradFD checks chisqgrad_p/m/vvis against finite differences of the chisq value in all four physical slots (pol_solve=[1,1,1,1]) for direct+nfft, asserting vvis slot 2 (EVPA) is identically zero. TestPolChisq{,Grad}Consistency check direct-vs-nfft agreement. Closes the m / p-slot-3 blind spots. * Add parametrized pol objective-FD sampling the polarization DOF block TestObjectiveGradPolarimetricFD checks objgrad vs FD for IP/IV/IQUV x {direct,nfft}, with each case bundling its pol data terms + a pol reg, so both the chisq and reg gradient paths through physical_grad_slots are exercised. Samples the pol DOF block (past the Stokes-I block), where the mcv/vcv cross-coupling lives -- the existing global-sampling FD tests missed it (the dropped IP slot-3 term is ~4% off FD at V=0.02*I, ~430% at V=0.2*I). Comments out the now-subsumed test_fd_matches_analytic_polarimetric (backend) and test_iv_gradient_matches_finite_difference (e2e). * Use an asymmetric image for chisq/regularizer/gradient FD fixtures Add make_asym_image (broad offset/elongated/rotated double-Gaussian) and switch the Stokes-I FD fixtures (chisq_setup, reg_setup, mfreg_setup, grad_setup) to it. Breaking the reflection/rotation/x<->y symmetry of the centered Gaussian surfaces boundary/axis-ordering bugs a symmetric image hides. Blobs kept broad (grid-filling, no dead pixels) so the |.|-kink TV gradients stay FD-well-conditioned at epsilon_tv=0; all tolerances unchanged. * Use asymmetric + spatially-varying pol in pol FD test fixtures chisq_setup_pol and a new asym_pol_setup build on make_asym_image and use add_random_pol (ccorr>0) so EVPA, vfrac, rho, and psi all vary spatially instead of a constant pol fraction. polreg_setup switches its Stokes I to the asymmetric image (keeping the per-pixel pol jitter that keeps TV denominators non-degenerate). TestObjectiveGradPolarimetricFD now uses the structured-pol obs. Widen the pol chisq FD check to a median+max split (median 1e-5, max 1e-3): the structured-pol imcur has sharper local curvature, so 2nd-order FD truncation pushes a few small-gradient pixels to ~2.6e-4 -- well below any real pol-gradient bug (%-level), which the tight median still catches. * Comment cleanup + epsilon_tv consistency in pol_imager_utils Manual review pass: per-slot dR/dX labels, docstrings on the reg kernels, a module-level CONVENTIONS block (imarr = [I, rho, phi=2chi, psi]), and removal of stale TODOs. Two behavior touches, both byte-identical at the defaults: - reg_vtv / reggrad_vtv now honor epsilon_tv (kwargs, default 0) like the ptv pair, instead of the value ignoring it while the grad pinned it to 0. - reggrad_ptv masks the chi-slot back-neighbor terms (c2/c3) too, for uniformity (they already self-zero at the pad). Plus an mcv_r exception-message fix and ruff-clean whitespace. * Comment cleanup in imager_utils (no behavior change) Manual review pass: docstrings on the Stokes-I reg kernels, per-block comments, 'fourier/transform matrices' labels on the diag Amatrices unpacking, and removal of dead commented-out systematic-noise code in the bispectrum data functions (the intent is now documented in apply_systematic_noise_snrcut). Purely cosmetic; ruff-clean. * fixed lint errors in test_regularizers and test_chisquared
achael
left a comment
There was a problem hiding this comment.
Thanks @rohandahale, looks good, ready to merge!
3 notes for the next PR:
- looks like EPSILON is a global constant in the spectral index TV -- we should change this to epsilon_tv to match the other functions.
- I think _rho_psi_safe should be promoted to a helper function like (rho_psi_from_mfrac_vfrac) or something -- it could have utililty outside the specific pol transforms here
- We should consider reviewing the tests and consolidating the gradient tests in particular soon. They are scattered in different files and I think there is some redundancy & non-uniformity in how they test.
| impad = xp.pad(im, 1, mode='constant', constant_values=0) | ||
| im_l1 = xp.roll(impad, -1, axis=0)[1:ny+1, 1:nx+1] | ||
| im_l2 = xp.roll(impad, -1, axis=1)[1:ny+1, 1:nx+1] | ||
| return xp.sum(xp.sqrt(xp.abs(im_l1 - im)**2 + xp.abs(im_l2 - im)**2 + EPSILON)) / norm |
There was a problem hiding this comment.
just noticed this EPSILON is fixed -- should make it epsilon_tv as well I think, later.
Address #295 review: epsilon_tv, public rho_psi helper, gradient-test consolidation
Makes the full imaging objective differentiable under JAX:
jax.grad(compute_objective)now equals the hand-writtencompute_objective_grad, on CPU and GPU, across Stokes-I + polarization + multifrequency,ttypedirect + NFFT, and all regularizers. Builds on #291's backend switch (array_namespace) by making the rest of the objective backend-agnostic. NumPy behavior is byte-identical throughout.Now also includes the
use_jaxflag onImager.make_image— an end-to-end jax reconstruction through scipy L-BFGS-B that recovers the source and matches the numpy recon — plus explicit GPU device + CPU↔GPU-parity tests for the full objective (direct + nfft). Remaining (draft): a full multifrequency imaging recon test (the mf forward has unit coverage; mf is deprioritized this week) and where to document the GPU jax-finufft build.What changed
unpack_imarr/transform_imarrmade functional (no in-place mutation); newmake_objective_jax(...)factory returns a scipyfun(x)->(value, grad)(closure over all-but-x, lazy jax import soimport ehtimstays jax-free).make_image(use_jax=True)builds the factory and runs scipy L-BFGS-B on the jax objective (numpy path unchanged when off; GPU by default on a CUDA box).nufft2_backenddispatches finufft (numpy, byte-identical) vsjax_finufft.nufft2(jax, differentiable, GPU). NumPy NFFT still uses only finufft. 6 chisq kernels.embed()/embed_imarr()functional scatter (full & partial mask) + 8 pol regs + 2 spectral.epsilon_tvonreg_ptv/reg_vtv(default 0, byte-identical, matchingreg_tv); andreggrad_ptvnow zeros its back-neighbor terms on the first row/col — the boundary maskingreggrad_tv/reggrad_vtvalready have.make_*_image,polcv/mcv/vcv(legacymcv/vcvdeadraiseguards kept numpy-only),chisq_p/m/vvis,image_at_freq.tutorials/ehtim_tutorial_jax.ipynb— install (jax + the GPU jax-finufft source build), thejax.grad(objfunc) == objgraddemo (direct + nfft on the GPU), and an end-to-endmake_image(use_jax=True)recon.Validation (
tests/test_objective_jax.py,@pytest.mark.jax)direct+nfft × {value parity, gold-standard grad parity, FD, factory, no-retrace} + embed/partial-mask + pol IP/IV (grad vs FD ground truth) +
make_image(use_jax=True)recovery (matches numpy) + GPU device & CPU↔GPU parity (direct + nfft) — 22 tests, green on calypso (CPU + GPU). NumPy regressions green (458 backend+e2e, 146 pol, 102 pol-reg). Ruff clean.1e-6, 1e-12→7e-10); tests usenfft_eps=1e-12.Pre-existing bugs surfaced by the autodiff-vs-analytic comparison
sqrt(unlikereg_tv'sepsilon_tv) → singular gradient at smooth pixels. Fixed here viaepsilon_tv(default 0, so byte-identical off).reggrad_ptvwas missing the first-row/col boundary masking thatreggrad_tv/reggrad_vtvhave → the entire first row + column of the magnitude-gradient slots (I/m/psi) was wrong (corner 4× off vs finite-difference; phase slot was fine because its numerators self-zero at the pad). Fixed here; FD-matched to ~1e-10 everywhere after, and a new full-grid regression test (test_reggrad_ptv_matches_fd_on_boundary) fails pre-fix.chisqgrad_vvis/vcv_grad): for IV imaging jax autodiff and finite differences agree, but the hand-written analytic dropped the physicalrhogradient thevcvchain consumes (solverv'-gradient ~87% off). Pre-existing (the jax port didn't touch the analytic grad kernels) — likestv_pol_grad(Fix factor-of-2 bug in stv_pol_grad gradient #240). Fixed in standalone PR Fix wrong V-pol (IV) imaging gradient: chisqgrad_vvis drops the rho gradient #296 (independent function, no overlap with this PR — kept separate so the correctness fix can merge fast).