From 0810cf4c725941aae05f8ad47831c29444b4321e Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Fri, 19 Aug 2022 16:13:47 +0000 Subject: [PATCH] Fill triangular bijector --- distrax/_src/bijectors/fill_triangular.py | 88 +++++++++++++++ .../_src/bijectors/fill_triangular_test.py | 101 ++++++++++++++++++ 2 files changed, 189 insertions(+) create mode 100644 distrax/_src/bijectors/fill_triangular.py create mode 100644 distrax/_src/bijectors/fill_triangular_test.py diff --git a/distrax/_src/bijectors/fill_triangular.py b/distrax/_src/bijectors/fill_triangular.py new file mode 100644 index 00000000..7d6b2e5f --- /dev/null +++ b/distrax/_src/bijectors/fill_triangular.py @@ -0,0 +1,88 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Fill triangular bijector.""" + +import jax.numpy as jnp +from typing import Tuple, Optional + +from distrax._src.bijectors import bijector as base + +Array = base.Array + + +class FillTriangular(base.Bijector): + """A transformation that maps a vector to a triangular matrix. The triangular + matrix can be either upper or lower triangular. By default, the lower + triangular matrix is used. + + When projecting from a vector to a triangular matrix, entries of the matrix + are populated row-wise. For example, if the vector is [1, 2, 3, 4, 5, 6], + the triangular matrix will be: + [[1, 0, 0], + [2, 3, 0], + [4, 5, 6]]. + """ + + def __init__( + self, + matrix_shape: int, + is_lower: Optional[bool] = True, + ): + """Initialise the `FillTriangular` bijector. + + Args: + matrix_shape (int): The number of rows (or columns) in the original + triangular matrix. + upper (Optional[bool]): Whether or not the matrix being transformed + is an upper or lower-triangular matrix. Defaults to True. + """ """""" + super().__init__(event_ndims_in=0) + self.matrix_shape = matrix_shape + self.index_fn = jnp.tril_indices if is_lower else jnp.triu_indices + + def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]: + """The forward method maps from a vector to a triangular matrix. + + Args: + x (Array): The 1-dimensional vector that is to be mapped into a + triangular matrix. + + Returns: + Tuple[Array, Array]: A triangular matrix and the log determinant of the + Jacobian. The log-determinant here is just 0. as the bijection is simply + reshaping. + """ + # matrix_shape = jnp.sqrt(0.25 + 2. * jnp.shape(x)[0]) - 0.5 + # matrix_shape = jnp.asarray(matrix_shape).astype(jnp.int32) + y = jnp.zeros((self.matrix_shape, self.matrix_shape)) + # Get the indexes for which we need to fill the triangular matrix + idxs = self.index_fn(self.matrix_shape) + # Fill the triangular matrix + y = y.at[idxs].set(x) + return y, jnp.array(0.0) + + def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]: + """The inverse method maps from a triangular matrix to a vector. + + Args: + y (Array): The lower triangular + + Returns: + Tuple[Array, Array]: The vectorised form of the supplied triangular + matrix and the log determinant of the Jacobian. The log-determinant + here is just 0. as the bijection is simply reshaping. + """ + matrix_shape = y.shape[0] + return y[self.index_fn(matrix_shape)], jnp.array(0.0) diff --git a/distrax/_src/bijectors/fill_triangular_test.py b/distrax/_src/bijectors/fill_triangular_test.py new file mode 100644 index 00000000..fc6f3e7a --- /dev/null +++ b/distrax/_src/bijectors/fill_triangular_test.py @@ -0,0 +1,101 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `fill_triangular.py`.""" + +from absl.testing import absltest +from absl.testing import parameterized + +import chex +from distrax._src.bijectors import fill_triangular +import jax +import jax.numpy as jnp +import numpy as np +from tensorflow_probability.substrates import jax as tfp + +tfb = tfp.bijectors + +RTOL = 1e-5 + + +class FillTriangularTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.seed = jax.random.PRNGKey(1234) + + def test_properties(self): + bijector = fill_triangular.FillTriangular(matrix_shape=3) + self.assertEqual(bijector.event_ndims_in, 0) + self.assertEqual(bijector.event_ndims_out, 0) + self.assertFalse(bijector.is_constant_jacobian) + self.assertFalse(bijector.is_constant_log_det) + + @chex.all_variants + @parameterized.parameters(True, False) + def test_forward_method(self, is_lower): + base_vector = jnp.array([1, 2, 3, 4, 5, 6]) + bijector = fill_triangular.FillTriangular(matrix_shape=3, is_lower=is_lower) + x_triangular = self.variant(bijector.forward)(base_vector) + if is_lower: + self.assertTrue(jnp.sum(jnp.tril(x_triangular)) == jnp.sum(base_vector)) + elif not is_lower: + self.assertTrue(jnp.sum(jnp.triu(x_triangular)) == jnp.sum(base_vector)) + + @chex.all_variants + @parameterized.parameters(True, False) + def test_inverse_method(self, is_lower): + random_array = jax.random.normal(self.seed, shape=(5, 5)) + psd_matrix = random_array @ random_array.T + triangular_mat = jnp.linalg.cholesky(psd_matrix) + if not is_lower: + triangular_mat = triangular_mat.T + + bijector = fill_triangular.FillTriangular(matrix_shape=5, is_lower=is_lower) + x_vector = self.variant(bijector.inverse)(triangular_mat) + if is_lower: + np.testing.assert_allclose(jnp.sum(jnp.tril(triangular_mat)), + jnp.sum(x_vector), + rtol=RTOL) + elif not is_lower: + np.testing.assert_allclose(jnp.sum(jnp.triu(triangular_mat)), + jnp.sum(x_vector), + rtol=RTOL) + + @chex.all_variants + @parameterized.parameters(True, False) + def test_inverse_log_jacobian(self, is_lower): + random_array = jax.random.normal(self.seed, shape=(5, 5)) + psd_matrix = random_array @ random_array.T + triangular_mat = jnp.linalg.cholesky(psd_matrix) + if not is_lower: + triangular_mat = triangular_mat.T + + bijector = fill_triangular.FillTriangular(matrix_shape=5, is_lower=is_lower) + log_det_jac = self.variant( + bijector.inverse_log_det_jacobian)(triangular_mat) + self.assertTrue(log_det_jac == 0.0) + + @chex.all_variants + @parameterized.parameters(True, False) + def test_forward_log_jacobian(self, is_lower): + base_vector = jnp.array([1, 2, 3, 4, 5, 6]) + bijector = fill_triangular.FillTriangular(matrix_shape=3, is_lower=is_lower) + inv_log_det_jac = self.variant( + bijector.forward_log_det_jacobian)(base_vector) + self.assertTrue(inv_log_det_jac == 0.0) + + +if __name__ == "__main__": + absltest.main()