Skip to content
Merged
11 changes: 6 additions & 5 deletions ehtim/imaging/multifreq_imager_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from ehtim.backends import array_namespace

NORM_REGULARIZER = True
EPSILON = 1.e-12
DD_RHOPOL = 1 # transform paramter for multifrequency polarization fraction
##################################################################################################
# multifrequency transformations
Expand Down Expand Up @@ -249,12 +248,13 @@ def reg_tv_spec(imvec, mask, **kwargs):
imvec = embed(imvec, mask, clipfloor=0, randomfloor=False)
nx, ny, psize = kwargs['xdim'], kwargs['ydim'], kwargs['psize']
beam_size = kwargs.get('beam_size') or psize
epsilon = kwargs.get('epsilon_tv', 0.)
norm = len(imvec) * psize / beam_size if kwargs.get('norm_reg', True) else 1
im = imvec.reshape(ny, nx)
impad = xp.pad(im, 1, mode='constant', constant_values=0)
im_l1 = xp.roll(impad, -1, axis=0)[1:ny+1, 1:nx+1]
im_l2 = xp.roll(impad, -1, axis=1)[1:ny+1, 1:nx+1]
return xp.sum(xp.sqrt(xp.abs(im_l1 - im)**2 + xp.abs(im_l2 - im)**2 + EPSILON)) / norm
return xp.sum(xp.sqrt(xp.abs(im_l1 - im)**2 + xp.abs(im_l2 - im)**2 + epsilon)) / norm


def reggrad_tv_spec(imvec, mask, **kwargs):
Expand All @@ -263,6 +263,7 @@ def reggrad_tv_spec(imvec, mask, **kwargs):
imvec = embed(imvec, mask, clipfloor=0, randomfloor=False)
nx, ny, psize = kwargs['xdim'], kwargs['ydim'], kwargs['psize']
beam_size = kwargs.get('beam_size') or psize
epsilon = kwargs.get('epsilon_tv', 0.)
norm = len(imvec) * psize / beam_size if kwargs.get('norm_reg', True) else 1
im = imvec.reshape(ny, nx)
impad = np.pad(im, 1, mode='constant', constant_values=0)
Expand All @@ -272,9 +273,9 @@ def reggrad_tv_spec(imvec, mask, **kwargs):
im_r2 = np.roll(impad, 1, axis=1)[1:ny+1, 1:nx+1]
im_r1l2 = np.roll(np.roll(impad, 1, axis=0), -1, axis=1)[1:ny+1, 1:nx+1]
im_l1r2 = np.roll(np.roll(impad, -1, axis=0), 1, axis=1)[1:ny+1, 1:nx+1]
g1 = (2*im - im_l1 - im_l2) / np.sqrt((im - im_l1)**2 + (im - im_l2)**2 + EPSILON)
g2 = (im - im_r1) / np.sqrt((im - im_r1)**2 + (im_r1l2 - im_r1)**2 + EPSILON)
g3 = (im - im_r2) / np.sqrt((im - im_r2)**2 + (im_l1r2 - im_r2)**2 + EPSILON)
g1 = (2*im - im_l1 - im_l2) / np.sqrt((im - im_l1)**2 + (im - im_l2)**2 + epsilon)
g2 = (im - im_r1) / np.sqrt((im - im_r1)**2 + (im_r1l2 - im_r1)**2 + epsilon)
g3 = (im - im_r2) / np.sqrt((im - im_r2)**2 + (im_l1r2 - im_r2)**2 + epsilon)
mask1 = np.zeros(im.shape)
mask2 = np.zeros(im.shape)
mask1[0, :] = 1
Expand Down
14 changes: 8 additions & 6 deletions ehtim/imaging/pol_imager_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,12 @@ def polcv_grad(imarr, gradarr):
return out


def _rho_psi_safe(xp, mfrac, vfrac):
"""rho=sqrt(mfrac^2+vfrac^2), psi=arcsin(vfrac/rho), guarded so jax.grad stays finite
at zero-polarization pixels (sqrt(0) and arcsin(+/-1) singularities). Values are
unchanged where rho>0 and |vfrac/rho|<1.
def rho_psi_from_mfrac_vfrac(xp, mfrac, vfrac):
"""Polar coordinates (rho, psi) of the (mfrac, vfrac) polarization vector.

