From b934cce215a1200e9743ee87c087ac210cb33eec Mon Sep 17 00:00:00 2001 From: jeremy-wayland Date: Wed, 3 Jun 2026 14:26:16 +0200 Subject: [PATCH 1/7] add DGL integration (explicit conversion of transforms in `perturbations.py`) --- rings/integrations/__init__.py | 24 ++++- rings/integrations/dgl.py | 167 +++++++++++++++++++++++++++++++++ rings/integrations/study.py | 36 ++++--- 3 files changed, 213 insertions(+), 14 deletions(-) create mode 100644 rings/integrations/dgl.py diff --git a/rings/integrations/__init__.py b/rings/integrations/__init__.py index e14abfa..6af5393 100644 --- a/rings/integrations/__init__.py +++ b/rings/integrations/__init__.py @@ -1,6 +1,15 @@ from rings.integrations.study import SeparabilityStudy -__all__ = ["SeparabilityStudy", "SeparabilityCallback"] +__all__ = [ + "SeparabilityStudy", + "SeparabilityCallback", + "DGLOriginal", + "DGLEmptyFeatures", + "DGLRandomFeatures", + "DGLEmptyGraph", + "DGLCompleteGraph", + "DGLRandomGraph", +] def __getattr__(name): @@ -8,4 +17,17 @@ def __getattr__(name): from rings.integrations.lightning import SeparabilityCallback return SeparabilityCallback + + if name in [ + "DGLOriginal", + "DGLEmptyFeatures", + "DGLRandomFeatures", + "DGLEmptyGraph", + "DGLCompleteGraph", + "DGLRandomGraph", + ]: + from rings.integrations import dgl + + return getattr(dgl, name) + raise AttributeError(f"module 'rings.integrations' has no attribute {name!r}") diff --git a/rings/integrations/dgl.py b/rings/integrations/dgl.py new file mode 100644 index 0000000..5d7b407 --- /dev/null +++ b/rings/integrations/dgl.py @@ -0,0 +1,167 @@ +"""DGL compatibility wrappers for core RINGS perturbations. + +This module keeps :mod:`rings.perturbations` as the source of truth by +converting ``dgl.DGLGraph`` objects to temporary PyG ``Data`` objects, applying +the perturbation, then converting back to DGL. +""" + +from typing import Callable, Optional + +import torch +from torch_geometric.data import Data + +from rings.perturbations import ( + CompleteGraph, + EmptyFeatures, + EmptyGraph, + Original, + RandomFeatures, + RandomGraph, +) + +_DGL_IMPORT_ERROR = None +try: + import dgl +except Exception as exc: # pragma: no cover - optional dependency import surface + dgl = None + _DGL_IMPORT_ERROR = exc + + +def _check_dgl(): + if dgl is None: + raise ImportError( + "DGL is required for these perturbations. " + "Install it with 'pip install dgl'." + ) from _DGL_IMPORT_ERROR + + +def dgl_to_pyg(g: "dgl.DGLGraph", feat_name: str = "x") -> Data: + """Convert a homogeneous DGL graph to a PyG ``Data`` object.""" + _check_dgl() + src, dst = g.edges() + data = Data( + edge_index=torch.stack([src, dst], dim=0), + num_nodes=g.num_nodes(), + ) + + for key, value in g.ndata.items(): + if key == feat_name: + data.x = value.clone() + else: + data[f"ndata__{key}"] = value.clone() + return data + + +def pyg_to_dgl(data: Data, feat_name: str = "x", device=None) -> "dgl.DGLGraph": + """Convert a PyG ``Data`` object to a homogeneous DGL graph.""" + _check_dgl() + if data.edge_index is None: + src = torch.empty((0,), dtype=torch.int64) + dst = torch.empty((0,), dtype=torch.int64) + else: + src, dst = data.edge_index[0], data.edge_index[1] + + if device is not None: + src = src.to(device) + dst = dst.to(device) + + new_g = dgl.graph((src, dst), num_nodes=data.num_nodes, device=device) + if getattr(data, "x", None) is not None: + x = data.x if device is None else data.x.to(device) + new_g.ndata[feat_name] = x + + for key, value in data.to_dict().items(): + if not key.startswith("ndata__"): + continue + out_key = key.split("ndata__", 1)[1] + new_g.ndata[out_key] = value if device is None else value.to(device) + return new_g + + +def as_dgl_transform( + pyg_transform: Callable, feat_name: str = "x" +) -> Callable[["dgl.DGLGraph"], "dgl.DGLGraph"]: + """Wrap a PyG perturbation for DGL graphs via round-trip conversion.""" + + def _transform(g: "dgl.DGLGraph") -> "dgl.DGLGraph": + _check_dgl() + data = dgl_to_pyg(g, feat_name=feat_name) + out = pyg_transform(data) + return pyg_to_dgl(out, feat_name=feat_name, device=g.device) + + return _transform + + +class DGLOriginal: + """DGL wrapper for :class:`rings.perturbations.Original`.""" + + def __init__(self, feat_name: str = "x"): + self._transform = as_dgl_transform(Original(), feat_name=feat_name) + + def __call__(self, g: "dgl.DGLGraph") -> "dgl.DGLGraph": + return self._transform(g) + + +class DGLEmptyFeatures: + """DGL wrapper for :class:`rings.perturbations.EmptyFeatures`.""" + + def __init__(self, feat_name: str = "x"): + self.feat_name = feat_name + self._transform = as_dgl_transform(EmptyFeatures(), feat_name=feat_name) + + def __call__(self, g: "dgl.DGLGraph") -> "dgl.DGLGraph": + return self._transform(g) + + +class DGLRandomFeatures: + """DGL wrapper for :class:`rings.perturbations.RandomFeatures`.""" + + def __init__( + self, + shuffle: bool = False, + feat_name: str = "x", + generator: Optional[torch.Generator] = None, + ): + self.shuffle = shuffle + self.feat_name = feat_name + self.generator = generator + self._transform = RandomFeatures(shuffle=shuffle) + if self.generator is not None: + self._transform.generator = self.generator + self._dgl_transform = as_dgl_transform(self._transform, feat_name=feat_name) + + def __call__(self, g: "dgl.DGLGraph") -> "dgl.DGLGraph": + return self._dgl_transform(g) + + +class DGLEmptyGraph: + """DGL wrapper for :class:`rings.perturbations.EmptyGraph`.""" + + def __init__(self, feat_name: str = "x"): + self._transform = as_dgl_transform(EmptyGraph(), feat_name=feat_name) + + def __call__(self, g: "dgl.DGLGraph") -> "dgl.DGLGraph": + return self._transform(g) + + +class DGLCompleteGraph: + """DGL wrapper for :class:`rings.perturbations.CompleteGraph`.""" + + def __init__(self, feat_name: str = "x"): + self._transform = as_dgl_transform(CompleteGraph(), feat_name=feat_name) + + def __call__(self, g: "dgl.DGLGraph") -> "dgl.DGLGraph": + return self._transform(g) + + +class DGLRandomGraph: + """DGL wrapper for :class:`rings.perturbations.RandomGraph`.""" + + def __init__(self, p: float = 0.1, generator: Optional[torch.Generator] = None): + self._transform = RandomGraph(p=p) + if generator is not None: + self._transform.generator = generator + self._dgl_transform = as_dgl_transform(self._transform) + + def __call__(self, g: "dgl.DGLGraph") -> "dgl.DGLGraph": + return self._dgl_transform(g) diff --git a/rings/integrations/study.py b/rings/integrations/study.py index 290fe5a..9776f54 100644 --- a/rings/integrations/study.py +++ b/rings/integrations/study.py @@ -92,21 +92,31 @@ def runs(self) -> Iterator[Tuple[str, Callable, int]]: @staticmethod def apply(data: Any, transform: Callable) -> Any: - """Apply ``transform`` to a single ``Data`` object or to a PyG ``Dataset``. + """Apply ``transform`` to a PyG ``Data``/``Dataset`` or a DGL ``DGLGraph``. - For a ``Dataset``, this sets ``dataset.transform`` so that the transform is - applied lazily on each ``__getitem__`` call — the PyG-idiomatic pattern. For - a single ``Data`` object, the transform is called directly and the result - is returned. Other inputs are passed straight through ``transform(data)`` - as a fallback. + For a PyG ``Dataset``, this sets ``dataset.transform`` for lazy application. + For PyG ``Data`` or DGL ``DGLGraph``, the transform is called directly. """ - from torch_geometric.data import Data, Dataset - - if isinstance(data, Dataset): - data.transform = transform - return data - if isinstance(data, Data): - return transform(data) + # Try PyG + try: + from torch_geometric.data import Data, Dataset + if isinstance(data, Dataset): + data.transform = transform + return data + if isinstance(data, Data): + return transform(data) + except ImportError: + pass + + # Try DGL + try: + import dgl + if isinstance(data, dgl.DGLGraph): + return transform(data) + except ImportError: + pass + + # Fallback for generic objects return transform(data) def record(self, name: str, score: float) -> None: From c6d816094155d80f8787e82e7df9ae63a308ff71 Mon Sep 17 00:00:00 2001 From: jeremy-wayland Date: Wed, 3 Jun 2026 14:26:39 +0200 Subject: [PATCH 2/7] add unit tests for DGL integrations --- tests/test_dgl_integration.py | 137 ++++++++++++++++++ tests/test_lightning_framework_integration.py | 115 +++++++++++++++ 2 files changed, 252 insertions(+) create mode 100644 tests/test_dgl_integration.py create mode 100644 tests/test_lightning_framework_integration.py diff --git a/tests/test_dgl_integration.py b/tests/test_dgl_integration.py new file mode 100644 index 0000000..a7c9013 --- /dev/null +++ b/tests/test_dgl_integration.py @@ -0,0 +1,137 @@ +from collections import Counter + +import pytest +import torch + +try: + import dgl +except Exception as exc: + pytest.skip(f"DGL unavailable in test runtime: {exc}", allow_module_level=True) + +from rings.integrations import ( # noqa: E402 + DGLCompleteGraph, + DGLEmptyFeatures, + DGLEmptyGraph, + DGLOriginal, + DGLRandomFeatures, + DGLRandomGraph, + SeparabilityStudy, +) +from rings.integrations.dgl import as_dgl_transform # noqa: E402 +from rings.perturbations import EmptyFeatures # noqa: E402 + + +def _citation_like_graph(): + # Directed "citation-like" toy graph: + # paper 0 cites 1 and 2; paper 3 cites 0 and 2; plus reciprocal links in a cluster. + src = torch.tensor([0, 0, 1, 2, 3, 3, 4, 5], dtype=torch.int64) + dst = torch.tensor([1, 2, 2, 1, 0, 2, 5, 4], dtype=torch.int64) + g = dgl.graph((src, dst), num_nodes=6) + g.ndata["x"] = torch.tensor( + [ + [1.0, 0.2, 0.1], + [0.5, 1.1, 0.3], + [0.7, 0.9, 0.4], + [1.3, 0.1, 0.8], + [0.2, 0.3, 1.2], + [0.1, 0.4, 1.1], + ], + dtype=torch.float32, + ) + return g + + +class TestDGLIntegration: + def test_original_matches_identity_semantics(self): + g = _citation_like_graph() + out = DGLOriginal()(g) + assert out is not g + assert torch.equal(out.ndata["x"], g.ndata["x"]) + out_src, out_dst = out.edges() + src, dst = g.edges() + assert torch.equal(out_src, src) + assert torch.equal(out_dst, dst) + + def test_empty_features_sets_single_zero_feature(self): + g = _citation_like_graph() + out = DGLEmptyFeatures()(g) + assert out.num_edges() == g.num_edges() + assert out.ndata["x"].shape == (g.num_nodes(), 1) + assert torch.all(out.ndata["x"] == 0) + + def test_random_features_shuffle_preserves_feature_multiset(self): + g = _citation_like_graph() + out = DGLRandomFeatures(shuffle=True)(g) + assert out.ndata["x"].shape == g.ndata["x"].shape + + original_rows = Counter(tuple(row.tolist()) for row in g.ndata["x"]) + shuffled_rows = Counter(tuple(row.tolist()) for row in out.ndata["x"]) + assert original_rows == shuffled_rows + + def test_random_features_respects_seeded_global_rng(self): + g = _citation_like_graph() + + torch.manual_seed(123) + out1 = DGLRandomFeatures(shuffle=False)(g).ndata["x"] + torch.manual_seed(123) + out2 = DGLRandomFeatures(shuffle=False)(g).ndata["x"] + + assert torch.allclose(out1, out2) + + def test_random_features_shuffle_with_missing_feature_is_noop(self): + g = dgl.graph((torch.tensor([0, 1]), torch.tensor([1, 0])), num_nodes=2) + out = DGLRandomFeatures(shuffle=True, feat_name="missing")(g) + assert out.num_nodes() == 2 + assert "missing" not in out.ndata + + def test_empty_graph_removes_edges_but_keeps_node_features(self): + g = _citation_like_graph() + out = DGLEmptyGraph()(g) + assert out.num_nodes() == g.num_nodes() + assert out.num_edges() == 0 + assert torch.equal(out.ndata["x"], g.ndata["x"]) + + def test_complete_graph_has_all_directed_edges_without_self_loops(self): + g = _citation_like_graph() + out = DGLCompleteGraph()(g) + n = g.num_nodes() + + assert out.num_edges() == n * (n - 1) + src, dst = out.edges() + assert torch.all(src != dst) + assert torch.equal(out.ndata["x"], g.ndata["x"]) + + def test_random_graph_has_no_self_loops_and_keeps_features(self): + g = _citation_like_graph() + out = DGLRandomGraph(p=0.4)(g) + src, dst = out.edges() + assert out.num_nodes() == g.num_nodes() + assert torch.all(src != dst) + assert torch.equal(out.ndata["x"], g.ndata["x"]) + + def test_study_apply_supports_dgl_graph(self): + g = _citation_like_graph() + out = SeparabilityStudy.apply(g, DGLEmptyGraph()) + assert out.num_nodes() == g.num_nodes() + assert out.num_edges() == 0 + + def test_adapter_preserves_non_feature_ndata(self): + g = _citation_like_graph() + g.ndata["y"] = torch.arange(g.num_nodes(), dtype=torch.float32).unsqueeze(1) + out = DGLEmptyGraph()(g) + assert "y" in out.ndata + assert torch.equal(out.ndata["y"], g.ndata["y"]) + + def test_as_dgl_transform_supports_custom_feature_key(self): + g = dgl.graph( + (torch.tensor([0, 1, 2]), torch.tensor([1, 2, 0])), + num_nodes=3, + ) + g.ndata["h"] = torch.tensor( + [[3.0, 1.0], [1.0, 4.0], [2.0, 2.0]], dtype=torch.float32 + ) + wrapped = as_dgl_transform(EmptyFeatures(), feat_name="h") + out = wrapped(g) + assert "h" in out.ndata + assert out.ndata["h"].shape == (3, 1) + assert torch.all(out.ndata["h"] == 0) diff --git a/tests/test_lightning_framework_integration.py b/tests/test_lightning_framework_integration.py new file mode 100644 index 0000000..3a8c806 --- /dev/null +++ b/tests/test_lightning_framework_integration.py @@ -0,0 +1,115 @@ +import pytest + +pl = pytest.importorskip("pytorch_lightning") +torch = pytest.importorskip("torch") +pytest.importorskip("torch_geometric") + +from torch.utils.data import DataLoader, Dataset # noqa: E402 +from torch_geometric.datasets import KarateClub # noqa: E402 + +from rings import EmptyGraph, Original # noqa: E402 +from rings.integrations import ( # noqa: E402 + DGLEmptyGraph, + DGLOriginal, + SeparabilityCallback, + SeparabilityStudy, +) + + +class _SingleGraphDataset(Dataset): + def __init__(self, graph): + self.graph = graph + + def __len__(self): + return 1 + + def __getitem__(self, idx): + return self.graph + + +def _single_graph_loader(graph): + dataset = _SingleGraphDataset(graph) + return DataLoader(dataset, batch_size=1, collate_fn=lambda batch: batch[0]) + + +class _GraphScoreModule(pl.LightningModule): + def __init__(self): + super().__init__() + self._dummy = torch.nn.Parameter(torch.zeros(1)) + + def configure_optimizers(self): + return torch.optim.SGD([self._dummy], lr=0.1) + + def test_step(self, batch, batch_idx): + if hasattr(batch, "edge_index"): + num_nodes = int(batch.num_nodes) + num_edges = int(batch.edge_index.size(1)) + else: + num_nodes = int(batch.num_nodes()) + num_edges = int(batch.num_edges()) + score = float(num_edges) / max(float(num_nodes), 1.0) + self.log("test_acc", score, prog_bar=False, on_step=False, on_epoch=True) + return score + + +def _run_study_with_lightning(base_graph, perturbations, num_seeds=2): + study = SeparabilityStudy(perturbations=perturbations, num_seeds=num_seeds) + model = _GraphScoreModule() + + for name, transform, seed in study.runs(): + torch.manual_seed(seed) + graph = study.apply(base_graph, transform) + callback = SeparabilityCallback(study, perturbation_name=name, metric_key="test_acc") + trainer = pl.Trainer( + accelerator="cpu", + devices=1, + max_epochs=1, + logger=False, + enable_checkpointing=False, + enable_model_summary=False, + callbacks=[callback], + ) + trainer.test(model, dataloaders=_single_graph_loader(graph), verbose=False) + return study + + +def _load_dgl_karateclub_or_skip(): + try: + import dgl # noqa: F401 + from dgl.data import KarateClubDataset + except Exception as exc: + pytest.skip(f"DGL unavailable in test runtime: {exc}") + return KarateClubDataset()[0] + + +def test_lightning_callback_with_pyg_karateclub(): + base_graph = KarateClub()[0] + study = _run_study_with_lightning( + base_graph=base_graph, + perturbations={"Original": Original(), "EmptyGraph": EmptyGraph()}, + num_seeds=2, + ) + + assert len(study.scores["Original"]) == 2 + assert len(study.scores["EmptyGraph"]) == 2 + assert all(score > 0 for score in study.scores["Original"]) + assert all(score == 0 for score in study.scores["EmptyGraph"]) + + +def test_lightning_callback_with_dgl_karateclub(): + base_graph = _load_dgl_karateclub_or_skip() + if "x" not in base_graph.ndata: + if "feat" not in base_graph.ndata: + pytest.skip("KarateClubDataset does not expose expected node features.") + base_graph.ndata["x"] = base_graph.ndata["feat"].float() + + study = _run_study_with_lightning( + base_graph=base_graph, + perturbations={"Original": DGLOriginal(), "EmptyGraph": DGLEmptyGraph()}, + num_seeds=2, + ) + + assert len(study.scores["Original"]) == 2 + assert len(study.scores["EmptyGraph"]) == 2 + assert all(score > 0 for score in study.scores["Original"]) + assert all(score == 0 for score in study.scores["EmptyGraph"]) From 9625779a4a0fa7eef6a9bcdb48d577fee33a79b2 Mon Sep 17 00:00:00 2001 From: jeremy-wayland Date: Wed, 3 Jun 2026 14:26:49 +0200 Subject: [PATCH 3/7] refactor dependencies in `pyproject.toml` to streamline optional dependencies and remove unused packages --- pyproject.toml | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f93c240..6777ced 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,14 +28,25 @@ classifiers = [ ] dependencies = [ "networkx>=3.4.2", + "pandas>=3.0.3", "pot>=0.9.5", "scikit-learn>=1.7.0", - "seaborn>=0.13.2", "torch>=2.7.0", "torch-geometric>=2.6.1", "torchvision>=0.22.0", ] +[project.optional-dependencies] +lightning = [ + "lightning>=2.6.1", +] +dgl = [ + "dgl>=1.1,<3; python_version < '3.13'", +] +integrations = [ + "rings-evaluation[lightning,dgl]", +] + [project.urls] Documentation = "https://aidos.group/rings/" Repository = "https://github.com/aidos-lab/rings" @@ -59,11 +70,20 @@ build-backend = "hatchling.build" [dependency-groups] dev = [ "ipykernel>=6.29.5", - "lightning>=2.6.1", "pytest>=8.4.1", "pytest-cov>=6.2.1", "ruff>=0.15.13", ] +lightning = [ + "lightning>=2.6.1", +] +dgl = [ + "dgl>=1.1,<3; python_version < '3.13'", +] +integrations = [ + "lightning>=2.6.1", + "dgl>=1.1,<3; python_version < '3.13'", +] docs = [ "furo>=2024.8.6", "sphinx>=8.2.3", From f0d6420bbf7bffc5d0aec63d356e53c95f3cfeba Mon Sep 17 00:00:00 2001 From: jeremy-wayland Date: Wed, 3 Jun 2026 14:32:51 +0200 Subject: [PATCH 4/7] update documentation to include optional DGL and PyTorch Lightning integration instructions --- README.md | 44 +++++++++++++++++++++++++++ docs/source/index.rst | 24 +++++++++++++++ docs/source/integrations.rst | 59 ++++++++++++++++++++++++++++++++++++ 3 files changed, 127 insertions(+) diff --git a/README.md b/README.md index 413ca5b..4aa530c 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,21 @@ pip install rings-evaluation Requires Python 3.11+. +### Optional integrations + +Install only what you need: + +```bash +# PyTorch Lightning integration +pip install "rings-evaluation[lightning]" + +# DGL integration (available for Python < 3.13) +pip install "rings-evaluation[dgl]" + +# Both integrations +pip install "rings-evaluation[integrations]" +``` + ### From source To contribute or run the examples in this repo: @@ -26,6 +41,14 @@ git clone https://github.com/aidos-lab/rings.git && cd rings uv sync && source .venv/bin/activate ``` +Enable optional integration groups as needed: + +```bash +uv sync --group lightning +uv sync --group dgl +uv sync --group integrations +``` + --- ## Quickstart @@ -78,6 +101,27 @@ for name, transform, seed in study.runs(): results = study.evaluate() ``` +**DGL** — use the DGL wrappers from `rings.integrations` (backed by the same core perturbation logic): + +```python +from rings.integrations import DGLOriginal, DGLEmptyGraph, SeparabilityStudy + +study = SeparabilityStudy( + perturbations={ + "Original": DGLOriginal(), + "EmptyGraph": DGLEmptyGraph(), + }, + num_seeds=5, +) + +for name, transform, seed in study.runs(): + perturbed = study.apply(base_dgl_graph, transform) + score = train_and_eval_dgl(perturbed, seed=seed) # your code + study.record(name, score) + +results = study.evaluate() +``` + **Custom evaluator** (GraphBench, OGB, anything that returns a scalar): just pass the number to `study.record(name, score)`. ### Runnable examples diff --git a/docs/source/index.rst b/docs/source/index.rst index e9213b1..bc2791c 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -21,6 +21,22 @@ Install Requires Python 3.11+. Package on `PyPI `__. +Optional integrations +~~~~~~~~~~~~~~~~~~~~~ + +Install only what you need: + +.. code-block:: bash + + # PyTorch Lightning integration + pip install "rings-evaluation[lightning]" + + # DGL integration (available for Python < 3.13) + pip install "rings-evaluation[dgl]" + + # Both integrations + pip install "rings-evaluation[integrations]" + From source ~~~~~~~~~~~ @@ -32,6 +48,14 @@ To contribute or run the examples in this repo: git clone https://github.com/aidos-lab/rings.git && cd rings uv sync && source .venv/bin/activate +Enable optional integration groups as needed: + +.. code-block:: bash + + uv sync --group lightning + uv sync --group dgl + uv sync --group integrations + Quickstart ---------------------------------------------- diff --git a/docs/source/integrations.rst b/docs/source/integrations.rst index 91e5a49..9a597d6 100644 --- a/docs/source/integrations.rst +++ b/docs/source/integrations.rst @@ -57,6 +57,57 @@ Lightning Your ``LightningModule.test_step`` must call ``self.log("test_acc", acc)`` (or whatever ``metric_key`` you pass to ``SeparabilityCallback``). +DGL +--- + +Install the optional integration dependencies: + +.. code-block:: bash + + pip install "rings-evaluation[dgl]" + # or with uv: + uv sync --group dgl + +RINGS keeps PyG perturbations as the source of truth and exposes DGL-compatible wrappers +through ``rings.integrations``: + +- ``DGLOriginal`` +- ``DGLEmptyFeatures`` +- ``DGLRandomFeatures`` +- ``DGLEmptyGraph`` +- ``DGLCompleteGraph`` +- ``DGLRandomGraph`` + +These wrappers convert a ``dgl.DGLGraph`` to a temporary PyG ``Data`` object, apply the +underlying RINGS perturbation, and convert the result back to DGL. + +.. code-block:: python + + from rings.integrations import ( + DGLOriginal, + DGLEmptyGraph, + DGLRandomFeatures, + SeparabilityStudy, + ) + + study = SeparabilityStudy( + perturbations={ + "Original": DGLOriginal(), + "EmptyGraph": DGLEmptyGraph(), + "RandomFeatures": DGLRandomFeatures(shuffle=True), + }, + num_seeds=5, + comparator="ks", + alpha=0.05, + ) + + for name, transform, seed in study.runs(): + perturbed = study.apply(base_dgl_graph, transform) + score = train_and_eval_dgl(perturbed, seed=seed) # your code + study.record(name, score) + + results = study.evaluate() + Custom evaluators ----------------- @@ -89,3 +140,11 @@ SeparabilityCallback :members: :undoc-members: :show-inheritance: + +DGL perturbation wrappers +^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. automodule:: rings.integrations.dgl + :members: + :undoc-members: + :show-inheritance: From 2d0edaa90ecc8d2f705848429b00846af13285d5 Mon Sep 17 00:00:00 2001 From: jeremy-wayland Date: Wed, 3 Jun 2026 14:37:30 +0200 Subject: [PATCH 5/7] formatting is ruff! --- rings/integrations/__init__.py | 2 +- rings/integrations/study.py | 2 ++ tests/test_lightning_framework_integration.py | 4 +++- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/rings/integrations/__init__.py b/rings/integrations/__init__.py index 6af5393..7a3ff54 100644 --- a/rings/integrations/__init__.py +++ b/rings/integrations/__init__.py @@ -17,7 +17,7 @@ def __getattr__(name): from rings.integrations.lightning import SeparabilityCallback return SeparabilityCallback - + if name in [ "DGLOriginal", "DGLEmptyFeatures", diff --git a/rings/integrations/study.py b/rings/integrations/study.py index 9776f54..00af71a 100644 --- a/rings/integrations/study.py +++ b/rings/integrations/study.py @@ -100,6 +100,7 @@ def apply(data: Any, transform: Callable) -> Any: # Try PyG try: from torch_geometric.data import Data, Dataset + if isinstance(data, Dataset): data.transform = transform return data @@ -111,6 +112,7 @@ def apply(data: Any, transform: Callable) -> Any: # Try DGL try: import dgl + if isinstance(data, dgl.DGLGraph): return transform(data) except ImportError: diff --git a/tests/test_lightning_framework_integration.py b/tests/test_lightning_framework_integration.py index 3a8c806..fb02351 100644 --- a/tests/test_lightning_framework_integration.py +++ b/tests/test_lightning_framework_integration.py @@ -59,7 +59,9 @@ def _run_study_with_lightning(base_graph, perturbations, num_seeds=2): for name, transform, seed in study.runs(): torch.manual_seed(seed) graph = study.apply(base_graph, transform) - callback = SeparabilityCallback(study, perturbation_name=name, metric_key="test_acc") + callback = SeparabilityCallback( + study, perturbation_name=name, metric_key="test_acc" + ) trainer = pl.Trainer( accelerator="cpu", devices=1, From 907ca733ce1c1497639c085e0388856b67792cd1 Mon Sep 17 00:00:00 2001 From: jeremy-wayland Date: Wed, 3 Jun 2026 14:40:50 +0200 Subject: [PATCH 6/7] bump version! --- pyproject.toml | 2 +- rings/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6777ced..7cc3208 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "rings-evaluation" -version = "0.1.0" +version = "0.1.1" description = "RINGS: Relevant Information in Node Features and Graph Structure. An evaluation framework for graph-learning based on first principles." readme = "README.md" license = "BSD-3-Clause" diff --git a/rings/__init__.py b/rings/__init__.py index a9d08be..24ec492 100644 --- a/rings/__init__.py +++ b/rings/__init__.py @@ -8,7 +8,7 @@ ) from rings import integrations -__version__ = "0.1.0" +__version__ = "0.1.1" __all__ = [ "Original", From 163a384da54c9b7fdf9e37c806997c903fbd77af Mon Sep 17 00:00:00 2001 From: jeremy-wayland Date: Wed, 3 Jun 2026 14:44:31 +0200 Subject: [PATCH 7/7] add changelog for DGL integration and update dependencies --- CHANGELOG.md | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 CHANGELOG.md diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..a2f3a62 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,24 @@ +# Changelog + +All notable changes to this project are documented in this file. + +## v0.1.1 + +### Added +- DGL integration wrappers in `rings.integrations` (`DGLOriginal`, `DGLEmptyFeatures`, `DGLRandomFeatures`, `DGLEmptyGraph`, `DGLCompleteGraph`, `DGLRandomGraph`). +- DGL/PyG conversion helpers in `rings.integrations.dgl` to round-trip homogeneous DGL graphs through existing RINGS perturbations. +- Test coverage for DGL integration behavior and seed/feature handling in `tests/test_dgl_integration.py`. +- Lightning integration tests that validate both PyG and DGL usage in `tests/test_lightning_framework_integration.py`. + +### Changed +- Optional dependency model now includes explicit install targets for `lightning`, `dgl`, and combined `integrations` extras/groups in `pyproject.toml`. +- Integration documentation now includes DGL setup and usage examples in `README.md` and `docs/source/integrations.rst`. +- Package version updated to `0.1.1`. + +## v0.1.0 + +### Added +- Initial `rings-evaluation` release on `main` with the core perturbation framework for graph-learning dataset evaluation. +- Support for integration into existing GNN workflows via `SeparabilityStudy`. +- PyTorch Lightning callback support via `SeparabilityCallback`. +- Documentation and runnable examples for integrating RINGS into model evaluation pipelines.