Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ ipython_config.py

# Remove previous ipynb_checkpoints
# git rm -r .ipynb_checkpoints/

outputs/
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
outputs/
artifact/

### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
110 changes: 104 additions & 6 deletions examples/linkproppred/tgat.py
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add the logic from load_checkpoint, save_checkpoint and find_latest_checkpoint to CheckPointHandler in tgm/util/ if we want to have checkpointing in other examples as well, to avoid duplicated code

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I agree

Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import argparse
import os
import random

import numpy as np
import torch
Expand Down Expand Up @@ -50,6 +52,18 @@
parser.add_argument(
'--log-file-path', type=str, default=None, help='Optional path to write logs'
)
parser.add_argument(
'--checkpoint-dir',
type=str,
default='outputs/checkpoints',
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
default='outputs/checkpoints',
default='artifact/checkpoints',

help='Directory to save checkpoints',
)
parser.add_argument(
'--resume',
type=str,
default=None,
help='Path to a specific checkpoint to resume from',
)

args = parser.parse_args()
enable_logging(log_file_path=args.log_file_path)
Expand Down Expand Up @@ -140,15 +154,19 @@ def train(
encoder: nn.Module,
decoder: nn.Module,
opt: torch.optim.Optimizer,
epoch: int,
hm: object,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
hm: object,
hm: HookManager,

nbr_hook: object,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
nbr_hook: object,
nbr_hook: StatefulHook,

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
nbr_hook: object,

start_batch: int = -1,
) -> float:
encoder.train()
decoder.train()
total_loss = 0
static_node_x = loader.dgraph.static_node_x

for batch in tqdm(loader):
for local_idx, batch in enumerate(tqdm(loader)):
batch_idx = local_idx + start_batch + 1
opt.zero_grad()

z = encoder(batch, static_node_x)
z_src, z_dst, z_neg = torch.chunk(z, 3)

