diff --git a/src/dalia/configs/likelihood_config.py b/src/dalia/configs/likelihood_config.py index 74f48ecd..72120c23 100644 --- a/src/dalia/configs/likelihood_config.py +++ b/src/dalia/configs/likelihood_config.py @@ -18,6 +18,7 @@ class LikelihoodConfig(BaseModel, ABC): model_config = ConfigDict(extra="forbid") type: Literal["gaussian", "poisson", "binomial"] = None + method: Literal["exact", "finite_difference", "jax_autodiff"] = "exact" # TODO: cleaner way to let user fix hyperparameters fix_hyperparameters: bool = False diff --git a/src/dalia/core/likelihood.py b/src/dalia/core/likelihood.py index 27e13adb..e1c8f20e 100644 --- a/src/dalia/core/likelihood.py +++ b/src/dalia/core/likelihood.py @@ -2,9 +2,16 @@ from abc import ABC, abstractmethod -from dalia import ArrayLike, NDArray +from dalia import ArrayLike, NDArray, xp,sp from dalia.configs.likelihood_config import LikelihoodConfig +try: + import jax.numpy as jnp + JAX_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + jnp = xp + JAX_AVAILABLE = False + class Likelihood(ABC): """Abstract core class for likelihood.""" @@ -18,6 +25,47 @@ def __init__( self.config = config self.n_observations = n_observations + + def gradient_likelihood(self, eta, y, h=1e-4, **kwargs): + if self.config.method == "exact": + return self.evaluate_gradient_likelihood(eta, y, **kwargs) + elif self.config.method == "finite_difference": + grad = self.finite_difference_gradient_likelihood(eta, y, h, **kwargs) + return grad + elif self.config.method == "jax_autodiff": + if not JAX_AVAILABLE: + raise RuntimeError("JAX is not available.") + return self.evaluate_gradient_likelihood_jax(eta, y, **kwargs) + else: + raise NotImplementedError(f"Method {self.config.method} not implemented.") + # ref = self.evaluate_gradient_likelihood(eta, y, **kwargs) + # grad = self.finite_difference_gradient_likelihood(eta, y, h, **kwargs) + # assert xp.allclose(ref, grad), f"Gradient mismatch: {ref} vs {grad}" + # return grad + + def hessian_likelihood(self, h: float = 1e-2, **kwargs): + if self.config.method == "exact": + return self.evaluate_hessian_likelihood(**kwargs) + elif self.config.method == "finite_difference": + kwargs = kwargs or {} + kwargs["h"] = h + hess = self.finite_difference_hessian_likelihood(**kwargs) + return hess + elif self.config.method == "jax_autodiff": + if not JAX_AVAILABLE: + raise RuntimeError("JAX is not available.") + return self.evaluate_hessian_likelihood_jax(**kwargs) + else: + raise NotImplementedError(f"Method {self.config.method} not implemented.") + # ref = self.evaluate_hessian_likelihood(**kwargs) + # ref_diag = ref.diagonal() + # kwargs = kwargs or {} + # kwargs["h"] = 1e-3 + # hess = self.finite_difference_hessian_likelihood(**kwargs) + # rel_error = xp.linalg.norm(ref_diag - hess) / xp.linalg.norm(ref_diag) + # if not xp.allclose(ref_diag, hess): + # print(f"Hessian mismatch: {rel_error}") + # return hess @abstractmethod def evaluate_likelihood( @@ -25,7 +73,7 @@ def evaluate_likelihood( eta: NDArray, y: NDArray, **kwargs, - ) -> float: + ) -> NDArray: """Evaluate the likelihood. Parameters @@ -44,6 +92,33 @@ def evaluate_likelihood( Likelihood. """ pass + + def evaluate_sum_likelihood( + self, + eta: NDArray, + y: NDArray, + **kwargs, + ) -> float: + """Evaluate the sum of the likelihood over all observations. + + Parameters + ---------- + eta : NDArray + Vector of the linear predictor. + y : NDArray + Vector of the observations. + kwargs : + theta : float + Specific parameter for the likelihood calculation. + + Returns + ------- + sum_likelihood : float + Sum of the likelihood over all observations. + """ + likelihood = self.evaluate_likelihood(eta, y, **kwargs) + sum_likelihood = float(likelihood.sum()) + return sum_likelihood @abstractmethod def evaluate_gradient_likelihood( @@ -68,7 +143,49 @@ def evaluate_gradient_likelihood( gradient_likelihood : NDArray Gradient of the likelihood. """ - pass + self.finite_difference_gradient_likelihood(eta, y, **kwargs) + + def finite_difference_gradient_likelihood( + self, + eta: NDArray, + y: NDArray, + h: float = 1e-4, + **kwargs, + ) -> NDArray: + """Evaluate the finite difference gradient of the likelihood wrt to eta = Ax. + + Parameters + ---------- + eta : NDArray + Vector of the linear predictor. + y : NDArray + Vector of the observations. + **kwargs : optional + Hyperparameters for likelihood. + + Returns + ------- + finite_difference_gradient : NDArray + Finite difference gradient of the likelihood. + """ + # grad = (-f(x+2h) + 8f(x+h) - 8f(x-h) + f(x-2h)) / 12h + # hessian = (-f(x+2h) + 16f(x+h) - 30f(x) + 16f(x-h) - f(x-2h)) / 12h^2 + f1 = self.evaluate_likelihood(eta + h, y, **kwargs) + f2 = self.evaluate_likelihood(eta + 2 * h, y, **kwargs) + b1 = self.evaluate_likelihood(eta - h, y, **kwargs) + b2 = self.evaluate_likelihood(eta - 2 * h, y, **kwargs) + # c = self.evaluate_likelihood(eta, y, **kwargs) + grad = (-f2 + 8 * f1 - 8 * b1 + b2) / (12 * h) + # hessian = (-f2 + 16 * f1 - 30 * c + 16 * b1 - b2) / (12 * h * h) + return grad + + def evaluate_gradient_likelihood_jax( + self, + eta: NDArray, + y: NDArray, + **kwargs, + ) -> NDArray: + raise NotImplementedError("JAX gradient not implemented for this likelihood.") @abstractmethod def evaluate_hessian_likelihood( @@ -92,4 +209,44 @@ def evaluate_hessian_likelihood( hessian_likelihood : ArrayLike Hessian of the likelihood. """ - pass + self.finite_difference_hessian_likelihood(eta, y, **kwargs) + + def finite_difference_hessian_likelihood( + self, + eta: NDArray, + y: NDArray, + h: float = 1e-2, + **kwargs, + ) -> NDArray: + """Evaluate the finite difference Hessian of the likelihood wrt to eta = Ax. + + Parameters + ---------- + eta : NDArray + Vector of the linear predictor. + y : NDArray + Vector of the observations. + **kwargs : optional + Hyperparameters for likelihood. + + Returns + ------- + finite_difference_gradient : NDArray + Finite difference gradient of the likelihood. + """ + # grad = (-f(x+2h) + 8f(x+h) - 8f(x-h) + f(x-2h)) / 12h + # hessian = (-f(x+2h) + 16f(x+h) - 30f(x) + 16f(x-h) - f(x-2h)) / 12h^2 + f1 = self.evaluate_likelihood(eta + h, y, **kwargs) + f2 = self.evaluate_likelihood(eta + 2 * h, y, **kwargs) + b1 = self.evaluate_likelihood(eta - h, y, **kwargs) + b2 = self.evaluate_likelihood(eta - 2 * h, y, **kwargs) + c = self.evaluate_likelihood(eta, y, **kwargs) + # grad = (-f2 + 8 * f1 - 8 * b1 + b2) / (12 * h) + hessian = (-f2 + 16 * f1 - 30 * c + 16 * b1 - b2) / (12 * h * h) + return hessian + + def evaluate_hessian_likelihood_jax( + self, + **kwargs, + ) -> ArrayLike: + raise NotImplementedError("JAX Hessian not implemented for this likelihood.") diff --git a/src/dalia/core/model.py b/src/dalia/core/model.py index 91608d62..ede0e035 100644 --- a/src/dalia/core/model.py +++ b/src/dalia/core/model.py @@ -431,19 +431,30 @@ def construct_Q_conditional( d_matrix = self.submodels[0].evaluate_d_matrix(**kwargs) else: # General rules - d_matrix = self.likelihood.evaluate_hessian_likelihood(**kwargs) + kwargs["y"] = self.y + d_matrix = self.likelihood.hessian_likelihood(**kwargs) + + if d_matrix.ndim == 1: + d_matrix_diagonal_0 = d_matrix[0] + d_matrix = sp.sparse.diags(d_matrix) + elif d_matrix.ndim == 2: + d_matrix_diagonal_0 = d_matrix.diagonal()[0] + else: + raise ValueError("d_matrix must be 1D or 2D array.") # if self.a is sparse -> Q_conditional should be sparse, else dense if sp.sparse.issparse(self.a): if self.aTa is not None: - self.Q_conditional = self.Q_prior - d_matrix.diagonal()[0] * self.aTa + # self.Q_conditional = self.Q_prior - d_matrix.diagonal()[0] * self.aTa + self.Q_conditional = self.Q_prior - d_matrix_diagonal_0 * self.aTa else: self.Q_conditional = self.Q_prior - self.a.T @ d_matrix @ self.a # self.Q_conditional = self.Q_prior - self.a.T @ d_matrix @ self.a else: if self.aTa is not None: self.Q_conditional = ( - self.Q_prior.toarray() - d_matrix.diagonal()[0] * self.aTa + # self.Q_prior.toarray() - d_matrix.diagonal()[0] * self.aTa + self.Q_prior.toarray() - d_matrix_diagonal_0 * self.aTa ) else: self.Q_conditional = ( @@ -467,7 +478,7 @@ def construct_information_vector( ) else: - gradient_likelihood = self.likelihood.evaluate_gradient_likelihood( + gradient_likelihood = self.likelihood.gradient_likelihood( eta=eta, y=self.y, theta=self.theta[self.hyperparameters_idx[-1] :], @@ -540,7 +551,7 @@ def evaluate_likelihood(self, eta: NDArray, **kwargs) -> float: kwargs["h2"] = float(self.theta[0]) likelihood = self.submodels[0].evaluate_likelihood(eta, self.y, **kwargs) else: - likelihood = self.likelihood.evaluate_likelihood( + likelihood = self.likelihood.evaluate_sum_likelihood( eta, self.y, theta=self.theta[self.hyperparameters_idx[-1] :] ) diff --git a/src/dalia/likelihoods/binomial.py b/src/dalia/likelihoods/binomial.py index 53c2754d..9156352e 100644 --- a/src/dalia/likelihoods/binomial.py +++ b/src/dalia/likelihoods/binomial.py @@ -43,7 +43,7 @@ def evaluate_likelihood( eta: NDArray, y: NDArray, **kwargs, - ) -> float: + ) -> NDArray: """Evalutate the a binomial likelihood. Parameters @@ -64,9 +64,10 @@ def evaluate_likelihood( """ linkEta: NDArray = self.link_function(eta) - likelihood: float = xp.dot(y, xp.log(linkEta)) + xp.dot( - self.n_trials - y, xp.log(1 - linkEta) - ) + # likelihood: float = xp.dot(y, xp.log(linkEta)) + xp.dot( + # self.n_trials - y, xp.log(1 - linkEta) + # ) + likelihood = y * xp.log(linkEta) + (self.n_trials - y) * xp.log(1 - linkEta) return likelihood diff --git a/src/dalia/likelihoods/gaussian.py b/src/dalia/likelihoods/gaussian.py index ee69bd16..629b9173 100644 --- a/src/dalia/likelihoods/gaussian.py +++ b/src/dalia/likelihoods/gaussian.py @@ -4,6 +4,14 @@ from dalia.configs.likelihood_config import GaussianLikelihoodConfig from dalia.core.likelihood import Likelihood +try: + import jax.numpy as jnp + from jax import grad, jit, vmap + JAX_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + jnp = xp + JAX_AVAILABLE = False + class GaussianLikelihood(Likelihood): """Gaussian likelihood.""" @@ -16,12 +24,18 @@ def __init__( """Initializes the Gaussian likelihood.""" super().__init__(n_observations, config) + if JAX_AVAILABLE: + first_derivative = grad(self.evaluate_likelihood_jax, argnums=0) + second_derivative = grad(first_derivative, argnums=0) + self.gradient_jax = jit(vmap(first_derivative)) + self.hessian_jax = jit(vmap(second_derivative)) + def evaluate_likelihood( self, eta: NDArray, y: NDArray, **kwargs, - ) -> float: + ) -> NDArray: """Evaluate a Gaussian likelihood. Notes @@ -55,11 +69,14 @@ def evaluate_likelihood( yEta = eta - y # print("xp.exp(theta) in lh:", xp.exp(theta)) - likelihood: float = ( - 0.5 * theta * self.n_observations - 0.5 * xp.exp(theta) * yEta.T @ yEta - ) + likelihood = 0.5 * theta - 0.5 * xp.exp(theta) * yEta * yEta + return likelihood + + def evaluate_likelihood_jax(self, eta, y, theta): + yEta = eta - y + return 0.5 * theta - 0.5 * jnp.exp(theta) * yEta * yEta def evaluate_gradient_likelihood( self, @@ -94,6 +111,21 @@ def evaluate_gradient_likelihood( gradient_likelihood: NDArray = -xp.exp(theta) * (eta - y) return gradient_likelihood + + def evaluate_gradient_likelihood_jax( + self, + eta: NDArray, + y: NDArray, + **kwargs, + ) -> NDArray: + jax_eta = jnp.from_dlpack(eta) + jax_y = jnp.from_dlpack(y) + theta = kwargs.get("theta", None) + if not isinstance(theta, float): + theta = float(theta[0]) + jax_theta = jnp.full_like(jax_eta, theta) + grad = self.gradient_jax(jax_eta, jax_y, jax_theta) + return xp.from_dlpack(grad) def evaluate_hessian_likelihood( self, @@ -129,3 +161,16 @@ def evaluate_hessian_likelihood( ) return hessian_likelihood + + def evaluate_hessian_likelihood_jax( + self, + **kwargs, + ) -> ArrayLike: + jax_eta = jnp.from_dlpack(kwargs.get("eta")) + jax_y = jnp.from_dlpack(kwargs.get("y")) + theta = kwargs.get("theta", None) + if not isinstance(theta, float): + theta = float(theta[0]) + jax_theta = jnp.full_like(jax_eta, theta) + hessian = self.hessian_jax(jax_eta, jax_y, jax_theta) + return xp.from_dlpack(hessian) diff --git a/src/dalia/likelihoods/poisson.py b/src/dalia/likelihoods/poisson.py index 6fbea733..3244cea2 100644 --- a/src/dalia/likelihoods/poisson.py +++ b/src/dalia/likelihoods/poisson.py @@ -8,6 +8,14 @@ from dalia.configs.likelihood_config import PoissonLikelihoodConfig from dalia.core.likelihood import Likelihood +try: + import jax.numpy as jnp + from jax import grad, jit, vmap + JAX_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + jnp = xp + JAX_AVAILABLE = False + class PoissonLikelihood(Likelihood): """Poisson likelihood.""" @@ -18,7 +26,7 @@ def __init__( config: PoissonLikelihoodConfig, ) -> None: """Initializes the Poisson likelihood.""" - super().__init__(config, n_observations) + super().__init__(n_observations, config) # Load the extra coeficients for Poisson likelihood try: @@ -31,16 +39,27 @@ def __init__( self.e: NDArray = e else: self.e: NDArray = xp.asarray(e) + + if JAX_AVAILABLE: + first_derivative = grad(self.evaluate_likelihood_jax, argnums=0) + second_derivative = grad(first_derivative, argnums=0) + self.gradient_jax = jit(vmap(first_derivative)) + self.hessian_jax = jit(vmap(second_derivative)) + self.jax_e = jnp.from_dlpack(self.e) def evaluate_likelihood( self, eta: NDArray, y: NDArray, **kwargs, - ) -> float: - likelihood: float = xp.dot(eta, y) - xp.sum(self.e * xp.exp(eta)) + ) -> NDArray: + # likelihood: float = xp.dot(eta, y) - xp.sum(self.e * xp.exp(eta)) + likelihood = eta * y - self.e * xp.exp(eta) return likelihood + + def evaluate_likelihood_jax(self, eta, y, e): + return eta * y - e * jnp.exp(eta) def evaluate_gradient_likelihood( self, @@ -51,6 +70,17 @@ def evaluate_gradient_likelihood( gradient_likelihood: NDArray = y - self.e * xp.exp(eta) return gradient_likelihood + + def evaluate_gradient_likelihood_jax( + self, + eta: NDArray, + y: NDArray, + **kwargs, + ) -> NDArray: + jax_eta = jnp.from_dlpack(eta) + jax_y = jnp.from_dlpack(y) + grad = self.gradient_jax(jax_eta, jax_y, self.jax_e) + return xp.from_dlpack(grad) def evaluate_hessian_likelihood( self, @@ -61,3 +91,12 @@ def evaluate_hessian_likelihood( hessian_likelihood: ArrayLike = -1.0 * sp.sparse.diags(self.e * xp.exp(eta)) return hessian_likelihood + + def evaluate_hessian_likelihood_jax( + self, + **kwargs, + ) -> ArrayLike: + jax_eta = jnp.from_dlpack(kwargs.get("eta")) + jax_y = jnp.from_dlpack(kwargs.get("y")) + hessian = self.hessian_jax(jax_eta, jax_y, self.jax_e) + return xp.from_dlpack(hessian) diff --git a/src/dalia/models/coregional_model.py b/src/dalia/models/coregional_model.py index 3f5717f7..3d9c5d33 100644 --- a/src/dalia/models/coregional_model.py +++ b/src/dalia/models/coregional_model.py @@ -630,9 +630,19 @@ def construct_Q_conditional( } # d_list[i] = model.likelihood.evaluate_hessian_likelihood(**kwargs) + kwargs["y"] = self.y[self.n_observations_idx[i] : self.n_observations_idx[i + 1]] + hessian = model.likelihood.hessian_likelihood(**kwargs) + if hessian.ndim == 1: + diag = hessian + elif hessian.ndim == 2: + diag = hessian.diagonal() + else: + raise ValueError( + "Hessian of the likelihood must be either 1D or 2D array." + ) d_vec[ self.n_observations_idx[i] : self.n_observations_idx[i + 1] - ] = model.likelihood.evaluate_hessian_likelihood(**kwargs).diagonal() + ] = diag self.Qconditional = self.custom_Q_ATDA( Q=self.Q_prior, @@ -653,7 +663,7 @@ def construct_information_vector( gradient_vector_list = [] for i, model in enumerate(self.models): - gradient_likelihood = model.likelihood.evaluate_gradient_likelihood( + gradient_likelihood = model.likelihood.gradient_likelihood( eta=eta[self.n_observations_idx[i] : self.n_observations_idx[i + 1]], y=self.y[self.n_observations_idx[i] : self.n_observations_idx[i + 1]], theta=float(self.theta[self.hyperparameters_idx[i + 1] - 1]), @@ -682,7 +692,7 @@ def evaluate_likelihood( ) -> float: likelihood: float = 0.0 for i, model in enumerate(self.models): - likelihood += model.likelihood.evaluate_likelihood( + likelihood += model.likelihood.evaluate_sum_likelihood( eta=eta[self.n_observations_idx[i] : self.n_observations_idx[i + 1]], y=self.y[self.n_observations_idx[i] : self.n_observations_idx[i + 1]], theta=float(self.theta[self.hyperparameters_idx[i + 1] - 1]), diff --git a/src/dalia/submodels/brainiac.py b/src/dalia/submodels/brainiac.py index 61e01bdf..b3154a61 100644 --- a/src/dalia/submodels/brainiac.py +++ b/src/dalia/submodels/brainiac.py @@ -98,6 +98,8 @@ def evaluate_likelihood(self, eta: NDArray, y: NDArray, **kwargs) -> float: return likelihood + # TODO/NOTE: Maybe specialize the Gaussian likelihood in its own class + # and have BrainiacSubModel use it as its likelihood? def evaluate_gradient_likelihood( self, eta: NDArray, y: NDArray, **kwargs ) -> NDArray: