Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
99da195
Add Spatiotemporal Area Attention (ST-A²) for V-JEPA 2
tarassh Feb 8, 2026
07bb0b1
Add Colab notebook for ST-A2 verification tests
tarassh Feb 8, 2026
094da22
Fix PyTorch 2.9+ compatibility: total_mem → total_memory
tarassh Feb 8, 2026
03fae9a
Add ST-A² ablation notebook for Colab T4
tarassh Feb 8, 2026
5ff32c7
Bump ablation to full resolution: 256px, 16 frames, batch=1
tarassh Feb 8, 2026
88282ae
Add git pull to Cell 1 so re-runs pick up latest code
tarassh Feb 8, 2026
31ddd46
Add config details to ablation summary and CSV output
tarassh Feb 8, 2026
b015999
Add per-layer profiling: attention vs MLP time breakdown
tarassh Feb 8, 2026
9c53fd6
Vectorize RoPEAreaAttention: sort-pad-attend-unsort replaces Python l…
tarassh Feb 8, 2026
b68274a
Add H100 multi-resolution ablation notebook
tarassh Feb 8, 2026
59e20aa
Add standalone Python script for H100 multi-resolution sweep
tarassh Feb 8, 2026
8640270
Fix f-string syntax for Python 3.10 compat
tarassh Feb 8, 2026
57bfd73
Add PR write-up and downstream eval configs for ST-A²
tarassh Feb 8, 2026
26f6490
Add fine-tune config and flexible checkpoint loading for ST-A²
tarassh Feb 9, 2026
ed48b6d
Add K400 CSV manifest generation script for downstream eval
tarassh Feb 9, 2026
44016f1
Add one-shot setup and eval script for Lambda A10/A100
tarassh Feb 9, 2026
420e8fe
Fix K400 CSV generation for flat directory layout
tarassh Feb 9, 2026
6d09099
Add K400 downstream eval results to PR description
tarassh Feb 10, 2026
2e867d0
Add one-shot fine-tune and eval script for Lambda A10/A100
tarassh Feb 10, 2026
dedfa42
Fix config generation: use Python yaml instead of fragile sed
tarassh Feb 10, 2026
4b11859
Fix fine-tuned checkpoint filename: latest.pt not jepa-latest.pth.tar
tarassh Feb 11, 2026
5f8203b
Add fine-tuned ST-A² downstream eval results to PR description
tarassh Feb 11, 2026
cf07386
Rewrite script as full 3-way eval pipeline
tarassh Feb 11, 2026
26dacc7
Rewrite as full validation pipeline with resume support
tarassh Feb 11, 2026
dddc179
Increase eval to 10 epochs / 5 HP sweeps for publication-grade results
tarassh Feb 11, 2026
24b1b04
Update PR description with definitive L40S results
tarassh Feb 14, 2026
26f8a61
remove pr_*.md file
tarassh Feb 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions app/vjepa/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ def main(args, resume_preempt=False):
use_silu = cfgs_model.get("use_silu", False)
use_pred_silu = cfgs_model.get("use_pred_silu", False)
wide_silu = cfgs_model.get("wide_silu", True)
# -- ST-A² (Spatiotemporal Area Attention)
use_area_attention = cfgs_model.get("use_area_attention", False)
area_attention_layers = cfgs_model.get("area_attention_layers", None)
area_spatial_splits = cfgs_model.get("area_spatial_splits", 2)
area_temporal_splits = cfgs_model.get("area_temporal_splits", 2)
area_residual_scale = cfgs_model.get("area_residual_scale", 1.0)

# -- DATA
cfgs_data = args.get("data")
Expand Down Expand Up @@ -218,6 +224,11 @@ def main(args, resume_preempt=False):
wide_silu=wide_silu,
use_rope=use_rope,
use_activation_checkpointing=use_activation_checkpointing,
use_area_attention=use_area_attention,
area_attention_layers=area_attention_layers,
area_spatial_splits=area_spatial_splits,
area_temporal_splits=area_temporal_splits,
area_residual_scale=area_residual_scale,
)
target_encoder = copy.deepcopy(encoder)

