Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,6 @@ submitit*

# Benchmarks
.benchmarks

# TGB-seq data cache
data/
2 changes: 1 addition & 1 deletion examples/graphproppred/gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def eval(
log_metrics_dict(train_results, epoch=epoch)
log_metrics_dict(val_results, epoch=epoch)

val_score = val_results['BinaryAUROC']
val_score = val_results['ValidationBinaryAUROC']
if val_score > best_val:
best_val = val_score
test_results = eval(test_loader, test_labels, encoder, decoder, test_metrics)
Expand Down
2 changes: 1 addition & 1 deletion examples/graphproppred/tgcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def eval(
log_metrics_dict(train_results, epoch=epoch)
log_metrics_dict(val_results, epoch=epoch)

val_score = val_results['BinaryAUROC']
val_score = val_results['ValidationBinaryAUROC']
if val_score > best_val:
best_val = val_score
test_results, h_0 = eval(
Expand Down
2 changes: 1 addition & 1 deletion examples/linkproppred/ctan.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def compute_delta_t_stats(train_dg: DGraph) -> Tuple[float, float]:
)
train_key, val_key, test_key = hm.keys
hm.register_shared(nbr_hook)
hm.register_shared(DeduplicationHook())
hm.register_shared(DeduplicationHook(seed_nodes_keys=['neg', 'nbr_nids']))

train_loader = DGDataLoader(train_dg, args.bsize, hook_manager=hm)
val_loader = DGDataLoader(val_dg, args.bsize, hook_manager=hm)
Expand Down
2 changes: 1 addition & 1 deletion examples/linkproppred/roland.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def eval(
except StopIteration:
pass

z = z.detach()
z[0], z[1] = z[0].detach(), z[1].detach()
return float(np.mean(perf_list)), z


Expand Down
9 changes: 8 additions & 1 deletion examples/linkproppred/tgb_seq/edgebank.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
from typing import Set

import numpy as np
import torch
Expand Down Expand Up @@ -64,7 +65,13 @@ def eval(


class TGBSEQ_NegativeEdgeSamplerHook(StatelessHook):
produces = {'neg', 'neg_time'}
@property
def produces(self) -> Set[str]:
return {'neg', 'neg_time'}

@property
def requires(self) -> Set[str]:
return {'edge_src', 'edge_dst', 'edge_time'}

def __init__(
self, dataset_name: str, split_mode: str, dgraph: DGraph, root: str = './data'
Expand Down
2 changes: 1 addition & 1 deletion test/integration/test_base3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
'--partition=main',
'--cpus-per-task=2',
'--mem=4G',
'--time=0:05:00',
'--time=0:20:00',
]
)
def test_base3_linkprop_inf_EB_memory(slurm_job_runner, dataset):
Expand Down
6 changes: 3 additions & 3 deletions test/integration/test_edgebank.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_edgebank_linkprop_pred_fixed_memory(slurm_job_runner, dataset):
'--partition=main',
'--cpus-per-task=2',
'--mem=8G',
'--time=0:15:00',
'--time=0:25:00',
]
)
def test_edgebank_tgb_seq_unlimited_memory(slurm_job_runner, dataset):
Expand Down Expand Up @@ -98,7 +98,7 @@ def test_edgebank_linkprop_pred_fixed_memory_thgl(slurm_job_runner, dataset):
'--partition=main',
'--cpus-per-task=2',
'--mem=8G',
'--time=1:15:00',
'--time=3:00:00',
]
)
def test_edgebank_linkprop_pred_unlimited_memory_tkgl(slurm_job_runner, dataset):
Expand All @@ -116,7 +116,7 @@ def test_edgebank_linkprop_pred_unlimited_memory_tkgl(slurm_job_runner, dataset)
'--partition=main',
'--cpus-per-task=2',
'--mem=8G',
'--time=2:00:00',
'--time=4:00:00',
]
)
def test_edgebank_linkprop_pred_fixed_memory_tkgl(slurm_job_runner, dataset):
Expand Down
2 changes: 1 addition & 1 deletion tgm/hooks/hook_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def activate(self, key: str) -> Iterator[None]:
def _ensure_valid_hook(self, hook: Any) -> None:
if not isinstance(hook, DGHook):
raise BadHookProtocolError(
f'Cannot register hook {type(hook).__name__}: must implement __call__(dg: DGraph, batch: DGBatch) -> DGBatch and reset_state()'
f'Cannot register hook {type(hook).__name__}: must implement __call__(dg: DGraph, batch: DGBatch) -> DGBatch, reset_state(), requires and produces properties.'
)

def _ensure_no_active_key(self) -> None:
Expand Down
Loading