diff --git a/astrophot/fit/__init__.py b/astrophot/fit/__init__.py index b788aa7a..80553b35 100644 --- a/astrophot/fit/__init__.py +++ b/astrophot/fit/__init__.py @@ -1,4 +1,5 @@ from .lm import LM, LMConstraint +from .batch_lm import BatchLM from .gradient import Grad, Slalom from .iterative import Iter, IterParam from .scipy_fit import ScipyFit @@ -10,6 +11,7 @@ __all__ = [ "LM", "LMConstraint", + "BatchLM", "Grad", "Iter", "MALA", diff --git a/astrophot/fit/batch_lm.py b/astrophot/fit/batch_lm.py new file mode 100644 index 00000000..f8a53470 --- /dev/null +++ b/astrophot/fit/batch_lm.py @@ -0,0 +1,263 @@ +import numpy as np +from ..models import Model +from ..image import TargetImageBatch, WindowBatch +from .base import BaseOptimizer +from ..backend_obj import backend, ArrayLike +from .. import config +from ..errors import OptimizeStopSuccess +from ..param import ValidContext +from . import func + + +class BatchLM(BaseOptimizer): + + def __init__( + self, + model: Model, + batch_target: TargetImageBatch, + batch_window: WindowBatch, + max_iter: int = 100, + relative_tolerance: float = 1e-5, + Lup=11.0, + Ldn=9.0, + L0=1.0, + max_step_iter: int = 3, + likelihood="gaussian", + **kwargs, + ): + + super().__init__( + model=model, + initial_state=model.get_values(), + max_iter=max_iter, + relative_tolerance=relative_tolerance, + **kwargs, + ) + + self.max_step_iter = max_step_iter + + # Likelihood + self.likelihood = likelihood + if self.likelihood not in ["gaussian", "poisson"]: + raise ValueError( + f"Unsupported likelihood: {self.likelihood}, should be one of: 'gaussian' or 'poisson'" + ) + + # mask + mask = backend.flatten(batch_target[batch_window].mask, 1, -1) + self.mask = ~mask + if backend.sum(self.mask).item() == 0: + raise OptimizeStopSuccess("No data to fit. All pixels are masked") + + # data + self.data = backend.flatten(batch_target[batch_window].data, 1, -1) + + # Weight + self.weight = backend.flatten(batch_target[batch_window].weight, 1, -1) + + # WCS + crtan = batch_target.crtan + shift = backend.as_array( + batch_window.origin_shifter(self.model.window), dtype=config.DTYPE, device=config.DEVICE + ) + crpix = batch_target[batch_window].crpix + shift + CD = batch_target.CD + psf = batch_target.psf_stack + psf_batch = None if psf is None else 0 + + # Forward + vmodel = backend.vmap( + lambda cd, crt, crp, psf, params: backend.flatten( + self.model(cd, crt, crp, psf, params=params).data + ), + in_dims=(0, 0, 0, psf_batch, 0), + ) + self.forward = lambda x: vmodel(CD, crtan, crpix, psf, x) + + # Jacobian + vjac = backend.vmap( + backend.jacfwd( + lambda cd, crt, crp, psf, params: backend.flatten( + self.model(cd, crt, crp, psf, params=params).data + ), + argnums=4, + ), + in_dims=(0, 0, 0, psf_batch, 0), + ) + self.jacobian = lambda x: vjac(CD, crtan, crpix, psf, x) + + # ndf + self.ndf = backend.clamp( + backend.sum(self.mask, dim=1) - self.current_state.shape[1], backend.as_array(1), None + ) + + # LM parameters + self.Lup = Lup + self.Ldn = Ldn + self.L = L0 * backend.ones( + self.current_state.shape[0], dtype=config.DTYPE, device=config.DEVICE + ) + + def chi2_ndf(self): + return ( + backend.sum( + self.weight * self.mask * (self.data - self.forward(self.current_state)) ** 2, + dim=1, + ) + / self.ndf + ) + + def poisson_2nll_ndf(self): + M = self.forward(self.current_state) + return ( + 2 * backend.sum((M - self.data * backend.log(M + 1e-10)) * self.mask, dim=1) / self.ndf + ) + + def fit(self, update_uncertainty=True): + if self.current_state.shape[1] == 0: + if self.verbose > 0: + config.logger.warning("No parameters to optimize. Exiting fit") + self.message = "No parameters to optimize. Exiting fit" + return self + + if self.likelihood == "gaussian": + quantity = "Chi^2/DoF" + self.loss_history = [backend.to_numpy(self.chi2_ndf())] + elif self.likelihood == "poisson": + quantity = "2NLL/DoF" + self.loss_history = [backend.to_numpy(self.poisson_2nll_ndf())] + self._covariance_matrix = None + self.L_history = [backend.to_numpy(self.L)] + self.lambda_history = [backend.to_numpy(backend.copy(self.current_state))] + if self.verbose > 0: + config.logger.info( + f"==Starting LM fit for '{self.model.name}' with batch of {self.current_state.shape[0]} images with {self.current_state.shape[1]} dynamic parameters and {self.data.shape[1]} pixels==" + ) + + for _ in range(self.max_iter): + if self.verbose > 0: + config.logger.info(f"{quantity}: {self.loss_history[-1]}, L: {self.L_history[-1]}") + + if self.fit_valid: + with ValidContext(self.model): + res = func.batch_lm_step( + x=self.model.to_valid(self.current_state), + data=self.data, + model=self.forward, + weight=self.weight, + mask=self.mask, + jacobian=self.jacobian, + L=self.L, + Lup=self.Lup, + Ldn=self.Ldn, + likelihood=self.likelihood, + max_step_iter=self.max_step_iter, + ) + self.current_state = self.model.from_valid(backend.copy(res["x"])) + else: + res = func.batch_lm_step( + x=self.current_state, + data=self.data, + model=self.forward, + weight=self.weight, + mask=self.mask, + jacobian=self.jacobian, + L=self.L, + Lup=self.Lup, + Ldn=self.Ldn, + likelihood=self.likelihood, + max_step_iter=self.max_step_iter, + ) + self.current_state = backend.copy(res["x"]) + + self.L = backend.clamp(res["L"], backend.as_array(1e-9), backend.as_array(1e9)) + self.L_history.append(backend.to_numpy(self.L)) + self.loss_history.append(2 * res["nll"] / backend.to_numpy(self.ndf)) + self.lambda_history.append(backend.to_numpy(backend.copy(self.current_state))) + + if self.check_convergence(): + break + else: + self.message = self.message + "fail. Maximum iterations" + + if self.verbose > 0: + config.logger.info( + f"Final {quantity}: {self.loss_history[-1]}, L: {self.L_history[-1]}. Converged: {self.message}" + ) + + self.model.set_values(self.current_state) + if update_uncertainty: + self.update_uncertainty() + + return self + + def check_convergence(self) -> bool: + """Check if the optimization has converged based on the last + iteration's chi^2 and the relative tolerance. + """ + if len(self.loss_history) < 3: + return False + if np.all( + (self.loss_history[-2] - self.loss_history[-1]) / self.loss_history[-1] + < self.relative_tolerance + ) and np.all(backend.to_numpy(self.L) < 0.1): + self.message = self.message + "success" + return True + if len(self.loss_history) < 10: + return False + if np.all( + (self.loss_history[-10] - self.loss_history[-1]) / self.loss_history[-1] + < self.relative_tolerance + ): + self.message = self.message + "success by immobility. Convergence not guaranteed" + return True + return False + + @property + def covariance_matrix(self) -> ArrayLike: + """The covariance matrix for the model at the current + parameters. This can be used to construct a full Gaussian PDF for the + parameters using: $\\mathcal{N}(\\mu,\\Sigma)$ where $\\mu$ is the + optimized parameters and $\\Sigma$ is the covariance matrix. + + """ + + if self._covariance_matrix is not None: + return self._covariance_matrix + J = self.jacobian(self.current_state) * self.mask.reshape(self.mask.shape + (1,)) + if self.likelihood == "gaussian": + hess = backend.vmap(func.hessian)(J, self.weight * self.mask) + elif self.likelihood == "poisson": + hess = backend.vmap(func.hessian_poisson)( + J, self.data * self.mask, self.forward(self.current_state) * self.mask + ) + try: + self._covariance_matrix = backend.vmap(backend.linalg.inv)(hess) + except: + config.logger.warning( + "WARNING: Hessian is singular, likely at least one parameter is non-physical. Will use pseudo-inverse of Hessian to continue but results should be inspected." + ) + self._covariance_matrix = backend.vmap(backend.linalg.pinv)(hess) + return self._covariance_matrix + + def update_uncertainty(self) -> None: + """Call this function after optimization to set the uncertainties for + the parameters. This will use the diagonal of the covariance + matrix to update the uncertainties. See the covariance_matrix + function for the full representation of the uncertainties. + + """ + # set the uncertainty for each parameter + cov = self.covariance_matrix + if backend.all(backend.isfinite(cov)): + try: + self.model.set_values( + backend.sqrt(backend.abs(backend.vmap(backend.diag)(cov))), + attribute="uncertainty", + ) + except RuntimeError as e: + config.logger.warning(f"Unable to update uncertainty due to: {e}") + else: + config.logger.warning( + "Unable to update uncertainty due to non finite covariance matrix" + ) diff --git a/astrophot/fit/func/__init__.py b/astrophot/fit/func/__init__.py index 58da703e..7d13dcd7 100644 --- a/astrophot/fit/func/__init__.py +++ b/astrophot/fit/func/__init__.py @@ -1,9 +1,10 @@ -from .lm import lm_step, hessian, gradient, hessian_poisson, gradient_poisson +from .lm import lm_step, hessian, gradient, hessian_poisson, gradient_poisson, batch_lm_step from .slalom import slalom_step from .mala import mala __all__ = [ "lm_step", + "batch_lm_step", "hessian", "gradient", "slalom_step", diff --git a/astrophot/fit/func/lm.py b/astrophot/fit/func/lm.py index 2e88c816..47632043 100644 --- a/astrophot/fit/func/lm.py +++ b/astrophot/fit/func/lm.py @@ -12,7 +12,7 @@ def nll(D, M, W): M: model prediction W: weights """ - return 0.5 * backend.sum(W * (D - M) ** 2) + return 0.5 * backend.sum(W * (D - M) ** 2, dim=-1) def nll_poisson(D, M): @@ -21,23 +21,23 @@ def nll_poisson(D, M): D: data M: model prediction """ - return backend.sum(M - D * backend.log(M + 1e-10)) # Adding small value to avoid log(0) + return backend.sum(M - D * backend.log(M + 1e-10), dim=-1) # Adding small value to avoid log(0) def gradient(J, W, D, M): - return J.T @ (W * (D - M))[:, None] + return J.T @ (W * (D - M)).reshape(D.shape + (1,)) def gradient_poisson(J, D, M): - return J.T @ (D / M - 1)[:, None] + return J.T @ (D / M - 1).reshape(D.shape + (1,)) def hessian(J, W): - return J.T @ (W[:, None] * J) + return J.T @ (W.reshape(W.shape + (1,)) * J) def hessian_poisson(J, D, M): - return J.T @ ((D / (M**2 + 1e-10))[:, None] * J) + return J.T @ ((D / (M**2 + 1e-10)).reshape(D.shape + (1,)) * J) def damp_hessian(hess, L): @@ -46,6 +46,10 @@ def damp_hessian(hess, L): return hess * (I + D / (1 + L)) + L * I * backend.diag(hess) +def rho(nll0, nll1, h, hessD, grad): + return (nll0 - nll1) / backend.abs(h.T @ hessD @ h - 2 * grad.T @ h) + + def solve(hess, grad, L): hessD = damp_hessian(hess, L) # (N, N) while True: @@ -116,15 +120,15 @@ def lm_step( break # actual nll improvement vs expected from linearization - rho = (nll0 - nll1) / backend.abs(h.T @ hessD @ h - 2 * grad.T @ h).item() + _rho = rho(nll0, nll1, h, hessD, grad).item() - if (nll1 < (nll0 + tolerance) and abs(rho - 1) < abs(scary["rho"] - 1)) or ( - nll1 < scary["nll"] and rho > -10 + if (nll1 < (nll0 + tolerance) and abs(_rho - 1) < abs(scary["rho"] - 1)) or ( + nll1 < scary["nll"] and _rho > -10 ): - scary = {"x": x + h.squeeze(1), "nll": nll1, "L": L0, "rho": rho} + scary = {"x": x + h.squeeze(1), "nll": nll1, "L": L0, "rho": _rho} # Avoid highly non-linear regions - if rho < 0.1 or rho > 2: + if _rho < 0.1 or _rho > 2: L *= Lup if improving is True: break @@ -156,3 +160,61 @@ def lm_step( raise OptimizeStopFail("Could not find step to improve chi^2") return best + + +def batch_lm_step( + x, + data, + model, + weight, + mask, + jacobian, + L=1.0, + Lup=9.0, + Ldn=11.0, + tolerance=1e-4, + likelihood="gaussian", + max_step_iter=3, +): + L0 = L # (D,) + M0 = backend.detach(model(x)) # (D, M) + J = backend.detach(jacobian(x)) # (D, M, N) + data = data * mask + M0 = M0 * mask + weight = weight * mask + J = J * mask.reshape(mask.shape + (1,)) + + if likelihood == "gaussian": + nll0 = nll(data, M0, weight) # (D,) + grad = backend.vmap(gradient)(J, weight, data, M0) # (D, N, 1) + hess = backend.vmap(hessian)(J, weight) # (D, N, N) + elif likelihood == "poisson": + nll0 = nll_poisson(data, M0) # (D,) + grad = backend.vmap(gradient_poisson)(J, data, M0) # (D, N, 1) + hess = backend.vmap(hessian_poisson)(J, data, M0) # (D, N, N) + else: + raise ValueError(f"Unsupported likelihood: {likelihood}") + + del J + + new_x = backend.copy(x) + new_nll = backend.copy(nll0) + new_L = backend.copy(L) + + for _ in range(max_step_iter): + hessD, h = backend.vmap(solve)(hess, grad, new_L) # (D, N, N), (D, N, 1) + M1 = model(x + h.squeeze(2)) # (D, M) + if likelihood == "gaussian": + nll1 = nll(data, M1, weight) # (D,) + elif likelihood == "poisson": + nll1 = nll_poisson(data, M1) # (D,) + + # actual nll improvement vs expected from linearization + _rho = backend.vmap(rho)(nll0, nll1, h, hessD, grad).reshape(-1) # (D,) + + good = backend.isfinite(nll1) & (nll1 < new_nll) & (_rho > 0.1) & (_rho < 2) + new_x = backend.where(good[:, None], x + h.squeeze(2), new_x) + new_nll = backend.where(good, nll1, new_nll) + new_L = backend.where(good, new_L / Ldn, new_L * Lup) + + return {"x": new_x, "nll": backend.to_numpy(new_nll), "L": new_L} diff --git a/astrophot/fit/lm.py b/astrophot/fit/lm.py index fa1aa2da..1168f7f5 100644 --- a/astrophot/fit/lm.py +++ b/astrophot/fit/lm.py @@ -158,8 +158,6 @@ def __init__( relative_tolerance=relative_tolerance, **kwargs, ) - # Maximum number of iterations of the algorithm - self.max_iter = max_iter # Maximum number of steps while searching for chi^2 improvement on a single jacobian evaluation self.max_step_iter = max_step_iter self.Lup = Lup @@ -167,7 +165,9 @@ def __init__( self.L = L0 self.likelihood = likelihood if self.likelihood not in ["gaussian", "poisson"]: - raise ValueError(f"Unsupported likelihood: {self.likelihood}") + raise ValueError( + f"Unsupported likelihood: {self.likelihood}, should be one of: 'gaussian' or 'poisson'" + ) self.constraint = constraint # mask diff --git a/astrophot/models/batch_model_object.py b/astrophot/models/batch_model_object.py index da8894c8..87630c58 100644 --- a/astrophot/models/batch_model_object.py +++ b/astrophot/models/batch_model_object.py @@ -176,12 +176,12 @@ def window(self, window): @forward def __call__(self, model_params=None, model_dims=None, **kwargs): working_image = self.target.model_image(self.window) - crtan = self.target.crtan + crtan = working_image.crtan shift = backend.as_array( self.window.origin_shifter(self.model.window), dtype=config.DTYPE, device=config.DEVICE ) - crpix = self.target.crpix + shift - CD = self.target.CD + crpix = working_image.crpix + shift + CD = working_image.CD psf = self.target.psf_stack psf_batch = None if psf is None else 0 working_image._data = backend.vmap( diff --git a/docs/source/tutorials/ImageTypes.ipynb b/docs/source/tutorials/ImageTypes.ipynb index f6a71d3e..41152f07 100644 --- a/docs/source/tutorials/ImageTypes.ipynb +++ b/docs/source/tutorials/ImageTypes.ipynb @@ -20,7 +20,8 @@ "import astrophot as ap\n", "import torch\n", "import matplotlib.pyplot as plt\n", - "from matplotlib.patches import Rectangle" + "from matplotlib.patches import Rectangle\n", + "import numpy as np" ] }, { @@ -151,6 +152,94 @@ "cell_type": "markdown", "id": "11", "metadata": {}, + "source": [ + "# Batch Images\n", + "\n", + "An `ImageBatch` is much like an `ImageList` except that it has certain restrictions applied to it. Namely that all the images must have the same size, and they must be of the \"regular\" image type, not `SIP` or `CMOS`. This extra constraint means they can't be used as generally as other image types, but the trade off is extra computational advantages. You can see the `batch scene model` in the model zoo for one of the primary advantages of the `ImageBatch` image type. Another advantage is that it can be used to vectorize the fitting of images. Let's say we have some simple model (either just a component model, or a full group model) and we want to use that same model to represent a bunch of similar images. We can use batched images to make this much more efficient. Rather than looping through the images one at a time, we can create a batch and set all their calculations to the hardware (CPU or GPU) at once. Here is an example of what this looks like, but the real advantage is when you seriously scale it up.\n", + "\n", + "Note: Every dynamic parameter for the model must be batched, these are supposed to be independent models and so every parameter is unique for each model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "base_target = ap.TargetImage(data=np.zeros((64, 64)))\n", + "model = ap.Model(\n", + " model_type=\"sersic galaxy model\",\n", + " name=\"batch_demo\",\n", + " center=(32, 40),\n", + " q=0.6,\n", + " PA=3.14 / 3,\n", + " n=1,\n", + " Re=10,\n", + " Ie=1,\n", + " target=base_target,\n", + " integrate_mode=\"none\",\n", + " sampling_mode=\"quad:3\",\n", + ")\n", + "model.initialize()\n", + "\n", + "np.random.seed(42)\n", + "target1 = ap.TargetImage(\n", + " name=\"target1\", data=model().data + np.random.normal(scale=0.5, size=(64, 64))\n", + ")\n", + "model.center = (15, 15)\n", + "model.PA = -3.14 / 4\n", + "target2 = ap.TargetImage(\n", + " name=\"target2\", data=model().data + np.random.normal(scale=0.5, size=(64, 64))\n", + ")\n", + "model.center = (45, 20)\n", + "model.Re = 15\n", + "target3 = ap.TargetImage(\n", + " name=\"target3\", data=model().data + np.random.normal(scale=0.5, size=(64, 64))\n", + ")\n", + "\n", + "batch_target = ap.TargetImageBatch([target1, target2, target3])\n", + "fig, axarr = plt.subplots(1, 3, figsize=(15, 5))\n", + "ap.plots.target_image(fig, axarr, batch_target)\n", + "plt.show()\n", + "\n", + "# Every parameter must be batched\n", + "# Notice we set them a bit off the true values\n", + "model.center = ((30, 42), (16, 16), (46, 21))\n", + "model.q = (0.6, 0.6, 0.6)\n", + "model.PA = (3.14 / 3, -3.14 / 5, -3.14 / 4)\n", + "model.n = (1.1, 0.9, 1)\n", + "model.Re = (11, 11, 14)\n", + "model.Ie = (1, 0.8, 1.2)\n", + "\n", + "res = ap.fit.BatchLM(model, batch_target, batch_target.window).fit()\n", + "print(\"Best-fit parameters:\", res.current_state.detach().cpu().numpy())\n", + "\n", + "fig, axarr = plt.subplots(2, 3, figsize=(15, 5))\n", + "for i in range(3):\n", + " model.set_values(res.current_state[i])\n", + " model.target = batch_target.images[i]\n", + " ap.plots.model_image(fig, axarr[0, i], model)\n", + " ap.plots.residual_image(fig, axarr[1, i], model)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "13", + "metadata": {}, + "source": [ + "And there we have it! Three models fit at once. Really, just doing three isn't a big deal computationally. But if you have dozens or hundreds of these little 64x64 pixel images then grouping them all together can get huge computational speedups. If you are able to run these big batched fits on a GPU you can get several orders of magnitude runtime reduction!\n", + "\n", + "Happy fitting!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], "source": [] } ],