diff --git a/app/vjepa_2_1/models/utils/modules.py b/app/vjepa_2_1/models/utils/modules.py index fedf8a2e..20765662 100644 --- a/app/vjepa_2_1/models/utils/modules.py +++ b/app/vjepa_2_1/models/utils/modules.py @@ -285,7 +285,7 @@ def forward( if self.use_sdpa: with torch.backends.cuda.sdp_kernel(): x = F.scaled_dot_product_attention( - q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal + q, k, v, dropout_p=self.proj_drop_prob if self.training else 0.0, is_causal=self.is_causal ) attn = None else: @@ -338,7 +338,7 @@ def forward(self, x): if self.use_sdpa: with torch.backends.cuda.sdp_kernel(): x = F.scaled_dot_product_attention( - q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal + q, k, v, dropout_p=self.proj_drop_prob if self.training else 0.0, is_causal=self.is_causal ) attn = None else: diff --git a/src/models/utils/modules.py b/src/models/utils/modules.py index 4b36564f..65eab58e 100644 --- a/src/models/utils/modules.py +++ b/src/models/utils/modules.py @@ -248,7 +248,8 @@ def merge_(tx, ta): if attn_mask is not None or self.use_sdpa: with torch.backends.cuda.sdp_kernel(): x = F.scaled_dot_product_attention( - q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal, attn_mask=attn_mask + q, k, v, dropout_p=self.proj_drop_prob if self.training else 0.0, + is_causal=self.is_causal, attn_mask=attn_mask ) attn = None else: @@ -372,7 +373,8 @@ def forward(self, x, mask=None, attn_mask=None, T=None, H_patches=None, W_patche if attn_mask is not None or self.use_sdpa: with torch.backends.cuda.sdp_kernel(): x = F.scaled_dot_product_attention( - q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal, attn_mask=attn_mask + q, k, v, dropout_p=self.proj_drop_prob if self.training else 0.0, + is_causal=self.is_causal, attn_mask=attn_mask ) attn = None else: @@ -419,7 +421,8 @@ def forward(self, x, mask=None, attn_mask=None): if attn_mask is not None or self.use_sdpa: with torch.backends.cuda.sdp_kernel(): x = F.scaled_dot_product_attention( - q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal, attn_mask=attn_mask + q, k, v, dropout_p=self.proj_drop_prob if self.training else 0.0, + is_causal=self.is_causal, attn_mask=attn_mask ) attn = None else: