diff --git a/hyperbench/data/__init__.py b/hyperbench/data/__init__.py index c0c19e9..70ebe46 100644 --- a/hyperbench/data/__init__.py +++ b/hyperbench/data/__init__.py @@ -2,27 +2,73 @@ Dataset, HIFConverter, ) + from .supported_datasets import ( AlgebraDataset, + AmazonDataset, + ContactHighSchoolDataset, + ContactPrimarySchoolDataset, CoraDataset, CourseraDataset, DBLPDataset, + EmailEnronDataset, + EmailW3CDataset, + GeometryDataset, + GOTDataset, IMDBDataset, + MusicBluesReviewsDataset, + NBADataset, + NDCClassesDataset, + NDCSubstancesDataset, PatentDataset, + PubmedDataset, + RestaurantReviewsDataset, + ThreadsAskUbuntuDataset, ThreadsMathsxDataset, + TwitterDataset, + VegasBarsReviewsDataset, ) from .loader import DataLoader +from .sampling import ( + BaseSampler, + HyperedgeSampler, + NodeSampler, + SamplingStrategy, + create_sampler_from_strategy, +) + __all__ = [ - "Dataset", - "DataLoader", "AlgebraDataset", + "AmazonDataset", + "BaseSampler", + "ContactHighSchoolDataset", + "ContactPrimarySchoolDataset", "CoraDataset", "CourseraDataset", + "Dataset", + "DataLoader", "DBLPDataset", + "EmailEnronDataset", + "EmailW3CDataset", + "GeometryDataset", + "GOTDataset", + "HIFConverter", + "HyperedgeSampler", "IMDBDataset", + "MusicBluesReviewsDataset", + "NBADataset", + "NDCClassesDataset", + "NDCSubstancesDataset", + "NodeSampler", "PatentDataset", + "PubmedDataset", + "RestaurantReviewsDataset", + "SamplingStrategy", + "ThreadsAskUbuntuDataset", "ThreadsMathsxDataset", - "HIFConverter", + "TwitterDataset", + "VegasBarsReviewsDataset", + "create_sampler_from_strategy", ] diff --git a/hyperbench/data/dataset.py b/hyperbench/data/dataset.py index a22ac92..24b04eb 100644 --- a/hyperbench/data/dataset.py +++ b/hyperbench/data/dataset.py @@ -6,12 +6,14 @@ import requests from enum import Enum -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional from torch import Tensor from torch.utils.data import Dataset as TorchDataset from hyperbench.types import HData, HIFHypergraph from hyperbench.utils import validate_hif_json +from .sampling import SamplingStrategy, create_sampler_from_strategy + class DatasetNames(Enum): """ @@ -97,53 +99,71 @@ def load_from_hif(dataset_name: Optional[str], save_on_disk: bool = False) -> HI class Dataset(TorchDataset): + """ + A dataset class for loading and processing hypergraph data. + + Attributes: + DATASET_NAME: Class variable indicating the name of the dataset to load. + hypergraph: The loaded hypergraph in HIF format. Can be ``None`` if initialized from an HData object. + hdata: The processed hypergraph data in HData format. + sampling_strategy: The strategy used for sampling sub-hypergraphs (e.g., by node IDs or hyperedge IDs). + If not provided, defaults to ``SamplingStrategy.HYPEREDGE``. + """ + DATASET_NAME = None def __init__( self, hdata: Optional[HData] = None, + sampling_strategy: SamplingStrategy = SamplingStrategy.HYPEREDGE, is_original: Optional[bool] = True, ) -> None: self.__is_original = is_original + self.__sampler = create_sampler_from_strategy(sampling_strategy) - self.hypergraph: HIFHypergraph = self.download() if hdata is None else HIFHypergraph.empty() - self.hdata: HData = self.process() if hdata is None else hdata + self.sampling_strategy = sampling_strategy + self.hypergraph = self.download() if hdata is None else HIFHypergraph.empty() + self.hdata = self.process() if hdata is None else hdata def __len__(self) -> int: - return self.hdata.num_nodes + return self.__sampler.len(self.hdata) def __getitem__(self, index: int | List[int]) -> HData: """ - Sample a sub-hypergraph containing the specified node(s) and all hyperedges incident to those nodes. - - Note: - The returned :class:`HData` contains only the sampled nodes and their incident hyperedges. - Node features (x) and hyperedge attributes are not included in the returned :class:`HData` . + Sample a sub-hypergraph based on the sampling strategy and return it as HData. + If: + - Sampling by node IDs, the sub-hypergraph will contain all hyperedges incident to the sampled nodes and all nodes incident to those hyperedges. + - Sampling by hyperedge IDs, the sub-hypergraph will contain all nodes incident to the sampled hyperedges. Args: - index: An integer or a list of integers representing node IDs to sample. + index: An integer or a list of integers representing node or hyperedge IDs to sample, depending on the sampling strategy. Returns: - An :class:`HData` object containing the sub-hypergraph induced by the specified node(s) and their incident hyperedges. - """ - sampled_node_ids_list = self.__get_node_ids_to_sample(index) - self.__validate_node_ids(sampled_node_ids_list) + An HData instance containing the sampled sub-hypergraph. - sampled_hyperedge_index, _, _ = self.__sample_hyperedge_index(sampled_node_ids_list) - return HData.from_hyperedge_index(sampled_hyperedge_index) + Raises: + ValueError: If the provided index is invalid (e.g., empty list or list length exceeds number of nodes/hyperedges). + IndexError: If any node/hyperedge ID is out of bounds. + """ + return self.__sampler.sample(index, self.hdata) @classmethod - def from_hdata(cls, hdata: HData) -> "Dataset": + def from_hdata( + cls, + hdata: HData, + sampling_strategy: SamplingStrategy = SamplingStrategy.HYPEREDGE, + ) -> "Dataset": """ Create a :class:`Dataset` instance from an :class:`HData` object. Args: hdata: :class:`HData` object containing the hypergraph data. + sampling_strategy: The sampling strategy to use for the dataset. If not provided, defaults to ``SamplingStrategy.HYPEREDGE``. Returns: The :class:`Dataset` instance with the provided :class:`HData`. """ - return cls(hdata=hdata, is_original=False) + return cls(hdata=hdata, sampling_strategy=sampling_strategy, is_original=False) def download(self) -> HIFHypergraph: """ @@ -281,7 +301,11 @@ def split( # i=2 -> permutation[2:3] = [2] (1 edge) split_hyperedge_ids = hyperedge_ids_permutation[start:end] split_hdata = HData.split(self.hdata, split_hyperedge_ids).to(device=device) - split_dataset = self.__class__(hdata=split_hdata, is_original=False) + split_dataset = self.__class__( + hdata=split_hdata, + sampling_strategy=self.sampling_strategy, + is_original=False, + ) split_datasets.append(split_dataset) start = end @@ -389,30 +413,6 @@ def __get_hyperedge_ids_permutation( ranged_hyperedge_ids_permutation = torch.arange(num_hyperedges, device=device) return ranged_hyperedge_ids_permutation - def __get_node_ids_to_sample(self, id: int | List[int]) -> List[int]: - """ - Get a list of node IDs to sample based on the provided index. - - Args: - id: An integer or a list of integers representing node IDs to sample. - - Returns: - List of node IDs to sample. - - Raises: - ValueError: If the provided index is invalid (e.g., empty list or list length exceeds number of nodes). - """ - if isinstance(id, list): - if len(id) < 1: - raise ValueError("Index list cannot be empty.") - elif len(id) > self.__len__(): - raise ValueError( - "Index list length cannot exceed number of nodes in the hypergraph." - ) - return list(set(id)) - - return [id] - def __process_hyperedge_attr( self, hyperedge_id_to_idx: Dict[Any, int], @@ -469,56 +469,3 @@ def __process_x(self, num_nodes: int) -> Tensor: x = torch.ones((num_nodes, 1), dtype=torch.float) return x # shape [num_nodes, num_node_features] - - def __sample_hyperedge_index( - self, - sampled_node_ids_list: List[int], - ) -> Tuple[Tensor, Tensor, Tensor]: - hyperedge_index = self.hdata.hyperedge_index - node_ids = hyperedge_index[0] - hyperedge_ids = hyperedge_index[1] - - sampled_node_ids = torch.tensor(sampled_node_ids_list, device=node_ids.device) - - # Find incidences where the node is in our sampled node set - # Example: hyperedge_index[0] = [0, 0, 1, 2, 3, 4], sampled_node_ids = [0, 3] - # -> node_incidence_mask = [True, True, False, False, True, False] - node_incidence_mask = torch.isin(node_ids, sampled_node_ids) - - # Get unique hyperedges that have at least one sampled node - # Example: hyperedge_index[1] = [0, 0, 0, 1, 2, 2], node_incidence_mask = [True, True, False, False, True, False] - # -> sampled_hyperedge_ids = [0, 2] as they connect to sampled nodes - sampled_hyperedge_ids = hyperedge_ids[node_incidence_mask].unique() - - # Find all incidences for the sampled hyperedges (not just sampled nodes) - # Example: hyperedge_index[1] = [0, 0, 0, 1, 2, 2], sampled_hyperedge_ids = [0, 2] - # -> hyperedge_incidence_mask = [True, True, True, False, True, True] - hyperedge_incidence_mask = torch.isin(hyperedge_ids, sampled_hyperedge_ids) - - # Collect all node IDs that appear in the sampled hyperedges - # Example: hyperedge_index[0] = [0, 0, 1, 2, 3, 4], hyperedge_incidence_mask = [True, True, True, False, True, True] - # -> node_ids_in_sampled_hyperedge = [0, 1, 3, 4] - node_ids_in_sampled_hyperedge = node_ids[hyperedge_incidence_mask].unique() - - # Keep all incidences belonging to the sampled hyperedges - # Example: hyperedge_index = [[0, 0, 1, 2, 3, 4], - # [0, 0, 0, 1, 2, 2]], - # hyperedge_incidence_mask = [True, True, True, False, True, True] - # -> sampled_hyperedge_index = [[0, 0, 1, 3, 4], - # [0, 0, 0, 2, 2]] - sampled_hyperedge_index = hyperedge_index[:, hyperedge_incidence_mask] - return sampled_hyperedge_index, node_ids_in_sampled_hyperedge, sampled_hyperedge_ids - - def __validate_node_ids(self, node_ids: List[int]) -> None: - """ - Validate that node IDs are within bounds of the hypergraph. - - Args: - node_ids: List of node IDs to validate. - - Raises: - IndexError: If any node ID is out of bounds. - """ - for id in node_ids: - if id < 0 or id >= self.__len__(): - raise IndexError(f"Node ID {id} is out of bounds (0, {self.__len__() - 1}).") diff --git a/hyperbench/data/sampling.py b/hyperbench/data/sampling.py new file mode 100644 index 0000000..5a69f3a --- /dev/null +++ b/hyperbench/data/sampling.py @@ -0,0 +1,254 @@ +import torch + +from abc import ABC, abstractmethod +from enum import Enum +from torch import Tensor +from typing import List +from hyperbench.types import HData + + +class SamplingStrategy(Enum): + NODE = "node" + HYPEREDGE = "hyperedge" + + +class BaseSampler(ABC): + @abstractmethod + def sample(self, index: int | List[int], hdata: HData) -> HData: + """ + Sample a sub-hypergraph and return HData with global IDs. + + Args: + index: An integer or list of integers specifying which items to sample. + hdata: The original HData to sample from. + + Returns: + A new HData instance containing only the sampled items and their associated data. + """ + raise NotImplementedError("Subclasses must implement the sample method.") + + @abstractmethod + def len(self, hdata: HData) -> int: + """ + Return the number of sampleable items (nodes or hyperedges). + + Args: + hdata: The HData to query for the number of sampleable items. + """ + raise NotImplementedError("Subclasses must implement the len method.") + + def _normalize_index(self, index: int | List[int], size: int) -> List[int]: + """ + Convert index to list, deduplicate, validate length. + + Args: + index: An integer or a list of integers representing IDs to sample. + size: The total number of sampleable items (e.g., nodes or hyperedges) for validation. + + Returns: + List of IDs to sample. + + Raises: + ValueError: If the provided index is invalid (e.g., empty list or list length exceeds number of sampleable items). + """ + if isinstance(index, list): + if len(index) < 1: + raise ValueError("Index list cannot be empty.") + if len(index) > size: + raise ValueError( + f"Index list length ({len(index)}) cannot exceed the number of sampleable items ({size})." + ) + return list(set(index)) + return [index] + + def _sample_hyperedge_index( + self, + hyperedge_index: Tensor, + sampled_hyperedge_ids: Tensor, + ) -> Tensor: + """ + Sample the hyperedge index to keep only incidences belonging to the specified sampled hyperedge IDs. + + Args: + hyperedge_index: The original hyperedge index tensor of shape ``[2, num_incidences]``. + sampled_hyperedge_ids: A tensor containing the IDs of hyperedges to sample. + + Returns: + A new hyperedge index tensor containing only the incidences of the sampled hyperedges. + """ + hyperedge_ids = hyperedge_index[1] + + # Find incidences where the hyperedge is in our sampled hyperedges + # Example: hyperedge_ids = [0, 0, 0, 1, 2, 2], sampled_hyperedge_ids = [0, 2] + # -> sampled_hyperedges_mask = [True, True, True, False, True, True] + sampled_hyperedges_mask = torch.isin(hyperedge_ids, sampled_hyperedge_ids) + + # Keep all incidences belonging to the sampled hyperedges + # Example: hyperedge_index = [[0, 0, 1, 2, 3, 4], + # [0, 0, 0, 1, 2, 2]], + # sampled_hyperedges_mask = [True, True, True, False, True, True] + # -> sampled_hyperedge_index = [[0, 0, 1, 3, 4], + # [0, 0, 0, 2, 2]] + sampled_hyperedge_index = hyperedge_index[:, sampled_hyperedges_mask] + return sampled_hyperedge_index + + def _validate_bounds(self, ids: List[int], size: int, label: str) -> None: + """ + Check all IDs are in [0, self.len). + + Args: + ids: List of IDs to validate. + size: The total number of sampleable items (e.g., nodes or hyperedges). + label: A string label for error messages (e.g., "Node ID" or "Hyperedge ID"). + + Raises: + IndexError: If any ID is out of bounds. + """ + for id in ids: + if id < 0 or id >= size: + raise IndexError(f"{label} {id} is out of bounds (0, {size - 1}).") + + +class HyperedgeSampler(BaseSampler): + def sample(self, index: int | List[int], hdata: HData) -> HData: + """ + Sample hyperedges by their IDs and return the sub-hypergraph containing only those hyperedges and their incident nodes. + + Examples: + >>> hyperedge_index = [[0, 0, 1, 2, 3, 4], + ... [0, 0, 0, 1, 2, 2]] + >>> hdata = HData.from_hyperedge_index(hyperedge_index) + >>> strategy = HyperedgeSampler() + >>> sampled_hdata = strategy.sample([0, 2], hdata) + >>> sampled_hdata.hyperedge_index + >>> tensor([[0, 0, 1, 3, 4], + ... [0, 0, 0, 2, 2]]) + + Args: + index: An integer or a list of integers representing hyperedge IDs to sample. + hdata: The original HData to sample from. + + Returns: + An HData instance containing only the sampled hyperedges and their incident nodes. + + Raises: + ValueError: If the provided index is invalid (e.g., empty list or list length exceeds number of hyperedges). + IndexError: If any hyperedge ID is out of bounds. + """ + ids = self._normalize_index(index, self.len(hdata)) + self._validate_bounds(ids, self.len(hdata), "Hyperedge ID") + + hyperedge_index = hdata.hyperedge_index + + sampled_hyperedge_ids = torch.tensor(ids, device=hyperedge_index.device) + + # Example: sampled_hyperedge_ids = [0, 2], + # hyperedge_index = [[0, 0, 1, 2, 3, 4], + # [0, 0, 0, 1, 2, 2]], + # -> sampled_hyperedges_mask = [True, True, True, False, True, True] + # -> sampled_hyperedge_index = [[0, 0, 1, 3, 4], + # [0, 0, 0, 2, 2]] + sampled_hyperedge_index = self._sample_hyperedge_index( + hyperedge_index, sampled_hyperedge_ids + ) + + return HData.from_hyperedge_index(sampled_hyperedge_index) + + def len(self, hdata: HData) -> int: + """ + Return the number of hyperedges in the given HData. + + Args: + hdata: The HData to query for the number of hyperedges. + + Returns: + The number of hyperedges in the HData. + """ + return hdata.num_hyperedges + + +class NodeSampler(BaseSampler): + def sample(self, index: int | List[int], hdata: HData) -> HData: + """ + Sample nodes by their IDs and return the sub-hypergraph containing only those nodes and their incident hyperedges. + + Examples: + >>> hyperedge_index = [[0, 0, 1, 2, 3, 4], + ... [0, 0, 0, 1, 2, 2]] + >>> hdata = HData.from_hyperedge_index(hyperedge_index) + >>> strategy = NodeSampler() + >>> sampled_hdata = strategy.sample([0, 3], hdata) + >>> sampled_hdata.hyperedge_index + >>> tensor([[0, 0, 1, 3, 4], + ... [0, 0, 0, 2, 2]]) + + Args: + index: An integer or a list of integers representing node IDs to sample. + hdata: The original HData to sample from. + + Returns: + An HData instance containing only the sampled nodes and their incident hyperedges. + + Raises: + ValueError: If the provided index is invalid (e.g., empty list or list length exceeds number of nodes). + IndexError: If any node ID is out of bounds. + """ + ids = self._normalize_index(index, self.len(hdata)) + self._validate_bounds(ids, self.len(hdata), "Node ID") + + hyperedge_index = hdata.hyperedge_index + node_ids = hyperedge_index[0] + hyperedge_ids = hyperedge_index[1] + + sampled_node_ids = torch.tensor(ids, device=node_ids.device) + + # Find incidences where the node is in our sampled nodes + # Example: node_ids = [0, 0, 1, 2, 3, 4], sampled_node_ids = [0, 3] + # -> sampled_nodes_mask = [True, True, False, False, True, False] + sampled_nodes_mask = torch.isin(node_ids, sampled_node_ids) + + # Get unique hyperedges that have at least one sampled node + # Example: hyperedge_ids = [0, 0, 0, 1, 2, 2], sampled_nodes_mask = [True, True, False, False, True, False] + # -> sampled_hyperedge_ids = [0, 2] as they connect to sampled nodes + sampled_hyperedge_ids = hyperedge_ids[sampled_nodes_mask].unique() + + # Example: sampled_hyperedge_ids = [0, 2], + # hyperedge_index = [[0, 0, 1, 2, 3, 4], + # [0, 0, 0, 1, 2, 2]], + # -> sampled_hyperedges_mask = [True, True, True, False, True, True] + # -> sampled_hyperedge_index = [[0, 0, 1, 3, 4], + # [0, 0, 0, 2, 2]] + sampled_hyperedge_index = self._sample_hyperedge_index( + hyperedge_index, sampled_hyperedge_ids + ) + + return HData.from_hyperedge_index(sampled_hyperedge_index) + + def len(self, hdata: HData) -> int: + """ + Return the number of nodes in the given HData. + + Args: + hdata: The HData to query for the number of nodes. + + Returns: + The number of nodes in the HData. + """ + return hdata.num_nodes + + +def create_sampler_from_strategy(strategy: SamplingStrategy) -> BaseSampler: + """ + Factory function to create a sampler instance based on the provided sampling strategy type. + + Args: + strategy: An instance of SamplingStrategy enum indicating which sampling strategy to use. + + Returns: + An instance of a subclass of BaseSampler corresponding to the specified strategy. If strategy is not recognized, defaults to ``HyperedgeSampler``. + """ + match strategy: + case SamplingStrategy.NODE: + return NodeSampler() + case _: + return HyperedgeSampler() diff --git a/hyperbench/tests/data/dataset_test.py b/hyperbench/tests/data/dataset_test.py index 58ad08a..a98d963 100644 --- a/hyperbench/tests/data/dataset_test.py +++ b/hyperbench/tests/data/dataset_test.py @@ -2,12 +2,13 @@ import torch import pytest from unittest.mock import patch, mock_open -from hyperbench.data import AlgebraDataset, Dataset, HIFConverter +from hyperbench.data import AlgebraDataset, Dataset, HIFConverter, SamplingStrategy from hyperbench.types import HData, HIFHypergraph from hyperbench.tests.mock import * +@pytest.fixture def mock_hdata() -> HData: x = torch.ones((3, 1), dtype=torch.float) hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]], dtype=torch.long) @@ -314,20 +315,40 @@ class FakeMockDataset(Dataset): FakeMockDataset() -def test_dataset_is_available(): +@pytest.mark.parametrize( + "strategy, expected_len", + [ + pytest.param(SamplingStrategy.NODE, 4, id="node_strategy"), + pytest.param(SamplingStrategy.HYPEREDGE, 2, id="hyperedge_strategy"), + ], +) +def test_dataset_is_available(strategy, expected_len): mock_hypergraph = HIFHypergraph( network_type="undirected", - nodes=[{"node": str(i)} for i in range(423)], - edges=[{"edge": str(i)} for i in range(1268)], - incidences=[{"node": "0", "edge": "0"}], + nodes=[ + {"node": "0"}, + {"node": "1"}, + {"node": "2"}, + {"node": "3"}, + ], + edges=[ + {"edge": "0"}, + {"edge": "1"}, + ], + incidences=[ + {"node": "0", "edge": "0"}, + {"node": "1", "edge": "0"}, + {"node": "2", "edge": "1"}, + {"node": "3", "edge": "1"}, + ], ) with patch.object(HIFConverter, "load_from_hif", return_value=mock_hypergraph): - dataset = AlgebraDataset() + dataset = AlgebraDataset(sampling_strategy=strategy) assert dataset.DATASET_NAME == "ALGEBRA" assert dataset.hypergraph is not None - assert dataset.__len__() == dataset.hypergraph.num_nodes + assert len(dataset) == expected_len def test_download_already_downloaded_dataset_uses_local_value(): @@ -456,92 +477,170 @@ def test_dataset_process_random_ids(): assert dataset.hdata.hyperedge_attr.shape == (2, 0) # 2 edges, 0 attributes each -def test_getitem_index_list_empty(mock_simple_hypergraph): - """Test __getitem__ with empty index list raises ValueError.""" +@pytest.mark.parametrize( + "strategy", + [ + pytest.param(SamplingStrategy.NODE, id="node_strategy"), + pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), + ], +) +def test_getitem_index_list_empty(mock_simple_hypergraph, strategy): with patch.object(HIFConverter, "load_from_hif", return_value=mock_simple_hypergraph): - dataset = AlgebraDataset() + dataset = AlgebraDataset(sampling_strategy=strategy) with pytest.raises(ValueError, match="Index list cannot be empty."): dataset[[]] -def test_getitem_raises_when_index_list_larger_then_num_nodes(mock_five_node_hypergraph): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_five_node_hypergraph): - dataset = AlgebraDataset() - - with pytest.raises( - ValueError, - match="Index list length cannot exceed number of nodes in the hypergraph.", - ): - dataset[[0, 1, 2, 3, 4, 5]] - - -def test_getitem_raises_when_index_out_of_bounds(mock_four_node_hypergraph): +@pytest.mark.parametrize( + "strategy, index_list, expected_message", + [ + pytest.param( + SamplingStrategy.NODE, + [0, 1, 2, 3, 4], + r"Index list length \(5\) cannot exceed the number of sampleable items \(4\)\.", + id="node_strategy", + ), + pytest.param( + SamplingStrategy.HYPEREDGE, + [0, 1, 2], + r"Index list length \(3\) cannot exceed the number of sampleable items \(2\)\.", + id="hyperedge_strategy", + ), + ], +) +def test_getitem_raises_when_index_list_larger_than_max( + mock_four_node_hypergraph, strategy, index_list, expected_message +): with patch.object(HIFConverter, "load_from_hif", return_value=mock_four_node_hypergraph): - dataset = AlgebraDataset() + dataset = AlgebraDataset(sampling_strategy=strategy) - with pytest.raises(IndexError, match="Node ID 4 is out of bounds."): - dataset[4] + with pytest.raises(ValueError, match=expected_message): + dataset[index_list] -def test_getitem_single_index(mock_sample_hypergraph): +@pytest.mark.parametrize( + "strategy, index, expected_message", + [ + pytest.param( + SamplingStrategy.NODE, 4, r"Node ID 4 is out of bounds \(0, 3\)\.", id="node_strategy" + ), + pytest.param( + SamplingStrategy.HYPEREDGE, + 2, + r"Hyperedge ID 2 is out of bounds \(0, 1\)\.", + id="hyperedge_strategy", + ), + ], +) +def test_getitem_raises_when_index_out_of_bounds( + mock_four_node_hypergraph, strategy, index, expected_message +): + with patch.object(HIFConverter, "load_from_hif", return_value=mock_four_node_hypergraph): + dataset = AlgebraDataset(sampling_strategy=strategy) + + with pytest.raises(IndexError, match=expected_message): + dataset[index] + + +@pytest.mark.parametrize( + "strategy, index, expected_shape, expected_num_hyperedges", + [ + # When node 1 is selected, we get hyperedge 0 with nodes 0 and 1 -> 2 incidences, 1 hyperedge + pytest.param(SamplingStrategy.NODE, 1, (2, 1), 1, id="node_strategy"), + # When hyperedge 0 is selected, we get nodes 0 and 1 -> 2 incidences, 1 hyperedge + pytest.param(SamplingStrategy.HYPEREDGE, 0, (2, 1), 1, id="hyperedge_strategy"), + ], +) +def test_getitem_single_index( + mock_sample_hypergraph, strategy, index, expected_shape, expected_num_hyperedges +): with patch.object(HIFConverter, "load_from_hif", return_value=mock_sample_hypergraph): - dataset = AlgebraDataset() + dataset = AlgebraDataset(sampling_strategy=strategy) - data = dataset[1] + data = dataset[index] - # Node 1 is isolated (self-loop hyperedge), so hyperedge_index has shape [2, 1] - assert data.hyperedge_index.shape == (2, 1) - assert data.num_hyperedges == 1 + assert data.hyperedge_index.shape == expected_shape + assert data.num_hyperedges == expected_num_hyperedges -def test_getitem_when_list_index_provided(mock_four_node_hypergraph): +@pytest.mark.parametrize( + "strategy, index, expected_shape, expected_num_hyperedges", + [ + # When nodes (0, 2, 3) -> hyperedge 0 (nodes 0, 1) + hyperedge 1 (nodes 2, 3) -> 4 incidences, 2 hyperedges + pytest.param(SamplingStrategy.NODE, [0, 2, 3], (2, 4), 2, id="node_strategy"), + # When hyperedge 0 (nodes 0, 1) + hyperedge 1 (nodes 2, 3) -> 4 incidences, 2 hyperedges + pytest.param(SamplingStrategy.HYPEREDGE, [0, 1], (2, 4), 2, id="hyperedge_strategy"), + ], +) +def test_getitem_when_list_index_provided( + mock_four_node_hypergraph, strategy, index, expected_shape, expected_num_hyperedges +): with patch.object(HIFConverter, "load_from_hif", return_value=mock_four_node_hypergraph): - dataset = AlgebraDataset() + dataset = AlgebraDataset(sampling_strategy=strategy) - data = dataset[[0, 2, 3]] + data = dataset[index] - # Node 1 is part of the hyperedge that contains node 0, - # so it's included in the hyperedge index (4 incidences across 2 hyperedges) - assert data.hyperedge_index.shape == (2, 4) - assert data.num_hyperedges == 2 + assert data.hyperedge_index.shape == expected_shape + assert data.num_hyperedges == expected_num_hyperedges -def test_getitem_with_edge_attr(mock_three_node_weighted_hypergraph): +@pytest.mark.parametrize( + "strategy", + [ + pytest.param(SamplingStrategy.NODE, id="node_strategy"), + pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), + ], +) +def test_getitem_with_edge_attr(mock_three_node_weighted_hypergraph, strategy): with patch.object( HIFConverter, "load_from_hif", return_value=mock_three_node_weighted_hypergraph ): - dataset = AlgebraDataset() + dataset = AlgebraDataset(sampling_strategy=strategy) data = dataset[0] - # __getitem__ now returns only hyperedge_index with global IDs, - # x is empty and hyperedge_attr is None assert data.hyperedge_index.shape == (2, 2) assert data.num_hyperedges == 1 assert data.hyperedge_attr is None -def test_getitem_without_edge_attr(mock_no_edge_attr_hypergraph): +@pytest.mark.parametrize( + "strategy", + [ + pytest.param(SamplingStrategy.NODE, id="node_strategy"), + pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), + ], +) +def test_getitem_without_edge_attr(mock_no_edge_attr_hypergraph, strategy): with patch.object(HIFConverter, "load_from_hif", return_value=mock_no_edge_attr_hypergraph): - dataset = AlgebraDataset() + dataset = AlgebraDataset(sampling_strategy=strategy) - node_data = dataset[0] - assert node_data.hyperedge_attr is None + data = dataset[0] + assert data.hyperedge_attr is None -def test_getitem_with_multiple_edges_attr(mock_multiple_edges_attr_hypergraph): +@pytest.mark.parametrize( + "strategy, index", + [ + # When nodes 0,2 -> hyperedge 0 (nodes 0, 1) + hyperedge 1 (node 2) -> 2 hyperedges + pytest.param(SamplingStrategy.NODE, [0, 2], id="node_strategy"), + # When hyperedge 0 (nodes 0, 1) + hyperedge 1 (node 2) -> 2 hyperedges + pytest.param(SamplingStrategy.HYPEREDGE, [0, 1], id="hyperedge_strategy"), + ], +) +def test_getitem_with_multiple_edges_attr(mock_multiple_edges_attr_hypergraph, strategy, index): with patch.object( HIFConverter, "load_from_hif", return_value=mock_multiple_edges_attr_hypergraph ): - dataset = AlgebraDataset() + dataset = AlgebraDataset(sampling_strategy=strategy) - node_data = dataset[[0, 2]] - assert node_data.num_hyperedges == 2 + data = dataset[index] + assert data.num_hyperedges == 2 # Even though the original hypergraph has edge attributes, __getitem__ should return hyperedge_attr as None # as the hyperedge attributes are handled by the loader's collate function during batching - assert node_data.hyperedge_attr is None + assert data.hyperedge_attr is None def test_getitem_hyperedge_attr_are_padded_with_zero_when_no_uniform_edges(): @@ -690,52 +789,6 @@ class TestDataset(Dataset): assert torch.allclose(dataset.hdata.x, torch.tensor([[1.5], [2.5], [3.5]])) -def test_getitem_returns_global_ids(): - mock_hypergraph = HIFHypergraph( - network_type="undirected", - nodes=[ - {"node": "0", "attrs": {"weight": 1.0}}, - {"node": "1", "attrs": {"weight": 2.0}}, - {"node": "2", "attrs": {"weight": 3.0}}, - ], - edges=[ - {"edge": "0", "attrs": {}}, - {"edge": "1", "attrs": {}}, - ], - incidences=[ - {"node": "0", "edge": "0"}, - {"node": "1", "edge": "0"}, - {"node": "2", "edge": "1"}, - ], - ) - - with patch.object(HIFConverter, "load_from_hif", return_value=mock_hypergraph): - - class TestDataset(Dataset): - DATASET_NAME = "TEST" - - dataset = TestDataset() - - data = dataset[0] - # Node 0 and node 1 are in hyperedge 0 - assert data.num_hyperedges == 1 - assert data.hyperedge_index.shape == (2, 2) - # Global IDs preserved: nodes 0,1 and hyperedge 0 - assert torch.equal(data.hyperedge_index[0], torch.tensor([0, 1])) - assert torch.equal(data.hyperedge_index[1], torch.tensor([0, 0])) - - data = dataset[[0, 2]] - # All 3 nodes, both hyperedges - assert data.num_hyperedges == 2 - assert data.hyperedge_index.shape == (2, 3) - - data = dataset[2] - # Only node 2 in hyperedge 1 - assert data.num_hyperedges == 1 - assert torch.equal(data.hyperedge_index[0], torch.tensor([2])) - assert torch.equal(data.hyperedge_index[1], torch.tensor([1])) - - def test_transform_attrs_adds_padding_zero_when_attr_keys_padding(): mock_hypergraph = HIFHypergraph( network_type="undirected", @@ -771,26 +824,30 @@ class TestDataset(Dataset): assert torch.allclose(result, torch.tensor([1.5, 0.8])) # weight, score (insertion order) -def test_from_hdata(): - hdata = mock_hdata() +@pytest.mark.parametrize( + "strategy, expected_len", + [ + # mock_hdata: 3 nodes, 2 hyperedges + pytest.param(SamplingStrategy.NODE, 3, id="node_strategy"), + pytest.param(SamplingStrategy.HYPEREDGE, 2, id="hyperedge_strategy"), + ], +) +def test_from_hdata(strategy, expected_len, mock_hdata): + dataset = Dataset.from_hdata(mock_hdata, sampling_strategy=strategy) - dataset = Dataset.from_hdata(hdata) + assert dataset.hdata is mock_hdata + assert len(dataset) == expected_len - assert dataset.hdata is hdata - assert len(dataset) == hdata.num_nodes - -def test_from_hdata_download_raises(): - hdata = mock_hdata() - dataset = Dataset.from_hdata(hdata) +def test_from_hdata_download_raises(mock_hdata): + dataset = Dataset.from_hdata(mock_hdata) with pytest.raises(ValueError, match="download can only be called for the original dataset."): dataset.download() -def test_from_hdata_process_raises(): - hdata = mock_hdata() - dataset = Dataset.from_hdata(hdata) +def test_from_hdata_process_raises(mock_hdata): + dataset = Dataset.from_hdata(mock_hdata) with pytest.raises(ValueError, match="process can only be called for the original dataset."): dataset.process() @@ -893,11 +950,10 @@ def test_split_without_edge_attr(mock_no_edge_attr_hypergraph): assert split.hdata.hyperedge_attr is None -def test_to_device(): +def test_to_device(mock_hdata): device = torch.device("cpu") - hdata = mock_hdata() - dataset = Dataset.from_hdata(hdata) + dataset = Dataset.from_hdata(mock_hdata) result = dataset.to(device) @@ -941,3 +997,45 @@ def test_load_from_hif_skips_download_when_file_exists(): result = HIFConverter.load_from_hif(dataset_name, save_on_disk=True) mock_get.assert_not_called() assert result == mock_hypergraph + + +def test_default_sampling_strategy_is_hyperedge(mock_four_node_hypergraph): + with patch.object(HIFConverter, "load_from_hif", return_value=mock_four_node_hypergraph): + dataset = AlgebraDataset() + + # Default strategy is HYPEREDGE, so len should be num_hyperedges (2), not num_nodes (4) + assert dataset.sampling_strategy == SamplingStrategy.HYPEREDGE + assert len(dataset) == 2 + + +def test_explicit_node_sampling_strategy(mock_four_node_hypergraph): + with patch.object(HIFConverter, "load_from_hif", return_value=mock_four_node_hypergraph): + dataset = AlgebraDataset(sampling_strategy=SamplingStrategy.NODE) + + # NODE strategy, so len should be num_nodes (4), not num_hyperedges (2) + assert dataset.sampling_strategy == SamplingStrategy.NODE + assert len(dataset) == 4 + + +@pytest.mark.parametrize( + "strategy", + [ + pytest.param(SamplingStrategy.NODE, id="node_strategy"), + pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), + ], +) +def test_split_preserves_sampling_strategy(mock_four_node_hypergraph, strategy): + with patch.object(HIFConverter, "load_from_hif", return_value=mock_four_node_hypergraph): + dataset = AlgebraDataset(sampling_strategy=strategy) + + splits = dataset.split([0.5, 0.5]) + + for split in splits: + assert split.sampling_strategy == strategy + + +def test_from_hdata_with_explicit_strategy(mock_hdata): + dataset = Dataset.from_hdata(mock_hdata, sampling_strategy=SamplingStrategy.NODE) + + assert dataset.sampling_strategy == SamplingStrategy.NODE + assert len(dataset) == 3 # mock_hdata has 3 nodes diff --git a/hyperbench/tests/data/loader_test.py b/hyperbench/tests/data/loader_test.py index 2a9457a..1a51aa4 100644 --- a/hyperbench/tests/data/loader_test.py +++ b/hyperbench/tests/data/loader_test.py @@ -301,3 +301,70 @@ def test_iteration_over_dataloader(): assert batch_count == 3 # 5 samples with batch_size=2 -> 3 batches (2 + 2 + 1) assert dataset.__getitem__.call_count == n_samples # Ensure all samples were accessed + + +def test_collate_with_hyperedge_sampled_batch(): + # Full dataset with 4 nodes and 2 hyperedges + x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]) + hyperedge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 2]]) + hyperedge_attr = torch.tensor([[0.5], [0.7], [0.9]]) + hdata = HData(x=x, hyperedge_index=hyperedge_index, hyperedge_attr=hyperedge_attr) + + # hyperedge 0 -> nodes 0 and 1 + sample0 = HData.from_hyperedge_index(torch.tensor([[0, 1], [0, 0]])) + # hyperedge 2 -> node 3 + sample1 = HData.from_hyperedge_index(torch.tensor([[3], [2]])) + + dataset = MagicMock(spec=Dataset) + dataset.hdata = hdata + + loader = DataLoader(dataset, batch_size=2) + batched = loader.collate([sample0, sample1]) + + assert batched.num_nodes == 3 # Nodes 0, 1 from sample0 and node 3 from sample1 + assert batched.num_hyperedges == 2 # Hyperedges 0 from sample0 and 2 from sample1 + + expected_x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [7.0, 8.0]]) + assert torch.equal(batched.x, expected_x) + + # 0-based rebasing: + # - global nodes [0, 1, 3] -> local nodes [0, 1, 2] + # - global hyperedges [0, 2] -> local hyperedges [0, 1] + expected_hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]]) + assert torch.equal(batched.hyperedge_index, expected_hyperedge_index) + + expected_hyperedge_attr = torch.tensor([[0.5], [0.9]]) # From global hyperedges 0 and 2 + assert torch.equal(utils.to_non_empty_edgeattr(batched.hyperedge_attr), expected_hyperedge_attr) + + +def test_collate_with_node_sampled_batch(): + # Full dataset with 4 nodes and 2 hyperedges + x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]) + hyperedge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 2]]) + hdata = HData(x=x, hyperedge_index=hyperedge_index) + + # Samples contain node's incident hyperedges to simulate NODE strategy + # Node 0 -> hyperedge 0 (nodes 0, 1) + sample0 = HData.from_hyperedge_index(torch.tensor([[0, 1], [0, 0]])) + # Node 3 -> hyperedge 2 (nodes 3) + sample1 = HData.from_hyperedge_index(torch.tensor([[3], [2]])) + + dataset = MagicMock(spec=Dataset) + dataset.hdata = hdata + + loader = DataLoader(dataset, batch_size=2) + batched = loader.collate([sample0, sample1]) + + assert batched.num_nodes == 3 # Nodes 0, 1 from sample0 and node 3 from sample1 + assert batched.num_hyperedges == 2 # Hyperedges 0 from sample0 and 2 from sample1 + + expected_x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [7.0, 8.0]]) + assert torch.equal(batched.x, expected_x) + + # 0-based rebasing: + # - global nodes [0, 1, 3] -> local nodes [0, 1, 2] + # - global hyperedges [0, 2] -> local hyperedges [0, 1] + expected_hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]]) + assert torch.equal(batched.hyperedge_index, expected_hyperedge_index) + + assert batched.hyperedge_attr is None diff --git a/hyperbench/tests/data/sampling_test.py b/hyperbench/tests/data/sampling_test.py new file mode 100644 index 0000000..faa6453 --- /dev/null +++ b/hyperbench/tests/data/sampling_test.py @@ -0,0 +1,233 @@ +import torch +import pytest + +from hyperbench.data import ( + BaseSampler, + HyperedgeSampler, + NodeSampler, + SamplingStrategy, + create_sampler_from_strategy, +) +from hyperbench.types import HData + + +@pytest.fixture +def mock_four_node_two_hyperedge_hdata(): + x = torch.ones((4, 1), dtype=torch.float) + hyperedge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]], dtype=torch.long) + return HData(x, hyperedge_index, num_nodes=4, num_hyperedges=2) + + +@pytest.fixture +def mock_single_node_single_hyperedge_hdata(): + x = torch.ones((1, 1), dtype=torch.float) + hyperedge_index = torch.tensor([[0], [0]], dtype=torch.long) + return HData(x, hyperedge_index, num_nodes=1, num_hyperedges=1) + + +@pytest.fixture +def mock_empty_hdata(): + x = torch.ones((0, 0), dtype=torch.float) + hyperedge_index = torch.zeros((2, 0), dtype=torch.long) + return HData(x=x, hyperedge_index=hyperedge_index, num_nodes=0, num_hyperedges=0) + + +def test_base_sampler_cannot_be_instantiated(): + with pytest.raises(TypeError): + BaseSampler() + + +def test_create_sampler_from_strategy_hyperedge(): + sampler = create_sampler_from_strategy(SamplingStrategy.HYPEREDGE) + assert isinstance(sampler, HyperedgeSampler) + + +def test_create_sampler_from_strategy_node(): + sampler = create_sampler_from_strategy(SamplingStrategy.NODE) + assert isinstance(sampler, NodeSampler) + + +def test_hyperedge_sampling_single_index(mock_four_node_two_hyperedge_hdata): + sampler = HyperedgeSampler() + result = sampler.sample(0, mock_four_node_two_hyperedge_hdata) + + # hyperedge 0 has nodes 0 and 1 + assert result.hyperedge_index.shape == (2, 2) + assert result.num_hyperedges == 1 + assert torch.equal(result.hyperedge_index[0], torch.tensor([0, 1])) + assert torch.equal(result.hyperedge_index[1], torch.tensor([0, 0])) + assert result.hyperedge_attr is None + assert result.x.shape == (0, 0) + + +def test_hyperedge_sampling_index_list(mock_four_node_two_hyperedge_hdata): + sampler = HyperedgeSampler() + result = sampler.sample([0, 1], mock_four_node_two_hyperedge_hdata) + + # hyperedge 0 (nodes 0, 1) + hyperedge 1 (nodes 2, 3) + assert result.hyperedge_index.shape == (2, 4) + assert result.num_hyperedges == 2 + assert torch.equal(result.hyperedge_index[0], torch.tensor([0, 1, 2, 3])) + assert torch.equal(result.hyperedge_index[1], torch.tensor([0, 0, 1, 1])) + assert result.hyperedge_attr is None + assert result.x.shape == (0, 0) + + +def test_hyperedge_sampling_len(mock_four_node_two_hyperedge_hdata): + sampler = HyperedgeSampler() + assert sampler.len(mock_four_node_two_hyperedge_hdata) == 2 + + +def test_node_sampling_single_index(mock_four_node_two_hyperedge_hdata): + sampler = NodeSampler() + result = sampler.sample(0, mock_four_node_two_hyperedge_hdata) + + # Node 0 is in hyperedge 0 (nodes 0, 1), so we get all incidences of hyperedge 0 + assert result.hyperedge_index.shape == (2, 2) + assert result.num_hyperedges == 1 + assert torch.equal(result.hyperedge_index[0], torch.tensor([0, 1])) + assert torch.equal(result.hyperedge_index[1], torch.tensor([0, 0])) + assert result.hyperedge_attr is None + assert result.x.shape == (0, 0) + + +def test_node_sampling_index_list(mock_four_node_two_hyperedge_hdata): + sampler = NodeSampler() + result = sampler.sample([0, 2], mock_four_node_two_hyperedge_hdata) + + # Node 0 -> hyperedge 0 (nodes 0, 1) + # Node 2 -> hyperedge 1 (nodes 2, 3) + assert result.hyperedge_index.shape == (2, 4) + assert result.num_hyperedges == 2 + assert torch.equal(result.hyperedge_index[0], torch.tensor([0, 1, 2, 3])) + assert torch.equal(result.hyperedge_index[1], torch.tensor([0, 0, 1, 1])) + assert result.hyperedge_attr is None + assert result.x.shape == (0, 0) + + +def test_node_sampling_len(mock_four_node_two_hyperedge_hdata): + sampler = NodeSampler() + assert sampler.len(mock_four_node_two_hyperedge_hdata) == 4 + + +@pytest.mark.parametrize( + "sampler", + [ + pytest.param(NodeSampler(), id="node_sampler"), + pytest.param(HyperedgeSampler(), id="hyperedge_sampler"), + ], +) +def test_sample_empty_index_raises(mock_four_node_two_hyperedge_hdata, sampler): + with pytest.raises(ValueError, match="Index list cannot be empty."): + sampler.sample([], mock_four_node_two_hyperedge_hdata) + + +@pytest.mark.parametrize( + "sampler, label", + [ + pytest.param(NodeSampler(), "Node ID", id="node_sampler_node_id_invalid"), + pytest.param( + HyperedgeSampler(), "Hyperedge ID", id="hyperedge_sampler_hyperedge_id_invalid" + ), + ], +) +def test_sample_index_out_of_bounds_raises(mock_four_node_two_hyperedge_hdata, sampler, label): + with pytest.raises(IndexError, match=rf"{label} 99 is out of bounds"): + sampler.sample(99, mock_four_node_two_hyperedge_hdata) + + +@pytest.mark.parametrize( + "sampler, index_list", + [ + pytest.param( + NodeSampler(), [0, 1, 2, 3, 4], id="node_sampler_index_list_too_large" + ), # 5 > 4 nodes + pytest.param( + HyperedgeSampler(), [0, 1, 2], id="hyperedge_sampler_index_list_too_large" + ), # 3 > 2 hyperedges + ], +) +def test_sample_index_list_too_large_raises( + mock_four_node_two_hyperedge_hdata, sampler, index_list +): + with pytest.raises(ValueError, match="Index list length .* cannot exceed"): + sampler.sample(index_list, mock_four_node_two_hyperedge_hdata) + + +@pytest.mark.parametrize( + "sampler", + [ + pytest.param(NodeSampler(), id="node_sampler"), + pytest.param(HyperedgeSampler(), id="hyperedge_sampler"), + ], +) +def test_sample_returns_correct_hdata(mock_four_node_two_hyperedge_hdata, sampler): + result = sampler.sample(0, mock_four_node_two_hyperedge_hdata) + + assert result.x.shape == (0, 0) + assert result.hyperedge_attr is None + + +def test_sample_single_node_graph_node_sampler(mock_single_node_single_hyperedge_hdata): + sampler = NodeSampler() + result = sampler.sample(0, mock_single_node_single_hyperedge_hdata) + + assert result.hyperedge_index.shape == (2, 1) + assert result.num_hyperedges == 1 + assert torch.equal(result.hyperedge_index[0], torch.tensor([0])) + assert torch.equal(result.hyperedge_index[1], torch.tensor([0])) + + +def test_sample_single_node_graph_hyperedge_sampler(mock_single_node_single_hyperedge_hdata): + sampler = HyperedgeSampler() + result = sampler.sample(0, mock_single_node_single_hyperedge_hdata) + + assert result.hyperedge_index.shape == (2, 1) + assert result.num_hyperedges == 1 + assert torch.equal(result.hyperedge_index[0], torch.tensor([0])) + assert torch.equal(result.hyperedge_index[1], torch.tensor([0])) + + +def test_sample_hyperedge_of_size_one_node_sampler(mock_single_node_single_hyperedge_hdata): + sampler = NodeSampler() + result = sampler.sample(0, mock_single_node_single_hyperedge_hdata) + + # Node 0 is in hyperedge 0 which has only node 0 + assert result.hyperedge_index.shape == (2, 1) + assert result.num_hyperedges == 1 + assert torch.equal(result.hyperedge_index[0], torch.tensor([0])) + assert torch.equal(result.hyperedge_index[1], torch.tensor([0])) + + +def test_sample_hyperedge_of_size_one_hyperedge_sampler(mock_single_node_single_hyperedge_hdata): + sampler = HyperedgeSampler() + result = sampler.sample(0, mock_single_node_single_hyperedge_hdata) + + # Hyperedge 0 has only node 0 + assert result.hyperedge_index.shape == (2, 1) + assert result.num_hyperedges == 1 + assert torch.equal(result.hyperedge_index[0], torch.tensor([0])) + assert torch.equal(result.hyperedge_index[1], torch.tensor([0])) + + +@pytest.mark.parametrize( + "sampler", + [ + pytest.param(NodeSampler(), id="node_sampler"), + pytest.param(HyperedgeSampler(), id="hyperedge_sampler"), + ], +) +def test_sample_empty_hyperedge_index_len(mock_empty_hdata, sampler): + assert sampler.len(mock_empty_hdata) == 0 + + +@pytest.mark.parametrize( + "sampler", + [ + pytest.param(NodeSampler(), id="node_sampler"), + pytest.param(HyperedgeSampler(), id="hyperedge_sampler"), + ], +) +def test_sample_empty_hyperedge_index_empty_list_raises(mock_empty_hdata, sampler): + with pytest.raises(ValueError, match="Index list cannot be empty."): + sampler.sample([], mock_empty_hdata) diff --git a/hyperbench/tests/types/hdata_test.py b/hyperbench/tests/types/hdata_test.py index 9f91ce3..6741a2e 100644 --- a/hyperbench/tests/types/hdata_test.py +++ b/hyperbench/tests/types/hdata_test.py @@ -444,33 +444,33 @@ def test_split_handles_none_edge_attr(): ], ) def test_with_y_to_sets_all_labels_to_value(mock_hdata, value): - result = mock_hdata.with_y_to(value) + hdata = mock_hdata.with_y_to(value) expected_y = torch.full((mock_hdata.num_hyperedges,), value, dtype=torch.float) - assert torch.equal(result.y, expected_y) + assert torch.equal(hdata.y, expected_y) def test_with_y_to_preserves_other_fields(mock_hdata): - result = mock_hdata.with_y_to(0.5) + hdata = mock_hdata.with_y_to(0.5) expected_y = torch.full((mock_hdata.num_hyperedges,), 0.5, dtype=torch.float) - assert torch.equal(result.x, mock_hdata.x) - assert torch.equal(result.hyperedge_index, mock_hdata.hyperedge_index) - assert torch.equal(result.y, expected_y) - assert result.num_nodes == mock_hdata.num_nodes - assert result.num_hyperedges == mock_hdata.num_hyperedges + assert torch.equal(hdata.x, mock_hdata.x) + assert torch.equal(hdata.hyperedge_index, mock_hdata.hyperedge_index) + assert torch.equal(hdata.y, expected_y) + assert hdata.num_nodes == mock_hdata.num_nodes + assert hdata.num_hyperedges == mock_hdata.num_hyperedges def test_with_y_ones_returns_all_ones(mock_hdata): - result = mock_hdata.with_y_ones() + hdata = mock_hdata.with_y_ones() - assert torch.equal(result.y, torch.ones(mock_hdata.num_hyperedges, dtype=torch.float)) + assert torch.equal(hdata.y, torch.ones(mock_hdata.num_hyperedges, dtype=torch.float)) def test_with_y_zeros_returns_all_zeros(mock_hdata): - result = mock_hdata.with_y_zeros() + hdata = mock_hdata.with_y_zeros() - assert torch.equal(result.y, torch.zeros(mock_hdata.num_hyperedges, dtype=torch.float)) + assert torch.equal(hdata.y, torch.zeros(mock_hdata.num_hyperedges, dtype=torch.float)) def test_get_device_if_all_consistent_returns_device_when_all_consistent(): @@ -522,3 +522,110 @@ def test_raises_on_inconsistent_device_placement_on_mps(): with pytest.raises(ValueError, match="Inconsistent device placement"): HData(x=x, hyperedge_index=hyperedge_index) + + +def test_shuffle_preserves_num_nodes_and_num_hyperedges(mock_hdata): + shuffled_hdata = mock_hdata.shuffle(seed=42) + + assert shuffled_hdata.num_nodes == mock_hdata.num_nodes + assert shuffled_hdata.num_hyperedges == mock_hdata.num_hyperedges + + +def test_shuffle_preserves_incidence_structure(mock_hdata): + shuffled_hdata = mock_hdata.shuffle(seed=7) + + def nodes_per_hyperegde(hyperedge_index, num_hyperedge): + hyperedges = set() + for hyperedge_id in range(num_hyperedge): + hyperedge_mask = hyperedge_index[1] == hyperedge_id + nodes_in_hyperedge = tuple(sorted(hyperedge_index[0][hyperedge_mask].tolist())) + hyperedges.add(nodes_in_hyperedge) + return hyperedges + + original_hyperedges = nodes_per_hyperegde(mock_hdata.hyperedge_index, mock_hdata.num_hyperedges) + shuffled_hyperedges = nodes_per_hyperegde( + shuffled_hdata.hyperedge_index, shuffled_hdata.num_hyperedges + ) + + assert original_hyperedges == shuffled_hyperedges + + +def test_shuffle_matches_labels_and_attr_with_correct_hyperedge(): + x = torch.randn(4, 2) + # Hyperedge 0 has nodes {0, 1}, hyperedge 1 has nodes {2, 3} + hyperedge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + y = torch.tensor([1.0, 0.0]) + hyperedge_attr = torch.tensor([[10.0], [20.0]]) + hdata = HData(x=x, hyperedge_index=hyperedge_index, y=y, hyperedge_attr=hyperedge_attr) + + shuffled_hdata = hdata.shuffle(seed=42) + + # For each new hyperedge ID, find which nodes it has and verify the label/attr match + for new_hyperedge_id in range(shuffled_hdata.num_hyperedges): + new_hyperedge_mask = shuffled_hdata.hyperedge_index[1] == new_hyperedge_id + new_nodes = set(shuffled_hdata.hyperedge_index[0][new_hyperedge_mask].tolist()) + + # Find the original hyperedge with the same nodes + for old_hyperedge_id in range(hdata.num_hyperedges): + old_hyperedge_mask = hdata.hyperedge_index[1] == old_hyperedge_id + old_nodes = set(hdata.hyperedge_index[0][old_hyperedge_mask].tolist()) + if old_nodes == new_nodes: + assert shuffled_hdata.y[new_hyperedge_id] == hdata.y[old_hyperedge_id] + assert torch.equal( + utils.to_non_empty_edgeattr(shuffled_hdata.hyperedge_attr)[new_hyperedge_id], + utils.to_non_empty_edgeattr(hdata.hyperedge_attr)[old_hyperedge_id], + ) + break + + +def test_shuffle_permutes_labels(mock_hdata): + mock_hdata.y = torch.tensor([1.0, 0.0, 0.5]) + shuffled_hdata = mock_hdata.shuffle(seed=42) + + # Same multiset of labels + assert sorted(shuffled_hdata.y.tolist()) == sorted(mock_hdata.y.tolist()) + + +def test_shuffle_permutes_hyperedge_attr(mock_hdata): + mock_hdata.hyperedge_attr = torch.tensor([[10.0], [20.0], [30.0]]) + shuffled_hdata = mock_hdata.shuffle(seed=42) + + # Same multiset of attribute rows + original_attr = {tuple(attrs.tolist()) for attrs in mock_hdata.hyperedge_attr} + shuffled_attr = {tuple(attrs.tolist()) for attrs in shuffled_hdata.hyperedge_attr} + + assert original_attr == shuffled_attr + + +def test_shuffle_handles_none_hyperedge_attr(mock_hdata): + mock_hdata.hyperedge_attr = None + shuffled_hdata = mock_hdata.shuffle(seed=42) + + assert shuffled_hdata.hyperedge_attr is None + + +def test_shuffle_does_not_modify_x(mock_hdata): + shuffled_hdata = mock_hdata.shuffle(seed=42) + + assert torch.equal(shuffled_hdata.x, mock_hdata.x) + + +def test_shuffle_with_seed_is_reproducible(): + x = torch.randn(5, 4) + hyperedge_index = torch.tensor([[0, 1, 2, 3, 4], [0, 0, 1, 1, 2]]) + y = torch.tensor([1.0, 0.0, 0.5]) + hdata = HData(x=x, hyperedge_index=hyperedge_index, y=y) + + shuffled_hdata1 = hdata.shuffle(seed=123) + shuffled_hdata2 = hdata.shuffle(seed=123) + + assert torch.equal(shuffled_hdata1.hyperedge_index, shuffled_hdata2.hyperedge_index) + assert torch.equal(shuffled_hdata1.y, shuffled_hdata2.y) + + +def test_shuffle_with_no_seed_set(mock_hdata): + shuffled_hdata1 = mock_hdata.shuffle() + + assert shuffled_hdata1.num_nodes == mock_hdata.num_nodes + assert shuffled_hdata1.num_hyperedges == mock_hdata.num_hyperedges + assert shuffled_hdata1.hyperedge_index.shape == mock_hdata.hyperedge_index.shape diff --git a/hyperbench/types/hdata.py b/hyperbench/types/hdata.py index 48ebcc9..335fbe0 100644 --- a/hyperbench/types/hdata.py +++ b/hyperbench/types/hdata.py @@ -265,6 +265,73 @@ def get_device_if_all_consistent(self) -> torch.device: return devices.pop() if len(devices) == 1 else torch.device("cpu") + def shuffle(self, seed: Optional[int] = None) -> "HData": + """ + Return a new :class:`HData` instance with hyperedge IDs randomly reassigned. + + Each hyperedge keeps its original set of nodes, but is assigned a new ID via a random permutation. + ``y`` and ``hyperedge_attr`` are reordered to match, so that ``y[new_id]`` still corresponds to the correct hyperedge. + Same for ``hyperedge_attr[new_id]`` if hyperedge attributes are present. + + Examples: + >>> hyperedge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + >>> y = torch.tensor([1, 0]) + >>> hdata = HData(x, hyperedge_index=hyperedge_index, y=y) + >>> shuffled_hdata = hdata.shuffle(seed=42) + >>> shuffled_hdata.hyperedge_index # hyperedges may be reassigned + ... # e.g., + ... [[0, 1, 2, 3], + ... [1, 1, 0, 0]] + >>> shuffled_hdata.y # labels are permuted to match new hyperedge IDs, e.g., [0, 1] + + Args: + seed: Optional random seed for reproducibility. If ``None``, the shuffle will be non-deterministic. + + Returns: + A new :class:`HData` instance with hyperedge IDs, ``y``, and ``hyperedge_attr`` permuted. + """ + generator = torch.Generator(device=self.device) + if seed is not None: + generator.manual_seed(seed) + + permutation = torch.randperm(self.num_hyperedges, generator=generator, device=self.device) + + # permutation[new_id] = old_id, so y[permutation] puts old labels into new slots + # inverse_permutation[old_id] = new_id, used to remap hyperedge IDs in incidences + # Example: permutation = [1, 2, 0] means new_id 0 gets old_id 1, new_id 1 gets old_id 2, new_id 2 gets old_id 0 + # -> inverse_permutation = [2, 0, 1] means old_id 0 gets new_id 2, old_id 1 gets new_id 0, old_id 2 gets new_id 1 + inverse_permutation = torch.empty_like(permutation) + inverse_permutation[permutation] = torch.arange(self.num_hyperedges, device=self.device) + + new_hyperedge_index = self.hyperedge_index.clone() + + # Example: hyperedge_index = [[0, 1, 2, 3, 4], + # [0, 0, 1, 1, 2]], + # inverse_permutation = [2, 0, 1] (new_id 0 -> old_id 2, new_id 1 -> old_id 0, new_id 2 -> old_id 1) + # -> new_hyperedge_index = [[0, 1, 2, 3, 4], + # [2, 2, 0, 0, 1]] + old_hyperedge_ids = self.hyperedge_index[1] + new_hyperedge_index[1] = inverse_permutation[old_hyperedge_ids] + + # Example: hyperedge_attr = [attr_0, attr_1, attr_2], permutation = [1, 2, 0] + # -> new_hyperedge_attr = [attr_1 (attr of old_id 1), attr_2 (attr of old_id 2), attr_0 (attr of old_id 0)] + new_hyperedge_attr = ( + self.hyperedge_attr[permutation] if self.hyperedge_attr is not None else None + ) + + # Example: y = [1, 1, 0], permutation = [1, 2, 0] + # -> new_y = [y[1], y[2], y[0]] = [1, 0, 1] + new_y = self.y[permutation] + + return HData( + x=self.x, + hyperedge_index=new_hyperedge_index, + hyperedge_attr=new_hyperedge_attr, + num_nodes=self.num_nodes, + num_hyperedges=self.num_hyperedges, + y=new_y, + ) + def to(self, device: torch.device | str, non_blocking: bool = False) -> "HData": """ Move all tensors to the specified device.