Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions app/vjepa_2_1/models/utils/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def forward(
if self.use_sdpa:
with torch.backends.cuda.sdp_kernel():
x = F.scaled_dot_product_attention(
q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal
q, k, v, dropout_p=self.proj_drop_prob if self.training else 0.0, is_causal=self.is_causal
)
attn = None
else:
Expand Down Expand Up @@ -338,7 +338,7 @@ def forward(self, x):
if self.use_sdpa:
with torch.backends.cuda.sdp_kernel():
x = F.scaled_dot_product_attention(
q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal
q, k, v, dropout_p=self.proj_drop_prob if self.training else 0.0, is_causal=self.is_causal
)
attn = None
else:
Expand Down
9 changes: 6 additions & 3 deletions src/models/utils/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,8 @@ def merge_(tx, ta):
if attn_mask is not None or self.use_sdpa:
with torch.backends.cuda.sdp_kernel():
x = F.scaled_dot_product_attention(
q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal, attn_mask=attn_mask
q, k, v, dropout_p=self.proj_drop_prob if self.training else 0.0,
is_causal=self.is_causal, attn_mask=attn_mask
)
attn = None
else:
Expand Down Expand Up @@ -372,7 +373,8 @@ def forward(self, x, mask=None, attn_mask=None, T=None, H_patches=None, W_patche
if attn_mask is not None or self.use_sdpa:
with torch.backends.cuda.sdp_kernel():
x = F.scaled_dot_product_attention(
q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal, attn_mask=attn_mask
q, k, v, dropout_p=self.proj_drop_prob if self.training else 0.0,
is_causal=self.is_causal, attn_mask=attn_mask
)
attn = None
else:
Expand Down Expand Up @@ -419,7 +421,8 @@ def forward(self, x, mask=None, attn_mask=None):
if attn_mask is not None or self.use_sdpa:
with torch.backends.cuda.sdp_kernel():
x = F.scaled_dot_product_attention(
q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal, attn_mask=attn_mask
q, k, v, dropout_p=self.proj_drop_prob if self.training else 0.0,
is_causal=self.is_causal, attn_mask=attn_mask
)
attn = None
else:
Expand Down