diff --git a/distrax/_src/bijectors/diag_linear.py b/distrax/_src/bijectors/diag_linear.py index d5b1d62..4e21ddb 100644 --- a/distrax/_src/bijectors/diag_linear.py +++ b/distrax/_src/bijectors/diag_linear.py @@ -57,11 +57,21 @@ def __init__(self, diag: Array): batch_shape=diag.shape[:-1], dtype=diag.dtype) self._diag = diag - self.forward = self._bijector.forward - self.forward_log_det_jacobian = self._bijector.forward_log_det_jacobian - self.inverse = self._bijector.inverse - self.inverse_log_det_jacobian = self._bijector.inverse_log_det_jacobian - self.inverse_and_log_det = self._bijector.inverse_and_log_det + + def forward(self, x: Array) -> Array: + return self._bijector.forward(x) + + def forward_log_det_jacobian(self, x: Array) -> Array: + return self._bijector.forward_log_det_jacobian(x) + + def inverse(self, y: Array) -> Array: + return self._bijector.inverse(y) + + def inverse_log_det_jacobian(self, y: Array) -> Array: + return self._bijector.inverse_log_det_jacobian(y) + + def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]: + return self._bijector.inverse_and_log_det(y) @property def diag(self) -> Array: diff --git a/distrax/_src/bijectors/diag_plus_low_rank_linear.py b/distrax/_src/bijectors/diag_plus_low_rank_linear.py index 9d1c13f..b4a6789 100644 --- a/distrax/_src/bijectors/diag_plus_low_rank_linear.py +++ b/distrax/_src/bijectors/diag_plus_low_rank_linear.py @@ -201,11 +201,21 @@ def __init__(self, diag: Array, u_matrix: Array, v_matrix: Array): self._diag = diag self._u_matrix = u_matrix self._v_matrix = v_matrix - self.forward = self._bijector.forward - self.forward_log_det_jacobian = self._bijector.forward_log_det_jacobian - self.inverse = self._bijector.inverse - self.inverse_log_det_jacobian = self._bijector.inverse_log_det_jacobian - self.inverse_and_log_det = self._bijector.inverse_and_log_det + + def forward(self, x: Array) -> Array: + return self._bijector.forward(x) + + def forward_log_det_jacobian(self, x: Array) -> Array: + return self._bijector.forward_log_det_jacobian(x) + + def inverse(self, y: Array) -> Array: + return self._bijector.inverse(y) + + def inverse_log_det_jacobian(self, y: Array) -> Array: + return self._bijector.inverse_log_det_jacobian(y) + + def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]: + return self._bijector.inverse_and_log_det(y) @property def diag(self) -> Array: