From 57e4188c42a24f77d4b7f2bac93d22a5b1da653f Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 2 Apr 2026 02:03:36 +0000 Subject: [PATCH 1/4] Fix Block transition matrices breaking Sum kernel operations (#265) Block diagonal matrices used in Sum kernel transition matrices, stationary covariances, and design matrices were incompatible with several operations: adding QSMs (banded noise), product kernels (_prod_helper indexing), and elementwise multiplication (self_mul fancy indexing). Convert Block to dense in these contexts since the state-space matrices are small. Also add a use_block=False option to Sum for users who want to bypass Block entirely. https://claude.ai/code/session_01Y2ACGEqvh9fTrCzR5WEPuJ --- src/tinygp/kernels/quasisep.py | 33 +++++++- src/tinygp/solvers/quasisep/core.py | 15 +++- src/tinygp/solvers/quasisep/ops.py | 34 +++++--- tests/test_kernels/test_quasisep.py | 127 ++++++++++++++++++++++++++++ 4 files changed, 194 insertions(+), 15 deletions(-) diff --git a/src/tinygp/kernels/quasisep.py b/src/tinygp/kernels/quasisep.py index af232811..ef414387 100644 --- a/src/tinygp/kernels/quasisep.py +++ b/src/tinygp/kernels/quasisep.py @@ -220,20 +220,41 @@ def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray: class Sum(Quasisep): - """A helper to represent the sum of two quasiseparable kernels""" + """A helper to represent the sum of two quasiseparable kernels + + Args: + kernel1: The first kernel. + kernel2: The second kernel. + use_block: If ``True`` (default), use :class:`Block` diagonal matrices + for the transition matrices, design matrices, and stationary + covariance. If ``False``, use dense ``block_diag`` representations + instead, which avoids compatibility issues with some operations + (e.g. banded noise, product kernels) at a small performance cost + for the state-space matrices. + """ kernel1: Quasisep kernel2: Quasisep + use_block: bool = eqx.field(static=True, default=True) def coord_to_sortable(self, X: JAXArray) -> JAXArray: """We assume that both kernels use the same coordinates""" return self.kernel1.coord_to_sortable(X) + def _block_or_dense(self, m1: JAXArray, m2: JAXArray) -> JAXArray: + if self.use_block: + return Block(m1, m2) + from jax.scipy.linalg import block_diag as jsp_block_diag + + return jsp_block_diag(m1, m2) + def design_matrix(self) -> JAXArray: - return Block(self.kernel1.design_matrix(), self.kernel2.design_matrix()) + return self._block_or_dense( + self.kernel1.design_matrix(), self.kernel2.design_matrix() + ) def stationary_covariance(self) -> JAXArray: - return Block( + return self._block_or_dense( self.kernel1.stationary_covariance(), self.kernel2.stationary_covariance(), ) @@ -247,7 +268,7 @@ def observation_model(self, X: JAXArray) -> JAXArray: ) def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray: - return Block( + return self._block_or_dense( self.kernel1.transition_matrix(X1, X2), self.kernel2.transition_matrix(X1, X2), ) @@ -632,6 +653,10 @@ def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray: def _prod_helper(a1: JAXArray, a2: JAXArray) -> JAXArray: + if isinstance(a1, Block): + a1 = a1.to_dense() + if isinstance(a2, Block): + a2 = a2.to_dense() i, j = np.meshgrid(np.arange(a1.shape[0]), np.arange(a2.shape[0])) i = i.flatten() j = j.flatten() diff --git a/src/tinygp/solvers/quasisep/core.py b/src/tinygp/solvers/quasisep/core.py index ac1fdf18..505f6bf8 100644 --- a/src/tinygp/solvers/quasisep/core.py +++ b/src/tinygp/solvers/quasisep/core.py @@ -203,6 +203,7 @@ def scale(self, other: JAXArray) -> StrictLowerTriQSM: def self_add(self, other: StrictLowerTriQSM) -> StrictLowerTriQSM: """The sum of two :class:`StrictLowerTriQSM` matrices""" + from tinygp.solvers.quasisep.block import Block @jax.vmap def impl( @@ -210,6 +211,10 @@ def impl( ) -> StrictLowerTriQSM: p1, q1, a1 = self p2, q2, a2 = other + if isinstance(a1, Block): + a1 = a1.to_dense() + if isinstance(a2, Block): + a2 = a2.to_dense() return StrictLowerTriQSM( p=jnp.concatenate((p1, p2)), q=jnp.concatenate((q1, q2)), @@ -220,13 +225,21 @@ def impl( def self_mul(self, other: StrictLowerTriQSM) -> StrictLowerTriQSM: """The elementwise product of two :class:`StrictLowerTriQSM` matrices""" + from tinygp.solvers.quasisep.block import Block + + self_a = self.a + other_a = other.a + if isinstance(self_a, Block): + self_a = jax.vmap(lambda b: b.to_dense())(self_a) + if isinstance(other_a, Block): + other_a = jax.vmap(lambda b: b.to_dense())(other_a) i, j = np.meshgrid(np.arange(self.p.shape[1]), np.arange(other.p.shape[1])) i = i.flatten() j = j.flatten() return StrictLowerTriQSM( p=self.p[:, i] * other.p[:, j], q=self.q[:, i] * other.q[:, j], - a=self.a[:, i[:, None], i[None, :]] * other.a[:, j[:, None], j[None, :]], + a=self_a[:, i[:, None], i[None, :]] * other_a[:, j[:, None], j[None, :]], ) def __neg__(self) -> StrictLowerTriQSM: diff --git a/src/tinygp/solvers/quasisep/ops.py b/src/tinygp/solvers/quasisep/ops.py index ff48fb18..0ee70c3b 100644 --- a/src/tinygp/solvers/quasisep/ops.py +++ b/src/tinygp/solvers/quasisep/ops.py @@ -8,6 +8,16 @@ import jax.numpy as jnp from tinygp.helpers import JAXArray +from tinygp.solvers.quasisep.block import Block + + +def _ensure_dense(x: JAXArray) -> JAXArray: + """Convert Block to dense array if needed.""" + if isinstance(x, Block): + return x.to_dense() + return x + + from tinygp.solvers.quasisep.core import ( QSM, DiagQSM, @@ -145,15 +155,17 @@ def impl( u += [upper_b.p] if upper_b is not None else [] if lower_a is not None and lower_b is not None: + la_a = _ensure_dense(lower_a.a) + lb_a = _ensure_dense(lower_b.a) ell = jnp.concatenate( ( jnp.concatenate( - (lower_a.a, jnp.outer(lower_a.q, lower_b.p)), axis=-1 + (la_a, jnp.outer(lower_a.q, lower_b.p)), axis=-1 ), jnp.concatenate( ( - jnp.zeros((lower_b.a.shape[0], lower_a.a.shape[0])), - lower_b.a, + jnp.zeros((lb_a.shape[0], la_a.shape[0])), + lb_a, ), axis=-1, ), @@ -162,23 +174,25 @@ def impl( ) else: ell = ( - lower_a.a + _ensure_dense(lower_a.a) if lower_a is not None - else lower_b.a if lower_b is not None else None + else _ensure_dense(lower_b.a) if lower_b is not None else None ) if upper_a is not None and upper_b is not None: + ua_a = _ensure_dense(upper_a.a) + ub_a = _ensure_dense(upper_b.a) delta = jnp.concatenate( ( jnp.concatenate( ( - upper_a.a, - jnp.zeros((upper_a.a.shape[0], upper_b.a.shape[0])), + ua_a, + jnp.zeros((ua_a.shape[0], ub_a.shape[0])), ), axis=-1, ), jnp.concatenate( - (jnp.outer(upper_b.q, upper_a.p), upper_b.a), axis=-1 + (jnp.outer(upper_b.q, upper_a.p), ub_a), axis=-1 ), ), axis=0, @@ -186,9 +200,9 @@ def impl( else: delta = ( - upper_a.a + _ensure_dense(upper_a.a) if upper_a is not None - else upper_b.a if upper_b is not None else None + else _ensure_dense(upper_b.a) if upper_b is not None else None ) return ( diff --git a/tests/test_kernels/test_quasisep.py b/tests/test_kernels/test_quasisep.py index c426224d..99efd6ee 100644 --- a/tests/test_kernels/test_quasisep.py +++ b/tests/test_kernels/test_quasisep.py @@ -157,3 +157,130 @@ def test_carma_quads(): assert_allclose(carma31.arroots, carma31_quads.arroots) assert_allclose(carma31.acf, carma31_quads.acf) assert_allclose(carma31.obsmodel, carma31_quads.obsmodel) + + +# Regression tests for https://github.com/dfm/tinygp/issues/265 +# Block transition matrices in Sum kernels broke several operations. + + +def test_sum_kernel_with_banded_noise(data): + """Sum kernel + banded noise: self_add must handle Block transition matrices""" + from tinygp.noise import Banded + + x, y, _ = data + N = len(x) + k = quasisep.Cosine(1.0) + quasisep.Cosine(2.0) + banded = Banded(diag=0.1 * jnp.ones(N), off_diags=0.01 * jnp.ones((N, 1))) + gp = GaussianProcess(k, x, noise=banded) + assert jnp.isfinite(gp.log_probability(y)) + + +def test_sum_kernel_with_banded_noise_condition(data): + """Sum kernel + banded noise: conditioning must handle Block in qsm_mul""" + from tinygp.noise import Banded + + x, y, _ = data + N = len(x) + k = quasisep.Cosine(1.0) + quasisep.Cosine(2.0) + banded = Banded(diag=0.1 * jnp.ones(N), off_diags=0.01 * jnp.ones((N, 1))) + gp = GaussianProcess(k, x, noise=banded) + lp, cond_gp = gp.condition(y) + assert jnp.isfinite(lp) + + +def test_product_of_sum_kernel(data): + """Product kernel with Sum factor: _prod_helper must handle Block inputs""" + x, y, _ = data + N = len(x) + k = (quasisep.Cosine(1.0) + quasisep.Cosine(2.0)) * quasisep.Exp(1.0) + gp = GaussianProcess(k, x, diag=jnp.ones(N)) + assert jnp.isfinite(gp.log_probability(y)) + + +def test_product_of_sum_kernel_consistency(data): + """Product of sum kernel QSM must match direct kernel evaluation""" + x, _, _ = data + k = (quasisep.Cosine(1.0) + quasisep.Cosine(2.0)) * quasisep.Exp(1.0) + assert_allclose(k.to_symm_qsm(x).to_dense(), k(x, x)) + + +def test_sum_times_sum_kernel(data): + """Product of two Sum kernels: self_mul must handle Block transition matrices""" + x, y, _ = data + N = len(x) + k = (quasisep.Cosine(1.0) + quasisep.Cosine(2.0)) * ( + quasisep.Exp(0.5) + quasisep.Matern32(1.0) + ) + gp = GaussianProcess(k, x, diag=jnp.ones(N)) + assert jnp.isfinite(gp.log_probability(y)) + + +def test_sum_kernel_use_block_false(data): + """Sum kernel with use_block=False bypasses Block entirely""" + x, y, _ = data + N = len(x) + k = quasisep.Sum(quasisep.Cosine(1.0), quasisep.Cosine(2.0), use_block=False) + gp = GaussianProcess(k, x, diag=0.1 * jnp.ones(N)) + assert jnp.isfinite(gp.log_probability(y)) + + +def test_sum_kernel_use_block_consistency(data): + """Block and non-block Sum kernels must produce the same results""" + x, y, _ = data + N = len(x) + k_block = quasisep.Cosine(1.0) + quasisep.Cosine(2.0) + k_dense = quasisep.Sum( + quasisep.Cosine(1.0), quasisep.Cosine(2.0), use_block=False + ) + gp_block = GaussianProcess(k_block, x, diag=0.1 * jnp.ones(N)) + gp_dense = GaussianProcess(k_dense, x, diag=0.1 * jnp.ones(N)) + assert_allclose(gp_block.log_probability(y), gp_dense.log_probability(y)) + + +def test_sum_kernel_use_block_false_with_banded_noise(data): + """Sum kernel use_block=False with banded noise""" + from tinygp.noise import Banded + + x, y, _ = data + N = len(x) + k = quasisep.Sum(quasisep.Cosine(1.0), quasisep.Cosine(2.0), use_block=False) + banded = Banded(diag=0.1 * jnp.ones(N), off_diags=0.01 * jnp.ones((N, 1))) + gp = GaussianProcess(k, x, noise=banded) + assert jnp.isfinite(gp.log_probability(y)) + + +def test_sum_kernel_use_block_false_product(data): + """Sum kernel use_block=False in a product""" + x, y, _ = data + N = len(x) + k = quasisep.Sum( + quasisep.Cosine(1.0), quasisep.Cosine(2.0), use_block=False + ) * quasisep.Exp(1.0) + gp = GaussianProcess(k, x, diag=jnp.ones(N)) + assert jnp.isfinite(gp.log_probability(y)) + + +def test_jit_sum_kernel_block(data): + """JIT must work with Sum kernel block computations""" + x, y, _ = data + N = len(x) + k = quasisep.Cosine(1.0) + quasisep.Cosine(2.0) + + @jax.jit + def compute(x, y): + gp = GaussianProcess(k, x, diag=0.1 * jnp.ones(N)) + return gp.log_probability(y) + + assert jnp.isfinite(compute(x, y)) + + +def test_grad_product_of_sum_kernel(data): + """Gradients must work through product of sum kernel""" + x, y, _ = data + N = len(x) + + def loss(sigma): + k = (quasisep.Cosine(sigma) + quasisep.Cosine(2.0)) * quasisep.Exp(1.0) + return GaussianProcess(k, x, diag=jnp.ones(N)).log_probability(y) + + assert jnp.isfinite(jax.grad(loss)(1.0)) From d5910cefca73ac5f534b33afd3a378e41006421b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 2 Apr 2026 02:08:37 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/tinygp/solvers/quasisep/ops.py | 8 ++------ tests/test_kernels/test_quasisep.py | 4 +--- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/src/tinygp/solvers/quasisep/ops.py b/src/tinygp/solvers/quasisep/ops.py index 0ee70c3b..5d0bca19 100644 --- a/src/tinygp/solvers/quasisep/ops.py +++ b/src/tinygp/solvers/quasisep/ops.py @@ -159,9 +159,7 @@ def impl( lb_a = _ensure_dense(lower_b.a) ell = jnp.concatenate( ( - jnp.concatenate( - (la_a, jnp.outer(lower_a.q, lower_b.p)), axis=-1 - ), + jnp.concatenate((la_a, jnp.outer(lower_a.q, lower_b.p)), axis=-1), jnp.concatenate( ( jnp.zeros((lb_a.shape[0], la_a.shape[0])), @@ -191,9 +189,7 @@ def impl( ), axis=-1, ), - jnp.concatenate( - (jnp.outer(upper_b.q, upper_a.p), ub_a), axis=-1 - ), + jnp.concatenate((jnp.outer(upper_b.q, upper_a.p), ub_a), axis=-1), ), axis=0, ) diff --git a/tests/test_kernels/test_quasisep.py b/tests/test_kernels/test_quasisep.py index 99efd6ee..1e4903e5 100644 --- a/tests/test_kernels/test_quasisep.py +++ b/tests/test_kernels/test_quasisep.py @@ -229,9 +229,7 @@ def test_sum_kernel_use_block_consistency(data): x, y, _ = data N = len(x) k_block = quasisep.Cosine(1.0) + quasisep.Cosine(2.0) - k_dense = quasisep.Sum( - quasisep.Cosine(1.0), quasisep.Cosine(2.0), use_block=False - ) + k_dense = quasisep.Sum(quasisep.Cosine(1.0), quasisep.Cosine(2.0), use_block=False) gp_block = GaussianProcess(k_block, x, diag=0.1 * jnp.ones(N)) gp_dense = GaussianProcess(k_dense, x, diag=0.1 * jnp.ones(N)) assert_allclose(gp_block.log_probability(y), gp_dense.log_probability(y)) From 0415ec06391436b8df1102a80fa40a67e7e0bded Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 2 Apr 2026 02:21:41 +0000 Subject: [PATCH 3/4] Address PR review: extract ensure_dense helper, move imports to top, trim tests - Extract ensure_dense() into block.py as a shared helper - Move all imports to module top level (no lazy imports) - Remove test docstrings and comments - Consolidate tests to minimal set covering the three failure modes + use_block https://claude.ai/code/session_01Y2ACGEqvh9fTrCzR5WEPuJ --- src/tinygp/kernels/quasisep.py | 12 ++-- src/tinygp/solvers/quasisep/block.py | 7 +++ src/tinygp/solvers/quasisep/core.py | 20 ++---- src/tinygp/solvers/quasisep/ops.py | 27 +++----- tests/test_kernels/test_quasisep.py | 92 +--------------------------- 5 files changed, 30 insertions(+), 128 deletions(-) diff --git a/src/tinygp/kernels/quasisep.py b/src/tinygp/kernels/quasisep.py index ef414387..2e071ecf 100644 --- a/src/tinygp/kernels/quasisep.py +++ b/src/tinygp/kernels/quasisep.py @@ -33,9 +33,11 @@ import jax.numpy as jnp import numpy as np +from jax.scipy.linalg import block_diag as jsp_block_diag + from tinygp.helpers import JAXArray from tinygp.kernels.base import Kernel -from tinygp.solvers.quasisep.block import Block +from tinygp.solvers.quasisep.block import Block, ensure_dense from tinygp.solvers.quasisep.core import DiagQSM, StrictLowerTriQSM, SymmQSM from tinygp.solvers.quasisep.general import GeneralQSM @@ -244,8 +246,6 @@ def coord_to_sortable(self, X: JAXArray) -> JAXArray: def _block_or_dense(self, m1: JAXArray, m2: JAXArray) -> JAXArray: if self.use_block: return Block(m1, m2) - from jax.scipy.linalg import block_diag as jsp_block_diag - return jsp_block_diag(m1, m2) def design_matrix(self) -> JAXArray: @@ -653,10 +653,8 @@ def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray: def _prod_helper(a1: JAXArray, a2: JAXArray) -> JAXArray: - if isinstance(a1, Block): - a1 = a1.to_dense() - if isinstance(a2, Block): - a2 = a2.to_dense() + a1 = ensure_dense(a1) + a2 = ensure_dense(a2) i, j = np.meshgrid(np.arange(a1.shape[0]), np.arange(a2.shape[0])) i = i.flatten() j = j.flatten() diff --git a/src/tinygp/solvers/quasisep/block.py b/src/tinygp/solvers/quasisep/block.py index f5736064..6e102cd8 100644 --- a/src/tinygp/solvers/quasisep/block.py +++ b/src/tinygp/solvers/quasisep/block.py @@ -9,6 +9,13 @@ from tinygp.helpers import JAXArray +def ensure_dense(x: Any) -> Any: + """Convert a Block to a dense array, passing through non-Block inputs.""" + if isinstance(x, Block): + return x.to_dense() + return x + + class Block(eqx.Module): blocks: tuple[Any, ...] __array_priority__ = 1999 diff --git a/src/tinygp/solvers/quasisep/core.py b/src/tinygp/solvers/quasisep/core.py index 505f6bf8..412e0e35 100644 --- a/src/tinygp/solvers/quasisep/core.py +++ b/src/tinygp/solvers/quasisep/core.py @@ -29,6 +29,7 @@ from jax.scipy.linalg import block_diag from tinygp.helpers import JAXArray +from tinygp.solvers.quasisep.block import ensure_dense def handle_matvec_shapes( @@ -203,7 +204,6 @@ def scale(self, other: JAXArray) -> StrictLowerTriQSM: def self_add(self, other: StrictLowerTriQSM) -> StrictLowerTriQSM: """The sum of two :class:`StrictLowerTriQSM` matrices""" - from tinygp.solvers.quasisep.block import Block @jax.vmap def impl( @@ -211,28 +211,20 @@ def impl( ) -> StrictLowerTriQSM: p1, q1, a1 = self p2, q2, a2 = other - if isinstance(a1, Block): - a1 = a1.to_dense() - if isinstance(a2, Block): - a2 = a2.to_dense() return StrictLowerTriQSM( p=jnp.concatenate((p1, p2)), q=jnp.concatenate((q1, q2)), - a=block_diag(a1, a2), + a=block_diag(ensure_dense(a1), ensure_dense(a2)), ) return impl(self, other) def self_mul(self, other: StrictLowerTriQSM) -> StrictLowerTriQSM: """The elementwise product of two :class:`StrictLowerTriQSM` matrices""" - from tinygp.solvers.quasisep.block import Block - - self_a = self.a - other_a = other.a - if isinstance(self_a, Block): - self_a = jax.vmap(lambda b: b.to_dense())(self_a) - if isinstance(other_a, Block): - other_a = jax.vmap(lambda b: b.to_dense())(other_a) + # vmap is needed because a batched Block has 3D block arrays that + # block_diag (used by to_dense) cannot handle without unbatching. + self_a = jax.vmap(ensure_dense)(self.a) + other_a = jax.vmap(ensure_dense)(other.a) i, j = np.meshgrid(np.arange(self.p.shape[1]), np.arange(other.p.shape[1])) i = i.flatten() j = j.flatten() diff --git a/src/tinygp/solvers/quasisep/ops.py b/src/tinygp/solvers/quasisep/ops.py index 5d0bca19..4d88d6dd 100644 --- a/src/tinygp/solvers/quasisep/ops.py +++ b/src/tinygp/solvers/quasisep/ops.py @@ -8,16 +8,7 @@ import jax.numpy as jnp from tinygp.helpers import JAXArray -from tinygp.solvers.quasisep.block import Block - - -def _ensure_dense(x: JAXArray) -> JAXArray: - """Convert Block to dense array if needed.""" - if isinstance(x, Block): - return x.to_dense() - return x - - +from tinygp.solvers.quasisep.block import ensure_dense from tinygp.solvers.quasisep.core import ( QSM, DiagQSM, @@ -155,8 +146,8 @@ def impl( u += [upper_b.p] if upper_b is not None else [] if lower_a is not None and lower_b is not None: - la_a = _ensure_dense(lower_a.a) - lb_a = _ensure_dense(lower_b.a) + la_a = ensure_dense(lower_a.a) + lb_a = ensure_dense(lower_b.a) ell = jnp.concatenate( ( jnp.concatenate((la_a, jnp.outer(lower_a.q, lower_b.p)), axis=-1), @@ -172,14 +163,14 @@ def impl( ) else: ell = ( - _ensure_dense(lower_a.a) + ensure_dense(lower_a.a) if lower_a is not None - else _ensure_dense(lower_b.a) if lower_b is not None else None + else ensure_dense(lower_b.a) if lower_b is not None else None ) if upper_a is not None and upper_b is not None: - ua_a = _ensure_dense(upper_a.a) - ub_a = _ensure_dense(upper_b.a) + ua_a = ensure_dense(upper_a.a) + ub_a = ensure_dense(upper_b.a) delta = jnp.concatenate( ( jnp.concatenate( @@ -196,9 +187,9 @@ def impl( else: delta = ( - _ensure_dense(upper_a.a) + ensure_dense(upper_a.a) if upper_a is not None - else _ensure_dense(upper_b.a) if upper_b is not None else None + else ensure_dense(upper_b.a) if upper_b is not None else None ) return ( diff --git a/tests/test_kernels/test_quasisep.py b/tests/test_kernels/test_quasisep.py index 1e4903e5..c585f989 100644 --- a/tests/test_kernels/test_quasisep.py +++ b/tests/test_kernels/test_quasisep.py @@ -6,6 +6,7 @@ from tinygp import GaussianProcess from tinygp.kernels import quasisep +from tinygp.noise import Banded from tinygp.test_utils import assert_allclose @@ -159,73 +160,35 @@ def test_carma_quads(): assert_allclose(carma31.obsmodel, carma31_quads.obsmodel) -# Regression tests for https://github.com/dfm/tinygp/issues/265 -# Block transition matrices in Sum kernels broke several operations. - - def test_sum_kernel_with_banded_noise(data): - """Sum kernel + banded noise: self_add must handle Block transition matrices""" - from tinygp.noise import Banded - x, y, _ = data N = len(x) k = quasisep.Cosine(1.0) + quasisep.Cosine(2.0) banded = Banded(diag=0.1 * jnp.ones(N), off_diags=0.01 * jnp.ones((N, 1))) gp = GaussianProcess(k, x, noise=banded) assert jnp.isfinite(gp.log_probability(y)) - - -def test_sum_kernel_with_banded_noise_condition(data): - """Sum kernel + banded noise: conditioning must handle Block in qsm_mul""" - from tinygp.noise import Banded - - x, y, _ = data - N = len(x) - k = quasisep.Cosine(1.0) + quasisep.Cosine(2.0) - banded = Banded(diag=0.1 * jnp.ones(N), off_diags=0.01 * jnp.ones((N, 1))) - gp = GaussianProcess(k, x, noise=banded) lp, cond_gp = gp.condition(y) assert jnp.isfinite(lp) def test_product_of_sum_kernel(data): - """Product kernel with Sum factor: _prod_helper must handle Block inputs""" x, y, _ = data - N = len(x) k = (quasisep.Cosine(1.0) + quasisep.Cosine(2.0)) * quasisep.Exp(1.0) - gp = GaussianProcess(k, x, diag=jnp.ones(N)) + gp = GaussianProcess(k, x, diag=jnp.ones(len(x))) assert jnp.isfinite(gp.log_probability(y)) - - -def test_product_of_sum_kernel_consistency(data): - """Product of sum kernel QSM must match direct kernel evaluation""" - x, _, _ = data - k = (quasisep.Cosine(1.0) + quasisep.Cosine(2.0)) * quasisep.Exp(1.0) assert_allclose(k.to_symm_qsm(x).to_dense(), k(x, x)) def test_sum_times_sum_kernel(data): - """Product of two Sum kernels: self_mul must handle Block transition matrices""" x, y, _ = data - N = len(x) k = (quasisep.Cosine(1.0) + quasisep.Cosine(2.0)) * ( quasisep.Exp(0.5) + quasisep.Matern32(1.0) ) - gp = GaussianProcess(k, x, diag=jnp.ones(N)) + gp = GaussianProcess(k, x, diag=jnp.ones(len(x))) assert jnp.isfinite(gp.log_probability(y)) def test_sum_kernel_use_block_false(data): - """Sum kernel with use_block=False bypasses Block entirely""" - x, y, _ = data - N = len(x) - k = quasisep.Sum(quasisep.Cosine(1.0), quasisep.Cosine(2.0), use_block=False) - gp = GaussianProcess(k, x, diag=0.1 * jnp.ones(N)) - assert jnp.isfinite(gp.log_probability(y)) - - -def test_sum_kernel_use_block_consistency(data): - """Block and non-block Sum kernels must produce the same results""" x, y, _ = data N = len(x) k_block = quasisep.Cosine(1.0) + quasisep.Cosine(2.0) @@ -233,52 +196,3 @@ def test_sum_kernel_use_block_consistency(data): gp_block = GaussianProcess(k_block, x, diag=0.1 * jnp.ones(N)) gp_dense = GaussianProcess(k_dense, x, diag=0.1 * jnp.ones(N)) assert_allclose(gp_block.log_probability(y), gp_dense.log_probability(y)) - - -def test_sum_kernel_use_block_false_with_banded_noise(data): - """Sum kernel use_block=False with banded noise""" - from tinygp.noise import Banded - - x, y, _ = data - N = len(x) - k = quasisep.Sum(quasisep.Cosine(1.0), quasisep.Cosine(2.0), use_block=False) - banded = Banded(diag=0.1 * jnp.ones(N), off_diags=0.01 * jnp.ones((N, 1))) - gp = GaussianProcess(k, x, noise=banded) - assert jnp.isfinite(gp.log_probability(y)) - - -def test_sum_kernel_use_block_false_product(data): - """Sum kernel use_block=False in a product""" - x, y, _ = data - N = len(x) - k = quasisep.Sum( - quasisep.Cosine(1.0), quasisep.Cosine(2.0), use_block=False - ) * quasisep.Exp(1.0) - gp = GaussianProcess(k, x, diag=jnp.ones(N)) - assert jnp.isfinite(gp.log_probability(y)) - - -def test_jit_sum_kernel_block(data): - """JIT must work with Sum kernel block computations""" - x, y, _ = data - N = len(x) - k = quasisep.Cosine(1.0) + quasisep.Cosine(2.0) - - @jax.jit - def compute(x, y): - gp = GaussianProcess(k, x, diag=0.1 * jnp.ones(N)) - return gp.log_probability(y) - - assert jnp.isfinite(compute(x, y)) - - -def test_grad_product_of_sum_kernel(data): - """Gradients must work through product of sum kernel""" - x, y, _ = data - N = len(x) - - def loss(sigma): - k = (quasisep.Cosine(sigma) + quasisep.Cosine(2.0)) * quasisep.Exp(1.0) - return GaussianProcess(k, x, diag=jnp.ones(N)).log_probability(y) - - assert jnp.isfinite(jax.grad(loss)(1.0)) From 35d976862488abfd964d67b347ef461aa96cb448 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 2 Apr 2026 02:21:53 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/tinygp/kernels/quasisep.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/tinygp/kernels/quasisep.py b/src/tinygp/kernels/quasisep.py index 2e071ecf..b380f08a 100644 --- a/src/tinygp/kernels/quasisep.py +++ b/src/tinygp/kernels/quasisep.py @@ -32,7 +32,6 @@ import jax import jax.numpy as jnp import numpy as np - from jax.scipy.linalg import block_diag as jsp_block_diag from tinygp.helpers import JAXArray