Historical negative sampling#406
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
| self._memory[0, self._count : self._count + batch_size] = batch.edge_src | ||
| self._memory[1, self._count : self._count + batch_size] = batch.edge_dst |
There was a problem hiding this comment.
One drawback of this is that the amortized space complexity is O(number of observed edge events) rather than O(number of edges).
There was a problem hiding this comment.
isn't this better? you mean we have to keep a buffer in space now?
There was a problem hiding this comment.
Ideally, we want the space complexity of this to scale linearly w.r.t the number of unique edges in the graphs. However, the amortized space complexity of this implementation is O(number of observed edge events), which means the memory will contain duplicated edges. One benefit of this implementation: it naturally enforces that edges, which appear more frequently in the past, would have higher probabilities to be sampled.
shenyangHuang
left a comment
There was a problem hiding this comment.
Thanks Bao, two major things:
- historical negative and random negative should be two hooks, one is stateful, one is stateless
- I think we should now seriously consider having a folder or a base class that handles negative sampling and shouldn't throw everything in a single script anymore
|
|
||
|
|
||
| class NegativeEdgeSamplerHook(StatelessHook): | ||
| class NegativeEdgeSamplerHook(StatefulHook): |
There was a problem hiding this comment.
I think we should have two classes, one for random negative, one for historical negative, it is easier conceptually for me to understand as well. Then I can mix and match the two by having two hooks in my hm.
There was a problem hiding this comment.
and also one of them is stateless, one of them is stateful, we should definitely separate them. Otherwise, by same logic, we should merge uniform sampling with recency sampling
|
|
||
| For each source node in the batch, randomly selects a destination node from | ||
| its past interactions stored in memory. If a source node has no recorded past | ||
| interactions, its corresponding negative sample is set to PADDED_NODE_ID as |
There was a problem hiding this comment.
negative sample set to PADDED_NODE_ID is fine, but we need to remind users to mask those out correctly. Alternatively the hook can let you know how much is padded?
There was a problem hiding this comment.
This function will return PADDED_NODE_ID for nodes that don't have past interactions. However, PADDED_NODE_ID will be replaced with random dsts before returning. Here is the logic from __call__:
elif self.strategy == 'hist_rnd':
if self._count == 0:
neg, neg_time = self._random_sampling(dg, batch)
neg, neg_time = neg[: size[0]], neg_time[: size[0]]
else: #replace PADDED_NODE_ID with random dst
rnd_size = round(size[0] * 0.5)
hst_size = size[0] - rnd_size
neg_rnd, neg_time_rnd = self._random_sampling(dg, batch)
neg_hst, neg_time_hst = self._random_hist_sampling(dg, batch)
original_valid_mask = neg_hst != PADDED_NODE_ID
valid_idx = torch.where(original_valid_mask)[0]
cutoff = min(hst_size, valid_idx.size(0))
neg = neg_rnd.clone()
neg_time = neg_time_rnd.clone()
chosen = valid_idx[:cutoff]
neg[chosen] = neg_hst[chosen]
neg_time[chosen] = neg_time_hst[chosen]So PADDED_NODE_ID won't be propagated to downstream, and for nodes that don't have past interactions, we use random sampling
| self._memory[0, self._count : self._count + batch_size] = batch.edge_src | ||
| self._memory[1, self._count : self._count + batch_size] = batch.edge_dst |
There was a problem hiding this comment.
isn't this better? you mean we have to keep a buffer in space now?
shenyangHuang
left a comment
There was a problem hiding this comment.
Thanks Bao, looks good on my end
Summary / Description
This PR handles:
tgbnegative samplers (tgbl,tkgl,thgl) to avoid duplicated codeRelated Issues: #405
Type of Change
Test Evidence
Describe how this PR has been tested.
Questions / Discussion Points
List any areas where you’d like reviewer input or have open questions.