From 8d175bd04d321c2186d605d801cd9245abbe752f Mon Sep 17 00:00:00 2001 From: Massimiliano Viola Date: Wed, 6 May 2026 15:46:41 +0200 Subject: [PATCH 1/2] fix stochastic in inference --- app/vjepa_2_1/models/utils/modules.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/app/vjepa_2_1/models/utils/modules.py b/app/vjepa_2_1/models/utils/modules.py index fedf8a2e..20765662 100644 --- a/app/vjepa_2_1/models/utils/modules.py +++ b/app/vjepa_2_1/models/utils/modules.py @@ -285,7 +285,7 @@ def forward( if self.use_sdpa: with torch.backends.cuda.sdp_kernel(): x = F.scaled_dot_product_attention( - q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal + q, k, v, dropout_p=self.proj_drop_prob if self.training else 0.0, is_causal=self.is_causal ) attn = None else: @@ -338,7 +338,7 @@ def forward(self, x): if self.use_sdpa: with torch.backends.cuda.sdp_kernel(): x = F.scaled_dot_product_attention( - q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal + q, k, v, dropout_p=self.proj_drop_prob if self.training else 0.0, is_causal=self.is_causal ) attn = None else: From ac6ccbe605a0d0f8222109f128190639a168b03e Mon Sep 17 00:00:00 2001 From: Massimiliano Viola Date: Wed, 6 May 2026 16:02:24 +0200 Subject: [PATCH 2/2] also fix vjepa2 --- src/models/utils/modules.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/models/utils/modules.py b/src/models/utils/modules.py index 4b36564f..65eab58e 100644 --- a/src/models/utils/modules.py +++ b/src/models/utils/modules.py @@ -248,7 +248,8 @@ def merge_(tx, ta): if attn_mask is not None or self.use_sdpa: with torch.backends.cuda.sdp_kernel(): x = F.scaled_dot_product_attention( - q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal, attn_mask=attn_mask + q, k, v, dropout_p=self.proj_drop_prob if self.training else 0.0, + is_causal=self.is_causal, attn_mask=attn_mask ) attn = None else: @@ -372,7 +373,8 @@ def forward(self, x, mask=None, attn_mask=None, T=None, H_patches=None, W_patche if attn_mask is not None or self.use_sdpa: with torch.backends.cuda.sdp_kernel(): x = F.scaled_dot_product_attention( - q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal, attn_mask=attn_mask + q, k, v, dropout_p=self.proj_drop_prob if self.training else 0.0, + is_causal=self.is_causal, attn_mask=attn_mask ) attn = None else: @@ -419,7 +421,8 @@ def forward(self, x, mask=None, attn_mask=None): if attn_mask is not None or self.use_sdpa: with torch.backends.cuda.sdp_kernel(): x = F.scaled_dot_product_attention( - q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal, attn_mask=attn_mask + q, k, v, dropout_p=self.proj_drop_prob if self.training else 0.0, + is_causal=self.is_causal, attn_mask=attn_mask ) attn = None else: