diff --git a/src/tinygp/solvers/quasisep/block.py b/src/tinygp/solvers/quasisep/block.py index 3bcb0bc4..a6f1dcb5 100644 --- a/src/tinygp/solvers/quasisep/block.py +++ b/src/tinygp/solvers/quasisep/block.py @@ -47,9 +47,16 @@ def transpose(self) -> "Block": def T(self) -> "Block": return self.transpose() + @property + def mT(self) -> "Block": + return Block(*(jnp.swapaxes(b, -1, -2) for b in self.blocks)) + def to_dense(self) -> JAXArray: - assert all(np.ndim(b) == 2 for b in self.blocks) - return block_diag(*self.blocks) + ndim = self.ndim + assert ndim >= 2 + if ndim == 2: + return block_diag(*self.blocks) + return jax.vmap(lambda *bs: Block(*bs).to_dense())(*self.blocks) @jax.jit def __mul__(self, other: Any) -> "Block": @@ -98,16 +105,17 @@ def __matmul__(self, other: Any) -> Any: assert len(self.blocks) == len(other.blocks) assert all( np.shape(b1) == np.shape(b2) - for b1, b2 in zip(self.blocks, other.blocks) + for b1, b2 in zip(self.blocks, other.blocks, strict=True) + ) + return Block( + *(b1 @ b2 for b1, b2 in zip(self.blocks, other.blocks, strict=True)) ) - return Block(*(b1 @ b2 for b1, b2 in zip(self.blocks, other.blocks))) - assert all(np.ndim(b) == 2 for b in self.blocks) ndim = np.ndim(other) assert ndim >= 1 idx = 0 ys = [] for b in self.blocks: - size = len(b) + size = np.shape(b)[-1] x = ( other[idx : idx + size] if ndim == 1 @@ -119,11 +127,10 @@ def __matmul__(self, other: Any) -> Any: @jax.jit def __rmatmul__(self, other: Any) -> Any: - assert all(np.ndim(b) == 2 for b in self.blocks) idx = 0 ys = [] for b in self.blocks: - size = len(b) + size = np.shape(b)[-2] x = other[..., idx : idx + size] ys.append(x @ b) idx += size diff --git a/src/tinygp/solvers/quasisep/core.py b/src/tinygp/solvers/quasisep/core.py index 412e0e35..c799ea64 100644 --- a/src/tinygp/solvers/quasisep/core.py +++ b/src/tinygp/solvers/quasisep/core.py @@ -33,12 +33,12 @@ def handle_matvec_shapes( - func: Callable[[Any, JAXArray], JAXArray], -) -> Callable[[Any, JAXArray], JAXArray]: + func: Callable[..., JAXArray], +) -> Callable[..., JAXArray]: @wraps(func) - def wrapped(self: Any, x: JAXArray) -> JAXArray: + def wrapped(self: Any, x: JAXArray, **kwargs: Any) -> JAXArray: output_shape = x.shape - result = func(self, jnp.reshape(x, (output_shape[0], -1))) + result = func(self, jnp.reshape(x, (output_shape[0], -1)), **kwargs) return jnp.reshape(result, output_shape) return wrapped @@ -61,12 +61,14 @@ def transpose(self) -> Any: raise NotImplementedError @abstractmethod - def matmul(self, x: JAXArray) -> JAXArray: + def matmul(self, x: JAXArray, *, parallel: bool = False) -> JAXArray: """The dot product of this matrix with a dense vector or matrix Args: x (n, ...): A matrix or vector with leading dimension matching this matrix. + parallel: If ``True``, use a parallel associative-scan algorithm + instead of the default sequential scan. """ raise NotImplementedError @@ -149,7 +151,8 @@ def transpose(self) -> DiagQSM: return self @handle_matvec_shapes - def matmul(self, x: JAXArray) -> JAXArray: + def matmul(self, x: JAXArray, *, parallel: bool = False) -> JAXArray: + del parallel return self.d[:, None] * x def scale(self, other: JAXArray) -> DiagQSM: @@ -188,16 +191,12 @@ def shape(self) -> tuple[int, int]: def transpose(self) -> StrictUpperTriQSM: return StrictUpperTriQSM(p=self.p, q=self.q, a=self.a) - @jax.jit @handle_matvec_shapes - def matmul(self, x: JAXArray) -> JAXArray: - def impl(f, data): # type: ignore - q, a, x = data - return a @ f + jnp.outer(q, x), f + def matmul(self, x: JAXArray, *, parallel: bool = False) -> JAXArray: + from tinygp.solvers.quasisep.ops import lower_matmul, lower_matmul_parallel - init = jnp.zeros_like(jnp.outer(self.q[0], x[0])) - _, f = jax.lax.scan(impl, init, (self.q, self.a, x)) - return jax.vmap(jnp.dot)(self.p, f) + impl = lower_matmul_parallel if parallel else lower_matmul + return impl(self.p, self.q, self.a, x) def scale(self, other: JAXArray) -> StrictLowerTriQSM: return StrictLowerTriQSM(p=self.p * other, q=self.q, a=self.a) @@ -265,16 +264,12 @@ def shape(self) -> tuple[int, int]: def transpose(self) -> StrictLowerTriQSM: return StrictLowerTriQSM(p=self.p, q=self.q, a=self.a) - @jax.jit @handle_matvec_shapes - def matmul(self, x: JAXArray) -> JAXArray: - def impl(f, data): # type: ignore - p, a, x = data - return a.T @ f + jnp.outer(p, x), f + def matmul(self, x: JAXArray, *, parallel: bool = False) -> JAXArray: + from tinygp.solvers.quasisep.ops import upper_matmul, upper_matmul_parallel - init = jnp.zeros_like(jnp.outer(self.p[-1], x[-1])) - _, f = jax.lax.scan(impl, init, (self.p, self.a, x), reverse=True) - return jax.vmap(jnp.dot)(self.q, f) + impl = upper_matmul_parallel if parallel else upper_matmul + return impl(self.p, self.q, self.a, x) def scale(self, other: JAXArray) -> StrictUpperTriQSM: return StrictUpperTriQSM(p=self.p, q=self.q * other, a=self.a) @@ -306,8 +301,8 @@ def transpose(self) -> UpperTriQSM: return UpperTriQSM(diag=self.diag, upper=self.lower.transpose()) @handle_matvec_shapes - def matmul(self, x: JAXArray) -> JAXArray: - return self.diag.matmul(x) + self.lower.matmul(x) + def matmul(self, x: JAXArray, *, parallel: bool = False) -> JAXArray: + return self.diag.matmul(x) + self.lower.matmul(x, parallel=parallel) def scale(self, other: JAXArray) -> LowerTriQSM: return LowerTriQSM(diag=self.diag.scale(other), lower=self.lower.scale(other)) @@ -321,9 +316,8 @@ def inv(self) -> LowerTriQSM: b = a - jax.vmap(jnp.outer)(v, p) return LowerTriQSM(diag=DiagQSM(g), lower=StrictLowerTriQSM(p=u, q=v, a=b)) - @jax.jit @handle_matvec_shapes - def solve(self, y: JAXArray) -> JAXArray: + def solve(self, y: JAXArray, *, parallel: bool = False) -> JAXArray: """Solve a linear system with this matrix If this matrix is called ``L``, this solves ``L @ x = y`` for ``x`` @@ -332,16 +326,14 @@ def solve(self, y: JAXArray) -> JAXArray: Args: y (n, ...): A matrix or vector with leading dimension matching this matrix. + parallel: If ``True``, use a parallel associative-scan algorithm. """ + from tinygp.solvers.quasisep.ops import lower_solve, lower_solve_parallel - def impl(fn, data): # type: ignore - ((cn,), (pn, wn, an)), yn = data - xn = (yn - pn @ fn) / cn - return an @ fn + jnp.outer(wn, xn), xn - - init = jnp.zeros_like(jnp.outer(self.lower.q[0], y[0])) - _, x = jax.lax.scan(impl, init, (self, y)) - return x + (d,) = self.diag + p, q, a = self.lower + impl = lower_solve_parallel if parallel else lower_solve + return impl(d, p, q, a, y) def __neg__(self) -> LowerTriQSM: return LowerTriQSM(diag=-self.diag, lower=-self.lower) @@ -362,8 +354,8 @@ def transpose(self) -> LowerTriQSM: return LowerTriQSM(diag=self.diag, lower=self.upper.transpose()) @handle_matvec_shapes - def matmul(self, x: JAXArray) -> JAXArray: - return self.diag.matmul(x) + self.upper.matmul(x) + def matmul(self, x: JAXArray, *, parallel: bool = False) -> JAXArray: + return self.diag.matmul(x) + self.upper.matmul(x, parallel=parallel) def scale(self, other: JAXArray) -> UpperTriQSM: return UpperTriQSM(diag=self.diag.scale(other), upper=self.upper.scale(other)) @@ -371,9 +363,8 @@ def scale(self, other: JAXArray) -> UpperTriQSM: def inv(self) -> UpperTriQSM: return self.transpose().inv().transpose() - @jax.jit @handle_matvec_shapes - def solve(self, y: JAXArray) -> JAXArray: + def solve(self, y: JAXArray, *, parallel: bool = False) -> JAXArray: """Solve a linear system with this matrix If this matrix is called ``U``, this solves ``U @ x = y`` for ``x`` @@ -382,16 +373,14 @@ def solve(self, y: JAXArray) -> JAXArray: Args: y (n, ...): A matrix or vector with leading dimension matching this matrix. + parallel: If ``True``, use a parallel associative-scan algorithm. """ + from tinygp.solvers.quasisep.ops import upper_solve, upper_solve_parallel - def impl(fn, data): # type: ignore - ((cn,), (pn, wn, an)), yn = data - xn = (yn - wn @ fn) / cn - return an.T @ fn + jnp.outer(pn, xn), xn - - init = jnp.zeros_like(jnp.outer(self.upper.p[-1], y[-1])) - _, x = jax.lax.scan(impl, init, (self, y), reverse=True) - return x + (d,) = self.diag + p, q, a = self.upper + impl = upper_solve_parallel if parallel else upper_solve + return impl(d, p, q, a, y) def __neg__(self) -> UpperTriQSM: return UpperTriQSM(diag=-self.diag, upper=-self.upper) @@ -418,8 +407,12 @@ def transpose(self) -> SquareQSM: ) @handle_matvec_shapes - def matmul(self, x: JAXArray) -> JAXArray: - return self.diag.matmul(x) + self.lower.matmul(x) + self.upper.matmul(x) + def matmul(self, x: JAXArray, *, parallel: bool = False) -> JAXArray: + return ( + self.diag.matmul(x) + + self.lower.matmul(x, parallel=parallel) + + self.upper.matmul(x, parallel=parallel) + ) def scale(self, other: JAXArray) -> SquareQSM: return SquareQSM( @@ -504,71 +497,45 @@ def transpose(self) -> SymmQSM: return self @handle_matvec_shapes - def matmul(self, x: JAXArray) -> JAXArray: + def matmul(self, x: JAXArray, *, parallel: bool = False) -> JAXArray: return ( self.diag.matmul(x) - + self.lower.matmul(x) - + self.lower.transpose().matmul(x) + + self.lower.matmul(x, parallel=parallel) + + self.lower.transpose().matmul(x, parallel=parallel) ) def scale(self, other: JAXArray) -> SymmQSM: return SymmQSM(diag=self.diag.scale(other), lower=self.lower.scale(other)) - @jax.jit - def inv(self) -> SymmQSM: - """The inverse of this matrix""" - (d,) = self.diag - p, q, a = self.lower + def inv(self, *, parallel: bool = False) -> SymmQSM: + """The inverse of this matrix - def forward(carry, data): # type: ignore - f = carry - dk, pk, qk, ak = data - fpk = f @ pk - left = qk - ak @ fpk - igk = 1 / (dk - pk @ fpk) - sk = igk * left - ellk = ak - jnp.outer(sk, pk) - fk = ak @ f @ ak.T + igk * jnp.outer(left, left.T) - return fk, (igk, sk, ellk) - - init = jnp.zeros_like(jnp.outer(q[0], q[0])) - ig, s, ell = jax.lax.scan(forward, init, (d, p, q, a))[1] + Args: + parallel: If ``True``, use a parallel associative-scan algorithm. + """ + from tinygp.solvers.quasisep.ops import symm_inv, symm_inv_parallel - def backward(carry, data): # type: ignore - z = carry - igk, pk, ak, sk = data - zak = z @ ak - skzak = sk @ zak - lk = igk + sk @ z @ sk - tk = skzak - lk * pk - zk = ak.T @ zak - jnp.outer(skzak, pk) - jnp.outer(pk, tk) - return zk, (lk, tk) - - init = jnp.zeros_like(jnp.outer(p[-1], p[-1])) - lam, t = jax.lax.scan(backward, init, (ig, p, a, s), reverse=True)[1] + (d,) = self.diag + p, q, a = self.lower + impl = symm_inv_parallel if parallel else symm_inv + lam, t, s, ell = impl(d, p, q, a) return SymmQSM(diag=DiagQSM(d=lam), lower=StrictLowerTriQSM(p=t, q=s, a=ell)) - @jax.jit - def cholesky(self) -> LowerTriQSM: + def cholesky(self, *, parallel: bool = False) -> LowerTriQSM: """The Cholesky decomposition of this matrix If this matrix is called ``A``, this method returns the :class:`LowerTriQSM` ``L`` such that ``L @ L.T = A``. + + Args: + parallel: If ``True``, use a parallel associative-scan algorithm. """ + from tinygp.solvers.quasisep.ops import cholesky, cholesky_parallel + (d,) = self.diag p, q, a = self.lower - - def impl(carry, data): # type: ignore - fp = carry - dk, pk, qk, ak = data - ck = jnp.sqrt(dk - pk @ fp @ pk) - tmp = fp @ ak.T - wk = (qk - pk @ tmp) / ck - fk = ak @ tmp + jnp.outer(wk, wk) - return fk, (ck, wk) - - init = jnp.zeros_like(jnp.outer(q[0], q[0])) - _, (c, w) = jax.lax.scan(impl, init, (d, p, q, a)) + impl = cholesky_parallel if parallel else cholesky + c, w = impl(d, p, q, a) return LowerTriQSM(diag=DiagQSM(c), lower=StrictLowerTriQSM(p=p, q=w, a=a)) def __neg__(self) -> SymmQSM: diff --git a/src/tinygp/solvers/quasisep/ops.py b/src/tinygp/solvers/quasisep/ops.py index 4d88d6dd..95c9d26a 100644 --- a/src/tinygp/solvers/quasisep/ops.py +++ b/src/tinygp/solvers/quasisep/ops.py @@ -291,3 +291,222 @@ def none_safe_add(a: JAXArray | None, b: JAXArray | None) -> JAXArray | None: if a is not None and b is not None: return a + b return a if a is not None else b + + +# Ops with parallel implementations + + +def _shift_fwd(x): + # associative_scan is inclusive; the strict-lower recurrences are exclusive. + return jnp.concatenate((jnp.zeros_like(x[:1]), x[:-1]), axis=0) + + +def _shift_bwd(x): + return jnp.concatenate((x[1:], jnp.zeros_like(x[-1:])), axis=0) + + +@jax.jit +def lower_matmul(p, q, a, x): + def impl(f, data): + q, a, x = data + return a @ f + jnp.outer(q, x), f + + init = jnp.zeros_like(jnp.outer(q[0], x[0])) + _, f = jax.lax.scan(impl, init, (q, a, x)) + return jnp.einsum("nj,njk->nk", p, f) + + +@jax.jit +def lower_matmul_parallel(p, q, a, x): + def combine(left, right): + (Al, Bl), (Ar, Br) = left, right + return Ar @ Al, Ar @ Bl + Br + + b = jnp.einsum("nj,nk->njk", q, x) + _, f = jax.lax.associative_scan(combine, (a, b)) + return jnp.einsum("nj,njk->nk", p, _shift_fwd(f)) + + +@jax.jit +def upper_matmul(p, q, a, x): + def impl(f, data): + p, a, x = data + return a.T @ f + jnp.outer(p, x), f + + init = jnp.zeros_like(jnp.outer(p[-1], x[-1])) + _, f = jax.lax.scan(impl, init, (p, a, x), reverse=True) + return jnp.einsum("nj,njk->nk", q, f) + + +@jax.jit +def upper_matmul_parallel(p, q, a, x): + def combine(left, right): + (Al, Bl), (Ar, Br) = left, right + return Ar @ Al, Ar @ Bl + Br + + b = jnp.einsum("nj,nk->njk", p, x) + _, f = jax.lax.associative_scan(combine, (a.mT, b), reverse=True) + return jnp.einsum("nj,njk->nk", q, _shift_bwd(f)) + + +@jax.jit +def cholesky(d, p, q, a): + def impl(carry, data): + fp = carry + dk, pk, qk, ak = data + ck = jnp.sqrt(dk - pk @ fp @ pk) + tmp = fp @ ak.T + wk = (qk - pk @ tmp) / ck + fk = ak @ tmp + jnp.outer(wk, wk) + return fk, (ck, wk) + + init = jnp.zeros_like(jnp.outer(q[0], q[0])) + _, (c, w) = jax.lax.scan(impl, init, (d, p, q, a)) + return c, w + + +def _riccati_scan(d, p, q, a): + J = p.shape[1] + I = jnp.eye(J) + inv_d = 1.0 / d + A = a - jnp.einsum("n,nj,nk->njk", inv_d, q, p) + F = jnp.einsum("n,nj,nk->njk", inv_d, q, q) + G = -jnp.einsum("n,nj,nk->njk", inv_d, p, p) + + def combine(left, right): + (Al, Fl, Gl), (Ar, Fr, Gr) = left, right + M = I + Fl @ Gr + return ( + Ar @ jnp.linalg.solve(M, Al), + Fr + Ar @ jnp.linalg.solve(M, Fl) @ Ar.mT, + Gl + Al.mT @ jnp.linalg.solve(M.mT, Gr) @ Al, + ) + + _, f, _ = jax.lax.associative_scan(combine, (A, F, G)) + return _shift_fwd(f) + + +@jax.jit +def cholesky_parallel(d, p, q, a): + f = _riccati_scan(d, p, q, a) + + def emit(f, dk, pk, qk, ak): + ck = jnp.sqrt(dk - pk @ f @ pk) + wk = (qk - pk @ f @ ak.T) / ck + return ck, wk + + c, w = jax.vmap(emit)(f, d, p, q, a) + return c, w + + +@jax.jit +def symm_inv(d, p, q, a): + def forward(f, data): + dk, pk, qk, ak = data + fpk = f @ pk + left = qk - ak @ fpk + igk = 1 / (dk - pk @ fpk) + sk = igk * left + ellk = ak - jnp.outer(sk, pk) + fk = ak @ f @ ak.T + igk * jnp.outer(left, left) + return fk, (igk, sk, ellk) + + init = jnp.zeros_like(jnp.outer(q[0], q[0])) + _, (ig, s, ell) = jax.lax.scan(forward, init, (d, p, q, a)) + + def backward(z, data): + igk, pk, ak, sk = data + zak = z @ ak + skzak = sk @ zak + lk = igk + sk @ z @ sk + tk = skzak - lk * pk + zk = ak.T @ zak - jnp.outer(skzak, pk) - jnp.outer(pk, tk) + return zk, (lk, tk) + + init = jnp.zeros_like(jnp.outer(p[-1], p[-1])) + _, (lam, t) = jax.lax.scan(backward, init, (ig, p, a, s), reverse=True) + return lam, t, s, ell + + +@jax.jit +def symm_inv_parallel(d, p, q, a): + f = _riccati_scan(d, p, q, a) + + def fwd_emit(f, dk, pk, qk, ak): + fpk = f @ pk + left = qk - ak @ fpk + igk = 1 / (dk - pk @ fpk) + sk = igk * left + ellk = ak - jnp.outer(sk, pk) + return igk, sk, ellk + + ig, s, ell = jax.vmap(fwd_emit)(f, d, p, q, a) + + def bwd_combine(left, right): + (Al, Bl), (Ar, Br) = left, right + return Ar @ Al, Ar @ Bl @ Ar.mT + Br + + B = jnp.einsum("n,nj,nk->njk", ig, p, p) + _, z = jax.lax.associative_scan(bwd_combine, (ell.mT, B), reverse=True) + z = _shift_bwd(z) + + def bwd_emit(z, igk, pk, ak, sk): + skz = sk @ z + lk = igk + skz @ sk + tk = skz @ ak - lk * pk + return lk, tk + + lam, t = jax.vmap(bwd_emit)(z, ig, p, a, s) + return lam, t, s, ell + + +@jax.jit +def lower_solve(d, p, q, a, x): + def impl(f, data): + d, p, q, a, x = data + y = (x - p @ f) / d + return a @ f + jnp.outer(q, y), y + + init = jnp.zeros_like(jnp.outer(q[0], x[0])) + _, x = jax.lax.scan(impl, init, (d, p, q, a, x)) + return x + + +@jax.jit +def lower_solve_parallel(d, p, q, a, x): + def combine(left, right): + (Al, Bl), (Ar, Br) = left, right + return Ar @ Al, Ar @ Bl + Br + + inv_d = 1.0 / d[:, None] + q_ = q * inv_d + A = a - jnp.einsum("nj,nk->njk", q_, p) + b = jnp.einsum("nj,nk->njk", q_, x) + _, f = jax.lax.associative_scan(combine, (A, b)) + return (x - jnp.einsum("nj,njk->nk", p, _shift_fwd(f))) * inv_d + + +@jax.jit +def upper_solve(d, p, q, a, x): + def impl(f, data): + d, p, q, a, x = data + y = (x - q @ f) / d + return a.T @ f + jnp.outer(p, y), y + + init = jnp.zeros_like(jnp.outer(p[-1], x[-1])) + _, x = jax.lax.scan(impl, init, (d, p, q, a, x), reverse=True) + return x + + +@jax.jit +def upper_solve_parallel(d, p, q, a, x): + def combine(left, right): + (Al, Bl), (Ar, Br) = left, right + return Ar @ Al, Ar @ Bl + Br + + inv_d = 1.0 / d[:, None] + p_ = p * inv_d + A = a.mT - jnp.einsum("nj,nk->njk", p_, q) + b = jnp.einsum("nj,nk->njk", p_, x) + _, f = jax.lax.associative_scan(combine, (A, b), reverse=True) + return (x - jnp.einsum("nj,njk->nk", q, _shift_bwd(f))) * inv_d diff --git a/src/tinygp/solvers/quasisep/solver.py b/src/tinygp/solvers/quasisep/solver.py index d9d52652..c76f5ea7 100644 --- a/src/tinygp/solvers/quasisep/solver.py +++ b/src/tinygp/solvers/quasisep/solver.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any +import equinox as eqx import jax import jax.numpy as jnp import numpy as np @@ -29,6 +30,7 @@ class QuasisepSolver(Solver): X: JAXArray matrix: SymmQSM factor: LowerTriQSM + parallel: bool = eqx.field(static=True) def __init__( self, @@ -38,6 +40,7 @@ def __init__( *, covariance: Any | None = None, assume_sorted: bool = False, + parallel: bool = False, ): """Build a :class:`QuasisepSolver` for a given kernel and coordinates @@ -54,6 +57,11 @@ def __init__( error if they are not. This can introduce a runtime overhead, and you can pass ``assume_sorted=True`` to get the best performance. + parallel: If ``True``, use parallel associative-scan algorithms for + the Cholesky factorization, triangular solves, and matrix + products. This trades increased FLOPs for reduced sequential + depth and can be substantially faster on GPUs/TPUs for large + ``N``. """ from tinygp.kernels.quasisep import Quasisep @@ -70,7 +78,8 @@ def __init__( matrix = covariance self.X = X self.matrix = matrix - self.factor = matrix.cholesky() + self.parallel = parallel + self.factor = matrix.cholesky(parallel=parallel) def variance(self) -> JAXArray: return self.matrix.diag.d @@ -85,12 +94,12 @@ def normalization(self) -> JAXArray: def solve_triangular(self, y: JAXArray, *, transpose: bool = False) -> JAXArray: if transpose: - return self.factor.transpose().solve(y) + return self.factor.transpose().solve(y, parallel=self.parallel) else: - return self.factor.solve(y) + return self.factor.solve(y, parallel=self.parallel) def dot_triangular(self, y: JAXArray) -> JAXArray: - return self.factor @ y + return self.factor.matmul(y, parallel=self.parallel) def condition(self, kernel: Kernel, X_test: JAXArray | None, noise: Noise) -> Any: """Compute the covariance matrix for a conditional GP diff --git a/tests/test_solvers/test_quasisep/test_core.py b/tests/test_solvers/test_quasisep/test_core.py index ae1de4c4..09545403 100644 --- a/tests/test_solvers/test_quasisep/test_core.py +++ b/tests/test_solvers/test_quasisep/test_core.py @@ -22,6 +22,11 @@ def name(request): return request.param +@pytest.fixture(params=[False, True], ids=["sequential", "parallel"]) +def parallel(request): + return request.param + + @pytest.fixture def matrices(name): return get_matrices(name) @@ -149,7 +154,7 @@ def get_value(i, j): assert_allclose(get_value(i, j), m[i, j]) -def test_strict_tri_matmul(matrices): +def test_strict_tri_matmul(matrices, parallel): _, p, q, a, v, m, l, u = matrices mat = StrictLowerTriQSM(p=p, q=q, a=a) @@ -158,15 +163,15 @@ def test_strict_tri_matmul(matrices): assert_allclose(mat.T.to_dense(), u) # Check matvec - assert_allclose(mat @ v, l @ v) - assert_allclose(mat.T @ v, u @ v) + assert_allclose(mat.matmul(v, parallel=parallel), l @ v) + assert_allclose(mat.T.matmul(v, parallel=parallel), u @ v) # Check matmat - assert_allclose(mat @ m, l @ m) - assert_allclose(mat.T @ m, u @ m) + assert_allclose(mat.matmul(m, parallel=parallel), l @ m) + assert_allclose(mat.T.matmul(m, parallel=parallel), u @ m) -def test_tri_matmul(matrices): +def test_tri_matmul(matrices, parallel): diag, p, q, a, v, m, l, _ = matrices mat = LowerTriQSM(diag=DiagQSM(diag), lower=StrictLowerTriQSM(p=p, q=q, a=a)) dense = l + jnp.diag(diag) @@ -176,16 +181,16 @@ def test_tri_matmul(matrices): assert_allclose(mat.T.to_dense(), dense.T) # Check matvec - assert_allclose(mat @ v, dense @ v) - assert_allclose(mat.T @ v, dense.T @ v) + assert_allclose(mat.matmul(v, parallel=parallel), dense @ v) + assert_allclose(mat.T.matmul(v, parallel=parallel), dense.T @ v) # Check matmat - assert_allclose(mat @ m, dense @ m) - assert_allclose(mat.T @ m, dense.T @ m) + assert_allclose(mat.matmul(m, parallel=parallel), dense @ m) + assert_allclose(mat.T.matmul(m, parallel=parallel), dense.T @ m) @pytest.mark.parametrize("symm", [True, False]) -def test_square_matmul(symm, matrices): +def test_square_matmul(symm, matrices, parallel): diag, p, q, a, v, m, l, u = matrices if symm: mat = SymmQSM(diag=DiagQSM(diag), lower=StrictLowerTriQSM(p=p, q=q, a=a)) @@ -203,8 +208,8 @@ def test_square_matmul(symm, matrices): assert_allclose(jnp.diag(dense), diag) # Test matmuls - assert_allclose(mat @ v, dense @ v) - assert_allclose(mat @ m, dense @ m) + assert_allclose(mat.matmul(v, parallel=parallel), dense @ v) + assert_allclose(mat.matmul(m, parallel=parallel), dense @ m) assert_allclose(v.T @ mat, v.T @ dense) assert_allclose(m.T @ mat, m.T @ dense) @@ -220,20 +225,20 @@ def test_tri_inv(matrices): @pytest.mark.parametrize("name", ["celerite"]) -def test_tri_solve(matrices): +def test_tri_solve(matrices, parallel): diag, p, q, a, v, m, _, _ = matrices mat = LowerTriQSM(diag=DiagQSM(diag), lower=StrictLowerTriQSM(p=p, q=q, a=a)) dense = mat.to_dense() - assert_allclose(mat.solve(v), jnp.linalg.solve(dense, v)) - assert_allclose(mat.solve(m), jnp.linalg.solve(dense, m)) + assert_allclose(mat.solve(v, parallel=parallel), jnp.linalg.solve(dense, v)) + assert_allclose(mat.solve(m, parallel=parallel), jnp.linalg.solve(dense, m)) - assert_allclose(mat.T.solve(v), jnp.linalg.solve(dense.T, v)) - assert_allclose(mat.T.solve(m), jnp.linalg.solve(dense.T, m)) + assert_allclose(mat.T.solve(v, parallel=parallel), jnp.linalg.solve(dense.T, v)) + assert_allclose(mat.T.solve(m, parallel=parallel), jnp.linalg.solve(dense.T, m)) - assert_allclose(mat.inv().solve(v), dense @ v) - assert_allclose(mat.inv().solve(m), dense @ m) - assert_allclose(mat.T.inv().solve(v), dense.T @ v) - assert_allclose(mat.T.inv().solve(m), dense.T @ m) + assert_allclose(mat.inv().solve(v, parallel=parallel), dense @ v) + assert_allclose(mat.inv().solve(m, parallel=parallel), dense @ m) + assert_allclose(mat.T.inv().solve(v, parallel=parallel), dense.T @ v) + assert_allclose(mat.T.inv().solve(m, parallel=parallel), dense.T @ m) @pytest.mark.parametrize("symm", [True, False]) @@ -301,20 +306,21 @@ def test_gram(matrices): @pytest.mark.parametrize("name", ["celerite"]) -def test_cholesky(matrices): +def test_cholesky(matrices, parallel): diag, p, q, a, v, m, _, _ = matrices mat = SymmQSM(diag=DiagQSM(diag), lower=StrictLowerTriQSM(p=p, q=q, a=a)) dense = mat.to_dense() - chol = mat.cholesky() + chol = mat.cholesky(parallel=parallel) assert_allclose(chol.to_dense(), jnp.linalg.cholesky(dense)) mat = mat.inv() dense = mat.to_dense() - chol = mat.cholesky() + chol = mat.cholesky(parallel=parallel) assert_allclose(chol.to_dense(), jnp.linalg.cholesky(dense)) - assert_allclose(chol.solve(v), jnp.linalg.solve(chol.to_dense(), v)) - assert_allclose(chol.solve(m), jnp.linalg.solve(chol.to_dense(), m)) + dense = chol.to_dense() + assert_allclose(chol.solve(v, parallel=parallel), jnp.linalg.solve(dense, v)) + assert_allclose(chol.solve(m, parallel=parallel), jnp.linalg.solve(dense, m)) def test_tri_qsmul(some_nice_matrices): diff --git a/tests/test_solvers/test_quasisep/test_ops.py b/tests/test_solvers/test_quasisep/test_ops.py new file mode 100644 index 00000000..20da5013 --- /dev/null +++ b/tests/test_solvers/test_quasisep/test_ops.py @@ -0,0 +1,76 @@ +# mypy: ignore-errors + +import jax.numpy as jnp +import pytest +from numpy import random as np_random + +from tinygp.kernels.quasisep import Matern32, Matern52 +from tinygp.solvers.quasisep.core import DiagQSM +from tinygp.solvers.quasisep.ops import ( + cholesky, + cholesky_parallel, + lower_matmul, + lower_matmul_parallel, + lower_solve, + lower_solve_parallel, + symm_inv, + symm_inv_parallel, + upper_matmul, + upper_matmul_parallel, + upper_solve, + upper_solve_parallel, +) +from tinygp.test_utils import assert_allclose + + +@pytest.fixture(params=[Matern32, Matern52]) +def data(request): + N = 100 + random = np_random.default_rng(1234) + t = jnp.sort(jnp.asarray(random.uniform(0, 10, N))) + kernel = request.param(scale=1.3) + qsm = kernel.to_symm_qsm(t) + DiagQSM(jnp.full(N, 0.1)) + (d,) = qsm.diag + p, q, a = qsm.lower + x = jnp.asarray(random.normal(size=(N, 3))) + return d, p, q, a, x + + +def test_lower_matmul_parallel(data): + _, p, q, a, x = data + assert_allclose(lower_matmul_parallel(p, q, a, x), lower_matmul(p, q, a, x)) + + +def test_upper_matmul_parallel(data): + _, p, q, a, x = data + assert_allclose(upper_matmul_parallel(p, q, a, x), upper_matmul(p, q, a, x)) + + +def test_cholesky_parallel(data): + d, p, q, a, _ = data + c_seq, w_seq = cholesky(d, p, q, a) + c_par, w_par = cholesky_parallel(d, p, q, a) + assert_allclose(c_par, c_seq) + assert_allclose(w_par, w_seq) + + +def test_lower_solve_parallel(data): + d, p, q, a, x = data + c, w = cholesky(d, p, q, a) + assert_allclose(lower_solve_parallel(c, p, w, a, x), lower_solve(c, p, w, a, x)) + + +def test_upper_solve_parallel(data): + d, p, q, a, x = data + c, w = cholesky(d, p, q, a) + assert_allclose(upper_solve_parallel(c, p, w, a, x), upper_solve(c, p, w, a, x)) + + +def test_symm_inv_parallel(data): + d, p, q, a, _ = data + lam_s, t_s, s_s, ell_s = symm_inv(d, p, q, a) + lam_p, t_p, s_p, ell_p = symm_inv_parallel(d, p, q, a) + assert_allclose(lam_p, lam_s) + assert_allclose(t_p, t_s) + assert_allclose(s_p, s_s) + assert_allclose(ell_p, ell_s) diff --git a/tests/test_solvers/test_quasisep/test_solver.py b/tests/test_solvers/test_quasisep/test_solver.py index 5fcb64f1..bd4704e6 100644 --- a/tests/test_solvers/test_quasisep/test_solver.py +++ b/tests/test_solvers/test_quasisep/test_solver.py @@ -46,17 +46,25 @@ def data(random): quasisep.Cosine(sigma=1.8, scale=1.5), 1.8**2 * kernels.Cosine(1.5), ), + ( + quasisep.Matern32(sigma=1.8, scale=1.5) + + quasisep.Matern52(sigma=0.9, scale=0.7), + 1.8**2 * kernels.Matern32(1.5) + 0.9**2 * kernels.Matern52(0.7), + ), ] ) def kernel_pair(request): return request.param -def test_consistent_with_direct(kernel_pair, data): +@pytest.mark.parametrize("parallel", [False, True], ids=["sequential", "parallel"]) +def test_consistent_with_direct(kernel_pair, data, parallel): kernel0 = quasisep.Matern32(sigma=3.8, scale=4.5) kernel1, kernel2 = kernel_pair x, y, t = data - gp1 = GaussianProcess(kernel1, x, diag=0.1, solver=QuasisepSolver) + gp1 = GaussianProcess( + kernel1, x, diag=0.1, solver=QuasisepSolver, parallel=parallel + ) gp2 = GaussianProcess(kernel2, x, diag=0.1, solver=DirectSolver) assert_allclose(gp1.covariance, gp2.covariance)