diff --git a/.gitignore b/.gitignore index d6cb73df..bcef43ca 100644 --- a/.gitignore +++ b/.gitignore @@ -17,7 +17,7 @@ ipython_config.py # Remove previous ipynb_checkpoints # git rm -r .ipynb_checkpoints/ - +outputs/ ### Python ### # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/examples/linkproppred/tgat.py b/examples/linkproppred/tgat.py index 8d1ac197..3cb8c48e 100644 --- a/examples/linkproppred/tgat.py +++ b/examples/linkproppred/tgat.py @@ -1,4 +1,6 @@ import argparse +import os +import random import numpy as np import torch @@ -50,6 +52,18 @@ parser.add_argument( '--log-file-path', type=str, default=None, help='Optional path to write logs' ) +parser.add_argument( + '--checkpoint-dir', + type=str, + default='outputs/checkpoints', + help='Directory to save checkpoints', +) +parser.add_argument( + '--resume', + type=str, + default=None, + help='Path to a specific checkpoint to resume from', +) args = parser.parse_args() enable_logging(log_file_path=args.log_file_path) @@ -140,15 +154,19 @@ def train( encoder: nn.Module, decoder: nn.Module, opt: torch.optim.Optimizer, + epoch: int, + hm: object, + nbr_hook: object, + start_batch: int = -1, ) -> float: encoder.train() decoder.train() total_loss = 0 static_node_x = loader.dgraph.static_node_x - for batch in tqdm(loader): + for local_idx, batch in enumerate(tqdm(loader)): + batch_idx = local_idx + start_batch + 1 opt.zero_grad() - z = encoder(batch, static_node_x) z_src, z_dst, z_neg = torch.chunk(z, 3) @@ -160,6 +178,13 @@ def train( loss.backward() opt.step() total_loss += float(loss) + + print( + f'epoch {epoch} batch {batch_idx} loss {float(loss):.6f} write_pos_sum {nbr_hook._write_pos.sum().item():.0f}' + ) + + if batch_idx % 100 == 0: + save_checkpoint(epoch, batch_idx, encoder, decoder, opt, hm) return total_loss @@ -235,7 +260,6 @@ def eval( train_key, val_key, test_key = hm.keys hm.register_shared(nbr_hook) -train_loader = DGDataLoader(train_dg, args.bsize, hook_manager=hm) val_loader = DGDataLoader(val_dg, args.bsize, hook_manager=hm) test_loader = DGDataLoader(test_dg, args.bsize, hook_manager=hm) @@ -252,14 +276,87 @@ def eval( args.device ) opt = torch.optim.Adam( - set(encoder.parameters()) | set(decoder.parameters()), lr=float(args.lr) + list(encoder.parameters()) + list(decoder.parameters()), lr=float(args.lr) ) best_val = 0.0 -for epoch in range(1, args.epochs + 1): + +def save_checkpoint(epoch, batch_idx, encoder, decoder, opt, hm): + os.makedirs(args.checkpoint_dir, exist_ok=True) + path = os.path.join(args.checkpoint_dir, f'ckpt_e{epoch}_b{batch_idx}.pt') + torch.save( + { + 'epoch': epoch, + 'batch_idx': batch_idx, + 'encoder': encoder.state_dict(), + 'decoder': decoder.state_dict(), + 'opt': opt.state_dict(), + 'hm': hm.state_dict(), + 'rng_torch': torch.get_rng_state(), + 'rng_numpy': np.random.get_state(), + 'rng_python': random.getstate(), + 'best_val': best_val, + }, + path, + ) + + +def load_checkpoint(path, encoder, decoder, opt, hm): + ckpt = torch.load(path, weights_only=False) + encoder.load_state_dict(ckpt['encoder']) + decoder.load_state_dict(ckpt['decoder']) + try: + opt.load_state_dict(ckpt['opt']) + except (ValueError, RuntimeError) as e: + print( + f'Warning: skipping optimizer state (shape mismatch — delete outputs/checkpoints/ and rerun): {e}' + ) + hm.load_state_dict(ckpt['hm']) + torch.set_rng_state(ckpt['rng_torch']) + np.random.set_state(ckpt['rng_numpy']) + return ckpt['epoch'], ckpt['batch_idx'], ckpt.get('best_val', 0.0) + + +def find_latest_checkpoint(directory): + if not os.path.isdir(directory): + return None + pts = [ + os.path.join(directory, f) for f in os.listdir(directory) if f.endswith('.pt') + ] + if not pts: + return None + return max(pts, key=os.path.getmtime) + + +start_epoch = 1 +start_batch = -1 +ckpt_path = ( + find_latest_checkpoint(args.resume) + if args.resume and os.path.isdir(args.resume) + else args.resume +) +if ckpt_path is not None: + start_epoch, start_batch, best_val = load_checkpoint( + ckpt_path, encoder, decoder, opt, hm + ) + +for epoch in range(start_epoch, args.epochs + 1): + skip = (start_batch + 1) if epoch == start_epoch else 0 + train_loader = DGDataLoader( + train_dg, args.bsize, hook_manager=hm, skip_batches=skip + ) with hm.activate(train_key): - loss = train(train_loader, encoder, decoder, opt) + loss = train( + train_loader, + encoder, + decoder, + opt, + epoch, + hm, + nbr_hook, + start_batch if epoch == start_epoch else -1, + ) with hm.activate(val_key): val_mrr = eval(val_loader, encoder, decoder, evaluator) @@ -268,6 +365,7 @@ def eval( if val_mrr > best_val: best_val = val_mrr + log_metric('Best Validation', best_val, epoch=epoch) with hm.activate(test_key): test_mrr = eval(test_loader, encoder, decoder, evaluator) log_metric(f'Test {METRIC_TGB_LINKPROPPRED}', test_mrr, epoch=args.epochs) diff --git a/test/unit/test_hooks/test_state_hook.py b/test/unit/test_hooks/test_state_hook.py new file mode 100644 index 00000000..939ca84e --- /dev/null +++ b/test/unit/test_hooks/test_state_hook.py @@ -0,0 +1,185 @@ +import pytest +import torch + +from tgm import DGraph +from tgm.data import DGData, DGDataLoader +from tgm.hooks import HookManager, RecencyNeighborHook +from tgm.hooks.base import StatefulHook + + +@pytest.fixture +def dg(): + edge_index = torch.IntTensor( + [ + [1, 2], + [1, 2], + [2, 3], + ] + ) + edge_time = torch.LongTensor([1, 1, 2]) + edge_x = torch.rand(3, 4) + + node_x_time = torch.LongTensor([5, 5, 6]) + node_x_nids = torch.IntTensor([2, 2, 3]) + node_x = torch.rand(3, 3) + + data = DGData.from_raw( + edge_time=edge_time, + edge_index=edge_index, + edge_x=edge_x, + node_x_time=node_x_time, + node_x_nids=node_x_nids, + node_x=node_x, + ) + return DGraph(data) + + +@pytest.fixture +def recency_hook(dg): + return RecencyNeighborHook( + num_nbrs=[2], + num_nodes=dg.num_nodes, + seed_nodes_keys=['edge_src', 'edge_dst'], + seed_times_keys=['edge_time', 'edge_time'], + ) + + +class _CounterHook(StatefulHook): + produces = set() + requires = set() + + def __init__(self): + self.counter = 0 + + def __call__(self, dg, batch): + self.counter += 1 + return batch + + def reset_state(self): + self.counter = 0 + + def state_dict(self): + return {'counter': self.counter} + + def load_state_dict(self, state): + self.counter = state['counter'] + + +def test_stateful_hook_state_dict_raises_if_not_implemented(): + class ForgottenHook(StatefulHook): + pass + + with pytest.raises(NotImplementedError): + ForgottenHook().state_dict() + + +def test_stateful_hook_load_state_dict_raises_if_not_implemented(): + class ForgottenHook(StatefulHook): + pass + + with pytest.raises(NotImplementedError): + ForgottenHook().load_state_dict({}) + + +def test_recency_hook_state_dict_contains_required_keys(dg, recency_hook): + hm = HookManager(keys=['train']) + hm.register_shared(recency_hook) + loader = DGDataLoader(dg, batch_size=1, hook_manager=hm) + + with hm.activate('train'): + for i, _ in enumerate(loader): + if i == 1: + break + + state = recency_hook.state_dict() + + assert '_nbr_ids' in state + assert '_nbr_times' in state + assert '_nbr_feats' in state + assert '_write_pos' in state + assert '_edge_x_dim' in state + assert '_need_to_initialize_nbr_feats' in state + + +def test_recency_hook_state_dict_load_state_dict_roundtrip(dg, recency_hook): + hm = HookManager(keys=['train']) + hm.register_shared(recency_hook) + loader = DGDataLoader(dg, batch_size=1, hook_manager=hm) + + with hm.activate('train'): + for i, _ in enumerate(loader): + if i == 1: + break + + write_pos_before = recency_hook._write_pos.clone() + nbr_ids_before = recency_hook._nbr_ids.clone() + + state = recency_hook.state_dict() + + recency_hook.reset_state() + assert not torch.equal(recency_hook._write_pos, write_pos_before) + + recency_hook.load_state_dict(state) + + assert torch.equal(recency_hook._write_pos, write_pos_before) + assert torch.equal(recency_hook._nbr_ids, nbr_ids_before) + + +def test_hook_manager_state_dict_saves_stateful_hook(): + hm = HookManager(keys=['train']) + hook = _CounterHook() + hook.counter = 42 + hm.register_shared(hook) + + states = hm.state_dict('train') + + assert len(states) >= 1 + saved_counter = list(states.values())[0]['counter'] + assert saved_counter == 42 + + +def test_hook_manager_state_dict_no_duplicate_saves(): + hm = HookManager(keys=['train']) + hook = _CounterHook() + hm.register_shared(hook) + + states = hm.state_dict('train') + + assert len(states) == 1 + + +def test_hook_manager_load_state_dict_restores_hook(): + hm = HookManager(keys=['train']) + hook = _CounterHook() + hm.register_shared(hook) + + hook.counter = 99 + states = hm.state_dict('train') + + hook.counter = 0 + hm.load_state_dict(states, 'train') + + assert hook.counter == 99 + + +def test_skip_batches_reduces_yielded_count(dg): + full_batches = list(DGDataLoader(dg, batch_size=1)) + skip_batches = list(DGDataLoader(dg, batch_size=1, skip_batches=1)) + + assert len(skip_batches) == len(full_batches) - 1 + + +def test_skip_batches_hook_not_executed(dg): + hm = HookManager(keys=['train']) + hook = _CounterHook() + hm.register('train', hook) + + total = len(list(DGDataLoader(dg, batch_size=1))) + + skip = 1 + loader = DGDataLoader(dg, batch_size=1, hook_manager=hm, skip_batches=skip) + + with hm.activate('train'): + list(loader) + + assert hook.counter == total - skip diff --git a/tgm/data/loader.py b/tgm/data/loader.py index fe427638..a5510278 100644 --- a/tgm/data/loader.py +++ b/tgm/data/loader.py @@ -61,7 +61,7 @@ def __iter__(self) -> Iterator[Any]: yield batch -class DGDataLoader(_SkippableDataLoaderMixin, torch.utils.data.DataLoader): # type: ignore +class DGDataLoader(_SkippableDataLoaderMixin, torch.utils.data.DataLoader): """Iterate and materialize batches from a `DGraph`. This DataLoader supports both event-ordered and time-ordered temporal graphs. @@ -105,6 +105,7 @@ def __init__( batch_unit: str = 'r', on_empty: Literal['skip', 'raise', None] = 'skip', hook_manager: HookManager | None = None, + skip_batches: int = 0, **kwargs: Any, ) -> None: if batch_size <= 0: @@ -138,6 +139,8 @@ def __init__( self._dg = dg self._batch_size = batch_size self._hook_manager = hook_manager + self._skip_batches = skip_batches + self._batch_idx = 0 if batch_time_delta.is_event_ordered: self._slice_op = dg.slice_events @@ -159,7 +162,7 @@ def __call__(self, slice_start: List[int]) -> DGBatch: slice_end = slice_start[0] + self._batch_size dg = self._slice_op(slice_start[0], slice_end) batch = dg.materialize() - if self._hook_manager is not None: + if self._hook_manager is not None and self._batch_idx >= self._skip_batches: logger.debug( 'Applying hooks to batch %s [%d:%d)', self._slice_op.__name__, @@ -169,6 +172,25 @@ def __call__(self, slice_start: List[int]) -> DGBatch: batch = self._hook_manager.execute_active_hooks(dg, batch) return batch + def _get_iterator(self) -> Any: + # PyTorch's _BaseDataLoaderIter.__init__ consumes 2 global RNG samples for + # an internal base_seed (used only for worker processes). Save and restore + # around iterator creation so the global RNG state is unaffected, making + # checkpoint resume produce identical results to an uninterrupted run. + rng_state = torch.get_rng_state() + iterator = super()._get_iterator() + torch.set_rng_state(rng_state) + return iterator + + def __iter__(self) -> Iterator[DGBatch]: # type: ignore[override] + self._batch_idx = 0 + for batch in super().__iter__(): + if self._batch_idx < self._skip_batches: + self._batch_idx += 1 + continue + self._batch_idx += 1 + yield batch + @property def dgraph(self) -> DGraph: return self._dg diff --git a/tgm/hooks/base.py b/tgm/hooks/base.py index a1d7c9c0..e4f40755 100644 --- a/tgm/hooks/base.py +++ b/tgm/hooks/base.py @@ -17,6 +17,10 @@ def __call__(self, dg: DGraph, batch: DGBatch) -> DGBatch: ... def reset_state(self) -> None: ... + def state_dict(self) -> dict: ... + + def load_state_dict(self, state: dict) -> None: ... + class StatelessHook: """Base class for hooks without internal state.""" @@ -31,6 +35,12 @@ def __call__(self, dg: DGraph, batch: DGBatch) -> DGBatch: def reset_state(self) -> None: pass + def state_dict(self) -> dict: + return {} + + def load_state_dict(self, state: dict) -> None: + pass + class StatefulHook: """Base class for hooks that maintain internal state.""" @@ -38,3 +48,19 @@ class StatefulHook: requires: Set[str] = set() produces: Set[str] = set() has_state: bool = True + + def state_dict(self) -> dict: + """Return the hook's state as a serializable dict.""" + raise NotImplementedError( + f'{self.__class__.__name__} has has_state=True ' + f'but did not implement state_dict(). ' + f'implement state_dict() to support checkpointing.' + ) + + def load_state_dict(self, state: dict) -> None: + """Restore the hook's state from a dict returned by state_dict().""" + raise NotImplementedError( + f'{self.__class__.__name__} has has_state=True ' + f'but did not implement load_state_dict(). ' + f'implement load_state_dict() to support checkpointing.' + ) diff --git a/tgm/hooks/hook_manager.py b/tgm/hooks/hook_manager.py index ef54d0a1..e4b420bd 100644 --- a/tgm/hooks/hook_manager.py +++ b/tgm/hooks/hook_manager.py @@ -168,6 +168,60 @@ def reset_state(self, key: str | None = None) -> None: logger.debug('Resetting state for keyed hook: %s', h.__class__.__name__) h.reset_state() + def state_dict(self, key: str | None = None) -> Dict[str, Any]: + """Returns the state of all stateful hooks for the given key as a single dict. + + Calls resolve_hooks(key) first to ensure the merged execution list is up to date. + Uses id() to avoid saving the same hook instance twice (e.g. shared hooks that also + appear in _key_to_hooks[key] after resolution). + + Args: + key (str | None): The split key whose hooks to save. Defaults to the first + registered key (typically 'train'). + + Returns: + Dict[str, Any]: A dict mapping hook keys to their state dicts. + """ + key = key if key is not None else self._registered_key[0] + self._ensure_valid_key(key) + if self._dirty[key]: + self.resolve_hooks(key) + + states: Dict[str, Any] = {} + seen: set = set() + for i, hook in enumerate(self._key_to_hooks[key]): + if id(hook) in seen: + continue + seen.add(id(hook)) + if hook.has_state: + states[f'{i}_{hook.__class__.__name__}'] = hook.state_dict() + return states + + def load_state_dict(self, states: Dict[str, Any], key: str | None = None) -> None: + """Restores the state of all stateful hooks from a dict produced by state_dict(). + + Calls resolve_hooks(key) first to ensure the merged execution list is up to date. + + Args: + states (Dict[str, Any]): A dict mapping hook keys to their state dicts. + key (str | None): The split key whose hooks to restore. Defaults to the first + registered key (typically 'train'). + """ + key = key if key is not None else self._registered_key[0] + self._ensure_valid_key(key) + if self._dirty[key]: + self.resolve_hooks(key) + + seen: set = set() + for i, hook in enumerate(self._key_to_hooks[key]): + if id(hook) in seen: + continue + seen.add(id(hook)) + if hook.has_state: + hook_key = f'{i}_{hook.__class__.__name__}' + if hook_key in states: + hook.load_state_dict(states[hook_key]) + def resolve_hooks(self, key: str | None = None) -> None: """Resolves hook execution order by topologically sorting them based on dependencies. diff --git a/tgm/hooks/neighbors.py b/tgm/hooks/neighbors.py index b454ead7..1c67923b 100644 --- a/tgm/hooks/neighbors.py +++ b/tgm/hooks/neighbors.py @@ -574,3 +574,29 @@ def _initialize_nbr_feats_if_needed(self, dg: DGraph) -> None: (self._num_nodes, self._max_nbrs, self._edge_x_dim) # type: ignore ) self._need_to_initialize_nbr_feats = False + + def state_dict(self) -> dict: + if self._nbr_feats is not None: + nbr_feats: torch.Tensor | None = self._nbr_feats.cpu().clone() + else: + nbr_feats = None + return { + '_nbr_ids': self._nbr_ids.cpu().clone(), + '_nbr_times': self._nbr_times.cpu().clone(), + '_nbr_feats': nbr_feats, + '_write_pos': self._write_pos.cpu().clone(), + '_edge_x_dim': self._edge_x_dim, + '_need_to_initialize_nbr_feats': self._need_to_initialize_nbr_feats, + } + + def load_state_dict(self, state: dict) -> None: + self._nbr_ids = state['_nbr_ids'].to(self._device) + self._nbr_times = state['_nbr_times'].to(self._device) + self._nbr_feats = ( + state['_nbr_feats'].to(self._device) + if state['_nbr_feats'] is not None + else None + ) + self._write_pos = state['_write_pos'].to(self._device) + self._edge_x_dim = state['_edge_x_dim'] + self._need_to_initialize_nbr_feats = state['_need_to_initialize_nbr_feats'] diff --git a/tgm/hooks/node_analytics.py b/tgm/hooks/node_analytics.py index c110f4ab..805191db 100644 --- a/tgm/hooks/node_analytics.py +++ b/tgm/hooks/node_analytics.py @@ -56,6 +56,7 @@ def __init__(self, tracked_nodes: Tensor, num_nodes: int) -> None: # Create a mask for fast lookup of tracked nodes self._tracked_mask = torch.zeros(num_nodes, dtype=torch.bool) self._tracked_mask[self.tracked_nodes] = True + self._device = torch.device('cpu') # State dictionaries for each tracked node self._first_seen: Dict[int, float] = {} @@ -211,6 +212,30 @@ def _compute_edge_statistics(self, batch: DGBatch) -> Dict[str, float]: return edge_stats + def state_dict(self) -> dict: + return { + '_first_seen': self._first_seen, + '_last_seen': self._last_seen, + '_appearances': self._appearances, + '_total_timesteps': self._total_timesteps, + '_node_timesteps': self._node_timesteps, + '_all_neighbors': self._all_neighbors, + '_engagement_sum': self._engagement_sum, + '_seen_edges': self._seen_edges, + '_tracked_mask': self._tracked_mask.cpu().clone(), + } + + def load_state_dict(self, state: dict) -> None: + self._first_seen = state['_first_seen'] + self._last_seen = state['_last_seen'] + self._appearances = state['_appearances'] + self._total_timesteps = state['_total_timesteps'] + self._node_timesteps = state['_node_timesteps'] + self._all_neighbors = state['_all_neighbors'] + self._engagement_sum = state['_engagement_sum'] + self._seen_edges = state['_seen_edges'] + self._tracked_mask = state['_tracked_mask'].to(self._device) + def reset_state(self) -> None: """Reset internal state.""" self._first_seen.clear() @@ -224,8 +249,14 @@ def reset_state(self) -> None: self._engagement_sum.clear() self._seen_edges.clear() + def _move_to_device_if_needed(self, device: torch.device) -> None: + if device != self._device: + self._device = device + self._tracked_mask = self._tracked_mask.to(device) + def __call__(self, dg: DGraph, batch: DGBatch) -> DGBatch: """Compute node-centric statistics for tracked nodes in the batch.""" + self._move_to_device_if_needed(dg.device) # Get current timestamp current_time = self._get_batch_timestamp(batch) diff --git a/tgm/hooks/node_tracks.py b/tgm/hooks/node_tracks.py index ff6a003f..f80adf1d 100644 --- a/tgm/hooks/node_tracks.py +++ b/tgm/hooks/node_tracks.py @@ -29,6 +29,12 @@ def __init__(self, num_nodes: int) -> None: self._seen_mask = torch.zeros(num_nodes, dtype=torch.bool) self._device = torch.device('cpu') + def state_dict(self) -> dict: + return {'_seen_mask': self._seen_mask.cpu().clone()} + + def load_state_dict(self, state: dict) -> None: + self._seen_mask = state['_seen_mask'].to(self._device) + def reset_state(self) -> None: logger.debug('Reset state of the hook') self._seen_mask.fill_(False)