rho = sqrt(mfrac^2 + vfrac^2), psi = arcsin(vfrac/rho), guarded so jax.grad
stays finite at zero-polarization pixels (the sqrt(0) and arcsin(+/-1)
singularities). Values are unchanged where rho > 0 and |vfrac/rho| < 1.
"""
r2 = mfrac**2 + vfrac**2
rho = xp.where(r2 > 0, xp.sqrt(xp.where(r2 > 0, r2, 1.0)), 0.0)
Expand All @@ -224,7 +226,7 @@ def mcv(imarr):
mfrac_prime = imarr[1]
mfrac = mfrac_max*(0.5 + xp.arctan(mfrac_prime/TANWIDTH_M)/np.pi)

rho, psi = _rho_psi_safe(xp, mfrac, vfrac)
rho, psi = rho_psi_from_mfrac_vfrac(xp, mfrac, vfrac)

out = xp.stack((imarr[0], rho, imarr[2], psi))
return out
Expand Down Expand Up @@ -304,7 +306,7 @@ def vcv(imarr):
vfrac_prime = imarr[3]
vfrac = 2*vfrac_max*xp.arctan(vfrac_prime/TANWIDTH_V)/np.pi

rho, psi = _rho_psi_safe(xp, mfrac, vfrac)
rho, psi = rho_psi_from_mfrac_vfrac(xp, mfrac, vfrac)

out = xp.stack((imarr[0], rho, imarr[2], psi))
return out
Expand Down
103 changes: 16 additions & 87 deletions tests/test_chisquared.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Tests for chi-squared consistency across transform types (direct, fast, nfft).
"""Tests for chi-squared cross-transform consistency (direct, fast, nfft).

