Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
name: Lint

on:
push:
branches: [main]
pull_request:

jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v5
- run: uv python install 3.10
- run: uv pip install ruff
- run: uv run ruff check .
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.2
hooks:
- id: ruff
args: [--fix]
3 changes: 1 addition & 2 deletions app/vjepa_2_1/models/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
import torch
import torch.nn as nn

from app.vjepa_2_1.models.utils.modules import Block
from src.masks.utils import apply_masks
from src.utils.tensors import repeat_interleave_batch, trunc_normal_

from app.vjepa_2_1.models.utils.modules import Block


class VisionTransformerPredictor(nn.Module):
"""Vision Transformer Predictor"""
Expand Down
2 changes: 0 additions & 2 deletions app/vjepa_2_1/models/utils/masks_dist.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import torch
import torch.nn as nn
import torchvision


def _get_frame_pos(ids, H_patches=None, W_patches=None, grid_size=None):
Expand Down
1 change: 0 additions & 1 deletion app/vjepa_2_1/models/utils/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from timm.models.layers import drop_path


Expand Down
3 changes: 1 addition & 2 deletions app/vjepa_2_1/models/utils/patch_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
# LICENSE file in the root directory of this source tree.
#

from einops import rearrange

import torch.nn as nn
from einops import rearrange


class AudioPatchEmbed(nn.Module):
Expand Down
5 changes: 2 additions & 3 deletions app/vjepa_2_1/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
import torch
import torch.nn as nn

from src.masks.utils import apply_masks
from src.utils.tensors import trunc_normal_

from app.vjepa_2_1.models.utils.modules import Block
from app.vjepa_2_1.models.utils.patch_embed import PatchEmbed, PatchEmbed3D
from src.masks.utils import apply_masks
from src.utils.tensors import trunc_normal_


class VisionTransformer(nn.Module):
Expand Down
4 changes: 2 additions & 2 deletions app/vjepa_2_1/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import torch
import torch.multiprocessing as mp
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel

from app.vjepa_2_1.models.utils.masks_dist import compute_mask_distance
from app.vjepa_2_1.models.utils.modules import Lambda_LinearWarmupHold
from app.vjepa_2_1.transforms import make_transforms
Expand All @@ -35,8 +37,6 @@
from src.masks.utils import apply_masks
from src.utils.distributed import init_distributed
from src.utils.logging import AverageMeter, CSVLogger, get_logger, gpu_timer
from torch.nn.parallel import DistributedDataParallel


log_timings = True
log_freq = 10
Expand Down
4 changes: 2 additions & 2 deletions app/vjepa_2_1/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
# LICENSE file in the root directory of this source tree.

import numpy as np

import src.datasets.utils.video.transforms as video_transforms
import torch
import torchvision.transforms as transforms
from PIL import Image

import src.datasets.utils.video.transforms as video_transforms
from src.datasets.utils.video.randerase import RandomErasing


Expand Down
5 changes: 3 additions & 2 deletions app/vjepa_2_1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
import logging
import sys

import app.vjepa_2_1.models.predictor as vit_pred
import app.vjepa_2_1.models.vision_transformer as video_vit
import torch
import torch.nn.functional as F
import yaml

import app.vjepa_2_1.models.predictor as vit_pred
import app.vjepa_2_1.models.vision_transformer as video_vit
from app.vjepa_2_1.wrappers import MultiSeqWrapper, PredictorMultiSeqWrapper
from src.utils.checkpoint_loader import robust_checkpoint_loader
from src.utils.schedulers import (
Expand Down
1 change: 0 additions & 1 deletion evals/image_classification_frozen/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from evals.image_classification_frozen.models import init_module
from src.datasets.data_manager import init_data
from src.models.attentive_pooler import AttentiveClassifier
from src.models.utils.modules import Block, CrossAttentionBlock
from src.utils.checkpoint_loader import robust_checkpoint_loader
from src.utils.distributed import AllReduce, init_distributed
from src.utils.logging import AverageMeter, CSVLogger
Expand Down
12 changes: 0 additions & 12 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,5 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from evals.hub.preprocessor import vjepa2_preprocessor
from src.hub.backbones import (
vjepa2_ac_vit_giant,
vjepa2_vit_giant,
vjepa2_vit_giant_384,
vjepa2_vit_huge,
vjepa2_vit_large,
vjepa2_1_vit_base_384,
vjepa2_1_vit_large_384,
vjepa2_1_vit_giant_384,
vjepa2_1_vit_gigantic_384,
)

dependencies = ["torch", "timm", "einops"]
11 changes: 11 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,14 @@ line_length=119

[tool.black]
line-length = 119

[tool.ruff]
line-length = 119
extend-exclude = ["*.ipynb"]

[tool.ruff.lint]
select = ["E", "F", "I"]
ignore = ["E501", "E402", "E731", "E722", "E741", "E721", "E701", "F841", "F403", "F405"]

[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401"]
2 changes: 0 additions & 2 deletions src/datasets/imagenet1k.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
# LICENSE file in the root directory of this source tree.

import os
import subprocess
import time
from logging import getLogger

import numpy as np
Expand Down
1 change: 1 addition & 0 deletions src/datasets/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

from src.utils.cluster import dataset_paths

from src.utils.logging import get_logger

logger = get_logger("Datasets utils")
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/utils/weighted_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

import numpy as np
import torch
from torch.utils.data import DistributedSampler, RandomSampler

from src.utils.logging import get_logger
from torch.utils.data import DistributedSampler, RandomSampler

logger = get_logger("WeightedSampler")

Expand Down
2 changes: 1 addition & 1 deletion src/datasets/video_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import pandas as pd
import torch
import torchvision
from decord import cpu, VideoReader
from decord import VideoReader, cpu

from src.datasets.utils.dataloader import (
ConcatIndices,
Expand Down
8 changes: 6 additions & 2 deletions src/hub/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def _make_vjepa2_ac_model(
):
from ..models import (
ac_predictor as vit_ac_predictor,
)
from ..models import (
vision_transformer as vit_encoder,
)

Expand Down Expand Up @@ -101,7 +103,8 @@ def _make_vjepa2_model(
pretrained: bool = True,
**kwargs,
):
from ..models import predictor as vit_predictor, vision_transformer as vit_encoder
from ..models import predictor as vit_predictor
from ..models import vision_transformer as vit_encoder

vit_encoder_kwargs = dict(
patch_size=patch_size,
Expand Down Expand Up @@ -224,7 +227,8 @@ def _make_vjepa2_1_model(
pretrained: bool = True,
**kwargs,
):
from app.vjepa_2_1.models import predictor as vit_predictor, vision_transformer as vit_encoder
from app.vjepa_2_1.models import predictor as vit_predictor
from app.vjepa_2_1.models import vision_transformer as vit_encoder

vit_encoder_kwargs = dict(
patch_size=patch_size,
Expand Down
1 change: 0 additions & 1 deletion src/models/utils/patch_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.

import torch.nn as nn
from einops import rearrange


class PatchEmbed(nn.Module):
Expand Down