diff --git a/src/ott/geometry/pointcloud.py b/src/ott/geometry/pointcloud.py index 37785a231..588f7cbe8 100644 --- a/src/ott/geometry/pointcloud.py +++ b/src/ott/geometry/pointcloud.py @@ -278,19 +278,27 @@ def to_LRCGeometry( the point cloud does not have squared Euclidean cost. Returns: - Returns the unmodified point cloud if :math:`n m \ge (n + m) d`, where - :math:`n, m` is the shape and :math:`d` is the dimension of the point - cloud with squared Euclidean cost. - Otherwise, returns the re-scaled low-rank geometry. + :class:`~ott.geometry.low_rank.LRCGeometry` or + :class:`~ott.geometry.pointcloud.PointCloud`: + The re-scaled low-rank geometry if: + + 1. :math:`n m > (n + m) d`, where + :math:`n, m` is the shape and :math:`d` is the + dimension of the point cloud with squared + Euclidean cost; or + 2. ``rank=0`` is passed as keyword argument. + + Otherwise, returns the unmodified point cloud. """ + force = (kwargs.get("rank") == 0) if self.is_squared_euclidean: - if self._check_LRC_dim: + if force or self._check_LRC_dim: return self._sqeucl_to_lr(scale) # we don't update the `scale_factor` because in GW, the linear cost # is first materialized and then scaled by `fused_penalty` afterwards return self if self.is_neg_dotp: - if self._check_LRC_dim: + if force or self._check_LRC_dim: return self._dotp_to_lr(scale) return self return super().to_LRCGeometry(scale=scale, **kwargs)