diff --git a/CHANGELOG.md b/CHANGELOG.md index 4bd4e09d..b0350a95 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog + +## [0.0.2] - 2025-03-16 + +Release of V-JEPA 2.1 + + ## [0.0.1] - 2025-06-05 -Initial release of V-JEPA 2 codebase \ No newline at end of file +Initial release of V-JEPA 2 codebase diff --git a/README.md b/README.md index 0f0698dd..18fab914 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,9 @@ + +🆕 **[2026-03-16]:** :fire: V-JEPA 2.1 is released :fire: A new familly of models trained with a novel recipe that learns high quality and temporolly consistent dense features !!! + +**[2025-06-25]:** V-JEPA 2 is released. [[`Blog`](https://ai.meta.com/blog/v-jepa-2-world-model-benchmarks)] + + # V-JEPA 2: Self-Supervised Video Models Enable Understanding, Prediction and Planning ### [Meta FAIR](https://ai.meta.com/research/) @@ -13,7 +19,7 @@ Rabbat*, Nicolas Ballas* [[`Paper`](https://arxiv.org/abs/2506.09985)] [[`Blog`](https://ai.meta.com/blog/v-jepa-2-world-model-benchmarks)] [[`BibTex`](#Citation)] -Official Pytorch codebase for V-JEPA 2 and V-JEPA 2-AC. +Official Pytorch codebase for V-JEPA 2, V-JEPA 2-AC, V-JEPA 2.1. V-JEPA 2 is a self-supervised approach to training video encoders, using internet-scale video data, that attains state-of-the-art performance on motion understanding and human action anticipation tasks. V-JEPA 2-AC is a latent action-conditioned world model post-trained from V-JEPA 2 (using a small amount of robot trajectory interaction data) that solves robot manipulation tasks without environment-specific data collection or task-specific training or calibration. @@ -21,11 +27,37 @@ V-JEPA 2 is a self-supervised approach to training video encoders, using interne

- + +## V-JEPA 2.1 Pre-training + +Lorenzo Mur-Labadia, Matthew Muckley, Amir Bar, Mahmoud Assran, Koustuv Sinha, Michael +Rabbat, Yann LeCun, Nicolas Ballas, Adrien Bardes + +[[`Paper`](https://arxiv.org/abs/TODO)] [[`BibTex`](#Citation)] + +V-JEPA 2.1 improves the training recipe to focus on learning high-quality and temporally consistent dense features, as higlighted by PCA visualizations: + +

+ +

+ +The V-JEPA 2.1 approach leverages: (1) **Dense Predictive Loss**, a masking-based +self-supervision objective where all tokens (both visible/context and masked tokens) contribute to the +self-supervised training loss; (2) **Deep Self-Supervision**, which applies the self-supervised loss at multiple +intermediate representations of the encoder models; (3) **Multi-Modal Tokenizers** for images and videos; +and we show that our approach benefit from (4) **Model and data scaling**. + +

+ +

+ +V-JEPA 2.1 performance across dense and global prediction tasks: + +

+ +

