Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions distrax/_src/bijectors/fill_triangular.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Fill triangular bijector."""

import jax.numpy as jnp
from typing import Tuple, Optional

from distrax._src.bijectors import bijector as base

Array = base.Array


class FillTriangular(base.Bijector):
"""A transformation that maps a vector to a triangular matrix. The triangular
matrix can be either upper or lower triangular. By default, the lower
triangular matrix is used.

When projecting from a vector to a triangular matrix, entries of the matrix
are populated row-wise. For example, if the vector is [1, 2, 3, 4, 5, 6],
the triangular matrix will be:
[[1, 0, 0],
[2, 3, 0],
[4, 5, 6]].
"""

def __init__(
self,
matrix_shape: int,
is_lower: Optional[bool] = True,
):
"""Initialise the `FillTriangular` bijector.

Args:
matrix_shape (int): The number of rows (or columns) in the original
triangular matrix.
upper (Optional[bool]): Whether or not the matrix being transformed
is an upper or lower-triangular matrix. Defaults to True.
""" """"""
super().__init__(event_ndims_in=0)
self.matrix_shape = matrix_shape
self.index_fn = jnp.tril_indices if is_lower else jnp.triu_indices

def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]:
"""The forward method maps from a vector to a triangular matrix.

Args:
x (Array): The 1-dimensional vector that is to be mapped into a
triangular matrix.

Returns:
Tuple[Array, Array]: A triangular matrix and the log determinant of the
Jacobian. The log-determinant here is just 0. as the bijection is simply
reshaping.
"""
# matrix_shape = jnp.sqrt(0.25 + 2. * jnp.shape(x)[0]) - 0.5
# matrix_shape = jnp.asarray(matrix_shape).astype(jnp.int32)
y = jnp.zeros((self.matrix_shape, self.matrix_shape))
# Get the indexes for which we need to fill the triangular matrix
idxs = self.index_fn(self.matrix_shape)
# Fill the triangular matrix
y = y.at[idxs].set(x)
return y, jnp.array(0.0)

def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]:
"""The inverse method maps from a triangular matrix to a vector.

Args:
y (Array): The lower triangular

Returns:
Tuple[Array, Array]: The vectorised form of the supplied triangular
matrix and the log determinant of the Jacobian. The log-determinant
here is just 0. as the bijection is simply reshaping.
"""
matrix_shape = y.shape[0]
return y[self.index_fn(matrix_shape)], jnp.array(0.0)
101 changes: 101 additions & 0 deletions distrax/_src/bijectors/fill_triangular_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `fill_triangular.py`."""

from absl.testing import absltest
from absl.testing import parameterized

import chex
from distrax._src.bijectors import fill_triangular
import jax
import jax.numpy as jnp
import numpy as np
from tensorflow_probability.substrates import jax as tfp

tfb = tfp.bijectors

RTOL = 1e-5


class FillTriangularTest(parameterized.TestCase):

def setUp(self):
super().setUp()
self.seed = jax.random.PRNGKey(1234)

def test_properties(self):
bijector = fill_triangular.FillTriangular(matrix_shape=3)
self.assertEqual(bijector.event_ndims_in, 0)
self.assertEqual(bijector.event_ndims_out, 0)
self.assertFalse(bijector.is_constant_jacobian)
self.assertFalse(bijector.is_constant_log_det)

@chex.all_variants
@parameterized.parameters(True, False)
def test_forward_method(self, is_lower):
base_vector = jnp.array([1, 2, 3, 4, 5, 6])
bijector = fill_triangular.FillTriangular(matrix_shape=3, is_lower=is_lower)
x_triangular = self.variant(bijector.forward)(base_vector)
if is_lower:
self.assertTrue(jnp.sum(jnp.tril(x_triangular)) == jnp.sum(base_vector))
elif not is_lower:
self.assertTrue(jnp.sum(jnp.triu(x_triangular)) == jnp.sum(base_vector))

@chex.all_variants
@parameterized.parameters(True, False)
def test_inverse_method(self, is_lower):
random_array = jax.random.normal(self.seed, shape=(5, 5))
psd_matrix = random_array @ random_array.T
triangular_mat = jnp.linalg.cholesky(psd_matrix)
if not is_lower:
triangular_mat = triangular_mat.T

bijector = fill_triangular.FillTriangular(matrix_shape=5, is_lower=is_lower)
x_vector = self.variant(bijector.inverse)(triangular_mat)
if is_lower:
np.testing.assert_allclose(jnp.sum(jnp.tril(triangular_mat)),
jnp.sum(x_vector),
rtol=RTOL)
elif not is_lower:
np.testing.assert_allclose(jnp.sum(jnp.triu(triangular_mat)),
jnp.sum(x_vector),
rtol=RTOL)

@chex.all_variants
@parameterized.parameters(True, False)
def test_inverse_log_jacobian(self, is_lower):
random_array = jax.random.normal(self.seed, shape=(5, 5))
psd_matrix = random_array @ random_array.T
triangular_mat = jnp.linalg.cholesky(psd_matrix)
if not is_lower:
triangular_mat = triangular_mat.T

bijector = fill_triangular.FillTriangular(matrix_shape=5, is_lower=is_lower)
log_det_jac = self.variant(
bijector.inverse_log_det_jacobian)(triangular_mat)
self.assertTrue(log_det_jac == 0.0)

@chex.all_variants
@parameterized.parameters(True, False)
def test_forward_log_jacobian(self, is_lower):
base_vector = jnp.array([1, 2, 3, 4, 5, 6])
bijector = fill_triangular.FillTriangular(matrix_shape=3, is_lower=is_lower)
inv_log_det_jac = self.variant(
bijector.forward_log_det_jacobian)(base_vector)
self.assertTrue(inv_log_det_jac == 0.0)


if __name__ == "__main__":
absltest.main()