[Feature] Unbalanced free-support Wasserstein barycenter with learnable weights#676
[Feature] Unbalanced free-support Wasserstein barycenter with learnable weights#676huguesva wants to merge 2 commits intoott-jax:mainfrom
Conversation
Allow joint optimisation of support locations and weights in FreeWassersteinBarycenter via block-coordinate descent with unbalanced OT. New parameters: tau_a (marginal relaxation), learn_a (enable weight learning), a_init (custom initial weights).
- Add tau_b parameter for relaxing input-measure marginal constraints - Warn when learn_a=True with tau_a=1.0 (no effect in balanced mode) - Thread tau_b through update(), body_fn, init_state, tree_flatten
|
thanks a lot @huguesva and apologies for the delayed reply, i was traveling and on holidays. i am trying to follow the proposal, @michalk8 will have some comments related to code. do you have a paper that describes this approach? In our old paper (https://proceedings.mlr.press/v32/cuturi14.html) we were doing this using alternate optimization ( Thanks! |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #676 +/- ##
==========================================
- Coverage 87.35% 87.27% -0.08%
==========================================
Files 82 82
Lines 8476 8514 +38
Branches 581 589 +8
==========================================
+ Hits 7404 7431 +27
- Misses 922 931 +9
- Partials 150 152 +2
🚀 New features to boost your workflow:
|
|
Thanks @marcocuturi and @michalk8 for the answer and for building this great package! Yes it's actually not super sophisticated: the idea is still to do block coordinate descent on If we write the joint objective as: Unlike classical balanced OT barycenter, here the inner feasible set i.e. the normalised weighted arithmetic mean of the transport-plan row marginals. I am not sure that this has been stated explicitly in a paper, at least I did not find it, but it is a straightforward application of Danskin. If you are interested, I would be happy to open a follow-up PR covering the balanced OT barycenter case, implementing projected gradient descent or mirror descent on the simplex to optimise the barycenter weights, as described in your original paper with Arnaud Doucet. Thanks a lot! |
|
Thanks @huguesva ! I have a few questions:
|
|
Thanks @marcocuturi for the feedback and questions!
|
|
Thanks so much @marcocuturi for the very helpful feedback! |

Summary
tau_aandtau_bparameters toFreeWassersteinBarycenterfor unbalanced transportlearn_aoption to jointly learn barycenter weights via block-coordinate descent. Relaxing the barycenter-side marginal (tau_a < 1) yields a closed-form weight update: the normalised arithmetic mean of transport-plan row marginals, derived as the exact block-coordinate minimiser via Danskin's theorem for the KL(r_i || a) penalty. No gradient computation or inner optimisation loop is needed.a_initfor custom initial barycenter weights (resolves existing TODO)lambda_i * r_{ik}(row marginal) instead oflambda_ialone, down-weighting unreliable projectionslearn_a=Truewithtau_a=1.0(no effect in balanced mode)Test plan
continuous_barycenter_mass_learning_test.pycontinuous_barycenter_test.pytests pass (no regression)