Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 49 additions & 3 deletions hyperbench/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,73 @@
Dataset,
HIFConverter,
)

from .supported_datasets import (
AlgebraDataset,
AmazonDataset,
ContactHighSchoolDataset,
ContactPrimarySchoolDataset,
CoraDataset,
CourseraDataset,
DBLPDataset,
EmailEnronDataset,
EmailW3CDataset,
GeometryDataset,
GOTDataset,
IMDBDataset,
MusicBluesReviewsDataset,
NBADataset,
NDCClassesDataset,
NDCSubstancesDataset,
PatentDataset,
PubmedDataset,
RestaurantReviewsDataset,
ThreadsAskUbuntuDataset,
ThreadsMathsxDataset,
TwitterDataset,
VegasBarsReviewsDataset,
)

from .loader import DataLoader

from .sampling import (
BaseSampler,
HyperedgeSampler,
NodeSampler,
SamplingStrategy,
create_sampler_from_strategy,
)

__all__ = [
"Dataset",
"DataLoader",
"AlgebraDataset",
"AmazonDataset",
"BaseSampler",
"ContactHighSchoolDataset",
"ContactPrimarySchoolDataset",
"CoraDataset",
"CourseraDataset",
"Dataset",
"DataLoader",
"DBLPDataset",
"EmailEnronDataset",
"EmailW3CDataset",
"GeometryDataset",
"GOTDataset",
"HIFConverter",
"HyperedgeSampler",
"IMDBDataset",
"MusicBluesReviewsDataset",
"NBADataset",
"NDCClassesDataset",
"NDCSubstancesDataset",
"NodeSampler",
"PatentDataset",
"PubmedDataset",
"RestaurantReviewsDataset",
"SamplingStrategy",
"ThreadsAskUbuntuDataset",
"ThreadsMathsxDataset",
"HIFConverter",
"TwitterDataset",
"VegasBarsReviewsDataset",
"create_sampler_from_strategy",
]
139 changes: 43 additions & 96 deletions hyperbench/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
import requests

from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional
from torch import Tensor
from torch.utils.data import Dataset as TorchDataset
from hyperbench.types import HData, HIFHypergraph
from hyperbench.utils import validate_hif_json

from .sampling import SamplingStrategy, create_sampler_from_strategy


class DatasetNames(Enum):
"""
Expand Down Expand Up @@ -97,53 +99,71 @@ def load_from_hif(dataset_name: Optional[str], save_on_disk: bool = False) -> HI


class Dataset(TorchDataset):
"""
A dataset class for loading and processing hypergraph data.

Attributes:
DATASET_NAME: Class variable indicating the name of the dataset to load.
hypergraph: The loaded hypergraph in HIF format. Can be ``None`` if initialized from an HData object.
hdata: The processed hypergraph data in HData format.
sampling_strategy: The strategy used for sampling sub-hypergraphs (e.g., by node IDs or hyperedge IDs).
If not provided, defaults to ``SamplingStrategy.HYPEREDGE``.
"""

DATASET_NAME = None

def __init__(
self,
hdata: Optional[HData] = None,
sampling_strategy: SamplingStrategy = SamplingStrategy.HYPEREDGE,
is_original: Optional[bool] = True,
) -> None:
self.__is_original = is_original
self.__sampler = create_sampler_from_strategy(sampling_strategy)

self.hypergraph: HIFHypergraph = self.download() if hdata is None else HIFHypergraph.empty()
self.hdata: HData = self.process() if hdata is None else hdata
self.sampling_strategy = sampling_strategy
self.hypergraph = self.download() if hdata is None else HIFHypergraph.empty()
self.hdata = self.process() if hdata is None else hdata

def __len__(self) -> int:
return self.hdata.num_nodes
return self.__sampler.len(self.hdata)

def __getitem__(self, index: int | List[int]) -> HData:
"""
Sample a sub-hypergraph containing the specified node(s) and all hyperedges incident to those nodes.

Note:
The returned :class:`HData` contains only the sampled nodes and their incident hyperedges.
Node features (x) and hyperedge attributes are not included in the returned :class:`HData` .
Sample a sub-hypergraph based on the sampling strategy and return it as HData.
If:
- Sampling by node IDs, the sub-hypergraph will contain all hyperedges incident to the sampled nodes and all nodes incident to those hyperedges.
- Sampling by hyperedge IDs, the sub-hypergraph will contain all nodes incident to the sampled hyperedges.

Args:
index: An integer or a list of integers representing node IDs to sample.
index: An integer or a list of integers representing node or hyperedge IDs to sample, depending on the sampling strategy.

Returns:
An :class:`HData` object containing the sub-hypergraph induced by the specified node(s) and their incident hyperedges.
"""
sampled_node_ids_list = self.__get_node_ids_to_sample(index)
self.__validate_node_ids(sampled_node_ids_list)
An HData instance containing the sampled sub-hypergraph.

sampled_hyperedge_index, _, _ = self.__sample_hyperedge_index(sampled_node_ids_list)
return HData.from_hyperedge_index(sampled_hyperedge_index)
Raises:
ValueError: If the provided index is invalid (e.g., empty list or list length exceeds number of nodes/hyperedges).
IndexError: If any node/hyperedge ID is out of bounds.
"""
return self.__sampler.sample(index, self.hdata)