Expand Down
34 changes: 27 additions & 7 deletions app/vjepa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,27 +104,36 @@ def load_checkpoint(
epoch = checkpoint["epoch"]

# -- loading encoder
# Use strict=False when annealing to allow loading baseline checkpoints
# into area-attention models (RoPEAreaAttention has identical weight
# structure to RoPEAttention, so all shared params load correctly).
pretrained_dict = checkpoint["encoder"]
msg = encoder.load_state_dict(pretrained_dict)
msg = encoder.load_state_dict(pretrained_dict, strict=not is_anneal)
logger.info(f"loaded pretrained encoder from epoch {epoch} with msg: {msg}")

# -- loading predictor
pretrained_dict = checkpoint["predictor"]
msg = predictor.load_state_dict(pretrained_dict)
msg = predictor.load_state_dict(pretrained_dict, strict=not is_anneal)
logger.info(f"loaded pretrained predictor from epoch {epoch} with msg: {msg}")

# -- loading target_encoder
if target_encoder is not None:
print(list(checkpoint.keys()))
pretrained_dict = checkpoint["target_encoder"]
msg = target_encoder.load_state_dict(pretrained_dict)
msg = target_encoder.load_state_dict(pretrained_dict, strict=not is_anneal)
logger.info(f"loaded pretrained target encoder from epoch {epoch} with msg: {msg}")

# -- loading optimizer
opt.load_state_dict(checkpoint["opt"])
if scaler is not None:
scaler.load_state_dict(checkpoint["scaler"])
logger.info(f"loaded optimizers from epoch {epoch}")
# Skip optimizer/scaler restore when annealing from a different
# architecture (e.g., baseline → area-attention) because the optimizer
# state dict keys won't match the new parameter set.
if is_anneal:
logger.info("Annealing: skipping optimizer/scaler restore (fresh optimizer)")
else:
opt.load_state_dict(checkpoint["opt"])
if scaler is not None:
scaler.load_state_dict(checkpoint["scaler"])
logger.info(f"loaded optimizers from epoch {epoch}")
logger.info(f"read-path: {r_path}")
del checkpoint

Expand Down Expand Up @@ -158,6 +167,12 @@ def init_video_model(
use_pred_silu=False,
wide_silu=False,
use_activation_checkpointing=False,
# -- ST-A² params
use_area_attention=False,
area_attention_layers=None,
area_spatial_splits=2,
area_temporal_splits=2,
area_residual_scale=1.0,
):
encoder = video_vit.__dict__[model_name](
img_size=crop_size,
Expand All @@ -170,6 +185,11 @@ def init_video_model(
wide_silu=wide_silu,
use_activation_checkpointing=use_activation_checkpointing,
use_rope=use_rope,
use_area_attention=use_area_attention,
area_attention_layers=area_attention_layers,
area_spatial_splits=area_spatial_splits,
area_temporal_splits=area_temporal_splits,
area_residual_scale=area_residual_scale,
)
encoder = MultiSeqWrapper(encoder)
predictor = vit_pred.__dict__["vit_predictor"](
Expand Down
182 changes: 182 additions & 0 deletions configs/eval/vitl/k400-area-attn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# ST-A² (Spatiotemporal Area Attention) eval config for Kinetics-400
# Based on k400.yaml with area attention enabled in the encoder.
#
# The encoder checkpoint must have been trained with matching area attention
# settings (use_area_attention=true, layers 0-17, 2x2 splits).
#
# Usage:
# python -m evals.main --fname configs/eval/vitl/k400-area-attn.yaml \
# --devices cuda:0 cuda:1 cuda:2 cuda:3 cuda:4 cuda:5 cuda:6 cuda:7

cpus_per_task: 16
eval_name: video_classification_frozen
folder: /your_folder/evals/vitl/k400-area-attn
mem_per_gpu: 220G
nodes: 8
num_workers: 8
resume_checkpoint: true
tag: k400-vitl16-16x8x3-16f-area-attn
tasks_per_node: 8
experiment:
classifier:
num_heads: 16
num_probe_blocks: 4
data:
dataset_type: VideoDataset
dataset_train: /your_data_path/k400_train_paths.csv
dataset_val: /your_data_path/k400_val_paths.csv
frame_step: 4
frames_per_clip: 16
num_classes: 400
num_segments: 8
num_views_per_segment: 3
resolution: 256
optimization:
batch_size: 4
multihead_kwargs:
- final_lr: 0.0
final_weight_decay: 0.01
lr: 0.005
start_lr: 0.005
warmup: 0.0
weight_decay: 0.01
- final_lr: 0.0
final_weight_decay: 0.01
lr: 0.003
start_lr: 0.003
warmup: 0.0
weight_decay: 0.01
- final_lr: 0.0
final_weight_decay: 0.01
lr: 0.001
start_lr: 0.001
warmup: 0.0
weight_decay: 0.01
- final_lr: 0.0
final_weight_decay: 0.01
lr: 0.0003
start_lr: 0.0003
warmup: 0.0
weight_decay: 0.01
- final_lr: 0.0
final_weight_decay: 0.01
lr: 0.0001
start_lr: 0.0001
warmup: 0.0
weight_decay: 0.01
- final_lr: 0.0
final_weight_decay: 0.1
lr: 0.005
start_lr: 0.005
warmup: 0.0
weight_decay: 0.1
- final_lr: 0.0
final_weight_decay: 0.1
lr: 0.003
start_lr: 0.003
warmup: 0.0
weight_decay: 0.1
- final_lr: 0.0
final_weight_decay: 0.1
lr: 0.001
start_lr: 0.001
warmup: 0.0
weight_decay: 0.1
- final_lr: 0.0
final_weight_decay: 0.1
lr: 0.0003
start_lr: 0.0003
warmup: 0.0
weight_decay: 0.1
- final_lr: 0.0
final_weight_decay: 0.1
lr: 0.0001
start_lr: 0.0001
warmup: 0.0
weight_decay: 0.1
- final_lr: 0.0
final_weight_decay: 0.4
lr: 0.005
start_lr: 0.005
warmup: 0.0
weight_decay: 0.4
- final_lr: 0.0
final_weight_decay: 0.4
lr: 0.003
start_lr: 0.003
warmup: 0.0
weight_decay: 0.4
- final_lr: 0.0
final_weight_decay: 0.4
lr: 0.001
start_lr: 0.001
warmup: 0.0
weight_decay: 0.4
- final_lr: 0.0
final_weight_decay: 0.4
lr: 0.0003
start_lr: 0.0003
warmup: 0.0
weight_decay: 0.4
- final_lr: 0.0
final_weight_decay: 0.4
lr: 0.0001
start_lr: 0.0001
warmup: 0.0
weight_decay: 0.4
- final_lr: 0.0
final_weight_decay: 0.8
lr: 0.005
start_lr: 0.005
warmup: 0.0
weight_decay: 0.8
- final_lr: 0.0
final_weight_decay: 0.8
lr: 0.003
start_lr: 0.003
warmup: 0.0
weight_decay: 0.8
- final_lr: 0.0
final_weight_decay: 0.8
lr: 0.001
start_lr: 0.001
warmup: 0.0
weight_decay: 0.8
- final_lr: 0.0
final_weight_decay: 0.8
lr: 0.0003
start_lr: 0.0003
warmup: 0.0
weight_decay: 0.8
- final_lr: 0.0
final_weight_decay: 0.8
lr: 0.0001
start_lr: 0.0001
warmup: 0.0
weight_decay: 0.8
num_epochs: 20
use_bfloat16: true
use_pos_embed: false
model_kwargs:
checkpoint: /your_vjepa2_checkpoints/vitl-area-attn.pt
module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip
pretrain_kwargs:
encoder:
checkpoint_key: target_encoder
img_temporal_dim_size: null
model_name: vit_large
patch_size: 16
tubelet_size: 2
uniform_power: true
use_rope: true
# -- ST-A² configuration (must match pretraining)
use_area_attention: true
area_attention_layers:
- 0
- 18
area_spatial_splits: 2
area_temporal_splits: 2
area_residual_scale: 1.0
wrapper_kwargs:
max_frames: 128
use_pos_embed: false
Loading