Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions docs/api/hooks/hooks.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@
::: tgm.hooks.PinMemoryHook
::: tgm.hooks.DeviceTransferHook
::: tgm.hooks.DeduplicationHook
::: tgm.hooks.TGBNegativeEdgeSamplerHook
::: tgm.hooks.NegativeEdgeSamplerHook
::: tgm.hooks.RecencyNeighborHook
::: tgm.hooks.BatchAnalyticsHook
::: tgm.hooks.NodeAnalyticsHook
::: tgm.hooks.negatives.TGBNegativeEdgeSamplerHook
::: tgm.hooks.negatives.TGBTHGNegativeEdgeSamplerHook
::: tgm.hooks.negatives.TGBTKGNegativeEdgeSamplerHook
::: tgm.hooks.negatives.RandomNegativeEdgeSamplerHook
::: tgm.hooks.negatives.HistoricalNegativeEdgeSamplerHook
::: tgm.hooks.neighbors.NeighborSamplerHook
::: tgm.hooks.neighbors.RecencyNeighborHook
::: tgm.hooks.analytics.BatchAnalyticsHook
::: tgm.hooks.analytics.NodeAnalyticsHook
2 changes: 1 addition & 1 deletion examples/analytics/node_analytics_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from tgm import DGraph
from tgm.data import DGData, DGDataLoader
from tgm.hooks import HookManager
from tgm.hooks.node_analytics import NodeAnalyticsHook
from tgm.hooks.analytics.node_analytics import NodeAnalyticsHook
from tgm.util.logging import enable_logging, log_latency, log_metrics_dict
from tgm.util.seed import seed_everything

Expand Down
123 changes: 112 additions & 11 deletions test/unit/test_hooks/test_negative_edge_sampler_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@
import torch

from tgm import DGBatch, DGraph
from tgm.constants import PADDED_NODE_ID
from tgm.data import DGData, DGDataLoader
from tgm.hooks import HookManager, NegativeEdgeSamplerHook
from tgm.hooks import (
HistoricalNegativeEdgeSamplerHook,
HookManager,
RandomNegativeEdgeSamplerHook,
)


@pytest.fixture
Expand All @@ -14,36 +19,48 @@ def data():


def test_hook_dependancies():
hook = NegativeEdgeSamplerHook(low=0, high=10)
hook = RandomNegativeEdgeSamplerHook(low=0, high=10)
assert hook.requires == {'edge_src', 'edge_dst', 'edge_time'}
assert hook.produces == {'neg', 'neg_time'}

hook_with_id = NegativeEdgeSamplerHook(low=0, high=10, id='foo')
hook_with_id = RandomNegativeEdgeSamplerHook(low=0, high=10, id='foo')
assert hook_with_id.requires == {'edge_src', 'edge_dst', 'edge_time'}
assert hook_with_id.produces == {'neg_foo', 'neg_time_foo'}

hook = HistoricalNegativeEdgeSamplerHook()
assert hook.requires == {'edge_src', 'edge_dst', 'edge_time'}
assert hook.produces == {'neg', 'neg_time', 'valid_neg_mask'}

hook_with_id = HistoricalNegativeEdgeSamplerHook(id='foo')
assert hook_with_id.requires == {'edge_src', 'edge_dst', 'edge_time'}
assert hook_with_id.produces == {'neg_foo', 'neg_time_foo', 'valid_neg_mask_foo'}


def test_hook_repre():
hook_with_id = NegativeEdgeSamplerHook(low=0, high=10, id='foo')
hook_with_id = RandomNegativeEdgeSamplerHook(low=0, high=10, id='foo')
assert 'foo' in hook_with_id.__repr__()

hook_with_id = HistoricalNegativeEdgeSamplerHook(id='foo')
assert 'foo' in hook_with_id.__repr__()


def test_hook_reset_state():
assert NegativeEdgeSamplerHook.has_state == False
assert RandomNegativeEdgeSamplerHook.has_state == False
assert HistoricalNegativeEdgeSamplerHook.has_state == True


def test_bad_negative_edge_sampler_init():
with pytest.raises(ValueError):
NegativeEdgeSamplerHook(low=0, high=0)
RandomNegativeEdgeSamplerHook(low=0, high=0)
with pytest.raises(ValueError):
NegativeEdgeSamplerHook(low=0, high=1, neg_ratio=0)
RandomNegativeEdgeSamplerHook(low=0, high=1, neg_ratio=0)
with pytest.raises(ValueError):
NegativeEdgeSamplerHook(low=0, high=1, neg_ratio=2)
RandomNegativeEdgeSamplerHook(low=0, high=1, neg_ratio=2)


def test_negative_edge_sampler(data):
dg = DGraph(data)
hook = NegativeEdgeSamplerHook(low=0, high=10)
hook = RandomNegativeEdgeSamplerHook(low=0, high=10)
batch = hook(dg, dg.materialize())
assert isinstance(batch, DGBatch)
assert torch.is_tensor(batch.neg)
Expand All @@ -54,7 +71,7 @@ def test_negative_edge_sampler(data):

def test_negative_edge_sampler_with_id(data):
dg = DGraph(data)
hook = NegativeEdgeSamplerHook(low=0, high=10, id='foo')
hook = RandomNegativeEdgeSamplerHook(low=0, high=10, id='foo')
batch = hook(dg, dg.materialize())
assert isinstance(batch, DGBatch)
assert torch.is_tensor(batch.neg_foo)
Expand Down Expand Up @@ -82,7 +99,7 @@ def node_only_data():
def test_node_only_batch_negative_edge_sampler(node_only_data):
dg = DGraph(node_only_data)
hm = HookManager(keys=['unit'])
hm.register('unit', NegativeEdgeSamplerHook(low=0, high=6))
hm.register('unit', RandomNegativeEdgeSamplerHook(low=0, high=6))
loader = DGDataLoader(dg, batch_size=3, hook_manager=hm)
with hm.activate('unit'):
batch_iter = iter(loader)
Expand All @@ -97,3 +114,87 @@ def test_node_only_batch_negative_edge_sampler(node_only_data):
assert isinstance(batch_2, DGBatch)
assert batch_2.neg.shape == (0,)
assert batch_2.neg_time.shape == (0,)


@pytest.fixture
def data_test_hst_sampling():
edge_index = torch.IntTensor(
[
# 1st batch
[1, 5],
[7, 6],
[2, 8],
[7, 8],
# 2nd batch
[1, 7],
[9, 10],
[3, 10],
[1, 9],
# 3rd batch
[3, 11],
[2, 10],
[7, 2],
[3, 5],
]
)
edge_time = torch.arange(edge_index.size(0))
return DGData.from_raw(edge_time, edge_index)


def test_hst_sampling(data_test_hst_sampling):
dg = DGraph(data_test_hst_sampling)

hm = HookManager(keys=['unit'])
sampler = HistoricalNegativeEdgeSamplerHook()

hm.register('unit', sampler)
loader = DGDataLoader(dg, batch_size=4, hook_manager=hm)

with hm.activate('unit'):
batch_iter = iter(loader)
batch_1 = next(batch_iter)
assert batch_1.neg.shape == (4,)
assert torch.equal(
batch_1.neg,
torch.Tensor(
[PADDED_NODE_ID, PADDED_NODE_ID, PADDED_NODE_ID, PADDED_NODE_ID]
),
)
assert torch.equal(
batch_1.valid_neg_mask,
torch.Tensor([False, False, False, False]),
)
assert sampler._memory is not None
assert sampler._memory.shape == (2, 8)
assert sampler._count == 4

batch_2 = next(batch_iter)
assert batch_2.neg.shape == (4,)
assert torch.equal(
batch_2.neg, torch.Tensor([5, PADDED_NODE_ID, PADDED_NODE_ID, 5])
)
assert torch.equal(
batch_2.valid_neg_mask,
torch.Tensor([True, False, False, True]),
)
assert sampler._memory is not None
assert sampler._memory.shape == (2, 8)
assert sampler._count == 8

batch_3 = next(batch_iter)
assert batch_3.neg.shape == (4,)
assert torch.equal(batch_3.neg, torch.Tensor([10, 8, 8, 10])) or torch.equal(
batch_3.neg, torch.Tensor([10, 8, 6, 10])
)
assert torch.equal(
batch_3.valid_neg_mask,
torch.Tensor([True, True, True, True]),
)
assert sampler._memory is not None
assert sampler._memory.shape == (2, 24)
assert sampler._count != 0
assert sampler._count == 12

sampler.reset_state()
assert sampler._memory is None
assert sampler._count == 0
2 changes: 1 addition & 1 deletion test/unit/test_hooks/test_node_analytics_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from tgm import DGBatch, DGraph
from tgm.data import DGData
from tgm.hooks.node_analytics import NodeAnalyticsHook
from tgm.hooks.analytics.node_analytics import NodeAnalyticsHook


@pytest.fixture
Expand Down
4 changes: 2 additions & 2 deletions test/unit/test_hooks/test_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tgm.exceptions import UndefinedRecipe
from tgm.hooks import (
HookManager,
NegativeEdgeSamplerHook,
RandomNegativeEdgeSamplerHook,
RecipeRegistry,
TGBNegativeEdgeSamplerHook,
)
Expand Down Expand Up @@ -59,7 +59,7 @@ def test_build_recipe_tgb_link_pred(mock_dataset_cls, tgb_dataset_factory, dg):
and register_keys[2] == 'test'
)
assert len(train_hooks) == len(val_hooks) == len(test_hooks) == 1
assert isinstance(train_hooks[0], NegativeEdgeSamplerHook)
assert isinstance(train_hooks[0], RandomNegativeEdgeSamplerHook)
assert isinstance(val_hooks[0], TGBNegativeEdgeSamplerHook)
assert isinstance(test_hooks[0], TGBNegativeEdgeSamplerHook)

Expand Down
6 changes: 3 additions & 3 deletions tgm/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from .dedup import DeduplicationHook
from .device import DeviceTransferHook, PinMemoryHook
from .negatives import (
NegativeEdgeSamplerHook,
RandomNegativeEdgeSamplerHook,
HistoricalNegativeEdgeSamplerHook,
TGBNegativeEdgeSamplerHook,
TGBTHGNegativeEdgeSamplerHook,
TGBTKGNegativeEdgeSamplerHook,
Expand All @@ -11,5 +12,4 @@
from .hook_manager import HookManager
from .recipe import RecipeRegistry
from .node_tracks import EdgeEventsSeenNodesTrackHook
from .batch_analytics import BatchAnalyticsHook
from .node_analytics import NodeAnalyticsHook
from .analytics import BatchAnalyticsHook, NodeAnalyticsHook
2 changes: 2 additions & 0 deletions tgm/hooks/analytics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .batch_analytics import BatchAnalyticsHook
from .node_analytics import NodeAnalyticsHook
File renamed without changes.
File renamed without changes.
Loading
Loading