@classmethod
def from_hdata(cls, hdata: HData) -> "Dataset":
def from_hdata(
cls,
hdata: HData,
sampling_strategy: SamplingStrategy = SamplingStrategy.HYPEREDGE,
) -> "Dataset":
"""
Create a :class:`Dataset` instance from an :class:`HData` object.

Args:
hdata: :class:`HData` object containing the hypergraph data.
sampling_strategy: The sampling strategy to use for the dataset. If not provided, defaults to ``SamplingStrategy.HYPEREDGE``.

Returns:
The :class:`Dataset` instance with the provided :class:`HData`.
"""
return cls(hdata=hdata, is_original=False)
return cls(hdata=hdata, sampling_strategy=sampling_strategy, is_original=False)

def download(self) -> HIFHypergraph:
"""
Expand Down Expand Up @@ -281,7 +301,11 @@ def split(
# i=2 -> permutation[2:3] = [2] (1 edge)
split_hyperedge_ids = hyperedge_ids_permutation[start:end]
split_hdata = HData.split(self.hdata, split_hyperedge_ids).to(device=device)
split_dataset = self.__class__(hdata=split_hdata, is_original=False)
split_dataset = self.__class__(
hdata=split_hdata,
sampling_strategy=self.sampling_strategy,
is_original=False,
)
split_datasets.append(split_dataset)

start = end
Expand Down Expand Up @@ -389,30 +413,6 @@ def __get_hyperedge_ids_permutation(
ranged_hyperedge_ids_permutation = torch.arange(num_hyperedges, device=device)
return ranged_hyperedge_ids_permutation

def __get_node_ids_to_sample(self, id: int | List[int]) -> List[int]:
"""
Get a list of node IDs to sample based on the provided index.

Args:
id: An integer or a list of integers representing node IDs to sample.

Returns:
List of node IDs to sample.

Raises:
ValueError: If the provided index is invalid (e.g., empty list or list length exceeds number of nodes).
"""
if isinstance(id, list):
if len(id) < 1:
raise ValueError("Index list cannot be empty.")
elif len(id) > self.__len__():
raise ValueError(
"Index list length cannot exceed number of nodes in the hypergraph."
)
return list(set(id))

return [id]

def __process_hyperedge_attr(
self,
hyperedge_id_to_idx: Dict[Any, int],
Expand Down Expand Up @@ -469,56 +469,3 @@ def __process_x(self, num_nodes: int) -> Tensor:
x = torch.ones((num_nodes, 1), dtype=torch.float)

return x # shape [num_nodes, num_node_features]

def __sample_hyperedge_index(
self,
sampled_node_ids_list: List[int],
) -> Tuple[Tensor, Tensor, Tensor]:
hyperedge_index = self.hdata.hyperedge_index
node_ids = hyperedge_index[0]
hyperedge_ids = hyperedge_index[1]

sampled_node_ids = torch.tensor(sampled_node_ids_list, device=node_ids.device)

# Find incidences where the node is in our sampled node set
# Example: hyperedge_index[0] = [0, 0, 1, 2, 3, 4], sampled_node_ids = [0, 3]
# -> node_incidence_mask = [True, True, False, False, True, False]
node_incidence_mask = torch.isin(node_ids, sampled_node_ids)

# Get unique hyperedges that have at least one sampled node
# Example: hyperedge_index[1] = [0, 0, 0, 1, 2, 2], node_incidence_mask = [True, True, False, False, True, False]
# -> sampled_hyperedge_ids = [0, 2] as they connect to sampled nodes
sampled_hyperedge_ids = hyperedge_ids[node_incidence_mask].unique()

# Find all incidences for the sampled hyperedges (not just sampled nodes)
# Example: hyperedge_index[1] = [0, 0, 0, 1, 2, 2], sampled_hyperedge_ids = [0, 2]
# -> hyperedge_incidence_mask = [True, True, True, False, True, True]
hyperedge_incidence_mask = torch.isin(hyperedge_ids, sampled_hyperedge_ids)

# Collect all node IDs that appear in the sampled hyperedges
# Example: hyperedge_index[0] = [0, 0, 1, 2, 3, 4], hyperedge_incidence_mask = [True, True, True, False, True, True]
# -> node_ids_in_sampled_hyperedge = [0, 1, 3, 4]
node_ids_in_sampled_hyperedge = node_ids[hyperedge_incidence_mask].unique()

# Keep all incidences belonging to the sampled hyperedges
# Example: hyperedge_index = [[0, 0, 1, 2, 3, 4],
# [0, 0, 0, 1, 2, 2]],
# hyperedge_incidence_mask = [True, True, True, False, True, True]
# -> sampled_hyperedge_index = [[0, 0, 1, 3, 4],
# [0, 0, 0, 2, 2]]
sampled_hyperedge_index = hyperedge_index[:, hyperedge_incidence_mask]
return sampled_hyperedge_index, node_ids_in_sampled_hyperedge, sampled_hyperedge_ids

def __validate_node_ids(self, node_ids: List[int]) -> None:
"""
Validate that node IDs are within bounds of the hypergraph.

Args:
node_ids: List of node IDs to validate.

Raises:
IndexError: If any node ID is out of bounds.
"""
for id in node_ids:
if id < 0 or id >= self.__len__():
raise IndexError(f"Node ID {id} is out of bounds (0, {self.__len__() - 1}).")
Loading