Expand All @@ -160,6 +178,13 @@ def train(
loss.backward()
opt.step()
total_loss += float(loss)

print(
f'epoch {epoch} batch {batch_idx} loss {float(loss):.6f} write_pos_sum {nbr_hook._write_pos.sum().item():.0f}'
)
Comment on lines +182 to +184
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print(
f'epoch {epoch} batch {batch_idx} loss {float(loss):.6f} write_pos_sum {nbr_hook._write_pos.sum().item():.0f}'
)


if batch_idx % 100 == 0:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to have a variable called checkpoint_interval from the parser instead of hardcoding 100.

I think another flag is needed to indicate whether we want to checkpoint the model as well. Such as train_checkpoint, then the condition will have

if batch_idx %  checkpoint_interval == 0 && args.train_checkpoint:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we just checkpoint only after each epoch though this also makes sense for larger datasets

save_checkpoint(epoch, batch_idx, encoder, decoder, opt, hm)
Comment thread
ntgbaoo marked this conversation as resolved.
return total_loss


Expand Down Expand Up @@ -235,7 +260,6 @@ def eval(
train_key, val_key, test_key = hm.keys
hm.register_shared(nbr_hook)

train_loader = DGDataLoader(train_dg, args.bsize, hook_manager=hm)
val_loader = DGDataLoader(val_dg, args.bsize, hook_manager=hm)
test_loader = DGDataLoader(test_dg, args.bsize, hook_manager=hm)

Expand All @@ -252,14 +276,87 @@ def eval(
args.device
)
opt = torch.optim.Adam(
set(encoder.parameters()) | set(decoder.parameters()), lr=float(args.lr)
list(encoder.parameters()) + list(decoder.parameters()), lr=float(args.lr)
)

best_val = 0.0

for epoch in range(1, args.epochs + 1):

def save_checkpoint(epoch, batch_idx, encoder, decoder, opt, hm):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a type for each input variable for this function? And would be nice to have a comment to describe the purpose of this method as well.

os.makedirs(args.checkpoint_dir, exist_ok=True)
path = os.path.join(args.checkpoint_dir, f'ckpt_e{epoch}_b{batch_idx}.pt')
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
path = os.path.join(args.checkpoint_dir, f'ckpt_e{epoch}_b{batch_idx}.pt')
path = os.path.join(args.checkpoint_dir, f'ckpt_epoch_{epoch}_batch_{batch_idx}.pt')

torch.save(
{
'epoch': epoch,
'batch_idx': batch_idx,
'encoder': encoder.state_dict(),
'decoder': decoder.state_dict(),
'opt': opt.state_dict(),
'hm': hm.state_dict(),
'rng_torch': torch.get_rng_state(),
'rng_numpy': np.random.get_state(),
'rng_python': random.getstate(),
'best_val': best_val,
},
path,
)


def load_checkpoint(path, encoder, decoder, opt, hm):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for this function. A type for each variable, a function comment is also necessary here

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

potentially need to make this verbose by default, user should want to know when and where the checkpoint was resumed from

ckpt = torch.load(path, weights_only=False)
encoder.load_state_dict(ckpt['encoder'])
decoder.load_state_dict(ckpt['decoder'])
try:
opt.load_state_dict(ckpt['opt'])
except (ValueError, RuntimeError) as e:
print(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print(
logger.warning(

f'Warning: skipping optimizer state (shape mismatch — delete outputs/checkpoints/ and rerun): {e}'
)
hm.load_state_dict(ckpt['hm'])
torch.set_rng_state(ckpt['rng_torch'])
np.random.set_state(ckpt['rng_numpy'])
return ckpt['epoch'], ckpt['batch_idx'], ckpt.get('best_val', 0.0)


def find_latest_checkpoint(directory):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be a helper function in util as Bao suggested

if not os.path.isdir(directory):
return None
pts = [
os.path.join(directory, f) for f in os.listdir(directory) if f.endswith('.pt')
]
if not pts:
return None
return max(pts, key=os.path.getmtime)


start_epoch = 1
start_batch = -1
ckpt_path = (
find_latest_checkpoint(args.resume)
if args.resume and os.path.isdir(args.resume)
else args.resume
)
if ckpt_path is not None:
start_epoch, start_batch, best_val = load_checkpoint(
ckpt_path, encoder, decoder, opt, hm
)

for epoch in range(start_epoch, args.epochs + 1):
skip = (start_batch + 1) if epoch == start_epoch else 0
train_loader = DGDataLoader(
train_dg, args.bsize, hook_manager=hm, skip_batches=skip
)
with hm.activate(train_key):
loss = train(train_loader, encoder, decoder, opt)
loss = train(
train_loader,
encoder,
decoder,
opt,
epoch,
hm,
nbr_hook,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
nbr_hook,

start_batch if epoch == start_epoch else -1,
)

with hm.activate(val_key):
val_mrr = eval(val_loader, encoder, decoder, evaluator)
Expand All @@ -268,6 +365,7 @@ def eval(

if val_mrr > best_val:
best_val = val_mrr
log_metric('Best Validation', best_val, epoch=epoch)
with hm.activate(test_key):
test_mrr = eval(test_loader, encoder, decoder, evaluator)
log_metric(f'Test {METRIC_TGB_LINKPROPPRED}', test_mrr, epoch=args.epochs)
Expand Down
185 changes: 185 additions & 0 deletions test/unit/test_hooks/test_state_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import pytest
import torch

from tgm import DGraph
from tgm.data import DGData, DGDataLoader
from tgm.hooks import HookManager, RecencyNeighborHook
from tgm.hooks.base import StatefulHook


@pytest.fixture
def dg():
edge_index = torch.IntTensor(
[
[1, 2],
[1, 2],
[2, 3],
]
)
edge_time = torch.LongTensor([1, 1, 2])
edge_x = torch.rand(3, 4)

node_x_time = torch.LongTensor([5, 5, 6])
node_x_nids = torch.IntTensor([2, 2, 3])
node_x = torch.rand(3, 3)

data = DGData.from_raw(
edge_time=edge_time,
edge_index=edge_index,
edge_x=edge_x,
node_x_time=node_x_time,
node_x_nids=node_x_nids,
node_x=node_x,
)
return DGraph(data)


@pytest.fixture
def recency_hook(dg):
return RecencyNeighborHook(
num_nbrs=[2],
num_nodes=dg.num_nodes,
seed_nodes_keys=['edge_src', 'edge_dst'],
seed_times_keys=['edge_time', 'edge_time'],
)


class _CounterHook(StatefulHook):
produces = set()
requires = set()

def __init__(self):
self.counter = 0

def __call__(self, dg, batch):
self.counter += 1
return batch

def reset_state(self):
self.counter = 0

def state_dict(self):
return {'counter': self.counter}

def load_state_dict(self, state):
self.counter = state['counter']


def test_stateful_hook_state_dict_raises_if_not_implemented():
class ForgottenHook(StatefulHook):
pass

with pytest.raises(NotImplementedError):
ForgottenHook().state_dict()


def test_stateful_hook_load_state_dict_raises_if_not_implemented():
class ForgottenHook(StatefulHook):
pass

with pytest.raises(NotImplementedError):
ForgottenHook().load_state_dict({})


def test_recency_hook_state_dict_contains_required_keys(dg, recency_hook):
hm = HookManager(keys=['train'])
hm.register_shared(recency_hook)
loader = DGDataLoader(dg, batch_size=1, hook_manager=hm)

with hm.activate('train'):
for i, _ in enumerate(loader):
if i == 1:
break

state = recency_hook.state_dict()

assert '_nbr_ids' in state
assert '_nbr_times' in state
assert '_nbr_feats' in state
assert '_write_pos' in state
assert '_edge_x_dim' in state
assert '_need_to_initialize_nbr_feats' in state


def test_recency_hook_state_dict_load_state_dict_roundtrip(dg, recency_hook):
hm = HookManager(keys=['train'])
hm.register_shared(recency_hook)
loader = DGDataLoader(dg, batch_size=1, hook_manager=hm)

with hm.activate('train'):
for i, _ in enumerate(loader):
if i == 1:
break

write_pos_before = recency_hook._write_pos.clone()
nbr_ids_before = recency_hook._nbr_ids.clone()

state = recency_hook.state_dict()

recency_hook.reset_state()
assert not torch.equal(recency_hook._write_pos, write_pos_before)

recency_hook.load_state_dict(state)

assert torch.equal(recency_hook._write_pos, write_pos_before)
assert torch.equal(recency_hook._nbr_ids, nbr_ids_before)


def test_hook_manager_state_dict_saves_stateful_hook():
hm = HookManager(keys=['train'])
hook = _CounterHook()
hook.counter = 42
hm.register_shared(hook)

states = hm.state_dict('train')

assert len(states) >= 1
saved_counter = list(states.values())[0]['counter']
assert saved_counter == 42


def test_hook_manager_state_dict_no_duplicate_saves():
hm = HookManager(keys=['train'])
hook = _CounterHook()
hm.register_shared(hook)

states = hm.state_dict('train')

assert len(states) == 1


def test_hook_manager_load_state_dict_restores_hook():
hm = HookManager(keys=['train'])
hook = _CounterHook()
hm.register_shared(hook)

hook.counter = 99
states = hm.state_dict('train')

hook.counter = 0
hm.load_state_dict(states, 'train')

assert hook.counter == 99


def test_skip_batches_reduces_yielded_count(dg):
full_batches = list(DGDataLoader(dg, batch_size=1))
skip_batches = list(DGDataLoader(dg, batch_size=1, skip_batches=1))

assert len(skip_batches) == len(full_batches) - 1


def test_skip_batches_hook_not_executed(dg):
hm = HookManager(keys=['train'])
hook = _CounterHook()
hm.register('train', hook)

total = len(list(DGDataLoader(dg, batch_size=1)))

skip = 1
loader = DGDataLoader(dg, batch_size=1, hook_manager=hm, skip_batches=skip)

with hm.activate('train'):
list(loader)

assert hook.counter == total - skip
Loading
Loading