diff --git a/app/vjepa_2_1/train.py b/app/vjepa_2_1/train.py index f8fe5949..abf0e209 100644 --- a/app/vjepa_2_1/train.py +++ b/app/vjepa_2_1/train.py @@ -114,7 +114,9 @@ def main(args, resume_preempt=False): normalize_predictor = cfgs_model.get("normalize_predictor", False) modality_embedding = cfgs_model.get("modality_embedding", False) levels_predictor = cfgs_model.get("levels_predictor", 4) - if model_name == "vit_large": + if model_name == "vit_base": + embed_dim_encoder = 768 + elif model_name == "vit_large": embed_dim_encoder = 1024 elif model_name == "vit_giant_xformers": embed_dim_encoder = 1408