Saving loading hook state#400
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
ntgbaoo
left a comment
There was a problem hiding this comment.
Hey @NaziaHossain066
Thanks for the PR. Great contribution!
Overall, the PR looks pretty good. I left a couple of comments and raised some concerns for us to discuss.
I just approved CI workflows to run for this PR. It is worth checking codecovand adding some unit tests to cover some lines that are missing tests rn, such as:
state_dictandload_state_dictfor node analytics and node tracks
| decoder: nn.Module, | ||
| opt: torch.optim.Optimizer, | ||
| epoch: int, | ||
| hm: object, |
There was a problem hiding this comment.
| hm: object, | |
| hm: HookManager, |
| opt: torch.optim.Optimizer, | ||
| epoch: int, | ||
| hm: object, | ||
| nbr_hook: object, |
There was a problem hiding this comment.
| nbr_hook: object, | |
| nbr_hook: StatefulHook, |
| print( | ||
| f'epoch {epoch} batch {batch_idx} loss {float(loss):.6f} write_pos_sum {nbr_hook._write_pos.sum().item():.0f}' | ||
| ) |
There was a problem hiding this comment.
| print( | |
| f'epoch {epoch} batch {batch_idx} loss {float(loss):.6f} write_pos_sum {nbr_hook._write_pos.sum().item():.0f}' | |
| ) |
| f'epoch {epoch} batch {batch_idx} loss {float(loss):.6f} write_pos_sum {nbr_hook._write_pos.sum().item():.0f}' | ||
| ) | ||
|
|
||
| if batch_idx % 100 == 0: |
There was a problem hiding this comment.
Would be nice to have a variable called checkpoint_interval from the parser instead of hardcoding 100.
I think another flag is needed to indicate whether we want to checkpoint the model as well. Such as train_checkpoint, then the condition will have
if batch_idx % checkpoint_interval == 0 && args.train_checkpoint:There was a problem hiding this comment.
should we just checkpoint only after each epoch though this also makes sense for larger datasets
|
|
||
| for epoch in range(1, args.epochs + 1): | ||
|
|
||
| def save_checkpoint(epoch, batch_idx, encoder, decoder, opt, hm): |
There was a problem hiding this comment.
Can we add a type for each input variable for this function? And would be nice to have a comment to describe the purpose of this method as well.
| # PyTorch's _BaseDataLoaderIter.__init__ consumes 2 global RNG samples for | ||
| # an internal base_seed (used only for worker processes). Save and restore | ||
| # around iterator creation so the global RNG state is unaffected, making | ||
| # checkpoint resume produce identical results to an uninterrupted run. |
There was a problem hiding this comment.
| # PyTorch's _BaseDataLoaderIter.__init__ consumes 2 global RNG samples for | |
| # an internal base_seed (used only for worker processes). Save and restore | |
| # around iterator creation so the global RNG state is unaffected, making | |
| # checkpoint resume produce identical results to an uninterrupted run. | |
| """PyTorch's _BaseDataLoaderIter.__init__ consumes 2 global RNG samples for | |
| an internal base_seed (used only for worker processes). Save and restore | |
| around iterator creation so the global RNG state is unaffected, making | |
| checkpoint resume produce identical results to an uninterrupted run. | |
| """ |
| parser.add_argument( | ||
| '--checkpoint-dir', | ||
| type=str, | ||
| default='outputs/checkpoints', |
There was a problem hiding this comment.
nit:
| default='outputs/checkpoints', | |
| default='artifact/checkpoints', |
| # Remove previous ipynb_checkpoints | ||
| # git rm -r .ipynb_checkpoints/ | ||
|
|
||
| outputs/ |
There was a problem hiding this comment.
nit
| outputs/ | |
| artifact/ |
|
|
||
| def reset_state(self) -> None: ... | ||
|
|
||
| def state_dict(self) -> dict: ... |
There was a problem hiding this comment.
| def state_dict(self) -> dict: ... | |
| @property | |
| def state_dict(self) -> dict: ... |
| '_all_neighbors': self._all_neighbors, | ||
| '_engagement_sum': self._engagement_sum, | ||
| '_seen_edges': self._seen_edges, | ||
| '_tracked_mask': self._tracked_mask.cpu().clone(), |
| f'epoch {epoch} batch {batch_idx} loss {float(loss):.6f} write_pos_sum {nbr_hook._write_pos.sum().item():.0f}' | ||
| ) | ||
|
|
||
| if batch_idx % 100 == 0: |
There was a problem hiding this comment.
should we just checkpoint only after each epoch though this also makes sense for larger datasets
| return ckpt['epoch'], ckpt['batch_idx'], ckpt.get('best_val', 0.0) | ||
|
|
||
|
|
||
| def find_latest_checkpoint(directory): |
There was a problem hiding this comment.
this should be a helper function in util as Bao suggested
| ) | ||
|
|
||
|
|
||
| def load_checkpoint(path, encoder, decoder, opt, hm): |
There was a problem hiding this comment.
potentially need to make this verbose by default, user should want to know when and where the checkpoint was resumed from
Summary / Description
Adds checkpointing support for stateful hooks in TGM, enabling training to be resumed from mid-epoch interruptions without losing hook state or RNG reproducibility. Demonstrates end-to-end checkpoint save/resume in
examples/linkproppred/tgat.py.Changes:
state_dict()/load_state_dict()toDGHook,StatelessHook,StatefulHook,RecencyNeighborHook,EdgeEventsSeenNodesTrackHook,NodeAnalyticsHook, andHookManager(withid()-based deduplication for shared hooks)skip_batchestoDGDataLoaderto skip already-processed batches on resume without running hooks on them_get_iteratorinDGDataLoaderto neutralize PyTorch's internal RNG consumption, ensuring resumed runs are bit-exactexamples/linkproppred/tgat.pywithsave_checkpoint,load_checkpoint, and--resume/--checkpoint-dirargstest/unit/test_hooks/test_state_hook.pyRelated Issues: # (395)
Type of Change
Test Evidence
Describe how this PR has been tested.
Questions / Discussion Points
Should checkpointing be extended to other example files beyond
tgat.py