Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions distrax/_src/distributions/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,13 @@ def _kl_divergence_categorical_categorical(
) -> Array:
"""Obtain the KL divergence `KL(dist1 || dist2)` between two Categoricals.

This is useful if we want to e.g. compute the cross entropy between logits and
labels as:

H(labels, logits) = KL(labels | logits) - H(labels)

where we parametrize labels and logits as categorical distributions.

Args:
dist1: A Categorical distribution.
dist2: A Categorical distribution.
Expand Down Expand Up @@ -193,8 +200,12 @@ def _kl_divergence_categorical_categorical(
else:
probs1 = dist1.probs

log_probs1 = jax.nn.log_softmax(logits1, axis=-1)
log_probs2 = jax.nn.log_softmax(logits2, axis=-1)
# When there are NaNs in the logits, e.g. when we take log(prob=0), we ignore
# those components and set them to 0.
log_probs1 = jnp.where(
jnp.isfinite(logits1), jax.nn.log_softmax(logits1, axis=-1), 0.)
log_probs2 = jnp.where(
jnp.isfinite(logits2), jax.nn.log_softmax(logits2, axis=-1), 0.)

return jnp.sum((probs1 * (log_probs1 - log_probs2)), axis=-1)

Expand Down