+ ## V-JEPA 2 Pre-training @@ -35,7 +67,7 @@ V-JEPA 2 is a self-supervised approach to training video encoders, using interne - + @@ -111,15 +143,19 @@ V-JEPA 2 is a self-supervised approach to training video encoders, using interne
BenchmarkVJEPA 2V-JEPA 2 Previous Best
+ + + + ## Models -### V-JEPA 2 +### V-JEPA 2 and V-JEPA 2.1 #### HuggingFace -See our [HuggingFace collection](https://huggingface.co/collections/facebook/v-jepa-2-6841bad8413014e185b497a6) for V-JEPA 2. +See our HuggingFace [collection](https://huggingface.co/collections/facebook/v-jepa-2-6841bad8413014e185b497a6) for V-JEPA 2. -#### Pretrained Checkpoints +#### V-JEPA 2 Pretrained Checkpoints @@ -159,6 +195,51 @@ See our [HuggingFace collection](https://huggingface.co/collections/facebook/v-j
+#### V-JEPA 2.1 Pretrained Checkpoints + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Model#ParametersResolutionDownload LinkPretraining Config
ViT-B/1680M384checkpointconfigs
ViT-L/16300M384checkpointconfigs
ViT-g/161B384checkpointconfigs
ViT-G/162B384checkpointconfigs
+ + #### Pretrained backbones (via PyTorch Hub) Please install [Pytorch](https://pytorch.org/get-started/locally/), [timm](https://pypi.org/project/timm/) and [einops](https://pypi.org/project/einops/) locally, then run the following to load each model. Installing Pytorch with CUDA support is strongly recommended. @@ -169,16 +250,22 @@ import torch # preprocessor processor = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_preprocessor') # models +# V-JEPA 2 vjepa2_vit_large = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_vit_large') vjepa2_vit_huge = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_vit_huge') vjepa2_vit_giant = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_vit_giant') vjepa2_vit_giant_384 = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_vit_giant_384') +# V-JEPA 2.1 +vjepa2_1_vit_base_384 = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_1_vit_base_384') +vjepa2_1_vit_large_384 = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_1_vit_large_384') +vjepa2_1_vit_giant_384 = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_1_vit_giant_384') +vjepa2_1_vit_gigantic_384 = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_1_vit_gigantic_384') ``` #### Pretrained checkpoints on Huggingface -You can also use our pretrained checkpoints on [Huggingface](https://huggingface.co/collections/facebook/v-jepa-2-6841bad8413014e185b497a6). +You can also use our pretrained checkpoints on [Huggingface for V-JEPA 2](https://huggingface.co/collections/facebook/v-jepa-2-6841bad8413014e185b497a6). ```python from transformers import AutoVideoProcessor, AutoModel @@ -189,7 +276,6 @@ hf_repo = "facebook/vjepa2-vitg-fpc64-256" # facebook/vjepa2-vitg-fpc64-256 # facebook/vjepa2-vitg-fpc64-384 - model = AutoModel.from_pretrained(hf_repo) processor = AutoVideoProcessor.from_pretrained(hf_repo) ``` @@ -283,6 +369,7 @@ See [energy_landscape_example.ipynb](notebooks/energy_landscape_example.ipynb) f To run this notebook, you'll need to additionally install [Jupyter](https://jupyter.org/install) and [Scipy](https://scipy.org/install/) in your conda environment. + ## Getting Started ### Setup @@ -400,13 +487,16 @@ python -m app.main_distributed \ ``` . ├── app # training loops -│ ├── vjepa # video JEPA pre-training +│ ├── vjepa # V-JEPA 2 pre-training +│ ├── vjepa_2_1 # V-JEPA 2.1 pre-training │ ├── vjepa_droid # training the action-conditioned model │ ├── main_distributed.py # entrypoint for launch app on slurm cluster │ └── main.py # entrypoint for launch app locally on your machine ├── configs # config files with experiment params for training and evaluation -│ ├── train # pretraining (phase 1), cooldown (phase 2), and action-conditioned training +│ ├── train # pretraining with V-JEPA 2 (phase 1), cooldown (phase 2), and action-conditioned training +│ ├── train_2_1 # pretraining with V-JEPA 2.1 (phase 1), cooldown (phase 2) │ └── eval # frozen evaluations +│ └── inference # inference only frozen evaluations ├── evals # evaluation loops training an attentive probe with frozen backbone... │ ├── action_anticipation_frozen # action anticipation │ ├── image_classification_frozen # image understanding @@ -434,7 +524,8 @@ are licensed under the Apache 2.0 license. ## Citation -If you find this repository useful in your research, please consider giving a star :star: and a citation +If you find this repository useful in your research, please consider giving a star :star: and cite the papers: + ```bibtex @article{assran2025vjepa2, title={V-JEPA~2: Self-Supervised Video Models Enable Understanding, Prediction and Planning}, @@ -448,3 +539,13 @@ Rabbat, Michael and Ballas, Nicolas}, year={2025} } ``` + +```bibtex +@article{murlabadia2026vjepa2_1, + title={V-JEPA 2.1: Unlocking Dense Features in Video Self-Supervised Learning}, + author={Mur-Labadia, Lorenzo and Muckley, Matthew and Bar, Amir and Assran, Mahmoud and +Sinha, Koustuv and Rabbat, Michael and LeCun, Yann and Ballas, Nicolas and Bardes, Adrien}, + journal={arXiv preprint arXiv:2603.14482}, + year={2026} +} +``` diff --git a/app/vjepa_2_1/models/predictor.py b/app/vjepa_2_1/models/predictor.py new file mode 100644 index 00000000..a8e7f968 --- /dev/null +++ b/app/vjepa_2_1/models/predictor.py @@ -0,0 +1,302 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from functools import partial + +import torch +import torch.nn as nn + +from src.masks.utils import apply_masks +from src.utils.tensors import repeat_interleave_batch, trunc_normal_ + +from app.vjepa_2_1.models.utils.modules import Block + + +class VisionTransformerPredictor(nn.Module): + """Vision Transformer Predictor""" + + def __init__( + self, + img_size=(224, 224), + patch_size=16, + num_frames=1, + tubelet_size=2, + embed_dim=768, + predictor_embed_dim=384, + out_embed_dim=None, + depth=6, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_layer=nn.LayerNorm, + init_std=0.02, + uniform_power=False, + use_mask_tokens=False, + num_mask_tokens=2, + zero_init_mask_tokens=True, + use_silu=False, + wide_silu=True, + is_causal=False, + use_activation_checkpointing=False, + return_all_tokens=False, + chop_last_n_tokens=0, + use_rope=False, + n_registers=0, + has_cls_first=False, + interpolate_rope=False, + modality_embedding=True, + img_temporal_dim_size=None, + teacher_embed_dim=None, + **kwargs + ): + super().__init__() + self.return_all_tokens = return_all_tokens + self.chop_last_n_tokens = chop_last_n_tokens + self.has_cls_first = has_cls_first + + if depth == 4: + all_hierarchical_layers = [0, 1, 2, 3] + elif depth == 8: + all_hierarchical_layers = [1, 3, 5, 7] + elif depth == 12: + all_hierarchical_layers = [2, 5, 8, 11] + elif depth == 20: + all_hierarchical_layers = [4, 9, 14, 19] + elif depth == 24: + all_hierarchical_layers = [4, 11, 17, 23] + elif depth == 40: + all_hierarchical_layers = [9, 19, 29, 39] + + n_output_distillation = kwargs.get("n_output_distillation", len(all_hierarchical_layers)) + self.hierarchical_layers = all_hierarchical_layers[-n_output_distillation:] + + act_layer_mlp = nn.SiLU if use_silu else nn.GELU + if len(self.hierarchical_layers) == 1: + self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True) + else: + self.predictor_embed = nn.Sequential( + nn.Linear(embed_dim * len(self.hierarchical_layers), embed_dim, bias=True), + act_layer_mlp(), + nn.Linear(embed_dim, predictor_embed_dim, bias=True), + ) + + self.mask_tokens = None + self.num_mask_tokens = 0 + if use_mask_tokens: + self.num_mask_tokens = num_mask_tokens + self.mask_tokens = nn.ParameterList( + [ + nn.Parameter(torch.zeros(1, 1, predictor_embed_dim)) + for i in range(num_mask_tokens) + ] + ) + + if type(img_size) is int: + img_size = (img_size, img_size) + self.img_height, self.img_width = img_size + self.patch_size = patch_size + self.num_frames = num_frames + self.tubelet_size = tubelet_size + self.is_video = num_frames > 1 + + self.grid_height = img_size[0] // self.patch_size + self.grid_width = img_size[1] // self.patch_size + self.grid_depth = num_frames // self.tubelet_size + self.use_activation_checkpointing = use_activation_checkpointing + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + + if self.is_video: + self.num_patches = ( + (num_frames // tubelet_size) + * (img_size[0] // patch_size) + * (img_size[1] // patch_size) + ) + else: + self.num_patches = (img_size[0] // patch_size) * ( + img_size[1] // patch_size + ) + + self.modality_embedding = False + if img_temporal_dim_size is not None: + if modality_embedding: + self.video_mod_embed = nn.Parameter( + torch.zeros(1, 1, predictor_embed_dim) + ) + nn.init.normal_(self.video_mod_embed, std=1e-6) + self.img_mod_embed = nn.Parameter( + torch.zeros(1, 1, predictor_embed_dim) + ) + nn.init.normal_(self.img_mod_embed, std=1e-6) + self.modality_embedding = True + + self.uniform_power = uniform_power + + self.use_rope = use_rope + self.predictor_blocks = nn.ModuleList( + [ + Block( + use_rope=use_rope, + grid_size=self.grid_height, + grid_depth=self.grid_depth, + dim=predictor_embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=nn.SiLU if use_silu else nn.GELU, + is_causal=is_causal, + wide_silu=wide_silu, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + n_registers=n_registers, + has_cls_first=has_cls_first, + interpolate_rope=interpolate_rope, + patch_size=patch_size, + ) + for i in range(depth) + ] + ) + + if out_embed_dim is None: + if teacher_embed_dim is not None: + out_embed_dim = teacher_embed_dim // len(self.hierarchical_layers) + else: + out_embed_dim = embed_dim + self.predictor_norm = norm_layer(predictor_embed_dim) + self.predictor_proj = nn.Linear( + predictor_embed_dim, + len(self.hierarchical_layers) * out_embed_dim, + bias=True, + ) + if self.return_all_tokens: + self.predictor_proj_context = nn.Linear( + predictor_embed_dim, + out_embed_dim * len(self.hierarchical_layers), + bias=True, + ) + + self.init_std = init_std + if not zero_init_mask_tokens: + for mt in self.mask_tokens: + trunc_normal_(mt, std=init_std) + + self.apply(self._init_weights) + self._rescale_blocks() + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=self.init_std) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def _rescale_blocks(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.predictor_blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def forward(self, x, masks_x, masks_y, mod="video", mask_index=1): + """ + :param x: context tokens + :param masks_x: indices of context tokens in input + :params masks_y: indices of target tokens in input + """ + assert (masks_x is not None) and ( + masks_y is not None + ), "Cannot run predictor without mask indices" + if not isinstance(masks_x, list): + masks_x = [masks_x] + if not isinstance(masks_y, list): + masks_y = [masks_y] + + B = len(x) // len(masks_x) + + x = self.predictor_embed(x) + _, N_ctxt, D = x.shape + + if not self.use_rope: + x_pos_embed = self.predictor_pos_embed.repeat(B, 1, 1) + x += apply_masks(x_pos_embed, masks_x) + + mask_index = mask_index % self.num_mask_tokens + pred_tokens = self.mask_tokens[mask_index] + pred_tokens = pred_tokens.repeat(B, self.num_patches, 1) + pred_tokens = apply_masks(pred_tokens, masks_y) + + if not self.use_rope: + pos_embs = self.predictor_pos_embed.repeat(B, 1, 1) + pos_embs = apply_masks(pos_embs, masks_y) + pos_embs = repeat_interleave_batch(pos_embs, B, repeat=len(masks_x)) + pred_tokens += pos_embs + + x = x.repeat(len(masks_x), 1, 1) + x = torch.cat([x, pred_tokens], dim=1) + + masks_x = torch.cat(masks_x, dim=0) + masks_y = torch.cat(masks_y, dim=0) + masks = torch.cat([masks_x, masks_y], dim=1) + + argsort = torch.argsort(masks, dim=1) + masks = torch.stack([masks[i, row] for i, row in enumerate(argsort)], dim=0) + x = torch.stack([x[i, row, :] for i, row in enumerate(argsort)], dim=0) + + if self.chop_last_n_tokens > 0: + x = x[:, : -self.chop_last_n_tokens] + masks = masks[:, : -self.chop_last_n_tokens] + + if self.modality_embedding: + if mod == "image": + x += self.img_mod_embed.repeat(B, 1, 1) + else: + x += self.video_mod_embed.repeat(B, 1, 1) + + for i, blk in enumerate(self.predictor_blocks): + if self.use_activation_checkpointing: + x, attn = torch.utils.checkpoint.checkpoint( + blk, x, masks, use_reentrant=False + ) + else: + x, attn = blk(x, mask=masks) + x = self.predictor_norm(x) + + if not self.return_all_tokens: + reverse_argsort = torch.argsort(argsort, dim=1) + x = torch.stack( + [x[i, row, :] for i, row in enumerate(reverse_argsort)], dim=0 + ) + x = x[:, N_ctxt:, :] + x = self.predictor_proj(x) + return x, None + else: + reverse_argsort = torch.argsort(argsort, dim=1) + x = torch.stack( + [x[i, row, :] for i, row in enumerate(reverse_argsort)], dim=0 + ) + x_pred = x[:, N_ctxt:, :] + x_context = x[:, :N_ctxt, :] + x_pred = self.predictor_proj(x_pred) + x_context = self.predictor_proj_context(x_context) + return x_pred, x_context + + +def vit_predictor(**kwargs): + model = VisionTransformerPredictor( + mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs + ) + return model diff --git a/app/vjepa_2_1/models/utils/masks_dist.py b/app/vjepa_2_1/models/utils/masks_dist.py new file mode 100644 index 00000000..216f5738 --- /dev/null +++ b/app/vjepa_2_1/models/utils/masks_dist.py @@ -0,0 +1,77 @@ +import torch +import torch.nn as nn +import torchvision + + +def _get_frame_pos(ids, H_patches=None, W_patches=None, grid_size=None): + if H_patches is None or W_patches is None: + tokens_per_frame = int(grid_size * grid_size) + else: + tokens_per_frame = int(H_patches * W_patches) + return ids // tokens_per_frame + + +def _get_height_pos(ids, H_patches=None, W_patches=None, grid_size=None): + # Remove frame component from ids + if H_patches is None or W_patches is None: + tokens_per_frame = int(grid_size * grid_size) + tokens_per_row = grid_size + else: + tokens_per_frame = int(H_patches * W_patches) + tokens_per_row = W_patches + frame_ids = _get_frame_pos(ids, H_patches, W_patches, grid_size) + ids = ids - tokens_per_frame * frame_ids + # -- + return ids // tokens_per_row + + +def separate_positions(ids, H_patches=None, W_patches=None, grid_size=None): + if H_patches is None or W_patches is None: + tokens_per_frame = int(grid_size * grid_size) + tokens_per_row = grid_size + else: + tokens_per_frame = int(H_patches * W_patches) + tokens_per_row = W_patches + frame_ids = _get_frame_pos(ids, H_patches, W_patches, grid_size) + # -- + height_ids = _get_height_pos(ids, H_patches, W_patches, grid_size) + # -- + # Remove frame component from ids (1st term) and height component (2nd term) + width_ids = (ids - tokens_per_frame * frame_ids) - tokens_per_row * height_ids + return 1.0 * frame_ids, 1.0 * height_ids, 1.0 * width_ids + + +def compute_mask_distance(masks_pred, masks_enc, grid_size, offset_context_loss): + # masks_pred: [fpc][mask] where each mask is [B, N_pred] + # masks_enc: [fpc][mask] where each mask is [B, N_enc] + distances = [] + for masks_pred_i, masks_enc_i in zip(masks_pred, masks_enc): + row_distances = [] + for masks_pred_ij, masks_enc_ij in zip(masks_pred_i, masks_enc_i): + N_enc_tokens = masks_enc_ij.shape[1] + d_enc, h_enc, w_enc = separate_positions( + masks_enc_ij, grid_size=grid_size + ) # (BS, N_enc) + d_pred, h_pred, w_pred = separate_positions( + masks_pred_ij, grid_size=grid_size + ) # (BS, N_pred) + pred = torch.stack([d_pred, h_pred, w_pred], dim=-1) # (BS, N_pred, 3) + enc_distances = [] + for enc_token in range(N_enc_tokens): + enc_position = torch.stack( + [d_enc[:, enc_token], h_enc[:, enc_token], w_enc[:, enc_token]], + dim=-1, + ).unsqueeze( + 1 + ) # (BS, 1, 3) + dist = torch.cdist(enc_position, pred, p=2) # (BS, N_enc) + dmin, argmin = dist.min(dim=-1) + if offset_context_loss: + coeff = grid_size // 16 # Which is the default value of grid_size + dmin = dmin * (1.0 / coeff) + dmin = dmin**0.5 # We want that it decreases less agressive + enc_distances.append(dmin) + enc_distances = torch.stack(enc_distances, dim=-1).squeeze() # (BS, N_enc) + row_distances.append(enc_distances) + distances.append(row_distances) + return distances diff --git a/app/vjepa_2_1/models/utils/modules.py b/app/vjepa_2_1/models/utils/modules.py new file mode 100644 index 00000000..fedf8a2e --- /dev/null +++ b/app/vjepa_2_1/models/utils/modules.py @@ -0,0 +1,544 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.models.layers import drop_path + + +def rotate_queries_or_keys(x, pos, n_registers, has_cls_first): + B, num_heads, N, D = x.size() + assert ( + D % 2 == 0 + ), "Embedding dimension must be a multiple of 2 for block matrix rotation" + + n_cls = 1 if has_cls_first else 0 + start_ctx = n_cls + end_ctx = N - n_registers + + x_cls = x[..., :n_cls, :] if n_cls else None + x_ctx = x[..., start_ctx:end_ctx, :] + x_reg = x[..., end_ctx:, :] if n_registers > 0 else None + + omega = torch.arange(D // 2, dtype=x.dtype, device=x.device) + omega /= D / 2.0 + omega = 1.0 / 10000**omega + freq = torch.einsum("..., f -> ... f", pos, omega) + + emb_sin = freq.sin() + emb_cos = freq.cos() + + emb_sin = emb_sin.repeat_interleave(2, dim=-1) + emb_cos = emb_cos.repeat_interleave(2, dim=-1) + + y = x_ctx.unflatten(-1, (-1, 2)) + y1, y2 = y.unbind(dim=-1) + y = torch.stack((-y2, y1), dim=-1) + y = y.flatten(-2) + + out_ctx = (x_ctx * emb_cos) + (y * emb_sin) + + parts = [] + if n_cls: + parts.append(x_cls) + parts.append(out_ctx) + if n_registers: + parts.append(x_reg) + out = torch.cat(parts, dim=-2) + + return out + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +class MLP(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.SiLU, + drop=0.0, + wide_silu=True, + ): + super().__init__() + out_features = out_features or in_features + swiglu_hidden_features = hidden_features = hidden_features or in_features + if wide_silu: + swiglu_hidden_features = int(2 * hidden_features / 3) + align_as = 8 + swiglu_hidden_features = ( + (swiglu_hidden_features + align_as - 1) // align_as * align_as + ) + self.fc1 = nn.Linear(in_features, swiglu_hidden_features) + self.fc2 = nn.Linear(in_features, swiglu_hidden_features) + self.act = act_layer() + self.fc3 = nn.Linear(swiglu_hidden_features, out_features) + + def forward(self, x): + x1 = self.fc1(x) + x2 = self.fc2(x) + hidden = F.silu(x1) * x2 + return self.fc3(hidden) + + +class RoPEAttention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + use_sdpa=True, + grid_size=14, + is_causal=False, + n_registers=0, + has_cls_first=False, + interpolate_rope=False, + patch_size=16, + ): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop_prob = proj_drop + self.proj_drop = nn.Dropout(proj_drop) + self.use_sdpa = use_sdpa + self.d_dim = int(2 * ((head_dim // 3) // 2)) + self.h_dim = int(2 * ((head_dim // 3) // 2)) + self.w_dim = int(2 * ((head_dim // 3) // 2)) + self.grid_size = grid_size + self.is_causal = is_causal + self.n_registers = n_registers + self.has_cls_first = has_cls_first + self.interpolate_rope = interpolate_rope + self.pretrained_patch_size = patch_size + if patch_size == 14: + self.pretrained_grid_size = int(252 / patch_size) + elif patch_size == 16: + self.pretrained_grid_size = int(256 / patch_size) + + def _get_frame_pos(self, ids, H_patches=None, W_patches=None): + if H_patches is None or W_patches is None: + tokens_per_frame = int(self.grid_size * self.grid_size) + else: + tokens_per_frame = int(H_patches * W_patches) + return ids // tokens_per_frame + + def _get_height_pos(self, ids, H_patches=None, W_patches=None): + if H_patches is None or W_patches is None: + tokens_per_frame = int(self.grid_size * self.grid_size) + tokens_per_row = self.grid_size + else: + tokens_per_frame = int(H_patches * W_patches) + tokens_per_row = W_patches + frame_ids = self._get_frame_pos(ids, H_patches, W_patches) + ids = ids - tokens_per_frame * frame_ids + return ids // tokens_per_row + + def separate_positions(self, ids, H_patches=None, W_patches=None): + if H_patches is None or W_patches is None: + tokens_per_frame = int(self.grid_size * self.grid_size) + tokens_per_row = self.grid_size + else: + tokens_per_frame = int(H_patches * W_patches) + tokens_per_row = W_patches + frame_ids = self._get_frame_pos(ids, H_patches, W_patches) + height_ids = self._get_height_pos(ids, H_patches, W_patches) + width_ids = (ids - tokens_per_frame * frame_ids) - tokens_per_row * height_ids + return 1.0 * frame_ids, 1.0 * height_ids, 1.0 * width_ids + + def forward( + self, + x, + mask=None, + T=None, + H_patches=None, + W_patches=None, + return_attn=False, + ): + B, N, C = x.size() + N_ctx = N - self.n_registers + grid_depth = int(N_ctx // (self.grid_size * self.grid_size)) + + qkv = self.qkv(x).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + if mask is not None: + mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1) + d_mask, h_mask, w_mask = self.separate_positions(mask, H_patches, W_patches) + else: + if T is None or H_patches is None or W_patches is None: + mask = torch.arange( + int(grid_depth * self.grid_size * self.grid_size), device=x.device + ) + else: + mask = torch.arange(int(T * H_patches * W_patches), device=x.device) + d_mask, h_mask, w_mask = self.separate_positions(mask, H_patches, W_patches) + + if self.interpolate_rope: + if H_patches is None: + H_patches = int(self.grid_size) + if W_patches is None: + W_patches = int(self.grid_size) + h_mask = h_mask * (self.pretrained_grid_size - 1) / (H_patches - 1) + w_mask = w_mask * (self.pretrained_grid_size - 1) / (W_patches - 1) + + s = 0 + qd = rotate_queries_or_keys( + q[..., s : s + self.d_dim], + pos=d_mask, + n_registers=self.n_registers, + has_cls_first=self.has_cls_first, + ) + kd = rotate_queries_or_keys( + k[..., s : s + self.d_dim], + pos=d_mask, + n_registers=self.n_registers, + has_cls_first=self.has_cls_first, + ) + s += self.d_dim + qh = rotate_queries_or_keys( + q[..., s : s + self.h_dim], + pos=h_mask, + n_registers=self.n_registers, + has_cls_first=self.has_cls_first, + ) + kh = rotate_queries_or_keys( + k[..., s : s + self.h_dim], + pos=h_mask, + n_registers=self.n_registers, + has_cls_first=self.has_cls_first, + ) + s += self.h_dim + qw = rotate_queries_or_keys( + q[..., s : s + self.w_dim], + pos=w_mask, + n_registers=self.n_registers, + has_cls_first=self.has_cls_first, + ) + kw = rotate_queries_or_keys( + k[..., s : s + self.w_dim], + pos=w_mask, + n_registers=self.n_registers, + has_cls_first=self.has_cls_first, + ) + s += self.w_dim + + if s < self.head_dim: + qr = q[..., s:] + kr = k[..., s:] + q = torch.cat([qd, qh, qw, qr], dim=-1) + k = torch.cat([kd, kh, kw, kr], dim=-1) + else: + q = torch.cat([qd, qh, qw], dim=-1) + k = torch.cat([kd, kh, kw], dim=-1) + + if self.use_sdpa: + with torch.backends.cuda.sdp_kernel(): + x = F.scaled_dot_product_attention( + q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal + ) + attn = None + else: + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + if return_attn: + return x, attn + else: + return x, None + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + use_sdpa=True, + is_causal=False, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop_prob = proj_drop + self.proj_drop = nn.Dropout(proj_drop) + self.use_sdpa = use_sdpa + self.is_causal = is_causal + + def forward(self, x): + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv[0], qkv[1], qkv[2] + + if self.use_sdpa: + with torch.backends.cuda.sdp_kernel(): + x = F.scaled_dot_product_attention( + q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal + ) + attn = None + else: + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + wide_silu=True, + norm_layer=nn.LayerNorm, + use_sdpa=True, + is_causal=False, + grid_size=16, + use_rope=False, + n_registers=0, + has_cls_first=False, + interpolate_rope=False, + patch_size=16, + **kwargs, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.use_rope = use_rope + if use_rope: + self.attn = RoPEAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + use_sdpa=use_sdpa, + is_causal=is_causal, + grid_size=grid_size, + proj_drop=drop, + n_registers=n_registers, + has_cls_first=has_cls_first, + interpolate_rope=interpolate_rope, + patch_size=patch_size, + ) + else: + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + use_sdpa=use_sdpa, + is_causal=is_causal, + proj_drop=drop, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + if act_layer is nn.SiLU: + self.mlp = SwiGLUFFN( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + wide_silu=wide_silu, + drop=drop, + ) + else: + self.mlp = MLP( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + def forward( + self, + x, + mask=None, + T=None, + H_patches=None, + W_patches=None, + return_attn=False, + mode="video", + ): + if self.use_rope: + y, attn = self.attn( + self.norm1(x), + mask=mask, + T=T, + H_patches=H_patches, + W_patches=W_patches, + return_attn=return_attn, + ) + else: + y = self.attn(self.norm1(x)) + attn = None + x = x + self.drop_path(y) + x = x + self.drop_path(self.mlp(self.norm2(x))) + if return_attn: + return x, attn + else: + return x, None + + +class CrossAttention(nn.Module): + def __init__(self, dim, num_heads=12, qkv_bias=False, use_sdpa=True): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, int(dim * 2), bias=qkv_bias) + self.use_sdpa = use_sdpa + + def forward(self, q, x): + B, n, C = q.shape + q = ( + self.q(q) + .reshape(B, n, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + + B, N, C = x.shape + kv = ( + self.kv(x) + .reshape(B, N, 2, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + k, v = kv[0], kv[1] + + if self.use_sdpa: + with torch.backends.cuda.sdp_kernel(): + q = F.scaled_dot_product_attention(q, k, v) + else: + xattn = (q @ k.transpose(-2, -1)) * self.scale + xattn = xattn.softmax(dim=-1) + q = xattn @ v + + q = q.transpose(1, 2).reshape(B, n, C) + return q + + +class CrossAttentionBlock(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.xattn = CrossAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MLP( + in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer + ) + + def forward(self, q, x): + y = self.xattn(q, self.norm1(x)) + q = q + y + q = q + self.mlp(self.norm2(q)) + return q + + +class Lambda_LinearWarmupHold: + """ + Linear warmup from 0 to lambda_value between [start_iter, end_iter], + 0 before start_iter and constant (=lambda_value) from end_iter onwards. + """ + + def __init__( + self, lambda_value: float, start_iter: int = 15_000, end_iter: int = 30_000 + ): + assert end_iter > start_iter, "end_iter must be > start_iter" + self.lambda_value = float(lambda_value) + self.start = int(start_iter) + self.end = int(end_iter) + self.span = self.end - self.start + + def value(self, global_iter: int) -> float: + if global_iter < self.start: + return 0.0 + if global_iter >= self.end: + return self.lambda_value + alpha = (global_iter - self.start) / self.span + return self.lambda_value * alpha diff --git a/app/vjepa_2_1/models/utils/patch_embed.py b/app/vjepa_2_1/models/utils/patch_embed.py new file mode 100644 index 00000000..9dc34352 --- /dev/null +++ b/app/vjepa_2_1/models/utils/patch_embed.py @@ -0,0 +1,72 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from einops import rearrange + +import torch.nn as nn + + +class AudioPatchEmbed(nn.Module): + """ + Audio to Patch Embedding + """ + + def __init__(self, freq_bands=128, tubelet_size=2, embed_dim=768): + super().__init__() + self.freq_bands = freq_bands + self.tubelet_size = tubelet_size + self.proj = nn.Conv2d(1, embed_dim, kernel_size=(freq_bands, tubelet_size), stride=(freq_bands, tubelet_size)) + + def forward(self, x): + x = rearrange(x, "b t c f -> b c f t") + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding + """ + + def __init__(self, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + self.patch_size = patch_size + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class PatchEmbed3D(nn.Module): + """ + Image to Patch Embedding + """ + + def __init__( + self, + patch_size=16, + tubelet_size=2, + in_chans=3, + embed_dim=768, + ): + super().__init__() + self.patch_size = patch_size + self.tubelet_size = tubelet_size + + self.proj = nn.Conv3d( + in_channels=in_chans, + out_channels=embed_dim, + kernel_size=(tubelet_size, patch_size, patch_size), + stride=(tubelet_size, patch_size, patch_size), + ) + + def forward(self, x, **kwargs): + B, C, T, H, W = x.shape + x = self.proj(x).flatten(2).transpose(1, 2) + return x diff --git a/app/vjepa_2_1/models/utils/pos_embs.py b/app/vjepa_2_1/models/utils/pos_embs.py new file mode 100644 index 00000000..85036dd3 --- /dev/null +++ b/app/vjepa_2_1/models/utils/pos_embs.py @@ -0,0 +1,95 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import numpy as np + + +def get_3d_sincos_pos_embed(embed_dim, grid_size, grid_depth, cls_token=False, uniform_power=False): + """ + grid_size: int of the grid height and width + grid_depth: int of the grid depth + returns: + pos_embed: [grid_depth*grid_size*grid_size, embed_dim] (w/o cls_token) + or [1+grid_depth*grid_size*grid_size, embed_dim] (w/ cls_token) + """ + grid_d = np.arange(grid_depth, dtype=float) + grid_h = np.arange(grid_size, dtype=float) + grid_w = np.arange(grid_size, dtype=float) + grid_h, grid_d, grid_w = np.meshgrid( + grid_h, grid_d, grid_w + ) # order of meshgrid is very important for indexing as [d,h,w] + + if not uniform_power: + h_embed_dim = embed_dim // 4 + w_embed_dim = embed_dim // 4 + d_embed_dim = embed_dim // 2 + else: + h_embed_dim = w_embed_dim = d_embed_dim = int(np.ceil(embed_dim / 6) * 2) + + emb_h = get_1d_sincos_pos_embed_from_grid(h_embed_dim, grid_h) # (T*H*W, D1) + emb_w = get_1d_sincos_pos_embed_from_grid(w_embed_dim, grid_w) # (T*H*W, D2) + emb_d = get_1d_sincos_pos_embed_from_grid(d_embed_dim, grid_d) # (T*H*W, D3) + pos_embed = np.concatenate([emb_d, emb_h, emb_w], axis=1) + pos_embed = pos_embed[:, :embed_dim] + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + returns: + pos_embed: [grid_size*grid_size, embed_dim] (w/o cls_token) + or [1+grid_size*grid_size, embed_dim] (w/ cls_token) + """ + grid_h = np.arange(grid_size, dtype=float) + grid_w = np.arange(grid_size, dtype=float) + grid_w, grid_h = np.meshgrid(grid_w, grid_h) # order of meshgrid is very important for indexing as [h, w] + + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_h) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_w) # (H*W, D/2) + pos_embed = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + embed_dim: output dimension for each position + grid_size: int of the grid length + returns: + pos_embed: [grid_size, embed_dim] (w/o cls_token) + or [1+grid_size, embed_dim] (w/ cls_token) + """ + grid = np.arange(grid_size, dtype=float) + pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + returns: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb diff --git a/app/vjepa_2_1/models/vision_transformer.py b/app/vjepa_2_1/models/vision_transformer.py new file mode 100644 index 00000000..c8797fbb --- /dev/null +++ b/app/vjepa_2_1/models/vision_transformer.py @@ -0,0 +1,608 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from functools import partial + +import torch +import torch.nn as nn + +from src.masks.utils import apply_masks +from src.utils.tensors import trunc_normal_ + +from app.vjepa_2_1.models.utils.modules import Block +from app.vjepa_2_1.models.utils.patch_embed import PatchEmbed, PatchEmbed3D + + +class VisionTransformer(nn.Module): + """Vision Transformer""" + + def __init__( + self, + img_size=(224, 224), + patch_size=16, + num_frames=1, + tubelet_size=2, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_layer=nn.LayerNorm, + init_std=0.02, + out_layers=None, + uniform_power=False, + use_silu=False, + wide_silu=True, + use_sdpa=True, + use_activation_checkpointing=False, + is_causal=False, + use_rope=False, + init_type: str = "default", + handle_nonsquare_inputs=True, + img_temporal_dim_size=None, + n_registers=0, + has_cls_first=False, + interpolate_rope=False, + modality_embedding=True, + n_output_distillation=4, + **kwargs, + ): + super().__init__() + self.num_features = self.embed_dim = embed_dim + self.num_heads = num_heads + self.out_layers = out_layers + self.init_type = init_type + self.handle_nonsquare_inputs = handle_nonsquare_inputs + self.img_temporal_dim_size = img_temporal_dim_size + + if type(img_size) is int: + img_size = (img_size, img_size) + self.img_height, self.img_width = img_size + self.patch_size = patch_size + self.num_frames = num_frames + self.tubelet_size = tubelet_size + self.is_video = num_frames > 1 + + self.use_activation_checkpointing = use_activation_checkpointing + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + + if self.is_video: + self.patch_embed = PatchEmbed3D( + patch_size=patch_size, + tubelet_size=tubelet_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) + self.num_patches = ( + (num_frames // tubelet_size) + * (img_size[0] // patch_size) + * (img_size[1] // patch_size) + ) + else: + self.patch_embed = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim + ) + self.num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size) + + if self.img_temporal_dim_size is not None: + if not isinstance(self.img_temporal_dim_size, int): + raise ValueError( + f"img_temporal_dim_size must be an int, got {self.img_temporal_dim_size}" + ) + self.patch_embed_img = PatchEmbed3D( + patch_size=patch_size, + tubelet_size=1, + in_chans=in_chans, + embed_dim=embed_dim, + ) + else: + self.patch_embed_img = None + + self.uniform_power = uniform_power + + self.use_rope = use_rope + self.blocks = nn.ModuleList( + [ + Block( + use_rope=use_rope, + grid_size=img_size[0] // patch_size, + grid_depth=num_frames // tubelet_size, + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + use_sdpa=use_sdpa, + is_causal=is_causal, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=nn.SiLU if use_silu else nn.GELU, + wide_silu=wide_silu, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + n_registers=n_registers, + has_cls_first=has_cls_first, + interpolate_rope=interpolate_rope, + patch_size=patch_size, + ) + for i in range(depth) + ] + ) + + self.attn_out = False + self.init_std = init_std + self.apply(self._init_weights) + self._rescale_blocks() + + if depth == 12: + self.hierarchical_layers = [2, 5, 8, 11] + if n_output_distillation == 4: + self.out_layers_distillation = [2, 5, 8, 11] + elif n_output_distillation == 1: + self.out_layers_distillation = [11] + + elif depth == 24: + self.hierarchical_layers = [5, 11, 17, 23] + if n_output_distillation == 4: + self.out_layers_distillation = [5, 11, 17, 23] + elif n_output_distillation == 1: + self.out_layers_distillation = [23] + + elif depth == 40: + self.hierarchical_layers = [9, 19, 29, 39] + if n_output_distillation == 4: + self.out_layers_distillation = [9, 19, 29, 39] + elif n_output_distillation == 1: + self.out_layers_distillation = [39] + + elif depth == 48: + self.hierarchical_layers = [11, 23, 37, 47] + if n_output_distillation == 4: + self.out_layers_distillation = [11, 23, 37, 47] + elif n_output_distillation == 1: + self.out_layers_distillation = [47] + else: + print("Check the code! ;)") + self.norms_block = nn.ModuleList( + [norm_layer(embed_dim) for _ in range(len(self.hierarchical_layers))] + ) + + self.cls_token = None + self.return_hierarchical = False + + self.modality_embedding = False + if modality_embedding: + self.img_mod_embed = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.video_mod_embed = nn.Parameter(torch.zeros(1, 1, embed_dim)) + nn.init.normal_(self.img_mod_embed, std=1e-6) + nn.init.normal_(self.video_mod_embed, std=1e-6) + self.modality_embedding = True + + def _init_weights(self, m): + if isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + return + if self.init_type == "default": + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=self.init_std) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=self.init_std) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv3d): + trunc_normal_(m.weight, std=self.init_std) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + elif self.init_type == "xavier_uniform": + if ( + isinstance(m, nn.Linear) + or isinstance(m, nn.Conv2d) + or isinstance(m, nn.Conv3d) + ): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif self.init_type == "xavier_normal": + if ( + isinstance(m, nn.Linear) + or isinstance(m, nn.Conv2d) + or isinstance(m, nn.Conv3d) + ): + nn.init.xavier_normal_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + else: + raise ValueError(f"Unknown init type {self.init_type}") + + def _rescale_blocks(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def get_num_layers(self): + return len(self.blocks) + + def no_weight_decay(self): + return {} + + def check_temporal_dim(self, shape) -> bool: + if self.img_temporal_dim_size is not None: + if shape[2] == self.img_temporal_dim_size: + return True + + return False + + def forward(self, x, masks=None, training=False): + """ + :param x: input image/video + :param masks: indices of patch tokens to mask (remove) + """ + if masks is not None and not isinstance(masks, list): + masks = [masks] + + if x.ndim == 4: + _, _, H, W = x.shape + T = 1 + elif x.ndim == 5: + _, _, T, H, W = x.shape + if self.check_temporal_dim(x.shape): + T = T // 1 + else: + T = T // self.tubelet_size + + H_patches = H // self.patch_size + W_patches = W // self.patch_size + if not self.handle_nonsquare_inputs: + T = H_patches = W_patches = None + + if not self.use_rope: + pos_embed = self.interpolate_pos_encoding(x, self.pos_embed) + + if self.check_temporal_dim(x.shape): + assert self.patch_embed_img is not None + x = self.patch_embed_img(x) + mode = "img" + if self.modality_embedding: + x += self.img_mod_embed.repeat(x.shape[0], 1, 1) + else: + x = self.patch_embed(x) + mode = "video" + if self.modality_embedding: + x += self.video_mod_embed.repeat(x.shape[0], 1, 1) + + if not self.use_rope: + x += pos_embed + + if masks is not None: + x = apply_masks(x, masks) + masks = torch.cat(masks, dim=0) + + outs = [] + hier = [] + for i, blk in enumerate(self.blocks): + if self.use_activation_checkpointing: + x, attn = torch.utils.checkpoint.checkpoint( + blk, + x, + masks, + T=T, + H_patches=H_patches, + W_patches=W_patches, + use_reentrant=False, + return_attn=self.attn_out, + mode=mode, + ) + else: + x, attn = blk( + x, + mask=masks, + T=T, + H_patches=H_patches, + W_patches=W_patches, + return_attn=self.attn_out, + mode=mode, + ) + + if self.out_layers is not None and i in self.out_layers: + out_idx = self.hierarchical_layers.index(i) + out_norm = self.norms_block[out_idx](x) + outs.append(out_norm) + + if i in self.out_layers_distillation: + out_idx = self.hierarchical_layers.index(i) + hier.append(self.norms_block[out_idx](x)) + + if self.out_layers is not None: + return outs + + if training or self.return_hierarchical: + hier = torch.cat(hier, dim=2) + return hier + else: + x = self.norms_block[-1](x) + return x + + def interpolate_pos_encoding(self, x, pos_embed): + + _, N, dim = pos_embed.shape + + if self.is_video: + + _, _, T, H, W = x.shape + if H == self.img_height and W == self.img_width and T == self.num_frames: + return pos_embed + + elif H == self.img_height and W == self.img_width and T < self.num_frames: + new_N = int( + (T // self.tubelet_size) + * (H // self.patch_size) + * (W // self.patch_size) + ) + return pos_embed[:, :new_N, :] + + T = T // self.tubelet_size + H = H // self.patch_size + W = W // self.patch_size + + N_t = self.num_frames // self.tubelet_size + N_h = self.img_height // self.patch_size + N_w = self.img_width // self.patch_size + assert N_h * N_w * N_t == N, "Positional embedding initialized incorrectly" + + scale_factor = (T / N_t, H / N_h, W / N_w) + + pos_embed = nn.functional.interpolate( + pos_embed.reshape(1, N_t, N_h, N_w, dim).permute(0, 4, 1, 2, 3), + scale_factor=scale_factor, + mode="trilinear", + ) + pos_embed = pos_embed.permute(0, 2, 3, 4, 1).view(1, -1, dim) + return pos_embed + + else: + + _, _, H, W = x.shape + if H == self.img_height and W == self.img_width: + return pos_embed + + npatch = (H // self.patch_size) * (W // self.patch_size) + scale_factor = math.sqrt(npatch / N) + + pos_embed = nn.functional.interpolate( + pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute( + 0, 3, 1, 2 + ), + scale_factor=scale_factor, + mode="bicubic", + ) + pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return pos_embed + + +def vit_synthetic(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, + embed_dim=1, + depth=1, + num_heads=1, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model + + +def vit_tiny(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, + embed_dim=192, + depth=12, + num_heads=3, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model + + +def vit_small(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model + + +def vit_base(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model + + +def vit_large(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model + + +def vit_large_rope(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + use_rope=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model + + +def vit_huge(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, + embed_dim=1280, + depth=32, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model + + +def vit_huge_rope(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, + embed_dim=1280, + depth=32, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + use_rope=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model + + +def vit_giant(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, + embed_dim=1408, + depth=40, + num_heads=16, + mlp_ratio=48 / 11, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model + + +def vit_giant_rope(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, + embed_dim=1408, + depth=40, + num_heads=16, + mlp_ratio=48 / 11, + qkv_bias=True, + use_rope=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model + + +def vit_giant_xformers(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, + embed_dim=1408, + depth=40, + num_heads=22, + mlp_ratio=48 / 11, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model + + +def vit_giant_xformers_rope(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, + embed_dim=1408, + depth=40, + num_heads=22, + mlp_ratio=48 / 11, + qkv_bias=True, + use_rope=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model + + +def vit_gigantic(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, + embed_dim=1664, + depth=48, + num_heads=16, + mlp_ratio=64 / 13, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model + + +def vit_gigantic_xformers(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, + embed_dim=1664, + depth=48, + num_heads=26, + mlp_ratio=64 / 13, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model + + +VIT_EMBED_DIMS = { + "vit_synthetic": 1, + "vit_tiny": 192, + "vit_small": 384, + "vit_base": 768, + "vit_large": 1024, + "vit_huge": 1280, + "vit_giant": 1408, + "vit_gigantic": 1664, +} diff --git a/app/vjepa_2_1/train.py b/app/vjepa_2_1/train.py new file mode 100644 index 00000000..f8fe5949 --- /dev/null +++ b/app/vjepa_2_1/train.py @@ -0,0 +1,835 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +# -- FOR DISTRIBUTED TRAINING ENSURE ONLY 1 DEVICE VISIBLE PER PROCESS +try: + os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["SLURM_LOCALID"] +except Exception: + pass + +import copy +import gc +import random +import time + +import numpy as np +import torch +import torch.multiprocessing as mp +import torch.nn.functional as F +from app.vjepa_2_1.models.utils.masks_dist import compute_mask_distance +from app.vjepa_2_1.models.utils.modules import Lambda_LinearWarmupHold +from app.vjepa_2_1.transforms import make_transforms +from app.vjepa_2_1.utils import ( + init_opt, + init_video_model, + load_checkpoint, + normalize_nested, +) +from src.datasets.data_manager import init_data +from src.masks.multiseq_multiblock3d import MaskCollator +from src.masks.utils import apply_masks +from src.utils.distributed import init_distributed +from src.utils.logging import AverageMeter, CSVLogger, get_logger, gpu_timer +from torch.nn.parallel import DistributedDataParallel + + +log_timings = True +log_freq = 10 +CHECKPOINT_FREQ = 1 +GARBAGE_COLLECT_ITR_FREQ = 50 +MAX_REPEAT_COUNTS = 10 + +_GLOBAL_SEED = 0 +random.seed(_GLOBAL_SEED) +np.random.seed(_GLOBAL_SEED) +torch.manual_seed(_GLOBAL_SEED) +torch.backends.cudnn.benchmark = True + + +logger = get_logger(__name__, force=True) + + +def main(args, resume_preempt=False): + # ----------------------------------------------------------------------- # + # PASSED IN PARAMS FROM CONFIG FILE + # ----------------------------------------------------------------------- # + + # -- META + folder = args.get("folder") + cfgs_meta = args.get("meta") + load_model = cfgs_meta.get("load_checkpoint") or resume_preempt + r_file = cfgs_meta.get("read_checkpoint", None) + seed = cfgs_meta.get("seed", _GLOBAL_SEED) + save_every_freq = cfgs_meta.get("save_every_freq", -1) + skip_batches = cfgs_meta.get("skip_batches", -1) + use_sdpa = cfgs_meta.get("use_sdpa", False) + sync_gc = cfgs_meta.get("sync_gc", False) + logger.info(f"LD_PRELOAD: {os.environ.get('LD_PRELOAD')}") + which_dtype = cfgs_meta.get("dtype") + logger.info(f"{which_dtype=}") + if which_dtype.lower() == "bfloat16": + dtype = torch.bfloat16 + mixed_precision = True + elif which_dtype.lower() == "float16": + dtype = torch.float16 + mixed_precision = True + else: + dtype = torch.float32 + mixed_precision = False + + # -- MASK + cfgs_mask = args.get("mask") + + # -- MODEL + cfgs_model = args.get("model") + compile_model = cfgs_model.get("compile_model", False) + use_activation_checkpointing = cfgs_model.get("use_activation_checkpointing", False) + model_name = cfgs_model.get("model_name") + pred_depth = cfgs_model.get("pred_depth") + pred_num_heads = cfgs_model.get("pred_num_heads", None) + pred_embed_dim = cfgs_model.get("pred_embed_dim") + uniform_power = cfgs_model.get("uniform_power", False) + use_mask_tokens = cfgs_model.get("use_mask_tokens", False) + zero_init_mask_tokens = cfgs_model.get("zero_init_mask_tokens", True) + use_rope = cfgs_model.get("use_rope", False) + use_silu = cfgs_model.get("use_silu", False) + use_pred_silu = cfgs_model.get("use_pred_silu", False) + wide_silu = cfgs_model.get("wide_silu", True) + is_causal = cfgs_model.get("is_causal", False) + pred_is_causal = cfgs_model.get("pred_is_causal", False) + init_type = cfgs_model.get("init_type", "default") + img_temporal_dim_size = cfgs_model.get("img_temporal_dim_size", None) + n_registers = cfgs_model.get("n_registers", 0) + has_cls_first = cfgs_model.get("has_cls_first", False) + interpolate_rope = cfgs_model.get("interpolate_rope", False) + lambda_value_img = cfgs_model.get("lambda_value_img", 0.0) + lambda_value_vid = cfgs_model.get("lambda_value_vid", 0.0) + n_registers_predictor = cfgs_model.get("n_registers_predictor", 0) + lambda_progressive = cfgs_model.get("lambda_progressive", True) + normalize_predictor = cfgs_model.get("normalize_predictor", False) + modality_embedding = cfgs_model.get("modality_embedding", False) + levels_predictor = cfgs_model.get("levels_predictor", 4) + if model_name == "vit_large": + embed_dim_encoder = 1024 + elif model_name == "vit_giant_xformers": + embed_dim_encoder = 1408 + elif model_name == "vit_gigantic_xformers": + embed_dim_encoder = 1664 + else: + print("Model name not recognized :(") + + # -- DATA + cfgs_data = args.get("data") + dataset_type = cfgs_data.get("dataset_type", "videodataset") + dataset_paths = cfgs_data.get("datasets", []) + datasets_weights = cfgs_data.get("datasets_weights") + dataset_fpcs = cfgs_data.get("dataset_fpcs") + max_num_frames = max(dataset_fpcs) + batch_size = cfgs_data.get("batch_size") + tubelet_size = cfgs_data.get("tubelet_size") + fps = cfgs_data.get("fps") + crop_size = cfgs_data.get("crop_size", 224) + patch_size = cfgs_data.get("patch_size") + grid_size = crop_size // patch_size + pin_mem = cfgs_data.get("pin_mem", False) + num_workers = cfgs_data.get("num_workers", 1) + + # -- IMG DATA + cfgs_img_data = args.get("img_data") + img_rank_ratio = 0.25 + img_mask = None + if cfgs_img_data is not None: + img_dataset_type = cfgs_img_data.get("dataset_type", "imagenet") + img_dataset_paths = cfgs_img_data.get("datasets", []) + img_dataset_weights = cfgs_img_data.get("datasets_weights", []) + img_dataset_fpcs = cfgs_img_data.get("dataset_fpcs") + img_dataset_batch_size = cfgs_img_data.get("batch_size") + img_rank_ratio = cfgs_img_data.get("rank_ratio", img_rank_ratio) + img_num_workers = cfgs_img_data.get("num_workers", num_workers) + + img_mask = args.get("img_mask", img_mask) + + # -- DATA AUGS + cfgs_data_aug = args.get("data_aug") + ar_range = cfgs_data_aug.get("random_resize_aspect_ratio", [3 / 4, 4 / 3]) + rr_scale = cfgs_data_aug.get("random_resize_scale", [0.3, 1.0]) + motion_shift = cfgs_data_aug.get("motion_shift", False) + reprob = cfgs_data_aug.get("reprob", 0.0) + use_aa = cfgs_data_aug.get("auto_augment", False) + + # -- LOSS + cfgs_loss = args.get("loss") + loss_exp = cfgs_loss.get("loss_exp") + shift_by_n = cfgs_loss.get("shift_by_n") + predict_all = cfgs_loss.get("predict_all", True) + weight_distance_loss = cfgs_loss.get("weight_distance_loss", False) + offset_context_loss = cfgs_loss.get("offset_context_loss", False) + + # -- OPTIMIZATION + cfgs_opt = args.get("optimization") + is_anneal = cfgs_opt.get("is_anneal", False) + anneal_ckpt = cfgs_opt.get("anneal_ckpt", None) + if is_anneal and anneal_ckpt is None: + raise ValueError("Must specify anneal_ckpt if is_anneal is True") + resume_anneal = cfgs_opt.get("resume_anneal", False) or ( + is_anneal and resume_preempt + ) + ipe = cfgs_opt.get("ipe", None) + ipe_scale = cfgs_opt.get("ipe_scale", 1.0) + wd = float(cfgs_opt.get("weight_decay")) + final_wd = float(cfgs_opt.get("final_weight_decay")) + num_epochs = cfgs_opt.get("epochs") + warmup = cfgs_opt.get("warmup") + start_lr = cfgs_opt.get("start_lr") + lr = cfgs_opt.get("lr") + final_lr = cfgs_opt.get("final_lr") + ema = cfgs_opt.get("ema") + use_radamw = cfgs_opt.get("use_radamw", False) + betas = cfgs_opt.get("betas", (0.9, 0.999)) + eps = cfgs_opt.get("eps", 1.0e-8) + loss_reg_std_mult = cfgs_opt.get("loss_reg_std_mult", None) + loss_reg_num_tracking_steps = cfgs_opt.get("loss_reg_num_tracking_steps", 300) + loss_reg_min_epoch = cfgs_opt.get("loss_reg_min_epoch", 50) + if loss_reg_std_mult is not None: + logger.info("Loss regulation activated") + # ----------------------------------------------------------------------- # + + np.random.seed(seed) + torch.manual_seed(seed) + torch.backends.cudnn.benchmark = True + try: + mp.set_start_method("spawn") + except Exception: + pass + + # -- init torch distributed backend + world_size, rank = init_distributed() + data_world_size, data_rank = world_size, rank + logger.info(f"Initialized (rank/world-size) {rank}/{world_size}") + img_world_size = 0 + + # make adjustments to batch size for image data + model_fpcs = dataset_fpcs + model_cfgs_mask = cfgs_mask + model_tubelet_size = tubelet_size + if cfgs_img_data is not None: + img_world_size = int(world_size * img_rank_ratio) + num_video_ranks = world_size - img_world_size + img_total_batch_size = img_dataset_batch_size * world_size + video_total_batch_size = batch_size * world_size + + if img_total_batch_size % img_world_size != 0: + raise ValueError( + f"img_total_batch_size ({img_total_batch_size}) must be divisible by num_img_ranks ({img_world_size})" + ) + if video_total_batch_size % num_video_ranks != 0: + raise ValueError( + f"video_total_batch_size ({video_total_batch_size}) must be divisible by num_video_ranks ({num_video_ranks})" + ) + + # img_dataset_batch_size = img_total_batch_size // img_world_size + batch_size = video_total_batch_size // num_video_ranks + + if rank < int(world_size * img_rank_ratio): + crop_size = cfgs_img_data.get("crop_size", 512) + grid_size = crop_size // patch_size + + if rank < int(world_size * img_rank_ratio): + logger.info( + f"On rank {rank}, updating dataset with dataset type {img_dataset_type}" + ) + if img_temporal_dim_size is not None: + if img_dataset_fpcs[0] != 1: + raise NotImplementedError( + "Image loader only supports 1 frame per clip with img_temporal_dim_size=1" + ) + tubelet_size = 1 + else: + tubelet_size = tubelet_size + + dataset_type = img_dataset_type + dataset_paths = img_dataset_paths + datasets_weights = img_dataset_weights + dataset_fpcs = img_dataset_fpcs + batch_size = img_dataset_batch_size + num_workers = img_num_workers + if img_mask is not None: + logger.info("Using image mask") + cfgs_mask = img_mask + + data_rank = rank + data_world_size = img_world_size + lambda_value = lambda_value_img # We select a different lambda value depending on video vs. image + else: + data_rank = rank - img_world_size + data_world_size = world_size - img_world_size + lambda_value = lambda_value_vid # We select a different lambda value depending on video vs. image + + logger.info( + f"For rank {rank} with world size {world_size}, " + f"we have total image batch size {img_total_batch_size}, total video batch size {video_total_batch_size}, " + f"image ranks: {img_world_size}, video ranks: {num_video_ranks}, " + f"using the following params: " + f"dataset_type: {dataset_type}, " + f"dataset_paths: {dataset_paths}, " + f"datasets_weights: {datasets_weights}, " + f"dataset_fpcs: {dataset_fpcs}, " + f"batch_size: {batch_size}, " + f"num_workers: {num_workers}, " + f"data_rank: {data_rank}, " + f"data_world_size: {data_world_size}" + f"lambda_value for the context loss: {lambda_value}" + ) + else: + lambda_value = lambda_value_vid + + # -- set device + if not torch.cuda.is_available(): + device = torch.device("cpu") + else: + device = torch.device("cuda:0") + torch.cuda.set_device(device) + + # -- log/checkpointing paths + log_file = os.path.join(folder, f"log_r{rank}.csv") + latest_file = "latest.pth.tar" + latest_path = os.path.join(folder, latest_file) + + load_path = None + if load_model: + if is_anneal: + if os.path.exists(latest_path) and resume_anneal: + load_path = latest_path + else: + load_path = anneal_ckpt + resume_anneal = False + else: + load_path = r_file if r_file is not None else latest_path + if not os.path.exists(load_path): + load_path = None + load_model = False + + # -- make csv_logger + csv_logger = CSVLogger( + log_file, + ("%d", "epoch"), + ("%d", "itr"), + ("%.5f", "loss"), + ("%d", "iter-time(ms)"), + ("%d", "gpu-time(ms)"), + ("%d", "dataload-time(ms)"), + ) + + # -- init model + encoder, predictor = init_video_model( + uniform_power=uniform_power, + use_mask_tokens=use_mask_tokens, + num_mask_tokens=int(len(model_cfgs_mask) * len(model_fpcs)), + zero_init_mask_tokens=zero_init_mask_tokens, + device=device, + patch_size=patch_size, + max_num_frames=max_num_frames, + tubelet_size=model_tubelet_size, + model_name=model_name, + crop_size=crop_size, + pred_depth=pred_depth, + pred_num_heads=pred_num_heads, + pred_embed_dim=pred_embed_dim, + is_causal=is_causal, + pred_is_causal=pred_is_causal, + use_sdpa=use_sdpa, + use_silu=use_silu, + use_pred_silu=use_pred_silu, + wide_silu=wide_silu, + use_rope=use_rope, + use_activation_checkpointing=use_activation_checkpointing, + return_all_tokens=predict_all, + chop_last_n_tokens=shift_by_n, + init_type=init_type, + img_temporal_dim_size=img_temporal_dim_size, + n_registers=n_registers, + n_registers_predictor=n_registers_predictor, + has_cls_first=has_cls_first, + interpolate_rope=interpolate_rope, + modality_embedding=modality_embedding, + ) + target_encoder = copy.deepcopy(encoder) + + if compile_model: + logger.info("Compiling encoder, target_encoder, and predictor.") + torch._dynamo.config.optimize_ddp = False + encoder.compile() + target_encoder.compile() + predictor.compile() + + mask_collator = MaskCollator( + cfgs_mask=cfgs_mask, + dataset_fpcs=dataset_fpcs, + crop_size=crop_size, + patch_size=patch_size, + tubelet_size=tubelet_size, + ) + + transform = make_transforms( + random_horizontal_flip=True, + random_resize_aspect_ratio=ar_range, + random_resize_scale=rr_scale, + reprob=reprob, + auto_augment=use_aa, + motion_shift=motion_shift, + crop_size=crop_size, + ) + + # -- init data-loaders/samplers + (unsupervised_loader, unsupervised_sampler) = init_data( + data=dataset_type, + root_path=dataset_paths, + batch_size=batch_size, + training=True, + # clip_len=clip_len, + dataset_fpcs=dataset_fpcs, + fps=fps, + transform=transform, + rank=data_rank, + world_size=data_world_size, + datasets_weights=datasets_weights, + collator=mask_collator, + num_workers=num_workers, + pin_mem=pin_mem, + log_dir=None, + ) + try: + _dlen = len(unsupervised_loader) + except Exception: + try: + _dlen = unsupervised_loader.num_batches + except Exception: + _dlen = -1 + if ipe is None: + ipe = _dlen + logger.info(f"Using batch size of {batch_size}, fpcs of {dataset_fpcs}") + logger.info(f"iterations per epoch/dataset length: {ipe}/{_dlen}") + + # zizi + + # -- init optimizer and scheduler + optimizer, scaler, scheduler, wd_scheduler = init_opt( + is_anneal=is_anneal, + encoder=encoder, + predictor=predictor, + use_radamw=use_radamw, + wd=wd, + final_wd=final_wd, + start_lr=start_lr, + ref_lr=lr, + final_lr=final_lr, + iterations_per_epoch=ipe, + warmup=warmup, + num_epochs=num_epochs, + ipe_scale=ipe_scale, + mixed_precision=mixed_precision, + betas=betas, + eps=eps, + ) + encoder = DistributedDataParallel(encoder, static_graph=True) + predictor = DistributedDataParallel( + predictor, static_graph=False, find_unused_parameters=True + ) + target_encoder = DistributedDataParallel(target_encoder) + for p in target_encoder.parameters(): + p.requires_grad = False + + # -- momentum schedule + momentum_scheduler = ( + ema[0] + i * (ema[1] - ema[0]) / (ipe * num_epochs * ipe_scale) + for i in range(int(ipe * num_epochs) + 1) + ) + lambda_sched = Lambda_LinearWarmupHold(lambda_value=lambda_value) + + start_epoch = 0 + # -- load training checkpoint + print("Loadind checkpoint from: ", load_path) + if load_model or os.path.exists(latest_path): + ( + encoder, + predictor, + target_encoder, + optimizer, + scaler, + start_epoch, + ) = load_checkpoint( + r_path=load_path, + encoder=encoder, + predictor=predictor, + target_encoder=target_encoder, + opt=optimizer, + scaler=scaler, + is_anneal=is_anneal and not resume_anneal, + ) + if not is_anneal or resume_anneal: + for _ in range(start_epoch * ipe): + scheduler.step() + wd_scheduler.step() + next(momentum_scheduler) + mask_collator.step() + + def save_checkpoint(epoch, path): + if rank != 0: + return + save_dict = { + "encoder": encoder.state_dict(), + "predictor": predictor.state_dict(), + "opt": optimizer.state_dict(), + "scaler": None if scaler is None else scaler.state_dict(), + "target_encoder": target_encoder.state_dict(), + "epoch": epoch, + "loss": loss_meter.avg, + "batch_size": batch_size, + "world_size": world_size, + "lr": lr, + } + try: + torch.save(save_dict, path) + except Exception as e: + logger.info(f"Encountered exception when saving checkpoint: {e}") + + logger.info("Initializing loader...") + unsupervised_sampler.set_epoch(start_epoch) + loader = iter(unsupervised_loader) + + if skip_batches > 0: + logger.info(f"Skip {skip_batches} batches") + + for itr in range(skip_batches): + if itr % 10 == 0: + logger.info(f"Skip {itr}/{skip_batches} batches") + try: + _ = next(loader) + except Exception: + loader = iter(unsupervised_loader) + _ = next(loader) + + if sync_gc: + gc.disable() + gc.collect() + + trailing_losses = [] + step_count = 0 + + # -- TRAINING LOOP + for epoch in range(start_epoch, num_epochs): + logger.info("Epoch %d" % (epoch + 1)) + + loss_meter = AverageMeter() + mask_meters = {fpc: AverageMeter() for fpc in dataset_fpcs} + iter_time_meter = AverageMeter() + gpu_time_meter = AverageMeter() + data_elapsed_time_meter = AverageMeter() + + for itr in range(ipe): + itr_start_time = time.time() + + iter_retries = 0 + iter_successful = False + while not iter_successful: + try: + sample = next(loader) + iter_successful = True + except StopIteration: + logger.info("Exhausted data loaders. Refreshing...") + if "airstore" in dataset_type.lower(): + unsupervised_sampler.increase_epoch() + else: + unsupervised_sampler.set_epoch(epoch) + loader = iter(unsupervised_loader) + except Exception as e: + NUM_RETRIES = 5 + if iter_retries < NUM_RETRIES: + logger.warning( + f"Encountered exception when loading data (num retries {iter_retries}):\n{e}" + ) + iter_retries += 1 + time.sleep(5) + else: + raise RuntimeError( + f"Exceeded max retries ({NUM_RETRIES}) when loading data." + ) from e + + for _fpc_sample in sample: + bs, fpc = _fpc_sample[0][-1][0].size() + mask_meters[fpc].update(bs / batch_size) + + def load_clips(): + all_clips, all_masks_enc, all_masks_pred = [], [], [] + for fpc_sample in sample: + udata, masks_enc, masks_pred = fpc_sample + all_clips += [udata[0][0].to(device, non_blocking=True)] + all_masks_enc += [ + [m.to(device, non_blocking=True) for m in masks_enc] + ] + all_masks_pred += [ + [m.to(device, non_blocking=True) for m in masks_pred] + ] + return all_clips, all_masks_enc, all_masks_pred + + clips, masks_enc, masks_pred = load_clips() + data_elapsed_time_ms = (time.time() - itr_start_time) * 1000.0 + + if sync_gc and (itr + 1) % GARBAGE_COLLECT_ITR_FREQ == 0: + logger.info("Running garbage collection...") + gc.collect() + + def train_step(): + _new_lr = scheduler.step() + _new_wd = wd_scheduler.step() + + def forward_target(c, embed_dim=embed_dim_encoder): + with torch.no_grad(): + h = target_encoder(c, gram_mode=False, training_mode=True) + new_h = [] + for hi in h: + if levels_predictor > 1: + hi_0 = F.layer_norm(hi[:, :, :embed_dim], (embed_dim,)) + hi_1 = F.layer_norm( + hi[:, :, embed_dim : embed_dim * 2], + (embed_dim,), + ) + hi_2 = F.layer_norm( + hi[:, :, embed_dim * 2 : embed_dim * 3], + (embed_dim,), + ) + hi_3 = F.layer_norm(hi[:, :, -embed_dim:], (embed_dim,)) + hi_norm = torch.cat([hi_0, hi_1, hi_2, hi_3], dim=2) + new_h.append(hi_norm) + else: + new_h.append(F.layer_norm(hi, (hi.size(-1),))) + return new_h + + def forward_context(clips, embed_dim=embed_dim_encoder): + modality = "video" + if img_temporal_dim_size is not None: + if clips[0].shape[2] == img_temporal_dim_size: + modality = "image" + z = encoder(clips, masks_enc, gram_mode=False, training_mode=True) + z_pred, z_context = predictor( + z, masks_enc, masks_pred, mod=modality + ) + if normalize_predictor: + z_pred = normalize_nested(z_pred, embed_dim) + + if predict_all: + z_context = normalize_nested(z_context, embed_dim) + return z_pred, z_context + + def loss_fn(z, h, masks_to_apply, cls_loss, d_weights): + if cls_loss: + h_cls = [hi[:, 0].unsqueeze(1) for hi in h] + h = [ + apply_masks(hi[:, 1:], mi, concat=False) + for hi, mi in zip(h, masks_to_apply) + ] + loss, n = 0, 0 + for zi, hi, hi_cls in zip(z, h, h_cls): + for zij, hij in zip(zi, hi): + h_term = torch.cat([hi_cls, hij], dim=1) + loss += ( + torch.mean(torch.abs(zij - h_term) ** loss_exp) + / loss_exp + ) + n += 1 + + loss /= n + return loss + else: + h = [ + apply_masks(hi, mi, concat=False) + for hi, mi in zip(h, masks_to_apply) + ] + + if d_weights is not None: + loss, n = 0, 0 + for zi, hi, d_i in zip(z, h, d_weights): + for zij, hij, d_ij in zip(zi, hi, d_i): + loss_n = torch.abs(zij - hij) ** loss_exp * ( + 1 / d_ij.unsqueeze(2) + ) + loss += torch.mean(loss_n) / loss_exp + n += 1 + loss /= n + return loss + else: + loss, n = 0, 0 + for zi, hi in zip(z, h): + for zij, hij in zip(zi, hi): + loss += ( + torch.mean(torch.abs(zij - hij) ** loss_exp) + / loss_exp + ) + n += 1 + loss /= n + return loss + + # Step 1. Forward + with torch.cuda.amp.autocast(dtype=dtype, enabled=mixed_precision): + h = forward_target(clips) + z_pred, z_context = forward_context(clips) + loss = 0 + loss_pred = loss_fn( + z_pred, h, masks_pred, cls_loss=has_cls_first, d_weights=None + ) + loss += loss_pred + + # Context loss + if predict_all: + distance_weights = compute_mask_distance( + masks_pred, masks_enc, grid_size, offset_context_loss + ) + if weight_distance_loss: + d_weights = distance_weights + else: + d_weights = None + loss_context = loss_fn( + z_context, h, masks_enc, cls_loss=False, d_weights=d_weights + ) + if lambda_progressive: + lambda_value_step = lambda_sched.value(epoch * ipe + itr) + else: + lambda_value_step = lambda_value + loss += loss_context * lambda_value_step + + # Step 2. Backward & step + run_step = True + if loss_reg_std_mult is not None: + meanval = np.mean(trailing_losses) + stdval = np.std(trailing_losses) + max_bound = meanval + loss_reg_std_mult * stdval + if ( + loss > max_bound + and epoch > loss_reg_min_epoch + and len(trailing_losses) + > int(0.5 * loss_reg_num_tracking_steps) + ): + run_step = False + loss.backward() + logger.info( + f"Loss {loss} is above bound {meanval} + {loss_reg_std_mult} * {stdval}. Skipping step." + ) + + if run_step: + if mixed_precision: + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + else: + loss.backward() + if mixed_precision: + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() + optimizer.zero_grad() + + # Step 3. momentum update of target encoder + m = min(next(momentum_scheduler), ema[1]) + with torch.no_grad(): + params_k = [] + params_q = [] + for param_q, param_k in zip( + encoder.parameters(), target_encoder.parameters() + ): + params_k.append(param_k) + params_q.append(param_q) + torch._foreach_mul_(params_k, m) + torch._foreach_add_(params_k, params_q, alpha=1 - m) + + return ( + float(loss), + _new_lr, + _new_wd, + run_step, + ) + + ( + loss, + _new_lr, + _new_wd, + run_step, + ), gpu_etime_ms = gpu_timer(train_step) + iter_elapsed_time_ms = (time.time() - itr_start_time) * 1000.0 + loss_meter.update(loss) + iter_time_meter.update(iter_elapsed_time_ms) + gpu_time_meter.update(gpu_etime_ms) + data_elapsed_time_meter.update(data_elapsed_time_ms) + + if loss_reg_std_mult is not None: + if run_step: + trailing_losses.append(loss) + if len(trailing_losses) > loss_reg_num_tracking_steps: + trailing_losses = trailing_losses[1:] + else: + step_count += 1 + if step_count > MAX_REPEAT_COUNTS: + raise RuntimeError( + "Loss is above bound for too many tries. Exiting." + ) + + # -- Logging + def log_stats(): + csv_logger.log( + epoch + 1, + itr, + loss, + iter_elapsed_time_ms, + gpu_etime_ms, + data_elapsed_time_ms, + ) + if ( + (itr % log_freq == 0) + or (itr == ipe - 1) + or np.isnan(loss) + or np.isinf(loss) + ): + logger.info( + "[%d, %5d] loss: %.3f " + "masks: %s " + "[wd: %.2e] [lr: %.2e] " + "[mem: %.2e] " + "[iter: %.1f ms] " + "[gpu: %.1f ms] " + "[data: %.1f ms]" + % ( + epoch + 1, + itr, + loss_meter.avg, + "[" + + ", ".join( + [ + f"{k}: " + "%.1f" % mask_meters[k].avg + for k in mask_meters + ] + ) + + "]", + _new_wd, + _new_lr, + torch.cuda.max_memory_allocated() / 1024.0**2, + iter_time_meter.avg, + gpu_time_meter.avg, + data_elapsed_time_meter.avg, + ) + ) + + log_stats() + assert not np.isnan(loss), "loss is nan" + + # -- Save Checkpoint + logger.info("avg. loss %.3f" % loss_meter.avg) + if (epoch + 1) % CHECKPOINT_FREQ == 0 or epoch == (num_epochs - 1): + save_checkpoint(epoch + 1, latest_path) + if save_every_freq > 0 and (epoch + 1) % save_every_freq == 0: + save_every_file = f"e{epoch}.pth.tar" + save_every_path = os.path.join(folder, save_every_file) + save_checkpoint(epoch + 1, save_every_path) diff --git a/app/vjepa_2_1/transforms.py b/app/vjepa_2_1/transforms.py new file mode 100644 index 00000000..0fad7ef3 --- /dev/null +++ b/app/vjepa_2_1/transforms.py @@ -0,0 +1,139 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np + +import src.datasets.utils.video.transforms as video_transforms +import torch +import torchvision.transforms as transforms +from PIL import Image +from src.datasets.utils.video.randerase import RandomErasing + + +def make_transforms( + random_horizontal_flip=True, + random_resize_aspect_ratio=(3 / 4, 4 / 3), + random_resize_scale=(0.3, 1.0), + reprob=0.0, + auto_augment=False, + motion_shift=False, + crop_size=224, + normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), +): + _frames_augmentation = VideoTransform( + random_horizontal_flip=random_horizontal_flip, + random_resize_aspect_ratio=random_resize_aspect_ratio, + random_resize_scale=random_resize_scale, + reprob=reprob, + auto_augment=auto_augment, + motion_shift=motion_shift, + crop_size=crop_size, + normalize=normalize, + ) + return _frames_augmentation + + +class VideoTransform(object): + def __init__( + self, + random_horizontal_flip=True, + random_resize_aspect_ratio=(3 / 4, 4 / 3), + random_resize_scale=(0.3, 1.0), + reprob=0.0, + auto_augment=False, + motion_shift=False, + crop_size=224, + normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ): + self.random_horizontal_flip = random_horizontal_flip + self.random_resize_aspect_ratio = random_resize_aspect_ratio + self.random_resize_scale = random_resize_scale + self.auto_augment = auto_augment + self.motion_shift = motion_shift + self.crop_size = crop_size + self.mean = torch.tensor(normalize[0], dtype=torch.float32) + self.std = torch.tensor(normalize[1], dtype=torch.float32) + if not self.auto_augment: + self.mean *= 255.0 + self.std *= 255.0 + + self.autoaug_transform = video_transforms.create_random_augment( + input_size=(crop_size, crop_size), + auto_augment="rand-m7-n4-mstd0.5-inc1", + interpolation="bicubic", + ) + + self.spatial_transform = ( + video_transforms.random_resized_crop_with_shift + if motion_shift + else video_transforms.random_resized_crop + ) + + self.reprob = reprob + self.erase_transform = RandomErasing( + reprob, + mode="pixel", + max_count=1, + num_splits=1, + device="cpu", + ) + + def __call__(self, buffer): + # Handle PIL Image input (from ImageNet datasets) + if isinstance(buffer, Image.Image): + # Convert PIL Image to tensor with shape T H W C (T=1 for single image) + buffer = np.array(buffer) # H W C + buffer = np.expand_dims(buffer, axis=0) # T H W C where T=1 + buffer = torch.tensor(buffer, dtype=torch.float32) + + if self.auto_augment: + buffer = [transforms.ToPILImage()(frame) for frame in buffer] + buffer = self.autoaug_transform(buffer) + buffer = [transforms.ToTensor()(img) for img in buffer] + buffer = torch.stack(buffer) # T C H W + buffer = buffer.permute(0, 2, 3, 1) # T H W C + elif torch.is_tensor(buffer): + buffer = buffer.to(torch.float32) + else: + buffer = torch.tensor(buffer, dtype=torch.float32) + + buffer = buffer.permute(3, 0, 1, 2) # T H W C -> C T H W + + buffer = self.spatial_transform( + images=buffer, + target_height=self.crop_size, + target_width=self.crop_size, + scale=self.random_resize_scale, + ratio=self.random_resize_aspect_ratio, + ) + if self.random_horizontal_flip: + buffer, _ = video_transforms.horizontal_flip(0.5, buffer) + + buffer = _tensor_normalize_inplace(buffer, self.mean, self.std) + if self.reprob > 0: + buffer = buffer.permute(1, 0, 2, 3) + buffer = self.erase_transform(buffer) + buffer = buffer.permute(1, 0, 2, 3) + + return buffer + + +def _tensor_normalize_inplace(tensor, mean, std): + """ + Normalize a given tensor by subtracting the mean and dividing the std. + Args: + tensor (tensor): tensor to normalize (with dimensions C, T, H, W). + mean (tensor): mean value to subtract (in 0 to 255 floats). + std (tensor): std to divide (in 0 to 255 floats). + """ + if tensor.dtype == torch.uint8: + tensor = tensor.float() + + C, T, H, W = tensor.shape + tensor = tensor.view(C, -1).permute(1, 0) + tensor.sub_(mean).div_(std) + tensor = tensor.permute(1, 0).view(C, T, H, W) + return tensor diff --git a/app/vjepa_2_1/utils.py b/app/vjepa_2_1/utils.py new file mode 100644 index 00000000..d5d32fd6 --- /dev/null +++ b/app/vjepa_2_1/utils.py @@ -0,0 +1,368 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import sys + +import app.vjepa_2_1.models.predictor as vit_pred +import app.vjepa_2_1.models.vision_transformer as video_vit +import torch +import torch.nn.functional as F +import yaml +from app.vjepa_2_1.wrappers import MultiSeqWrapper, PredictorMultiSeqWrapper +from src.utils.checkpoint_loader import robust_checkpoint_loader +from src.utils.schedulers import ( + CosineWDSchedule, + LinearDecaySchedule, + WarmupCosineSchedule, +) + +logging.basicConfig(stream=sys.stdout, level=logging.INFO) +logger = logging.getLogger() + + +def normalize_and_concat(tensor, embed_dim): + """Split tensor into 4 chunks of size embed_dim along the last axis, + apply LayerNorm to each chunk, then concatenate back.""" + chunks = [ + F.layer_norm(tensor[:, :, i * embed_dim : (i + 1) * embed_dim], (embed_dim,)) + for i in range(4) + ] + return torch.cat(chunks, dim=2) + + +def normalize_nested(nested, embed_dim): + """Apply normalize_and_concat recursively over nested lists.""" + return [ + [[normalize_and_concat(z, embed_dim) for z in inner] for inner in outer] + for outer in nested + ] + + +def build_eval_args( + model_name, + patch_size, + tubelet_size, + num_frames, + logging_folder, + checkpoint, + write_tag, + eval_cfg_paths, + uniform_power=False, + use_sdpa=False, + clip_duration=None, + use_silu=False, + wide_silu=True, + tag=None, +): + """ + Helper function to parse the pre-training configs to construct the + evaluation configs, return as a list of eval configs. + """ + import warnings + + if eval_cfg_paths is None: + logger.info("No evaluations specified!") + return + + eval_nodes = None + eval_tasks_per_node = None + args_eval = [] + for i, f in enumerate(eval_cfg_paths): + with open(f, "r") as y_file: + _args = yaml.load(y_file, Loader=yaml.FullLoader) + _tag = _args.get("tag", "") + _args["tag"] = f"{tag}-{_tag}" + _nodes = _args.get("nodes", None) + _tasks = _args.get("tasks_per_node", 8) + eval_nodes = _nodes if eval_nodes is None else eval_nodes + eval_tasks_per_node = ( + _tasks if eval_tasks_per_node is None else eval_tasks_per_node + ) + if (eval_nodes != _nodes) or (eval_tasks_per_node != _tasks): + warnings.warn( + "Configs for online evals must use same number of nodes for slurm-batch processing" + ) + + _args["pretrain"] = {} + _args["pretrain"]["model_name"] = model_name + _args["pretrain"]["patch_size"] = patch_size + _args["pretrain"]["tubelet_size"] = tubelet_size + _args["pretrain"]["uniform_power"] = uniform_power + _args["pretrain"]["use_sdpa"] = use_sdpa + _args["pretrain"]["clip_duration"] = clip_duration + _args["pretrain"]["use_silu"] = use_silu + _args["pretrain"]["wide_silu"] = wide_silu + _args["pretrain"]["frames_per_clip"] = num_frames + _args["pretrain"]["folder"] = logging_folder + _args["pretrain"]["checkpoint"] = checkpoint + _args["pretrain"]["write_tag"] = write_tag + + args_eval += [_args] + + return eval_nodes, eval_tasks_per_node, args_eval + + +def load_checkpoint( + r_path, + encoder, + predictor, + target_encoder, + opt, + scaler, + is_anneal=False, +): + logger.info(f"Loading {r_path}") + checkpoint = robust_checkpoint_loader(r_path, map_location=torch.device("cpu")) + + epoch = 0 + if not is_anneal: + epoch = checkpoint["epoch"] + + pretrained_dict = checkpoint["encoder"] + for k, v in encoder.state_dict().items(): + if k not in pretrained_dict: + logger.info(f'key "{k}" could not be found in loaded state dict') + elif pretrained_dict[k].shape != v.shape: + logger.info( + f'key "{k}" is of different shape in model and loaded state dict' + ) + pretrained_dict[k] = v + msg = encoder.load_state_dict(pretrained_dict, strict=False) + logger.info(f"loaded pretrained encoder from epoch {epoch} with msg: {msg}") + + pretrained_dict = checkpoint["predictor"] + for k, v in predictor.state_dict().items(): + if k not in pretrained_dict: + logger.info(f'key "{k}" could not be found in loaded state dict') + elif pretrained_dict[k].shape != v.shape: + logger.info( + f'key "{k}" is of different shape in model and loaded state dict' + ) + pretrained_dict[k] = v + msg = predictor.load_state_dict(pretrained_dict, strict=False) + logger.info(f"loaded pretrained predictor from epoch {epoch} with msg: {msg}") + + if target_encoder is not None: + pretrained_dict = checkpoint["target_encoder"] + for k, v in target_encoder.state_dict().items(): + if k not in pretrained_dict: + logger.info(f'key "{k}" could not be found in loaded state dict') + elif pretrained_dict[k].shape != v.shape: + logger.info( + f'key "{k}" is of different shape in model and loaded state dict' + ) + pretrained_dict[k] = v + msg = target_encoder.load_state_dict(pretrained_dict, strict=False) + logger.info( + f"loaded pretrained target encoder from epoch {epoch} with msg: {msg}" + ) + + try: + opt.load_state_dict(checkpoint["opt"]) + except ValueError: + print("[warn] Optimizer groups mismatch; reinitializing optimizer.") + if scaler is not None: + scaler.load_state_dict(checkpoint["scaler"]) + logger.info(f"loaded optimizers from epoch {epoch}") + logger.info(f"read-path: {r_path}") + del checkpoint + + return ( + encoder, + predictor, + target_encoder, + opt, + scaler, + epoch, + ) + + +def init_video_model( + device, + patch_size=16, + max_num_frames=16, + tubelet_size=2, + model_name="vit_base", + crop_size=224, + pred_depth=6, + pred_num_heads=None, + pred_embed_dim=384, + uniform_power=False, + use_mask_tokens=False, + num_mask_tokens=2, + zero_init_mask_tokens=True, + use_sdpa=False, + use_rope=False, + use_silu=False, + use_pred_silu=False, + wide_silu=False, + is_causal=False, + pred_is_causal=False, + use_activation_checkpointing=False, + return_all_tokens=False, + chop_last_n_tokens=0, + init_type="default", + img_temporal_dim_size=None, + n_registers=0, + n_registers_predictor=0, + has_cls_first=False, + interpolate_rope=False, + modality_embedding=False, +): + encoder = video_vit.__dict__[model_name]( + img_size=crop_size, + patch_size=patch_size, + num_frames=max_num_frames, + tubelet_size=tubelet_size, + uniform_power=uniform_power, + use_sdpa=use_sdpa, + use_silu=use_silu, + wide_silu=wide_silu, + use_activation_checkpointing=use_activation_checkpointing, + is_causal=is_causal, + use_rope=use_rope, + init_type=init_type, + img_temporal_dim_size=img_temporal_dim_size, + n_registers=n_registers, + has_cls_first=has_cls_first, + interpolate_rope=interpolate_rope, + modality_embedding=modality_embedding, + ) + encoder = MultiSeqWrapper(encoder) + predictor = vit_pred.__dict__["vit_predictor"]( + img_size=crop_size, + use_mask_tokens=use_mask_tokens, + patch_size=patch_size, + num_frames=max_num_frames, + tubelet_size=tubelet_size, + embed_dim=encoder.backbone.embed_dim, + predictor_embed_dim=pred_embed_dim, + depth=pred_depth, + num_heads=( + encoder.backbone.num_heads if pred_num_heads is None else pred_num_heads + ), + uniform_power=uniform_power, + num_mask_tokens=num_mask_tokens, + zero_init_mask_tokens=zero_init_mask_tokens, + use_rope=use_rope, + use_sdpa=use_sdpa, + is_causal=pred_is_causal, + use_silu=use_pred_silu, + wide_silu=wide_silu, + use_activation_checkpointing=use_activation_checkpointing, + return_all_tokens=return_all_tokens, + chop_last_n_tokens=chop_last_n_tokens, + n_registers=n_registers_predictor, + has_cls_first=has_cls_first, + interpolate_rope=interpolate_rope, + modality_embedding=modality_embedding, + img_temporal_dim_size=img_temporal_dim_size, + ) + predictor = PredictorMultiSeqWrapper(predictor) + + encoder.to(device) + predictor.to(device) + logger.info(encoder) + logger.info(predictor) + + def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.info(f"Encoder number of parameters: {count_parameters(encoder)}") + logger.info(f"Predictor number of parameters: {count_parameters(predictor)}") + + return encoder, predictor + + +def init_opt( + is_anneal, + encoder, + predictor, + iterations_per_epoch, + start_lr, + ref_lr, + warmup, + num_epochs, + use_radamw=False, + wd=1e-6, + final_wd=1e-6, + final_lr=0.0, + mixed_precision=False, + ipe_scale=1.25, + betas=(0.9, 0.999), + eps=1e-8, + zero_init_bias_wd=True, +): + param_groups = [ + { + "params": ( + p + for n, p in encoder.named_parameters() + if ("bias" not in n) and (len(p.shape) != 1) + ) + }, + { + "params": ( + p + for n, p in predictor.named_parameters() + if ("bias" not in n) and (len(p.shape) != 1) + ) + }, + { + "params": ( + p + for n, p in encoder.named_parameters() + if ("bias" in n) or (len(p.shape) == 1) + ), + "WD_exclude": zero_init_bias_wd, + "weight_decay": 0, + }, + { + "params": ( + p + for n, p in predictor.named_parameters() + if ("bias" in n) or (len(p.shape) == 1) + ), + "WD_exclude": zero_init_bias_wd, + "weight_decay": 0, + }, + ] + + if use_radamw: + from src.utils.adamw import AdamW as RAdamW + + logger.info("Using Rescaled-AdamW") + optimizer = RAdamW(param_groups, betas=betas, eps=eps) + else: + logger.info("Using AdamW") + optimizer = torch.optim.AdamW(param_groups, betas=betas, eps=eps) + + if not is_anneal: + scheduler = WarmupCosineSchedule( + optimizer, + warmup_steps=int(warmup * iterations_per_epoch), + start_lr=start_lr, + ref_lr=ref_lr, + final_lr=final_lr, + T_max=int(ipe_scale * num_epochs * iterations_per_epoch), + ) + else: + scheduler = LinearDecaySchedule( + optimizer, + ref_lr=ref_lr, + final_lr=final_lr, + T_max=int(ipe_scale * num_epochs * iterations_per_epoch), + ) + wd_scheduler = CosineWDSchedule( + optimizer, + ref_wd=wd, + final_wd=final_wd, + T_max=int(ipe_scale * num_epochs * iterations_per_epoch), + ) + + scaler = torch.cuda.amp.GradScaler() if mixed_precision else None + return optimizer, scaler, scheduler, wd_scheduler diff --git a/app/vjepa_2_1/wrappers.py b/app/vjepa_2_1/wrappers.py new file mode 100644 index 00000000..d31dfb80 --- /dev/null +++ b/app/vjepa_2_1/wrappers.py @@ -0,0 +1,115 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch.nn as nn +import torch.nn.functional as F + + +class MultiSeqWrapper(nn.Module): + + def __init__(self, backbone): + super().__init__() + self.backbone = backbone + self.embed_dim = backbone.embed_dim + + def forward(self, x, masks=None, gram_mode=False, training_mode=False): + """ + :param x: [list] List of Tensors of different seq lengths + :param masks: [list] List of Tensors (index: masks for given seq length) + """ + if masks is None: + outputs = [] + for x_fpc in x: + if gram_mode: + # First we make the image bigger + B, C, T, H, W = x_fpc.shape + x_2d = x_fpc.permute(0, 2, 1, 3, 4).reshape( + B * T, C, H, W + ) # (B*T, C, H, W) + x_up = F.interpolate( + x_2d, scale_factor=2, mode="bicubic", align_corners=False + ) + _, _, H_up, W_up = x_up.shape + x_up = x_up.view(B, T, C, H_up, W_up).permute( + 0, 2, 1, 3, 4 + ) # (B,C,T,H_up,W_up) + + # Then we make the pass through the backbone + out = self.backbone(x_up) + B, N, D = out.shape + H_up_patches = W_up_patches = int( + H_up // 16 + ) # We have hardcoded this to the patch size + if T == 1: + T_up_patches = 1 # In this case, it is a LVD image + else: + T_up_patches = int( + T // 2 + ) # We have hardcoded this to the tubelet size + out_3d = out.view( + B, T_up_patches, H_up_patches, W_up_patches, D + ) # (bs, T, H, W, D) + out_3d = out_3d.permute(0, 4, 1, 2, 3) # (bs, D, T, H, W) + + # Downscale to original 2D size + out_2d = out_3d.permute(0, 2, 1, 3, 4).reshape( + B * T_up_patches, D, H_up_patches, W_up_patches + ) # (B*T, C, H_up, W_up) + out = F.interpolate( + out_2d, + size=(int(H_up_patches // 2), int(W_up_patches // 2)), + mode="bicubic", + align_corners=False, + ) + out = out.view( + B, + T_up_patches, + D, + int(H_up_patches // 2), + int(W_up_patches // 2), + ).permute( + 0, 2, 1, 3, 4 + ) # (B,C,T,H,W) + out = out.permute(0, 2, 3, 4, 1).reshape( + B, + T_up_patches * int(H_up_patches // 2) * int(W_up_patches // 2), + D, + ) # (B,C,T,H,W) -> (B, T, H, W, C) -> (B, T * H * W, C) + outputs.append(out) + else: + outputs.append(self.backbone(x_fpc, training=training_mode)) + return outputs + + outs = [[] for _ in x] + for i, (x_fpc, m_fpc) in enumerate(zip(x, masks)): + for m in m_fpc: + outs[i] += [self.backbone(x_fpc, masks=m, training=training_mode)] + return outs + + +class PredictorMultiSeqWrapper(nn.Module): + + def __init__(self, backbone): + super().__init__() + self.backbone = backbone + + def forward(self, x, masks_x, masks_y, mod="video"): + """ + :param x: [list] List of encoder outputs for different seq lengths + :param masks_x: [list] List of encoder masks + :param masks_y: [list] List of predictor masks + """ + n = 0 + outs_pred = [[] for _ in x] + outs_context = [[] for _ in x] + for i, (x_fpc, mx_fpc, my_fpc) in enumerate(zip(x, masks_x, masks_y)): + for xij, mx, my in zip(x_fpc, mx_fpc, my_fpc): + x_pred, x_context = self.backbone(xij, mx, my, mask_index=i, mod=mod) + outs_pred[i] += [x_pred] + outs_context[i] += [x_context] + n += 1 + return outs_pred, outs_context diff --git a/assets/architecture_vjepa2_1.jpg b/assets/architecture_vjepa2_1.jpg new file mode 100644 index 00000000..3e75ac7c Binary files /dev/null and b/assets/architecture_vjepa2_1.jpg differ diff --git a/assets/bars_teaser_tikz-1.png b/assets/bars_teaser_tikz-1.png new file mode 100644 index 00000000..682c74ac Binary files /dev/null and b/assets/bars_teaser_tikz-1.png differ diff --git a/assets/teaser_screenshot_5dice.png b/assets/teaser_screenshot_5dice.png new file mode 100644 index 00000000..25bf2556 Binary files /dev/null and b/assets/teaser_screenshot_5dice.png differ diff --git a/configs/eval/vitg-384/in1k.yaml b/configs/eval/vitg-384/in1k.yaml index 1b5c70ec..99513dae 100644 --- a/configs/eval/vitg-384/in1k.yaml +++ b/configs/eval/vitg-384/in1k.yaml @@ -13,8 +13,7 @@ experiment: data: dataset_name: ImageNet num_classes: 1000 - root_path: /datasets/ - image_folder: ImageNet_FullSize/240712/061417/ + root_path: /your_data_path/ImageNet_FullSize/240712/061417/ resolution: 384 optimization: batch_size: 8 diff --git a/configs/eval/vitl/in1k.yaml b/configs/eval/vitl/in1k.yaml index 9500666e..961a3913 100644 --- a/configs/eval/vitl/in1k.yaml +++ b/configs/eval/vitl/in1k.yaml @@ -13,8 +13,7 @@ experiment: data: dataset_name: ImageNet num_classes: 1000 - root_path: /datasets/ - image_folder: ImageNet_FullSize/240712/061417/ + root_path: /your_data_path/ImageNet_FullSize/240712/061417/ resolution: 256 optimization: batch_size: 16 diff --git a/configs/eval_2_1/vitG-384/coin.yaml b/configs/eval_2_1/vitG-384/coin.yaml new file mode 100644 index 00000000..31195de4 --- /dev/null +++ b/configs/eval_2_1/vitG-384/coin.yaml @@ -0,0 +1,163 @@ +cpus_per_task: 16 +eval_name: video_classification_frozen +folder: /your_folder/evals_2_1/vitG-384/coin +mem_per_gpu: 220G +nodes: 16 +resume_checkpoint: true +tag: coin-vitG16-384-16x8x3 +tasks_per_node: 8 +experiment: + classifier: + num_heads: 16 + num_probe_blocks: 4 + data: + dataset_type: VideoDataset + dataset_train: /your_data_folder/COIN/train_paths.csv + dataset_val: /your_data_folder/COIN/val_paths.csv + frame_step: 4 + frames_per_clip: 16 + num_classes: 180 + num_segments: 8 + num_views_per_segment: 3 + resolution: 384 + optimization: + batch_size: 1 + multihead_kwargs: + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.8 + num_epochs: 20 + use_bfloat16: true + use_pos_embed: false +model_kwargs: + checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitG_384.pt + module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip + pretrain_kwargs: + encoder: + checkpoint_key: target_encoder + img_temporal_dim_size: 1 + model_name: vit_gigantic_xformers + patch_size: 16 + tubelet_size: 2 + uniform_power: true + use_rope: true + wrapper_kwargs: + max_frames: 128 + use_pos_embed: false diff --git a/configs/eval_2_1/vitG-384/diving48.yaml b/configs/eval_2_1/vitG-384/diving48.yaml new file mode 100644 index 00000000..21d62862 --- /dev/null +++ b/configs/eval_2_1/vitG-384/diving48.yaml @@ -0,0 +1,62 @@ +nodes: 8 +tasks_per_node: 8 +cpus_per_task: 12 +mem_per_gpu: 200G +tag: diving48-vitG16-384-32x4x3 +eval_name: video_classification_frozen +folder: /your_folder/evals_2_1/vitG-384/diving48 +resume_checkpoint: true +experiment: + classifier: + num_probe_blocks: 4 + num_heads: 16 + data: + dataset_type: VideoDataset + dataset_train: /your_data_dir/diving48/annotations/Diving48_train_paths.csv + dataset_val: /your_data_dir/diving48/annotations/Diving48_test_paths.csv + num_classes: 48 + resolution: 384 + frames_per_clip: 32 + frame_step: 2 + num_segments: 4 + num_views_per_segment: 3 + optimization: + use_pos_embed: false + num_epochs: 100 + batch_size: 2 + use_bfloat16: true + multihead_kwargs: + - weight_decay: 0.8 + final_weight_decay: 0.8 + lr: 0.001 + start_lr: 0.001 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.8 + final_weight_decay: 0.8 + lr: 0.0003 + start_lr: 0.0003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.8 + final_weight_decay: 0.8 + lr: 0.0001 + start_lr: 0.0001 + final_lr: 0.0 + warmup: 0. +model_kwargs: + checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitG_384.pt + module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip_multilevel + wrapper_kwargs: + max_frames: 128 + use_pos_embed: false + out_layers: [11, 23, 37, 47] + pretrain_kwargs: + encoder: + model_name: vit_gigantic_xformers + checkpoint_key: target_encoder + img_temporal_dim_size: 1 + tubelet_size: 2 + patch_size: 16 + uniform_power: true + use_rope: true diff --git a/configs/eval_2_1/vitG-384/ek100.yaml b/configs/eval_2_1/vitG-384/ek100.yaml new file mode 100644 index 00000000..f51c8a5d --- /dev/null +++ b/configs/eval_2_1/vitG-384/ek100.yaml @@ -0,0 +1,199 @@ +nodes: 8 +tasks_per_node: 8 +cpus_per_task: 12 +tag: ek100-vitG16-384 +eval_name: action_anticipation_frozen +folder: /your_folder/evals_2_1/vitG-384/ek100 +resume_checkpoint: true +experiment: + classifier: + num_probe_blocks: 4 + num_heads: 16 + data: + anticipation_time_sec: + - 1.0 + - 1.0 + auto_augment: true + file_format: 0 + dataset: EK100 + # e.g. /home/username/EPIC-KITCHENS + base_path: /your_ek100_root_dir/ + dataset_train: /your_data/EPIC_100_train.csv + dataset_val: /your_data/EPIC_100_validation.csv + frames_per_clip: 32 + frames_per_second: 8 + motion_shift: false + num_workers: 2 + pin_memory: true + random_resize_scale: + - 0.08 + - 1.0 + reprob: 0.25 + resolution: 384 + train_anticipation_point: + - 0.0 + - 0.25 + train_anticipation_time_sec: + - 0.25 + - 1.75 + optimization: + num_epochs: 20 + batch_size: 2 + use_bfloat16: true + use_focal_loss: true + multihead_kwargs: + - weight_decay: 0.0001 + final_weight_decay: 0.0001 + lr: 0.005 + start_lr: 0.005 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.0001 + final_weight_decay: 0.0001 + lr: 0.003 + start_lr: 0.003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.0001 + final_weight_decay: 0.0001 + lr: 0.001 + start_lr: 0.001 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.0001 + final_weight_decay: 0.0001 + lr: 0.0003 + start_lr: 0.0003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.0001 + final_weight_decay: 0.0001 + lr: 0.0001 + start_lr: 0.0001 + final_lr: 0.0 + warmup: 0. + + - weight_decay: 0.001 + final_weight_decay: 0.001 + lr: 0.005 + start_lr: 0.005 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.001 + final_weight_decay: 0.001 + lr: 0.003 + start_lr: 0.003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.001 + final_weight_decay: 0.001 + lr: 0.001 + start_lr: 0.001 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.001 + final_weight_decay: 0.001 + lr: 0.0003 + start_lr: 0.0003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.001 + final_weight_decay: 0.001 + lr: 0.0001 + start_lr: 0.0001 + final_lr: 0.0 + warmup: 0. + + - weight_decay: 0.01 + final_weight_decay: 0.01 + lr: 0.005 + start_lr: 0.005 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.01 + final_weight_decay: 0.01 + lr: 0.003 + start_lr: 0.003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.01 + final_weight_decay: 0.01 + lr: 0.001 + start_lr: 0.001 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.01 + final_weight_decay: 0.01 + lr: 0.0003 + start_lr: 0.0003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.01 + final_weight_decay: 0.01 + lr: 0.0001 + start_lr: 0.0001 + final_lr: 0.0 + warmup: 0. + + - weight_decay: 0.1 + final_weight_decay: 0.1 + lr: 0.005 + start_lr: 0.005 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.1 + final_weight_decay: 0.1 + lr: 0.003 + start_lr: 0.003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.1 + final_weight_decay: 0.1 + lr: 0.001 + start_lr: 0.001 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.1 + final_weight_decay: 0.1 + lr: 0.0003 + start_lr: 0.0003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.1 + final_weight_decay: 0.1 + lr: 0.0001 + start_lr: 0.0001 + final_lr: 0.0 + warmup: 0. +model_kwargs: + checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitG_384.pt + module_name: evals.action_anticipation_frozen.modelcustom.vit_encoder_predictor_concat_ar + wrapper_kwargs: + no_predictor: false + num_output_frames: 2 + num_steps: 1 + pretrain_kwargs: + use_v2_1: true + encoder: + model_name: vit_gigantic_xformers + checkpoint_key: target_encoder + img_temporal_dim_size: 1 + tubelet_size: 2 + patch_size: 16 + uniform_power: true + use_rope: true + predictor: + model_name: vit_predictor + checkpoint_key: predictor + num_frames: 64 + depth: 24 + num_heads: 12 + predictor_embed_dim: 384 + num_mask_tokens: 8 + img_temporal_dim_size: 1 + uniform_power: true + use_mask_tokens: true + use_sdpa: true + use_silu: false + wide_silu: false + use_rope: true diff --git a/configs/eval_2_1/vitG-384/in1k.yaml b/configs/eval_2_1/vitG-384/in1k.yaml new file mode 100644 index 00000000..43c69869 --- /dev/null +++ b/configs/eval_2_1/vitG-384/in1k.yaml @@ -0,0 +1,162 @@ +cpus_per_task: 16 +eval_name: image_classification_frozen +folder: /your_folder/evals_2_1/vitG-384/in1k +mem_per_gpu: 220G +nodes: 16 +resume_checkpoint: true +tag: in1k-vitG16-384-18f +tasks_per_node: 8 +experiment: + classifier: + num_heads: 16 + num_probe_blocks: 4 + data: + dataset_name: ImageNet + num_classes: 1000 + root_path: /your_data_path/ImageNet_FullSize/240712/061417/ + resolution: 384 + optimization: + batch_size: 8 + multihead_kwargs: + - final_lr: 0.0 + final_weight_decay: 0.0 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.001 + - final_lr: 0.0 + final_weight_decay: 0.008 + lr: 0.0005 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.008 + - final_lr: 0.0 + final_weight_decay: 0.004 + lr: 0.0005 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.004 + - final_lr: 0.0 + final_weight_decay: 0.002 + lr: 0.0005 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.002 + - final_lr: 0.0 + final_weight_decay: 0.001 + lr: 0.0005 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.001 + - final_lr: 0.0 + final_weight_decay: 0.0005 + lr: 0.0005 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.0005 + - final_lr: 0.0 + final_weight_decay: 0.008 + lr: 0.001 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.008 + - final_lr: 0.0 + final_weight_decay: 0.004 + lr: 0.001 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.004 + - final_lr: 0.0 + final_weight_decay: 0.002 + lr: 0.001 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.002 + - final_lr: 0.0 + final_weight_decay: 0.001 + lr: 0.001 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.001 + - final_lr: 0.0 + final_weight_decay: 0.0005 + lr: 0.001 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.0005 + - final_lr: 0.0 + final_weight_decay: 0.008 + lr: 0.0015 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.008 + - final_lr: 0.0 + final_weight_decay: 0.004 + lr: 0.0015 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.004 + - final_lr: 0.0 + final_weight_decay: 0.002 + lr: 0.0015 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.002 + - final_lr: 0.0 + final_weight_decay: 0.001 + lr: 0.0015 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.001 + - final_lr: 0.0 + final_weight_decay: 0.0005 + lr: 0.0015 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.0005 + - final_lr: 0.0 + final_weight_decay: 0.008 + lr: 0.002 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.008 + - final_lr: 0.0 + final_weight_decay: 0.004 + lr: 0.002 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.004 + - final_lr: 0.0 + final_weight_decay: 0.002 + lr: 0.002 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.002 + - final_lr: 0.0 + final_weight_decay: 0.001 + lr: 0.002 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.001 + - final_lr: 0.0 + final_weight_decay: 0.0005 + lr: 0.002 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.0005 + num_epochs: 20 + use_bfloat16: true +model_kwargs: + checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitG_384.pt + module_name: evals.image_classification_frozen.modelcustom.vit_encoder + pretrain_kwargs: + encoder: + checkpoint_key: target_encoder + img_temporal_dim_size: 1 + model_name: vit_gigantic_xformers + patch_size: 16 + tubelet_size: 2 + uniform_power: true + use_rope: true + wrapper_kwargs: + img_as_video_nframes: 18 diff --git a/configs/eval_2_1/vitG-384/jester.yaml b/configs/eval_2_1/vitG-384/jester.yaml new file mode 100644 index 00000000..fd38a394 --- /dev/null +++ b/configs/eval_2_1/vitG-384/jester.yaml @@ -0,0 +1,62 @@ +nodes: 8 +tasks_per_node: 8 +cpus_per_task: 12 +mem_per_gpu: 200G +tag: jester-vitG16-384-32x4x3 +eval_name: video_classification_frozen +folder: /your_folder/evals_2_1/vitG-384/jester +resume_checkpoint: true +experiment: + classifier: + num_probe_blocks: 4 + num_heads: 16 + data: + dataset_type: VideoDataset + dataset_train: /your_data_dir/Jester/annotations/jester_train_paths.csv + dataset_val: /your_data_dir/Jester/annotations/jester_validation_paths.csv + num_classes: 27 + resolution: 384 + frames_per_clip: 32 + frame_step: 2 + num_segments: 4 + num_views_per_segment: 3 + optimization: + use_pos_embed: false + num_epochs: 100 + batch_size: 2 + use_bfloat16: true + multihead_kwargs: + - weight_decay: 0.8 + final_weight_decay: 0.8 + lr: 0.001 + start_lr: 0.001 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.8 + final_weight_decay: 0.8 + lr: 0.0003 + start_lr: 0.0003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.8 + final_weight_decay: 0.8 + lr: 0.0001 + start_lr: 0.0001 + final_lr: 0.0 + warmup: 0. +model_kwargs: + checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitG_384.pt + module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip_multilevel + wrapper_kwargs: + max_frames: 128 + use_pos_embed: false + out_layers: [11, 23, 37, 47] + pretrain_kwargs: + encoder: + model_name: vit_gigantic_xformers + checkpoint_key: target_encoder + img_temporal_dim_size: 1 + tubelet_size: 2 + patch_size: 16 + uniform_power: true + use_rope: true diff --git a/configs/eval_2_1/vitG-384/k400.yaml b/configs/eval_2_1/vitG-384/k400.yaml new file mode 100644 index 00000000..2e256315 --- /dev/null +++ b/configs/eval_2_1/vitG-384/k400.yaml @@ -0,0 +1,164 @@ +cpus_per_task: 16 +eval_name: video_classification_frozen +folder: /your_folder/evals_2_1/vitG-384/k400 +mem_per_gpu: 220G +nodes: 32 +num_workers: 8 +resume_checkpoint: true +tag: k400-vitG16-384-16x8x3-16f +tasks_per_node: 8 +experiment: + classifier: + num_heads: 16 + num_probe_blocks: 4 + data: + dataset_type: VideoDataset + dataset_train: /your_data_path/k400_train_paths.csv + dataset_val: /your_data_path/k400_val_paths.csv + frame_step: 4 + frames_per_clip: 16 + num_classes: 400 + num_segments: 8 + num_views_per_segment: 3 + resolution: 384 + optimization: + batch_size: 1 + multihead_kwargs: + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.8 + num_epochs: 20 + use_bfloat16: true + use_pos_embed: false +model_kwargs: + checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitG_384.pt + module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip + pretrain_kwargs: + encoder: + checkpoint_key: target_encoder + img_temporal_dim_size: 1 + model_name: vit_gigantic_xformers + patch_size: 16 + tubelet_size: 2 + uniform_power: true + use_rope: true + wrapper_kwargs: + max_frames: 128 + use_pos_embed: false diff --git a/configs/eval_2_1/vitG-384/ssv2.yaml b/configs/eval_2_1/vitG-384/ssv2.yaml new file mode 100644 index 00000000..206dfdd5 --- /dev/null +++ b/configs/eval_2_1/vitG-384/ssv2.yaml @@ -0,0 +1,164 @@ +cpus_per_task: 16 +eval_name: video_classification_frozen +folder: /your_folder/evals_2_1/vitG-384/ssv2 +mem_per_gpu: 220G +nodes: 16 +num_workers: 8 +resume_checkpoint: true +tag: ssv2-vitG16-384-64x2x3 +tasks_per_node: 8 +experiment: + classifier: + num_heads: 16 + num_probe_blocks: 4 + data: + dataset_type: VideoDataset + dataset_train: /your_data_path/ssv2_train_paths.csv + dataset_val: /your_data_path/ssv2_val_paths.csv + frame_step: 2 + frames_per_clip: 64 + num_classes: 174 + num_segments: 2 + num_views_per_segment: 3 + resolution: 384 + optimization: + batch_size: 1 + multihead_kwargs: + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.8 + num_epochs: 20 + use_bfloat16: true + use_pos_embed: false +model_kwargs: + checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitG_384.pt + module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip + pretrain_kwargs: + encoder: + checkpoint_key: target_encoder + img_temporal_dim_size: 1 + model_name: vit_gigantic_xformers + patch_size: 16 + tubelet_size: 2 + uniform_power: true + use_rope: true + wrapper_kwargs: + max_frames: 128 + use_pos_embed: false diff --git a/configs/eval_2_1/vitb-384/coin.yaml b/configs/eval_2_1/vitb-384/coin.yaml new file mode 100644 index 00000000..e30343d2 --- /dev/null +++ b/configs/eval_2_1/vitb-384/coin.yaml @@ -0,0 +1,163 @@ +cpus_per_task: 16 +eval_name: video_classification_frozen +folder: /your_folder/evals_2_1/vitb-384/coin +mem_per_gpu: 220G +nodes: 8 +resume_checkpoint: true +tag: coin-vitb16-384-16x8x3 +tasks_per_node: 8 +experiment: + classifier: + num_heads: 16 + num_probe_blocks: 4 + data: + dataset_type: VideoDataset + dataset_train: /your_data_folder/COIN/train_paths.csv + dataset_val: /your_data_folder/COIN/val_paths.csv + frame_step: 4 + frames_per_clip: 16 + num_classes: 180 + num_segments: 8 + num_views_per_segment: 3 + resolution: 384 + optimization: + batch_size: 2 + multihead_kwargs: + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.8 + num_epochs: 20 + use_bfloat16: true + use_pos_embed: false +model_kwargs: + checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitb_dist_vitG_384.pt + module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip + pretrain_kwargs: + encoder: + checkpoint_key: ema_encoder + img_temporal_dim_size: 1 + model_name: vit_base + patch_size: 16 + tubelet_size: 2 + uniform_power: true + use_rope: true + wrapper_kwargs: + max_frames: 128 + use_pos_embed: false diff --git a/configs/eval_2_1/vitb-384/diving48.yaml b/configs/eval_2_1/vitb-384/diving48.yaml new file mode 100644 index 00000000..4d0f9ace --- /dev/null +++ b/configs/eval_2_1/vitb-384/diving48.yaml @@ -0,0 +1,62 @@ +nodes: 8 +tasks_per_node: 8 +cpus_per_task: 12 +mem_per_gpu: 200G +tag: diving48-vitb16-384-32x4x3 +eval_name: video_classification_frozen +folder: /your_folder/evals_2_1/vitb-384/diving48 +resume_checkpoint: true +experiment: + classifier: + num_probe_blocks: 4 + num_heads: 16 + data: + dataset_type: VideoDataset + dataset_train: /your_data_dir/diving48/annotations/Diving48_train_paths.csv + dataset_val: /your_data_dir/diving48/annotations/Diving48_test_paths.csv + num_classes: 48 + resolution: 384 + frames_per_clip: 32 + frame_step: 2 + num_segments: 4 + num_views_per_segment: 3 + optimization: + use_pos_embed: false + num_epochs: 100 + batch_size: 2 + use_bfloat16: true + multihead_kwargs: + - weight_decay: 0.8 + final_weight_decay: 0.8 + lr: 0.001 + start_lr: 0.001 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.8 + final_weight_decay: 0.8 + lr: 0.0003 + start_lr: 0.0003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.8 + final_weight_decay: 0.8 + lr: 0.0001 + start_lr: 0.0001 + final_lr: 0.0 + warmup: 0. +model_kwargs: + checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitb_dist_vitG_384.pt + module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip_multilevel + wrapper_kwargs: + max_frames: 128 + use_pos_embed: false + out_layers: [2, 5, 8, 11] + pretrain_kwargs: + encoder: + model_name: vit_base + checkpoint_key: ema_encoder + img_temporal_dim_size: 1 + tubelet_size: 2 + patch_size: 16 + uniform_power: true + use_rope: true diff --git a/configs/eval_2_1/vitb-384/ek100.yaml b/configs/eval_2_1/vitb-384/ek100.yaml new file mode 100644 index 00000000..6c445e19 --- /dev/null +++ b/configs/eval_2_1/vitb-384/ek100.yaml @@ -0,0 +1,202 @@ +nodes: 8 +tasks_per_node: 8 +cpus_per_task: 12 +tag: ek100-vitb16-384 +eval_name: action_anticipation_frozen +folder: /your_folder/evals_2_1/vitb-384/ek100 +resume_checkpoint: true +experiment: + classifier: + num_probe_blocks: 4 + num_heads: 16 + data: + anticipation_time_sec: + - 1.0 + - 1.0 + auto_augment: true + file_format: 0 + dataset: EK100 + # e.g. /home/username/EPIC-KITCHENS + base_path: /your_ek100_root_dir/ + dataset_train: /your_data/EPIC_100_train.csv + dataset_val: /your_data/EPIC_100_validation.csv + frames_per_clip: 32 + frames_per_second: 8 + motion_shift: false + num_workers: 2 + pin_memory: true + random_resize_scale: + - 0.08 + - 1.0 + reprob: 0.25 + resolution: 384 + train_anticipation_point: + - 0.0 + - 0.25 + train_anticipation_time_sec: + - 0.25 + - 1.75 + optimization: + num_epochs: 20 + batch_size: 2 + use_bfloat16: true + use_focal_loss: true + multihead_kwargs: + - weight_decay: 0.0001 + final_weight_decay: 0.0001 + lr: 0.005 + start_lr: 0.005 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.0001 + final_weight_decay: 0.0001 + lr: 0.003 + start_lr: 0.003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.0001 + final_weight_decay: 0.0001 + lr: 0.001 + start_lr: 0.001 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.0001 + final_weight_decay: 0.0001 + lr: 0.0003 + start_lr: 0.0003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.0001 + final_weight_decay: 0.0001 + lr: 0.0001 + start_lr: 0.0001 + final_lr: 0.0 + warmup: 0. + + - weight_decay: 0.001 + final_weight_decay: 0.001 + lr: 0.005 + start_lr: 0.005 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.001 + final_weight_decay: 0.001 + lr: 0.003 + start_lr: 0.003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.001 + final_weight_decay: 0.001 + lr: 0.001 + start_lr: 0.001 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.001 + final_weight_decay: 0.001 + lr: 0.0003 + start_lr: 0.0003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.001 + final_weight_decay: 0.001 + lr: 0.0001 + start_lr: 0.0001 + final_lr: 0.0 + warmup: 0. + + - weight_decay: 0.01 + final_weight_decay: 0.01 + lr: 0.005 + start_lr: 0.005 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.01 + final_weight_decay: 0.01 + lr: 0.003 + start_lr: 0.003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.01 + final_weight_decay: 0.01 + lr: 0.001 + start_lr: 0.001 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.01 + final_weight_decay: 0.01 + lr: 0.0003 + start_lr: 0.0003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.01 + final_weight_decay: 0.01 + lr: 0.0001 + start_lr: 0.0001 + final_lr: 0.0 + warmup: 0. + + - weight_decay: 0.1 + final_weight_decay: 0.1 + lr: 0.005 + start_lr: 0.005 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.1 + final_weight_decay: 0.1 + lr: 0.003 + start_lr: 0.003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.1 + final_weight_decay: 0.1 + lr: 0.001 + start_lr: 0.001 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.1 + final_weight_decay: 0.1 + lr: 0.0003 + start_lr: 0.0003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.1 + final_weight_decay: 0.1 + lr: 0.0001 + start_lr: 0.0001 + final_lr: 0.0 + warmup: 0. +model_kwargs: + checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitb_dist_vitG_384.pt + module_name: evals.action_anticipation_frozen.modelcustom.vit_encoder_predictor_concat_ar + wrapper_kwargs: + no_predictor: false + num_output_frames: 2 + num_steps: 1 + pretrain_kwargs: + use_v2_1: true + encoder: + model_name: vit_base + checkpoint_key: ema_encoder + img_temporal_dim_size: 1 + tubelet_size: 2 + patch_size: 16 + uniform_power: true + use_rope: true + predictor: + model_name: vit_predictor + checkpoint_key: predictor + num_frames: 64 + depth: 12 + num_heads: 12 + predictor_embed_dim: 384 + teacher_embed_dim: 1664 + num_mask_tokens: 8 + n_output_distillation: 1 + return_all_tokens: true + img_temporal_dim_size: 1 + uniform_power: true + use_mask_tokens: true + use_sdpa: true + use_silu: false + wide_silu: false + use_rope: true diff --git a/configs/eval_2_1/vitb-384/in1k.yaml b/configs/eval_2_1/vitb-384/in1k.yaml new file mode 100644 index 00000000..049e332f --- /dev/null +++ b/configs/eval_2_1/vitb-384/in1k.yaml @@ -0,0 +1,162 @@ +cpus_per_task: 16 +eval_name: image_classification_frozen +folder: /your_folder/evals_2_1/vitb-384/in1k +mem_per_gpu: 220G +nodes: 8 +resume_checkpoint: true +tag: in1k-vitb16-384-16f +tasks_per_node: 8 +experiment: + classifier: + num_heads: 16 + num_probe_blocks: 4 + data: + dataset_name: ImageNet + num_classes: 1000 + root_path: /your_data_path/ImageNet_FullSize/240712/061417/ + resolution: 384 + optimization: + batch_size: 16 + multihead_kwargs: + - final_lr: 0.0 + final_weight_decay: 0.0 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.001 + - final_lr: 0.0 + final_weight_decay: 0.008 + lr: 0.0005 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.008 + - final_lr: 0.0 + final_weight_decay: 0.004 + lr: 0.0005 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.004 + - final_lr: 0.0 + final_weight_decay: 0.002 + lr: 0.0005 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.002 + - final_lr: 0.0 + final_weight_decay: 0.001 + lr: 0.0005 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.001 + - final_lr: 0.0 + final_weight_decay: 0.0005 + lr: 0.0005 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.0005 + - final_lr: 0.0 + final_weight_decay: 0.008 + lr: 0.001 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.008 + - final_lr: 0.0 + final_weight_decay: 0.004 + lr: 0.001 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.004 + - final_lr: 0.0 + final_weight_decay: 0.002 + lr: 0.001 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.002 + - final_lr: 0.0 + final_weight_decay: 0.001 + lr: 0.001 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.001 + - final_lr: 0.0 + final_weight_decay: 0.0005 + lr: 0.001 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.0005 + - final_lr: 0.0 + final_weight_decay: 0.008 + lr: 0.0015 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.008 + - final_lr: 0.0 + final_weight_decay: 0.004 + lr: 0.0015 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.004 + - final_lr: 0.0 + final_weight_decay: 0.002 + lr: 0.0015 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.002 + - final_lr: 0.0 + final_weight_decay: 0.001 + lr: 0.0015 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.001 + - final_lr: 0.0 + final_weight_decay: 0.0005 + lr: 0.0015 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.0005 + - final_lr: 0.0 + final_weight_decay: 0.008 + lr: 0.002 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.008 + - final_lr: 0.0 + final_weight_decay: 0.004 + lr: 0.002 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.004 + - final_lr: 0.0 + final_weight_decay: 0.002 + lr: 0.002 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.002 + - final_lr: 0.0 + final_weight_decay: 0.001 + lr: 0.002 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.001 + - final_lr: 0.0 + final_weight_decay: 0.0005 + lr: 0.002 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.0005 + num_epochs: 20 + use_bfloat16: true +model_kwargs: + checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitb_dist_vitG_384.pt + module_name: evals.image_classification_frozen.modelcustom.vit_encoder + pretrain_kwargs: + encoder: + checkpoint_key: ema_encoder + img_temporal_dim_size: 1 + model_name: vit_base + patch_size: 16 + tubelet_size: 2 + uniform_power: true + use_rope: true + wrapper_kwargs: + img_as_video_nframes: 16 diff --git a/configs/eval_2_1/vitb-384/jester.yaml b/configs/eval_2_1/vitb-384/jester.yaml new file mode 100644 index 00000000..5958a266 --- /dev/null +++ b/configs/eval_2_1/vitb-384/jester.yaml @@ -0,0 +1,62 @@ +nodes: 8 +tasks_per_node: 8 +cpus_per_task: 12 +mem_per_gpu: 200G +tag: jester-vitb16-384-32x4x3 +eval_name: video_classification_frozen +folder: /your_folder/evals_2_1/vitb-384/jester +resume_checkpoint: true +experiment: + classifier: + num_probe_blocks: 4 + num_heads: 16 + data: + dataset_type: VideoDataset + dataset_train: /your_data_dir/Jester/annotations/jester_train_paths.csv + dataset_val: /your_data_dir/Jester/annotations/jester_validation_paths.csv + num_classes: 27 + resolution: 384 + frames_per_clip: 32 + frame_step: 2 + num_segments: 4 + num_views_per_segment: 3 + optimization: + use_pos_embed: false + num_epochs: 100 + batch_size: 2 + use_bfloat16: true + multihead_kwargs: + - weight_decay: 0.8 + final_weight_decay: 0.8 + lr: 0.001 + start_lr: 0.001 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.8 + final_weight_decay: 0.8 + lr: 0.0003 + start_lr: 0.0003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.8 + final_weight_decay: 0.8 + lr: 0.0001 + start_lr: 0.0001 + final_lr: 0.0 + warmup: 0. +model_kwargs: + checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitb_dist_vitG_384.pt + module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip_multilevel + wrapper_kwargs: + max_frames: 128 + use_pos_embed: false + out_layers: [2, 5, 8, 11] + pretrain_kwargs: + encoder: + model_name: vit_base + checkpoint_key: ema_encoder + img_temporal_dim_size: 1 + tubelet_size: 2 + patch_size: 16 + uniform_power: true + use_rope: true diff --git a/configs/eval_2_1/vitb-384/k400.yaml b/configs/eval_2_1/vitb-384/k400.yaml new file mode 100644 index 00000000..da3f5128 --- /dev/null +++ b/configs/eval_2_1/vitb-384/k400.yaml @@ -0,0 +1,164 @@ +cpus_per_task: 16 +eval_name: video_classification_frozen +folder: /your_folder/evals_2_1/vitb-384/k400 +mem_per_gpu: 220G +nodes: 8 +num_workers: 8 +resume_checkpoint: true +tag: k400-vitb16-384-16x8x3-16f +tasks_per_node: 8 +experiment: + classifier: + num_heads: 16 + num_probe_blocks: 4 + data: + dataset_type: VideoDataset + dataset_train: /your_data_path/k400_train_paths.csv + dataset_val: /your_data_path/k400_val_paths.csv + frame_step: 4 + frames_per_clip: 16 + num_classes: 400 + num_segments: 8 + num_views_per_segment: 3 + resolution: 384 + optimization: + batch_size: 4 + multihead_kwargs: + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.8 + num_epochs: 20 + use_bfloat16: true + use_pos_embed: false +model_kwargs: + checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitb_dist_vitG_384.pt + module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip + pretrain_kwargs: + encoder: + checkpoint_key: ema_encoder + img_temporal_dim_size: 1 + model_name: vit_base + patch_size: 16 + tubelet_size: 2 + uniform_power: true + use_rope: true + wrapper_kwargs: + max_frames: 128 + use_pos_embed: false diff --git a/configs/eval_2_1/vitb-384/ssv2.yaml b/configs/eval_2_1/vitb-384/ssv2.yaml new file mode 100644 index 00000000..66aebaae --- /dev/null +++ b/configs/eval_2_1/vitb-384/ssv2.yaml @@ -0,0 +1,164 @@ +cpus_per_task: 16 +eval_name: video_classification_frozen +folder: /your_folder/evals_2_1/vitb-384/ssv2 +mem_per_gpu: 220G +nodes: 8 +max_workers: 8 +resume_checkpoint: true +tag: ssv2-vitb16-384-16x2x3-16f +tasks_per_node: 8 +experiment: + classifier: + num_heads: 16 + num_probe_blocks: 4 + data: + dataset_type: VideoDataset + dataset_train: /your_data_path/ssv2_train_paths.csv + dataset_val: /your_data_path/ssv2_val_paths.csv + frame_step: 4 + frames_per_clip: 16 + num_classes: 174 + num_segments: 2 + num_views_per_segment: 3 + resolution: 384 + optimization: + batch_size: 4 + multihead_kwargs: + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.8 + num_epochs: 20 + use_bfloat16: true + use_pos_embed: false +model_kwargs: + checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitb_dist_vitG_384.pt + module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip + pretrain_kwargs: + encoder: + checkpoint_key: ema_encoder + img_temporal_dim_size: 1 + model_name: vit_base + patch_size: 16 + tubelet_size: 2 + uniform_power: true + use_rope: true + wrapper_kwargs: + max_frames: 128 + use_pos_embed: false diff --git a/configs/eval_2_1/vitg-384/coin.yaml b/configs/eval_2_1/vitg-384/coin.yaml new file mode 100644 index 00000000..fdd89dec --- /dev/null +++ b/configs/eval_2_1/vitg-384/coin.yaml @@ -0,0 +1,163 @@ +cpus_per_task: 16 +eval_name: video_classification_frozen +folder: /your_folder/evals_2_1/vitg-384/coin +mem_per_gpu: 220G +nodes: 16 +resume_checkpoint: true +tag: coin-vitg16-384-16x8x3 +tasks_per_node: 8 +experiment: + classifier: + num_heads: 16 + num_probe_blocks: 4 + data: + dataset_type: VideoDataset + dataset_train: /your_data_folder/COIN/train_paths.csv + dataset_val: /your_data_folder/COIN/val_paths.csv + frame_step: 4 + frames_per_clip: 16 + num_classes: 180 + num_segments: 8 + num_views_per_segment: 3 + resolution: 384 + optimization: + batch_size: 1 + multihead_kwargs: + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.8 + num_epochs: 20 + use_bfloat16: true + use_pos_embed: false +model_kwargs: + checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitg_384.pt + module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip + pretrain_kwargs: + encoder: + checkpoint_key: target_encoder + img_temporal_dim_size: 1 + model_name: vit_giant_xformers + patch_size: 16 + tubelet_size: 2 + uniform_power: true + use_rope: true + wrapper_kwargs: + max_frames: 128 + use_pos_embed: false diff --git a/configs/eval_2_1/vitg-384/diving48.yaml b/configs/eval_2_1/vitg-384/diving48.yaml new file mode 100644 index 00000000..ac68b44a --- /dev/null +++ b/configs/eval_2_1/vitg-384/diving48.yaml @@ -0,0 +1,62 @@ +nodes: 8 +tasks_per_node: 8 +cpus_per_task: 12 +mem_per_gpu: 200G +tag: diving48-vitg16-384-32x4x3 +eval_name: video_classification_frozen +folder: /your_folder/evals_2_1/vitg-384/diving48 +resume_checkpoint: true +experiment: + classifier: + num_probe_blocks: 4 + num_heads: 16 + data: + dataset_type: VideoDataset + dataset_train: /your_data_dir/diving48/annotations/Diving48_train_paths.csv + dataset_val: /your_data_dir/diving48/annotations/Diving48_test_paths.csv + num_classes: 48 + resolution: 384 + frames_per_clip: 32 + frame_step: 2 + num_segments: 4 + num_views_per_segment: 3 + optimization: + use_pos_embed: false + num_epochs: 100 + batch_size: 2 + use_bfloat16: true + multihead_kwargs: + - weight_decay: 0.8 + final_weight_decay: 0.8 + lr: 0.001 + start_lr: 0.001 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.8 + final_weight_decay: 0.8 + lr: 0.0003 + start_lr: 0.0003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.8 + final_weight_decay: 0.8 + lr: 0.0001 + start_lr: 0.0001 + final_lr: 0.0 + warmup: 0. +model_kwargs: + checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitg_384.pt + module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip_multilevel + wrapper_kwargs: + max_frames: 128 + use_pos_embed: false + out_layers: [9, 19, 29, 39] + pretrain_kwargs: + encoder: + model_name: vit_giant_xformers + checkpoint_key: target_encoder + img_temporal_dim_size: 1 + tubelet_size: 2 + patch_size: 16 + uniform_power: true + use_rope: true diff --git a/configs/eval_2_1/vitg-384/ek100.yaml b/configs/eval_2_1/vitg-384/ek100.yaml new file mode 100644 index 00000000..faa31913 --- /dev/null +++ b/configs/eval_2_1/vitg-384/ek100.yaml @@ -0,0 +1,199 @@ +nodes: 8 +tasks_per_node: 8 +cpus_per_task: 12 +tag: ek100-vitg16-384 +eval_name: action_anticipation_frozen +folder: /your_folder/evals_2_1/vitg-384/ek100 +resume_checkpoint: true +experiment: + classifier: + num_probe_blocks: 4 + num_heads: 16 + data: + anticipation_time_sec: + - 1.0 + - 1.0 + auto_augment: true + file_format: 0 + dataset: EK100 + # e.g. /home/username/EPIC-KITCHENS + base_path: /your_ek100_root_dir/ + dataset_train: /your_data/EPIC_100_train.csv + dataset_val: /your_data/EPIC_100_validation.csv + frames_per_clip: 32 + frames_per_second: 8 + motion_shift: false + num_workers: 2 + pin_memory: true + random_resize_scale: + - 0.08 + - 1.0 + reprob: 0.25 + resolution: 384 + train_anticipation_point: + - 0.0 + - 0.25 + train_anticipation_time_sec: + - 0.25 + - 1.75 + optimization: + num_epochs: 20 + batch_size: 2 + use_bfloat16: true + use_focal_loss: true + multihead_kwargs: + - weight_decay: 0.0001 + final_weight_decay: 0.0001 + lr: 0.005 + start_lr: 0.005 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.0001 + final_weight_decay: 0.0001 + lr: 0.003 + start_lr: 0.003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.0001 + final_weight_decay: 0.0001 + lr: 0.001 + start_lr: 0.001 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.0001 + final_weight_decay: 0.0001 + lr: 0.0003 + start_lr: 0.0003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.0001 + final_weight_decay: 0.0001 + lr: 0.0001 + start_lr: 0.0001 + final_lr: 0.0 + warmup: 0. + + - weight_decay: 0.001 + final_weight_decay: 0.001 + lr: 0.005 + start_lr: 0.005 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.001 + final_weight_decay: 0.001 + lr: 0.003 + start_lr: 0.003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.001 + final_weight_decay: 0.001 + lr: 0.001 + start_lr: 0.001 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.001 + final_weight_decay: 0.001 + lr: 0.0003 + start_lr: 0.0003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.001 + final_weight_decay: 0.001 + lr: 0.0001 + start_lr: 0.0001 + final_lr: 0.0 + warmup: 0. + + - weight_decay: 0.01 + final_weight_decay: 0.01 + lr: 0.005 + start_lr: 0.005 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.01 + final_weight_decay: 0.01 + lr: 0.003 + start_lr: 0.003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.01 + final_weight_decay: 0.01 + lr: 0.001 + start_lr: 0.001 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.01 + final_weight_decay: 0.01 + lr: 0.0003 + start_lr: 0.0003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.01 + final_weight_decay: 0.01 + lr: 0.0001 + start_lr: 0.0001 + final_lr: 0.0 + warmup: 0. + + - weight_decay: 0.1 + final_weight_decay: 0.1 + lr: 0.005 + start_lr: 0.005 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.1 + final_weight_decay: 0.1 + lr: 0.003 + start_lr: 0.003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.1 + final_weight_decay: 0.1 + lr: 0.001 + start_lr: 0.001 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.1 + final_weight_decay: 0.1 + lr: 0.0003 + start_lr: 0.0003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.1 + final_weight_decay: 0.1 + lr: 0.0001 + start_lr: 0.0001 + final_lr: 0.0 + warmup: 0. +model_kwargs: + checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitg_384.pt + module_name: evals.action_anticipation_frozen.modelcustom.vit_encoder_predictor_concat_ar + wrapper_kwargs: + no_predictor: false + num_output_frames: 2 + num_steps: 1 + pretrain_kwargs: + use_v2_1: true + encoder: + model_name: vit_giant_xformers + checkpoint_key: target_encoder + img_temporal_dim_size: 1 + tubelet_size: 2 + patch_size: 16 + uniform_power: true + use_rope: true + predictor: + model_name: vit_predictor + checkpoint_key: predictor + num_frames: 64 + depth: 24 + num_heads: 12 + predictor_embed_dim: 384 + num_mask_tokens: 8 + img_temporal_dim_size: 1 + uniform_power: true + use_mask_tokens: true + use_sdpa: true + use_silu: false + wide_silu: false + use_rope: true diff --git a/configs/eval_2_1/vitg-384/in1k.yaml b/configs/eval_2_1/vitg-384/in1k.yaml new file mode 100644 index 00000000..73e4321d --- /dev/null +++ b/configs/eval_2_1/vitg-384/in1k.yaml @@ -0,0 +1,162 @@ +cpus_per_task: 16 +eval_name: image_classification_frozen +folder: /your_folder/evals_2_1/vitg-384/in1k +mem_per_gpu: 220G +nodes: 16 +resume_checkpoint: true +tag: in1k-vitg16-384-18f +tasks_per_node: 8 +experiment: + classifier: + num_heads: 16 + num_probe_blocks: 4 + data: + dataset_name: ImageNet + num_classes: 1000 + root_path: /your_data_path/ImageNet_FullSize/240712/061417/ + resolution: 384 + optimization: + batch_size: 8 + multihead_kwargs: + - final_lr: 0.0 + final_weight_decay: 0.0 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.001 + - final_lr: 0.0 + final_weight_decay: 0.008 + lr: 0.0005 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.008 + - final_lr: 0.0 + final_weight_decay: 0.004 + lr: 0.0005 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.004 + - final_lr: 0.0 + final_weight_decay: 0.002 + lr: 0.0005 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.002 + - final_lr: 0.0 + final_weight_decay: 0.001 + lr: 0.0005 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.001 + - final_lr: 0.0 + final_weight_decay: 0.0005 + lr: 0.0005 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.0005 + - final_lr: 0.0 + final_weight_decay: 0.008 + lr: 0.001 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.008 + - final_lr: 0.0 + final_weight_decay: 0.004 + lr: 0.001 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.004 + - final_lr: 0.0 + final_weight_decay: 0.002 + lr: 0.001 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.002 + - final_lr: 0.0 + final_weight_decay: 0.001 + lr: 0.001 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.001 + - final_lr: 0.0 + final_weight_decay: 0.0005 + lr: 0.001 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.0005 + - final_lr: 0.0 + final_weight_decay: 0.008 + lr: 0.0015 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.008 + - final_lr: 0.0 + final_weight_decay: 0.004 + lr: 0.0015 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.004 + - final_lr: 0.0 + final_weight_decay: 0.002 + lr: 0.0015 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.002 + - final_lr: 0.0 + final_weight_decay: 0.001 + lr: 0.0015 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.001 + - final_lr: 0.0 + final_weight_decay: 0.0005 + lr: 0.0015 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.0005 + - final_lr: 0.0 + final_weight_decay: 0.008 + lr: 0.002 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.008 + - final_lr: 0.0 + final_weight_decay: 0.004 + lr: 0.002 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.004 + - final_lr: 0.0 + final_weight_decay: 0.002 + lr: 0.002 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.002 + - final_lr: 0.0 + final_weight_decay: 0.001 + lr: 0.002 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.001 + - final_lr: 0.0 + final_weight_decay: 0.0005 + lr: 0.002 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.0005 + num_epochs: 20 + use_bfloat16: true +model_kwargs: + checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitg_384.pt + module_name: evals.image_classification_frozen.modelcustom.vit_encoder + pretrain_kwargs: + encoder: + checkpoint_key: target_encoder + img_temporal_dim_size: 1 + model_name: vit_giant_xformers + patch_size: 16 + tubelet_size: 2 + uniform_power: true + use_rope: true + wrapper_kwargs: + img_as_video_nframes: 18 diff --git a/configs/eval_2_1/vitg-384/jester.yaml b/configs/eval_2_1/vitg-384/jester.yaml new file mode 100644 index 00000000..78cb2aac --- /dev/null +++ b/configs/eval_2_1/vitg-384/jester.yaml @@ -0,0 +1,62 @@ +nodes: 8 +tasks_per_node: 8 +cpus_per_task: 12 +mem_per_gpu: 200G +tag: jester-vitg16-384-32x4x3 +eval_name: video_classification_frozen +folder: /your_folder/evals_2_1/vitg-384/jester +resume_checkpoint: true +experiment: + classifier: + num_probe_blocks: 4 + num_heads: 16 + data: + dataset_type: VideoDataset + dataset_train: /your_data_dir/Jester/annotations/jester_train_paths.csv + dataset_val: /your_data_dir/Jester/annotations/jester_validation_paths.csv + num_classes: 27 + resolution: 384 + frames_per_clip: 32 + frame_step: 2 + num_segments: 4 + num_views_per_segment: 3 + optimization: + use_pos_embed: false + num_epochs: 100 + batch_size: 2 + use_bfloat16: true + multihead_kwargs: + - weight_decay: 0.8 + final_weight_decay: 0.8 + lr: 0.001 + start_lr: 0.001 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.8 + final_weight_decay: 0.8 + lr: 0.0003 + start_lr: 0.0003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.8 + final_weight_decay: 0.8 + lr: 0.0001 + start_lr: 0.0001 + final_lr: 0.0 + warmup: 0. +model_kwargs: + checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitg_384.pt + module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip_multilevel + wrapper_kwargs: + max_frames: 128 + use_pos_embed: false + out_layers: [9, 19, 29, 39] + pretrain_kwargs: + encoder: + model_name: vit_giant_xformers + checkpoint_key: target_encoder + img_temporal_dim_size: 1 + tubelet_size: 2 + patch_size: 16 + uniform_power: true + use_rope: true diff --git a/configs/eval_2_1/vitg-384/k400.yaml b/configs/eval_2_1/vitg-384/k400.yaml new file mode 100644 index 00000000..3cc6677f --- /dev/null +++ b/configs/eval_2_1/vitg-384/k400.yaml @@ -0,0 +1,164 @@ +cpus_per_task: 16 +eval_name: video_classification_frozen +folder: /your_folder/evals_2_1/vitg-384/k400 +mem_per_gpu: 220G +nodes: 32 +num_workers: 8 +resume_checkpoint: true +tag: k400-vitg16-384-16x8x3-16f +tasks_per_node: 8 +experiment: + classifier: + num_heads: 16 + num_probe_blocks: 4 + data: + dataset_type: VideoDataset + dataset_train: /your_data_path/k400_train_paths.csv + dataset_val: /your_data_path/k400_val_paths.csv + frame_step: 4 + frames_per_clip: 16 + num_classes: 400 + num_segments: 8 + num_views_per_segment: 3 + resolution: 384 + optimization: + batch_size: 1 + multihead_kwargs: + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.8 + num_epochs: 20 + use_bfloat16: true + use_pos_embed: false +model_kwargs: + checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitg_384.pt + module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip + pretrain_kwargs: + encoder: + checkpoint_key: target_encoder + img_temporal_dim_size: 1 + model_name: vit_giant_xformers + patch_size: 16 + tubelet_size: 2 + uniform_power: true + use_rope: true + wrapper_kwargs: + max_frames: 128 + use_pos_embed: false diff --git a/configs/eval_2_1/vitg-384/ssv2.yaml b/configs/eval_2_1/vitg-384/ssv2.yaml new file mode 100644 index 00000000..97c85f24 --- /dev/null +++ b/configs/eval_2_1/vitg-384/ssv2.yaml @@ -0,0 +1,164 @@ +cpus_per_task: 16 +eval_name: video_classification_frozen +folder: /your_folder/evals_2_1/vitg-384/ssv2 +mem_per_gpu: 220G +nodes: 16 +num_workers: 8 +resume_checkpoint: true +tag: ssv2-vitg16-384-64x2x3 +tasks_per_node: 8 +experiment: + classifier: + num_heads: 16 + num_probe_blocks: 4 + data: + dataset_type: VideoDataset + dataset_train: /your_data_path/ssv2_train_paths.csv + dataset_val: /your_data_path/ssv2_val_paths.csv + frame_step: 2 + frames_per_clip: 64 + num_classes: 174 + num_segments: 2 + num_views_per_segment: 3 + resolution: 384 + optimization: + batch_size: 1 + multihead_kwargs: + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.8 + num_epochs: 20 + use_bfloat16: true + use_pos_embed: false +model_kwargs: + checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitg_384.pt + module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip + pretrain_kwargs: + encoder: + checkpoint_key: target_encoder + img_temporal_dim_size: 1 + model_name: vit_giant_xformers + patch_size: 16 + tubelet_size: 2 + uniform_power: true + use_rope: true + wrapper_kwargs: + max_frames: 128 + use_pos_embed: false diff --git a/configs/eval_2_1/vitl-384/coin.yaml b/configs/eval_2_1/vitl-384/coin.yaml new file mode 100644 index 00000000..d4e68613 --- /dev/null +++ b/configs/eval_2_1/vitl-384/coin.yaml @@ -0,0 +1,163 @@ +cpus_per_task: 16 +eval_name: video_classification_frozen +folder: /your_folder/evals_2_1/vitl-384/coin +mem_per_gpu: 220G +nodes: 8 +resume_checkpoint: true +tag: coin-vitl16-384-16x8x3 +tasks_per_node: 8 +experiment: + classifier: + num_heads: 16 + num_probe_blocks: 4 + data: + dataset_type: VideoDataset + dataset_train: /your_data_folder/COIN/train_paths.csv + dataset_val: /your_data_folder/COIN/val_paths.csv + frame_step: 4 + frames_per_clip: 16 + num_classes: 180 + num_segments: 8 + num_views_per_segment: 3 + resolution: 384 + optimization: + batch_size: 2 + multihead_kwargs: + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.8 + num_epochs: 20 + use_bfloat16: true + use_pos_embed: false +model_kwargs: + checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitl_dist_vitG_384.pt + module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip + pretrain_kwargs: + encoder: + checkpoint_key: ema_encoder + img_temporal_dim_size: 1 + model_name: vit_large + patch_size: 16 + tubelet_size: 2 + uniform_power: true + use_rope: true + wrapper_kwargs: + max_frames: 128 + use_pos_embed: false diff --git a/configs/eval_2_1/vitl-384/diving48.yaml b/configs/eval_2_1/vitl-384/diving48.yaml new file mode 100644 index 00000000..144846bc --- /dev/null +++ b/configs/eval_2_1/vitl-384/diving48.yaml @@ -0,0 +1,62 @@ +nodes: 8 +tasks_per_node: 8 +cpus_per_task: 12 +mem_per_gpu: 200G +tag: diving48-vitl16-384-32x4x3 +eval_name: video_classification_frozen +folder: /your_folder/evals_2_1/vitl-384/diving48 +resume_checkpoint: true +experiment: + classifier: + num_probe_blocks: 4 + num_heads: 16 + data: + dataset_type: VideoDataset + dataset_train: /your_data_dir/diving48/annotations/Diving48_train_paths.csv + dataset_val: /your_data_dir/diving48/annotations/Diving48_test_paths.csv + num_classes: 48 + resolution: 384 + frames_per_clip: 32 + frame_step: 2 + num_segments: 4 + num_views_per_segment: 3 + optimization: + use_pos_embed: false + num_epochs: 100 + batch_size: 2 + use_bfloat16: true + multihead_kwargs: + - weight_decay: 0.8 + final_weight_decay: 0.8 + lr: 0.001 + start_lr: 0.001 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.8 + final_weight_decay: 0.8 + lr: 0.0003 + start_lr: 0.0003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.8 + final_weight_decay: 0.8 + lr: 0.0001 + start_lr: 0.0001 + final_lr: 0.0 + warmup: 0. +model_kwargs: + checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitl_dist_vitG_384.pt + module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip_multilevel + wrapper_kwargs: + max_frames: 128 + use_pos_embed: false + out_layers: [5, 11, 17, 23] + pretrain_kwargs: + encoder: + model_name: vit_large + checkpoint_key: ema_encoder + img_temporal_dim_size: 1 + tubelet_size: 2 + patch_size: 16 + uniform_power: true + use_rope: true diff --git a/configs/eval_2_1/vitl-384/ek100.yaml b/configs/eval_2_1/vitl-384/ek100.yaml new file mode 100644 index 00000000..c57e4126 --- /dev/null +++ b/configs/eval_2_1/vitl-384/ek100.yaml @@ -0,0 +1,202 @@ +nodes: 8 +tasks_per_node: 8 +cpus_per_task: 12 +tag: ek100-vitl16-384 +eval_name: action_anticipation_frozen +folder: /your_folder/evals_2_1/vitl-384/ek100 +resume_checkpoint: true +experiment: + classifier: + num_probe_blocks: 4 + num_heads: 16 + data: + anticipation_time_sec: + - 1.0 + - 1.0 + auto_augment: true + file_format: 0 + dataset: EK100 + # e.g. /home/username/EPIC-KITCHENS + base_path: /your_ek100_root_dir/ + dataset_train: /your_data/EPIC_100_train.csv + dataset_val: /your_data/EPIC_100_validation.csv + frames_per_clip: 32 + frames_per_second: 8 + motion_shift: false + num_workers: 2 + pin_memory: true + random_resize_scale: + - 0.08 + - 1.0 + reprob: 0.25 + resolution: 384 + train_anticipation_point: + - 0.0 + - 0.25 + train_anticipation_time_sec: + - 0.25 + - 1.75 + optimization: + num_epochs: 20 + batch_size: 2 + use_bfloat16: true + use_focal_loss: true + multihead_kwargs: + - weight_decay: 0.0001 + final_weight_decay: 0.0001 + lr: 0.005 + start_lr: 0.005 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.0001 + final_weight_decay: 0.0001 + lr: 0.003 + start_lr: 0.003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.0001 + final_weight_decay: 0.0001 + lr: 0.001 + start_lr: 0.001 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.0001 + final_weight_decay: 0.0001 + lr: 0.0003 + start_lr: 0.0003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.0001 + final_weight_decay: 0.0001 + lr: 0.0001 + start_lr: 0.0001 + final_lr: 0.0 + warmup: 0. + + - weight_decay: 0.001 + final_weight_decay: 0.001 + lr: 0.005 + start_lr: 0.005 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.001 + final_weight_decay: 0.001 + lr: 0.003 + start_lr: 0.003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.001 + final_weight_decay: 0.001 + lr: 0.001 + start_lr: 0.001 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.001 + final_weight_decay: 0.001 + lr: 0.0003 + start_lr: 0.0003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.001 + final_weight_decay: 0.001 + lr: 0.0001 + start_lr: 0.0001 + final_lr: 0.0 + warmup: 0. + + - weight_decay: 0.01 + final_weight_decay: 0.01 + lr: 0.005 + start_lr: 0.005 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.01 + final_weight_decay: 0.01 + lr: 0.003 + start_lr: 0.003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.01 + final_weight_decay: 0.01 + lr: 0.001 + start_lr: 0.001 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.01 + final_weight_decay: 0.01 + lr: 0.0003 + start_lr: 0.0003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.01 + final_weight_decay: 0.01 + lr: 0.0001 + start_lr: 0.0001 + final_lr: 0.0 + warmup: 0. + + - weight_decay: 0.1 + final_weight_decay: 0.1 + lr: 0.005 + start_lr: 0.005 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.1 + final_weight_decay: 0.1 + lr: 0.003 + start_lr: 0.003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.1 + final_weight_decay: 0.1 + lr: 0.001 + start_lr: 0.001 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.1 + final_weight_decay: 0.1 + lr: 0.0003 + start_lr: 0.0003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.1 + final_weight_decay: 0.1 + lr: 0.0001 + start_lr: 0.0001 + final_lr: 0.0 + warmup: 0. +model_kwargs: + checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitl_dist_vitG_384.pt + module_name: evals.action_anticipation_frozen.modelcustom.vit_encoder_predictor_concat_ar + wrapper_kwargs: + no_predictor: false + num_output_frames: 2 + num_steps: 1 + pretrain_kwargs: + use_v2_1: true + encoder: + model_name: vit_large + checkpoint_key: ema_encoder + img_temporal_dim_size: 1 + tubelet_size: 2 + patch_size: 16 + uniform_power: true + use_rope: true + predictor: + model_name: vit_predictor + checkpoint_key: predictor + num_frames: 64 + depth: 12 + num_heads: 12 + predictor_embed_dim: 384 + teacher_embed_dim: 1664 + num_mask_tokens: 8 + n_output_distillation: 1 + return_all_tokens: true + img_temporal_dim_size: 1 + uniform_power: true + use_mask_tokens: true + use_sdpa: true + use_silu: false + wide_silu: false + use_rope: true diff --git a/configs/eval_2_1/vitl-384/in1k.yaml b/configs/eval_2_1/vitl-384/in1k.yaml new file mode 100644 index 00000000..408d466a --- /dev/null +++ b/configs/eval_2_1/vitl-384/in1k.yaml @@ -0,0 +1,162 @@ +cpus_per_task: 16 +eval_name: image_classification_frozen +folder: /your_folder/evals_2_1/vitl-384/in1k +mem_per_gpu: 220G +nodes: 8 +resume_checkpoint: true +tag: in1k-vitl16-384-16f +tasks_per_node: 8 +experiment: + classifier: + num_heads: 16 + num_probe_blocks: 4 + data: + dataset_name: ImageNet + num_classes: 1000 + root_path: /your_data_path/ImageNet_FullSize/240712/061417/ + resolution: 384 + optimization: + batch_size: 16 + multihead_kwargs: + - final_lr: 0.0 + final_weight_decay: 0.0 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.001 + - final_lr: 0.0 + final_weight_decay: 0.008 + lr: 0.0005 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.008 + - final_lr: 0.0 + final_weight_decay: 0.004 + lr: 0.0005 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.004 + - final_lr: 0.0 + final_weight_decay: 0.002 + lr: 0.0005 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.002 + - final_lr: 0.0 + final_weight_decay: 0.001 + lr: 0.0005 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.001 + - final_lr: 0.0 + final_weight_decay: 0.0005 + lr: 0.0005 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.0005 + - final_lr: 0.0 + final_weight_decay: 0.008 + lr: 0.001 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.008 + - final_lr: 0.0 + final_weight_decay: 0.004 + lr: 0.001 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.004 + - final_lr: 0.0 + final_weight_decay: 0.002 + lr: 0.001 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.002 + - final_lr: 0.0 + final_weight_decay: 0.001 + lr: 0.001 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.001 + - final_lr: 0.0 + final_weight_decay: 0.0005 + lr: 0.001 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.0005 + - final_lr: 0.0 + final_weight_decay: 0.008 + lr: 0.0015 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.008 + - final_lr: 0.0 + final_weight_decay: 0.004 + lr: 0.0015 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.004 + - final_lr: 0.0 + final_weight_decay: 0.002 + lr: 0.0015 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.002 + - final_lr: 0.0 + final_weight_decay: 0.001 + lr: 0.0015 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.001 + - final_lr: 0.0 + final_weight_decay: 0.0005 + lr: 0.0015 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.0005 + - final_lr: 0.0 + final_weight_decay: 0.008 + lr: 0.002 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.008 + - final_lr: 0.0 + final_weight_decay: 0.004 + lr: 0.002 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.004 + - final_lr: 0.0 + final_weight_decay: 0.002 + lr: 0.002 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.002 + - final_lr: 0.0 + final_weight_decay: 0.001 + lr: 0.002 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.001 + - final_lr: 0.0 + final_weight_decay: 0.0005 + lr: 0.002 + start_lr: 0.0002 + warmup: 5 + weight_decay: 0.0005 + num_epochs: 20 + use_bfloat16: true +model_kwargs: + checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitl_dist_vitG_384.pt + module_name: evals.image_classification_frozen.modelcustom.vit_encoder + pretrain_kwargs: + encoder: + checkpoint_key: ema_encoder + img_temporal_dim_size: 1 + model_name: vit_large + patch_size: 16 + tubelet_size: 2 + uniform_power: true + use_rope: true + wrapper_kwargs: + img_as_video_nframes: 16 diff --git a/configs/eval_2_1/vitl-384/jester.yaml b/configs/eval_2_1/vitl-384/jester.yaml new file mode 100644 index 00000000..e00ca607 --- /dev/null +++ b/configs/eval_2_1/vitl-384/jester.yaml @@ -0,0 +1,62 @@ +nodes: 8 +tasks_per_node: 8 +cpus_per_task: 12 +mem_per_gpu: 200G +tag: jester-vitl16-384-32x4x3 +eval_name: video_classification_frozen +folder: /your_folder/evals_2_1/vitl-384/jester +resume_checkpoint: true +experiment: + classifier: + num_probe_blocks: 4 + num_heads: 16 + data: + dataset_type: VideoDataset + dataset_train: /your_data_dir/Jester/annotations/jester_train_paths.csv + dataset_val: /your_data_dir/Jester/annotations/jester_validation_paths.csv + num_classes: 27 + resolution: 384 + frames_per_clip: 32 + frame_step: 2 + num_segments: 4 + num_views_per_segment: 3 + optimization: + use_pos_embed: false + num_epochs: 100 + batch_size: 2 + use_bfloat16: true + multihead_kwargs: + - weight_decay: 0.8 + final_weight_decay: 0.8 + lr: 0.001 + start_lr: 0.001 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.8 + final_weight_decay: 0.8 + lr: 0.0003 + start_lr: 0.0003 + final_lr: 0.0 + warmup: 0. + - weight_decay: 0.8 + final_weight_decay: 0.8 + lr: 0.0001 + start_lr: 0.0001 + final_lr: 0.0 + warmup: 0. +model_kwargs: + checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitl_dist_vitG_384.pt + module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip_multilevel + wrapper_kwargs: + max_frames: 128 + use_pos_embed: false + out_layers: [5, 11, 17, 23] + pretrain_kwargs: + encoder: + model_name: vit_large + checkpoint_key: ema_encoder + img_temporal_dim_size: 1 + tubelet_size: 2 + patch_size: 16 + uniform_power: true + use_rope: true diff --git a/configs/eval_2_1/vitl-384/k400.yaml b/configs/eval_2_1/vitl-384/k400.yaml new file mode 100644 index 00000000..5cd8bd67 --- /dev/null +++ b/configs/eval_2_1/vitl-384/k400.yaml @@ -0,0 +1,164 @@ +cpus_per_task: 16 +eval_name: video_classification_frozen +folder: /your_folder/evals_2_1/vitl-384/k400 +mem_per_gpu: 220G +nodes: 8 +num_workers: 8 +resume_checkpoint: true +tag: k400-vitl16-384-16x8x3-16f +tasks_per_node: 8 +experiment: + classifier: + num_heads: 16 + num_probe_blocks: 4 + data: + dataset_type: VideoDataset + dataset_train: /your_data_path/k400_train_paths.csv + dataset_val: /your_data_path/k400_val_paths.csv + frame_step: 4 + frames_per_clip: 16 + num_classes: 400 + num_segments: 8 + num_views_per_segment: 3 + resolution: 384 + optimization: + batch_size: 4 + multihead_kwargs: + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.8 + num_epochs: 20 + use_bfloat16: true + use_pos_embed: false +model_kwargs: + checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitl_dist_vitG_384.pt + module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip + pretrain_kwargs: + encoder: + checkpoint_key: ema_encoder + img_temporal_dim_size: 1 + model_name: vit_large + patch_size: 16 + tubelet_size: 2 + uniform_power: true + use_rope: true + wrapper_kwargs: + max_frames: 128 + use_pos_embed: false diff --git a/configs/eval_2_1/vitl-384/ssv2.yaml b/configs/eval_2_1/vitl-384/ssv2.yaml new file mode 100644 index 00000000..38056e1d --- /dev/null +++ b/configs/eval_2_1/vitl-384/ssv2.yaml @@ -0,0 +1,164 @@ +cpus_per_task: 16 +eval_name: video_classification_frozen +folder: /your_folder/evals_2_1/vitl-384/ssv2 +mem_per_gpu: 220G +nodes: 8 +max_workers: 8 +resume_checkpoint: true +tag: ssv2-vitl16-384-16x2x3-16f +tasks_per_node: 8 +experiment: + classifier: + num_heads: 16 + num_probe_blocks: 4 + data: + dataset_type: VideoDataset + dataset_train: /your_data_path/ssv2_train_paths.csv + dataset_val: /your_data_path/ssv2_val_paths.csv + frame_step: 4 + frames_per_clip: 16 + num_classes: 174 + num_segments: 2 + num_views_per_segment: 3 + resolution: 384 + optimization: + batch_size: 4 + multihead_kwargs: + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.01 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.01 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.1 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.1 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.4 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.4 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.005 + start_lr: 0.005 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.003 + start_lr: 0.003 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.001 + start_lr: 0.001 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.0003 + start_lr: 0.0003 + warmup: 0.0 + weight_decay: 0.8 + - final_lr: 0.0 + final_weight_decay: 0.8 + lr: 0.0001 + start_lr: 0.0001 + warmup: 0.0 + weight_decay: 0.8 + num_epochs: 20 + use_bfloat16: true + use_pos_embed: false +model_kwargs: + checkpoint: /your_vjepa2_1_checkpoints/vjepa2_1_vitl_dist_vitG_384.pt + module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip + pretrain_kwargs: + encoder: + checkpoint_key: ema_encoder + img_temporal_dim_size: 1 + model_name: vit_large + patch_size: 16 + tubelet_size: 2 + uniform_power: true + use_rope: true + wrapper_kwargs: + max_frames: 128 + use_pos_embed: false diff --git a/configs/train_2_1/vitG16/cooldown-256px-64f.yaml b/configs/train_2_1/vitG16/cooldown-256px-64f.yaml new file mode 100644 index 00000000..eca1e592 --- /dev/null +++ b/configs/train_2_1/vitG16/cooldown-256px-64f.yaml @@ -0,0 +1,144 @@ +app: vjepa_2_1 +nodes: 64 +tasks_per_node: 8 +cpus_per_task: 64 +mem_per_gpu: 220G +folder: /your_folder/cooldown_2_1/16.8.vitG.256px.16f +data: + dataset_type: VideoDataset + datasets: + - /your_data/k710_train_paths.csv + - /your_data/ssv2_train_paths.csv + - /your_data/howto_320p.csv + datasets_weights: + - 0.335 + - 0.100 + - 0.565 + batch_size: 6 + crop_size: 256 + dataset_fpcs: + - 64 + - 64 + - 64 + fps: 4 + num_workers: 12 + patch_size: 16 + persistent_workers: true + pin_mem: false + tubelet_size: 2 +data_aug: + auto_augment: false + motion_shift: false + random_resize_aspect_ratio: + - 0.75 + - 1.35 + random_resize_scale: + - 0.3 + - 1.0 + reprob: 0.0 + +img_data: + batch_size: 36 + crop_size: 256 + dataset_fpcs: + - 1 + dataset_type: VideoDataset + datasets: + - /your_data/imagenet1k.csv + datasets_weights: + - 1.0 + rank_ratio: 0.5 +img_mask: +- aspect_ratio: + - 0.75 + - 1.5 + full_complement: false + max_keep: null + max_temporal_keep: 1.0 + num_blocks: 10 + spatial_scale: + - 0.15 + - 0.15 + temporal_scale: + - 1.0 + - 1.0 + +loss: + computing_gram_loss: false + gram_HighRes: false + gram_ckpt: None + gram_loss_weight: 10.0 + loss_exp: 1.0 + predict_all: true + shift_by_n: 0 + weight_distance_loss: false +mask: +- aspect_ratio: + - 0.75 + - 1.5 + full_complement: false + max_keep: null + max_temporal_keep: 1.0 + num_blocks: 8 + spatial_scale: + - 0.15 + - 0.15 + temporal_scale: + - 1.0 + - 1.0 +- aspect_ratio: + - 0.75 + - 1.5 + full_complement: false + max_keep: null + max_temporal_keep: 1.0 + num_blocks: 2 + spatial_scale: + - 0.7 + - 0.7 + temporal_scale: + - 1.0 + - 1.0 +meta: + dtype: bfloat16 + eval_freq: 100 + load_checkpoint: true + read_checkpoint: null + save_every_freq: 50 + seed: 239 + use_sdpa: true +model: + has_cls_first: false + img_temporal_dim_size: 1 + interpolate_rope: true + lambda_progressive: false + lambda_value_img: 0.5 + lambda_value_vid: 0.5 + modality_embedding: true + model_name: vit_gigantic_xformers + n_registers: 0 + normalize_predictor: false + pred_depth: 24 + pred_embed_dim: 384 + pred_num_heads: 12 + uniform_power: true + use_activation_checkpointing: true + use_mask_tokens: true + use_rope: true + zero_init_mask_tokens: true +optimization: + anneal_ckpt: /your_folder/pretrain_2_1/16.8.vitG.256px.16f/latest.pth.tar + ema: + - 0.99925 + - 0.99925 + epochs: 40 + final_lr: 1.0e-06 + final_weight_decay: 0.04 + ipe: 300 + ipe_scale: 1.25 + is_anneal: true + lr: 0.0006 + resume_anneal: true + start_lr: 0.0001 + warmup: 0 + weight_decay: 0.04 diff --git a/configs/train_2_1/vitG16/pretrain-256px-16f.yaml b/configs/train_2_1/vitG16/pretrain-256px-16f.yaml new file mode 100644 index 00000000..288c045f --- /dev/null +++ b/configs/train_2_1/vitG16/pretrain-256px-16f.yaml @@ -0,0 +1,149 @@ +app: vjepa_2_1 +nodes: 32 +tasks_per_node: 8 +cpus_per_task: 32 +mem_per_gpu: 220G +folder: /your_folder/pretrain_2_1/16.8.vitG.256px.16f +data: + dataset_type: VideoDataset + datasets: + - /your_data/k710_train_paths.csv + - /your_data/ssv2_train_paths.csv + - /your_data/howto_320p.csv + datasets_weights: + - 0.335 + - 0.100 + - 0.565 + batch_size: 12 + crop_size: 256 + patch_size: 16 + dataset_fpcs: + - 16 + - 16 + - 16 + tubelet_size: 2 + fps: 4 + num_workers: 8 + persistent_workers: true + pin_mem: true +data_aug: + auto_augment: false + motion_shift: false + random_resize_aspect_ratio: + - 0.75 + - 1.35 + random_resize_scale: + - 0.3 + - 1.0 + reprob: 0.0 +loss: + loss_exp: 1.0 + predict_all: true + reg_coeff: 0.0 + shift_by_n: 0 + weight_distance_loss: true + +img_data: + batch_size: 36 + crop_size: 256 + dataset_fpcs: + - 1 + dataset_type: VideoDataset + datasets: + - /your_data/imagenet1k.csv + datasets_weights: + - 1.0 + rank_ratio: 0.5 +img_mask: +- aspect_ratio: + - 0.75 + - 1.5 + full_complement: false + max_keep: null + max_temporal_keep: 1.0 + num_blocks: 10 + spatial_scale: + - 0.15 + - 0.15 + temporal_scale: + - 1.0 + - 1.0 + +mask: +- aspect_ratio: + - 0.75 + - 1.5 + full_complement: false + max_keep: null + max_temporal_keep: 1.0 + num_blocks: 8 + spatial_scale: + - 0.15 + - 0.15 + temporal_scale: + - 1.0 + - 1.0 +- aspect_ratio: + - 0.75 + - 1.5 + full_complement: false + max_keep: null + max_temporal_keep: 1.0 + num_blocks: 2 + spatial_scale: + - 0.7 + - 0.7 + temporal_scale: + - 1.0 + - 1.0 +meta: + dtype: bfloat16 + eval_freq: 100 + load_checkpoint: true + read_checkpoint: null + save_every_freq: 50 + seed: 239 + use_sdpa: true +model: + has_cls_first: false + img_temporal_dim_size: 1 + interpolate_rope: true + is_causal: false + lambda_value_img: 0.5 + lambda_value_vid: 0.5 + local_window: + - -1 + - -1 + - -1 + modality_embedding: true + model_name: vit_gigantic_xformers + n_registers: 0 + n_registers_predictor: 0 + normalize_predictor: false + pred_depth: 24 + pred_embed_dim: 384 + pred_is_causal: false + pred_local_window: + - -1 + - -1 + - -1 + pred_num_heads: 12 + uniform_power: true + use_activation_checkpointing: true + use_mask_tokens: true + use_rope: true + vit_conv: false + zero_init_mask_tokens: true +optimization: + ema: + - 0.99925 + - 0.99925 + epochs: 1000 + final_lr: 0.0006 + final_weight_decay: 0.04 + ipe: 300 + ipe_scale: 1.25 + lr: 0.0006 + start_lr: 0.0001 + warmup: 40 + weight_decay: 0.04 diff --git a/configs/train_2_1/vitb16/cooldown-256px-64f.yaml b/configs/train_2_1/vitb16/cooldown-256px-64f.yaml new file mode 100644 index 00000000..343ae9ba --- /dev/null +++ b/configs/train_2_1/vitb16/cooldown-256px-64f.yaml @@ -0,0 +1,144 @@ +app: vjepa_2_1 +nodes: 16 +tasks_per_node: 8 +cpus_per_task: 16 +mem_per_gpu: 220G +folder: /your_folder/cooldown_2_1/16.8.vitb.256px.16f +data: + dataset_type: VideoDataset + datasets: + - /your_data/k710_train_paths.csv + - /your_data/ssv2_train_paths.csv + - /your_data/howto_320p.csv + datasets_weights: + - 0.335 + - 0.100 + - 0.565 + batch_size: 24 + crop_size: 256 + dataset_fpcs: + - 64 + - 64 + - 64 + fps: 4 + num_workers: 12 + patch_size: 16 + persistent_workers: true + pin_mem: false + tubelet_size: 2 +data_aug: + auto_augment: false + motion_shift: false + random_resize_aspect_ratio: + - 0.75 + - 1.35 + random_resize_scale: + - 0.3 + - 1.0 + reprob: 0.0 + +img_data: + batch_size: 144 + crop_size: 256 + dataset_fpcs: + - 1 + dataset_type: VideoDataset + datasets: + - /your_data/imagenet1k.csv + datasets_weights: + - 1.0 + rank_ratio: 0.5 +img_mask: +- aspect_ratio: + - 0.75 + - 1.5 + full_complement: false + max_keep: null + max_temporal_keep: 1.0 + num_blocks: 10 + spatial_scale: + - 0.15 + - 0.15 + temporal_scale: + - 1.0 + - 1.0 + +loss: + computing_gram_loss: false + gram_HighRes: false + gram_ckpt: None + gram_loss_weight: 10.0 + loss_exp: 1.0 + predict_all: true + shift_by_n: 0 + weight_distance_loss: false +mask: +- aspect_ratio: + - 0.75 + - 1.5 + full_complement: false + max_keep: null + max_temporal_keep: 1.0 + num_blocks: 8 + spatial_scale: + - 0.15 + - 0.15 + temporal_scale: + - 1.0 + - 1.0 +- aspect_ratio: + - 0.75 + - 1.5 + full_complement: false + max_keep: null + max_temporal_keep: 1.0 + num_blocks: 2 + spatial_scale: + - 0.7 + - 0.7 + temporal_scale: + - 1.0 + - 1.0 +meta: + dtype: bfloat16 + eval_freq: 100 + load_checkpoint: true + read_checkpoint: null + save_every_freq: 50 + seed: 239 + use_sdpa: true +model: + has_cls_first: false + img_temporal_dim_size: 1 + interpolate_rope: true + lambda_progressive: false + lambda_value_img: 0.5 + lambda_value_vid: 0.5 + modality_embedding: true + model_name: vit_base + n_registers: 0 + normalize_predictor: false + pred_depth: 12 + pred_embed_dim: 384 + pred_num_heads: 12 + uniform_power: true + use_activation_checkpointing: true + use_mask_tokens: true + use_rope: true + zero_init_mask_tokens: true +optimization: + anneal_ckpt: /your_folder/pretrain_2_1/16.8.vitb.256px.16f/latest.pth.tar + ema: + - 0.99925 + - 0.99925 + epochs: 40 + final_lr: 1.0e-06 + final_weight_decay: 0.04 + ipe: 300 + ipe_scale: 1.25 + is_anneal: true + lr: 0.0006 + resume_anneal: true + start_lr: 0.0001 + warmup: 0 + weight_decay: 0.04 diff --git a/configs/train_2_1/vitb16/pretrain-256px-16f.yaml b/configs/train_2_1/vitb16/pretrain-256px-16f.yaml new file mode 100644 index 00000000..dca536f3 --- /dev/null +++ b/configs/train_2_1/vitb16/pretrain-256px-16f.yaml @@ -0,0 +1,149 @@ +app: vjepa_2_1 +nodes: 8 +tasks_per_node: 8 +cpus_per_task: 8 +mem_per_gpu: 220G +folder: /your_folder/pretrain_2_1/16.8.vitb.256px.16f +data: + dataset_type: VideoDataset + datasets: + - /your_data/k710_train_paths.csv + - /your_data/ssv2_train_paths.csv + - /your_data/howto_320p.csv + datasets_weights: + - 0.335 + - 0.100 + - 0.565 + batch_size: 48 + crop_size: 256 + patch_size: 16 + dataset_fpcs: + - 16 + - 16 + - 16 + tubelet_size: 2 + fps: 4 + num_workers: 8 + persistent_workers: true + pin_mem: true +data_aug: + auto_augment: false + motion_shift: false + random_resize_aspect_ratio: + - 0.75 + - 1.35 + random_resize_scale: + - 0.3 + - 1.0 + reprob: 0.0 +loss: + loss_exp: 1.0 + predict_all: true + reg_coeff: 0.0 + shift_by_n: 0 + weight_distance_loss: true + +img_data: + batch_size: 144 + crop_size: 256 + dataset_fpcs: + - 1 + dataset_type: VideoDataset + datasets: + - /your_data/imagenet1k.csv + datasets_weights: + - 1.0 + rank_ratio: 0.5 +img_mask: +- aspect_ratio: + - 0.75 + - 1.5 + full_complement: false + max_keep: null + max_temporal_keep: 1.0 + num_blocks: 10 + spatial_scale: + - 0.15 + - 0.15 + temporal_scale: + - 1.0 + - 1.0 + +mask: +- aspect_ratio: + - 0.75 + - 1.5 + full_complement: false + max_keep: null + max_temporal_keep: 1.0 + num_blocks: 8 + spatial_scale: + - 0.15 + - 0.15 + temporal_scale: + - 1.0 + - 1.0 +- aspect_ratio: + - 0.75 + - 1.5 + full_complement: false + max_keep: null + max_temporal_keep: 1.0 + num_blocks: 2 + spatial_scale: + - 0.7 + - 0.7 + temporal_scale: + - 1.0 + - 1.0 +meta: + dtype: bfloat16 + eval_freq: 100 + load_checkpoint: true + read_checkpoint: null + save_every_freq: 50 + seed: 239 + use_sdpa: true +model: + has_cls_first: false + img_temporal_dim_size: 1 + interpolate_rope: true + is_causal: false + lambda_value_img: 0.5 + lambda_value_vid: 0.5 + local_window: + - -1 + - -1 + - -1 + modality_embedding: true + model_name: vit_base + n_registers: 0 + n_registers_predictor: 0 + normalize_predictor: false + pred_depth: 12 + pred_embed_dim: 384 + pred_is_causal: false + pred_local_window: + - -1 + - -1 + - -1 + pred_num_heads: 12 + uniform_power: true + use_activation_checkpointing: true + use_mask_tokens: true + use_rope: true + vit_conv: false + zero_init_mask_tokens: true +optimization: + ema: + - 0.99925 + - 0.99925 + epochs: 1000 + final_lr: 0.0006 + final_weight_decay: 0.04 + ipe: 300 + ipe_scale: 1.25 + lr: 0.0006 + start_lr: 0.0001 + warmup: 40 + weight_decay: 0.04 diff --git a/configs/train_2_1/vitg16/cooldown-256px-64f.yaml b/configs/train_2_1/vitg16/cooldown-256px-64f.yaml new file mode 100644 index 00000000..c7a40962 --- /dev/null +++ b/configs/train_2_1/vitg16/cooldown-256px-64f.yaml @@ -0,0 +1,144 @@ +app: vjepa_2_1 +nodes: 32 +tasks_per_node: 8 +cpus_per_task: 32 +mem_per_gpu: 220G +folder: /your_folder/cooldown_2_1/16.8.vitg.256px.16f +data: + dataset_type: VideoDataset + datasets: + - /your_data/k710_train_paths.csv + - /your_data/ssv2_train_paths.csv + - /your_data/howto_320p.csv + datasets_weights: + - 0.335 + - 0.100 + - 0.565 + batch_size: 12 + crop_size: 256 + dataset_fpcs: + - 64 + - 64 + - 64 + fps: 4 + num_workers: 12 + patch_size: 16 + persistent_workers: true + pin_mem: false + tubelet_size: 2 +data_aug: + auto_augment: false + motion_shift: false + random_resize_aspect_ratio: + - 0.75 + - 1.35 + random_resize_scale: + - 0.3 + - 1.0 + reprob: 0.0 + +img_data: + batch_size: 72 + crop_size: 256 + dataset_fpcs: + - 1 + dataset_type: VideoDataset + datasets: + - /your_data/imagenet1k.csv + datasets_weights: + - 1.0 + rank_ratio: 0.5 +img_mask: +- aspect_ratio: + - 0.75 + - 1.5 + full_complement: false + max_keep: null + max_temporal_keep: 1.0 + num_blocks: 10 + spatial_scale: + - 0.15 + - 0.15 + temporal_scale: + - 1.0 + - 1.0 + +loss: + computing_gram_loss: false + gram_HighRes: false + gram_ckpt: None + gram_loss_weight: 10.0 + loss_exp: 1.0 + predict_all: true + shift_by_n: 0 + weight_distance_loss: false +mask: +- aspect_ratio: + - 0.75 + - 1.5 + full_complement: false + max_keep: null + max_temporal_keep: 1.0 + num_blocks: 8 + spatial_scale: + - 0.15 + - 0.15 + temporal_scale: + - 1.0 + - 1.0 +- aspect_ratio: + - 0.75 + - 1.5 + full_complement: false + max_keep: null + max_temporal_keep: 1.0 + num_blocks: 2 + spatial_scale: + - 0.7 + - 0.7 + temporal_scale: + - 1.0 + - 1.0 +meta: + dtype: bfloat16 + eval_freq: 100 + load_checkpoint: true + read_checkpoint: null + save_every_freq: 50 + seed: 239 + use_sdpa: true +model: + has_cls_first: false + img_temporal_dim_size: 1 + interpolate_rope: true + lambda_progressive: false + lambda_value_img: 0.5 + lambda_value_vid: 0.5 + modality_embedding: true + model_name: vit_giant_xformers + n_registers: 0 + normalize_predictor: false + pred_depth: 24 + pred_embed_dim: 384 + pred_num_heads: 12 + uniform_power: true + use_activation_checkpointing: true + use_mask_tokens: true + use_rope: true + zero_init_mask_tokens: true +optimization: + anneal_ckpt: /your_folder/pretrain_2_1/16.8.vitg.256px.16f/latest.pth.tar + ema: + - 0.99925 + - 0.99925 + epochs: 40 + final_lr: 1.0e-06 + final_weight_decay: 0.04 + ipe: 300 + ipe_scale: 1.25 + is_anneal: true + lr: 0.0006 + resume_anneal: true + start_lr: 0.0001 + warmup: 0 + weight_decay: 0.04 diff --git a/configs/train_2_1/vitg16/pretrain-256px-16f.yaml b/configs/train_2_1/vitg16/pretrain-256px-16f.yaml new file mode 100644 index 00000000..170cd2f5 --- /dev/null +++ b/configs/train_2_1/vitg16/pretrain-256px-16f.yaml @@ -0,0 +1,149 @@ +app: vjepa_2_1 +nodes: 16 +tasks_per_node: 8 +cpus_per_task: 16 +mem_per_gpu: 220G +folder: /your_folder/pretrain_2_1/16.8.vitg.256px.16f +data: + dataset_type: VideoDataset + datasets: + - /your_data/k710_train_paths.csv + - /your_data/ssv2_train_paths.csv + - /your_data/howto_320p.csv + datasets_weights: + - 0.335 + - 0.100 + - 0.565 + batch_size: 24 + crop_size: 256 + patch_size: 16 + dataset_fpcs: + - 16 + - 16 + - 16 + tubelet_size: 2 + fps: 4 + num_workers: 8 + persistent_workers: true + pin_mem: true +data_aug: + auto_augment: false + motion_shift: false + random_resize_aspect_ratio: + - 0.75 + - 1.35 + random_resize_scale: + - 0.3 + - 1.0 + reprob: 0.0 +loss: + loss_exp: 1.0 + predict_all: true + reg_coeff: 0.0 + shift_by_n: 0 + weight_distance_loss: true + +img_data: + batch_size: 72 + crop_size: 256 + dataset_fpcs: + - 1 + dataset_type: VideoDataset + datasets: + - /your_data/imagenet1k.csv + datasets_weights: + - 1.0 + rank_ratio: 0.5 +img_mask: +- aspect_ratio: + - 0.75 + - 1.5 + full_complement: false + max_keep: null + max_temporal_keep: 1.0 + num_blocks: 10 + spatial_scale: + - 0.15 + - 0.15 + temporal_scale: + - 1.0 + - 1.0 + +mask: +- aspect_ratio: + - 0.75 + - 1.5 + full_complement: false + max_keep: null + max_temporal_keep: 1.0 + num_blocks: 8 + spatial_scale: + - 0.15 + - 0.15 + temporal_scale: + - 1.0 + - 1.0 +- aspect_ratio: + - 0.75 + - 1.5 + full_complement: false + max_keep: null + max_temporal_keep: 1.0 + num_blocks: 2 + spatial_scale: + - 0.7 + - 0.7 + temporal_scale: + - 1.0 + - 1.0 +meta: + dtype: bfloat16 + eval_freq: 100 + load_checkpoint: true + read_checkpoint: null + save_every_freq: 50 + seed: 239 + use_sdpa: true +model: + has_cls_first: false + img_temporal_dim_size: 1 + interpolate_rope: true + is_causal: false + lambda_value_img: 0.5 + lambda_value_vid: 0.5 + local_window: + - -1 + - -1 + - -1 + modality_embedding: true + model_name: vit_giant_xformers + n_registers: 0 + n_registers_predictor: 0 + normalize_predictor: false + pred_depth: 24 + pred_embed_dim: 384 + pred_is_causal: false + pred_local_window: + - -1 + - -1 + - -1 + pred_num_heads: 12 + uniform_power: true + use_activation_checkpointing: true + use_mask_tokens: true + use_rope: true + vit_conv: false + zero_init_mask_tokens: true +optimization: + ema: + - 0.99925 + - 0.99925 + epochs: 1000 + final_lr: 0.0006 + final_weight_decay: 0.04 + ipe: 300 + ipe_scale: 1.25 + lr: 0.0006 + start_lr: 0.0001 + warmup: 40 + weight_decay: 0.04 diff --git a/configs/train_2_1/vitl16/cooldown-256px-64f.yaml b/configs/train_2_1/vitl16/cooldown-256px-64f.yaml new file mode 100644 index 00000000..73308856 --- /dev/null +++ b/configs/train_2_1/vitl16/cooldown-256px-64f.yaml @@ -0,0 +1,144 @@ +app: vjepa_2_1 +nodes: 32 +tasks_per_node: 8 +cpus_per_task: 32 +mem_per_gpu: 220G +folder: /your_folder/cooldown_2_1/16.8.vitl.256px.16f +data: + dataset_type: VideoDataset + datasets: + - /your_data/k710_train_paths.csv + - /your_data/ssv2_train_paths.csv + - /your_data/howto_320p.csv + datasets_weights: + - 0.335 + - 0.100 + - 0.565 + batch_size: 12 + crop_size: 256 + dataset_fpcs: + - 64 + - 64 + - 64 + fps: 4 + num_workers: 12 + patch_size: 16 + persistent_workers: true + pin_mem: false + tubelet_size: 2 +data_aug: + auto_augment: false + motion_shift: false + random_resize_aspect_ratio: + - 0.75 + - 1.35 + random_resize_scale: + - 0.3 + - 1.0 + reprob: 0.0 + +img_data: + batch_size: 72 + crop_size: 256 + dataset_fpcs: + - 1 + dataset_type: VideoDataset + datasets: + - /your_data/imagenet1k.csv + datasets_weights: + - 1.0 + rank_ratio: 0.5 +img_mask: +- aspect_ratio: + - 0.75 + - 1.5 + full_complement: false + max_keep: null + max_temporal_keep: 1.0 + num_blocks: 10 + spatial_scale: + - 0.15 + - 0.15 + temporal_scale: + - 1.0 + - 1.0 + +loss: + computing_gram_loss: false + gram_HighRes: false + gram_ckpt: None + gram_loss_weight: 10.0 + loss_exp: 1.0 + predict_all: true + shift_by_n: 0 + weight_distance_loss: false +mask: +- aspect_ratio: + - 0.75 + - 1.5 + full_complement: false + max_keep: null + max_temporal_keep: 1.0 + num_blocks: 8 + spatial_scale: + - 0.15 + - 0.15 + temporal_scale: + - 1.0 + - 1.0 +- aspect_ratio: + - 0.75 + - 1.5 + full_complement: false + max_keep: null + max_temporal_keep: 1.0 + num_blocks: 2 + spatial_scale: + - 0.7 + - 0.7 + temporal_scale: + - 1.0 + - 1.0 +meta: + dtype: bfloat16 + eval_freq: 100 + load_checkpoint: true + read_checkpoint: null + save_every_freq: 50 + seed: 239 + use_sdpa: true +model: + has_cls_first: false + img_temporal_dim_size: 1 + interpolate_rope: true + lambda_progressive: false + lambda_value_img: 0.5 + lambda_value_vid: 0.5 + modality_embedding: true + model_name: vit_large + n_registers: 0 + normalize_predictor: false + pred_depth: 24 + pred_embed_dim: 384 + pred_num_heads: 12 + uniform_power: true + use_activation_checkpointing: true + use_mask_tokens: true + use_rope: true + zero_init_mask_tokens: true +optimization: + anneal_ckpt: /your_folder/pretrain_2_1/16.8.vitl.256px.16f/latest.pth.tar + ema: + - 0.99925 + - 0.99925 + epochs: 40 + final_lr: 1.0e-06 + final_weight_decay: 0.04 + ipe: 300 + ipe_scale: 1.25 + is_anneal: true + lr: 0.0006 + resume_anneal: true + start_lr: 0.0001 + warmup: 0 + weight_decay: 0.04 diff --git a/configs/train_2_1/vitl16/pretrain-256px-16f.yaml b/configs/train_2_1/vitl16/pretrain-256px-16f.yaml new file mode 100644 index 00000000..3988cf9a --- /dev/null +++ b/configs/train_2_1/vitl16/pretrain-256px-16f.yaml @@ -0,0 +1,149 @@ +app: vjepa_2_1 +nodes: 16 +tasks_per_node: 8 +cpus_per_task: 16 +mem_per_gpu: 220G +folder: /your_folder/pretrain_2_1/16.8.vitl.256px.16f +data: + dataset_type: VideoDataset + datasets: + - /your_data/k710_train_paths.csv + - /your_data/ssv2_train_paths.csv + - /your_data/howto_320p.csv + datasets_weights: + - 0.335 + - 0.100 + - 0.565 + batch_size: 24 + crop_size: 256 + patch_size: 16 + dataset_fpcs: + - 16 + - 16 + - 16 + tubelet_size: 2 + fps: 4 + num_workers: 8 + persistent_workers: true + pin_mem: true +data_aug: + auto_augment: false + motion_shift: false + random_resize_aspect_ratio: + - 0.75 + - 1.35 + random_resize_scale: + - 0.3 + - 1.0 + reprob: 0.0 +loss: + loss_exp: 1.0 + predict_all: true + reg_coeff: 0.0 + shift_by_n: 0 + weight_distance_loss: true + +img_data: + batch_size: 72 + crop_size: 256 + dataset_fpcs: + - 1 + dataset_type: VideoDataset + datasets: + - /your_data/imagenet1k.csv + datasets_weights: + - 1.0 + rank_ratio: 0.5 +img_mask: +- aspect_ratio: + - 0.75 + - 1.5 + full_complement: false + max_keep: null + max_temporal_keep: 1.0 + num_blocks: 10 + spatial_scale: + - 0.15 + - 0.15 + temporal_scale: + - 1.0 + - 1.0 + +mask: +- aspect_ratio: + - 0.75 + - 1.5 + full_complement: false + max_keep: null + max_temporal_keep: 1.0 + num_blocks: 8 + spatial_scale: + - 0.15 + - 0.15 + temporal_scale: + - 1.0 + - 1.0 +- aspect_ratio: + - 0.75 + - 1.5 + full_complement: false + max_keep: null + max_temporal_keep: 1.0 + num_blocks: 2 + spatial_scale: + - 0.7 + - 0.7 + temporal_scale: + - 1.0 + - 1.0 +meta: + dtype: bfloat16 + eval_freq: 100 + load_checkpoint: true + read_checkpoint: null + save_every_freq: 50 + seed: 239 + use_sdpa: true +model: + has_cls_first: false + img_temporal_dim_size: 1 + interpolate_rope: true + is_causal: false + lambda_value_img: 0.5 + lambda_value_vid: 0.5 + local_window: + - -1 + - -1 + - -1 + modality_embedding: true + model_name: vit_large + n_registers: 0 + n_registers_predictor: 0 + normalize_predictor: false + pred_depth: 24 + pred_embed_dim: 384 + pred_is_causal: false + pred_local_window: + - -1 + - -1 + - -1 + pred_num_heads: 12 + uniform_power: true + use_activation_checkpointing: true + use_mask_tokens: true + use_rope: true + vit_conv: false + zero_init_mask_tokens: true +optimization: + ema: + - 0.99925 + - 0.99925 + epochs: 1000 + final_lr: 0.0006 + final_weight_decay: 0.04 + ipe: 300 + ipe_scale: 1.25 + lr: 0.0006 + start_lr: 0.0001 + warmup: 40 + weight_decay: 0.04 diff --git a/evals/action_anticipation_frozen/eval.py b/evals/action_anticipation_frozen/eval.py index 9a15ee04..5dcce775 100644 --- a/evals/action_anticipation_frozen/eval.py +++ b/evals/action_anticipation_frozen/eval.py @@ -743,7 +743,7 @@ def load_checkpoint(device, r_path, classifiers, opt, scaler, val_only=False): [o.load_state_dict(c) for o, c in zip(opt, checkpoint["opt"])] if scaler is not None: - [s.load_state_dict(c) for s, c in zip(opt, checkpoint["scaler"])] + [s.load_state_dict(c) for s, c in zip(scaler, checkpoint["scaler"])] logger.info(f"loaded optimizers from epoch {epoch}") return classifiers, opt, scaler, epoch diff --git a/evals/action_anticipation_frozen/modelcustom/vit_encoder_predictor_concat_ar.py b/evals/action_anticipation_frozen/modelcustom/vit_encoder_predictor_concat_ar.py index 90424702..975a908b 100644 --- a/evals/action_anticipation_frozen/modelcustom/vit_encoder_predictor_concat_ar.py +++ b/evals/action_anticipation_frozen/modelcustom/vit_encoder_predictor_concat_ar.py @@ -26,14 +26,23 @@ import torch -import src.models.predictor as vit_pred -import src.models.vision_transformer as vit - logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) +def _get_model_modules(pretrain_kwargs): + """Import encoder/predictor modules based on config.""" + use_v2_1 = pretrain_kwargs.get("use_v2_1", False) + if use_v2_1: + import app.vjepa_2_1.models.predictor as vit_pred + import app.vjepa_2_1.models.vision_transformer as vit + else: + import src.models.predictor as vit_pred + import src.models.vision_transformer as vit + return vit, vit_pred + + def init_module( frames_per_clip: int, frames_per_second: int, @@ -50,6 +59,8 @@ def init_module( # ----------------------------------------------------------------------- # # Initialize Encoder # ----------------------------------------------------------------------- # + vit, vit_pred = _get_model_modules(model_kwargs) + enc_kwargs = model_kwargs["encoder"] enc_ckp_key = enc_kwargs.get("checkpoint_key") enc_model_name = enc_kwargs.get("model_name") @@ -76,13 +87,21 @@ def init_module( prd_ckp_key = prd_kwargs.get("checkpoint_key") prd_model_name = prd_kwargs.get("model_name") + teacher_embed_dim = prd_kwargs.get("teacher_embed_dim") + n_output_distillation = prd_kwargs.get("n_output_distillation", 4) + prd_out_embed_dim = teacher_embed_dim // n_output_distillation if teacher_embed_dim is not None else None + print(f"[DEBUG] teacher_embed_dim={teacher_embed_dim}, n_output_distillation={n_output_distillation}, prd_out_embed_dim={prd_out_embed_dim}") + print(f"[DEBUG] vit_pred module: {vit_pred.__name__}") + predictor = vit_pred.__dict__[prd_model_name]( img_size=resolution, embed_dim=encoder.embed_dim, patch_size=encoder.patch_size, tubelet_size=encoder.tubelet_size, + out_embed_dim=prd_out_embed_dim, **prd_kwargs, ) + print(f"[DEBUG] predictor_proj: {predictor.predictor_proj}") pretrained_dict = checkpoint[prd_ckp_key] # -- pretrained_dict = {k.replace("module.", ""): v for k, v in pretrained_dict.items()} @@ -92,7 +111,7 @@ def init_module( logger.info(f'key "{k}" could not be found in loaded state dict') elif pretrained_dict[k].shape != v.shape: logger.info( - f'key "{k}" is of different shape in model and loaded state dict', pretrained_dict[k].shape, v.shape + f'key "{k}" is of different shape in model and loaded state dict: {pretrained_dict[k].shape} vs {v.shape}' ) pretrained_dict[k] = v msg = predictor.load_state_dict(pretrained_dict, strict=False) @@ -113,6 +132,10 @@ def init_module( ) model.embed_dim = encoder.embed_dim + # Enable hierarchical feature output for non-distilled predictors + if hasattr(predictor, 'hierarchical_layers') and len(predictor.hierarchical_layers) > 1: + encoder.return_hierarchical = True + return model @@ -151,20 +174,24 @@ def forward(self, x, anticipation_times): :param x: (Tensor) video of shape [B, C, T, H, W] :param anticipation_time: (Tensor) [B] seconds into the future to predict for each sample in batch """ - x = self.encoder(x) + x_full = self.encoder(x) - # determine 1D position of context tokens (x) - # determine 1D position of prediction tokens - # forward predictor with ctxt=x, tgt=None, masks_ctxt, masks_tgt if self.no_predictor: - return x + return x_full + + B, N, D_full = x_full.size() + embed_dim = self.encoder.embed_dim + use_hierarchical = D_full > embed_dim - # Will output representations of $num_output_frames, that are - # $anticipation_time seconds into the future. - B, N, D = x.size() + # For the accumulator/classifier, use last-layer features (embed_dim) + # For the predictor, use full hierarchical features (D_full) + if use_hierarchical: + x = x_full[:, :, -embed_dim:] + else: + x = x_full if self.no_encoder: - x_accumulate = torch.rand(x.size(0), 0, x.size(2)).to(x.device) + x_accumulate = torch.rand(B, 0, embed_dim).to(x.device) else: x_accumulate = x.clone() @@ -180,9 +207,18 @@ def forward(self, x, anticipation_times): tgt_positions = torch.arange(N_pred).unsqueeze(0).repeat(B, 1).to(x.device) tgt_positions += skip_positions.unsqueeze(1).repeat(1, N_pred) + x_pred_input = x_full for _ in range(self.num_steps): - x_pred = self.predictor(x, masks_x=ctxt_positions, masks_y=tgt_positions) + pred_out = self.predictor(x_pred_input, masks_x=ctxt_positions, masks_y=tgt_positions) + x_pred_full = pred_out[0] if isinstance(pred_out, tuple) else pred_out + + if x_pred_full.size(-1) != embed_dim: + x_pred = x_pred_full[:, :, -embed_dim:] + else: + x_pred = x_pred_full + x_accumulate = torch.cat([x_accumulate, x_pred], dim=1) - x = torch.cat([x[:, N_pred:, :], x_pred], dim=1) + x_pred_for_input = x_pred_full if x_pred_full.size(-1) == x_pred_input.size(-1) else x_pred + x_pred_input = torch.cat([x_pred_input[:, N_pred:, :], x_pred_for_input], dim=1) return x_accumulate diff --git a/evals/image_classification_frozen/eval.py b/evals/image_classification_frozen/eval.py index df0bf958..9017cd25 100644 --- a/evals/image_classification_frozen/eval.py +++ b/evals/image_classification_frozen/eval.py @@ -81,7 +81,6 @@ def main(args_eval, resume_preempt=False): dataset_name = args_data.get("dataset_name") num_classes = args_data.get("num_classes") root_path = args_data.get("root_path", None) - image_folder = args_data.get("image_folder", None) resolution = args_data.get("resolution", 224) normalization = args_data.get("normalization", None) @@ -157,7 +156,6 @@ def main(args_eval, resume_preempt=False): dataset_name=dataset_name, root_path=root_path, img_size=resolution, - image_folder=image_folder, batch_size=batch_size, world_size=world_size, rank=rank, @@ -168,7 +166,6 @@ def main(args_eval, resume_preempt=False): dataset_name=dataset_name, root_path=root_path, img_size=resolution, - image_folder=image_folder, batch_size=batch_size, world_size=world_size, rank=rank, @@ -354,7 +351,6 @@ def load_checkpoint(device, r_path, classifiers, opt, scaler, val_only=False): def make_dataloader( dataset_name, root_path, - image_folder, batch_size, world_size, rank, @@ -396,7 +392,6 @@ def make_dataloader( world_size=world_size, rank=rank, root_path=root_path, - image_folder=image_folder, training=training, drop_last=False, subset_file=subset_file, diff --git a/hubconf.py b/hubconf.py index 1aa48b64..39f303d5 100644 --- a/hubconf.py +++ b/hubconf.py @@ -10,6 +10,10 @@ vjepa2_vit_giant_384, vjepa2_vit_huge, vjepa2_vit_large, + vjepa2_1_vit_base_384, + vjepa2_1_vit_large_384, + vjepa2_1_vit_giant_384, + vjepa2_1_vit_gigantic_384, ) dependencies = ["torch", "timm", "einops"] diff --git a/setup.py b/setup.py index 35af1634..da9c84c1 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ from setuptools import setup NAME = "vjepa2" -VERSION = "0.0.1" +VERSION = "0.0.2" DESCRIPTION = "PyTorch code and models for V-JEPA 2." URL = "https://github.com/facebookresearch/vjepa2" diff --git a/src/datasets/data_manager.py b/src/datasets/data_manager.py index e3548639..d4e1a534 100644 --- a/src/datasets/data_manager.py +++ b/src/datasets/data_manager.py @@ -20,7 +20,6 @@ def init_data( world_size=1, rank=0, root_path=None, - image_folder=None, training=True, drop_last=True, subset_file=None, @@ -52,7 +51,6 @@ def init_data( world_size=world_size, rank=rank, root_path=root_path, - image_folder=image_folder, persistent_workers=persistent_workers, drop_last=drop_last, subset_file=subset_file, diff --git a/src/datasets/imagenet1k.py b/src/datasets/imagenet1k.py index 3597e3a1..55603632 100644 --- a/src/datasets/imagenet1k.py +++ b/src/datasets/imagenet1k.py @@ -21,7 +21,6 @@ class ImageNet(torchvision.datasets.ImageFolder): def __init__( self, root, - image_folder="imagenet_full_size/061417/", tar_file="imagenet_full_size-061417.tar.gz", transform=None, train=True, @@ -35,7 +34,6 @@ def __init__( Dataset wrapper :param root: root network directory for ImageNet data - :param image_folder: path to images inside root network directory :param tar_file: zipped image_folder inside root network directory :param train: whether to load train data (or validation) :param job_id: scheduler job-id used to create dir on local machine @@ -43,7 +41,7 @@ def __init__( """ suffix = "train/" if train else "val/" - data_path = os.path.join(root, image_folder, suffix) + data_path = os.path.join(root, suffix) logger.info(f"data-path {data_path}") super(ImageNet, self).__init__(root=data_path, transform=transform) @@ -120,7 +118,6 @@ def make_imagenet1k( world_size=1, rank=0, root_path=None, - image_folder=None, training=True, drop_last=True, persistent_workers=False, @@ -128,7 +125,6 @@ def make_imagenet1k( ): dataset = ImageNet( root=root_path, - image_folder=image_folder, transform=transform, train=training, index_targets=False, @@ -136,7 +132,9 @@ def make_imagenet1k( if subset_file is not None: dataset = ImageNetSubset(dataset, subset_file) logger.info("ImageNet dataset created") - dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset=dataset, num_replicas=world_size, rank=rank) + dist_sampler = torch.utils.data.distributed.DistributedSampler( + dataset=dataset, num_replicas=world_size, rank=rank + ) data_loader = torch.utils.data.DataLoader( dataset, collate_fn=collator, diff --git a/src/datasets/utils/weighted_sampler.py b/src/datasets/utils/weighted_sampler.py index 9b4105fa..63a67b4e 100644 --- a/src/datasets/utils/weighted_sampler.py +++ b/src/datasets/utils/weighted_sampler.py @@ -8,9 +8,9 @@ import numpy as np import torch -from torch.utils.data import DistributedSampler, RandomSampler from src.utils.logging import get_logger +from torch.utils.data import DistributedSampler, RandomSampler logger = get_logger("WeightedSampler") @@ -35,7 +35,9 @@ def __init__( seed: int = 0, drop_last: bool = False, ): - logger.info(f"Using DistributedWeightedSampler with rank {rank} / {num_replicas}") + logger.info( + f"Using DistributedWeightedSampler with rank {rank} / {num_replicas}" + ) assert hasattr( dataset, "sample_weights" ), "Dataset must have sample_weights property for using DistributedWeightedSampler" @@ -78,7 +80,9 @@ def __iter__(self) -> Iterator: if padding_size <= len(indices): indices += indices[:padding_size] else: - indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] + indices += (indices * math.ceil(padding_size / len(indices)))[ + :padding_size + ] else: # remove tail of data to make it evenly divisible indices = indices[: self.total_size] @@ -106,7 +110,9 @@ def __init__( shuffle: bool = True, seed: int = 0, ): - logger.info(f"Using MemoryEfficientDistributedWeightedSampler with rank {rank} / {num_replicas}") + logger.info( + f"Using MemoryEfficientDistributedWeightedSampler with rank {rank} / {num_replicas}" + ) assert hasattr( dataset, "dataset_weights" ), "Dataset must have dataset_weights property for using MemoryEfficientDistributedWeightedSampler" @@ -129,10 +135,14 @@ def __init__( if self.shuffle: self.rng = np.random.default_rng(self.seed + self.rank + self.epoch) total_weights = sum(self.dataset_weights) - self.dataset_probablities = np.array([w / total_weights for w in self.dataset_weights]) + self.dataset_probablities = np.array( + [w / total_weights for w in self.dataset_weights] + ) else: if any([not isinstance(w, int) for w in self.dataset_weights]): - raise ValueError("Dataset weights must be integers when shuffle is False") + raise ValueError( + "Dataset weights must be integers when shuffle is False" + ) self.dataset_orders = [] for i, w in enumerate(self.dataset_weights): @@ -145,7 +155,9 @@ def __iter__(self) -> Iterator: def __next__(self) -> int: if self.shuffle: - selected_dataset_idx = self.rng.choice(range(len(self.dataset_weights)), p=self.dataset_probablities) + selected_dataset_idx = self.rng.choice( + range(len(self.dataset_weights)), p=self.dataset_probablities + ) # In order to avoid sampling the same example multiple times between the ranks, # we limit each rank to a subset of the total number of samples in the dataset. @@ -171,12 +183,14 @@ def __next__(self) -> int: else: # Iterate through the dataset orders in a round-robin fashion, offset by the rank - dataset_orders_idx = (self.rank + self.drawn_samples) % len(self.dataset_orders) + dataset_orders_idx = (self.rank + self.drawn_samples) % len( + self.dataset_orders + ) selected_dataset_idx = self.dataset_orders[dataset_orders_idx] # Get the sample index in the selected dataset by skipping with the num_replicas * drawn_samples - sample_idx_in_dataset = (self.drawn_samples * self.num_replicas + self.rank) % self.dataset_sizes[ - selected_dataset_idx - ] + sample_idx_in_dataset = ( + self.drawn_samples * self.num_replicas + self.rank + ) % self.dataset_sizes[selected_dataset_idx] self.drawn_samples += 1 # Getting the index of the sample in the whole dataset @@ -224,7 +238,9 @@ def __init__( shuffle: bool = True, seed: int = 0, ): - logger.info(f"Using MemoryEfficientDistributedWeightedSamplerLessRepeat with rank {rank} / {num_replicas}") + logger.info( + f"Using MemoryEfficientDistributedWeightedSamplerLessRepeat with rank {rank} / {num_replicas}" + ) assert hasattr( dataset, "dataset_weights" ), "Dataset must have dataset_weights property for using MemoryEfficientDistributedWeightedSamplerLessRepeat" @@ -250,11 +266,15 @@ def __init__( if self.shuffle: self.rng = np.random.default_rng(self.seed + self.rank + self.epoch) total_weights = sum(self.dataset_weights) - self.dataset_probablities = np.array([w / total_weights for w in self.dataset_weights]) + self.dataset_probablities = np.array( + [w / total_weights for w in self.dataset_weights] + ) # For each dataset we generate a permutation of the indices that will be processed by that rank. # This is going to be the subset of indices, selected by the steps sizes of the world size. - logger.info(f"Generating dataset indices for rank {self.rank} / {self.num_replicas}") + logger.info( + f"Generating dataset indices for rank {self.rank} / {self.num_replicas}" + ) # Getting a RandomSampler for indices assigned to each dataset. self.individual_dataset_sampler = [] @@ -263,11 +283,15 @@ def __init__( # NOTE: this may effectively drop the last batch, # but given the sample sizes that we use this sampler with, it should not be an issue. num_samples_in_rank = ds // self.num_replicas - self.individual_dataset_sampler.append(self._new_sampler(num_samples_in_rank)) + self.individual_dataset_sampler.append( + self._new_sampler(num_samples_in_rank) + ) else: if any([not isinstance(w, int) for w in self.dataset_weights]): - raise ValueError("Dataset weights must be integers when shuffle is False") + raise ValueError( + "Dataset weights must be integers when shuffle is False" + ) self.dataset_orders = [] for i, w in enumerate(self.dataset_weights): @@ -295,7 +319,9 @@ def _in_rank_next_index_for_dataset(self, dataset_idx: int) -> int: if next_sampler_idx is None: # We have reached the end of the dataset, we need to reset the sampler. num_samples_in_rank = self.dataset_sizes[dataset_idx] // self.num_replicas - self.individual_dataset_sampler[dataset_idx] = self._new_sampler(num_samples_in_rank) + self.individual_dataset_sampler[dataset_idx] = self._new_sampler( + num_samples_in_rank + ) next_sampler_idx = safe_next(self.individual_dataset_sampler[dataset_idx]) assert next_sampler_idx is not None @@ -303,7 +329,9 @@ def _in_rank_next_index_for_dataset(self, dataset_idx: int) -> int: def __next__(self) -> int: if self.shuffle: - selected_dataset_idx = self.rng.choice(range(len(self.dataset_weights)), p=self.dataset_probablities) + selected_dataset_idx = self.rng.choice( + range(len(self.dataset_weights)), p=self.dataset_probablities + ) in_rank_sample = self._in_rank_next_index_for_dataset(selected_dataset_idx) # 2) Getting sample index in the dataset. @@ -311,12 +339,14 @@ def __next__(self) -> int: else: # Iterate through the dataset orders in a round-robin fashion, offset by the rank - dataset_orders_idx = (self.rank + self.drawn_samples) % len(self.dataset_orders) + dataset_orders_idx = (self.rank + self.drawn_samples) % len( + self.dataset_orders + ) selected_dataset_idx = self.dataset_orders[dataset_orders_idx] # Get the sample index in the selected dataset by skipping with the num_replicas * drawn_samples - sample_idx_in_dataset = (self.drawn_samples * self.num_replicas + self.rank) % self.dataset_sizes[ - selected_dataset_idx - ] + sample_idx_in_dataset = ( + self.drawn_samples * self.num_replicas + self.rank + ) % self.dataset_sizes[selected_dataset_idx] self.drawn_samples += 1 # Getting the index of the sample in the whole dataset diff --git a/src/datasets/video_dataset.py b/src/datasets/video_dataset.py index b0c47f66..05ec2a38 100644 --- a/src/datasets/video_dataset.py +++ b/src/datasets/video_dataset.py @@ -13,9 +13,13 @@ import pandas as pd import torch import torchvision -from decord import VideoReader, cpu +from decord import cpu, VideoReader -from src.datasets.utils.dataloader import ConcatIndices, MonitoredDataset, NondeterministicDataLoader +from src.datasets.utils.dataloader import ( + ConcatIndices, + MonitoredDataset, + NondeterministicDataLoader, +) from src.datasets.utils.weighted_sampler import DistributedWeightedSampler _GLOBAL_SEED = 0 @@ -79,7 +83,9 @@ def make_videodataset( logger.info("VideoDataset dataset created") if datasets_weights is not None: - dist_sampler = DistributedWeightedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True) + dist_sampler = DistributedWeightedSampler( + dataset, num_replicas=world_size, rank=rank, shuffle=True + ) else: dist_sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=world_size, rank=rank, shuffle=True @@ -146,7 +152,9 @@ def __init__( self.fps = fps if sum([v is not None for v in (fps, duration, frame_step)]) != 1: - raise ValueError(f"Must specify exactly one of either {fps=}, {duration=}, or {frame_step=}.") + raise ValueError( + f"Must specify exactly one of either {fps=}, {duration=}, or {frame_step=}." + ) if isinstance(data_paths, str): data_paths = [data_paths] @@ -155,11 +163,15 @@ def __init__( self.dataset_fpcs = [frames_per_clip for _ in data_paths] else: if len(dataset_fpcs) != len(data_paths): - raise ValueError("Frames per clip not properly specified for NFS data paths") + raise ValueError( + "Frames per clip not properly specified for NFS data paths" + ) self.dataset_fpcs = dataset_fpcs if VideoReader is None: - raise ImportError('Unable to import "decord" which is required to read videos.') + raise ImportError( + 'Unable to import "decord" which is required to read videos.' + ) # Load video paths and labels samples, labels = [], [] @@ -222,7 +234,9 @@ def get_item_video(self, index): dataset_idx, _ = self.per_dataset_indices[index] frames_per_clip = self.dataset_fpcs[dataset_idx] - buffer, clip_indices = self.loadvideo_decord(sample, frames_per_clip) # [T H W 3] + buffer, clip_indices = self.loadvideo_decord( + sample, frames_per_clip + ) # [T H W 3] loaded_video = len(buffer) > 0 if not loaded_video: return @@ -251,7 +265,9 @@ def get_item_image(self, index): fpc = self.dataset_fpcs[dataset_idx] try: - image_tensor = torchvision.io.read_image(path=sample, mode=torchvision.io.ImageReadMode.RGB) + image_tensor = torchvision.io.read_image( + path=sample, mode=torchvision.io.ImageReadMode.RGB + ) except Exception: return label = self.labels[index] diff --git a/src/hub/backbones.py b/src/hub/backbones.py index 35308f86..435c72ea 100644 --- a/src/hub/backbones.py +++ b/src/hub/backbones.py @@ -5,17 +5,23 @@ import torch -VJEPA_BASE_URL = "https://dl.fbaipublicfiles.com/vjepa2" +# VJEPA_BASE_URL = "https://dl.fbaipublicfiles.com/vjepa2" # for testing -# VJEPA_BASE_URL = "http://localhost:8300" +VJEPA_BASE_URL = "http://localhost:8300" ARCH_NAME_MAP = { + # V-JEPA 2 "vit_large": ("vit_large", "vitl"), "vit_huge": ("vit_huge", "vith"), "vit_giant": ("vit_giant_xformers", "vitg"), "vit_ac_giant": ("vit_giant_xformers", "vjepa2-ac-vitg"), "vit_giant_384": ("vit_giant_xformers", "vitg-384"), + # V-JEPA 2.1 + "vjepa2_1_vit_base_384": ("vit_base", "vjepa2_1_vitb_dist_vitG_384"), + "vjepa2_1_vit_large_384": ("vit_large", "vjepa2_1_vitl_dist_vitG_384"), + "vjepa2_1_vit_giant_384": ("vit_giant_xformers", "vjepa2_1_vitg_384"), + "vjepa2_1_vit_gigantic_384": ("vit_gigantic_xformers", "vjepa2_1_vitG_384"), } @@ -38,8 +44,10 @@ def _make_vjepa2_ac_model( pretrained: bool = True, **kwargs, ): - from ..models import ac_predictor as vit_ac_predictor - from ..models import vision_transformer as vit_encoder + from ..models import ( + ac_predictor as vit_ac_predictor, + vision_transformer as vit_encoder, + ) vit_encoder_kwargs = dict( patch_size=patch_size, @@ -83,15 +91,17 @@ def _make_vjepa2_ac_model( def _make_vjepa2_model( *, model_name: str = "vit_large", + checkpoint_key="target_encoder", img_size=256, patch_size=16, tubelet_size=2, num_frames=64, + predictor_embed_dim=384, + predictor_out_embed_dim=None, pretrained: bool = True, **kwargs, ): - from ..models import predictor as vit_predictor - from ..models import vision_transformer as vit_encoder + from ..models import predictor as vit_predictor, vision_transformer as vit_encoder vit_encoder_kwargs = dict( patch_size=patch_size, @@ -114,7 +124,8 @@ def _make_vjepa2_model( patch_size=patch_size, use_mask_tokens=True, embed_dim=encoder.embed_dim, - predictor_embed_dim=384, + predictor_embed_dim=predictor_embed_dim, + out_embed_dim=predictor_out_embed_dim, num_frames=num_frames, tubelet_size=tubelet_size, depth=12, @@ -134,10 +145,14 @@ def _make_vjepa2_model( model_file = ARCH_NAME_MAP[model_name][-1] url = VJEPA_BASE_URL + f"/{model_file}.pt" state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") - encoder_state_dict = _clean_backbone_key(state_dict["encoder"]) - encoder.load_state_dict(encoder_state_dict, strict=False) # state_dict has pos_embed but we use RoPE + encoder_state_dict = _clean_backbone_key(state_dict[checkpoint_key]) + encoder.load_state_dict( + encoder_state_dict, strict=False + ) # state_dict has pos_embed but we use RoPE predictor_state_dict = _clean_backbone_key(state_dict["predictor"]) - predictor.load_state_dict(predictor_state_dict, strict=False) # state_dict has pos_embed but we use RoPE + predictor.load_state_dict( + predictor_state_dict, strict=False + ) # state_dict has pos_embed but we use RoPE return encoder, predictor @@ -146,32 +161,179 @@ def vjepa2_vit_large(*, pretrained: bool = True, **kwargs): """ VJEPA 2 ViT-Large model """ - return _make_vjepa2_model(model_name="vit_large", img_size=256, pretrained=pretrained, **kwargs) + return _make_vjepa2_model( + model_name="vit_large", img_size=256, pretrained=pretrained, **kwargs + ) def vjepa2_vit_huge(*, pretrained: bool = True, **kwargs): """ VJEPA 2 ViT-Huge model """ - return _make_vjepa2_model(model_name="vit_huge", img_size=256, pretrained=pretrained, **kwargs) + return _make_vjepa2_model( + model_name="vit_huge", img_size=256, pretrained=pretrained, **kwargs + ) def vjepa2_vit_giant(*, pretrained: bool = True, **kwargs): """ VJEPA 2 ViT-giant model """ - return _make_vjepa2_model(model_name="vit_giant", img_size=256, pretrained=pretrained, **kwargs) + return _make_vjepa2_model( + model_name="vit_giant", img_size=256, pretrained=pretrained, **kwargs + ) def vjepa2_vit_giant_384(*, pretrained: bool = True, **kwargs): """ VJEPA 2 ViT-giant-384 model """ - return _make_vjepa2_model(model_name="vit_giant_384", img_size=384, pretrained=pretrained, **kwargs) + return _make_vjepa2_model( + model_name="vit_giant_384", img_size=384, pretrained=pretrained, **kwargs + ) def vjepa2_ac_vit_giant(*, pretrained: bool = True, **kwargs): """ VJEPA 2-AC ViT-giant model """ - return _make_vjepa2_ac_model(model_name="vit_ac_giant", img_size=256, pretrained=pretrained, **kwargs) + return _make_vjepa2_ac_model( + model_name="vit_ac_giant", img_size=256, pretrained=pretrained, **kwargs + ) + + +# ########## V-JEPA 2.1 ########## + + +vjepa2_1_teacher_embed_dim = 1664 + + +def _make_vjepa2_1_model( + model_name: str = "vjepa2_1_vit_large_384", + checkpoint_key="target_encoder", + img_size=384, + patch_size=16, + tubelet_size=2, + num_frames=64, + predictor_embed_dim=384, + predictor_depth=24, + predictor_num_mask_tokens=10, + n_output_distillation=4, + return_all_tokens=False, + teacher_embed_dim=None, + pretrained: bool = True, + **kwargs, +): + from app.vjepa_2_1.models import predictor as vit_predictor, vision_transformer as vit_encoder + + vit_encoder_kwargs = dict( + patch_size=patch_size, + img_size=(img_size, img_size), + num_frames=num_frames, + tubelet_size=tubelet_size, + use_sdpa=True, + use_SiLU=False, + wide_SiLU=True, + uniform_power=False, + use_rope=True, + img_temporal_dim_size=1, + interpolate_rope=True, + ) + vit_encoder_kwargs.update(**kwargs) + + arch_name = ARCH_NAME_MAP[model_name][0] + encoder = vit_encoder.__dict__[arch_name](**vit_encoder_kwargs) + + vit_predictor_kwargs = dict( + img_size=(img_size, img_size), + patch_size=patch_size, + use_mask_tokens=True, + embed_dim=encoder.embed_dim, + predictor_embed_dim=predictor_embed_dim, + teacher_embed_dim=teacher_embed_dim, + num_frames=num_frames, + tubelet_size=tubelet_size, + depth=predictor_depth, + num_heads=12, + num_mask_tokens=predictor_num_mask_tokens, + use_rope=True, + uniform_power=False, + use_sdpa=True, + use_silu=False, + wide_silu=True, + n_output_distillation=n_output_distillation, + return_all_tokens=return_all_tokens, + img_temporal_dim_size=1, + ) + vit_predictor_kwargs.update(**kwargs) + + predictor = vit_predictor.__dict__["vit_predictor"](**vit_predictor_kwargs) + + if pretrained: + model_file = ARCH_NAME_MAP[model_name][-1] + url = VJEPA_BASE_URL + f"/{model_file}.pt" + state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") + encoder_state_dict = _clean_backbone_key(state_dict[checkpoint_key]) + encoder.load_state_dict( + encoder_state_dict, strict=True + ) # state_dict has pos_embed but we use RoPE + predictor_state_dict = _clean_backbone_key(state_dict["predictor"]) + predictor.load_state_dict( + predictor_state_dict, strict=True + ) # state_dict has pos_embed but we use RoPE + + return encoder, predictor + + +def vjepa2_1_vit_base_384(*, pretrained: bool = True, **kwargs): + return _make_vjepa2_1_model( + model_name="vjepa2_1_vit_base_384", + checkpoint_key="ema_encoder", + img_size=384, + predictor_depth=12, + predictor_num_mask_tokens=8, + n_output_distillation=1, + return_all_tokens=True, + teacher_embed_dim=vjepa2_1_teacher_embed_dim, + pretrained=pretrained, + **kwargs, + ) + + +def vjepa2_1_vit_large_384(*, pretrained: bool = True, **kwargs): + return _make_vjepa2_1_model( + model_name="vjepa2_1_vit_large_384", + checkpoint_key="ema_encoder", + img_size=384, + predictor_depth=12, + predictor_num_mask_tokens=8, + n_output_distillation=1, + return_all_tokens=True, + teacher_embed_dim=vjepa2_1_teacher_embed_dim, + pretrained=pretrained, + **kwargs, + ) + + +def vjepa2_1_vit_giant_384(*, pretrained: bool = True, **kwargs): + return _make_vjepa2_1_model( + model_name="vjepa2_1_vit_giant_384", + img_size=384, + predictor_num_mask_tokens=8, + n_output_distillation=4, + return_all_tokens=True, + pretrained=pretrained, + **kwargs, + ) + + +def vjepa2_1_vit_gigantic_384(*, pretrained: bool = True, **kwargs): + return _make_vjepa2_1_model( + model_name="vjepa2_1_vit_gigantic_384", + img_size=384, + predictor_num_mask_tokens=8, + n_output_distillation=4, + return_all_tokens=True, + pretrained=pretrained, + **kwargs, + ) diff --git a/src/masks/multiseq_multiblock3d.py b/src/masks/multiseq_multiblock3d.py index 85d9b049..cd8e2f2f 100644 --- a/src/masks/multiseq_multiblock3d.py +++ b/src/masks/multiseq_multiblock3d.py @@ -53,11 +53,23 @@ def step(self): def __call__(self, batch): - # Batch: [buffer, label, clip_indices] + # Batch: [buffer, label, clip_indices] for video + # or [buffer, label] for images filtered_batches = {fpc: [] for fpc in self.mask_generators} for sample in batch: - fpc = len(sample[-1][-1]) - filtered_batches[fpc] += [sample] + # Check if sample is from video dataset (has clip_indices) or image dataset + if len(sample) >= 3 and isinstance(sample[-1], (list, tuple)): + # Video sample: sample[-1] is clip_indices, sample[-1][-1] contains frame indices + try: + fpc = len(sample[-1][-1]) + except (TypeError, IndexError): + # Fallback: assume single frame if structure is unexpected + fpc = 1 + else: + # Image sample: single frame + fpc = 1 + if fpc in filtered_batches: + filtered_batches[fpc] += [sample] fpc_collations = [] for fpc in filtered_batches: @@ -71,7 +83,9 @@ def __call__(self, batch): masks_enc, masks_pred = mask_generator(batch_size) collated_masks_enc.append(masks_enc) collated_masks_pred.append(masks_pred) - fpc_collations += [(collated_batch, collated_masks_enc, collated_masks_pred)] + fpc_collations += [ + (collated_batch, collated_masks_enc, collated_masks_pred) + ] return fpc_collations @@ -100,7 +114,9 @@ def __init__( if not isinstance(spatial_patch_size, tuple): spatial_patch_size = (spatial_patch_size,) * 2 self.crop_size = crop_size - self.height, self.width = [crop_size[i] // spatial_patch_size[i] for i in (0, 1)] + self.height, self.width = [ + crop_size[i] // spatial_patch_size[i] for i in (0, 1) + ] self.duration = num_frames // temporal_patch_size self.full_complement = full_complement self.pred_full_complement = pred_full_complement @@ -126,7 +142,9 @@ def step(self): v = i.value return v - def _sample_block_size(self, generator, temporal_scale, spatial_scale, aspect_ratio_scale): + def _sample_block_size( + self, generator, temporal_scale, spatial_scale, aspect_ratio_scale + ): # -- Sample temporal block mask scale _rand = torch.rand(1, generator=generator).item() min_t, max_t = temporal_scale @@ -193,7 +211,9 @@ def __call__(self, batch_size): empty_context = True while empty_context: - mask_e = torch.ones((self.duration, self.height, self.width), dtype=torch.int32) + mask_e = torch.ones( + (self.duration, self.height, self.width), dtype=torch.int32 + ) for _ in range(self.npred): mask_e *= self._sample_block_mask(p_size) mask_e = mask_e.flatten() @@ -216,7 +236,12 @@ def __call__(self, batch_size): if self.full_complement: # predictor mask is just complement of encoder mask collated_masks_pred = [ torch.tensor( - sorted(list(set(range(int(self.duration * self.height * self.width))) - set(cm.tolist()))), + sorted( + list( + set(range(int(self.duration * self.height * self.width))) + - set(cm.tolist()) + ) + ), dtype=cm.dtype, ) for cm in collated_masks_enc @@ -224,7 +249,12 @@ def __call__(self, batch_size): elif self.pred_full_complement: collated_masks_enc = [ torch.tensor( - sorted(list(set(range(int(self.duration * self.height * self.width))) - set(cm.tolist()))), + sorted( + list( + set(range(int(self.duration * self.height * self.width))) + - set(cm.tolist()) + ) + ), dtype=cm.dtype, ) for cm in collated_masks_pred diff --git a/src/models/predictor.py b/src/models/predictor.py index 1b00189f..968c5cc5 100644 --- a/src/models/predictor.py +++ b/src/models/predictor.py @@ -26,6 +26,7 @@ def __init__( tubelet_size=2, embed_dim=768, predictor_embed_dim=384, + out_embed_dim=None, depth=6, num_heads=12, mlp_ratio=4.0, @@ -120,9 +121,16 @@ def __init__( ] ) + if out_embed_dim is None: + teacher_embed_dim = kwargs.get("teacher_embed_dim", None) + if teacher_embed_dim is not None: + out_embed_dim = teacher_embed_dim + else: + out_embed_dim = embed_dim + # Normalize & project back to input dimension self.predictor_norm = norm_layer(predictor_embed_dim) - self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True) + self.predictor_proj = nn.Linear(predictor_embed_dim, out_embed_dim, bias=True) # ------ initialize weights if self.predictor_pos_embed is not None: diff --git a/src/models/vision_transformer.py b/src/models/vision_transformer.py index e2c43592..be2bbfc5 100644 --- a/src/models/vision_transformer.py +++ b/src/models/vision_transformer.py @@ -453,7 +453,7 @@ def vit_gigantic(patch_size=16, **kwargs): embed_dim=1664, depth=48, num_heads=16, - mpl_ratio=64 / 13, + mlp_ratio=64 / 13, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs @@ -467,7 +467,7 @@ def vit_gigantic_xformers(patch_size=16, **kwargs): embed_dim=1664, depth=48, num_heads=26, - mpl_ratio=64 / 13, + mlp_ratio=64 / 13, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs diff --git a/src/utils/wrappers.py b/src/utils/wrappers.py index 7e08ec2d..cb24a136 100644 --- a/src/utils/wrappers.py +++ b/src/utils/wrappers.py @@ -11,6 +11,7 @@ class MultiSeqWrapper(nn.Module): def __init__(self, backbone): super().__init__() self.backbone = backbone + self.embed_dim = backbone.embed_dim def forward(self, x, masks=None): """