diff --git a/src/ott/datasets.py b/src/ott/datasets.py index 206123b84..e19ac7182 100644 --- a/src/ott/datasets.py +++ b/src/ott/datasets.py @@ -18,7 +18,13 @@ import jax.numpy as jnp import numpy as np -__all__ = ["create_gaussian_mixture_samplers", "Dataset", "GaussianMixture"] +__all__ = [ + "create_gaussian_mixture_samplers", + "create_conditional_gaussian_mixture_samplers", + "ConditionalDataset", + "Dataset", + "GaussianMixture", +] from ott import utils @@ -32,8 +38,26 @@ class Dataset(NamedTuple): source_iter: loader for the source measure target_iter: loader for the target measure """ + + source_iter: Iterator[jnp.ndarray] + target_iter: Iterator[jnp.ndarray] + + +class ConditionalDataset(NamedTuple): + r"""Samplers from conditional source and target measures. + + Args: + source_iter: loader for the source measure, ``[batch, d]`` + target_iter: loader for the target measure, ``[batch, d]`` + condition_iter: loader for condition vectors, + ``[batch, dim_c]`` + label_iter: loader for integer condition labels, ``[batch]`` + """ + source_iter: Iterator[jnp.ndarray] target_iter: Iterator[jnp.ndarray] + condition_iter: Iterator[jnp.ndarray] + label_iter: Iterator[jnp.ndarray] @dataclasses.dataclass @@ -55,6 +79,7 @@ class GaussianMixture: scale: scale of the Gaussian means std: the standard deviation of the individual Gaussian samples """ + name: Name_t batch_size: int rng: jax.Array @@ -132,7 +157,7 @@ def create_gaussian_mixture_samplers( ), target_iter=iter( GaussianMixture(name_target, batch_size=train_batch_size, rng=rng2) - ) + ), ) valid_dataset = Dataset( source_iter=iter( @@ -140,7 +165,130 @@ def create_gaussian_mixture_samplers( ), target_iter=iter( GaussianMixture(name_target, batch_size=valid_batch_size, rng=rng4) - ) + ), ) dim_data = 2 return train_dataset, valid_dataset, dim_data + + +@dataclasses.dataclass +class ConditionalGaussianMixture: + """Conditional Gaussian sampler for testing. + + For each condition *k*, draws source ~ N(0, I) and + target ~ source + offset_k. + Condition vectors are one-hot encoded labels. + + Args: + num_conditions: number of distinct conditions. + batch_size: total batch size (divided equally among conditions). + dim: data dimensionality. + offsets: ``[num_conditions, dim]`` translation per condition. + rng: initial PRNG key. + """ + + num_conditions: int + batch_size: int + dim: int + offsets: jnp.ndarray + rng: jax.Array + + def __iter__(self) -> Iterator[Tuple[jnp.ndarray, ...]]: + return self._generate() + + def _generate(self) -> Iterator[Tuple[jnp.ndarray, ...]]: + rng = self.rng + per_cond = self.batch_size // self.num_conditions + while True: + rng, rng_s = jax.random.split(rng) + sources, targets, conds, labels = [], [], [], [] + for k in range(self.num_conditions): + rng_s, rng_k = jax.random.split(rng_s) + s = jax.random.normal(rng_k, (per_cond, self.dim)) + t = s + self.offsets[k] + c = jnp.zeros((per_cond, self.num_conditions)).at[:, k].set(1.0) + lab = jnp.full((per_cond,), k, dtype=jnp.int32) + sources.append(s) + targets.append(t) + conds.append(c) + labels.append(lab) + yield ( + jnp.concatenate(sources), + jnp.concatenate(targets), + jnp.concatenate(conds), + jnp.concatenate(labels), + ) + + +def create_conditional_gaussian_mixture_samplers( + num_conditions: int = 3, + dim: int = 2, + train_batch_size: int = 90, + valid_batch_size: int = 90, + rng: Optional[jax.Array] = None, +) -> Tuple[ConditionalDataset, ConditionalDataset, int, int, int]: + """Create conditional Gaussian samplers for testing. + + Each condition defines a different translation of the source distribution. + + Args: + num_conditions: number of distinct conditions. + dim: data dimensionality. + train_batch_size: training batch size (should be divisible by + ``num_conditions``). + valid_batch_size: validation batch size. + rng: initial PRNG key. + + Returns: + ``(train_dataset, valid_dataset, dim_data, num_conditions, + max_measure_size)`` where ``max_measure_size = + batch_size // num_conditions``. + """ + rng = utils.default_prng_key(rng) + rng1, rng2, rng_off = jax.random.split(rng, 3) + + # Each condition has a different offset (translation) + offsets = jax.random.normal(rng_off, (num_conditions, dim)) * 3.0 + + def _make_dataset( + bs: int, + key: jax.Array, + ) -> ConditionalDataset: + sampler = ConditionalGaussianMixture( + num_conditions=num_conditions, + batch_size=bs, + dim=dim, + offsets=offsets, + rng=key, + ) + gen = iter(sampler) + # Cache the current batch so all 4 iterators stay synchronized. + cache = {} + + def _next_batch(): + if "batch" not in cache: + cache["batch"] = next(gen) + return cache + + def _iter(idx: int) -> Iterator[jnp.ndarray]: + while True: + c = _next_batch() + val = c["batch"][idx] + # Mark consumed; when all 4 are done, clear cache. + c.setdefault("consumed", set()) + c["consumed"].add(idx) + if len(c["consumed"]) == 4: + cache.clear() + yield val + + return ConditionalDataset( + source_iter=_iter(0), + target_iter=_iter(1), + condition_iter=_iter(2), + label_iter=_iter(3), + ) + + train_ds = _make_dataset(train_batch_size, rng1) + valid_ds = _make_dataset(valid_batch_size, rng2) + max_measure_size = train_batch_size // num_conditions + return train_ds, valid_ds, dim, num_conditions, max_measure_size diff --git a/src/ott/neural/methods/__init__.py b/src/ott/neural/methods/__init__.py index ea3e51b31..983c3e2ac 100644 --- a/src/ott/neural/methods/__init__.py +++ b/src/ott/neural/methods/__init__.py @@ -11,4 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import expectile_neural_dual, flow_matching, monge_gap, neuraldual +from . import ( + conditional_monge_gap, + expectile_neural_dual, + flow_matching, + monge_gap, + neuraldual, +) diff --git a/src/ott/neural/methods/conditional_monge_gap.py b/src/ott/neural/methods/conditional_monge_gap.py new file mode 100644 index 000000000..9ff80db5d --- /dev/null +++ b/src/ott/neural/methods/conditional_monge_gap.py @@ -0,0 +1,436 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import collections +import functools +import logging +from typing import ( + Any, + Callable, + Dict, + Iterator, + Literal, + Optional, + Sequence, + Tuple, + Union, +) + +import jax +import jax.numpy as jnp + +import optax +from flax.core import frozen_dict +from flax.training import train_state + +from ott import utils +from ott.geometry import costs, pointcloud, segment +from ott.neural.networks.conditional_perturbation_network import ( + ConditionalPerturbationNetwork, +) +from ott.problems.linear import linear_problem +from ott.solvers.linear import sinkhorn + +logger = logging.getLogger(__name__) + +__all__ = [ + "cmonge_gap_from_samples", + "ConditionalMongeGapEstimator", +] + + +def cmonge_gap_from_samples( + source: jnp.ndarray, + target: jnp.ndarray, + condition: jnp.ndarray, + cost_fn: Optional[costs.CostFn] = None, + epsilon: Optional[float] = None, + relative_epsilon: Optional[Literal["mean", "std"]] = None, + scale_cost: Union[float, Literal["mean", "max_cost", "median"]] = 1.0, + return_output: bool = False, + num_segments: Optional[int] = None, + max_measure_size: Optional[int] = None, + **kwargs: Any, +) -> Union[float, Tuple[float, jnp.ndarray]]: + r"""Conditional Monge gap from samples using the segment interface. + + Computes the average Monge gap across conditions: + + .. math:: + + \frac{1}{K} \sum_{k=1}^{K} \left[ + \frac{1}{n_k} \sum_{i:\, c_i = k} c(x_i, y_i) - + W_{c, \varepsilon}\!\bigl(\hat{\rho}_{n_k}^{(k)},\, + \hat{\nu}_{n_k}^{(k)}\bigr) \right] + + where :math:`W_{c, \varepsilon}` is the + :term:`entropy-regularized optimal transport` cost. + + This implementation uses :func:`~ott.geometry.segment._segment_interface` + to pad and ``vmap`` across conditions, making it fully JIT-compatible. + + Args: + source: samples from first measure, array of shape ``[n, d]``. + target: samples from second measure, array of shape ``[n, d]``. + Assumed paired with ``source``, i.e. ``target[i] = T(source[i])``. + condition: integer array of shape ``[n]`` indicating the condition + for each source-target pair. Values in ``range(num_segments)``. + cost_fn: a cost function between two points in dimension :math:`d`. + If :obj:`None`, :class:`~ott.geometry.costs.SqEuclidean` is used. + epsilon: regularization parameter. See + :class:`~ott.geometry.pointcloud.PointCloud`. + relative_epsilon: when set, ``epsilon`` refers to a fraction of the + :attr:`~ott.geometry.pointcloud.PointCloud.mean_cost_matrix`. + scale_cost: option to rescale the cost matrix. Implemented scalings + are ``'median'``, ``'mean'`` and ``'max_cost'``. Alternatively, a + float factor can be given to rescale the cost such that + ``cost_matrix /= scale_cost``. + return_output: if :obj:`True`, also return per-condition Monge gaps. + num_segments: number of distinct conditions. Required for JIT. + max_measure_size: maximum number of points in any single condition + (used for padding). Required for JIT. + kwargs: keyword arguments for the + :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver. + + Returns: + The average Monge gap across conditions and, when ``return_output`` + is :obj:`True`, a ``[num_segments]`` array of per-condition gaps. + """ + cost_fn = costs.SqEuclidean() if cost_fn is None else cost_fn + dim = source.shape[1] + padding_vector = cost_fn._padder(dim=dim) + + # Warn if any condition is heavily padded (>10x below max_measure_size), + # which can cause numerical differences vs non-padded Sinkhorn solves. + # Skipped silently under JIT where condition values are traced. + if max_measure_size is not None and num_segments is not None: + try: + counts = jnp.bincount(condition, length=num_segments) + min_count = int(jnp.min(counts)) + if min_count > 0 and max_measure_size // min_count >= 10: + logger.warning( + "Condition with %d samples will be padded to %d " + "(%.0fx). Per-condition Monge gap values may differ " + "from non-padded monge_gap_from_samples calls.", + min_count, + max_measure_size, + max_measure_size / min_count, + ) + except jax.errors.ConcretizationTypeError: + pass + + # NOTE: Eval function takes some logic from: + # ott.neural.methods.monge_gap.monge_gap_from_samples` + # as well as `ott.geometry.segment.py` + def eval_fn( + padded_x: jnp.ndarray, + padded_y: jnp.ndarray, + padded_weight_x: jnp.ndarray, + padded_weight_y: jnp.ndarray, + ) -> jnp.ndarray: + """Monge gap for a single (padded) condition segment.""" + # Displacement cost: weighted mean of pairwise costs c(x_i, T(x_i)). + # Padded entries have weight 0, so they do not contribute. + pairwise_costs = jax.vmap(cost_fn)(padded_x, padded_y) + displacement_cost = jnp.sum(pairwise_costs * padded_weight_x) + + # Entropy-regularized OT cost W_{c,ε}. + geom = pointcloud.PointCloud( + padded_x, + padded_y, + cost_fn=cost_fn, + epsilon=epsilon, + relative_epsilon=relative_epsilon, + scale_cost=scale_cost, + ) + prob = linear_problem.LinearProblem( + geom, a=padded_weight_x, b=padded_weight_y + ) + solver = sinkhorn.Sinkhorn(**kwargs) + out = solver(prob) + + return displacement_cost - out.ent_reg_cost + + per_condition_gaps = segment._segment_interface( + x=source, + y=target, + eval_fn=eval_fn, + num_segments=num_segments, + max_measure_size=max_measure_size, + segment_ids_x=condition, + segment_ids_y=condition, + indices_are_sorted=False, + padding_vector=padding_vector, + ) + + avg_gap = jnp.mean(per_condition_gaps) + return (avg_gap, per_condition_gaps) if return_output else avg_gap + + +class ConditionalMongeGapEstimator: + r"""Conditional map estimator between probability measures. + + Estimates a condition-dependent map :math:`T(\cdot, c)` by minimizing: + + .. math:: + + \min_\theta \; \Delta\bigl(T_\theta(\cdot, c) \sharp \mu,\, \nu\bigr) + + \lambda \; R_{\text{cond}}\bigl(T_\theta(\cdot, c) \sharp \rho,\, + \rho \mid c\bigr) + + where :math:`\Delta` is a fitting loss (e.g. + :func:`~ott.tools.sinkhorn_divergence.sinkdiv`), + :math:`R_{\text{cond}}` is the conditional Monge gap regularizer + :func:`cmonge_gap_from_samples`, and :math:`c` is a condition label. + + This mirrors :class:`~ott.neural.methods.monge_gap.MongeGapEstimator` + but handles condition-aware maps and per-condition regularization. + + Args: + dim_data: input dimensionality of the data. + model: a :class:`~ott.neural.networks.\ +conditional_perturbation_network.ConditionalPerturbationNetwork` or any + ``nn.Module`` whose ``__call__`` signature is ``(x, c)``. + optimizer: optimizer for the map parameters. + fitting_loss: callable ``(mapped, target) -> (loss, log)`` that + measures how well the pushforward matches the target distribution. + regularizer: callable ``(source, mapped, condition_labels) -> + (loss, log)`` that computes the conditional Monge gap or similar + per-condition regularizer. + regularizer_strength: scalar or schedule for :math:`\lambda`. + num_train_iters: number of training iterations. + logging: whether to record train/eval metrics. + valid_freq: how often to evaluate on the validation set. + rng: random seed. + """ + + def __init__( + self, + dim_data: int, + model: ConditionalPerturbationNetwork, + optimizer: Optional[optax.OptState] = None, + fitting_loss: Optional[Callable[[jnp.ndarray, jnp.ndarray], + Tuple[float, Optional[Any]]]] = None, + regularizer: Optional[Callable[ + [jnp.ndarray, jnp.ndarray, jnp.ndarray], + Tuple[float, Optional[Any]], + ]] = None, + regularizer_strength: Union[float, Sequence[float]] = 1.0, + num_train_iters: int = 10_000, + logging: bool = False, + valid_freq: int = 500, + rng: Optional[jax.Array] = None, + ): + self._fitting_loss = fitting_loss + self._regularizer = regularizer + self.regularizer_strength = jnp.repeat( + jnp.atleast_2d(regularizer_strength), + num_train_iters, + total_repeat_length=num_train_iters, + axis=0, + ).ravel() + self.num_train_iters = num_train_iters + self.logging = logging + self.valid_freq = valid_freq + self.rng = utils.default_prng_key(rng) + + if optimizer is None: + optimizer = optax.adam(learning_rate=0.001) + + self.setup(dim_data, model, optimizer) + + def setup( + self, + dim_data: int, + neural_net: ConditionalPerturbationNetwork, + optimizer: optax.OptState, + ): + """Set up all components required to train the network.""" + self.state_neural_net = neural_net.create_train_state( + self.rng, optimizer, dim_data + ) + self.step_fn = self._get_step_fn() + + @property + def regularizer( + self, + ) -> Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], Tuple[float, + Optional[Any]]]: + """Conditional regularizer ``(source, mapped, labels) -> (loss, log)``. + + Defaults to zero if not provided. + """ + if self._regularizer is not None: + return self._regularizer + return lambda *_, **__: (0.0, None) + + @property + def fitting_loss( + self, + ) -> Callable[[jnp.ndarray, jnp.ndarray], Tuple[float, Optional[Any]]]: + """Fitting loss ``(mapped, target) -> (loss, log)``. + + Defaults to zero if not provided. + """ + if self._fitting_loss is not None: + return self._fitting_loss + return lambda *_, **__: (0.0, None) + + @staticmethod + def _generate_batch( + loader_source: Iterator[jnp.ndarray], + loader_target: Iterator[jnp.ndarray], + loader_condition: Iterator[jnp.ndarray], + loader_label: Iterator[jnp.ndarray], + ) -> Dict[str, jnp.ndarray]: + """Generate a batch of samples from all four iterators.""" + return { + "source": next(loader_source), + "target": next(loader_target), + "condition": next(loader_condition), + "condition_labels": next(loader_label), + } + + def train_map_estimator( + self, + trainloader_source: Iterator[jnp.ndarray], + trainloader_target: Iterator[jnp.ndarray], + trainloader_condition: Iterator[jnp.ndarray], + trainloader_label: Iterator[jnp.ndarray], + validloader_source: Iterator[jnp.ndarray], + validloader_target: Iterator[jnp.ndarray], + validloader_condition: Iterator[jnp.ndarray], + validloader_label: Iterator[jnp.ndarray], + ) -> Tuple[train_state.TrainState, Dict[str, Any]]: + """Training loop.""" + logs = collections.defaultdict(lambda: collections.defaultdict(list)) + + try: + from tqdm import trange + + tbar = trange(self.num_train_iters, leave=True) + except ImportError: + tbar = range(self.num_train_iters) + + for step in tbar: + is_logging_step = self.logging and ((step % self.valid_freq == 0) or + (step == self.num_train_iters - 1)) + train_batch = self._generate_batch( + trainloader_source, + trainloader_target, + trainloader_condition, + trainloader_label, + ) + valid_batch = ( + None if not is_logging_step else self._generate_batch( + validloader_source, + validloader_target, + validloader_condition, + validloader_label, + ) + ) + self.state_neural_net, current_logs = self.step_fn( + self.state_neural_net, + train_batch, + valid_batch, + is_logging_step, + step, + ) + + if is_logging_step: + for log_key in current_logs: + for metric_key in current_logs[log_key]: + logs[log_key][metric_key].append(current_logs[log_key][metric_key]) + if not isinstance(tbar, range): + reg_msg = ( + "NA" if current_logs["eval"]["regularizer"] == 0.0 else + f"{current_logs['eval']['regularizer']:.4f}" + ) + postfix_str = ( + f"fitting_loss:" + f" {current_logs['eval']['fitting_loss']:.4f}, " + f"regularizer: {reg_msg} ," + f"total: {current_logs['eval']['total_loss']:.4f}" + ) + tbar.set_postfix_str(postfix_str) + + return self.state_neural_net, logs + + def _get_step_fn(self) -> Callable: + """Create a one-step training and evaluation function.""" + + def loss_fn( + params: frozen_dict.FrozenDict, + apply_fn: Callable, + batch: Dict[str, jnp.ndarray], + step: int, + ) -> Tuple[float, Dict[str, float]]: + """Loss function with conditional map and regularizer.""" + # Apply the conditional map: T(source, condition) + mapped_samples = apply_fn({"params": params}, batch["source"], + batch["condition"]) + + # Fitting loss: Δ(T(x,c), y) + val_fitting_loss, log_fitting_loss = self.fitting_loss( + mapped_samples, batch["target"] + ) + + # Conditional regularizer: R(x, T(x,c), labels) + val_regularizer, log_regularizer = self.regularizer( + batch["source"], mapped_samples, batch["condition_labels"] + ) + + val_tot_loss = ( + val_fitting_loss + self.regularizer_strength[step] * val_regularizer + ) + + loss_logs = { + "total_loss": val_tot_loss, + "fitting_loss": val_fitting_loss, + "regularizer": val_regularizer, + "log_regularizer": log_regularizer, + "log_fitting": log_fitting_loss, + } + + return val_tot_loss, loss_logs + + @functools.partial(jax.jit, static_argnums=3) + def step_fn( + state_neural_net: train_state.TrainState, + train_batch: Dict[str, jnp.ndarray], + valid_batch: Optional[Dict[str, jnp.ndarray]] = None, + is_logging_step: bool = False, + step: int = 0, + ) -> Tuple[train_state.TrainState, Dict[str, float]]: + """One step function.""" + grad_fn = jax.value_and_grad(loss_fn, argnums=0, has_aux=True) + (_, current_train_logs), grads = grad_fn( + state_neural_net.params, + state_neural_net.apply_fn, + train_batch, + step, + ) + + current_logs = {"train": current_train_logs, "eval": {}} + if is_logging_step: + _, current_eval_logs = loss_fn( + params=state_neural_net.params, + apply_fn=state_neural_net.apply_fn, + batch=valid_batch, + step=step, + ) + current_logs["eval"] = current_eval_logs + + return state_neural_net.apply_gradients(grads=grads), current_logs + + return step_fn diff --git a/src/ott/neural/networks/__init__.py b/src/ott/neural/networks/__init__.py index b35fcb61a..568a02e37 100644 --- a/src/ott/neural/networks/__init__.py +++ b/src/ott/neural/networks/__init__.py @@ -11,6 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import icnn, layers, potentials +from . import conditional_perturbation_network, icnn, layers, potentials -__all__ = ["icnn", "layers", "potentials"] +__all__ = ["conditional_perturbation_network", "icnn", "layers", "potentials"] diff --git a/src/ott/neural/networks/conditional_perturbation_network.py b/src/ott/neural/networks/conditional_perturbation_network.py new file mode 100644 index 000000000..bafb73260 --- /dev/null +++ b/src/ott/neural/networks/conditional_perturbation_network.py @@ -0,0 +1,142 @@ +from typing import ( + Any, + Callable, + Dict, + Iterable, + Optional, + Sequence, + Tuple, + Union, +) + +import jax.numpy as jnp + +import flax.linen as nn +import optax + +from ott.neural.networks.potentials import BasePotential, PotentialTrainState + + +class ConditionalPerturbationNetwork(BasePotential): + """Condition-aware perturbation network for OT maps.""" + + dim_hidden: Sequence[int] = None + dim_data: int = None + dim_cond: int = None # Full dimension of all context variables concatenated + # Same length as context_entity_bonds if embed_cond_equal is False + # (if True, first item is size of deep set layer, rest is ignored) + dim_cond_map: Iterable[int] = (50,) + act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.gelu + is_potential: bool = False + layer_norm: bool = False + embed_cond_equal: bool = ( + False # Whether all context variables should be treated as set or not + ) + context_entity_bonds: Iterable[Tuple[int, int]] = ( + (0, 10), + (10, 20), + ) # (start, stop) slicing bounds per context modality in c; + # should be contiguous, non-overlapping by default. + num_contexts: int = 2 + + @nn.compact + def __call__( + self, + x: jnp.ndarray, + c: Optional[jnp.ndarray] = None + ) -> Union[jnp.ndarray, Dict[str, jnp.ndarray]]: # noqa: D102 + """Forward pass: map (x, c) -> x + residual. + + Args: + x: Input data of shape ``(batch, dim_data)``. + c: Context vector of shape ``(batch, dim_cond)``. May + contain multiple modalities concatenated along the last + axis. ``context_entity_bonds`` specifies which slice + ``c[:, start:stop]`` belongs to each modality. Slices + should generally be contiguous and non-overlapping, e.g. + ``((0, 10), (10, 20))`` for two 10-dim modalities. + + Returns: + Mapped output of shape ``(batch, dim_data)``. + """ + return_batch = False + if isinstance(x, dict): + c = x["c"] + x = x["X"] + return_batch = True + + n_input = x.shape[-1] + + # Chunk the inputs + contexts = [ + c[:, e[0]:e[1]] + for i, e in enumerate(self.context_entity_bonds) + if i < self.num_contexts + ] + + if not self.embed_cond_equal: + # Each context is processed by a different layer, + # good for combining modalities + assert len(self.context_entity_bonds) == len(self.dim_cond_map), ( + "Length of context entity bonds and context map sizes have to " + f"match: {self.context_entity_bonds} != {self.dim_cond_map}" + ) + + layers = [ + nn.Dense(self.dim_cond_map[i], use_bias=True) + for i in range(len(contexts)) + ] + embeddings = [ + self.act_fn(layers[i](context)) for i, context in enumerate(contexts) + ] + cond_embedding = jnp.concatenate(embeddings, axis=1) + else: + # We can use any number of contexts from the same modality, + # via a permutation-invariant deep set layer. + sizes = [c.shape[-1] for c in contexts] + if not len(set(sizes)) == 1: + raise ValueError( + "For embedding a set, all contexts need same length ," + f"not {sizes}" + ) + layer = nn.Dense(self.dim_cond_map[0], use_bias=True) + embeddings = [self.act_fn(layer(context)) for context in contexts] + # Average along stacked dimension + # (alternatives like summing are possible) + cond_embedding = jnp.mean(jnp.stack(embeddings), axis=0) + + z = jnp.concatenate((x, cond_embedding), axis=1) + if self.layer_norm: + n = nn.LayerNorm() + z = n(z) + + for n_hidden in self.dim_hidden: + wx = nn.Dense(n_hidden, use_bias=True) + z = self.act_fn(wx(z)) + wx = nn.Dense(n_input, use_bias=True) + + y = x + wx(z) + + if return_batch: + return {"X": y, "c": c} + return y + + def create_train_state( + self, + rng: jnp.ndarray, + optimizer: optax.OptState, + dim_data: int, + **kwargs: Any, + ) -> PotentialTrainState: + """Create initial `TrainState`.""" + c = jnp.ones((1, self.dim_cond)) # (n_batch, embed_dim) + x = jnp.ones((1, dim_data)) # (n_batch, data_dim) + params = self.init(rng, x=x, c=c)["params"] + return PotentialTrainState.create( + apply_fn=self.apply, + params=params, + tx=optimizer, + potential_value_fn=self.potential_value_fn, + potential_gradient_fn=self.potential_gradient_fn, + **kwargs, + ) diff --git a/tests/neural/methods/conditional_monge_gap_test.py b/tests/neural/methods/conditional_monge_gap_test.py new file mode 100644 index 000000000..320dc06ad --- /dev/null +++ b/tests/neural/methods/conditional_monge_gap_test.py @@ -0,0 +1,596 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import pytest + +import jax +import jax.numpy as jnp +import numpy as np + +from ott import datasets +from ott.geometry import costs, regularizers +from ott.neural.methods import conditional_monge_gap +from ott.neural.methods.monge_gap import monge_gap_from_samples +from ott.neural.networks import potentials +from ott.neural.networks.conditional_perturbation_network import ( + ConditionalPerturbationNetwork, +) +from ott.tools import sinkhorn_divergence + + +@pytest.mark.fast() +class TestConditionalMongeGap: + + @pytest.mark.parametrize("n_samples", [10, 30]) + @pytest.mark.parametrize("n_features", [4, 10]) + @pytest.mark.parametrize("num_conditions", [2, 3]) + def test_non_negativity( + self, + rng: jax.Array, + n_samples: int, + n_features: int, + num_conditions: int, + ): + rng1, rng2 = jax.random.split(rng) + per_cond = n_samples // num_conditions + n = per_cond * num_conditions + + source = jax.random.normal(rng1, (n, n_features)) + target = source + 0.5 * jax.random.normal(rng2, (n, n_features)) + condition = jnp.repeat(jnp.arange(num_conditions), per_cond) + + gap = conditional_monge_gap.cmonge_gap_from_samples( + source, + target, + condition, + num_segments=num_conditions, + max_measure_size=per_cond, + ) + np.testing.assert_array_equal(gap >= 0, True) + + def test_jit_consistency(self, rng: jax.Array): + n, d, k = 60, 4, 3 + per_cond = n // k + rng1, rng2 = jax.random.split(rng) + source = jax.random.normal(rng1, (n, d)) + target = source + 0.1 * jax.random.normal(rng2, (n, d)) + condition = jnp.repeat(jnp.arange(k), per_cond) + + eager_gap = conditional_monge_gap.cmonge_gap_from_samples( + source, + target, + condition, + num_segments=k, + max_measure_size=per_cond, + ) + jit_gap = jax.jit( + lambda s, t, c: conditional_monge_gap.cmonge_gap_from_samples( + s, + t, + c, + num_segments=k, + max_measure_size=per_cond, + ) + )(source, target, condition) + + np.testing.assert_allclose(eager_gap, jit_gap, rtol=1e-3) + + def test_matches_loop_baseline(self, rng: jax.Array): + """Segment-based result matches manual per-condition loop.""" + n, d, k = 60, 4, 3 + per_cond = n // k + rng1, rng2 = jax.random.split(rng) + source = jax.random.normal(rng1, (n, d)) + target = source + 0.1 * jax.random.normal(rng2, (n, d)) + condition = jnp.repeat(jnp.arange(k), per_cond) + + new_gap = conditional_monge_gap.cmonge_gap_from_samples( + source, + target, + condition, + num_segments=k, + max_measure_size=per_cond, + ) + + # Manual loop (the old approach) + manual_gaps = [] + for c in range(k): + mask = condition == c + gap = monge_gap_from_samples(source[mask], target[mask]) + manual_gaps.append(float(gap)) + manual_avg = sum(manual_gaps) / len(manual_gaps) + + np.testing.assert_allclose(float(new_gap), manual_avg, atol=1e-5) + + def test_identity_smaller_than_random(self, rng: jax.Array): + """Identity map should have smaller Monge gap than a random map.""" + n, d, k = 60, 4, 3 + per_cond = n // k + rng1, rng2 = jax.random.split(rng) + source = jax.random.normal(rng1, (n, d)) + condition = jnp.repeat(jnp.arange(k), per_cond) + + identity_gap = conditional_monge_gap.cmonge_gap_from_samples( + source, + source, + condition, + num_segments=k, + max_measure_size=per_cond, + ) + random_target = jax.random.normal(rng2, (n, d)) * 3.0 + random_gap = conditional_monge_gap.cmonge_gap_from_samples( + source, + random_target, + condition, + num_segments=k, + max_measure_size=per_cond, + ) + assert identity_gap < random_gap + + @pytest.mark.parametrize( + "cost_fn", + [ + costs.SqEuclidean(), + costs.PNormP(p=1), + ], + ids=["sqeucl", "pnorm-1"], + ) + def test_different_costs(self, rng: jax.Array, cost_fn: costs.CostFn): + n, d, k = 30, 4, 3 + per_cond = n // k + rng1, rng2 = jax.random.split(rng) + source = jax.random.normal(rng1, (n, d)) + target = source + jax.random.normal(rng2, (n, d)) * 0.5 + condition = jnp.repeat(jnp.arange(k), per_cond) + + gap = conditional_monge_gap.cmonge_gap_from_samples( + source, + target, + condition, + cost_fn=cost_fn, + num_segments=k, + max_measure_size=per_cond, + ) + np.testing.assert_array_equal(jnp.isfinite(gap), True) + np.testing.assert_array_equal(gap >= 0, True) + + def test_return_output_shape(self, rng: jax.Array): + n, d, k = 60, 4, 3 + per_cond = n // k + rng1, rng2 = jax.random.split(rng) + source = jax.random.normal(rng1, (n, d)) + target = source + 0.1 * jax.random.normal(rng2, (n, d)) + condition = jnp.repeat(jnp.arange(k), per_cond) + + result = conditional_monge_gap.cmonge_gap_from_samples( + source, + target, + condition, + num_segments=k, + max_measure_size=per_cond, + return_output=True, + ) + assert isinstance(result, tuple) + avg_gap, per_cond_gaps = result + assert per_cond_gaps.shape == (k,) + np.testing.assert_allclose( + float(avg_gap), + float(jnp.mean(per_cond_gaps)), + rtol=1e-5, + ) + + @pytest.mark.parametrize("n_samples", [10, 30]) + @pytest.mark.parametrize("n_features", [4, 10]) + def test_non_negativity_neural_map( + self, + rng: jax.Array, + n_samples: int, + n_features: int, + ): + """Non-negativity with a learned nonlinear map.""" + k = 2 + per_cond = n_samples // k + n = per_cond * k + rng1, rng2 = jax.random.split(rng) + + source = jax.random.normal(rng1, (n, n_features)) + model = potentials.PotentialMLP(dim_hidden=[8, 8], is_potential=False) + params = model.init(rng2, x=source[0]) + target = model.apply(params, source) + condition = jnp.repeat(jnp.arange(k), per_cond) + + gap = conditional_monge_gap.cmonge_gap_from_samples( + source, + target, + condition, + num_segments=k, + max_measure_size=per_cond, + ) + np.testing.assert_array_equal(jnp.isfinite(gap), True) + np.testing.assert_array_equal(gap >= 0, True) + + @pytest.mark.parametrize( + "cost_fn", + [ + costs.PNormP(p=1), + costs.RegTICost(regularizers.L1(), lam=2.0), + costs.RegTICost(regularizers.STVS(gamma=3.0), lam=1.0), + ], + ids=["pnorm-1", "l1-lam2", "stvs-lam1"], + ) + def test_different_costs_give_different_values( + self, + rng: jax.Array, + cost_fn: costs.CostFn, + ): + """Non-Euclidean costs produce different cmonge_gap than Euclidean.""" + n, d, k = 30, 5, 3 + per_cond = n // k + rng1, rng2 = jax.random.split(rng) + source = jax.random.normal(rng1, (n, d)) + target = jax.random.normal(rng2, (n, d)) * 0.1 + 3.0 + condition = jnp.repeat(jnp.arange(k), per_cond) + + gap_eucl = conditional_monge_gap.cmonge_gap_from_samples( + source, + target, + condition, + cost_fn=costs.Euclidean(), + num_segments=k, + max_measure_size=per_cond, + ) + gap_other = conditional_monge_gap.cmonge_gap_from_samples( + source, + target, + condition, + cost_fn=cost_fn, + num_segments=k, + max_measure_size=per_cond, + ) + + with pytest.raises(AssertionError, match=r"tolerance"): + np.testing.assert_allclose( + gap_eucl, + gap_other, + rtol=1e-1, + atol=1e-1, + ) + np.testing.assert_array_equal(jnp.isfinite(gap_eucl), True) + np.testing.assert_array_equal(jnp.isfinite(gap_other), True) + + def test_uniform_conditions_equals_averaged_monge_gap( + self, + rng: jax.Array, + ): + """cmonge_gap with equal-size conditions == mean of monge_gap calls.""" + k = 3 + per_cond = 20 + d = 5 + + # Different offsets per condition so gaps are distinct + offsets = jnp.array([0.1, 1.0, 3.0]) + rngs = jax.random.split(rng, 2 * k) + sources, targets = [], [] + for c in range(k): + s = jax.random.normal(rngs[2 * c], (per_cond, d)) + t = ( + s + offsets[c] + + 0.05 * jax.random.normal(rngs[2 * c + 1], (per_cond, d)) + ) + sources.append(s) + targets.append(t) + + source = jnp.concatenate(sources, axis=0) + target = jnp.concatenate(targets, axis=0) + condition = jnp.repeat(jnp.arange(k), per_cond) + + # Segmented cmonge_gap (single call, vmapped) + t0 = time.perf_counter() + avg_gap, per_cond_gaps = conditional_monge_gap.cmonge_gap_from_samples( + source, + target, + condition, + num_segments=k, + max_measure_size=per_cond, + return_output=True, + ) + # Force computation to complete before timing + avg_gap.block_until_ready() + t_cmonge = time.perf_counter() - t0 + + # Manual per-condition monge_gap calls (K sequential calls) + t0 = time.perf_counter() + manual_gaps = [] + for c in range(k): + gap_c = monge_gap_from_samples(sources[c], targets[c]) + manual_gaps.append(float(gap_c)) + manual_avg = sum(manual_gaps) / k + t_loop = time.perf_counter() - t0 + + # Single-condition overhead: cmonge_gap(K=1) vs monge_gap + t0 = time.perf_counter() + gap_single_cmonge = conditional_monge_gap.cmonge_gap_from_samples( + sources[0], + targets[0], + jnp.zeros(per_cond, dtype=jnp.int32), + num_segments=1, + max_measure_size=per_cond, + ) + gap_single_cmonge.block_until_ready() + t_cmonge_1 = time.perf_counter() - t0 + + t0 = time.perf_counter() + gap_single_monge = monge_gap_from_samples(sources[0], targets[0]) + float(gap_single_monge) # block + t_monge_1 = time.perf_counter() - t0 + + print( # noqa: T201 + f"\n K={k}: cmonge_gap: {t_cmonge:.3f}s | " + f"loop({k}x monge_gap): {t_loop:.3f}s | " + f"speedup: {t_loop / t_cmonge:.1f}x" + f"\n K=1: cmonge_gap: {t_cmonge_1:.3f}s | " + f"monge_gap: {t_monge_1:.3f}s | " + f"overhead: {t_cmonge_1 / t_monge_1:.1f}x" + ) + + # Average should match + np.testing.assert_allclose(float(avg_gap), manual_avg, atol=1e-5) + # Per-condition gaps should match individual calls + for c in range(k): + np.testing.assert_allclose( + float(per_cond_gaps[c]), + manual_gaps[c], + atol=1e-5, + ) + + def test_unequal_conditions_shifts_average(self, rng: jax.Array): + """With unequal n_k, per-condition gaps change and shift the average. + + The segment interface pads all conditions to max_measure_size, so + per-condition gaps with padding do NOT exactly match non-padded + monge_gap_from_samples calls (the geometry differs). We verify + structural properties instead: gaps are finite, easy < hard, + average = mean(per_cond_gaps), and the average shifts when n_k changes. + """ + d = 5 + rng_easy, rng_hard, rng_noise = jax.random.split(rng, 3) + + base_easy = jax.random.normal(rng_easy, (60, d)) + base_hard = jax.random.normal(rng_hard, (60, d)) + noise = 0.01 * jax.random.normal(rng_noise, (60, d)) + + target_easy = base_easy + noise + target_hard = base_hard + 5.0 + + # (a) Equal sizes: 30/30 + n_eq = 30 + src_eq = jnp.concatenate([base_easy[:n_eq], base_hard[:n_eq]]) + tgt_eq = jnp.concatenate([target_easy[:n_eq], target_hard[:n_eq]]) + cond_eq = jnp.repeat(jnp.arange(2), n_eq) + + avg_eq, gaps_eq = conditional_monge_gap.cmonge_gap_from_samples( + src_eq, + tgt_eq, + cond_eq, + num_segments=2, + max_measure_size=n_eq, + return_output=True, + ) + + # (b) Unequal sizes: 50 easy / 10 hard + n_a, n_b = 50, 10 + src_uneq = jnp.concatenate([base_easy[:n_a], base_hard[:n_b]]) + tgt_uneq = jnp.concatenate([target_easy[:n_a], target_hard[:n_b]]) + cond_uneq = jnp.concatenate([ + jnp.zeros(n_a, dtype=jnp.int32), + jnp.ones(n_b, dtype=jnp.int32), + ]) + + avg_uneq, gaps_uneq = conditional_monge_gap.cmonge_gap_from_samples( + src_uneq, + tgt_uneq, + cond_uneq, + num_segments=2, + max_measure_size=n_a, + return_output=True, + ) + + # All gaps are finite and non-negative + for gaps in [gaps_eq, gaps_uneq]: + np.testing.assert_array_equal(jnp.all(jnp.isfinite(gaps)), True) + np.testing.assert_array_equal(jnp.all(gaps >= 0), True) + + # Easy condition has smaller gap than hard condition + assert gaps_eq[0] < gaps_eq[1] + assert gaps_uneq[0] < gaps_uneq[1] + + # Average is the mean of per-condition gaps + np.testing.assert_allclose( + float(avg_eq), + float(jnp.mean(gaps_eq)), + rtol=1e-5, + ) + np.testing.assert_allclose( + float(avg_uneq), + float(jnp.mean(gaps_uneq)), + rtol=1e-5, + ) + + # Averages differ between equal and unequal splits (n_k affects + # the padded OT cost estimation, shifting per-condition gaps) + assert float(avg_eq) != float(avg_uneq) + + def test_per_condition_gaps_reflect_difficulty(self, rng: jax.Array): + """Per-condition gaps increase with offset magnitude.""" + k = 3 + per_cond = 25 + d = 4 + offsets = jnp.array([0.0, 1.5, 5.0]) + + rngs = jax.random.split(rng, 2 * k) + sources, targets = [], [] + for c in range(k): + s = jax.random.normal(rngs[2 * c], (per_cond, d)) + t = s + offsets[c] + sources.append(s) + targets.append(t) + + source = jnp.concatenate(sources, axis=0) + target = jnp.concatenate(targets, axis=0) + condition = jnp.repeat(jnp.arange(k), per_cond) + + _, per_cond_gaps = conditional_monge_gap.cmonge_gap_from_samples( + source, + target, + condition, + num_segments=k, + max_measure_size=per_cond, + return_output=True, + ) + + assert per_cond_gaps[0] < per_cond_gaps[1] < per_cond_gaps[2] + + +@pytest.mark.fast() +class TestConditionalMongeGapEstimator: + + def test_estimator_convergence(self): + """Train a conditional map and verify loss decreases.""" + num_conditions = 3 + dim_data = 2 + dim_cond = num_conditions # one-hot + batch_size = 30 + + train_ds, valid_ds, _, n_cond, max_ms = ( + datasets.create_conditional_gaussian_mixture_samplers( + num_conditions=num_conditions, + dim=dim_data, + train_batch_size=batch_size, + valid_batch_size=batch_size, + ) + ) + + def fitting_loss(mapped, target): + div, _ = sinkhorn_divergence.sinkdiv(x=mapped, y=target) + return div, None + + def regularizer(source, mapped, labels): + gap, per_cond = conditional_monge_gap.cmonge_gap_from_samples( + source, + mapped, + labels, + num_segments=n_cond, + max_measure_size=max_ms, + return_output=True, + ) + return gap, None + + model = ConditionalPerturbationNetwork( + dim_hidden=[16, 8], + dim_data=dim_data, + dim_cond=dim_cond, + dim_cond_map=(16,), + is_potential=False, + context_entity_bonds=((0, dim_cond),), + num_contexts=1, + ) + + solver = conditional_monge_gap.ConditionalMongeGapEstimator( + dim_data=dim_data, + fitting_loss=fitting_loss, + regularizer=regularizer, + model=model, + regularizer_strength=1.0, + num_train_iters=15, + logging=True, + valid_freq=5, + ) + + neural_state, logs = solver.train_map_estimator( + *train_ds, + *valid_ds, + ) + + # Loss should decrease + assert logs["train"]["total_loss"][0] > logs["train"]["total_loss"][-1] + + # Output shape should match input + source_batch = next(train_ds.source_iter) + cond_batch = next(train_ds.condition_iter) + mapped = neural_state.apply_fn( + {"params": neural_state.params}, + source_batch, + cond_batch, + ) + assert mapped.shape == source_batch.shape + np.testing.assert_array_equal(jnp.all(jnp.isfinite(mapped)), True) + + def test_estimator_no_regularizer(self): + """Training with regularizer_strength=0 still converges.""" + num_conditions = 2 + dim_data = 2 + dim_cond = num_conditions + batch_size = 20 + + train_ds, valid_ds, _, _, _ = ( + datasets.create_conditional_gaussian_mixture_samplers( + num_conditions=num_conditions, + dim=dim_data, + train_batch_size=batch_size, + valid_batch_size=batch_size, + ) + ) + + def fitting_loss(mapped, target): + div, _ = sinkhorn_divergence.sinkdiv(x=mapped, y=target) + return div, None + + model = ConditionalPerturbationNetwork( + dim_hidden=[8, 8], + dim_data=dim_data, + dim_cond=dim_cond, + dim_cond_map=(8,), + is_potential=False, + context_entity_bonds=((0, dim_cond),), + num_contexts=1, + ) + + solver = conditional_monge_gap.ConditionalMongeGapEstimator( + dim_data=dim_data, + fitting_loss=fitting_loss, + model=model, + regularizer_strength=0.0, + num_train_iters=10, + logging=True, + valid_freq=5, + ) + + neural_state, logs = solver.train_map_estimator( + *train_ds, + *valid_ds, + ) + + # Should have run without errors and logged metrics + assert len(logs["train"]["total_loss"]) > 0 + # Mapped output should be finite + source_batch = next(train_ds.source_iter) + cond_batch = next(train_ds.condition_iter) + mapped = neural_state.apply_fn( + {"params": neural_state.params}, + source_batch, + cond_batch, + ) + np.testing.assert_array_equal(jnp.all(jnp.isfinite(mapped)), True)