diff --git a/ehtim/imaging/multifreq_imager_utils.py b/ehtim/imaging/multifreq_imager_utils.py index ecaadecc..678e0f90 100644 --- a/ehtim/imaging/multifreq_imager_utils.py +++ b/ehtim/imaging/multifreq_imager_utils.py @@ -22,7 +22,6 @@ from ehtim.backends import array_namespace NORM_REGULARIZER = True -EPSILON = 1.e-12 DD_RHOPOL = 1 # transform paramter for multifrequency polarization fraction ################################################################################################## # multifrequency transformations @@ -249,12 +248,13 @@ def reg_tv_spec(imvec, mask, **kwargs): imvec = embed(imvec, mask, clipfloor=0, randomfloor=False) nx, ny, psize = kwargs['xdim'], kwargs['ydim'], kwargs['psize'] beam_size = kwargs.get('beam_size') or psize + epsilon = kwargs.get('epsilon_tv', 0.) norm = len(imvec) * psize / beam_size if kwargs.get('norm_reg', True) else 1 im = imvec.reshape(ny, nx) impad = xp.pad(im, 1, mode='constant', constant_values=0) im_l1 = xp.roll(impad, -1, axis=0)[1:ny+1, 1:nx+1] im_l2 = xp.roll(impad, -1, axis=1)[1:ny+1, 1:nx+1] - return xp.sum(xp.sqrt(xp.abs(im_l1 - im)**2 + xp.abs(im_l2 - im)**2 + EPSILON)) / norm + return xp.sum(xp.sqrt(xp.abs(im_l1 - im)**2 + xp.abs(im_l2 - im)**2 + epsilon)) / norm def reggrad_tv_spec(imvec, mask, **kwargs): @@ -263,6 +263,7 @@ def reggrad_tv_spec(imvec, mask, **kwargs): imvec = embed(imvec, mask, clipfloor=0, randomfloor=False) nx, ny, psize = kwargs['xdim'], kwargs['ydim'], kwargs['psize'] beam_size = kwargs.get('beam_size') or psize + epsilon = kwargs.get('epsilon_tv', 0.) norm = len(imvec) * psize / beam_size if kwargs.get('norm_reg', True) else 1 im = imvec.reshape(ny, nx) impad = np.pad(im, 1, mode='constant', constant_values=0) @@ -272,9 +273,9 @@ def reggrad_tv_spec(imvec, mask, **kwargs): im_r2 = np.roll(impad, 1, axis=1)[1:ny+1, 1:nx+1] im_r1l2 = np.roll(np.roll(impad, 1, axis=0), -1, axis=1)[1:ny+1, 1:nx+1] im_l1r2 = np.roll(np.roll(impad, -1, axis=0), 1, axis=1)[1:ny+1, 1:nx+1] - g1 = (2*im - im_l1 - im_l2) / np.sqrt((im - im_l1)**2 + (im - im_l2)**2 + EPSILON) - g2 = (im - im_r1) / np.sqrt((im - im_r1)**2 + (im_r1l2 - im_r1)**2 + EPSILON) - g3 = (im - im_r2) / np.sqrt((im - im_r2)**2 + (im_l1r2 - im_r2)**2 + EPSILON) + g1 = (2*im - im_l1 - im_l2) / np.sqrt((im - im_l1)**2 + (im - im_l2)**2 + epsilon) + g2 = (im - im_r1) / np.sqrt((im - im_r1)**2 + (im_r1l2 - im_r1)**2 + epsilon) + g3 = (im - im_r2) / np.sqrt((im - im_r2)**2 + (im_l1r2 - im_r2)**2 + epsilon) mask1 = np.zeros(im.shape) mask2 = np.zeros(im.shape) mask1[0, :] = 1 diff --git a/ehtim/imaging/pol_imager_utils.py b/ehtim/imaging/pol_imager_utils.py index a3efc2f8..ad7fc350 100644 --- a/ehtim/imaging/pol_imager_utils.py +++ b/ehtim/imaging/pol_imager_utils.py @@ -198,10 +198,12 @@ def polcv_grad(imarr, gradarr): return out -def _rho_psi_safe(xp, mfrac, vfrac): - """rho=sqrt(mfrac^2+vfrac^2), psi=arcsin(vfrac/rho), guarded so jax.grad stays finite - at zero-polarization pixels (sqrt(0) and arcsin(+/-1) singularities). Values are - unchanged where rho>0 and |vfrac/rho|<1. +def rho_psi_from_mfrac_vfrac(xp, mfrac, vfrac): + """Polar coordinates (rho, psi) of the (mfrac, vfrac) polarization vector. + + rho = sqrt(mfrac^2 + vfrac^2), psi = arcsin(vfrac/rho), guarded so jax.grad + stays finite at zero-polarization pixels (the sqrt(0) and arcsin(+/-1) + singularities). Values are unchanged where rho > 0 and |vfrac/rho| < 1. """ r2 = mfrac**2 + vfrac**2 rho = xp.where(r2 > 0, xp.sqrt(xp.where(r2 > 0, r2, 1.0)), 0.0) @@ -224,7 +226,7 @@ def mcv(imarr): mfrac_prime = imarr[1] mfrac = mfrac_max*(0.5 + xp.arctan(mfrac_prime/TANWIDTH_M)/np.pi) - rho, psi = _rho_psi_safe(xp, mfrac, vfrac) + rho, psi = rho_psi_from_mfrac_vfrac(xp, mfrac, vfrac) out = xp.stack((imarr[0], rho, imarr[2], psi)) return out @@ -304,7 +306,7 @@ def vcv(imarr): vfrac_prime = imarr[3] vfrac = 2*vfrac_max*xp.arctan(vfrac_prime/TANWIDTH_V)/np.pi - rho, psi = _rho_psi_safe(xp, mfrac, vfrac) + rho, psi = rho_psi_from_mfrac_vfrac(xp, mfrac, vfrac) out = xp.stack((imarr[0], rho, imarr[2], psi)) return out diff --git a/tests/test_chisquared.py b/tests/test_chisquared.py index 8dd411ae..dd6228a0 100644 --- a/tests/test_chisquared.py +++ b/tests/test_chisquared.py @@ -1,8 +1,10 @@ -"""Tests for chi-squared consistency across transform types (direct, fast, nfft). +"""Tests for chi-squared cross-transform consistency (direct, fast, nfft). -Verifies that chi-squared values and gradients agree between DFT, FFT, and NFFT -for all standard data types. All tests use a 32x48 image so xdim != ydim -exercises the rectangular-image code paths (rect subsumes square). +Verifies that chi-squared values and gradients AGREE across DFT, FFT, and NFFT for all +data types -- a cross-check, not a standalone correctness proof: the direct path is the +reference, finite-difference-validated in test_gradients.py (the canonical FD suite), so a +tight direct-vs-nfft gradient bound pins nfft gradient correctness. All tests use a 32x48 +image so xdim != ydim exercises the rectangular-image code paths (rect subsumes square). """ import numpy as np @@ -14,7 +16,6 @@ MfConfig, compute_chisq_term, compute_chisqdata_term, - compute_chisqgrad_term, ) from ehtim.imaging.imager_utils import chisq, chisqdata, chisqgrad @@ -33,20 +34,20 @@ # Transform type pairs to compare TTYPE_PAIRS = [("direct", "fast"), ("direct", "nfft"), ("nfft", "fast")] -# NFFT max gradient tolerance is much wider at this resolution -GRAD_MAX_TOL_NFFT = 10.0 +# Cross-transform gradient agreement. The direct DFT path is the trusted reference -- its +# gradients are finite-difference-validated in test_gradients.py -- so a tight direct-vs-nfft +# bound effectively pins nfft gradient correctness. Pairs that include the gridded 'fast' FFT +# are limited by its interpolation accuracy, not by a gradient bug, so they stay looser. +GRAD_MAX_TOL_DIRECT_NFFT = 1e-2 +GRAD_MAX_TOL = 0.25 # any pair containing 'fast' -# Diagonalized closures orthogonalize per-timestamp covariance, which can -# amplify direct-vs-fast tail outliers in test_grad_max_frac_diff. Bump the -# max-fractional-diff tolerance for these dtypes only; median tolerance and -# chi-squared tolerance unchanged. +# Diagonalized closures orthogonalize per-timestamp covariance, which amplifies tail outliers. GRAD_MAX_TOL_DIAG = 0.5 DIAG_DTYPES = {"cphase_diag", "logcamp_diag"} # Tolerances (calibrated on 32x48 synthetic Gaussian) CHISQ_FRAC_TOL = 0.01 GRAD_MEDIAN_TOL = 0.05 -GRAD_MAX_TOL = 0.25 # --------------------------------------------------------------------------- # chisqdata optional parameters (explicit for tracking across refactors) @@ -175,12 +176,10 @@ def test_grad_median_frac_diff(self, chisq_setup, dtype, pair): @pytest.mark.parametrize("pair", TTYPE_PAIRS, ids=lambda p: f"{p[0]}-{p[1]}") def test_grad_max_frac_diff(self, chisq_setup, dtype, pair): _, max_frac = _gradient_comparison(chisq_setup, dtype, pair) - if "nfft" in pair: - tol = GRAD_MAX_TOL_NFFT - elif dtype in DIAG_DTYPES: - tol = GRAD_MAX_TOL_DIAG + if pair == ("direct", "nfft"): + tol = GRAD_MAX_TOL_DIAG if dtype in DIAG_DTYPES else GRAD_MAX_TOL_DIRECT_NFFT else: - tol = GRAD_MAX_TOL + tol = GRAD_MAX_TOL_DIAG if dtype in DIAG_DTYPES else GRAD_MAX_TOL assert max_frac < tol, ( f"{dtype} {pair[0]}-{pair[1]}: grad max frac diff = {max_frac:.6f}" ) @@ -210,14 +209,6 @@ def _gradient_comparison(chisq_setup, dtype, pair): # Polarimetric chi-squared (pvis / m / vvis). Pol has no 'fast' ttype. # --------------------------------------------------------------------------- POL_DATATERMS = ["pvis", "m", "vvis"] -POL_FD_REL = 1e-6 -POL_FD_FLOOR = 1e-9 -# fractional FD vs analytic (same ttype, self-consistent). median is tight (any -# systematic gradient error blows it up); max is looser to tolerate 2nd-order FD -# truncation outliers at small-gradient pixels of the structured-pol imcur. Real -# pol-gradient bugs are %-level (e.g. the mcv slot-3 coupling), far above these. -POL_GRAD_FD_MEDIAN_TOL = 1e-5 -POL_GRAD_FD_MAX_TOL = 1e-3 POL_CHISQ_FRAC_TOL = 0.01 # direct-vs-nfft value agreement @@ -283,65 +274,3 @@ def test_chisq_values(self, chisq_setup_pol, dtype): vals[tt] = compute_chisq_term(imcur, dtype, A, data, sigma, ttype=tt, mask=mask) frac = abs((vals["direct"] - vals["nfft"]) / abs(vals["direct"])) assert frac < POL_CHISQ_FRAC_TOL, f"{dtype}: chisq frac diff = {frac:.6f}" - - -class TestPolChisqGradConsistency: - """Pol chi-squared gradients agree between direct and nfft.""" - - @pytest.mark.parametrize("dtype", POL_DATATERMS) - def test_grad_values(self, chisq_setup_pol, dtype): - obs, prior = chisq_setup_pol["obs"], chisq_setup_pol["prior"] - mask, imcur = chisq_setup_pol["mask"], chisq_setup_pol["imcur"] - grads = {} - for tt in ("direct", "nfft"): - A, data, sigma = _pol_data_tuple(obs, prior, mask, dtype, tt) - grads[tt] = compute_chisqgrad_term( - imcur, dtype, A, data, sigma, ttype=tt, mask=mask, - pol_solve=np.array([1, 1, 1, 1])) - a, b = grads["direct"], grads["nfft"] - floor = np.min(np.abs(a)) * 1e-20 + 1e-100 - frac = np.abs((a - b) / (np.abs(a) + floor)) - assert np.median(frac) < GRAD_MEDIAN_TOL - assert np.max(frac) < GRAD_MAX_TOL_NFFT - - -class TestPolChisqGradFD: - """Pol chisqgrad matches finite differences of the chisq value, in all four - physical slots (driven with pol_solve=[1,1,1,1]). vvis slot 2 (EVPA) is - asserted identically zero -- Stokes V is independent of EVPA -- and its FD - is also zero, so the all-slot loop covers it consistently.""" - - @pytest.mark.parametrize("dtype", POL_DATATERMS) - @pytest.mark.parametrize("ttype", ["direct", "nfft"]) - def test_grad_matches_fd(self, chisq_setup_pol, dtype, ttype): - obs, prior = chisq_setup_pol["obs"], chisq_setup_pol["prior"] - mask, imcur = chisq_setup_pol["mask"], chisq_setup_pol["imcur"] - A, data, sigma = _pol_data_tuple(obs, prior, mask, dtype, ttype) - grad = compute_chisqgrad_term( - imcur, dtype, A, data, sigma, ttype=ttype, mask=mask, - pol_solve=np.array([1, 1, 1, 1])) - - if dtype == "vvis": # V is independent of EVPA - np.testing.assert_array_equal(grad[2], 0.0) - - rng = np.random.default_rng(RNG_SEED) - n = imcur.shape[1] - sample = rng.choice(n, size=min(30, n), replace=False) - frac = [] - for slot in range(4): - for j in sample: - dx = max(POL_FD_REL * abs(imcur[slot, j]), POL_FD_FLOOR) - ip = imcur.copy() - ip[slot, j] += dx - im_ = imcur.copy() - im_[slot, j] -= dx - fd = (compute_chisq_term(ip, dtype, A, data, sigma, ttype=ttype, mask=mask) - - compute_chisq_term(im_, dtype, A, data, sigma, ttype=ttype, mask=mask)) / (2 * dx) - ex = grad[slot, j] - denom = max(abs(ex), abs(fd), POL_FD_FLOOR) - frac.append(abs(ex - fd) / denom) - frac = np.array(frac) - assert np.median(frac) < POL_GRAD_FD_MEDIAN_TOL, ( - f"{dtype} {ttype}: median frac diff = {np.median(frac):.4g}") - assert np.max(frac) < POL_GRAD_FD_MAX_TOL, ( - f"{dtype} {ttype}: max frac diff = {np.max(frac):.4g}") diff --git a/tests/test_gradients.py b/tests/test_gradients.py index d3d91301..a86e4ffc 100755 --- a/tests/test_gradients.py +++ b/tests/test_gradients.py @@ -1,236 +1,408 @@ -"""Tests for analytic gradient correctness via numeric finite differences. - -Verifies that analytic chi-squared gradients and regularizer gradients -match numeric finite differences computed element-wise. All tests use a -32x48 image so xdim != ydim exercises the rectangular-image code paths. +"""Canonical finite-difference gradient tests for the numpy imaging backend. + +Every analytic gradient -- chi-squared data terms, regularizers, and the pol +change-of-variables transforms -- is checked against central finite differences +of its own scalar value, in pure numpy. The jax-autodiff parity check is the +separate second opinion in test_objective_jax.py; cross-transform value/gradient +consistency (direct vs fast vs nfft) lives in test_chisquared.py. + +Structure -- parallel Stokes-I / pol / spectral sections, term lists at the top: + + S1 chi^2 Stokes-I : vis amp logamp bs cphase cphase_diag camp logcamp logcamp_diag + S2 chi^2 pol : pvis m vvis + S3 reg Stokes-I : every name in imager_backend.REGULARIZERS + S4 reg pol : every name in REGULARIZERS_POL (all four physical slots) + S5 reg spectral : every name in REGULARIZERS_SPECTRAL + S6 transforms : mcv vcv polcv + +Methodology (uniform across every section): central differences (FD_EPS), +compared by max fractional error relative to the per-slot gradient scale -- robust +where the analytic gradient is ~0 (a real missing term gives an O(1) ratio, while +finite-difference noise at a near-zero component does not). Finite differences are +full-grid on a small image so the zero-pad boundary (row 0 / col 0), where the TV +neighbour-roll regularizer gradients live, is always covered. Pol cases run at +v != 0 AND m != 0 at every pixel -- the regime where the mcv/vcv cross-term bugs hid. """ - import numpy as np import pytest import ehtim as eh -import ehtim.imaging.imager_utils as iu +from ehtim.imaging.imager_backend import ( + REGULARIZERS, + REGULARIZERS_SPECTRAL, + ImagerConfig, + MfConfig, + compute_chisq_term, + compute_chisqdata_term, + compute_chisqgrad_term, + compute_regularizer_term, + compute_regularizergrad_term, +) from ehtim.imaging.imager_utils import chisq, chisqdata, chisqgrad +from ehtim.imaging.pol_imager_utils import ( + REGULARIZERS_POL, + mcv, + mcv_grad, + mcv_r, + polcv, + polcv_grad, + polcv_r, + vcv, + vcv_grad, + vcv_r, +) + +# --- finite-difference harness (one implementation for every section) --------- +FD_EPS = 1e-6 +FD_RTOL = 1e-3 # max fractional error vs the gradient scale; correct grads sit far below +NFFT_RTOL = 1e-2 # nfft adds finufft truncation on top of the central-difference floor +ABS_FLOOR = 1e-9 # |grad| below this counts as identically zero (e.g. the vvis EVPA slot) +REL_FLOOR = 1e-3 # ignore components below this fraction of the slot's gradient scale +SEED = 4 + +# observation parameters (match conftest.py) +TINT_SEC, TADV_SEC, TSTART_HR, TSTOP_HR, BW_HZ = 5, 600, 0, 24, 4e9 + + +def fd_grad(value_fn, x, eps=FD_EPS): + """Full-grid central-difference gradient of scalar value_fn over array x.""" + x = np.asarray(x, dtype=float) + g = np.zeros(x.shape) + for idx in np.ndindex(x.shape): + xp, xm = x.copy(), x.copy() + xp[idx] += eps + xm[idx] -= eps + g[idx] = (value_fn(xp) - value_fn(xm)) / (2 * eps) + return g + + +def assert_grad_close(analytic, fd, rtol=FD_RTOL, label="", allow_zero=False): + """Assert analytic == fd by max fractional error relative to the gradient scale. + + Each component is normalized by max(|analytic|, |fd|, REL_FLOOR*scale) so near-zero + components do not blow up the ratio while a real missing term gives an O(1) ratio. + + A finite-difference gradient at machine zero (scale below ABS_FLOOR) is rejected as a + vacuous test unless ``allow_zero`` -- a slot that is identically zero by construction + (e.g. the vvis EVPA slot, or a V regularizer's EVPA slot). When allowed, the analytic + slot must be ~0 too. Callers pair this with a per-case non-vacuousness check so a term + whose every slot collapses to zero still fails. + """ + analytic = np.asarray(analytic, dtype=float) + fd = np.asarray(fd, dtype=float) + scale = float(np.max(np.abs(fd))) if fd.size else 0.0 + if scale < ABS_FLOOR: + assert allow_zero, f"{label}: finite-difference gradient is ~0 -- vacuous test" + amax = float(np.max(np.abs(analytic))) if analytic.size else 0.0 + assert amax < ABS_FLOOR, f"{label}: expected ~0 gradient, got max|analytic|={amax:.2e}" + return + denom = np.maximum(np.maximum(np.abs(analytic), np.abs(fd)), REL_FLOOR * scale) + frac = np.abs(analytic - fd) / denom + assert np.max(frac) < rtol, ( + f"{label}: max frac err {np.max(frac):.2e} (median {np.median(frac):.2e})") -# Data types, regularizers, and transform types to test -DATATERMS = ["vis", "bs", "amp", "cphase", "camp", "logcamp"] -REGULARIZERS = ["simple", "gs", "l1w", "tv", "tv2"] -TTYPES = ["direct"] # expand to ["direct", "fast", "nfft"] to test other transforms - -# Diagonalized closure gradients, checked on the supported transforms. -# ('fast'/plain-FFT is omitted; that mode is slated for deprecation.) -DATATERMS_DIAG = ["cphase_diag", "logcamp_diag"] -TTYPES_DIAG = ["direct", "nfft"] - -# Tolerances (calibrated on 32x48 synthetic Gaussian with relative step size) -CHISQ_GRAD_MEDIAN_TOL = 0.001 -CHISQ_GRAD_MAX_TOL = 0.01 -REG_GRAD_MEDIAN_TOL = 0.01 -REG_GRAD_MAX_TOL = 0.10 - -# --------------------------------------------------------------------------- -# Numeric gradient parameters -# --------------------------------------------------------------------------- -N_GRAD_SAMPLES = 100 -RNG_SEED = 4 -GRAD_DX_REL = 1e-8 # relative step size per pixel -GRAD_DX_FLOOR = 1e-12 # absolute minimum step size - -# --------------------------------------------------------------------------- -# Regularizer optional parameters -# --------------------------------------------------------------------------- -BEAM_SIZE = 20.0 * eh.RADPERUAS -ALPHA_A = 5000.0 -EPSILON_TV = 0.0 - -# --------------------------------------------------------------------------- -# Observation parameters -# --------------------------------------------------------------------------- -TINT_SEC = 5 -TADV_SEC = 600 -TSTART_HR = 0 -TSTOP_HR = 24 -BW_HZ = 4e9 +def assert_nonvacuous(fd, label=""): + """A real gradient must be exercised somewhere -- guards against a setup that + silently zeros every component and turns the finite-difference check into a no-op.""" + assert float(np.max(np.abs(np.asarray(fd)))) > ABS_FLOOR, ( + f"{label}: gradient is ~0 everywhere -- the test exercises nothing") -@pytest.fixture(scope="module") -def grad_setup(eht_array, make_asym_image): - """Set up observation and test image from a 32x48 asymmetric image. - Offset double-Gaussian: xdim != ydim exercises the rectangular-image code - paths, and the broken symmetry + edge flux surface boundary/axis bugs a - centered Gaussian hides. +def _rtol(ttype): + return NFFT_RTOL if ttype == "nfft" else FD_RTOL + + +def _chisqdata_kwargs(ttype): + """chisqdata options, explicit so a default flip cannot silently move a tolerance. + + debias=False keeps raw amplitudes: debiasing floors sqrt(amp^2 - sigma^2) to 0 at + baselines where the source has resolved out, which makes logamp's log|V| singular. """ - im = make_asym_image(32, 48) - im.imvec = im.imvec * 2.0 / im.total_flux() # normalize to 2 Jy + kw = dict(systematic_noise=0.0, snrcut=0.0, debias=False, weighting="natural", + maxset=False, cp_uv_min=False, systematic_cphase_noise=0.0) + if ttype in ("fast", "nfft"): + kw.update(fft_pad_factor=10, p_rad=12, conv_func="gaussian", order=3) + return kw + + +TTYPES = ["direct", "nfft"] - obs = im.observe( - eht_array, TINT_SEC, TADV_SEC, TSTART_HR, TSTOP_HR, BW_HZ, - sgrscat=False, ampcal=True, phasecal=True, - ttype="direct", add_th_noise=False, - ) +# ============================ S1: chi^2 Stokes-I ============================== +# Per-baseline amplitude/visibility terms run on a null-free single Gaussian: logamp is +# log(|V|), singular wherever a visibility nulls (an asymmetric source interferes to zero +# at some baselines). Closure phases/amps run on an asymmetric image: a symmetric source +# has zero closure phase, making those tests vacuous. Each group uses the image that keeps +# its gradient both well-conditioned and non-trivial. +PERBASELINE_TERMS = ["vis", "amp", "logamp"] +CLOSURE_TERMS = ["bs", "cphase", "cphase_diag", "camp", "logcamp", "logcamp_diag"] +DATATERMS_SI = PERBASELINE_TERMS + CLOSURE_TERMS + + +def _si_setup(im, eht_array): + """Normalize a small (8x10) Stokes-I image, observe it once (DFT), jitter the imvec.""" + im.imvec = im.imvec * 2.0 / im.total_flux() + obs = im.observe(eht_array, TINT_SEC, TADV_SEC, TSTART_HR, TSTOP_HR, BW_HZ, + sgrscat=False, ampcal=True, phasecal=True, ttype="direct", + add_th_noise=False) prior = im.copy() - im2 = prior.copy() + rng = np.random.default_rng(SEED) + imvec = im.imvec * (1.0 + 0.05 * (rng.random(im.imvec.size) - 0.5)) + mask = np.ones(imvec.size, dtype=bool) + return {"obs": obs, "prior": prior, "imvec": imvec, "mask": mask} - rng = np.random.default_rng(RNG_SEED) - im2.imvec *= 1.0 + (rng.random(len(im2.imvec)) - 0.5) / 10.0 - im2.imvec += (1.0 + (rng.random(len(im2.imvec)) - 0.5) / 10.0) * np.mean(im2.imvec) - mask = im2.imvec > 0.5 * np.median(im2.imvec) - test_imvec = im2.imvec[mask] if np.any(~mask) else im2.imvec +@pytest.fixture(scope="module") +def si_gauss_setup(eht_array): + """Compact single Gaussian on a fine grid for the per-baseline amplitude terms. - return { - "obs": obs, - "prior": prior, - "test_imvec": test_imvec, - "mask": mask, - "im": im, - } + A compact source keeps |V| well above the noise floor at every EHT baseline, so + log|V| (logamp) is well-conditioned everywhere; an extended source resolves out and + its long-baseline amplitudes underflow. + """ + im = eh.image.make_empty(10, 80 * eh.RADPERUAS, 17.761, -29.0, rf=230e9) + im = im.add_gauss(2.0, (25 * eh.RADPERUAS, 25 * eh.RADPERUAS, 0, 0, 0)) + return _si_setup(im, eht_array) -class TestChisqGradientFiniteDiff: - """Analytic chi-squared gradients match numeric finite differences.""" +@pytest.fixture(scope="module") +def si_asym_setup(eht_array, make_asym_image): + """Asymmetric image (rect 8x10) -> nonzero closure phases, for the closure terms.""" + return _si_setup(make_asym_image(8, 10), eht_array) - @pytest.mark.parametrize("dtype", DATATERMS) - @pytest.mark.parametrize("ttype", TTYPES) - def test_median_frac_diff(self, grad_setup, dtype, ttype): - median_frac, _ = _chisq_gradient_check(grad_setup, dtype, ttype) - assert median_frac < CHISQ_GRAD_MEDIAN_TOL, ( - f"{dtype} ({ttype}) median fractional gradient diff = {median_frac:.6f}" - ) - @pytest.mark.parametrize("dtype", DATATERMS) +class TestChisqGradientStokesI: + """Analytic Stokes-I chi^2 gradients match central finite differences.""" + @pytest.mark.parametrize("ttype", TTYPES) - def test_max_frac_diff(self, grad_setup, dtype, ttype): - _, max_frac = _chisq_gradient_check(grad_setup, dtype, ttype) - assert max_frac < CHISQ_GRAD_MAX_TOL, ( - f"{dtype} ({ttype}) max fractional gradient diff = {max_frac:.6f}" - ) + @pytest.mark.parametrize("dtype", DATATERMS_SI) + def test_grad_matches_fd(self, request, dtype, ttype): + fixture = "si_gauss_setup" if dtype in PERBASELINE_TERMS else "si_asym_setup" + s = request.getfixturevalue(fixture) + obs, prior, mask, imvec = (s[k] for k in ("obs", "prior", "mask", "imvec")) + data, sigma, A = chisqdata(obs, prior, mask, dtype, ttype=ttype, **_chisqdata_kwargs(ttype)) + analytic = chisqgrad(imvec, A, data, sigma, dtype, ttype=ttype, mask=mask) + fd = fd_grad(lambda v: chisq(v, A, data, sigma, dtype, ttype=ttype, mask=mask), imvec) + assert_grad_close(analytic, fd, rtol=_rtol(ttype), label=f"{dtype} {ttype}") + + +def test_diag_chisq_nfft_matches_direct(si_asym_setup): + """Block-diagonal nfft diag chi^2 agrees with the direct-DFT diag chi^2. + + The nfft diagonalized closures apply the per-block decorrelating transforms as + one block-diagonal matmul; the direct terms loop. Both share the same transforms + and measured closures, so their chi^2 must agree -- a self-consistent-but-wrong + restructure error that finite differences alone would not catch. + """ + obs, prior, mask, imvec = (si_asym_setup[k] for k in ("obs", "prior", "mask", "imvec")) + for dtype in ("cphase_diag", "logcamp_diag"): + cdir = chisqdata(obs, prior, mask, dtype, ttype="direct", **_chisqdata_kwargs("direct")) + cnf = chisqdata(obs, prior, mask, dtype, ttype="nfft", **_chisqdata_kwargs("nfft")) + chi_dir = chisq(imvec, cdir[2], cdir[0], cdir[1], dtype, ttype="direct", mask=mask) + chi_nf = chisq(imvec, cnf[2], cnf[0], cnf[1], dtype, ttype="nfft", mask=mask) + assert abs(chi_dir - chi_nf) <= 1e-2 * abs(chi_dir), f"{dtype}: {chi_dir:.6g} vs {chi_nf:.6g}" -class TestChisqGradientFiniteDiffDiag: - """Diagonalized-closure gradients match finite differences. +# ============================== S2: chi^2 pol ================================ +POL_DATATERMS = ["pvis", "m", "vvis"] - Pins the vectorized per-time-block matvec in chisqgrad_{cphase,logcamp}_diag - against numeric finite differences, at the same tolerance as the standard - closures. - """ - @pytest.mark.parametrize("dtype", DATATERMS_DIAG) - @pytest.mark.parametrize("ttype", TTYPES_DIAG) - def test_median_frac_diff(self, grad_setup, dtype, ttype): - median_frac, _ = _chisq_gradient_check(grad_setup, dtype, ttype) - assert median_frac < CHISQ_GRAD_MEDIAN_TOL, ( - f"{dtype} ({ttype}) median fractional gradient diff = {median_frac:.6f}" - ) - - @pytest.mark.parametrize("dtype", DATATERMS_DIAG) - @pytest.mark.parametrize("ttype", TTYPES_DIAG) - def test_max_frac_diff(self, grad_setup, dtype, ttype): - _, max_frac = _chisq_gradient_check(grad_setup, dtype, ttype) - assert max_frac < CHISQ_GRAD_MAX_TOL, ( - f"{dtype} ({ttype}) max fractional gradient diff = {max_frac:.6f}" - ) - - -def test_diag_chisq_nfft_matches_direct(grad_setup): - """Block-diagonal nfft diag chisq agrees with the direct-DFT diag chisq. - - The nfft diagonalized-closure terms apply the per-block decorrelating - transforms as one block-diagonal matmul; the direct terms loop. Both share - the same transforms and measured closures and differ only by Fourier - accuracy, so their chi^2 must agree closely. Guards the block-diagonal - restructure against a self-consistent-but-wrong error (which finite - differences alone would not catch). - """ - obs, prior, mask = grad_setup["obs"], grad_setup["prior"], grad_setup["mask"] - iv = grad_setup["test_imvec"] - for dtype in DATATERMS_DIAG: - cdir = chisqdata(obs, prior, mask, dtype, ttype="direct") - cnf = chisqdata(obs, prior, mask, dtype, ttype="nfft") - chi_dir = chisq(iv, cdir[2], cdir[0], cdir[1], dtype, ttype="direct", mask=mask) - chi_nf = chisq(iv, cnf[2], cnf[0], cnf[1], dtype, ttype="nfft", mask=mask) - assert abs(chi_dir - chi_nf) <= 1e-2 * abs(chi_dir), ( - f"{dtype}: direct={chi_dir:.6g} vs nfft={chi_nf:.6g}" - ) +def _pol_config(ttype): + return ImagerConfig(pol="IP", transforms=[], ttype=ttype, mf=False, + mf_config=MfConfig(mf_order=0, mf_order_pol=0, mf_rm=0, mf_cm=0)) -class TestRegularizerGradientFiniteDiff: - """Analytic regularizer gradients match numeric finite differences.""" +@pytest.fixture(scope="module") +def pol_setup(eht_array, make_asym_image): + """Small (8x10) asymmetric polarized image + a jittered physical imcur [I,rho,phi,psi]. - @pytest.mark.parametrize("rtype", REGULARIZERS) - def test_median_frac_diff(self, grad_setup, rtype): - median_frac, _ = _reg_gradient_check(grad_setup, rtype) - assert median_frac < REG_GRAD_MEDIAN_TOL, ( - f"{rtype} median fractional gradient diff = {median_frac:.6f}" - ) + add_random_pol gives a spatially-varying EVPA and (cmag>0) circular fraction, so + chi, vfrac, rho and psi all vary; the imcur is clipped to rho in (0,1) and psi away + from 0 so v != 0 and m != 0 at every pixel. + """ + im = make_asym_image(8, 10) + im.imvec = im.imvec * 2.0 / im.total_flux() + im = im.add_random_pol(0.25, 40 * eh.RADPERUAS, cmag=0.06, ccorr=40 * eh.RADPERUAS, seed=7) + prior = im.copy() + obs = im.observe(eht_array, TINT_SEC, TADV_SEC, TSTART_HR, TSTOP_HR, BW_HZ, + ampcal=True, phasecal=True, ttype="direct", add_th_noise=False) + mask = np.ones(im.imvec.size, dtype=bool) + rng = np.random.default_rng(SEED) + I = im.imvec + Q, U, V = im.qvec, im.uvec, im.vvec + P = np.sqrt(Q**2 + U**2 + V**2) + n = I.size + imcur = np.array([ + I * (1.0 + 0.05 * (rng.random(n) - 0.5)), + np.clip((P / I) * (1.0 + 0.1 * (rng.random(n) - 0.5)), 0.02, 0.95), + np.arctan2(U, Q) + 0.1 * (rng.random(n) - 0.5), + np.clip(np.abs(np.arcsin(V / (P + 1e-30))) * (1.0 + 0.1 * (rng.random(n) - 0.5)), 0.02, 1.5), + ]) + return {"obs": obs, "prior": prior, "mask": mask, "imcur": imcur} + + +class TestChisqGradientPol: + """Analytic pol chi^2 gradients match central finite differences in all four slots. + + vvis is independent of the EVPA, so its slot-2 gradient is identically zero and + assert_grad_close checks that the analytic slot is ~0 too. + """ + + @pytest.mark.parametrize("ttype", TTYPES) + @pytest.mark.parametrize("dtype", POL_DATATERMS) + def test_grad_matches_fd(self, pol_setup, dtype, ttype): + obs, prior, mask, imcur = (pol_setup[k] for k in ("obs", "prior", "mask", "imcur")) + data, sigma, A = compute_chisqdata_term(obs, prior, mask, dtype, _pol_config(ttype)) + analytic = compute_chisqgrad_term(imcur, dtype, A, data, sigma, ttype=ttype, mask=mask, + pol_solve=np.array([1, 1, 1, 1])) + fd = fd_grad(lambda im: compute_chisq_term(im, dtype, A, data, sigma, ttype=ttype, mask=mask), imcur) + assert_nonvacuous(fd, label=f"{dtype} {ttype}") + for s in range(4): + assert_grad_close(analytic[s], fd[s], rtol=_rtol(ttype), + label=f"{dtype} {ttype} slot{s}", allow_zero=True) + + +# ============================ regularizers (S3/S4/S5) ========================= +# Full-grid finite differences on a small image so the zero-pad boundary (row 0 / col 0), +# where the TV neighbour-roll gradients live, is always covered. epsilon_tv rounds the |.| +# kink where neighbouring pixels coincide (a smooth source's extrema); it is negligible at +# every other pixel and does not change the gradient formula, so a dropped/factor/boundary +# term still gives an O(1) finite-difference mismatch. +REG_BEAM = 20 * eh.RADPERUAS +REG_EPS = 1e-8 +REG_XDIM, REG_YDIM = 6, 8 +REG_N = REG_XDIM * REG_YDIM + + +# ------------------------------- S3: reg Stokes-I ---------------------------- +@pytest.fixture(scope="module") +def reg_si_setup(make_asym_image): + """Small (6x8) asymmetric Stokes-I image + parameters every regularizer can draw from.""" + im = make_asym_image(REG_XDIM, REG_YDIM) + im.imvec = im.imvec * 2.0 / im.total_flux() + imvec = im.imvec + mask = np.ones(imvec.size, dtype=bool) + nprior = np.full(imvec.size, imvec.mean()) # uniform prior != imvec + kw = dict(nprior=nprior, flux=0.5 * imvec.sum(), # flux != sum so reg_flux gradient != 0 + xdim=im.xdim, ydim=im.ydim, psize=im.psize, + beam_size=REG_BEAM, alpha_A=5000.0, epsilon_tv=REG_EPS, + major=50 * eh.RADPERUAS, minor=60 * eh.RADPERUAS, PA=np.pi / 3, norm_reg=True) + return imvec, mask, kw + + +class TestRegularizerGradientStokesI: + """Analytic Stokes-I regularizer gradients match central finite differences.""" @pytest.mark.parametrize("rtype", REGULARIZERS) - def test_max_frac_diff(self, grad_setup, rtype): - _, max_frac = _reg_gradient_check(grad_setup, rtype) - assert max_frac < REG_GRAD_MAX_TOL, ( - f"{rtype} max fractional gradient diff = {max_frac:.6f}" - ) - - -def _chisq_gradient_check(grad_setup, dtype, ttype): - """Compare analytic vs numeric chi-squared gradient on subsampled pixels.""" - obs = grad_setup["obs"] - prior = grad_setup["prior"] - mask = grad_setup["mask"] - test_imvec = grad_setup["test_imvec"] - - cdata = chisqdata(obs, prior, mask, dtype, ttype=ttype) - grad_exact = chisqgrad(test_imvec, cdata[2], cdata[0], cdata[1], dtype, ttype=ttype, mask=mask) - y0 = chisq(test_imvec, cdata[2], cdata[0], cdata[1], dtype, ttype=ttype, mask=mask) - - rng = np.random.default_rng(RNG_SEED) - sample_idx = rng.choice(len(test_imvec), size=N_GRAD_SAMPLES, replace=False) - - grad_numeric = np.zeros(N_GRAD_SAMPLES) - for i, j in enumerate(sample_idx): - dx = max(GRAD_DX_REL * abs(test_imvec[j]), GRAD_DX_FLOOR) - imvec2 = test_imvec.copy() - imvec2[j] += dx - y1 = chisq(imvec2, cdata[2], cdata[0], cdata[1], dtype, ttype=ttype, mask=mask) - grad_numeric[i] = (y1 - y0) / dx - - grad_sampled = grad_exact[sample_idx] - compare_floor = np.min(np.abs(grad_sampled)) * 1e-20 + 1e-100 - frac_diff = np.abs((grad_numeric - grad_sampled) / (np.abs(grad_sampled) + compare_floor)) - return np.median(frac_diff), np.max(frac_diff) - - -def _reg_gradient_check(grad_setup, rtype): - """Compare analytic vs numeric regularizer gradient on subsampled pixels.""" - test_imvec = grad_setup["test_imvec"] - im = grad_setup["im"] - - nprior = np.ones_like(test_imvec) - nprior = nprior * np.sum(test_imvec) / np.sum(nprior) - mask = grad_setup["mask"] - flux = np.sum(test_imvec) - - kwargs = dict( - beam_size=BEAM_SIZE, alpha_A=ALPHA_A, epsilon_tv=EPSILON_TV, norm_reg=True, - ) - - y0 = iu.regularizer(test_imvec, nprior, mask, flux, im.xdim, im.ydim, im.psize, rtype, **kwargs) - grad_exact = iu.regularizergrad(test_imvec, nprior, mask, flux, im.xdim, im.ydim, im.psize, rtype, **kwargs) - - rng = np.random.default_rng(RNG_SEED) - sample_idx = rng.choice(len(test_imvec), size=N_GRAD_SAMPLES, replace=False) - - grad_numeric = np.zeros(N_GRAD_SAMPLES) - for i, j in enumerate(sample_idx): - dx = max(GRAD_DX_REL * abs(test_imvec[j]), GRAD_DX_FLOOR) - imvec2 = test_imvec.copy() - imvec2[j] += dx - y1 = iu.regularizer(imvec2, nprior, mask, flux, im.xdim, im.ydim, im.psize, rtype, **kwargs) - grad_numeric[i] = (y1 - y0) / dx - - grad_sampled = grad_exact[sample_idx] - compare_floor = np.min(np.abs(grad_sampled)) * 1e-20 + 1e-100 - frac_diff = np.abs((grad_numeric - grad_sampled) / (np.abs(grad_sampled) + compare_floor)) - return np.median(frac_diff), np.max(frac_diff) + def test_grad_matches_fd(self, reg_si_setup, rtype): + imvec, mask, kw = reg_si_setup + analytic = compute_regularizergrad_term(imvec, rtype, mask, **kw) + fd = fd_grad(lambda v: compute_regularizer_term(v, rtype, mask, **kw), imvec) + assert_nonvacuous(fd, label=rtype) + assert_grad_close(analytic, fd, label=rtype) + + +# --------------------------------- S4: reg pol ------------------------------- +@pytest.fixture(scope="module") +def reg_pol_setup(): + """Physical imarr [I, rho, phi, psi] with v != 0 AND m != 0 at every pixel. + + pol_solve=(1,1,1,1) ungates every slot so the finite differences check the full + gradient formula, not just the slots the solver happens to optimize. + """ + rng = np.random.default_rng(SEED) + imarr = np.stack([ + 0.5 + rng.random(REG_N), # I > 0 + 0.2 + 0.6 * rng.random(REG_N), # rho (total pol frac) in (0.2, 0.8) + 2 * np.pi * rng.random(REG_N), # phi = 2*chi + 0.3 + 0.5 * rng.random(REG_N), # psi in (0.3, 0.8) -> v != 0 and m != 0 + ]) + mask = np.ones(REG_N, dtype=bool) + kw = dict(flux=1.0, pflux=0.3, vflux=0.1, xdim=REG_XDIM, ydim=REG_YDIM, psize=1.0, + beam_size=2.0, epsilon_tv=REG_EPS, norm_reg=True, pol_solve=(1, 1, 1, 1)) + return imarr, mask, kw + + +class TestRegularizerGradientPol: + """Analytic pol regularizer gradients match central finite differences in all four slots.""" + + @pytest.mark.parametrize("rtype", REGULARIZERS_POL) + def test_grad_matches_fd(self, reg_pol_setup, rtype): + imarr, mask, kw = reg_pol_setup + analytic = compute_regularizergrad_term(imarr, rtype, mask, **kw) + fd = fd_grad(lambda im: compute_regularizer_term(im, rtype, mask, **kw), imarr) + assert_nonvacuous(fd, label=rtype) + for s in range(4): + assert_grad_close(analytic[s], fd[s], label=f"{rtype} slot{s}", allow_zero=True) + + +# ------------------------------ S5: reg spectral ----------------------------- +@pytest.fixture(scope="module") +def reg_spectral_setup(make_asym_image): + """Small (6x8) spectral-coefficient map (e.g. alpha) + a half-amplitude prior.""" + im = make_asym_image(REG_XDIM, REG_YDIM) + imvec = im.imvec * 2.0 / im.total_flux() + mask = np.ones(imvec.size, dtype=bool) + kw = dict(nprior=imvec * 0.5, xdim=im.xdim, ydim=im.ydim, psize=im.psize, + beam_size=REG_BEAM, epsilon_tv=REG_EPS, norm_reg=True) + return imvec, mask, kw + + +class TestRegularizerGradientSpectral: + """Analytic spectral-index regularizer gradients match central finite differences.""" + + @pytest.mark.parametrize("rtype", REGULARIZERS_SPECTRAL) + def test_grad_matches_fd(self, reg_spectral_setup, rtype): + imvec, mask, kw = reg_spectral_setup + analytic = compute_regularizergrad_term(imvec, rtype, mask, **kw) + fd = fd_grad(lambda v: compute_regularizer_term(v, rtype, mask, **kw), imvec) + assert_nonvacuous(fd, label=rtype) + assert_grad_close(analytic, fd, label=rtype) + + +# ============================== S6: transforms =============================== +# Each change-of-variables maps a solver image to a physical one. We finite-difference an +# arbitrary DENSE linear objective sum(grad_phys * fwd(solver)) w.r.t. the solver slots and +# compare to the transform's Jacobian-vector product (its grad fn returns slots 1,2,3). A +# dense grad_phys exercises every output slot -- a real objective like chi^2_p is blind to +# the slot a transform holds constant, which is how the mcv/vcv cross-term bugs once hid. +TRANSFORMS = { + # name: (forward, grad, reverse, solved slots, held-constant slot) + "mcv": (mcv, mcv_grad, mcv_r, [1, 2], 3), # solve m', phi; hold v + "vcv": (vcv, vcv_grad, vcv_r, [2, 3], 1), # solve phi, v'; hold m + "polcv": (polcv, polcv_grad, polcv_r, [1, 2, 3], None), +} + + +def _phys_imarr(n, seed=SEED): + """Physical imarr [I, rho, phi, psi] with v != 0 AND m != 0 at every pixel.""" + rng = np.random.default_rng(seed) + return np.stack([ + 0.5 + rng.random(n), + 0.2 + 0.6 * rng.random(n), + 2 * np.pi * rng.random(n), + 0.3 + 0.5 * rng.random(n), + ]) + + +class TestTransformGradient: + """Analytic change-of-variables Jacobian-vector products match finite differences. + + The forward transforms only touch slots 1,2,3 (slot 0 is the log-Stokes-I path), and the + grad functions return those three slots; the held-constant slot must come back ~0. + """ + + @pytest.mark.parametrize("name", list(TRANSFORMS)) + def test_grad_matches_fd(self, name): + fwd, gradfn, rev, free, fixed = TRANSFORMS[name] + solver = rev(_phys_imarr(REG_N)) + grad_phys = np.random.default_rng(SEED + 5).standard_normal((4, REG_N)) + analytic = gradfn(solver, grad_phys) # J^T @ grad_phys, slots 1..3 + fd = fd_grad(lambda s: float(np.sum(grad_phys * fwd(s))), solver) + assert_nonvacuous(fd, label=name) + for slot in free: + assert_grad_close(analytic[slot - 1], fd[slot], label=f"{name} slot{slot}") + if fixed is not None: + held = float(np.max(np.abs(analytic[fixed - 1]))) + assert held < ABS_FLOOR, f"{name}: held slot {fixed} gradient not ~0 (got {held:.2e})" diff --git a/tests/test_imager_backend.py b/tests/test_imager_backend.py index 194fc85b..fb5431d5 100644 --- a/tests/test_imager_backend.py +++ b/tests/test_imager_backend.py @@ -3641,7 +3641,14 @@ def test_pol_mode_short_which_solve_passthrough(self): class TestComputeRegularizerTerm: - """Parity between compute_regularizer_term and the legacy outer dispatchers.""" + """Routing parity between compute_regularizer_term and the legacy outer dispatchers. + + The legacy dispatchers are thin shims that forward to compute_regularizer{,grad}_term, + so these assert the two entry points reach the SAME leaf (a routing/forwarding check), + NOT that the leaf gradient is mathematically correct. The analytic-vs-finite-difference + gradient correctness for every regularizer lives in test_gradients.py (the canonical FD + suite); a leaf bug would pass here (both sides share it) but fail there. + """ @pytest.fixture(scope="class") def stokes_setup(self, make_rect_image): diff --git a/tests/test_objective_jax.py b/tests/test_objective_jax.py index a6b23467..68f405dd 100644 --- a/tests/test_objective_jax.py +++ b/tests/test_objective_jax.py @@ -22,7 +22,6 @@ import pytest import ehtim as eh -import ehtim.imaging.imager_utils as iu from ehtim.imaging.imager_backend import make_objective_jax from ehtim.observing.obs_helpers import NFFTInfo, ftmatrix, nufft2_backend @@ -182,32 +181,6 @@ def counted(x): assert traces["n"] == 1 -# ============================== embed / partial mask ============================== -def test_embed_functional_jax(): - # embed's functional scatter is byte-identical on numpy and differentiable on jax - mask = np.array([True, False, True, False, True, True]) - imvec = np.array([1.0, 2.0, 3.0, 4.0]) - assert np.array_equal(iu.embed(imvec, mask), np.asarray(iu.embed(jnp.asarray(imvec), mask))) - g = jax.grad(lambda v: jnp.sum(iu.embed(v, mask) ** 2))(jnp.asarray(imvec)) - assert np.allclose(np.asarray(g), 2 * imvec) # grad routes only to on-mask pixels - - -def test_spatial_reg_partial_mask_parity(): - # reg_tv with a partial mask exercises the embed scatter under jax; clipfloor=0 - # makes the off-mask fill deterministic (no seed needed for numpy<->jax parity). - rng = np.random.default_rng(0) - ny, nx = 6, 6 - full = rng.uniform(0.1, 1.0, ny * nx) - mask = np.ones(ny * nx, dtype=bool) - mask[rng.choice(ny * nx, 8, replace=False)] = False - imvec = full[mask] - kw = dict(xdim=nx, ydim=ny, psize=1.0, flux=float(full.sum()), beam_size=2.0, norm_reg=True) - assert np.allclose(iu.reg_tv(imvec, mask, **kw), - float(iu.reg_tv(jnp.asarray(imvec), mask, **kw)), rtol=1e-12) - g_jax = np.asarray(jax.grad(lambda v: iu.reg_tv(v, mask, **kw))(jnp.asarray(imvec))) - assert np.allclose(g_jax, iu.reggrad_tv(imvec, mask, **kw), rtol=1e-8, atol=1e-10) - - # ============================== nufft2_backend equality ============================== def test_nufft2_backend_numpy_jax_equality(): # nufft2_backend dispatches finufft (numpy) vs jax_finufft (jax); the two must agree, diff --git a/tests/test_regularizers.py b/tests/test_regularizers.py index f17450b6..e06ebde2 100644 --- a/tests/test_regularizers.py +++ b/tests/test_regularizers.py @@ -1,8 +1,11 @@ -"""Tests for ehtim regularizer functions. - -Verifies that all regularizer types return finite values and that -analytic gradients match numeric finite differences. All tests use a -32x48 image so xdim != ydim exercises the rectangular-image code paths. +"""Tests for ehtim regularizer values and boundary/edge-case gradients. + +Verifies that all regularizer types return finite (and, for pol/spectral, +non-zero) values, that the TV-family gradients are correct on the zero-pad +boundary, and the center-of-mass regularizer semantics. The systematic +analytic-vs-finite-difference gradient parity for every regularizer (Stokes-I, +pol, spectral) lives in test_gradients.py, the canonical FD suite. All tests use +a 32x48 image so xdim != ydim exercises the rectangular-image code paths. """ import numpy as np @@ -14,10 +17,6 @@ import ehtim.imaging.multifreq_imager_utils as mfu import ehtim.imaging.pol_imager_utils as pu -# Tolerances for gradient checks (calibrated on 32x48 synthetic Gaussian) -MEDIAN_FRAC_TOL = 0.05 -MAX_FRAC_TOL = 0.6 - # --------------------------------------------------------------------------- # Regularizer optional parameters (explicit for tracking across refactors) # --------------------------------------------------------------------------- @@ -30,14 +29,6 @@ RGAUSS_MINOR = 60.0 * eh.RADPERUAS RGAUSS_PA = np.pi / 3 -# --------------------------------------------------------------------------- -# Numeric gradient parameters -# --------------------------------------------------------------------------- -N_GRAD_SAMPLES = 100 -RNG_SEED = 4 -GRAD_DX_REL = 1e-8 # relative step size per pixel -GRAD_DX_FLOOR = 1e-12 # absolute minimum step size - @pytest.fixture(scope="module") def reg_setup(make_asym_image): @@ -72,24 +63,6 @@ def test_returns_finite(self, reg_setup, rtype, norm_reg): assert np.isfinite(val), f"{rtype} (norm_reg={norm_reg}) returned {val}" -class TestRegularizerGradients: - """Analytic regularizer gradients match numeric finite differences.""" - - @pytest.mark.parametrize("rtype", iu.REGULARIZERS) - def test_median_frac_diff(self, reg_setup, rtype): - median_frac, _ = _gradient_check(reg_setup, rtype) - assert median_frac < MEDIAN_FRAC_TOL, ( - f"{rtype} median fractional gradient diff = {median_frac:.6f}" - ) - - @pytest.mark.parametrize("rtype", iu.REGULARIZERS) - def test_max_frac_diff(self, reg_setup, rtype): - _, max_frac = _gradient_check(reg_setup, rtype) - assert max_frac < MAX_FRAC_TOL, ( - f"{rtype} max fractional gradient diff = {max_frac:.6f}" - ) - - class TestCenterOfMassRegularizer: """`reg_cm` / `reggrad_cm` semantics on a full-grid imvec. @@ -172,59 +145,8 @@ def _reg_kwargs(rtype, norm_reg=True): return kwargs -def _gradient_check(reg_setup, rtype): - """Compute median and max fractional diff between analytic and numeric gradient.""" - im, imvec, nprior, mask, flux = reg_setup - kwargs = _reg_kwargs(rtype) - - y0 = iu.regularizer(imvec, nprior, mask, flux, im.xdim, im.ydim, im.psize, rtype, **kwargs) - grad_exact = iu.regularizergrad(imvec, nprior, mask, flux, im.xdim, im.ydim, im.psize, rtype, **kwargs) - - rng = np.random.default_rng(RNG_SEED) - sample_idx = rng.choice(len(imvec), size=N_GRAD_SAMPLES, replace=False) - - grad_numeric = np.zeros(N_GRAD_SAMPLES) - for i, j in enumerate(sample_idx): - dx = max(GRAD_DX_REL * abs(imvec[j]), GRAD_DX_FLOOR) - imvec2 = imvec.copy() - imvec2[j] += dx - y1 = iu.regularizer(imvec2, nprior, mask, flux, im.xdim, im.ydim, im.psize, rtype, **kwargs) - grad_numeric[i] = (y1 - y0) / dx - - grad_exact_sampled = grad_exact[sample_idx] - compare_floor = np.min(np.abs(grad_exact_sampled)) * 1e-20 + 1e-100 - frac_diff = np.abs((grad_numeric - grad_exact_sampled) / (np.abs(grad_exact_sampled) + compare_floor)) - return np.median(frac_diff), np.max(frac_diff) - - # polregularizer / polregularizergrad operate on a (4, nimage) imarr in # solver space: imarr[0]=I, imarr[1]=rho, imarr[2]=phi=2*chi, imarr[3]=psi. -# Linear-pol regs solve slots (rho, phi); circular-pol regs solve psi only. -POL_LIN_REGS = ('msimple', 'hw', 'ptv') -POL_CIRC_REGS = ('vflux', 'l1v', 'l2v', 'vtv', 'vtv2') - -POL_MEDIAN_FRAC_TOL = 0.05 -POL_MAX_FRAC_TOL = 0.6 -# TV-style regs (ptv, vtv, vtv2) have a sqrt-of-squared-differences denominator -# that goes to ~0 on smooth regions; a few FD samples near image edges land in -# the small-denominator regime where the linear approximation breaks down. -POL_MAX_FRAC_TOL_TV = 2.0 - - -def _pol_solve_for(rtype): - # Drive every physical slot (I, rho, phi, psi), not just the mode's DOF - # slots, so the cross-coupling slots are FD-checked: reggrad_ptv psi (3), - # reggrad_vflux/l1v/l2v/vtv rho (1), and slot 0 for every pol reg. Kernels - # that do not fill a slot leave it 0, which FD of the value confirms. - return np.array([1, 1, 1, 1]) - - -def _pol_tols(rtype): - if rtype in ('ptv', 'vtv', 'vtv2'): - return POL_MEDIAN_FRAC_TOL, POL_MAX_FRAC_TOL_TV - return POL_MEDIAN_FRAC_TOL, POL_MAX_FRAC_TOL - - @pytest.fixture(scope="module") def polreg_setup(make_asym_image): """32x48 asymmetric Stokes I with jittered pol structure. @@ -286,23 +208,7 @@ def test_returns_nonzero(self, polreg_setup, rtype): class TestPolRegularizerGradients: - """Analytic gradients match numeric finite differences in pol_solve slots.""" - - @pytest.mark.parametrize("rtype", pu.REGULARIZERS_POL) - def test_median_frac_diff(self, polreg_setup, rtype): - median_tol, _ = _pol_tols(rtype) - median_frac, _ = _pol_gradient_check(polreg_setup, rtype) - assert median_frac < median_tol, ( - f"{rtype} median fractional gradient diff = {median_frac:.6f} (tol={median_tol})" - ) - - @pytest.mark.parametrize("rtype", pu.REGULARIZERS_POL) - def test_max_frac_diff(self, polreg_setup, rtype): - _, max_tol = _pol_tols(rtype) - _, max_frac = _pol_gradient_check(polreg_setup, rtype) - assert max_frac < max_tol, ( - f"{rtype} max fractional gradient diff = {max_frac:.6f} (tol={max_tol})" - ) + """TV-family regularizer gradients on the zero-pad boundary and at the |.| kink.""" # TV-family regularizers: tv/tv2 are Stokes-I (1D imvec), ptv/vtv/vtv2 are pol (4-row imarr) @pytest.mark.parametrize("rtype, is_pol", [("tv", False), ("tv2", False), @@ -368,47 +274,6 @@ def test_tv_epsilon_removes_singularity(self, rtype, is_pol): assert np.all(np.isfinite(ge)), f"{rtype}: expected finite grad at epsilon_tv>0" -def _pol_gradient_check(polreg_setup, rtype): - """FD-vs-analytic check across N pixels in each pol_solve-active slot.""" - im, imarr, priorarr, mask, flux, pflux, vflux = polreg_setup - kwargs = _polreg_kwargs() - pol_solve = _pol_solve_for(rtype) - - y0 = pu.polregularizer( - imarr, priorarr, mask, flux, pflux, vflux, - im.xdim, im.ydim, im.psize, rtype, **kwargs, - ) - grad_exact = pu.polregularizergrad( - imarr, priorarr, mask, flux, pflux, vflux, - im.xdim, im.ydim, im.psize, rtype, - pol_solve=pol_solve, **kwargs, - ) - - rng = np.random.default_rng(RNG_SEED) - nimage = imarr.shape[1] - sample_idx = rng.choice(nimage, size=N_GRAD_SAMPLES, replace=False) - - frac_diffs = [] - for slot in range(4): - if pol_solve[slot] == 0: - continue - for j in sample_idx: - dx = max(GRAD_DX_REL * abs(imarr[slot, j]), GRAD_DX_FLOOR) - imarr2 = imarr.copy() - imarr2[slot, j] += dx - y1 = pu.polregularizer( - imarr2, priorarr, mask, flux, pflux, vflux, - im.xdim, im.ydim, im.psize, rtype, **kwargs, - ) - numeric = (y1 - y0) / dx - exact = grad_exact[slot, j] - compare_floor = max(abs(exact), 1e-100) * 1e-20 + 1e-100 - frac_diffs.append(abs((numeric - exact) / (abs(exact) + compare_floor))) - - frac_diffs = np.array(frac_diffs) - return float(np.median(frac_diffs)), float(np.max(frac_diffs)) - - # reggrad_{ptv,vtv,tv} zero their back-neighbor (m2/m3, g2/g3) terms on the # first row/column, where the back-neighbor is the zero pad and does not exist. # Without it the entire first row+column of the affected slots is wrong (corner @@ -525,50 +390,3 @@ def test_returns_nonzero(self, mfreg_setup, rtype): assert val != 0, f"{rtype} returned 0 - dispatch likely missed" -class TestMFRegularizerGradients: - """Analytic regularizergrad_mf matches numeric finite differences.""" - - @pytest.mark.parametrize("rtype", ib.REGULARIZERS_SPECTRAL) - def test_median_frac_diff(self, mfreg_setup, rtype): - median_frac, _ = _mfreg_gradient_check(mfreg_setup, rtype) - assert median_frac < MEDIAN_FRAC_TOL, ( - f"{rtype} median fractional gradient diff = {median_frac:.6f}" - ) - - @pytest.mark.parametrize("rtype", ib.REGULARIZERS_SPECTRAL) - def test_max_frac_diff(self, mfreg_setup, rtype): - _, max_frac = _mfreg_gradient_check(mfreg_setup, rtype) - assert max_frac < MAX_FRAC_TOL, ( - f"{rtype} max fractional gradient diff = {max_frac:.6f}" - ) - - -def _mfreg_gradient_check(mfreg_setup, rtype): - """Compare analytic and finite-difference gradients on N random pixels.""" - im, imvec, nprior, mask = mfreg_setup - kwargs = dict(beam_size=BEAM_SIZE, norm_reg=True) - - y0 = mfu.regularizer_mf( - imvec, nprior, mask, im.xdim, im.ydim, im.psize, rtype, **kwargs, - ) - grad_exact = mfu.regularizergrad_mf( - imvec, nprior, mask, im.xdim, im.ydim, im.psize, rtype, **kwargs, - ) - - rng = np.random.default_rng(RNG_SEED) - sample_idx = rng.choice(len(imvec), size=N_GRAD_SAMPLES, replace=False) - - grad_numeric = np.zeros(N_GRAD_SAMPLES) - for i, j in enumerate(sample_idx): - dx = max(GRAD_DX_REL * abs(imvec[j]), GRAD_DX_FLOOR) - imvec2 = imvec.copy() - imvec2[j] += dx - y1 = mfu.regularizer_mf( - imvec2, nprior, mask, im.xdim, im.ydim, im.psize, rtype, **kwargs, - ) - grad_numeric[i] = (y1 - y0) / dx - - grad_exact_sampled = grad_exact[sample_idx] - compare_floor = np.min(np.abs(grad_exact_sampled)) * 1e-20 + 1e-100 - frac_diff = np.abs((grad_numeric - grad_exact_sampled) / (np.abs(grad_exact_sampled) + compare_floor)) - return float(np.median(frac_diff)), float(np.max(frac_diff)) diff --git a/tests/test_regularizers_jax.py b/tests/test_regularizers_jax.py index 8d7f7ba8..b5397fc8 100644 --- a/tests/test_regularizers_jax.py +++ b/tests/test_regularizers_jax.py @@ -90,3 +90,29 @@ def test_three_way_gradient(reg_setup, rtype): frac = np.abs((g_fd - g_analytic[idx]) / (np.abs(g_analytic[idx]) + 1e-100)) assert np.median(frac) < FD_MEDIAN_TOL assert np.max(frac) < FD_MAX_TOL + + +# ============================== embed / partial mask ============================== +def test_embed_functional_jax(): + # embed's functional scatter is byte-identical on numpy and differentiable on jax + mask = np.array([True, False, True, False, True, True]) + imvec = np.array([1.0, 2.0, 3.0, 4.0]) + assert np.array_equal(iu.embed(imvec, mask), np.asarray(iu.embed(jnp.asarray(imvec), mask))) + g = jax.grad(lambda v: jnp.sum(iu.embed(v, mask) ** 2))(jnp.asarray(imvec)) + assert np.allclose(np.asarray(g), 2 * imvec) # grad routes only to on-mask pixels + + +def test_spatial_reg_partial_mask_parity(): + # reg_tv with a partial mask exercises the embed scatter under jax; clipfloor=0 + # makes the off-mask fill deterministic (no seed needed for numpy<->jax parity). + rng = np.random.default_rng(0) + ny, nx = 6, 6 + full = rng.uniform(0.1, 1.0, ny * nx) + mask = np.ones(ny * nx, dtype=bool) + mask[rng.choice(ny * nx, 8, replace=False)] = False + imvec = full[mask] + kw = dict(xdim=nx, ydim=ny, psize=1.0, flux=float(full.sum()), beam_size=2.0, norm_reg=True) + assert np.allclose(iu.reg_tv(imvec, mask, **kw), + float(iu.reg_tv(jnp.asarray(imvec), mask, **kw)), rtol=1e-12) + g_jax = np.asarray(jax.grad(lambda v: iu.reg_tv(v, mask, **kw))(jnp.asarray(imvec))) + assert np.allclose(g_jax, iu.reggrad_tv(imvec, mask, **kw), rtol=1e-8, atol=1e-10)