Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,18 @@ To install `pyECT`, use pip:
pip install pyect
```

Gudhi alpha-complex support is optional:

```bash
pip install pyect[gudhi]
```

The Gudhi integration lives at `pyect.integrations.gudhi` so Gudhi is not
imported by the core package. In `alpha_complex_to_filtration_data`,
`point_weights` are passed to Gudhi to construct the alpha filtration. The
pyECT simplex weights are `1.0` by default; pass `simplex_weight_fn` to use a
custom weighting rule, such as the max of the simplex vertex weights.

## Usage

Here's a simple example of how to use `pyECT`:
Expand Down
37 changes: 37 additions & 0 deletions examples/gudhi_alpha_complex_wecf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Example WECF computation from a Gudhi alpha complex."""

import torch

from pyect import compute_wecfs_general
from pyect.integrations.gudhi import alpha_complex_to_filtration_data


def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

points = [
[1.0, 1.0],
[7.0, 0.0],
[4.0, 6.0],
[9.0, 6.0],
[0.0, 14.0],
[2.0, 19.0],
[9.0, 17.0],
]

point_weights = [0.0, 0.2, 0.1, 0.3, 0.0, 0.1, 0.2]

filtration_data, simplex_tree = alpha_complex_to_filtration_data(
points,
point_weights=point_weights,
device=device,
)

wecf = compute_wecfs_general(filtration_data, num_vals=200)

print(simplex_tree.num_simplices())
print(wecf.shape)


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion pyect/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .tensor_complex import Complex
from .directions import sample_directions_2d, sample_directions_3d
from .image_ecf import Image_ECF_2D, Image_ECF_3D
from .differentiable_wect import DWECT
from .general_filtrations import compute_wecfs_general
from .preprocessing.mesh_processing import mesh_to_complex
from .preprocessing.image_processing import (
weighted_freudenthal,
Expand Down
148 changes: 0 additions & 148 deletions pyect/differentiable_wect.py

This file was deleted.

59 changes: 59 additions & 0 deletions pyect/general_filtrations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""For computing the WECF of an arbitrary (not neccesarily lower-star) filtration.
If your filtrations are lower-star (they usually are), then use wecfs.py instead.
"""

import torch
from typing import List, Tuple

def compute_wecfs_general(
filtration_data: List[Tuple[torch.Tensor, torch.Tensor]],
num_vals: int
) -> torch.Tensor:
"""Calculates WECFs for filtrations with values assigned to every simplex.

Args:
filtration_data: A weighted simplicial or cubical complex with a collection
of filter functions defined on each simplex, represented as a list of
pairs of tensors. The list index is the simplex dimension.

filtration_data[i] = (simplex_filters, simplex_weights):
simplex_filters (torch.Tensor): A tensor of shape (k_i, m), where
k_i is the number of i-simplices and m is the number of filter
functions. Each row contains the filter values of one simplex.

simplex_weights (torch.Tensor): A tensor of shape (k_i). Values
are the weights of the i-simplices.

Returns:
wecfs (torch.Tensor): A 2d tensor of shape (m, num_vals)
containing the WECFs.
"""

if num_vals <= 0:
raise ValueError("num_vals must be positive.")

if len(filtration_data) == 0:
raise ValueError("filtration_data must be non-empty.")

device = filtration_data[0][0].device
m = filtration_data[0][0].size(dim=1)
eps = torch.finfo(torch.float32).eps

max_val = torch.cat([f.reshape(-1) for f, _ in filtration_data]).max()
min_val = torch.cat([f.reshape(-1) for f, _ in filtration_data]).min()
val_range = torch.clamp(max_val - min_val, min=eps)

diff_wecfs = torch.zeros((m, num_vals), dtype=torch.float32, device=device)

for i, (simplex_filters, simplex_weights) in enumerate(filtration_data):
simplex_indices = torch.ceil(
(num_vals - 1) * (simplex_filters - min_val) / (val_range)
).clamp(0, num_vals-1).long()

expanded_simplex_weights = (
(-1) ** i * simplex_weights.unsqueeze(0).expand(m, -1)
)

diff_wecfs.scatter_add_(1, simplex_indices.T, expanded_simplex_weights)

return torch.cumsum(diff_wecfs, dim=1)
1 change: 1 addition & 0 deletions pyect/integrations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Optional integrations for pyECT."""
Loading
Loading