From b75c879a913f3828b2292109adae26a80737e9f6 Mon Sep 17 00:00:00 2001 From: huguesva Date: Mon, 23 Feb 2026 21:30:37 +0000 Subject: [PATCH 1/2] [Feature] Add learnable barycenter weights to free-support solver Allow joint optimisation of support locations and weights in FreeWassersteinBarycenter via block-coordinate descent with unbalanced OT. New parameters: tau_a (marginal relaxation), learn_a (enable weight learning), a_init (custom initial weights). --- .../solvers/linear/continuous_barycenter.py | 160 +++++++++++++++--- ...ontinuous_barycenter_mass_learning_test.py | 91 ++++++++++ 2 files changed, 225 insertions(+), 26 deletions(-) create mode 100644 tests/solvers/linear/continuous_barycenter_mass_learning_test.py diff --git a/src/ott/solvers/linear/continuous_barycenter.py b/src/ott/solvers/linear/continuous_barycenter.py index 84dc46cc1..595d8c4d0 100644 --- a/src/ott/solvers/linear/continuous_barycenter.py +++ b/src/ott/solvers/linear/continuous_barycenter.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -from typing import Any, NamedTuple, Optional, Tuple, Union +from typing import Any, Dict, NamedTuple, Optional, Sequence, Tuple, Union import jax import jax.numpy as jnp @@ -58,8 +58,14 @@ def set(self, **kwargs: Any) -> "FreeBarycenterState": return self._replace(**kwargs) def update( - self, iteration: int, bar_prob: barycenter_problem.FreeBarycenterProblem, - linear_solver: Any, store_errors: bool + self, + iteration: int, + bar_prob: barycenter_problem.FreeBarycenterProblem, + linear_solver: Any, + store_errors: bool, + *, + tau_a: float = 1.0, + learn_a: bool = False, ) -> "FreeBarycenterState": """Update the state of the solver. @@ -68,6 +74,14 @@ def update( bar_prob: the barycenter problem. linear_solver: the linear OT solver to use. store_errors: whether to store the errors of the inner loop. + tau_a: Relaxation parameter for the barycenter marginal constraint. + Must lie in ``(0, 1]``. When ``tau_a < 1``, unbalanced transport + is used on the barycenter side, allowing the transport row marginals + to differ from the current barycenter weights ``a``. + learn_a: If ``True``, update barycenter weights at every iteration + using the weighted arithmetic mean of transport-plan row marginals. + Requires ``tau_a < 1`` to be effective, since balanced transport + forces row marginals to equal ``a``. Returns: The updated state. @@ -81,9 +95,14 @@ def solve_linear_ot( geom = pointcloud.PointCloud( x, y, cost_fn=bar_prob.cost_fn, epsilon=bar_prob.epsilon ) - prob = linear_problem.LinearProblem(geom, a=a, b=b) + # Preserve exact old behaviour when tau_a == 1.0. + if tau_a < 1.0: + prob = linear_problem.LinearProblem( + geom, a=a, b=b, tau_a=tau_a, tau_b=1.0 + ) + else: + prob = linear_problem.LinearProblem(geom, a=a, b=b) out = linear_solver(prob) - # instantiate matrix since it is a property of out. return out, out.matrix outs, matrices = solve_linear_ot(self.a, self.x, seg_b, seg_y) @@ -102,17 +121,43 @@ def solve_linear_ot( errors = None # Approximation of barycenter as barycenter of barycenters per measure. - barycenters_per_measure = mu.barycentric_projection( matrices, seg_y, bar_prob.cost_fn ) - - x_new = jax.vmap( - lambda w, y: bar_prob.cost_fn.barycenter(w, y)[0], in_axes=[None, 1] - )(bar_prob.weights, barycenters_per_measure) + row_marginals = jnp.sum(matrices, axis=2) # [num_measures, bar_size] + + if tau_a < 1.0: + # In unbalanced mode each atom k aggregates measure i with weight + # proportional to lambda_i * r_{ik} (measure weight times the row + # marginal of the plan at that atom). + atom_weights = bar_prob.weights[:, None] * row_marginals + normalizer = jnp.sum(atom_weights, axis=0, keepdims=True) + atom_weights = jnp.where( + normalizer > 0.0, + atom_weights / normalizer, + bar_prob.weights[:, None], + ) + x_new = jax.vmap( + lambda w, y: bar_prob.cost_fn.barycenter(w, y)[0], in_axes=[1, 1] + )(atom_weights, barycenters_per_measure) + else: + x_new = jax.vmap( + lambda w, y: bar_prob.cost_fn.barycenter(w, y)[0], in_axes=[None, 1] + )(bar_prob.weights, barycenters_per_measure) + + # Optionally learn barycenter weights from transport-plan row marginals. + # The weighted arithmetic mean is the exact block-coordinate minimiser + # for the joint objective with KL(r_i || a) penalty. + if learn_a: + a_new = jnp.sum(row_marginals * bar_prob.weights[:, None], axis=0) + a_new = jnp.clip(a_new, min=1e-12) + a_new = a_new / jnp.sum(a_new) + else: + a_new = self.a return self.set( x=x_new, + a=a_new, costs=updated_costs, linear_convergence=linear_convergence, linear_outputs=outs, @@ -181,7 +226,55 @@ def linear_output_at_index(self, i: int) -> LinearOutput: @jax.tree_util.register_pytree_node_class class FreeWassersteinBarycenter(was_solver.WassersteinSolver): - """Continuous Wasserstein barycenter solver :cite:`cuturi:14`.""" + """Continuous Wasserstein barycenter solver :cite:`cuturi:14`. + + Args: + linear_solver: Inner linear OT solver. + tau_a: Relaxation parameter for the barycenter marginal constraint. + Must lie in ``(0, 1]``. When ``tau_a < 1``, unbalanced transport is + used on the barycenter side, so that the transport row marginals can + differ from ``a``. Default ``1.0`` (balanced). + learn_a: If ``True``, update barycenter weights at every outer + iteration using the weighted arithmetic mean of transport-plan row + marginals. Requires ``tau_a < 1`` to be effective, since balanced + transport forces row marginals to equal ``a``. + a_init: Optional initial barycenter weights of shape ``(bar_size,)``. + If ``None``, weights are initialised uniformly. + kwargs: Forwarded to + :class:`~ott.solvers.was_solver.WassersteinSolver`. + """ + + def __init__( + self, + linear_solver, + *, + tau_a: float = 1.0, + learn_a: bool = False, + a_init: Optional[jnp.ndarray] = None, + **kwargs, + ): + super().__init__(linear_solver=linear_solver, **kwargs) + if not (0.0 < tau_a <= 1.0): + raise ValueError("tau_a must be in (0, 1].") + self.tau_a = tau_a + self.learn_a = learn_a + self.a_init = a_init + + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 + return ([self.linear_solver, self.threshold, self.a_init], { + "min_iterations": self.min_iterations, + "max_iterations": self.max_iterations, + "store_inner_errors": self.store_inner_errors, + "tau_a": self.tau_a, + "learn_a": self.learn_a, + }) + + @classmethod + def tree_unflatten( # noqa: D102 + cls, aux_data: Dict[str, Any], children: Sequence[Any] + ) -> "FreeWassersteinBarycenter": + linear_solver, threshold, a_init = children + return cls(linear_solver, threshold=threshold, a_init=a_init, **aux_data) def __call__( # noqa: D102 self, @@ -189,7 +282,7 @@ def __call__( # noqa: D102 bar_size: int = 100, x_init: Optional[jnp.ndarray] = None, rng: Optional[jax.Array] = None, - ) -> FreeBarycenterState: + ) -> FreeBarycenterOutput: rng = utils.default_prng_key(rng) return self.iterations(bar_size, bar_prob, x_init, rng) @@ -218,7 +311,6 @@ def init_state( assert bar_size == x_init.shape[0] x = x_init else: - # sample randomly points in the support of the y measures rng = utils.default_prng_key(rng) indices_subset = jax.random.choice( rng, @@ -229,8 +321,15 @@ def init_state( ) x = bar_prob.flattened_y[indices_subset, :] - # TODO(cuturi) expand to non-uniform weights for barycenter. - a = jnp.ones((bar_size,)) / bar_size + if self.a_init is not None: + a = jnp.asarray(self.a_init) + if a.shape != (bar_size,): + raise ValueError("a_init must have shape (bar_size,).") + a = jnp.clip(a, min=0.0) + a = a / jnp.sum(a) + else: + a = jnp.ones((bar_size,)) / bar_size + num_iter = self.max_iterations if self.store_inner_errors: errors = -jnp.ones(( @@ -246,8 +345,12 @@ def init_state( errors=errors ) abstract_tree = jax.eval_shape( - functools.partial(state.update, store_errors=self.store_inner_errors), - 0, bar_prob, self.linear_solver + functools.partial( + state.update, + store_errors=self.store_inner_errors, + tau_a=self.tau_a, + learn_a=self.learn_a, + ), 0, bar_prob, self.linear_solver ) linear_outputs = jax.tree.map(jnp.zeros_like, abstract_tree.linear_outputs) @@ -278,28 +381,33 @@ def output_from_state( # noqa: D102 ) def iterations( - self, bar_size: int, bar_prob: barycenter_problem.FreeBarycenterProblem, - x_init: jnp.ndarray, rng: jax.Array - ) -> FreeBarycenterState: + self, + bar_size: int, + bar_prob: barycenter_problem.FreeBarycenterProblem, + x_init: jnp.ndarray, + rng: jax.Array, + ) -> FreeBarycenterOutput: """Wasserstein barycenter outer loop.""" def cond_fn( - iteration: int, - constants: Tuple[FreeWassersteinBarycenter, - barycenter_problem.FreeBarycenterProblem], + iteration: int, constants: barycenter_problem.FreeBarycenterProblem, state: FreeBarycenterState ) -> bool: return self._continue(state, iteration) def body_fn( - iteration, constants: Tuple[FreeWassersteinBarycenter, - barycenter_problem.FreeBarycenterProblem], + iteration: int, constants: barycenter_problem.FreeBarycenterProblem, state: FreeBarycenterState, compute_error: bool ) -> FreeBarycenterState: del compute_error # Always assumed True bar_prob = constants return state.update( - iteration, bar_prob, self.linear_solver, self.store_inner_errors + iteration, + bar_prob, + self.linear_solver, + self.store_inner_errors, + tau_a=self.tau_a, + learn_a=self.learn_a, ) state = fixed_point_loop.fixpoint_iter( diff --git a/tests/solvers/linear/continuous_barycenter_mass_learning_test.py b/tests/solvers/linear/continuous_barycenter_mass_learning_test.py new file mode 100644 index 000000000..9f5b36029 --- /dev/null +++ b/tests/solvers/linear/continuous_barycenter_mass_learning_test.py @@ -0,0 +1,91 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +import jax +import jax.numpy as jnp +import numpy as np + +from ott.problems.linear import barycenter_problem +from ott.solvers.linear import continuous_barycenter, sinkhorn + + +def _toy_two_atom_problem(): + """Two identical measures on [-1, +1] with skewed weights [0.9, 0.1].""" + y = jnp.array([ + [[-1.0], [1.0]], + [[-1.0], [1.0]], + ]) + b = jnp.array([ + [0.9, 0.1], + [0.9, 0.1], + ]) + return barycenter_problem.FreeBarycenterProblem(y=y, b=b) + + +class TestLearnBarycenterWeights: + + @pytest.mark.fast() + def test_default_keeps_uniform_a(self, rng: jax.Array): + """Without learn_a the barycenter weights stay uniform.""" + bar_prob = _toy_two_atom_problem() + solver = continuous_barycenter.FreeWassersteinBarycenter( + sinkhorn.Sinkhorn(), + max_iterations=8, + ) + out = solver(bar_prob, bar_size=2, rng=rng) + np.testing.assert_allclose(out.a, jnp.array([0.5, 0.5]), atol=1e-6) + + @pytest.mark.fast() + def test_learn_a_balanced_stays_uniform(self, rng: jax.Array): + """learn_a=True with tau_a=1 (balanced) cannot change weights.""" + bar_prob = _toy_two_atom_problem() + solver = continuous_barycenter.FreeWassersteinBarycenter( + sinkhorn.Sinkhorn(), + learn_a=True, + tau_a=1.0, + max_iterations=8, + ) + out = solver(bar_prob, bar_size=2, rng=rng) + np.testing.assert_allclose(out.a, jnp.array([0.5, 0.5]), atol=1e-6) + + @pytest.mark.fast() + def test_learn_a_unbalanced_recovers_skewed_weights(self, rng: jax.Array): + """Unbalanced + learn_a recovers the skewed input weights.""" + bar_prob = _toy_two_atom_problem() + solver = continuous_barycenter.FreeWassersteinBarycenter( + sinkhorn.Sinkhorn(), + learn_a=True, + tau_a=0.9, + max_iterations=25, + ) + out = solver(bar_prob, bar_size=2, rng=rng) + + got = jnp.sort(out.a)[::-1] + target = jnp.array([0.9, 0.1]) + + np.testing.assert_allclose(got, target, atol=0.15) + assert jnp.isfinite(out.a).all() + assert jnp.all(out.a > 0.0) + np.testing.assert_allclose(jnp.sum(out.a), 1.0, atol=1e-6) + # Weights should have moved meaningfully away from uniform. + assert jnp.max(jnp.abs(out.a - 0.5)) > 1e-2 + + @pytest.mark.fast() + def test_invalid_tau_a(self): + with pytest.raises(ValueError, match="tau_a"): + continuous_barycenter.FreeWassersteinBarycenter( + sinkhorn.Sinkhorn(), + tau_a=0.0, + ) From 13c0f48a7b9fd003af63ec18759b4ff79e58af8d Mon Sep 17 00:00:00 2001 From: huguesva Date: Mon, 23 Feb 2026 23:53:46 +0000 Subject: [PATCH 2/2] Add tau_b support and learn_a warning to free-support barycenter - Add tau_b parameter for relaxing input-measure marginal constraints - Warn when learn_a=True with tau_a=1.0 (no effect in balanced mode) - Thread tau_b through update(), body_fn, init_state, tree_flatten --- .../solvers/linear/continuous_barycenter.py | 31 +++++++++++++++++-- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/src/ott/solvers/linear/continuous_barycenter.py b/src/ott/solvers/linear/continuous_barycenter.py index 595d8c4d0..221efe396 100644 --- a/src/ott/solvers/linear/continuous_barycenter.py +++ b/src/ott/solvers/linear/continuous_barycenter.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools +import warnings from typing import Any, Dict, NamedTuple, Optional, Sequence, Tuple, Union import jax @@ -65,6 +66,7 @@ def update( store_errors: bool, *, tau_a: float = 1.0, + tau_b: float = 1.0, learn_a: bool = False, ) -> "FreeBarycenterState": """Update the state of the solver. @@ -78,6 +80,10 @@ def update( Must lie in ``(0, 1]``. When ``tau_a < 1``, unbalanced transport is used on the barycenter side, allowing the transport row marginals to differ from the current barycenter weights ``a``. + tau_b: Relaxation parameter for the input-measure marginal constraint. + Must lie in ``(0, 1]``. When ``tau_b < 1``, the input measures + do not need to be fully transported, providing robustness to + outliers. learn_a: If ``True``, update barycenter weights at every iteration using the weighted arithmetic mean of transport-plan row marginals. Requires ``tau_a < 1`` to be effective, since balanced transport @@ -95,10 +101,10 @@ def solve_linear_ot( geom = pointcloud.PointCloud( x, y, cost_fn=bar_prob.cost_fn, epsilon=bar_prob.epsilon ) - # Preserve exact old behaviour when tau_a == 1.0. - if tau_a < 1.0: + # Preserve exact old behaviour when both taus are 1.0. + if tau_a < 1.0 or tau_b < 1.0: prob = linear_problem.LinearProblem( - geom, a=a, b=b, tau_a=tau_a, tau_b=1.0 + geom, a=a, b=b, tau_a=tau_a, tau_b=tau_b ) else: prob = linear_problem.LinearProblem(geom, a=a, b=b) @@ -234,6 +240,10 @@ class FreeWassersteinBarycenter(was_solver.WassersteinSolver): Must lie in ``(0, 1]``. When ``tau_a < 1``, unbalanced transport is used on the barycenter side, so that the transport row marginals can differ from ``a``. Default ``1.0`` (balanced). + tau_b: Relaxation parameter for the input-measure marginal constraint. + Must lie in ``(0, 1]``. When ``tau_b < 1``, the input measures do + not need to be fully transported, providing robustness to outliers. + Default ``1.0`` (balanced). learn_a: If ``True``, update barycenter weights at every outer iteration using the weighted arithmetic mean of transport-plan row marginals. Requires ``tau_a < 1`` to be effective, since balanced @@ -249,6 +259,7 @@ def __init__( linear_solver, *, tau_a: float = 1.0, + tau_b: float = 1.0, learn_a: bool = False, a_init: Optional[jnp.ndarray] = None, **kwargs, @@ -256,7 +267,18 @@ def __init__( super().__init__(linear_solver=linear_solver, **kwargs) if not (0.0 < tau_a <= 1.0): raise ValueError("tau_a must be in (0, 1].") + if not (0.0 < tau_b <= 1.0): + raise ValueError("tau_b must be in (0, 1].") + if learn_a and tau_a == 1.0: + warnings.warn( + "`learn_a=True` has no effect when `tau_a=1.0` (balanced). " + "The current implementation learns weights from unbalanced " + "row marginals, which coincide with `a` when balanced. " + "Set `tau_a < 1` to enable weight learning.", + stacklevel=2, + ) self.tau_a = tau_a + self.tau_b = tau_b self.learn_a = learn_a self.a_init = a_init @@ -266,6 +288,7 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 "max_iterations": self.max_iterations, "store_inner_errors": self.store_inner_errors, "tau_a": self.tau_a, + "tau_b": self.tau_b, "learn_a": self.learn_a, }) @@ -349,6 +372,7 @@ def init_state( state.update, store_errors=self.store_inner_errors, tau_a=self.tau_a, + tau_b=self.tau_b, learn_a=self.learn_a, ), 0, bar_prob, self.linear_solver ) @@ -407,6 +431,7 @@ def body_fn( self.linear_solver, self.store_inner_errors, tau_a=self.tau_a, + tau_b=self.tau_b, learn_a=self.learn_a, )