Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions src/ott/geometry/pointcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,19 +278,27 @@ def to_LRCGeometry(
the point cloud does not have squared Euclidean cost.

Returns:
Returns the unmodified point cloud if :math:`n m \ge (n + m) d`, where
:math:`n, m` is the shape and :math:`d` is the dimension of the point
cloud with squared Euclidean cost.
Otherwise, returns the re-scaled low-rank geometry.
:class:`~ott.geometry.low_rank.LRCGeometry` or
:class:`~ott.geometry.pointcloud.PointCloud`:
The re-scaled low-rank geometry if:

1. :math:`n m > (n + m) d`, where
:math:`n, m` is the shape and :math:`d` is the
dimension of the point cloud with squared
Euclidean cost; or
2. ``rank=0`` is passed as keyword argument.

Otherwise, returns the unmodified point cloud.
"""
force = (kwargs.get("rank") == 0)
if self.is_squared_euclidean:
if self._check_LRC_dim:
if force or self._check_LRC_dim:
return self._sqeucl_to_lr(scale)
# we don't update the `scale_factor` because in GW, the linear cost
# is first materialized and then scaled by `fused_penalty` afterwards
return self
if self.is_neg_dotp:
if self._check_LRC_dim:
if force or self._check_LRC_dim:
return self._dotp_to_lr(scale)
return self
return super().to_LRCGeometry(scale=scale, **kwargs)
Expand Down