Skip to content

JAX-differentiable imaging objective on GPU: direct + NFFT, Stokes-I + pol + mf#295

Merged
achael merged 23 commits into
dev-backendfrom
feature/jax-objective
Jun 18, 2026
Merged

JAX-differentiable imaging objective on GPU: direct + NFFT, Stokes-I + pol + mf#295
achael merged 23 commits into
dev-backendfrom
feature/jax-objective

Conversation

@rohandahale

@rohandahale rohandahale commented Jun 9, 2026

Copy link
Copy Markdown
Collaborator

Update (2026-06-17) — rebased onto current dev-backend.
The pol-gradient correctness fixes this PR originally surfaced are now upstream, so #295 is scoped to the JAX implementation + tests:

Validated on the rebased branch: 1784 numpy + 71 jax green. The sections below are retained as history of where those fixes originated.


Makes the full imaging objective differentiable under JAX: jax.grad(compute_objective) now equals the hand-written compute_objective_grad, on CPU and GPU, across Stokes-I + polarization + multifrequency, ttype direct + NFFT, and all regularizers. Builds on #291's backend switch (array_namespace) by making the rest of the objective backend-agnostic. NumPy behavior is byte-identical throughout.

Now also includes the use_jax flag on Imager.make_image — an end-to-end jax reconstruction through scipy L-BFGS-B that recovers the source and matches the numpy recon — plus explicit GPU device + CPU↔GPU-parity tests for the full objective (direct + nfft). Remaining (draft): a full multifrequency imaging recon test (the mf forward has unit coverage; mf is deprioritized this week) and where to document the GPU jax-finufft build.

