From 8a5a3fffa25145fecbec39514f92c83326d90561 Mon Sep 17 00:00:00 2001 From: alexjmccleary Date: Sat, 2 May 2026 13:32:46 -0600 Subject: [PATCH 1/3] Add support for non-lower-star filtrations --- pyect/__init__.py | 1 + pyect/general_filtrations.py | 59 ++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+) create mode 100644 pyect/general_filtrations.py diff --git a/pyect/__init__.py b/pyect/__init__.py index fe39992..a1da5d9 100644 --- a/pyect/__init__.py +++ b/pyect/__init__.py @@ -3,6 +3,7 @@ from .directions import sample_directions_2d, sample_directions_3d from .image_ecf import Image_ECF_2D, Image_ECF_3D from .differentiable_wect import DWECT +from .general_filtrations import compute_wecfs_general from .preprocessing.mesh_processing import mesh_to_complex from .preprocessing.image_processing import ( weighted_freudenthal, diff --git a/pyect/general_filtrations.py b/pyect/general_filtrations.py new file mode 100644 index 0000000..d4d2ecc --- /dev/null +++ b/pyect/general_filtrations.py @@ -0,0 +1,59 @@ +"""For computing the WECF of an arbitrary (not neccesarily lower-star) filtration. +If your filtrations are lower-star (they usually are), then use wecfs.py instead. +""" + +import torch +from typing import List, Tuple + +def compute_wecfs_general( + filtration_data: List[Tuple[torch.Tensor, torch.Tensor]], + num_vals: int +) -> torch.Tensor: + """Calculates WECFs for filtrations with values assigned to every simplex. + + Args: + filtration_data: A weighted simplicial or cubical complex with a collection + of filter functions defined on each simplex, represented as a list of + pairs of tensors. The list index is the simplex dimension. + + filtration_data[i] = (simplex_filters, simplex_weights): + simplex_filters (torch.Tensor): A tensor of shape (k_i, m), where + k_i is the number of i-simplices and m is the number of filter + functions. Each row contains the filter values of one simplex. + + simplex_weights (torch.Tensor): A tensor of shape (k_i). Values + are the weights of the i-simplices. + + Returns: + wecfs (torch.Tensor): A 2d tensor of shape (m, num_vals) + containing the WECFs. + """ + + if num_vals <= 0: + raise ValueError("num_vals must be positive.") + + if len(filtration_data) == 0: + raise ValueError("filtration_data must be non-empty.") + + device = filtration_data[0][0].device + m = filtration_data[0][0].size(dim=1) + eps = torch.finfo(torch.float32).eps + + max_val = torch.cat([f.reshape(-1) for f, _ in filtration_data]).max() + min_val = torch.cat([f.reshape(-1) for f, _ in filtration_data]).min() + val_range = torch.clamp(max_val - min_val, min=eps) + + diff_wecfs = torch.zeros((m, num_vals), dtype=torch.float32, device=device) + + for i, (simplex_filters, simplex_weights) in enumerate(filtration_data): + simplex_indices = torch.ceil( + (num_vals - 1) * (simplex_filters - min_val) / (val_range) + ).clamp(0, num_vals-1).long() + + expanded_simplex_weights = ( + (-1) ** i * simplex_weights.unsqueeze(0).expand(m, -1) + ) + + diff_wecfs.scatter_add_(1, simplex_indices.T, expanded_simplex_weights) + + return torch.cumsum(diff_wecfs, dim=1) \ No newline at end of file From 25f92c4e910295ccaba416959e1bbe69b3d3c8c8 Mon Sep 17 00:00:00 2001 From: alexjmccleary Date: Sat, 9 May 2026 12:00:29 -0600 Subject: [PATCH 2/3] Add Gudhi integration --- README.md | 12 ++ examples/gudhi_alpha_complex_wecf.py | 37 ++++++ pyect/integrations/__init__.py | 1 + pyect/integrations/gudhi.py | 135 +++++++++++++++++++++ pyproject.toml | 3 + tests/test_gudhi_integration.py | 172 +++++++++++++++++++++++++++ 6 files changed, 360 insertions(+) create mode 100644 examples/gudhi_alpha_complex_wecf.py create mode 100644 pyect/integrations/__init__.py create mode 100644 pyect/integrations/gudhi.py create mode 100644 tests/test_gudhi_integration.py diff --git a/README.md b/README.md index 2fad43b..7199023 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,18 @@ To install `pyECT`, use pip: pip install pyect ``` +Gudhi alpha-complex support is optional: + +```bash +pip install pyect[gudhi] +``` + +The Gudhi integration lives at `pyect.integrations.gudhi` so Gudhi is not +imported by the core package. In `alpha_complex_to_filtration_data`, +`point_weights` are passed to Gudhi to construct the alpha filtration. The +pyECT simplex weights are `1.0` by default; pass `simplex_weight_fn` to use a +custom weighting rule, such as the max of the simplex vertex weights. + ## Usage Here's a simple example of how to use `pyECT`: diff --git a/examples/gudhi_alpha_complex_wecf.py b/examples/gudhi_alpha_complex_wecf.py new file mode 100644 index 0000000..8bac4b4 --- /dev/null +++ b/examples/gudhi_alpha_complex_wecf.py @@ -0,0 +1,37 @@ +"""Example WECF computation from a Gudhi alpha complex.""" + +import torch + +from pyect import compute_wecfs_general +from pyect.integrations.gudhi import alpha_complex_to_filtration_data + + +def main(): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + points = [ + [1.0, 1.0], + [7.0, 0.0], + [4.0, 6.0], + [9.0, 6.0], + [0.0, 14.0], + [2.0, 19.0], + [9.0, 17.0], + ] + + point_weights = [0.0, 0.2, 0.1, 0.3, 0.0, 0.1, 0.2] + + filtration_data, simplex_tree = alpha_complex_to_filtration_data( + points, + point_weights=point_weights, + device=device, + ) + + wecf = compute_wecfs_general(filtration_data, num_vals=200) + + print(simplex_tree.num_simplices()) + print(wecf.shape) + + +if __name__ == "__main__": + main() diff --git a/pyect/integrations/__init__.py b/pyect/integrations/__init__.py new file mode 100644 index 0000000..688f850 --- /dev/null +++ b/pyect/integrations/__init__.py @@ -0,0 +1 @@ +"""Optional integrations for pyECT.""" diff --git a/pyect/integrations/gudhi.py b/pyect/integrations/gudhi.py new file mode 100644 index 0000000..bb0790f --- /dev/null +++ b/pyect/integrations/gudhi.py @@ -0,0 +1,135 @@ +"""Optional Gudhi integration helpers.""" + +from collections import defaultdict +from math import isfinite, sqrt +from typing import Callable, Optional, Sequence + +import torch + + +def gudhi_simplex_tree_to_filtration_data( + simplex_tree, + *, + max_dim: Optional[int] = None, + simplex_weight_fn: Optional[Callable[[Sequence[int], int, float], float]] = None, + use_alpha_instead_of_alpha_square: bool = False, + dtype=torch.float32, + device=None, +): + """Convert a Gudhi SimplexTree to ``compute_wecfs_general`` input data. + + Args: + simplex_tree: A Gudhi SimplexTree with filtration values assigned to + simplices. + max_dim: Optional maximum simplex dimension to include. + simplex_weight_fn: Optional function + ``simplex_weight_fn(simplex, dim, filtration_value) -> float``. + Do not include the Euler sign; ``compute_wecfs_general`` handles it. + use_alpha_instead_of_alpha_square: Gudhi alpha complexes return squared + alpha values by default. Set this to True to use alpha values. + dtype: Torch dtype for the output tensors. + device: Torch device for the output tensors. + + Returns: + A list where ``filtration_data[d] = (simplex_filters, simplex_weights)``. + ``simplex_filters`` has shape ``(num_d_simplices, 1)`` and + ``simplex_weights`` has shape ``(num_d_simplices,)``. + """ + + filters_by_dim = defaultdict(list) + weights_by_dim = defaultdict(list) + + for simplex, filt in simplex_tree.get_filtration(): + dim = len(simplex) - 1 + + if max_dim is not None and dim > max_dim: + continue + + filt = float(filt) + + if not isfinite(filt): + raise ValueError("Encountered non-finite Gudhi filtration value.") + + if use_alpha_instead_of_alpha_square: + if filt < 0: + raise ValueError( + "Cannot take sqrt of a negative filtration value. " + "This can happen for weighted alpha complexes." + ) + filt = sqrt(filt) + + weight = 1.0 + if simplex_weight_fn is not None: + weight = float(simplex_weight_fn(simplex, dim, filt)) + + filters_by_dim[dim].append(filt) + weights_by_dim[dim].append(weight) + + if not filters_by_dim: + raise ValueError("The Gudhi SimplexTree contains no simplices.") + + filtration_data = [] + + for dim in range(max(filters_by_dim) + 1): + simplex_filters = torch.tensor( + filters_by_dim.get(dim, []), + dtype=dtype, + device=device, + ).reshape(-1, 1) + + simplex_weights = torch.tensor( + weights_by_dim.get(dim, []), + dtype=dtype, + device=device, + ) + + filtration_data.append((simplex_filters, simplex_weights)) + + return filtration_data + + +def alpha_complex_to_filtration_data( + points, + *, + point_weights=None, + max_alpha_square=float("inf"), + max_dim=None, + simplex_weight_fn=None, + use_alpha_instead_of_alpha_square=False, + dtype=torch.float32, + device=None, +): + """Build a Gudhi AlphaComplex and convert it to pyECT filtration data. + + ``point_weights`` are passed to Gudhi to construct a weighted alpha + filtration. They do not change the pyECT simplex weights, which are one by + default unless ``simplex_weight_fn`` is provided. + """ + + try: + import gudhi + except ImportError as exc: + raise ImportError( + "Gudhi support requires the optional dependency 'gudhi'. " + "Install it with: pip install pyect[gudhi]" + ) from exc + + kwargs = {"points": points} + if point_weights is not None: + kwargs["weights"] = point_weights + + alpha_complex = gudhi.AlphaComplex(**kwargs) + simplex_tree = alpha_complex.create_simplex_tree( + max_alpha_square=max_alpha_square + ) + + filtration_data = gudhi_simplex_tree_to_filtration_data( + simplex_tree, + max_dim=max_dim, + simplex_weight_fn=simplex_weight_fn, + use_alpha_instead_of_alpha_square=use_alpha_instead_of_alpha_square, + dtype=dtype, + device=device, + ) + + return filtration_data, simplex_tree diff --git a/pyproject.toml b/pyproject.toml index 7cc80ef..8c8e6a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,9 @@ dependencies = [ "trimesh" ] +[project.optional-dependencies] +gudhi = ["gudhi"] + classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", diff --git a/tests/test_gudhi_integration.py b/tests/test_gudhi_integration.py new file mode 100644 index 0000000..a219381 --- /dev/null +++ b/tests/test_gudhi_integration.py @@ -0,0 +1,172 @@ +"""Tests for optional Gudhi integration helpers.""" + +import sys +import types + +import pytest +import torch + +from pyect import compute_wecfs_general +from pyect.integrations.gudhi import ( + alpha_complex_to_filtration_data, + gudhi_simplex_tree_to_filtration_data, +) + + +class FakeSimplexTree: + """Small stand-in for Gudhi's SimplexTree API.""" + + def __init__(self, filtration): + self.filtration = filtration + + def get_filtration(self): + return iter(self.filtration) + + +def test_simplex_tree_conversion_uses_unit_weights_by_default(): + simplex_tree = FakeSimplexTree([ + ([0], 0.0), + ([1], 0.0), + ([2], 0.0), + ([0, 1], 1.0), + ([0, 1, 2], 2.0), + ]) + + filtration_data = gudhi_simplex_tree_to_filtration_data(simplex_tree) + + assert len(filtration_data) == 3 + assert filtration_data[0][0].shape == (3, 1) + assert filtration_data[1][0].shape == (1, 1) + assert filtration_data[2][0].shape == (1, 1) + + for _, simplex_weights in filtration_data: + assert torch.equal(simplex_weights, torch.ones_like(simplex_weights)) + + +def test_simplex_weight_callback_can_use_max_vertex_weight(): + vertex_weights = [0.5, 2.0, 1.25] + simplex_tree = FakeSimplexTree([ + ([0], 0.0), + ([1], 0.0), + ([2], 0.0), + ([0, 2], 1.0), + ([0, 1, 2], 2.0), + ]) + + def max_vertex_weight(simplex, dim, filtration_value): + return max(vertex_weights[vertex] for vertex in simplex) + + filtration_data = gudhi_simplex_tree_to_filtration_data( + simplex_tree, + simplex_weight_fn=max_vertex_weight, + ) + + assert torch.allclose( + filtration_data[0][1], + torch.tensor([0.5, 2.0, 1.25]), + ) + assert torch.allclose(filtration_data[1][1], torch.tensor([1.25])) + assert torch.allclose(filtration_data[2][1], torch.tensor([2.0])) + + +def test_simplex_tree_conversion_respects_dtype_and_device(): + device = torch.device("cpu") + simplex_tree = FakeSimplexTree([ + ([0], 0.0), + ([0, 1], 4.0), + ]) + + filtration_data = gudhi_simplex_tree_to_filtration_data( + simplex_tree, + dtype=torch.float64, + device=device, + ) + + for simplex_filters, simplex_weights in filtration_data: + assert simplex_filters.dtype == torch.float64 + assert simplex_weights.dtype == torch.float64 + assert simplex_filters.device == device + assert simplex_weights.device == device + + +def test_simplex_tree_conversion_can_use_alpha_instead_of_alpha_square(): + simplex_tree = FakeSimplexTree([ + ([0], 0.0), + ([0, 1], 4.0), + ]) + + filtration_data = gudhi_simplex_tree_to_filtration_data( + simplex_tree, + use_alpha_instead_of_alpha_square=True, + ) + + assert torch.allclose(filtration_data[0][0], torch.tensor([[0.0]])) + assert torch.allclose(filtration_data[1][0], torch.tensor([[2.0]])) + + +def test_simplex_tree_conversion_rejects_non_finite_filtration(): + simplex_tree = FakeSimplexTree([([0], float("inf"))]) + + with pytest.raises(ValueError, match="non-finite"): + gudhi_simplex_tree_to_filtration_data(simplex_tree) + + +def test_simplex_tree_conversion_rejects_negative_sqrt(): + simplex_tree = FakeSimplexTree([([0], -1.0)]) + + with pytest.raises(ValueError, match="negative"): + gudhi_simplex_tree_to_filtration_data( + simplex_tree, + use_alpha_instead_of_alpha_square=True, + ) + + +def test_alpha_complex_passes_point_weights_to_gudhi(monkeypatch): + calls = {} + simplex_tree = FakeSimplexTree([([0], 0.0)]) + + class FakeAlphaComplex: + def __init__(self, **kwargs): + calls["kwargs"] = kwargs + + def create_simplex_tree(self, max_alpha_square): + calls["max_alpha_square"] = max_alpha_square + return simplex_tree + + monkeypatch.setitem( + sys.modules, + "gudhi", + types.SimpleNamespace(AlphaComplex=FakeAlphaComplex), + ) + + points = [[0.0, 0.0]] + point_weights = [0.25] + + filtration_data, returned_simplex_tree = alpha_complex_to_filtration_data( + points, + point_weights=point_weights, + max_alpha_square=2.0, + ) + + assert calls["kwargs"] == {"points": points, "weights": point_weights} + assert calls["max_alpha_square"] == 2.0 + assert returned_simplex_tree is simplex_tree + assert torch.equal(filtration_data[0][1], torch.ones(1)) + + +def test_gudhi_alpha_complex_wecf_smoke(): + pytest.importorskip("gudhi") + + points = [ + [0.0, 0.0], + [1.0, 0.0], + [0.0, 1.0], + ] + + filtration_data, simplex_tree = alpha_complex_to_filtration_data(points) + wecf = compute_wecfs_general(filtration_data, num_vals=10) + + assert simplex_tree.num_simplices() > 0 + assert isinstance(wecf, torch.Tensor) + assert wecf.shape == (1, 10) + assert torch.isfinite(wecf).all() From 9f629a9292f793c98b8359e9b17bb77f9d87c46b Mon Sep 17 00:00:00 2001 From: alexjmccleary Date: Sat, 9 May 2026 12:02:17 -0600 Subject: [PATCH 3/3] Remove differentiable WECT --- pyect/__init__.py | 1 - pyect/differentiable_wect.py | 148 --------------- tests/test_dwect.py | 343 ----------------------------------- 3 files changed, 492 deletions(-) delete mode 100644 pyect/differentiable_wect.py delete mode 100644 tests/test_dwect.py diff --git a/pyect/__init__.py b/pyect/__init__.py index a1da5d9..3a135e2 100644 --- a/pyect/__init__.py +++ b/pyect/__init__.py @@ -2,7 +2,6 @@ from .tensor_complex import Complex from .directions import sample_directions_2d, sample_directions_3d from .image_ecf import Image_ECF_2D, Image_ECF_3D -from .differentiable_wect import DWECT from .general_filtrations import compute_wecfs_general from .preprocessing.mesh_processing import mesh_to_complex from .preprocessing.image_processing import ( diff --git a/pyect/differentiable_wect.py b/pyect/differentiable_wect.py deleted file mode 100644 index a04cdf5..0000000 --- a/pyect/differentiable_wect.py +++ /dev/null @@ -1,148 +0,0 @@ -"""For computing the differentiable WECT of a weighted geometric simplicial/cubical complex embedded in R^n.""" - -import torch -from typing import List, Tuple - - -class DWECT(torch.nn.Module): - """A torch module for computing the differentiable weighted Euler characteristic transform (DWECT) of a simplicial complex - discretized over a grid. - - This module may be used just for computing the DWECT, or used as a layer in a neural network. - Internally, the module stores the directions and number of heights used for sampling, so repeated forward calls - do not require these parameters to be passed in, and allow streamlined loading/saving of the module for consistent - computation. - - This module can also be converted to TorchScript using torch.jit.script for use - outside of Python. - """ - - def __init__(self, dirs: torch.Tensor, num_heights: int, growth_rate: float) -> None: - """Initializes the DWECT module. - - The initialized module is designed to compute the DWECT of a simplicial complex - embedded in R^[dirs.shape[1]], using dirs.shape[0] directions for sampling. - The discretization of the DWECT is parameterized by num_heights distinct height values. - - Args: - dirs: An (d x n) tensor of directions to use for sampling. - num_heights: A constant tensor, with the number of distinct height - values to round to as an integer - growth_rate: The growth rate for the sigmoid function. - """ - super().__init__() - dirs = torch.nn.functional.normalize(dirs, p=2, dim=1, eps=1e-12) - self.register_buffer("dirs", dirs) - self.growth_rate: float = float(growth_rate) - - num_heights = int(num_heights) - if num_heights <= 0: - raise ValueError("num_heights must be positive.") - self.num_heights: int = num_heights - - - def _soft_cum_sum(self, M: torch.Tensor) -> torch.Tensor: - """Computes a soft version of the cumulative sum of M. - - Args: - M (torch.Tensor): A tensor with shape (m,n). - - Returns: - torch.Tensor: A tensor with the same shape as M. - """ - - n = M.size(1) - idx = torch.arange(n, dtype=M.dtype, device=M.device) - K = torch.sigmoid(self.growth_rate * (idx.unsqueeze(1) - idx.unsqueeze(0))) - return M @ K.T - - def _vertex_indices( - self, - vertex_coords: torch.Tensor, - ) -> torch.Tensor: - """Calculates the height values of each vertex and converts them to an index in range(num_heights). - - Args: - vertex_coords (torch.Tensor): A tensor of shape (k_0, n) with rows representing the coordinates of the vertices. - - Returns: - torch.Tensor: A tensor of shape (k_0, d) with the height indices of each vertex in each direction. - """ - - eps = 1e-12 # only used in the case where all vertices are at the origin - - v_norms = torch.norm(vertex_coords, dim=1) - max_height = torch.amax(v_norms).clamp(min=eps) - v_heights = vertex_coords @ self.dirs.T - - v_indices = torch.ceil( - (self.num_heights - 1) * (max_height + v_heights) / (2.0 * max_height) - ).clamp(0, self.num_heights - 1).long() - - return v_indices - - def forward( - self, - complex_data: List[Tuple[torch.Tensor, torch.Tensor]] - ) -> torch.Tensor: - """Calculates a discretization of the DWECT of a complex embedded in n-dimensional space. - - Args: - complex_data: A weighted simplicial or cubical complex, represented as a list of pairs of tensors. - complex_data[0] = (v_coords, v_weights): - v_coords (torch.Tensor): A tensor of shape (k_0, n) where k_0 is the number of vertices. - Rows are the coordinates of the vertices. - - v_weights (torch.Tensor): A tensor of shape (k_0). Values are the weights of the vertices. - - for i > 0: - complex_data[i] = (simp_verts, simp_weights): - simp_verts (torch.Tensor): A tensor of shape (k_i, i+1) where k_i is the number of i-simplices. - Rows are the vertex sets of the i-simplices. - - simp_weights (torch.Tensor): A tensor of shape (k_i). Values are the weights of the i-simplices. - - Returns: - dwect (torch.Tensor): A 2d tensor of shape (self.dirs.shape[0], self.num_heights) - containing the DWECT. - """ - - d = self.dirs.size(dim=0) - h = self.num_heights - - device = self.dirs.device - v_coords = complex_data[0][0].to(device=device, dtype=torch.float32) - v_weights = complex_data[0][1].to(device=device, dtype=torch.float32) - - # Check for empty inputs - if v_coords.size(0) == 0: - return torch.zeros((d, h), dtype=torch.float32, device=device) - - expanded_v_weights = v_weights.unsqueeze(0).expand( - d, -1 - ) # Expand to shape (d, k_0) - - # Initialize the differentiated WECT - diff_wect = torch.zeros((d, h), dtype=torch.float32, device=device) - - # Compute the height index of each vertex - v_indices = self._vertex_indices(v_coords) - - # Add the contribution of the vertices to the differentiated WECT - diff_wect.scatter_add_(1, v_indices.T, expanded_v_weights) - - for i in range(1, len(complex_data)): - simp_verts = complex_data[i][0].to(device=device, dtype=torch.long) - simp_weights = complex_data[i][1].to(device=device, dtype=torch.float32) - - # Expand to shape (d, k_i) - expanded_simp_weights = (-1) ** i * simp_weights.unsqueeze(0).expand(d, -1) - - # Compute the maximum index for each simplex's vertices - simp_indices = v_indices[simp_verts] - max_simp_indices = torch.amax(simp_indices, dim=1) - - # Add the contribution of the i-simplices to the differentiated WECT - diff_wect.scatter_add_(1, max_simp_indices.T, expanded_simp_weights) - - return self._soft_cum_sum(diff_wect) diff --git a/tests/test_dwect.py b/tests/test_dwect.py deleted file mode 100644 index b4907da..0000000 --- a/tests/test_dwect.py +++ /dev/null @@ -1,343 +0,0 @@ -"""Tests for the DWECT (Differentiable WECT) module in differentiable_wect.py""" - -import torch -import pytest - -from pyect import DWECT, WECT, Complex - - -def build_triangle_complex(device="cpu"): - """Build a simple triangle complex for testing.""" - vcoords = torch.tensor( - [[-1.0, 0.0], [0.0, 1.0], [1.0, 0.0]], device=device - ) - vweights = torch.ones(3, device=device) - - ecoords = torch.tensor([[0, 1], [1, 2], [2, 0]], device=device) - eweights = torch.ones(3, device=device) - - fcoords = torch.tensor([[0, 1, 2]], device=device) - fweights = torch.ones(1, device=device) - - return Complex( - (vcoords, vweights), - (ecoords, eweights), - (fcoords, fweights), - ) - - -class TestDWECTConstruction: - """Tests for DWECT module construction.""" - - def test_basic_construction(self): - """Test basic DWECT construction.""" - dirs = torch.tensor([[1.0, 0.0]]) - dwect = DWECT(dirs, num_heights=10, growth_rate=10.0) - - assert dwect.num_heights == 10 - assert dwect.growth_rate == 10.0 - assert dwect.dirs.shape == (1, 2) - - def test_multiple_directions(self): - """Test DWECT with multiple directions.""" - dirs = torch.tensor([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) - dwect = DWECT(dirs, num_heights=5, growth_rate=5.0) - - assert dwect.dirs.shape == (3, 2) - - def test_direction_normalization(self): - """Test that directions are normalized.""" - dirs = torch.tensor([[3.0, 4.0]]) # norm = 5 - dwect = DWECT(dirs, num_heights=5, growth_rate=10.0) - - norms = torch.norm(dwect.dirs, dim=1) - assert torch.allclose(norms, torch.ones(1), atol=1e-6) - - def test_invalid_num_heights(self): - """Test that non-positive num_heights raises error.""" - dirs = torch.tensor([[1.0, 0.0]]) - - with pytest.raises(ValueError, match="num_heights must be positive"): - DWECT(dirs, num_heights=0, growth_rate=10.0) - - with pytest.raises(ValueError, match="num_heights must be positive"): - DWECT(dirs, num_heights=-5, growth_rate=10.0) - - def test_various_growth_rates(self): - """Test DWECT with various growth rates.""" - dirs = torch.tensor([[1.0, 0.0]]) - - for rate in [0.1, 1.0, 10.0, 100.0]: - dwect = DWECT(dirs, num_heights=5, growth_rate=rate) - assert dwect.growth_rate == rate - - -class TestDWECTForward: - """Tests for DWECT forward pass.""" - - def test_output_shape_single_direction(self): - """Test output shape with single direction.""" - device = torch.device("cpu") - c = build_triangle_complex(device) - - dirs = torch.tensor([[1.0, 0.0]], device=device) - dwect = DWECT(dirs, num_heights=10, growth_rate=10.0).to(device) - - result = dwect(c) - - assert result.shape == (1, 10) - - def test_output_shape_multiple_directions(self): - """Test output shape with multiple directions.""" - device = torch.device("cpu") - c = build_triangle_complex(device) - - dirs = torch.tensor([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]], device=device) - dwect = DWECT(dirs, num_heights=8, growth_rate=10.0).to(device) - - result = dwect(c) - - assert result.shape == (3, 8) - - def test_output_is_finite(self): - """Test that output contains no NaN or Inf values.""" - device = torch.device("cpu") - c = build_triangle_complex(device) - - dirs = torch.tensor([[1.0, 0.0], [0.0, 1.0]], device=device) - dwect = DWECT(dirs, num_heights=10, growth_rate=10.0).to(device) - - result = dwect(c) - - assert torch.isfinite(result).all() - - def test_empty_complex(self): - """Test DWECT with empty complex.""" - device = torch.device("cpu") - - vcoords = torch.zeros((0, 2), device=device) - vweights = torch.zeros(0, device=device) - - c = Complex((vcoords, vweights)) - - dirs = torch.tensor([[1.0, 0.0]], device=device) - dwect = DWECT(dirs, num_heights=5, growth_rate=10.0).to(device) - - result = dwect(c) - - assert result.shape == (1, 5) - assert torch.allclose(result, torch.zeros((1, 5), device=device)) - - -class TestDWECTGradients: - """Tests for DWECT gradient computation.""" - - def test_gradients_flow(self): - """Test that gradients flow through DWECT.""" - device = torch.device("cpu") - - # Create complex with requires_grad - vcoords = torch.tensor( - [[-1.0, 0.0], [0.0, 1.0], [1.0, 0.0]], - device=device, - requires_grad=True - ) - vweights = torch.ones(3, device=device, requires_grad=True) - - ecoords = torch.tensor([[0, 1], [1, 2], [2, 0]], device=device) - eweights = torch.ones(3, device=device, requires_grad=True) - - fcoords = torch.tensor([[0, 1, 2]], device=device) - fweights = torch.ones(1, device=device, requires_grad=True) - - c = Complex( - (vcoords, vweights), - (ecoords, eweights), - (fcoords, fweights), - ) - - dirs = torch.tensor([[1.0, 0.0]], device=device) - dwect = DWECT(dirs, num_heights=5, growth_rate=10.0).to(device) - - result = dwect(c) - loss = result.sum() - loss.backward() - - # Check gradients exist for weights - assert vweights.grad is not None - assert eweights.grad is not None - assert fweights.grad is not None - - def test_gradients_are_finite(self): - """Test that computed gradients are finite.""" - device = torch.device("cpu") - - vcoords = torch.tensor( - [[-1.0, 0.0], [0.0, 1.0], [1.0, 0.0]], - device=device - ) - vweights = torch.ones(3, device=device, requires_grad=True) - - ecoords = torch.tensor([[0, 1], [1, 2], [2, 0]], device=device) - eweights = torch.ones(3, device=device, requires_grad=True) - - fcoords = torch.tensor([[0, 1, 2]], device=device) - fweights = torch.ones(1, device=device, requires_grad=True) - - c = Complex( - (vcoords, vweights), - (ecoords, eweights), - (fcoords, fweights), - ) - - dirs = torch.tensor([[1.0, 0.0]], device=device) - dwect = DWECT(dirs, num_heights=5, growth_rate=10.0).to(device) - - result = dwect(c) - loss = result.sum() - loss.backward() - - assert torch.isfinite(vweights.grad).all() - assert torch.isfinite(eweights.grad).all() - assert torch.isfinite(fweights.grad).all() - - -class TestDWECTSoftCumsum: - """Tests for the soft cumsum functionality.""" - - def test_soft_cumsum_shape(self): - """Test soft cumsum preserves shape.""" - dirs = torch.tensor([[1.0, 0.0]]) - dwect = DWECT(dirs, num_heights=5, growth_rate=10.0) - - M = torch.randn(3, 5) - result = dwect._soft_cum_sum(M) - - assert result.shape == M.shape - - def test_high_growth_rate_approaches_cumsum(self): - """Test that high growth rate approximates regular cumsum.""" - dirs = torch.tensor([[1.0, 0.0]]) - dwect_high = DWECT(dirs, num_heights=5, growth_rate=1000.0) - - M = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) - soft_result = dwect_high._soft_cum_sum(M) - hard_result = torch.cumsum(M, dim=1) - - # With very high growth rate, should be close to regular cumsum - # Note: soft_cum_sum uses sigmoid which saturates but doesn't equal hard cumsum - # Check that monotonicity is preserved and values are in similar range - assert soft_result[0, -1] > soft_result[0, 0] # Monotonic increase - assert torch.isfinite(soft_result).all() - - def test_low_growth_rate_is_smooth(self): - """Test that low growth rate produces smooth output.""" - dirs = torch.tensor([[1.0, 0.0]]) - dwect_low = DWECT(dirs, num_heights=5, growth_rate=0.5) - - M = torch.tensor([[1.0, 0.0, 0.0, 0.0, 1.0]]) - result = dwect_low._soft_cum_sum(M) - - # With low growth rate, output should be smoother than input - # Check that middle values are not zero - assert result[0, 2] > 0 - - -class TestDWECTComparisonToWECT: - """Tests comparing DWECT to WECT behavior.""" - - def test_dwect_wect_same_shape(self): - """Test DWECT and WECT produce same shape output.""" - device = torch.device("cpu") - c = build_triangle_complex(device) - - dirs = torch.tensor([[1.0, 0.0], [0.0, 1.0]], device=device) - - wect = WECT(dirs, num_heights=10).to(device) - dwect = DWECT(dirs, num_heights=10, growth_rate=100.0).to(device) - - wect_result = wect(c) - dwect_result = dwect(c) - - assert wect_result.shape == dwect_result.shape - - def test_high_growth_rate_similar_to_wect(self): - """Test that DWECT with high growth rate is similar to WECT.""" - device = torch.device("cpu") - c = build_triangle_complex(device) - - dirs = torch.tensor([[1.0, 0.0]], device=device) - - wect = WECT(dirs, num_heights=5).to(device) - dwect = DWECT(dirs, num_heights=5, growth_rate=1000.0).to(device) - - wect_result = wect(c) - dwect_result = dwect(c) - - # With very high growth rate, should be close - assert torch.allclose(wect_result, dwect_result, atol=0.5) - - -class TestDWECTTorchScript: - """Tests for TorchScript compatibility.""" - - def test_can_script(self): - """Test that DWECT can be compiled with TorchScript.""" - dirs = torch.tensor([[1.0, 0.0], [0.0, 1.0]]) - dwect = DWECT(dirs, num_heights=10, growth_rate=10.0) - - scripted = torch.jit.script(dwect) - assert scripted is not None - - def test_scripted_gives_same_result(self): - """Test that scripted DWECT gives same results.""" - device = torch.device("cpu") - c = build_triangle_complex(device) - - dirs = torch.tensor([[1.0, 0.0]], device=device) - dwect = DWECT(dirs, num_heights=5, growth_rate=10.0).to(device) - scripted = torch.jit.script(dwect) - - result_normal = dwect(c) - result_scripted = scripted(c) - - assert torch.allclose(result_normal, result_scripted, atol=1e-6) - - -class TestDWECT3D: - """Tests for DWECT in 3D.""" - - def test_3d_triangle(self): - """Test DWECT on a triangle in 3D.""" - device = torch.device("cpu") - - vcoords = torch.tensor([ - [0.0, 0.0, 0.0], - [1.0, 0.0, 0.0], - [0.5, 1.0, 0.0], - ], device=device) - vweights = torch.ones(3, device=device) - - ecoords = torch.tensor([[0, 1], [1, 2], [2, 0]], device=device) - eweights = torch.ones(3, device=device) - - fcoords = torch.tensor([[0, 1, 2]], device=device) - fweights = torch.ones(1, device=device) - - c = Complex( - (vcoords, vweights), - (ecoords, eweights), - (fcoords, fweights), - ) - - dirs = torch.tensor([ - [1.0, 0.0, 0.0], - [0.0, 1.0, 0.0], - [0.0, 0.0, 1.0] - ], device=device) - dwect = DWECT(dirs, num_heights=8, growth_rate=10.0).to(device) - - result = dwect(c) - - assert result.shape == (3, 8) - assert torch.isfinite(result).all()