Verifies that chi-squared values and gradients agree between DFT, FFT, and NFFT
for all standard data types. All tests use a 32x48 image so xdim != ydim
exercises the rectangular-image code paths (rect subsumes square).
Verifies that chi-squared values and gradients AGREE across DFT, FFT, and NFFT for all
data types -- a cross-check, not a standalone correctness proof: the direct path is the
reference, finite-difference-validated in test_gradients.py (the canonical FD suite), so a
tight direct-vs-nfft gradient bound pins nfft gradient correctness. All tests use a 32x48
image so xdim != ydim exercises the rectangular-image code paths (rect subsumes square).
"""

import numpy as np
Expand All @@ -14,7 +16,6 @@
MfConfig,
compute_chisq_term,
compute_chisqdata_term,
compute_chisqgrad_term,
)
from ehtim.imaging.imager_utils import chisq, chisqdata, chisqgrad

Expand All @@ -33,20 +34,20 @@
# Transform type pairs to compare
TTYPE_PAIRS = [("direct", "fast"), ("direct", "nfft"), ("nfft", "fast")]

# NFFT max gradient tolerance is much wider at this resolution
GRAD_MAX_TOL_NFFT = 10.0
# Cross-transform gradient agreement. The direct DFT path is the trusted reference -- its
# gradients are finite-difference-validated in test_gradients.py -- so a tight direct-vs-nfft
# bound effectively pins nfft gradient correctness. Pairs that include the gridded 'fast' FFT
# are limited by its interpolation accuracy, not by a gradient bug, so they stay looser.
GRAD_MAX_TOL_DIRECT_NFFT = 1e-2
GRAD_MAX_TOL = 0.25 # any pair containing 'fast'

# Diagonalized closures orthogonalize per-timestamp covariance, which can
# amplify direct-vs-fast tail outliers in test_grad_max_frac_diff. Bump the
# max-fractional-diff tolerance for these dtypes only; median tolerance and
# chi-squared tolerance unchanged.
# Diagonalized closures orthogonalize per-timestamp covariance, which amplifies tail outliers.
GRAD_MAX_TOL_DIAG = 0.5
DIAG_DTYPES = {"cphase_diag", "logcamp_diag"}

# Tolerances (calibrated on 32x48 synthetic Gaussian)
CHISQ_FRAC_TOL = 0.01
GRAD_MEDIAN_TOL = 0.05
GRAD_MAX_TOL = 0.25

# ---------------------------------------------------------------------------
# chisqdata optional parameters (explicit for tracking across refactors)
Expand Down Expand Up @@ -175,12 +176,10 @@ def test_grad_median_frac_diff(self, chisq_setup, dtype, pair):
@pytest.mark.parametrize("pair", TTYPE_PAIRS, ids=lambda p: f"{p[0]}-{p[1]}")
def test_grad_max_frac_diff(self, chisq_setup, dtype, pair):
_, max_frac = _gradient_comparison(chisq_setup, dtype, pair)
if "nfft" in pair:
tol = GRAD_MAX_TOL_NFFT
elif dtype in DIAG_DTYPES:
tol = GRAD_MAX_TOL_DIAG
if pair == ("direct", "nfft"):
tol = GRAD_MAX_TOL_DIAG if dtype in DIAG_DTYPES else GRAD_MAX_TOL_DIRECT_NFFT
else:
tol = GRAD_MAX_TOL
tol = GRAD_MAX_TOL_DIAG if dtype in DIAG_DTYPES else GRAD_MAX_TOL
assert max_frac < tol, (
f"{dtype} {pair[0]}-{pair[1]}: grad max frac diff = {max_frac:.6f}"
)
Expand Down Expand Up @@ -210,14 +209,6 @@ def _gradient_comparison(chisq_setup, dtype, pair):
# Polarimetric chi-squared (pvis / m / vvis). Pol has no 'fast' ttype.
# ---------------------------------------------------------------------------
POL_DATATERMS = ["pvis", "m", "vvis"]
POL_FD_REL = 1e-6
POL_FD_FLOOR = 1e-9
# fractional FD vs analytic (same ttype, self-consistent). median is tight (any
# systematic gradient error blows it up); max is looser to tolerate 2nd-order FD
# truncation outliers at small-gradient pixels of the structured-pol imcur. Real
# pol-gradient bugs are %-level (e.g. the mcv slot-3 coupling), far above these.
POL_GRAD_FD_MEDIAN_TOL = 1e-5
POL_GRAD_FD_MAX_TOL = 1e-3
POL_CHISQ_FRAC_TOL = 0.01 # direct-vs-nfft value agreement


Expand Down Expand Up @@ -283,65 +274,3 @@ def test_chisq_values(self, chisq_setup_pol, dtype):
vals[tt] = compute_chisq_term(imcur, dtype, A, data, sigma, ttype=tt, mask=mask)
frac = abs((vals["direct"] - vals["nfft"]) / abs(vals["direct"]))
assert frac < POL_CHISQ_FRAC_TOL, f"{dtype}: chisq frac diff = {frac:.6f}"


class TestPolChisqGradConsistency:
"""Pol chi-squared gradients agree between direct and nfft."""

@pytest.mark.parametrize("dtype", POL_DATATERMS)
def test_grad_values(self, chisq_setup_pol, dtype):
obs, prior = chisq_setup_pol["obs"], chisq_setup_pol["prior"]
mask, imcur = chisq_setup_pol["mask"], chisq_setup_pol["imcur"]
grads = {}
for tt in ("direct", "nfft"):
A, data, sigma = _pol_data_tuple(obs, prior, mask, dtype, tt)
grads[tt] = compute_chisqgrad_term(
imcur, dtype, A, data, sigma, ttype=tt, mask=mask,
pol_solve=np.array([1, 1, 1, 1]))
a, b = grads["direct"], grads["nfft"]
floor = np.min(np.abs(a)) * 1e-20 + 1e-100
frac = np.abs((a - b) / (np.abs(a) + floor))
assert np.median(frac) < GRAD_MEDIAN_TOL
assert np.max(frac) < GRAD_MAX_TOL_NFFT


class TestPolChisqGradFD:
"""Pol chisqgrad matches finite differences of the chisq value, in all four
physical slots (driven with pol_solve=[1,1,1,1]). vvis slot 2 (EVPA) is
asserted identically zero -- Stokes V is independent of EVPA -- and its FD
is also zero, so the all-slot loop covers it consistently."""

@pytest.mark.parametrize("dtype", POL_DATATERMS)
@pytest.mark.parametrize("ttype", ["direct", "nfft"])
def test_grad_matches_fd(self, chisq_setup_pol, dtype, ttype):
obs, prior = chisq_setup_pol["obs"], chisq_setup_pol["prior"]
mask, imcur = chisq_setup_pol["mask"], chisq_setup_pol["imcur"]
A, data, sigma = _pol_data_tuple(obs, prior, mask, dtype, ttype)
grad = compute_chisqgrad_term(
imcur, dtype, A, data, sigma, ttype=ttype, mask=mask,
pol_solve=np.array([1, 1, 1, 1]))

if dtype == "vvis": # V is independent of EVPA
np.testing.assert_array_equal(grad[2], 0.0)

rng = np.random.default_rng(RNG_SEED)
n = imcur.shape[1]
sample = rng.choice(n, size=min(30, n), replace=False)
frac = []
for slot in range(4):
for j in sample:
dx = max(POL_FD_REL * abs(imcur[slot, j]), POL_FD_FLOOR)
ip = imcur.copy()
ip[slot, j] += dx
im_ = imcur.copy()
im_[slot, j] -= dx
fd = (compute_chisq_term(ip, dtype, A, data, sigma, ttype=ttype, mask=mask)
- compute_chisq_term(im_, dtype, A, data, sigma, ttype=ttype, mask=mask)) / (2 * dx)
ex = grad[slot, j]
denom = max(abs(ex), abs(fd), POL_FD_FLOOR)
frac.append(abs(ex - fd) / denom)
frac = np.array(frac)
assert np.median(frac) < POL_GRAD_FD_MEDIAN_TOL, (
f"{dtype} {ttype}: median frac diff = {np.median(frac):.4g}")
assert np.max(frac) < POL_GRAD_FD_MAX_TOL, (
f"{dtype} {ttype}: max frac diff = {np.max(frac):.4g}")
Loading
Loading