Skip to content

Add Dataset.get_subhypergraph() for self-contained sub-hypergraph sampling #67

@tizianocitro

Description

@tizianocitro

Dataset.__getitem__ currently returns an HData with global IDs and no features, designed for efficient batching in DataLoader.collate. This means dataset[0] returns something that isn't directly usable for inspection or standalone computation (e.g., the hyperedge_index has global IDs but x is empty).

I want to add a get_subhypergraph(index: int | List[int]) method that returns a fully self-contained HData:

  • hyperedge_index remapped to 0-based local IDs.
  • x subset to the nodes in the subgraph.
  • y subset to the hyperedges in the subgraph.
  • hyperedge_attr subset if present.

This keeps getitem fast and minimal (global IDs only, used by DataLoader), while giving users a clean way to extract and inspect a usable subgraph.

For DataLoader (fast, global IDs, no features):

dataloader = DataLoader(dataset, batch_size=16)

For inspection/standalone use (self-contained, 0-based):

subgraph = dataset.get_subgraph(0)
subgraph.x[subgraph.hyperedge_index[0]]  # works correctly

Notes:

  • Reuse the existing __sample_hyperedge_index to filter incidences.
  • Remap to 0-based using to_0based_ids (already available in hyperbench/utils/data_utils.py)
  • Subsets x, y, hyperedge_attr from self.hdata using the sampled global IDs.

Possible implementation

def get_subhypergraph(self, index: int | List[int]):
        sampled_node_ids_list = self.__get_node_ids_to_sample(index)
        self.__validate_node_ids(sampled_node_ids_list)

        sampled_hyperedge_index, sampled_node_ids, sampled_hyperedge_ids = (
            self.__sample_hyperedge_index(sampled_node_ids_list)
        )

        new_x = self.hdata.x[sampled_node_ids]
        new_y = self.hdata.y[sampled_hyperedge_ids]

        new_edge_attr = None
        if self.hdata.hyperedge_attr is not None and len(sampled_hyperedge_ids) > 0:
            new_edge_attr = self.hdata.hyperedge_attr[sampled_hyperedge_ids]

        return HData(
            x=new_x,
            hyperedge_index=sampled_hyperedge_index,
            hyperedge_attr=new_edge_attr,
            num_nodes=len(sampled_node_ids),
            num_hyperedges=len(sampled_hyperedge_ids),
            y=new_y,
        )

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions