From 91bb4b6dd84d274bce3d21c9acef8371bfb46e48 Mon Sep 17 00:00:00 2001 From: Surya Bhupatiraju Date: Thu, 10 Feb 2022 10:35:28 -0800 Subject: [PATCH] rlax: Replace rlax categorical cross entropy computation with distrax components. PiperOrigin-RevId: 427788283 --- distrax/_src/distributions/categorical.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/distrax/_src/distributions/categorical.py b/distrax/_src/distributions/categorical.py index bc2bb57c..5eea3e4a 100644 --- a/distrax/_src/distributions/categorical.py +++ b/distrax/_src/distributions/categorical.py @@ -165,6 +165,13 @@ def _kl_divergence_categorical_categorical( ) -> Array: """Obtain the KL divergence `KL(dist1 || dist2)` between two Categoricals. + This is useful if we want to e.g. compute the cross entropy between logits and + labels as: + + H(labels, logits) = KL(labels | logits) - H(labels) + + where we parametrize labels and logits as categorical distributions. + Args: dist1: A Categorical distribution. dist2: A Categorical distribution. @@ -193,8 +200,12 @@ def _kl_divergence_categorical_categorical( else: probs1 = dist1.probs - log_probs1 = jax.nn.log_softmax(logits1, axis=-1) - log_probs2 = jax.nn.log_softmax(logits2, axis=-1) + # When there are NaNs in the logits, e.g. when we take log(prob=0), we ignore + # those components and set them to 0. + log_probs1 = jnp.where( + jnp.isfinite(logits1), jax.nn.log_softmax(logits1, axis=-1), 0.) + log_probs2 = jnp.where( + jnp.isfinite(logits2), jax.nn.log_softmax(logits2, axis=-1), 0.) return jnp.sum((probs1 * (log_probs1 - log_probs2)), axis=-1)