What changed

  • direct objective: unpack_imarr/transform_imarr made functional (no in-place mutation); new make_objective_jax(...) factory returns a scipy fun(x)->(value, grad) (closure over all-but-x, lazy jax import so import ehtim stays jax-free).
  • OO integration: make_image(use_jax=True) builds the factory and runs scipy L-BFGS-B on the jax objective (numpy path unchanged when off; GPU by default on a CUDA box).
  • NFFT under jax-finufft: nufft2_backend dispatches finufft (numpy, byte-identical) vs jax_finufft.nufft2 (jax, differentiable, GPU). NumPy NFFT still uses only finufft. 6 chisq kernels.
  • all regularizers agnostic: 6 embed-free (Add a NumPy/JAX backend switch and differentiable Stokes-I imaging objectives #291) + 9 spatial + full embed()/embed_imarr() functional scatter (full & partial mask) + 8 pol regs + 2 spectral.
  • two pol-TV gradient correctness fixes (both surfaced by the autodiff-vs-analytic comparison, see below): epsilon_tv on reg_ptv/reg_vtv (default 0, byte-identical, matching reg_tv); and reggrad_ptv now zeros its back-neighbor terms on the first row/col — the boundary masking reggrad_tv/reggrad_vtv already have.
  • pol/mf forward: make_*_image, polcv/mcv/vcv (legacy mcv/vcv dead raise guards kept numpy-only), chisq_p/m/vvis, image_at_freq.
  • tutorial: tutorials/ehtim_tutorial_jax.ipynb — install (jax + the GPU jax-finufft source build), the jax.grad(objfunc) == objgrad demo (direct + nfft on the GPU), and an end-to-end make_image(use_jax=True) recon.

Validation (tests/test_objective_jax.py, @pytest.mark.jax)

direct+nfft × {value parity, gold-standard grad parity, FD, factory, no-retrace} + embed/partial-mask + pol IP/IV (grad vs FD ground truth) + make_image(use_jax=True) recovery (matches numpy) + GPU device & CPU↔GPU parity (direct + nfft) — 22 tests, green on calypso (CPU + GPU). NumPy regressions green (458 backend+e2e, 146 pol, 102 pol-reg). Ruff clean.

  • NFFT grad parity is eps-limited, not a bug (proven: residual scales 1e-9→1e-6, 1e-12→7e-10); tests use nfft_eps=1e-12.

Pre-existing bugs surfaced by the autodiff-vs-analytic comparison

  1. pol TV regularizers had no epsilon in their sqrt (unlike reg_tv's epsilon_tv) → singular gradient at smooth pixels. Fixed here via epsilon_tv (default 0, so byte-identical off).
  2. reggrad_ptv was missing the first-row/col boundary masking that reggrad_tv/reggrad_vtv have → the entire first row + column of the magnitude-gradient slots (I/m/psi) was wrong (corner 4× off vs finite-difference; phase slot was fine because its numerators self-zero at the pad). Fixed here; FD-matched to ~1e-10 everywhere after, and a new full-grid regression test (test_reggrad_ptv_matches_fd_on_boundary) fails pre-fix.
  3. The V-pol analytic gradient is wrong (chisqgrad_vvis / vcv_grad): for IV imaging jax autodiff and finite differences agree, but the hand-written analytic dropped the physical rho gradient the vcv chain consumes (solver v'-gradient ~87% off). Pre-existing (the jax port didn't touch the analytic grad kernels) — like stv_pol_grad (Fix factor-of-2 bug in stv_pol_grad gradient #240). Fixed in standalone PR Fix wrong V-pol (IV) imaging gradient: chisqgrad_vvis drops the rho gradient #296 (independent function, no overlap with this PR — kept separate so the correctness fix can merge fast).

@codecov

codecov Bot commented Jun 9, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 84.31373% with 40 lines in your changes missing coverage. Please review.
✅ Project coverage is 47.43%. Comparing base (c3ea9a2) to head (2e3edcb).

Files with missing lines Patch % Lines
ehtim/imaging/imager_backend.py 58.13% 17 Missing and 1 partial ⚠️
ehtim/imaging/imager_utils.py 90.00% 7 Missing and 2 partials ⚠️
ehtim/imaging/multifreq_imager_utils.py 64.70% 6 Missing ⚠️
ehtim/observing/obs_helpers.py 69.23% 3 Missing and 1 partial ⚠️
ehtim/imager.py 25.00% 2 Missing and 1 partial ⚠️
Additional details and impacted files
@@               Coverage Diff               @@
##           dev-backend     #295      +/-   ##
===============================================
- Coverage        47.54%   47.43%   -0.11%     
===============================================
  Files               55       55              
  Lines            26977    26978       +1     
  Branches          4595     4599       +4     
===============================================
- Hits             12825    12797      -28     
- Misses           12663    12689      +26     
- Partials          1489     1492       +3     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@rohandahale rohandahale requested a review from achael June 10, 2026 11:51
@rohandahale rohandahale self-assigned this Jun 10, 2026
@rohandahale rohandahale added this to the 2.0 milestone Jun 10, 2026
@rohandahale rohandahale marked this pull request as ready for review June 10, 2026 11:54

@achael achael left a comment

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome stuff, I'm amazed this has been so simple so far! lots of comments but they boil down to 3 points.

  1. a few places where there seem to be calls to np. where i'm wondering if you want xp. for the flexible backend. But not being fluent in jax i'm not sure if it's necessary
  2. coverage of jaxifying the chisqs and gradient functions -- in pol_imager_utils.py in particular, it doesn't seem like all chisq functions are converted, while on the flip side some gradient functions are converted, when gradients aren't touched
  3. reworking the comments to be clear (some things deserve more comments so those not familiar with jax can follow) but a bit more terse (in bug fixes the comments don't need to signpost that there used to be a bug there).

Comment thread ehtim/imaging/imager_utils.py Outdated
nfft_info.plan.f_hat = f_hat
nfft_info.plan.trafo()
return nfft_info.plan.f.copy()
from jax_finufft import nufft2

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: i'd prefer an explicit else statement here

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment thread ehtim/imaging/imager_utils.py
Comment thread ehtim/imaging/imager_utils.py
Comment thread ehtim/imaging/imager_utils.py
Comment thread ehtim/imaging/imager_utils.py
Comment thread ehtim/imaging/pol_imager_utils.py
Comment thread ehtim/imaging/pol_imager_utils.py
Comment thread ehtim/imaging/pol_imager_utils.py Outdated
d3 = np.sqrt(np.abs(im_r2 - im)**2 + np.abs(im_l1r2 - im_r2)**2 + epsilon)
# Numerators below use cos/sin of the single-angle difference between
# neighbors, from d|P_l1 - P|^2/d|P| = 2|P| - 2|P_l1|*cos(angle(P_l1) - angle(P)).
# The back-neighbor magnitude numerators m2/m3 keep a |P| term that does not

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make this a one-line comment (e.g. # mask the first row column gradient terms that don't exist)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment thread ehtim/imaging/imager_backend.py Outdated
works for any pol mode / ttype the backend supports. jax is imported lazily so
`import ehtim` stays jax-free.

The only traced argument is x; everything else is captured in the closure, so

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this last part of the docstring is a bit jargony and i don't think necessary -- i would remove or rephrase

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and i'd add some more one-line explanatory comments throughout for non jax experts to be able to follow what's going on here

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cleaned up, dropped the dense paragraphs, reworded the confusing NFFTInfo line, and added short comments through the body

Comment thread ehtim/imaging/imager_backend.py Outdated
a = jnp.asarray(a)
return jax.device_put(a, device) if device is not None else a

# A non-array A (the nfft NFFTInfo) is left as a host object for the kernel.

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment in particular is a bit confusing

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed it

rohandahale pushed a commit that referenced this pull request Jun 11, 2026
…, cleanups

- Route the three pol nfft value chisqs (chisq_p/m/vvis_nfft) through
  nufft2_backend so pol+nfft is traceable under jax (the gap Andrew flagged);
  add obs_pol_nfft fixture + test_pol_nfft_*[IP,IV] parity tests. Value chisqs
  are now uniformly backend-agnostic; gradient functions stay numpy.
- Move nufft2_backend to obs_helpers.py next to NFFTInfo/FINUFFTPlan; note plan
  is numpy-only. Explicit else in nufft2_backend + embed; unify embed/embed_imarr
  on/off index syntax.
- Delete the dead mfrac_max>1 / vfrac_max>1 guards in mcv/vcv.
- Trim comments: drop make_p_image NOTE, collapse reggrad_ptv boundary comment,
  de-jargon make_objective_jax docstring; add jax_device docstring examples.
- Warn when jax_enable_x64 is off in make_objective_jax (float32 degrades grad).
@rohandahale rohandahale force-pushed the feature/jax-objective branch from 5cd669b to 0482c7c Compare June 11, 2026 18:22
@rohandahale

Copy link
Copy Markdown
Collaborator Author

Thanks @achael this was really useful!

1. np. vs xp. Every bare np. you flagged is on something that doesn't depend on the image we're solving for — coordinate grids from meshgrid, or scalar arithmetic like np.log(flux/npix) where flux/npix is just a Python float. None of that gets traced. And when one of these constant numpy arrays multiplies the traced image (e.g. imvec*xx), jax folds it in as a constant on its own, so switching to xp. wouldn't get anything — for the meshgrid it'd actually be worse, since xp.meshgrid would force a device array where a baked-in constant is exactly what we want. The one spot where it matters, the log of the image itself, already uses xp.log(imvec). So the bare np.s are all intentional.

2. chisq vs grad conversion. Good catch, it really was inconsistent. The rule I'm following is: the value chisqs are backend-agnostic (jax autodiffs them), and the gradient functions stay pure numpy (under jax, autodiff replaces them entirely, so they're never traced). What you spotted was the pol nfft values — chisq_p_nfft / chisq_m_nfft / chisq_vvis_nfft were still driving the finufft plan directly, which jax can't trace through. They now go through nufft2_backend like the Stokes-I nfft kernels and the non-nfft pol ones. That path actually didn't work under jax at all before this, so I added `test_pol_nfft_*[IP,IV], they pass on the GPU.

3. Trimmed the comments, cut the jargon out of the make_objective_jax docstring, and added plain one-liners.

@achael achael left a comment

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great! I misread some of the changes in the git diff before which lead to the wrong comments about function coverage.

My only remaining comments are on test_objective_jax now. I don't think the current version covers all Stokes I chisq/regularizer functions, and the structure is not parallel between stokes I and polarization. It might be useful just to have Claude rewrite this whole module now that the jax functions are fixed, since probably some of the current non-parallel structure came together as the code was being developed.

Here is what I think the full jax test module should contain:

  • coverage of ['direct','nfft'] transforms -- already done
  • coverage of different pairs of data_terms and reg_terms that, taken together, exercise all of the regularizers and chisq functions
  • coverage of different imaging modes - ['I','IP','P','V','IV','IPV'] that together exercise the different transform functions ['polcv','mcv','vcv']
  • uniform comparison to analytic and finite diff gradients in all cases. We should xfail the one fixed by the pending bugfix PR, and use epsilon_tv != uniformly between pol and stokes I gradients to avoid singularities
  • similar coverage for standard Stokes I multifrequency chisqs, regularizers -- with a note that we need to add real multifreq synthetic data to the test and that multifreq polarization coverage needs to be added.

Finally, I think there should also be a simple test on equality between the numpy/jax branches of nufft2_backend (this is already exercised by the chisqs, but I think that its useful to have a dedicated test).

out = xp.stack((imarr[0], rho, imarr[2], psi))
return out

def vcv_r(imarr):

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see, i was misreading what function these lines were in in the github diff. Makes sense; we should just keep in mind this asymmetric jaxification of the _r functions in case they ever did end up being used in an optimization. .

Comment thread ehtim/imaging/pol_imager_utils.py
Comment thread tests/test_objective_jax.py Outdated
NFFT_EPS = 1e-12
NFFT_GRAD_RTOL = 1e-7
GRAD_ATOL = 1e-12
# Pol grad tests validate jax.grad against central FD (the ground truth), NOT the

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it would be good to add in the tests against the analytic gradients and xfail the failing vvis one until we patch in the recent fix.

Isn't the gradient singularity for vtv/ptv also present in normal tv when epsilon_tv=0, as it is by default? I don't see why that should prevent us from testing with finite epsilon.

Comment thread tests/test_objective_jax.py Outdated
# Stokes-I reg only here; pol regs are exercised below.
imgr = eh.imager.Imager(
obs_pol_direct, gauss_prior, prior_im=gauss_prior, flux=gauss_im_pol.total_flux(),
data_term={"amp": 100, "pvis": 100, "m": 50},

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pull data_term and reg_term out to top as in stokes I imager

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and make sure the test exercises all of the pol regularizers and chisqs

Comment thread tests/test_objective_jax.py
Comment thread tests/test_objective_jax.py Outdated
assert np.all(np.isfinite(np.asarray(g)))


def test_ptv_epsilon_tv_removes_singularity():

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i dont think this should be in the jax-objective test; probably in test-regularizers instead?

Comment thread tests/test_objective_jax.py Outdated
RNG_SEED = 4
PERTURB = 0.10

DATA_TERM = {"amp": 100, "cphase": 100, "logcamp": 50}

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't exercise all of the data terms and regularizer terms -- it would be good to loop over different data_term and reg_term combos that in the union cover all of the individual data_term and reg_term functions.

Comment thread tests/test_objective_jax.py
Comment thread tests/test_objective_jax.py Outdated

# nfft pol: routes chisq_{p,m,vvis}_nfft through nufft2_backend (jax_finufft).
# IP exercises chisq_p_nfft (pvis) + chisq_m_nfft (m); IV exercises chisq_vvis_nfft.
_POL_NFFT_CASES = [

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add and IPV test to cover full-stokes imaging and the polcv transform.

Comment thread tests/test_regularizers.py Outdated
f"{rtype} max fractional gradient diff = {max_frac:.6f} (tol={max_tol})"
)

def test_reggrad_ptv_matches_fd_on_boundary(self):

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adapt this test to run over all tv functions; tv, tv2, poltv, vtv?

@achael

achael commented Jun 12, 2026

Copy link
Copy Markdown
Owner

I merged the IV gradient fix to dev-backend in #299 so in the new tests of the jax vs analytic pol gradients we shouldn't need to xfail anything

@rohandahale

rohandahale commented Jun 13, 2026

Copy link
Copy Markdown
Collaborator Author

Thanks @achael! I rewrote test_objective_jax along the lines you laid out. The broader coverage caught more gradient bugs in the same family as the IV one (#296).

Every case now checks the three things uniformly, value parity (numpy vs jax), analytic-gradient parity (jax.grad vs the hand-written objgrad), and finite differences. The parametrized cases together exercise:

  • both transforms (direct + nfft),
  • all the Stokes-I chisqs (vis/amp/bs/cphase/camp/logcamp) and regularizers,
  • all six pol modes : IP/P (mcv), IV/V (vcv), IQUV/IPV (polcv) and all the pol chisqs (pvis/m/vvis) and pol regs,
  • a multifrequency Stokes-I case.

Two things made it actually catch bugs:

  1. An asymmetric image. A centered, circular source zeros gradient components by symmetry, so the analytic and autodiff gradients can disagree and you'd never see it. Switching to an off-center, elongated source exposed the disagreements.
  2. Comparing analytic vs autodiff vs FD together: three independent gradients

The bugs it found (all same as #296):
The mcv/vcv transforms have non-diagonal Jacobians (mprime couples to both rho and psi, vprime likewise) so the gradient has to give both physical-slot gradients that the chain rule consumes. Several places used the term on the wrong pol_solve index and dropped it (exactly like chisqgrad_vvis did for IV):

  • chisqgrad_p / chisqgrad_m dropped grad_psi for pol='IP' --> wrong IP/P gradient
  • V-pol regs (vflux/l1v/l2v/vtv/vtv2) dropped grad_rho for IV/V.
  • linear-pol regs (msimple/hw/ptv) dropped grad_psi for IP/P.

The fix in each is the mirror of #296: compute the needed term whenever the coupled variable is solved (pol_solve[1] or pol_solve[3]). These affect the main branch numpy pol imaging (IP especially), symmetric test images had been hiding them. They are in this PR, but we can move them to dev and main if you want.

  • NaN-safe mcv/vcv transforms. jax.grad of the pol transform returned NaN at zero-polarization pixels (sqrt(0) / arcsin(±1) in the backward pass), so jax pol imaging would NaN on any realistic image. I used where to fix it. numpy values are unchanged.
  • IPV was in POLARIZATION_MODES but the validator rejected it. It now works (identical to IQUV / polcv).

test_regularizers: I generalized the ptv boundary FD test over all the TV funcs (tv/tv2/ptv/vtv/vtv2), and added a parametrized epsilon_tv-removes-singularity test (tv/ptv/vtv). The image_at_freq test moved out of the jax module; it's covered there now by a proper multifrequency-imager case instead.

@achael achael mentioned this pull request Jun 16, 2026
achael added a commit that referenced this pull request Jun 16, 2026
* Add physical_grad_slots helper

Maps the Stokes DOF mask to the physical gradout slots the chisq/reg
kernels must fill, centralizing the mcv/vcv cross-coupling that mirrors
transform_gradients' Jacobian sparsity. Not yet wired in. + unit tests.

* Wire physical_grad_slots into chisq and reg gradient dicts

Feed the cross-coupling-aware mask to the pol gradient kernels in both
compute_chisqgrad_dict and compute_reggrad_dict. Behavior-identical for
now (kernels still carry the or-patches). Guard physical_grad_slots
against sub-4-wide single-pol masks (Stokes-I carries 'mcv' inertly).
+ regression test.

* Revert vvis kernels to diagonal pol_solve gating

The mcv/vcv cross-coupling now lives in physical_grad_slots, so drop the
'or pol_solve[3]' patches (#296) in chisqgrad_vvis / chisqgrad_vvis_nfft;
each physical slot keys on its own bit again. Note in each pol chisqgrad
docstring that pol_solve flags required physical gradients, not DOFs.

* Fix reggrad_ptv first-row/col boundary masking + epsilon_tv

Zero the back-neighbor (m2/m3) terms on the first row/column in
reggrad_ptv slots 0/1/3 (the back-neighbor is the zero pad), matching
reggrad_vtv/reggrad_tv. Pre-fix the whole first row+col of those slots
was wrong (corner ~4x off vs FD). Add epsilon_tv to reg_ptv/reggrad_ptv
denominators (default 0, byte-identical) for #295 parity. Note pol_solve
= physical-gradient slots in the 8 pol reggrad docstrings.

Add full-grid boundary FD regression tests for ptv, vtv, and Stokes-I tv.

* Note pol_solve semantics in polchisqgrad docstring

polchisqgrad is a legacy shim (parity tests only); document that its
pol_solve is a physical-gradient mask, not a raw DOF mask.

* Drive pol regularizer FD with all four physical slots

_pol_solve_for now returns [1,1,1,1] so the previously-blind cross-
coupling slots are FD-checked: reggrad_ptv psi (3), reggrad_vflux/l1v/
l2v/vtv rho (1), and slot 0 for every pol reg. Proves the reg-grad
slots are individually correct against finite differences.

* Add pol chisq FD + cross-ttype tests in test_chisquared.py

New pol coverage in its final-home file: TestPolChisqGradFD checks
chisqgrad_p/m/vvis against finite differences of the chisq value in all
four physical slots (pol_solve=[1,1,1,1]) for direct+nfft, asserting
vvis slot 2 (EVPA) is identically zero. TestPolChisq{,Grad}Consistency
check direct-vs-nfft agreement. Closes the m / p-slot-3 blind spots.

* Add parametrized pol objective-FD sampling the polarization DOF block

TestObjectiveGradPolarimetricFD checks objgrad vs FD for IP/IV/IQUV x
{direct,nfft}, with each case bundling its pol data terms + a pol reg, so
both the chisq and reg gradient paths through physical_grad_slots are
exercised. Samples the pol DOF block (past the Stokes-I block), where the
mcv/vcv cross-coupling lives -- the existing global-sampling FD tests
missed it (the dropped IP slot-3 term is ~4% off FD at V=0.02*I, ~430% at
V=0.2*I). Comments out the now-subsumed test_fd_matches_analytic_polarimetric
(backend) and test_iv_gradient_matches_finite_difference (e2e).

* Use an asymmetric image for chisq/regularizer/gradient FD fixtures

Add make_asym_image (broad offset/elongated/rotated double-Gaussian) and
switch the Stokes-I FD fixtures (chisq_setup, reg_setup, mfreg_setup,
grad_setup) to it. Breaking the reflection/rotation/x<->y symmetry of the
centered Gaussian surfaces boundary/axis-ordering bugs a symmetric image
hides. Blobs kept broad (grid-filling, no dead pixels) so the |.|-kink TV
gradients stay FD-well-conditioned at epsilon_tv=0; all tolerances unchanged.

* Use asymmetric + spatially-varying pol in pol FD test fixtures

chisq_setup_pol and a new asym_pol_setup build on make_asym_image and use
add_random_pol (ccorr>0) so EVPA, vfrac, rho, and psi all vary spatially
instead of a constant pol fraction. polreg_setup switches its Stokes I to
the asymmetric image (keeping the per-pixel pol jitter that keeps TV
denominators non-degenerate). TestObjectiveGradPolarimetricFD now uses the
structured-pol obs.

Widen the pol chisq FD check to a median+max split (median 1e-5, max 1e-3):
the structured-pol imcur has sharper local curvature, so 2nd-order FD
truncation pushes a few small-gradient pixels to ~2.6e-4 -- well below any
real pol-gradient bug (%-level), which the tight median still catches.

* Comment cleanup + epsilon_tv consistency in pol_imager_utils

Manual review pass: per-slot dR/dX labels, docstrings on the reg kernels,
a module-level CONVENTIONS block (imarr = [I, rho, phi=2chi, psi]), and
removal of stale TODOs. Two behavior touches, both byte-identical at the
defaults:
- reg_vtv / reggrad_vtv now honor epsilon_tv (kwargs, default 0) like the
  ptv pair, instead of the value ignoring it while the grad pinned it to 0.
- reggrad_ptv masks the chi-slot back-neighbor terms (c2/c3) too, for
  uniformity (they already self-zero at the pad).
Plus an mcv_r exception-message fix and ruff-clean whitespace.

* Comment cleanup in imager_utils (no behavior change)

Manual review pass: docstrings on the Stokes-I reg kernels, per-block
comments, 'fourier/transform matrices' labels on the diag Amatrices
unpacking, and removal of dead commented-out systematic-noise code in the
bispectrum data functions (the intent is now documented in
apply_systematic_noise_snrcut). Purely cosmetic; ruff-clean.

* fixed lint errors in test_regularizers and test_chisquared
@rohandahale

Copy link
Copy Markdown
Collaborator Author

@achael This PR is now ready to be merged! I resolved all conflicts after #306 fix but please do a quick check before merging.

@achael achael left a comment

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @rohandahale, looks good, ready to merge!

3 notes for the next PR:

  1. looks like EPSILON is a global constant in the spectral index TV -- we should change this to epsilon_tv to match the other functions.
  2. I think _rho_psi_safe should be promoted to a helper function like (rho_psi_from_mfrac_vfrac) or something -- it could have utililty outside the specific pol transforms here
  3. We should consider reviewing the tests and consolidating the gradient tests in particular soon. They are scattered in different files and I think there is some redundancy & non-uniformity in how they test.

impad = xp.pad(im, 1, mode='constant', constant_values=0)
im_l1 = xp.roll(impad, -1, axis=0)[1:ny+1, 1:nx+1]
im_l2 = xp.roll(impad, -1, axis=1)[1:ny+1, 1:nx+1]
return xp.sum(xp.sqrt(xp.abs(im_l1 - im)**2 + xp.abs(im_l2 - im)**2 + EPSILON)) / norm

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just noticed this EPSILON is fixed -- should make it epsilon_tv as well I think, later.

@achael achael merged commit 05a4529 into dev-backend Jun 18, 2026
6 checks passed
rohandahale added a commit that referenced this pull request Jun 22, 2026
Address #295 review: epsilon_tv, public rho_psi helper, gradient-test consolidation
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants