Skip to content

Saving loading hook state#400

Open
NaziaHossain066 wants to merge 3 commits into
tgm-team:mainfrom
NaziaHossain066:saving-loading-hook-state
Open

Saving loading hook state#400
NaziaHossain066 wants to merge 3 commits into
tgm-team:mainfrom
NaziaHossain066:saving-loading-hook-state

Conversation

@NaziaHossain066
Copy link
Copy Markdown

Summary / Description

Adds checkpointing support for stateful hooks in TGM, enabling training to be resumed from mid-epoch interruptions without losing hook state or RNG reproducibility. Demonstrates end-to-end checkpoint save/resume in examples/linkproppred/tgat.py.

Changes:

  • Added state_dict() / load_state_dict() to DGHook, StatelessHook, StatefulHook, RecencyNeighborHook, EdgeEventsSeenNodesTrackHook, NodeAnalyticsHook, and HookManager (with id()-based deduplication for shared hooks)
  • Added skip_batches to DGDataLoader to skip already-processed batches on resume without running hooks on them
  • Overrode _get_iterator in DGDataLoader to neutralize PyTorch's internal RNG consumption, ensuring resumed runs are bit-exact
  • Updated examples/linkproppred/tgat.py with save_checkpoint, load_checkpoint, and --resume / --checkpoint-dir args
  • Added 9 unit tests in test/unit/test_hooks/test_state_hook.py
    Related Issues: # (395)

Type of Change

  • Bug fix
  • New feature
  • Breaking Change
  • Refactoring
  • Documentation update

Test Evidence

Describe how this PR has been tested.

  • Unit tests
  • Integration tests
  • Performance tests

Questions / Discussion Points

Should checkpointing be extended to other example files beyond tgat.py

@ntgbaoo ntgbaoo linked an issue Apr 11, 2026 that may be closed by this pull request
@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 11, 2026

Codecov Report

❌ Patch coverage is 77.90698% with 19 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
tgm/hooks/node_analytics.py 33.33% 12 Missing ⚠️
tgm/hooks/hook_manager.py 89.28% 3 Missing ⚠️
tgm/hooks/node_tracks.py 50.00% 2 Missing ⚠️
tgm/hooks/base.py 85.71% 1 Missing ⚠️
tgm/hooks/neighbors.py 91.66% 1 Missing ⚠️

📢 Thoughts on this report? Let us know!

Copy link
Copy Markdown
Member

@ntgbaoo ntgbaoo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @NaziaHossain066

Thanks for the PR. Great contribution!

Overall, the PR looks pretty good. I left a couple of comments and raised some concerns for us to discuss.

I just approved CI workflows to run for this PR. It is worth checking codecovand adding some unit tests to cover some lines that are missing tests rn, such as:

  • state_dict and load_state_dict for node analytics and node tracks

decoder: nn.Module,
opt: torch.optim.Optimizer,
epoch: int,
hm: object,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
hm: object,
hm: HookManager,

opt: torch.optim.Optimizer,
epoch: int,
hm: object,
nbr_hook: object,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
nbr_hook: object,
nbr_hook: StatefulHook,

Comment on lines +182 to +184
print(
f'epoch {epoch} batch {batch_idx} loss {float(loss):.6f} write_pos_sum {nbr_hook._write_pos.sum().item():.0f}'
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print(
f'epoch {epoch} batch {batch_idx} loss {float(loss):.6f} write_pos_sum {nbr_hook._write_pos.sum().item():.0f}'
)

f'epoch {epoch} batch {batch_idx} loss {float(loss):.6f} write_pos_sum {nbr_hook._write_pos.sum().item():.0f}'
)

if batch_idx % 100 == 0:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to have a variable called checkpoint_interval from the parser instead of hardcoding 100.

I think another flag is needed to indicate whether we want to checkpoint the model as well. Such as train_checkpoint, then the condition will have

if batch_idx %  checkpoint_interval == 0 && args.train_checkpoint:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we just checkpoint only after each epoch though this also makes sense for larger datasets


for epoch in range(1, args.epochs + 1):

def save_checkpoint(epoch, batch_idx, encoder, decoder, opt, hm):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a type for each input variable for this function? And would be nice to have a comment to describe the purpose of this method as well.

Comment thread tgm/data/loader.py
Comment on lines +176 to +179
# PyTorch's _BaseDataLoaderIter.__init__ consumes 2 global RNG samples for
# an internal base_seed (used only for worker processes). Save and restore
# around iterator creation so the global RNG state is unaffected, making
# checkpoint resume produce identical results to an uninterrupted run.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# PyTorch's _BaseDataLoaderIter.__init__ consumes 2 global RNG samples for
# an internal base_seed (used only for worker processes). Save and restore
# around iterator creation so the global RNG state is unaffected, making
# checkpoint resume produce identical results to an uninterrupted run.
"""PyTorch's _BaseDataLoaderIter.__init__ consumes 2 global RNG samples for
an internal base_seed (used only for worker processes). Save and restore
around iterator creation so the global RNG state is unaffected, making
checkpoint resume produce identical results to an uninterrupted run.
"""

parser.add_argument(
'--checkpoint-dir',
type=str,
default='outputs/checkpoints',
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
default='outputs/checkpoints',
default='artifact/checkpoints',

Comment thread .gitignore
# Remove previous ipynb_checkpoints
# git rm -r .ipynb_checkpoints/

outputs/
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
outputs/
artifact/

Comment thread tgm/hooks/base.py

def reset_state(self) -> None: ...

def state_dict(self) -> dict: ...
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def state_dict(self) -> dict: ...
@property
def state_dict(self) -> dict: ...

'_all_neighbors': self._all_neighbors,
'_engagement_sum': self._engagement_sum,
'_seen_edges': self._seen_edges,
'_tracked_mask': self._tracked_mask.cpu().clone(),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we actually need this?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I agree

f'epoch {epoch} batch {batch_idx} loss {float(loss):.6f} write_pos_sum {nbr_hook._write_pos.sum().item():.0f}'
)

if batch_idx % 100 == 0:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we just checkpoint only after each epoch though this also makes sense for larger datasets

return ckpt['epoch'], ckpt['batch_idx'], ckpt.get('best_val', 0.0)


def find_latest_checkpoint(directory):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be a helper function in util as Bao suggested

)


def load_checkpoint(path, encoder, decoder, opt, hm):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

potentially need to make this verbose by default, user should want to know when and where the checkpoint was resumed from

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Saving and Loading Hook State

3 participants