Skip to content

[Feature] Unbalanced free-support Wasserstein barycenter with learnable weights#676

Open
huguesva wants to merge 2 commits intoott-jax:mainfrom
huguesva:feature/learn-barycenter-weights-v2
Open

[Feature] Unbalanced free-support Wasserstein barycenter with learnable weights#676
huguesva wants to merge 2 commits intoott-jax:mainfrom
huguesva:feature/learn-barycenter-weights-v2

Conversation

@huguesva
Copy link
Copy Markdown

@huguesva huguesva commented Feb 23, 2026

Summary

  • Add tau_a and tau_b parameters to FreeWassersteinBarycenter for unbalanced transport
  • Add learn_a option to jointly learn barycenter weights via block-coordinate descent. Relaxing the barycenter-side marginal (tau_a < 1) yields a closed-form weight update: the normalised arithmetic mean of transport-plan row marginals, derived as the exact block-coordinate minimiser via Danskin's theorem for the KL(r_i || a) penalty. No gradient computation or inner optimisation loop is needed.
  • Add a_init for custom initial barycenter weights (resolves existing TODO)
  • In unbalanced mode, the location update weights each measure by lambda_i * r_{ik} (row marginal) instead of lambda_i alone, down-weighting unreliable projections
  • Warning emitted when learn_a=True with tau_a=1.0 (no effect in balanced mode)
  • Fully backward compatible: all defaults produce identical behaviour to main

Test plan

  • 4 new tests in continuous_barycenter_mass_learning_test.py
  • All 20 existing continuous_barycenter_test.py tests pass (no regression)
  • Pre-commit passes

Allow joint optimisation of support locations and weights in
FreeWassersteinBarycenter via block-coordinate descent with
unbalanced OT.  New parameters: tau_a (marginal relaxation),
learn_a (enable weight learning), a_init (custom initial weights).
@huguesva huguesva closed this Feb 23, 2026
- Add tau_b parameter for relaxing input-measure marginal constraints
- Warn when learn_a=True with tau_a=1.0 (no effect in balanced mode)
- Thread tau_b through update(), body_fn, init_state, tree_flatten
@huguesva huguesva changed the title [Feature] Add learnable barycenter weights to free-support solver [Feature] Unbalanced free-support Wasserstein barycenter with learnable weights Feb 23, 2026
@huguesva huguesva reopened this Feb 23, 2026
@marcocuturi
Copy link
Copy Markdown
Contributor

marcocuturi commented Mar 16, 2026

thanks a lot @huguesva and apologies for the delayed reply, i was traveling and on holidays.

