-
Notifications
You must be signed in to change notification settings - Fork 20
Saving loading hook state #400
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
7927840
93490a2
cbbf899
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should add the logic from
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes I agree |
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -1,4 +1,6 @@ | ||||||||
| import argparse | ||||||||
| import os | ||||||||
| import random | ||||||||
|
|
||||||||
| import numpy as np | ||||||||
| import torch | ||||||||
|
|
@@ -50,6 +52,18 @@ | |||||||
| parser.add_argument( | ||||||||
| '--log-file-path', type=str, default=None, help='Optional path to write logs' | ||||||||
| ) | ||||||||
| parser.add_argument( | ||||||||
| '--checkpoint-dir', | ||||||||
| type=str, | ||||||||
| default='outputs/checkpoints', | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit:
Suggested change
|
||||||||
| help='Directory to save checkpoints', | ||||||||
| ) | ||||||||
| parser.add_argument( | ||||||||
| '--resume', | ||||||||
| type=str, | ||||||||
| default=None, | ||||||||
| help='Path to a specific checkpoint to resume from', | ||||||||
| ) | ||||||||
|
|
||||||||
| args = parser.parse_args() | ||||||||
| enable_logging(log_file_path=args.log_file_path) | ||||||||
|
|
@@ -140,15 +154,19 @@ def train( | |||||||
| encoder: nn.Module, | ||||||||
| decoder: nn.Module, | ||||||||
| opt: torch.optim.Optimizer, | ||||||||
| epoch: int, | ||||||||
| hm: object, | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| nbr_hook: object, | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| start_batch: int = -1, | ||||||||
| ) -> float: | ||||||||
| encoder.train() | ||||||||
| decoder.train() | ||||||||
| total_loss = 0 | ||||||||
| static_node_x = loader.dgraph.static_node_x | ||||||||
|
|
||||||||
| for batch in tqdm(loader): | ||||||||
| for local_idx, batch in enumerate(tqdm(loader)): | ||||||||
| batch_idx = local_idx + start_batch + 1 | ||||||||
| opt.zero_grad() | ||||||||
|
|
||||||||
| z = encoder(batch, static_node_x) | ||||||||
| z_src, z_dst, z_neg = torch.chunk(z, 3) | ||||||||
|
|
||||||||
|
|
@@ -160,6 +178,13 @@ def train( | |||||||
| loss.backward() | ||||||||
| opt.step() | ||||||||
| total_loss += float(loss) | ||||||||
|
|
||||||||
| print( | ||||||||
| f'epoch {epoch} batch {batch_idx} loss {float(loss):.6f} write_pos_sum {nbr_hook._write_pos.sum().item():.0f}' | ||||||||
| ) | ||||||||
|
Comment on lines
+182
to
+184
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
|
|
||||||||
| if batch_idx % 100 == 0: | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would be nice to have a variable called I think another flag is needed to indicate whether we want to checkpoint the model as well. Such as if batch_idx % checkpoint_interval == 0 && args.train_checkpoint:
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we just checkpoint only after each epoch though this also makes sense for larger datasets |
||||||||
| save_checkpoint(epoch, batch_idx, encoder, decoder, opt, hm) | ||||||||
|
ntgbaoo marked this conversation as resolved.
|
||||||||
| return total_loss | ||||||||
|
|
||||||||
|
|
||||||||
|
|
@@ -235,7 +260,6 @@ def eval( | |||||||
| train_key, val_key, test_key = hm.keys | ||||||||
| hm.register_shared(nbr_hook) | ||||||||
|
|
||||||||
| train_loader = DGDataLoader(train_dg, args.bsize, hook_manager=hm) | ||||||||
| val_loader = DGDataLoader(val_dg, args.bsize, hook_manager=hm) | ||||||||
| test_loader = DGDataLoader(test_dg, args.bsize, hook_manager=hm) | ||||||||
|
|
||||||||
|
|
@@ -252,14 +276,87 @@ def eval( | |||||||
| args.device | ||||||||
| ) | ||||||||
| opt = torch.optim.Adam( | ||||||||
| set(encoder.parameters()) | set(decoder.parameters()), lr=float(args.lr) | ||||||||
| list(encoder.parameters()) + list(decoder.parameters()), lr=float(args.lr) | ||||||||
| ) | ||||||||
|
|
||||||||
| best_val = 0.0 | ||||||||
|
|
||||||||
| for epoch in range(1, args.epochs + 1): | ||||||||
|
|
||||||||
| def save_checkpoint(epoch, batch_idx, encoder, decoder, opt, hm): | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add a type for each input variable for this function? And would be nice to have a comment to describe the purpose of this method as well. |
||||||||
| os.makedirs(args.checkpoint_dir, exist_ok=True) | ||||||||
| path = os.path.join(args.checkpoint_dir, f'ckpt_e{epoch}_b{batch_idx}.pt') | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| torch.save( | ||||||||
| { | ||||||||
| 'epoch': epoch, | ||||||||
| 'batch_idx': batch_idx, | ||||||||
| 'encoder': encoder.state_dict(), | ||||||||
| 'decoder': decoder.state_dict(), | ||||||||
| 'opt': opt.state_dict(), | ||||||||
| 'hm': hm.state_dict(), | ||||||||
| 'rng_torch': torch.get_rng_state(), | ||||||||
| 'rng_numpy': np.random.get_state(), | ||||||||
| 'rng_python': random.getstate(), | ||||||||
| 'best_val': best_val, | ||||||||
| }, | ||||||||
| path, | ||||||||
| ) | ||||||||
|
|
||||||||
|
|
||||||||
| def load_checkpoint(path, encoder, decoder, opt, hm): | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same for this function. A type for each variable, a function comment is also necessary here
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. potentially need to make this verbose by default, user should want to know when and where the checkpoint was resumed from |
||||||||
| ckpt = torch.load(path, weights_only=False) | ||||||||
| encoder.load_state_dict(ckpt['encoder']) | ||||||||
| decoder.load_state_dict(ckpt['decoder']) | ||||||||
| try: | ||||||||
| opt.load_state_dict(ckpt['opt']) | ||||||||
| except (ValueError, RuntimeError) as e: | ||||||||
| print( | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| f'Warning: skipping optimizer state (shape mismatch — delete outputs/checkpoints/ and rerun): {e}' | ||||||||
| ) | ||||||||
| hm.load_state_dict(ckpt['hm']) | ||||||||
| torch.set_rng_state(ckpt['rng_torch']) | ||||||||
| np.random.set_state(ckpt['rng_numpy']) | ||||||||
| return ckpt['epoch'], ckpt['batch_idx'], ckpt.get('best_val', 0.0) | ||||||||
|
|
||||||||
|
|
||||||||
| def find_latest_checkpoint(directory): | ||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be a helper function in util as Bao suggested |
||||||||
| if not os.path.isdir(directory): | ||||||||
| return None | ||||||||
| pts = [ | ||||||||
| os.path.join(directory, f) for f in os.listdir(directory) if f.endswith('.pt') | ||||||||
| ] | ||||||||
| if not pts: | ||||||||
| return None | ||||||||
| return max(pts, key=os.path.getmtime) | ||||||||
|
|
||||||||
|
|
||||||||
| start_epoch = 1 | ||||||||
| start_batch = -1 | ||||||||
| ckpt_path = ( | ||||||||
| find_latest_checkpoint(args.resume) | ||||||||
| if args.resume and os.path.isdir(args.resume) | ||||||||
| else args.resume | ||||||||
| ) | ||||||||
| if ckpt_path is not None: | ||||||||
| start_epoch, start_batch, best_val = load_checkpoint( | ||||||||
| ckpt_path, encoder, decoder, opt, hm | ||||||||
| ) | ||||||||
|
|
||||||||
| for epoch in range(start_epoch, args.epochs + 1): | ||||||||
| skip = (start_batch + 1) if epoch == start_epoch else 0 | ||||||||
| train_loader = DGDataLoader( | ||||||||
| train_dg, args.bsize, hook_manager=hm, skip_batches=skip | ||||||||
| ) | ||||||||
| with hm.activate(train_key): | ||||||||
| loss = train(train_loader, encoder, decoder, opt) | ||||||||
| loss = train( | ||||||||
| train_loader, | ||||||||
| encoder, | ||||||||
| decoder, | ||||||||
| opt, | ||||||||
| epoch, | ||||||||
| hm, | ||||||||
| nbr_hook, | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| start_batch if epoch == start_epoch else -1, | ||||||||
| ) | ||||||||
|
|
||||||||
| with hm.activate(val_key): | ||||||||
| val_mrr = eval(val_loader, encoder, decoder, evaluator) | ||||||||
|
|
@@ -268,6 +365,7 @@ def eval( | |||||||
|
|
||||||||
| if val_mrr > best_val: | ||||||||
| best_val = val_mrr | ||||||||
| log_metric('Best Validation', best_val, epoch=epoch) | ||||||||
| with hm.activate(test_key): | ||||||||
| test_mrr = eval(test_loader, encoder, decoder, evaluator) | ||||||||
| log_metric(f'Test {METRIC_TGB_LINKPROPPRED}', test_mrr, epoch=args.epochs) | ||||||||
|
|
||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,185 @@ | ||
| import pytest | ||
| import torch | ||
|
|
||
| from tgm import DGraph | ||
| from tgm.data import DGData, DGDataLoader | ||
| from tgm.hooks import HookManager, RecencyNeighborHook | ||
| from tgm.hooks.base import StatefulHook | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def dg(): | ||
| edge_index = torch.IntTensor( | ||
| [ | ||
| [1, 2], | ||
| [1, 2], | ||
| [2, 3], | ||
| ] | ||
| ) | ||
| edge_time = torch.LongTensor([1, 1, 2]) | ||
| edge_x = torch.rand(3, 4) | ||
|
|
||
| node_x_time = torch.LongTensor([5, 5, 6]) | ||
| node_x_nids = torch.IntTensor([2, 2, 3]) | ||
| node_x = torch.rand(3, 3) | ||
|
|
||
| data = DGData.from_raw( | ||
| edge_time=edge_time, | ||
| edge_index=edge_index, | ||
| edge_x=edge_x, | ||
| node_x_time=node_x_time, | ||
| node_x_nids=node_x_nids, | ||
| node_x=node_x, | ||
| ) | ||
| return DGraph(data) | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def recency_hook(dg): | ||
| return RecencyNeighborHook( | ||
| num_nbrs=[2], | ||
| num_nodes=dg.num_nodes, | ||
| seed_nodes_keys=['edge_src', 'edge_dst'], | ||
| seed_times_keys=['edge_time', 'edge_time'], | ||
| ) | ||
|
|
||
|
|
||
| class _CounterHook(StatefulHook): | ||
| produces = set() | ||
| requires = set() | ||
|
|
||
| def __init__(self): | ||
| self.counter = 0 | ||
|
|
||
| def __call__(self, dg, batch): | ||
| self.counter += 1 | ||
| return batch | ||
|
|
||
| def reset_state(self): | ||
| self.counter = 0 | ||
|
|
||
| def state_dict(self): | ||
| return {'counter': self.counter} | ||
|
|
||
| def load_state_dict(self, state): | ||
| self.counter = state['counter'] | ||
|
|
||
|
|
||
| def test_stateful_hook_state_dict_raises_if_not_implemented(): | ||
| class ForgottenHook(StatefulHook): | ||
| pass | ||
|
|
||
| with pytest.raises(NotImplementedError): | ||
| ForgottenHook().state_dict() | ||
|
|
||
|
|
||
| def test_stateful_hook_load_state_dict_raises_if_not_implemented(): | ||
| class ForgottenHook(StatefulHook): | ||
| pass | ||
|
|
||
| with pytest.raises(NotImplementedError): | ||
| ForgottenHook().load_state_dict({}) | ||
|
|
||
|
|
||
| def test_recency_hook_state_dict_contains_required_keys(dg, recency_hook): | ||
| hm = HookManager(keys=['train']) | ||
| hm.register_shared(recency_hook) | ||
| loader = DGDataLoader(dg, batch_size=1, hook_manager=hm) | ||
|
|
||
| with hm.activate('train'): | ||
| for i, _ in enumerate(loader): | ||
| if i == 1: | ||
| break | ||
|
|
||
| state = recency_hook.state_dict() | ||
|
|
||
| assert '_nbr_ids' in state | ||
| assert '_nbr_times' in state | ||
| assert '_nbr_feats' in state | ||
| assert '_write_pos' in state | ||
| assert '_edge_x_dim' in state | ||
| assert '_need_to_initialize_nbr_feats' in state | ||
|
|
||
|
|
||
| def test_recency_hook_state_dict_load_state_dict_roundtrip(dg, recency_hook): | ||
| hm = HookManager(keys=['train']) | ||
| hm.register_shared(recency_hook) | ||
| loader = DGDataLoader(dg, batch_size=1, hook_manager=hm) | ||
|
|
||
| with hm.activate('train'): | ||
| for i, _ in enumerate(loader): | ||
| if i == 1: | ||
| break | ||
|
|
||
| write_pos_before = recency_hook._write_pos.clone() | ||
| nbr_ids_before = recency_hook._nbr_ids.clone() | ||
|
|
||
| state = recency_hook.state_dict() | ||
|
|
||
| recency_hook.reset_state() | ||
| assert not torch.equal(recency_hook._write_pos, write_pos_before) | ||
|
|
||
| recency_hook.load_state_dict(state) | ||
|
|
||
| assert torch.equal(recency_hook._write_pos, write_pos_before) | ||
| assert torch.equal(recency_hook._nbr_ids, nbr_ids_before) | ||
|
|
||
|
|
||
| def test_hook_manager_state_dict_saves_stateful_hook(): | ||
| hm = HookManager(keys=['train']) | ||
| hook = _CounterHook() | ||
| hook.counter = 42 | ||
| hm.register_shared(hook) | ||
|
|
||
| states = hm.state_dict('train') | ||
|
|
||
| assert len(states) >= 1 | ||
| saved_counter = list(states.values())[0]['counter'] | ||
| assert saved_counter == 42 | ||
|
|
||
|
|
||
| def test_hook_manager_state_dict_no_duplicate_saves(): | ||
| hm = HookManager(keys=['train']) | ||
| hook = _CounterHook() | ||
| hm.register_shared(hook) | ||
|
|
||
| states = hm.state_dict('train') | ||
|
|
||
| assert len(states) == 1 | ||
|
|
||
|
|
||
| def test_hook_manager_load_state_dict_restores_hook(): | ||
| hm = HookManager(keys=['train']) | ||
| hook = _CounterHook() | ||
| hm.register_shared(hook) | ||
|
|
||
| hook.counter = 99 | ||
| states = hm.state_dict('train') | ||
|
|
||
| hook.counter = 0 | ||
| hm.load_state_dict(states, 'train') | ||
|
|
||
| assert hook.counter == 99 | ||
|
|
||
|
|
||
| def test_skip_batches_reduces_yielded_count(dg): | ||
| full_batches = list(DGDataLoader(dg, batch_size=1)) | ||
| skip_batches = list(DGDataLoader(dg, batch_size=1, skip_batches=1)) | ||
|
|
||
| assert len(skip_batches) == len(full_batches) - 1 | ||
|
|
||
|
|
||
| def test_skip_batches_hook_not_executed(dg): | ||
| hm = HookManager(keys=['train']) | ||
| hook = _CounterHook() | ||
| hm.register('train', hook) | ||
|
|
||
| total = len(list(DGDataLoader(dg, batch_size=1))) | ||
|
|
||
| skip = 1 | ||
| loader = DGDataLoader(dg, batch_size=1, hook_manager=hm, skip_batches=skip) | ||
|
|
||
| with hm.activate('train'): | ||
| list(loader) | ||
|
|
||
| assert hook.counter == total - skip |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit