Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 159 additions & 26 deletions src/ott/solvers/linear/continuous_barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from typing import Any, NamedTuple, Optional, Tuple, Union
import warnings
from typing import Any, Dict, NamedTuple, Optional, Sequence, Tuple, Union

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -58,8 +59,15 @@ def set(self, **kwargs: Any) -> "FreeBarycenterState":
return self._replace(**kwargs)

def update(
self, iteration: int, bar_prob: barycenter_problem.FreeBarycenterProblem,
linear_solver: Any, store_errors: bool
self,
iteration: int,
bar_prob: barycenter_problem.FreeBarycenterProblem,
linear_solver: Any,
store_errors: bool,
*,
tau_a: float = 1.0,
tau_b: float = 1.0,
learn_a: bool = False,
) -> "FreeBarycenterState":
"""Update the state of the solver.

Expand All @@ -68,6 +76,18 @@ def update(
bar_prob: the barycenter problem.
linear_solver: the linear OT solver to use.
store_errors: whether to store the errors of the inner loop.
tau_a: Relaxation parameter for the barycenter marginal constraint.
Must lie in ``(0, 1]``. When ``tau_a < 1``, unbalanced transport
is used on the barycenter side, allowing the transport row marginals
to differ from the current barycenter weights ``a``.
tau_b: Relaxation parameter for the input-measure marginal constraint.
Must lie in ``(0, 1]``. When ``tau_b < 1``, the input measures
do not need to be fully transported, providing robustness to
outliers.
learn_a: If ``True``, update barycenter weights at every iteration
using the weighted arithmetic mean of transport-plan row marginals.
Requires ``tau_a < 1`` to be effective, since balanced transport
forces row marginals to equal ``a``.

Returns:
The updated state.
Expand All @@ -81,9 +101,14 @@ def solve_linear_ot(
geom = pointcloud.PointCloud(
x, y, cost_fn=bar_prob.cost_fn, epsilon=bar_prob.epsilon
)
prob = linear_problem.LinearProblem(geom, a=a, b=b)
# Preserve exact old behaviour when both taus are 1.0.
if tau_a < 1.0 or tau_b < 1.0:
prob = linear_problem.LinearProblem(
geom, a=a, b=b, tau_a=tau_a, tau_b=tau_b
)
else:
prob = linear_problem.LinearProblem(geom, a=a, b=b)
out = linear_solver(prob)
# instantiate matrix since it is a property of out.
return out, out.matrix

outs, matrices = solve_linear_ot(self.a, self.x, seg_b, seg_y)
Expand All @@ -102,17 +127,43 @@ def solve_linear_ot(
errors = None

# Approximation of barycenter as barycenter of barycenters per measure.

barycenters_per_measure = mu.barycentric_projection(
matrices, seg_y, bar_prob.cost_fn
)

x_new = jax.vmap(
lambda w, y: bar_prob.cost_fn.barycenter(w, y)[0], in_axes=[None, 1]
)(bar_prob.weights, barycenters_per_measure)
row_marginals = jnp.sum(matrices, axis=2) # [num_measures, bar_size]

if tau_a < 1.0:
# In unbalanced mode each atom k aggregates measure i with weight
# proportional to lambda_i * r_{ik} (measure weight times the row
# marginal of the plan at that atom).
atom_weights = bar_prob.weights[:, None] * row_marginals
normalizer = jnp.sum(atom_weights, axis=0, keepdims=True)
atom_weights = jnp.where(
normalizer > 0.0,
atom_weights / normalizer,
bar_prob.weights[:, None],
)
x_new = jax.vmap(
lambda w, y: bar_prob.cost_fn.barycenter(w, y)[0], in_axes=[1, 1]
)(atom_weights, barycenters_per_measure)
else:
x_new = jax.vmap(
lambda w, y: bar_prob.cost_fn.barycenter(w, y)[0], in_axes=[None, 1]
)(bar_prob.weights, barycenters_per_measure)

# Optionally learn barycenter weights from transport-plan row marginals.
# The weighted arithmetic mean is the exact block-coordinate minimiser
# for the joint objective with KL(r_i || a) penalty.
if learn_a:
a_new = jnp.sum(row_marginals * bar_prob.weights[:, None], axis=0)
a_new = jnp.clip(a_new, min=1e-12)
a_new = a_new / jnp.sum(a_new)
else:
a_new = self.a

return self.set(
x=x_new,
a=a_new,
costs=updated_costs,
linear_convergence=linear_convergence,
linear_outputs=outs,
Expand Down Expand Up @@ -181,15 +232,80 @@ def linear_output_at_index(self, i: int) -> LinearOutput:

@jax.tree_util.register_pytree_node_class
class FreeWassersteinBarycenter(was_solver.WassersteinSolver):
"""Continuous Wasserstein barycenter solver :cite:`cuturi:14`."""
"""Continuous Wasserstein barycenter solver :cite:`cuturi:14`.

