Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/dalia/configs/likelihood_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class LikelihoodConfig(BaseModel, ABC):
model_config = ConfigDict(extra="forbid")

type: Literal["gaussian", "poisson", "binomial"] = None
method: Literal["exact", "finite_difference", "jax_autodiff"] = "exact"

# TODO: cleaner way to let user fix hyperparameters
fix_hyperparameters: bool = False
Expand Down
165 changes: 161 additions & 4 deletions src/dalia/core/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,16 @@

from abc import ABC, abstractmethod

from dalia import ArrayLike, NDArray
from dalia import ArrayLike, NDArray, xp,sp
from dalia.configs.likelihood_config import LikelihoodConfig

try:
import jax.numpy as jnp
JAX_AVAILABLE = True
except (ImportError, ModuleNotFoundError):
jnp = xp
JAX_AVAILABLE = False


class Likelihood(ABC):
"""Abstract core class for likelihood."""
Expand All @@ -18,14 +25,55 @@ def __init__(

self.config = config
self.n_observations = n_observations

def gradient_likelihood(self, eta, y, h=1e-4, **kwargs):
if self.config.method == "exact":
return self.evaluate_gradient_likelihood(eta, y, **kwargs)
elif self.config.method == "finite_difference":
grad = self.finite_difference_gradient_likelihood(eta, y, h, **kwargs)
return grad
elif self.config.method == "jax_autodiff":
if not JAX_AVAILABLE:
raise RuntimeError("JAX is not available.")
return self.evaluate_gradient_likelihood_jax(eta, y, **kwargs)
else:
raise NotImplementedError(f"Method {self.config.method} not implemented.")
# ref = self.evaluate_gradient_likelihood(eta, y, **kwargs)
# grad = self.finite_difference_gradient_likelihood(eta, y, h, **kwargs)
# assert xp.allclose(ref, grad), f"Gradient mismatch: {ref} vs {grad}"
# return grad

def hessian_likelihood(self, h: float = 1e-2, **kwargs):
if self.config.method == "exact":
return self.evaluate_hessian_likelihood(**kwargs)
elif self.config.method == "finite_difference":
kwargs = kwargs or {}
kwargs["h"] = h
hess = self.finite_difference_hessian_likelihood(**kwargs)
return hess
elif self.config.method == "jax_autodiff":
if not JAX_AVAILABLE:
raise RuntimeError("JAX is not available.")
return self.evaluate_hessian_likelihood_jax(**kwargs)
else:
raise NotImplementedError(f"Method {self.config.method} not implemented.")
# ref = self.evaluate_hessian_likelihood(**kwargs)
# ref_diag = ref.diagonal()
# kwargs = kwargs or {}
# kwargs["h"] = 1e-3
# hess = self.finite_difference_hessian_likelihood(**kwargs)
# rel_error = xp.linalg.norm(ref_diag - hess) / xp.linalg.norm(ref_diag)
# if not xp.allclose(ref_diag, hess):
# print(f"Hessian mismatch: {rel_error}")
# return hess

@abstractmethod
def evaluate_likelihood(
self,
eta: NDArray,
y: NDArray,
**kwargs,
) -> float:
) -> NDArray:
"""Evaluate the likelihood.

Parameters
Expand All @@ -44,6 +92,33 @@ def evaluate_likelihood(
Likelihood.
"""
pass

def evaluate_sum_likelihood(
self,
eta: NDArray,
y: NDArray,
**kwargs,
) -> float:
"""Evaluate the sum of the likelihood over all observations.

Parameters
----------
eta : NDArray
Vector of the linear predictor.
y : NDArray
Vector of the observations.
kwargs :
theta : float
Specific parameter for the likelihood calculation.

Returns
-------
sum_likelihood : float
Sum of the likelihood over all observations.
"""
likelihood = self.evaluate_likelihood(eta, y, **kwargs)
sum_likelihood = float(likelihood.sum())
return sum_likelihood

