diff --git a/docs/api/hooks/hooks.md b/docs/api/hooks/hooks.md index 7013832b..25506af0 100644 --- a/docs/api/hooks/hooks.md +++ b/docs/api/hooks/hooks.md @@ -9,8 +9,12 @@ ::: tgm.hooks.PinMemoryHook ::: tgm.hooks.DeviceTransferHook ::: tgm.hooks.DeduplicationHook -::: tgm.hooks.TGBNegativeEdgeSamplerHook -::: tgm.hooks.NegativeEdgeSamplerHook -::: tgm.hooks.RecencyNeighborHook -::: tgm.hooks.BatchAnalyticsHook -::: tgm.hooks.NodeAnalyticsHook +::: tgm.hooks.negatives.TGBNegativeEdgeSamplerHook +::: tgm.hooks.negatives.TGBTHGNegativeEdgeSamplerHook +::: tgm.hooks.negatives.TGBTKGNegativeEdgeSamplerHook +::: tgm.hooks.negatives.RandomNegativeEdgeSamplerHook +::: tgm.hooks.negatives.HistoricalNegativeEdgeSamplerHook +::: tgm.hooks.neighbors.NeighborSamplerHook +::: tgm.hooks.neighbors.RecencyNeighborHook +::: tgm.hooks.analytics.BatchAnalyticsHook +::: tgm.hooks.analytics.NodeAnalyticsHook diff --git a/examples/analytics/node_analytics_example.py b/examples/analytics/node_analytics_example.py index d2c31147..b6aad4d5 100644 --- a/examples/analytics/node_analytics_example.py +++ b/examples/analytics/node_analytics_example.py @@ -9,7 +9,7 @@ from tgm import DGraph from tgm.data import DGData, DGDataLoader from tgm.hooks import HookManager -from tgm.hooks.node_analytics import NodeAnalyticsHook +from tgm.hooks.analytics.node_analytics import NodeAnalyticsHook from tgm.util.logging import enable_logging, log_latency, log_metrics_dict from tgm.util.seed import seed_everything diff --git a/test/unit/test_hooks/test_negative_edge_sampler_hook.py b/test/unit/test_hooks/test_negative_edge_sampler_hook.py index f1379196..485f8c8b 100644 --- a/test/unit/test_hooks/test_negative_edge_sampler_hook.py +++ b/test/unit/test_hooks/test_negative_edge_sampler_hook.py @@ -2,8 +2,13 @@ import torch from tgm import DGBatch, DGraph +from tgm.constants import PADDED_NODE_ID from tgm.data import DGData, DGDataLoader -from tgm.hooks import HookManager, NegativeEdgeSamplerHook +from tgm.hooks import ( + HistoricalNegativeEdgeSamplerHook, + HookManager, + RandomNegativeEdgeSamplerHook, +) @pytest.fixture @@ -14,36 +19,48 @@ def data(): def test_hook_dependancies(): - hook = NegativeEdgeSamplerHook(low=0, high=10) + hook = RandomNegativeEdgeSamplerHook(low=0, high=10) assert hook.requires == {'edge_src', 'edge_dst', 'edge_time'} assert hook.produces == {'neg', 'neg_time'} - hook_with_id = NegativeEdgeSamplerHook(low=0, high=10, id='foo') + hook_with_id = RandomNegativeEdgeSamplerHook(low=0, high=10, id='foo') assert hook_with_id.requires == {'edge_src', 'edge_dst', 'edge_time'} assert hook_with_id.produces == {'neg_foo', 'neg_time_foo'} + hook = HistoricalNegativeEdgeSamplerHook() + assert hook.requires == {'edge_src', 'edge_dst', 'edge_time'} + assert hook.produces == {'neg', 'neg_time', 'valid_neg_mask'} + + hook_with_id = HistoricalNegativeEdgeSamplerHook(id='foo') + assert hook_with_id.requires == {'edge_src', 'edge_dst', 'edge_time'} + assert hook_with_id.produces == {'neg_foo', 'neg_time_foo', 'valid_neg_mask_foo'} + def test_hook_repre(): - hook_with_id = NegativeEdgeSamplerHook(low=0, high=10, id='foo') + hook_with_id = RandomNegativeEdgeSamplerHook(low=0, high=10, id='foo') + assert 'foo' in hook_with_id.__repr__() + + hook_with_id = HistoricalNegativeEdgeSamplerHook(id='foo') assert 'foo' in hook_with_id.__repr__() def test_hook_reset_state(): - assert NegativeEdgeSamplerHook.has_state == False + assert RandomNegativeEdgeSamplerHook.has_state == False + assert HistoricalNegativeEdgeSamplerHook.has_state == True def test_bad_negative_edge_sampler_init(): with pytest.raises(ValueError): - NegativeEdgeSamplerHook(low=0, high=0) + RandomNegativeEdgeSamplerHook(low=0, high=0) with pytest.raises(ValueError): - NegativeEdgeSamplerHook(low=0, high=1, neg_ratio=0) + RandomNegativeEdgeSamplerHook(low=0, high=1, neg_ratio=0) with pytest.raises(ValueError): - NegativeEdgeSamplerHook(low=0, high=1, neg_ratio=2) + RandomNegativeEdgeSamplerHook(low=0, high=1, neg_ratio=2) def test_negative_edge_sampler(data): dg = DGraph(data) - hook = NegativeEdgeSamplerHook(low=0, high=10) + hook = RandomNegativeEdgeSamplerHook(low=0, high=10) batch = hook(dg, dg.materialize()) assert isinstance(batch, DGBatch) assert torch.is_tensor(batch.neg) @@ -54,7 +71,7 @@ def test_negative_edge_sampler(data): def test_negative_edge_sampler_with_id(data): dg = DGraph(data) - hook = NegativeEdgeSamplerHook(low=0, high=10, id='foo') + hook = RandomNegativeEdgeSamplerHook(low=0, high=10, id='foo') batch = hook(dg, dg.materialize()) assert isinstance(batch, DGBatch) assert torch.is_tensor(batch.neg_foo) @@ -82,7 +99,7 @@ def node_only_data(): def test_node_only_batch_negative_edge_sampler(node_only_data): dg = DGraph(node_only_data) hm = HookManager(keys=['unit']) - hm.register('unit', NegativeEdgeSamplerHook(low=0, high=6)) + hm.register('unit', RandomNegativeEdgeSamplerHook(low=0, high=6)) loader = DGDataLoader(dg, batch_size=3, hook_manager=hm) with hm.activate('unit'): batch_iter = iter(loader) @@ -97,3 +114,87 @@ def test_node_only_batch_negative_edge_sampler(node_only_data): assert isinstance(batch_2, DGBatch) assert batch_2.neg.shape == (0,) assert batch_2.neg_time.shape == (0,) + + +@pytest.fixture +def data_test_hst_sampling(): + edge_index = torch.IntTensor( + [ + # 1st batch + [1, 5], + [7, 6], + [2, 8], + [7, 8], + # 2nd batch + [1, 7], + [9, 10], + [3, 10], + [1, 9], + # 3rd batch + [3, 11], + [2, 10], + [7, 2], + [3, 5], + ] + ) + edge_time = torch.arange(edge_index.size(0)) + return DGData.from_raw(edge_time, edge_index) + + +def test_hst_sampling(data_test_hst_sampling): + dg = DGraph(data_test_hst_sampling) + + hm = HookManager(keys=['unit']) + sampler = HistoricalNegativeEdgeSamplerHook() + + hm.register('unit', sampler) + loader = DGDataLoader(dg, batch_size=4, hook_manager=hm) + + with hm.activate('unit'): + batch_iter = iter(loader) + batch_1 = next(batch_iter) + assert batch_1.neg.shape == (4,) + assert torch.equal( + batch_1.neg, + torch.Tensor( + [PADDED_NODE_ID, PADDED_NODE_ID, PADDED_NODE_ID, PADDED_NODE_ID] + ), + ) + assert torch.equal( + batch_1.valid_neg_mask, + torch.Tensor([False, False, False, False]), + ) + assert sampler._memory is not None + assert sampler._memory.shape == (2, 8) + assert sampler._count == 4 + + batch_2 = next(batch_iter) + assert batch_2.neg.shape == (4,) + assert torch.equal( + batch_2.neg, torch.Tensor([5, PADDED_NODE_ID, PADDED_NODE_ID, 5]) + ) + assert torch.equal( + batch_2.valid_neg_mask, + torch.Tensor([True, False, False, True]), + ) + assert sampler._memory is not None + assert sampler._memory.shape == (2, 8) + assert sampler._count == 8 + + batch_3 = next(batch_iter) + assert batch_3.neg.shape == (4,) + assert torch.equal(batch_3.neg, torch.Tensor([10, 8, 8, 10])) or torch.equal( + batch_3.neg, torch.Tensor([10, 8, 6, 10]) + ) + assert torch.equal( + batch_3.valid_neg_mask, + torch.Tensor([True, True, True, True]), + ) + assert sampler._memory is not None + assert sampler._memory.shape == (2, 24) + assert sampler._count != 0 + assert sampler._count == 12 + + sampler.reset_state() + assert sampler._memory is None + assert sampler._count == 0 diff --git a/test/unit/test_hooks/test_node_analytics_hook.py b/test/unit/test_hooks/test_node_analytics_hook.py index 4c17304a..b208a952 100644 --- a/test/unit/test_hooks/test_node_analytics_hook.py +++ b/test/unit/test_hooks/test_node_analytics_hook.py @@ -3,7 +3,7 @@ from tgm import DGBatch, DGraph from tgm.data import DGData -from tgm.hooks.node_analytics import NodeAnalyticsHook +from tgm.hooks.analytics.node_analytics import NodeAnalyticsHook @pytest.fixture diff --git a/test/unit/test_hooks/test_recipe.py b/test/unit/test_hooks/test_recipe.py index f8c4f982..11e22638 100644 --- a/test/unit/test_hooks/test_recipe.py +++ b/test/unit/test_hooks/test_recipe.py @@ -10,7 +10,7 @@ from tgm.exceptions import UndefinedRecipe from tgm.hooks import ( HookManager, - NegativeEdgeSamplerHook, + RandomNegativeEdgeSamplerHook, RecipeRegistry, TGBNegativeEdgeSamplerHook, ) @@ -59,7 +59,7 @@ def test_build_recipe_tgb_link_pred(mock_dataset_cls, tgb_dataset_factory, dg): and register_keys[2] == 'test' ) assert len(train_hooks) == len(val_hooks) == len(test_hooks) == 1 - assert isinstance(train_hooks[0], NegativeEdgeSamplerHook) + assert isinstance(train_hooks[0], RandomNegativeEdgeSamplerHook) assert isinstance(val_hooks[0], TGBNegativeEdgeSamplerHook) assert isinstance(test_hooks[0], TGBNegativeEdgeSamplerHook) diff --git a/tgm/hooks/__init__.py b/tgm/hooks/__init__.py index e2b8c781..8a6e6565 100644 --- a/tgm/hooks/__init__.py +++ b/tgm/hooks/__init__.py @@ -2,7 +2,8 @@ from .dedup import DeduplicationHook from .device import DeviceTransferHook, PinMemoryHook from .negatives import ( - NegativeEdgeSamplerHook, + RandomNegativeEdgeSamplerHook, + HistoricalNegativeEdgeSamplerHook, TGBNegativeEdgeSamplerHook, TGBTHGNegativeEdgeSamplerHook, TGBTKGNegativeEdgeSamplerHook, @@ -11,5 +12,4 @@ from .hook_manager import HookManager from .recipe import RecipeRegistry from .node_tracks import EdgeEventsSeenNodesTrackHook -from .batch_analytics import BatchAnalyticsHook -from .node_analytics import NodeAnalyticsHook +from .analytics import BatchAnalyticsHook, NodeAnalyticsHook diff --git a/tgm/hooks/analytics/__init__.py b/tgm/hooks/analytics/__init__.py new file mode 100644 index 00000000..dbbc2f99 --- /dev/null +++ b/tgm/hooks/analytics/__init__.py @@ -0,0 +1,2 @@ +from .batch_analytics import BatchAnalyticsHook +from .node_analytics import NodeAnalyticsHook diff --git a/tgm/hooks/batch_analytics.py b/tgm/hooks/analytics/batch_analytics.py similarity index 100% rename from tgm/hooks/batch_analytics.py rename to tgm/hooks/analytics/batch_analytics.py diff --git a/tgm/hooks/node_analytics.py b/tgm/hooks/analytics/node_analytics.py similarity index 100% rename from tgm/hooks/node_analytics.py rename to tgm/hooks/analytics/node_analytics.py diff --git a/tgm/hooks/negatives.py b/tgm/hooks/negatives.py deleted file mode 100644 index daebfc2b..00000000 --- a/tgm/hooks/negatives.py +++ /dev/null @@ -1,441 +0,0 @@ -from __future__ import annotations - -from pathlib import Path - -import torch - -from tgm import DGBatch, DGraph -from tgm.hooks import StatelessHook -from tgm.util.logging import _get_logger - -logger = _get_logger(__name__) - - -class NegativeEdgeSamplerHook(StatelessHook): - """Sample negative edges for dynamic link prediction. - - Args: - low (int): The minimum node id to sample - high (int) : The maximum node id to sample - neg_ratio (float): The ratio of sampled negative destination nodes - to the number of positive destination nodes (default = 1.0). - id (str): A unique identifier for the hook. The hook’s name and all attributes it produces will be suffixed with this `id`. - """ - - _cls_requires = {'edge_src', 'edge_dst', 'edge_time'} - _cls_produces = {'neg', 'neg_time'} - - def __init__( - self, low: int, high: int, neg_ratio: float = 1.0, id: str | None = None - ) -> None: - super().__init__() - if not 0 < neg_ratio <= 1: - raise ValueError(f'neg_ratio must be in (0, 1], got: {neg_ratio}') - if not low < high: - raise ValueError(f'low ({low}) must be strictly less than high ({high})') - self.low = low - self.high = high - self.neg_ratio = neg_ratio - self._id = id - self.__post_init__() - - # TODO: Historical vs. random - def __call__(self, dg: DGraph, batch: DGBatch) -> DGBatch: - size = (round(self.neg_ratio * batch.edge_dst.size(0)),) - if size[0] == 0: - self.add_batch_attribute( - batch, 'neg', torch.empty(size, dtype=torch.int32, device=dg.device) - ) - self.add_batch_attribute( - batch, - 'neg_time', - torch.empty(size, dtype=torch.int64, device=dg.device), - ) - else: - self.add_batch_attribute( - batch, - 'neg', - torch.randint( - self.low, self.high, size, dtype=torch.int32, device=dg.device - ), - ) - self.add_batch_attribute(batch, 'neg_time', batch.edge_time.clone()) - return batch - - -class TGBNegativeEdgeSamplerHook(StatelessHook): - """Load data from DGraph using pre-generated TGB negative samples. - Make sure to perform `dataset.load_val_ns()` or `dataset.load_test_ns()` before using this hook. - - Args: - dataset_name (str): The name of the TGB dataset to produce sampler for. - split_mode (str): The split mode to use for sampling, either 'val' or 'test'. - id (str): A unique identifier for the hook. The hook’s name and all attributes it produces will be suffixed with this `id`. - - Raises: - ValueError: If neg_sampler is not provided. - """ - - _cls_requires = {'edge_src', 'edge_dst', 'edge_time'} - _cls_produces = {'neg', 'neg_batch_list', 'neg_time'} - - def __init__( - self, dataset_name: str, split_mode: str, id: str | None = None - ) -> None: - super().__init__() - if split_mode not in ['val', 'test']: - raise ValueError(f'split_mode must be "val" or "test", got: {split_mode}') - - try: - from tgb.linkproppred.negative_sampler import NegativeEdgeSampler - from tgb.utils.info import DATA_VERSION_DICT, PROJ_DIR - except ImportError: - raise ImportError( - f'TGB required for {self.__class__.__name__}, try `pip install py-tgb`' - ) - - if not dataset_name.startswith('tgbl-'): - raise ValueError( - 'TGBNegativeEdgeSamplerHook should only be registered for ' - f'"tgbl-xxx" datasets, but got: {dataset_name}' - ) - - neg_sampler = NegativeEdgeSampler(dataset_name=dataset_name) - - # Load evaluation sets - root = Path(PROJ_DIR + 'datasets') / dataset_name.replace('-', '_') - if DATA_VERSION_DICT.get(dataset_name, 1) > 1: - version_suffix = f'_v{DATA_VERSION_DICT[dataset_name]}' - else: - version_suffix = '' - - ns_fname = root / f'{dataset_name}_{split_mode}_ns{version_suffix}.pkl' - logger.debug( - 'Loading %s split (neg_sampler.load_eval_set) for dataset: %s from file: %s', - split_mode, - dataset_name, - ns_fname, - ) - neg_sampler.load_eval_set(fname=str(ns_fname), split_mode=split_mode) - - self.neg_sampler = neg_sampler - self.split_mode = split_mode - self._id = id - self.__post_init__() - - def __call__(self, dg: DGraph, batch: DGBatch) -> DGBatch: - if batch.edge_src.size(0) == 0: - batch_neg = torch.empty( - batch.edge_src.size(0), dtype=torch.int32, device=dg.device - ) - batch_neg_time = torch.empty( - batch.edge_src.size(0), dtype=torch.int64, device=dg.device - ) - batch_neg_batch_list = [] - # return batch # empty batch - else: - try: - neg_batch_list = self.neg_sampler.query_batch( - batch.edge_src, - batch.edge_dst, - batch.edge_time, - split_mode=self.split_mode, - ) - except ValueError as e: - raise ValueError( - f'Negative sampling failed for split_mode={self.split_mode}. Try updating your TGB package: `pip install --upgrade py-tgb`' - ) from e - - batch_neg_batch_list = [ - torch.tensor(neg_batch, dtype=torch.int32, device=dg.device) - for neg_batch in neg_batch_list - ] - batch_neg = torch.unique(torch.cat(batch_neg_batch_list)) - - # This is a heuristic. For our fake (negative) link times, - # we pick random time stamps within [batch.start_time, batch.end_time]. - # Using random times on the whole graph will likely produce information - # leakage, making the prediction easier than it should be. - - # Use generator to local constrain rng for reproducibility - gen = torch.Generator(device=dg.device) - gen.manual_seed(0) - batch_neg_time = torch.randint( - int(batch.edge_time.min().item()), - int(batch.edge_time.max().item()) + 1, - (batch_neg.size(0),), - device=dg.device, - generator=gen, - ) - - self.add_batch_attribute(batch, 'neg', batch_neg) - self.add_batch_attribute(batch, 'neg_batch_list', batch_neg_batch_list) - self.add_batch_attribute(batch, 'neg_time', batch_neg_time) - return batch - - -class TGBTHGNegativeEdgeSamplerHook(StatelessHook): - """Load data from DGraph using pre-generated TGB negative samples for heterogeneous graph. - Make sure to perform `dataset.load_val_ns()` or `dataset.load_test_ns()` before using this hook. - - Args: - dataset_name (str): The name of the TGB dataset to produce sampler for. - split_mode (str): The split mode to use for sampling, either 'val' or 'test'. - first_node_id (int): identity of the first node - last_node_id (int): identity of the last destination node - node_type (Tensor): the node type of each node - id (str): A unique identifier for the hook. The hook’s name and all attributes it produces will be suffixed with this `id`. - - - - Raises: - ValueError: If neg_sampler is not provided. - """ - - _cls_requires = {'edge_src', 'edge_dst', 'edge_time', 'edge_type'} - _cls_produces = {'neg', 'neg_batch_list', 'neg_time'} - - def __init__( - self, - dataset_name: str, - split_mode: str, - first_node_id: int, - last_node_id: int, - node_type: torch.Tensor, - id: str | None = None, - ) -> None: - super().__init__() - if split_mode not in ['val', 'test']: - raise ValueError(f'split_mode must be "val" or "test", got: {split_mode}') - - if first_node_id < 0 or last_node_id < 0: - raise ValueError('First and last ID of node must be positive') - - if node_type is None: - raise ValueError('Node type must not be None') - - if node_type.shape[0] < last_node_id: - raise ValueError(f'last_node_id {last_node_id} must be within node_type') - - try: - from tgb.linkproppred.thg_negative_sampler import ( - THGNegativeEdgeSampler, - ) - from tgb.utils.info import DATA_VERSION_DICT, PROJ_DIR - except ImportError: - raise ImportError( - f'TGB required for {self.__class__.__name__}, try `pip install py-tgb`' - ) - - if not dataset_name.startswith('thgl-'): - raise ValueError( - 'TGBTHGNegativeEdgeSamplerHook should only be registered for ' - f'"thgl-xxx" datasets, but got: {dataset_name}' - ) - - neg_sampler = THGNegativeEdgeSampler( - dataset_name=dataset_name, - first_node_id=first_node_id, - last_node_id=last_node_id, - node_type=node_type.numpy(), - ) - - # Load evaluation sets - root = Path(PROJ_DIR + 'datasets') / dataset_name.replace('-', '_') - if DATA_VERSION_DICT.get(dataset_name, 1) > 1: - version_suffix = f'_v{DATA_VERSION_DICT[dataset_name]}' - else: - version_suffix = '' - - ns_fname = root / f'{dataset_name}_{split_mode}_ns{version_suffix}.pkl' - logger.debug( - 'Loading %s split (neg_sampler.load_eval_set) for dataset: %s from file: %s', - split_mode, - dataset_name, - ns_fname, - ) - neg_sampler.load_eval_set(fname=str(ns_fname), split_mode=split_mode) - - self.neg_sampler = neg_sampler - self.split_mode = split_mode - - self._id = id - - self.__post_init__() - - def __call__(self, dg: DGraph, batch: DGBatch) -> DGBatch: - if batch.edge_src.size(0) == 0: - batch_neg = torch.empty( - batch.edge_src.size(0), dtype=torch.int32, device=dg.device - ) - batch_neg_time = torch.empty( - batch.edge_src.size(0), dtype=torch.int64, device=dg.device - ) - batch_neg_batch_list = [] - else: - try: - neg_batch_list = self.neg_sampler.query_batch( - batch.edge_src, - batch.edge_dst, - batch.edge_time, - batch.edge_type, - split_mode=self.split_mode, - ) - except ValueError as e: - raise ValueError( - f'THGL Negative sampling failed for split_mode={self.split_mode}. Try updating your TGB package: `pip install --upgrade py-tgb`' - ) from e - - batch_neg_batch_list = [ - torch.tensor(neg_batch, dtype=torch.int32, device=dg.device) - for neg_batch in neg_batch_list - ] - batch_neg = torch.unique(torch.cat(batch_neg_batch_list)) - - # This is a heuristic. For our fake (negative) link times, - # we pick random time stamps within [batch.start_time, batch.end_time]. - # Using random times on the whole graph will likely produce information - # leakage, making the prediction easier than it should be. - - # Use generator to local constrain rng for reproducibility - gen = torch.Generator(device=dg.device) - gen.manual_seed(0) - batch_neg_time = torch.randint( - int(batch.edge_time.min().item()), - int(batch.edge_time.max().item()) + 1, - (batch_neg.size(0),), - device=dg.device, - generator=gen, - ) - - self.add_batch_attribute(batch, 'neg', batch_neg) - self.add_batch_attribute(batch, 'neg_batch_list', batch_neg_batch_list) - self.add_batch_attribute(batch, 'neg_time', batch_neg_time) - return batch - - -class TGBTKGNegativeEdgeSamplerHook(StatelessHook): - """Load data from DGraph using pre-generated TGB negative samples for knowledge graph. - Make sure to perform `dataset.load_val_ns()` or `dataset.load_test_ns()` before using this hook. - - Args: - dataset_name (str): The name of the TGB dataset to produce sampler for. - split_mode (str): The split mode to use for sampling, either 'val' or 'test'. - first_dst_id (int): identity of the first destination node - last_dst_id (int): identity of the last destination node - id (str): A unique identifier for the hook. The hook’s name and all attributes it produces will be suffixed with this `id`. - - - Raises: - ValueError: If neg_sampler is not provided. - """ - - _cls_requires = {'edge_src', 'edge_dst', 'edge_time', 'edge_type'} - _cls_produces = {'neg', 'neg_batch_list', 'neg_time'} - - def __init__( - self, - dataset_name: str, - split_mode: str, - first_dst_id: int, - last_dst_id: int, - id: str | None = None, - ) -> None: - super().__init__() - if split_mode not in ['val', 'test']: - raise ValueError(f'split_mode must be "val" or "test", got: {split_mode}') - - if first_dst_id < 0 or last_dst_id < 0: - raise ValueError('First and last ID of node must be positive') - - try: - from tgb.linkproppred.tkg_negative_sampler import ( - TKGNegativeEdgeSampler, - ) - from tgb.utils.info import DATA_VERSION_DICT, PROJ_DIR - except ImportError: - raise ImportError( - f'TGB required for {self.__class__.__name__}, try `pip install py-tgb`' - ) - - if not dataset_name.startswith('tkgl-'): - raise ValueError( - 'TGBTKGNegativeEdgeSamplerHook should only be registered for ' - f'"tkgl-xxx" datasets, but got: {dataset_name}' - ) - - neg_sampler = TKGNegativeEdgeSampler( - dataset_name=dataset_name, - first_dst_id=first_dst_id, - last_dst_id=last_dst_id, - ) - - # Load evaluation sets - root = Path(PROJ_DIR + 'datasets') / dataset_name.replace('-', '_') - if DATA_VERSION_DICT.get(dataset_name, 1) > 1: - version_suffix = f'_v{DATA_VERSION_DICT[dataset_name]}' - else: - version_suffix = '' - - ns_fname = root / f'{dataset_name}_{split_mode}_ns{version_suffix}.pkl' - logger.debug( - 'Loading %s split (neg_sampler.load_eval_set) for dataset: %s from file: %s', - split_mode, - dataset_name, - ns_fname, - ) - neg_sampler.load_eval_set(fname=str(ns_fname), split_mode=split_mode) - - self.neg_sampler = neg_sampler - self.split_mode = split_mode - self._id = id - self.__post_init__() - - def __call__(self, dg: DGraph, batch: DGBatch) -> DGBatch: - if batch.edge_src.size(0) == 0: - batch_neg = torch.empty( - batch.edge_src.size(0), dtype=torch.int32, device=dg.device - ) - batch_neg_time = torch.empty( - batch.edge_src.size(0), dtype=torch.int64, device=dg.device - ) - batch_neg_batch_list = [] - else: - try: - neg_batch_list = self.neg_sampler.query_batch( - batch.edge_src, - batch.edge_dst, - batch.edge_time, - batch.edge_type, - split_mode=self.split_mode, - ) - except ValueError as e: - raise ValueError( - f'TKGL Negative sampling failed for split_mode={self.split_mode}. Try updating your TGB package: `pip install --upgrade py-tgb`' - ) from e - - batch_neg_batch_list = [ - torch.tensor(neg_batch, dtype=torch.int32, device=dg.device) - for neg_batch in neg_batch_list - ] - batch_neg = torch.unique(torch.cat(batch_neg_batch_list)) - # This is a heuristic. For our fake (negative) link times, - # we pick random time stamps within [batch.start_time, batch.end_time]. - # Using random times on the whole graph will likely produce information - # leakage, making the prediction easier than it should be. - - # Use generator to local constrain rng for reproducibility - gen = torch.Generator(device=dg.device) - gen.manual_seed(0) - batch_neg_time = torch.randint( - int(batch.edge_time.min().item()), - int(batch.edge_time.max().item()) + 1, - (batch_neg.size(0),), - device=dg.device, - generator=gen, - ) - - self.add_batch_attribute(batch, 'neg', batch_neg) - self.add_batch_attribute(batch, 'neg_batch_list', batch_neg_batch_list) - self.add_batch_attribute(batch, 'neg_time', batch_neg_time) - return batch diff --git a/tgm/hooks/negatives/__init__.py b/tgm/hooks/negatives/__init__.py new file mode 100644 index 00000000..68c76bfc --- /dev/null +++ b/tgm/hooks/negatives/__init__.py @@ -0,0 +1,6 @@ +from .sampler import RandomNegativeEdgeSamplerHook, HistoricalNegativeEdgeSamplerHook +from .tgb_sampler import ( + TGBNegativeEdgeSamplerHook, + TGBTHGNegativeEdgeSamplerHook, + TGBTKGNegativeEdgeSamplerHook, +) diff --git a/tgm/hooks/negatives/sampler.py b/tgm/hooks/negatives/sampler.py new file mode 100644 index 00000000..b327b14d --- /dev/null +++ b/tgm/hooks/negatives/sampler.py @@ -0,0 +1,231 @@ +from __future__ import annotations + +import torch + +from tgm.constants import PADDED_NODE_ID +from tgm.core import DGBatch, DGraph +from tgm.hooks.base import StatefulHook, StatelessHook +from tgm.util.logging import _get_logger + +logger = _get_logger(__name__) + + +class RandomNegativeEdgeSamplerHook(StatelessHook): + """Random sampling negative edges for dynamic link prediction. + + Args: + low (int): The minimum node id to sample + high (int) : The maximum node id to sample + neg_ratio (float): The ratio of sampled negative destination nodes + to the number of positive destination nodes (default = 1.0). + id (str): A unique identifier for the hook. The hook’s name and all attributes it produces will be suffixed with this `id`. + """ + + _cls_requires = {'edge_src', 'edge_dst', 'edge_time'} + _cls_produces = {'neg', 'neg_time'} + + def __init__( + self, low: int, high: int, neg_ratio: float = 1.0, id: str | None = None + ) -> None: + super().__init__() + if not 0 < neg_ratio <= 1: + raise ValueError(f'neg_ratio must be in (0, 1], got: {neg_ratio}') + if not low < high: + raise ValueError(f'low ({low}) must be strictly less than high ({high})') + self.low = low + self.high = high + self.neg_ratio = neg_ratio + self._id = id + self.__post_init__() + + def __call__(self, dg: DGraph, batch: DGBatch) -> DGBatch: + size = (round(self.neg_ratio * batch.edge_dst.size(0)),) + if size[0] == 0: + self.add_batch_attribute( + batch, 'neg', torch.empty(size, dtype=torch.int32, device=dg.device) + ) + self.add_batch_attribute( + batch, + 'neg_time', + torch.empty(size, dtype=torch.int64, device=dg.device), + ) + else: + self.add_batch_attribute( + batch, + 'neg', + torch.randint( + self.low, self.high, size, dtype=torch.int32, device=dg.device + ), + ) + self.add_batch_attribute(batch, 'neg_time', batch.edge_time.clone()) + return batch + + +class HistoricalNegativeEdgeSamplerHook(StatefulHook): + """Sample negative edges from past interactions for dynamic link prediction. + + Args: + id (str): A unique identifier for the hook. The hook’s name and all attributes it produces will be suffixed with this `id`. + + Notes: + If a node doesn't have past interactions, we return `PADDED_NODE_ID`(-1) as the negative destination. + `valid_neg_mask` (BoolTensor): Boolean mask of shape ``(num_neg,)`` indicating + which entries in ``neg`` are real negative samples. ``True`` means the + corresponding node id is a valid negative; ``False`` means the entry is + a padding placeholder (``PADDED_NODE_ID``) and should be excluded from + loss computation and evaluation. + """ + + _cls_requires = {'edge_src', 'edge_dst', 'edge_time'} + _cls_produces = {'neg', 'neg_time', 'valid_neg_mask'} + + def __init__( + self, + id: str | None = None, + ) -> None: + super().__init__() + + self._id = id + + self._memory: torch.Tensor | None = None + self._count: int = 0 + self.__post_init__() + + def reset_state(self) -> None: + self._memory = None + self._count = 0 + + def __call__(self, dg: DGraph, batch: DGBatch) -> DGBatch: + if self._count == 0: + neg = torch.full( + (batch.edge_dst.size(0),), + PADDED_NODE_ID, + dtype=batch.edge_dst.dtype, + device=dg.device, + ) + else: + neg = self._hist_sampling(dg, batch) + + neg_time = batch.edge_time.clone() + valid_neg_mask = neg != PADDED_NODE_ID + self._update_hst_memory(dg, batch) + + self.add_batch_attribute(batch, 'neg', neg) + self.add_batch_attribute(batch, 'neg_time', neg_time) + self.add_batch_attribute(batch, 'valid_neg_mask', valid_neg_mask) + return batch + + def _hist_sampling(self, dg: DGraph, batch: DGBatch) -> torch.Tensor: + """Sample negative destination nodes from each source node's historical interactions. + + For each source node in the batch, randomly selects a destination node from + its past interactions stored in memory. If a source node has no recorded past + interactions, its corresponding negative sample is set to PADDED_NODE_ID as + a sentinel value indicating no history is available. + + The random selection is performed via a vectorized scatter-max over random + weights assigned to each historical edge, avoiding explicit loops. + + Args: + dg (DGraph): The dynamic graph, used to determine the target device. + batch (DGBatch): The current batch of edges. + + Returns: + neg (torch.Tensor): Historically sampled negative destination nodes + of shape (batch_size,) and dtype int32. Nodes with no historical + interactions are set to PADDED_NODE_ID. + + Note: + Assumes self._memory is a tensor of shape (2, num_observed_edges) where + row 0 contains source nodes and row 1 contains destination nodes of all + previously observed edges. + """ + assert self._memory is not None + + mask = torch.isin(self._memory[0], batch.edge_src) + sampling_edges = self._memory[:, mask] + + # Group duplicate srcs: for each unique src, collect all batch positions + unique_srcs, inverse = torch.unique(batch.edge_src, return_inverse=True) + + unique_src_to_idx = torch.zeros( + (int(batch.edge_src.max().item()) + 1,), dtype=torch.long, device=dg.device + ) + unique_src_to_idx[unique_srcs] = torch.arange( + unique_srcs.size(0), device=dg.device + ) + + edge_to_unique_idx = unique_src_to_idx[sampling_edges[0]] + + sampling_edges_random_weights = torch.rand( + sampling_edges.size(1), device=dg.device + ) + best_weights = torch.full((unique_srcs.size(0),), -1.0, device=dg.device) + + best_weights.scatter_reduce_( + 0, edge_to_unique_idx, sampling_edges_random_weights, reduce='amax' + ) + best_edge_mask = ( + sampling_edges_random_weights == best_weights[edge_to_unique_idx] + ) + + # Sample one neg per unique src + neg_per_unique = torch.full( + (unique_srcs.size(0),), + PADDED_NODE_ID, + dtype=sampling_edges.dtype, + device=dg.device, + ) + neg_per_unique[edge_to_unique_idx[best_edge_mask]] = sampling_edges[ + 1, best_edge_mask + ] + + # Broadcast back to all batch positions (duplicates get the same sampled neg) + neg = neg_per_unique[inverse] + return neg + + def _update_hst_memory(self, dg: DGraph, batch: DGBatch) -> None: + """Append the current batch of edges to the historical memory buffer. + + Maintains a dynamically resizing memory buffer of observed edges for use + in historical negative sampling. The buffer doubles in size when capacity + is exceeded, ensuring expected O(1) time complexity insertion and amortized O(E) space complexity + where E is the total number of observed edges. + + Args: + dg (DGraph): The dynamic graph, used to determine the target device. + batch (DGBatch): The current batch of edges whose source and destination + nodes will be appended to memory. + + Note: + - Memory is lazily initialized on the first call with twice the initial + batch size as the starting capacity. + - When the buffer is full, it is expanded to the maximum of twice its + current size or twice the required size, ensuring no immediate + back-to-back resizes even for large batches. + - Only source and destination nodes are stored; edge timestamps are + not retained in memory. + - This scale linear w.r.t the number of interaction event rather than number of edges. + Since _memory can contain duplicated edges. + """ + batch_size = batch.edge_src.size(0) + + if self._memory is None: + self._memory = torch.zeros( + (2, batch_size * 2), dtype=torch.int32, device=dg.device + ) + + if self._count + batch_size > self._memory.size(1): + new_size = max(self._memory.size(1) * 2, (self._count + batch_size) * 2) + + edge_buffer = torch.zeros( + (2, new_size - self._memory.size(1)), + dtype=torch.int32, + device=dg.device, + ) + self._memory = torch.cat([self._memory, edge_buffer], dim=1) + + self._memory[0, self._count : self._count + batch_size] = batch.edge_src + self._memory[1, self._count : self._count + batch_size] = batch.edge_dst + + self._count += batch_size diff --git a/tgm/hooks/negatives/tgb_sampler.py b/tgm/hooks/negatives/tgb_sampler.py new file mode 100644 index 00000000..e605af28 --- /dev/null +++ b/tgm/hooks/negatives/tgb_sampler.py @@ -0,0 +1,297 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import torch + +from tgm.core import DGBatch, DGraph +from tgm.hooks.base import StatelessHook +from tgm.util.logging import _get_logger + +logger = _get_logger(__name__) + + +class TGBNegativeEdgeSamplerBase(StatelessHook): + """Base class for TGB pre-generated negative edge sampler hooks. + + Handles common logic for loading evaluation sets, querying negative samples, + and assembling batch attributes. Subclasses must implement ``_build_sampler`` + and ``_query_batch``. + + Args: + dataset_name (str): The name of the TGB dataset. + split_mode (str): The split mode to use for sampling, either 'val' or 'test'. + id (str | None): A unique identifier for the hook. + + Attributes produced: + neg (Tensor[int32]): Unique negative destination node ids across the batch. + neg_batch_list (list[Tensor[int32]]): Per-edge negative candidate lists, + aligned with ``batch.edge_src``. + neg_time (Tensor[int64]): Randomly sampled timestamps for each negative, + drawn uniformly from ``[batch.edge_time.min(), batch.edge_time.max()]`` + with a fixed seed for reproducibility. + """ + + _dataset_prefix: str # e.g. 'tgbl-', 'thgl-', 'tkgl-' + + def __init__( + self, dataset_name: str, split_mode: str, id: str | None = None + ) -> None: + super().__init__() + if split_mode not in ['val', 'test']: + raise ValueError(f'split_mode must be "val" or "test", got: {split_mode}') + + try: + from tgb.utils.info import DATA_VERSION_DICT, PROJ_DIR + except ImportError: + raise ImportError( + f'TGB required for {self.__class__.__name__}, try `pip install py-tgb`' + ) + + if not dataset_name.startswith(f'{self._dataset_prefix}-'): + raise ValueError( + 'TGBNegativeEdgeSamplerHook should only be registered for ' + f'"{self._dataset_prefix}-xxx" datasets, but got: {dataset_name}' + ) + + neg_sampler = self._build_sampler(dataset_name) + + # Load evaluation sets + root = Path(PROJ_DIR + 'datasets') / dataset_name.replace('-', '_') + if DATA_VERSION_DICT.get(dataset_name, 1) > 1: + version_suffix = f'_v{DATA_VERSION_DICT[dataset_name]}' + else: + version_suffix = '' + + ns_fname = root / f'{dataset_name}_{split_mode}_ns{version_suffix}.pkl' + logger.debug( + 'Loading %s split (neg_sampler.load_eval_set) for dataset: %s from file: %s', + split_mode, + dataset_name, + ns_fname, + ) + neg_sampler.load_eval_set(fname=str(ns_fname), split_mode=split_mode) + + self.neg_sampler = neg_sampler + self.split_mode = split_mode + self._id = id + self.__post_init__() + + def _build_sampler(self, dataset_name: str) -> Any: + """Instantiate and return the TGB negative sampler. Must be implemented by subclasses.""" + raise NotImplementedError + + def _query_batch(self, batch: DGBatch) -> list: + """Query the sampler for a batch. Must be implemented by subclasses.""" + raise NotImplementedError + + def __call__(self, dg: DGraph, batch: DGBatch) -> DGBatch: + if batch.edge_src.size(0) == 0: + batch_neg = torch.empty( + batch.edge_src.size(0), dtype=torch.int32, device=dg.device + ) + batch_neg_time = torch.empty( + batch.edge_src.size(0), dtype=torch.int64, device=dg.device + ) + batch_neg_batch_list = [] + else: + try: + neg_batch_list = self._query_batch(batch) + except ValueError as e: + raise ValueError( + f'{self._dataset_prefix.upper()} Negative sampling failed for split_mode={self.split_mode}. Try updating your TGB package: `pip install --upgrade py-tgb`' + ) from e + + batch_neg_batch_list = [ + torch.tensor(neg_batch, dtype=torch.int32, device=dg.device) + for neg_batch in neg_batch_list + ] + batch_neg = torch.unique(torch.cat(batch_neg_batch_list)) + + # This is a heuristic. For our fake (negative) link times, + # we pick random time stamps within [batch.start_time, batch.end_time]. + # Using random times on the whole graph will likely produce information + # leakage, making the prediction easier than it should be. + + # Use generator to local constrain rng for reproducibility + gen = torch.Generator(device=dg.device) + gen.manual_seed(0) + batch_neg_time = torch.randint( + int(batch.edge_time.min().item()), + int(batch.edge_time.max().item()) + 1, + (batch_neg.size(0),), + device=dg.device, + generator=gen, + ) + + self.add_batch_attribute(batch, 'neg', batch_neg) + self.add_batch_attribute(batch, 'neg_batch_list', batch_neg_batch_list) + self.add_batch_attribute(batch, 'neg_time', batch_neg_time) + return batch + + +class TGBNegativeEdgeSamplerHook(TGBNegativeEdgeSamplerBase): + """Load data from DGraph using pre-generated TGB negative samples. + Make sure to perform `dataset.load_val_ns()` or `dataset.load_test_ns()` before using this hook. + + Args: + dataset_name (str): The name of the TGB dataset to produce sampler for. + split_mode (str): The split mode to use for sampling, either 'val' or 'test'. + id (str): A unique identifier for the hook. The hook’s name and all attributes it produces will be suffixed with this `id`. + + Raises: + ValueError: If neg_sampler is not provided. + """ + + _cls_requires = {'edge_src', 'edge_dst', 'edge_time'} + _cls_produces = {'neg', 'neg_batch_list', 'neg_time'} + _dataset_prefix = 'tgbl' + + def _build_sampler(self, dataset_name: str) -> Any: + try: + from tgb.linkproppred.negative_sampler import NegativeEdgeSampler + except ImportError: + raise ImportError( + f'TGB required for {self.__class__.__name__}, try `pip install py-tgb`' + ) + return NegativeEdgeSampler(dataset_name=dataset_name) + + def _query_batch(self, batch: DGBatch) -> list: + return self.neg_sampler.query_batch( + batch.edge_src, + batch.edge_dst, + batch.edge_time, + split_mode=self.split_mode, + ) + + +class TGBTHGNegativeEdgeSamplerHook(TGBNegativeEdgeSamplerBase): + """Load data from DGraph using pre-generated TGB negative samples for heterogeneous graph. + Make sure to perform `dataset.load_val_ns()` or `dataset.load_test_ns()` before using this hook. + + Args: + dataset_name (str): The name of the TGB dataset to produce sampler for. + split_mode (str): The split mode to use for sampling, either 'val' or 'test'. + first_node_id (int): identity of the first node + last_node_id (int): identity of the last destination node + node_type (Tensor): the node type of each node + id (str): A unique identifier for the hook. The hook’s name and all attributes it produces will be suffixed with this `id`. + + + + Raises: + ValueError: If neg_sampler is not provided. + """ + + _cls_requires = {'edge_src', 'edge_dst', 'edge_time', 'edge_type'} + _cls_produces = {'neg', 'neg_batch_list', 'neg_time'} + _dataset_prefix = 'thgl' + + def __init__( + self, + dataset_name: str, + split_mode: str, + first_node_id: int, + last_node_id: int, + node_type: torch.Tensor, + id: str | None = None, + ) -> None: + if first_node_id < 0 or last_node_id < 0: + raise ValueError('First and last ID of node must be positive') + + if node_type is None: + raise ValueError('Node type must not be None') + + if node_type.shape[0] < last_node_id: + raise ValueError(f'last_node_id {last_node_id} must be within node_type') + + self._first_node_id = first_node_id + self._last_node_id = last_node_id + self._node_type = node_type + super().__init__(dataset_name, split_mode, id) + + def _build_sampler(self, dataset_name: str) -> Any: + try: + from tgb.linkproppred.thg_negative_sampler import ( + THGNegativeEdgeSampler, + ) + except ImportError: + raise ImportError( + f'TGB required for {self.__class__.__name__}, try `pip install py-tgb`' + ) + + return THGNegativeEdgeSampler( + dataset_name=dataset_name, + first_node_id=self._first_node_id, + last_node_id=self._last_node_id, + node_type=self._node_type.numpy(), + ) + + def _query_batch(self, batch: DGBatch) -> list: + return self.neg_sampler.query_batch( + batch.edge_src, + batch.edge_dst, + batch.edge_time, + batch.edge_type, + split_mode=self.split_mode, + ) + + +class TGBTKGNegativeEdgeSamplerHook(TGBNegativeEdgeSamplerBase): + """Load data from DGraph using pre-generated TGB negative samples for knowledge graph. + Make sure to perform `dataset.load_val_ns()` or `dataset.load_test_ns()` before using this hook. + + Args: + dataset_name (str): The name of the TGB dataset to produce sampler for. + split_mode (str): The split mode to use for sampling, either 'val' or 'test'. + first_dst_id (int): identity of the first destination node + last_dst_id (int): identity of the last destination node + id (str): A unique identifier for the hook. The hook’s name and all attributes it produces will be suffixed with this `id`. + + + Raises: + ValueError: If neg_sampler is not provided. + """ + + _cls_requires = {'edge_src', 'edge_dst', 'edge_time', 'edge_type'} + _cls_produces = {'neg', 'neg_batch_list', 'neg_time'} + _dataset_prefix = 'tkgl' + + def __init__( + self, + dataset_name: str, + split_mode: str, + first_dst_id: int, + last_dst_id: int, + id: str | None = None, + ) -> None: + if first_dst_id < 0 or last_dst_id < 0: + raise ValueError('First and last ID of node must be positive') + self._first_dst_id = first_dst_id + self._last_dst_id = last_dst_id + super().__init__(dataset_name, split_mode, id) + + def _build_sampler(self, dataset_name: str) -> Any: + try: + from tgb.linkproppred.tkg_negative_sampler import ( + TKGNegativeEdgeSampler, + ) + except ImportError: + raise ImportError( + f'TGB required for {self.__class__.__name__}, try `pip install py-tgb`' + ) + return TKGNegativeEdgeSampler( + dataset_name=dataset_name, + first_dst_id=self._first_dst_id, + last_dst_id=self._last_dst_id, + ) + + def _query_batch(self, batch: DGBatch) -> list: + return self.neg_sampler.query_batch( + batch.edge_src, + batch.edge_dst, + batch.edge_time, + batch.edge_type, + split_mode=self.split_mode, + ) diff --git a/tgm/hooks/neighbors/__init__.py b/tgm/hooks/neighbors/__init__.py new file mode 100644 index 00000000..8b5719d3 --- /dev/null +++ b/tgm/hooks/neighbors/__init__.py @@ -0,0 +1,2 @@ +from .recency import RecencyNeighborHook +from .uniform import NeighborSamplerHook diff --git a/tgm/hooks/neighbors.py b/tgm/hooks/neighbors/recency.py similarity index 66% rename from tgm/hooks/neighbors.py rename to tgm/hooks/neighbors/recency.py index 53e2e298..3fbcc351 100644 --- a/tgm/hooks/neighbors.py +++ b/tgm/hooks/neighbors/recency.py @@ -7,210 +7,12 @@ from tgm import DGBatch, DGraph from tgm.constants import PADDED_NODE_ID -from tgm.core._storage import DGSliceTracker -from tgm.hooks import SeedableHook, StatefulHook, StatelessHook +from tgm.hooks import SeedableHook, StatefulHook from tgm.util.logging import _get_logger logger = _get_logger(__name__) -class NeighborSamplerHook(StatelessHook, SeedableHook): - """Load data from DGraph using a memory based sampling function. - - Args: - num_nbrs (List[int]): Number of neighbors to sample at each hop (-1 to keep all) - directed (bool): If true, aggregates interactions in edge_src->edge_dst direction only (default=False). - seed_nodes_keys ([List[str]): List of batch attribute keys to identify the initial seed nodes to sample for. - seed_times_keys ([List[str]): List of batch attribute keys to identify the initial seed times to sample for. - id (str): A unique identifier for the hook. The hook’s name and all attributes it produces will be suffixed with this `id`. - - - Note: - The order of the output tensors respect the order of seed_nodes_keys. - For instance, for seed node keys ['edge_src', 'edge_dst', 'neg'] will have the first output index (hop 0) contain the concatenation - of batch.edge_src, batch.edge_dst, batch.neg (in that order). The next index (hop 1) will contain first-hop neighbors of batch.edge_src - followed by first-hop neighbors of batch.edge_dst, and then those of batch.neg. This pattern repeats for deeper hops. - - Raises: - ValueError: If the num_nbrs list is empty or has non-positive entries. - ValueError: If len(seed_nodes_keys) != len(seed_times_keys). - """ - - _cls_requires = {'edge_src', 'edge_dst', 'edge_time'} - _cls_produces = { - 'seed_nids', - 'seed_times', - 'nbr_nids', - 'nbr_edge_time', - 'nbr_edge_x', - 'seed_node_nbr_mask', - } - - def __init__( - self, - num_nbrs: List[int], - seed_nodes_keys: List[str], - seed_times_keys: List[str], - directed: bool = False, - id: str | None = None, - ) -> None: - super().__init__() - if not len(num_nbrs): - raise ValueError('num_nbrs must be non-empty') - if not all([isinstance(x, int) and (x > 0) for x in num_nbrs]): - raise ValueError('Each value in num_nbrs must be a positive integer') - self._num_nbrs = num_nbrs - self._directed = directed - - if len(seed_nodes_keys) != len(seed_times_keys): - raise ValueError( - f'len(seed_nodes_keys) ({len(seed_nodes_keys)}) ' - f'!= len(seed_times_keys) ({len(seed_times_keys)})\n' - f'seed_nodes_keys={seed_nodes_keys}, ' - f'seed_times_keys={seed_times_keys}' - ) - self._seed_nodes_keys = seed_nodes_keys - self._seed_times_keys = seed_times_keys - logger.debug( - 'Seed nodes keys: %s, Seed times keys: %s', - self._seed_nodes_keys, - self._seed_times_keys, - ) - self._warned_seed_None = False - self._id = id - self.seed_keys = seed_nodes_keys - self.__post_init__() - - @property - def num_nbrs(self) -> List[int]: - return self._num_nbrs - - def __call__(self, dg: DGraph, batch: DGBatch) -> DGBatch: - batch_seed_nids, batch_seed_times = [], [] - batch_nbr_nids, batch_nbr_edge_time = [], [] - batch_nbr_edge_x = [] - - def _append_empty_hop() -> None: - batch_seed_nids.append(torch.empty(0, dtype=torch.int32)) - batch_seed_times.append(torch.empty(0, dtype=torch.int64)) - batch_nbr_nids.append(torch.empty(0, dtype=torch.int32)) - batch_nbr_edge_time.append(torch.empty(0, dtype=torch.int64)) - batch_nbr_edge_x.append( - torch.empty(0, dg.edge_x_dim).float() # type: ignore - ) - - seed_nodes, seed_times, seed_node_nbr_mask = self._get_seed_tensors(batch) - if not seed_nodes.numel(): - logger.debug('No seed_nodes found, appending empty hop information') - for _ in self.num_nbrs: - _append_empty_hop() - - else: - for hop, num_nbrs in enumerate(self.num_nbrs): - if hop > 0: - seed_nodes = batch_nbr_nids[hop - 1].flatten() - seed_times = batch_nbr_edge_time[hop - 1].flatten() - - # TODO: Storage needs to use the right device - - # We slice on batch.start_time so that we only consider neighbor events - # that occurred strictly before this batch - logger.debug( - 'Getting uniform nbrs for hop %d with %d seed nodes', - hop, - seed_nodes.numel(), - ) - nbr_nids, nbr_edge_time, nbr_edge_x = dg._storage.get_nbrs( - seed_nodes, - num_nbrs=num_nbrs, - slice=DGSliceTracker(end_time=int(batch.edge_time.min()) - 1), - directed=self._directed, - ) - - batch_seed_nids.append(seed_nodes) - batch_seed_times.append(seed_times) - batch_nbr_nids.append(nbr_nids) - batch_nbr_edge_time.append(nbr_edge_time) - batch_nbr_edge_x.append(nbr_edge_x) - - self.add_batch_attribute(batch, 'seed_nids', batch_seed_nids) - self.add_batch_attribute(batch, 'seed_times', batch_seed_times) - self.add_batch_attribute(batch, 'nbr_nids', batch_nbr_nids) - self.add_batch_attribute(batch, 'nbr_edge_time', batch_nbr_edge_time) - self.add_batch_attribute(batch, 'nbr_edge_x', batch_nbr_edge_x) - self.add_batch_attribute(batch, 'seed_node_nbr_mask', seed_node_nbr_mask) - - return batch - - def _get_seed_tensors( - self, batch: DGBatch - ) -> Tuple[torch.Tensor, torch.Tensor, dict]: - device = batch.edge_src.device - seeds, seed_times = [], [] - seed_node_mask = dict() - - offset = 0 - for node_attr, time_attr in zip(self._seed_nodes_keys, self._seed_times_keys): - missing = [ - attr for attr in (node_attr, time_attr) if not hasattr(batch, attr) - ] - if missing: - raise ValueError(f'Missing seed attributes {missing} on batch') - - seed = getattr(batch, node_attr) - time = getattr(batch, time_attr) - - for name, tensor in [(node_attr, seed), (time_attr, time)]: - # We recover from tensor = None, since the current batch could just - # be missing certain attributes (e.g. dynamic node events), but for - # non-Tensor and non-None attrs we explicitly raise - if tensor is None: - logger.debug( - 'Seed attribute %s is None on this batch, skipping', name - ) - if not self._warned_seed_None: - warnings.warn( - f'Seed attribute {name} is None on this batch, skipping this batch. ' - 'Future occurrences will also be skipped but the warning will be suppressed', - UserWarning, - ) - self._warned_seed_None = True - break - if not isinstance(tensor, torch.Tensor): - raise ValueError(f'{name} must be a Tensor, got {type(tensor)}') - if tensor.ndim != 1: - raise ValueError(f'{name} must be 1-D, got shape {tensor.shape}') - - # Bounds checks - # TODO: Infer self._num_nodes from underlying graph - self._num_nodes = float('inf') - if name == node_attr: - if (tensor < 0).any() or (tensor >= self._num_nodes).any(): - raise ValueError( - f'Seed nodes in {name} must satisfy 0 <= x < {self._num_nodes}, ' - f'got values in range [{tensor.min().item()}, {tensor.max().item()}]' - ) - seeds.append(seed.to(device)) - num_seed_nodes = tensor.shape[0] - seed_node_mask[name] = torch.arange( - offset, offset + num_seed_nodes, device=device - ) - offset += num_seed_nodes - elif name == time_attr: - if (tensor < 0).any(): - raise ValueError( - f'Seed times in {name} must be >= 0, got min value: {tensor.min().item()}' - ) - seed_times.append(time.to(device)) - - if seeds and seed_times: - seed_nodes, seed_times = torch.cat(seeds), torch.cat(seed_times) # type: ignore - else: - seed_nodes = torch.empty(0, dtype=torch.int32, device=device) - seed_times = torch.empty(0, dtype=torch.int64, device=device) # type: ignore - return seed_nodes, seed_times, seed_node_mask # type: ignore - - class RecencyNeighborHook(StatefulHook, SeedableHook): """Load neighbors from DGraph using a recency sampling. Each node maintains a fixed number of recent neighbors. diff --git a/tgm/hooks/neighbors/uniform.py b/tgm/hooks/neighbors/uniform.py new file mode 100644 index 00000000..735b98ef --- /dev/null +++ b/tgm/hooks/neighbors/uniform.py @@ -0,0 +1,210 @@ +from __future__ import annotations + +import warnings +from typing import List, Tuple + +import torch + +from tgm import DGBatch, DGraph +from tgm.core._storage import DGSliceTracker +from tgm.hooks import SeedableHook, StatelessHook +from tgm.util.logging import _get_logger + +logger = _get_logger(__name__) + + +class NeighborSamplerHook(StatelessHook, SeedableHook): + """Load neighbors from DGraph using a memory based sampling function. + + Args: + num_nbrs (List[int]): Number of neighbors to sample at each hop (-1 to keep all) + directed (bool): If true, aggregates interactions in edge_src->edge_dst direction only (default=False). + seed_nodes_keys ([List[str]): List of batch attribute keys to identify the initial seed nodes to sample for. + seed_times_keys ([List[str]): List of batch attribute keys to identify the initial seed times to sample for. + id (str): A unique identifier for the hook. The hook’s name and all attributes it produces will be suffixed with this `id`. + + + Note: + The order of the output tensors respect the order of seed_nodes_keys. + For instance, for seed node keys ['edge_src', 'edge_dst', 'neg'] will have the first output index (hop 0) contain the concatenation + of batch.edge_src, batch.edge_dst, batch.neg (in that order). The next index (hop 1) will contain first-hop neighbors of batch.edge_src + followed by first-hop neighbors of batch.edge_dst, and then those of batch.neg. This pattern repeats for deeper hops. + + Raises: + ValueError: If the num_nbrs list is empty or has non-positive entries. + ValueError: If len(seed_nodes_keys) != len(seed_times_keys). + """ + + _cls_requires = {'edge_src', 'edge_dst', 'edge_time'} + _cls_produces = { + 'seed_nids', + 'seed_times', + 'nbr_nids', + 'nbr_edge_time', + 'nbr_edge_x', + 'seed_node_nbr_mask', + } + + def __init__( + self, + num_nbrs: List[int], + seed_nodes_keys: List[str], + seed_times_keys: List[str], + directed: bool = False, + id: str | None = None, + ) -> None: + super().__init__() + if not len(num_nbrs): + raise ValueError('num_nbrs must be non-empty') + if not all([isinstance(x, int) and (x > 0) for x in num_nbrs]): + raise ValueError('Each value in num_nbrs must be a positive integer') + self._num_nbrs = num_nbrs + self._directed = directed + + if len(seed_nodes_keys) != len(seed_times_keys): + raise ValueError( + f'len(seed_nodes_keys) ({len(seed_nodes_keys)}) ' + f'!= len(seed_times_keys) ({len(seed_times_keys)})\n' + f'seed_nodes_keys={seed_nodes_keys}, ' + f'seed_times_keys={seed_times_keys}' + ) + self._seed_nodes_keys = seed_nodes_keys + self._seed_times_keys = seed_times_keys + logger.debug( + 'Seed nodes keys: %s, Seed times keys: %s', + self._seed_nodes_keys, + self._seed_times_keys, + ) + self._warned_seed_None = False + self._id = id + self.seed_keys = seed_nodes_keys + self.__post_init__() + + @property + def num_nbrs(self) -> List[int]: + return self._num_nbrs + + def __call__(self, dg: DGraph, batch: DGBatch) -> DGBatch: + batch_seed_nids, batch_seed_times = [], [] + batch_nbr_nids, batch_nbr_edge_time = [], [] + batch_nbr_edge_x = [] + + def _append_empty_hop() -> None: + batch_seed_nids.append(torch.empty(0, dtype=torch.int32)) + batch_seed_times.append(torch.empty(0, dtype=torch.int64)) + batch_nbr_nids.append(torch.empty(0, dtype=torch.int32)) + batch_nbr_edge_time.append(torch.empty(0, dtype=torch.int64)) + batch_nbr_edge_x.append( + torch.empty(0, dg.edge_x_dim).float() # type: ignore + ) + + seed_nodes, seed_times, seed_node_nbr_mask = self._get_seed_tensors(batch) + if not seed_nodes.numel(): + logger.debug('No seed_nodes found, appending empty hop information') + for _ in self.num_nbrs: + _append_empty_hop() + + else: + for hop, num_nbrs in enumerate(self.num_nbrs): + if hop > 0: + seed_nodes = batch_nbr_nids[hop - 1].flatten() + seed_times = batch_nbr_edge_time[hop - 1].flatten() + + # TODO: Storage needs to use the right device + + # We slice on batch.start_time so that we only consider neighbor events + # that occurred strictly before this batch + logger.debug( + 'Getting uniform nbrs for hop %d with %d seed nodes', + hop, + seed_nodes.numel(), + ) + nbr_nids, nbr_edge_time, nbr_edge_x = dg._storage.get_nbrs( + seed_nodes, + num_nbrs=num_nbrs, + slice=DGSliceTracker(end_time=int(batch.edge_time.min()) - 1), + directed=self._directed, + ) + + batch_seed_nids.append(seed_nodes) + batch_seed_times.append(seed_times) + batch_nbr_nids.append(nbr_nids) + batch_nbr_edge_time.append(nbr_edge_time) + batch_nbr_edge_x.append(nbr_edge_x) + + self.add_batch_attribute(batch, 'seed_nids', batch_seed_nids) + self.add_batch_attribute(batch, 'seed_times', batch_seed_times) + self.add_batch_attribute(batch, 'nbr_nids', batch_nbr_nids) + self.add_batch_attribute(batch, 'nbr_edge_time', batch_nbr_edge_time) + self.add_batch_attribute(batch, 'nbr_edge_x', batch_nbr_edge_x) + self.add_batch_attribute(batch, 'seed_node_nbr_mask', seed_node_nbr_mask) + + return batch + + def _get_seed_tensors( + self, batch: DGBatch + ) -> Tuple[torch.Tensor, torch.Tensor, dict]: + device = batch.edge_src.device + seeds, seed_times = [], [] + seed_node_mask = dict() + + offset = 0 + for node_attr, time_attr in zip(self._seed_nodes_keys, self._seed_times_keys): + missing = [ + attr for attr in (node_attr, time_attr) if not hasattr(batch, attr) + ] + if missing: + raise ValueError(f'Missing seed attributes {missing} on batch') + + seed = getattr(batch, node_attr) + time = getattr(batch, time_attr) + + for name, tensor in [(node_attr, seed), (time_attr, time)]: + # We recover from tensor = None, since the current batch could just + # be missing certain attributes (e.g. dynamic node events), but for + # non-Tensor and non-None attrs we explicitly raise + if tensor is None: + logger.debug( + 'Seed attribute %s is None on this batch, skipping', name + ) + if not self._warned_seed_None: + warnings.warn( + f'Seed attribute {name} is None on this batch, skipping this batch. ' + 'Future occurrences will also be skipped but the warning will be suppressed', + UserWarning, + ) + self._warned_seed_None = True + break + if not isinstance(tensor, torch.Tensor): + raise ValueError(f'{name} must be a Tensor, got {type(tensor)}') + if tensor.ndim != 1: + raise ValueError(f'{name} must be 1-D, got shape {tensor.shape}') + + # Bounds checks + # TODO: Infer self._num_nodes from underlying graph + self._num_nodes = float('inf') + if name == node_attr: + if (tensor < 0).any() or (tensor >= self._num_nodes).any(): + raise ValueError( + f'Seed nodes in {name} must satisfy 0 <= x < {self._num_nodes}, ' + f'got values in range [{tensor.min().item()}, {tensor.max().item()}]' + ) + seeds.append(seed.to(device)) + num_seed_nodes = tensor.shape[0] + seed_node_mask[name] = torch.arange( + offset, offset + num_seed_nodes, device=device + ) + offset += num_seed_nodes + elif name == time_attr: + if (tensor < 0).any(): + raise ValueError( + f'Seed times in {name} must be >= 0, got min value: {tensor.min().item()}' + ) + seed_times.append(time.to(device)) + + if seeds and seed_times: + seed_nodes, seed_times = torch.cat(seeds), torch.cat(seed_times) # type: ignore + else: + seed_nodes = torch.empty(0, dtype=torch.int32, device=device) + seed_times = torch.empty(0, dtype=torch.int64, device=device) # type: ignore + return seed_nodes, seed_times, seed_node_mask # type: ignore diff --git a/tgm/hooks/recipe.py b/tgm/hooks/recipe.py index 75541acf..f967321c 100644 --- a/tgm/hooks/recipe.py +++ b/tgm/hooks/recipe.py @@ -5,7 +5,7 @@ from tgm.exceptions import UndefinedRecipe from tgm.hooks import ( HookManager, - NegativeEdgeSamplerHook, + RandomNegativeEdgeSamplerHook, TGBNegativeEdgeSamplerHook, hook_manager, ) @@ -62,7 +62,8 @@ def build_tgb_link_pred(dataset_name: str, train_dg: DGraph) -> HookManager: dst = train_dg.edge_dst hm = HookManager(keys=['train', 'val', 'test']) hm.register( - 'train', NegativeEdgeSamplerHook(low=int(dst.min()), high=int(dst.max())) + 'train', + RandomNegativeEdgeSamplerHook(low=int(dst.min()), high=int(dst.max())), ) hm.register('val', TGBNegativeEdgeSamplerHook(dataset_name, split_mode='val')) hm.register('test', TGBNegativeEdgeSamplerHook(dataset_name, split_mode='test'))