Args:
linear_solver: Inner linear OT solver.
tau_a: Relaxation parameter for the barycenter marginal constraint.
Must lie in ``(0, 1]``. When ``tau_a < 1``, unbalanced transport is
used on the barycenter side, so that the transport row marginals can
differ from ``a``. Default ``1.0`` (balanced).
tau_b: Relaxation parameter for the input-measure marginal constraint.
Must lie in ``(0, 1]``. When ``tau_b < 1``, the input measures do
not need to be fully transported, providing robustness to outliers.
Default ``1.0`` (balanced).
learn_a: If ``True``, update barycenter weights at every outer
iteration using the weighted arithmetic mean of transport-plan row
marginals. Requires ``tau_a < 1`` to be effective, since balanced
transport forces row marginals to equal ``a``.
a_init: Optional initial barycenter weights of shape ``(bar_size,)``.
If ``None``, weights are initialised uniformly.
kwargs: Forwarded to
:class:`~ott.solvers.was_solver.WassersteinSolver`.
"""

def __init__(
self,
linear_solver,
*,
tau_a: float = 1.0,
tau_b: float = 1.0,
learn_a: bool = False,
a_init: Optional[jnp.ndarray] = None,
**kwargs,
):
super().__init__(linear_solver=linear_solver, **kwargs)
if not (0.0 < tau_a <= 1.0):
raise ValueError("tau_a must be in (0, 1].")
if not (0.0 < tau_b <= 1.0):
raise ValueError("tau_b must be in (0, 1].")
if learn_a and tau_a == 1.0:
warnings.warn(
"`learn_a=True` has no effect when `tau_a=1.0` (balanced). "
"The current implementation learns weights from unbalanced "
"row marginals, which coincide with `a` when balanced. "
"Set `tau_a < 1` to enable weight learning.",
stacklevel=2,
)
self.tau_a = tau_a
self.tau_b = tau_b
self.learn_a = learn_a
self.a_init = a_init

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102
return ([self.linear_solver, self.threshold, self.a_init], {
"min_iterations": self.min_iterations,
"max_iterations": self.max_iterations,
"store_inner_errors": self.store_inner_errors,
"tau_a": self.tau_a,
"tau_b": self.tau_b,
"learn_a": self.learn_a,
})

@classmethod
def tree_unflatten( # noqa: D102
cls, aux_data: Dict[str, Any], children: Sequence[Any]
) -> "FreeWassersteinBarycenter":
linear_solver, threshold, a_init = children
return cls(linear_solver, threshold=threshold, a_init=a_init, **aux_data)

def __call__( # noqa: D102
self,
bar_prob: barycenter_problem.FreeBarycenterProblem,
bar_size: int = 100,
x_init: Optional[jnp.ndarray] = None,
rng: Optional[jax.Array] = None,
) -> FreeBarycenterState:
) -> FreeBarycenterOutput:
rng = utils.default_prng_key(rng)
return self.iterations(bar_size, bar_prob, x_init, rng)

Expand Down Expand Up @@ -218,7 +334,6 @@ def init_state(
assert bar_size == x_init.shape[0]
x = x_init
else:
# sample randomly points in the support of the y measures
rng = utils.default_prng_key(rng)
indices_subset = jax.random.choice(
rng,
Expand All @@ -229,8 +344,15 @@ def init_state(
)
x = bar_prob.flattened_y[indices_subset, :]

# TODO(cuturi) expand to non-uniform weights for barycenter.
a = jnp.ones((bar_size,)) / bar_size
if self.a_init is not None:
a = jnp.asarray(self.a_init)
if a.shape != (bar_size,):
raise ValueError("a_init must have shape (bar_size,).")
a = jnp.clip(a, min=0.0)
a = a / jnp.sum(a)
else:
a = jnp.ones((bar_size,)) / bar_size

num_iter = self.max_iterations
if self.store_inner_errors:
errors = -jnp.ones((
Expand All @@ -246,8 +368,13 @@ def init_state(
errors=errors
)
abstract_tree = jax.eval_shape(
functools.partial(state.update, store_errors=self.store_inner_errors),
0, bar_prob, self.linear_solver
functools.partial(
state.update,
store_errors=self.store_inner_errors,
tau_a=self.tau_a,
tau_b=self.tau_b,
learn_a=self.learn_a,
), 0, bar_prob, self.linear_solver
)

linear_outputs = jax.tree.map(jnp.zeros_like, abstract_tree.linear_outputs)
Expand Down Expand Up @@ -278,28 +405,34 @@ def output_from_state( # noqa: D102
)

def iterations(
self, bar_size: int, bar_prob: barycenter_problem.FreeBarycenterProblem,
x_init: jnp.ndarray, rng: jax.Array
) -> FreeBarycenterState:
self,
bar_size: int,
bar_prob: barycenter_problem.FreeBarycenterProblem,
x_init: jnp.ndarray,
rng: jax.Array,
) -> FreeBarycenterOutput:
"""Wasserstein barycenter outer loop."""

def cond_fn(
iteration: int,
constants: Tuple[FreeWassersteinBarycenter,
barycenter_problem.FreeBarycenterProblem],
iteration: int, constants: barycenter_problem.FreeBarycenterProblem,
state: FreeBarycenterState
) -> bool:
return self._continue(state, iteration)

def body_fn(
iteration, constants: Tuple[FreeWassersteinBarycenter,
barycenter_problem.FreeBarycenterProblem],
iteration: int, constants: barycenter_problem.FreeBarycenterProblem,
state: FreeBarycenterState, compute_error: bool
) -> FreeBarycenterState:
del compute_error # Always assumed True
bar_prob = constants
return state.update(
iteration, bar_prob, self.linear_solver, self.store_inner_errors
iteration,
bar_prob,
self.linear_solver,
self.store_inner_errors,
tau_a=self.tau_a,
tau_b=self.tau_b,
learn_a=self.learn_a,
)

state = fixed_point_loop.fixpoint_iter(
Expand Down
91 changes: 91 additions & 0 deletions tests/solvers/linear/continuous_barycenter_mass_learning_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest

import jax
import jax.numpy as jnp
import numpy as np

from ott.problems.linear import barycenter_problem
from ott.solvers.linear import continuous_barycenter, sinkhorn


def _toy_two_atom_problem():
"""Two identical measures on [-1, +1] with skewed weights [0.9, 0.1]."""
y = jnp.array([
[[-1.0], [1.0]],
[[-1.0], [1.0]],
])
b = jnp.array([
[0.9, 0.1],
[0.9, 0.1],
])
return barycenter_problem.FreeBarycenterProblem(y=y, b=b)


class TestLearnBarycenterWeights:

@pytest.mark.fast()
def test_default_keeps_uniform_a(self, rng: jax.Array):
"""Without learn_a the barycenter weights stay uniform."""
bar_prob = _toy_two_atom_problem()
solver = continuous_barycenter.FreeWassersteinBarycenter(
sinkhorn.Sinkhorn(),
max_iterations=8,
)
out = solver(bar_prob, bar_size=2, rng=rng)
np.testing.assert_allclose(out.a, jnp.array([0.5, 0.5]), atol=1e-6)

@pytest.mark.fast()
def test_learn_a_balanced_stays_uniform(self, rng: jax.Array):
"""learn_a=True with tau_a=1 (balanced) cannot change weights."""
bar_prob = _toy_two_atom_problem()
solver = continuous_barycenter.FreeWassersteinBarycenter(
sinkhorn.Sinkhorn(),
learn_a=True,
tau_a=1.0,
max_iterations=8,
)
out = solver(bar_prob, bar_size=2, rng=rng)
np.testing.assert_allclose(out.a, jnp.array([0.5, 0.5]), atol=1e-6)

@pytest.mark.fast()
def test_learn_a_unbalanced_recovers_skewed_weights(self, rng: jax.Array):
"""Unbalanced + learn_a recovers the skewed input weights."""
bar_prob = _toy_two_atom_problem()
solver = continuous_barycenter.FreeWassersteinBarycenter(
sinkhorn.Sinkhorn(),
learn_a=True,
tau_a=0.9,
max_iterations=25,
)
out = solver(bar_prob, bar_size=2, rng=rng)

got = jnp.sort(out.a)[::-1]
target = jnp.array([0.9, 0.1])

np.testing.assert_allclose(got, target, atol=0.15)
assert jnp.isfinite(out.a).all()
assert jnp.all(out.a > 0.0)
np.testing.assert_allclose(jnp.sum(out.a), 1.0, atol=1e-6)
# Weights should have moved meaningfully away from uniform.
assert jnp.max(jnp.abs(out.a - 0.5)) > 1e-2

@pytest.mark.fast()
def test_invalid_tau_a(self):
with pytest.raises(ValueError, match="tau_a"):
continuous_barycenter.FreeWassersteinBarycenter(
sinkhorn.Sinkhorn(),
tau_a=0.0,
)
Loading