@abstractmethod
def evaluate_gradient_likelihood(
Expand All @@ -68,7 +143,49 @@ def evaluate_gradient_likelihood(
gradient_likelihood : NDArray
Gradient of the likelihood.
"""
pass
self.finite_difference_gradient_likelihood(eta, y, **kwargs)

Copilot AI Sep 12, 2025

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These abstract method implementations are missing return statements. They should return the result of the finite difference computation.

Suggested change
self.finite_difference_gradient_likelihood(eta, y, **kwargs)
return self.finite_difference_gradient_likelihood(eta, y, **kwargs)

Copilot uses AI. Check for mistakes.

def finite_difference_gradient_likelihood(
self,
eta: NDArray,
y: NDArray,
h: float = 1e-4,
**kwargs,
) -> NDArray:
"""Evaluate the finite difference gradient of the likelihood wrt to eta = Ax.

Parameters
----------
eta : NDArray
Vector of the linear predictor.
y : NDArray
Vector of the observations.
**kwargs : optional
Hyperparameters for likelihood.

Returns
-------
finite_difference_gradient : NDArray
Finite difference gradient of the likelihood.
"""
# grad = (-f(x+2h) + 8f(x+h) - 8f(x-h) + f(x-2h)) / 12h
# hessian = (-f(x+2h) + 16f(x+h) - 30f(x) + 16f(x-h) - f(x-2h)) / 12h^2
f1 = self.evaluate_likelihood(eta + h, y, **kwargs)
f2 = self.evaluate_likelihood(eta + 2 * h, y, **kwargs)
b1 = self.evaluate_likelihood(eta - h, y, **kwargs)
b2 = self.evaluate_likelihood(eta - 2 * h, y, **kwargs)
# c = self.evaluate_likelihood(eta, y, **kwargs)
grad = (-f2 + 8 * f1 - 8 * b1 + b2) / (12 * h)
# hessian = (-f2 + 16 * f1 - 30 * c + 16 * b1 - b2) / (12 * h * h)
return grad

def evaluate_gradient_likelihood_jax(
self,
eta: NDArray,
y: NDArray,
**kwargs,
) -> NDArray:
raise NotImplementedError("JAX gradient not implemented for this likelihood.")

@abstractmethod
def evaluate_hessian_likelihood(
Expand All @@ -92,4 +209,44 @@ def evaluate_hessian_likelihood(
hessian_likelihood : ArrayLike
Hessian of the likelihood.
"""
pass
self.finite_difference_hessian_likelihood(eta, y, **kwargs)

Copilot AI Sep 12, 2025

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These abstract method implementations are missing return statements. They should return the result of the finite difference computation.

Suggested change
self.finite_difference_hessian_likelihood(eta, y, **kwargs)
return self.finite_difference_hessian_likelihood(eta, y, **kwargs)

Copilot uses AI. Check for mistakes.

def finite_difference_hessian_likelihood(
self,
eta: NDArray,
y: NDArray,
h: float = 1e-2,
**kwargs,
) -> NDArray:
Comment on lines +214 to +220

Copilot AI Sep 12, 2025

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method signature includes eta and y parameters, but the abstract method evaluate_hessian_likelihood doesn't include these parameters in its kwargs. The finite difference implementation expects these parameters to be passed explicitly.

Copilot uses AI. Check for mistakes.
"""Evaluate the finite difference Hessian of the likelihood wrt to eta = Ax.

Parameters
----------
eta : NDArray
Vector of the linear predictor.
y : NDArray
Vector of the observations.
**kwargs : optional
Hyperparameters for likelihood.

Returns
-------
finite_difference_gradient : NDArray
Finite difference gradient of the likelihood.
"""
# grad = (-f(x+2h) + 8f(x+h) - 8f(x-h) + f(x-2h)) / 12h
# hessian = (-f(x+2h) + 16f(x+h) - 30f(x) + 16f(x-h) - f(x-2h)) / 12h^2
f1 = self.evaluate_likelihood(eta + h, y, **kwargs)
f2 = self.evaluate_likelihood(eta + 2 * h, y, **kwargs)
b1 = self.evaluate_likelihood(eta - h, y, **kwargs)
b2 = self.evaluate_likelihood(eta - 2 * h, y, **kwargs)
c = self.evaluate_likelihood(eta, y, **kwargs)
# grad = (-f2 + 8 * f1 - 8 * b1 + b2) / (12 * h)
hessian = (-f2 + 16 * f1 - 30 * c + 16 * b1 - b2) / (12 * h * h)
return hessian

def evaluate_hessian_likelihood_jax(
self,
**kwargs,
) -> ArrayLike:
raise NotImplementedError("JAX Hessian not implemented for this likelihood.")
21 changes: 16 additions & 5 deletions src/dalia/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,19 +431,30 @@ def construct_Q_conditional(
d_matrix = self.submodels[0].evaluate_d_matrix(**kwargs)
else:
# General rules
d_matrix = self.likelihood.evaluate_hessian_likelihood(**kwargs)
kwargs["y"] = self.y
d_matrix = self.likelihood.hessian_likelihood(**kwargs)

if d_matrix.ndim == 1:
d_matrix_diagonal_0 = d_matrix[0]
d_matrix = sp.sparse.diags(d_matrix)
elif d_matrix.ndim == 2:
d_matrix_diagonal_0 = d_matrix.diagonal()[0]
else:
raise ValueError("d_matrix must be 1D or 2D array.")

# if self.a is sparse -> Q_conditional should be sparse, else dense
if sp.sparse.issparse(self.a):
if self.aTa is not None:
self.Q_conditional = self.Q_prior - d_matrix.diagonal()[0] * self.aTa
# self.Q_conditional = self.Q_prior - d_matrix.diagonal()[0] * self.aTa
self.Q_conditional = self.Q_prior - d_matrix_diagonal_0 * self.aTa
else:
self.Q_conditional = self.Q_prior - self.a.T @ d_matrix @ self.a
# self.Q_conditional = self.Q_prior - self.a.T @ d_matrix @ self.a
else:
if self.aTa is not None:
self.Q_conditional = (
self.Q_prior.toarray() - d_matrix.diagonal()[0] * self.aTa
# self.Q_prior.toarray() - d_matrix.diagonal()[0] * self.aTa
self.Q_prior.toarray() - d_matrix_diagonal_0 * self.aTa
)
else:
self.Q_conditional = (
Expand All @@ -467,7 +478,7 @@ def construct_information_vector(
)

else:
gradient_likelihood = self.likelihood.evaluate_gradient_likelihood(
gradient_likelihood = self.likelihood.gradient_likelihood(
eta=eta,
y=self.y,
theta=self.theta[self.hyperparameters_idx[-1] :],
Expand Down Expand Up @@ -540,7 +551,7 @@ def evaluate_likelihood(self, eta: NDArray, **kwargs) -> float:
kwargs["h2"] = float(self.theta[0])
likelihood = self.submodels[0].evaluate_likelihood(eta, self.y, **kwargs)
else:
likelihood = self.likelihood.evaluate_likelihood(
likelihood = self.likelihood.evaluate_sum_likelihood(
eta, self.y, theta=self.theta[self.hyperparameters_idx[-1] :]
)

Expand Down
9 changes: 5 additions & 4 deletions src/dalia/likelihoods/binomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def evaluate_likelihood(
eta: NDArray,
y: NDArray,
**kwargs,
) -> float:
) -> NDArray:
"""Evalutate the a binomial likelihood.

Parameters
Expand All @@ -64,9 +64,10 @@ def evaluate_likelihood(
"""
linkEta: NDArray = self.link_function(eta)

likelihood: float = xp.dot(y, xp.log(linkEta)) + xp.dot(
self.n_trials - y, xp.log(1 - linkEta)
)
# likelihood: float = xp.dot(y, xp.log(linkEta)) + xp.dot(
# self.n_trials - y, xp.log(1 - linkEta)
# )
likelihood = y * xp.log(linkEta) + (self.n_trials - y) * xp.log(1 - linkEta)

return likelihood

Expand Down
53 changes: 49 additions & 4 deletions src/dalia/likelihoods/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@
from dalia.configs.likelihood_config import GaussianLikelihoodConfig
from dalia.core.likelihood import Likelihood

try:
import jax.numpy as jnp
from jax import grad, jit, vmap
JAX_AVAILABLE = True
except (ImportError, ModuleNotFoundError):
jnp = xp
JAX_AVAILABLE = False


class GaussianLikelihood(Likelihood):
"""Gaussian likelihood."""
Expand All @@ -16,12 +24,18 @@ def __init__(
"""Initializes the Gaussian likelihood."""
super().__init__(n_observations, config)

if JAX_AVAILABLE:
first_derivative = grad(self.evaluate_likelihood_jax, argnums=0)
second_derivative = grad(first_derivative, argnums=0)
self.gradient_jax = jit(vmap(first_derivative))
self.hessian_jax = jit(vmap(second_derivative))

def evaluate_likelihood(
self,
eta: NDArray,
y: NDArray,
**kwargs,
) -> float:
) -> NDArray:
"""Evaluate a Gaussian likelihood.

Notes
Expand Down Expand Up @@ -55,11 +69,14 @@ def evaluate_likelihood(
yEta = eta - y
# print("xp.exp(theta) in lh:", xp.exp(theta))

likelihood: float = (
0.5 * theta * self.n_observations - 0.5 * xp.exp(theta) * yEta.T @ yEta
)
likelihood = 0.5 * theta - 0.5 * xp.exp(theta) * yEta * yEta


return likelihood

def evaluate_likelihood_jax(self, eta, y, theta):
yEta = eta - y
return 0.5 * theta - 0.5 * jnp.exp(theta) * yEta * yEta

def evaluate_gradient_likelihood(
self,
Expand Down Expand Up @@ -94,6 +111,21 @@ def evaluate_gradient_likelihood(
gradient_likelihood: NDArray = -xp.exp(theta) * (eta - y)

return gradient_likelihood

def evaluate_gradient_likelihood_jax(
self,
eta: NDArray,
y: NDArray,
**kwargs,
) -> NDArray:
jax_eta = jnp.from_dlpack(eta)
jax_y = jnp.from_dlpack(y)
theta = kwargs.get("theta", None)
if not isinstance(theta, float):
theta = float(theta[0])
jax_theta = jnp.full_like(jax_eta, theta)
grad = self.gradient_jax(jax_eta, jax_y, jax_theta)
return xp.from_dlpack(grad)

def evaluate_hessian_likelihood(
self,
Expand Down Expand Up @@ -129,3 +161,16 @@ def evaluate_hessian_likelihood(
)

return hessian_likelihood

def evaluate_hessian_likelihood_jax(
self,
**kwargs,
) -> ArrayLike:
jax_eta = jnp.from_dlpack(kwargs.get("eta"))
jax_y = jnp.from_dlpack(kwargs.get("y"))
theta = kwargs.get("theta", None)
if not isinstance(theta, float):
theta = float(theta[0])
jax_theta = jnp.full_like(jax_eta, theta)
hessian = self.hessian_jax(jax_eta, jax_y, jax_theta)
return xp.from_dlpack(hessian)
Loading