From c0baf50711db2dfe1777ffa5e5e53f2137b64eb5 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Thu, 2 Apr 2026 10:18:13 -0400 Subject: [PATCH 1/7] create batch lm object --- astrophot/fit/batch_lm.py | 36 ++++++++++++++++++++++++++++++++++++ astrophot/fit/lm.py | 2 -- 2 files changed, 36 insertions(+), 2 deletions(-) create mode 100644 astrophot/fit/batch_lm.py diff --git a/astrophot/fit/batch_lm.py b/astrophot/fit/batch_lm.py new file mode 100644 index 00000000..44da54f8 --- /dev/null +++ b/astrophot/fit/batch_lm.py @@ -0,0 +1,36 @@ +from ..models import Model +from ..image import TargetImageBatch +from .base import BaseOptimizer + + +class BatchLM(BaseOptimizer): + + def __init__( + self, + model: Model, + batch_target: TargetImageBatch, + max_iter: int = 100, + relative_tolerance: float = 1e-5, + Lup=11.0, + Ldn=9.0, + L0=1.0, + max_step_iter: int = 10, + ndf=None, + likelihood="gaussian", + constraint: Optional[LMConstraint] = None, + forward=None, + jacobian=None, + **kwargs, + ): + + super().__init__( + model=model, + initial_state=model.get_values(), + max_iter=max_iter, + relative_tolerance=relative_tolerance, + **kwargs, + ) + + self.Lup = Lup + self.Ldn = Ldn + self.L = L0 diff --git a/astrophot/fit/lm.py b/astrophot/fit/lm.py index adca7c97..fd3b5b8e 100644 --- a/astrophot/fit/lm.py +++ b/astrophot/fit/lm.py @@ -159,8 +159,6 @@ def __init__( relative_tolerance=relative_tolerance, **kwargs, ) - # Maximum number of iterations of the algorithm - self.max_iter = max_iter # Maximum number of steps while searching for chi^2 improvement on a single jacobian evaluation self.max_step_iter = max_step_iter self.Lup = Lup From a7aacc179c146bb6c4f6761afc9dcf2c0e46249a Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Thu, 2 Apr 2026 10:28:54 -0400 Subject: [PATCH 2/7] likelihood setting --- astrophot/fit/batch_lm.py | 6 ++++++ astrophot/fit/lm.py | 4 +++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/astrophot/fit/batch_lm.py b/astrophot/fit/batch_lm.py index 44da54f8..f27ab85d 100644 --- a/astrophot/fit/batch_lm.py +++ b/astrophot/fit/batch_lm.py @@ -34,3 +34,9 @@ def __init__( self.Lup = Lup self.Ldn = Ldn self.L = L0 + + self.likelihood = likelihood + if self.likelihood not in ["gaussian", "poisson"]: + raise ValueError( + f"Unsupported likelihood: {self.likelihood}, should be one of: 'gaussian' or 'poisson'" + ) diff --git a/astrophot/fit/lm.py b/astrophot/fit/lm.py index fd3b5b8e..b8a2521d 100644 --- a/astrophot/fit/lm.py +++ b/astrophot/fit/lm.py @@ -166,7 +166,9 @@ def __init__( self.L = L0 self.likelihood = likelihood if self.likelihood not in ["gaussian", "poisson"]: - raise ValueError(f"Unsupported likelihood: {self.likelihood}") + raise ValueError( + f"Unsupported likelihood: {self.likelihood}, should be one of: 'gaussian' or 'poisson'" + ) self.constraint = constraint # mask From f49cad3704d8a85e3337b18df4f969fe3d48de7a Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Thu, 2 Apr 2026 11:18:06 -0400 Subject: [PATCH 3/7] building setup for batch LM --- astrophot/fit/batch_lm.py | 77 ++++++++++++++++++++++++-- astrophot/models/batch_model_object.py | 6 +- 2 files changed, 76 insertions(+), 7 deletions(-) diff --git a/astrophot/fit/batch_lm.py b/astrophot/fit/batch_lm.py index f27ab85d..3343ae37 100644 --- a/astrophot/fit/batch_lm.py +++ b/astrophot/fit/batch_lm.py @@ -1,6 +1,9 @@ from ..models import Model -from ..image import TargetImageBatch +from ..image import TargetImageBatch, WindowBatch from .base import BaseOptimizer +from ..backend_obj import backend +from .. import config +from ..errors import OptimizeStopFail, OptimizeStopSuccess class BatchLM(BaseOptimizer): @@ -9,17 +12,15 @@ def __init__( self, model: Model, batch_target: TargetImageBatch, + batch_window: WindowBatch, max_iter: int = 100, relative_tolerance: float = 1e-5, Lup=11.0, Ldn=9.0, L0=1.0, max_step_iter: int = 10, - ndf=None, likelihood="gaussian", constraint: Optional[LMConstraint] = None, - forward=None, - jacobian=None, **kwargs, ): @@ -40,3 +41,71 @@ def __init__( raise ValueError( f"Unsupported likelihood: {self.likelihood}, should be one of: 'gaussian' or 'poisson'" ) + + # mask + mask = backend.flatten(batch_target[batch_window].mask, 1, -1) + self.mask = ~mask + if backend.sum(self.mask).item() == 0: + raise OptimizeStopSuccess("No data to fit. All pixels are masked") + + # data + self.data = backend.flatten(batch_target[batch_window].data, 1, -1) + + # Weight + self.weight = backend.flatten(batch_target[batch_window].weight, 1, -1) + + # WCS + crtan = batch_target.crtan + shift = backend.as_array( + batch_window.origin_shifter(self.model.window), dtype=config.DTYPE, device=config.DEVICE + ) + crpix = batch_target[batch_window].crpix + shift + CD = batch_target.CD + psf = batch_target.psf_stack + psf_batch = None if psf is None else 0 + + # Forward + vmodel = backend.vmap( + lambda cd, crt, crp, psf, params: backend.flatten( + self.model(cd, crt, crp, psf, params=params).data + ), + in_dims=(0, 0, 0, psf_batch, 0), + ) + self.forward = lambda x: vmodel(CD, crtan, crpix, psf, x) + + # Jacobian + vjac = backend.vmap( + backend.jacfwd( + lambda cd, crt, crp, psf, params: backend.flatten( + self.model(cd, crt, crp, psf, params=params).data + ), + argnums=4, + ), + in_dims=(0, 0, 0, psf_batch, 0), + ) + self.jacobian = lambda x: vjac(CD, crtan, crpix, psf, x) + + # ndf + self.ndf = backend.sum(self.mask, axis=1) - self.current_state.shape[1] + + def chi2_ndf(self): + return ( + backend.sum( + self.weight * self.mask * (self.data - self.forward(self.current_state)) ** 2, + axis=1, + ) + / self.ndf + ) + + def poisson_2nll_ndf(self): + M = self.forward(self.current_state) + return ( + 2 * backend.sum((M - self.data * backend.log(M + 1e-10)) * self.mask, axis=1) / self.ndf + ) + + def fit(self, update_uncertainty=True): + if self.current_state.shape[1] == 0: + if self.verbose > 0: + config.logger.warning("No parameters to optimize. Exiting fit") + self.message = "No parameters to optimize. Exiting fit" + return self diff --git a/astrophot/models/batch_model_object.py b/astrophot/models/batch_model_object.py index 273aa5a1..acec95f1 100644 --- a/astrophot/models/batch_model_object.py +++ b/astrophot/models/batch_model_object.py @@ -168,12 +168,12 @@ def window(self, window): @forward def __call__(self, model_params=None, model_dims=None, **kwargs): working_image = self.target.model_image(self.window) - crtan = self.target.crtan + crtan = working_image.crtan shift = backend.as_array( self.window.origin_shifter(self.model.window), dtype=config.DTYPE, device=config.DEVICE ) - crpix = self.target.crpix + shift - CD = self.target.CD + crpix = working_image.crpix + shift + CD = working_image.CD psf = self.target.psf_stack psf_batch = None if psf is None else 0 working_image._data = backend.vmap( From a9c9eae003ca7759f7667200be9e2eac0bff8dab Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Thu, 2 Apr 2026 11:48:58 -0400 Subject: [PATCH 4/7] building out fit method --- astrophot/fit/batch_lm.py | 77 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 72 insertions(+), 5 deletions(-) diff --git a/astrophot/fit/batch_lm.py b/astrophot/fit/batch_lm.py index 3343ae37..67af6359 100644 --- a/astrophot/fit/batch_lm.py +++ b/astrophot/fit/batch_lm.py @@ -4,6 +4,8 @@ from ..backend_obj import backend from .. import config from ..errors import OptimizeStopFail, OptimizeStopSuccess +from ..param import ValidContext +from . import func class BatchLM(BaseOptimizer): @@ -20,7 +22,6 @@ def __init__( L0=1.0, max_step_iter: int = 10, likelihood="gaussian", - constraint: Optional[LMConstraint] = None, **kwargs, ): @@ -32,10 +33,7 @@ def __init__( **kwargs, ) - self.Lup = Lup - self.Ldn = Ldn - self.L = L0 - + # Likelihood self.likelihood = likelihood if self.likelihood not in ["gaussian", "poisson"]: raise ValueError( @@ -88,6 +86,13 @@ def __init__( # ndf self.ndf = backend.sum(self.mask, axis=1) - self.current_state.shape[1] + # LM parmeters + self.Lup = Lup + self.Ldn = Ldn + self.L = L0 * backend.ones( + self.current_state.shape[0], dtype=config.DTYPE, device=config.DEVICE + ) + def chi2_ndf(self): return ( backend.sum( @@ -109,3 +114,65 @@ def fit(self, update_uncertainty=True): config.logger.warning("No parameters to optimize. Exiting fit") self.message = "No parameters to optimize. Exiting fit" return self + + if self.likelihood == "gaussian": + quantity = "Chi^2/DoF" + self.loss_history = [backend.to_numpy(self.chi2_ndf())] + elif self.likelihood == "poisson": + quantity = "2NLL/DoF" + self.loss_history = [backend.to_numpy(self.poisson_2nll_ndf())] + self._covariance_matrix = None + self.L_history = [backend.to_numpy(self.L)] + self.lambda_history = [backend.to_numpy(backend.copy(self.current_state))] + if self.verbose > 0: + config.logger.info( + f"==Starting LM fit for '{self.model.name}' with batch of {self.current_state.shape[0]} images with {self.current_state.shape[1]} dynamic parameters and {self.data.shape[1]} pixels==" + ) + + for _ in range(self.max_iter): + if self.verbose > 0: + config.logger.info(f"{quantity}: {self.loss_history[-1]}, L: {self.L_history[-1]}") + + try: + if self.fit_valid: + with ValidContext(self.model): + res = func.batch_lm_step( + x=self.model.to_valid(self.current_state), + data=self.data, + model=self.forward, + weight=self.weight, + mask=self.mask, + jacobian=self.jacobian, + L=self.L, + Lup=self.Lup, + Ldn=self.Ldn, + likelihood=self.likelihood, + ) + self.current_state = self.model.from_valid(backend.copy(res["x"])) + else: + res = func.batch_lm_step( + x=self.current_state, + data=self.data, + model=self.forward, + weight=self.weight, + mask=self.mask, + jacobian=self.jacobian, + L=self.L, + Lup=self.Lup, + Ldn=self.Ldn, + likelihood=self.likelihood, + ) + self.current_state = backend.copy(res["x"]) + except OptimizeStopFail: + if self.verbose > 0: + config.logger.warning("Could not find step to improve Chi^2, stopping") + self.message = ( + self.message + + "success by immobility. Could not find step to improve Chi^2. Convergence not guaranteed" + ) + break + except OptimizeStopSuccess as e: + if self.verbose > 0: + config.logger.info(f"Optimization converged successfully: {e}") + self.message = self.message + "success" + break From d21ea49a2cbe217688c7f8983ca4c9d22be45285 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Thu, 2 Apr 2026 15:03:12 -0400 Subject: [PATCH 5/7] working on batchlm first run --- astrophot/fit/__init__.py | 2 + astrophot/fit/batch_lm.py | 155 +++++++++++++++++++------ astrophot/fit/func/__init__.py | 3 +- astrophot/fit/func/lm.py | 84 ++++++++++++-- docs/source/tutorials/ImageTypes.ipynb | 80 ++++++++++++- 5 files changed, 275 insertions(+), 49 deletions(-) diff --git a/astrophot/fit/__init__.py b/astrophot/fit/__init__.py index b788aa7a..80553b35 100644 --- a/astrophot/fit/__init__.py +++ b/astrophot/fit/__init__.py @@ -1,4 +1,5 @@ from .lm import LM, LMConstraint +from .batch_lm import BatchLM from .gradient import Grad, Slalom from .iterative import Iter, IterParam from .scipy_fit import ScipyFit @@ -10,6 +11,7 @@ __all__ = [ "LM", "LMConstraint", + "BatchLM", "Grad", "Iter", "MALA", diff --git a/astrophot/fit/batch_lm.py b/astrophot/fit/batch_lm.py index 67af6359..6e38cf98 100644 --- a/astrophot/fit/batch_lm.py +++ b/astrophot/fit/batch_lm.py @@ -1,9 +1,10 @@ +import numpy as np from ..models import Model from ..image import TargetImageBatch, WindowBatch from .base import BaseOptimizer -from ..backend_obj import backend +from ..backend_obj import backend, ArrayLike from .. import config -from ..errors import OptimizeStopFail, OptimizeStopSuccess +from ..errors import OptimizeStopSuccess from ..param import ValidContext from . import func @@ -20,7 +21,7 @@ def __init__( Lup=11.0, Ldn=9.0, L0=1.0, - max_step_iter: int = 10, + max_step_iter: int = 3, likelihood="gaussian", **kwargs, ): @@ -33,6 +34,8 @@ def __init__( **kwargs, ) + self.max_step_iter = max_step_iter + # Likelihood self.likelihood = likelihood if self.likelihood not in ["gaussian", "poisson"]: @@ -84,7 +87,7 @@ def __init__( self.jacobian = lambda x: vjac(CD, crtan, crpix, psf, x) # ndf - self.ndf = backend.sum(self.mask, axis=1) - self.current_state.shape[1] + self.ndf = backend.sum(self.mask, dim=1) - self.current_state.shape[1] # LM parmeters self.Lup = Lup @@ -97,7 +100,7 @@ def chi2_ndf(self): return ( backend.sum( self.weight * self.mask * (self.data - self.forward(self.current_state)) ** 2, - axis=1, + dim=1, ) / self.ndf ) @@ -105,7 +108,7 @@ def chi2_ndf(self): def poisson_2nll_ndf(self): M = self.forward(self.current_state) return ( - 2 * backend.sum((M - self.data * backend.log(M + 1e-10)) * self.mask, axis=1) / self.ndf + 2 * backend.sum((M - self.data * backend.log(M + 1e-10)) * self.mask, dim=1) / self.ndf ) def fit(self, update_uncertainty=True): @@ -133,25 +136,10 @@ def fit(self, update_uncertainty=True): if self.verbose > 0: config.logger.info(f"{quantity}: {self.loss_history[-1]}, L: {self.L_history[-1]}") - try: - if self.fit_valid: - with ValidContext(self.model): - res = func.batch_lm_step( - x=self.model.to_valid(self.current_state), - data=self.data, - model=self.forward, - weight=self.weight, - mask=self.mask, - jacobian=self.jacobian, - L=self.L, - Lup=self.Lup, - Ldn=self.Ldn, - likelihood=self.likelihood, - ) - self.current_state = self.model.from_valid(backend.copy(res["x"])) - else: + if self.fit_valid: + with ValidContext(self.model): res = func.batch_lm_step( - x=self.current_state, + x=self.model.to_valid(self.current_state), data=self.data, model=self.forward, weight=self.weight, @@ -161,18 +149,113 @@ def fit(self, update_uncertainty=True): Lup=self.Lup, Ldn=self.Ldn, likelihood=self.likelihood, + max_step_iter=self.max_step_iter, ) - self.current_state = backend.copy(res["x"]) - except OptimizeStopFail: - if self.verbose > 0: - config.logger.warning("Could not find step to improve Chi^2, stopping") - self.message = ( - self.message - + "success by immobility. Could not find step to improve Chi^2. Convergence not guaranteed" + self.current_state = self.model.from_valid(backend.copy(res["x"])) + else: + res = func.batch_lm_step( + x=self.current_state, + data=self.data, + model=self.forward, + weight=self.weight, + mask=self.mask, + jacobian=self.jacobian, + L=self.L, + Lup=self.Lup, + Ldn=self.Ldn, + likelihood=self.likelihood, + max_step_iter=self.max_step_iter, ) + self.current_state = backend.copy(res["x"]) + + self.L = np.clip(res["L"], 1e-9, 1e9) + self.L_history.append(res["L"]) + self.loss_history.append(2 * res["nll"] / backend.to_numpy(self.ndf)) + self.lambda_history.append(backend.to_numpy(backend.copy(self.current_state))) + + if self.check_convergence(): break - except OptimizeStopSuccess as e: - if self.verbose > 0: - config.logger.info(f"Optimization converged successfully: {e}") - self.message = self.message + "success" - break + else: + self.message = self.message + "fail. Maximum iterations" + + if self.verbose > 0: + config.logger.info( + f"Final {quantity}: {np.nanmin(self.loss_history, axis=0)}, L: {self.L_history[np.nanargmin(self.loss_history, axis=0)]}. Converged: {self.message}" + ) + + self.model.set_values(self.current_state) + if update_uncertainty: + self.update_uncertainty() + + return self + + def check_convergence(self) -> bool: + """Check if the optimization has converged based on the last + iteration's chi^2 and the relative tolerance. + """ + if len(self.loss_history) < 3: + return False + if np.all( + (self.loss_history[-2] - self.loss_history[-1]) / self.loss_history[-1] + < self.relative_tolerance + ) and np.all(self.L < 0.1): + self.message = self.message + "success" + return True + if len(self.loss_history) < 10: + return False + if np.all( + (self.loss_history[-10] - self.loss_history[-1]) / self.loss_history[-1] + < self.relative_tolerance + ): + self.message = self.message + "success by immobility. Convergence not guaranteed" + return True + return False + + @property + def covariance_matrix(self) -> ArrayLike: + """The covariance matrix for the model at the current + parameters. This can be used to construct a full Gaussian PDF for the + parameters using: $\\mathcal{N}(\\mu,\\Sigma)$ where $\\mu$ is the + optimized parameters and $\\Sigma$ is the covariance matrix. + + """ + + if self._covariance_matrix is not None: + return self._covariance_matrix + J = self.jacobian(self.current_state) + if self.likelihood == "gaussian": + hess = backend.vmap(func.hessian)(J, self.weight) + elif self.likelihood == "poisson": + hess = backend.vmap(func.hessian_poisson)( + J, self.data, self.forward(self.current_state) + ) + try: + self._covariance_matrix = backend.vmap(backend.linalg.inv)(hess) + except: + config.logger.warning( + "WARNING: Hessian is singular, likely at least one parameter is non-physical. Will use pseudo-inverse of Hessian to continue but results should be inspected." + ) + self._covariance_matrix = backend.vmap(backend.linalg.pinv)(hess) + return self._covariance_matrix + + def update_uncertainty(self) -> None: + """Call this function after optimization to set the uncertainties for + the parameters. This will use the diagonal of the covariance + matrix to update the uncertainties. See the covariance_matrix + function for the full representation of the uncertainties. + + """ + # set the uncertainty for each parameter + cov = self.covariance_matrix + if backend.all(backend.isfinite(cov)): + try: + self.model.set_values( + backend.sqrt(backend.abs(backend.vmap(backend.diag)(cov))), + attribute="uncertainty", + ) + except RuntimeError as e: + config.logger.warning(f"Unable to update uncertainty due to: {e}") + else: + config.logger.warning( + "Unable to update uncertainty due to non finite covariance matrix" + ) diff --git a/astrophot/fit/func/__init__.py b/astrophot/fit/func/__init__.py index 58da703e..7d13dcd7 100644 --- a/astrophot/fit/func/__init__.py +++ b/astrophot/fit/func/__init__.py @@ -1,9 +1,10 @@ -from .lm import lm_step, hessian, gradient, hessian_poisson, gradient_poisson +from .lm import lm_step, hessian, gradient, hessian_poisson, gradient_poisson, batch_lm_step from .slalom import slalom_step from .mala import mala __all__ = [ "lm_step", + "batch_lm_step", "hessian", "gradient", "slalom_step", diff --git a/astrophot/fit/func/lm.py b/astrophot/fit/func/lm.py index 2e88c816..047d3d2c 100644 --- a/astrophot/fit/func/lm.py +++ b/astrophot/fit/func/lm.py @@ -12,7 +12,7 @@ def nll(D, M, W): M: model prediction W: weights """ - return 0.5 * backend.sum(W * (D - M) ** 2) + return 0.5 * backend.sum(W * (D - M) ** 2, dim=-1) def nll_poisson(D, M): @@ -21,23 +21,23 @@ def nll_poisson(D, M): D: data M: model prediction """ - return backend.sum(M - D * backend.log(M + 1e-10)) # Adding small value to avoid log(0) + return backend.sum(M - D * backend.log(M + 1e-10), dim=-1) # Adding small value to avoid log(0) def gradient(J, W, D, M): - return J.T @ (W * (D - M))[:, None] + return J.T @ (W * (D - M)).reshape(D.shape + (1,)) def gradient_poisson(J, D, M): - return J.T @ (D / M - 1)[:, None] + return J.T @ (D / M - 1).reshape(D.shape + (1,)) def hessian(J, W): - return J.T @ (W[:, None] * J) + return J.T @ (W.reshape(W.shape + (1,)) * J) def hessian_poisson(J, D, M): - return J.T @ ((D / (M**2 + 1e-10))[:, None] * J) + return J.T @ ((D / (M**2 + 1e-10)).reshape(D.shape + (1,)) * J) def damp_hessian(hess, L): @@ -46,6 +46,10 @@ def damp_hessian(hess, L): return hess * (I + D / (1 + L)) + L * I * backend.diag(hess) +def rho(nll0, nll1, h, hessD, grad): + return (nll0 - nll1) / backend.abs(h.T @ hessD @ h - 2 * grad.T @ h) + + def solve(hess, grad, L): hessD = damp_hessian(hess, L) # (N, N) while True: @@ -116,15 +120,15 @@ def lm_step( break # actual nll improvement vs expected from linearization - rho = (nll0 - nll1) / backend.abs(h.T @ hessD @ h - 2 * grad.T @ h).item() + _rho = rho(nll0, nll1, h, hessD, grad).item() - if (nll1 < (nll0 + tolerance) and abs(rho - 1) < abs(scary["rho"] - 1)) or ( - nll1 < scary["nll"] and rho > -10 + if (nll1 < (nll0 + tolerance) and abs(_rho - 1) < abs(scary["rho"] - 1)) or ( + nll1 < scary["nll"] and _rho > -10 ): - scary = {"x": x + h.squeeze(1), "nll": nll1, "L": L0, "rho": rho} + scary = {"x": x + h.squeeze(1), "nll": nll1, "L": L0, "rho": _rho} # Avoid highly non-linear regions - if rho < 0.1 or rho > 2: + if _rho < 0.1 or _rho > 2: L *= Lup if improving is True: break @@ -156,3 +160,61 @@ def lm_step( raise OptimizeStopFail("Could not find step to improve chi^2") return best + + +def batch_lm_step( + x, + data, + model, + weight, + mask, + jacobian, + L=1.0, + Lup=9.0, + Ldn=11.0, + tolerance=1e-4, + likelihood="gaussian", + max_step_iter=3, +): + L0 = L # (D,) + M0 = backend.detach(model(x)) # (D, M) + J = backend.detach(jacobian(x)) # (D, M, N) + data = data * mask + M0 = M0 * mask + weight = weight * mask + J = J * mask.reshape(mask.shape + (1,)) + + if likelihood == "gaussian": + nll0 = nll(data, M0, weight) # (D,) + grad = backend.vmap(gradient)(J, weight, data, M0) # (D, N, 1) + hess = backend.vmap(hessian)(J, weight) # (D, N, N) + elif likelihood == "poisson": + nll0 = nll_poisson(data, M0) # (D,) + grad = backend.vmap(gradient_poisson)(J, data, M0) # (D, N, 1) + hess = backend.vmap(hessian_poisson)(J, data, M0) # (D, N, N) + else: + raise ValueError(f"Unsupported likelihood: {likelihood}") + + del J + + new_x = backend.copy(x) + new_nll = backend.copy(nll0) + new_L = backend.copy(L) + + for _ in range(max_step_iter): + hessD, h = backend.vmap(solve)(hess, grad, new_L) # (D, N, N), (D, N, 1) + M1 = model(x + h.squeeze(2)) # (D, M) + if likelihood == "gaussian": + nll1 = nll(data, M1, weight) # (D,) + elif likelihood == "poisson": + nll1 = nll_poisson(data, M1) # (D,) + + # actual nll improvement vs expected from linearization + _rho = backend.vmap(rho)(nll0, nll1, h, hessD, grad) # (D,) + + good = backend.isfinite(nll1) & (nll1 < new_nll) & (_rho > 0.1) & (_rho < 2) + new_x = backend.where(good[:, None], x + h.squeeze(2), new_x) + new_nll = backend.where(good, nll1, new_nll) + new_L = backend.where(good, new_L / Ldn, new_L * Lup) + + return {"x": new_x, "nll": backend.to_numpy(new_nll), "L": backend.to_numpy(new_L)} diff --git a/docs/source/tutorials/ImageTypes.ipynb b/docs/source/tutorials/ImageTypes.ipynb index f6a71d3e..fdf6f937 100644 --- a/docs/source/tutorials/ImageTypes.ipynb +++ b/docs/source/tutorials/ImageTypes.ipynb @@ -20,7 +20,8 @@ "import astrophot as ap\n", "import torch\n", "import matplotlib.pyplot as plt\n", - "from matplotlib.patches import Rectangle" + "from matplotlib.patches import Rectangle\n", + "import numpy as np" ] }, { @@ -152,6 +153,83 @@ "id": "11", "metadata": {}, "source": [] + }, + { + "cell_type": "markdown", + "id": "12", + "metadata": {}, + "source": [ + "# Batch Images\n", + "\n", + "An `ImageBatch` is much like an `ImageList` except that it has certain restrictions applied to it. Namely that all the images must have the same size, and they must be of the \"regular\" image type, not `SIP` or `CMOS`. This extra constraint means they can't be used as generally as other image types, but the trade off is extra computational advantages. You can see the `batch scene model` in the model zoo for one of the primary advantages of the `ImageBatch` image type. Another advantage is that it can be used to vectorize the fitting of images. Lets say we have some simple model (either just a component model, or a full group model) and we want to use that same model to represent a bunch of similar images. We can use batched images to make this much more efficient. Rather than looping through the images one at a time, we can create a batch and set all their calculations to the hardware (CPU or GPU) at once. Here is an example of what this looks like, but the real advantage is when you seriously scale it up." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "base_target = ap.TargetImage(data=np.zeros((64, 64)))\n", + "model = ap.Model(\n", + " model_type=\"sersic galaxy model\",\n", + " name=\"batch_demo\",\n", + " center=(32, 40),\n", + " q=0.6,\n", + " PA=3.14 / 3,\n", + " n=1,\n", + " Re=10,\n", + " Ie=1,\n", + " target=base_target,\n", + ")\n", + "model.initialize()\n", + "\n", + "target1 = ap.TargetImage(\n", + " name=\"target1\", data=model().data + np.random.normal(scale=0.5, size=(64, 64))\n", + ")\n", + "model.center = (15, 15)\n", + "model.PA = -3.14 / 4\n", + "target2 = ap.TargetImage(\n", + " name=\"target2\", data=model().data + np.random.normal(scale=0.5, size=(64, 64))\n", + ")\n", + "model.center = (45, 20)\n", + "model.Re = 15\n", + "target3 = ap.TargetImage(\n", + " name=\"target3\", data=model().data + np.random.normal(scale=0.5, size=(64, 64))\n", + ")\n", + "\n", + "batch_target = ap.TargetImageBatch([target1, target2, target3])\n", + "fig, axarr = plt.subplots(1, 3, figsize=(15, 5))\n", + "ap.plots.target_image(fig, axarr, batch_target)\n", + "plt.show()\n", + "\n", + "# Every parameter must be batched\n", + "# Notice we set them a bit off the true values\n", + "model.center = ((30, 42), (16, 16), (46, 21))\n", + "model.q = (0.6, 0.6, 0.6)\n", + "model.PA = (3.14 / 3, -3.14 / 5, -3.14 / 4)\n", + "model.n = (1.1, 0.9, 1)\n", + "model.Re = (11, 11, 14)\n", + "model.Ie = (1, 0.8, 1.2)\n", + "\n", + "res = ap.fit.BatchLM(model, batch_target, batch_target.window).fit()\n", + "\n", + "fig, axarr = plt.subplots(2, 3, figsize=(15, 5))\n", + "for i in range(3):\n", + " model.set_values(res.current_state[i])\n", + " ap.plots.model_image(fig, axarr[0, i], model)\n", + " ap.plots.residual_image(fig, axarr[1, i], model, target=batch_target.images[i])\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { From 9fbaff97f5ee1011533bb79acd2eb26497ac30e7 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Thu, 2 Apr 2026 15:41:26 -0400 Subject: [PATCH 6/7] batch lm runs! --- astrophot/fit/batch_lm.py | 8 ++++---- astrophot/fit/func/lm.py | 4 ++-- docs/source/tutorials/ImageTypes.ipynb | 28 +++++++++++++++++--------- 3 files changed, 25 insertions(+), 15 deletions(-) diff --git a/astrophot/fit/batch_lm.py b/astrophot/fit/batch_lm.py index 6e38cf98..74358447 100644 --- a/astrophot/fit/batch_lm.py +++ b/astrophot/fit/batch_lm.py @@ -168,8 +168,8 @@ def fit(self, update_uncertainty=True): ) self.current_state = backend.copy(res["x"]) - self.L = np.clip(res["L"], 1e-9, 1e9) - self.L_history.append(res["L"]) + self.L = backend.clamp(res["L"], backend.as_array(1e-9), backend.as_array(1e9)) + self.L_history.append(backend.to_numpy(self.L)) self.loss_history.append(2 * res["nll"] / backend.to_numpy(self.ndf)) self.lambda_history.append(backend.to_numpy(backend.copy(self.current_state))) @@ -180,7 +180,7 @@ def fit(self, update_uncertainty=True): if self.verbose > 0: config.logger.info( - f"Final {quantity}: {np.nanmin(self.loss_history, axis=0)}, L: {self.L_history[np.nanargmin(self.loss_history, axis=0)]}. Converged: {self.message}" + f"Final {quantity}: {self.loss_history[-1]}, L: {self.L_history[-1]}. Converged: {self.message}" ) self.model.set_values(self.current_state) @@ -198,7 +198,7 @@ def check_convergence(self) -> bool: if np.all( (self.loss_history[-2] - self.loss_history[-1]) / self.loss_history[-1] < self.relative_tolerance - ) and np.all(self.L < 0.1): + ) and np.all(backend.to_numpy(self.L) < 0.1): self.message = self.message + "success" return True if len(self.loss_history) < 10: diff --git a/astrophot/fit/func/lm.py b/astrophot/fit/func/lm.py index 047d3d2c..47632043 100644 --- a/astrophot/fit/func/lm.py +++ b/astrophot/fit/func/lm.py @@ -210,11 +210,11 @@ def batch_lm_step( nll1 = nll_poisson(data, M1) # (D,) # actual nll improvement vs expected from linearization - _rho = backend.vmap(rho)(nll0, nll1, h, hessD, grad) # (D,) + _rho = backend.vmap(rho)(nll0, nll1, h, hessD, grad).reshape(-1) # (D,) good = backend.isfinite(nll1) & (nll1 < new_nll) & (_rho > 0.1) & (_rho < 2) new_x = backend.where(good[:, None], x + h.squeeze(2), new_x) new_nll = backend.where(good, nll1, new_nll) new_L = backend.where(good, new_L / Ldn, new_L * Lup) - return {"x": new_x, "nll": backend.to_numpy(new_nll), "L": backend.to_numpy(new_L)} + return {"x": new_x, "nll": backend.to_numpy(new_nll), "L": new_L} diff --git a/docs/source/tutorials/ImageTypes.ipynb b/docs/source/tutorials/ImageTypes.ipynb index fdf6f937..b52af740 100644 --- a/docs/source/tutorials/ImageTypes.ipynb +++ b/docs/source/tutorials/ImageTypes.ipynb @@ -152,22 +152,18 @@ "cell_type": "markdown", "id": "11", "metadata": {}, - "source": [] - }, - { - "cell_type": "markdown", - "id": "12", - "metadata": {}, "source": [ "# Batch Images\n", "\n", - "An `ImageBatch` is much like an `ImageList` except that it has certain restrictions applied to it. Namely that all the images must have the same size, and they must be of the \"regular\" image type, not `SIP` or `CMOS`. This extra constraint means they can't be used as generally as other image types, but the trade off is extra computational advantages. You can see the `batch scene model` in the model zoo for one of the primary advantages of the `ImageBatch` image type. Another advantage is that it can be used to vectorize the fitting of images. Lets say we have some simple model (either just a component model, or a full group model) and we want to use that same model to represent a bunch of similar images. We can use batched images to make this much more efficient. Rather than looping through the images one at a time, we can create a batch and set all their calculations to the hardware (CPU or GPU) at once. Here is an example of what this looks like, but the real advantage is when you seriously scale it up." + "An `ImageBatch` is much like an `ImageList` except that it has certain restrictions applied to it. Namely that all the images must have the same size, and they must be of the \"regular\" image type, not `SIP` or `CMOS`. This extra constraint means they can't be used as generally as other image types, but the trade off is extra computational advantages. You can see the `batch scene model` in the model zoo for one of the primary advantages of the `ImageBatch` image type. Another advantage is that it can be used to vectorize the fitting of images. Lets say we have some simple model (either just a component model, or a full group model) and we want to use that same model to represent a bunch of similar images. We can use batched images to make this much more efficient. Rather than looping through the images one at a time, we can create a batch and set all their calculations to the hardware (CPU or GPU) at once. Here is an example of what this looks like, but the real advantage is when you seriously scale it up.\n", + "\n", + "Note: Every dynamic parameter for the model must be batched, these are supposed to be independent models and so every parameter is unique for each model." ] }, { "cell_type": "code", "execution_count": null, - "id": "13", + "id": "12", "metadata": {}, "outputs": [], "source": [ @@ -182,6 +178,8 @@ " Re=10,\n", " Ie=1,\n", " target=base_target,\n", + " integrate_mode=\"none\",\n", + " sampling_mode=\"quad:3\",\n", ")\n", "model.initialize()\n", "\n", @@ -214,15 +212,27 @@ "model.Ie = (1, 0.8, 1.2)\n", "\n", "res = ap.fit.BatchLM(model, batch_target, batch_target.window).fit()\n", + "print(\"Best-fit parameters:\", res.current_state.detach().cpu().numpy())\n", "\n", "fig, axarr = plt.subplots(2, 3, figsize=(15, 5))\n", "for i in range(3):\n", " model.set_values(res.current_state[i])\n", + " model.target = batch_target.images[i]\n", " ap.plots.model_image(fig, axarr[0, i], model)\n", - " ap.plots.residual_image(fig, axarr[1, i], model, target=batch_target.images[i])\n", + " ap.plots.residual_image(fig, axarr[1, i], model)\n", "plt.show()" ] }, + { + "cell_type": "markdown", + "id": "13", + "metadata": {}, + "source": [ + "And there we have it! Three models fit at once. Really, just doing three isn't a big deal computationally. But if you have dozens or hundreds of these little 64x64 pixel images then grouping them all together can get huge computational speedups. If you are able to run these big batched fits on a GPU you can get several orders of magnitude runtime reduction!\n", + "\n", + "Happy fitting!" + ] + }, { "cell_type": "code", "execution_count": null, From dd9704a23fe0e3f69d7263307dfb4466eec48e7f Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Thu, 2 Apr 2026 16:00:29 -0400 Subject: [PATCH 7/7] fixes from copilot suggestions --- astrophot/fit/batch_lm.py | 12 +++++++----- docs/source/tutorials/ImageTypes.ipynb | 3 ++- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/astrophot/fit/batch_lm.py b/astrophot/fit/batch_lm.py index 74358447..f8a53470 100644 --- a/astrophot/fit/batch_lm.py +++ b/astrophot/fit/batch_lm.py @@ -87,9 +87,11 @@ def __init__( self.jacobian = lambda x: vjac(CD, crtan, crpix, psf, x) # ndf - self.ndf = backend.sum(self.mask, dim=1) - self.current_state.shape[1] + self.ndf = backend.clamp( + backend.sum(self.mask, dim=1) - self.current_state.shape[1], backend.as_array(1), None + ) - # LM parmeters + # LM parameters self.Lup = Lup self.Ldn = Ldn self.L = L0 * backend.ones( @@ -222,12 +224,12 @@ def covariance_matrix(self) -> ArrayLike: if self._covariance_matrix is not None: return self._covariance_matrix - J = self.jacobian(self.current_state) + J = self.jacobian(self.current_state) * self.mask.reshape(self.mask.shape + (1,)) if self.likelihood == "gaussian": - hess = backend.vmap(func.hessian)(J, self.weight) + hess = backend.vmap(func.hessian)(J, self.weight * self.mask) elif self.likelihood == "poisson": hess = backend.vmap(func.hessian_poisson)( - J, self.data, self.forward(self.current_state) + J, self.data * self.mask, self.forward(self.current_state) * self.mask ) try: self._covariance_matrix = backend.vmap(backend.linalg.inv)(hess) diff --git a/docs/source/tutorials/ImageTypes.ipynb b/docs/source/tutorials/ImageTypes.ipynb index b52af740..41152f07 100644 --- a/docs/source/tutorials/ImageTypes.ipynb +++ b/docs/source/tutorials/ImageTypes.ipynb @@ -155,7 +155,7 @@ "source": [ "# Batch Images\n", "\n", - "An `ImageBatch` is much like an `ImageList` except that it has certain restrictions applied to it. Namely that all the images must have the same size, and they must be of the \"regular\" image type, not `SIP` or `CMOS`. This extra constraint means they can't be used as generally as other image types, but the trade off is extra computational advantages. You can see the `batch scene model` in the model zoo for one of the primary advantages of the `ImageBatch` image type. Another advantage is that it can be used to vectorize the fitting of images. Lets say we have some simple model (either just a component model, or a full group model) and we want to use that same model to represent a bunch of similar images. We can use batched images to make this much more efficient. Rather than looping through the images one at a time, we can create a batch and set all their calculations to the hardware (CPU or GPU) at once. Here is an example of what this looks like, but the real advantage is when you seriously scale it up.\n", + "An `ImageBatch` is much like an `ImageList` except that it has certain restrictions applied to it. Namely that all the images must have the same size, and they must be of the \"regular\" image type, not `SIP` or `CMOS`. This extra constraint means they can't be used as generally as other image types, but the trade off is extra computational advantages. You can see the `batch scene model` in the model zoo for one of the primary advantages of the `ImageBatch` image type. Another advantage is that it can be used to vectorize the fitting of images. Let's say we have some simple model (either just a component model, or a full group model) and we want to use that same model to represent a bunch of similar images. We can use batched images to make this much more efficient. Rather than looping through the images one at a time, we can create a batch and set all their calculations to the hardware (CPU or GPU) at once. Here is an example of what this looks like, but the real advantage is when you seriously scale it up.\n", "\n", "Note: Every dynamic parameter for the model must be batched, these are supposed to be independent models and so every parameter is unique for each model." ] @@ -183,6 +183,7 @@ ")\n", "model.initialize()\n", "\n", + "np.random.seed(42)\n", "target1 = ap.TargetImage(\n", " name=\"target1\", data=model().data + np.random.normal(scale=0.5, size=(64, 64))\n", ")\n",