diff --git a/CHANGELOG.md b/CHANGELOG.md
index 4bd4e09d..b0350a95 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,11 @@
# Changelog
+
+## [0.0.2] - 2025-03-16
+
+Release of V-JEPA 2.1
+
+
## [0.0.1] - 2025-06-05
-Initial release of V-JEPA 2 codebase
\ No newline at end of file
+Initial release of V-JEPA 2 codebase
diff --git a/README.md b/README.md
index 0f0698dd..18fab914 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,9 @@
+
+🆕 **[2026-03-16]:** :fire: V-JEPA 2.1 is released :fire: A new familly of models trained with a novel recipe that learns high quality and temporolly consistent dense features !!!
+
+**[2025-06-25]:** V-JEPA 2 is released. [[`Blog`](https://ai.meta.com/blog/v-jepa-2-world-model-benchmarks)]
+
+
# V-JEPA 2: Self-Supervised Video Models Enable Understanding, Prediction and Planning
### [Meta FAIR](https://ai.meta.com/research/)
@@ -13,7 +19,7 @@ Rabbat*, Nicolas Ballas*
[[`Paper`](https://arxiv.org/abs/2506.09985)] [[`Blog`](https://ai.meta.com/blog/v-jepa-2-world-model-benchmarks)] [[`BibTex`](#Citation)]
-Official Pytorch codebase for V-JEPA 2 and V-JEPA 2-AC.
+Official Pytorch codebase for V-JEPA 2, V-JEPA 2-AC, V-JEPA 2.1.
V-JEPA 2 is a self-supervised approach to training video encoders, using internet-scale video data, that attains state-of-the-art performance on motion understanding and human action anticipation tasks. V-JEPA 2-AC is a latent action-conditioned world model post-trained from V-JEPA 2 (using a small amount of robot trajectory interaction data) that solves robot manipulation tasks without environment-specific data collection or task-specific training or calibration.
@@ -21,11 +27,37 @@ V-JEPA 2 is a self-supervised approach to training video encoders, using interne
-
+
+## V-JEPA 2.1 Pre-training
+
+Lorenzo Mur-Labadia, Matthew Muckley, Amir Bar, Mahmoud Assran, Koustuv Sinha, Michael
+Rabbat, Yann LeCun, Nicolas Ballas, Adrien Bardes
+
+[[`Paper`](https://arxiv.org/abs/TODO)] [[`BibTex`](#Citation)]
+
+V-JEPA 2.1 improves the training recipe to focus on learning high-quality and temporally consistent dense features, as higlighted by PCA visualizations:
+
+
+
+
+
+The V-JEPA 2.1 approach leverages: (1) **Dense Predictive Loss**, a masking-based
+self-supervision objective where all tokens (both visible/context and masked tokens) contribute to the
+self-supervised training loss; (2) **Deep Self-Supervision**, which applies the self-supervised loss at multiple
+intermediate representations of the encoder models; (3) **Multi-Modal Tokenizers** for images and videos;
+and we show that our approach benefit from (4) **Model and data scaling**.
+
+
+
+
+
+V-JEPA 2.1 performance across dense and global prediction tasks:
+
+
+
+
+
## V-JEPA 2 Pre-training
@@ -35,7 +67,7 @@ V-JEPA 2 is a self-supervised approach to training video encoders, using interne
| Benchmark |
- VJEPA 2 |
+ V-JEPA 2 |
Previous Best |
@@ -111,15 +143,19 @@ V-JEPA 2 is a self-supervised approach to training video encoders, using interne
+
+
+
+
## Models
-### V-JEPA 2
+### V-JEPA 2 and V-JEPA 2.1
#### HuggingFace
-See our [HuggingFace collection](https://huggingface.co/collections/facebook/v-jepa-2-6841bad8413014e185b497a6) for V-JEPA 2.
+See our HuggingFace [collection](https://huggingface.co/collections/facebook/v-jepa-2-6841bad8413014e185b497a6) for V-JEPA 2.
-#### Pretrained Checkpoints
+#### V-JEPA 2 Pretrained Checkpoints
@@ -159,6 +195,51 @@ See our [HuggingFace collection](https://huggingface.co/collections/facebook/v-j
+#### V-JEPA 2.1 Pretrained Checkpoints
+
+
+
+
#### Pretrained backbones (via PyTorch Hub)
Please install [Pytorch](https://pytorch.org/get-started/locally/), [timm](https://pypi.org/project/timm/) and [einops](https://pypi.org/project/einops/) locally, then run the following to load each model. Installing Pytorch with CUDA support is strongly recommended.
@@ -169,16 +250,22 @@ import torch
# preprocessor
processor = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_preprocessor')
# models
+# V-JEPA 2
vjepa2_vit_large = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_vit_large')
vjepa2_vit_huge = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_vit_huge')
vjepa2_vit_giant = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_vit_giant')
vjepa2_vit_giant_384 = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_vit_giant_384')
+# V-JEPA 2.1
+vjepa2_1_vit_base_384 = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_1_vit_base_384')
+vjepa2_1_vit_large_384 = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_1_vit_large_384')
+vjepa2_1_vit_giant_384 = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_1_vit_giant_384')
+vjepa2_1_vit_gigantic_384 = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_1_vit_gigantic_384')
```
#### Pretrained checkpoints on Huggingface
-You can also use our pretrained checkpoints on [Huggingface](https://huggingface.co/collections/facebook/v-jepa-2-6841bad8413014e185b497a6).
+You can also use our pretrained checkpoints on [Huggingface for V-JEPA 2](https://huggingface.co/collections/facebook/v-jepa-2-6841bad8413014e185b497a6).
```python
from transformers import AutoVideoProcessor, AutoModel
@@ -189,7 +276,6 @@ hf_repo = "facebook/vjepa2-vitg-fpc64-256"
# facebook/vjepa2-vitg-fpc64-256
# facebook/vjepa2-vitg-fpc64-384
-
model = AutoModel.from_pretrained(hf_repo)
processor = AutoVideoProcessor.from_pretrained(hf_repo)
```
@@ -283,6 +369,7 @@ See [energy_landscape_example.ipynb](notebooks/energy_landscape_example.ipynb) f
To run this notebook, you'll need to additionally install [Jupyter](https://jupyter.org/install) and [Scipy](https://scipy.org/install/) in your conda environment.
+
## Getting Started
### Setup
@@ -400,13 +487,16 @@ python -m app.main_distributed \
```
.
├── app # training loops
-│ ├── vjepa # video JEPA pre-training
+│ ├── vjepa # V-JEPA 2 pre-training
+│ ├── vjepa_2_1 # V-JEPA 2.1 pre-training
│ ├── vjepa_droid # training the action-conditioned model
│ ├── main_distributed.py # entrypoint for launch app on slurm cluster
│ └── main.py # entrypoint for launch app locally on your machine
├── configs # config files with experiment params for training and evaluation
-│ ├── train # pretraining (phase 1), cooldown (phase 2), and action-conditioned training
+│ ├── train # pretraining with V-JEPA 2 (phase 1), cooldown (phase 2), and action-conditioned training
+│ ├── train_2_1 # pretraining with V-JEPA 2.1 (phase 1), cooldown (phase 2)
│ └── eval # frozen evaluations
+│ └── inference # inference only frozen evaluations
├── evals # evaluation loops training an attentive probe with frozen backbone...
│ ├── action_anticipation_frozen # action anticipation
│ ├── image_classification_frozen # image understanding
@@ -434,7 +524,8 @@ are licensed under the Apache 2.0 license.
## Citation
-If you find this repository useful in your research, please consider giving a star :star: and a citation
+If you find this repository useful in your research, please consider giving a star :star: and cite the papers:
+
```bibtex
@article{assran2025vjepa2,
title={V-JEPA~2: Self-Supervised Video Models Enable Understanding, Prediction and Planning},
@@ -448,3 +539,13 @@ Rabbat, Michael and Ballas, Nicolas},
year={2025}
}
```
+
+```bibtex
+@article{murlabadia2026vjepa2_1,
+ title={V-JEPA 2.1: Unlocking Dense Features in Video Self-Supervised Learning},
+ author={Mur-Labadia, Lorenzo and Muckley, Matthew and Bar, Amir and Assran, Mahmoud and
+Sinha, Koustuv and Rabbat, Michael and LeCun, Yann and Ballas, Nicolas and Bardes, Adrien},
+ journal={arXiv preprint arXiv:2603.14482},
+ year={2026}
+}
+```
diff --git a/app/vjepa_2_1/models/predictor.py b/app/vjepa_2_1/models/predictor.py
new file mode 100644
index 00000000..a8e7f968
--- /dev/null
+++ b/app/vjepa_2_1/models/predictor.py
@@ -0,0 +1,302 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from functools import partial
+
+import torch
+import torch.nn as nn
+
+from src.masks.utils import apply_masks
+from src.utils.tensors import repeat_interleave_batch, trunc_normal_
+
+from app.vjepa_2_1.models.utils.modules import Block
+
+
+class VisionTransformerPredictor(nn.Module):
+ """Vision Transformer Predictor"""
+
+ def __init__(
+ self,
+ img_size=(224, 224),
+ patch_size=16,
+ num_frames=1,
+ tubelet_size=2,
+ embed_dim=768,
+ predictor_embed_dim=384,
+ out_embed_dim=None,
+ depth=6,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.0,
+ attn_drop_rate=0.0,
+ drop_path_rate=0.0,
+ norm_layer=nn.LayerNorm,
+ init_std=0.02,
+ uniform_power=False,
+ use_mask_tokens=False,
+ num_mask_tokens=2,
+ zero_init_mask_tokens=True,
+ use_silu=False,
+ wide_silu=True,
+ is_causal=False,
+ use_activation_checkpointing=False,
+ return_all_tokens=False,
+ chop_last_n_tokens=0,
+ use_rope=False,
+ n_registers=0,
+ has_cls_first=False,
+ interpolate_rope=False,
+ modality_embedding=True,
+ img_temporal_dim_size=None,
+ teacher_embed_dim=None,
+ **kwargs
+ ):
+ super().__init__()
+ self.return_all_tokens = return_all_tokens
+ self.chop_last_n_tokens = chop_last_n_tokens
+ self.has_cls_first = has_cls_first
+
+ if depth == 4:
+ all_hierarchical_layers = [0, 1, 2, 3]
+ elif depth == 8:
+ all_hierarchical_layers = [1, 3, 5, 7]
+ elif depth == 12:
+ all_hierarchical_layers = [2, 5, 8, 11]
+ elif depth == 20:
+ all_hierarchical_layers = [4, 9, 14, 19]
+ elif depth == 24:
+ all_hierarchical_layers = [4, 11, 17, 23]
+ elif depth == 40:
+ all_hierarchical_layers = [9, 19, 29, 39]
+
+ n_output_distillation = kwargs.get("n_output_distillation", len(all_hierarchical_layers))
+ self.hierarchical_layers = all_hierarchical_layers[-n_output_distillation:]
+
+ act_layer_mlp = nn.SiLU if use_silu else nn.GELU
+ if len(self.hierarchical_layers) == 1:
+ self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True)
+ else:
+ self.predictor_embed = nn.Sequential(
+ nn.Linear(embed_dim * len(self.hierarchical_layers), embed_dim, bias=True),
+ act_layer_mlp(),
+ nn.Linear(embed_dim, predictor_embed_dim, bias=True),
+ )
+
+ self.mask_tokens = None
+ self.num_mask_tokens = 0
+ if use_mask_tokens:
+ self.num_mask_tokens = num_mask_tokens
+ self.mask_tokens = nn.ParameterList(
+ [
+ nn.Parameter(torch.zeros(1, 1, predictor_embed_dim))
+ for i in range(num_mask_tokens)
+ ]
+ )
+
+ if type(img_size) is int:
+ img_size = (img_size, img_size)
+ self.img_height, self.img_width = img_size
+ self.patch_size = patch_size
+ self.num_frames = num_frames
+ self.tubelet_size = tubelet_size
+ self.is_video = num_frames > 1
+
+ self.grid_height = img_size[0] // self.patch_size
+ self.grid_width = img_size[1] // self.patch_size
+ self.grid_depth = num_frames // self.tubelet_size
+ self.use_activation_checkpointing = use_activation_checkpointing
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
+
+ if self.is_video:
+ self.num_patches = (
+ (num_frames // tubelet_size)
+ * (img_size[0] // patch_size)
+ * (img_size[1] // patch_size)
+ )
+ else:
+ self.num_patches = (img_size[0] // patch_size) * (
+ img_size[1] // patch_size
+ )
+
+ self.modality_embedding = False
+ if img_temporal_dim_size is not None:
+ if modality_embedding:
+ self.video_mod_embed = nn.Parameter(
+ torch.zeros(1, 1, predictor_embed_dim)
+ )
+ nn.init.normal_(self.video_mod_embed, std=1e-6)
+ self.img_mod_embed = nn.Parameter(
+ torch.zeros(1, 1, predictor_embed_dim)
+ )
+ nn.init.normal_(self.img_mod_embed, std=1e-6)
+ self.modality_embedding = True
+
+ self.uniform_power = uniform_power
+
+ self.use_rope = use_rope
+ self.predictor_blocks = nn.ModuleList(
+ [
+ Block(
+ use_rope=use_rope,
+ grid_size=self.grid_height,
+ grid_depth=self.grid_depth,
+ dim=predictor_embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ act_layer=nn.SiLU if use_silu else nn.GELU,
+ is_causal=is_causal,
+ wide_silu=wide_silu,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ n_registers=n_registers,
+ has_cls_first=has_cls_first,
+ interpolate_rope=interpolate_rope,
+ patch_size=patch_size,
+ )
+ for i in range(depth)
+ ]
+ )
+
+ if out_embed_dim is None:
+ if teacher_embed_dim is not None:
+ out_embed_dim = teacher_embed_dim // len(self.hierarchical_layers)
+ else:
+ out_embed_dim = embed_dim
+ self.predictor_norm = norm_layer(predictor_embed_dim)
+ self.predictor_proj = nn.Linear(
+ predictor_embed_dim,
+ len(self.hierarchical_layers) * out_embed_dim,
+ bias=True,
+ )
+ if self.return_all_tokens:
+ self.predictor_proj_context = nn.Linear(
+ predictor_embed_dim,
+ out_embed_dim * len(self.hierarchical_layers),
+ bias=True,
+ )
+
+ self.init_std = init_std
+ if not zero_init_mask_tokens:
+ for mt in self.mask_tokens:
+ trunc_normal_(mt, std=init_std)
+
+ self.apply(self._init_weights)
+ self._rescale_blocks()
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=self.init_std)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def _rescale_blocks(self):
+ def rescale(param, layer_id):
+ param.div_(math.sqrt(2.0 * layer_id))
+
+ for layer_id, layer in enumerate(self.predictor_blocks):
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
+
+ def forward(self, x, masks_x, masks_y, mod="video", mask_index=1):
+ """
+ :param x: context tokens
+ :param masks_x: indices of context tokens in input
+ :params masks_y: indices of target tokens in input
+ """
+ assert (masks_x is not None) and (
+ masks_y is not None
+ ), "Cannot run predictor without mask indices"
+ if not isinstance(masks_x, list):
+ masks_x = [masks_x]
+ if not isinstance(masks_y, list):
+ masks_y = [masks_y]
+
+ B = len(x) // len(masks_x)
+
+ x = self.predictor_embed(x)
+ _, N_ctxt, D = x.shape
+
+ if not self.use_rope:
+ x_pos_embed = self.predictor_pos_embed.repeat(B, 1, 1)
+ x += apply_masks(x_pos_embed, masks_x)
+
+ mask_index = mask_index % self.num_mask_tokens
+ pred_tokens = self.mask_tokens[mask_index]
+ pred_tokens = pred_tokens.repeat(B, self.num_patches, 1)
+ pred_tokens = apply_masks(pred_tokens, masks_y)
+
+ if not self.use_rope:
+ pos_embs = self.predictor_pos_embed.repeat(B, 1, 1)
+ pos_embs = apply_masks(pos_embs, masks_y)
+ pos_embs = repeat_interleave_batch(pos_embs, B, repeat=len(masks_x))
+ pred_tokens += pos_embs
+
+ x = x.repeat(len(masks_x), 1, 1)
+ x = torch.cat([x, pred_tokens], dim=1)
+
+ masks_x = torch.cat(masks_x, dim=0)
+ masks_y = torch.cat(masks_y, dim=0)
+ masks = torch.cat([masks_x, masks_y], dim=1)
+
+ argsort = torch.argsort(masks, dim=1)
+ masks = torch.stack([masks[i, row] for i, row in enumerate(argsort)], dim=0)
+ x = torch.stack([x[i, row, :] for i, row in enumerate(argsort)], dim=0)
+
+ if self.chop_last_n_tokens > 0:
+ x = x[:, : -self.chop_last_n_tokens]
+ masks = masks[:, : -self.chop_last_n_tokens]
+
+ if self.modality_embedding:
+ if mod == "image":
+ x += self.img_mod_embed.repeat(B, 1, 1)
+ else:
+ x += self.video_mod_embed.repeat(B, 1, 1)
+
+ for i, blk in enumerate(self.predictor_blocks):
+ if self.use_activation_checkpointing:
+ x, attn = torch.utils.checkpoint.checkpoint(
+ blk, x, masks, use_reentrant=False
+ )
+ else:
+ x, attn = blk(x, mask=masks)
+ x = self.predictor_norm(x)
+
+ if not self.return_all_tokens:
+ reverse_argsort = torch.argsort(argsort, dim=1)
+ x = torch.stack(
+ [x[i, row, :] for i, row in enumerate(reverse_argsort)], dim=0
+ )
+ x = x[:, N_ctxt:, :]
+ x = self.predictor_proj(x)
+ return x, None
+ else:
+ reverse_argsort = torch.argsort(argsort, dim=1)
+ x = torch.stack(
+ [x[i, row, :] for i, row in enumerate(reverse_argsort)], dim=0
+ )
+ x_pred = x[:, N_ctxt:, :]
+ x_context = x[:, :N_ctxt, :]
+ x_pred = self.predictor_proj(x_pred)
+ x_context = self.predictor_proj_context(x_context)
+ return x_pred, x_context
+
+
+def vit_predictor(**kwargs):
+ model = VisionTransformerPredictor(
+ mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs
+ )
+ return model
diff --git a/app/vjepa_2_1/models/utils/masks_dist.py b/app/vjepa_2_1/models/utils/masks_dist.py
new file mode 100644
index 00000000..216f5738
--- /dev/null
+++ b/app/vjepa_2_1/models/utils/masks_dist.py
@@ -0,0 +1,77 @@
+import torch
+import torch.nn as nn
+import torchvision
+
+
+def _get_frame_pos(ids, H_patches=None, W_patches=None, grid_size=None):
+ if H_patches is None or W_patches is None:
+ tokens_per_frame = int(grid_size * grid_size)
+ else:
+ tokens_per_frame = int(H_patches * W_patches)
+ return ids // tokens_per_frame
+
+
+def _get_height_pos(ids, H_patches=None, W_patches=None, grid_size=None):
+ # Remove frame component from ids
+ if H_patches is None or W_patches is None:
+ tokens_per_frame = int(grid_size * grid_size)
+ tokens_per_row = grid_size
+ else:
+ tokens_per_frame = int(H_patches * W_patches)
+ tokens_per_row = W_patches
+ frame_ids = _get_frame_pos(ids, H_patches, W_patches, grid_size)
+ ids = ids - tokens_per_frame * frame_ids
+ # --
+ return ids // tokens_per_row
+
+
+def separate_positions(ids, H_patches=None, W_patches=None, grid_size=None):
+ if H_patches is None or W_patches is None:
+ tokens_per_frame = int(grid_size * grid_size)
+ tokens_per_row = grid_size
+ else:
+ tokens_per_frame = int(H_patches * W_patches)
+ tokens_per_row = W_patches
+ frame_ids = _get_frame_pos(ids, H_patches, W_patches, grid_size)
+ # --
+ height_ids = _get_height_pos(ids, H_patches, W_patches, grid_size)
+ # --
+ # Remove frame component from ids (1st term) and height component (2nd term)
+ width_ids = (ids - tokens_per_frame * frame_ids) - tokens_per_row * height_ids
+ return 1.0 * frame_ids, 1.0 * height_ids, 1.0 * width_ids
+
+
+def compute_mask_distance(masks_pred, masks_enc, grid_size, offset_context_loss):
+ # masks_pred: [fpc][mask] where each mask is [B, N_pred]
+ # masks_enc: [fpc][mask] where each mask is [B, N_enc]
+ distances = []
+ for masks_pred_i, masks_enc_i in zip(masks_pred, masks_enc):
+ row_distances = []
+ for masks_pred_ij, masks_enc_ij in zip(masks_pred_i, masks_enc_i):
+ N_enc_tokens = masks_enc_ij.shape[1]
+ d_enc, h_enc, w_enc = separate_positions(
+ masks_enc_ij, grid_size=grid_size
+ ) # (BS, N_enc)
+ d_pred, h_pred, w_pred = separate_positions(
+ masks_pred_ij, grid_size=grid_size
+ ) # (BS, N_pred)
+ pred = torch.stack([d_pred, h_pred, w_pred], dim=-1) # (BS, N_pred, 3)
+ enc_distances = []
+ for enc_token in range(N_enc_tokens):
+ enc_position = torch.stack(
+ [d_enc[:, enc_token], h_enc[:, enc_token], w_enc[:, enc_token]],
+ dim=-1,
+ ).unsqueeze(
+ 1
+ ) # (BS, 1, 3)
+ dist = torch.cdist(enc_position, pred, p=2) # (BS, N_enc)
+ dmin, argmin = dist.min(dim=-1)
+ if offset_context_loss:
+ coeff = grid_size // 16 # Which is the default value of grid_size
+ dmin = dmin * (1.0 / coeff)
+ dmin = dmin**0.5 # We want that it decreases less agressive
+ enc_distances.append(dmin)
+ enc_distances = torch.stack(enc_distances, dim=-1).squeeze() # (BS, N_enc)
+ row_distances.append(enc_distances)
+ distances.append(row_distances)
+ return distances
diff --git a/app/vjepa_2_1/models/utils/modules.py b/app/vjepa_2_1/models/utils/modules.py
new file mode 100644
index 00000000..fedf8a2e
--- /dev/null
+++ b/app/vjepa_2_1/models/utils/modules.py
@@ -0,0 +1,544 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from timm.models.layers import drop_path
+
+
+def rotate_queries_or_keys(x, pos, n_registers, has_cls_first):
+ B, num_heads, N, D = x.size()
+ assert (
+ D % 2 == 0
+ ), "Embedding dimension must be a multiple of 2 for block matrix rotation"
+
+ n_cls = 1 if has_cls_first else 0
+ start_ctx = n_cls
+ end_ctx = N - n_registers
+
+ x_cls = x[..., :n_cls, :] if n_cls else None
+ x_ctx = x[..., start_ctx:end_ctx, :]
+ x_reg = x[..., end_ctx:, :] if n_registers > 0 else None
+
+ omega = torch.arange(D // 2, dtype=x.dtype, device=x.device)
+ omega /= D / 2.0
+ omega = 1.0 / 10000**omega
+ freq = torch.einsum("..., f -> ... f", pos, omega)
+
+ emb_sin = freq.sin()
+ emb_cos = freq.cos()
+
+ emb_sin = emb_sin.repeat_interleave(2, dim=-1)
+ emb_cos = emb_cos.repeat_interleave(2, dim=-1)
+
+ y = x_ctx.unflatten(-1, (-1, 2))
+ y1, y2 = y.unbind(dim=-1)
+ y = torch.stack((-y2, y1), dim=-1)
+ y = y.flatten(-2)
+
+ out_ctx = (x_ctx * emb_cos) + (y * emb_sin)
+
+ parts = []
+ if n_cls:
+ parts.append(x_cls)
+ parts.append(out_ctx)
+ if n_registers:
+ parts.append(x_reg)
+ out = torch.cat(parts, dim=-2)
+
+ return out
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return "p={}".format(self.drop_prob)
+
+
+class MLP(nn.Module):
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ drop=0.0,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class SwiGLUFFN(nn.Module):
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.SiLU,
+ drop=0.0,
+ wide_silu=True,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ swiglu_hidden_features = hidden_features = hidden_features or in_features
+ if wide_silu:
+ swiglu_hidden_features = int(2 * hidden_features / 3)
+ align_as = 8
+ swiglu_hidden_features = (
+ (swiglu_hidden_features + align_as - 1) // align_as * align_as
+ )
+ self.fc1 = nn.Linear(in_features, swiglu_hidden_features)
+ self.fc2 = nn.Linear(in_features, swiglu_hidden_features)
+ self.act = act_layer()
+ self.fc3 = nn.Linear(swiglu_hidden_features, out_features)
+
+ def forward(self, x):
+ x1 = self.fc1(x)
+ x2 = self.fc2(x)
+ hidden = F.silu(x1) * x2
+ return self.fc3(hidden)
+
+
+class RoPEAttention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ use_sdpa=True,
+ grid_size=14,
+ is_causal=False,
+ n_registers=0,
+ has_cls_first=False,
+ interpolate_rope=False,
+ patch_size=16,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ self.head_dim = head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop_prob = proj_drop
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.use_sdpa = use_sdpa
+ self.d_dim = int(2 * ((head_dim // 3) // 2))
+ self.h_dim = int(2 * ((head_dim // 3) // 2))
+ self.w_dim = int(2 * ((head_dim // 3) // 2))
+ self.grid_size = grid_size
+ self.is_causal = is_causal
+ self.n_registers = n_registers
+ self.has_cls_first = has_cls_first
+ self.interpolate_rope = interpolate_rope
+ self.pretrained_patch_size = patch_size
+ if patch_size == 14:
+ self.pretrained_grid_size = int(252 / patch_size)
+ elif patch_size == 16:
+ self.pretrained_grid_size = int(256 / patch_size)
+
+ def _get_frame_pos(self, ids, H_patches=None, W_patches=None):
+ if H_patches is None or W_patches is None:
+ tokens_per_frame = int(self.grid_size * self.grid_size)
+ else:
+ tokens_per_frame = int(H_patches * W_patches)
+ return ids // tokens_per_frame
+
+ def _get_height_pos(self, ids, H_patches=None, W_patches=None):
+ if H_patches is None or W_patches is None:
+ tokens_per_frame = int(self.grid_size * self.grid_size)
+ tokens_per_row = self.grid_size
+ else:
+ tokens_per_frame = int(H_patches * W_patches)
+ tokens_per_row = W_patches
+ frame_ids = self._get_frame_pos(ids, H_patches, W_patches)
+ ids = ids - tokens_per_frame * frame_ids
+ return ids // tokens_per_row
+
+ def separate_positions(self, ids, H_patches=None, W_patches=None):
+ if H_patches is None or W_patches is None:
+ tokens_per_frame = int(self.grid_size * self.grid_size)
+ tokens_per_row = self.grid_size
+ else:
+ tokens_per_frame = int(H_patches * W_patches)
+ tokens_per_row = W_patches
+ frame_ids = self._get_frame_pos(ids, H_patches, W_patches)
+ height_ids = self._get_height_pos(ids, H_patches, W_patches)
+ width_ids = (ids - tokens_per_frame * frame_ids) - tokens_per_row * height_ids
+ return 1.0 * frame_ids, 1.0 * height_ids, 1.0 * width_ids
+
+ def forward(
+ self,
+ x,
+ mask=None,
+ T=None,
+ H_patches=None,
+ W_patches=None,
+ return_attn=False,
+ ):
+ B, N, C = x.size()
+ N_ctx = N - self.n_registers
+ grid_depth = int(N_ctx // (self.grid_size * self.grid_size))
+
+ qkv = self.qkv(x).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2]
+
+ if mask is not None:
+ mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1)
+ d_mask, h_mask, w_mask = self.separate_positions(mask, H_patches, W_patches)
+ else:
+ if T is None or H_patches is None or W_patches is None:
+ mask = torch.arange(
+ int(grid_depth * self.grid_size * self.grid_size), device=x.device
+ )
+ else:
+ mask = torch.arange(int(T * H_patches * W_patches), device=x.device)
+ d_mask, h_mask, w_mask = self.separate_positions(mask, H_patches, W_patches)
+
+ if self.interpolate_rope:
+ if H_patches is None:
+ H_patches = int(self.grid_size)
+ if W_patches is None:
+ W_patches = int(self.grid_size)
+ h_mask = h_mask * (self.pretrained_grid_size - 1) / (H_patches - 1)
+ w_mask = w_mask * (self.pretrained_grid_size - 1) / (W_patches - 1)
+
+ s = 0
+ qd = rotate_queries_or_keys(
+ q[..., s : s + self.d_dim],
+ pos=d_mask,
+ n_registers=self.n_registers,
+ has_cls_first=self.has_cls_first,
+ )
+ kd = rotate_queries_or_keys(
+ k[..., s : s + self.d_dim],
+ pos=d_mask,
+ n_registers=self.n_registers,
+ has_cls_first=self.has_cls_first,
+ )
+ s += self.d_dim
+ qh = rotate_queries_or_keys(
+ q[..., s : s + self.h_dim],
+ pos=h_mask,
+ n_registers=self.n_registers,
+ has_cls_first=self.has_cls_first,
+ )
+ kh = rotate_queries_or_keys(
+ k[..., s : s + self.h_dim],
+ pos=h_mask,
+ n_registers=self.n_registers,
+ has_cls_first=self.has_cls_first,
+ )
+ s += self.h_dim
+ qw = rotate_queries_or_keys(
+ q[..., s : s + self.w_dim],
+ pos=w_mask,
+ n_registers=self.n_registers,
+ has_cls_first=self.has_cls_first,
+ )
+ kw = rotate_queries_or_keys(
+ k[..., s : s + self.w_dim],
+ pos=w_mask,
+ n_registers=self.n_registers,
+ has_cls_first=self.has_cls_first,
+ )
+ s += self.w_dim
+
+ if s < self.head_dim:
+ qr = q[..., s:]
+ kr = k[..., s:]
+ q = torch.cat([qd, qh, qw, qr], dim=-1)
+ k = torch.cat([kd, kh, kw, kr], dim=-1)
+ else:
+ q = torch.cat([qd, qh, qw], dim=-1)
+ k = torch.cat([kd, kh, kw], dim=-1)
+
+ if self.use_sdpa:
+ with torch.backends.cuda.sdp_kernel():
+ x = F.scaled_dot_product_attention(
+ q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal
+ )
+ attn = None
+ else:
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = attn @ v
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ if return_attn:
+ return x, attn
+ else:
+ return x, None
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ use_sdpa=True,
+ is_causal=False,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop_prob = proj_drop
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.use_sdpa = use_sdpa
+ self.is_causal = is_causal
+
+ def forward(self, x):
+ B, N, C = x.shape
+ qkv = (
+ self.qkv(x)
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
+ .permute(2, 0, 3, 1, 4)
+ )
+ q, k, v = qkv[0], qkv[1], qkv[2]
+
+ if self.use_sdpa:
+ with torch.backends.cuda.sdp_kernel():
+ x = F.scaled_dot_product_attention(
+ q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal
+ )
+ attn = None
+ else:
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = attn @ v
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ wide_silu=True,
+ norm_layer=nn.LayerNorm,
+ use_sdpa=True,
+ is_causal=False,
+ grid_size=16,
+ use_rope=False,
+ n_registers=0,
+ has_cls_first=False,
+ interpolate_rope=False,
+ patch_size=16,
+ **kwargs,
+ ):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.use_rope = use_rope
+ if use_rope:
+ self.attn = RoPEAttention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ use_sdpa=use_sdpa,
+ is_causal=is_causal,
+ grid_size=grid_size,
+ proj_drop=drop,
+ n_registers=n_registers,
+ has_cls_first=has_cls_first,
+ interpolate_rope=interpolate_rope,
+ patch_size=patch_size,
+ )
+ else:
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ use_sdpa=use_sdpa,
+ is_causal=is_causal,
+ proj_drop=drop,
+ )
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ if act_layer is nn.SiLU:
+ self.mlp = SwiGLUFFN(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ wide_silu=wide_silu,
+ drop=drop,
+ )
+ else:
+ self.mlp = MLP(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ )
+
+ def forward(
+ self,
+ x,
+ mask=None,
+ T=None,
+ H_patches=None,
+ W_patches=None,
+ return_attn=False,
+ mode="video",
+ ):
+ if self.use_rope:
+ y, attn = self.attn(
+ self.norm1(x),
+ mask=mask,
+ T=T,
+ H_patches=H_patches,
+ W_patches=W_patches,
+ return_attn=return_attn,
+ )
+ else:
+ y = self.attn(self.norm1(x))
+ attn = None
+ x = x + self.drop_path(y)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ if return_attn:
+ return x, attn
+ else:
+ return x, None
+
+
+class CrossAttention(nn.Module):
+ def __init__(self, dim, num_heads=12, qkv_bias=False, use_sdpa=True):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
+ self.kv = nn.Linear(dim, int(dim * 2), bias=qkv_bias)
+ self.use_sdpa = use_sdpa
+
+ def forward(self, q, x):
+ B, n, C = q.shape
+ q = (
+ self.q(q)
+ .reshape(B, n, self.num_heads, C // self.num_heads)
+ .permute(0, 2, 1, 3)
+ )
+
+ B, N, C = x.shape
+ kv = (
+ self.kv(x)
+ .reshape(B, N, 2, self.num_heads, C // self.num_heads)
+ .permute(2, 0, 3, 1, 4)
+ )
+ k, v = kv[0], kv[1]
+
+ if self.use_sdpa:
+ with torch.backends.cuda.sdp_kernel():
+ q = F.scaled_dot_product_attention(q, k, v)
+ else:
+ xattn = (q @ k.transpose(-2, -1)) * self.scale
+ xattn = xattn.softmax(dim=-1)
+ q = xattn @ v
+
+ q = q.transpose(1, 2).reshape(B, n, C)
+ return q
+
+
+class CrossAttentionBlock(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ ):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.xattn = CrossAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias)
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = MLP(
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer
+ )
+
+ def forward(self, q, x):
+ y = self.xattn(q, self.norm1(x))
+ q = q + y
+ q = q + self.mlp(self.norm2(q))
+ return q
+
+
+class Lambda_LinearWarmupHold:
+ """
+ Linear warmup from 0 to lambda_value between [start_iter, end_iter],
+ 0 before start_iter and constant (=lambda_value) from end_iter onwards.
+ """
+
+ def __init__(
+ self, lambda_value: float, start_iter: int = 15_000, end_iter: int = 30_000
+ ):
+ assert end_iter > start_iter, "end_iter must be > start_iter"
+ self.lambda_value = float(lambda_value)
+ self.start = int(start_iter)
+ self.end = int(end_iter)
+ self.span = self.end - self.start
+
+ def value(self, global_iter: int) -> float:
+ if global_iter < self.start:
+ return 0.0
+ if global_iter >= self.end:
+ return self.lambda_value
+ alpha = (global_iter - self.start) / self.span
+ return self.lambda_value * alpha
diff --git a/app/vjepa_2_1/models/utils/patch_embed.py b/app/vjepa_2_1/models/utils/patch_embed.py
new file mode 100644
index 00000000..9dc34352
--- /dev/null
+++ b/app/vjepa_2_1/models/utils/patch_embed.py
@@ -0,0 +1,72 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+#
+
+from einops import rearrange
+
+import torch.nn as nn
+
+
+class AudioPatchEmbed(nn.Module):
+ """
+ Audio to Patch Embedding
+ """
+
+ def __init__(self, freq_bands=128, tubelet_size=2, embed_dim=768):
+ super().__init__()
+ self.freq_bands = freq_bands
+ self.tubelet_size = tubelet_size
+ self.proj = nn.Conv2d(1, embed_dim, kernel_size=(freq_bands, tubelet_size), stride=(freq_bands, tubelet_size))
+
+ def forward(self, x):
+ x = rearrange(x, "b t c f -> b c f t")
+ x = self.proj(x).flatten(2).transpose(1, 2)
+ return x
+
+
+class PatchEmbed(nn.Module):
+ """
+ Image to Patch Embedding
+ """
+
+ def __init__(self, patch_size=16, in_chans=3, embed_dim=768):
+ super().__init__()
+ self.patch_size = patch_size
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ x = self.proj(x).flatten(2).transpose(1, 2)
+ return x
+
+
+class PatchEmbed3D(nn.Module):
+ """
+ Image to Patch Embedding
+ """
+
+ def __init__(
+ self,
+ patch_size=16,
+ tubelet_size=2,
+ in_chans=3,
+ embed_dim=768,
+ ):
+ super().__init__()
+ self.patch_size = patch_size
+ self.tubelet_size = tubelet_size
+
+ self.proj = nn.Conv3d(
+ in_channels=in_chans,
+ out_channels=embed_dim,
+ kernel_size=(tubelet_size, patch_size, patch_size),
+ stride=(tubelet_size, patch_size, patch_size),
+ )
+
+ def forward(self, x, **kwargs):
+ B, C, T, H, W = x.shape
+ x = self.proj(x).flatten(2).transpose(1, 2)
+ return x
diff --git a/app/vjepa_2_1/models/utils/pos_embs.py b/app/vjepa_2_1/models/utils/pos_embs.py
new file mode 100644
index 00000000..85036dd3
--- /dev/null
+++ b/app/vjepa_2_1/models/utils/pos_embs.py
@@ -0,0 +1,95 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+#
+
+import numpy as np
+
+
+def get_3d_sincos_pos_embed(embed_dim, grid_size, grid_depth, cls_token=False, uniform_power=False):
+ """
+ grid_size: int of the grid height and width
+ grid_depth: int of the grid depth
+ returns:
+ pos_embed: [grid_depth*grid_size*grid_size, embed_dim] (w/o cls_token)
+ or [1+grid_depth*grid_size*grid_size, embed_dim] (w/ cls_token)
+ """
+ grid_d = np.arange(grid_depth, dtype=float)
+ grid_h = np.arange(grid_size, dtype=float)
+ grid_w = np.arange(grid_size, dtype=float)
+ grid_h, grid_d, grid_w = np.meshgrid(
+ grid_h, grid_d, grid_w
+ ) # order of meshgrid is very important for indexing as [d,h,w]
+
+ if not uniform_power:
+ h_embed_dim = embed_dim // 4
+ w_embed_dim = embed_dim // 4
+ d_embed_dim = embed_dim // 2
+ else:
+ h_embed_dim = w_embed_dim = d_embed_dim = int(np.ceil(embed_dim / 6) * 2)
+
+ emb_h = get_1d_sincos_pos_embed_from_grid(h_embed_dim, grid_h) # (T*H*W, D1)
+ emb_w = get_1d_sincos_pos_embed_from_grid(w_embed_dim, grid_w) # (T*H*W, D2)
+ emb_d = get_1d_sincos_pos_embed_from_grid(d_embed_dim, grid_d) # (T*H*W, D3)
+ pos_embed = np.concatenate([emb_d, emb_h, emb_w], axis=1)
+ pos_embed = pos_embed[:, :embed_dim]
+ if cls_token:
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
+ """
+ grid_size: int of the grid height and width
+ returns:
+ pos_embed: [grid_size*grid_size, embed_dim] (w/o cls_token)
+ or [1+grid_size*grid_size, embed_dim] (w/ cls_token)
+ """
+ grid_h = np.arange(grid_size, dtype=float)
+ grid_w = np.arange(grid_size, dtype=float)
+ grid_w, grid_h = np.meshgrid(grid_w, grid_h) # order of meshgrid is very important for indexing as [h, w]
+
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_h) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_w) # (H*W, D/2)
+ pos_embed = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ if cls_token:
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
+ """
+ embed_dim: output dimension for each position
+ grid_size: int of the grid length
+ returns:
+ pos_embed: [grid_size, embed_dim] (w/o cls_token)
+ or [1+grid_size, embed_dim] (w/ cls_token)
+ """
+ grid = np.arange(grid_size, dtype=float)
+ pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token:
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ returns: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=float)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
diff --git a/app/vjepa_2_1/models/vision_transformer.py b/app/vjepa_2_1/models/vision_transformer.py
new file mode 100644
index 00000000..c8797fbb
--- /dev/null
+++ b/app/vjepa_2_1/models/vision_transformer.py
@@ -0,0 +1,608 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from functools import partial
+
+import torch
+import torch.nn as nn
+
+from src.masks.utils import apply_masks
+from src.utils.tensors import trunc_normal_
+
+from app.vjepa_2_1.models.utils.modules import Block
+from app.vjepa_2_1.models.utils.patch_embed import PatchEmbed, PatchEmbed3D
+
+
+class VisionTransformer(nn.Module):
+ """Vision Transformer"""
+
+ def __init__(
+ self,
+ img_size=(224, 224),
+ patch_size=16,
+ num_frames=1,
+ tubelet_size=2,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.0,
+ attn_drop_rate=0.0,
+ drop_path_rate=0.0,
+ norm_layer=nn.LayerNorm,
+ init_std=0.02,
+ out_layers=None,
+ uniform_power=False,
+ use_silu=False,
+ wide_silu=True,
+ use_sdpa=True,
+ use_activation_checkpointing=False,
+ is_causal=False,
+ use_rope=False,
+ init_type: str = "default",
+ handle_nonsquare_inputs=True,
+ img_temporal_dim_size=None,
+ n_registers=0,
+ has_cls_first=False,
+ interpolate_rope=False,
+ modality_embedding=True,
+ n_output_distillation=4,
+ **kwargs,
+ ):
+ super().__init__()
+ self.num_features = self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.out_layers = out_layers
+ self.init_type = init_type
+ self.handle_nonsquare_inputs = handle_nonsquare_inputs
+ self.img_temporal_dim_size = img_temporal_dim_size
+
+ if type(img_size) is int:
+ img_size = (img_size, img_size)
+ self.img_height, self.img_width = img_size
+ self.patch_size = patch_size
+ self.num_frames = num_frames
+ self.tubelet_size = tubelet_size
+ self.is_video = num_frames > 1
+
+ self.use_activation_checkpointing = use_activation_checkpointing
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
+
+ if self.is_video:
+ self.patch_embed = PatchEmbed3D(
+ patch_size=patch_size,
+ tubelet_size=tubelet_size,
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ )
+ self.num_patches = (
+ (num_frames // tubelet_size)
+ * (img_size[0] // patch_size)
+ * (img_size[1] // patch_size)
+ )
+ else:
+ self.patch_embed = PatchEmbed(
+ patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim
+ )
+ self.num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
+
+ if self.img_temporal_dim_size is not None:
+ if not isinstance(self.img_temporal_dim_size, int):
+ raise ValueError(
+ f"img_temporal_dim_size must be an int, got {self.img_temporal_dim_size}"
+ )
+ self.patch_embed_img = PatchEmbed3D(
+ patch_size=patch_size,
+ tubelet_size=1,
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ )
+ else:
+ self.patch_embed_img = None
+
+ self.uniform_power = uniform_power
+
+ self.use_rope = use_rope
+ self.blocks = nn.ModuleList(
+ [
+ Block(
+ use_rope=use_rope,
+ grid_size=img_size[0] // patch_size,
+ grid_depth=num_frames // tubelet_size,
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ use_sdpa=use_sdpa,
+ is_causal=is_causal,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ act_layer=nn.SiLU if use_silu else nn.GELU,
+ wide_silu=wide_silu,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ n_registers=n_registers,
+ has_cls_first=has_cls_first,
+ interpolate_rope=interpolate_rope,
+ patch_size=patch_size,
+ )
+ for i in range(depth)
+ ]
+ )
+
+ self.attn_out = False
+ self.init_std = init_std
+ self.apply(self._init_weights)
+ self._rescale_blocks()
+
+ if depth == 12:
+ self.hierarchical_layers = [2, 5, 8, 11]
+ if n_output_distillation == 4:
+ self.out_layers_distillation = [2, 5, 8, 11]
+ elif n_output_distillation == 1:
+ self.out_layers_distillation = [11]
+
+ elif depth == 24:
+ self.hierarchical_layers = [5, 11, 17, 23]
+ if n_output_distillation == 4:
+ self.out_layers_distillation = [5, 11, 17, 23]
+ elif n_output_distillation == 1:
+ self.out_layers_distillation = [23]
+
+ elif depth == 40:
+ self.hierarchical_layers = [9, 19, 29, 39]
+ if n_output_distillation == 4:
+ self.out_layers_distillation = [9, 19, 29, 39]
+ elif n_output_distillation == 1:
+ self.out_layers_distillation = [39]
+
+ elif depth == 48:
+ self.hierarchical_layers = [11, 23, 37, 47]
+ if n_output_distillation == 4:
+ self.out_layers_distillation = [11, 23, 37, 47]
+ elif n_output_distillation == 1:
+ self.out_layers_distillation = [47]
+ else:
+ print("Check the code! ;)")
+ self.norms_block = nn.ModuleList(
+ [norm_layer(embed_dim) for _ in range(len(self.hierarchical_layers))]
+ )
+
+ self.cls_token = None
+ self.return_hierarchical = False
+
+ self.modality_embedding = False
+ if modality_embedding:
+ self.img_mod_embed = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.video_mod_embed = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ nn.init.normal_(self.img_mod_embed, std=1e-6)
+ nn.init.normal_(self.video_mod_embed, std=1e-6)
+ self.modality_embedding = True
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ return
+ if self.init_type == "default":
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=self.init_std)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Conv2d):
+ trunc_normal_(m.weight, std=self.init_std)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Conv3d):
+ trunc_normal_(m.weight, std=self.init_std)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ elif self.init_type == "xavier_uniform":
+ if (
+ isinstance(m, nn.Linear)
+ or isinstance(m, nn.Conv2d)
+ or isinstance(m, nn.Conv3d)
+ ):
+ nn.init.xavier_uniform_(m.weight)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif self.init_type == "xavier_normal":
+ if (
+ isinstance(m, nn.Linear)
+ or isinstance(m, nn.Conv2d)
+ or isinstance(m, nn.Conv3d)
+ ):
+ nn.init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ else:
+ raise ValueError(f"Unknown init type {self.init_type}")
+
+ def _rescale_blocks(self):
+ def rescale(param, layer_id):
+ param.div_(math.sqrt(2.0 * layer_id))
+
+ for layer_id, layer in enumerate(self.blocks):
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
+
+ def get_num_layers(self):
+ return len(self.blocks)
+
+ def no_weight_decay(self):
+ return {}
+
+ def check_temporal_dim(self, shape) -> bool:
+ if self.img_temporal_dim_size is not None:
+ if shape[2] == self.img_temporal_dim_size:
+ return True
+
+ return False
+
+ def forward(self, x, masks=None, training=False):
+ """
+ :param x: input image/video
+ :param masks: indices of patch tokens to mask (remove)
+ """
+ if masks is not None and not isinstance(masks, list):
+ masks = [masks]
+
+ if x.ndim == 4:
+ _, _, H, W = x.shape
+ T = 1
+ elif x.ndim == 5:
+ _, _, T, H, W = x.shape
+ if self.check_temporal_dim(x.shape):
+ T = T // 1
+ else:
+ T = T // self.tubelet_size
+
+ H_patches = H // self.patch_size
+ W_patches = W // self.patch_size
+ if not self.handle_nonsquare_inputs:
+ T = H_patches = W_patches = None
+
+ if not self.use_rope:
+ pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
+
+ if self.check_temporal_dim(x.shape):
+ assert self.patch_embed_img is not None
+ x = self.patch_embed_img(x)
+ mode = "img"
+ if self.modality_embedding:
+ x += self.img_mod_embed.repeat(x.shape[0], 1, 1)
+ else:
+ x = self.patch_embed(x)
+ mode = "video"
+ if self.modality_embedding:
+ x += self.video_mod_embed.repeat(x.shape[0], 1, 1)
+
+ if not self.use_rope:
+ x += pos_embed
+
+ if masks is not None:
+ x = apply_masks(x, masks)
+ masks = torch.cat(masks, dim=0)
+
+ outs = []
+ hier = []
+ for i, blk in enumerate(self.blocks):
+ if self.use_activation_checkpointing:
+ x, attn = torch.utils.checkpoint.checkpoint(
+ blk,
+ x,
+ masks,
+ T=T,
+ H_patches=H_patches,
+ W_patches=W_patches,
+ use_reentrant=False,
+ return_attn=self.attn_out,
+ mode=mode,
+ )
+ else:
+ x, attn = blk(
+ x,
+ mask=masks,
+ T=T,
+ H_patches=H_patches,
+ W_patches=W_patches,
+ return_attn=self.attn_out,
+ mode=mode,
+ )
+
+ if self.out_layers is not None and i in self.out_layers:
+ out_idx = self.hierarchical_layers.index(i)
+ out_norm = self.norms_block[out_idx](x)
+ outs.append(out_norm)
+
+ if i in self.out_layers_distillation:
+ out_idx = self.hierarchical_layers.index(i)
+ hier.append(self.norms_block[out_idx](x))
+
+ if self.out_layers is not None:
+ return outs
+
+ if training or self.return_hierarchical:
+ hier = torch.cat(hier, dim=2)
+ return hier
+ else:
+ x = self.norms_block[-1](x)
+ return x
+
+ def interpolate_pos_encoding(self, x, pos_embed):
+
+ _, N, dim = pos_embed.shape
+
+ if self.is_video:
+
+ _, _, T, H, W = x.shape
+ if H == self.img_height and W == self.img_width and T == self.num_frames:
+ return pos_embed
+
+ elif H == self.img_height and W == self.img_width and T < self.num_frames:
+ new_N = int(
+ (T // self.tubelet_size)
+ * (H // self.patch_size)
+ * (W // self.patch_size)
+ )
+ return pos_embed[:, :new_N, :]
+
+ T = T // self.tubelet_size
+ H = H // self.patch_size
+ W = W // self.patch_size
+
+ N_t = self.num_frames // self.tubelet_size
+ N_h = self.img_height // self.patch_size
+ N_w = self.img_width // self.patch_size
+ assert N_h * N_w * N_t == N, "Positional embedding initialized incorrectly"
+
+ scale_factor = (T / N_t, H / N_h, W / N_w)
+
+ pos_embed = nn.functional.interpolate(
+ pos_embed.reshape(1, N_t, N_h, N_w, dim).permute(0, 4, 1, 2, 3),
+ scale_factor=scale_factor,
+ mode="trilinear",
+ )
+ pos_embed = pos_embed.permute(0, 2, 3, 4, 1).view(1, -1, dim)
+ return pos_embed
+
+ else:
+
+ _, _, H, W = x.shape
+ if H == self.img_height and W == self.img_width:
+ return pos_embed
+
+ npatch = (H // self.patch_size) * (W // self.patch_size)
+ scale_factor = math.sqrt(npatch / N)
+
+ pos_embed = nn.functional.interpolate(
+ pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(
+ 0, 3, 1, 2
+ ),
+ scale_factor=scale_factor,
+ mode="bicubic",
+ )
+ pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return pos_embed
+
+
+def vit_synthetic(patch_size=16, **kwargs):
+ model = VisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1,
+ depth=1,
+ num_heads=1,
+ mlp_ratio=4,
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ **kwargs,
+ )
+ return model
+
+
+def vit_tiny(patch_size=16, **kwargs):
+ model = VisionTransformer(
+ patch_size=patch_size,
+ embed_dim=192,
+ depth=12,
+ num_heads=3,
+ mlp_ratio=4,
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ **kwargs,
+ )
+ return model
+
+
+def vit_small(patch_size=16, **kwargs):
+ model = VisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ **kwargs,
+ )
+ return model
+
+
+def vit_base(patch_size=16, **kwargs):
+ model = VisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ **kwargs,
+ )
+ return model
+
+
+def vit_large(patch_size=16, **kwargs):
+ model = VisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ **kwargs,
+ )
+ return model
+
+
+def vit_large_rope(patch_size=16, **kwargs):
+ model = VisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ qkv_bias=True,
+ use_rope=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ **kwargs,
+ )
+ return model
+
+
+def vit_huge(patch_size=16, **kwargs):
+ model = VisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1280,
+ depth=32,
+ num_heads=16,
+ mlp_ratio=4,
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ **kwargs,
+ )
+ return model
+
+
+def vit_huge_rope(patch_size=16, **kwargs):
+ model = VisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1280,
+ depth=32,
+ num_heads=16,
+ mlp_ratio=4,
+ qkv_bias=True,
+ use_rope=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ **kwargs,
+ )
+ return model
+
+
+def vit_giant(patch_size=16, **kwargs):
+ model = VisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1408,
+ depth=40,
+ num_heads=16,
+ mlp_ratio=48 / 11,
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ **kwargs,
+ )
+ return model
+
+
+def vit_giant_rope(patch_size=16, **kwargs):
+ model = VisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1408,
+ depth=40,
+ num_heads=16,
+ mlp_ratio=48 / 11,
+ qkv_bias=True,
+ use_rope=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ **kwargs,
+ )
+ return model
+
+
+def vit_giant_xformers(patch_size=16, **kwargs):
+ model = VisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1408,
+ depth=40,
+ num_heads=22,
+ mlp_ratio=48 / 11,
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ **kwargs,
+ )
+ return model
+
+
+def vit_giant_xformers_rope(patch_size=16, **kwargs):
+ model = VisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1408,
+ depth=40,
+ num_heads=22,
+ mlp_ratio=48 / 11,
+ qkv_bias=True,
+ use_rope=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ **kwargs,
+ )
+ return model
+
+
+def vit_gigantic(patch_size=16, **kwargs):
+ model = VisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1664,
+ depth=48,
+ num_heads=16,
+ mlp_ratio=64 / 13,
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ **kwargs,
+ )
+ return model
+
+
+def vit_gigantic_xformers(patch_size=16, **kwargs):
+ model = VisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1664,
+ depth=48,
+ num_heads=26,
+ mlp_ratio=64 / 13,
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ **kwargs,
+ )
+ return model
+
+
+VIT_EMBED_DIMS = {
+ "vit_synthetic": 1,
+ "vit_tiny": 192,
+ "vit_small": 384,
+ "vit_base": 768,
+ "vit_large": 1024,
+ "vit_huge": 1280,
+ "vit_giant": 1408,
+ "vit_gigantic": 1664,
+}
diff --git a/app/vjepa_2_1/train.py b/app/vjepa_2_1/train.py
new file mode 100644
index 00000000..f8fe5949
--- /dev/null
+++ b/app/vjepa_2_1/train.py
@@ -0,0 +1,835 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+
+# -- FOR DISTRIBUTED TRAINING ENSURE ONLY 1 DEVICE VISIBLE PER PROCESS
+try:
+ os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["SLURM_LOCALID"]
+except Exception:
+ pass
+
+import copy
+import gc
+import random
+import time
+
+import numpy as np
+import torch
+import torch.multiprocessing as mp
+import torch.nn.functional as F
+from app.vjepa_2_1.models.utils.masks_dist import compute_mask_distance
+from app.vjepa_2_1.models.utils.modules import Lambda_LinearWarmupHold
+from app.vjepa_2_1.transforms import make_transforms
+from app.vjepa_2_1.utils import (
+ init_opt,
+ init_video_model,
+ load_checkpoint,
+ normalize_nested,
+)
+from src.datasets.data_manager import init_data
+from src.masks.multiseq_multiblock3d import MaskCollator
+from src.masks.utils import apply_masks
+from src.utils.distributed import init_distributed
+from src.utils.logging import AverageMeter, CSVLogger, get_logger, gpu_timer
+from torch.nn.parallel import DistributedDataParallel
+
+
+log_timings = True
+log_freq = 10
+CHECKPOINT_FREQ = 1
+GARBAGE_COLLECT_ITR_FREQ = 50
+MAX_REPEAT_COUNTS = 10
+
+_GLOBAL_SEED = 0
+random.seed(_GLOBAL_SEED)
+np.random.seed(_GLOBAL_SEED)
+torch.manual_seed(_GLOBAL_SEED)
+torch.backends.cudnn.benchmark = True
+
+
+logger = get_logger(__name__, force=True)
+
+
+def main(args, resume_preempt=False):
+ # ----------------------------------------------------------------------- #
+ # PASSED IN PARAMS FROM CONFIG FILE
+ # ----------------------------------------------------------------------- #
+
+ # -- META
+ folder = args.get("folder")
+ cfgs_meta = args.get("meta")
+ load_model = cfgs_meta.get("load_checkpoint") or resume_preempt
+ r_file = cfgs_meta.get("read_checkpoint", None)
+ seed = cfgs_meta.get("seed", _GLOBAL_SEED)
+ save_every_freq = cfgs_meta.get("save_every_freq", -1)
+ skip_batches = cfgs_meta.get("skip_batches", -1)
+ use_sdpa = cfgs_meta.get("use_sdpa", False)
+ sync_gc = cfgs_meta.get("sync_gc", False)
+ logger.info(f"LD_PRELOAD: {os.environ.get('LD_PRELOAD')}")
+ which_dtype = cfgs_meta.get("dtype")
+ logger.info(f"{which_dtype=}")
+ if which_dtype.lower() == "bfloat16":
+ dtype = torch.bfloat16
+ mixed_precision = True
+ elif which_dtype.lower() == "float16":
+ dtype = torch.float16
+ mixed_precision = True
+ else:
+ dtype = torch.float32
+ mixed_precision = False
+
+ # -- MASK
+ cfgs_mask = args.get("mask")
+
+ # -- MODEL
+ cfgs_model = args.get("model")
+ compile_model = cfgs_model.get("compile_model", False)
+ use_activation_checkpointing = cfgs_model.get("use_activation_checkpointing", False)
+ model_name = cfgs_model.get("model_name")
+ pred_depth = cfgs_model.get("pred_depth")
+ pred_num_heads = cfgs_model.get("pred_num_heads", None)
+ pred_embed_dim = cfgs_model.get("pred_embed_dim")
+ uniform_power = cfgs_model.get("uniform_power", False)
+ use_mask_tokens = cfgs_model.get("use_mask_tokens", False)
+ zero_init_mask_tokens = cfgs_model.get("zero_init_mask_tokens", True)
+ use_rope = cfgs_model.get("use_rope", False)
+ use_silu = cfgs_model.get("use_silu", False)
+ use_pred_silu = cfgs_model.get("use_pred_silu", False)
+ wide_silu = cfgs_model.get("wide_silu", True)
+ is_causal = cfgs_model.get("is_causal", False)
+ pred_is_causal = cfgs_model.get("pred_is_causal", False)
+ init_type = cfgs_model.get("init_type", "default")
+ img_temporal_dim_size = cfgs_model.get("img_temporal_dim_size", None)
+ n_registers = cfgs_model.get("n_registers", 0)
+ has_cls_first = cfgs_model.get("has_cls_first", False)
+ interpolate_rope = cfgs_model.get("interpolate_rope", False)
+ lambda_value_img = cfgs_model.get("lambda_value_img", 0.0)
+ lambda_value_vid = cfgs_model.get("lambda_value_vid", 0.0)
+ n_registers_predictor = cfgs_model.get("n_registers_predictor", 0)
+ lambda_progressive = cfgs_model.get("lambda_progressive", True)
+ normalize_predictor = cfgs_model.get("normalize_predictor", False)
+ modality_embedding = cfgs_model.get("modality_embedding", False)
+ levels_predictor = cfgs_model.get("levels_predictor", 4)
+ if model_name == "vit_large":
+ embed_dim_encoder = 1024
+ elif model_name == "vit_giant_xformers":
+ embed_dim_encoder = 1408
+ elif model_name == "vit_gigantic_xformers":
+ embed_dim_encoder = 1664
+ else:
+ print("Model name not recognized :(")
+
+ # -- DATA
+ cfgs_data = args.get("data")
+ dataset_type = cfgs_data.get("dataset_type", "videodataset")
+ dataset_paths = cfgs_data.get("datasets", [])
+ datasets_weights = cfgs_data.get("datasets_weights")
+ dataset_fpcs = cfgs_data.get("dataset_fpcs")
+ max_num_frames = max(dataset_fpcs)
+ batch_size = cfgs_data.get("batch_size")
+ tubelet_size = cfgs_data.get("tubelet_size")
+ fps = cfgs_data.get("fps")
+ crop_size = cfgs_data.get("crop_size", 224)
+ patch_size = cfgs_data.get("patch_size")
+ grid_size = crop_size // patch_size
+ pin_mem = cfgs_data.get("pin_mem", False)
+ num_workers = cfgs_data.get("num_workers", 1)
+
+ # -- IMG DATA
+ cfgs_img_data = args.get("img_data")
+ img_rank_ratio = 0.25
+ img_mask = None
+ if cfgs_img_data is not None:
+ img_dataset_type = cfgs_img_data.get("dataset_type", "imagenet")
+ img_dataset_paths = cfgs_img_data.get("datasets", [])
+ img_dataset_weights = cfgs_img_data.get("datasets_weights", [])
+ img_dataset_fpcs = cfgs_img_data.get("dataset_fpcs")
+ img_dataset_batch_size = cfgs_img_data.get("batch_size")
+ img_rank_ratio = cfgs_img_data.get("rank_ratio", img_rank_ratio)
+ img_num_workers = cfgs_img_data.get("num_workers", num_workers)
+
+ img_mask = args.get("img_mask", img_mask)
+
+ # -- DATA AUGS
+ cfgs_data_aug = args.get("data_aug")
+ ar_range = cfgs_data_aug.get("random_resize_aspect_ratio", [3 / 4, 4 / 3])
+ rr_scale = cfgs_data_aug.get("random_resize_scale", [0.3, 1.0])
+ motion_shift = cfgs_data_aug.get("motion_shift", False)
+ reprob = cfgs_data_aug.get("reprob", 0.0)
+ use_aa = cfgs_data_aug.get("auto_augment", False)
+
+ # -- LOSS
+ cfgs_loss = args.get("loss")
+ loss_exp = cfgs_loss.get("loss_exp")
+ shift_by_n = cfgs_loss.get("shift_by_n")
+ predict_all = cfgs_loss.get("predict_all", True)
+ weight_distance_loss = cfgs_loss.get("weight_distance_loss", False)
+ offset_context_loss = cfgs_loss.get("offset_context_loss", False)
+
+ # -- OPTIMIZATION
+ cfgs_opt = args.get("optimization")
+ is_anneal = cfgs_opt.get("is_anneal", False)
+ anneal_ckpt = cfgs_opt.get("anneal_ckpt", None)
+ if is_anneal and anneal_ckpt is None:
+ raise ValueError("Must specify anneal_ckpt if is_anneal is True")
+ resume_anneal = cfgs_opt.get("resume_anneal", False) or (
+ is_anneal and resume_preempt
+ )
+ ipe = cfgs_opt.get("ipe", None)
+ ipe_scale = cfgs_opt.get("ipe_scale", 1.0)
+ wd = float(cfgs_opt.get("weight_decay"))
+ final_wd = float(cfgs_opt.get("final_weight_decay"))
+ num_epochs = cfgs_opt.get("epochs")
+ warmup = cfgs_opt.get("warmup")
+ start_lr = cfgs_opt.get("start_lr")
+ lr = cfgs_opt.get("lr")
+ final_lr = cfgs_opt.get("final_lr")
+ ema = cfgs_opt.get("ema")
+ use_radamw = cfgs_opt.get("use_radamw", False)
+ betas = cfgs_opt.get("betas", (0.9, 0.999))
+ eps = cfgs_opt.get("eps", 1.0e-8)
+ loss_reg_std_mult = cfgs_opt.get("loss_reg_std_mult", None)
+ loss_reg_num_tracking_steps = cfgs_opt.get("loss_reg_num_tracking_steps", 300)
+ loss_reg_min_epoch = cfgs_opt.get("loss_reg_min_epoch", 50)
+ if loss_reg_std_mult is not None:
+ logger.info("Loss regulation activated")
+ # ----------------------------------------------------------------------- #
+
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.backends.cudnn.benchmark = True
+ try:
+ mp.set_start_method("spawn")
+ except Exception:
+ pass
+
+ # -- init torch distributed backend
+ world_size, rank = init_distributed()
+ data_world_size, data_rank = world_size, rank
+ logger.info(f"Initialized (rank/world-size) {rank}/{world_size}")
+ img_world_size = 0
+
+ # make adjustments to batch size for image data
+ model_fpcs = dataset_fpcs
+ model_cfgs_mask = cfgs_mask
+ model_tubelet_size = tubelet_size
+ if cfgs_img_data is not None:
+ img_world_size = int(world_size * img_rank_ratio)
+ num_video_ranks = world_size - img_world_size
+ img_total_batch_size = img_dataset_batch_size * world_size
+ video_total_batch_size = batch_size * world_size
+
+ if img_total_batch_size % img_world_size != 0:
+ raise ValueError(
+ f"img_total_batch_size ({img_total_batch_size}) must be divisible by num_img_ranks ({img_world_size})"
+ )
+ if video_total_batch_size % num_video_ranks != 0:
+ raise ValueError(
+ f"video_total_batch_size ({video_total_batch_size}) must be divisible by num_video_ranks ({num_video_ranks})"
+ )
+
+ # img_dataset_batch_size = img_total_batch_size // img_world_size
+ batch_size = video_total_batch_size // num_video_ranks
+
+ if rank < int(world_size * img_rank_ratio):
+ crop_size = cfgs_img_data.get("crop_size", 512)
+ grid_size = crop_size // patch_size
+
+ if rank < int(world_size * img_rank_ratio):
+ logger.info(
+ f"On rank {rank}, updating dataset with dataset type {img_dataset_type}"
+ )
+ if img_temporal_dim_size is not None:
+ if img_dataset_fpcs[0] != 1:
+ raise NotImplementedError(
+ "Image loader only supports 1 frame per clip with img_temporal_dim_size=1"
+ )
+ tubelet_size = 1
+ else:
+ tubelet_size = tubelet_size
+
+ dataset_type = img_dataset_type
+ dataset_paths = img_dataset_paths
+ datasets_weights = img_dataset_weights
+ dataset_fpcs = img_dataset_fpcs
+ batch_size = img_dataset_batch_size
+ num_workers = img_num_workers
+ if img_mask is not None:
+ logger.info("Using image mask")
+ cfgs_mask = img_mask
+
+ data_rank = rank
+ data_world_size = img_world_size
+ lambda_value = lambda_value_img # We select a different lambda value depending on video vs. image
+ else:
+ data_rank = rank - img_world_size
+ data_world_size = world_size - img_world_size
+ lambda_value = lambda_value_vid # We select a different lambda value depending on video vs. image
+
+ logger.info(
+ f"For rank {rank} with world size {world_size}, "
+ f"we have total image batch size {img_total_batch_size}, total video batch size {video_total_batch_size}, "
+ f"image ranks: {img_world_size}, video ranks: {num_video_ranks}, "
+ f"using the following params: "
+ f"dataset_type: {dataset_type}, "
+ f"dataset_paths: {dataset_paths}, "
+ f"datasets_weights: {datasets_weights}, "
+ f"dataset_fpcs: {dataset_fpcs}, "
+ f"batch_size: {batch_size}, "
+ f"num_workers: {num_workers}, "
+ f"data_rank: {data_rank}, "
+ f"data_world_size: {data_world_size}"
+ f"lambda_value for the context loss: {lambda_value}"
+ )
+ else:
+ lambda_value = lambda_value_vid
+
+ # -- set device
+ if not torch.cuda.is_available():
+ device = torch.device("cpu")
+ else:
+ device = torch.device("cuda:0")
+ torch.cuda.set_device(device)
+
+ # -- log/checkpointing paths
+ log_file = os.path.join(folder, f"log_r{rank}.csv")
+ latest_file = "latest.pth.tar"
+ latest_path = os.path.join(folder, latest_file)
+
+ load_path = None
+ if load_model:
+ if is_anneal:
+ if os.path.exists(latest_path) and resume_anneal:
+ load_path = latest_path
+ else:
+ load_path = anneal_ckpt
+ resume_anneal = False
+ else:
+ load_path = r_file if r_file is not None else latest_path
+ if not os.path.exists(load_path):
+ load_path = None
+ load_model = False
+
+ # -- make csv_logger
+ csv_logger = CSVLogger(
+ log_file,
+ ("%d", "epoch"),
+ ("%d", "itr"),
+ ("%.5f", "loss"),
+ ("%d", "iter-time(ms)"),
+ ("%d", "gpu-time(ms)"),
+ ("%d", "dataload-time(ms)"),
+ )
+
+ # -- init model
+ encoder, predictor = init_video_model(
+ uniform_power=uniform_power,
+ use_mask_tokens=use_mask_tokens,
+ num_mask_tokens=int(len(model_cfgs_mask) * len(model_fpcs)),
+ zero_init_mask_tokens=zero_init_mask_tokens,
+ device=device,
+ patch_size=patch_size,
+ max_num_frames=max_num_frames,
+ tubelet_size=model_tubelet_size,
+ model_name=model_name,
+ crop_size=crop_size,
+ pred_depth=pred_depth,
+ pred_num_heads=pred_num_heads,
+ pred_embed_dim=pred_embed_dim,
+ is_causal=is_causal,
+ pred_is_causal=pred_is_causal,
+ use_sdpa=use_sdpa,
+ use_silu=use_silu,
+ use_pred_silu=use_pred_silu,
+ wide_silu=wide_silu,
+ use_rope=use_rope,
+ use_activation_checkpointing=use_activation_checkpointing,
+ return_all_tokens=predict_all,
+ chop_last_n_tokens=shift_by_n,
+ init_type=init_type,
+ img_temporal_dim_size=img_temporal_dim_size,
+ n_registers=n_registers,
+ n_registers_predictor=n_registers_predictor,
+ has_cls_first=has_cls_first,
+ interpolate_rope=interpolate_rope,
+ modality_embedding=modality_embedding,
+ )
+ target_encoder = copy.deepcopy(encoder)
+
+ if compile_model:
+ logger.info("Compiling encoder, target_encoder, and predictor.")
+ torch._dynamo.config.optimize_ddp = False
+ encoder.compile()
+ target_encoder.compile()
+ predictor.compile()
+
+ mask_collator = MaskCollator(
+ cfgs_mask=cfgs_mask,
+ dataset_fpcs=dataset_fpcs,
+ crop_size=crop_size,
+ patch_size=patch_size,
+ tubelet_size=tubelet_size,
+ )
+
+ transform = make_transforms(
+ random_horizontal_flip=True,
+ random_resize_aspect_ratio=ar_range,
+ random_resize_scale=rr_scale,
+ reprob=reprob,
+ auto_augment=use_aa,
+ motion_shift=motion_shift,
+ crop_size=crop_size,
+ )
+
+ # -- init data-loaders/samplers
+ (unsupervised_loader, unsupervised_sampler) = init_data(
+ data=dataset_type,
+ root_path=dataset_paths,
+ batch_size=batch_size,
+ training=True,
+ # clip_len=clip_len,
+ dataset_fpcs=dataset_fpcs,
+ fps=fps,
+ transform=transform,
+ rank=data_rank,
+ world_size=data_world_size,
+ datasets_weights=datasets_weights,
+ collator=mask_collator,
+ num_workers=num_workers,
+ pin_mem=pin_mem,
+ log_dir=None,
+ )
+ try:
+ _dlen = len(unsupervised_loader)
+ except Exception:
+ try:
+ _dlen = unsupervised_loader.num_batches
+ except Exception:
+ _dlen = -1
+ if ipe is None:
+ ipe = _dlen
+ logger.info(f"Using batch size of {batch_size}, fpcs of {dataset_fpcs}")
+ logger.info(f"iterations per epoch/dataset length: {ipe}/{_dlen}")
+
+ # zizi
+
+ # -- init optimizer and scheduler
+ optimizer, scaler, scheduler, wd_scheduler = init_opt(
+ is_anneal=is_anneal,
+ encoder=encoder,
+ predictor=predictor,
+ use_radamw=use_radamw,
+ wd=wd,
+ final_wd=final_wd,
+ start_lr=start_lr,
+ ref_lr=lr,
+ final_lr=final_lr,
+ iterations_per_epoch=ipe,
+ warmup=warmup,
+ num_epochs=num_epochs,
+ ipe_scale=ipe_scale,
+ mixed_precision=mixed_precision,
+ betas=betas,
+ eps=eps,
+ )
+ encoder = DistributedDataParallel(encoder, static_graph=True)
+ predictor = DistributedDataParallel(
+ predictor, static_graph=False, find_unused_parameters=True
+ )
+ target_encoder = DistributedDataParallel(target_encoder)
+ for p in target_encoder.parameters():
+ p.requires_grad = False
+
+ # -- momentum schedule
+ momentum_scheduler = (
+ ema[0] + i * (ema[1] - ema[0]) / (ipe * num_epochs * ipe_scale)
+ for i in range(int(ipe * num_epochs) + 1)
+ )
+ lambda_sched = Lambda_LinearWarmupHold(lambda_value=lambda_value)
+
+ start_epoch = 0
+ # -- load training checkpoint
+ print("Loadind checkpoint from: ", load_path)
+ if load_model or os.path.exists(latest_path):
+ (
+ encoder,
+ predictor,
+ target_encoder,
+ optimizer,
+ scaler,
+ start_epoch,
+ ) = load_checkpoint(
+ r_path=load_path,
+ encoder=encoder,
+ predictor=predictor,
+ target_encoder=target_encoder,
+ opt=optimizer,
+ scaler=scaler,
+ is_anneal=is_anneal and not resume_anneal,
+ )
+ if not is_anneal or resume_anneal:
+ for _ in range(start_epoch * ipe):
+ scheduler.step()
+ wd_scheduler.step()
+ next(momentum_scheduler)
+ mask_collator.step()
+
+ def save_checkpoint(epoch, path):
+ if rank != 0:
+ return
+ save_dict = {
+ "encoder": encoder.state_dict(),
+ "predictor": predictor.state_dict(),
+ "opt": optimizer.state_dict(),
+ "scaler": None if scaler is None else scaler.state_dict(),
+ "target_encoder": target_encoder.state_dict(),
+ "epoch": epoch,
+ "loss": loss_meter.avg,
+ "batch_size": batch_size,
+ "world_size": world_size,
+ "lr": lr,
+ }
+ try:
+ torch.save(save_dict, path)
+ except Exception as e:
+ logger.info(f"Encountered exception when saving checkpoint: {e}")
+
+ logger.info("Initializing loader...")
+ unsupervised_sampler.set_epoch(start_epoch)
+ loader = iter(unsupervised_loader)
+
+ if skip_batches > 0:
+ logger.info(f"Skip {skip_batches} batches")
+
+ for itr in range(skip_batches):
+ if itr % 10 == 0:
+ logger.info(f"Skip {itr}/{skip_batches} batches")
+ try:
+ _ = next(loader)
+ except Exception:
+ loader = iter(unsupervised_loader)
+ _ = next(loader)
+
+ if sync_gc:
+ gc.disable()
+ gc.collect()
+
+ trailing_losses = []
+ step_count = 0
+
+ # -- TRAINING LOOP
+ for epoch in range(start_epoch, num_epochs):
+ logger.info("Epoch %d" % (epoch + 1))
+
+ loss_meter = AverageMeter()
+ mask_meters = {fpc: AverageMeter() for fpc in dataset_fpcs}
+ iter_time_meter = AverageMeter()
+ gpu_time_meter = AverageMeter()
+ data_elapsed_time_meter = AverageMeter()
+
+ for itr in range(ipe):
+ itr_start_time = time.time()
+
+ iter_retries = 0
+ iter_successful = False
+ while not iter_successful:
+ try:
+ sample = next(loader)
+ iter_successful = True
+ except StopIteration:
+ logger.info("Exhausted data loaders. Refreshing...")
+ if "airstore" in dataset_type.lower():
+ unsupervised_sampler.increase_epoch()
+ else:
+ unsupervised_sampler.set_epoch(epoch)
+ loader = iter(unsupervised_loader)
+ except Exception as e:
+ NUM_RETRIES = 5
+ if iter_retries < NUM_RETRIES:
+ logger.warning(
+ f"Encountered exception when loading data (num retries {iter_retries}):\n{e}"
+ )
+ iter_retries += 1
+ time.sleep(5)
+ else:
+ raise RuntimeError(
+ f"Exceeded max retries ({NUM_RETRIES}) when loading data."
+ ) from e
+
+ for _fpc_sample in sample:
+ bs, fpc = _fpc_sample[0][-1][0].size()
+ mask_meters[fpc].update(bs / batch_size)
+
+ def load_clips():
+ all_clips, all_masks_enc, all_masks_pred = [], [], []
+ for fpc_sample in sample:
+ udata, masks_enc, masks_pred = fpc_sample
+ all_clips += [udata[0][0].to(device, non_blocking=True)]
+ all_masks_enc += [
+ [m.to(device, non_blocking=True) for m in masks_enc]
+ ]
+ all_masks_pred += [
+ [m.to(device, non_blocking=True) for m in masks_pred]
+ ]
+ return all_clips, all_masks_enc, all_masks_pred
+
+ clips, masks_enc, masks_pred = load_clips()
+ data_elapsed_time_ms = (time.time() - itr_start_time) * 1000.0
+
+ if sync_gc and (itr + 1) % GARBAGE_COLLECT_ITR_FREQ == 0:
+ logger.info("Running garbage collection...")
+ gc.collect()
+
+ def train_step():
+ _new_lr = scheduler.step()
+ _new_wd = wd_scheduler.step()
+
+ def forward_target(c, embed_dim=embed_dim_encoder):
+ with torch.no_grad():
+ h = target_encoder(c, gram_mode=False, training_mode=True)
+ new_h = []
+ for hi in h:
+ if levels_predictor > 1:
+ hi_0 = F.layer_norm(hi[:, :, :embed_dim], (embed_dim,))
+ hi_1 = F.layer_norm(
+ hi[:, :, embed_dim : embed_dim * 2],
+ (embed_dim,),
+ )
+ hi_2 = F.layer_norm(
+ hi[:, :, embed_dim * 2 : embed_dim * 3],
+ (embed_dim,),
+ )
+ hi_3 = F.layer_norm(hi[:, :, -embed_dim:], (embed_dim,))
+ hi_norm = torch.cat([hi_0, hi_1, hi_2, hi_3], dim=2)
+ new_h.append(hi_norm)
+ else:
+ new_h.append(F.layer_norm(hi, (hi.size(-1),)))
+ return new_h
+
+ def forward_context(clips, embed_dim=embed_dim_encoder):
+ modality = "video"
+ if img_temporal_dim_size is not None:
+ if clips[0].shape[2] == img_temporal_dim_size:
+ modality = "image"
+ z = encoder(clips, masks_enc, gram_mode=False, training_mode=True)
+ z_pred, z_context = predictor(
+ z, masks_enc, masks_pred, mod=modality
+ )
+ if normalize_predictor:
+ z_pred = normalize_nested(z_pred, embed_dim)
+
+ if predict_all:
+ z_context = normalize_nested(z_context, embed_dim)
+ return z_pred, z_context
+
+ def loss_fn(z, h, masks_to_apply, cls_loss, d_weights):
+ if cls_loss:
+ h_cls = [hi[:, 0].unsqueeze(1) for hi in h]
+ h = [
+ apply_masks(hi[:, 1:], mi, concat=False)
+ for hi, mi in zip(h, masks_to_apply)
+ ]
+ loss, n = 0, 0
+ for zi, hi, hi_cls in zip(z, h, h_cls):
+ for zij, hij in zip(zi, hi):
+ h_term = torch.cat([hi_cls, hij], dim=1)
+ loss += (
+ torch.mean(torch.abs(zij - h_term) ** loss_exp)
+ / loss_exp
+ )
+ n += 1
+
+ loss /= n
+ return loss
+ else:
+ h = [
+ apply_masks(hi, mi, concat=False)
+ for hi, mi in zip(h, masks_to_apply)
+ ]
+
+ if d_weights is not None:
+ loss, n = 0, 0
+ for zi, hi, d_i in zip(z, h, d_weights):
+ for zij, hij, d_ij in zip(zi, hi, d_i):
+ loss_n = torch.abs(zij - hij) ** loss_exp * (
+ 1 / d_ij.unsqueeze(2)
+ )
+ loss += torch.mean(loss_n) / loss_exp
+ n += 1
+ loss /= n
+ return loss
+ else:
+ loss, n = 0, 0
+ for zi, hi in zip(z, h):
+ for zij, hij in zip(zi, hi):
+ loss += (
+ torch.mean(torch.abs(zij - hij) ** loss_exp)
+ / loss_exp
+ )
+ n += 1
+ loss /= n
+ return loss
+
+ # Step 1. Forward
+ with torch.cuda.amp.autocast(dtype=dtype, enabled=mixed_precision):
+ h = forward_target(clips)
+ z_pred, z_context = forward_context(clips)
+ loss = 0
+ loss_pred = loss_fn(
+ z_pred, h, masks_pred, cls_loss=has_cls_first, d_weights=None
+ )
+ loss += loss_pred
+
+ # Context loss
+ if predict_all:
+ distance_weights = compute_mask_distance(
+ masks_pred, masks_enc, grid_size, offset_context_loss
+ )
+ if weight_distance_loss:
+ d_weights = distance_weights
+ else:
+ d_weights = None
+ loss_context = loss_fn(
+ z_context, h, masks_enc, cls_loss=False, d_weights=d_weights
+ )
+ if lambda_progressive:
+ lambda_value_step = lambda_sched.value(epoch * ipe + itr)
+ else:
+ lambda_value_step = lambda_value
+ loss += loss_context * lambda_value_step
+
+ # Step 2. Backward & step
+ run_step = True
+ if loss_reg_std_mult is not None:
+ meanval = np.mean(trailing_losses)
+ stdval = np.std(trailing_losses)
+ max_bound = meanval + loss_reg_std_mult * stdval
+ if (
+ loss > max_bound
+ and epoch > loss_reg_min_epoch
+ and len(trailing_losses)
+ > int(0.5 * loss_reg_num_tracking_steps)
+ ):
+ run_step = False
+ loss.backward()
+ logger.info(
+ f"Loss {loss} is above bound {meanval} + {loss_reg_std_mult} * {stdval}. Skipping step."
+ )
+
+ if run_step:
+ if mixed_precision:
+ scaler.scale(loss).backward()
+ scaler.unscale_(optimizer)
+ else:
+ loss.backward()
+ if mixed_precision:
+ scaler.step(optimizer)
+ scaler.update()
+ else:
+ optimizer.step()
+ optimizer.zero_grad()
+
+ # Step 3. momentum update of target encoder
+ m = min(next(momentum_scheduler), ema[1])
+ with torch.no_grad():
+ params_k = []
+ params_q = []
+ for param_q, param_k in zip(
+ encoder.parameters(), target_encoder.parameters()
+ ):
+ params_k.append(param_k)
+ params_q.append(param_q)
+ torch._foreach_mul_(params_k, m)
+ torch._foreach_add_(params_k, params_q, alpha=1 - m)
+
+ return (
+ float(loss),
+ _new_lr,
+ _new_wd,
+ run_step,
+ )
+
+ (
+ loss,
+ _new_lr,
+ _new_wd,
+ run_step,
+ ), gpu_etime_ms = gpu_timer(train_step)
+ iter_elapsed_time_ms = (time.time() - itr_start_time) * 1000.0
+ loss_meter.update(loss)
+ iter_time_meter.update(iter_elapsed_time_ms)
+ gpu_time_meter.update(gpu_etime_ms)
+ data_elapsed_time_meter.update(data_elapsed_time_ms)
+
+ if loss_reg_std_mult is not None:
+ if run_step:
+ trailing_losses.append(loss)
+ if len(trailing_losses) > loss_reg_num_tracking_steps:
+ trailing_losses = trailing_losses[1:]
+ else:
+ step_count += 1
+ if step_count > MAX_REPEAT_COUNTS:
+ raise RuntimeError(
+ "Loss is above bound for too many tries. Exiting."
+ )
+
+ # -- Logging
+ def log_stats():
+ csv_logger.log(
+ epoch + 1,
+ itr,
+ loss,
+ iter_elapsed_time_ms,
+ gpu_etime_ms,
+ data_elapsed_time_ms,
+ )
+ if (
+ (itr % log_freq == 0)
+ or (itr == ipe - 1)
+ or np.isnan(loss)
+ or np.isinf(loss)
+ ):
+ logger.info(
+ "[%d, %5d] loss: %.3f "
+ "masks: %s "
+ "[wd: %.2e] [lr: %.2e] "
+ "[mem: %.2e] "
+ "[iter: %.1f ms] "
+ "[gpu: %.1f ms] "
+ "[data: %.1f ms]"
+ % (
+ epoch + 1,
+ itr,
+ loss_meter.avg,
+ "["
+ + ", ".join(
+ [
+ f"{k}: " + "%.1f" % mask_meters[k].avg
+ for k in mask_meters
+ ]
+ )
+ + "]",
+ _new_wd,
+ _new_lr,
+ torch.cuda.max_memory_allocated() / 1024.0**2,
+ iter_time_meter.avg,
+ gpu_time_meter.avg,
+ data_elapsed_time_meter.avg,
+ )
+ )
+
+ log_stats()
+ assert not np.isnan(loss), "loss is nan"
+
+ # -- Save Checkpoint
+ logger.info("avg. loss %.3f" % loss_meter.avg)
+ if (epoch + 1) % CHECKPOINT_FREQ == 0 or epoch == (num_epochs - 1):
+ save_checkpoint(epoch + 1, latest_path)
+ if save_every_freq > 0 and (epoch + 1) % save_every_freq == 0:
+ save_every_file = f"e{epoch}.pth.tar"
+ save_every_path = os.path.join(folder, save_every_file)
+ save_checkpoint(epoch + 1, save_every_path)
diff --git a/app/vjepa_2_1/transforms.py b/app/vjepa_2_1/transforms.py
new file mode 100644
index 00000000..0fad7ef3
--- /dev/null
+++ b/app/vjepa_2_1/transforms.py
@@ -0,0 +1,139 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+
+import src.datasets.utils.video.transforms as video_transforms
+import torch
+import torchvision.transforms as transforms
+from PIL import Image
+from src.datasets.utils.video.randerase import RandomErasing
+
+
+def make_transforms(
+ random_horizontal_flip=True,
+ random_resize_aspect_ratio=(3 / 4, 4 / 3),
+ random_resize_scale=(0.3, 1.0),
+ reprob=0.0,
+ auto_augment=False,
+ motion_shift=False,
+ crop_size=224,
+ normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
+):
+ _frames_augmentation = VideoTransform(
+ random_horizontal_flip=random_horizontal_flip,
+ random_resize_aspect_ratio=random_resize_aspect_ratio,
+ random_resize_scale=random_resize_scale,
+ reprob=reprob,
+ auto_augment=auto_augment,
+ motion_shift=motion_shift,
+ crop_size=crop_size,
+ normalize=normalize,
+ )
+ return _frames_augmentation
+
+
+class VideoTransform(object):
+ def __init__(
+ self,
+ random_horizontal_flip=True,
+ random_resize_aspect_ratio=(3 / 4, 4 / 3),
+ random_resize_scale=(0.3, 1.0),
+ reprob=0.0,
+ auto_augment=False,
+ motion_shift=False,
+ crop_size=224,
+ normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
+ ):
+ self.random_horizontal_flip = random_horizontal_flip
+ self.random_resize_aspect_ratio = random_resize_aspect_ratio
+ self.random_resize_scale = random_resize_scale
+ self.auto_augment = auto_augment
+ self.motion_shift = motion_shift
+ self.crop_size = crop_size
+ self.mean = torch.tensor(normalize[0], dtype=torch.float32)
+ self.std = torch.tensor(normalize[1], dtype=torch.float32)
+ if not self.auto_augment:
+ self.mean *= 255.0
+ self.std *= 255.0
+
+ self.autoaug_transform = video_transforms.create_random_augment(
+ input_size=(crop_size, crop_size),
+ auto_augment="rand-m7-n4-mstd0.5-inc1",
+ interpolation="bicubic",
+ )
+
+ self.spatial_transform = (
+ video_transforms.random_resized_crop_with_shift
+ if motion_shift
+ else video_transforms.random_resized_crop
+ )
+
+ self.reprob = reprob
+ self.erase_transform = RandomErasing(
+ reprob,
+ mode="pixel",
+ max_count=1,
+ num_splits=1,
+ device="cpu",
+ )
+
+ def __call__(self, buffer):
+ # Handle PIL Image input (from ImageNet datasets)
+ if isinstance(buffer, Image.Image):
+ # Convert PIL Image to tensor with shape T H W C (T=1 for single image)
+ buffer = np.array(buffer) # H W C
+ buffer = np.expand_dims(buffer, axis=0) # T H W C where T=1
+ buffer = torch.tensor(buffer, dtype=torch.float32)
+
+ if self.auto_augment:
+ buffer = [transforms.ToPILImage()(frame) for frame in buffer]
+ buffer = self.autoaug_transform(buffer)
+ buffer = [transforms.ToTensor()(img) for img in buffer]
+ buffer = torch.stack(buffer) # T C H W
+ buffer = buffer.permute(0, 2, 3, 1) # T H W C
+ elif torch.is_tensor(buffer):
+ buffer = buffer.to(torch.float32)
+ else:
+ buffer = torch.tensor(buffer, dtype=torch.float32)
+
+ buffer = buffer.permute(3, 0, 1, 2) # T H W C -> C T H W
+
+ buffer = self.spatial_transform(
+ images=buffer,
+ target_height=self.crop_size,
+ target_width=self.crop_size,
+ scale=self.random_resize_scale,
+ ratio=self.random_resize_aspect_ratio,
+ )
+ if self.random_horizontal_flip:
+ buffer, _ = video_transforms.horizontal_flip(0.5, buffer)
+
+ buffer = _tensor_normalize_inplace(buffer, self.mean, self.std)
+ if self.reprob > 0:
+ buffer = buffer.permute(1, 0, 2, 3)
+ buffer = self.erase_transform(buffer)
+ buffer = buffer.permute(1, 0, 2, 3)
+
+ return buffer
+
+
+def _tensor_normalize_inplace(tensor, mean, std):
+ """
+ Normalize a given tensor by subtracting the mean and dividing the std.
+ Args:
+ tensor (tensor): tensor to normalize (with dimensions C, T, H, W).
+ mean (tensor): mean value to subtract (in 0 to 255 floats).
+ std (tensor): std to divide (in 0 to 255 floats).
+ """
+ if tensor.dtype == torch.uint8:
+ tensor = tensor.float()
+
+ C, T, H, W = tensor.shape
+ tensor = tensor.view(C, -1).permute(1, 0)
+ tensor.sub_(mean).div_(std)
+ tensor = tensor.permute(1, 0).view(C, T, H, W)
+ return tensor
diff --git a/app/vjepa_2_1/utils.py b/app/vjepa_2_1/utils.py
new file mode 100644
index 00000000..d5d32fd6
--- /dev/null
+++ b/app/vjepa_2_1/utils.py
@@ -0,0 +1,368 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import sys
+
+import app.vjepa_2_1.models.predictor as vit_pred
+import app.vjepa_2_1.models.vision_transformer as video_vit
+import torch
+import torch.nn.functional as F
+import yaml
+from app.vjepa_2_1.wrappers import MultiSeqWrapper, PredictorMultiSeqWrapper
+from src.utils.checkpoint_loader import robust_checkpoint_loader
+from src.utils.schedulers import (
+ CosineWDSchedule,
+ LinearDecaySchedule,
+ WarmupCosineSchedule,
+)
+
+logging.basicConfig(stream=sys.stdout, level=logging.INFO)
+logger = logging.getLogger()
+
+
+def normalize_and_concat(tensor, embed_dim):
+ """Split tensor into 4 chunks of size embed_dim along the last axis,
+ apply LayerNorm to each chunk, then concatenate back."""
+ chunks = [
+ F.layer_norm(tensor[:, :, i * embed_dim : (i + 1) * embed_dim], (embed_dim,))
+ for i in range(4)
+ ]
+ return torch.cat(chunks, dim=2)
+
+
+def normalize_nested(nested, embed_dim):
+ """Apply normalize_and_concat recursively over nested lists."""
+ return [
+ [[normalize_and_concat(z, embed_dim) for z in inner] for inner in outer]
+ for outer in nested
+ ]
+
+
+def build_eval_args(
+ model_name,
+ patch_size,
+ tubelet_size,
+ num_frames,
+ logging_folder,
+ checkpoint,
+ write_tag,
+ eval_cfg_paths,
+ uniform_power=False,
+ use_sdpa=False,
+ clip_duration=None,
+ use_silu=False,
+ wide_silu=True,
+ tag=None,
+):
+ """
+ Helper function to parse the pre-training configs to construct the
+ evaluation configs, return as a list of eval configs.
+ """
+ import warnings
+
+ if eval_cfg_paths is None:
+ logger.info("No evaluations specified!")
+ return
+
+ eval_nodes = None
+ eval_tasks_per_node = None
+ args_eval = []
+ for i, f in enumerate(eval_cfg_paths):
+ with open(f, "r") as y_file:
+ _args = yaml.load(y_file, Loader=yaml.FullLoader)
+ _tag = _args.get("tag", "")
+ _args["tag"] = f"{tag}-{_tag}"
+ _nodes = _args.get("nodes", None)
+ _tasks = _args.get("tasks_per_node", 8)
+ eval_nodes = _nodes if eval_nodes is None else eval_nodes
+ eval_tasks_per_node = (
+ _tasks if eval_tasks_per_node is None else eval_tasks_per_node
+ )
+ if (eval_nodes != _nodes) or (eval_tasks_per_node != _tasks):
+ warnings.warn(
+ "Configs for online evals must use same number of nodes for slurm-batch processing"
+ )
+
+ _args["pretrain"] = {}
+ _args["pretrain"]["model_name"] = model_name
+ _args["pretrain"]["patch_size"] = patch_size
+ _args["pretrain"]["tubelet_size"] = tubelet_size
+ _args["pretrain"]["uniform_power"] = uniform_power
+ _args["pretrain"]["use_sdpa"] = use_sdpa
+ _args["pretrain"]["clip_duration"] = clip_duration
+ _args["pretrain"]["use_silu"] = use_silu
+ _args["pretrain"]["wide_silu"] = wide_silu
+ _args["pretrain"]["frames_per_clip"] = num_frames
+ _args["pretrain"]["folder"] = logging_folder
+ _args["pretrain"]["checkpoint"] = checkpoint
+ _args["pretrain"]["write_tag"] = write_tag
+
+ args_eval += [_args]
+
+ return eval_nodes, eval_tasks_per_node, args_eval
+
+
+def load_checkpoint(
+ r_path,
+ encoder,
+ predictor,
+ target_encoder,
+ opt,
+ scaler,
+ is_anneal=False,
+):
+ logger.info(f"Loading {r_path}")
+ checkpoint = robust_checkpoint_loader(r_path, map_location=torch.device("cpu"))
+
+ epoch = 0
+ if not is_anneal:
+ epoch = checkpoint["epoch"]
+
+ pretrained_dict = checkpoint["encoder"]
+ for k, v in encoder.state_dict().items():
+ if k not in pretrained_dict:
+ logger.info(f'key "{k}" could not be found in loaded state dict')
+ elif pretrained_dict[k].shape != v.shape:
+ logger.info(
+ f'key "{k}" is of different shape in model and loaded state dict'
+ )
+ pretrained_dict[k] = v
+ msg = encoder.load_state_dict(pretrained_dict, strict=False)
+ logger.info(f"loaded pretrained encoder from epoch {epoch} with msg: {msg}")
+
+ pretrained_dict = checkpoint["predictor"]
+ for k, v in predictor.state_dict().items():
+ if k not in pretrained_dict:
+ logger.info(f'key "{k}" could not be found in loaded state dict')
+ elif pretrained_dict[k].shape != v.shape:
+ logger.info(
+ f'key "{k}" is of different shape in model and loaded state dict'
+ )
+ pretrained_dict[k] = v
+ msg = predictor.load_state_dict(pretrained_dict, strict=False)
+ logger.info(f"loaded pretrained predictor from epoch {epoch} with msg: {msg}")
+
+ if target_encoder is not None:
+ pretrained_dict = checkpoint["target_encoder"]
+ for k, v in target_encoder.state_dict().items():
+ if k not in pretrained_dict:
+ logger.info(f'key "{k}" could not be found in loaded state dict')
+ elif pretrained_dict[k].shape != v.shape:
+ logger.info(
+ f'key "{k}" is of different shape in model and loaded state dict'
+ )
+ pretrained_dict[k] = v
+ msg = target_encoder.load_state_dict(pretrained_dict, strict=False)
+ logger.info(
+ f"loaded pretrained target encoder from epoch {epoch} with msg: {msg}"
+ )
+
+ try:
+ opt.load_state_dict(checkpoint["opt"])
+ except ValueError:
+ print("[warn] Optimizer groups mismatch; reinitializing optimizer.")
+ if scaler is not None:
+ scaler.load_state_dict(checkpoint["scaler"])
+ logger.info(f"loaded optimizers from epoch {epoch}")
+ logger.info(f"read-path: {r_path}")
+ del checkpoint
+
+ return (
+ encoder,
+ predictor,
+ target_encoder,
+ opt,
+ scaler,
+ epoch,
+ )
+
+
+def init_video_model(
+ device,
+ patch_size=16,
+ max_num_frames=16,
+ tubelet_size=2,
+ model_name="vit_base",
+ crop_size=224,
+ pred_depth=6,
+ pred_num_heads=None,
+ pred_embed_dim=384,
+ uniform_power=False,
+ use_mask_tokens=False,
+ num_mask_tokens=2,
+ zero_init_mask_tokens=True,
+ use_sdpa=False,
+ use_rope=False,
+ use_silu=False,
+ use_pred_silu=False,
+ wide_silu=False,
+ is_causal=False,
+ pred_is_causal=False,
+ use_activation_checkpointing=False,
+ return_all_tokens=False,
+ chop_last_n_tokens=0,
+ init_type="default",
+ img_temporal_dim_size=None,
+ n_registers=0,
+ n_registers_predictor=0,
+ has_cls_first=False,
+ interpolate_rope=False,
+ modality_embedding=False,
+):
+ encoder = video_vit.__dict__[model_name](
+ img_size=crop_size,
+ patch_size=patch_size,
+ num_frames=max_num_frames,
+ tubelet_size=tubelet_size,
+ uniform_power=uniform_power,
+ use_sdpa=use_sdpa,
+ use_silu=use_silu,
+ wide_silu=wide_silu,
+ use_activation_checkpointing=use_activation_checkpointing,
+ is_causal=is_causal,
+ use_rope=use_rope,
+ init_type=init_type,
+ img_temporal_dim_size=img_temporal_dim_size,
+ n_registers=n_registers,
+ has_cls_first=has_cls_first,
+ interpolate_rope=interpolate_rope,
+ modality_embedding=modality_embedding,
+ )
+ encoder = MultiSeqWrapper(encoder)
+ predictor = vit_pred.__dict__["vit_predictor"](
+ img_size=crop_size,
+ use_mask_tokens=use_mask_tokens,
+ patch_size=patch_size,
+ num_frames=max_num_frames,
+ tubelet_size=tubelet_size,
+ embed_dim=encoder.backbone.embed_dim,
+ predictor_embed_dim=pred_embed_dim,
+ depth=pred_depth,
+ num_heads=(
+ encoder.backbone.num_heads if pred_num_heads is None else pred_num_heads
+ ),
+ uniform_power=uniform_power,
+ num_mask_tokens=num_mask_tokens,
+ zero_init_mask_tokens=zero_init_mask_tokens,
+ use_rope=use_rope,
+ use_sdpa=use_sdpa,
+ is_causal=pred_is_causal,
+ use_silu=use_pred_silu,
+ wide_silu=wide_silu,
+ use_activation_checkpointing=use_activation_checkpointing,
+ return_all_tokens=return_all_tokens,
+ chop_last_n_tokens=chop_last_n_tokens,
+ n_registers=n_registers_predictor,
+ has_cls_first=has_cls_first,
+ interpolate_rope=interpolate_rope,
+ modality_embedding=modality_embedding,
+ img_temporal_dim_size=img_temporal_dim_size,
+ )
+ predictor = PredictorMultiSeqWrapper(predictor)
+
+ encoder.to(device)
+ predictor.to(device)
+ logger.info(encoder)
+ logger.info(predictor)
+
+ def count_parameters(model):
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.info(f"Encoder number of parameters: {count_parameters(encoder)}")
+ logger.info(f"Predictor number of parameters: {count_parameters(predictor)}")
+
+ return encoder, predictor
+
+
+def init_opt(
+ is_anneal,
+ encoder,
+ predictor,
+ iterations_per_epoch,
+ start_lr,
+ ref_lr,
+ warmup,
+ num_epochs,
+ use_radamw=False,
+ wd=1e-6,
+ final_wd=1e-6,
+ final_lr=0.0,
+ mixed_precision=False,
+ ipe_scale=1.25,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ zero_init_bias_wd=True,
+):
+ param_groups = [
+ {
+ "params": (
+ p
+ for n, p in encoder.named_parameters()
+ if ("bias" not in n) and (len(p.shape) != 1)
+ )
+ },
+ {
+ "params": (
+ p
+ for n, p in predictor.named_parameters()
+ if ("bias" not in n) and (len(p.shape) != 1)
+ )
+ },
+ {
+ "params": (
+ p
+ for n, p in encoder.named_parameters()
+ if ("bias" in n) or (len(p.shape) == 1)
+ ),
+ "WD_exclude": zero_init_bias_wd,
+ "weight_decay": 0,
+ },
+ {
+ "params": (
+ p
+ for n, p in predictor.named_parameters()
+ if ("bias" in n) or (len(p.shape) == 1)
+ ),
+ "WD_exclude": zero_init_bias_wd,
+ "weight_decay": 0,
+ },
+ ]
+
+ if use_radamw:
+ from src.utils.adamw import AdamW as RAdamW
+
+ logger.info("Using Rescaled-AdamW")
+ optimizer = RAdamW(param_groups, betas=betas, eps=eps)
+ else:
+ logger.info("Using AdamW")
+ optimizer = torch.optim.AdamW(param_groups, betas=betas, eps=eps)
+
+ if not is_anneal:
+ scheduler = WarmupCosineSchedule(
+ optimizer,
+ warmup_steps=int(warmup * iterations_per_epoch),
+ start_lr=start_lr,
+ ref_lr=ref_lr,
+ final_lr=final_lr,
+ T_max=int(ipe_scale * num_epochs * iterations_per_epoch),
+ )
+ else:
+ scheduler = LinearDecaySchedule(
+ optimizer,
+ ref_lr=ref_lr,
+ final_lr=final_lr,
+ T_max=int(ipe_scale * num_epochs * iterations_per_epoch),
+ )
+ wd_scheduler = CosineWDSchedule(
+ optimizer,
+ ref_wd=wd,
+ final_wd=final_wd,
+ T_max=int(ipe_scale * num_epochs * iterations_per_epoch),
+ )
+
+ scaler = torch.cuda.amp.GradScaler() if mixed_precision else None
+ return optimizer, scaler, scheduler, wd_scheduler
diff --git a/app/vjepa_2_1/wrappers.py b/app/vjepa_2_1/wrappers.py
new file mode 100644
index 00000000..d31dfb80
--- /dev/null
+++ b/app/vjepa_2_1/wrappers.py
@@ -0,0 +1,115 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+#
+
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class MultiSeqWrapper(nn.Module):
+
+ def __init__(self, backbone):
+ super().__init__()
+ self.backbone = backbone
+ self.embed_dim = backbone.embed_dim
+
+ def forward(self, x, masks=None, gram_mode=False, training_mode=False):
+ """
+ :param x: [list] List of Tensors of different seq lengths
+ :param masks: [list] List of Tensors (index: masks for given seq length)
+ """
+ if masks is None:
+ outputs = []
+ for x_fpc in x:
+ if gram_mode:
+ # First we make the image bigger
+ B, C, T, H, W = x_fpc.shape
+ x_2d = x_fpc.permute(0, 2, 1, 3, 4).reshape(
+ B * T, C, H, W
+ ) # (B*T, C, H, W)
+ x_up = F.interpolate(
+ x_2d, scale_factor=2, mode="bicubic", align_corners=False
+ )
+ _, _, H_up, W_up = x_up.shape
+ x_up = x_up.view(B, T, C, H_up, W_up).permute(
+ 0, 2, 1, 3, 4
+ ) # (B,C,T,H_up,W_up)
+
+ # Then we make the pass through the backbone
+ out = self.backbone(x_up)
+ B, N, D = out.shape
+ H_up_patches = W_up_patches = int(
+ H_up // 16
+ ) # We have hardcoded this to the patch size
+ if T == 1:
+ T_up_patches = 1 # In this case, it is a LVD image
+ else:
+ T_up_patches = int(
+ T // 2
+ ) # We have hardcoded this to the tubelet size
+ out_3d = out.view(
+ B, T_up_patches, H_up_patches, W_up_patches, D
+ ) # (bs, T, H, W, D)
+ out_3d = out_3d.permute(0, 4, 1, 2, 3) # (bs, D, T, H, W)
+
+ # Downscale to original 2D size
+ out_2d = out_3d.permute(0, 2, 1, 3, 4).reshape(
+ B * T_up_patches, D, H_up_patches, W_up_patches
+ ) # (B*T, C, H_up, W_up)
+ out = F.interpolate(
+ out_2d,
+ size=(int(H_up_patches // 2), int(W_up_patches // 2)),
+ mode="bicubic",
+ align_corners=False,
+ )
+ out = out.view(
+ B,
+ T_up_patches,
+ D,
+ int(H_up_patches // 2),
+ int(W_up_patches // 2),
+ ).permute(
+ 0, 2, 1, 3, 4
+ ) # (B,C,T,H,W)
+ out = out.permute(0, 2, 3, 4, 1).reshape(
+ B,
+ T_up_patches * int(H_up_patches // 2) * int(W_up_patches // 2),
+ D,
+ ) # (B,C,T,H,W) -> (B, T, H, W, C) -> (B, T * H * W, C)
+ outputs.append(out)
+ else:
+ outputs.append(self.backbone(x_fpc, training=training_mode))
+ return outputs
+
+ outs = [[] for _ in x]
+ for i, (x_fpc, m_fpc) in enumerate(zip(x, masks)):
+ for m in m_fpc:
+ outs[i] += [self.backbone(x_fpc, masks=m, training=training_mode)]
+ return outs
+
+
+class PredictorMultiSeqWrapper(nn.Module):
+
+ def __init__(self, backbone):
+ super().__init__()
+ self.backbone = backbone
+
+ def forward(self, x, masks_x, masks_y, mod="video"):
+ """
+ :param x: [list] List of encoder outputs for different seq lengths
+ :param masks_x: [list] List of encoder masks
+ :param masks_y: [list] List of predictor masks
+ """
+ n = 0
+ outs_pred = [[] for _ in x]
+ outs_context = [[] for _ in x]
+ for i, (x_fpc, mx_fpc, my_fpc) in enumerate(zip(x, masks_x, masks_y)):
+ for xij, mx, my in zip(x_fpc, mx_fpc, my_fpc):
+ x_pred, x_context = self.backbone(xij, mx, my, mask_index=i, mod=mod)
+ outs_pred[i] += [x_pred]
+ outs_context[i] += [x_context]
+ n += 1
+ return outs_pred, outs_context
diff --git a/assets/architecture_vjepa2_1.jpg b/assets/architecture_vjepa2_1.jpg
new file mode 100644
index 00000000..3e75ac7c
Binary files /dev/null and b/assets/architecture_vjepa2_1.jpg differ
diff --git a/assets/bars_teaser_tikz-1.png b/assets/bars_teaser_tikz-1.png
new file mode 100644
index 00000000..682c74ac
Binary files /dev/null and b/assets/bars_teaser_tikz-1.png differ
diff --git a/assets/teaser_screenshot_5dice.png b/assets/teaser_screenshot_5dice.png
new file mode 100644
index 00000000..25bf2556
Binary files /dev/null and b/assets/teaser_screenshot_5dice.png differ
diff --git a/configs/eval/vitg-384/in1k.yaml b/configs/eval/vitg-384/in1k.yaml
index 1b5c70ec..99513dae 100644
--- a/configs/eval/vitg-384/in1k.yaml
+++ b/configs/eval/vitg-384/in1k.yaml
@@ -13,8 +13,7 @@ experiment:
data:
dataset_name: ImageNet
num_classes: 1000
- root_path: /datasets/
- image_folder: ImageNet_FullSize/240712/061417/
+ root_path: /your_data_path/ImageNet_FullSize/240712/061417/
resolution: 384
optimization:
batch_size: 8
diff --git a/configs/eval/vitl/in1k.yaml b/configs/eval/vitl/in1k.yaml
index 9500666e..961a3913 100644
--- a/configs/eval/vitl/in1k.yaml
+++ b/configs/eval/vitl/in1k.yaml
@@ -13,8 +13,7 @@ experiment:
data:
dataset_name: ImageNet
num_classes: 1000
- root_path: /datasets/
- image_folder: ImageNet_FullSize/240712/061417/
+ root_path: /your_data_path/ImageNet_FullSize/240712/061417/
resolution: 256
optimization:
batch_size: 16
diff --git a/configs/eval_2_1/vitG-384/coin.yaml b/configs/eval_2_1/vitG-384/coin.yaml
new file mode 100644
index 00000000..31195de4
--- /dev/null
+++ b/configs/eval_2_1/vitG-384/coin.yaml
@@ -0,0 +1,163 @@
+cpus_per_task: 16
+eval_name: video_classification_frozen
+folder: /your_folder/evals_2_1/vitG-384/coin
+mem_per_gpu: 220G
+nodes: 16
+resume_checkpoint: true
+tag: coin-vitG16-384-16x8x3
+tasks_per_node: 8
+experiment:
+ classifier:
+ num_heads: 16
+ num_probe_blocks: 4
+ data:
+ dataset_type: VideoDataset
+ dataset_train: /your_data_folder/COIN/train_paths.csv
+ dataset_val: /your_data_folder/COIN/val_paths.csv
+ frame_step: 4
+ frames_per_clip: 16
+ num_classes: 180
+ num_segments: 8
+ num_views_per_segment: 3
+ resolution: 384
+ optimization:
+ batch_size: 1
+ multihead_kwargs:
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.8
+ num_epochs: 20
+ use_bfloat16: true
+ use_pos_embed: false
+model_kwargs:
+ checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitG_384.pt
+ module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip
+ pretrain_kwargs:
+ encoder:
+ checkpoint_key: target_encoder
+ img_temporal_dim_size: 1
+ model_name: vit_gigantic_xformers
+ patch_size: 16
+ tubelet_size: 2
+ uniform_power: true
+ use_rope: true
+ wrapper_kwargs:
+ max_frames: 128
+ use_pos_embed: false
diff --git a/configs/eval_2_1/vitG-384/diving48.yaml b/configs/eval_2_1/vitG-384/diving48.yaml
new file mode 100644
index 00000000..21d62862
--- /dev/null
+++ b/configs/eval_2_1/vitG-384/diving48.yaml
@@ -0,0 +1,62 @@
+nodes: 8
+tasks_per_node: 8
+cpus_per_task: 12
+mem_per_gpu: 200G
+tag: diving48-vitG16-384-32x4x3
+eval_name: video_classification_frozen
+folder: /your_folder/evals_2_1/vitG-384/diving48
+resume_checkpoint: true
+experiment:
+ classifier:
+ num_probe_blocks: 4
+ num_heads: 16
+ data:
+ dataset_type: VideoDataset
+ dataset_train: /your_data_dir/diving48/annotations/Diving48_train_paths.csv
+ dataset_val: /your_data_dir/diving48/annotations/Diving48_test_paths.csv
+ num_classes: 48
+ resolution: 384
+ frames_per_clip: 32
+ frame_step: 2
+ num_segments: 4
+ num_views_per_segment: 3
+ optimization:
+ use_pos_embed: false
+ num_epochs: 100
+ batch_size: 2
+ use_bfloat16: true
+ multihead_kwargs:
+ - weight_decay: 0.8
+ final_weight_decay: 0.8
+ lr: 0.001
+ start_lr: 0.001
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.8
+ final_weight_decay: 0.8
+ lr: 0.0003
+ start_lr: 0.0003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.8
+ final_weight_decay: 0.8
+ lr: 0.0001
+ start_lr: 0.0001
+ final_lr: 0.0
+ warmup: 0.
+model_kwargs:
+ checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitG_384.pt
+ module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip_multilevel
+ wrapper_kwargs:
+ max_frames: 128
+ use_pos_embed: false
+ out_layers: [11, 23, 37, 47]
+ pretrain_kwargs:
+ encoder:
+ model_name: vit_gigantic_xformers
+ checkpoint_key: target_encoder
+ img_temporal_dim_size: 1
+ tubelet_size: 2
+ patch_size: 16
+ uniform_power: true
+ use_rope: true
diff --git a/configs/eval_2_1/vitG-384/ek100.yaml b/configs/eval_2_1/vitG-384/ek100.yaml
new file mode 100644
index 00000000..f51c8a5d
--- /dev/null
+++ b/configs/eval_2_1/vitG-384/ek100.yaml
@@ -0,0 +1,199 @@
+nodes: 8
+tasks_per_node: 8
+cpus_per_task: 12
+tag: ek100-vitG16-384
+eval_name: action_anticipation_frozen
+folder: /your_folder/evals_2_1/vitG-384/ek100
+resume_checkpoint: true
+experiment:
+ classifier:
+ num_probe_blocks: 4
+ num_heads: 16
+ data:
+ anticipation_time_sec:
+ - 1.0
+ - 1.0
+ auto_augment: true
+ file_format: 0
+ dataset: EK100
+ # e.g. /home/username/EPIC-KITCHENS
+ base_path: /your_ek100_root_dir/
+ dataset_train: /your_data/EPIC_100_train.csv
+ dataset_val: /your_data/EPIC_100_validation.csv
+ frames_per_clip: 32
+ frames_per_second: 8
+ motion_shift: false
+ num_workers: 2
+ pin_memory: true
+ random_resize_scale:
+ - 0.08
+ - 1.0
+ reprob: 0.25
+ resolution: 384
+ train_anticipation_point:
+ - 0.0
+ - 0.25
+ train_anticipation_time_sec:
+ - 0.25
+ - 1.75
+ optimization:
+ num_epochs: 20
+ batch_size: 2
+ use_bfloat16: true
+ use_focal_loss: true
+ multihead_kwargs:
+ - weight_decay: 0.0001
+ final_weight_decay: 0.0001
+ lr: 0.005
+ start_lr: 0.005
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.0001
+ final_weight_decay: 0.0001
+ lr: 0.003
+ start_lr: 0.003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.0001
+ final_weight_decay: 0.0001
+ lr: 0.001
+ start_lr: 0.001
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.0001
+ final_weight_decay: 0.0001
+ lr: 0.0003
+ start_lr: 0.0003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.0001
+ final_weight_decay: 0.0001
+ lr: 0.0001
+ start_lr: 0.0001
+ final_lr: 0.0
+ warmup: 0.
+
+ - weight_decay: 0.001
+ final_weight_decay: 0.001
+ lr: 0.005
+ start_lr: 0.005
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.001
+ final_weight_decay: 0.001
+ lr: 0.003
+ start_lr: 0.003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.001
+ final_weight_decay: 0.001
+ lr: 0.001
+ start_lr: 0.001
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.001
+ final_weight_decay: 0.001
+ lr: 0.0003
+ start_lr: 0.0003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.001
+ final_weight_decay: 0.001
+ lr: 0.0001
+ start_lr: 0.0001
+ final_lr: 0.0
+ warmup: 0.
+
+ - weight_decay: 0.01
+ final_weight_decay: 0.01
+ lr: 0.005
+ start_lr: 0.005
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.01
+ final_weight_decay: 0.01
+ lr: 0.003
+ start_lr: 0.003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.01
+ final_weight_decay: 0.01
+ lr: 0.001
+ start_lr: 0.001
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.01
+ final_weight_decay: 0.01
+ lr: 0.0003
+ start_lr: 0.0003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.01
+ final_weight_decay: 0.01
+ lr: 0.0001
+ start_lr: 0.0001
+ final_lr: 0.0
+ warmup: 0.
+
+ - weight_decay: 0.1
+ final_weight_decay: 0.1
+ lr: 0.005
+ start_lr: 0.005
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.1
+ final_weight_decay: 0.1
+ lr: 0.003
+ start_lr: 0.003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.1
+ final_weight_decay: 0.1
+ lr: 0.001
+ start_lr: 0.001
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.1
+ final_weight_decay: 0.1
+ lr: 0.0003
+ start_lr: 0.0003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.1
+ final_weight_decay: 0.1
+ lr: 0.0001
+ start_lr: 0.0001
+ final_lr: 0.0
+ warmup: 0.
+model_kwargs:
+ checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitG_384.pt
+ module_name: evals.action_anticipation_frozen.modelcustom.vit_encoder_predictor_concat_ar
+ wrapper_kwargs:
+ no_predictor: false
+ num_output_frames: 2
+ num_steps: 1
+ pretrain_kwargs:
+ use_v2_1: true
+ encoder:
+ model_name: vit_gigantic_xformers
+ checkpoint_key: target_encoder
+ img_temporal_dim_size: 1
+ tubelet_size: 2
+ patch_size: 16
+ uniform_power: true
+ use_rope: true
+ predictor:
+ model_name: vit_predictor
+ checkpoint_key: predictor
+ num_frames: 64
+ depth: 24
+ num_heads: 12
+ predictor_embed_dim: 384
+ num_mask_tokens: 8
+ img_temporal_dim_size: 1
+ uniform_power: true
+ use_mask_tokens: true
+ use_sdpa: true
+ use_silu: false
+ wide_silu: false
+ use_rope: true
diff --git a/configs/eval_2_1/vitG-384/in1k.yaml b/configs/eval_2_1/vitG-384/in1k.yaml
new file mode 100644
index 00000000..43c69869
--- /dev/null
+++ b/configs/eval_2_1/vitG-384/in1k.yaml
@@ -0,0 +1,162 @@
+cpus_per_task: 16
+eval_name: image_classification_frozen
+folder: /your_folder/evals_2_1/vitG-384/in1k
+mem_per_gpu: 220G
+nodes: 16
+resume_checkpoint: true
+tag: in1k-vitG16-384-18f
+tasks_per_node: 8
+experiment:
+ classifier:
+ num_heads: 16
+ num_probe_blocks: 4
+ data:
+ dataset_name: ImageNet
+ num_classes: 1000
+ root_path: /your_data_path/ImageNet_FullSize/240712/061417/
+ resolution: 384
+ optimization:
+ batch_size: 8
+ multihead_kwargs:
+ - final_lr: 0.0
+ final_weight_decay: 0.0
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.001
+ - final_lr: 0.0
+ final_weight_decay: 0.008
+ lr: 0.0005
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.008
+ - final_lr: 0.0
+ final_weight_decay: 0.004
+ lr: 0.0005
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.004
+ - final_lr: 0.0
+ final_weight_decay: 0.002
+ lr: 0.0005
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.002
+ - final_lr: 0.0
+ final_weight_decay: 0.001
+ lr: 0.0005
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.001
+ - final_lr: 0.0
+ final_weight_decay: 0.0005
+ lr: 0.0005
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.0005
+ - final_lr: 0.0
+ final_weight_decay: 0.008
+ lr: 0.001
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.008
+ - final_lr: 0.0
+ final_weight_decay: 0.004
+ lr: 0.001
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.004
+ - final_lr: 0.0
+ final_weight_decay: 0.002
+ lr: 0.001
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.002
+ - final_lr: 0.0
+ final_weight_decay: 0.001
+ lr: 0.001
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.001
+ - final_lr: 0.0
+ final_weight_decay: 0.0005
+ lr: 0.001
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.0005
+ - final_lr: 0.0
+ final_weight_decay: 0.008
+ lr: 0.0015
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.008
+ - final_lr: 0.0
+ final_weight_decay: 0.004
+ lr: 0.0015
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.004
+ - final_lr: 0.0
+ final_weight_decay: 0.002
+ lr: 0.0015
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.002
+ - final_lr: 0.0
+ final_weight_decay: 0.001
+ lr: 0.0015
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.001
+ - final_lr: 0.0
+ final_weight_decay: 0.0005
+ lr: 0.0015
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.0005
+ - final_lr: 0.0
+ final_weight_decay: 0.008
+ lr: 0.002
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.008
+ - final_lr: 0.0
+ final_weight_decay: 0.004
+ lr: 0.002
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.004
+ - final_lr: 0.0
+ final_weight_decay: 0.002
+ lr: 0.002
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.002
+ - final_lr: 0.0
+ final_weight_decay: 0.001
+ lr: 0.002
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.001
+ - final_lr: 0.0
+ final_weight_decay: 0.0005
+ lr: 0.002
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.0005
+ num_epochs: 20
+ use_bfloat16: true
+model_kwargs:
+ checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitG_384.pt
+ module_name: evals.image_classification_frozen.modelcustom.vit_encoder
+ pretrain_kwargs:
+ encoder:
+ checkpoint_key: target_encoder
+ img_temporal_dim_size: 1
+ model_name: vit_gigantic_xformers
+ patch_size: 16
+ tubelet_size: 2
+ uniform_power: true
+ use_rope: true
+ wrapper_kwargs:
+ img_as_video_nframes: 18
diff --git a/configs/eval_2_1/vitG-384/jester.yaml b/configs/eval_2_1/vitG-384/jester.yaml
new file mode 100644
index 00000000..fd38a394
--- /dev/null
+++ b/configs/eval_2_1/vitG-384/jester.yaml
@@ -0,0 +1,62 @@
+nodes: 8
+tasks_per_node: 8
+cpus_per_task: 12
+mem_per_gpu: 200G
+tag: jester-vitG16-384-32x4x3
+eval_name: video_classification_frozen
+folder: /your_folder/evals_2_1/vitG-384/jester
+resume_checkpoint: true
+experiment:
+ classifier:
+ num_probe_blocks: 4
+ num_heads: 16
+ data:
+ dataset_type: VideoDataset
+ dataset_train: /your_data_dir/Jester/annotations/jester_train_paths.csv
+ dataset_val: /your_data_dir/Jester/annotations/jester_validation_paths.csv
+ num_classes: 27
+ resolution: 384
+ frames_per_clip: 32
+ frame_step: 2
+ num_segments: 4
+ num_views_per_segment: 3
+ optimization:
+ use_pos_embed: false
+ num_epochs: 100
+ batch_size: 2
+ use_bfloat16: true
+ multihead_kwargs:
+ - weight_decay: 0.8
+ final_weight_decay: 0.8
+ lr: 0.001
+ start_lr: 0.001
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.8
+ final_weight_decay: 0.8
+ lr: 0.0003
+ start_lr: 0.0003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.8
+ final_weight_decay: 0.8
+ lr: 0.0001
+ start_lr: 0.0001
+ final_lr: 0.0
+ warmup: 0.
+model_kwargs:
+ checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitG_384.pt
+ module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip_multilevel
+ wrapper_kwargs:
+ max_frames: 128
+ use_pos_embed: false
+ out_layers: [11, 23, 37, 47]
+ pretrain_kwargs:
+ encoder:
+ model_name: vit_gigantic_xformers
+ checkpoint_key: target_encoder
+ img_temporal_dim_size: 1
+ tubelet_size: 2
+ patch_size: 16
+ uniform_power: true
+ use_rope: true
diff --git a/configs/eval_2_1/vitG-384/k400.yaml b/configs/eval_2_1/vitG-384/k400.yaml
new file mode 100644
index 00000000..2e256315
--- /dev/null
+++ b/configs/eval_2_1/vitG-384/k400.yaml
@@ -0,0 +1,164 @@
+cpus_per_task: 16
+eval_name: video_classification_frozen
+folder: /your_folder/evals_2_1/vitG-384/k400
+mem_per_gpu: 220G
+nodes: 32
+num_workers: 8
+resume_checkpoint: true
+tag: k400-vitG16-384-16x8x3-16f
+tasks_per_node: 8
+experiment:
+ classifier:
+ num_heads: 16
+ num_probe_blocks: 4
+ data:
+ dataset_type: VideoDataset
+ dataset_train: /your_data_path/k400_train_paths.csv
+ dataset_val: /your_data_path/k400_val_paths.csv
+ frame_step: 4
+ frames_per_clip: 16
+ num_classes: 400
+ num_segments: 8
+ num_views_per_segment: 3
+ resolution: 384
+ optimization:
+ batch_size: 1
+ multihead_kwargs:
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.8
+ num_epochs: 20
+ use_bfloat16: true
+ use_pos_embed: false
+model_kwargs:
+ checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitG_384.pt
+ module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip
+ pretrain_kwargs:
+ encoder:
+ checkpoint_key: target_encoder
+ img_temporal_dim_size: 1
+ model_name: vit_gigantic_xformers
+ patch_size: 16
+ tubelet_size: 2
+ uniform_power: true
+ use_rope: true
+ wrapper_kwargs:
+ max_frames: 128
+ use_pos_embed: false
diff --git a/configs/eval_2_1/vitG-384/ssv2.yaml b/configs/eval_2_1/vitG-384/ssv2.yaml
new file mode 100644
index 00000000..206dfdd5
--- /dev/null
+++ b/configs/eval_2_1/vitG-384/ssv2.yaml
@@ -0,0 +1,164 @@
+cpus_per_task: 16
+eval_name: video_classification_frozen
+folder: /your_folder/evals_2_1/vitG-384/ssv2
+mem_per_gpu: 220G
+nodes: 16
+num_workers: 8
+resume_checkpoint: true
+tag: ssv2-vitG16-384-64x2x3
+tasks_per_node: 8
+experiment:
+ classifier:
+ num_heads: 16
+ num_probe_blocks: 4
+ data:
+ dataset_type: VideoDataset
+ dataset_train: /your_data_path/ssv2_train_paths.csv
+ dataset_val: /your_data_path/ssv2_val_paths.csv
+ frame_step: 2
+ frames_per_clip: 64
+ num_classes: 174
+ num_segments: 2
+ num_views_per_segment: 3
+ resolution: 384
+ optimization:
+ batch_size: 1
+ multihead_kwargs:
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.8
+ num_epochs: 20
+ use_bfloat16: true
+ use_pos_embed: false
+model_kwargs:
+ checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitG_384.pt
+ module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip
+ pretrain_kwargs:
+ encoder:
+ checkpoint_key: target_encoder
+ img_temporal_dim_size: 1
+ model_name: vit_gigantic_xformers
+ patch_size: 16
+ tubelet_size: 2
+ uniform_power: true
+ use_rope: true
+ wrapper_kwargs:
+ max_frames: 128
+ use_pos_embed: false
diff --git a/configs/eval_2_1/vitb-384/coin.yaml b/configs/eval_2_1/vitb-384/coin.yaml
new file mode 100644
index 00000000..e30343d2
--- /dev/null
+++ b/configs/eval_2_1/vitb-384/coin.yaml
@@ -0,0 +1,163 @@
+cpus_per_task: 16
+eval_name: video_classification_frozen
+folder: /your_folder/evals_2_1/vitb-384/coin
+mem_per_gpu: 220G
+nodes: 8
+resume_checkpoint: true
+tag: coin-vitb16-384-16x8x3
+tasks_per_node: 8
+experiment:
+ classifier:
+ num_heads: 16
+ num_probe_blocks: 4
+ data:
+ dataset_type: VideoDataset
+ dataset_train: /your_data_folder/COIN/train_paths.csv
+ dataset_val: /your_data_folder/COIN/val_paths.csv
+ frame_step: 4
+ frames_per_clip: 16
+ num_classes: 180
+ num_segments: 8
+ num_views_per_segment: 3
+ resolution: 384
+ optimization:
+ batch_size: 2
+ multihead_kwargs:
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.8
+ num_epochs: 20
+ use_bfloat16: true
+ use_pos_embed: false
+model_kwargs:
+ checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitb_dist_vitG_384.pt
+ module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip
+ pretrain_kwargs:
+ encoder:
+ checkpoint_key: ema_encoder
+ img_temporal_dim_size: 1
+ model_name: vit_base
+ patch_size: 16
+ tubelet_size: 2
+ uniform_power: true
+ use_rope: true
+ wrapper_kwargs:
+ max_frames: 128
+ use_pos_embed: false
diff --git a/configs/eval_2_1/vitb-384/diving48.yaml b/configs/eval_2_1/vitb-384/diving48.yaml
new file mode 100644
index 00000000..4d0f9ace
--- /dev/null
+++ b/configs/eval_2_1/vitb-384/diving48.yaml
@@ -0,0 +1,62 @@
+nodes: 8
+tasks_per_node: 8
+cpus_per_task: 12
+mem_per_gpu: 200G
+tag: diving48-vitb16-384-32x4x3
+eval_name: video_classification_frozen
+folder: /your_folder/evals_2_1/vitb-384/diving48
+resume_checkpoint: true
+experiment:
+ classifier:
+ num_probe_blocks: 4
+ num_heads: 16
+ data:
+ dataset_type: VideoDataset
+ dataset_train: /your_data_dir/diving48/annotations/Diving48_train_paths.csv
+ dataset_val: /your_data_dir/diving48/annotations/Diving48_test_paths.csv
+ num_classes: 48
+ resolution: 384
+ frames_per_clip: 32
+ frame_step: 2
+ num_segments: 4
+ num_views_per_segment: 3
+ optimization:
+ use_pos_embed: false
+ num_epochs: 100
+ batch_size: 2
+ use_bfloat16: true
+ multihead_kwargs:
+ - weight_decay: 0.8
+ final_weight_decay: 0.8
+ lr: 0.001
+ start_lr: 0.001
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.8
+ final_weight_decay: 0.8
+ lr: 0.0003
+ start_lr: 0.0003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.8
+ final_weight_decay: 0.8
+ lr: 0.0001
+ start_lr: 0.0001
+ final_lr: 0.0
+ warmup: 0.
+model_kwargs:
+ checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitb_dist_vitG_384.pt
+ module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip_multilevel
+ wrapper_kwargs:
+ max_frames: 128
+ use_pos_embed: false
+ out_layers: [2, 5, 8, 11]
+ pretrain_kwargs:
+ encoder:
+ model_name: vit_base
+ checkpoint_key: ema_encoder
+ img_temporal_dim_size: 1
+ tubelet_size: 2
+ patch_size: 16
+ uniform_power: true
+ use_rope: true
diff --git a/configs/eval_2_1/vitb-384/ek100.yaml b/configs/eval_2_1/vitb-384/ek100.yaml
new file mode 100644
index 00000000..6c445e19
--- /dev/null
+++ b/configs/eval_2_1/vitb-384/ek100.yaml
@@ -0,0 +1,202 @@
+nodes: 8
+tasks_per_node: 8
+cpus_per_task: 12
+tag: ek100-vitb16-384
+eval_name: action_anticipation_frozen
+folder: /your_folder/evals_2_1/vitb-384/ek100
+resume_checkpoint: true
+experiment:
+ classifier:
+ num_probe_blocks: 4
+ num_heads: 16
+ data:
+ anticipation_time_sec:
+ - 1.0
+ - 1.0
+ auto_augment: true
+ file_format: 0
+ dataset: EK100
+ # e.g. /home/username/EPIC-KITCHENS
+ base_path: /your_ek100_root_dir/
+ dataset_train: /your_data/EPIC_100_train.csv
+ dataset_val: /your_data/EPIC_100_validation.csv
+ frames_per_clip: 32
+ frames_per_second: 8
+ motion_shift: false
+ num_workers: 2
+ pin_memory: true
+ random_resize_scale:
+ - 0.08
+ - 1.0
+ reprob: 0.25
+ resolution: 384
+ train_anticipation_point:
+ - 0.0
+ - 0.25
+ train_anticipation_time_sec:
+ - 0.25
+ - 1.75
+ optimization:
+ num_epochs: 20
+ batch_size: 2
+ use_bfloat16: true
+ use_focal_loss: true
+ multihead_kwargs:
+ - weight_decay: 0.0001
+ final_weight_decay: 0.0001
+ lr: 0.005
+ start_lr: 0.005
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.0001
+ final_weight_decay: 0.0001
+ lr: 0.003
+ start_lr: 0.003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.0001
+ final_weight_decay: 0.0001
+ lr: 0.001
+ start_lr: 0.001
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.0001
+ final_weight_decay: 0.0001
+ lr: 0.0003
+ start_lr: 0.0003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.0001
+ final_weight_decay: 0.0001
+ lr: 0.0001
+ start_lr: 0.0001
+ final_lr: 0.0
+ warmup: 0.
+
+ - weight_decay: 0.001
+ final_weight_decay: 0.001
+ lr: 0.005
+ start_lr: 0.005
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.001
+ final_weight_decay: 0.001
+ lr: 0.003
+ start_lr: 0.003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.001
+ final_weight_decay: 0.001
+ lr: 0.001
+ start_lr: 0.001
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.001
+ final_weight_decay: 0.001
+ lr: 0.0003
+ start_lr: 0.0003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.001
+ final_weight_decay: 0.001
+ lr: 0.0001
+ start_lr: 0.0001
+ final_lr: 0.0
+ warmup: 0.
+
+ - weight_decay: 0.01
+ final_weight_decay: 0.01
+ lr: 0.005
+ start_lr: 0.005
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.01
+ final_weight_decay: 0.01
+ lr: 0.003
+ start_lr: 0.003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.01
+ final_weight_decay: 0.01
+ lr: 0.001
+ start_lr: 0.001
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.01
+ final_weight_decay: 0.01
+ lr: 0.0003
+ start_lr: 0.0003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.01
+ final_weight_decay: 0.01
+ lr: 0.0001
+ start_lr: 0.0001
+ final_lr: 0.0
+ warmup: 0.
+
+ - weight_decay: 0.1
+ final_weight_decay: 0.1
+ lr: 0.005
+ start_lr: 0.005
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.1
+ final_weight_decay: 0.1
+ lr: 0.003
+ start_lr: 0.003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.1
+ final_weight_decay: 0.1
+ lr: 0.001
+ start_lr: 0.001
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.1
+ final_weight_decay: 0.1
+ lr: 0.0003
+ start_lr: 0.0003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.1
+ final_weight_decay: 0.1
+ lr: 0.0001
+ start_lr: 0.0001
+ final_lr: 0.0
+ warmup: 0.
+model_kwargs:
+ checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitb_dist_vitG_384.pt
+ module_name: evals.action_anticipation_frozen.modelcustom.vit_encoder_predictor_concat_ar
+ wrapper_kwargs:
+ no_predictor: false
+ num_output_frames: 2
+ num_steps: 1
+ pretrain_kwargs:
+ use_v2_1: true
+ encoder:
+ model_name: vit_base
+ checkpoint_key: ema_encoder
+ img_temporal_dim_size: 1
+ tubelet_size: 2
+ patch_size: 16
+ uniform_power: true
+ use_rope: true
+ predictor:
+ model_name: vit_predictor
+ checkpoint_key: predictor
+ num_frames: 64
+ depth: 12
+ num_heads: 12
+ predictor_embed_dim: 384
+ teacher_embed_dim: 1664
+ num_mask_tokens: 8
+ n_output_distillation: 1
+ return_all_tokens: true
+ img_temporal_dim_size: 1
+ uniform_power: true
+ use_mask_tokens: true
+ use_sdpa: true
+ use_silu: false
+ wide_silu: false
+ use_rope: true
diff --git a/configs/eval_2_1/vitb-384/in1k.yaml b/configs/eval_2_1/vitb-384/in1k.yaml
new file mode 100644
index 00000000..049e332f
--- /dev/null
+++ b/configs/eval_2_1/vitb-384/in1k.yaml
@@ -0,0 +1,162 @@
+cpus_per_task: 16
+eval_name: image_classification_frozen
+folder: /your_folder/evals_2_1/vitb-384/in1k
+mem_per_gpu: 220G
+nodes: 8
+resume_checkpoint: true
+tag: in1k-vitb16-384-16f
+tasks_per_node: 8
+experiment:
+ classifier:
+ num_heads: 16
+ num_probe_blocks: 4
+ data:
+ dataset_name: ImageNet
+ num_classes: 1000
+ root_path: /your_data_path/ImageNet_FullSize/240712/061417/
+ resolution: 384
+ optimization:
+ batch_size: 16
+ multihead_kwargs:
+ - final_lr: 0.0
+ final_weight_decay: 0.0
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.001
+ - final_lr: 0.0
+ final_weight_decay: 0.008
+ lr: 0.0005
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.008
+ - final_lr: 0.0
+ final_weight_decay: 0.004
+ lr: 0.0005
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.004
+ - final_lr: 0.0
+ final_weight_decay: 0.002
+ lr: 0.0005
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.002
+ - final_lr: 0.0
+ final_weight_decay: 0.001
+ lr: 0.0005
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.001
+ - final_lr: 0.0
+ final_weight_decay: 0.0005
+ lr: 0.0005
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.0005
+ - final_lr: 0.0
+ final_weight_decay: 0.008
+ lr: 0.001
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.008
+ - final_lr: 0.0
+ final_weight_decay: 0.004
+ lr: 0.001
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.004
+ - final_lr: 0.0
+ final_weight_decay: 0.002
+ lr: 0.001
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.002
+ - final_lr: 0.0
+ final_weight_decay: 0.001
+ lr: 0.001
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.001
+ - final_lr: 0.0
+ final_weight_decay: 0.0005
+ lr: 0.001
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.0005
+ - final_lr: 0.0
+ final_weight_decay: 0.008
+ lr: 0.0015
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.008
+ - final_lr: 0.0
+ final_weight_decay: 0.004
+ lr: 0.0015
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.004
+ - final_lr: 0.0
+ final_weight_decay: 0.002
+ lr: 0.0015
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.002
+ - final_lr: 0.0
+ final_weight_decay: 0.001
+ lr: 0.0015
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.001
+ - final_lr: 0.0
+ final_weight_decay: 0.0005
+ lr: 0.0015
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.0005
+ - final_lr: 0.0
+ final_weight_decay: 0.008
+ lr: 0.002
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.008
+ - final_lr: 0.0
+ final_weight_decay: 0.004
+ lr: 0.002
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.004
+ - final_lr: 0.0
+ final_weight_decay: 0.002
+ lr: 0.002
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.002
+ - final_lr: 0.0
+ final_weight_decay: 0.001
+ lr: 0.002
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.001
+ - final_lr: 0.0
+ final_weight_decay: 0.0005
+ lr: 0.002
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.0005
+ num_epochs: 20
+ use_bfloat16: true
+model_kwargs:
+ checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitb_dist_vitG_384.pt
+ module_name: evals.image_classification_frozen.modelcustom.vit_encoder
+ pretrain_kwargs:
+ encoder:
+ checkpoint_key: ema_encoder
+ img_temporal_dim_size: 1
+ model_name: vit_base
+ patch_size: 16
+ tubelet_size: 2
+ uniform_power: true
+ use_rope: true
+ wrapper_kwargs:
+ img_as_video_nframes: 16
diff --git a/configs/eval_2_1/vitb-384/jester.yaml b/configs/eval_2_1/vitb-384/jester.yaml
new file mode 100644
index 00000000..5958a266
--- /dev/null
+++ b/configs/eval_2_1/vitb-384/jester.yaml
@@ -0,0 +1,62 @@
+nodes: 8
+tasks_per_node: 8
+cpus_per_task: 12
+mem_per_gpu: 200G
+tag: jester-vitb16-384-32x4x3
+eval_name: video_classification_frozen
+folder: /your_folder/evals_2_1/vitb-384/jester
+resume_checkpoint: true
+experiment:
+ classifier:
+ num_probe_blocks: 4
+ num_heads: 16
+ data:
+ dataset_type: VideoDataset
+ dataset_train: /your_data_dir/Jester/annotations/jester_train_paths.csv
+ dataset_val: /your_data_dir/Jester/annotations/jester_validation_paths.csv
+ num_classes: 27
+ resolution: 384
+ frames_per_clip: 32
+ frame_step: 2
+ num_segments: 4
+ num_views_per_segment: 3
+ optimization:
+ use_pos_embed: false
+ num_epochs: 100
+ batch_size: 2
+ use_bfloat16: true
+ multihead_kwargs:
+ - weight_decay: 0.8
+ final_weight_decay: 0.8
+ lr: 0.001
+ start_lr: 0.001
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.8
+ final_weight_decay: 0.8
+ lr: 0.0003
+ start_lr: 0.0003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.8
+ final_weight_decay: 0.8
+ lr: 0.0001
+ start_lr: 0.0001
+ final_lr: 0.0
+ warmup: 0.
+model_kwargs:
+ checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitb_dist_vitG_384.pt
+ module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip_multilevel
+ wrapper_kwargs:
+ max_frames: 128
+ use_pos_embed: false
+ out_layers: [2, 5, 8, 11]
+ pretrain_kwargs:
+ encoder:
+ model_name: vit_base
+ checkpoint_key: ema_encoder
+ img_temporal_dim_size: 1
+ tubelet_size: 2
+ patch_size: 16
+ uniform_power: true
+ use_rope: true
diff --git a/configs/eval_2_1/vitb-384/k400.yaml b/configs/eval_2_1/vitb-384/k400.yaml
new file mode 100644
index 00000000..da3f5128
--- /dev/null
+++ b/configs/eval_2_1/vitb-384/k400.yaml
@@ -0,0 +1,164 @@
+cpus_per_task: 16
+eval_name: video_classification_frozen
+folder: /your_folder/evals_2_1/vitb-384/k400
+mem_per_gpu: 220G
+nodes: 8
+num_workers: 8
+resume_checkpoint: true
+tag: k400-vitb16-384-16x8x3-16f
+tasks_per_node: 8
+experiment:
+ classifier:
+ num_heads: 16
+ num_probe_blocks: 4
+ data:
+ dataset_type: VideoDataset
+ dataset_train: /your_data_path/k400_train_paths.csv
+ dataset_val: /your_data_path/k400_val_paths.csv
+ frame_step: 4
+ frames_per_clip: 16
+ num_classes: 400
+ num_segments: 8
+ num_views_per_segment: 3
+ resolution: 384
+ optimization:
+ batch_size: 4
+ multihead_kwargs:
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.8
+ num_epochs: 20
+ use_bfloat16: true
+ use_pos_embed: false
+model_kwargs:
+ checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitb_dist_vitG_384.pt
+ module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip
+ pretrain_kwargs:
+ encoder:
+ checkpoint_key: ema_encoder
+ img_temporal_dim_size: 1
+ model_name: vit_base
+ patch_size: 16
+ tubelet_size: 2
+ uniform_power: true
+ use_rope: true
+ wrapper_kwargs:
+ max_frames: 128
+ use_pos_embed: false
diff --git a/configs/eval_2_1/vitb-384/ssv2.yaml b/configs/eval_2_1/vitb-384/ssv2.yaml
new file mode 100644
index 00000000..66aebaae
--- /dev/null
+++ b/configs/eval_2_1/vitb-384/ssv2.yaml
@@ -0,0 +1,164 @@
+cpus_per_task: 16
+eval_name: video_classification_frozen
+folder: /your_folder/evals_2_1/vitb-384/ssv2
+mem_per_gpu: 220G
+nodes: 8
+max_workers: 8
+resume_checkpoint: true
+tag: ssv2-vitb16-384-16x2x3-16f
+tasks_per_node: 8
+experiment:
+ classifier:
+ num_heads: 16
+ num_probe_blocks: 4
+ data:
+ dataset_type: VideoDataset
+ dataset_train: /your_data_path/ssv2_train_paths.csv
+ dataset_val: /your_data_path/ssv2_val_paths.csv
+ frame_step: 4
+ frames_per_clip: 16
+ num_classes: 174
+ num_segments: 2
+ num_views_per_segment: 3
+ resolution: 384
+ optimization:
+ batch_size: 4
+ multihead_kwargs:
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.8
+ num_epochs: 20
+ use_bfloat16: true
+ use_pos_embed: false
+model_kwargs:
+ checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitb_dist_vitG_384.pt
+ module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip
+ pretrain_kwargs:
+ encoder:
+ checkpoint_key: ema_encoder
+ img_temporal_dim_size: 1
+ model_name: vit_base
+ patch_size: 16
+ tubelet_size: 2
+ uniform_power: true
+ use_rope: true
+ wrapper_kwargs:
+ max_frames: 128
+ use_pos_embed: false
diff --git a/configs/eval_2_1/vitg-384/coin.yaml b/configs/eval_2_1/vitg-384/coin.yaml
new file mode 100644
index 00000000..fdd89dec
--- /dev/null
+++ b/configs/eval_2_1/vitg-384/coin.yaml
@@ -0,0 +1,163 @@
+cpus_per_task: 16
+eval_name: video_classification_frozen
+folder: /your_folder/evals_2_1/vitg-384/coin
+mem_per_gpu: 220G
+nodes: 16
+resume_checkpoint: true
+tag: coin-vitg16-384-16x8x3
+tasks_per_node: 8
+experiment:
+ classifier:
+ num_heads: 16
+ num_probe_blocks: 4
+ data:
+ dataset_type: VideoDataset
+ dataset_train: /your_data_folder/COIN/train_paths.csv
+ dataset_val: /your_data_folder/COIN/val_paths.csv
+ frame_step: 4
+ frames_per_clip: 16
+ num_classes: 180
+ num_segments: 8
+ num_views_per_segment: 3
+ resolution: 384
+ optimization:
+ batch_size: 1
+ multihead_kwargs:
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.8
+ num_epochs: 20
+ use_bfloat16: true
+ use_pos_embed: false
+model_kwargs:
+ checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitg_384.pt
+ module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip
+ pretrain_kwargs:
+ encoder:
+ checkpoint_key: target_encoder
+ img_temporal_dim_size: 1
+ model_name: vit_giant_xformers
+ patch_size: 16
+ tubelet_size: 2
+ uniform_power: true
+ use_rope: true
+ wrapper_kwargs:
+ max_frames: 128
+ use_pos_embed: false
diff --git a/configs/eval_2_1/vitg-384/diving48.yaml b/configs/eval_2_1/vitg-384/diving48.yaml
new file mode 100644
index 00000000..ac68b44a
--- /dev/null
+++ b/configs/eval_2_1/vitg-384/diving48.yaml
@@ -0,0 +1,62 @@
+nodes: 8
+tasks_per_node: 8
+cpus_per_task: 12
+mem_per_gpu: 200G
+tag: diving48-vitg16-384-32x4x3
+eval_name: video_classification_frozen
+folder: /your_folder/evals_2_1/vitg-384/diving48
+resume_checkpoint: true
+experiment:
+ classifier:
+ num_probe_blocks: 4
+ num_heads: 16
+ data:
+ dataset_type: VideoDataset
+ dataset_train: /your_data_dir/diving48/annotations/Diving48_train_paths.csv
+ dataset_val: /your_data_dir/diving48/annotations/Diving48_test_paths.csv
+ num_classes: 48
+ resolution: 384
+ frames_per_clip: 32
+ frame_step: 2
+ num_segments: 4
+ num_views_per_segment: 3
+ optimization:
+ use_pos_embed: false
+ num_epochs: 100
+ batch_size: 2
+ use_bfloat16: true
+ multihead_kwargs:
+ - weight_decay: 0.8
+ final_weight_decay: 0.8
+ lr: 0.001
+ start_lr: 0.001
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.8
+ final_weight_decay: 0.8
+ lr: 0.0003
+ start_lr: 0.0003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.8
+ final_weight_decay: 0.8
+ lr: 0.0001
+ start_lr: 0.0001
+ final_lr: 0.0
+ warmup: 0.
+model_kwargs:
+ checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitg_384.pt
+ module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip_multilevel
+ wrapper_kwargs:
+ max_frames: 128
+ use_pos_embed: false
+ out_layers: [9, 19, 29, 39]
+ pretrain_kwargs:
+ encoder:
+ model_name: vit_giant_xformers
+ checkpoint_key: target_encoder
+ img_temporal_dim_size: 1
+ tubelet_size: 2
+ patch_size: 16
+ uniform_power: true
+ use_rope: true
diff --git a/configs/eval_2_1/vitg-384/ek100.yaml b/configs/eval_2_1/vitg-384/ek100.yaml
new file mode 100644
index 00000000..faa31913
--- /dev/null
+++ b/configs/eval_2_1/vitg-384/ek100.yaml
@@ -0,0 +1,199 @@
+nodes: 8
+tasks_per_node: 8
+cpus_per_task: 12
+tag: ek100-vitg16-384
+eval_name: action_anticipation_frozen
+folder: /your_folder/evals_2_1/vitg-384/ek100
+resume_checkpoint: true
+experiment:
+ classifier:
+ num_probe_blocks: 4
+ num_heads: 16
+ data:
+ anticipation_time_sec:
+ - 1.0
+ - 1.0
+ auto_augment: true
+ file_format: 0
+ dataset: EK100
+ # e.g. /home/username/EPIC-KITCHENS
+ base_path: /your_ek100_root_dir/
+ dataset_train: /your_data/EPIC_100_train.csv
+ dataset_val: /your_data/EPIC_100_validation.csv
+ frames_per_clip: 32
+ frames_per_second: 8
+ motion_shift: false
+ num_workers: 2
+ pin_memory: true
+ random_resize_scale:
+ - 0.08
+ - 1.0
+ reprob: 0.25
+ resolution: 384
+ train_anticipation_point:
+ - 0.0
+ - 0.25
+ train_anticipation_time_sec:
+ - 0.25
+ - 1.75
+ optimization:
+ num_epochs: 20
+ batch_size: 2
+ use_bfloat16: true
+ use_focal_loss: true
+ multihead_kwargs:
+ - weight_decay: 0.0001
+ final_weight_decay: 0.0001
+ lr: 0.005
+ start_lr: 0.005
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.0001
+ final_weight_decay: 0.0001
+ lr: 0.003
+ start_lr: 0.003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.0001
+ final_weight_decay: 0.0001
+ lr: 0.001
+ start_lr: 0.001
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.0001
+ final_weight_decay: 0.0001
+ lr: 0.0003
+ start_lr: 0.0003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.0001
+ final_weight_decay: 0.0001
+ lr: 0.0001
+ start_lr: 0.0001
+ final_lr: 0.0
+ warmup: 0.
+
+ - weight_decay: 0.001
+ final_weight_decay: 0.001
+ lr: 0.005
+ start_lr: 0.005
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.001
+ final_weight_decay: 0.001
+ lr: 0.003
+ start_lr: 0.003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.001
+ final_weight_decay: 0.001
+ lr: 0.001
+ start_lr: 0.001
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.001
+ final_weight_decay: 0.001
+ lr: 0.0003
+ start_lr: 0.0003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.001
+ final_weight_decay: 0.001
+ lr: 0.0001
+ start_lr: 0.0001
+ final_lr: 0.0
+ warmup: 0.
+
+ - weight_decay: 0.01
+ final_weight_decay: 0.01
+ lr: 0.005
+ start_lr: 0.005
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.01
+ final_weight_decay: 0.01
+ lr: 0.003
+ start_lr: 0.003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.01
+ final_weight_decay: 0.01
+ lr: 0.001
+ start_lr: 0.001
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.01
+ final_weight_decay: 0.01
+ lr: 0.0003
+ start_lr: 0.0003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.01
+ final_weight_decay: 0.01
+ lr: 0.0001
+ start_lr: 0.0001
+ final_lr: 0.0
+ warmup: 0.
+
+ - weight_decay: 0.1
+ final_weight_decay: 0.1
+ lr: 0.005
+ start_lr: 0.005
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.1
+ final_weight_decay: 0.1
+ lr: 0.003
+ start_lr: 0.003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.1
+ final_weight_decay: 0.1
+ lr: 0.001
+ start_lr: 0.001
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.1
+ final_weight_decay: 0.1
+ lr: 0.0003
+ start_lr: 0.0003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.1
+ final_weight_decay: 0.1
+ lr: 0.0001
+ start_lr: 0.0001
+ final_lr: 0.0
+ warmup: 0.
+model_kwargs:
+ checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitg_384.pt
+ module_name: evals.action_anticipation_frozen.modelcustom.vit_encoder_predictor_concat_ar
+ wrapper_kwargs:
+ no_predictor: false
+ num_output_frames: 2
+ num_steps: 1
+ pretrain_kwargs:
+ use_v2_1: true
+ encoder:
+ model_name: vit_giant_xformers
+ checkpoint_key: target_encoder
+ img_temporal_dim_size: 1
+ tubelet_size: 2
+ patch_size: 16
+ uniform_power: true
+ use_rope: true
+ predictor:
+ model_name: vit_predictor
+ checkpoint_key: predictor
+ num_frames: 64
+ depth: 24
+ num_heads: 12
+ predictor_embed_dim: 384
+ num_mask_tokens: 8
+ img_temporal_dim_size: 1
+ uniform_power: true
+ use_mask_tokens: true
+ use_sdpa: true
+ use_silu: false
+ wide_silu: false
+ use_rope: true
diff --git a/configs/eval_2_1/vitg-384/in1k.yaml b/configs/eval_2_1/vitg-384/in1k.yaml
new file mode 100644
index 00000000..73e4321d
--- /dev/null
+++ b/configs/eval_2_1/vitg-384/in1k.yaml
@@ -0,0 +1,162 @@
+cpus_per_task: 16
+eval_name: image_classification_frozen
+folder: /your_folder/evals_2_1/vitg-384/in1k
+mem_per_gpu: 220G
+nodes: 16
+resume_checkpoint: true
+tag: in1k-vitg16-384-18f
+tasks_per_node: 8
+experiment:
+ classifier:
+ num_heads: 16
+ num_probe_blocks: 4
+ data:
+ dataset_name: ImageNet
+ num_classes: 1000
+ root_path: /your_data_path/ImageNet_FullSize/240712/061417/
+ resolution: 384
+ optimization:
+ batch_size: 8
+ multihead_kwargs:
+ - final_lr: 0.0
+ final_weight_decay: 0.0
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.001
+ - final_lr: 0.0
+ final_weight_decay: 0.008
+ lr: 0.0005
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.008
+ - final_lr: 0.0
+ final_weight_decay: 0.004
+ lr: 0.0005
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.004
+ - final_lr: 0.0
+ final_weight_decay: 0.002
+ lr: 0.0005
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.002
+ - final_lr: 0.0
+ final_weight_decay: 0.001
+ lr: 0.0005
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.001
+ - final_lr: 0.0
+ final_weight_decay: 0.0005
+ lr: 0.0005
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.0005
+ - final_lr: 0.0
+ final_weight_decay: 0.008
+ lr: 0.001
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.008
+ - final_lr: 0.0
+ final_weight_decay: 0.004
+ lr: 0.001
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.004
+ - final_lr: 0.0
+ final_weight_decay: 0.002
+ lr: 0.001
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.002
+ - final_lr: 0.0
+ final_weight_decay: 0.001
+ lr: 0.001
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.001
+ - final_lr: 0.0
+ final_weight_decay: 0.0005
+ lr: 0.001
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.0005
+ - final_lr: 0.0
+ final_weight_decay: 0.008
+ lr: 0.0015
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.008
+ - final_lr: 0.0
+ final_weight_decay: 0.004
+ lr: 0.0015
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.004
+ - final_lr: 0.0
+ final_weight_decay: 0.002
+ lr: 0.0015
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.002
+ - final_lr: 0.0
+ final_weight_decay: 0.001
+ lr: 0.0015
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.001
+ - final_lr: 0.0
+ final_weight_decay: 0.0005
+ lr: 0.0015
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.0005
+ - final_lr: 0.0
+ final_weight_decay: 0.008
+ lr: 0.002
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.008
+ - final_lr: 0.0
+ final_weight_decay: 0.004
+ lr: 0.002
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.004
+ - final_lr: 0.0
+ final_weight_decay: 0.002
+ lr: 0.002
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.002
+ - final_lr: 0.0
+ final_weight_decay: 0.001
+ lr: 0.002
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.001
+ - final_lr: 0.0
+ final_weight_decay: 0.0005
+ lr: 0.002
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.0005
+ num_epochs: 20
+ use_bfloat16: true
+model_kwargs:
+ checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitg_384.pt
+ module_name: evals.image_classification_frozen.modelcustom.vit_encoder
+ pretrain_kwargs:
+ encoder:
+ checkpoint_key: target_encoder
+ img_temporal_dim_size: 1
+ model_name: vit_giant_xformers
+ patch_size: 16
+ tubelet_size: 2
+ uniform_power: true
+ use_rope: true
+ wrapper_kwargs:
+ img_as_video_nframes: 18
diff --git a/configs/eval_2_1/vitg-384/jester.yaml b/configs/eval_2_1/vitg-384/jester.yaml
new file mode 100644
index 00000000..78cb2aac
--- /dev/null
+++ b/configs/eval_2_1/vitg-384/jester.yaml
@@ -0,0 +1,62 @@
+nodes: 8
+tasks_per_node: 8
+cpus_per_task: 12
+mem_per_gpu: 200G
+tag: jester-vitg16-384-32x4x3
+eval_name: video_classification_frozen
+folder: /your_folder/evals_2_1/vitg-384/jester
+resume_checkpoint: true
+experiment:
+ classifier:
+ num_probe_blocks: 4
+ num_heads: 16
+ data:
+ dataset_type: VideoDataset
+ dataset_train: /your_data_dir/Jester/annotations/jester_train_paths.csv
+ dataset_val: /your_data_dir/Jester/annotations/jester_validation_paths.csv
+ num_classes: 27
+ resolution: 384
+ frames_per_clip: 32
+ frame_step: 2
+ num_segments: 4
+ num_views_per_segment: 3
+ optimization:
+ use_pos_embed: false
+ num_epochs: 100
+ batch_size: 2
+ use_bfloat16: true
+ multihead_kwargs:
+ - weight_decay: 0.8
+ final_weight_decay: 0.8
+ lr: 0.001
+ start_lr: 0.001
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.8
+ final_weight_decay: 0.8
+ lr: 0.0003
+ start_lr: 0.0003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.8
+ final_weight_decay: 0.8
+ lr: 0.0001
+ start_lr: 0.0001
+ final_lr: 0.0
+ warmup: 0.
+model_kwargs:
+ checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitg_384.pt
+ module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip_multilevel
+ wrapper_kwargs:
+ max_frames: 128
+ use_pos_embed: false
+ out_layers: [9, 19, 29, 39]
+ pretrain_kwargs:
+ encoder:
+ model_name: vit_giant_xformers
+ checkpoint_key: target_encoder
+ img_temporal_dim_size: 1
+ tubelet_size: 2
+ patch_size: 16
+ uniform_power: true
+ use_rope: true
diff --git a/configs/eval_2_1/vitg-384/k400.yaml b/configs/eval_2_1/vitg-384/k400.yaml
new file mode 100644
index 00000000..3cc6677f
--- /dev/null
+++ b/configs/eval_2_1/vitg-384/k400.yaml
@@ -0,0 +1,164 @@
+cpus_per_task: 16
+eval_name: video_classification_frozen
+folder: /your_folder/evals_2_1/vitg-384/k400
+mem_per_gpu: 220G
+nodes: 32
+num_workers: 8
+resume_checkpoint: true
+tag: k400-vitg16-384-16x8x3-16f
+tasks_per_node: 8
+experiment:
+ classifier:
+ num_heads: 16
+ num_probe_blocks: 4
+ data:
+ dataset_type: VideoDataset
+ dataset_train: /your_data_path/k400_train_paths.csv
+ dataset_val: /your_data_path/k400_val_paths.csv
+ frame_step: 4
+ frames_per_clip: 16
+ num_classes: 400
+ num_segments: 8
+ num_views_per_segment: 3
+ resolution: 384
+ optimization:
+ batch_size: 1
+ multihead_kwargs:
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.8
+ num_epochs: 20
+ use_bfloat16: true
+ use_pos_embed: false
+model_kwargs:
+ checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitg_384.pt
+ module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip
+ pretrain_kwargs:
+ encoder:
+ checkpoint_key: target_encoder
+ img_temporal_dim_size: 1
+ model_name: vit_giant_xformers
+ patch_size: 16
+ tubelet_size: 2
+ uniform_power: true
+ use_rope: true
+ wrapper_kwargs:
+ max_frames: 128
+ use_pos_embed: false
diff --git a/configs/eval_2_1/vitg-384/ssv2.yaml b/configs/eval_2_1/vitg-384/ssv2.yaml
new file mode 100644
index 00000000..97c85f24
--- /dev/null
+++ b/configs/eval_2_1/vitg-384/ssv2.yaml
@@ -0,0 +1,164 @@
+cpus_per_task: 16
+eval_name: video_classification_frozen
+folder: /your_folder/evals_2_1/vitg-384/ssv2
+mem_per_gpu: 220G
+nodes: 16
+num_workers: 8
+resume_checkpoint: true
+tag: ssv2-vitg16-384-64x2x3
+tasks_per_node: 8
+experiment:
+ classifier:
+ num_heads: 16
+ num_probe_blocks: 4
+ data:
+ dataset_type: VideoDataset
+ dataset_train: /your_data_path/ssv2_train_paths.csv
+ dataset_val: /your_data_path/ssv2_val_paths.csv
+ frame_step: 2
+ frames_per_clip: 64
+ num_classes: 174
+ num_segments: 2
+ num_views_per_segment: 3
+ resolution: 384
+ optimization:
+ batch_size: 1
+ multihead_kwargs:
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.8
+ num_epochs: 20
+ use_bfloat16: true
+ use_pos_embed: false
+model_kwargs:
+ checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitg_384.pt
+ module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip
+ pretrain_kwargs:
+ encoder:
+ checkpoint_key: target_encoder
+ img_temporal_dim_size: 1
+ model_name: vit_giant_xformers
+ patch_size: 16
+ tubelet_size: 2
+ uniform_power: true
+ use_rope: true
+ wrapper_kwargs:
+ max_frames: 128
+ use_pos_embed: false
diff --git a/configs/eval_2_1/vitl-384/coin.yaml b/configs/eval_2_1/vitl-384/coin.yaml
new file mode 100644
index 00000000..d4e68613
--- /dev/null
+++ b/configs/eval_2_1/vitl-384/coin.yaml
@@ -0,0 +1,163 @@
+cpus_per_task: 16
+eval_name: video_classification_frozen
+folder: /your_folder/evals_2_1/vitl-384/coin
+mem_per_gpu: 220G
+nodes: 8
+resume_checkpoint: true
+tag: coin-vitl16-384-16x8x3
+tasks_per_node: 8
+experiment:
+ classifier:
+ num_heads: 16
+ num_probe_blocks: 4
+ data:
+ dataset_type: VideoDataset
+ dataset_train: /your_data_folder/COIN/train_paths.csv
+ dataset_val: /your_data_folder/COIN/val_paths.csv
+ frame_step: 4
+ frames_per_clip: 16
+ num_classes: 180
+ num_segments: 8
+ num_views_per_segment: 3
+ resolution: 384
+ optimization:
+ batch_size: 2
+ multihead_kwargs:
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.8
+ num_epochs: 20
+ use_bfloat16: true
+ use_pos_embed: false
+model_kwargs:
+ checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitl_dist_vitG_384.pt
+ module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip
+ pretrain_kwargs:
+ encoder:
+ checkpoint_key: ema_encoder
+ img_temporal_dim_size: 1
+ model_name: vit_large
+ patch_size: 16
+ tubelet_size: 2
+ uniform_power: true
+ use_rope: true
+ wrapper_kwargs:
+ max_frames: 128
+ use_pos_embed: false
diff --git a/configs/eval_2_1/vitl-384/diving48.yaml b/configs/eval_2_1/vitl-384/diving48.yaml
new file mode 100644
index 00000000..144846bc
--- /dev/null
+++ b/configs/eval_2_1/vitl-384/diving48.yaml
@@ -0,0 +1,62 @@
+nodes: 8
+tasks_per_node: 8
+cpus_per_task: 12
+mem_per_gpu: 200G
+tag: diving48-vitl16-384-32x4x3
+eval_name: video_classification_frozen
+folder: /your_folder/evals_2_1/vitl-384/diving48
+resume_checkpoint: true
+experiment:
+ classifier:
+ num_probe_blocks: 4
+ num_heads: 16
+ data:
+ dataset_type: VideoDataset
+ dataset_train: /your_data_dir/diving48/annotations/Diving48_train_paths.csv
+ dataset_val: /your_data_dir/diving48/annotations/Diving48_test_paths.csv
+ num_classes: 48
+ resolution: 384
+ frames_per_clip: 32
+ frame_step: 2
+ num_segments: 4
+ num_views_per_segment: 3
+ optimization:
+ use_pos_embed: false
+ num_epochs: 100
+ batch_size: 2
+ use_bfloat16: true
+ multihead_kwargs:
+ - weight_decay: 0.8
+ final_weight_decay: 0.8
+ lr: 0.001
+ start_lr: 0.001
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.8
+ final_weight_decay: 0.8
+ lr: 0.0003
+ start_lr: 0.0003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.8
+ final_weight_decay: 0.8
+ lr: 0.0001
+ start_lr: 0.0001
+ final_lr: 0.0
+ warmup: 0.
+model_kwargs:
+ checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitl_dist_vitG_384.pt
+ module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip_multilevel
+ wrapper_kwargs:
+ max_frames: 128
+ use_pos_embed: false
+ out_layers: [5, 11, 17, 23]
+ pretrain_kwargs:
+ encoder:
+ model_name: vit_large
+ checkpoint_key: ema_encoder
+ img_temporal_dim_size: 1
+ tubelet_size: 2
+ patch_size: 16
+ uniform_power: true
+ use_rope: true
diff --git a/configs/eval_2_1/vitl-384/ek100.yaml b/configs/eval_2_1/vitl-384/ek100.yaml
new file mode 100644
index 00000000..c57e4126
--- /dev/null
+++ b/configs/eval_2_1/vitl-384/ek100.yaml
@@ -0,0 +1,202 @@
+nodes: 8
+tasks_per_node: 8
+cpus_per_task: 12
+tag: ek100-vitl16-384
+eval_name: action_anticipation_frozen
+folder: /your_folder/evals_2_1/vitl-384/ek100
+resume_checkpoint: true
+experiment:
+ classifier:
+ num_probe_blocks: 4
+ num_heads: 16
+ data:
+ anticipation_time_sec:
+ - 1.0
+ - 1.0
+ auto_augment: true
+ file_format: 0
+ dataset: EK100
+ # e.g. /home/username/EPIC-KITCHENS
+ base_path: /your_ek100_root_dir/
+ dataset_train: /your_data/EPIC_100_train.csv
+ dataset_val: /your_data/EPIC_100_validation.csv
+ frames_per_clip: 32
+ frames_per_second: 8
+ motion_shift: false
+ num_workers: 2
+ pin_memory: true
+ random_resize_scale:
+ - 0.08
+ - 1.0
+ reprob: 0.25
+ resolution: 384
+ train_anticipation_point:
+ - 0.0
+ - 0.25
+ train_anticipation_time_sec:
+ - 0.25
+ - 1.75
+ optimization:
+ num_epochs: 20
+ batch_size: 2
+ use_bfloat16: true
+ use_focal_loss: true
+ multihead_kwargs:
+ - weight_decay: 0.0001
+ final_weight_decay: 0.0001
+ lr: 0.005
+ start_lr: 0.005
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.0001
+ final_weight_decay: 0.0001
+ lr: 0.003
+ start_lr: 0.003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.0001
+ final_weight_decay: 0.0001
+ lr: 0.001
+ start_lr: 0.001
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.0001
+ final_weight_decay: 0.0001
+ lr: 0.0003
+ start_lr: 0.0003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.0001
+ final_weight_decay: 0.0001
+ lr: 0.0001
+ start_lr: 0.0001
+ final_lr: 0.0
+ warmup: 0.
+
+ - weight_decay: 0.001
+ final_weight_decay: 0.001
+ lr: 0.005
+ start_lr: 0.005
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.001
+ final_weight_decay: 0.001
+ lr: 0.003
+ start_lr: 0.003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.001
+ final_weight_decay: 0.001
+ lr: 0.001
+ start_lr: 0.001
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.001
+ final_weight_decay: 0.001
+ lr: 0.0003
+ start_lr: 0.0003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.001
+ final_weight_decay: 0.001
+ lr: 0.0001
+ start_lr: 0.0001
+ final_lr: 0.0
+ warmup: 0.
+
+ - weight_decay: 0.01
+ final_weight_decay: 0.01
+ lr: 0.005
+ start_lr: 0.005
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.01
+ final_weight_decay: 0.01
+ lr: 0.003
+ start_lr: 0.003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.01
+ final_weight_decay: 0.01
+ lr: 0.001
+ start_lr: 0.001
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.01
+ final_weight_decay: 0.01
+ lr: 0.0003
+ start_lr: 0.0003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.01
+ final_weight_decay: 0.01
+ lr: 0.0001
+ start_lr: 0.0001
+ final_lr: 0.0
+ warmup: 0.
+
+ - weight_decay: 0.1
+ final_weight_decay: 0.1
+ lr: 0.005
+ start_lr: 0.005
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.1
+ final_weight_decay: 0.1
+ lr: 0.003
+ start_lr: 0.003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.1
+ final_weight_decay: 0.1
+ lr: 0.001
+ start_lr: 0.001
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.1
+ final_weight_decay: 0.1
+ lr: 0.0003
+ start_lr: 0.0003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.1
+ final_weight_decay: 0.1
+ lr: 0.0001
+ start_lr: 0.0001
+ final_lr: 0.0
+ warmup: 0.
+model_kwargs:
+ checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitl_dist_vitG_384.pt
+ module_name: evals.action_anticipation_frozen.modelcustom.vit_encoder_predictor_concat_ar
+ wrapper_kwargs:
+ no_predictor: false
+ num_output_frames: 2
+ num_steps: 1
+ pretrain_kwargs:
+ use_v2_1: true
+ encoder:
+ model_name: vit_large
+ checkpoint_key: ema_encoder
+ img_temporal_dim_size: 1
+ tubelet_size: 2
+ patch_size: 16
+ uniform_power: true
+ use_rope: true
+ predictor:
+ model_name: vit_predictor
+ checkpoint_key: predictor
+ num_frames: 64
+ depth: 12
+ num_heads: 12
+ predictor_embed_dim: 384
+ teacher_embed_dim: 1664
+ num_mask_tokens: 8
+ n_output_distillation: 1
+ return_all_tokens: true
+ img_temporal_dim_size: 1
+ uniform_power: true
+ use_mask_tokens: true
+ use_sdpa: true
+ use_silu: false
+ wide_silu: false
+ use_rope: true
diff --git a/configs/eval_2_1/vitl-384/in1k.yaml b/configs/eval_2_1/vitl-384/in1k.yaml
new file mode 100644
index 00000000..408d466a
--- /dev/null
+++ b/configs/eval_2_1/vitl-384/in1k.yaml
@@ -0,0 +1,162 @@
+cpus_per_task: 16
+eval_name: image_classification_frozen
+folder: /your_folder/evals_2_1/vitl-384/in1k
+mem_per_gpu: 220G
+nodes: 8
+resume_checkpoint: true
+tag: in1k-vitl16-384-16f
+tasks_per_node: 8
+experiment:
+ classifier:
+ num_heads: 16
+ num_probe_blocks: 4
+ data:
+ dataset_name: ImageNet
+ num_classes: 1000
+ root_path: /your_data_path/ImageNet_FullSize/240712/061417/
+ resolution: 384
+ optimization:
+ batch_size: 16
+ multihead_kwargs:
+ - final_lr: 0.0
+ final_weight_decay: 0.0
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.001
+ - final_lr: 0.0
+ final_weight_decay: 0.008
+ lr: 0.0005
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.008
+ - final_lr: 0.0
+ final_weight_decay: 0.004
+ lr: 0.0005
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.004
+ - final_lr: 0.0
+ final_weight_decay: 0.002
+ lr: 0.0005
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.002
+ - final_lr: 0.0
+ final_weight_decay: 0.001
+ lr: 0.0005
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.001
+ - final_lr: 0.0
+ final_weight_decay: 0.0005
+ lr: 0.0005
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.0005
+ - final_lr: 0.0
+ final_weight_decay: 0.008
+ lr: 0.001
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.008
+ - final_lr: 0.0
+ final_weight_decay: 0.004
+ lr: 0.001
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.004
+ - final_lr: 0.0
+ final_weight_decay: 0.002
+ lr: 0.001
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.002
+ - final_lr: 0.0
+ final_weight_decay: 0.001
+ lr: 0.001
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.001
+ - final_lr: 0.0
+ final_weight_decay: 0.0005
+ lr: 0.001
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.0005
+ - final_lr: 0.0
+ final_weight_decay: 0.008
+ lr: 0.0015
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.008
+ - final_lr: 0.0
+ final_weight_decay: 0.004
+ lr: 0.0015
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.004
+ - final_lr: 0.0
+ final_weight_decay: 0.002
+ lr: 0.0015
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.002
+ - final_lr: 0.0
+ final_weight_decay: 0.001
+ lr: 0.0015
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.001
+ - final_lr: 0.0
+ final_weight_decay: 0.0005
+ lr: 0.0015
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.0005
+ - final_lr: 0.0
+ final_weight_decay: 0.008
+ lr: 0.002
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.008
+ - final_lr: 0.0
+ final_weight_decay: 0.004
+ lr: 0.002
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.004
+ - final_lr: 0.0
+ final_weight_decay: 0.002
+ lr: 0.002
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.002
+ - final_lr: 0.0
+ final_weight_decay: 0.001
+ lr: 0.002
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.001
+ - final_lr: 0.0
+ final_weight_decay: 0.0005
+ lr: 0.002
+ start_lr: 0.0002
+ warmup: 5
+ weight_decay: 0.0005
+ num_epochs: 20
+ use_bfloat16: true
+model_kwargs:
+ checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitl_dist_vitG_384.pt
+ module_name: evals.image_classification_frozen.modelcustom.vit_encoder
+ pretrain_kwargs:
+ encoder:
+ checkpoint_key: ema_encoder
+ img_temporal_dim_size: 1
+ model_name: vit_large
+ patch_size: 16
+ tubelet_size: 2
+ uniform_power: true
+ use_rope: true
+ wrapper_kwargs:
+ img_as_video_nframes: 16
diff --git a/configs/eval_2_1/vitl-384/jester.yaml b/configs/eval_2_1/vitl-384/jester.yaml
new file mode 100644
index 00000000..e00ca607
--- /dev/null
+++ b/configs/eval_2_1/vitl-384/jester.yaml
@@ -0,0 +1,62 @@
+nodes: 8
+tasks_per_node: 8
+cpus_per_task: 12
+mem_per_gpu: 200G
+tag: jester-vitl16-384-32x4x3
+eval_name: video_classification_frozen
+folder: /your_folder/evals_2_1/vitl-384/jester
+resume_checkpoint: true
+experiment:
+ classifier:
+ num_probe_blocks: 4
+ num_heads: 16
+ data:
+ dataset_type: VideoDataset
+ dataset_train: /your_data_dir/Jester/annotations/jester_train_paths.csv
+ dataset_val: /your_data_dir/Jester/annotations/jester_validation_paths.csv
+ num_classes: 27
+ resolution: 384
+ frames_per_clip: 32
+ frame_step: 2
+ num_segments: 4
+ num_views_per_segment: 3
+ optimization:
+ use_pos_embed: false
+ num_epochs: 100
+ batch_size: 2
+ use_bfloat16: true
+ multihead_kwargs:
+ - weight_decay: 0.8
+ final_weight_decay: 0.8
+ lr: 0.001
+ start_lr: 0.001
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.8
+ final_weight_decay: 0.8
+ lr: 0.0003
+ start_lr: 0.0003
+ final_lr: 0.0
+ warmup: 0.
+ - weight_decay: 0.8
+ final_weight_decay: 0.8
+ lr: 0.0001
+ start_lr: 0.0001
+ final_lr: 0.0
+ warmup: 0.
+model_kwargs:
+ checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitl_dist_vitG_384.pt
+ module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip_multilevel
+ wrapper_kwargs:
+ max_frames: 128
+ use_pos_embed: false
+ out_layers: [5, 11, 17, 23]
+ pretrain_kwargs:
+ encoder:
+ model_name: vit_large
+ checkpoint_key: ema_encoder
+ img_temporal_dim_size: 1
+ tubelet_size: 2
+ patch_size: 16
+ uniform_power: true
+ use_rope: true
diff --git a/configs/eval_2_1/vitl-384/k400.yaml b/configs/eval_2_1/vitl-384/k400.yaml
new file mode 100644
index 00000000..5cd8bd67
--- /dev/null
+++ b/configs/eval_2_1/vitl-384/k400.yaml
@@ -0,0 +1,164 @@
+cpus_per_task: 16
+eval_name: video_classification_frozen
+folder: /your_folder/evals_2_1/vitl-384/k400
+mem_per_gpu: 220G
+nodes: 8
+num_workers: 8
+resume_checkpoint: true
+tag: k400-vitl16-384-16x8x3-16f
+tasks_per_node: 8
+experiment:
+ classifier:
+ num_heads: 16
+ num_probe_blocks: 4
+ data:
+ dataset_type: VideoDataset
+ dataset_train: /your_data_path/k400_train_paths.csv
+ dataset_val: /your_data_path/k400_val_paths.csv
+ frame_step: 4
+ frames_per_clip: 16
+ num_classes: 400
+ num_segments: 8
+ num_views_per_segment: 3
+ resolution: 384
+ optimization:
+ batch_size: 4
+ multihead_kwargs:
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.8
+ num_epochs: 20
+ use_bfloat16: true
+ use_pos_embed: false
+model_kwargs:
+ checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitl_dist_vitG_384.pt
+ module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip
+ pretrain_kwargs:
+ encoder:
+ checkpoint_key: ema_encoder
+ img_temporal_dim_size: 1
+ model_name: vit_large
+ patch_size: 16
+ tubelet_size: 2
+ uniform_power: true
+ use_rope: true
+ wrapper_kwargs:
+ max_frames: 128
+ use_pos_embed: false
diff --git a/configs/eval_2_1/vitl-384/ssv2.yaml b/configs/eval_2_1/vitl-384/ssv2.yaml
new file mode 100644
index 00000000..38056e1d
--- /dev/null
+++ b/configs/eval_2_1/vitl-384/ssv2.yaml
@@ -0,0 +1,164 @@
+cpus_per_task: 16
+eval_name: video_classification_frozen
+folder: /your_folder/evals_2_1/vitl-384/ssv2
+mem_per_gpu: 220G
+nodes: 8
+max_workers: 8
+resume_checkpoint: true
+tag: ssv2-vitl16-384-16x2x3-16f
+tasks_per_node: 8
+experiment:
+ classifier:
+ num_heads: 16
+ num_probe_blocks: 4
+ data:
+ dataset_type: VideoDataset
+ dataset_train: /your_data_path/ssv2_train_paths.csv
+ dataset_val: /your_data_path/ssv2_val_paths.csv
+ frame_step: 4
+ frames_per_clip: 16
+ num_classes: 174
+ num_segments: 2
+ num_views_per_segment: 3
+ resolution: 384
+ optimization:
+ batch_size: 4
+ multihead_kwargs:
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.01
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.01
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.1
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.1
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.4
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.4
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.005
+ start_lr: 0.005
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.003
+ start_lr: 0.003
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.001
+ start_lr: 0.001
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.0003
+ start_lr: 0.0003
+ warmup: 0.0
+ weight_decay: 0.8
+ - final_lr: 0.0
+ final_weight_decay: 0.8
+ lr: 0.0001
+ start_lr: 0.0001
+ warmup: 0.0
+ weight_decay: 0.8
+ num_epochs: 20
+ use_bfloat16: true
+ use_pos_embed: false
+model_kwargs:
+ checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitl_dist_vitG_384.pt
+ module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip
+ pretrain_kwargs:
+ encoder:
+ checkpoint_key: ema_encoder
+ img_temporal_dim_size: 1
+ model_name: vit_large
+ patch_size: 16
+ tubelet_size: 2
+ uniform_power: true
+ use_rope: true
+ wrapper_kwargs:
+ max_frames: 128
+ use_pos_embed: false
diff --git a/configs/train_2_1/vitG16/cooldown-256px-64f.yaml b/configs/train_2_1/vitG16/cooldown-256px-64f.yaml
new file mode 100644
index 00000000..eca1e592
--- /dev/null
+++ b/configs/train_2_1/vitG16/cooldown-256px-64f.yaml
@@ -0,0 +1,144 @@
+app: vjepa_2_1
+nodes: 64
+tasks_per_node: 8
+cpus_per_task: 64
+mem_per_gpu: 220G
+folder: /your_folder/cooldown_2_1/16.8.vitG.256px.16f
+data:
+ dataset_type: VideoDataset
+ datasets:
+ - /your_data/k710_train_paths.csv
+ - /your_data/ssv2_train_paths.csv
+ - /your_data/howto_320p.csv
+ datasets_weights:
+ - 0.335
+ - 0.100
+ - 0.565
+ batch_size: 6
+ crop_size: 256
+ dataset_fpcs:
+ - 64
+ - 64
+ - 64
+ fps: 4
+ num_workers: 12
+ patch_size: 16
+ persistent_workers: true
+ pin_mem: false
+ tubelet_size: 2
+data_aug:
+ auto_augment: false
+ motion_shift: false
+ random_resize_aspect_ratio:
+ - 0.75
+ - 1.35
+ random_resize_scale:
+ - 0.3
+ - 1.0
+ reprob: 0.0
+
+img_data:
+ batch_size: 36
+ crop_size: 256
+ dataset_fpcs:
+ - 1
+ dataset_type: VideoDataset
+ datasets:
+ - /your_data/imagenet1k.csv
+ datasets_weights:
+ - 1.0
+ rank_ratio: 0.5
+img_mask:
+- aspect_ratio:
+ - 0.75
+ - 1.5
+ full_complement: false
+ max_keep: null
+ max_temporal_keep: 1.0
+ num_blocks: 10
+ spatial_scale:
+ - 0.15
+ - 0.15
+ temporal_scale:
+ - 1.0
+ - 1.0
+
+loss:
+ computing_gram_loss: false
+ gram_HighRes: false
+ gram_ckpt: None
+ gram_loss_weight: 10.0
+ loss_exp: 1.0
+ predict_all: true
+ shift_by_n: 0
+ weight_distance_loss: false
+mask:
+- aspect_ratio:
+ - 0.75
+ - 1.5
+ full_complement: false
+ max_keep: null
+ max_temporal_keep: 1.0
+ num_blocks: 8
+ spatial_scale:
+ - 0.15
+ - 0.15
+ temporal_scale:
+ - 1.0
+ - 1.0
+- aspect_ratio:
+ - 0.75
+ - 1.5
+ full_complement: false
+ max_keep: null
+ max_temporal_keep: 1.0
+ num_blocks: 2
+ spatial_scale:
+ - 0.7
+ - 0.7
+ temporal_scale:
+ - 1.0
+ - 1.0
+meta:
+ dtype: bfloat16
+ eval_freq: 100
+ load_checkpoint: true
+ read_checkpoint: null
+ save_every_freq: 50
+ seed: 239
+ use_sdpa: true
+model:
+ has_cls_first: false
+ img_temporal_dim_size: 1
+ interpolate_rope: true
+ lambda_progressive: false
+ lambda_value_img: 0.5
+ lambda_value_vid: 0.5
+ modality_embedding: true
+ model_name: vit_gigantic_xformers
+ n_registers: 0
+ normalize_predictor: false
+ pred_depth: 24
+ pred_embed_dim: 384
+ pred_num_heads: 12
+ uniform_power: true
+ use_activation_checkpointing: true
+ use_mask_tokens: true
+ use_rope: true
+ zero_init_mask_tokens: true
+optimization:
+ anneal_ckpt: /your_folder/pretrain_2_1/16.8.vitG.256px.16f/latest.pth.tar
+ ema:
+ - 0.99925
+ - 0.99925
+ epochs: 40
+ final_lr: 1.0e-06
+ final_weight_decay: 0.04
+ ipe: 300
+ ipe_scale: 1.25
+ is_anneal: true
+ lr: 0.0006
+ resume_anneal: true
+ start_lr: 0.0001
+ warmup: 0
+ weight_decay: 0.04
diff --git a/configs/train_2_1/vitG16/pretrain-256px-16f.yaml b/configs/train_2_1/vitG16/pretrain-256px-16f.yaml
new file mode 100644
index 00000000..288c045f
--- /dev/null
+++ b/configs/train_2_1/vitG16/pretrain-256px-16f.yaml
@@ -0,0 +1,149 @@
+app: vjepa_2_1
+nodes: 32
+tasks_per_node: 8
+cpus_per_task: 32
+mem_per_gpu: 220G
+folder: /your_folder/pretrain_2_1/16.8.vitG.256px.16f
+data:
+ dataset_type: VideoDataset
+ datasets:
+ - /your_data/k710_train_paths.csv
+ - /your_data/ssv2_train_paths.csv
+ - /your_data/howto_320p.csv
+ datasets_weights:
+ - 0.335
+ - 0.100
+ - 0.565
+ batch_size: 12
+ crop_size: 256
+ patch_size: 16
+ dataset_fpcs:
+ - 16
+ - 16
+ - 16
+ tubelet_size: 2
+ fps: 4
+ num_workers: 8
+ persistent_workers: true
+ pin_mem: true
+data_aug:
+ auto_augment: false
+ motion_shift: false
+ random_resize_aspect_ratio:
+ - 0.75
+ - 1.35
+ random_resize_scale:
+ - 0.3
+ - 1.0
+ reprob: 0.0
+loss:
+ loss_exp: 1.0
+ predict_all: true
+ reg_coeff: 0.0
+ shift_by_n: 0
+ weight_distance_loss: true
+
+img_data:
+ batch_size: 36
+ crop_size: 256
+ dataset_fpcs:
+ - 1
+ dataset_type: VideoDataset
+ datasets:
+ - /your_data/imagenet1k.csv
+ datasets_weights:
+ - 1.0
+ rank_ratio: 0.5
+img_mask:
+- aspect_ratio:
+ - 0.75
+ - 1.5
+ full_complement: false
+ max_keep: null
+ max_temporal_keep: 1.0
+ num_blocks: 10
+ spatial_scale:
+ - 0.15
+ - 0.15
+ temporal_scale:
+ - 1.0
+ - 1.0
+
+mask:
+- aspect_ratio:
+ - 0.75
+ - 1.5
+ full_complement: false
+ max_keep: null
+ max_temporal_keep: 1.0
+ num_blocks: 8
+ spatial_scale:
+ - 0.15
+ - 0.15
+ temporal_scale:
+ - 1.0
+ - 1.0
+- aspect_ratio:
+ - 0.75
+ - 1.5
+ full_complement: false
+ max_keep: null
+ max_temporal_keep: 1.0
+ num_blocks: 2
+ spatial_scale:
+ - 0.7
+ - 0.7
+ temporal_scale:
+ - 1.0
+ - 1.0
+meta:
+ dtype: bfloat16
+ eval_freq: 100
+ load_checkpoint: true
+ read_checkpoint: null
+ save_every_freq: 50
+ seed: 239
+ use_sdpa: true
+model:
+ has_cls_first: false
+ img_temporal_dim_size: 1
+ interpolate_rope: true
+ is_causal: false
+ lambda_value_img: 0.5
+ lambda_value_vid: 0.5
+ local_window:
+ - -1
+ - -1
+ - -1
+ modality_embedding: true
+ model_name: vit_gigantic_xformers
+ n_registers: 0
+ n_registers_predictor: 0
+ normalize_predictor: false
+ pred_depth: 24
+ pred_embed_dim: 384
+ pred_is_causal: false
+ pred_local_window:
+ - -1
+ - -1
+ - -1
+ pred_num_heads: 12
+ uniform_power: true
+ use_activation_checkpointing: true
+ use_mask_tokens: true
+ use_rope: true
+ vit_conv: false
+ zero_init_mask_tokens: true
+optimization:
+ ema:
+ - 0.99925
+ - 0.99925
+ epochs: 1000
+ final_lr: 0.0006
+ final_weight_decay: 0.04
+ ipe: 300
+ ipe_scale: 1.25
+ lr: 0.0006
+ start_lr: 0.0001
+ warmup: 40
+ weight_decay: 0.04
diff --git a/configs/train_2_1/vitb16/cooldown-256px-64f.yaml b/configs/train_2_1/vitb16/cooldown-256px-64f.yaml
new file mode 100644
index 00000000..343ae9ba
--- /dev/null
+++ b/configs/train_2_1/vitb16/cooldown-256px-64f.yaml
@@ -0,0 +1,144 @@
+app: vjepa_2_1
+nodes: 16
+tasks_per_node: 8
+cpus_per_task: 16
+mem_per_gpu: 220G
+folder: /your_folder/cooldown_2_1/16.8.vitb.256px.16f
+data:
+ dataset_type: VideoDataset
+ datasets:
+ - /your_data/k710_train_paths.csv
+ - /your_data/ssv2_train_paths.csv
+ - /your_data/howto_320p.csv
+ datasets_weights:
+ - 0.335
+ - 0.100
+ - 0.565
+ batch_size: 24
+ crop_size: 256
+ dataset_fpcs:
+ - 64
+ - 64
+ - 64
+ fps: 4
+ num_workers: 12
+ patch_size: 16
+ persistent_workers: true
+ pin_mem: false
+ tubelet_size: 2
+data_aug:
+ auto_augment: false
+ motion_shift: false
+ random_resize_aspect_ratio:
+ - 0.75
+ - 1.35
+ random_resize_scale:
+ - 0.3
+ - 1.0
+ reprob: 0.0
+
+img_data:
+ batch_size: 144
+ crop_size: 256
+ dataset_fpcs:
+ - 1
+ dataset_type: VideoDataset
+ datasets:
+ - /your_data/imagenet1k.csv
+ datasets_weights:
+ - 1.0
+ rank_ratio: 0.5
+img_mask:
+- aspect_ratio:
+ - 0.75
+ - 1.5
+ full_complement: false
+ max_keep: null
+ max_temporal_keep: 1.0
+ num_blocks: 10
+ spatial_scale:
+ - 0.15
+ - 0.15
+ temporal_scale:
+ - 1.0
+ - 1.0
+
+loss:
+ computing_gram_loss: false
+ gram_HighRes: false
+ gram_ckpt: None
+ gram_loss_weight: 10.0
+ loss_exp: 1.0
+ predict_all: true
+ shift_by_n: 0
+ weight_distance_loss: false
+mask:
+- aspect_ratio:
+ - 0.75
+ - 1.5
+ full_complement: false
+ max_keep: null
+ max_temporal_keep: 1.0
+ num_blocks: 8
+ spatial_scale:
+ - 0.15
+ - 0.15
+ temporal_scale:
+ - 1.0
+ - 1.0
+- aspect_ratio:
+ - 0.75
+ - 1.5
+ full_complement: false
+ max_keep: null
+ max_temporal_keep: 1.0
+ num_blocks: 2
+ spatial_scale:
+ - 0.7
+ - 0.7
+ temporal_scale:
+ - 1.0
+ - 1.0
+meta:
+ dtype: bfloat16
+ eval_freq: 100
+ load_checkpoint: true
+ read_checkpoint: null
+ save_every_freq: 50
+ seed: 239
+ use_sdpa: true
+model:
+ has_cls_first: false
+ img_temporal_dim_size: 1
+ interpolate_rope: true
+ lambda_progressive: false
+ lambda_value_img: 0.5
+ lambda_value_vid: 0.5
+ modality_embedding: true
+ model_name: vit_base
+ n_registers: 0
+ normalize_predictor: false
+ pred_depth: 12
+ pred_embed_dim: 384
+ pred_num_heads: 12
+ uniform_power: true
+ use_activation_checkpointing: true
+ use_mask_tokens: true
+ use_rope: true
+ zero_init_mask_tokens: true
+optimization:
+ anneal_ckpt: /your_folder/pretrain_2_1/16.8.vitb.256px.16f/latest.pth.tar
+ ema:
+ - 0.99925
+ - 0.99925
+ epochs: 40
+ final_lr: 1.0e-06
+ final_weight_decay: 0.04
+ ipe: 300
+ ipe_scale: 1.25
+ is_anneal: true
+ lr: 0.0006
+ resume_anneal: true
+ start_lr: 0.0001
+ warmup: 0
+ weight_decay: 0.04
diff --git a/configs/train_2_1/vitb16/pretrain-256px-16f.yaml b/configs/train_2_1/vitb16/pretrain-256px-16f.yaml
new file mode 100644
index 00000000..dca536f3
--- /dev/null
+++ b/configs/train_2_1/vitb16/pretrain-256px-16f.yaml
@@ -0,0 +1,149 @@
+app: vjepa_2_1
+nodes: 8
+tasks_per_node: 8
+cpus_per_task: 8
+mem_per_gpu: 220G
+folder: /your_folder/pretrain_2_1/16.8.vitb.256px.16f
+data:
+ dataset_type: VideoDataset
+ datasets:
+ - /your_data/k710_train_paths.csv
+ - /your_data/ssv2_train_paths.csv
+ - /your_data/howto_320p.csv
+ datasets_weights:
+ - 0.335
+ - 0.100
+ - 0.565
+ batch_size: 48
+ crop_size: 256
+ patch_size: 16
+ dataset_fpcs:
+ - 16
+ - 16
+ - 16
+ tubelet_size: 2
+ fps: 4
+ num_workers: 8
+ persistent_workers: true
+ pin_mem: true
+data_aug:
+ auto_augment: false
+ motion_shift: false
+ random_resize_aspect_ratio:
+ - 0.75
+ - 1.35
+ random_resize_scale:
+ - 0.3
+ - 1.0
+ reprob: 0.0
+loss:
+ loss_exp: 1.0
+ predict_all: true
+ reg_coeff: 0.0
+ shift_by_n: 0
+ weight_distance_loss: true
+
+img_data:
+ batch_size: 144
+ crop_size: 256
+ dataset_fpcs:
+ - 1
+ dataset_type: VideoDataset
+ datasets:
+ - /your_data/imagenet1k.csv
+ datasets_weights:
+ - 1.0
+ rank_ratio: 0.5
+img_mask:
+- aspect_ratio:
+ - 0.75
+ - 1.5
+ full_complement: false
+ max_keep: null
+ max_temporal_keep: 1.0
+ num_blocks: 10
+ spatial_scale:
+ - 0.15
+ - 0.15
+ temporal_scale:
+ - 1.0
+ - 1.0
+
+mask:
+- aspect_ratio:
+ - 0.75
+ - 1.5
+ full_complement: false
+ max_keep: null
+ max_temporal_keep: 1.0
+ num_blocks: 8
+ spatial_scale:
+ - 0.15
+ - 0.15
+ temporal_scale:
+ - 1.0
+ - 1.0
+- aspect_ratio:
+ - 0.75
+ - 1.5
+ full_complement: false
+ max_keep: null
+ max_temporal_keep: 1.0
+ num_blocks: 2
+ spatial_scale:
+ - 0.7
+ - 0.7
+ temporal_scale:
+ - 1.0
+ - 1.0
+meta:
+ dtype: bfloat16
+ eval_freq: 100
+ load_checkpoint: true
+ read_checkpoint: null
+ save_every_freq: 50
+ seed: 239
+ use_sdpa: true
+model:
+ has_cls_first: false
+ img_temporal_dim_size: 1
+ interpolate_rope: true
+ is_causal: false
+ lambda_value_img: 0.5
+ lambda_value_vid: 0.5
+ local_window:
+ - -1
+ - -1
+ - -1
+ modality_embedding: true
+ model_name: vit_base
+ n_registers: 0
+ n_registers_predictor: 0
+ normalize_predictor: false
+ pred_depth: 12
+ pred_embed_dim: 384
+ pred_is_causal: false
+ pred_local_window:
+ - -1
+ - -1
+ - -1
+ pred_num_heads: 12
+ uniform_power: true
+ use_activation_checkpointing: true
+ use_mask_tokens: true
+ use_rope: true
+ vit_conv: false
+ zero_init_mask_tokens: true
+optimization:
+ ema:
+ - 0.99925
+ - 0.99925
+ epochs: 1000
+ final_lr: 0.0006
+ final_weight_decay: 0.04
+ ipe: 300
+ ipe_scale: 1.25
+ lr: 0.0006
+ start_lr: 0.0001
+ warmup: 40
+ weight_decay: 0.04
diff --git a/configs/train_2_1/vitg16/cooldown-256px-64f.yaml b/configs/train_2_1/vitg16/cooldown-256px-64f.yaml
new file mode 100644
index 00000000..c7a40962
--- /dev/null
+++ b/configs/train_2_1/vitg16/cooldown-256px-64f.yaml
@@ -0,0 +1,144 @@
+app: vjepa_2_1
+nodes: 32
+tasks_per_node: 8
+cpus_per_task: 32
+mem_per_gpu: 220G
+folder: /your_folder/cooldown_2_1/16.8.vitg.256px.16f
+data:
+ dataset_type: VideoDataset
+ datasets:
+ - /your_data/k710_train_paths.csv
+ - /your_data/ssv2_train_paths.csv
+ - /your_data/howto_320p.csv
+ datasets_weights:
+ - 0.335
+ - 0.100
+ - 0.565
+ batch_size: 12
+ crop_size: 256
+ dataset_fpcs:
+ - 64
+ - 64
+ - 64
+ fps: 4
+ num_workers: 12
+ patch_size: 16
+ persistent_workers: true
+ pin_mem: false
+ tubelet_size: 2
+data_aug:
+ auto_augment: false
+ motion_shift: false
+ random_resize_aspect_ratio:
+ - 0.75
+ - 1.35
+ random_resize_scale:
+ - 0.3
+ - 1.0
+ reprob: 0.0
+
+img_data:
+ batch_size: 72
+ crop_size: 256
+ dataset_fpcs:
+ - 1
+ dataset_type: VideoDataset
+ datasets:
+ - /your_data/imagenet1k.csv
+ datasets_weights:
+ - 1.0
+ rank_ratio: 0.5
+img_mask:
+- aspect_ratio:
+ - 0.75
+ - 1.5
+ full_complement: false
+ max_keep: null
+ max_temporal_keep: 1.0
+ num_blocks: 10
+ spatial_scale:
+ - 0.15
+ - 0.15
+ temporal_scale:
+ - 1.0
+ - 1.0
+
+loss:
+ computing_gram_loss: false
+ gram_HighRes: false
+ gram_ckpt: None
+ gram_loss_weight: 10.0
+ loss_exp: 1.0
+ predict_all: true
+ shift_by_n: 0
+ weight_distance_loss: false
+mask:
+- aspect_ratio:
+ - 0.75
+ - 1.5
+ full_complement: false
+ max_keep: null
+ max_temporal_keep: 1.0
+ num_blocks: 8
+ spatial_scale:
+ - 0.15
+ - 0.15
+ temporal_scale:
+ - 1.0
+ - 1.0
+- aspect_ratio:
+ - 0.75
+ - 1.5
+ full_complement: false
+ max_keep: null
+ max_temporal_keep: 1.0
+ num_blocks: 2
+ spatial_scale:
+ - 0.7
+ - 0.7
+ temporal_scale:
+ - 1.0
+ - 1.0
+meta:
+ dtype: bfloat16
+ eval_freq: 100
+ load_checkpoint: true
+ read_checkpoint: null
+ save_every_freq: 50
+ seed: 239
+ use_sdpa: true
+model:
+ has_cls_first: false
+ img_temporal_dim_size: 1
+ interpolate_rope: true
+ lambda_progressive: false
+ lambda_value_img: 0.5
+ lambda_value_vid: 0.5
+ modality_embedding: true
+ model_name: vit_giant_xformers
+ n_registers: 0
+ normalize_predictor: false
+ pred_depth: 24
+ pred_embed_dim: 384
+ pred_num_heads: 12
+ uniform_power: true
+ use_activation_checkpointing: true
+ use_mask_tokens: true
+ use_rope: true
+ zero_init_mask_tokens: true
+optimization:
+ anneal_ckpt: /your_folder/pretrain_2_1/16.8.vitg.256px.16f/latest.pth.tar
+ ema:
+ - 0.99925
+ - 0.99925
+ epochs: 40
+ final_lr: 1.0e-06
+ final_weight_decay: 0.04
+ ipe: 300
+ ipe_scale: 1.25
+ is_anneal: true
+ lr: 0.0006
+ resume_anneal: true
+ start_lr: 0.0001
+ warmup: 0
+ weight_decay: 0.04
diff --git a/configs/train_2_1/vitg16/pretrain-256px-16f.yaml b/configs/train_2_1/vitg16/pretrain-256px-16f.yaml
new file mode 100644
index 00000000..170cd2f5
--- /dev/null
+++ b/configs/train_2_1/vitg16/pretrain-256px-16f.yaml
@@ -0,0 +1,149 @@
+app: vjepa_2_1
+nodes: 16
+tasks_per_node: 8
+cpus_per_task: 16
+mem_per_gpu: 220G
+folder: /your_folder/pretrain_2_1/16.8.vitg.256px.16f
+data:
+ dataset_type: VideoDataset
+ datasets:
+ - /your_data/k710_train_paths.csv
+ - /your_data/ssv2_train_paths.csv
+ - /your_data/howto_320p.csv
+ datasets_weights:
+ - 0.335
+ - 0.100
+ - 0.565
+ batch_size: 24
+ crop_size: 256
+ patch_size: 16
+ dataset_fpcs:
+ - 16
+ - 16
+ - 16
+ tubelet_size: 2
+ fps: 4
+ num_workers: 8
+ persistent_workers: true
+ pin_mem: true
+data_aug:
+ auto_augment: false
+ motion_shift: false
+ random_resize_aspect_ratio:
+ - 0.75
+ - 1.35
+ random_resize_scale:
+ - 0.3
+ - 1.0
+ reprob: 0.0
+loss:
+ loss_exp: 1.0
+ predict_all: true
+ reg_coeff: 0.0
+ shift_by_n: 0
+ weight_distance_loss: true
+
+img_data:
+ batch_size: 72
+ crop_size: 256
+ dataset_fpcs:
+ - 1
+ dataset_type: VideoDataset
+ datasets:
+ - /your_data/imagenet1k.csv
+ datasets_weights:
+ - 1.0
+ rank_ratio: 0.5
+img_mask:
+- aspect_ratio:
+ - 0.75
+ - 1.5
+ full_complement: false
+ max_keep: null
+ max_temporal_keep: 1.0
+ num_blocks: 10
+ spatial_scale:
+ - 0.15
+ - 0.15
+ temporal_scale:
+ - 1.0
+ - 1.0
+
+mask:
+- aspect_ratio:
+ - 0.75
+ - 1.5
+ full_complement: false
+ max_keep: null
+ max_temporal_keep: 1.0
+ num_blocks: 8
+ spatial_scale:
+ - 0.15
+ - 0.15
+ temporal_scale:
+ - 1.0
+ - 1.0
+- aspect_ratio:
+ - 0.75
+ - 1.5
+ full_complement: false
+ max_keep: null
+ max_temporal_keep: 1.0
+ num_blocks: 2
+ spatial_scale:
+ - 0.7
+ - 0.7
+ temporal_scale:
+ - 1.0
+ - 1.0
+meta:
+ dtype: bfloat16
+ eval_freq: 100
+ load_checkpoint: true
+ read_checkpoint: null
+ save_every_freq: 50
+ seed: 239
+ use_sdpa: true
+model:
+ has_cls_first: false
+ img_temporal_dim_size: 1
+ interpolate_rope: true
+ is_causal: false
+ lambda_value_img: 0.5
+ lambda_value_vid: 0.5
+ local_window:
+ - -1
+ - -1
+ - -1
+ modality_embedding: true
+ model_name: vit_giant_xformers
+ n_registers: 0
+ n_registers_predictor: 0
+ normalize_predictor: false
+ pred_depth: 24
+ pred_embed_dim: 384
+ pred_is_causal: false
+ pred_local_window:
+ - -1
+ - -1
+ - -1
+ pred_num_heads: 12
+ uniform_power: true
+ use_activation_checkpointing: true
+ use_mask_tokens: true
+ use_rope: true
+ vit_conv: false
+ zero_init_mask_tokens: true
+optimization:
+ ema:
+ - 0.99925
+ - 0.99925
+ epochs: 1000
+ final_lr: 0.0006
+ final_weight_decay: 0.04
+ ipe: 300
+ ipe_scale: 1.25
+ lr: 0.0006
+ start_lr: 0.0001
+ warmup: 40
+ weight_decay: 0.04
diff --git a/configs/train_2_1/vitl16/cooldown-256px-64f.yaml b/configs/train_2_1/vitl16/cooldown-256px-64f.yaml
new file mode 100644
index 00000000..73308856
--- /dev/null
+++ b/configs/train_2_1/vitl16/cooldown-256px-64f.yaml
@@ -0,0 +1,144 @@
+app: vjepa_2_1
+nodes: 32
+tasks_per_node: 8
+cpus_per_task: 32
+mem_per_gpu: 220G
+folder: /your_folder/cooldown_2_1/16.8.vitl.256px.16f
+data:
+ dataset_type: VideoDataset
+ datasets:
+ - /your_data/k710_train_paths.csv
+ - /your_data/ssv2_train_paths.csv
+ - /your_data/howto_320p.csv
+ datasets_weights:
+ - 0.335
+ - 0.100
+ - 0.565
+ batch_size: 12
+ crop_size: 256
+ dataset_fpcs:
+ - 64
+ - 64
+ - 64
+ fps: 4
+ num_workers: 12
+ patch_size: 16
+ persistent_workers: true
+ pin_mem: false
+ tubelet_size: 2
+data_aug:
+ auto_augment: false
+ motion_shift: false
+ random_resize_aspect_ratio:
+ - 0.75
+ - 1.35
+ random_resize_scale:
+ - 0.3
+ - 1.0
+ reprob: 0.0
+
+img_data:
+ batch_size: 72
+ crop_size: 256
+ dataset_fpcs:
+ - 1
+ dataset_type: VideoDataset
+ datasets:
+ - /your_data/imagenet1k.csv
+ datasets_weights:
+ - 1.0
+ rank_ratio: 0.5
+img_mask:
+- aspect_ratio:
+ - 0.75
+ - 1.5
+ full_complement: false
+ max_keep: null
+ max_temporal_keep: 1.0
+ num_blocks: 10
+ spatial_scale:
+ - 0.15
+ - 0.15
+ temporal_scale:
+ - 1.0
+ - 1.0
+
+loss:
+ computing_gram_loss: false
+ gram_HighRes: false
+ gram_ckpt: None
+ gram_loss_weight: 10.0
+ loss_exp: 1.0
+ predict_all: true
+ shift_by_n: 0
+ weight_distance_loss: false
+mask:
+- aspect_ratio:
+ - 0.75
+ - 1.5
+ full_complement: false
+ max_keep: null
+ max_temporal_keep: 1.0
+ num_blocks: 8
+ spatial_scale:
+ - 0.15
+ - 0.15
+ temporal_scale:
+ - 1.0
+ - 1.0
+- aspect_ratio:
+ - 0.75
+ - 1.5
+ full_complement: false
+ max_keep: null
+ max_temporal_keep: 1.0
+ num_blocks: 2
+ spatial_scale:
+ - 0.7
+ - 0.7
+ temporal_scale:
+ - 1.0
+ - 1.0
+meta:
+ dtype: bfloat16
+ eval_freq: 100
+ load_checkpoint: true
+ read_checkpoint: null
+ save_every_freq: 50
+ seed: 239
+ use_sdpa: true
+model:
+ has_cls_first: false
+ img_temporal_dim_size: 1
+ interpolate_rope: true
+ lambda_progressive: false
+ lambda_value_img: 0.5
+ lambda_value_vid: 0.5
+ modality_embedding: true
+ model_name: vit_large
+ n_registers: 0
+ normalize_predictor: false
+ pred_depth: 24
+ pred_embed_dim: 384
+ pred_num_heads: 12
+ uniform_power: true
+ use_activation_checkpointing: true
+ use_mask_tokens: true
+ use_rope: true
+ zero_init_mask_tokens: true
+optimization:
+ anneal_ckpt: /your_folder/pretrain_2_1/16.8.vitl.256px.16f/latest.pth.tar
+ ema:
+ - 0.99925
+ - 0.99925
+ epochs: 40
+ final_lr: 1.0e-06
+ final_weight_decay: 0.04
+ ipe: 300
+ ipe_scale: 1.25
+ is_anneal: true
+ lr: 0.0006
+ resume_anneal: true
+ start_lr: 0.0001
+ warmup: 0
+ weight_decay: 0.04
diff --git a/configs/train_2_1/vitl16/pretrain-256px-16f.yaml b/configs/train_2_1/vitl16/pretrain-256px-16f.yaml
new file mode 100644
index 00000000..3988cf9a
--- /dev/null
+++ b/configs/train_2_1/vitl16/pretrain-256px-16f.yaml
@@ -0,0 +1,149 @@
+app: vjepa_2_1
+nodes: 16
+tasks_per_node: 8
+cpus_per_task: 16
+mem_per_gpu: 220G
+folder: /your_folder/pretrain_2_1/16.8.vitl.256px.16f
+data:
+ dataset_type: VideoDataset
+ datasets:
+ - /your_data/k710_train_paths.csv
+ - /your_data/ssv2_train_paths.csv
+ - /your_data/howto_320p.csv
+ datasets_weights:
+ - 0.335
+ - 0.100
+ - 0.565
+ batch_size: 24
+ crop_size: 256
+ patch_size: 16
+ dataset_fpcs:
+ - 16
+ - 16
+ - 16
+ tubelet_size: 2
+ fps: 4
+ num_workers: 8
+ persistent_workers: true
+ pin_mem: true
+data_aug:
+ auto_augment: false
+ motion_shift: false
+ random_resize_aspect_ratio:
+ - 0.75
+ - 1.35
+ random_resize_scale:
+ - 0.3
+ - 1.0
+ reprob: 0.0
+loss:
+ loss_exp: 1.0
+ predict_all: true
+ reg_coeff: 0.0
+ shift_by_n: 0
+ weight_distance_loss: true
+
+img_data:
+ batch_size: 72
+ crop_size: 256
+ dataset_fpcs:
+ - 1
+ dataset_type: VideoDataset
+ datasets:
+ - /your_data/imagenet1k.csv
+ datasets_weights:
+ - 1.0
+ rank_ratio: 0.5
+img_mask:
+- aspect_ratio:
+ - 0.75
+ - 1.5
+ full_complement: false
+ max_keep: null
+ max_temporal_keep: 1.0
+ num_blocks: 10
+ spatial_scale:
+ - 0.15
+ - 0.15
+ temporal_scale:
+ - 1.0
+ - 1.0
+
+mask:
+- aspect_ratio:
+ - 0.75
+ - 1.5
+ full_complement: false
+ max_keep: null
+ max_temporal_keep: 1.0
+ num_blocks: 8
+ spatial_scale:
+ - 0.15
+ - 0.15
+ temporal_scale:
+ - 1.0
+ - 1.0
+- aspect_ratio:
+ - 0.75
+ - 1.5
+ full_complement: false
+ max_keep: null
+ max_temporal_keep: 1.0
+ num_blocks: 2
+ spatial_scale:
+ - 0.7
+ - 0.7
+ temporal_scale:
+ - 1.0
+ - 1.0
+meta:
+ dtype: bfloat16
+ eval_freq: 100
+ load_checkpoint: true
+ read_checkpoint: null
+ save_every_freq: 50
+ seed: 239
+ use_sdpa: true
+model:
+ has_cls_first: false
+ img_temporal_dim_size: 1
+ interpolate_rope: true
+ is_causal: false
+ lambda_value_img: 0.5
+ lambda_value_vid: 0.5
+ local_window:
+ - -1
+ - -1
+ - -1
+ modality_embedding: true
+ model_name: vit_large
+ n_registers: 0
+ n_registers_predictor: 0
+ normalize_predictor: false
+ pred_depth: 24
+ pred_embed_dim: 384
+ pred_is_causal: false
+ pred_local_window:
+ - -1
+ - -1
+ - -1
+ pred_num_heads: 12
+ uniform_power: true
+ use_activation_checkpointing: true
+ use_mask_tokens: true
+ use_rope: true
+ vit_conv: false
+ zero_init_mask_tokens: true
+optimization:
+ ema:
+ - 0.99925
+ - 0.99925
+ epochs: 1000
+ final_lr: 0.0006
+ final_weight_decay: 0.04
+ ipe: 300
+ ipe_scale: 1.25
+ lr: 0.0006
+ start_lr: 0.0001
+ warmup: 40
+ weight_decay: 0.04
diff --git a/evals/action_anticipation_frozen/eval.py b/evals/action_anticipation_frozen/eval.py
index 9a15ee04..5dcce775 100644
--- a/evals/action_anticipation_frozen/eval.py
+++ b/evals/action_anticipation_frozen/eval.py
@@ -743,7 +743,7 @@ def load_checkpoint(device, r_path, classifiers, opt, scaler, val_only=False):
[o.load_state_dict(c) for o, c in zip(opt, checkpoint["opt"])]
if scaler is not None:
- [s.load_state_dict(c) for s, c in zip(opt, checkpoint["scaler"])]
+ [s.load_state_dict(c) for s, c in zip(scaler, checkpoint["scaler"])]
logger.info(f"loaded optimizers from epoch {epoch}")
return classifiers, opt, scaler, epoch
diff --git a/evals/action_anticipation_frozen/modelcustom/vit_encoder_predictor_concat_ar.py b/evals/action_anticipation_frozen/modelcustom/vit_encoder_predictor_concat_ar.py
index 90424702..975a908b 100644
--- a/evals/action_anticipation_frozen/modelcustom/vit_encoder_predictor_concat_ar.py
+++ b/evals/action_anticipation_frozen/modelcustom/vit_encoder_predictor_concat_ar.py
@@ -26,14 +26,23 @@
import torch
-import src.models.predictor as vit_pred
-import src.models.vision_transformer as vit
-
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
+def _get_model_modules(pretrain_kwargs):
+ """Import encoder/predictor modules based on config."""
+ use_v2_1 = pretrain_kwargs.get("use_v2_1", False)
+ if use_v2_1:
+ import app.vjepa_2_1.models.predictor as vit_pred
+ import app.vjepa_2_1.models.vision_transformer as vit
+ else:
+ import src.models.predictor as vit_pred
+ import src.models.vision_transformer as vit
+ return vit, vit_pred
+
+
def init_module(
frames_per_clip: int,
frames_per_second: int,
@@ -50,6 +59,8 @@ def init_module(
# ----------------------------------------------------------------------- #
# Initialize Encoder
# ----------------------------------------------------------------------- #
+ vit, vit_pred = _get_model_modules(model_kwargs)
+
enc_kwargs = model_kwargs["encoder"]
enc_ckp_key = enc_kwargs.get("checkpoint_key")
enc_model_name = enc_kwargs.get("model_name")
@@ -76,13 +87,21 @@ def init_module(
prd_ckp_key = prd_kwargs.get("checkpoint_key")
prd_model_name = prd_kwargs.get("model_name")
+ teacher_embed_dim = prd_kwargs.get("teacher_embed_dim")
+ n_output_distillation = prd_kwargs.get("n_output_distillation", 4)
+ prd_out_embed_dim = teacher_embed_dim // n_output_distillation if teacher_embed_dim is not None else None
+ print(f"[DEBUG] teacher_embed_dim={teacher_embed_dim}, n_output_distillation={n_output_distillation}, prd_out_embed_dim={prd_out_embed_dim}")
+ print(f"[DEBUG] vit_pred module: {vit_pred.__name__}")
+
predictor = vit_pred.__dict__[prd_model_name](
img_size=resolution,
embed_dim=encoder.embed_dim,
patch_size=encoder.patch_size,
tubelet_size=encoder.tubelet_size,
+ out_embed_dim=prd_out_embed_dim,
**prd_kwargs,
)
+ print(f"[DEBUG] predictor_proj: {predictor.predictor_proj}")
pretrained_dict = checkpoint[prd_ckp_key]
# --
pretrained_dict = {k.replace("module.", ""): v for k, v in pretrained_dict.items()}
@@ -92,7 +111,7 @@ def init_module(
logger.info(f'key "{k}" could not be found in loaded state dict')
elif pretrained_dict[k].shape != v.shape:
logger.info(
- f'key "{k}" is of different shape in model and loaded state dict', pretrained_dict[k].shape, v.shape
+ f'key "{k}" is of different shape in model and loaded state dict: {pretrained_dict[k].shape} vs {v.shape}'
)
pretrained_dict[k] = v
msg = predictor.load_state_dict(pretrained_dict, strict=False)
@@ -113,6 +132,10 @@ def init_module(
)
model.embed_dim = encoder.embed_dim
+ # Enable hierarchical feature output for non-distilled predictors
+ if hasattr(predictor, 'hierarchical_layers') and len(predictor.hierarchical_layers) > 1:
+ encoder.return_hierarchical = True
+
return model
@@ -151,20 +174,24 @@ def forward(self, x, anticipation_times):
:param x: (Tensor) video of shape [B, C, T, H, W]
:param anticipation_time: (Tensor) [B] seconds into the future to predict for each sample in batch
"""
- x = self.encoder(x)
+ x_full = self.encoder(x)
- # determine 1D position of context tokens (x)
- # determine 1D position of prediction tokens
- # forward predictor with ctxt=x, tgt=None, masks_ctxt, masks_tgt
if self.no_predictor:
- return x
+ return x_full
+
+ B, N, D_full = x_full.size()
+ embed_dim = self.encoder.embed_dim
+ use_hierarchical = D_full > embed_dim
- # Will output representations of $num_output_frames, that are
- # $anticipation_time seconds into the future.
- B, N, D = x.size()
+ # For the accumulator/classifier, use last-layer features (embed_dim)
+ # For the predictor, use full hierarchical features (D_full)
+ if use_hierarchical:
+ x = x_full[:, :, -embed_dim:]
+ else:
+ x = x_full
if self.no_encoder:
- x_accumulate = torch.rand(x.size(0), 0, x.size(2)).to(x.device)
+ x_accumulate = torch.rand(B, 0, embed_dim).to(x.device)
else:
x_accumulate = x.clone()
@@ -180,9 +207,18 @@ def forward(self, x, anticipation_times):
tgt_positions = torch.arange(N_pred).unsqueeze(0).repeat(B, 1).to(x.device)
tgt_positions += skip_positions.unsqueeze(1).repeat(1, N_pred)
+ x_pred_input = x_full
for _ in range(self.num_steps):
- x_pred = self.predictor(x, masks_x=ctxt_positions, masks_y=tgt_positions)
+ pred_out = self.predictor(x_pred_input, masks_x=ctxt_positions, masks_y=tgt_positions)
+ x_pred_full = pred_out[0] if isinstance(pred_out, tuple) else pred_out
+
+ if x_pred_full.size(-1) != embed_dim:
+ x_pred = x_pred_full[:, :, -embed_dim:]
+ else:
+ x_pred = x_pred_full
+
x_accumulate = torch.cat([x_accumulate, x_pred], dim=1)
- x = torch.cat([x[:, N_pred:, :], x_pred], dim=1)
+ x_pred_for_input = x_pred_full if x_pred_full.size(-1) == x_pred_input.size(-1) else x_pred
+ x_pred_input = torch.cat([x_pred_input[:, N_pred:, :], x_pred_for_input], dim=1)
return x_accumulate
diff --git a/evals/image_classification_frozen/eval.py b/evals/image_classification_frozen/eval.py
index df0bf958..9017cd25 100644
--- a/evals/image_classification_frozen/eval.py
+++ b/evals/image_classification_frozen/eval.py
@@ -81,7 +81,6 @@ def main(args_eval, resume_preempt=False):
dataset_name = args_data.get("dataset_name")
num_classes = args_data.get("num_classes")
root_path = args_data.get("root_path", None)
- image_folder = args_data.get("image_folder", None)
resolution = args_data.get("resolution", 224)
normalization = args_data.get("normalization", None)
@@ -157,7 +156,6 @@ def main(args_eval, resume_preempt=False):
dataset_name=dataset_name,
root_path=root_path,
img_size=resolution,
- image_folder=image_folder,
batch_size=batch_size,
world_size=world_size,
rank=rank,
@@ -168,7 +166,6 @@ def main(args_eval, resume_preempt=False):
dataset_name=dataset_name,
root_path=root_path,
img_size=resolution,
- image_folder=image_folder,
batch_size=batch_size,
world_size=world_size,
rank=rank,
@@ -354,7 +351,6 @@ def load_checkpoint(device, r_path, classifiers, opt, scaler, val_only=False):
def make_dataloader(
dataset_name,
root_path,
- image_folder,
batch_size,
world_size,
rank,
@@ -396,7 +392,6 @@ def make_dataloader(
world_size=world_size,
rank=rank,
root_path=root_path,
- image_folder=image_folder,
training=training,
drop_last=False,
subset_file=subset_file,
diff --git a/hubconf.py b/hubconf.py
index 1aa48b64..39f303d5 100644
--- a/hubconf.py
+++ b/hubconf.py
@@ -10,6 +10,10 @@
vjepa2_vit_giant_384,
vjepa2_vit_huge,
vjepa2_vit_large,
+ vjepa2_1_vit_base_384,
+ vjepa2_1_vit_large_384,
+ vjepa2_1_vit_giant_384,
+ vjepa2_1_vit_gigantic_384,
)
dependencies = ["torch", "timm", "einops"]
diff --git a/setup.py b/setup.py
index 35af1634..da9c84c1 100644
--- a/setup.py
+++ b/setup.py
@@ -6,7 +6,7 @@
from setuptools import setup
NAME = "vjepa2"
-VERSION = "0.0.1"
+VERSION = "0.0.2"
DESCRIPTION = "PyTorch code and models for V-JEPA 2."
URL = "https://github.com/facebookresearch/vjepa2"
diff --git a/src/datasets/data_manager.py b/src/datasets/data_manager.py
index e3548639..d4e1a534 100644
--- a/src/datasets/data_manager.py
+++ b/src/datasets/data_manager.py
@@ -20,7 +20,6 @@ def init_data(
world_size=1,
rank=0,
root_path=None,
- image_folder=None,
training=True,
drop_last=True,
subset_file=None,
@@ -52,7 +51,6 @@ def init_data(
world_size=world_size,
rank=rank,
root_path=root_path,
- image_folder=image_folder,
persistent_workers=persistent_workers,
drop_last=drop_last,
subset_file=subset_file,
diff --git a/src/datasets/imagenet1k.py b/src/datasets/imagenet1k.py
index 3597e3a1..55603632 100644
--- a/src/datasets/imagenet1k.py
+++ b/src/datasets/imagenet1k.py
@@ -21,7 +21,6 @@ class ImageNet(torchvision.datasets.ImageFolder):
def __init__(
self,
root,
- image_folder="imagenet_full_size/061417/",
tar_file="imagenet_full_size-061417.tar.gz",
transform=None,
train=True,
@@ -35,7 +34,6 @@ def __init__(
Dataset wrapper
:param root: root network directory for ImageNet data
- :param image_folder: path to images inside root network directory
:param tar_file: zipped image_folder inside root network directory
:param train: whether to load train data (or validation)
:param job_id: scheduler job-id used to create dir on local machine
@@ -43,7 +41,7 @@ def __init__(
"""
suffix = "train/" if train else "val/"
- data_path = os.path.join(root, image_folder, suffix)
+ data_path = os.path.join(root, suffix)
logger.info(f"data-path {data_path}")
super(ImageNet, self).__init__(root=data_path, transform=transform)
@@ -120,7 +118,6 @@ def make_imagenet1k(
world_size=1,
rank=0,
root_path=None,
- image_folder=None,
training=True,
drop_last=True,
persistent_workers=False,
@@ -128,7 +125,6 @@ def make_imagenet1k(
):
dataset = ImageNet(
root=root_path,
- image_folder=image_folder,
transform=transform,
train=training,
index_targets=False,
@@ -136,7 +132,9 @@ def make_imagenet1k(
if subset_file is not None:
dataset = ImageNetSubset(dataset, subset_file)
logger.info("ImageNet dataset created")
- dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset=dataset, num_replicas=world_size, rank=rank)
+ dist_sampler = torch.utils.data.distributed.DistributedSampler(
+ dataset=dataset, num_replicas=world_size, rank=rank
+ )
data_loader = torch.utils.data.DataLoader(
dataset,
collate_fn=collator,
diff --git a/src/datasets/utils/weighted_sampler.py b/src/datasets/utils/weighted_sampler.py
index 9b4105fa..63a67b4e 100644
--- a/src/datasets/utils/weighted_sampler.py
+++ b/src/datasets/utils/weighted_sampler.py
@@ -8,9 +8,9 @@
import numpy as np
import torch
-from torch.utils.data import DistributedSampler, RandomSampler
from src.utils.logging import get_logger
+from torch.utils.data import DistributedSampler, RandomSampler
logger = get_logger("WeightedSampler")
@@ -35,7 +35,9 @@ def __init__(
seed: int = 0,
drop_last: bool = False,
):
- logger.info(f"Using DistributedWeightedSampler with rank {rank} / {num_replicas}")
+ logger.info(
+ f"Using DistributedWeightedSampler with rank {rank} / {num_replicas}"
+ )
assert hasattr(
dataset, "sample_weights"
), "Dataset must have sample_weights property for using DistributedWeightedSampler"
@@ -78,7 +80,9 @@ def __iter__(self) -> Iterator:
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
- indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
+ indices += (indices * math.ceil(padding_size / len(indices)))[
+ :padding_size
+ ]
else:
# remove tail of data to make it evenly divisible
indices = indices[: self.total_size]
@@ -106,7 +110,9 @@ def __init__(
shuffle: bool = True,
seed: int = 0,
):
- logger.info(f"Using MemoryEfficientDistributedWeightedSampler with rank {rank} / {num_replicas}")
+ logger.info(
+ f"Using MemoryEfficientDistributedWeightedSampler with rank {rank} / {num_replicas}"
+ )
assert hasattr(
dataset, "dataset_weights"
), "Dataset must have dataset_weights property for using MemoryEfficientDistributedWeightedSampler"
@@ -129,10 +135,14 @@ def __init__(
if self.shuffle:
self.rng = np.random.default_rng(self.seed + self.rank + self.epoch)
total_weights = sum(self.dataset_weights)
- self.dataset_probablities = np.array([w / total_weights for w in self.dataset_weights])
+ self.dataset_probablities = np.array(
+ [w / total_weights for w in self.dataset_weights]
+ )
else:
if any([not isinstance(w, int) for w in self.dataset_weights]):
- raise ValueError("Dataset weights must be integers when shuffle is False")
+ raise ValueError(
+ "Dataset weights must be integers when shuffle is False"
+ )
self.dataset_orders = []
for i, w in enumerate(self.dataset_weights):
@@ -145,7 +155,9 @@ def __iter__(self) -> Iterator:
def __next__(self) -> int:
if self.shuffle:
- selected_dataset_idx = self.rng.choice(range(len(self.dataset_weights)), p=self.dataset_probablities)
+ selected_dataset_idx = self.rng.choice(
+ range(len(self.dataset_weights)), p=self.dataset_probablities
+ )
# In order to avoid sampling the same example multiple times between the ranks,
# we limit each rank to a subset of the total number of samples in the dataset.
@@ -171,12 +183,14 @@ def __next__(self) -> int:
else:
# Iterate through the dataset orders in a round-robin fashion, offset by the rank
- dataset_orders_idx = (self.rank + self.drawn_samples) % len(self.dataset_orders)
+ dataset_orders_idx = (self.rank + self.drawn_samples) % len(
+ self.dataset_orders
+ )
selected_dataset_idx = self.dataset_orders[dataset_orders_idx]
# Get the sample index in the selected dataset by skipping with the num_replicas * drawn_samples
- sample_idx_in_dataset = (self.drawn_samples * self.num_replicas + self.rank) % self.dataset_sizes[
- selected_dataset_idx
- ]
+ sample_idx_in_dataset = (
+ self.drawn_samples * self.num_replicas + self.rank
+ ) % self.dataset_sizes[selected_dataset_idx]
self.drawn_samples += 1
# Getting the index of the sample in the whole dataset
@@ -224,7 +238,9 @@ def __init__(
shuffle: bool = True,
seed: int = 0,
):
- logger.info(f"Using MemoryEfficientDistributedWeightedSamplerLessRepeat with rank {rank} / {num_replicas}")
+ logger.info(
+ f"Using MemoryEfficientDistributedWeightedSamplerLessRepeat with rank {rank} / {num_replicas}"
+ )
assert hasattr(
dataset, "dataset_weights"
), "Dataset must have dataset_weights property for using MemoryEfficientDistributedWeightedSamplerLessRepeat"
@@ -250,11 +266,15 @@ def __init__(
if self.shuffle:
self.rng = np.random.default_rng(self.seed + self.rank + self.epoch)
total_weights = sum(self.dataset_weights)
- self.dataset_probablities = np.array([w / total_weights for w in self.dataset_weights])
+ self.dataset_probablities = np.array(
+ [w / total_weights for w in self.dataset_weights]
+ )
# For each dataset we generate a permutation of the indices that will be processed by that rank.
# This is going to be the subset of indices, selected by the steps sizes of the world size.
- logger.info(f"Generating dataset indices for rank {self.rank} / {self.num_replicas}")
+ logger.info(
+ f"Generating dataset indices for rank {self.rank} / {self.num_replicas}"
+ )
# Getting a RandomSampler for indices assigned to each dataset.
self.individual_dataset_sampler = []
@@ -263,11 +283,15 @@ def __init__(
# NOTE: this may effectively drop the last batch,
# but given the sample sizes that we use this sampler with, it should not be an issue.
num_samples_in_rank = ds // self.num_replicas
- self.individual_dataset_sampler.append(self._new_sampler(num_samples_in_rank))
+ self.individual_dataset_sampler.append(
+ self._new_sampler(num_samples_in_rank)
+ )
else:
if any([not isinstance(w, int) for w in self.dataset_weights]):
- raise ValueError("Dataset weights must be integers when shuffle is False")
+ raise ValueError(
+ "Dataset weights must be integers when shuffle is False"
+ )
self.dataset_orders = []
for i, w in enumerate(self.dataset_weights):
@@ -295,7 +319,9 @@ def _in_rank_next_index_for_dataset(self, dataset_idx: int) -> int:
if next_sampler_idx is None:
# We have reached the end of the dataset, we need to reset the sampler.
num_samples_in_rank = self.dataset_sizes[dataset_idx] // self.num_replicas
- self.individual_dataset_sampler[dataset_idx] = self._new_sampler(num_samples_in_rank)
+ self.individual_dataset_sampler[dataset_idx] = self._new_sampler(
+ num_samples_in_rank
+ )
next_sampler_idx = safe_next(self.individual_dataset_sampler[dataset_idx])
assert next_sampler_idx is not None
@@ -303,7 +329,9 @@ def _in_rank_next_index_for_dataset(self, dataset_idx: int) -> int:
def __next__(self) -> int:
if self.shuffle:
- selected_dataset_idx = self.rng.choice(range(len(self.dataset_weights)), p=self.dataset_probablities)
+ selected_dataset_idx = self.rng.choice(
+ range(len(self.dataset_weights)), p=self.dataset_probablities
+ )
in_rank_sample = self._in_rank_next_index_for_dataset(selected_dataset_idx)
# 2) Getting sample index in the dataset.
@@ -311,12 +339,14 @@ def __next__(self) -> int:
else:
# Iterate through the dataset orders in a round-robin fashion, offset by the rank
- dataset_orders_idx = (self.rank + self.drawn_samples) % len(self.dataset_orders)
+ dataset_orders_idx = (self.rank + self.drawn_samples) % len(
+ self.dataset_orders
+ )
selected_dataset_idx = self.dataset_orders[dataset_orders_idx]
# Get the sample index in the selected dataset by skipping with the num_replicas * drawn_samples
- sample_idx_in_dataset = (self.drawn_samples * self.num_replicas + self.rank) % self.dataset_sizes[
- selected_dataset_idx
- ]
+ sample_idx_in_dataset = (
+ self.drawn_samples * self.num_replicas + self.rank
+ ) % self.dataset_sizes[selected_dataset_idx]
self.drawn_samples += 1
# Getting the index of the sample in the whole dataset
diff --git a/src/datasets/video_dataset.py b/src/datasets/video_dataset.py
index b0c47f66..05ec2a38 100644
--- a/src/datasets/video_dataset.py
+++ b/src/datasets/video_dataset.py
@@ -13,9 +13,13 @@
import pandas as pd
import torch
import torchvision
-from decord import VideoReader, cpu
+from decord import cpu, VideoReader
-from src.datasets.utils.dataloader import ConcatIndices, MonitoredDataset, NondeterministicDataLoader
+from src.datasets.utils.dataloader import (
+ ConcatIndices,
+ MonitoredDataset,
+ NondeterministicDataLoader,
+)
from src.datasets.utils.weighted_sampler import DistributedWeightedSampler
_GLOBAL_SEED = 0
@@ -79,7 +83,9 @@ def make_videodataset(
logger.info("VideoDataset dataset created")
if datasets_weights is not None:
- dist_sampler = DistributedWeightedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
+ dist_sampler = DistributedWeightedSampler(
+ dataset, num_replicas=world_size, rank=rank, shuffle=True
+ )
else:
dist_sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=world_size, rank=rank, shuffle=True
@@ -146,7 +152,9 @@ def __init__(
self.fps = fps
if sum([v is not None for v in (fps, duration, frame_step)]) != 1:
- raise ValueError(f"Must specify exactly one of either {fps=}, {duration=}, or {frame_step=}.")
+ raise ValueError(
+ f"Must specify exactly one of either {fps=}, {duration=}, or {frame_step=}."
+ )
if isinstance(data_paths, str):
data_paths = [data_paths]
@@ -155,11 +163,15 @@ def __init__(
self.dataset_fpcs = [frames_per_clip for _ in data_paths]
else:
if len(dataset_fpcs) != len(data_paths):
- raise ValueError("Frames per clip not properly specified for NFS data paths")
+ raise ValueError(
+ "Frames per clip not properly specified for NFS data paths"
+ )
self.dataset_fpcs = dataset_fpcs
if VideoReader is None:
- raise ImportError('Unable to import "decord" which is required to read videos.')
+ raise ImportError(
+ 'Unable to import "decord" which is required to read videos.'
+ )
# Load video paths and labels
samples, labels = [], []
@@ -222,7 +234,9 @@ def get_item_video(self, index):
dataset_idx, _ = self.per_dataset_indices[index]
frames_per_clip = self.dataset_fpcs[dataset_idx]
- buffer, clip_indices = self.loadvideo_decord(sample, frames_per_clip) # [T H W 3]
+ buffer, clip_indices = self.loadvideo_decord(
+ sample, frames_per_clip
+ ) # [T H W 3]
loaded_video = len(buffer) > 0
if not loaded_video:
return
@@ -251,7 +265,9 @@ def get_item_image(self, index):
fpc = self.dataset_fpcs[dataset_idx]
try:
- image_tensor = torchvision.io.read_image(path=sample, mode=torchvision.io.ImageReadMode.RGB)
+ image_tensor = torchvision.io.read_image(
+ path=sample, mode=torchvision.io.ImageReadMode.RGB
+ )
except Exception:
return
label = self.labels[index]
diff --git a/src/hub/backbones.py b/src/hub/backbones.py
index 35308f86..435c72ea 100644
--- a/src/hub/backbones.py
+++ b/src/hub/backbones.py
@@ -5,17 +5,23 @@
import torch
-VJEPA_BASE_URL = "https://dl.fbaipublicfiles.com/vjepa2"
+# VJEPA_BASE_URL = "https://dl.fbaipublicfiles.com/vjepa2"
# for testing
-# VJEPA_BASE_URL = "http://localhost:8300"
+VJEPA_BASE_URL = "http://localhost:8300"
ARCH_NAME_MAP = {
+ # V-JEPA 2
"vit_large": ("vit_large", "vitl"),
"vit_huge": ("vit_huge", "vith"),
"vit_giant": ("vit_giant_xformers", "vitg"),
"vit_ac_giant": ("vit_giant_xformers", "vjepa2-ac-vitg"),
"vit_giant_384": ("vit_giant_xformers", "vitg-384"),
+ # V-JEPA 2.1
+ "vjepa2_1_vit_base_384": ("vit_base", "vjepa2_1_vitb_dist_vitG_384"),
+ "vjepa2_1_vit_large_384": ("vit_large", "vjepa2_1_vitl_dist_vitG_384"),
+ "vjepa2_1_vit_giant_384": ("vit_giant_xformers", "vjepa2_1_vitg_384"),
+ "vjepa2_1_vit_gigantic_384": ("vit_gigantic_xformers", "vjepa2_1_vitG_384"),
}
@@ -38,8 +44,10 @@ def _make_vjepa2_ac_model(
pretrained: bool = True,
**kwargs,
):
- from ..models import ac_predictor as vit_ac_predictor
- from ..models import vision_transformer as vit_encoder
+ from ..models import (
+ ac_predictor as vit_ac_predictor,
+ vision_transformer as vit_encoder,
+ )
vit_encoder_kwargs = dict(
patch_size=patch_size,
@@ -83,15 +91,17 @@ def _make_vjepa2_ac_model(
def _make_vjepa2_model(
*,
model_name: str = "vit_large",
+ checkpoint_key="target_encoder",
img_size=256,
patch_size=16,
tubelet_size=2,
num_frames=64,
+ predictor_embed_dim=384,
+ predictor_out_embed_dim=None,
pretrained: bool = True,
**kwargs,
):
- from ..models import predictor as vit_predictor
- from ..models import vision_transformer as vit_encoder
+ from ..models import predictor as vit_predictor, vision_transformer as vit_encoder
vit_encoder_kwargs = dict(
patch_size=patch_size,
@@ -114,7 +124,8 @@ def _make_vjepa2_model(
patch_size=patch_size,
use_mask_tokens=True,
embed_dim=encoder.embed_dim,
- predictor_embed_dim=384,
+ predictor_embed_dim=predictor_embed_dim,
+ out_embed_dim=predictor_out_embed_dim,
num_frames=num_frames,
tubelet_size=tubelet_size,
depth=12,
@@ -134,10 +145,14 @@ def _make_vjepa2_model(
model_file = ARCH_NAME_MAP[model_name][-1]
url = VJEPA_BASE_URL + f"/{model_file}.pt"
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
- encoder_state_dict = _clean_backbone_key(state_dict["encoder"])
- encoder.load_state_dict(encoder_state_dict, strict=False) # state_dict has pos_embed but we use RoPE
+ encoder_state_dict = _clean_backbone_key(state_dict[checkpoint_key])
+ encoder.load_state_dict(
+ encoder_state_dict, strict=False
+ ) # state_dict has pos_embed but we use RoPE
predictor_state_dict = _clean_backbone_key(state_dict["predictor"])
- predictor.load_state_dict(predictor_state_dict, strict=False) # state_dict has pos_embed but we use RoPE
+ predictor.load_state_dict(
+ predictor_state_dict, strict=False
+ ) # state_dict has pos_embed but we use RoPE
return encoder, predictor
@@ -146,32 +161,179 @@ def vjepa2_vit_large(*, pretrained: bool = True, **kwargs):
"""
VJEPA 2 ViT-Large model
"""
- return _make_vjepa2_model(model_name="vit_large", img_size=256, pretrained=pretrained, **kwargs)
+ return _make_vjepa2_model(
+ model_name="vit_large", img_size=256, pretrained=pretrained, **kwargs
+ )
def vjepa2_vit_huge(*, pretrained: bool = True, **kwargs):
"""
VJEPA 2 ViT-Huge model
"""
- return _make_vjepa2_model(model_name="vit_huge", img_size=256, pretrained=pretrained, **kwargs)
+ return _make_vjepa2_model(
+ model_name="vit_huge", img_size=256, pretrained=pretrained, **kwargs
+ )
def vjepa2_vit_giant(*, pretrained: bool = True, **kwargs):
"""
VJEPA 2 ViT-giant model
"""
- return _make_vjepa2_model(model_name="vit_giant", img_size=256, pretrained=pretrained, **kwargs)
+ return _make_vjepa2_model(
+ model_name="vit_giant", img_size=256, pretrained=pretrained, **kwargs
+ )
def vjepa2_vit_giant_384(*, pretrained: bool = True, **kwargs):
"""
VJEPA 2 ViT-giant-384 model
"""
- return _make_vjepa2_model(model_name="vit_giant_384", img_size=384, pretrained=pretrained, **kwargs)
+ return _make_vjepa2_model(
+ model_name="vit_giant_384", img_size=384, pretrained=pretrained, **kwargs
+ )
def vjepa2_ac_vit_giant(*, pretrained: bool = True, **kwargs):
"""
VJEPA 2-AC ViT-giant model
"""
- return _make_vjepa2_ac_model(model_name="vit_ac_giant", img_size=256, pretrained=pretrained, **kwargs)
+ return _make_vjepa2_ac_model(
+ model_name="vit_ac_giant", img_size=256, pretrained=pretrained, **kwargs
+ )
+
+
+# ########## V-JEPA 2.1 ##########
+
+
+vjepa2_1_teacher_embed_dim = 1664
+
+
+def _make_vjepa2_1_model(
+ model_name: str = "vjepa2_1_vit_large_384",
+ checkpoint_key="target_encoder",
+ img_size=384,
+ patch_size=16,
+ tubelet_size=2,
+ num_frames=64,
+ predictor_embed_dim=384,
+ predictor_depth=24,
+ predictor_num_mask_tokens=10,
+ n_output_distillation=4,
+ return_all_tokens=False,
+ teacher_embed_dim=None,
+ pretrained: bool = True,
+ **kwargs,
+):
+ from app.vjepa_2_1.models import predictor as vit_predictor, vision_transformer as vit_encoder
+
+ vit_encoder_kwargs = dict(
+ patch_size=patch_size,
+ img_size=(img_size, img_size),
+ num_frames=num_frames,
+ tubelet_size=tubelet_size,
+ use_sdpa=True,
+ use_SiLU=False,
+ wide_SiLU=True,
+ uniform_power=False,
+ use_rope=True,
+ img_temporal_dim_size=1,
+ interpolate_rope=True,
+ )
+ vit_encoder_kwargs.update(**kwargs)
+
+ arch_name = ARCH_NAME_MAP[model_name][0]
+ encoder = vit_encoder.__dict__[arch_name](**vit_encoder_kwargs)
+
+ vit_predictor_kwargs = dict(
+ img_size=(img_size, img_size),
+ patch_size=patch_size,
+ use_mask_tokens=True,
+ embed_dim=encoder.embed_dim,
+ predictor_embed_dim=predictor_embed_dim,
+ teacher_embed_dim=teacher_embed_dim,
+ num_frames=num_frames,
+ tubelet_size=tubelet_size,
+ depth=predictor_depth,
+ num_heads=12,
+ num_mask_tokens=predictor_num_mask_tokens,
+ use_rope=True,
+ uniform_power=False,
+ use_sdpa=True,
+ use_silu=False,
+ wide_silu=True,
+ n_output_distillation=n_output_distillation,
+ return_all_tokens=return_all_tokens,
+ img_temporal_dim_size=1,
+ )
+ vit_predictor_kwargs.update(**kwargs)
+
+ predictor = vit_predictor.__dict__["vit_predictor"](**vit_predictor_kwargs)
+
+ if pretrained:
+ model_file = ARCH_NAME_MAP[model_name][-1]
+ url = VJEPA_BASE_URL + f"/{model_file}.pt"
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
+ encoder_state_dict = _clean_backbone_key(state_dict[checkpoint_key])
+ encoder.load_state_dict(
+ encoder_state_dict, strict=True
+ ) # state_dict has pos_embed but we use RoPE
+ predictor_state_dict = _clean_backbone_key(state_dict["predictor"])
+ predictor.load_state_dict(
+ predictor_state_dict, strict=True
+ ) # state_dict has pos_embed but we use RoPE
+
+ return encoder, predictor
+
+
+def vjepa2_1_vit_base_384(*, pretrained: bool = True, **kwargs):
+ return _make_vjepa2_1_model(
+ model_name="vjepa2_1_vit_base_384",
+ checkpoint_key="ema_encoder",
+ img_size=384,
+ predictor_depth=12,
+ predictor_num_mask_tokens=8,
+ n_output_distillation=1,
+ return_all_tokens=True,
+ teacher_embed_dim=vjepa2_1_teacher_embed_dim,
+ pretrained=pretrained,
+ **kwargs,
+ )
+
+
+def vjepa2_1_vit_large_384(*, pretrained: bool = True, **kwargs):
+ return _make_vjepa2_1_model(
+ model_name="vjepa2_1_vit_large_384",
+ checkpoint_key="ema_encoder",
+ img_size=384,
+ predictor_depth=12,
+ predictor_num_mask_tokens=8,
+ n_output_distillation=1,
+ return_all_tokens=True,
+ teacher_embed_dim=vjepa2_1_teacher_embed_dim,
+ pretrained=pretrained,
+ **kwargs,
+ )
+
+
+def vjepa2_1_vit_giant_384(*, pretrained: bool = True, **kwargs):
+ return _make_vjepa2_1_model(
+ model_name="vjepa2_1_vit_giant_384",
+ img_size=384,
+ predictor_num_mask_tokens=8,
+ n_output_distillation=4,
+ return_all_tokens=True,
+ pretrained=pretrained,
+ **kwargs,
+ )
+
+
+def vjepa2_1_vit_gigantic_384(*, pretrained: bool = True, **kwargs):
+ return _make_vjepa2_1_model(
+ model_name="vjepa2_1_vit_gigantic_384",
+ img_size=384,
+ predictor_num_mask_tokens=8,
+ n_output_distillation=4,
+ return_all_tokens=True,
+ pretrained=pretrained,
+ **kwargs,
+ )
diff --git a/src/masks/multiseq_multiblock3d.py b/src/masks/multiseq_multiblock3d.py
index 85d9b049..cd8e2f2f 100644
--- a/src/masks/multiseq_multiblock3d.py
+++ b/src/masks/multiseq_multiblock3d.py
@@ -53,11 +53,23 @@ def step(self):
def __call__(self, batch):
- # Batch: [buffer, label, clip_indices]
+ # Batch: [buffer, label, clip_indices] for video
+ # or [buffer, label] for images
filtered_batches = {fpc: [] for fpc in self.mask_generators}
for sample in batch:
- fpc = len(sample[-1][-1])
- filtered_batches[fpc] += [sample]
+ # Check if sample is from video dataset (has clip_indices) or image dataset
+ if len(sample) >= 3 and isinstance(sample[-1], (list, tuple)):
+ # Video sample: sample[-1] is clip_indices, sample[-1][-1] contains frame indices
+ try:
+ fpc = len(sample[-1][-1])
+ except (TypeError, IndexError):
+ # Fallback: assume single frame if structure is unexpected
+ fpc = 1
+ else:
+ # Image sample: single frame
+ fpc = 1
+ if fpc in filtered_batches:
+ filtered_batches[fpc] += [sample]
fpc_collations = []
for fpc in filtered_batches:
@@ -71,7 +83,9 @@ def __call__(self, batch):
masks_enc, masks_pred = mask_generator(batch_size)
collated_masks_enc.append(masks_enc)
collated_masks_pred.append(masks_pred)
- fpc_collations += [(collated_batch, collated_masks_enc, collated_masks_pred)]
+ fpc_collations += [
+ (collated_batch, collated_masks_enc, collated_masks_pred)
+ ]
return fpc_collations
@@ -100,7 +114,9 @@ def __init__(
if not isinstance(spatial_patch_size, tuple):
spatial_patch_size = (spatial_patch_size,) * 2
self.crop_size = crop_size
- self.height, self.width = [crop_size[i] // spatial_patch_size[i] for i in (0, 1)]
+ self.height, self.width = [
+ crop_size[i] // spatial_patch_size[i] for i in (0, 1)
+ ]
self.duration = num_frames // temporal_patch_size
self.full_complement = full_complement
self.pred_full_complement = pred_full_complement
@@ -126,7 +142,9 @@ def step(self):
v = i.value
return v
- def _sample_block_size(self, generator, temporal_scale, spatial_scale, aspect_ratio_scale):
+ def _sample_block_size(
+ self, generator, temporal_scale, spatial_scale, aspect_ratio_scale
+ ):
# -- Sample temporal block mask scale
_rand = torch.rand(1, generator=generator).item()
min_t, max_t = temporal_scale
@@ -193,7 +211,9 @@ def __call__(self, batch_size):
empty_context = True
while empty_context:
- mask_e = torch.ones((self.duration, self.height, self.width), dtype=torch.int32)
+ mask_e = torch.ones(
+ (self.duration, self.height, self.width), dtype=torch.int32
+ )
for _ in range(self.npred):
mask_e *= self._sample_block_mask(p_size)
mask_e = mask_e.flatten()
@@ -216,7 +236,12 @@ def __call__(self, batch_size):
if self.full_complement: # predictor mask is just complement of encoder mask
collated_masks_pred = [
torch.tensor(
- sorted(list(set(range(int(self.duration * self.height * self.width))) - set(cm.tolist()))),
+ sorted(
+ list(
+ set(range(int(self.duration * self.height * self.width)))
+ - set(cm.tolist())
+ )
+ ),
dtype=cm.dtype,
)
for cm in collated_masks_enc
@@ -224,7 +249,12 @@ def __call__(self, batch_size):
elif self.pred_full_complement:
collated_masks_enc = [
torch.tensor(
- sorted(list(set(range(int(self.duration * self.height * self.width))) - set(cm.tolist()))),
+ sorted(
+ list(
+ set(range(int(self.duration * self.height * self.width)))
+ - set(cm.tolist())
+ )
+ ),
dtype=cm.dtype,
)
for cm in collated_masks_pred
diff --git a/src/models/predictor.py b/src/models/predictor.py
index 1b00189f..968c5cc5 100644
--- a/src/models/predictor.py
+++ b/src/models/predictor.py
@@ -26,6 +26,7 @@ def __init__(
tubelet_size=2,
embed_dim=768,
predictor_embed_dim=384,
+ out_embed_dim=None,
depth=6,
num_heads=12,
mlp_ratio=4.0,
@@ -120,9 +121,16 @@ def __init__(
]
)
+ if out_embed_dim is None:
+ teacher_embed_dim = kwargs.get("teacher_embed_dim", None)
+ if teacher_embed_dim is not None:
+ out_embed_dim = teacher_embed_dim
+ else:
+ out_embed_dim = embed_dim
+
# Normalize & project back to input dimension
self.predictor_norm = norm_layer(predictor_embed_dim)
- self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)
+ self.predictor_proj = nn.Linear(predictor_embed_dim, out_embed_dim, bias=True)
# ------ initialize weights
if self.predictor_pos_embed is not None:
diff --git a/src/models/vision_transformer.py b/src/models/vision_transformer.py
index e2c43592..be2bbfc5 100644
--- a/src/models/vision_transformer.py
+++ b/src/models/vision_transformer.py
@@ -453,7 +453,7 @@ def vit_gigantic(patch_size=16, **kwargs):
embed_dim=1664,
depth=48,
num_heads=16,
- mpl_ratio=64 / 13,
+ mlp_ratio=64 / 13,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs
@@ -467,7 +467,7 @@ def vit_gigantic_xformers(patch_size=16, **kwargs):
embed_dim=1664,
depth=48,
num_heads=26,
- mpl_ratio=64 / 13,
+ mlp_ratio=64 / 13,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs
diff --git a/src/utils/wrappers.py b/src/utils/wrappers.py
index 7e08ec2d..cb24a136 100644
--- a/src/utils/wrappers.py
+++ b/src/utils/wrappers.py
@@ -11,6 +11,7 @@ class MultiSeqWrapper(nn.Module):
def __init__(self, backbone):
super().__init__()
self.backbone = backbone
+ self.embed_dim = backbone.embed_dim
def forward(self, x, masks=None):
"""