I was looking through tribev2/model.py and noticed a performance bottleneck in the temporal dropout implementation.
Currently, the dropout mask is generated using a Python for loop over the batch dimension, and torch.rand is defaulting to the CPU:
if self.config.temporal_dropout > 0 and self.training:
for batch_idx in range(out.shape[0]):
mask = torch.rand(out.shape[1]) < self.config.temporal_dropout
out[batch_idx, mask, :] = torch.zeros_like(out[batch_idx, mask, :])
This causes a cross-device memory transfer and breaks the CUDA execution graph for every item in the batch.
We can completely vectorize this directly on the GPU using masked_fill, which should give a nice bump to training speed:
if self.config.temporal_dropout > 0 and self.training:
mask = torch.rand(out.shape[:2], device=out.device) < self.config.temporal_dropout
out = out.masked_fill_(mask.unsqueeze(-1), 0.0)
Happy to open a quick PR for this if it looks good to you!
I was looking through
tribev2/model.pyand noticed a performance bottleneck in the temporal dropout implementation.Currently, the dropout mask is generated using a Python
forloop over the batch dimension, andtorch.randis defaulting to the CPU:This causes a cross-device memory transfer and breaks the CUDA execution graph for every item in the batch.
We can completely vectorize this directly on the GPU using
masked_fill, which should give a nice bump to training speed:Happy to open a quick PR for this if it looks good to you!