i am trying to follow the proposal, @michalk8 will have some comments related to code. do you have a paper that describes this approach? In our old paper (https://proceedings.mlr.press/v32/cuturi14.html) we were doing this using alternate optimization (a_new was updated, and then x_new, the idea being that new point location would yield a different histogram, and that a new histogram would lead to a different OT computation to update locations) but you are indicating that this is simpler for the unbalanced case. Do you have a reference for the "closed-form weight update" you mention?

Thanks!

@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 19, 2026

Codecov Report

❌ Patch coverage is 73.80952% with 11 lines in your changes missing coverage. Please review.
✅ Project coverage is 87.27%. Comparing base (7ecebc9) to head (13c0f48).
⚠️ Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
src/ott/solvers/linear/continuous_barycenter.py 73.80% 9 Missing and 2 partials ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #676      +/-   ##
==========================================
- Coverage   87.35%   87.27%   -0.08%     
==========================================
  Files          82       82              
  Lines        8476     8514      +38     
  Branches      581      589       +8     
==========================================
+ Hits         7404     7431      +27     
- Misses        922      931       +9     
- Partials      150      152       +2     
Files with missing lines Coverage Δ
src/ott/solvers/linear/continuous_barycenter.py 88.57% <73.80%> (-6.53%) ⬇️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@huguesva
Copy link
Copy Markdown
Author

Thanks @marcocuturi and @michalk8 for the answer and for building this great package!

Yes it's actually not super sophisticated: the idea is still to do block coordinate descent on a_new, x_new, and OT, but simply noting that in the unbalanced setting a_new also has a closed-form update.

If we write the joint objective as:

$$F(x, a) = \sum_{i=1}^{M} \lambda_i \min_{T_i \geq 0} \langle C_i, T_i \rangle + \varepsilon \mathrm{KL}(T_i | a \otimes b_i) + \rho_a \mathrm{KL}(T_i\mathbf{1} | a) + \rho_b \mathrm{KL}(T_i^\top\mathbf{1} | b_i)$$

Unlike classical balanced OT barycenter, here the inner feasible set ${T_i \geq 0}$ does not depend on $a$, so by Danskin's theorem the gradient of $F$ with respect to $a_k$ is the partial derivative of the UOT objective at the current optimum $T_i^\star$. Only $\varepsilon \mathrm{KL}(T_i | a \otimes b_i)$ and $\rho_a \mathrm{KL}(T_i\mathbf{1} | a)$ depend on $a$, and both yield $\partial/\partial a_k = -r_{i,k}/a_k + 1$ where $r_{i,k} = \sum_j T_{i,k,j}$. Setting $\partial F / \partial a_k = 0$ subject to $\sum_k a_k = 1$ gives the closed form:

$$a_k^{\text{new}} = \frac{\sum_i \lambda_i r_{i,k}}{\sum_{\ell}\sum_i \lambda_i r_{i,\ell}}$$

i.e. the normalised weighted arithmetic mean of the transport-plan row marginals. I am not sure that this has been stated explicitly in a paper, at least I did not find it, but it is a straightforward application of Danskin.

If you are interested, I would be happy to open a follow-up PR covering the balanced OT barycenter case, implementing projected gradient descent or mirror descent on the simplex to optimise the barycenter weights, as described in your original paper with Arnaud Doucet.

Thanks a lot!

@marcocuturi
Copy link
Copy Markdown
Contributor

Thanks @huguesva !

I have a few questions:

  • in the balanced case, it's also possible to get access to the derivative of a transport cost w.r.t. weights a if we want to do a gradient update. Of course, in the balanced case, if we optimize exactly for a, we would get something that looks more like the distribution resulting from the application of a nearest neighbor rule (what k-means implements in the M maximization step of EM).
  • it feels your solution is leveraging one OT compute to do two updates: both x and a. This differs quite significantly from a standard "EM" approach, in which one OT compute would move x, one OT compute would move a. Have you tried sticking more closely to EM by reusing a fresh OT compute to change weights? As you mentioned, this is what we were advocating in the 2014 paper.

@huguesva
Copy link
Copy Markdown
Author

huguesva commented Apr 7, 2026

Thanks @marcocuturi for the feedback and questions!

  • On the k-means connection: Yes! And more generally, relaxing the barycenter marginal constraint ($\tau_a &lt; 1$) interpolates between the balanced entropic OT barycenter ($\tau_a = 1$) and multi-view soft k-means ($\tau_a \to 0$, $\varepsilon &gt; 0$), where the OT solve reduces to a column-wise softmax. The current PR covers $\tau_a \in (0, 1]$. I can also add the $\tau_a = 0$ limit (softmax assignment, no Sinkhorn) in this PR if you think it's worth including.

  • On re-solving OT before the $a$ update: Interesting! The current approach is 3-block BCD on the joint objective. In the unbalanced case the feasible sets are independent so standard BCD results apply (monotone decrease, convergence to a stationary point). Using fresher OT plans for the $a$ update could potentially improve. We will benchmark both variants and keep you posted. Thanks !

@lf2684
Copy link
Copy Markdown

lf2684 commented Apr 16, 2026

Thanks so much @marcocuturi for the very helpful feedback!
We have tested the EM-like approach (performing a second OT solve after the support update), and in our preliminary experiments on biological datasets the results agree very well with the original implementation. Final OT costs differ only marginally, and barycenter weights are highly correlated. We’re happy to include the EM-style update as an optional setting if you think that would be useful!
Screenshot 2026-04-16 at 10 58 25 AM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants