From 40f3115dfebab54d603d07dbc95aed5d932ca7f1 Mon Sep 17 00:00:00 2001 From: DhruvaRajwade Date: Wed, 11 Mar 2026 10:23:08 +0100 Subject: [PATCH 1/6] fix(cmonge_gap): replace jnp.unique loop with segment interface for JIT compatibility The original cmonge_gap_from_samples used a Python for-loop over jnp.unique(condition), which breaks JAX JIT compilation since jnp.unique returns a dynamically-sized array. Replace with _segment_interface which pads per-condition point clouds to a fixed max_measure_size and vmaps the per-segment Monge gap computation. This makes the function fully JIT-compatible. The eval_fn computes per-segment: displacement_cost - ent_reg_cost, matching the definition in monge_gap_from_samples. Padded entries have zero weight and do not affect the result. New parameters num_segments and max_measure_size are required for JIT (consistent with segment_sinkhorn API). Cost function parameters (cost_fn, epsilon, relative_epsilon, scale_cost) are now explicit rather than passed through **kwargs. --- src/ott/neural/methods/__init__.py | 2 +- .../neural/methods/conditional_monge_gap.py | 135 ++++++++++++++++++ 2 files changed, 136 insertions(+), 1 deletion(-) create mode 100644 src/ott/neural/methods/conditional_monge_gap.py diff --git a/src/ott/neural/methods/__init__.py b/src/ott/neural/methods/__init__.py index ea3e51b31..b3883052a 100644 --- a/src/ott/neural/methods/__init__.py +++ b/src/ott/neural/methods/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import expectile_neural_dual, flow_matching, monge_gap, neuraldual +from . import conditional_monge_gap, expectile_neural_dual, flow_matching, monge_gap, neuraldual diff --git a/src/ott/neural/methods/conditional_monge_gap.py b/src/ott/neural/methods/conditional_monge_gap.py new file mode 100644 index 000000000..e039f63c2 --- /dev/null +++ b/src/ott/neural/methods/conditional_monge_gap.py @@ -0,0 +1,135 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import ( + Any, + Literal, + Optional, + Tuple, + Union, +) + +import jax +import jax.numpy as jnp + +from ott.geometry import costs, pointcloud, segment +from ott.problems.linear import linear_problem +from ott.solvers.linear import sinkhorn + +__all__ = ["cmonge_gap_from_samples"] + + +def cmonge_gap_from_samples( + source: jnp.ndarray, + target: jnp.ndarray, + condition: jnp.ndarray, + cost_fn: Optional[costs.CostFn] = None, + epsilon: Optional[float] = None, + relative_epsilon: Optional[Literal["mean", "std"]] = None, + scale_cost: Union[float, Literal["mean", "max_cost", "median"]] = 1.0, + return_output: bool = False, + num_segments: Optional[int] = None, + max_measure_size: Optional[int] = None, + **kwargs: Any, +) -> Union[float, Tuple[float, jnp.ndarray]]: + r"""Conditional Monge gap from samples using the segment interface. + + Computes the average Monge gap across conditions: + + .. math:: + + \frac{1}{K} \sum_{k=1}^{K} \left[ + \frac{1}{n_k} \sum_{i:\, c_i = k} c(x_i, y_i) - + W_{c, \varepsilon}\!\bigl(\hat{\rho}_{n_k}^{(k)},\, + \hat{\nu}_{n_k}^{(k)}\bigr) \right] + + where :math:`W_{c, \varepsilon}` is the + :term:`entropy-regularized optimal transport` cost. + + This implementation uses :func:`~ott.geometry.segment._segment_interface` + to pad and ``vmap`` across conditions, making it fully JIT-compatible. + + Args: + source: samples from first measure, array of shape ``[n, d]``. + target: samples from second measure, array of shape ``[n, d]``. + Assumed paired with ``source``, i.e. ``target[i] = T(source[i])``. + condition: integer array of shape ``[n]`` indicating the condition + for each source-target pair. Values in ``range(num_segments)``. + cost_fn: a cost function between two points in dimension :math:`d`. + If :obj:`None`, :class:`~ott.geometry.costs.SqEuclidean` is used. + epsilon: regularization parameter. See + :class:`~ott.geometry.pointcloud.PointCloud`. + relative_epsilon: when set, ``epsilon`` refers to a fraction of the + :attr:`~ott.geometry.pointcloud.PointCloud.mean_cost_matrix`. + scale_cost: option to rescale the cost matrix. Implemented scalings + are ``'median'``, ``'mean'`` and ``'max_cost'``. Alternatively, a + float factor can be given to rescale the cost such that + ``cost_matrix /= scale_cost``. + return_output: if :obj:`True`, also return per-condition Monge gaps. + num_segments: number of distinct conditions. Required for JIT. + max_measure_size: maximum number of points in any single condition + (used for padding). Required for JIT. + kwargs: keyword arguments for the + :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver. + + Returns: + The average Monge gap across conditions and, when ``return_output`` + is :obj:`True`, a ``[num_segments]`` array of per-condition gaps. + """ + cost_fn = costs.SqEuclidean() if cost_fn is None else cost_fn + dim = source.shape[1] + padding_vector = cost_fn._padder(dim=dim) + + def eval_fn( + padded_x: jnp.ndarray, + padded_y: jnp.ndarray, + padded_weight_x: jnp.ndarray, + padded_weight_y: jnp.ndarray, + ) -> jnp.ndarray: + """Monge gap for a single (padded) condition segment.""" + # Displacement cost: weighted mean of pairwise costs c(x_i, T(x_i)). + # Padded entries have weight 0, so they do not contribute. + pairwise_costs = jax.vmap(cost_fn)(padded_x, padded_y) + displacement_cost = jnp.sum(pairwise_costs * padded_weight_x) + + # Entropy-regularized OT cost W_{c,ε}. + geom = pointcloud.PointCloud( + padded_x, + padded_y, + cost_fn=cost_fn, + epsilon=epsilon, + relative_epsilon=relative_epsilon, + scale_cost=scale_cost, + ) + prob = linear_problem.LinearProblem( + geom, a=padded_weight_x, b=padded_weight_y + ) + solver = sinkhorn.Sinkhorn(**kwargs) + out = solver(prob) + + return displacement_cost - out.ent_reg_cost + + per_condition_gaps = segment._segment_interface( + x=source, + y=target, + eval_fn=eval_fn, + num_segments=num_segments, + max_measure_size=max_measure_size, + segment_ids_x=condition, + segment_ids_y=condition, + indices_are_sorted=False, + padding_vector=padding_vector, + ) + + avg_gap = jnp.mean(per_condition_gaps) + return (avg_gap, per_condition_gaps) if return_output else avg_gap From a14563c11adcadc0bf5be15dc340b5d87d45b609 Mon Sep 17 00:00:00 2001 From: DhruvaRajwade Date: Wed, 11 Mar 2026 15:37:05 +0100 Subject: [PATCH 2/6] feat: add ConditionalMongeGapEstimator for training condition-aware maps Add the estimator class that mirrors MongeGapEstimator but handles condition-dependent neural maps T(x, c) with per-condition Monge gap regularization via cmonge_gap_from_samples. Changes: - ConditionalMongeGapEstimator in conditional_monge_gap.py: training loop with 3-arg regularizer(source, mapped, labels), 4-iterator batch protocol, JIT-compiled step function - ConditionalDataset + create_conditional_gaussian_mixture_samplers in datasets.py: synchronized 4-iterator data pipeline for testing - Export conditional_perturbation_network from networks/__init__ - 16 tests: 8 unit tests for cmonge_gap_from_samples (non-negativity, JIT consistency, loop baseline match, identity vs random, cost fns, return shape) + 2 integration tests for the estimator (convergence, no-regularizer mode) --- src/ott/datasets.py | 143 ++++++++- .../neural/methods/conditional_monge_gap.py | 296 +++++++++++++++++- src/ott/neural/networks/__init__.py | 9 +- .../conditional_perturbation_network.py | 135 ++++++++ .../methods/conditional_monge_gap_test.py | 280 +++++++++++++++++ 5 files changed, 859 insertions(+), 4 deletions(-) create mode 100644 src/ott/neural/networks/conditional_perturbation_network.py create mode 100644 tests/neural/methods/conditional_monge_gap_test.py diff --git a/src/ott/datasets.py b/src/ott/datasets.py index 206123b84..2e6ee4d42 100644 --- a/src/ott/datasets.py +++ b/src/ott/datasets.py @@ -18,7 +18,13 @@ import jax.numpy as jnp import numpy as np -__all__ = ["create_gaussian_mixture_samplers", "Dataset", "GaussianMixture"] +__all__ = [ + "create_gaussian_mixture_samplers", + "create_conditional_gaussian_mixture_samplers", + "ConditionalDataset", + "Dataset", + "GaussianMixture", +] from ott import utils @@ -36,6 +42,21 @@ class Dataset(NamedTuple): target_iter: Iterator[jnp.ndarray] +class ConditionalDataset(NamedTuple): + r"""Samplers from conditional source and target measures. + + Args: + source_iter: loader for the source measure, ``[batch, d]`` + target_iter: loader for the target measure, ``[batch, d]`` + condition_iter: loader for continuous condition vectors, ``[batch, dim_c]`` + label_iter: loader for integer condition labels, ``[batch]`` + """ + source_iter: Iterator[jnp.ndarray] + target_iter: Iterator[jnp.ndarray] + condition_iter: Iterator[jnp.ndarray] + label_iter: Iterator[jnp.ndarray] + + @dataclasses.dataclass class GaussianMixture: """A mixture of Gaussians. @@ -144,3 +165,123 @@ def create_gaussian_mixture_samplers( ) dim_data = 2 return train_dataset, valid_dataset, dim_data + + +@dataclasses.dataclass +class ConditionalGaussianMixture: + """Conditional Gaussian sampler for testing. + + For each condition *k*, draws source ~ N(0, I) and target ~ source + offset_k. + Condition vectors are one-hot encoded labels. + + Args: + num_conditions: number of distinct conditions. + batch_size: total batch size (divided equally among conditions). + dim: data dimensionality. + offsets: ``[num_conditions, dim]`` translation per condition. + rng: initial PRNG key. + """ + num_conditions: int + batch_size: int + dim: int + offsets: jnp.ndarray + rng: jax.Array + + def __iter__(self) -> Iterator[Tuple[jnp.ndarray, ...]]: + return self._generate() + + def _generate(self) -> Iterator[Tuple[jnp.ndarray, ...]]: + rng = self.rng + per_cond = self.batch_size // self.num_conditions + while True: + rng, rng_s = jax.random.split(rng) + sources, targets, conds, labels = [], [], [], [] + for k in range(self.num_conditions): + rng_s, rng_k = jax.random.split(rng_s) + s = jax.random.normal(rng_k, (per_cond, self.dim)) + t = s + self.offsets[k] + c = jnp.zeros((per_cond, self.num_conditions)).at[:, k].set(1.0) + lab = jnp.full((per_cond,), k, dtype=jnp.int32) + sources.append(s) + targets.append(t) + conds.append(c) + labels.append(lab) + yield ( + jnp.concatenate(sources), + jnp.concatenate(targets), + jnp.concatenate(conds), + jnp.concatenate(labels), + ) + + +def create_conditional_gaussian_mixture_samplers( + num_conditions: int = 3, + dim: int = 2, + train_batch_size: int = 90, + valid_batch_size: int = 90, + rng: Optional[jax.Array] = None, +) -> Tuple[ConditionalDataset, ConditionalDataset, int, int, int]: + """Create conditional Gaussian samplers for testing. + + Each condition defines a different translation of the source distribution. + + Args: + num_conditions: number of distinct conditions. + dim: data dimensionality. + train_batch_size: training batch size (should be divisible by + ``num_conditions``). + valid_batch_size: validation batch size. + rng: initial PRNG key. + + Returns: + ``(train_dataset, valid_dataset, dim_data, num_conditions, + max_measure_size)`` where ``max_measure_size = + batch_size // num_conditions``. + """ + rng = utils.default_prng_key(rng) + rng1, rng2, rng_off = jax.random.split(rng, 3) + + # Each condition has a different offset (translation) + offsets = jax.random.normal(rng_off, (num_conditions, dim)) * 3.0 + + def _make_dataset( + bs: int, key: jax.Array, + ) -> ConditionalDataset: + sampler = ConditionalGaussianMixture( + num_conditions=num_conditions, + batch_size=bs, + dim=dim, + offsets=offsets, + rng=key, + ) + gen = iter(sampler) + # Cache the current batch so all 4 iterators stay synchronized. + cache = {} + + def _next_batch(): + if "batch" not in cache: + cache["batch"] = next(gen) + return cache + + def _iter(idx: int) -> Iterator[jnp.ndarray]: + while True: + c = _next_batch() + val = c["batch"][idx] + # Mark this index as consumed; when all 4 are consumed, clear cache. + c.setdefault("consumed", set()) + c["consumed"].add(idx) + if len(c["consumed"]) == 4: + cache.clear() + yield val + + return ConditionalDataset( + source_iter=_iter(0), + target_iter=_iter(1), + condition_iter=_iter(2), + label_iter=_iter(3), + ) + + train_ds = _make_dataset(train_batch_size, rng1) + valid_ds = _make_dataset(valid_batch_size, rng2) + max_measure_size = train_batch_size // num_conditions + return train_ds, valid_ds, dim, num_conditions, max_measure_size diff --git a/src/ott/neural/methods/conditional_monge_gap.py b/src/ott/neural/methods/conditional_monge_gap.py index e039f63c2..f31629d31 100644 --- a/src/ott/neural/methods/conditional_monge_gap.py +++ b/src/ott/neural/methods/conditional_monge_gap.py @@ -11,10 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import collections +import functools from typing import ( Any, + Callable, + Dict, + Iterator, Literal, Optional, + Sequence, Tuple, Union, ) @@ -22,11 +28,22 @@ import jax import jax.numpy as jnp +import optax +from flax.core import frozen_dict +from flax.training import train_state + +from ott import utils from ott.geometry import costs, pointcloud, segment +from ott.neural.networks.conditional_perturbation_network import ( + ConditionalPerturbationNetwork, +) from ott.problems.linear import linear_problem from ott.solvers.linear import sinkhorn -__all__ = ["cmonge_gap_from_samples"] +__all__ = [ + "cmonge_gap_from_samples", + "ConditionalMongeGapEstimator", +] def cmonge_gap_from_samples( @@ -90,6 +107,9 @@ def cmonge_gap_from_samples( dim = source.shape[1] padding_vector = cost_fn._padder(dim=dim) + # NOTE: Eval function takes some logic from: + # ott.neural.methods.monge_gap.monge_gap_from_samples` + # as well as `ott.geometry.segment.py` def eval_fn( padded_x: jnp.ndarray, padded_y: jnp.ndarray, @@ -133,3 +153,277 @@ def eval_fn( avg_gap = jnp.mean(per_condition_gaps) return (avg_gap, per_condition_gaps) if return_output else avg_gap + + +class ConditionalMongeGapEstimator: + r"""Conditional map estimator between probability measures. + + Estimates a condition-dependent map :math:`T(\cdot, c)` by minimizing: + + .. math:: + + \min_\theta \; \Delta\bigl(T_\theta(\cdot, c) \sharp \mu,\, \nu\bigr) + + \lambda \; R_{\text{cond}}\bigl(T_\theta(\cdot, c) \sharp \rho,\, + \rho \mid c\bigr) + + where :math:`\Delta` is a fitting loss (e.g. + :func:`~ott.tools.sinkhorn_divergence.sinkdiv`), + :math:`R_{\text{cond}}` is the conditional Monge gap regularizer + :func:`cmonge_gap_from_samples`, and :math:`c` is a condition label. + + This mirrors :class:`~ott.neural.methods.monge_gap.MongeGapEstimator` + but handles condition-aware maps and per-condition regularization. + + Args: + dim_data: input dimensionality of the data. + model: a :class:`~ott.neural.networks.\ +conditional_perturbation_network.ConditionalPerturbationNetwork` or any + ``nn.Module`` whose ``__call__`` signature is ``(x, c)``. + optimizer: optimizer for the map parameters. + fitting_loss: callable ``(mapped, target) -> (loss, log)`` that + measures how well the pushforward matches the target distribution. + regularizer: callable ``(source, mapped, condition_labels) -> + (loss, log)`` that computes the conditional Monge gap or similar + per-condition regularizer. + regularizer_strength: scalar or schedule for :math:`\lambda`. + num_train_iters: number of training iterations. + logging: whether to record train/eval metrics. + valid_freq: how often to evaluate on the validation set. + rng: random seed. + """ + + def __init__( + self, + dim_data: int, + model: ConditionalPerturbationNetwork, + optimizer: Optional[optax.OptState] = None, + fitting_loss: Optional[ + Callable[ + [jnp.ndarray, jnp.ndarray], Tuple[float, Optional[Any]] + ] + ] = None, + regularizer: Optional[ + Callable[ + [jnp.ndarray, jnp.ndarray, jnp.ndarray], + Tuple[float, Optional[Any]], + ] + ] = None, + regularizer_strength: Union[float, Sequence[float]] = 1.0, + num_train_iters: int = 10_000, + logging: bool = False, + valid_freq: int = 500, + rng: Optional[jax.Array] = None, + ): + self._fitting_loss = fitting_loss + self._regularizer = regularizer + self.regularizer_strength = jnp.repeat( + jnp.atleast_2d(regularizer_strength), + num_train_iters, + total_repeat_length=num_train_iters, + axis=0, + ).ravel() + self.num_train_iters = num_train_iters + self.logging = logging + self.valid_freq = valid_freq + self.rng = utils.default_prng_key(rng) + + if optimizer is None: + optimizer = optax.adam(learning_rate=0.001) + + self.setup(dim_data, model, optimizer) + + def setup( + self, + dim_data: int, + neural_net: ConditionalPerturbationNetwork, + optimizer: optax.OptState, + ): + """Set up all components required to train the network.""" + self.state_neural_net = neural_net.create_train_state( + self.rng, optimizer, dim_data + ) + self.step_fn = self._get_step_fn() + + @property + def regularizer( + self, + ) -> Callable[ + [jnp.ndarray, jnp.ndarray, jnp.ndarray], Tuple[float, Optional[Any]] + ]: + """Conditional regularizer ``(source, mapped, labels) -> (loss, log)``. + + Defaults to zero if not provided. + """ + if self._regularizer is not None: + return self._regularizer + return lambda *_, **__: (0.0, None) + + @property + def fitting_loss( + self, + ) -> Callable[[jnp.ndarray, jnp.ndarray], Tuple[float, Optional[Any]]]: + """Fitting loss ``(mapped, target) -> (loss, log)``. + + Defaults to zero if not provided. + """ + if self._fitting_loss is not None: + return self._fitting_loss + return lambda *_, **__: (0.0, None) + + @staticmethod + def _generate_batch( + loader_source: Iterator[jnp.ndarray], + loader_target: Iterator[jnp.ndarray], + loader_condition: Iterator[jnp.ndarray], + loader_label: Iterator[jnp.ndarray], + ) -> Dict[str, jnp.ndarray]: + """Generate a batch of samples from all four iterators.""" + return { + "source": next(loader_source), + "target": next(loader_target), + "condition": next(loader_condition), + "condition_labels": next(loader_label), + } + + def train_map_estimator( + self, + trainloader_source: Iterator[jnp.ndarray], + trainloader_target: Iterator[jnp.ndarray], + trainloader_condition: Iterator[jnp.ndarray], + trainloader_label: Iterator[jnp.ndarray], + validloader_source: Iterator[jnp.ndarray], + validloader_target: Iterator[jnp.ndarray], + validloader_condition: Iterator[jnp.ndarray], + validloader_label: Iterator[jnp.ndarray], + ) -> Tuple[train_state.TrainState, Dict[str, Any]]: + """Training loop.""" + logs = collections.defaultdict(lambda: collections.defaultdict(list)) + + try: + from tqdm import trange + + tbar = trange(self.num_train_iters, leave=True) + except ImportError: + tbar = range(self.num_train_iters) + + for step in tbar: + is_logging_step = self.logging and ( + (step % self.valid_freq == 0) + or (step == self.num_train_iters - 1) + ) + train_batch = self._generate_batch( + trainloader_source, + trainloader_target, + trainloader_condition, + trainloader_label, + ) + valid_batch = ( + None + if not is_logging_step + else self._generate_batch( + validloader_source, + validloader_target, + validloader_condition, + validloader_label, + ) + ) + self.state_neural_net, current_logs = self.step_fn( + self.state_neural_net, + train_batch, + valid_batch, + is_logging_step, + step, + ) + + if is_logging_step: + for log_key in current_logs: + for metric_key in current_logs[log_key]: + logs[log_key][metric_key].append( + current_logs[log_key][metric_key] + ) + if not isinstance(tbar, range): + reg_msg = ( + "NA" + if current_logs["eval"]["regularizer"] == 0.0 + else f"{current_logs['eval']['regularizer']:.4f}" + ) + postfix_str = ( + f"fitting_loss:" + f" {current_logs['eval']['fitting_loss']:.4f}, " + f"regularizer: {reg_msg} ," + f"total: {current_logs['eval']['total_loss']:.4f}" + ) + tbar.set_postfix_str(postfix_str) + + return self.state_neural_net, logs + + def _get_step_fn(self) -> Callable: + """Create a one-step training and evaluation function.""" + + def loss_fn( + params: frozen_dict.FrozenDict, + apply_fn: Callable, + batch: Dict[str, jnp.ndarray], + step: int, + ) -> Tuple[float, Dict[str, float]]: + """Loss function with conditional map and regularizer.""" + # Apply the conditional map: T(source, condition) + mapped_samples = apply_fn( + {"params": params}, batch["source"], batch["condition"] + ) + + # Fitting loss: Δ(T(x,c), y) + val_fitting_loss, log_fitting_loss = self.fitting_loss( + mapped_samples, batch["target"] + ) + + # Conditional regularizer: R(x, T(x,c), labels) + val_regularizer, log_regularizer = self.regularizer( + batch["source"], mapped_samples, batch["condition_labels"] + ) + + val_tot_loss = ( + val_fitting_loss + + self.regularizer_strength[step] * val_regularizer + ) + + loss_logs = { + "total_loss": val_tot_loss, + "fitting_loss": val_fitting_loss, + "regularizer": val_regularizer, + "log_regularizer": log_regularizer, + "log_fitting": log_fitting_loss, + } + + return val_tot_loss, loss_logs + + @functools.partial(jax.jit, static_argnums=3) + def step_fn( + state_neural_net: train_state.TrainState, + train_batch: Dict[str, jnp.ndarray], + valid_batch: Optional[Dict[str, jnp.ndarray]] = None, + is_logging_step: bool = False, + step: int = 0, + ) -> Tuple[train_state.TrainState, Dict[str, float]]: + """One step function.""" + grad_fn = jax.value_and_grad(loss_fn, argnums=0, has_aux=True) + (_, current_train_logs), grads = grad_fn( + state_neural_net.params, + state_neural_net.apply_fn, + train_batch, + step, + ) + + current_logs = {"train": current_train_logs, "eval": {}} + if is_logging_step: + _, current_eval_logs = loss_fn( + params=state_neural_net.params, + apply_fn=state_neural_net.apply_fn, + batch=valid_batch, + step=step, + ) + current_logs["eval"] = current_eval_logs + + return state_neural_net.apply_gradients(grads=grads), current_logs + + return step_fn diff --git a/src/ott/neural/networks/__init__.py b/src/ott/neural/networks/__init__.py index b35fcb61a..0abb298a0 100644 --- a/src/ott/neural/networks/__init__.py +++ b/src/ott/neural/networks/__init__.py @@ -11,6 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import icnn, layers, potentials +from . import ( + conditional_perturbation_network, + icnn, + layers, + potentials, +) -__all__ = ["icnn", "layers", "potentials"] +__all__ = ["conditional_perturbation_network", "icnn", "layers", "potentials"] diff --git a/src/ott/neural/networks/conditional_perturbation_network.py b/src/ott/neural/networks/conditional_perturbation_network.py new file mode 100644 index 000000000..3cfbcd826 --- /dev/null +++ b/src/ott/neural/networks/conditional_perturbation_network.py @@ -0,0 +1,135 @@ +from typing import ( + Any, + Callable, + Dict, + Iterable, + Optional, + Sequence, + Tuple, + Union, +) + +import flax.linen as nn +import jax.numpy as jnp +import optax +from ott.neural.networks.potentials import ( + BasePotential, + PotentialTrainState, +) + + +class ConditionalPerturbationNetwork(BasePotential): + dim_hidden: Sequence[int] = None + dim_data: int = None + dim_cond: int = None # Full dimension of all context variables concatenated + # Same length as context_entity_bonds if embed_cond_equal is False + # (if True, first item is size of deep set layer, rest is ignored) + dim_cond_map: Iterable[int] = (50,) + act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.gelu + is_potential: bool = False + layer_norm: bool = False + embed_cond_equal: bool = ( + False # Whether all context variables should be treated as set or not + ) + context_entity_bonds: Iterable[Tuple[int, int]] = ( + (0, 10), + (0, 11), + ) # Start/stop index per modality + num_contexts: int = 2 + + @nn.compact + def __call__( + self, x: jnp.ndarray, c: Optional[jnp.ndarray] = None + ) -> Union[jnp.ndarray, Dict[str, jnp.ndarray]]: # noqa: D102 + """Args: + x (jnp.ndarray): The input data of shape bs x dim_data + c (jnp.ndarray): The context of shape bs x dim_cond with + possibly different modalities + concatenated, as can be specified via context_entity_bonds. + + Returns: + jnp.ndarray: _description_ + """ + return_batch = False + if isinstance(x, dict): + c = x["c"] + x = x["X"] + return_batch = True + + n_input = x.shape[-1] + + # Chunk the inputs + contexts = [ + c[:, e[0] : e[1]] + for i, e in enumerate(self.context_entity_bonds) + if i < self.num_contexts + ] + + if not self.embed_cond_equal: + # Each context is processed by a different layer, + # good for combining modalities + assert len(self.context_entity_bonds) == len(self.dim_cond_map), ( + "Length of context entity bonds and context map sizes have to " + f"match: {self.context_entity_bonds} != {self.dim_cond_map}" + ) + + layers = [ + nn.Dense(self.dim_cond_map[i], use_bias=True) + for i in range(len(contexts)) + ] + embeddings = [ + self.act_fn(layers[i](context)) + for i, context in enumerate(contexts) + ] + cond_embedding = jnp.concatenate(embeddings, axis=1) + else: + # We can use any number of contexts from the same modality, + # via a permutation-invariant deep set layer. + sizes = [c.shape[-1] for c in contexts] + if not len(set(sizes)) == 1: + raise ValueError( + "For embedding a set, all contexts need same length ," + f"not {sizes}" + ) + layer = nn.Dense(self.dim_cond_map[0], use_bias=True) + embeddings = [self.act_fn(layer(context)) for context in contexts] + # Average along stacked dimension + # (alternatives like summing are possible) + cond_embedding = jnp.mean(jnp.stack(embeddings), axis=0) + + z = jnp.concatenate((x, cond_embedding), axis=1) + if self.layer_norm: + n = nn.LayerNorm() + z = n(z) + + for n_hidden in self.dim_hidden: + wx = nn.Dense(n_hidden, use_bias=True) + z = self.act_fn(wx(z)) + wx = nn.Dense(n_input, use_bias=True) + + y = x + wx(z) + + if return_batch: + return {"X": y, "c": c} + else: + return y + + def create_train_state( + self, + rng: jnp.ndarray, + optimizer: optax.OptState, + dim_data: int, + **kwargs: Any, + ) -> PotentialTrainState: + """Create initial `TrainState`.""" + c = jnp.ones((1, self.dim_cond)) # (n_batch, embed_dim) + x = jnp.ones((1, dim_data)) # (n_batch, data_dim) + params = self.init(rng, x=x, c=c)["params"] + return PotentialTrainState.create( + apply_fn=self.apply, + params=params, + tx=optimizer, + potential_value_fn=self.potential_value_fn, + potential_gradient_fn=self.potential_gradient_fn, + **kwargs, + ) diff --git a/tests/neural/methods/conditional_monge_gap_test.py b/tests/neural/methods/conditional_monge_gap_test.py new file mode 100644 index 000000000..8098ac6f3 --- /dev/null +++ b/tests/neural/methods/conditional_monge_gap_test.py @@ -0,0 +1,280 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import jax +import jax.numpy as jnp +import numpy as np + +from ott import datasets +from ott.geometry import costs +from ott.neural.methods import conditional_monge_gap +from ott.neural.methods.monge_gap import monge_gap_from_samples +from ott.neural.networks.conditional_perturbation_network import ( + ConditionalPerturbationNetwork, +) +from ott.tools import sinkhorn_divergence + + +@pytest.mark.fast() +class TestConditionalMongeGap: + + @pytest.mark.parametrize("n_samples", [10, 30]) + @pytest.mark.parametrize("n_features", [4, 10]) + @pytest.mark.parametrize("num_conditions", [2, 3]) + def test_non_negativity( + self, rng: jax.Array, n_samples: int, n_features: int, + num_conditions: int, + ): + rng1, rng2 = jax.random.split(rng) + per_cond = n_samples // num_conditions + n = per_cond * num_conditions + + source = jax.random.normal(rng1, (n, n_features)) + target = source + 0.5 * jax.random.normal(rng2, (n, n_features)) + condition = jnp.repeat(jnp.arange(num_conditions), per_cond) + + gap = conditional_monge_gap.cmonge_gap_from_samples( + source, target, condition, + num_segments=num_conditions, + max_measure_size=per_cond, + ) + np.testing.assert_array_equal(gap >= 0, True) + + def test_jit_consistency(self, rng: jax.Array): + n, d, k = 60, 4, 3 + per_cond = n // k + rng1, rng2 = jax.random.split(rng) + source = jax.random.normal(rng1, (n, d)) + target = source + 0.1 * jax.random.normal(rng2, (n, d)) + condition = jnp.repeat(jnp.arange(k), per_cond) + + eager_gap = conditional_monge_gap.cmonge_gap_from_samples( + source, target, condition, + num_segments=k, max_measure_size=per_cond, + ) + jit_gap = jax.jit( + lambda s, t, c: conditional_monge_gap.cmonge_gap_from_samples( + s, t, c, num_segments=k, max_measure_size=per_cond, + ) + )(source, target, condition) + + np.testing.assert_allclose(eager_gap, jit_gap, rtol=1e-3) + + def test_matches_loop_baseline(self, rng: jax.Array): + """Segment-based result matches manual per-condition loop.""" + n, d, k = 60, 4, 3 + per_cond = n // k + rng1, rng2 = jax.random.split(rng) + source = jax.random.normal(rng1, (n, d)) + target = source + 0.1 * jax.random.normal(rng2, (n, d)) + condition = jnp.repeat(jnp.arange(k), per_cond) + + new_gap = conditional_monge_gap.cmonge_gap_from_samples( + source, target, condition, + num_segments=k, max_measure_size=per_cond, + ) + + # Manual loop (the old approach) + manual_gaps = [] + for c in range(k): + mask = condition == c + gap = monge_gap_from_samples(source[mask], target[mask]) + manual_gaps.append(float(gap)) + manual_avg = sum(manual_gaps) / len(manual_gaps) + + np.testing.assert_allclose(float(new_gap), manual_avg, atol=1e-5) + + def test_identity_smaller_than_random(self, rng: jax.Array): + """Identity map should have smaller Monge gap than a random map.""" + n, d, k = 60, 4, 3 + per_cond = n // k + rng1, rng2 = jax.random.split(rng) + source = jax.random.normal(rng1, (n, d)) + condition = jnp.repeat(jnp.arange(k), per_cond) + + identity_gap = conditional_monge_gap.cmonge_gap_from_samples( + source, source, condition, + num_segments=k, max_measure_size=per_cond, + ) + random_target = jax.random.normal(rng2, (n, d)) * 3.0 + random_gap = conditional_monge_gap.cmonge_gap_from_samples( + source, random_target, condition, + num_segments=k, max_measure_size=per_cond, + ) + assert identity_gap < random_gap + + @pytest.mark.parametrize("cost_fn", [ + costs.SqEuclidean(), + costs.PNormP(p=1), + ], ids=["sqeucl", "pnorm-1"]) + def test_different_costs(self, rng: jax.Array, cost_fn: costs.CostFn): + n, d, k = 30, 4, 3 + per_cond = n // k + rng1, rng2 = jax.random.split(rng) + source = jax.random.normal(rng1, (n, d)) + target = source + jax.random.normal(rng2, (n, d)) * 0.5 + condition = jnp.repeat(jnp.arange(k), per_cond) + + gap = conditional_monge_gap.cmonge_gap_from_samples( + source, target, condition, cost_fn=cost_fn, + num_segments=k, max_measure_size=per_cond, + ) + np.testing.assert_array_equal(jnp.isfinite(gap), True) + np.testing.assert_array_equal(gap >= 0, True) + + def test_return_output_shape(self, rng: jax.Array): + n, d, k = 60, 4, 3 + per_cond = n // k + rng1, rng2 = jax.random.split(rng) + source = jax.random.normal(rng1, (n, d)) + target = source + 0.1 * jax.random.normal(rng2, (n, d)) + condition = jnp.repeat(jnp.arange(k), per_cond) + + result = conditional_monge_gap.cmonge_gap_from_samples( + source, target, condition, + num_segments=k, max_measure_size=per_cond, + return_output=True, + ) + assert isinstance(result, tuple) + avg_gap, per_cond_gaps = result + assert per_cond_gaps.shape == (k,) + np.testing.assert_allclose( + float(avg_gap), float(jnp.mean(per_cond_gaps)), rtol=1e-5, + ) + + +@pytest.mark.fast() +class TestConditionalMongeGapEstimator: + + def test_estimator_convergence(self): + """Train a conditional map and verify loss decreases.""" + num_conditions = 3 + dim_data = 2 + dim_cond = num_conditions # one-hot + batch_size = 30 + + train_ds, valid_ds, _, n_cond, max_ms = ( + datasets.create_conditional_gaussian_mixture_samplers( + num_conditions=num_conditions, + dim=dim_data, + train_batch_size=batch_size, + valid_batch_size=batch_size, + ) + ) + + def fitting_loss(mapped, target): + div, _ = sinkhorn_divergence.sinkdiv(x=mapped, y=target) + return div, None + + def regularizer(source, mapped, labels): + gap, per_cond = conditional_monge_gap.cmonge_gap_from_samples( + source, mapped, labels, + num_segments=n_cond, + max_measure_size=max_ms, + return_output=True, + ) + return gap, None + + model = ConditionalPerturbationNetwork( + dim_hidden=[16, 8], + dim_data=dim_data, + dim_cond=dim_cond, + dim_cond_map=(16,), + is_potential=False, + context_entity_bonds=((0, dim_cond),), + num_contexts=1, + ) + + solver = conditional_monge_gap.ConditionalMongeGapEstimator( + dim_data=dim_data, + fitting_loss=fitting_loss, + regularizer=regularizer, + model=model, + regularizer_strength=1.0, + num_train_iters=15, + logging=True, + valid_freq=5, + ) + + neural_state, logs = solver.train_map_estimator( + *train_ds, *valid_ds, + ) + + # Loss should decrease + assert logs["train"]["total_loss"][0] > logs["train"]["total_loss"][-1] + + # Output shape should match input + source_batch = next(train_ds.source_iter) + cond_batch = next(train_ds.condition_iter) + mapped = neural_state.apply_fn( + {"params": neural_state.params}, source_batch, cond_batch, + ) + assert mapped.shape == source_batch.shape + np.testing.assert_array_equal(jnp.all(jnp.isfinite(mapped)), True) + + def test_estimator_no_regularizer(self): + """Training with regularizer_strength=0 still converges.""" + num_conditions = 2 + dim_data = 2 + dim_cond = num_conditions + batch_size = 20 + + train_ds, valid_ds, _, _, _ = ( + datasets.create_conditional_gaussian_mixture_samplers( + num_conditions=num_conditions, + dim=dim_data, + train_batch_size=batch_size, + valid_batch_size=batch_size, + ) + ) + + def fitting_loss(mapped, target): + div, _ = sinkhorn_divergence.sinkdiv(x=mapped, y=target) + return div, None + + model = ConditionalPerturbationNetwork( + dim_hidden=[8, 8], + dim_data=dim_data, + dim_cond=dim_cond, + dim_cond_map=(8,), + is_potential=False, + context_entity_bonds=((0, dim_cond),), + num_contexts=1, + ) + + solver = conditional_monge_gap.ConditionalMongeGapEstimator( + dim_data=dim_data, + fitting_loss=fitting_loss, + model=model, + regularizer_strength=0.0, + num_train_iters=10, + logging=True, + valid_freq=5, + ) + + neural_state, logs = solver.train_map_estimator( + *train_ds, *valid_ds, + ) + + # Should have run without errors and logged metrics + assert len(logs["train"]["total_loss"]) > 0 + # Mapped output should be finite + source_batch = next(train_ds.source_iter) + cond_batch = next(train_ds.condition_iter) + mapped = neural_state.apply_fn( + {"params": neural_state.params}, source_batch, cond_batch, + ) + np.testing.assert_array_equal(jnp.all(jnp.isfinite(mapped)), True) From 758f81dcc7162798d36d60f6ae4252065903b79d Mon Sep 17 00:00:00 2001 From: DhruvaRajwade Date: Mon, 16 Mar 2026 15:55:23 +0100 Subject: [PATCH 3/6] test: expand cmonge_gap test suite with equivalence and cost function tests Add 5 new tests to TestConditionalMongeGap: - test_non_negativity_neural_map: PotentialMLP-based targets - test_different_costs_give_different_values: PNormP, RegTICost(L1), RegTICost(STVS) - test_uniform_conditions_equals_averaged_monge_gap: exact equivalence proof - test_unequal_conditions_shifts_average: structural properties with padding - test_per_condition_gaps_reflect_difficulty: monotonic gap ordering --- .../methods/conditional_monge_gap_test.py | 201 +++++++++++++++++- 1 file changed, 200 insertions(+), 1 deletion(-) diff --git a/tests/neural/methods/conditional_monge_gap_test.py b/tests/neural/methods/conditional_monge_gap_test.py index 8098ac6f3..6f5215c1f 100644 --- a/tests/neural/methods/conditional_monge_gap_test.py +++ b/tests/neural/methods/conditional_monge_gap_test.py @@ -19,9 +19,10 @@ import numpy as np from ott import datasets -from ott.geometry import costs +from ott.geometry import costs, regularizers from ott.neural.methods import conditional_monge_gap from ott.neural.methods.monge_gap import monge_gap_from_samples +from ott.neural.networks import potentials from ott.neural.networks.conditional_perturbation_network import ( ConditionalPerturbationNetwork, ) @@ -156,6 +157,204 @@ def test_return_output_shape(self, rng: jax.Array): ) + @pytest.mark.parametrize("n_samples", [10, 30]) + @pytest.mark.parametrize("n_features", [4, 10]) + def test_non_negativity_neural_map( + self, rng: jax.Array, n_samples: int, n_features: int, + ): + """Non-negativity with a learned nonlinear map (mirrors monge_gap_test).""" + k = 2 + per_cond = n_samples // k + n = per_cond * k + rng1, rng2 = jax.random.split(rng) + + source = jax.random.normal(rng1, (n, n_features)) + model = potentials.PotentialMLP(dim_hidden=[8, 8], is_potential=False) + params = model.init(rng2, x=source[0]) + target = model.apply(params, source) + condition = jnp.repeat(jnp.arange(k), per_cond) + + gap = conditional_monge_gap.cmonge_gap_from_samples( + source, target, condition, + num_segments=k, max_measure_size=per_cond, + ) + np.testing.assert_array_equal(jnp.isfinite(gap), True) + np.testing.assert_array_equal(gap >= 0, True) + + @pytest.mark.parametrize("cost_fn", [ + costs.PNormP(p=1), + costs.RegTICost(regularizers.L1(), lam=2.0), + costs.RegTICost(regularizers.STVS(gamma=3.0), lam=1.0), + ], ids=["pnorm-1", "l1-lam2", "stvs-lam1"]) + def test_different_costs_give_different_values( + self, rng: jax.Array, cost_fn: costs.CostFn, + ): + """Non-Euclidean costs produce different cmonge_gap than Euclidean.""" + n, d, k = 30, 5, 3 + per_cond = n // k + rng1, rng2 = jax.random.split(rng) + source = jax.random.normal(rng1, (n, d)) + target = jax.random.normal(rng2, (n, d)) * 0.1 + 3.0 + condition = jnp.repeat(jnp.arange(k), per_cond) + + gap_eucl = conditional_monge_gap.cmonge_gap_from_samples( + source, target, condition, cost_fn=costs.Euclidean(), + num_segments=k, max_measure_size=per_cond, + ) + gap_other = conditional_monge_gap.cmonge_gap_from_samples( + source, target, condition, cost_fn=cost_fn, + num_segments=k, max_measure_size=per_cond, + ) + + with pytest.raises(AssertionError, match=r"tolerance"): + np.testing.assert_allclose( + gap_eucl, gap_other, rtol=1e-1, atol=1e-1, + ) + np.testing.assert_array_equal(jnp.isfinite(gap_eucl), True) + np.testing.assert_array_equal(jnp.isfinite(gap_other), True) + + def test_uniform_conditions_equals_averaged_monge_gap( + self, rng: jax.Array, + ): + """cmonge_gap with equal-size conditions == mean of monge_gap calls.""" + k = 3 + per_cond = 20 + d = 5 + n = k * per_cond + + # Different offsets per condition so gaps are distinct + offsets = jnp.array([0.1, 1.0, 3.0]) + rngs = jax.random.split(rng, 2 * k) + sources, targets = [], [] + for c in range(k): + s = jax.random.normal(rngs[2 * c], (per_cond, d)) + t = s + offsets[c] + 0.05 * jax.random.normal( + rngs[2 * c + 1], (per_cond, d) + ) + sources.append(s) + targets.append(t) + + source = jnp.concatenate(sources, axis=0) + target = jnp.concatenate(targets, axis=0) + condition = jnp.repeat(jnp.arange(k), per_cond) + + # Segmented cmonge_gap + avg_gap, per_cond_gaps = conditional_monge_gap.cmonge_gap_from_samples( + source, target, condition, + num_segments=k, max_measure_size=per_cond, + return_output=True, + ) + + # Manual per-condition monge_gap calls + manual_gaps = [] + for c in range(k): + gap_c = monge_gap_from_samples(sources[c], targets[c]) + manual_gaps.append(float(gap_c)) + manual_avg = sum(manual_gaps) / k + + # Average should match + np.testing.assert_allclose(float(avg_gap), manual_avg, atol=1e-5) + # Per-condition gaps should match individual calls + for c in range(k): + np.testing.assert_allclose( + float(per_cond_gaps[c]), manual_gaps[c], atol=1e-5, + ) + + def test_unequal_conditions_shifts_average(self, rng: jax.Array): + """With unequal n_k, per-condition gaps change and shift the average. + + The segment interface pads all conditions to max_measure_size, so + per-condition gaps with padding do NOT exactly match non-padded + monge_gap_from_samples calls (the geometry differs). We verify + structural properties instead: gaps are finite, easy < hard, + average = mean(per_cond_gaps), and the average shifts when n_k changes. + """ + d = 5 + rng_easy, rng_hard, rng_noise = jax.random.split(rng, 3) + + base_easy = jax.random.normal(rng_easy, (60, d)) + base_hard = jax.random.normal(rng_hard, (60, d)) + noise = 0.01 * jax.random.normal(rng_noise, (60, d)) + + target_easy = base_easy + noise + target_hard = base_hard + 5.0 + + # (a) Equal sizes: 30/30 + n_eq = 30 + src_eq = jnp.concatenate([base_easy[:n_eq], base_hard[:n_eq]]) + tgt_eq = jnp.concatenate([target_easy[:n_eq], target_hard[:n_eq]]) + cond_eq = jnp.repeat(jnp.arange(2), n_eq) + + avg_eq, gaps_eq = conditional_monge_gap.cmonge_gap_from_samples( + src_eq, tgt_eq, cond_eq, + num_segments=2, max_measure_size=n_eq, + return_output=True, + ) + + # (b) Unequal sizes: 50 easy / 10 hard + n_a, n_b = 50, 10 + src_uneq = jnp.concatenate([base_easy[:n_a], base_hard[:n_b]]) + tgt_uneq = jnp.concatenate([target_easy[:n_a], target_hard[:n_b]]) + cond_uneq = jnp.concatenate([ + jnp.zeros(n_a, dtype=jnp.int32), + jnp.ones(n_b, dtype=jnp.int32), + ]) + + avg_uneq, gaps_uneq = conditional_monge_gap.cmonge_gap_from_samples( + src_uneq, tgt_uneq, cond_uneq, + num_segments=2, max_measure_size=n_a, + return_output=True, + ) + + # All gaps are finite and non-negative + for gaps in [gaps_eq, gaps_uneq]: + np.testing.assert_array_equal(jnp.all(jnp.isfinite(gaps)), True) + np.testing.assert_array_equal(jnp.all(gaps >= 0), True) + + # Easy condition has smaller gap than hard condition + assert gaps_eq[0] < gaps_eq[1] + assert gaps_uneq[0] < gaps_uneq[1] + + # Average is the mean of per-condition gaps + np.testing.assert_allclose( + float(avg_eq), float(jnp.mean(gaps_eq)), rtol=1e-5, + ) + np.testing.assert_allclose( + float(avg_uneq), float(jnp.mean(gaps_uneq)), rtol=1e-5, + ) + + # Averages differ between equal and unequal splits (n_k affects + # the padded OT cost estimation, shifting per-condition gaps) + assert float(avg_eq) != float(avg_uneq) + + def test_per_condition_gaps_reflect_difficulty(self, rng: jax.Array): + """Per-condition gaps increase with offset magnitude.""" + k = 3 + per_cond = 25 + d = 4 + offsets = jnp.array([0.0, 1.5, 5.0]) + + rngs = jax.random.split(rng, 2 * k) + sources, targets = [], [] + for c in range(k): + s = jax.random.normal(rngs[2 * c], (per_cond, d)) + t = s + offsets[c] + sources.append(s) + targets.append(t) + + source = jnp.concatenate(sources, axis=0) + target = jnp.concatenate(targets, axis=0) + condition = jnp.repeat(jnp.arange(k), per_cond) + + _, per_cond_gaps = conditional_monge_gap.cmonge_gap_from_samples( + source, target, condition, + num_segments=k, max_measure_size=per_cond, + return_output=True, + ) + + assert per_cond_gaps[0] < per_cond_gaps[1] < per_cond_gaps[2] + + @pytest.mark.fast() class TestConditionalMongeGapEstimator: From 3f8e6a1334a079a89594b5a32335d51141671afa Mon Sep 17 00:00:00 2001 From: DhruvaRajwade Date: Wed, 18 Mar 2026 15:18:20 +0100 Subject: [PATCH 4/6] refactor: add padding warning, runtime timing, and shorten PR message - Add logger.warning in cmonge_gap_from_samples when any condition is padded >10x its actual size (skipped under JIT via try/except) - Add runtime timing to test_uniform_conditions_equals_averaged_monge_gap comparing segmented vs loop performance - Rewrite PR_MESSAGE.md to ~half page with concise overview and tutorial plot --- .../neural/methods/conditional_monge_gap.py | 22 ++++++++++++ .../conditional_perturbation_network.py | 22 +++++++----- .../methods/conditional_monge_gap_test.py | 36 +++++++++++++++++-- 3 files changed, 70 insertions(+), 10 deletions(-) diff --git a/src/ott/neural/methods/conditional_monge_gap.py b/src/ott/neural/methods/conditional_monge_gap.py index f31629d31..f34c28ca6 100644 --- a/src/ott/neural/methods/conditional_monge_gap.py +++ b/src/ott/neural/methods/conditional_monge_gap.py @@ -13,6 +13,7 @@ # limitations under the License. import collections import functools +import logging from typing import ( Any, Callable, @@ -40,6 +41,8 @@ from ott.problems.linear import linear_problem from ott.solvers.linear import sinkhorn +logger = logging.getLogger(__name__) + __all__ = [ "cmonge_gap_from_samples", "ConditionalMongeGapEstimator", @@ -107,6 +110,25 @@ def cmonge_gap_from_samples( dim = source.shape[1] padding_vector = cost_fn._padder(dim=dim) + # Warn if any condition is heavily padded (>10x below max_measure_size), + # which can cause numerical differences vs non-padded Sinkhorn solves. + # Skipped silently under JIT where condition values are traced. + if max_measure_size is not None and num_segments is not None: + try: + counts = jnp.bincount(condition, length=num_segments) + min_count = int(jnp.min(counts)) + if min_count > 0 and max_measure_size // min_count >= 10: + logger.warning( + "Condition with %d samples will be padded to %d " + "(%.0fx). Per-condition Monge gap values may differ " + "from non-padded monge_gap_from_samples calls.", + min_count, + max_measure_size, + max_measure_size / min_count, + ) + except jax.errors.ConcretizationTypeError: + pass + # NOTE: Eval function takes some logic from: # ott.neural.methods.monge_gap.monge_gap_from_samples` # as well as `ott.geometry.segment.py` diff --git a/src/ott/neural/networks/conditional_perturbation_network.py b/src/ott/neural/networks/conditional_perturbation_network.py index 3cfbcd826..e622f4bfa 100644 --- a/src/ott/neural/networks/conditional_perturbation_network.py +++ b/src/ott/neural/networks/conditional_perturbation_network.py @@ -33,22 +33,28 @@ class ConditionalPerturbationNetwork(BasePotential): ) context_entity_bonds: Iterable[Tuple[int, int]] = ( (0, 10), - (0, 11), - ) # Start/stop index per modality + (10, 20), + ) # (start, stop) slicing bounds per context modality in c; + # should be contiguous, non-overlapping by default. num_contexts: int = 2 @nn.compact def __call__( self, x: jnp.ndarray, c: Optional[jnp.ndarray] = None ) -> Union[jnp.ndarray, Dict[str, jnp.ndarray]]: # noqa: D102 - """Args: - x (jnp.ndarray): The input data of shape bs x dim_data - c (jnp.ndarray): The context of shape bs x dim_cond with - possibly different modalities - concatenated, as can be specified via context_entity_bonds. + """Forward pass: map (x, c) -> x + residual. + + Args: + x: Input data of shape ``(batch, dim_data)``. + c: Context vector of shape ``(batch, dim_cond)``. May + contain multiple modalities concatenated along the last + axis. ``context_entity_bonds`` specifies which slice + ``c[:, start:stop]`` belongs to each modality. Slices + should generally be contiguous and non-overlapping, e.g. + ``((0, 10), (10, 20))`` for two 10-dim modalities. Returns: - jnp.ndarray: _description_ + Mapped output of shape ``(batch, dim_data)``. """ return_batch = False if isinstance(x, dict): diff --git a/tests/neural/methods/conditional_monge_gap_test.py b/tests/neural/methods/conditional_monge_gap_test.py index 6f5215c1f..dba6dbb32 100644 --- a/tests/neural/methods/conditional_monge_gap_test.py +++ b/tests/neural/methods/conditional_monge_gap_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import time + import pytest import jax @@ -238,19 +240,49 @@ def test_uniform_conditions_equals_averaged_monge_gap( target = jnp.concatenate(targets, axis=0) condition = jnp.repeat(jnp.arange(k), per_cond) - # Segmented cmonge_gap + # Segmented cmonge_gap (single call, vmapped) + t0 = time.perf_counter() avg_gap, per_cond_gaps = conditional_monge_gap.cmonge_gap_from_samples( source, target, condition, num_segments=k, max_measure_size=per_cond, return_output=True, ) + # Force computation to complete before timing + avg_gap.block_until_ready() + t_cmonge = time.perf_counter() - t0 - # Manual per-condition monge_gap calls + # Manual per-condition monge_gap calls (K sequential calls) + t0 = time.perf_counter() manual_gaps = [] for c in range(k): gap_c = monge_gap_from_samples(sources[c], targets[c]) manual_gaps.append(float(gap_c)) manual_avg = sum(manual_gaps) / k + t_loop = time.perf_counter() - t0 + + # Single-condition overhead: cmonge_gap(K=1) vs monge_gap + t0 = time.perf_counter() + gap_single_cmonge = conditional_monge_gap.cmonge_gap_from_samples( + sources[0], targets[0], + jnp.zeros(per_cond, dtype=jnp.int32), + num_segments=1, max_measure_size=per_cond, + ) + gap_single_cmonge.block_until_ready() + t_cmonge_1 = time.perf_counter() - t0 + + t0 = time.perf_counter() + gap_single_monge = monge_gap_from_samples(sources[0], targets[0]) + float(gap_single_monge) # block + t_monge_1 = time.perf_counter() - t0 + + print( + f"\n K={k}: cmonge_gap: {t_cmonge:.3f}s | " + f"loop({k}x monge_gap): {t_loop:.3f}s | " + f"speedup: {t_loop / t_cmonge:.1f}x" + f"\n K=1: cmonge_gap: {t_cmonge_1:.3f}s | " + f"monge_gap: {t_monge_1:.3f}s | " + f"overhead: {t_cmonge_1 / t_monge_1:.1f}x" + ) # Average should match np.testing.assert_allclose(float(avg_gap), manual_avg, atol=1e-5) From 9e51b4022df9be6abf775b5eb3f5c3d75ab6b607 Mon Sep 17 00:00:00 2001 From: DhruvaRajwade Date: Mon, 23 Mar 2026 16:32:38 +0100 Subject: [PATCH 5/6] style: fix ruff/isort lint (2-space indent, line length, imports) --- src/ott/datasets.py | 137 +-- src/ott/neural/methods/__init__.py | 8 +- .../neural/methods/conditional_monge_gap.py | 575 +++++----- src/ott/neural/networks/__init__.py | 7 +- .../conditional_perturbation_network.py | 223 ++-- .../methods/conditional_monge_gap_test.py | 1011 +++++++++-------- 6 files changed, 1020 insertions(+), 941 deletions(-) diff --git a/src/ott/datasets.py b/src/ott/datasets.py index 2e6ee4d42..05ea03ed7 100644 --- a/src/ott/datasets.py +++ b/src/ott/datasets.py @@ -34,10 +34,11 @@ class Dataset(NamedTuple): r"""Samplers from source and target measures. - Args: - source_iter: loader for the source measure - target_iter: loader for the target measure - """ + Args: + source_iter: loader for the source measure + target_iter: loader for the target measure + """ + source_iter: Iterator[jnp.ndarray] target_iter: Iterator[jnp.ndarray] @@ -45,12 +46,14 @@ class Dataset(NamedTuple): class ConditionalDataset(NamedTuple): r"""Samplers from conditional source and target measures. - Args: - source_iter: loader for the source measure, ``[batch, d]`` - target_iter: loader for the target measure, ``[batch, d]`` - condition_iter: loader for continuous condition vectors, ``[batch, dim_c]`` - label_iter: loader for integer condition labels, ``[batch]`` - """ + Args: + source_iter: loader for the source measure, ``[batch, d]`` + target_iter: loader for the target measure, ``[batch, d]`` + condition_iter: loader for condition vectors, + ``[batch, dim_c]`` + label_iter: loader for integer condition labels, ``[batch]`` + """ + source_iter: Iterator[jnp.ndarray] target_iter: Iterator[jnp.ndarray] condition_iter: Iterator[jnp.ndarray] @@ -61,21 +64,22 @@ class ConditionalDataset(NamedTuple): class GaussianMixture: """A mixture of Gaussians. - Args: - name: the name specifying the centers of the mixture components: - - - ``simple`` - data clustered in one center, - - ``circle`` - two-dimensional Gaussians arranged on a circle, - - ``square_five`` - two-dimensional Gaussians on a square with - one Gaussian in the center, and - - ``square_four`` - two-dimensional Gaussians in the corners of a - rectangle - - batch_size: batch size of the samples - rng: initial PRNG key - scale: scale of the Gaussian means - std: the standard deviation of the individual Gaussian samples - """ + Args: + name: the name specifying the centers of the mixture components: + + - ``simple`` - data clustered in one center, + - ``circle`` - two-dimensional Gaussians arranged on a circle, + - ``square_five`` - two-dimensional Gaussians on a square with + one Gaussian in the center, and + - ``square_four`` - two-dimensional Gaussians in the corners of a + rectangle + + batch_size: batch size of the samples + rng: initial PRNG key + scale: scale of the Gaussian means + std: the standard deviation of the individual Gaussian samples + """ + name: Name_t batch_size: int rng: jax.Array @@ -111,9 +115,9 @@ def __post_init__(self) -> None: def __iter__(self) -> Iterator[jnp.array]: """Random sample generator from Gaussian mixture. - Returns: - A generator of samples from the Gaussian mixture. - """ + Returns: + A generator of samples from the Gaussian mixture. + """ return self._create_sample_generators() def _create_sample_generators(self) -> Iterator[jnp.array]: @@ -135,16 +139,16 @@ def create_gaussian_mixture_samplers( ) -> Tuple[Dataset, Dataset, int]: """Gaussian samplers. - Args: - name_source: name of the source sampler - name_target: name of the target sampler - train_batch_size: the training batch size - valid_batch_size: the validation batch size - rng: initial PRNG key + Args: + name_source: name of the source sampler + name_target: name of the target sampler + train_batch_size: the training batch size + valid_batch_size: the validation batch size + rng: initial PRNG key - Returns: - The dataset and dimension of the data. - """ + Returns: + The dataset and dimension of the data. + """ rng = utils.default_prng_key(rng) rng1, rng2, rng3, rng4 = jax.random.split(rng, 4) train_dataset = Dataset( @@ -153,7 +157,7 @@ def create_gaussian_mixture_samplers( ), target_iter=iter( GaussianMixture(name_target, batch_size=train_batch_size, rng=rng2) - ) + ), ) valid_dataset = Dataset( source_iter=iter( @@ -161,7 +165,7 @@ def create_gaussian_mixture_samplers( ), target_iter=iter( GaussianMixture(name_target, batch_size=valid_batch_size, rng=rng4) - ) + ), ) dim_data = 2 return train_dataset, valid_dataset, dim_data @@ -171,16 +175,18 @@ def create_gaussian_mixture_samplers( class ConditionalGaussianMixture: """Conditional Gaussian sampler for testing. - For each condition *k*, draws source ~ N(0, I) and target ~ source + offset_k. - Condition vectors are one-hot encoded labels. + For each condition *k*, draws source ~ N(0, I) and + target ~ source + offset_k. + Condition vectors are one-hot encoded labels. + + Args: + num_conditions: number of distinct conditions. + batch_size: total batch size (divided equally among conditions). + dim: data dimensionality. + offsets: ``[num_conditions, dim]`` translation per condition. + rng: initial PRNG key. + """ - Args: - num_conditions: number of distinct conditions. - batch_size: total batch size (divided equally among conditions). - dim: data dimensionality. - offsets: ``[num_conditions, dim]`` translation per condition. - rng: initial PRNG key. - """ num_conditions: int batch_size: int dim: int @@ -223,21 +229,21 @@ def create_conditional_gaussian_mixture_samplers( ) -> Tuple[ConditionalDataset, ConditionalDataset, int, int, int]: """Create conditional Gaussian samplers for testing. - Each condition defines a different translation of the source distribution. - - Args: - num_conditions: number of distinct conditions. - dim: data dimensionality. - train_batch_size: training batch size (should be divisible by - ``num_conditions``). - valid_batch_size: validation batch size. - rng: initial PRNG key. - - Returns: - ``(train_dataset, valid_dataset, dim_data, num_conditions, - max_measure_size)`` where ``max_measure_size = - batch_size // num_conditions``. - """ + Each condition defines a different translation of the source distribution. + + Args: + num_conditions: number of distinct conditions. + dim: data dimensionality. + train_batch_size: training batch size (should be divisible by + ``num_conditions``). + valid_batch_size: validation batch size. + rng: initial PRNG key. + + Returns: + ``(train_dataset, valid_dataset, dim_data, num_conditions, + max_measure_size)`` where ``max_measure_size = + batch_size // num_conditions``. + """ rng = utils.default_prng_key(rng) rng1, rng2, rng_off = jax.random.split(rng, 3) @@ -245,7 +251,8 @@ def create_conditional_gaussian_mixture_samplers( offsets = jax.random.normal(rng_off, (num_conditions, dim)) * 3.0 def _make_dataset( - bs: int, key: jax.Array, + bs: int, + key: jax.Array, ) -> ConditionalDataset: sampler = ConditionalGaussianMixture( num_conditions=num_conditions, @@ -267,7 +274,7 @@ def _iter(idx: int) -> Iterator[jnp.ndarray]: while True: c = _next_batch() val = c["batch"][idx] - # Mark this index as consumed; when all 4 are consumed, clear cache. + # Mark consumed; when all 4 are done, clear cache. c.setdefault("consumed", set()) c["consumed"].add(idx) if len(c["consumed"]) == 4: diff --git a/src/ott/neural/methods/__init__.py b/src/ott/neural/methods/__init__.py index b3883052a..983c3e2ac 100644 --- a/src/ott/neural/methods/__init__.py +++ b/src/ott/neural/methods/__init__.py @@ -11,4 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import conditional_monge_gap, expectile_neural_dual, flow_matching, monge_gap, neuraldual +from . import ( + conditional_monge_gap, + expectile_neural_dual, + flow_matching, + monge_gap, + neuraldual, +) diff --git a/src/ott/neural/methods/conditional_monge_gap.py b/src/ott/neural/methods/conditional_monge_gap.py index f34c28ca6..d47528a96 100644 --- a/src/ott/neural/methods/conditional_monge_gap.py +++ b/src/ott/neural/methods/conditional_monge_gap.py @@ -62,7 +62,7 @@ def cmonge_gap_from_samples( max_measure_size: Optional[int] = None, **kwargs: Any, ) -> Union[float, Tuple[float, jnp.ndarray]]: - r"""Conditional Monge gap from samples using the segment interface. + r"""Conditional Monge gap from samples using the segment interface. Computes the average Monge gap across conditions: @@ -106,79 +106,79 @@ def cmonge_gap_from_samples( The average Monge gap across conditions and, when ``return_output`` is :obj:`True`, a ``[num_segments]`` array of per-condition gaps. """ - cost_fn = costs.SqEuclidean() if cost_fn is None else cost_fn - dim = source.shape[1] - padding_vector = cost_fn._padder(dim=dim) - - # Warn if any condition is heavily padded (>10x below max_measure_size), - # which can cause numerical differences vs non-padded Sinkhorn solves. - # Skipped silently under JIT where condition values are traced. - if max_measure_size is not None and num_segments is not None: - try: - counts = jnp.bincount(condition, length=num_segments) - min_count = int(jnp.min(counts)) - if min_count > 0 and max_measure_size // min_count >= 10: - logger.warning( - "Condition with %d samples will be padded to %d " - "(%.0fx). Per-condition Monge gap values may differ " - "from non-padded monge_gap_from_samples calls.", - min_count, - max_measure_size, - max_measure_size / min_count, - ) - except jax.errors.ConcretizationTypeError: - pass - - # NOTE: Eval function takes some logic from: - # ott.neural.methods.monge_gap.monge_gap_from_samples` - # as well as `ott.geometry.segment.py` - def eval_fn( - padded_x: jnp.ndarray, - padded_y: jnp.ndarray, - padded_weight_x: jnp.ndarray, - padded_weight_y: jnp.ndarray, - ) -> jnp.ndarray: - """Monge gap for a single (padded) condition segment.""" - # Displacement cost: weighted mean of pairwise costs c(x_i, T(x_i)). - # Padded entries have weight 0, so they do not contribute. - pairwise_costs = jax.vmap(cost_fn)(padded_x, padded_y) - displacement_cost = jnp.sum(pairwise_costs * padded_weight_x) - - # Entropy-regularized OT cost W_{c,ε}. - geom = pointcloud.PointCloud( - padded_x, - padded_y, - cost_fn=cost_fn, - epsilon=epsilon, - relative_epsilon=relative_epsilon, - scale_cost=scale_cost, + cost_fn = costs.SqEuclidean() if cost_fn is None else cost_fn + dim = source.shape[1] + padding_vector = cost_fn._padder(dim=dim) + + # Warn if any condition is heavily padded (>10x below max_measure_size), + # which can cause numerical differences vs non-padded Sinkhorn solves. + # Skipped silently under JIT where condition values are traced. + if max_measure_size is not None and num_segments is not None: + try: + counts = jnp.bincount(condition, length=num_segments) + min_count = int(jnp.min(counts)) + if min_count > 0 and max_measure_size // min_count >= 10: + logger.warning( + "Condition with %d samples will be padded to %d " + "(%.0fx). Per-condition Monge gap values may differ " + "from non-padded monge_gap_from_samples calls.", + min_count, + max_measure_size, + max_measure_size / min_count, ) - prob = linear_problem.LinearProblem( - geom, a=padded_weight_x, b=padded_weight_y - ) - solver = sinkhorn.Sinkhorn(**kwargs) - out = solver(prob) - - return displacement_cost - out.ent_reg_cost - - per_condition_gaps = segment._segment_interface( - x=source, - y=target, - eval_fn=eval_fn, - num_segments=num_segments, - max_measure_size=max_measure_size, - segment_ids_x=condition, - segment_ids_y=condition, - indices_are_sorted=False, - padding_vector=padding_vector, + except jax.errors.ConcretizationTypeError: + pass + + # NOTE: Eval function takes some logic from: + # ott.neural.methods.monge_gap.monge_gap_from_samples` + # as well as `ott.geometry.segment.py` + def eval_fn( + padded_x: jnp.ndarray, + padded_y: jnp.ndarray, + padded_weight_x: jnp.ndarray, + padded_weight_y: jnp.ndarray, + ) -> jnp.ndarray: + """Monge gap for a single (padded) condition segment.""" + # Displacement cost: weighted mean of pairwise costs c(x_i, T(x_i)). + # Padded entries have weight 0, so they do not contribute. + pairwise_costs = jax.vmap(cost_fn)(padded_x, padded_y) + displacement_cost = jnp.sum(pairwise_costs * padded_weight_x) + + # Entropy-regularized OT cost W_{c,ε}. + geom = pointcloud.PointCloud( + padded_x, + padded_y, + cost_fn=cost_fn, + epsilon=epsilon, + relative_epsilon=relative_epsilon, + scale_cost=scale_cost, + ) + prob = linear_problem.LinearProblem( + geom, a=padded_weight_x, b=padded_weight_y ) + solver = sinkhorn.Sinkhorn(**kwargs) + out = solver(prob) + + return displacement_cost - out.ent_reg_cost - avg_gap = jnp.mean(per_condition_gaps) - return (avg_gap, per_condition_gaps) if return_output else avg_gap + per_condition_gaps = segment._segment_interface( + x=source, + y=target, + eval_fn=eval_fn, + num_segments=num_segments, + max_measure_size=max_measure_size, + segment_ids_x=condition, + segment_ids_y=condition, + indices_are_sorted=False, + padding_vector=padding_vector, + ) + + avg_gap = jnp.mean(per_condition_gaps) + return (avg_gap, per_condition_gaps) if return_output else avg_gap class ConditionalMongeGapEstimator: - r"""Conditional map estimator between probability measures. + r"""Conditional map estimator between probability measures. Estimates a condition-dependent map :math:`T(\cdot, c)` by minimizing: @@ -214,238 +214,223 @@ class ConditionalMongeGapEstimator: rng: random seed. """ - def __init__( - self, - dim_data: int, - model: ConditionalPerturbationNetwork, - optimizer: Optional[optax.OptState] = None, - fitting_loss: Optional[ - Callable[ - [jnp.ndarray, jnp.ndarray], Tuple[float, Optional[Any]] - ] - ] = None, - regularizer: Optional[ - Callable[ - [jnp.ndarray, jnp.ndarray, jnp.ndarray], - Tuple[float, Optional[Any]], - ] - ] = None, - regularizer_strength: Union[float, Sequence[float]] = 1.0, - num_train_iters: int = 10_000, - logging: bool = False, - valid_freq: int = 500, - rng: Optional[jax.Array] = None, - ): - self._fitting_loss = fitting_loss - self._regularizer = regularizer - self.regularizer_strength = jnp.repeat( - jnp.atleast_2d(regularizer_strength), - num_train_iters, - total_repeat_length=num_train_iters, - axis=0, - ).ravel() - self.num_train_iters = num_train_iters - self.logging = logging - self.valid_freq = valid_freq - self.rng = utils.default_prng_key(rng) - - if optimizer is None: - optimizer = optax.adam(learning_rate=0.001) - - self.setup(dim_data, model, optimizer) - - def setup( - self, - dim_data: int, - neural_net: ConditionalPerturbationNetwork, - optimizer: optax.OptState, - ): - """Set up all components required to train the network.""" - self.state_neural_net = neural_net.create_train_state( - self.rng, optimizer, dim_data - ) - self.step_fn = self._get_step_fn() + def __init__( + self, + dim_data: int, + model: ConditionalPerturbationNetwork, + optimizer: Optional[optax.OptState] = None, + fitting_loss: Optional[Callable[[jnp.ndarray, jnp.ndarray], + Tuple[float, Optional[Any]]]] = None, + regularizer: Optional[Callable[ + [jnp.ndarray, jnp.ndarray, jnp.ndarray], + Tuple[float, Optional[Any]], + ]] = None, + regularizer_strength: Union[float, Sequence[float]] = 1.0, + num_train_iters: int = 10_000, + logging: bool = False, + valid_freq: int = 500, + rng: Optional[jax.Array] = None, + ): + self._fitting_loss = fitting_loss + self._regularizer = regularizer + self.regularizer_strength = jnp.repeat( + jnp.atleast_2d(regularizer_strength), + num_train_iters, + total_repeat_length=num_train_iters, + axis=0, + ).ravel() + self.num_train_iters = num_train_iters + self.logging = logging + self.valid_freq = valid_freq + self.rng = utils.default_prng_key(rng) + + if optimizer is None: + optimizer = optax.adam(learning_rate=0.001) + + self.setup(dim_data, model, optimizer) + + def setup( + self, + dim_data: int, + neural_net: ConditionalPerturbationNetwork, + optimizer: optax.OptState, + ): + """Set up all components required to train the network.""" + self.state_neural_net = neural_net.create_train_state( + self.rng, optimizer, dim_data + ) + self.step_fn = self._get_step_fn() - @property - def regularizer( - self, - ) -> Callable[ - [jnp.ndarray, jnp.ndarray, jnp.ndarray], Tuple[float, Optional[Any]] - ]: - """Conditional regularizer ``(source, mapped, labels) -> (loss, log)``. + @property + def regularizer( + self, + ) -> Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], Tuple[float, + Optional[Any]]]: + """Conditional regularizer ``(source, mapped, labels) -> (loss, log)``. Defaults to zero if not provided. """ - if self._regularizer is not None: - return self._regularizer - return lambda *_, **__: (0.0, None) + if self._regularizer is not None: + return self._regularizer + return lambda *_, **__: (0.0, None) - @property - def fitting_loss( - self, - ) -> Callable[[jnp.ndarray, jnp.ndarray], Tuple[float, Optional[Any]]]: - """Fitting loss ``(mapped, target) -> (loss, log)``. + @property + def fitting_loss( + self, + ) -> Callable[[jnp.ndarray, jnp.ndarray], Tuple[float, Optional[Any]]]: + """Fitting loss ``(mapped, target) -> (loss, log)``. Defaults to zero if not provided. """ - if self._fitting_loss is not None: - return self._fitting_loss - return lambda *_, **__: (0.0, None) - - @staticmethod - def _generate_batch( - loader_source: Iterator[jnp.ndarray], - loader_target: Iterator[jnp.ndarray], - loader_condition: Iterator[jnp.ndarray], - loader_label: Iterator[jnp.ndarray], - ) -> Dict[str, jnp.ndarray]: - """Generate a batch of samples from all four iterators.""" - return { - "source": next(loader_source), - "target": next(loader_target), - "condition": next(loader_condition), - "condition_labels": next(loader_label), - } - - def train_map_estimator( - self, - trainloader_source: Iterator[jnp.ndarray], - trainloader_target: Iterator[jnp.ndarray], - trainloader_condition: Iterator[jnp.ndarray], - trainloader_label: Iterator[jnp.ndarray], - validloader_source: Iterator[jnp.ndarray], - validloader_target: Iterator[jnp.ndarray], - validloader_condition: Iterator[jnp.ndarray], - validloader_label: Iterator[jnp.ndarray], - ) -> Tuple[train_state.TrainState, Dict[str, Any]]: - """Training loop.""" - logs = collections.defaultdict(lambda: collections.defaultdict(list)) - - try: - from tqdm import trange - - tbar = trange(self.num_train_iters, leave=True) - except ImportError: - tbar = range(self.num_train_iters) - - for step in tbar: - is_logging_step = self.logging and ( - (step % self.valid_freq == 0) - or (step == self.num_train_iters - 1) - ) - train_batch = self._generate_batch( - trainloader_source, - trainloader_target, - trainloader_condition, - trainloader_label, - ) - valid_batch = ( - None - if not is_logging_step - else self._generate_batch( - validloader_source, - validloader_target, - validloader_condition, - validloader_label, - ) - ) - self.state_neural_net, current_logs = self.step_fn( - self.state_neural_net, - train_batch, - valid_batch, - is_logging_step, - step, - ) - - if is_logging_step: - for log_key in current_logs: - for metric_key in current_logs[log_key]: - logs[log_key][metric_key].append( - current_logs[log_key][metric_key] - ) - if not isinstance(tbar, range): - reg_msg = ( - "NA" - if current_logs["eval"]["regularizer"] == 0.0 - else f"{current_logs['eval']['regularizer']:.4f}" - ) - postfix_str = ( - f"fitting_loss:" - f" {current_logs['eval']['fitting_loss']:.4f}, " - f"regularizer: {reg_msg} ," - f"total: {current_logs['eval']['total_loss']:.4f}" - ) - tbar.set_postfix_str(postfix_str) - - return self.state_neural_net, logs - - def _get_step_fn(self) -> Callable: - """Create a one-step training and evaluation function.""" - - def loss_fn( - params: frozen_dict.FrozenDict, - apply_fn: Callable, - batch: Dict[str, jnp.ndarray], - step: int, - ) -> Tuple[float, Dict[str, float]]: - """Loss function with conditional map and regularizer.""" - # Apply the conditional map: T(source, condition) - mapped_samples = apply_fn( - {"params": params}, batch["source"], batch["condition"] - ) - - # Fitting loss: Δ(T(x,c), y) - val_fitting_loss, log_fitting_loss = self.fitting_loss( - mapped_samples, batch["target"] - ) - - # Conditional regularizer: R(x, T(x,c), labels) - val_regularizer, log_regularizer = self.regularizer( - batch["source"], mapped_samples, batch["condition_labels"] - ) - - val_tot_loss = ( - val_fitting_loss - + self.regularizer_strength[step] * val_regularizer - ) - - loss_logs = { - "total_loss": val_tot_loss, - "fitting_loss": val_fitting_loss, - "regularizer": val_regularizer, - "log_regularizer": log_regularizer, - "log_fitting": log_fitting_loss, - } - - return val_tot_loss, loss_logs - - @functools.partial(jax.jit, static_argnums=3) - def step_fn( - state_neural_net: train_state.TrainState, - train_batch: Dict[str, jnp.ndarray], - valid_batch: Optional[Dict[str, jnp.ndarray]] = None, - is_logging_step: bool = False, - step: int = 0, - ) -> Tuple[train_state.TrainState, Dict[str, float]]: - """One step function.""" - grad_fn = jax.value_and_grad(loss_fn, argnums=0, has_aux=True) - (_, current_train_logs), grads = grad_fn( - state_neural_net.params, - state_neural_net.apply_fn, - train_batch, - step, - ) - - current_logs = {"train": current_train_logs, "eval": {}} - if is_logging_step: - _, current_eval_logs = loss_fn( - params=state_neural_net.params, - apply_fn=state_neural_net.apply_fn, - batch=valid_batch, - step=step, - ) - current_logs["eval"] = current_eval_logs - - return state_neural_net.apply_gradients(grads=grads), current_logs - - return step_fn + if self._fitting_loss is not None: + return self._fitting_loss + return lambda *_, **__: (0.0, None) + + @staticmethod + def _generate_batch( + loader_source: Iterator[jnp.ndarray], + loader_target: Iterator[jnp.ndarray], + loader_condition: Iterator[jnp.ndarray], + loader_label: Iterator[jnp.ndarray], + ) -> Dict[str, jnp.ndarray]: + """Generate a batch of samples from all four iterators.""" + return { + "source": next(loader_source), + "target": next(loader_target), + "condition": next(loader_condition), + "condition_labels": next(loader_label), + } + + def train_map_estimator( + self, + trainloader_source: Iterator[jnp.ndarray], + trainloader_target: Iterator[jnp.ndarray], + trainloader_condition: Iterator[jnp.ndarray], + trainloader_label: Iterator[jnp.ndarray], + validloader_source: Iterator[jnp.ndarray], + validloader_target: Iterator[jnp.ndarray], + validloader_condition: Iterator[jnp.ndarray], + validloader_label: Iterator[jnp.ndarray], + ) -> Tuple[train_state.TrainState, Dict[str, Any]]: + """Training loop.""" + logs = collections.defaultdict(lambda: collections.defaultdict(list)) + + try: + from tqdm import trange + + tbar = trange(self.num_train_iters, leave=True) + except ImportError: + tbar = range(self.num_train_iters) + + for step in tbar: + is_logging_step = self.logging and ((step % self.valid_freq == 0) or + (step == self.num_train_iters - 1)) + train_batch = self._generate_batch( + trainloader_source, + trainloader_target, + trainloader_condition, + trainloader_label, + ) + valid_batch = ( + None if not is_logging_step else self._generate_batch( + validloader_source, + validloader_target, + validloader_condition, + validloader_label, + ) + ) + self.state_neural_net, current_logs = self.step_fn( + self.state_neural_net, + train_batch, + valid_batch, + is_logging_step, + step, + ) + + if is_logging_step: + for log_key in current_logs: + for metric_key in current_logs[log_key]: + logs[log_key][metric_key].append(current_logs[log_key][metric_key]) + if not isinstance(tbar, range): + reg_msg = ( + "NA" if current_logs["eval"]["regularizer"] == 0.0 else + f"{current_logs['eval']['regularizer']:.4f}" + ) + postfix_str = ( + f"fitting_loss:" + f" {current_logs['eval']['fitting_loss']:.4f}, " + f"regularizer: {reg_msg} ," + f"total: {current_logs['eval']['total_loss']:.4f}" + ) + tbar.set_postfix_str(postfix_str) + + return self.state_neural_net, logs + + def _get_step_fn(self) -> Callable: + """Create a one-step training and evaluation function.""" + + def loss_fn( + params: frozen_dict.FrozenDict, + apply_fn: Callable, + batch: Dict[str, jnp.ndarray], + step: int, + ) -> Tuple[float, Dict[str, float]]: + """Loss function with conditional map and regularizer.""" + # Apply the conditional map: T(source, condition) + mapped_samples = apply_fn({"params": params}, batch["source"], + batch["condition"]) + + # Fitting loss: Δ(T(x,c), y) + val_fitting_loss, log_fitting_loss = self.fitting_loss( + mapped_samples, batch["target"] + ) + + # Conditional regularizer: R(x, T(x,c), labels) + val_regularizer, log_regularizer = self.regularizer( + batch["source"], mapped_samples, batch["condition_labels"] + ) + + val_tot_loss = ( + val_fitting_loss + self.regularizer_strength[step] * val_regularizer + ) + + loss_logs = { + "total_loss": val_tot_loss, + "fitting_loss": val_fitting_loss, + "regularizer": val_regularizer, + "log_regularizer": log_regularizer, + "log_fitting": log_fitting_loss, + } + + return val_tot_loss, loss_logs + + @functools.partial(jax.jit, static_argnums=3) + def step_fn( + state_neural_net: train_state.TrainState, + train_batch: Dict[str, jnp.ndarray], + valid_batch: Optional[Dict[str, jnp.ndarray]] = None, + is_logging_step: bool = False, + step: int = 0, + ) -> Tuple[train_state.TrainState, Dict[str, float]]: + """One step function.""" + grad_fn = jax.value_and_grad(loss_fn, argnums=0, has_aux=True) + (_, current_train_logs), grads = grad_fn( + state_neural_net.params, + state_neural_net.apply_fn, + train_batch, + step, + ) + + current_logs = {"train": current_train_logs, "eval": {}} + if is_logging_step: + _, current_eval_logs = loss_fn( + params=state_neural_net.params, + apply_fn=state_neural_net.apply_fn, + batch=valid_batch, + step=step, + ) + current_logs["eval"] = current_eval_logs + + return state_neural_net.apply_gradients(grads=grads), current_logs + + return step_fn diff --git a/src/ott/neural/networks/__init__.py b/src/ott/neural/networks/__init__.py index 0abb298a0..568a02e37 100644 --- a/src/ott/neural/networks/__init__.py +++ b/src/ott/neural/networks/__init__.py @@ -11,11 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import ( - conditional_perturbation_network, - icnn, - layers, - potentials, -) +from . import conditional_perturbation_network, icnn, layers, potentials __all__ = ["conditional_perturbation_network", "icnn", "layers", "potentials"] diff --git a/src/ott/neural/networks/conditional_perturbation_network.py b/src/ott/neural/networks/conditional_perturbation_network.py index e622f4bfa..76593b35c 100644 --- a/src/ott/neural/networks/conditional_perturbation_network.py +++ b/src/ott/neural/networks/conditional_perturbation_network.py @@ -9,40 +9,43 @@ Union, ) -import flax.linen as nn import jax.numpy as jnp + +import flax.linen as nn import optax -from ott.neural.networks.potentials import ( - BasePotential, - PotentialTrainState, -) + +from ott.neural.networks.potentials import BasePotential, PotentialTrainState class ConditionalPerturbationNetwork(BasePotential): - dim_hidden: Sequence[int] = None - dim_data: int = None - dim_cond: int = None # Full dimension of all context variables concatenated - # Same length as context_entity_bonds if embed_cond_equal is False - # (if True, first item is size of deep set layer, rest is ignored) - dim_cond_map: Iterable[int] = (50,) - act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.gelu - is_potential: bool = False - layer_norm: bool = False - embed_cond_equal: bool = ( - False # Whether all context variables should be treated as set or not - ) - context_entity_bonds: Iterable[Tuple[int, int]] = ( - (0, 10), - (10, 20), - ) # (start, stop) slicing bounds per context modality in c; - # should be contiguous, non-overlapping by default. - num_contexts: int = 2 - - @nn.compact - def __call__( - self, x: jnp.ndarray, c: Optional[jnp.ndarray] = None - ) -> Union[jnp.ndarray, Dict[str, jnp.ndarray]]: # noqa: D102 - """Forward pass: map (x, c) -> x + residual. + """Condition-aware perturbation network for OT maps.""" + + dim_hidden: Sequence[int] = None + dim_data: int = None + dim_cond: int = None # Full dimension of all context variables concatenated + # Same length as context_entity_bonds if embed_cond_equal is False + # (if True, first item is size of deep set layer, rest is ignored) + dim_cond_map: Iterable[int] = (50,) + act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.gelu + is_potential: bool = False + layer_norm: bool = False + embed_cond_equal: bool = ( + False # Whether all context variables should be treated as set or not + ) + context_entity_bonds: Iterable[Tuple[int, int]] = ( + (0, 10), + (10, 20), + ) # (start, stop) slicing bounds per context modality in c; + # should be contiguous, non-overlapping by default. + num_contexts: int = 2 + + @nn.compact + def __call__( + self, + x: jnp.ndarray, + c: Optional[jnp.ndarray] = None + ) -> Union[jnp.ndarray, Dict[str, jnp.ndarray]]: # noqa: D102 + """Forward pass: map (x, c) -> x + residual. Args: x: Input data of shape ``(batch, dim_data)``. @@ -56,86 +59,84 @@ def __call__( Returns: Mapped output of shape ``(batch, dim_data)``. """ - return_batch = False - if isinstance(x, dict): - c = x["c"] - x = x["X"] - return_batch = True - - n_input = x.shape[-1] - - # Chunk the inputs - contexts = [ - c[:, e[0] : e[1]] - for i, e in enumerate(self.context_entity_bonds) - if i < self.num_contexts - ] - - if not self.embed_cond_equal: - # Each context is processed by a different layer, - # good for combining modalities - assert len(self.context_entity_bonds) == len(self.dim_cond_map), ( - "Length of context entity bonds and context map sizes have to " - f"match: {self.context_entity_bonds} != {self.dim_cond_map}" - ) - - layers = [ - nn.Dense(self.dim_cond_map[i], use_bias=True) - for i in range(len(contexts)) - ] - embeddings = [ - self.act_fn(layers[i](context)) - for i, context in enumerate(contexts) - ] - cond_embedding = jnp.concatenate(embeddings, axis=1) - else: - # We can use any number of contexts from the same modality, - # via a permutation-invariant deep set layer. - sizes = [c.shape[-1] for c in contexts] - if not len(set(sizes)) == 1: - raise ValueError( - "For embedding a set, all contexts need same length ," - f"not {sizes}" - ) - layer = nn.Dense(self.dim_cond_map[0], use_bias=True) - embeddings = [self.act_fn(layer(context)) for context in contexts] - # Average along stacked dimension - # (alternatives like summing are possible) - cond_embedding = jnp.mean(jnp.stack(embeddings), axis=0) - - z = jnp.concatenate((x, cond_embedding), axis=1) - if self.layer_norm: - n = nn.LayerNorm() - z = n(z) - - for n_hidden in self.dim_hidden: - wx = nn.Dense(n_hidden, use_bias=True) - z = self.act_fn(wx(z)) - wx = nn.Dense(n_input, use_bias=True) - - y = x + wx(z) - - if return_batch: - return {"X": y, "c": c} - else: - return y - - def create_train_state( - self, - rng: jnp.ndarray, - optimizer: optax.OptState, - dim_data: int, - **kwargs: Any, - ) -> PotentialTrainState: - """Create initial `TrainState`.""" - c = jnp.ones((1, self.dim_cond)) # (n_batch, embed_dim) - x = jnp.ones((1, dim_data)) # (n_batch, data_dim) - params = self.init(rng, x=x, c=c)["params"] - return PotentialTrainState.create( - apply_fn=self.apply, - params=params, - tx=optimizer, - potential_value_fn=self.potential_value_fn, - potential_gradient_fn=self.potential_gradient_fn, - **kwargs, + return_batch = False + if isinstance(x, dict): + c = x["c"] + x = x["X"] + return_batch = True + + n_input = x.shape[-1] + + # Chunk the inputs + contexts = [ + c[:, e[0]:e[1]] + for i, e in enumerate(self.context_entity_bonds) + if i < self.num_contexts + ] + + if not self.embed_cond_equal: + # Each context is processed by a different layer, + # good for combining modalities + assert len(self.context_entity_bonds) == len(self.dim_cond_map), ( + "Length of context entity bonds and context map sizes have to " + f"match: {self.context_entity_bonds} != {self.dim_cond_map}" + ) + + layers = [ + nn.Dense(self.dim_cond_map[i], use_bias=True) + for i in range(len(contexts)) + ] + embeddings = [ + self.act_fn(layers[i](context)) for i, context in enumerate(contexts) + ] + cond_embedding = jnp.concatenate(embeddings, axis=1) + else: + # We can use any number of contexts from the same modality, + # via a permutation-invariant deep set layer. + sizes = [c.shape[-1] for c in contexts] + if not len(set(sizes)) == 1: + raise ValueError( + "For embedding a set, all contexts need same length ," + f"not {sizes}" ) + layer = nn.Dense(self.dim_cond_map[0], use_bias=True) + embeddings = [self.act_fn(layer(context)) for context in contexts] + # Average along stacked dimension + # (alternatives like summing are possible) + cond_embedding = jnp.mean(jnp.stack(embeddings), axis=0) + + z = jnp.concatenate((x, cond_embedding), axis=1) + if self.layer_norm: + n = nn.LayerNorm() + z = n(z) + + for n_hidden in self.dim_hidden: + wx = nn.Dense(n_hidden, use_bias=True) + z = self.act_fn(wx(z)) + wx = nn.Dense(n_input, use_bias=True) + + y = x + wx(z) + + if return_batch: + return {"X": y, "c": c} + return y + + def create_train_state( + self, + rng: jnp.ndarray, + optimizer: optax.OptState, + dim_data: int, + **kwargs: Any, + ) -> PotentialTrainState: + """Create initial `TrainState`.""" + c = jnp.ones((1, self.dim_cond)) # (n_batch, embed_dim) + x = jnp.ones((1, dim_data)) # (n_batch, data_dim) + params = self.init(rng, x=x, c=c)["params"] + return PotentialTrainState.create( + apply_fn=self.apply, + params=params, + tx=optimizer, + potential_value_fn=self.potential_value_fn, + potential_gradient_fn=self.potential_gradient_fn, + **kwargs, + ) diff --git a/tests/neural/methods/conditional_monge_gap_test.py b/tests/neural/methods/conditional_monge_gap_test.py index dba6dbb32..bf97fda20 100644 --- a/tests/neural/methods/conditional_monge_gap_test.py +++ b/tests/neural/methods/conditional_monge_gap_test.py @@ -31,269 +31,333 @@ from ott.tools import sinkhorn_divergence -@pytest.mark.fast() +@pytest.mark.fast class TestConditionalMongeGap: - @pytest.mark.parametrize("n_samples", [10, 30]) - @pytest.mark.parametrize("n_features", [4, 10]) - @pytest.mark.parametrize("num_conditions", [2, 3]) - def test_non_negativity( - self, rng: jax.Array, n_samples: int, n_features: int, - num_conditions: int, - ): - rng1, rng2 = jax.random.split(rng) - per_cond = n_samples // num_conditions - n = per_cond * num_conditions - - source = jax.random.normal(rng1, (n, n_features)) - target = source + 0.5 * jax.random.normal(rng2, (n, n_features)) - condition = jnp.repeat(jnp.arange(num_conditions), per_cond) - - gap = conditional_monge_gap.cmonge_gap_from_samples( - source, target, condition, - num_segments=num_conditions, + @pytest.mark.parametrize("n_samples", [10, 30]) + @pytest.mark.parametrize("n_features", [4, 10]) + @pytest.mark.parametrize("num_conditions", [2, 3]) + def test_non_negativity( + self, + rng: jax.Array, + n_samples: int, + n_features: int, + num_conditions: int, + ): + rng1, rng2 = jax.random.split(rng) + per_cond = n_samples // num_conditions + n = per_cond * num_conditions + + source = jax.random.normal(rng1, (n, n_features)) + target = source + 0.5 * jax.random.normal(rng2, (n, n_features)) + condition = jnp.repeat(jnp.arange(num_conditions), per_cond) + + gap = conditional_monge_gap.cmonge_gap_from_samples( + source, + target, + condition, + num_segments=num_conditions, + max_measure_size=per_cond, + ) + np.testing.assert_array_equal(gap >= 0, True) + + def test_jit_consistency(self, rng: jax.Array): + n, d, k = 60, 4, 3 + per_cond = n // k + rng1, rng2 = jax.random.split(rng) + source = jax.random.normal(rng1, (n, d)) + target = source + 0.1 * jax.random.normal(rng2, (n, d)) + condition = jnp.repeat(jnp.arange(k), per_cond) + + eager_gap = conditional_monge_gap.cmonge_gap_from_samples( + source, + target, + condition, + num_segments=k, + max_measure_size=per_cond, + ) + jit_gap = jax.jit( + lambda s, t, c: conditional_monge_gap.cmonge_gap_from_samples( + s, + t, + c, + num_segments=k, max_measure_size=per_cond, ) - np.testing.assert_array_equal(gap >= 0, True) - - def test_jit_consistency(self, rng: jax.Array): - n, d, k = 60, 4, 3 - per_cond = n // k - rng1, rng2 = jax.random.split(rng) - source = jax.random.normal(rng1, (n, d)) - target = source + 0.1 * jax.random.normal(rng2, (n, d)) - condition = jnp.repeat(jnp.arange(k), per_cond) - - eager_gap = conditional_monge_gap.cmonge_gap_from_samples( - source, target, condition, - num_segments=k, max_measure_size=per_cond, - ) - jit_gap = jax.jit( - lambda s, t, c: conditional_monge_gap.cmonge_gap_from_samples( - s, t, c, num_segments=k, max_measure_size=per_cond, - ) - )(source, target, condition) - - np.testing.assert_allclose(eager_gap, jit_gap, rtol=1e-3) - - def test_matches_loop_baseline(self, rng: jax.Array): - """Segment-based result matches manual per-condition loop.""" - n, d, k = 60, 4, 3 - per_cond = n // k - rng1, rng2 = jax.random.split(rng) - source = jax.random.normal(rng1, (n, d)) - target = source + 0.1 * jax.random.normal(rng2, (n, d)) - condition = jnp.repeat(jnp.arange(k), per_cond) - - new_gap = conditional_monge_gap.cmonge_gap_from_samples( - source, target, condition, - num_segments=k, max_measure_size=per_cond, - ) - - # Manual loop (the old approach) - manual_gaps = [] - for c in range(k): - mask = condition == c - gap = monge_gap_from_samples(source[mask], target[mask]) - manual_gaps.append(float(gap)) - manual_avg = sum(manual_gaps) / len(manual_gaps) - - np.testing.assert_allclose(float(new_gap), manual_avg, atol=1e-5) - - def test_identity_smaller_than_random(self, rng: jax.Array): - """Identity map should have smaller Monge gap than a random map.""" - n, d, k = 60, 4, 3 - per_cond = n // k - rng1, rng2 = jax.random.split(rng) - source = jax.random.normal(rng1, (n, d)) - condition = jnp.repeat(jnp.arange(k), per_cond) - - identity_gap = conditional_monge_gap.cmonge_gap_from_samples( - source, source, condition, - num_segments=k, max_measure_size=per_cond, - ) - random_target = jax.random.normal(rng2, (n, d)) * 3.0 - random_gap = conditional_monge_gap.cmonge_gap_from_samples( - source, random_target, condition, - num_segments=k, max_measure_size=per_cond, - ) - assert identity_gap < random_gap - - @pytest.mark.parametrize("cost_fn", [ - costs.SqEuclidean(), - costs.PNormP(p=1), - ], ids=["sqeucl", "pnorm-1"]) - def test_different_costs(self, rng: jax.Array, cost_fn: costs.CostFn): - n, d, k = 30, 4, 3 - per_cond = n // k - rng1, rng2 = jax.random.split(rng) - source = jax.random.normal(rng1, (n, d)) - target = source + jax.random.normal(rng2, (n, d)) * 0.5 - condition = jnp.repeat(jnp.arange(k), per_cond) - - gap = conditional_monge_gap.cmonge_gap_from_samples( - source, target, condition, cost_fn=cost_fn, - num_segments=k, max_measure_size=per_cond, - ) - np.testing.assert_array_equal(jnp.isfinite(gap), True) - np.testing.assert_array_equal(gap >= 0, True) - - def test_return_output_shape(self, rng: jax.Array): - n, d, k = 60, 4, 3 - per_cond = n // k - rng1, rng2 = jax.random.split(rng) - source = jax.random.normal(rng1, (n, d)) - target = source + 0.1 * jax.random.normal(rng2, (n, d)) - condition = jnp.repeat(jnp.arange(k), per_cond) - - result = conditional_monge_gap.cmonge_gap_from_samples( - source, target, condition, - num_segments=k, max_measure_size=per_cond, - return_output=True, - ) - assert isinstance(result, tuple) - avg_gap, per_cond_gaps = result - assert per_cond_gaps.shape == (k,) - np.testing.assert_allclose( - float(avg_gap), float(jnp.mean(per_cond_gaps)), rtol=1e-5, - ) - - - @pytest.mark.parametrize("n_samples", [10, 30]) - @pytest.mark.parametrize("n_features", [4, 10]) - def test_non_negativity_neural_map( - self, rng: jax.Array, n_samples: int, n_features: int, - ): - """Non-negativity with a learned nonlinear map (mirrors monge_gap_test).""" - k = 2 - per_cond = n_samples // k - n = per_cond * k - rng1, rng2 = jax.random.split(rng) - - source = jax.random.normal(rng1, (n, n_features)) - model = potentials.PotentialMLP(dim_hidden=[8, 8], is_potential=False) - params = model.init(rng2, x=source[0]) - target = model.apply(params, source) - condition = jnp.repeat(jnp.arange(k), per_cond) - - gap = conditional_monge_gap.cmonge_gap_from_samples( - source, target, condition, - num_segments=k, max_measure_size=per_cond, - ) - np.testing.assert_array_equal(jnp.isfinite(gap), True) - np.testing.assert_array_equal(gap >= 0, True) - - @pytest.mark.parametrize("cost_fn", [ - costs.PNormP(p=1), - costs.RegTICost(regularizers.L1(), lam=2.0), - costs.RegTICost(regularizers.STVS(gamma=3.0), lam=1.0), - ], ids=["pnorm-1", "l1-lam2", "stvs-lam1"]) - def test_different_costs_give_different_values( - self, rng: jax.Array, cost_fn: costs.CostFn, - ): - """Non-Euclidean costs produce different cmonge_gap than Euclidean.""" - n, d, k = 30, 5, 3 - per_cond = n // k - rng1, rng2 = jax.random.split(rng) - source = jax.random.normal(rng1, (n, d)) - target = jax.random.normal(rng2, (n, d)) * 0.1 + 3.0 - condition = jnp.repeat(jnp.arange(k), per_cond) - - gap_eucl = conditional_monge_gap.cmonge_gap_from_samples( - source, target, condition, cost_fn=costs.Euclidean(), - num_segments=k, max_measure_size=per_cond, - ) - gap_other = conditional_monge_gap.cmonge_gap_from_samples( - source, target, condition, cost_fn=cost_fn, - num_segments=k, max_measure_size=per_cond, - ) - - with pytest.raises(AssertionError, match=r"tolerance"): - np.testing.assert_allclose( - gap_eucl, gap_other, rtol=1e-1, atol=1e-1, - ) - np.testing.assert_array_equal(jnp.isfinite(gap_eucl), True) - np.testing.assert_array_equal(jnp.isfinite(gap_other), True) - - def test_uniform_conditions_equals_averaged_monge_gap( - self, rng: jax.Array, - ): - """cmonge_gap with equal-size conditions == mean of monge_gap calls.""" - k = 3 - per_cond = 20 - d = 5 - n = k * per_cond - - # Different offsets per condition so gaps are distinct - offsets = jnp.array([0.1, 1.0, 3.0]) - rngs = jax.random.split(rng, 2 * k) - sources, targets = [], [] - for c in range(k): - s = jax.random.normal(rngs[2 * c], (per_cond, d)) - t = s + offsets[c] + 0.05 * jax.random.normal( - rngs[2 * c + 1], (per_cond, d) - ) - sources.append(s) - targets.append(t) - - source = jnp.concatenate(sources, axis=0) - target = jnp.concatenate(targets, axis=0) - condition = jnp.repeat(jnp.arange(k), per_cond) - - # Segmented cmonge_gap (single call, vmapped) - t0 = time.perf_counter() - avg_gap, per_cond_gaps = conditional_monge_gap.cmonge_gap_from_samples( - source, target, condition, - num_segments=k, max_measure_size=per_cond, - return_output=True, - ) - # Force computation to complete before timing - avg_gap.block_until_ready() - t_cmonge = time.perf_counter() - t0 - - # Manual per-condition monge_gap calls (K sequential calls) - t0 = time.perf_counter() - manual_gaps = [] - for c in range(k): - gap_c = monge_gap_from_samples(sources[c], targets[c]) - manual_gaps.append(float(gap_c)) - manual_avg = sum(manual_gaps) / k - t_loop = time.perf_counter() - t0 - - # Single-condition overhead: cmonge_gap(K=1) vs monge_gap - t0 = time.perf_counter() - gap_single_cmonge = conditional_monge_gap.cmonge_gap_from_samples( - sources[0], targets[0], - jnp.zeros(per_cond, dtype=jnp.int32), - num_segments=1, max_measure_size=per_cond, - ) - gap_single_cmonge.block_until_ready() - t_cmonge_1 = time.perf_counter() - t0 - - t0 = time.perf_counter() - gap_single_monge = monge_gap_from_samples(sources[0], targets[0]) - float(gap_single_monge) # block - t_monge_1 = time.perf_counter() - t0 - - print( - f"\n K={k}: cmonge_gap: {t_cmonge:.3f}s | " - f"loop({k}x monge_gap): {t_loop:.3f}s | " - f"speedup: {t_loop / t_cmonge:.1f}x" - f"\n K=1: cmonge_gap: {t_cmonge_1:.3f}s | " - f"monge_gap: {t_monge_1:.3f}s | " - f"overhead: {t_cmonge_1 / t_monge_1:.1f}x" - ) - - # Average should match - np.testing.assert_allclose(float(avg_gap), manual_avg, atol=1e-5) - # Per-condition gaps should match individual calls - for c in range(k): - np.testing.assert_allclose( - float(per_cond_gaps[c]), manual_gaps[c], atol=1e-5, - ) - - def test_unequal_conditions_shifts_average(self, rng: jax.Array): - """With unequal n_k, per-condition gaps change and shift the average. + )(source, target, condition) + + np.testing.assert_allclose(eager_gap, jit_gap, rtol=1e-3) + + def test_matches_loop_baseline(self, rng: jax.Array): + """Segment-based result matches manual per-condition loop.""" + n, d, k = 60, 4, 3 + per_cond = n // k + rng1, rng2 = jax.random.split(rng) + source = jax.random.normal(rng1, (n, d)) + target = source + 0.1 * jax.random.normal(rng2, (n, d)) + condition = jnp.repeat(jnp.arange(k), per_cond) + + new_gap = conditional_monge_gap.cmonge_gap_from_samples( + source, + target, + condition, + num_segments=k, + max_measure_size=per_cond, + ) + + # Manual loop (the old approach) + manual_gaps = [] + for c in range(k): + mask = condition == c + gap = monge_gap_from_samples(source[mask], target[mask]) + manual_gaps.append(float(gap)) + manual_avg = sum(manual_gaps) / len(manual_gaps) + + np.testing.assert_allclose(float(new_gap), manual_avg, atol=1e-5) + + def test_identity_smaller_than_random(self, rng: jax.Array): + """Identity map should have smaller Monge gap than a random map.""" + n, d, k = 60, 4, 3 + per_cond = n // k + rng1, rng2 = jax.random.split(rng) + source = jax.random.normal(rng1, (n, d)) + condition = jnp.repeat(jnp.arange(k), per_cond) + + identity_gap = conditional_monge_gap.cmonge_gap_from_samples( + source, + source, + condition, + num_segments=k, + max_measure_size=per_cond, + ) + random_target = jax.random.normal(rng2, (n, d)) * 3.0 + random_gap = conditional_monge_gap.cmonge_gap_from_samples( + source, + random_target, + condition, + num_segments=k, + max_measure_size=per_cond, + ) + assert identity_gap < random_gap + + @pytest.mark.parametrize( + "cost_fn", + [ + costs.SqEuclidean(), + costs.PNormP(p=1), + ], + ids=["sqeucl", "pnorm-1"], + ) + def test_different_costs(self, rng: jax.Array, cost_fn: costs.CostFn): + n, d, k = 30, 4, 3 + per_cond = n // k + rng1, rng2 = jax.random.split(rng) + source = jax.random.normal(rng1, (n, d)) + target = source + jax.random.normal(rng2, (n, d)) * 0.5 + condition = jnp.repeat(jnp.arange(k), per_cond) + + gap = conditional_monge_gap.cmonge_gap_from_samples( + source, + target, + condition, + cost_fn=cost_fn, + num_segments=k, + max_measure_size=per_cond, + ) + np.testing.assert_array_equal(jnp.isfinite(gap), True) + np.testing.assert_array_equal(gap >= 0, True) + + def test_return_output_shape(self, rng: jax.Array): + n, d, k = 60, 4, 3 + per_cond = n // k + rng1, rng2 = jax.random.split(rng) + source = jax.random.normal(rng1, (n, d)) + target = source + 0.1 * jax.random.normal(rng2, (n, d)) + condition = jnp.repeat(jnp.arange(k), per_cond) + + result = conditional_monge_gap.cmonge_gap_from_samples( + source, + target, + condition, + num_segments=k, + max_measure_size=per_cond, + return_output=True, + ) + assert isinstance(result, tuple) + avg_gap, per_cond_gaps = result + assert per_cond_gaps.shape == (k,) + np.testing.assert_allclose( + float(avg_gap), + float(jnp.mean(per_cond_gaps)), + rtol=1e-5, + ) + + @pytest.mark.parametrize("n_samples", [10, 30]) + @pytest.mark.parametrize("n_features", [4, 10]) + def test_non_negativity_neural_map( + self, + rng: jax.Array, + n_samples: int, + n_features: int, + ): + """Non-negativity with a learned nonlinear map.""" + k = 2 + per_cond = n_samples // k + n = per_cond * k + rng1, rng2 = jax.random.split(rng) + + source = jax.random.normal(rng1, (n, n_features)) + model = potentials.PotentialMLP(dim_hidden=[8, 8], is_potential=False) + params = model.init(rng2, x=source[0]) + target = model.apply(params, source) + condition = jnp.repeat(jnp.arange(k), per_cond) + + gap = conditional_monge_gap.cmonge_gap_from_samples( + source, + target, + condition, + num_segments=k, + max_measure_size=per_cond, + ) + np.testing.assert_array_equal(jnp.isfinite(gap), True) + np.testing.assert_array_equal(gap >= 0, True) + + @pytest.mark.parametrize( + "cost_fn", + [ + costs.PNormP(p=1), + costs.RegTICost(regularizers.L1(), lam=2.0), + costs.RegTICost(regularizers.STVS(gamma=3.0), lam=1.0), + ], + ids=["pnorm-1", "l1-lam2", "stvs-lam1"], + ) + def test_different_costs_give_different_values( + self, + rng: jax.Array, + cost_fn: costs.CostFn, + ): + """Non-Euclidean costs produce different cmonge_gap than Euclidean.""" + n, d, k = 30, 5, 3 + per_cond = n // k + rng1, rng2 = jax.random.split(rng) + source = jax.random.normal(rng1, (n, d)) + target = jax.random.normal(rng2, (n, d)) * 0.1 + 3.0 + condition = jnp.repeat(jnp.arange(k), per_cond) + + gap_eucl = conditional_monge_gap.cmonge_gap_from_samples( + source, + target, + condition, + cost_fn=costs.Euclidean(), + num_segments=k, + max_measure_size=per_cond, + ) + gap_other = conditional_monge_gap.cmonge_gap_from_samples( + source, + target, + condition, + cost_fn=cost_fn, + num_segments=k, + max_measure_size=per_cond, + ) + + with pytest.raises(AssertionError, match=r"tolerance"): + np.testing.assert_allclose( + gap_eucl, + gap_other, + rtol=1e-1, + atol=1e-1, + ) + np.testing.assert_array_equal(jnp.isfinite(gap_eucl), True) + np.testing.assert_array_equal(jnp.isfinite(gap_other), True) + + def test_uniform_conditions_equals_averaged_monge_gap( + self, + rng: jax.Array, + ): + """cmonge_gap with equal-size conditions == mean of monge_gap calls.""" + k = 3 + per_cond = 20 + d = 5 + + # Different offsets per condition so gaps are distinct + offsets = jnp.array([0.1, 1.0, 3.0]) + rngs = jax.random.split(rng, 2 * k) + sources, targets = [], [] + for c in range(k): + s = jax.random.normal(rngs[2 * c], (per_cond, d)) + t = ( + s + offsets[c] + + 0.05 * jax.random.normal(rngs[2 * c + 1], (per_cond, d)) + ) + sources.append(s) + targets.append(t) + + source = jnp.concatenate(sources, axis=0) + target = jnp.concatenate(targets, axis=0) + condition = jnp.repeat(jnp.arange(k), per_cond) + + # Segmented cmonge_gap (single call, vmapped) + t0 = time.perf_counter() + avg_gap, per_cond_gaps = conditional_monge_gap.cmonge_gap_from_samples( + source, + target, + condition, + num_segments=k, + max_measure_size=per_cond, + return_output=True, + ) + # Force computation to complete before timing + avg_gap.block_until_ready() + t_cmonge = time.perf_counter() - t0 + + # Manual per-condition monge_gap calls (K sequential calls) + t0 = time.perf_counter() + manual_gaps = [] + for c in range(k): + gap_c = monge_gap_from_samples(sources[c], targets[c]) + manual_gaps.append(float(gap_c)) + manual_avg = sum(manual_gaps) / k + t_loop = time.perf_counter() - t0 + + # Single-condition overhead: cmonge_gap(K=1) vs monge_gap + t0 = time.perf_counter() + gap_single_cmonge = conditional_monge_gap.cmonge_gap_from_samples( + sources[0], + targets[0], + jnp.zeros(per_cond, dtype=jnp.int32), + num_segments=1, + max_measure_size=per_cond, + ) + gap_single_cmonge.block_until_ready() + t_cmonge_1 = time.perf_counter() - t0 + + t0 = time.perf_counter() + gap_single_monge = monge_gap_from_samples(sources[0], targets[0]) + float(gap_single_monge) # block + t_monge_1 = time.perf_counter() - t0 + + print( # noqa: T201 + f"\n K={k}: cmonge_gap: {t_cmonge:.3f}s | " + f"loop({k}x monge_gap): {t_loop:.3f}s | " + f"speedup: {t_loop / t_cmonge:.1f}x" + f"\n K=1: cmonge_gap: {t_cmonge_1:.3f}s | " + f"monge_gap: {t_monge_1:.3f}s | " + f"overhead: {t_cmonge_1 / t_monge_1:.1f}x" + ) + + # Average should match + np.testing.assert_allclose(float(avg_gap), manual_avg, atol=1e-5) + # Per-condition gaps should match individual calls + for c in range(k): + np.testing.assert_allclose( + float(per_cond_gaps[c]), + manual_gaps[c], + atol=1e-5, + ) + + def test_unequal_conditions_shifts_average(self, rng: jax.Array): + """With unequal n_k, per-condition gaps change and shift the average. The segment interface pads all conditions to max_measure_size, so per-condition gaps with padding do NOT exactly match non-padded @@ -301,211 +365,232 @@ def test_unequal_conditions_shifts_average(self, rng: jax.Array): structural properties instead: gaps are finite, easy < hard, average = mean(per_cond_gaps), and the average shifts when n_k changes. """ - d = 5 - rng_easy, rng_hard, rng_noise = jax.random.split(rng, 3) - - base_easy = jax.random.normal(rng_easy, (60, d)) - base_hard = jax.random.normal(rng_hard, (60, d)) - noise = 0.01 * jax.random.normal(rng_noise, (60, d)) - - target_easy = base_easy + noise - target_hard = base_hard + 5.0 - - # (a) Equal sizes: 30/30 - n_eq = 30 - src_eq = jnp.concatenate([base_easy[:n_eq], base_hard[:n_eq]]) - tgt_eq = jnp.concatenate([target_easy[:n_eq], target_hard[:n_eq]]) - cond_eq = jnp.repeat(jnp.arange(2), n_eq) - - avg_eq, gaps_eq = conditional_monge_gap.cmonge_gap_from_samples( - src_eq, tgt_eq, cond_eq, - num_segments=2, max_measure_size=n_eq, - return_output=True, - ) - - # (b) Unequal sizes: 50 easy / 10 hard - n_a, n_b = 50, 10 - src_uneq = jnp.concatenate([base_easy[:n_a], base_hard[:n_b]]) - tgt_uneq = jnp.concatenate([target_easy[:n_a], target_hard[:n_b]]) - cond_uneq = jnp.concatenate([ - jnp.zeros(n_a, dtype=jnp.int32), - jnp.ones(n_b, dtype=jnp.int32), - ]) - - avg_uneq, gaps_uneq = conditional_monge_gap.cmonge_gap_from_samples( - src_uneq, tgt_uneq, cond_uneq, - num_segments=2, max_measure_size=n_a, - return_output=True, - ) - - # All gaps are finite and non-negative - for gaps in [gaps_eq, gaps_uneq]: - np.testing.assert_array_equal(jnp.all(jnp.isfinite(gaps)), True) - np.testing.assert_array_equal(jnp.all(gaps >= 0), True) - - # Easy condition has smaller gap than hard condition - assert gaps_eq[0] < gaps_eq[1] - assert gaps_uneq[0] < gaps_uneq[1] - - # Average is the mean of per-condition gaps - np.testing.assert_allclose( - float(avg_eq), float(jnp.mean(gaps_eq)), rtol=1e-5, - ) - np.testing.assert_allclose( - float(avg_uneq), float(jnp.mean(gaps_uneq)), rtol=1e-5, - ) - - # Averages differ between equal and unequal splits (n_k affects - # the padded OT cost estimation, shifting per-condition gaps) - assert float(avg_eq) != float(avg_uneq) - - def test_per_condition_gaps_reflect_difficulty(self, rng: jax.Array): - """Per-condition gaps increase with offset magnitude.""" - k = 3 - per_cond = 25 - d = 4 - offsets = jnp.array([0.0, 1.5, 5.0]) - - rngs = jax.random.split(rng, 2 * k) - sources, targets = [], [] - for c in range(k): - s = jax.random.normal(rngs[2 * c], (per_cond, d)) - t = s + offsets[c] - sources.append(s) - targets.append(t) - - source = jnp.concatenate(sources, axis=0) - target = jnp.concatenate(targets, axis=0) - condition = jnp.repeat(jnp.arange(k), per_cond) - - _, per_cond_gaps = conditional_monge_gap.cmonge_gap_from_samples( - source, target, condition, - num_segments=k, max_measure_size=per_cond, - return_output=True, - ) - - assert per_cond_gaps[0] < per_cond_gaps[1] < per_cond_gaps[2] - - -@pytest.mark.fast() + d = 5 + rng_easy, rng_hard, rng_noise = jax.random.split(rng, 3) + + base_easy = jax.random.normal(rng_easy, (60, d)) + base_hard = jax.random.normal(rng_hard, (60, d)) + noise = 0.01 * jax.random.normal(rng_noise, (60, d)) + + target_easy = base_easy + noise + target_hard = base_hard + 5.0 + + # (a) Equal sizes: 30/30 + n_eq = 30 + src_eq = jnp.concatenate([base_easy[:n_eq], base_hard[:n_eq]]) + tgt_eq = jnp.concatenate([target_easy[:n_eq], target_hard[:n_eq]]) + cond_eq = jnp.repeat(jnp.arange(2), n_eq) + + avg_eq, gaps_eq = conditional_monge_gap.cmonge_gap_from_samples( + src_eq, + tgt_eq, + cond_eq, + num_segments=2, + max_measure_size=n_eq, + return_output=True, + ) + + # (b) Unequal sizes: 50 easy / 10 hard + n_a, n_b = 50, 10 + src_uneq = jnp.concatenate([base_easy[:n_a], base_hard[:n_b]]) + tgt_uneq = jnp.concatenate([target_easy[:n_a], target_hard[:n_b]]) + cond_uneq = jnp.concatenate([ + jnp.zeros(n_a, dtype=jnp.int32), + jnp.ones(n_b, dtype=jnp.int32), + ]) + + avg_uneq, gaps_uneq = conditional_monge_gap.cmonge_gap_from_samples( + src_uneq, + tgt_uneq, + cond_uneq, + num_segments=2, + max_measure_size=n_a, + return_output=True, + ) + + # All gaps are finite and non-negative + for gaps in [gaps_eq, gaps_uneq]: + np.testing.assert_array_equal(jnp.all(jnp.isfinite(gaps)), True) + np.testing.assert_array_equal(jnp.all(gaps >= 0), True) + + # Easy condition has smaller gap than hard condition + assert gaps_eq[0] < gaps_eq[1] + assert gaps_uneq[0] < gaps_uneq[1] + + # Average is the mean of per-condition gaps + np.testing.assert_allclose( + float(avg_eq), + float(jnp.mean(gaps_eq)), + rtol=1e-5, + ) + np.testing.assert_allclose( + float(avg_uneq), + float(jnp.mean(gaps_uneq)), + rtol=1e-5, + ) + + # Averages differ between equal and unequal splits (n_k affects + # the padded OT cost estimation, shifting per-condition gaps) + assert float(avg_eq) != float(avg_uneq) + + def test_per_condition_gaps_reflect_difficulty(self, rng: jax.Array): + """Per-condition gaps increase with offset magnitude.""" + k = 3 + per_cond = 25 + d = 4 + offsets = jnp.array([0.0, 1.5, 5.0]) + + rngs = jax.random.split(rng, 2 * k) + sources, targets = [], [] + for c in range(k): + s = jax.random.normal(rngs[2 * c], (per_cond, d)) + t = s + offsets[c] + sources.append(s) + targets.append(t) + + source = jnp.concatenate(sources, axis=0) + target = jnp.concatenate(targets, axis=0) + condition = jnp.repeat(jnp.arange(k), per_cond) + + _, per_cond_gaps = conditional_monge_gap.cmonge_gap_from_samples( + source, + target, + condition, + num_segments=k, + max_measure_size=per_cond, + return_output=True, + ) + + assert per_cond_gaps[0] < per_cond_gaps[1] < per_cond_gaps[2] + + +@pytest.mark.fast class TestConditionalMongeGapEstimator: - def test_estimator_convergence(self): - """Train a conditional map and verify loss decreases.""" - num_conditions = 3 - dim_data = 2 - dim_cond = num_conditions # one-hot - batch_size = 30 - - train_ds, valid_ds, _, n_cond, max_ms = ( - datasets.create_conditional_gaussian_mixture_samplers( - num_conditions=num_conditions, - dim=dim_data, - train_batch_size=batch_size, - valid_batch_size=batch_size, - ) - ) - - def fitting_loss(mapped, target): - div, _ = sinkhorn_divergence.sinkdiv(x=mapped, y=target) - return div, None - - def regularizer(source, mapped, labels): - gap, per_cond = conditional_monge_gap.cmonge_gap_from_samples( - source, mapped, labels, - num_segments=n_cond, - max_measure_size=max_ms, - return_output=True, - ) - return gap, None - - model = ConditionalPerturbationNetwork( - dim_hidden=[16, 8], - dim_data=dim_data, - dim_cond=dim_cond, - dim_cond_map=(16,), - is_potential=False, - context_entity_bonds=((0, dim_cond),), - num_contexts=1, - ) - - solver = conditional_monge_gap.ConditionalMongeGapEstimator( - dim_data=dim_data, - fitting_loss=fitting_loss, - regularizer=regularizer, - model=model, - regularizer_strength=1.0, - num_train_iters=15, - logging=True, - valid_freq=5, - ) - - neural_state, logs = solver.train_map_estimator( - *train_ds, *valid_ds, - ) - - # Loss should decrease - assert logs["train"]["total_loss"][0] > logs["train"]["total_loss"][-1] - - # Output shape should match input - source_batch = next(train_ds.source_iter) - cond_batch = next(train_ds.condition_iter) - mapped = neural_state.apply_fn( - {"params": neural_state.params}, source_batch, cond_batch, - ) - assert mapped.shape == source_batch.shape - np.testing.assert_array_equal(jnp.all(jnp.isfinite(mapped)), True) - - def test_estimator_no_regularizer(self): - """Training with regularizer_strength=0 still converges.""" - num_conditions = 2 - dim_data = 2 - dim_cond = num_conditions - batch_size = 20 - - train_ds, valid_ds, _, _, _ = ( - datasets.create_conditional_gaussian_mixture_samplers( - num_conditions=num_conditions, - dim=dim_data, - train_batch_size=batch_size, - valid_batch_size=batch_size, - ) - ) - - def fitting_loss(mapped, target): - div, _ = sinkhorn_divergence.sinkdiv(x=mapped, y=target) - return div, None - - model = ConditionalPerturbationNetwork( - dim_hidden=[8, 8], - dim_data=dim_data, - dim_cond=dim_cond, - dim_cond_map=(8,), - is_potential=False, - context_entity_bonds=((0, dim_cond),), - num_contexts=1, - ) - - solver = conditional_monge_gap.ConditionalMongeGapEstimator( - dim_data=dim_data, - fitting_loss=fitting_loss, - model=model, - regularizer_strength=0.0, - num_train_iters=10, - logging=True, - valid_freq=5, + def test_estimator_convergence(self): + """Train a conditional map and verify loss decreases.""" + num_conditions = 3 + dim_data = 2 + dim_cond = num_conditions # one-hot + batch_size = 30 + + train_ds, valid_ds, _, n_cond, max_ms = ( + datasets.create_conditional_gaussian_mixture_samplers( + num_conditions=num_conditions, + dim=dim_data, + train_batch_size=batch_size, + valid_batch_size=batch_size, ) - - neural_state, logs = solver.train_map_estimator( - *train_ds, *valid_ds, - ) - - # Should have run without errors and logged metrics - assert len(logs["train"]["total_loss"]) > 0 - # Mapped output should be finite - source_batch = next(train_ds.source_iter) - cond_batch = next(train_ds.condition_iter) - mapped = neural_state.apply_fn( - {"params": neural_state.params}, source_batch, cond_batch, + ) + + def fitting_loss(mapped, target): + div, _ = sinkhorn_divergence.sinkdiv(x=mapped, y=target) + return div, None + + def regularizer(source, mapped, labels): + gap, per_cond = conditional_monge_gap.cmonge_gap_from_samples( + source, + mapped, + labels, + num_segments=n_cond, + max_measure_size=max_ms, + return_output=True, + ) + return gap, None + + model = ConditionalPerturbationNetwork( + dim_hidden=[16, 8], + dim_data=dim_data, + dim_cond=dim_cond, + dim_cond_map=(16,), + is_potential=False, + context_entity_bonds=((0, dim_cond),), + num_contexts=1, + ) + + solver = conditional_monge_gap.ConditionalMongeGapEstimator( + dim_data=dim_data, + fitting_loss=fitting_loss, + regularizer=regularizer, + model=model, + regularizer_strength=1.0, + num_train_iters=15, + logging=True, + valid_freq=5, + ) + + neural_state, logs = solver.train_map_estimator( + *train_ds, + *valid_ds, + ) + + # Loss should decrease + assert logs["train"]["total_loss"][0] > logs["train"]["total_loss"][-1] + + # Output shape should match input + source_batch = next(train_ds.source_iter) + cond_batch = next(train_ds.condition_iter) + mapped = neural_state.apply_fn( + {"params": neural_state.params}, + source_batch, + cond_batch, + ) + assert mapped.shape == source_batch.shape + np.testing.assert_array_equal(jnp.all(jnp.isfinite(mapped)), True) + + def test_estimator_no_regularizer(self): + """Training with regularizer_strength=0 still converges.""" + num_conditions = 2 + dim_data = 2 + dim_cond = num_conditions + batch_size = 20 + + train_ds, valid_ds, _, _, _ = ( + datasets.create_conditional_gaussian_mixture_samplers( + num_conditions=num_conditions, + dim=dim_data, + train_batch_size=batch_size, + valid_batch_size=batch_size, ) - np.testing.assert_array_equal(jnp.all(jnp.isfinite(mapped)), True) + ) + + def fitting_loss(mapped, target): + div, _ = sinkhorn_divergence.sinkdiv(x=mapped, y=target) + return div, None + + model = ConditionalPerturbationNetwork( + dim_hidden=[8, 8], + dim_data=dim_data, + dim_cond=dim_cond, + dim_cond_map=(8,), + is_potential=False, + context_entity_bonds=((0, dim_cond),), + num_contexts=1, + ) + + solver = conditional_monge_gap.ConditionalMongeGapEstimator( + dim_data=dim_data, + fitting_loss=fitting_loss, + model=model, + regularizer_strength=0.0, + num_train_iters=10, + logging=True, + valid_freq=5, + ) + + neural_state, logs = solver.train_map_estimator( + *train_ds, + *valid_ds, + ) + + # Should have run without errors and logged metrics + assert len(logs["train"]["total_loss"]) > 0 + # Mapped output should be finite + source_batch = next(train_ds.source_iter) + cond_batch = next(train_ds.condition_iter) + mapped = neural_state.apply_fn( + {"params": neural_state.params}, + source_batch, + cond_batch, + ) + np.testing.assert_array_equal(jnp.all(jnp.isfinite(mapped)), True) From 4a60ff58ac0c54f23d5baa25b55be8a15a95fe5d Mon Sep 17 00:00:00 2001 From: DhruvaRajwade Date: Wed, 1 Apr 2026 13:25:20 +0200 Subject: [PATCH 6/6] style: fix docstring indentation and pytest mark style for ruff v0.4.10 --- src/ott/datasets.py | 122 +++++++-------- .../neural/methods/conditional_monge_gap.py | 146 +++++++++--------- .../conditional_perturbation_network.py | 24 +-- .../methods/conditional_monge_gap_test.py | 4 +- 4 files changed, 148 insertions(+), 148 deletions(-) diff --git a/src/ott/datasets.py b/src/ott/datasets.py index 05ea03ed7..e19ac7182 100644 --- a/src/ott/datasets.py +++ b/src/ott/datasets.py @@ -34,10 +34,10 @@ class Dataset(NamedTuple): r"""Samplers from source and target measures. - Args: - source_iter: loader for the source measure - target_iter: loader for the target measure - """ + Args: + source_iter: loader for the source measure + target_iter: loader for the target measure + """ source_iter: Iterator[jnp.ndarray] target_iter: Iterator[jnp.ndarray] @@ -46,13 +46,13 @@ class Dataset(NamedTuple): class ConditionalDataset(NamedTuple): r"""Samplers from conditional source and target measures. - Args: - source_iter: loader for the source measure, ``[batch, d]`` - target_iter: loader for the target measure, ``[batch, d]`` - condition_iter: loader for condition vectors, - ``[batch, dim_c]`` - label_iter: loader for integer condition labels, ``[batch]`` - """ + Args: + source_iter: loader for the source measure, ``[batch, d]`` + target_iter: loader for the target measure, ``[batch, d]`` + condition_iter: loader for condition vectors, + ``[batch, dim_c]`` + label_iter: loader for integer condition labels, ``[batch]`` + """ source_iter: Iterator[jnp.ndarray] target_iter: Iterator[jnp.ndarray] @@ -64,21 +64,21 @@ class ConditionalDataset(NamedTuple): class GaussianMixture: """A mixture of Gaussians. - Args: - name: the name specifying the centers of the mixture components: + Args: + name: the name specifying the centers of the mixture components: - - ``simple`` - data clustered in one center, - - ``circle`` - two-dimensional Gaussians arranged on a circle, - - ``square_five`` - two-dimensional Gaussians on a square with - one Gaussian in the center, and - - ``square_four`` - two-dimensional Gaussians in the corners of a - rectangle + - ``simple`` - data clustered in one center, + - ``circle`` - two-dimensional Gaussians arranged on a circle, + - ``square_five`` - two-dimensional Gaussians on a square with + one Gaussian in the center, and + - ``square_four`` - two-dimensional Gaussians in the corners of a + rectangle - batch_size: batch size of the samples - rng: initial PRNG key - scale: scale of the Gaussian means - std: the standard deviation of the individual Gaussian samples - """ + batch_size: batch size of the samples + rng: initial PRNG key + scale: scale of the Gaussian means + std: the standard deviation of the individual Gaussian samples + """ name: Name_t batch_size: int @@ -115,9 +115,9 @@ def __post_init__(self) -> None: def __iter__(self) -> Iterator[jnp.array]: """Random sample generator from Gaussian mixture. - Returns: - A generator of samples from the Gaussian mixture. - """ + Returns: + A generator of samples from the Gaussian mixture. + """ return self._create_sample_generators() def _create_sample_generators(self) -> Iterator[jnp.array]: @@ -139,16 +139,16 @@ def create_gaussian_mixture_samplers( ) -> Tuple[Dataset, Dataset, int]: """Gaussian samplers. - Args: - name_source: name of the source sampler - name_target: name of the target sampler - train_batch_size: the training batch size - valid_batch_size: the validation batch size - rng: initial PRNG key + Args: + name_source: name of the source sampler + name_target: name of the target sampler + train_batch_size: the training batch size + valid_batch_size: the validation batch size + rng: initial PRNG key - Returns: - The dataset and dimension of the data. - """ + Returns: + The dataset and dimension of the data. + """ rng = utils.default_prng_key(rng) rng1, rng2, rng3, rng4 = jax.random.split(rng, 4) train_dataset = Dataset( @@ -175,17 +175,17 @@ def create_gaussian_mixture_samplers( class ConditionalGaussianMixture: """Conditional Gaussian sampler for testing. - For each condition *k*, draws source ~ N(0, I) and - target ~ source + offset_k. - Condition vectors are one-hot encoded labels. + For each condition *k*, draws source ~ N(0, I) and + target ~ source + offset_k. + Condition vectors are one-hot encoded labels. - Args: - num_conditions: number of distinct conditions. - batch_size: total batch size (divided equally among conditions). - dim: data dimensionality. - offsets: ``[num_conditions, dim]`` translation per condition. - rng: initial PRNG key. - """ + Args: + num_conditions: number of distinct conditions. + batch_size: total batch size (divided equally among conditions). + dim: data dimensionality. + offsets: ``[num_conditions, dim]`` translation per condition. + rng: initial PRNG key. + """ num_conditions: int batch_size: int @@ -229,21 +229,21 @@ def create_conditional_gaussian_mixture_samplers( ) -> Tuple[ConditionalDataset, ConditionalDataset, int, int, int]: """Create conditional Gaussian samplers for testing. - Each condition defines a different translation of the source distribution. - - Args: - num_conditions: number of distinct conditions. - dim: data dimensionality. - train_batch_size: training batch size (should be divisible by - ``num_conditions``). - valid_batch_size: validation batch size. - rng: initial PRNG key. - - Returns: - ``(train_dataset, valid_dataset, dim_data, num_conditions, - max_measure_size)`` where ``max_measure_size = - batch_size // num_conditions``. - """ + Each condition defines a different translation of the source distribution. + + Args: + num_conditions: number of distinct conditions. + dim: data dimensionality. + train_batch_size: training batch size (should be divisible by + ``num_conditions``). + valid_batch_size: validation batch size. + rng: initial PRNG key. + + Returns: + ``(train_dataset, valid_dataset, dim_data, num_conditions, + max_measure_size)`` where ``max_measure_size = + batch_size // num_conditions``. + """ rng = utils.default_prng_key(rng) rng1, rng2, rng_off = jax.random.split(rng, 3) diff --git a/src/ott/neural/methods/conditional_monge_gap.py b/src/ott/neural/methods/conditional_monge_gap.py index d47528a96..9ff80db5d 100644 --- a/src/ott/neural/methods/conditional_monge_gap.py +++ b/src/ott/neural/methods/conditional_monge_gap.py @@ -64,48 +64,48 @@ def cmonge_gap_from_samples( ) -> Union[float, Tuple[float, jnp.ndarray]]: r"""Conditional Monge gap from samples using the segment interface. - Computes the average Monge gap across conditions: - - .. math:: - - \frac{1}{K} \sum_{k=1}^{K} \left[ - \frac{1}{n_k} \sum_{i:\, c_i = k} c(x_i, y_i) - - W_{c, \varepsilon}\!\bigl(\hat{\rho}_{n_k}^{(k)},\, - \hat{\nu}_{n_k}^{(k)}\bigr) \right] - - where :math:`W_{c, \varepsilon}` is the - :term:`entropy-regularized optimal transport` cost. - - This implementation uses :func:`~ott.geometry.segment._segment_interface` - to pad and ``vmap`` across conditions, making it fully JIT-compatible. - - Args: - source: samples from first measure, array of shape ``[n, d]``. - target: samples from second measure, array of shape ``[n, d]``. - Assumed paired with ``source``, i.e. ``target[i] = T(source[i])``. - condition: integer array of shape ``[n]`` indicating the condition - for each source-target pair. Values in ``range(num_segments)``. - cost_fn: a cost function between two points in dimension :math:`d`. - If :obj:`None`, :class:`~ott.geometry.costs.SqEuclidean` is used. - epsilon: regularization parameter. See - :class:`~ott.geometry.pointcloud.PointCloud`. - relative_epsilon: when set, ``epsilon`` refers to a fraction of the - :attr:`~ott.geometry.pointcloud.PointCloud.mean_cost_matrix`. - scale_cost: option to rescale the cost matrix. Implemented scalings - are ``'median'``, ``'mean'`` and ``'max_cost'``. Alternatively, a - float factor can be given to rescale the cost such that - ``cost_matrix /= scale_cost``. - return_output: if :obj:`True`, also return per-condition Monge gaps. - num_segments: number of distinct conditions. Required for JIT. - max_measure_size: maximum number of points in any single condition - (used for padding). Required for JIT. - kwargs: keyword arguments for the - :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver. - - Returns: - The average Monge gap across conditions and, when ``return_output`` - is :obj:`True`, a ``[num_segments]`` array of per-condition gaps. - """ + Computes the average Monge gap across conditions: + + .. math:: + + \frac{1}{K} \sum_{k=1}^{K} \left[ + \frac{1}{n_k} \sum_{i:\, c_i = k} c(x_i, y_i) - + W_{c, \varepsilon}\!\bigl(\hat{\rho}_{n_k}^{(k)},\, + \hat{\nu}_{n_k}^{(k)}\bigr) \right] + + where :math:`W_{c, \varepsilon}` is the + :term:`entropy-regularized optimal transport` cost. + + This implementation uses :func:`~ott.geometry.segment._segment_interface` + to pad and ``vmap`` across conditions, making it fully JIT-compatible. + + Args: + source: samples from first measure, array of shape ``[n, d]``. + target: samples from second measure, array of shape ``[n, d]``. + Assumed paired with ``source``, i.e. ``target[i] = T(source[i])``. + condition: integer array of shape ``[n]`` indicating the condition + for each source-target pair. Values in ``range(num_segments)``. + cost_fn: a cost function between two points in dimension :math:`d`. + If :obj:`None`, :class:`~ott.geometry.costs.SqEuclidean` is used. + epsilon: regularization parameter. See + :class:`~ott.geometry.pointcloud.PointCloud`. + relative_epsilon: when set, ``epsilon`` refers to a fraction of the + :attr:`~ott.geometry.pointcloud.PointCloud.mean_cost_matrix`. + scale_cost: option to rescale the cost matrix. Implemented scalings + are ``'median'``, ``'mean'`` and ``'max_cost'``. Alternatively, a + float factor can be given to rescale the cost such that + ``cost_matrix /= scale_cost``. + return_output: if :obj:`True`, also return per-condition Monge gaps. + num_segments: number of distinct conditions. Required for JIT. + max_measure_size: maximum number of points in any single condition + (used for padding). Required for JIT. + kwargs: keyword arguments for the + :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver. + + Returns: + The average Monge gap across conditions and, when ``return_output`` + is :obj:`True`, a ``[num_segments]`` array of per-condition gaps. + """ cost_fn = costs.SqEuclidean() if cost_fn is None else cost_fn dim = source.shape[1] padding_vector = cost_fn._padder(dim=dim) @@ -180,39 +180,39 @@ def eval_fn( class ConditionalMongeGapEstimator: r"""Conditional map estimator between probability measures. - Estimates a condition-dependent map :math:`T(\cdot, c)` by minimizing: + Estimates a condition-dependent map :math:`T(\cdot, c)` by minimizing: - .. math:: + .. math:: - \min_\theta \; \Delta\bigl(T_\theta(\cdot, c) \sharp \mu,\, \nu\bigr) - + \lambda \; R_{\text{cond}}\bigl(T_\theta(\cdot, c) \sharp \rho,\, - \rho \mid c\bigr) + \min_\theta \; \Delta\bigl(T_\theta(\cdot, c) \sharp \mu,\, \nu\bigr) + + \lambda \; R_{\text{cond}}\bigl(T_\theta(\cdot, c) \sharp \rho,\, + \rho \mid c\bigr) - where :math:`\Delta` is a fitting loss (e.g. - :func:`~ott.tools.sinkhorn_divergence.sinkdiv`), - :math:`R_{\text{cond}}` is the conditional Monge gap regularizer - :func:`cmonge_gap_from_samples`, and :math:`c` is a condition label. + where :math:`\Delta` is a fitting loss (e.g. + :func:`~ott.tools.sinkhorn_divergence.sinkdiv`), + :math:`R_{\text{cond}}` is the conditional Monge gap regularizer + :func:`cmonge_gap_from_samples`, and :math:`c` is a condition label. - This mirrors :class:`~ott.neural.methods.monge_gap.MongeGapEstimator` - but handles condition-aware maps and per-condition regularization. + This mirrors :class:`~ott.neural.methods.monge_gap.MongeGapEstimator` + but handles condition-aware maps and per-condition regularization. - Args: - dim_data: input dimensionality of the data. - model: a :class:`~ott.neural.networks.\ + Args: + dim_data: input dimensionality of the data. + model: a :class:`~ott.neural.networks.\ conditional_perturbation_network.ConditionalPerturbationNetwork` or any - ``nn.Module`` whose ``__call__`` signature is ``(x, c)``. - optimizer: optimizer for the map parameters. - fitting_loss: callable ``(mapped, target) -> (loss, log)`` that - measures how well the pushforward matches the target distribution. - regularizer: callable ``(source, mapped, condition_labels) -> - (loss, log)`` that computes the conditional Monge gap or similar - per-condition regularizer. - regularizer_strength: scalar or schedule for :math:`\lambda`. - num_train_iters: number of training iterations. - logging: whether to record train/eval metrics. - valid_freq: how often to evaluate on the validation set. - rng: random seed. - """ + ``nn.Module`` whose ``__call__`` signature is ``(x, c)``. + optimizer: optimizer for the map parameters. + fitting_loss: callable ``(mapped, target) -> (loss, log)`` that + measures how well the pushforward matches the target distribution. + regularizer: callable ``(source, mapped, condition_labels) -> + (loss, log)`` that computes the conditional Monge gap or similar + per-condition regularizer. + regularizer_strength: scalar or schedule for :math:`\lambda`. + num_train_iters: number of training iterations. + logging: whether to record train/eval metrics. + valid_freq: how often to evaluate on the validation set. + rng: random seed. + """ def __init__( self, @@ -268,8 +268,8 @@ def regularizer( Optional[Any]]]: """Conditional regularizer ``(source, mapped, labels) -> (loss, log)``. - Defaults to zero if not provided. - """ + Defaults to zero if not provided. + """ if self._regularizer is not None: return self._regularizer return lambda *_, **__: (0.0, None) @@ -280,8 +280,8 @@ def fitting_loss( ) -> Callable[[jnp.ndarray, jnp.ndarray], Tuple[float, Optional[Any]]]: """Fitting loss ``(mapped, target) -> (loss, log)``. - Defaults to zero if not provided. - """ + Defaults to zero if not provided. + """ if self._fitting_loss is not None: return self._fitting_loss return lambda *_, **__: (0.0, None) diff --git a/src/ott/neural/networks/conditional_perturbation_network.py b/src/ott/neural/networks/conditional_perturbation_network.py index 76593b35c..bafb73260 100644 --- a/src/ott/neural/networks/conditional_perturbation_network.py +++ b/src/ott/neural/networks/conditional_perturbation_network.py @@ -47,18 +47,18 @@ def __call__( ) -> Union[jnp.ndarray, Dict[str, jnp.ndarray]]: # noqa: D102 """Forward pass: map (x, c) -> x + residual. - Args: - x: Input data of shape ``(batch, dim_data)``. - c: Context vector of shape ``(batch, dim_cond)``. May - contain multiple modalities concatenated along the last - axis. ``context_entity_bonds`` specifies which slice - ``c[:, start:stop]`` belongs to each modality. Slices - should generally be contiguous and non-overlapping, e.g. - ``((0, 10), (10, 20))`` for two 10-dim modalities. - - Returns: - Mapped output of shape ``(batch, dim_data)``. - """ + Args: + x: Input data of shape ``(batch, dim_data)``. + c: Context vector of shape ``(batch, dim_cond)``. May + contain multiple modalities concatenated along the last + axis. ``context_entity_bonds`` specifies which slice + ``c[:, start:stop]`` belongs to each modality. Slices + should generally be contiguous and non-overlapping, e.g. + ``((0, 10), (10, 20))`` for two 10-dim modalities. + + Returns: + Mapped output of shape ``(batch, dim_data)``. + """ return_batch = False if isinstance(x, dict): c = x["c"] diff --git a/tests/neural/methods/conditional_monge_gap_test.py b/tests/neural/methods/conditional_monge_gap_test.py index bf97fda20..320dc06ad 100644 --- a/tests/neural/methods/conditional_monge_gap_test.py +++ b/tests/neural/methods/conditional_monge_gap_test.py @@ -31,7 +31,7 @@ from ott.tools import sinkhorn_divergence -@pytest.mark.fast +@pytest.mark.fast() class TestConditionalMongeGap: @pytest.mark.parametrize("n_samples", [10, 30]) @@ -464,7 +464,7 @@ def test_per_condition_gaps_reflect_difficulty(self, rng: jax.Array): assert per_cond_gaps[0] < per_cond_gaps[1] < per_cond_gaps[2] -@pytest.mark.fast +@pytest.mark.fast() class TestConditionalMongeGapEstimator: def test_estimator_convergence(self):