diff --git a/notebooks/vjepa2_demo.ipynb b/notebooks/vjepa2_demo.ipynb index 2a816bc5..67c2945d 100644 --- a/notebooks/vjepa2_demo.ipynb +++ b/notebooks/vjepa2_demo.ipynb @@ -81,14 +81,14 @@ " return video\n", "\n", "\n", - "def forward_vjepa_video(model_hf, model_pt, hf_transform, pt_transform):\n", + "def forward_vjepa_video(model_hf, model_pt, hf_transform, pt_transform, device):\n", " # Run a sample inference with VJEPA\n", " with torch.inference_mode():\n", " # Read and pre-process the image\n", " video = get_video() # T x H x W x C\n", " video = torch.from_numpy(video).permute(0, 3, 1, 2) # T x C x H x W\n", - " x_pt = pt_transform(video).cuda().unsqueeze(0)\n", - " x_hf = hf_transform(video, return_tensors=\"pt\")[\"pixel_values_videos\"].to(\"cuda\")\n", + " x_pt = pt_transform(video).to(device).unsqueeze(0)\n", + " x_hf = hf_transform(video, return_tensors=\"pt\")[\"pixel_values_videos\"].to(device)\n", " # Extract the patch-wise features from the last layer\n", " out_patch_features_pt = model_pt(x_pt)\n", " out_patch_features_hf = model_hf.get_vision_features(x_hf)\n", @@ -176,9 +176,17 @@ "# Path to local PyTorch weights\n", "pt_model_path = \"YOUR_MODEL_PATH\"\n", "\n", + "# Configuring GPU acceleration for CUDA or MPS(Apple Silicon)\n", + "if torch.cuda.is_available():\n", + " device = torch.device(\"cuda\")\n", + "elif torch.backends.mps.is_available():\n", + " device = torch.device(\"mps\") # Apple Silicon GPU support\n", + "else:\n", + " device = torch.device(\"cpu\")\n", + "\n", "# Initialize the HuggingFace model, load pretrained weights\n", "model_hf = AutoModel.from_pretrained(hf_model_name)\n", - "model_hf.cuda().eval()\n", + "model_hf.to(device).eval()\n", "\n", "# Build HuggingFace preprocessing transform\n", "hf_transform = AutoVideoProcessor.from_pretrained(hf_model_name)\n", @@ -186,12 +194,12 @@ "\n", "# Initialize the PyTorch model, load pretrained weights\n", "model_pt = vit_giant_xformers_rope(img_size=(img_size, img_size), num_frames=64)\n", - "model_pt.cuda().eval()\n", + "model_pt.to(device).eval()\n", "load_pretrained_vjepa_pt_weights(model_pt, pt_model_path)\n", "\n", "### Can also use torch.hub to load the model\n", "# model_pt, _ = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_vit_giant_384')\n", - "# model_pt.cuda().eval()\n", + "# model_pt.to(device).eval()\n", "\n", "# Build PyTorch preprocessing transform\n", "pt_video_transform = build_pt_video_transform(img_size=img_size)" @@ -212,7 +220,7 @@ "source": [ "# Inference on video to get the patch-wise features\n", "out_patch_features_hf, out_patch_features_pt = forward_vjepa_video(\n", - " model_hf, model_pt, hf_transform, pt_video_transform\n", + " model_hf, model_pt, hf_transform, pt_video_transform, device\n", ")\n", "\n", "print(\n", @@ -246,7 +254,7 @@ "# Initialize the classifier\n", "classifier_model_path = \"YOUR_ATTENTIVE_PROBE_PATH\"\n", "classifier = (\n", - " AttentiveClassifier(embed_dim=model_pt.embed_dim, num_heads=16, depth=4, num_classes=174).cuda().eval()\n", + " AttentiveClassifier(embed_dim=model_pt.embed_dim, num_heads=16, depth=4, num_classes=174).to(device).eval()\n", ")\n", "load_pretrained_vjepa_classifier_weights(classifier, classifier_model_path)\n", "\n", diff --git a/notebooks/vjepa2_demo.py b/notebooks/vjepa2_demo.py index 625c7112..1267bfb8 100644 --- a/notebooks/vjepa2_demo.py +++ b/notebooks/vjepa2_demo.py @@ -63,14 +63,14 @@ def get_video(): return video -def forward_vjepa_video(model_hf, model_pt, hf_transform, pt_transform): +def forward_vjepa_video(model_hf, model_pt, hf_transform, pt_transform, device): # Run a sample inference with VJEPA with torch.inference_mode(): # Read and pre-process the image video = get_video() # T x H x W x C video = torch.from_numpy(video).permute(0, 3, 1, 2) # T x C x H x W - x_pt = pt_transform(video).cuda().unsqueeze(0) - x_hf = hf_transform(video, return_tensors="pt")["pixel_values_videos"].to("cuda") + x_pt = pt_transform(video).to(device).unsqueeze(0) + x_hf = hf_transform(video, return_tensors="pt")["pixel_values_videos"].to(device) # Extract the patch-wise features from the last layer out_patch_features_pt = model_pt(x_pt) out_patch_features_hf = model_hf.get_vision_features(x_hf) @@ -96,7 +96,7 @@ def get_vjepa_video_classification_results(classifier, out_patch_features_pt): return -def run_sample_inference(): +def run_sample_inference(device): # HuggingFace model repo name hf_model_name = ( "facebook/vjepa2-vitg-fpc64-384" # Replace with your favored model, e.g. facebook/vjepa2-vitg-fpc64-384 @@ -114,7 +114,7 @@ def run_sample_inference(): # Initialize the HuggingFace model, load pretrained weights model_hf = AutoModel.from_pretrained(hf_model_name) - model_hf.cuda().eval() + model_hf.to(device).eval() # Build HuggingFace preprocessing transform hf_transform = AutoVideoProcessor.from_pretrained(hf_model_name) @@ -122,7 +122,7 @@ def run_sample_inference(): # Initialize the PyTorch model, load pretrained weights model_pt = vit_giant_xformers_rope(img_size=(img_size, img_size), num_frames=64) - model_pt.cuda().eval() + model_pt.to(device).eval() load_pretrained_vjepa_pt_weights(model_pt, pt_model_path) # Build PyTorch preprocessing transform @@ -130,7 +130,7 @@ def run_sample_inference(): # Inference on video out_patch_features_hf, out_patch_features_pt = forward_vjepa_video( - model_hf, model_pt, hf_transform, pt_video_transform + model_hf, model_pt, hf_transform, pt_video_transform, device ) print( @@ -146,7 +146,7 @@ def run_sample_inference(): # Initialize the classifier classifier_model_path = "YOUR_ATTENTIVE_PROBE_PATH" classifier = ( - AttentiveClassifier(embed_dim=model_pt.embed_dim, num_heads=16, depth=4, num_classes=174).cuda().eval() + AttentiveClassifier(embed_dim=model_pt.embed_dim, num_heads=16, depth=4, num_classes=174).to(device).eval() ) load_pretrained_vjepa_classifier_weights(classifier, classifier_model_path) @@ -167,5 +167,12 @@ def run_sample_inference(): if __name__ == "__main__": + # Configuring GPU acceleration for CUDA or MPS(Apple Silicon) + if torch.cuda.is_available(): + device = torch.device("cuda") + elif torch.backends.mps.is_available(): + device = torch.device("mps") + else: + device = torch.device("cpu") # Run with: `python -m notebooks.vjepa2_demo` - run_sample_inference() + run_sample_inference(device)