Skip to content

Optimize temporal dropout to remove CPU/GPU bottleneck #50

@dearabhin

Description

@dearabhin

I was looking through tribev2/model.py and noticed a performance bottleneck in the temporal dropout implementation.

Currently, the dropout mask is generated using a Python for loop over the batch dimension, and torch.rand is defaulting to the CPU:

if self.config.temporal_dropout > 0 and self.training:
    for batch_idx in range(out.shape[0]):
        mask = torch.rand(out.shape[1]) < self.config.temporal_dropout
        out[batch_idx, mask, :] = torch.zeros_like(out[batch_idx, mask, :])

This causes a cross-device memory transfer and breaks the CUDA execution graph for every item in the batch.

We can completely vectorize this directly on the GPU using masked_fill, which should give a nice bump to training speed:

if self.config.temporal_dropout > 0 and self.training:
    mask = torch.rand(out.shape[:2], device=out.device) < self.config.temporal_dropout
    out = out.masked_fill_(mask.unsqueeze(-1), 0.0)

Happy to open a quick PR for this if it looks good to you!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions