-
Notifications
You must be signed in to change notification settings - Fork 10
Autodiff with JAX #122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Autodiff with JAX #122
Changes from all commits
2ba72a5
1daa638
27afde8
89ae916
c8d9bee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -2,9 +2,16 @@ | |||||
|
|
||||||
| from abc import ABC, abstractmethod | ||||||
|
|
||||||
| from dalia import ArrayLike, NDArray | ||||||
| from dalia import ArrayLike, NDArray, xp,sp | ||||||
| from dalia.configs.likelihood_config import LikelihoodConfig | ||||||
|
|
||||||
| try: | ||||||
| import jax.numpy as jnp | ||||||
| JAX_AVAILABLE = True | ||||||
| except (ImportError, ModuleNotFoundError): | ||||||
| jnp = xp | ||||||
| JAX_AVAILABLE = False | ||||||
|
|
||||||
|
|
||||||
| class Likelihood(ABC): | ||||||
| """Abstract core class for likelihood.""" | ||||||
|
|
@@ -18,14 +25,55 @@ def __init__( | |||||
|
|
||||||
| self.config = config | ||||||
| self.n_observations = n_observations | ||||||
|
|
||||||
| def gradient_likelihood(self, eta, y, h=1e-4, **kwargs): | ||||||
| if self.config.method == "exact": | ||||||
| return self.evaluate_gradient_likelihood(eta, y, **kwargs) | ||||||
| elif self.config.method == "finite_difference": | ||||||
| grad = self.finite_difference_gradient_likelihood(eta, y, h, **kwargs) | ||||||
| return grad | ||||||
| elif self.config.method == "jax_autodiff": | ||||||
| if not JAX_AVAILABLE: | ||||||
| raise RuntimeError("JAX is not available.") | ||||||
| return self.evaluate_gradient_likelihood_jax(eta, y, **kwargs) | ||||||
| else: | ||||||
| raise NotImplementedError(f"Method {self.config.method} not implemented.") | ||||||
| # ref = self.evaluate_gradient_likelihood(eta, y, **kwargs) | ||||||
| # grad = self.finite_difference_gradient_likelihood(eta, y, h, **kwargs) | ||||||
| # assert xp.allclose(ref, grad), f"Gradient mismatch: {ref} vs {grad}" | ||||||
| # return grad | ||||||
|
|
||||||
| def hessian_likelihood(self, h: float = 1e-2, **kwargs): | ||||||
| if self.config.method == "exact": | ||||||
| return self.evaluate_hessian_likelihood(**kwargs) | ||||||
| elif self.config.method == "finite_difference": | ||||||
| kwargs = kwargs or {} | ||||||
| kwargs["h"] = h | ||||||
| hess = self.finite_difference_hessian_likelihood(**kwargs) | ||||||
| return hess | ||||||
| elif self.config.method == "jax_autodiff": | ||||||
| if not JAX_AVAILABLE: | ||||||
| raise RuntimeError("JAX is not available.") | ||||||
| return self.evaluate_hessian_likelihood_jax(**kwargs) | ||||||
| else: | ||||||
| raise NotImplementedError(f"Method {self.config.method} not implemented.") | ||||||
| # ref = self.evaluate_hessian_likelihood(**kwargs) | ||||||
| # ref_diag = ref.diagonal() | ||||||
| # kwargs = kwargs or {} | ||||||
| # kwargs["h"] = 1e-3 | ||||||
| # hess = self.finite_difference_hessian_likelihood(**kwargs) | ||||||
| # rel_error = xp.linalg.norm(ref_diag - hess) / xp.linalg.norm(ref_diag) | ||||||
| # if not xp.allclose(ref_diag, hess): | ||||||
| # print(f"Hessian mismatch: {rel_error}") | ||||||
| # return hess | ||||||
|
|
||||||
| @abstractmethod | ||||||
| def evaluate_likelihood( | ||||||
| self, | ||||||
| eta: NDArray, | ||||||
| y: NDArray, | ||||||
| **kwargs, | ||||||
| ) -> float: | ||||||
| ) -> NDArray: | ||||||
| """Evaluate the likelihood. | ||||||
|
|
||||||
| Parameters | ||||||
|
|
@@ -44,6 +92,33 @@ def evaluate_likelihood( | |||||
| Likelihood. | ||||||
| """ | ||||||
| pass | ||||||
|
|
||||||
| def evaluate_sum_likelihood( | ||||||
| self, | ||||||
| eta: NDArray, | ||||||
| y: NDArray, | ||||||
| **kwargs, | ||||||
| ) -> float: | ||||||
| """Evaluate the sum of the likelihood over all observations. | ||||||
|
|
||||||
| Parameters | ||||||
| ---------- | ||||||
| eta : NDArray | ||||||
| Vector of the linear predictor. | ||||||
| y : NDArray | ||||||
| Vector of the observations. | ||||||
| kwargs : | ||||||
| theta : float | ||||||
| Specific parameter for the likelihood calculation. | ||||||
|
|
||||||
| Returns | ||||||
| ------- | ||||||
| sum_likelihood : float | ||||||
| Sum of the likelihood over all observations. | ||||||
| """ | ||||||
| likelihood = self.evaluate_likelihood(eta, y, **kwargs) | ||||||
| sum_likelihood = float(likelihood.sum()) | ||||||
| return sum_likelihood | ||||||
|
|
||||||
| @abstractmethod | ||||||
| def evaluate_gradient_likelihood( | ||||||
|
|
@@ -68,7 +143,49 @@ def evaluate_gradient_likelihood( | |||||
| gradient_likelihood : NDArray | ||||||
| Gradient of the likelihood. | ||||||
| """ | ||||||
| pass | ||||||
| self.finite_difference_gradient_likelihood(eta, y, **kwargs) | ||||||
|
|
||||||
| def finite_difference_gradient_likelihood( | ||||||
| self, | ||||||
| eta: NDArray, | ||||||
| y: NDArray, | ||||||
| h: float = 1e-4, | ||||||
| **kwargs, | ||||||
| ) -> NDArray: | ||||||
| """Evaluate the finite difference gradient of the likelihood wrt to eta = Ax. | ||||||
|
|
||||||
| Parameters | ||||||
| ---------- | ||||||
| eta : NDArray | ||||||
| Vector of the linear predictor. | ||||||
| y : NDArray | ||||||
| Vector of the observations. | ||||||
| **kwargs : optional | ||||||
| Hyperparameters for likelihood. | ||||||
|
|
||||||
| Returns | ||||||
| ------- | ||||||
| finite_difference_gradient : NDArray | ||||||
| Finite difference gradient of the likelihood. | ||||||
| """ | ||||||
| # grad = (-f(x+2h) + 8f(x+h) - 8f(x-h) + f(x-2h)) / 12h | ||||||
| # hessian = (-f(x+2h) + 16f(x+h) - 30f(x) + 16f(x-h) - f(x-2h)) / 12h^2 | ||||||
| f1 = self.evaluate_likelihood(eta + h, y, **kwargs) | ||||||
| f2 = self.evaluate_likelihood(eta + 2 * h, y, **kwargs) | ||||||
| b1 = self.evaluate_likelihood(eta - h, y, **kwargs) | ||||||
| b2 = self.evaluate_likelihood(eta - 2 * h, y, **kwargs) | ||||||
| # c = self.evaluate_likelihood(eta, y, **kwargs) | ||||||
| grad = (-f2 + 8 * f1 - 8 * b1 + b2) / (12 * h) | ||||||
| # hessian = (-f2 + 16 * f1 - 30 * c + 16 * b1 - b2) / (12 * h * h) | ||||||
| return grad | ||||||
|
|
||||||
| def evaluate_gradient_likelihood_jax( | ||||||
| self, | ||||||
| eta: NDArray, | ||||||
| y: NDArray, | ||||||
| **kwargs, | ||||||
| ) -> NDArray: | ||||||
| raise NotImplementedError("JAX gradient not implemented for this likelihood.") | ||||||
|
|
||||||
| @abstractmethod | ||||||
| def evaluate_hessian_likelihood( | ||||||
|
|
@@ -92,4 +209,44 @@ def evaluate_hessian_likelihood( | |||||
| hessian_likelihood : ArrayLike | ||||||
| Hessian of the likelihood. | ||||||
| """ | ||||||
| pass | ||||||
| self.finite_difference_hessian_likelihood(eta, y, **kwargs) | ||||||
|
||||||
| self.finite_difference_hessian_likelihood(eta, y, **kwargs) | |
| return self.finite_difference_hessian_likelihood(eta, y, **kwargs) |
Copilot
AI
Sep 12, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The method signature includes eta and y parameters, but the abstract method evaluate_hessian_likelihood doesn't include these parameters in its kwargs. The finite difference implementation expects these parameters to be passed explicitly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These abstract method implementations are missing return statements. They should return the result of the finite difference computation.