From ff9712a9b38c754d72ef8b3882913db5c4e0ae2a Mon Sep 17 00:00:00 2001 From: Andy Huang Date: Fri, 15 May 2026 16:06:09 +0100 Subject: [PATCH] adding tgnv2 --- .gitignore | 1 + examples/linkproppred/tgnv2.py | 266 ++++++++++++++++++ examples/nodeproppred/tgnv2.py | 261 +++++++++++++++++ test/integration/test_tgnv2.py | 43 +++ .../test_hooks/test_deduplication_hook.py | 24 ++ test/unit/test_nn/test_tgn.py | 76 ++++- tgm/nn/__init__.py | 14 +- tgm/nn/encoder/__init__.py | 6 +- tgm/nn/encoder/tgn.py | 81 ++++++ tgm/nn/modules/time_encoding.py | 14 +- tgnv2.md | 28 ++ 11 files changed, 805 insertions(+), 9 deletions(-) create mode 100644 examples/linkproppred/tgnv2.py create mode 100644 examples/nodeproppred/tgnv2.py create mode 100644 test/integration/test_tgnv2.py create mode 100644 tgnv2.md diff --git a/.gitignore b/.gitignore index d6cb73df..85b80f9b 100644 --- a/.gitignore +++ b/.gitignore @@ -166,3 +166,4 @@ submitit* # Benchmarks .benchmarks +.claude/ diff --git a/examples/linkproppred/tgnv2.py b/examples/linkproppred/tgnv2.py new file mode 100644 index 00000000..d985113f --- /dev/null +++ b/examples/linkproppred/tgnv2.py @@ -0,0 +1,266 @@ +import argparse + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from tgb.linkproppred.evaluate import Evaluator +from tqdm import tqdm + +from tgm import DGraph +from tgm.constants import ( + METRIC_TGB_LINKPROPPRED, + PADDED_NODE_ID, + RECIPE_TGB_LINK_PRED, +) +from tgm.data import DGData, DGDataLoader +from tgm.hooks import DeduplicationHook, RecencyNeighborHook, RecipeRegistry +from tgm.nn import LinkPredictor, TGNv2Memory +from tgm.nn.encoder.tgn import ( + EncodeIndexMessage, + GraphAttentionEmbedding, + LastAggregator, +) +from tgm.util.logging import enable_logging, log_gpu, log_latency, log_metric +from tgm.util.seed import seed_everything + +parser = argparse.ArgumentParser( + description='TGNv2 LinkPropPred Example', + formatter_class=argparse.ArgumentDefaultsHelpFormatter, +) +parser.add_argument('--seed', type=int, default=1337, help='random seed to use') +parser.add_argument('--dataset', type=str, default='tgbl-wiki', help='Dataset name') +parser.add_argument('--bsize', type=int, default=200, help='batch size') +parser.add_argument('--device', type=str, default='cpu', help='torch device') +parser.add_argument('--epochs', type=int, default=30, help='number of epochs') +parser.add_argument('--lr', type=float, default=0.0001, help='learning rate') +parser.add_argument('--time-dim', type=int, default=100, help='time encoding dimension') +parser.add_argument('--embed-dim', type=int, default=100, help='attention dimension') +parser.add_argument('--memory-dim', type=int, default=100, help='memory dimension') +parser.add_argument( + '--index-dim', + type=int, + default=None, + help='source-target node ID encoding dimension; defaults to memory_dim', +) +parser.add_argument( + '--n-nbrs', + type=int, + nargs='+', + default=[10], + help='num sampled nbrs at each hop', +) +parser.add_argument( + '--log-file-path', type=str, default=None, help='Optional path to write logs' +) + +args = parser.parse_args() +enable_logging(log_file_path=args.log_file_path) + + +@log_gpu +@log_latency +def train( + loader: DGDataLoader, + memory: nn.Module, + encoder: nn.Module, + decoder: nn.Module, + opt: torch.optim.Optimizer, +) -> float: + memory.train() + encoder.train() + decoder.train() + total_loss = 0 + + memory.reset_state() + + for batch in tqdm(loader): + opt.zero_grad() + + nbr_nodes = batch.nbr_nids[0].flatten() + nbr_mask = nbr_nodes != PADDED_NODE_ID + + num_nbrs = len(nbr_nodes) // ( + len(batch.edge_src) + len(batch.edge_dst) + len(batch.neg) + ) + src_nodes = torch.cat( + [ + batch.edge_src.repeat_interleave(num_nbrs), + batch.edge_dst.repeat_interleave(num_nbrs), + batch.neg.repeat_interleave(num_nbrs), + ] + ) + nbr_edge_index = torch.stack( + [ + batch.global_to_local(src_nodes[nbr_mask]), + batch.global_to_local(nbr_nodes[nbr_mask]), + ] + ).to(dtype=torch.int64) + + nbr_edge_time = batch.nbr_edge_time[0].flatten()[nbr_mask] + nbr_edge_x = batch.nbr_edge_x[0].flatten(0, -2).float()[nbr_mask] + + z, last_update = memory(batch.unique_nids) + z = encoder(z, last_update, nbr_edge_index, nbr_edge_time, nbr_edge_x) + + inv_src = batch.global_to_local(batch.edge_src) + inv_dst = batch.global_to_local(batch.edge_dst) + inv_neg = batch.global_to_local(batch.neg) + pos_out = decoder(z[inv_src], z[inv_dst]) + neg_out = decoder(z[inv_src], z[inv_neg]) + + loss = F.binary_cross_entropy_with_logits(pos_out, torch.ones_like(pos_out)) + loss += F.binary_cross_entropy_with_logits(neg_out, torch.zeros_like(neg_out)) + + memory.update_state( + batch.edge_src, batch.edge_dst, batch.edge_time, batch.edge_x.float() + ) + + loss.backward() + opt.step() + total_loss += float(loss) + + memory.detach() + + return total_loss + + +@log_gpu +@log_latency +@torch.no_grad() +def eval( + loader: DGDataLoader, + memory: nn.Module, + encoder: nn.Module, + decoder: nn.Module, + evaluator: Evaluator, +) -> float: + memory.eval() + encoder.eval() + decoder.eval() + perf_list = [] + + for batch in tqdm(loader): + nbr_nodes = batch.nbr_nids[0].flatten() + nbr_mask = nbr_nodes != PADDED_NODE_ID + + num_nbrs = len(nbr_nodes) // ( + len(batch.edge_src) + len(batch.edge_dst) + len(batch.neg) + ) + src_nodes = torch.cat( + [ + batch.edge_src.repeat_interleave(num_nbrs), + batch.edge_dst.repeat_interleave(num_nbrs), + batch.neg.repeat_interleave(num_nbrs), + ] + ) + nbr_edge_index = torch.stack( + [ + batch.global_to_local(src_nodes[nbr_mask]), + batch.global_to_local(nbr_nodes[nbr_mask]), + ] + ).to(dtype=torch.int64) + nbr_edge_time = batch.nbr_edge_time[0].flatten()[nbr_mask] + nbr_edge_x = batch.nbr_edge_x[0].flatten(0, -2).float()[nbr_mask] + + z, last_update = memory(batch.unique_nids) + z = encoder(z, last_update, nbr_edge_index, nbr_edge_time, nbr_edge_x) + + for idx, neg_batch in enumerate(batch.neg_batch_list): + dst_ids = torch.cat([batch.edge_dst[idx].unsqueeze(0), neg_batch]) + src_ids = batch.edge_src[idx].repeat(len(dst_ids)) + + inv_src = batch.global_to_local(src_ids) + inv_dst = batch.global_to_local(dst_ids) + y_pred = decoder(z[inv_src], z[inv_dst]).sigmoid() + + input_dict = { + 'y_pred_pos': y_pred[0], + 'y_pred_neg': y_pred[1:], + 'eval_metric': [METRIC_TGB_LINKPROPPRED], + } + perf_list.append(evaluator.eval(input_dict)[METRIC_TGB_LINKPROPPRED]) + + memory.update_state( + batch.edge_src, batch.edge_dst, batch.edge_time, batch.edge_x.float() + ) + + return float(np.mean(perf_list)) + + +seed_everything(args.seed) +evaluator = Evaluator(name=args.dataset) + +full_data = DGData.from_tgb(args.dataset) +train_data, val_data, test_data = full_data.split() +train_dg = DGraph(train_data, device=args.device) +val_dg = DGraph(val_data, device=args.device) +test_dg = DGraph(test_data, device=args.device) + +nbr_hook = RecencyNeighborHook( + num_nbrs=args.n_nbrs, + num_nodes=full_data.num_nodes, + seed_nodes_keys=['edge_src', 'edge_dst', 'neg'], + seed_times_keys=['edge_time', 'edge_time', 'neg_time'], +) + +hm = RecipeRegistry.build( + RECIPE_TGB_LINK_PRED, dataset_name=args.dataset, train_dg=train_dg +) +train_key, val_key, test_key = hm.keys +hm.register_shared(nbr_hook) +hm.register_shared(DeduplicationHook(seed_nodes_keys=['neg', 'nbr_nids'])) + +train_loader = DGDataLoader(train_dg, args.bsize, hook_manager=hm) +val_loader = DGDataLoader(val_dg, args.bsize, hook_manager=hm) +test_loader = DGDataLoader(test_dg, args.bsize, hook_manager=hm) + +index_dim = args.memory_dim if args.index_dim is None else args.index_dim +message_module = EncodeIndexMessage( + test_dg.edge_x_dim, + args.memory_dim, + args.time_dim, + index_dim, +) +memory = TGNv2Memory( + full_data.num_nodes, + test_dg.edge_x_dim, + args.memory_dim, + args.time_dim, + index_dim, + message_module=message_module, + aggregator_module=LastAggregator(), +).to(args.device) +encoder = GraphAttentionEmbedding( + in_channels=args.memory_dim, + out_channels=args.embed_dim, + msg_dim=test_dg.edge_x_dim, + time_enc=memory.time_enc, +).to(args.device) +decoder = LinkPredictor(node_dim=args.embed_dim, hidden_dim=args.embed_dim).to( + args.device +) +opt = torch.optim.Adam( + set(memory.parameters()) | set(encoder.parameters()) | set(decoder.parameters()), + lr=args.lr, +) + +best_val = 0.0 + +for epoch in range(1, args.epochs + 1): + with hm.activate(train_key): + loss = train(train_loader, memory, encoder, decoder, opt) + + with hm.activate(val_key): + val_mrr = eval(val_loader, memory, encoder, decoder, evaluator) + log_metric('Loss', loss, epoch=epoch) + log_metric(f'Validation {METRIC_TGB_LINKPROPPRED}', val_mrr, epoch=epoch) + + if val_mrr > best_val: + best_val = val_mrr + with hm.activate(test_key): + test_mrr = eval(test_loader, memory, encoder, decoder, evaluator) + log_metric(f'Test {METRIC_TGB_LINKPROPPRED}', test_mrr, epoch=args.epochs) + + if epoch < args.epochs: + hm.reset_state() diff --git a/examples/nodeproppred/tgnv2.py b/examples/nodeproppred/tgnv2.py new file mode 100644 index 00000000..5043e8ce --- /dev/null +++ b/examples/nodeproppred/tgnv2.py @@ -0,0 +1,261 @@ +import argparse + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from tgb.nodeproppred.evaluate import Evaluator +from tqdm import tqdm + +from tgm import DGraph +from tgm.constants import METRIC_TGB_NODEPROPPRED, PADDED_NODE_ID +from tgm.data import DGData, DGDataLoader +from tgm.hooks import DeduplicationHook, HookManager, RecencyNeighborHook +from tgm.nn import NodePredictor, TGNv2Memory +from tgm.nn.encoder.tgn import ( + EncodeIndexMessage, + GraphAttentionEmbedding, + LastAggregator, +) +from tgm.util.logging import enable_logging, log_gpu, log_latency, log_metric +from tgm.util.seed import seed_everything + +parser = argparse.ArgumentParser( + description='TGNv2 NodePropPred Example', + formatter_class=argparse.ArgumentDefaultsHelpFormatter, +) +parser.add_argument('--seed', type=int, default=1337, help='random seed to use') +parser.add_argument('--dataset', type=str, default='tgbn-trade', help='Dataset name') +parser.add_argument('--bsize', type=int, default=200, help='batch size') +parser.add_argument('--device', type=str, default='cpu', help='torch device') +parser.add_argument('--epochs', type=int, default=30, help='number of epochs') +parser.add_argument('--lr', type=float, default=0.0001, help='learning rate') +parser.add_argument('--time-dim', type=int, default=100, help='time encoding dimension') +parser.add_argument('--embed-dim', type=int, default=100, help='attention dimension') +parser.add_argument('--memory-dim', type=int, default=100, help='memory dimension') +parser.add_argument( + '--index-dim', + type=int, + default=None, + help='source-target node ID encoding dimension; defaults to memory_dim', +) +parser.add_argument( + '--n-nbrs', + type=int, + nargs='+', + default=[10], + help='num sampled nbrs at each hop', +) +parser.add_argument( + '--time-gran', + type=str, + default=None, + help='raw time granularity for dataset', +) +parser.add_argument( + '--log-file-path', type=str, default=None, help='Optional path to write logs' +) + +args = parser.parse_args() +enable_logging(log_file_path=args.log_file_path) + + +@log_gpu +@log_latency +def train( + loader: DGDataLoader, + memory: nn.Module, + encoder: nn.Module, + decoder: nn.Module, + opt: torch.optim.Optimizer, +) -> tuple[float, float]: + memory.train() + encoder.train() + decoder.train() + total_loss = 0 + + perf_list = [] + memory.reset_state() + for batch in tqdm(loader): + opt.zero_grad() + y_labels = batch.node_y + if y_labels is not None: + nbr_nodes = batch.nbr_nids[0].flatten() + nbr_mask = nbr_nodes != PADDED_NODE_ID + + num_nbrs = len(nbr_nodes) // (len(batch.node_y_nids)) + src_nodes = batch.node_y_nids.repeat_interleave(num_nbrs) + nbr_edge_index = torch.stack( + [ + batch.global_to_local(src_nodes[nbr_mask]), + batch.global_to_local(nbr_nodes[nbr_mask]), + ] + ).to(dtype=torch.int64) + + nbr_edge_time = batch.nbr_edge_time[0].flatten()[nbr_mask] + nbr_edge_x = batch.nbr_edge_x[0].flatten(0, -2).float()[nbr_mask] + + z, last_update = memory(batch.unique_nids) + z = encoder(z, last_update, nbr_edge_index, nbr_edge_time, nbr_edge_x) + + inv_src = batch.global_to_local(batch.node_y_nids) + y_pred = decoder(z[inv_src]) + loss = F.cross_entropy(y_pred, y_labels) + loss.backward() + opt.step() + total_loss += float(loss) + + input_dict = { + 'y_true': y_labels, + 'y_pred': y_pred, + 'eval_metric': [METRIC_TGB_NODEPROPPRED], + } + perf = evaluator.eval(input_dict)[METRIC_TGB_NODEPROPPRED] + perf_list.append(perf) + + if len(batch.edge_src) > 0: + memory.update_state( + batch.edge_src, batch.edge_dst, batch.edge_time, batch.edge_x.float() + ) + memory.detach() + + return total_loss, float(np.mean(perf_list)) + + +@log_gpu +@log_latency +@torch.no_grad() +def eval( + loader: DGDataLoader, + memory: nn.Module, + encoder: nn.Module, + decoder: nn.Module, + evaluator: Evaluator, +) -> float: + memory.eval() + encoder.eval() + decoder.eval() + perf_list = [] + + for batch in tqdm(loader): + y_labels = batch.node_y + if y_labels is not None: + nbr_nodes = batch.nbr_nids[0].flatten() + nbr_mask = nbr_nodes != PADDED_NODE_ID + + num_nbrs = len(nbr_nodes) // (len(batch.node_y_nids)) + src_nodes = batch.node_y_nids.repeat_interleave(num_nbrs) + nbr_edge_index = torch.stack( + [ + batch.global_to_local(src_nodes[nbr_mask]), + batch.global_to_local(nbr_nodes[nbr_mask]), + ] + ).to(dtype=torch.int64) + + nbr_edge_time = batch.nbr_edge_time[0].flatten()[nbr_mask] + nbr_edge_x = batch.nbr_edge_x[0].flatten(0, -2).float()[nbr_mask] + + z, last_update = memory(batch.unique_nids) + z = encoder(z, last_update, nbr_edge_index, nbr_edge_time, nbr_edge_x) + + inv_src = batch.global_to_local(batch.node_y_nids) + y_pred = decoder(z[inv_src]) + + input_dict = { + 'y_true': y_labels, + 'y_pred': y_pred, + 'eval_metric': [METRIC_TGB_NODEPROPPRED], + } + perf_list.append(evaluator.eval(input_dict)[METRIC_TGB_NODEPROPPRED]) + + if len(batch.edge_src) > 0: + memory.update_state( + batch.edge_src, batch.edge_dst, batch.edge_time, batch.edge_x.float() + ) + + return float(np.mean(perf_list)) + + +seed_everything(args.seed) +evaluator = Evaluator(name=args.dataset) + +full_data = DGData.from_tgb(args.dataset) +train_data, val_data, test_data = full_data.split() + +if args.time_gran is not None: + train_data = train_data.discretize(args.time_gran) + val_data = val_data.discretize(args.time_gran) + test_data = test_data.discretize(args.time_gran) + +train_dg = DGraph(train_data, device=args.device) +val_dg = DGraph(val_data, device=args.device) +test_dg = DGraph(test_data, device=args.device) + +num_classes = train_dg.node_y_dim + +nbr_hook = RecencyNeighborHook( + num_nbrs=args.n_nbrs, + num_nodes=full_data.num_nodes, + seed_nodes_keys=['node_y_nids'], + seed_times_keys=['node_y_time'], +) + +hm = HookManager(keys=['train', 'val', 'test']) +hm.register_shared(nbr_hook) +hm.register_shared(DeduplicationHook(seed_nodes_keys=['node_y_nids', 'nbr_nids'])) + +train_loader = DGDataLoader(train_dg, args.bsize, hook_manager=hm) +val_loader = DGDataLoader(val_dg, args.bsize, hook_manager=hm) +test_loader = DGDataLoader(test_dg, args.bsize, hook_manager=hm) + +index_dim = args.memory_dim if args.index_dim is None else args.index_dim +message_module = EncodeIndexMessage( + test_dg.edge_x_dim, + args.memory_dim, + args.time_dim, + index_dim, +) +memory = TGNv2Memory( + full_data.num_nodes, + test_dg.edge_x_dim, + args.memory_dim, + args.time_dim, + index_dim, + message_module=message_module, + aggregator_module=LastAggregator(), +).to(args.device) +encoder = GraphAttentionEmbedding( + in_channels=args.memory_dim, + out_channels=args.embed_dim, + msg_dim=test_dg.edge_x_dim, + time_enc=memory.time_enc, +).to(args.device) +decoder = NodePredictor( + in_dim=args.embed_dim, out_dim=num_classes, hidden_dim=args.embed_dim +).to(args.device) +opt = torch.optim.Adam( + set(memory.parameters()) | set(encoder.parameters()) | set(decoder.parameters()), + lr=args.lr, +) + +best_val = 0.0 + +for epoch in range(1, args.epochs + 1): + with hm.activate('train'): + loss, train_metric = train(train_loader, memory, encoder, decoder, opt) + + with hm.activate('val'): + val_metric = eval(val_loader, memory, encoder, decoder, evaluator) + + log_metric('Loss', loss, epoch=epoch) + log_metric(f'Train {METRIC_TGB_NODEPROPPRED}', train_metric, epoch=epoch) + log_metric(f'Validation {METRIC_TGB_NODEPROPPRED}', val_metric, epoch=epoch) + + if val_metric > best_val: + best_val = val_metric + with hm.activate('test'): + test_metric = eval(test_loader, memory, encoder, decoder, evaluator) + log_metric(f'Test {METRIC_TGB_NODEPROPPRED}', test_metric, epoch=args.epochs) + + if epoch < args.epochs: + hm.reset_state() diff --git a/test/integration/test_tgnv2.py b/test/integration/test_tgnv2.py new file mode 100644 index 00000000..18c31735 --- /dev/null +++ b/test/integration/test_tgnv2.py @@ -0,0 +1,43 @@ +import pytest + + +@pytest.mark.integration +@pytest.mark.parametrize('dataset', ['tgbl-wiki']) +@pytest.mark.slurm( + resources=[ + '--partition=main', + '--cpus-per-task=2', + '--mem=8G', + '--time=1:00:00', + '--gres=gpu:a100l:1', + ] +) +def test_tgnv2_linkprop_pred(slurm_job_runner, dataset): + cmd = f""" +python "$ROOT_DIR/examples/linkproppred/tgnv2.py" \ + --dataset {dataset} \ + --device cuda \ + --epochs 1""" + state = slurm_job_runner(cmd) + assert state == 'COMPLETED' + + +@pytest.mark.integration +@pytest.mark.parametrize('dataset', ['tgbn-trade']) +@pytest.mark.slurm( + resources=[ + '--partition=main', + '--cpus-per-task=2', + '--mem=4G', + '--time=1:00:00', + '--gres=gpu:a100l:1', + ] +) +def test_tgnv2_nodeprop_pred(slurm_job_runner, dataset): + cmd = f""" +python "$ROOT_DIR/examples/nodeproppred/tgnv2.py" \ + --dataset {dataset} \ + --device cuda \ + --epochs 1""" + state = slurm_job_runner(cmd) + assert state == 'COMPLETED' diff --git a/test/unit/test_hooks/test_deduplication_hook.py b/test/unit/test_hooks/test_deduplication_hook.py index a184d253..da0a6c5b 100644 --- a/test/unit/test_hooks/test_deduplication_hook.py +++ b/test/unit/test_hooks/test_deduplication_hook.py @@ -2,6 +2,7 @@ import torch from tgm import DGraph +from tgm.constants import PADDED_NODE_ID from tgm.data import DGData, DGDataLoader from tgm.hooks import DeduplicationHook from tgm.hooks.hook_manager import HookManager @@ -111,6 +112,29 @@ def test_dedup_with_nbrs(dg): ) +def test_dedup_with_node_labels_and_nbrs(dg): + hook = DeduplicationHook(seed_nodes_keys=['node_y_nids', 'nbr_nids']) + batch = dg.materialize() + batch.node_y_nids = torch.IntTensor([5, 10]) + batch.nbr_nids = [ + torch.IntTensor([[1, 5], [10, PADDED_NODE_ID]]), + ] + + processed_batch = hook(dg, batch) + torch.testing.assert_close( + processed_batch.unique_nids, torch.IntTensor([1, 2, 4, 5, 8, 10]) + ) + torch.testing.assert_close( + processed_batch.global_to_local(batch.node_y_nids), torch.IntTensor([3, 5]) + ) + torch.testing.assert_close( + processed_batch.global_to_local( + batch.nbr_nids[0][batch.nbr_nids[0] != PADDED_NODE_ID] + ), + torch.IntTensor([0, 3, 5]), + ) + + @pytest.fixture def node_only_graph(): edge_index = torch.IntTensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]]) diff --git a/test/unit/test_nn/test_tgn.py b/test/unit/test_nn/test_tgn.py index 2d45367d..7dddaf5a 100644 --- a/test/unit/test_nn/test_tgn.py +++ b/test/unit/test_nn/test_tgn.py @@ -2,10 +2,12 @@ from tgm.nn import TGNMemory from tgm.nn.encoder.tgn import ( + EncodeIndexMessage, GraphAttentionEmbedding, IdentityMessage, LastAggregator, MeanAggregator, + TGNv2Memory, ) @@ -13,6 +15,7 @@ def test_tgn_last_aggre(): B, Z = 10, 100 E, M, T = 7, 5, 2 + n_id = torch.arange(B) edge_index = torch.randint(0, B, size=(2, B)) edge_time = torch.randint(0, B, size=(B,)) edge_feat = torch.randint(0, B, size=(B, E)) @@ -32,7 +35,7 @@ def test_tgn_last_aggre(): ) memory.train() encoder.train() - z, last_update = memory(torch.unique(edge_index)) + z, last_update = memory(n_id) z = encoder(z, last_update, edge_index, edge_time, edge_feat) memory.detach() memory.reset_parameters() @@ -42,7 +45,7 @@ def test_tgn_last_aggre(): memory.eval() encoder.eval() - z, last_update = memory(torch.unique(edge_index)) + z, last_update = memory(n_id) z = encoder(z, last_update, edge_index, edge_time, edge_feat) memory.detach() @@ -54,6 +57,7 @@ def test_tgn_mean_aggre(): B, Z = 10, 100 E, M, T = 7, 5, 2 + n_id = torch.arange(B) edge_index = torch.randint(0, B, size=(2, B)) edge_time = torch.randint(0, B, size=(B,)) edge_feat = torch.randint(0, B, size=(B, E)) @@ -73,7 +77,7 @@ def test_tgn_mean_aggre(): ) memory.train() encoder.train() - z, last_update = memory(torch.unique(edge_index)) + z, last_update = memory(n_id) z = encoder(z, last_update, edge_index, edge_time, edge_feat) memory.update_state(edge_index[0], edge_index[1], edge_time, edge_feat.float()) @@ -84,10 +88,74 @@ def test_tgn_mean_aggre(): memory.eval() encoder.eval() - z, last_update = memory(torch.unique(edge_index)) + z, last_update = memory(n_id) z = encoder(z, last_update, edge_index, edge_time, edge_feat) memory.update_state(edge_index[0], edge_index[1], edge_time, edge_feat.float()) memory.detach() assert z.shape == (B, Z) assert not torch.isnan(z).any() + + +def test_tgnv2_encode_index_message(): + B, E, M, T, I = 3, 2, 4, 5, 6 + msg_module = EncodeIndexMessage(E, M, T, I) + + z_src = torch.randn(B, M) + z_dst = torch.randn(B, M) + raw_msg = torch.randn(B, E) + t_enc = torch.randn(B, T) + src_enc = torch.randn(B, I) + dst_enc = torch.randn(B, I) + + out = msg_module(z_src, z_dst, raw_msg, t_enc, src_enc, dst_enc) + expected = torch.cat([z_src, z_dst, raw_msg, src_enc, dst_enc, t_enc], dim=-1) + + assert msg_module.out_channels == 2 * M + E + 2 * I + T + assert out.shape == (B, msg_module.out_channels) + torch.testing.assert_close(out, expected) + + +def test_tgnv2_memory(): + B, Z = 10, 100 + E, M, T, I = 7, 5, 2, 3 + + n_id = torch.arange(B) + edge_index = torch.randint(0, B, size=(2, B)) + edge_time = torch.randint(0, B, size=(B,)) + edge_feat = torch.randn(B, E) + memory = TGNv2Memory( + B, + E, + M, + T, + I, + message_module=EncodeIndexMessage(E, M, T, I), + aggregator_module=LastAggregator(), + ) + encoder = GraphAttentionEmbedding( + in_channels=M, + out_channels=Z, + msg_dim=E, + time_enc=memory.time_enc, + ) + + memory.train() + encoder.train() + z, last_update = memory(n_id) + z = encoder(z, last_update, edge_index, edge_time, edge_feat) + memory.update_state(edge_index[0], edge_index[1], edge_time, edge_feat) + memory.detach() + + assert z.shape == (B, Z) + assert not torch.isnan(z).any() + + memory.eval() + encoder.eval() + z, last_update = memory(n_id) + z = encoder(z, last_update, edge_index, edge_time, edge_feat) + memory.update_state(edge_index[0], edge_index[1], edge_time, edge_feat) + memory.detach() + + assert z.shape == (B, Z) + assert not torch.isnan(z).any() diff --git a/tgm/nn/__init__.py b/tgm/nn/__init__.py index 93ddb5e3..f4c6f184 100644 --- a/tgm/nn/__init__.py +++ b/tgm/nn/__init__.py @@ -1,4 +1,14 @@ -from .encoder import CTAN, CTANMemory, DyGFormer, TPNet, TGCN, GCLSTM, TGNMemory +from .encoder import ( + CTAN, + CTANMemory, + DyGFormer, + EncodeIndexMessage, + TPNet, + TGCN, + GCLSTM, + TGNMemory, + TGNv2Memory, +) from .decoder import GraphPredictor, NodePredictor, LinkPredictor, NCNPredictor from .encoder import DyGFormer, TPNet, TGCN, GCLSTM, RandomProjectionModule, ROLAND from .modules import ( @@ -15,6 +25,7 @@ 'CTAN', 'CTANMemory', 'DyGFormer', + 'EncodeIndexMessage', 'EdgeBankPredictor', 'GCLSTM', 'GraphPredictor', @@ -27,6 +38,7 @@ 'Time2Vec', 'tCoMemPredictor', 'TGNMemory', + 'TGNv2Memory', 'NCNPredictor', 'PopTrackPredictor', ] diff --git a/tgm/nn/encoder/__init__.py b/tgm/nn/encoder/__init__.py index ecae2fa9..895b24a1 100644 --- a/tgm/nn/encoder/__init__.py +++ b/tgm/nn/encoder/__init__.py @@ -4,11 +4,13 @@ from .tgcn import TGCN from .gclstm import GCLSTM from .tgn import ( + EncodeIndexMessage, GraphAttentionEmbedding, + IdentityMessage, LastAggregator, MeanAggregator, - IdentityMessage, TGNMemory, + TGNv2Memory, ) from .roland import ROLAND @@ -22,10 +24,12 @@ 'RandomProjectionModule', 'TGCN', 'TPNet', + 'EncodeIndexMessage', 'GraphAttentionEmbedding', 'LastAggregator', 'MeanAggregator', 'IdentityMessage', 'TGNMemory', + 'TGNv2Memory', 'ROLAND', ] diff --git a/tgm/nn/encoder/tgn.py b/tgm/nn/encoder/tgn.py index e6ec4620..17dea59f 100644 --- a/tgm/nn/encoder/tgn.py +++ b/tgm/nn/encoder/tgn.py @@ -74,6 +74,31 @@ def forward( return torch.cat([z_src, z_dst, raw_msg, t_enc], dim=-1) +class EncodeIndexMessage(torch.nn.Module): + def __init__( + self, + raw_msg_dim: int, + memory_dim: int, + time_dim: int, + index_dim: int, + ) -> None: + super().__init__() + if index_dim <= 0: + raise ValueError('index_dim must be positive') + self.out_channels = raw_msg_dim + 2 * memory_dim + time_dim + 2 * index_dim + + def forward( + self, + z_src: Tensor, + z_dst: Tensor, + raw_msg: Tensor, + t_enc: Tensor, + src_enc: Tensor, + dst_enc: Tensor, + ) -> torch.Tensor: + return torch.cat([z_src, z_dst, raw_msg, src_enc, dst_enc, t_enc], dim=-1) + + TGNMessageStoreType = Dict[int, Tuple[Tensor, Tensor, Tensor, Tensor]] @@ -249,3 +274,59 @@ def train(self, mode: bool = True) -> 'TGNMemory': self._reset_message_store() super().train(mode) return self + + +class TGNv2Memory(TGNMemory): + def __init__( + self, + num_nodes: int, + raw_msg_dim: int, + memory_dim: int, + time_dim: int, + index_dim: int, + message_module: Callable, + aggregator_module: Callable, + ): + if index_dim <= 0: + raise ValueError('index_dim must be positive') + self.index_dim = index_dim + super().__init__( + num_nodes, + raw_msg_dim, + memory_dim, + time_dim, + message_module, + aggregator_module, + ) + self.index_enc = Time2Vec(time_dim=index_dim) + + def reset_parameters(self) -> None: + super().reset_parameters() + if hasattr(self, 'index_enc'): + self.index_enc.reset_parameters() + + def _compute_msg( + self, n_id: Tensor, msg_store: TGNMessageStoreType, msg_module: Callable + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + data = [msg_store[i] for i in n_id.tolist()] + src_store, dst_store, t_store, raw_msg_store = list(zip(*data)) + src_tensor = torch.cat(src_store, dim=0).to(self.device) + dst_tensor = torch.cat(dst_store, dim=0).to(self.device) + t_tensor = torch.cat(t_store, dim=0).to(self.device) + raw_msg_tensor = torch.cat(raw_msg_store, dim=0).to(self.device) + t_rel = t_tensor - self.last_update[src_tensor] + t_enc = self.time_enc(t_rel.to(raw_msg_tensor.dtype)) + src_enc = self.index_enc(src_tensor.to(raw_msg_tensor.dtype)) + dst_enc = self.index_enc(dst_tensor.to(raw_msg_tensor.dtype)) + msg = msg_module( + self.memory[src_tensor], + self.memory[dst_tensor], + raw_msg_tensor, + t_enc, + src_enc, + dst_enc, + ) + return cast( + Tuple[Tensor, Tensor, Tensor, Tensor], + (msg, t_tensor, src_tensor, dst_tensor), + ) diff --git a/tgm/nn/modules/time_encoding.py b/tgm/nn/modules/time_encoding.py index 608102a7..1528c6d6 100644 --- a/tgm/nn/modules/time_encoding.py +++ b/tgm/nn/modules/time_encoding.py @@ -13,11 +13,19 @@ def __init__(self, time_dim: int) -> None: super().__init__() self.time_dim = time_dim self.w = torch.nn.Linear(1, time_dim) + self.reset_parameters() + def reset_parameters(self) -> None: # Initialization from: https://github.com/yule-BUAA/DyGLib/blob/master/models/modules.py - w = (1 / 10 ** np.linspace(0, 9, time_dim)).reshape(time_dim, 1) - self.w.weight = torch.nn.Parameter(torch.from_numpy(w).float()) - self.w.bias = torch.nn.Parameter(torch.zeros(time_dim)) + w = (1 / 10 ** np.linspace(0, 9, self.time_dim)).reshape(self.time_dim, 1) + w_tensor = torch.as_tensor( + w, + dtype=self.w.weight.dtype, + device=self.w.weight.device, + ) + with torch.no_grad(): + self.w.weight.copy_(w_tensor) + self.w.bias.zero_() def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.unsqueeze(dim=-1).float() # (batch_size, seq_len, 1) diff --git a/tgnv2.md b/tgnv2.md new file mode 100644 index 00000000..622febb0 --- /dev/null +++ b/tgnv2.md @@ -0,0 +1,28 @@ +# TGNv2 Port Plan + +## Summary + +- Add TGNv2 as reusable TGM modules instead of keeping the implementation example-local. +- Keep the existing TGN APIs backward compatible. +- Draft runnable TGNv2 examples for link prediction and node property prediction using TGM data loaders, hooks, logging, and decoders. + +## Key Changes + +- Add `EncodeIndexMessage` to concatenate source memory, destination memory, raw edge message, source ID encoding, destination ID encoding, and time encoding. +- Add `TGNv2Memory` as a separate memory module that mirrors `TGNMemory` but passes trainable source and destination node-ID encodings into the message module. +- Export the new classes from `tgm.nn.encoder` and `tgm.nn`. +- Implement `examples/linkproppred/tgnv2.py` from the existing TGM TGN link-prediction example, replacing `TGNMemory + IdentityMessage` with `TGNv2Memory + EncodeIndexMessage`. +- Add `examples/nodeproppred/tgnv2.py` using the same reusable modules for node-property parity. + +## Test Plan + +- Extend TGN unit tests with `EncodeIndexMessage` shape/content checks. +- Add `TGNv2Memory` train/eval and `update_state` smoke tests. +- Add an integration smoke test for the link prediction example on `tgbl-wiki`. +- Run focused verification with `pytest test/unit/test_nn/test_tgn.py`. + +## Assumptions + +- The upstream behavior is implemented in TGM style rather than copied verbatim. +- `index_dim` defaults to `memory_dim` in examples when not explicitly provided. +- Existing TGN tests and examples should continue to work unchanged.