diff --git a/.gitignore b/.gitignore
index 34a74493..f207be41 100644
--- a/.gitignore
+++ b/.gitignore
@@ -9,6 +9,8 @@ __pycache__/
# Distribution / packaging
.Python
build/
+checkpoints
+Data
develop-eggs/
dist/
downloads/
@@ -18,6 +20,7 @@ lib/
lib64/
parts/
sdist/
+segmentation_output
var/
wheels/
pip-wheel-metadata/
@@ -104,6 +107,7 @@ celerybeat.pid
# Environments
.env
.venv
+venv3.9
env/
venv/
ENV/
diff --git a/engine_finetune.py b/engine_finetune.py
index fe60d442..2cc37f33 100644
--- a/engine_finetune.py
+++ b/engine_finetune.py
@@ -1,148 +1,148 @@
-import os
-import csv
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import numpy as np
-import matplotlib.pyplot as plt
-from typing import Iterable, Optional
-from timm.data import Mixup
-from timm.utils import accuracy
-from sklearn.metrics import (
- accuracy_score, roc_auc_score, f1_score, average_precision_score,
- hamming_loss, jaccard_score, recall_score, precision_score, cohen_kappa_score
-)
-from pycm import ConfusionMatrix
-import util.misc as misc
-import util.lr_sched as lr_sched
-
-def train_one_epoch(
- model: torch.nn.Module,
- criterion: torch.nn.Module,
- data_loader: Iterable,
- optimizer: torch.optim.Optimizer,
- device: torch.device,
- epoch: int,
- loss_scaler,
- max_norm: float = 0,
- mixup_fn: Optional[Mixup] = None,
- log_writer=None,
- args=None
-):
- """Train the model for one epoch."""
- model.train(True)
- metric_logger = misc.MetricLogger(delimiter=" ")
- metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
- print_freq, accum_iter = 20, args.accum_iter
- optimizer.zero_grad()
-
- if log_writer:
- print(f'log_dir: {log_writer.log_dir}')
-
- for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, f'Epoch: [{epoch}]')):
- if data_iter_step % accum_iter == 0:
- lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
-
- samples, targets = samples.to(device, non_blocking=True), targets.to(device, non_blocking=True)
- if mixup_fn:
- samples, targets = mixup_fn(samples, targets)
-
- with torch.cuda.amp.autocast():
- outputs = model(samples)
- loss = criterion(outputs, targets)
- loss_value = loss.item()
- loss /= accum_iter
-
- loss_scaler(loss, optimizer, clip_grad=max_norm, parameters=model.parameters(), create_graph=False,
- update_grad=(data_iter_step + 1) % accum_iter == 0)
- if (data_iter_step + 1) % accum_iter == 0:
- optimizer.zero_grad()
-
- torch.cuda.synchronize()
- metric_logger.update(loss=loss_value)
- min_lr = 10.
- max_lr = 0.
- for group in optimizer.param_groups:
- min_lr = min(min_lr, group["lr"])
- max_lr = max(max_lr, group["lr"])
-
- metric_logger.update(lr=max_lr)
-
- loss_value_reduce = misc.all_reduce_mean(loss_value)
- if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
- """ We use epoch_1000x as the x-axis in tensorboard.
- This calibrates different curves when batch size changes.
- """
- epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
- log_writer.add_scalar('loss/train', loss_value_reduce, epoch_1000x)
- log_writer.add_scalar('lr', max_lr, epoch_1000x)
-
- metric_logger.synchronize_between_processes()
- print("Averaged stats:", metric_logger)
- return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
-
-@torch.no_grad()
-def evaluate(data_loader, model, device, args, epoch, mode, num_class, log_writer):
- """Evaluate the model."""
- criterion = nn.CrossEntropyLoss()
- metric_logger = misc.MetricLogger(delimiter=" ")
- os.makedirs(os.path.join(args.output_dir, args.task), exist_ok=True)
-
- model.eval()
- true_onehot, pred_onehot, true_labels, pred_labels, pred_softmax = [], [], [], [], []
-
- for batch in metric_logger.log_every(data_loader, 10, f'{mode}:'):
- images, target = batch[0].to(device, non_blocking=True), batch[-1].to(device, non_blocking=True)
- target_onehot = F.one_hot(target.to(torch.int64), num_classes=num_class)
-
- with torch.cuda.amp.autocast():
- output = model(images)
- loss = criterion(output, target)
- output_ = nn.Softmax(dim=1)(output)
- output_label = output_.argmax(dim=1)
- output_onehot = F.one_hot(output_label.to(torch.int64), num_classes=num_class)
-
- metric_logger.update(loss=loss.item())
- true_onehot.extend(target_onehot.cpu().numpy())
- pred_onehot.extend(output_onehot.detach().cpu().numpy())
- true_labels.extend(target.cpu().numpy())
- pred_labels.extend(output_label.detach().cpu().numpy())
- pred_softmax.extend(output_.detach().cpu().numpy())
-
- accuracy = accuracy_score(true_labels, pred_labels)
- hamming = hamming_loss(true_onehot, pred_onehot)
- jaccard = jaccard_score(true_onehot, pred_onehot, average='macro')
- average_precision = average_precision_score(true_onehot, pred_softmax, average='macro')
- kappa = cohen_kappa_score(true_labels, pred_labels)
- f1 = f1_score(true_onehot, pred_onehot, zero_division=0, average='macro')
- roc_auc = roc_auc_score(true_onehot, pred_softmax, multi_class='ovr', average='macro')
- precision = precision_score(true_onehot, pred_onehot, zero_division=0, average='macro')
- recall = recall_score(true_onehot, pred_onehot, zero_division=0, average='macro')
-
- score = (f1 + roc_auc + kappa) / 3
- if log_writer:
- for metric_name, value in zip(['accuracy', 'f1', 'roc_auc', 'hamming', 'jaccard', 'precision', 'recall', 'average_precision', 'kappa', 'score'],
- [accuracy, f1, roc_auc, hamming, jaccard, precision, recall, average_precision, kappa, score]):
- log_writer.add_scalar(f'perf/{metric_name}', value, epoch)
-
- print(f'val loss: {metric_logger.meters["loss"].global_avg}')
- print(f'Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}, ROC AUC: {roc_auc:.4f}, Hamming Loss: {hamming:.4f},\n'
- f' Jaccard Score: {jaccard:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f},\n'
- f' Average Precision: {average_precision:.4f}, Kappa: {kappa:.4f}, Score: {score:.4f}')
-
- metric_logger.synchronize_between_processes()
-
- results_path = os.path.join(args.output_dir, args.task, f'metrics_{mode}.csv')
- file_exists = os.path.isfile(results_path)
- with open(results_path, 'a', newline='', encoding='utf8') as cfa:
- wf = csv.writer(cfa)
- if not file_exists:
- wf.writerow(['val_loss', 'accuracy', 'f1', 'roc_auc', 'hamming', 'jaccard', 'precision', 'recall', 'average_precision', 'kappa'])
- wf.writerow([metric_logger.meters["loss"].global_avg, accuracy, f1, roc_auc, hamming, jaccard, precision, recall, average_precision, kappa])
-
- if mode == 'test':
- cm = ConfusionMatrix(actual_vector=true_labels, predict_vector=pred_labels)
- cm.plot(cmap=plt.cm.Blues, number_label=True, normalized=True, plot_lib="matplotlib")
- plt.savefig(os.path.join(args.output_dir, args.task, 'confusion_matrix_test.jpg'), dpi=600, bbox_inches='tight')
-
- return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, score
+import os
+import csv
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import matplotlib.pyplot as plt
+from typing import Iterable, Optional
+from timm.data import Mixup
+from timm.utils import accuracy
+from sklearn.metrics import (
+ accuracy_score, roc_auc_score, f1_score, average_precision_score,
+ hamming_loss, jaccard_score, recall_score, precision_score, cohen_kappa_score
+)
+from pycm import ConfusionMatrix
+import util.misc as misc
+import util.lr_sched as lr_sched
+
+def train_one_epoch(
+ model: torch.nn.Module,
+ criterion: torch.nn.Module,
+ data_loader: Iterable,
+ optimizer: torch.optim.Optimizer,
+ device: torch.device,
+ epoch: int,
+ loss_scaler,
+ max_norm: float = 0,
+ mixup_fn: Optional[Mixup] = None,
+ log_writer=None,
+ args=None
+):
+ """Train the model for one epoch."""
+ model.train(True)
+ metric_logger = misc.MetricLogger(delimiter=" ")
+ metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
+ print_freq, accum_iter = 20, args.accum_iter
+ optimizer.zero_grad()
+
+ if log_writer:
+ print(f'log_dir: {log_writer.log_dir}')
+
+ for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, f'Epoch: [{epoch}]')):
+ if data_iter_step % accum_iter == 0:
+ lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
+
+ samples, targets = samples.to(device, non_blocking=True), targets.to(device, non_blocking=True)
+ if mixup_fn:
+ samples, targets = mixup_fn(samples, targets)
+
+ with torch.cuda.amp.autocast():
+ outputs = model(samples)
+ loss = criterion(outputs, targets)
+ loss_value = loss.item()
+ loss /= accum_iter
+
+ loss_scaler(loss, optimizer, clip_grad=max_norm, parameters=model.parameters(), create_graph=False,
+ update_grad=(data_iter_step + 1) % accum_iter == 0)
+ if (data_iter_step + 1) % accum_iter == 0:
+ optimizer.zero_grad()
+
+ torch.cuda.synchronize()
+ metric_logger.update(loss=loss_value)
+ min_lr = 10.
+ max_lr = 0.
+ for group in optimizer.param_groups:
+ min_lr = min(min_lr, group["lr"])
+ max_lr = max(max_lr, group["lr"])
+
+ metric_logger.update(lr=max_lr)
+
+ loss_value_reduce = misc.all_reduce_mean(loss_value)
+ if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
+ """ We use epoch_1000x as the x-axis in tensorboard.
+ This calibrates different curves when batch size changes.
+ """
+ epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
+ log_writer.add_scalar('loss/train', loss_value_reduce, epoch_1000x)
+ log_writer.add_scalar('lr', max_lr, epoch_1000x)
+
+ metric_logger.synchronize_between_processes()
+ print("Averaged stats:", metric_logger)
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
+
+@torch.no_grad()
+def evaluate(data_loader, model, device, args, epoch, mode, num_class, log_writer):
+ """Evaluate the model."""
+ criterion = nn.CrossEntropyLoss()
+ metric_logger = misc.MetricLogger(delimiter=" ")
+ os.makedirs(os.path.join(args.output_dir, args.task), exist_ok=True)
+
+ model.eval()
+ true_onehot, pred_onehot, true_labels, pred_labels, pred_softmax = [], [], [], [], []
+
+ for batch in metric_logger.log_every(data_loader, 10, f'{mode}:'):
+ images, target = batch[0].to(device, non_blocking=True), batch[-1].to(device, non_blocking=True)
+ target_onehot = F.one_hot(target.to(torch.int64), num_classes=num_class)
+
+ with torch.cuda.amp.autocast():
+ output = model(images)
+ loss = criterion(output, target)
+ output_ = nn.Softmax(dim=1)(output)
+ output_label = output_.argmax(dim=1)
+ output_onehot = F.one_hot(output_label.to(torch.int64), num_classes=num_class)
+
+ metric_logger.update(loss=loss.item())
+ true_onehot.extend(target_onehot.cpu().numpy())
+ pred_onehot.extend(output_onehot.detach().cpu().numpy())
+ true_labels.extend(target.cpu().numpy())
+ pred_labels.extend(output_label.detach().cpu().numpy())
+ pred_softmax.extend(output_.detach().cpu().numpy())
+
+ accuracy = accuracy_score(true_labels, pred_labels)
+ hamming = hamming_loss(true_onehot, pred_onehot)
+ jaccard = jaccard_score(true_onehot, pred_onehot, average='macro')
+ average_precision = average_precision_score(true_onehot, pred_softmax, average='macro')
+ kappa = cohen_kappa_score(true_labels, pred_labels)
+ f1 = f1_score(true_onehot, pred_onehot, zero_division=0, average='macro')
+ roc_auc = roc_auc_score(true_onehot, pred_softmax, multi_class='ovr', average='macro')
+ precision = precision_score(true_onehot, pred_onehot, zero_division=0, average='macro')
+ recall = recall_score(true_onehot, pred_onehot, zero_division=0, average='macro')
+
+ score = (f1 + roc_auc + kappa) / 3
+ if log_writer:
+ for metric_name, value in zip(['accuracy', 'f1', 'roc_auc', 'hamming', 'jaccard', 'precision', 'recall', 'average_precision', 'kappa', 'score'],
+ [accuracy, f1, roc_auc, hamming, jaccard, precision, recall, average_precision, kappa, score]):
+ log_writer.add_scalar(f'perf/{metric_name}', value, epoch)
+
+ print(f'val loss: {metric_logger.meters["loss"].global_avg}')
+ print(f'Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}, ROC AUC: {roc_auc:.4f}, Hamming Loss: {hamming:.4f},\n'
+ f' Jaccard Score: {jaccard:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f},\n'
+ f' Average Precision: {average_precision:.4f}, Kappa: {kappa:.4f}, Score: {score:.4f}')
+
+ metric_logger.synchronize_between_processes()
+
+ results_path = os.path.join(args.output_dir, args.task, f'metrics_{mode}.csv')
+ file_exists = os.path.isfile(results_path)
+ with open(results_path, 'a', newline='', encoding='utf8') as cfa:
+ wf = csv.writer(cfa)
+ if not file_exists:
+ wf.writerow(['val_loss', 'accuracy', 'f1', 'roc_auc', 'hamming', 'jaccard', 'precision', 'recall', 'average_precision', 'kappa'])
+ wf.writerow([metric_logger.meters["loss"].global_avg, accuracy, f1, roc_auc, hamming, jaccard, precision, recall, average_precision, kappa])
+
+ if mode == 'test':
+ cm = ConfusionMatrix(actual_vector=true_labels, predict_vector=pred_labels)
+ cm.plot(cmap=plt.cm.Blues, number_label=True, normalized=True, plot_lib="matplotlib")
+ plt.savefig(os.path.join(args.output_dir, args.task, 'confusion_matrix_test.jpg'), dpi=600, bbox_inches='tight')
+
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, score
diff --git a/engine_segmentation.py b/engine_segmentation.py
new file mode 100644
index 00000000..c1c11229
--- /dev/null
+++ b/engine_segmentation.py
@@ -0,0 +1,182 @@
+import torch
+import torch.nn.functional as F
+import numpy as np
+
+
+# ============================================================
+# Dice Loss
+# ============================================================
+def dice_loss(pred, target, smooth=1e-6):
+ """
+ Computes multi-class Dice loss.
+
+ Args:
+ pred : Raw logits from the model [B, C, H, W]
+ target : Ground truth labels [B, H, W]
+ smooth : Small constant to avoid division by zero
+
+ Returns:
+ Scalar dice loss (1 - dice coefficient)
+ """
+
+ # Convert logits to probabilities
+ pred = F.softmax(pred, dim=1)
+
+ # Number of classes (e.g., 2 for background/drusen)
+ num_classes = pred.shape[1]
+
+ # Convert target to one-hot representation
+ target_oh = F.one_hot(target, num_classes).permute(0, 3, 1, 2).float()
+
+ # Intersection between prediction and ground truth
+ inter = (pred * target_oh).sum((2, 3))
+
+ # Sum of prediction and ground truth areas
+ union = pred.sum((2, 3)) + target_oh.sum((2, 3))
+
+ # Dice coefficient → converted to loss (1 - dice)
+ return 1 - ((2 * inter + smooth) / (union + smooth)).mean()
+
+
+# ============================================================
+# Combined CE + Dice Loss
+# ============================================================
+def combined_loss_fn(outputs, targets, ce_fn, dice_w=1.0):
+ """
+ Combines Cross-Entropy loss with Dice loss.
+
+ Args:
+ outputs : Model logits [B, C, H, W]
+ targets : Ground truth [B, H, W]
+ ce_fn : CrossEntropyLoss function
+ dice_w : Weight for dice loss term
+
+ Returns:
+ Weighted sum of CE and Dice loss
+ """
+
+ return ce_fn(outputs, targets) + dice_w * dice_loss(outputs, targets)
+
+
+# ============================================================
+# Metric Computation
+# ============================================================
+def compute_metrics(preds, targets, smooth=1e-6):
+ """
+ Computes pixel accuracy, Dice, and IoU.
+
+ Args:
+ preds : Binary predictions (numpy array)
+ targets : Binary ground truth (numpy array)
+
+ Returns:
+ pixel_acc : Pixel-wise accuracy
+ dice : Dice coefficient
+ iou : Intersection over Union
+ """
+
+ # Pixel-wise accuracy
+ pixel_acc = (preds == targets).mean()
+
+ # Intersection between prediction and GT
+ inter = (preds & targets).sum()
+
+ # Dice coefficient
+ dice = (2 * inter + smooth) / (preds.sum() + targets.sum() + smooth)
+
+ # IoU computation
+ union = preds.sum() + targets.sum() - inter
+ iou = (inter + smooth) / (union + smooth)
+
+ return pixel_acc, dice, iou
+
+
+# ============================================================
+# Training Loop
+# ============================================================
+def train_segmentation(model, loader, loss_fn, optimizer, device):
+ """
+ One epoch training for segmentation model.
+
+ Args:
+ model : Segmentation network
+ loader : Training dataloader
+ loss_fn : Loss function (CE + Dice)
+ optimizer : Optimizer
+ device : cuda/cpu
+
+ Returns:
+ Average epoch loss
+ """
+
+ model.train()
+ total = 0
+
+ for step, (x, y) in enumerate(loader):
+
+ # Move data to device
+ x, y = x.to(device), y.to(device)
+
+ optimizer.zero_grad()
+
+ # Forward pass
+ out = model(x)
+
+ # Compute loss
+ loss = loss_fn(out, y)
+
+ # Backpropagation
+ loss.backward()
+ optimizer.step()
+
+ # Accumulate loss
+ total += loss.item() * x.size(0)
+
+ # Progress print
+ if step % 10 == 0:
+ print(f" [batch {step}/{len(loader)}] loss: {loss.item():.4f}")
+
+ return total / len(loader.dataset)
+
+
+# ============================================================
+# Validation / Evaluation Loop
+# ============================================================
+@torch.no_grad()
+def evaluate_segmentation(model, loader, loss_fn, device):
+ """
+ Runs inference on validation/test set and collects predictions.
+
+ Args:
+ model : Trained model
+ loader : Val/test dataloader
+ loss_fn : Loss function
+ device : cuda/cpu
+
+ Returns:
+ avg_loss : Average loss over dataset
+ P : All predictions (numpy)
+ T : All ground truth (numpy)
+ """
+
+ model.eval()
+
+ total = 0
+ P, T = [], []
+
+ for x, y in loader:
+
+ x, y = x.to(device), y.to(device)
+
+ # Forward pass
+ out = model(x)
+
+ # Loss computation
+ loss = loss_fn(out, y)
+ total += loss.item() * x.size(0)
+
+ # Store predictions and targets
+ P.append(out.argmax(1).cpu())
+ T.append(y.cpu())
+
+ return total / len(loader.dataset), torch.cat(P).numpy(), torch.cat(T).numpy()
diff --git a/examples/RETFound_MendeleyOCT_demo.ipynb b/examples/RETFound_MendeleyOCT_demo.ipynb
new file mode 100644
index 00000000..7ea2b1bf
--- /dev/null
+++ b/examples/RETFound_MendeleyOCT_demo.ipynb
@@ -0,0 +1,1166 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "76b39fb1",
+ "metadata": {
+ "id": "76b39fb1",
+ "jp-MarkdownHeadingCollapsed": true
+ },
+ "source": [
+ "## Jupyter notebook example - Segementation task\n",
+ "### Example using [MendeleyOCT](https://data.mendeley.com/datasets/rscbjbr9sj/2) dataset\n",
+ "**Application**: Using RETFound for Drusen segmentation\n",
+ "\n",
+ "**Author**: Yukun Zhou, Salman Shams\n",
+ "\n",
+ "**Date**: 08 Jan 2026\n",
+ "\n",
+ "**Contribution:** \n",
+ "This notebook extends the original RETFound classification pipeline to **semantic segmentation** by adding a lightweight decoder on top of the pretrained ViT encoder and training with CE + Dice loss.\n",
+ "\n",
+ "**Performance**:\n",
+ "\n",
+ "
\n",
+ "\n",
+ " | Dice | \n",
+ " IOU | \n",
+ "
\n",
+ "\n",
+ " | 0.4804 | \n",
+ " 0.3495 | \n",
+ "
\n",
+ "
\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7ec435a7",
+ "metadata": {
+ "id": "7ec435a7"
+ },
+ "source": [
+ "## 1. Install environment\n",
+ "1. Follow [RETFound README](https://github.com/rmaphoh/RETFound) to install environment\n",
+ "2. Restart this Jupyter Notebook\n",
+ "3. Select Kernel retfound\n",
+ "\n",
+ "> **Note:** Ensure the same PyTorch / timm versions as the original RETFound repository to avoid weight-loading issues."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "7cbf5e93-6ca0-4401-88e6-64e39968e7cd",
+ "metadata": {
+ "id": "7cbf5e93-6ca0-4401-88e6-64e39968e7cd"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Project root: F:\\GitHub\\RETFound\n",
+ "sys.executable: f:\\GitHub\\RETFound\\venv3.9\\Scripts\\python.exe\n",
+ "torch version: 2.8.0+cpu\n"
+ ]
+ }
+ ],
+ "source": [
+ "import sys, torch\n",
+ "from pathlib import Path\n",
+ "import os\n",
+ "\n",
+ "PROJECT_ROOT = Path.cwd().resolve()\n",
+ "\n",
+ "if PROJECT_ROOT.name == 'examples': PROJECT_ROOT = PROJECT_ROOT.parent\n",
+ "os.chdir(PROJECT_ROOT)\n",
+ "\n",
+ "print('Project root:', PROJECT_ROOT)\n",
+ "print(\"sys.executable:\", sys.executable)\n",
+ "print(\"torch version:\", torch.__version__)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ed67953f",
+ "metadata": {
+ "id": "ed67953f"
+ },
+ "source": [
+ "## 2. Prepare MendeleyOCT dataset\n",
+ "1. Download dataset from the [gdrive](https://drive.google.com/drive/folders/1gBFXrkhRpp8EbTBlTn72h-UvS6a1JYBv?usp=sharing).\n",
+ "2. Put the data folder under the project directory, e.g. \"RETFound/MendeleyOCT\".\n",
+ "\n",
+ "> **Note:** \n",
+ "The dataset used in this work has been **preprocessed and annotated for the segmentation task**. \n",
+ "> - Each B-scan was **horizontally cropped from the top and bottom** to focus on the retinal region. \n",
+ "> - Binary pixel-level annotations were created for **drusen segmentation** (0: background, 1: drusen). \n",
+ "> - Image–mask pairs are provided in JPEG/PNG format for direct training.\n",
+ "> - Paired format: `images/*.jpeg` ↔ `masks/*_mask.png`"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "357be2fa-a914-4d1f-8759-76b2b1c3f20f",
+ "metadata": {
+ "id": "357be2fa-a914-4d1f-8759-76b2b1c3f20f"
+ },
+ "source": [
+ "## 3. Hyperparameter and Path Settings\n",
+ "- Backbone: RETFound ViT-Large \n",
+ "- Task: Binary drusen segmentation \n",
+ "- Loss: Cross-Entropy + Dice \n",
+ "- Image size: 256×256 \n",
+ "- Classes: 2 (background, drusen)\n",
+ "\n",
+ "> **Note:** Encoder check point can be downloaded from [here](https://drive.google.com/drive/folders/14SQdLuIxfkiqz_zmpvNkd9Ka4NTW3Fml?usp=sharing)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "7f192e16",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Dataset: MendeleyOCT\\Data\n",
+ "Checkpoint: checkpoints\\checkpoint-best.pth\n"
+ ]
+ }
+ ],
+ "source": [
+ "from pathlib import Path\n",
+ "\n",
+ "DATA_PATH = Path(\"MendeleyOCT/Data\")\n",
+ "CKPT = Path(\"checkpoints/checkpoint-best.pth\")\n",
+ "\n",
+ "IMG_SIZE = 256\n",
+ "BATCH_SIZE = 4\n",
+ "EPOCHS = 20\n",
+ "CE_WEIGHT = \"0.1,0.9\"\n",
+ "\n",
+ "print(\"Dataset:\", DATA_PATH)\n",
+ "print(\"Checkpoint:\", CKPT)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6ac04845",
+ "metadata": {
+ "id": "6ac04845"
+ },
+ "source": [
+ "## 4. Fine-tuning RETFound for Segmentation\n",
+ "\n",
+ "The pretrained ViT encoder is initialized from RETFound weights, \n",
+ "and a lightweight convolutional decoder is trained for pixel prediction."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "374fdce3",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "executionInfo": {
+ "elapsed": 12160408,
+ "status": "ok",
+ "timestamp": 1768486129477,
+ "user": {
+ "displayName": "MD SALMAN SHAMS",
+ "userId": "17411188514128174175"
+ },
+ "user_tz": -330
+ },
+ "id": "374fdce3",
+ "outputId": "47f94636-3e02-400e-cf47-fba01f0b88bd"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.12/dist-packages/albumentations/__init__.py:24: UserWarning: A new version of Albumentations is available: 2.0.8 (you have 1.4.24). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.\n",
+ " check_for_updates()\n",
+ "Loaded 826 valid samples from Segmentation/Data/train\n",
+ "Loaded 200 valid samples from Segmentation/Data/val\n",
+ "Loaded 50 valid samples from Segmentation/Data/test\n",
+ "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:627: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n",
+ " warnings.warn(\n",
+ "Position interpolate from 14x14 to 16x16\n",
+ "Pretrained RETFound weights loaded.\n",
+ "\n",
+ "[DEBUG] Starting training loop...\n",
+ "\n",
+ "[DEBUG] Entered epoch 1\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ " [batch 0/207] loss: 1.6072\n",
+ " [batch 10/207] loss: 0.5284\n",
+ " [batch 20/207] loss: 0.5346\n",
+ " [batch 30/207] loss: 0.4621\n",
+ " [batch 40/207] loss: 0.4229\n",
+ " [batch 50/207] loss: 0.4235\n",
+ " [batch 60/207] loss: 0.3985\n",
+ " [batch 70/207] loss: 0.3799\n",
+ " [batch 80/207] loss: 0.4031\n",
+ " [batch 90/207] loss: 0.3644\n",
+ " [batch 100/207] loss: 0.3574\n",
+ " [batch 110/207] loss: 0.4511\n",
+ " [batch 120/207] loss: 0.3842\n",
+ " [batch 130/207] loss: 0.3465\n",
+ " [batch 140/207] loss: 0.3782\n",
+ " [batch 150/207] loss: 0.3437\n",
+ " [batch 160/207] loss: 0.3442\n",
+ " [batch 170/207] loss: 0.3764\n",
+ " [batch 180/207] loss: 0.3371\n",
+ " [batch 190/207] loss: 0.3378\n",
+ " [batch 200/207] loss: 0.3950\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "Epoch 1: Train=0.4005 | Val=0.2806 | Dice=0.5605 | IoU=0.3894\n",
+ "[DEBUG] Entered epoch 2\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ " [batch 0/207] loss: 0.3567\n",
+ " [batch 10/207] loss: 0.2937\n",
+ " [batch 20/207] loss: 0.3549\n",
+ " [batch 30/207] loss: 0.2875\n",
+ " [batch 40/207] loss: 0.2788\n",
+ " [batch 50/207] loss: 0.2848\n",
+ " [batch 60/207] loss: 0.3167\n",
+ " [batch 70/207] loss: 0.2793\n",
+ " [batch 80/207] loss: 0.3208\n",
+ " [batch 90/207] loss: 0.2842\n",
+ " [batch 100/207] loss: 0.3256\n",
+ " [batch 110/207] loss: 0.2269\n",
+ " [batch 120/207] loss: 0.3196\n",
+ " [batch 130/207] loss: 0.2960\n",
+ " [batch 140/207] loss: 0.3300\n",
+ " [batch 150/207] loss: 0.3175\n",
+ " [batch 160/207] loss: 0.2969\n",
+ " [batch 170/207] loss: 0.3674\n",
+ " [batch 180/207] loss: 0.2710\n",
+ " [batch 190/207] loss: 0.3173\n",
+ " [batch 200/207] loss: 0.2744\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "Epoch 2: Train=0.3019 | Val=0.2676 | Dice=0.5939 | IoU=0.4224\n",
+ "[DEBUG] Entered epoch 3\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ " [batch 0/207] loss: 0.2541\n",
+ " [batch 10/207] loss: 0.2887\n",
+ " [batch 20/207] loss: 0.3042\n",
+ " [batch 30/207] loss: 0.2080\n",
+ " [batch 40/207] loss: 0.2231\n",
+ " [batch 50/207] loss: 0.3497\n",
+ " [batch 60/207] loss: 0.3842\n",
+ " [batch 70/207] loss: 0.2624\n",
+ " [batch 80/207] loss: 0.2801\n",
+ " [batch 90/207] loss: 0.3401\n",
+ " [batch 100/207] loss: 0.3218\n",
+ " [batch 110/207] loss: 0.2471\n",
+ " [batch 120/207] loss: 0.2856\n",
+ " [batch 130/207] loss: 0.3261\n",
+ " [batch 140/207] loss: 0.2053\n",
+ " [batch 150/207] loss: 0.2462\n",
+ " [batch 160/207] loss: 0.2413\n",
+ " [batch 170/207] loss: 0.3874\n",
+ " [batch 180/207] loss: 0.2323\n",
+ " [batch 190/207] loss: 0.3317\n",
+ " [batch 200/207] loss: 0.2441\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "Epoch 3: Train=0.2727 | Val=0.2280 | Dice=0.6600 | IoU=0.4925\n",
+ "[DEBUG] Entered epoch 4\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ " [batch 0/207] loss: 0.2119\n",
+ " [batch 10/207] loss: 0.3079\n",
+ " [batch 20/207] loss: 0.2442\n",
+ " [batch 30/207] loss: 0.2435\n",
+ " [batch 40/207] loss: 0.3318\n",
+ " [batch 50/207] loss: 0.2625\n",
+ " [batch 60/207] loss: 0.2661\n",
+ " [batch 70/207] loss: 0.2390\n",
+ " [batch 80/207] loss: 0.2467\n",
+ " [batch 90/207] loss: 0.3138\n",
+ " [batch 100/207] loss: 0.2086\n",
+ " [batch 110/207] loss: 0.2197\n",
+ " [batch 120/207] loss: 0.2395\n",
+ " [batch 130/207] loss: 0.1610\n",
+ " [batch 140/207] loss: 0.3063\n",
+ " [batch 150/207] loss: 0.2948\n",
+ " [batch 160/207] loss: 0.2099\n",
+ " [batch 170/207] loss: 0.2426\n",
+ " [batch 180/207] loss: 0.1962\n",
+ " [batch 190/207] loss: 0.2409\n",
+ " [batch 200/207] loss: 0.2781\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "Epoch 4: Train=0.2448 | Val=0.2342 | Dice=0.6503 | IoU=0.4818\n",
+ "[DEBUG] Entered epoch 5\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ " [batch 0/207] loss: 0.3115\n",
+ " [batch 10/207] loss: 0.1931\n",
+ " [batch 20/207] loss: 0.2424\n",
+ " [batch 30/207] loss: 0.2089\n",
+ " [batch 40/207] loss: 0.1868\n",
+ " [batch 50/207] loss: 0.3348\n",
+ " [batch 60/207] loss: 0.2479\n",
+ " [batch 70/207] loss: 0.1659\n",
+ " [batch 80/207] loss: 0.2025\n",
+ " [batch 90/207] loss: 0.1686\n",
+ " [batch 100/207] loss: 0.2567\n",
+ " [batch 110/207] loss: 0.1917\n",
+ " [batch 120/207] loss: 0.2328\n",
+ " [batch 130/207] loss: 0.2900\n",
+ " [batch 140/207] loss: 0.2101\n",
+ " [batch 150/207] loss: 0.2216\n",
+ " [batch 160/207] loss: 0.2112\n",
+ " [batch 170/207] loss: 0.2272\n",
+ " [batch 180/207] loss: 0.1955\n",
+ " [batch 190/207] loss: 0.2354\n",
+ " [batch 200/207] loss: 0.3457\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "Epoch 5: Train=0.2316 | Val=0.2245 | Dice=0.6685 | IoU=0.5020\n",
+ "[DEBUG] Entered epoch 6\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ " [batch 0/207] loss: 0.2025\n",
+ " [batch 10/207] loss: 0.2152\n",
+ " [batch 20/207] loss: 0.2111\n",
+ " [batch 30/207] loss: 0.2317\n",
+ " [batch 40/207] loss: 0.2358\n",
+ " [batch 50/207] loss: 0.1877\n",
+ " [batch 60/207] loss: 0.1501\n",
+ " [batch 70/207] loss: 0.2540\n",
+ " [batch 80/207] loss: 0.2082\n",
+ " [batch 90/207] loss: 0.2259\n",
+ " [batch 100/207] loss: 0.1742\n",
+ " [batch 110/207] loss: 0.1807\n",
+ " [batch 120/207] loss: 0.1410\n",
+ " [batch 130/207] loss: 0.3102\n",
+ " [batch 140/207] loss: 0.2125\n",
+ " [batch 150/207] loss: 0.2967\n",
+ " [batch 160/207] loss: 0.2384\n",
+ " [batch 170/207] loss: 0.1771\n",
+ " [batch 180/207] loss: 0.2419\n",
+ " [batch 190/207] loss: 0.3038\n",
+ " [batch 200/207] loss: 0.2155\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "Epoch 6: Train=0.2164 | Val=0.2218 | Dice=0.6714 | IoU=0.5053\n",
+ "[DEBUG] Entered epoch 7\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ " [batch 0/207] loss: 0.1777\n",
+ " [batch 10/207] loss: 0.2458\n",
+ " [batch 20/207] loss: 0.2361\n",
+ " [batch 30/207] loss: 0.2720\n",
+ " [batch 40/207] loss: 0.2114\n",
+ " [batch 50/207] loss: 0.2370\n",
+ " [batch 60/207] loss: 0.2767\n",
+ " [batch 70/207] loss: 0.1874\n",
+ " [batch 80/207] loss: 0.1919\n",
+ " [batch 90/207] loss: 0.2213\n",
+ " [batch 100/207] loss: 0.2213\n",
+ " [batch 110/207] loss: 0.2263\n",
+ " [batch 120/207] loss: 0.1775\n",
+ " [batch 130/207] loss: 0.2739\n",
+ " [batch 140/207] loss: 0.3291\n",
+ " [batch 150/207] loss: 0.2236\n",
+ " [batch 160/207] loss: 0.2162\n",
+ " [batch 170/207] loss: 0.1912\n",
+ " [batch 180/207] loss: 0.1648\n",
+ " [batch 190/207] loss: 0.1859\n",
+ " [batch 200/207] loss: 0.2054\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "Epoch 7: Train=0.2038 | Val=0.2250 | Dice=0.6699 | IoU=0.5036\n",
+ "[DEBUG] Entered epoch 8\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ " [batch 0/207] loss: 0.2239\n",
+ " [batch 10/207] loss: 0.1675\n",
+ " [batch 20/207] loss: 0.1499\n",
+ " [batch 30/207] loss: 0.2443\n",
+ " [batch 40/207] loss: 0.2731\n",
+ " [batch 50/207] loss: 0.1959\n",
+ " [batch 60/207] loss: 0.2296\n",
+ " [batch 70/207] loss: 0.2078\n",
+ " [batch 80/207] loss: 0.2190\n",
+ " [batch 90/207] loss: 0.2172\n",
+ " [batch 100/207] loss: 0.1288\n",
+ " [batch 110/207] loss: 0.1876\n",
+ " [batch 120/207] loss: 0.1710\n",
+ " [batch 130/207] loss: 0.2748\n",
+ " [batch 140/207] loss: 0.1927\n",
+ " [batch 150/207] loss: 0.1707\n",
+ " [batch 160/207] loss: 0.2482\n",
+ " [batch 170/207] loss: 0.1914\n",
+ " [batch 180/207] loss: 0.2109\n",
+ " [batch 190/207] loss: 0.1399\n",
+ " [batch 200/207] loss: 0.2101\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "Epoch 8: Train=0.1969 | Val=0.2102 | Dice=0.6928 | IoU=0.5299\n",
+ "[DEBUG] Entered epoch 9\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ " [batch 0/207] loss: 0.1398\n",
+ " [batch 10/207] loss: 0.1192\n",
+ " [batch 20/207] loss: 0.1453\n",
+ " [batch 30/207] loss: 0.2658\n",
+ " [batch 40/207] loss: 0.2172\n",
+ " [batch 50/207] loss: 0.1517\n",
+ " [batch 60/207] loss: 0.1373\n",
+ " [batch 70/207] loss: 0.2481\n",
+ " [batch 80/207] loss: 0.2536\n",
+ " [batch 90/207] loss: 0.1781\n",
+ " [batch 100/207] loss: 0.2304\n",
+ " [batch 110/207] loss: 0.1664\n",
+ " [batch 120/207] loss: 0.3063\n",
+ " [batch 130/207] loss: 0.1773\n",
+ " [batch 140/207] loss: 0.2314\n",
+ " [batch 150/207] loss: 0.3274\n",
+ " [batch 160/207] loss: 0.1585\n",
+ " [batch 170/207] loss: 0.3672\n",
+ " [batch 180/207] loss: 0.2085\n",
+ " [batch 190/207] loss: 0.1775\n",
+ " [batch 200/207] loss: 0.1586\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "Epoch 9: Train=0.1890 | Val=0.2178 | Dice=0.6853 | IoU=0.5213\n",
+ "[DEBUG] Entered epoch 10\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ " [batch 0/207] loss: 0.1770\n",
+ " [batch 10/207] loss: 0.1227\n",
+ " [batch 20/207] loss: 0.1684\n",
+ " [batch 30/207] loss: 0.1987\n",
+ " [batch 40/207] loss: 0.1156\n",
+ " [batch 50/207] loss: 0.1463\n",
+ " [batch 60/207] loss: 0.1355\n",
+ " [batch 70/207] loss: 0.2587\n",
+ " [batch 80/207] loss: 0.0852\n",
+ " [batch 90/207] loss: 0.1101\n",
+ " [batch 100/207] loss: 0.1334\n",
+ " [batch 110/207] loss: 0.1751\n",
+ " [batch 120/207] loss: 0.2327\n",
+ " [batch 130/207] loss: 0.1607\n",
+ " [batch 140/207] loss: 0.1764\n",
+ " [batch 150/207] loss: 0.2233\n",
+ " [batch 160/207] loss: 0.1401\n",
+ " [batch 170/207] loss: 0.2363\n",
+ " [batch 180/207] loss: 0.2062\n",
+ " [batch 190/207] loss: 0.1247\n",
+ " [batch 200/207] loss: 0.1477\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "Epoch 10: Train=0.1837 | Val=0.2289 | Dice=0.6750 | IoU=0.5094\n",
+ "[DEBUG] Entered epoch 11\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ " [batch 0/207] loss: 0.2078\n",
+ " [batch 10/207] loss: 0.2047\n",
+ " [batch 20/207] loss: 0.1852\n",
+ " [batch 30/207] loss: 0.2619\n",
+ " [batch 40/207] loss: 0.1754\n",
+ " [batch 50/207] loss: 0.1004\n",
+ " [batch 60/207] loss: 0.1033\n",
+ " [batch 70/207] loss: 0.2141\n",
+ " [batch 80/207] loss: 0.1825\n",
+ " [batch 90/207] loss: 0.1713\n",
+ " [batch 100/207] loss: 0.1536\n",
+ " [batch 110/207] loss: 0.2070\n",
+ " [batch 120/207] loss: 0.1225\n",
+ " [batch 130/207] loss: 0.2451\n",
+ " [batch 140/207] loss: 0.1151\n",
+ " [batch 150/207] loss: 0.1468\n",
+ " [batch 160/207] loss: 0.1393\n",
+ " [batch 170/207] loss: 0.1352\n",
+ " [batch 180/207] loss: 0.1514\n",
+ " [batch 190/207] loss: 0.1821\n",
+ " [batch 200/207] loss: 0.1583\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "Epoch 11: Train=0.1776 | Val=0.2207 | Dice=0.6855 | IoU=0.5215\n",
+ "[DEBUG] Entered epoch 12\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ " [batch 0/207] loss: 0.2290\n",
+ " [batch 10/207] loss: 0.1820\n",
+ " [batch 20/207] loss: 0.1776\n",
+ " [batch 30/207] loss: 0.1927\n",
+ " [batch 40/207] loss: 0.1674\n",
+ " [batch 50/207] loss: 0.1655\n",
+ " [batch 60/207] loss: 0.1710\n",
+ " [batch 70/207] loss: 0.1582\n",
+ " [batch 80/207] loss: 0.1816\n",
+ " [batch 90/207] loss: 0.1819\n",
+ " [batch 100/207] loss: 0.1738\n",
+ " [batch 110/207] loss: 0.2001\n",
+ " [batch 120/207] loss: 0.2124\n",
+ " [batch 130/207] loss: 0.1337\n",
+ " [batch 140/207] loss: 0.2463\n",
+ " [batch 150/207] loss: 0.1945\n",
+ " [batch 160/207] loss: 0.3018\n",
+ " [batch 170/207] loss: 0.1278\n",
+ " [batch 180/207] loss: 0.2135\n",
+ " [batch 190/207] loss: 0.2194\n",
+ " [batch 200/207] loss: 0.2093\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "Epoch 12: Train=0.1721 | Val=0.2217 | Dice=0.6907 | IoU=0.5275\n",
+ "[DEBUG] Entered epoch 13\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ " [batch 0/207] loss: 0.1652\n",
+ " [batch 10/207] loss: 0.1329\n",
+ " [batch 20/207] loss: 0.2092\n",
+ " [batch 30/207] loss: 0.1648\n",
+ " [batch 40/207] loss: 0.0932\n",
+ " [batch 50/207] loss: 0.1448\n",
+ " [batch 60/207] loss: 0.2008\n",
+ " [batch 70/207] loss: 0.1754\n",
+ " [batch 80/207] loss: 0.2081\n",
+ " [batch 90/207] loss: 0.1241\n",
+ " [batch 100/207] loss: 0.1880\n",
+ " [batch 110/207] loss: 0.1601\n",
+ " [batch 120/207] loss: 0.1282\n",
+ " [batch 130/207] loss: 0.1487\n",
+ " [batch 140/207] loss: 0.1517\n",
+ " [batch 150/207] loss: 0.1544\n",
+ " [batch 160/207] loss: 0.1018\n",
+ " [batch 170/207] loss: 0.1403\n",
+ " [batch 180/207] loss: 0.1175\n",
+ " [batch 190/207] loss: 0.1488\n",
+ " [batch 200/207] loss: 0.1415\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "Epoch 13: Train=0.1682 | Val=0.2622 | Dice=0.6457 | IoU=0.4768\n",
+ "[DEBUG] Entered epoch 14\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ " [batch 0/207] loss: 0.2855\n",
+ " [batch 10/207] loss: 0.2600\n",
+ " [batch 20/207] loss: 0.1778\n",
+ " [batch 30/207] loss: 0.1939\n",
+ " [batch 40/207] loss: 0.1395\n",
+ " [batch 50/207] loss: 0.0920\n",
+ " [batch 60/207] loss: 0.1603\n",
+ " [batch 70/207] loss: 0.1238\n",
+ " [batch 80/207] loss: 0.1673\n",
+ " [batch 90/207] loss: 0.1552\n",
+ " [batch 100/207] loss: 0.1225\n",
+ " [batch 110/207] loss: 0.1592\n",
+ " [batch 120/207] loss: 0.1972\n",
+ " [batch 130/207] loss: 0.1940\n",
+ " [batch 140/207] loss: 0.1740\n",
+ " [batch 150/207] loss: 0.1977\n",
+ " [batch 160/207] loss: 0.1552\n",
+ " [batch 170/207] loss: 0.2466\n",
+ " [batch 180/207] loss: 0.2127\n",
+ " [batch 190/207] loss: 0.1929\n",
+ " [batch 200/207] loss: 0.1205\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "Epoch 14: Train=0.1700 | Val=0.2158 | Dice=0.6897 | IoU=0.5263\n",
+ "[DEBUG] Entered epoch 15\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ " [batch 0/207] loss: 0.1854\n",
+ " [batch 10/207] loss: 0.2445\n",
+ " [batch 20/207] loss: 0.0977\n",
+ " [batch 30/207] loss: 0.1831\n",
+ " [batch 40/207] loss: 0.2082\n",
+ " [batch 50/207] loss: 0.1230\n",
+ " [batch 60/207] loss: 0.2431\n",
+ " [batch 70/207] loss: 0.2035\n",
+ " [batch 80/207] loss: 0.1267\n",
+ " [batch 90/207] loss: 0.1210\n",
+ " [batch 100/207] loss: 0.1787\n",
+ " [batch 110/207] loss: 0.2013\n",
+ " [batch 120/207] loss: 0.2087\n",
+ " [batch 130/207] loss: 0.1681\n",
+ " [batch 140/207] loss: 0.1317\n",
+ " [batch 150/207] loss: 0.1911\n",
+ " [batch 160/207] loss: 0.1193\n",
+ " [batch 170/207] loss: 0.1786\n",
+ " [batch 180/207] loss: 0.1416\n",
+ " [batch 190/207] loss: 0.1293\n",
+ " [batch 200/207] loss: 0.1544\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "Epoch 15: Train=0.1661 | Val=0.2179 | Dice=0.6994 | IoU=0.5378\n",
+ "[DEBUG] Entered epoch 16\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ " [batch 0/207] loss: 0.1455\n",
+ " [batch 10/207] loss: 0.1057\n",
+ " [batch 20/207] loss: 0.1510\n",
+ " [batch 30/207] loss: 0.1456\n",
+ " [batch 40/207] loss: 0.2355\n",
+ " [batch 50/207] loss: 0.1749\n",
+ " [batch 60/207] loss: 0.1486\n",
+ " [batch 70/207] loss: 0.1764\n",
+ " [batch 80/207] loss: 0.1339\n",
+ " [batch 90/207] loss: 0.1780\n",
+ " [batch 100/207] loss: 0.1619\n",
+ " [batch 110/207] loss: 0.1672\n",
+ " [batch 120/207] loss: 0.2510\n",
+ " [batch 130/207] loss: 0.1626\n",
+ " [batch 140/207] loss: 0.1840\n",
+ " [batch 150/207] loss: 0.1420\n",
+ " [batch 160/207] loss: 0.1930\n",
+ " [batch 170/207] loss: 0.1624\n",
+ " [batch 180/207] loss: 0.1335\n",
+ " [batch 190/207] loss: 0.1410\n",
+ " [batch 200/207] loss: 0.1213\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "Epoch 16: Train=0.1537 | Val=0.2204 | Dice=0.7009 | IoU=0.5395\n",
+ "[DEBUG] Entered epoch 17\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ " [batch 0/207] loss: 0.1523\n",
+ " [batch 10/207] loss: 0.1845\n",
+ " [batch 20/207] loss: 0.1752\n",
+ " [batch 30/207] loss: 0.2439\n",
+ " [batch 40/207] loss: 0.1606\n",
+ " [batch 50/207] loss: 0.1179\n",
+ " [batch 60/207] loss: 0.1376\n",
+ " [batch 70/207] loss: 0.1196\n",
+ " [batch 80/207] loss: 0.2147\n",
+ " [batch 90/207] loss: 0.1457\n",
+ " [batch 100/207] loss: 0.2614\n",
+ " [batch 110/207] loss: 0.2404\n",
+ " [batch 120/207] loss: 0.1848\n",
+ " [batch 130/207] loss: 0.1519\n",
+ " [batch 140/207] loss: 0.1041\n",
+ " [batch 150/207] loss: 0.1542\n",
+ " [batch 160/207] loss: 0.2578\n",
+ " [batch 170/207] loss: 0.1130\n",
+ " [batch 180/207] loss: 0.1311\n",
+ " [batch 190/207] loss: 0.1231\n",
+ " [batch 200/207] loss: 0.1946\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "Epoch 17: Train=0.1509 | Val=0.2207 | Dice=0.6998 | IoU=0.5383\n",
+ "[DEBUG] Entered epoch 18\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ " [batch 0/207] loss: 0.1491\n",
+ " [batch 10/207] loss: 0.1399\n",
+ " [batch 20/207] loss: 0.1529\n",
+ " [batch 30/207] loss: 0.2028\n",
+ " [batch 40/207] loss: 0.1463\n",
+ " [batch 50/207] loss: 0.1690\n",
+ " [batch 60/207] loss: 0.1515\n",
+ " [batch 70/207] loss: 0.1616\n",
+ " [batch 80/207] loss: 0.1476\n",
+ " [batch 90/207] loss: 0.1398\n",
+ " [batch 100/207] loss: 0.0772\n",
+ " [batch 110/207] loss: 0.0930\n",
+ " [batch 120/207] loss: 0.1421\n",
+ " [batch 130/207] loss: 0.1534\n",
+ " [batch 140/207] loss: 0.1148\n",
+ " [batch 150/207] loss: 0.1338\n",
+ " [batch 160/207] loss: 0.1688\n",
+ " [batch 170/207] loss: 0.1613\n",
+ " [batch 180/207] loss: 0.1792\n",
+ " [batch 190/207] loss: 0.1392\n",
+ " [batch 200/207] loss: 0.1403\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "Epoch 18: Train=0.1491 | Val=0.2116 | Dice=0.7052 | IoU=0.5446\n",
+ "[DEBUG] Entered epoch 19\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ " [batch 0/207] loss: 0.1171\n",
+ " [batch 10/207] loss: 0.1994\n",
+ " [batch 20/207] loss: 0.1493\n",
+ " [batch 30/207] loss: 0.1619\n",
+ " [batch 40/207] loss: 0.1725\n",
+ " [batch 50/207] loss: 0.1558\n",
+ " [batch 60/207] loss: 0.1124\n",
+ " [batch 70/207] loss: 0.1338\n",
+ " [batch 80/207] loss: 0.1428\n",
+ " [batch 90/207] loss: 0.0972\n",
+ " [batch 100/207] loss: 0.1496\n",
+ " [batch 110/207] loss: 0.1751\n",
+ " [batch 120/207] loss: 0.1270\n",
+ " [batch 130/207] loss: 0.1437\n",
+ " [batch 140/207] loss: 0.2481\n",
+ " [batch 150/207] loss: 0.1262\n",
+ " [batch 160/207] loss: 0.1557\n",
+ " [batch 170/207] loss: 0.1365\n",
+ " [batch 180/207] loss: 0.0968\n",
+ " [batch 190/207] loss: 0.1993\n",
+ " [batch 200/207] loss: 0.1454\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "Epoch 19: Train=0.1451 | Val=0.2132 | Dice=0.7078 | IoU=0.5478\n",
+ "[DEBUG] Entered epoch 20\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ " [batch 0/207] loss: 0.2005\n",
+ " [batch 10/207] loss: 0.2464\n",
+ " [batch 20/207] loss: 0.1515\n",
+ " [batch 30/207] loss: 0.2253\n",
+ " [batch 40/207] loss: 0.0954\n",
+ " [batch 50/207] loss: 0.1562\n",
+ " [batch 60/207] loss: 0.1923\n",
+ " [batch 70/207] loss: 0.1598\n",
+ " [batch 80/207] loss: 0.1240\n",
+ " [batch 90/207] loss: 0.1108\n",
+ " [batch 100/207] loss: 0.1739\n",
+ " [batch 110/207] loss: 0.2150\n",
+ " [batch 120/207] loss: 0.1125\n",
+ " [batch 130/207] loss: 0.1525\n",
+ " [batch 140/207] loss: 0.1313\n",
+ " [batch 150/207] loss: 0.1273\n",
+ " [batch 160/207] loss: 0.1075\n",
+ " [batch 170/207] loss: 0.1172\n",
+ " [batch 180/207] loss: 0.3110\n",
+ " [batch 190/207] loss: 0.1552\n",
+ " [batch 200/207] loss: 0.1810\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "Epoch 20: Train=0.1420 | Val=0.2250 | Dice=0.7002 | IoU=0.5387\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "/content/drive/MyDrive/Assignments/RETFound/main_segmentation.py:62: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " return img, torch.tensor(mask, dtype=torch.long)\n",
+ "Test: Loss=0.2997 | Dice=0.6362 | IoU=0.4665\n"
+ ]
+ }
+ ],
+ "source": [
+ "!python main_segmentation.py \\\n",
+ " --data_path MendeleyOCT/Data \\\n",
+ " --epochs 20 \\\n",
+ " --batch_size 4 \\\n",
+ " --finetune RETFound_OCT"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "84ce93ac",
+ "metadata": {
+ "id": "84ce93ac"
+ },
+ "source": [
+ "## 5. Inference and Evaluation\n",
+ "\n",
+ "The following script:\n",
+ "- generates overlay visualizations \n",
+ "- computes Dice and IOU \n",
+ "- reports final metrics on the test set"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "02d2dce7-31c2-48e2-87ce-9223b74cf94e",
+ "metadata": {
+ "id": "02d2dce7-31c2-48e2-87ce-9223b74cf94e"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "===== Running Inference =====\n",
+ "\n",
+ "DRUSEN-100580-1.jpeg Dice: 0.3590 IoU: 0.2188\n",
+ "DRUSEN-103885-1.jpeg Dice: 0.5493 IoU: 0.3786\n",
+ "DRUSEN-103885-2.jpeg Dice: 0.6422 IoU: 0.4730\n",
+ "DRUSEN-103885-3.jpeg Dice: 0.6422 IoU: 0.4730\n",
+ "DRUSEN-103885-4.jpeg Dice: 0.6587 IoU: 0.4911\n",
+ "DRUSEN-103885-5.jpeg Dice: 0.6587 IoU: 0.4911\n",
+ "DRUSEN-142234-1.jpeg Dice: 0.2874 IoU: 0.1678\n",
+ "DRUSEN-142234-10.jpeg Dice: 0.3077 IoU: 0.1818\n",
+ "DRUSEN-142234-11.jpeg Dice: 0.6610 IoU: 0.4937\n",
+ "DRUSEN-142234-12.jpeg Dice: 0.1434 IoU: 0.0773\n",
+ "DRUSEN-142234-13.jpeg Dice: 0.5970 IoU: 0.4255\n",
+ "DRUSEN-142234-14.jpeg Dice: 0.3913 IoU: 0.2432\n",
+ "DRUSEN-142234-15.jpeg Dice: 0.0000 IoU: 0.0000\n",
+ "DRUSEN-142234-16.jpeg Dice: 0.0000 IoU: 0.0000\n",
+ "DRUSEN-142234-17.jpeg Dice: 0.7333 IoU: 0.5789\n",
+ "DRUSEN-142234-18.jpeg Dice: 0.1111 IoU: 0.0588\n",
+ "DRUSEN-142234-19.jpeg Dice: 0.0000 IoU: 0.0000\n",
+ "DRUSEN-142234-2.jpeg Dice: 0.5764 IoU: 0.4049\n",
+ "DRUSEN-142234-20.jpeg Dice: 0.2250 IoU: 0.1268\n",
+ "DRUSEN-142234-21.jpeg Dice: 0.5943 IoU: 0.4228\n",
+ "DRUSEN-142234-22.jpeg Dice: 0.4255 IoU: 0.2703\n",
+ "DRUSEN-142234-23.jpeg Dice: 0.4754 IoU: 0.3118\n",
+ "DRUSEN-142234-25.jpeg Dice: 0.0000 IoU: 0.0000\n",
+ "DRUSEN-142234-26.jpeg Dice: 0.1892 IoU: 0.1045\n",
+ "DRUSEN-142234-27.jpeg Dice: 0.5161 IoU: 0.3478\n",
+ "DRUSEN-142234-3.jpeg Dice: 0.4615 IoU: 0.3000\n",
+ "DRUSEN-142234-4.jpeg Dice: 0.3585 IoU: 0.2184\n",
+ "DRUSEN-142234-5.jpeg Dice: 0.3864 IoU: 0.2394\n",
+ "DRUSEN-142234-6.jpeg Dice: 0.5495 IoU: 0.3788\n",
+ "DRUSEN-142234-7.jpeg Dice: 0.0000 IoU: 0.0000\n",
+ "DRUSEN-142234-8.jpeg Dice: 0.1793 IoU: 0.0985\n",
+ "DRUSEN-142234-9.jpeg Dice: 0.2857 IoU: 0.1667\n",
+ "DRUSEN-163081-1.jpeg Dice: 0.7474 IoU: 0.5967\n",
+ "DRUSEN-228939-1.jpeg Dice: 0.7668 IoU: 0.6218\n",
+ "DRUSEN-228939-2.jpeg Dice: 0.6174 IoU: 0.4465\n",
+ "DRUSEN-303435-1.jpeg Dice: 0.3904 IoU: 0.2425\n",
+ "DRUSEN-349021-1.jpeg Dice: 0.7522 IoU: 0.6029\n",
+ "DRUSEN-349021-2.jpeg Dice: 0.6172 IoU: 0.4464\n",
+ "DRUSEN-364469-1.jpeg Dice: 0.7781 IoU: 0.6368\n",
+ "DRUSEN-364469-2.jpeg Dice: 0.6948 IoU: 0.5323\n",
+ "DRUSEN-364469-3.jpeg Dice: 0.6527 IoU: 0.4845\n",
+ "DRUSEN-364469-4.jpeg Dice: 0.6513 IoU: 0.4829\n",
+ "DRUSEN-457907-1.jpeg Dice: 0.3455 IoU: 0.2089\n",
+ "DRUSEN-95633-1.jpeg Dice: 0.4597 IoU: 0.2985\n",
+ "DRUSEN-9800172-2.jpeg Dice: 0.6992 IoU: 0.5375\n",
+ "DRUSEN-9837663-1.jpeg Dice: 0.8452 IoU: 0.7318\n",
+ "DRUSEN-9861332-1.jpeg Dice: 0.8508 IoU: 0.7404\n",
+ "DRUSEN-9884539-1.jpeg Dice: 0.7580 IoU: 0.6103\n",
+ "DRUSEN-9894035-2.jpeg Dice: 0.7279 IoU: 0.5722\n",
+ "DRUSEN-9928043-1.jpeg Dice: 0.7005 IoU: 0.5391\n",
+ "\n",
+ "===== FINAL REPORT =====\n",
+ "Mean Dice: 0.4804\n",
+ "Mean IoU : 0.3495\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "f:\\GitHub\\RETFound\\venv3.9\\lib\\site-packages\\albumentations\\__init__.py:24: UserWarning: A new version of Albumentations is available: 2.0.8 (you have 1.4.24). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.\n",
+ " check_for_updates()\n"
+ ]
+ }
+ ],
+ "source": [
+ "!python inference_segmentation.py \\\n",
+ " --ckpt segmentation_output/best.pth \\\n",
+ " --data_path MendeleyOCT/Data \\\n",
+ " --out_dir segmentation_output/inference"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cdccae46",
+ "metadata": {},
+ "source": [
+ "## Summary\n",
+ "\n",
+ "- Extended RETFound to **drusen segmentation** via decoder \n",
+ "- Reused MAE-pretrained ViT encoder \n",
+ "- Provided training & inference pipeline \n",
+ "- Evaluated with Dice and IoU metrics \n",
+ "- Demonstrated transfer to dense prediction"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "20f7ee18",
+ "metadata": {},
+ "source": [
+ "## Future Work\n",
+ "\n",
+ "- Improve annotation consistency \n",
+ "- Add more diverse OCT data \n",
+ "- Stronger preprocessing & normalization \n",
+ "- Deeper decoder (if compute allows) \n",
+ "- Hyperparameter tuning"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuType": "T4",
+ "provenance": []
+ },
+ "environment": {
+ "kernel": "retfound",
+ "name": "workbench-notebooks.m128",
+ "type": "gcloud",
+ "uri": "us-docker.pkg.dev/deeplearning-platform-release/gcr.io/workbench-notebooks:m128"
+ },
+ "kernelspec": {
+ "display_name": "venv3.9",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/inference_segmentation.py b/inference_segmentation.py
new file mode 100644
index 00000000..b8e901ad
--- /dev/null
+++ b/inference_segmentation.py
@@ -0,0 +1,244 @@
+import os
+import cv2
+import torch
+import argparse
+import numpy as np
+from albumentations import Compose, Resize, Normalize
+from albumentations.pytorch import ToTensorV2
+
+from models_segmentation import RETFoundSegmentation
+
+# Select device automatically
+DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
+
+
+# ============================================================
+# Metrics
+# ============================================================
+def dice_iou(pred, gt, smooth=1e-6):
+ """
+ Compute Dice coefficient and IoU between prediction and ground truth.
+
+ Args:
+ pred : Binary prediction mask (numpy)
+ gt : Binary ground truth mask (numpy)
+ smooth: Small constant to avoid division by zero
+
+ Returns:
+ dice : Dice similarity score
+ iou : Intersection over Union score
+ """
+
+ # Convert to boolean for logical operations
+ pred = pred.astype(bool)
+ gt = gt.astype(bool)
+
+ # Intersection and union areas
+ inter = (pred & gt).sum()
+ union = (pred | gt).sum()
+
+ # Dice and IoU formulas
+ dice = (2 * inter + smooth) / (pred.sum() + gt.sum() + smooth)
+ iou = (inter + smooth) / (union + smooth)
+
+ return dice, iou
+
+
+# ============================================================
+# Model loader
+# ============================================================
+def load_model(ckpt, img_size=256):
+ """
+ Load trained RETFound segmentation model.
+
+ Args:
+ ckpt : Path to checkpoint
+ img_size : Input size used during training
+
+ Returns:
+ model in evaluation mode
+ """
+
+ model = RETFoundSegmentation(img_size=img_size).to(DEVICE)
+
+ # Load weights
+ state = torch.load(ckpt, map_location=DEVICE, weights_only=False)
+ model.load_state_dict(state)
+
+ model.eval()
+ return model
+
+
+# ============================================================
+# Preprocess
+# ============================================================
+def preprocess(img_path, img_size=256):
+ """
+ Read and preprocess input image.
+
+ - Read image using OpenCV
+ - Convert BGR → RGB
+ - Resize and normalize as per RETFound training
+ - Convert to tensor
+
+ Returns:
+ tensor image for model, original image for visualization
+ """
+
+ img = cv2.imread(img_path)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+
+ tf = Compose([
+ Resize(img_size, img_size),
+ Normalize((0.485, 0.456, 0.406),
+ (0.229, 0.224, 0.225)),
+ ToTensorV2()
+ ])
+
+ aug = tf(image=img)
+ return aug["image"].unsqueeze(0), img
+
+
+# ============================================================
+# Overlay visualization
+# ============================================================
+def make_overlay(img, mask):
+ """
+ Create green overlay on original image using predicted mask.
+
+ Args:
+ img : Original RGB image
+ mask : Binary prediction mask
+
+ Returns:
+ Overlay visualization
+ """
+
+ # Ensure uint8 type
+ mask = mask.astype(np.uint8)
+
+ # Resize mask back to original image size
+ mask = cv2.resize(
+ mask,
+ (img.shape[1], img.shape[0]),
+ interpolation=cv2.INTER_NEAREST
+ )
+
+ # Create green color mask
+ color = np.zeros_like(img, dtype=np.uint8)
+ color[:, :, 1] = mask * 255 # Green channel
+
+ # Blend with original image
+ overlay = cv2.addWeighted(img, 0.7, color, 0.3, 0)
+ return overlay
+
+
+# ============================================================
+# Main inference
+# ============================================================
+def run_inference(args):
+ """
+ Perform inference on test set.
+
+ - Load model
+ - Iterate over test images
+ - Predict masks
+ - Compute Dice & IoU
+ - Save overlay visualizations
+ """
+
+ # ----- Resolve dataset paths from single root -----
+ test_img_dir = os.path.join(args.data_path, "test", "images")
+ test_mask_dir = os.path.join(args.data_path, "test", "masks")
+
+ # Validate paths
+ if not os.path.isdir(test_img_dir):
+ raise FileNotFoundError(f"Images folder not found: {test_img_dir}")
+
+ if not os.path.isdir(test_mask_dir):
+ raise FileNotFoundError(f"Masks folder not found: {test_mask_dir}")
+
+ os.makedirs(args.out_dir, exist_ok=True)
+
+ # Load trained model
+ model = load_model(args.ckpt, args.img_size)
+
+ dice_scores = []
+ iou_scores = []
+
+ print("\n===== Running Inference =====\n")
+
+ # Iterate through test images
+ for name in sorted(os.listdir(test_img_dir)):
+
+ img_path = os.path.join(test_img_dir, name)
+
+ # Corresponding ground truth name
+ stem = os.path.splitext(name)[0]
+ gt_name = stem + "_mask.png"
+ gt_path = os.path.join(test_mask_dir, gt_name)
+
+ if not os.path.isfile(gt_path):
+ continue
+
+ # ----- Predict -----
+ tensor, orig = preprocess(img_path, args.img_size)
+
+ with torch.no_grad():
+ out = model(tensor.to(DEVICE))
+ pred = out.argmax(1).squeeze().cpu().numpy()
+
+ # ----- Load GT -----
+ gt = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)
+ gt = (gt > 0).astype("uint8")
+
+ # Align GT to prediction resolution
+ gt = cv2.resize(
+ gt,
+ (pred.shape[1], pred.shape[0]),
+ interpolation=cv2.INTER_NEAREST
+ )
+
+ # ----- Metrics -----
+ dice, iou = dice_iou(pred, gt)
+ dice_scores.append(dice)
+ iou_scores.append(iou)
+
+ # ----- Save overlay only -----
+ over = make_overlay(orig, pred)
+ cv2.imwrite(os.path.join(args.out_dir, stem + "_overlay.png"), over)
+
+ # Print per-image result
+ print(f"{name:25s} Dice: {dice:.4f} IoU: {iou:.4f}")
+
+ # ---------------------------------------------------------
+ # Final aggregated report
+ print("\n===== FINAL REPORT =====")
+ print("Mean Dice:", round(np.mean(dice_scores), 4))
+ print("Mean IoU :", round(np.mean(iou_scores), 4))
+
+
+# ============================================================
+# CLI
+# ============================================================
+if __name__ == "__main__":
+
+ # Command line argument parser
+ parser = argparse.ArgumentParser(description="RETFound Segmentation Inference")
+
+ parser.add_argument("--ckpt", type=str, required=True,
+ help="Path to trained model checkpoint")
+
+ parser.add_argument("--data_path", type=str, required=True,
+ help="Root dataset folder containing test/images and test/masks")
+
+ parser.add_argument("--out_dir", type=str, default="segmentation_output/inference",
+ help="Output directory for overlays")
+
+ parser.add_argument("--img_size", type=int, default=256,
+ help="Input resize dimension")
+
+ args = parser.parse_args()
+
+ # Start inference
+ run_inference(args)
diff --git a/main_finetune.py b/main_finetune.py
index 0f4513f4..91e192ee 100644
--- a/main_finetune.py
+++ b/main_finetune.py
@@ -1,448 +1,448 @@
-#!/usr/bin/env python3
-
-# =========================
-import argparse
-import datetime
-import json
-import os
-import time
-from pathlib import Path
-import warnings
-import faulthandler
-
-# =========================
-import numpy as np
-import torch
-import torch.backends.cudnn as cudnn
-from torch.utils.tensorboard import SummaryWriter
-from timm.models.layers import trunc_normal_
-from timm.data.mixup import Mixup
-from huggingface_hub import hf_hub_download, login # login imported as in original
-
-# =========================
-import models_vit as models
-import util.lr_decay as lrd
-import util.misc as misc
-from util.datasets import build_dataset
-from util.pos_embed import interpolate_pos_embed
-from util.misc import NativeScalerWithGradNormCount as NativeScaler
-from engine_finetune import train_one_epoch, evaluate
-
-# =========================
-faulthandler.enable()
-warnings.simplefilter(action="ignore", category=FutureWarning)
-
-
-def get_args_parser():
- parser = argparse.ArgumentParser(
- "MAE fine-tuning / linear probing for image classification", add_help=False
- )
-
- # ---- Core training
- parser.add_argument("--batch_size", default=128, type=int,
- help="Batch size per GPU (effective batch size = batch_size * accum_iter * #gpus)")
- parser.add_argument("--epochs", default=50, type=int)
- parser.add_argument("--accum_iter", default=1, type=int,
- help="Gradient accumulation steps")
-
- # ---- Model parameters
- parser.add_argument("--model", default="vit_large_patch16", type=str, metavar="MODEL",
- help="Model entry in models_vit.py")
- parser.add_argument("--model_arch", default="dinov3_vits16", type=str, metavar="MODEL_ARCH",
- help="Backbone architecture key (e.g., dinov2_vitl14, convnext_base, etc.)")
- parser.add_argument("--input_size", default=256, type=int, help="Image size")
- parser.add_argument("--drop_path", type=float, default=0.2, metavar="PCT", help="Drop path rate")
- parser.add_argument("--global_pool", action="store_true"); parser.set_defaults(global_pool=True)
- parser.add_argument("--cls_token", action="store_false", dest="global_pool",
- help="Use class token instead of global pool for classification")
-
- # ---- Optimizer parameters
- parser.add_argument("--clip_grad", type=float, default=None, metavar="NORM", help="Clip grad norm")
- parser.add_argument("--weight_decay", type=float, default=0.05, help="Weight decay")
- parser.add_argument("--lr", type=float, default=None, metavar="LR", help="Absolute LR (overrides blr)")
- parser.add_argument("--blr", type=float, default=5e-3, metavar="LR",
- help="Base LR: lr = blr * total_batch_size / 256")
- parser.add_argument("--layer_decay", type=float, default=0.65, help="Layer-wise LR decay (ViT)")
- parser.add_argument("--min_lr", type=float, default=1e-6, metavar="LR", help="Lower LR bound")
- parser.add_argument("--warmup_epochs", type=int, default=10, metavar="N", help="Warmup epochs")
-
- # ---- Augmentation
- parser.add_argument("--color_jitter", type=float, default=None, metavar="PCT")
- parser.add_argument("--aa", type=str, default="rand-m9-mstd0.5-inc1", metavar="NAME")
- parser.add_argument("--smoothing", type=float, default=0.1)
-
- # ---- Random erase
- parser.add_argument("--reprob", type=float, default=0.25, metavar="PCT")
- parser.add_argument("--remode", type=str, default="pixel")
- parser.add_argument("--recount", type=int, default=1)
- parser.add_argument("--resplit", action="store_true", default=False)
-
- # ---- Mixup/Cutmix
- parser.add_argument("--mixup", type=float, default=0.0)
- parser.add_argument("--cutmix", type=float, default=0.0)
- parser.add_argument("--cutmix_minmax", type=float, nargs="+", default=None)
- parser.add_argument("--mixup_prob", type=float, default=1.0)
- parser.add_argument("--mixup_switch_prob", type=float, default=0.5)
- parser.add_argument("--mixup_mode", type=str, default="batch")
-
- # ---- Finetuning & adaptation
- parser.add_argument("--finetune", default="", type=str, help="Checkpoint id/path (see model rules below)")
- parser.add_argument("--task", default="", type=str, help="Task name for logging/output grouping")
- parser.add_argument("--adaptation", default="finetune", choices=["finetune", "lp"],
- help="Adaptation strategy: finetune=full fine-tune, lp=linear probe (train head only)")
-
- # ---- Dataset & paths
- parser.add_argument("--data_path", default="./data/", type=str)
- parser.add_argument("--nb_classes", default=8, type=int)
- parser.add_argument("--output_dir", default="./output_dir")
- parser.add_argument("--log_dir", default="./output_logs")
-
- # >>> NEW: training data efficiency <<<
- parser.add_argument(
- "--dataratio", type=str, default="1.0",
- help=('Training data ratio(s) for subsampling in build_dataset. '
- 'Use a single float in (0,1] (e.g., 0.25) or a comma-separated list '
- '(e.g., "1.0,0.5,0.25") if your build_dataset supports sweeps.')
- )
- parser.add_argument(
- "--stratified", action="store_true",
- help="If set, subsample training data in a class-stratified manner (requires support in build_dataset)."
- )
-
- # ---- Runtime
- parser.add_argument("--device", default="cuda")
- parser.add_argument("--seed", default=0, type=int)
- parser.add_argument("--resume", default="", help="Resume full state (optimizer, scaler, etc.)")
- parser.add_argument("--start_epoch", default=0, type=int, metavar="N")
- parser.add_argument("--eval", action="store_true", help="Evaluation only")
- parser.add_argument("--dist_eval", action="store_true", default=False,
- help="Distributed evaluation (faster monitoring during training)")
- parser.add_argument("--num_workers", default=10, type=int)
- parser.add_argument("--pin_mem", action="store_true"); parser.set_defaults(pin_mem=True)
-
- # ---- Distributed
- parser.add_argument("--world_size", default=1, type=int)
- parser.add_argument("--local_rank", default=-1, type=int)
- parser.add_argument("--dist_on_itp", action="store_true")
- parser.add_argument("--dist_url", default="env://")
-
- # ---- Misc
- parser.add_argument("--savemodel", action="store_true", default=True, help="Save best model")
- parser.add_argument("--norm", default="IMAGENET", type=str)
- parser.add_argument("--enhance", action="store_true", default=False)
- parser.add_argument("--datasets_seed", default=2026, type=int)
-
- return parser
-
-
-# =========================
-# Main
-# =========================
-def main(args, criterion):
- # ---- Optionally load args from resume (when training)
- if args.resume and not args.eval:
- resume_path = args.resume
- checkpoint = torch.load(args.resume, map_location="cpu")
- print(f"Load checkpoint (args) from: {args.resume}")
- args = checkpoint["args"]
- args.resume = resume_path
-
- # ---- Distributed setup
- misc.init_distributed_mode(args)
-
- print(f"job dir: {os.path.dirname(os.path.realpath(__file__))}")
- print(f"{args}".replace(", ", ",\n"))
-
- device = torch.device(args.device)
-
- # ---- Reproducibility
- seed = args.seed + misc.get_rank()
- torch.manual_seed(seed)
- np.random.seed(seed)
- cudnn.benchmark = True
-
- # ---- Build model
- if args.model == "RETFound_mae":
- model = models.__dict__[args.model](
- img_size=args.input_size,
- num_classes=args.nb_classes,
- drop_path_rate=args.drop_path,
- global_pool=args.global_pool,
- )
- else:
- model = models.__dict__[args.model](
- num_classes=args.nb_classes,
- drop_path_rate=args.drop_path,
- args=args,
- )
-
- # ---- Load pre-trained weights (if requested and not eval-only)
- if args.finetune and not args.eval:
- print(f"Preparing to load pre-trained weights: {args.finetune}")
-
- if args.model in ["Dinov3", "Dinov2"]:
- checkpoint_path = args.finetune # local path
- elif args.model in ["RETFound_dinov2", "RETFound_mae"]:
- print(f"Downloading pre-trained weights from Hugging Face Hub: {args.finetune}")
- checkpoint_path = hf_hub_download(
- repo_id=f"YukunZhou/{args.finetune}",
- filename=f"{args.finetune}.pth",
- )
- else:
- raise ValueError(
- f"Unsupported model '{args.model}'. "
- f"Expected one of: Dinov3, Dinov2, RETFound_dinov2, RETFound_mae"
- )
-
- checkpoint = torch.load(checkpoint_path, map_location="cpu")
- print(f"Loaded pre-trained checkpoint from: {checkpoint_path}")
-
- if args.model in ["Dinov3", "Dinov2"]:
- checkpoint_model = checkpoint
- elif args.model == "RETFound_dinov2":
- checkpoint_model = checkpoint["teacher"]
- else: # RETFound_mae
- checkpoint_model = checkpoint["model"]
-
- # -- Key hygiene
- checkpoint_model = {k.replace("backbone.", ""): v for k, v in checkpoint_model.items()}
- checkpoint_model = {k.replace("mlp.w12.", "mlp.fc1."): v for k, v in checkpoint_model.items()}
- checkpoint_model = {k.replace("mlp.w3.", "mlp.fc2."): v for k, v in checkpoint_model.items()}
-
- # -- Remove classifier if shape mismatched
- state_dict = model.state_dict()
- for k in ["head.weight", "head.bias"]:
- if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
- print(f"Removing key {k} from pretrained checkpoint")
- del checkpoint_model[k]
-
- # -- Interpolate pos embed (ViT)
- interpolate_pos_embed(model, checkpoint_model)
-
- # -- Load backbone weights (non-strict)
- _ = model.load_state_dict(checkpoint_model, strict=False)
-
- # -- Re-init head
- if hasattr(model, "head") and hasattr(model.head, "weight"):
- trunc_normal_(model.head.weight, std=2e-5)
-
- # ---- Datasets & samplers
- dataset_train = build_dataset(is_train="train", args=args)
- dataset_val = build_dataset(is_train="val", args=args)
- dataset_test = build_dataset(is_train="test", args=args)
-
- num_tasks = misc.get_world_size()
- global_rank = misc.get_rank()
-
- if not args.eval:
- sampler_train = torch.utils.data.DistributedSampler(
- dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
- )
- print(f"Sampler_train = {sampler_train}")
- if args.dist_eval:
- if len(dataset_val) % num_tasks != 0:
- print("Warning: dist eval with dataset not divisible by #procs; results may differ slightly.")
- sampler_val = torch.utils.data.DistributedSampler(
- dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True
- )
- else:
- sampler_val = torch.utils.data.SequentialSampler(dataset_val)
-
- if args.dist_eval:
- if len(dataset_test) % num_tasks != 0:
- print("Warning: dist eval test set not divisible by #procs; results may differ slightly.")
- sampler_test = torch.utils.data.DistributedSampler(
- dataset_test, num_replicas=num_tasks, rank=global_rank, shuffle=True
- )
- else:
- sampler_test = torch.utils.data.SequentialSampler(dataset_test)
-
- # ---- Logging
- if global_rank == 0 and args.log_dir is not None and not args.eval:
- os.makedirs(args.log_dir, exist_ok=True)
- log_writer = SummaryWriter(log_dir=os.path.join(args.log_dir, args.task))
- else:
- log_writer = None
-
- # ---- DataLoaders
- if not args.eval:
- data_loader_train = torch.utils.data.DataLoader(
- dataset_train, sampler=sampler_train,
- batch_size=args.batch_size, num_workers=args.num_workers,
- pin_memory=args.pin_mem, drop_last=True,
- )
- print(f"len of train_set: {len(data_loader_train) * args.batch_size}")
-
- data_loader_val = torch.utils.data.DataLoader(
- dataset_val, sampler=sampler_val,
- batch_size=args.batch_size, num_workers=args.num_workers,
- pin_memory=args.pin_mem, drop_last=False,
- )
-
- data_loader_test = torch.utils.data.DataLoader(
- dataset_test, sampler=sampler_test,
- batch_size=args.batch_size, num_workers=args.num_workers,
- pin_memory=args.pin_mem, drop_last=False,
- )
-
- # ---- Mixup/CutMix
- mixup_fn = None
- mixup_active = (args.mixup > 0) or (args.cutmix > 0.) or (args.cutmix_minmax is not None)
- if mixup_active:
- print("Mixup is activated!")
- mixup_fn = Mixup(
- mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
- prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
- label_smoothing=args.smoothing, num_classes=args.nb_classes
- )
-
- # ---- Eval-only: resume weights
- if args.resume and args.eval:
- checkpoint = torch.load(args.resume, map_location="cpu")
- print(f"Load checkpoint for eval from: {args.resume}")
- model.load_state_dict(checkpoint["model"])
-
- model.to(device)
- model_without_ddp = model
-
- # ---- Adaptation toggle
- if args.adaptation == "lp":
- for name, param in model.named_parameters():
- param.requires_grad = ("head" in name)
- print("[Adaptation] Linear probe: training classifier head only.")
- else:
- print("[Adaptation] Full fine-tuning: training all parameters.")
-
- # ---- Count trainable params
- n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
- print(f"number of trainable params (M): {n_parameters / 1.e6:.2f}")
-
- # ---- LR scaling by effective batch size
- eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
- if args.lr is None:
- args.lr = args.blr * eff_batch_size / 256
- print(f"base lr: {args.lr * 256 / eff_batch_size:.2e}")
- print(f"actual lr: {args.lr:.2e}")
- print(f"accumulate grad iterations: {args.accum_iter}")
- print(f"effective batch size: {eff_batch_size}")
-
- # ---- DDP (if available)
- if args.distributed and torch.cuda.device_count() > 1:
- ddp_kwargs = {}
- if args.adaptation == "lp":
- ddp_kwargs["find_unused_parameters"] = True
- model = torch.nn.parallel.DistributedDataParallel(
- model, device_ids=[args.gpu], **ddp_kwargs
- )
- model_without_ddp = model.module
- else:
- model_without_ddp = model # single-GPU
-
- # ---- Optimizer param groups (after freezing)
- no_weight_decay = (model_without_ddp.no_weight_decay()
- if hasattr(model_without_ddp, "no_weight_decay") else [])
-
-
- param_groups = lrd.param_groups_lrd(
- model_without_ddp,
- weight_decay=args.weight_decay,
- no_weight_decay_list=no_weight_decay,
- layer_decay=args.layer_decay,
- )
- for g in param_groups:
- g["params"] = [p for p in g["params"] if p.requires_grad]
-
- optimizer = torch.optim.AdamW(param_groups, lr=args.lr)
- loss_scaler = NativeScaler()
- print(f"criterion = {criterion}")
-
- # ---- Load previous full state (optimizer, scaler, etc.)
- misc.load_model(args=args, model_without_ddp=model_without_ddp,
- optimizer=optimizer, loss_scaler=loss_scaler)
-
- # =========================
- # Eval-only Short Circuit
- # =========================
- if args.eval:
- if "checkpoint" in locals() and isinstance(checkpoint, dict) and ("epoch" in checkpoint):
- print(f"Test with the best model at epoch = {checkpoint['epoch']}")
- test_stats, auc_roc = evaluate(
- data_loader_test, model, device, args, epoch=0, mode="test",
- num_class=args.nb_classes, log_writer=log_writer
- )
- return
-
- # =========================
- # Train Loop
- # =========================
- print(f"Start training for {args.epochs} epochs")
- start_time = time.time()
- max_score = 0.0
- best_epoch = 0
-
- for epoch in range(args.start_epoch, args.epochs):
- if args.distributed:
- data_loader_train.sampler.set_epoch(epoch)
-
- train_stats = train_one_epoch(
- model, criterion, data_loader_train,
- optimizer, device, epoch, loss_scaler,
- args.clip_grad, mixup_fn,
- log_writer=log_writer, args=args
- )
-
- val_stats, val_score = evaluate(
- data_loader_val, model, device, args, epoch, mode="val",
- num_class=args.nb_classes, log_writer=log_writer
- )
-
- if max_score < val_score:
- max_score = val_score
- best_epoch = epoch
- if args.output_dir and args.savemodel:
- misc.save_model(
- args=args, model=model, model_without_ddp=model_without_ddp,
- optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch, mode="best"
- )
- print(f"Best epoch = {best_epoch}, Best score = {max_score:.4f}")
-
- if log_writer is not None:
- log_writer.add_scalar("loss/val", val_stats["loss"], epoch)
- log_writer.flush()
-
- log_stats = {**{f"train_{k}": v for k, v in train_stats.items()},
- "epoch": epoch,
- "n_parameters": n_parameters}
-
- if args.output_dir and misc.is_main_process():
- with open(os.path.join(args.output_dir, args.task, "log.txt"), "a", encoding="utf-8") as f:
- f.write(json.dumps(log_stats) + "\n")
-
- # =========================
- # Final Test (Best Ckpt)
- # =========================
- ckpt_path = os.path.join(args.output_dir, args.task, "checkpoint-best.pth")
- checkpoint = torch.load(ckpt_path, map_location="cpu")
- model_without_ddp.load_state_dict(checkpoint["model"], strict=False)
- model.to(device)
- print(f"Test with the best model, epoch = {checkpoint.get('epoch', -1)}:")
- _test_stats, _auc_roc = evaluate(
- data_loader_test, model, device, args, -1, mode="test",
- num_class=args.nb_classes, log_writer=None
- )
-
- total_time = time.time() - start_time
- total_time_str = str(datetime.timedelta(seconds=int(total_time)))
- print(f"Training time {total_time_str}")
-
-
-if __name__ == "__main__":
- args = get_args_parser()
- args = args.parse_args()
-
- criterion = torch.nn.CrossEntropyLoss()
-
- if args.output_dir:
- Path(args.output_dir).mkdir(parents=True, exist_ok=True)
-
- main(args, criterion)
+#!/usr/bin/env python3
+
+# =========================
+import argparse
+import datetime
+import json
+import os
+import time
+from pathlib import Path
+import warnings
+import faulthandler
+
+# =========================
+import numpy as np
+import torch
+import torch.backends.cudnn as cudnn
+from torch.utils.tensorboard import SummaryWriter
+from timm.models.layers import trunc_normal_
+from timm.data.mixup import Mixup
+from huggingface_hub import hf_hub_download, login # login imported as in original
+
+# =========================
+import models_vit as models
+import util.lr_decay as lrd
+import util.misc as misc
+from util.datasets import build_dataset
+from util.pos_embed import interpolate_pos_embed
+from util.misc import NativeScalerWithGradNormCount as NativeScaler
+from engine_finetune import train_one_epoch, evaluate
+
+# =========================
+faulthandler.enable()
+warnings.simplefilter(action="ignore", category=FutureWarning)
+
+
+def get_args_parser():
+ parser = argparse.ArgumentParser(
+ "MAE fine-tuning / linear probing for image classification", add_help=False
+ )
+
+ # ---- Core training
+ parser.add_argument("--batch_size", default=128, type=int,
+ help="Batch size per GPU (effective batch size = batch_size * accum_iter * #gpus)")
+ parser.add_argument("--epochs", default=50, type=int)
+ parser.add_argument("--accum_iter", default=1, type=int,
+ help="Gradient accumulation steps")
+
+ # ---- Model parameters
+ parser.add_argument("--model", default="vit_large_patch16", type=str, metavar="MODEL",
+ help="Model entry in models_vit.py")
+ parser.add_argument("--model_arch", default="dinov3_vits16", type=str, metavar="MODEL_ARCH",
+ help="Backbone architecture key (e.g., dinov2_vitl14, convnext_base, etc.)")
+ parser.add_argument("--input_size", default=256, type=int, help="Image size")
+ parser.add_argument("--drop_path", type=float, default=0.2, metavar="PCT", help="Drop path rate")
+ parser.add_argument("--global_pool", action="store_true"); parser.set_defaults(global_pool=True)
+ parser.add_argument("--cls_token", action="store_false", dest="global_pool",
+ help="Use class token instead of global pool for classification")
+
+ # ---- Optimizer parameters
+ parser.add_argument("--clip_grad", type=float, default=None, metavar="NORM", help="Clip grad norm")
+ parser.add_argument("--weight_decay", type=float, default=0.05, help="Weight decay")
+ parser.add_argument("--lr", type=float, default=None, metavar="LR", help="Absolute LR (overrides blr)")
+ parser.add_argument("--blr", type=float, default=5e-3, metavar="LR",
+ help="Base LR: lr = blr * total_batch_size / 256")
+ parser.add_argument("--layer_decay", type=float, default=0.65, help="Layer-wise LR decay (ViT)")
+ parser.add_argument("--min_lr", type=float, default=1e-6, metavar="LR", help="Lower LR bound")
+ parser.add_argument("--warmup_epochs", type=int, default=10, metavar="N", help="Warmup epochs")
+
+ # ---- Augmentation
+ parser.add_argument("--color_jitter", type=float, default=None, metavar="PCT")
+ parser.add_argument("--aa", type=str, default="rand-m9-mstd0.5-inc1", metavar="NAME")
+ parser.add_argument("--smoothing", type=float, default=0.1)
+
+ # ---- Random erase
+ parser.add_argument("--reprob", type=float, default=0.25, metavar="PCT")
+ parser.add_argument("--remode", type=str, default="pixel")
+ parser.add_argument("--recount", type=int, default=1)
+ parser.add_argument("--resplit", action="store_true", default=False)
+
+ # ---- Mixup/Cutmix
+ parser.add_argument("--mixup", type=float, default=0.0)
+ parser.add_argument("--cutmix", type=float, default=0.0)
+ parser.add_argument("--cutmix_minmax", type=float, nargs="+", default=None)
+ parser.add_argument("--mixup_prob", type=float, default=1.0)
+ parser.add_argument("--mixup_switch_prob", type=float, default=0.5)
+ parser.add_argument("--mixup_mode", type=str, default="batch")
+
+ # ---- Finetuning & adaptation
+ parser.add_argument("--finetune", default="", type=str, help="Checkpoint id/path (see model rules below)")
+ parser.add_argument("--task", default="", type=str, help="Task name for logging/output grouping")
+ parser.add_argument("--adaptation", default="finetune", choices=["finetune", "lp"],
+ help="Adaptation strategy: finetune=full fine-tune, lp=linear probe (train head only)")
+
+ # ---- Dataset & paths
+ parser.add_argument("--data_path", default="./data/", type=str)
+ parser.add_argument("--nb_classes", default=8, type=int)
+ parser.add_argument("--output_dir", default="./output_dir")
+ parser.add_argument("--log_dir", default="./output_logs")
+
+ # >>> NEW: training data efficiency <<<
+ parser.add_argument(
+ "--dataratio", type=str, default="1.0",
+ help=('Training data ratio(s) for subsampling in build_dataset. '
+ 'Use a single float in (0,1] (e.g., 0.25) or a comma-separated list '
+ '(e.g., "1.0,0.5,0.25") if your build_dataset supports sweeps.')
+ )
+ parser.add_argument(
+ "--stratified", action="store_true",
+ help="If set, subsample training data in a class-stratified manner (requires support in build_dataset)."
+ )
+
+ # ---- Runtime
+ parser.add_argument("--device", default="cuda")
+ parser.add_argument("--seed", default=0, type=int)
+ parser.add_argument("--resume", default="", help="Resume full state (optimizer, scaler, etc.)")
+ parser.add_argument("--start_epoch", default=0, type=int, metavar="N")
+ parser.add_argument("--eval", action="store_true", help="Evaluation only")
+ parser.add_argument("--dist_eval", action="store_true", default=False,
+ help="Distributed evaluation (faster monitoring during training)")
+ parser.add_argument("--num_workers", default=10, type=int)
+ parser.add_argument("--pin_mem", action="store_true"); parser.set_defaults(pin_mem=True)
+
+ # ---- Distributed
+ parser.add_argument("--world_size", default=1, type=int)
+ parser.add_argument("--local_rank", default=-1, type=int)
+ parser.add_argument("--dist_on_itp", action="store_true")
+ parser.add_argument("--dist_url", default="env://")
+
+ # ---- Misc
+ parser.add_argument("--savemodel", action="store_true", default=True, help="Save best model")
+ parser.add_argument("--norm", default="IMAGENET", type=str)
+ parser.add_argument("--enhance", action="store_true", default=False)
+ parser.add_argument("--datasets_seed", default=2026, type=int)
+
+ return parser
+
+
+# =========================
+# Main
+# =========================
+def main(args, criterion):
+ # ---- Optionally load args from resume (when training)
+ if args.resume and not args.eval:
+ resume_path = args.resume
+ checkpoint = torch.load(args.resume, map_location="cpu")
+ print(f"Load checkpoint (args) from: {args.resume}")
+ args = checkpoint["args"]
+ args.resume = resume_path
+
+ # ---- Distributed setup
+ misc.init_distributed_mode(args)
+
+ print(f"job dir: {os.path.dirname(os.path.realpath(__file__))}")
+ print(f"{args}".replace(", ", ",\n"))
+
+ device = torch.device(args.device)
+
+ # ---- Reproducibility
+ seed = args.seed + misc.get_rank()
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ cudnn.benchmark = True
+
+ # ---- Build model
+ if args.model == "RETFound_mae":
+ model = models.__dict__[args.model](
+ img_size=args.input_size,
+ num_classes=args.nb_classes,
+ drop_path_rate=args.drop_path,
+ global_pool=args.global_pool,
+ )
+ else:
+ model = models.__dict__[args.model](
+ num_classes=args.nb_classes,
+ drop_path_rate=args.drop_path,
+ args=args,
+ )
+
+ # ---- Load pre-trained weights (if requested and not eval-only)
+ if args.finetune and not args.eval:
+ print(f"Preparing to load pre-trained weights: {args.finetune}")
+
+ if args.model in ["Dinov3", "Dinov2"]:
+ checkpoint_path = args.finetune # local path
+ elif args.model in ["RETFound_dinov2", "RETFound_mae"]:
+ print(f"Downloading pre-trained weights from Hugging Face Hub: {args.finetune}")
+ checkpoint_path = hf_hub_download(
+ repo_id=f"YukunZhou/{args.finetune}",
+ filename=f"{args.finetune}.pth",
+ )
+ else:
+ raise ValueError(
+ f"Unsupported model '{args.model}'. "
+ f"Expected one of: Dinov3, Dinov2, RETFound_dinov2, RETFound_mae"
+ )
+
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
+ print(f"Loaded pre-trained checkpoint from: {checkpoint_path}")
+
+ if args.model in ["Dinov3", "Dinov2"]:
+ checkpoint_model = checkpoint
+ elif args.model == "RETFound_dinov2":
+ checkpoint_model = checkpoint["teacher"]
+ else: # RETFound_mae
+ checkpoint_model = checkpoint["model"]
+
+ # -- Key hygiene
+ checkpoint_model = {k.replace("backbone.", ""): v for k, v in checkpoint_model.items()}
+ checkpoint_model = {k.replace("mlp.w12.", "mlp.fc1."): v for k, v in checkpoint_model.items()}
+ checkpoint_model = {k.replace("mlp.w3.", "mlp.fc2."): v for k, v in checkpoint_model.items()}
+
+ # -- Remove classifier if shape mismatched
+ state_dict = model.state_dict()
+ for k in ["head.weight", "head.bias"]:
+ if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
+ print(f"Removing key {k} from pretrained checkpoint")
+ del checkpoint_model[k]
+
+ # -- Interpolate pos embed (ViT)
+ interpolate_pos_embed(model, checkpoint_model)
+
+ # -- Load backbone weights (non-strict)
+ _ = model.load_state_dict(checkpoint_model, strict=False)
+
+ # -- Re-init head
+ if hasattr(model, "head") and hasattr(model.head, "weight"):
+ trunc_normal_(model.head.weight, std=2e-5)
+
+ # ---- Datasets & samplers
+ dataset_train = build_dataset(is_train="train", args=args)
+ dataset_val = build_dataset(is_train="val", args=args)
+ dataset_test = build_dataset(is_train="test", args=args)
+
+ num_tasks = misc.get_world_size()
+ global_rank = misc.get_rank()
+
+ if not args.eval:
+ sampler_train = torch.utils.data.DistributedSampler(
+ dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
+ )
+ print(f"Sampler_train = {sampler_train}")
+ if args.dist_eval:
+ if len(dataset_val) % num_tasks != 0:
+ print("Warning: dist eval with dataset not divisible by #procs; results may differ slightly.")
+ sampler_val = torch.utils.data.DistributedSampler(
+ dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True
+ )
+ else:
+ sampler_val = torch.utils.data.SequentialSampler(dataset_val)
+
+ if args.dist_eval:
+ if len(dataset_test) % num_tasks != 0:
+ print("Warning: dist eval test set not divisible by #procs; results may differ slightly.")
+ sampler_test = torch.utils.data.DistributedSampler(
+ dataset_test, num_replicas=num_tasks, rank=global_rank, shuffle=True
+ )
+ else:
+ sampler_test = torch.utils.data.SequentialSampler(dataset_test)
+
+ # ---- Logging
+ if global_rank == 0 and args.log_dir is not None and not args.eval:
+ os.makedirs(args.log_dir, exist_ok=True)
+ log_writer = SummaryWriter(log_dir=os.path.join(args.log_dir, args.task))
+ else:
+ log_writer = None
+
+ # ---- DataLoaders
+ if not args.eval:
+ data_loader_train = torch.utils.data.DataLoader(
+ dataset_train, sampler=sampler_train,
+ batch_size=args.batch_size, num_workers=args.num_workers,
+ pin_memory=args.pin_mem, drop_last=True,
+ )
+ print(f"len of train_set: {len(data_loader_train) * args.batch_size}")
+
+ data_loader_val = torch.utils.data.DataLoader(
+ dataset_val, sampler=sampler_val,
+ batch_size=args.batch_size, num_workers=args.num_workers,
+ pin_memory=args.pin_mem, drop_last=False,
+ )
+
+ data_loader_test = torch.utils.data.DataLoader(
+ dataset_test, sampler=sampler_test,
+ batch_size=args.batch_size, num_workers=args.num_workers,
+ pin_memory=args.pin_mem, drop_last=False,
+ )
+
+ # ---- Mixup/CutMix
+ mixup_fn = None
+ mixup_active = (args.mixup > 0) or (args.cutmix > 0.) or (args.cutmix_minmax is not None)
+ if mixup_active:
+ print("Mixup is activated!")
+ mixup_fn = Mixup(
+ mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
+ prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
+ label_smoothing=args.smoothing, num_classes=args.nb_classes
+ )
+
+ # ---- Eval-only: resume weights
+ if args.resume and args.eval:
+ checkpoint = torch.load(args.resume, map_location="cpu")
+ print(f"Load checkpoint for eval from: {args.resume}")
+ model.load_state_dict(checkpoint["model"])
+
+ model.to(device)
+ model_without_ddp = model
+
+ # ---- Adaptation toggle
+ if args.adaptation == "lp":
+ for name, param in model.named_parameters():
+ param.requires_grad = ("head" in name)
+ print("[Adaptation] Linear probe: training classifier head only.")
+ else:
+ print("[Adaptation] Full fine-tuning: training all parameters.")
+
+ # ---- Count trainable params
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ print(f"number of trainable params (M): {n_parameters / 1.e6:.2f}")
+
+ # ---- LR scaling by effective batch size
+ eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
+ if args.lr is None:
+ args.lr = args.blr * eff_batch_size / 256
+ print(f"base lr: {args.lr * 256 / eff_batch_size:.2e}")
+ print(f"actual lr: {args.lr:.2e}")
+ print(f"accumulate grad iterations: {args.accum_iter}")
+ print(f"effective batch size: {eff_batch_size}")
+
+ # ---- DDP (if available)
+ if args.distributed and torch.cuda.device_count() > 1:
+ ddp_kwargs = {}
+ if args.adaptation == "lp":
+ ddp_kwargs["find_unused_parameters"] = True
+ model = torch.nn.parallel.DistributedDataParallel(
+ model, device_ids=[args.gpu], **ddp_kwargs
+ )
+ model_without_ddp = model.module
+ else:
+ model_without_ddp = model # single-GPU
+
+ # ---- Optimizer param groups (after freezing)
+ no_weight_decay = (model_without_ddp.no_weight_decay()
+ if hasattr(model_without_ddp, "no_weight_decay") else [])
+
+
+ param_groups = lrd.param_groups_lrd(
+ model_without_ddp,
+ weight_decay=args.weight_decay,
+ no_weight_decay_list=no_weight_decay,
+ layer_decay=args.layer_decay,
+ )
+ for g in param_groups:
+ g["params"] = [p for p in g["params"] if p.requires_grad]
+
+ optimizer = torch.optim.AdamW(param_groups, lr=args.lr)
+ loss_scaler = NativeScaler()
+ print(f"criterion = {criterion}")
+
+ # ---- Load previous full state (optimizer, scaler, etc.)
+ misc.load_model(args=args, model_without_ddp=model_without_ddp,
+ optimizer=optimizer, loss_scaler=loss_scaler)
+
+ # =========================
+ # Eval-only Short Circuit
+ # =========================
+ if args.eval:
+ if "checkpoint" in locals() and isinstance(checkpoint, dict) and ("epoch" in checkpoint):
+ print(f"Test with the best model at epoch = {checkpoint['epoch']}")
+ test_stats, auc_roc = evaluate(
+ data_loader_test, model, device, args, epoch=0, mode="test",
+ num_class=args.nb_classes, log_writer=log_writer
+ )
+ return
+
+ # =========================
+ # Train Loop
+ # =========================
+ print(f"Start training for {args.epochs} epochs")
+ start_time = time.time()
+ max_score = 0.0
+ best_epoch = 0
+
+ for epoch in range(args.start_epoch, args.epochs):
+ if args.distributed:
+ data_loader_train.sampler.set_epoch(epoch)
+
+ train_stats = train_one_epoch(
+ model, criterion, data_loader_train,
+ optimizer, device, epoch, loss_scaler,
+ args.clip_grad, mixup_fn,
+ log_writer=log_writer, args=args
+ )
+
+ val_stats, val_score = evaluate(
+ data_loader_val, model, device, args, epoch, mode="val",
+ num_class=args.nb_classes, log_writer=log_writer
+ )
+
+ if max_score < val_score:
+ max_score = val_score
+ best_epoch = epoch
+ if args.output_dir and args.savemodel:
+ misc.save_model(
+ args=args, model=model, model_without_ddp=model_without_ddp,
+ optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch, mode="best"
+ )
+ print(f"Best epoch = {best_epoch}, Best score = {max_score:.4f}")
+
+ if log_writer is not None:
+ log_writer.add_scalar("loss/val", val_stats["loss"], epoch)
+ log_writer.flush()
+
+ log_stats = {**{f"train_{k}": v for k, v in train_stats.items()},
+ "epoch": epoch,
+ "n_parameters": n_parameters}
+
+ if args.output_dir and misc.is_main_process():
+ with open(os.path.join(args.output_dir, args.task, "log.txt"), "a", encoding="utf-8") as f:
+ f.write(json.dumps(log_stats) + "\n")
+
+ # =========================
+ # Final Test (Best Ckpt)
+ # =========================
+ ckpt_path = os.path.join(args.output_dir, args.task, "checkpoint-best.pth")
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
+ model_without_ddp.load_state_dict(checkpoint["model"], strict=False)
+ model.to(device)
+ print(f"Test with the best model, epoch = {checkpoint.get('epoch', -1)}:")
+ _test_stats, _auc_roc = evaluate(
+ data_loader_test, model, device, args, -1, mode="test",
+ num_class=args.nb_classes, log_writer=None
+ )
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print(f"Training time {total_time_str}")
+
+
+if __name__ == "__main__":
+ args = get_args_parser()
+ args = args.parse_args()
+
+ criterion = torch.nn.CrossEntropyLoss()
+
+ if args.output_dir:
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
+
+ main(args, criterion)
diff --git a/main_segmentation.py b/main_segmentation.py
new file mode 100644
index 00000000..2b2c257c
--- /dev/null
+++ b/main_segmentation.py
@@ -0,0 +1,227 @@
+import os
+import argparse
+import logging
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import Dataset, DataLoader
+import cv2
+from albumentations import Compose, Resize, Normalize
+from albumentations.pytorch import ToTensorV2
+from huggingface_hub import hf_hub_download
+from util.pos_embed import interpolate_pos_embed
+
+from models_segmentation import RETFoundSegmentation
+from engine_segmentation import (
+ train_segmentation,
+ evaluate_segmentation,
+ combined_loss_fn,
+ compute_metrics,
+)
+
+# Configure logging
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+# ============================================================
+# Dataset Definition
+# ============================================================
+class OCTDrusenDataset(Dataset):
+ """
+ Dataset for OCT drusen segmentation.
+
+ Expected structure:
+ root/
+ ├── images/
+ └── masks/ (same filename + _mask.png)
+ """
+
+ def __init__(self, root, transform=None):
+
+ self.image_dir = os.path.join(root, "images")
+ self.mask_dir = os.path.join(root, "masks")
+ self.transform = transform
+
+ # Collect valid image–mask pairs
+ self.samples = []
+ for img_name in sorted(os.listdir(self.image_dir)):
+ stem = os.path.splitext(img_name)[0]
+ mask_name = stem + "_mask.png"
+ mask_path = os.path.join(self.mask_dir, mask_name)
+
+ if os.path.isfile(mask_path):
+ self.samples.append((img_name, mask_name))
+
+ print(f"Loaded {len(self.samples)} valid samples from {root}")
+
+ def __len__(self):
+ return len(self.samples)
+
+ def __getitem__(self, idx):
+ """
+ Load image and mask and apply transforms.
+ """
+
+ img_name, mask_name = self.samples[idx]
+
+ # ----- Image -----
+ img = cv2.imread(os.path.join(self.image_dir, img_name))
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+
+ # ----- Mask -----
+ mask = cv2.imread(
+ os.path.join(self.mask_dir, mask_name),
+ cv2.IMREAD_GRAYSCALE
+ )
+
+ mask = (mask > 0).astype("uint8")
+
+ # ----- Transform -----
+ if self.transform:
+ aug = self.transform(image=img, mask=mask)
+ img, mask = aug["image"], aug["mask"]
+
+ return img, torch.tensor(mask, dtype=torch.long)
+
+
+# ============================================================
+# Main Training Script
+# ============================================================
+def main():
+
+ # --------------------------------------------------------
+ # Arguments
+ # --------------------------------------------------------
+ parser = argparse.ArgumentParser("RETFound Segmentation")
+
+ parser.add_argument("--data_path", type=str, required=True)
+ parser.add_argument("--epochs", type=int, default=50)
+ parser.add_argument("--batch_size", type=int, default=4)
+ parser.add_argument("--lr", type=float, default=1e-4)
+
+ parser.add_argument("--img_size", type=int, default=512)
+ parser.add_argument("--patch_size", type=int, default=16)
+ parser.add_argument("--drop_path", type=float, default=0.2)
+
+ parser.add_argument("--finetune", type=str, default="")
+ parser.add_argument("--output_dir", type=str, default="./segmentation_output")
+
+ parser.add_argument("--dice_weight", type=float, default=1.0)
+ parser.add_argument("--ce_weight", type=str, default="0.3,0.7")
+
+ args = parser.parse_args()
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # --------------------------------------------------------
+ # Transform
+ # --------------------------------------------------------
+ transform = Compose([
+ Resize(args.img_size, args.img_size),
+ Normalize((0.485, 0.456, 0.406),
+ (0.229, 0.224, 0.225)),
+ ToTensorV2()
+ ])
+
+ # --------------------------------------------------------
+ # Datasets & Loaders (ONLY TRAIN + VAL)
+ # --------------------------------------------------------
+ train_ds = OCTDrusenDataset(os.path.join(args.data_path, "train"), transform)
+ val_ds = OCTDrusenDataset(os.path.join(args.data_path, "val"), transform)
+
+ train_loader = DataLoader(train_ds, args.batch_size, shuffle=True, num_workers=4)
+ val_loader = DataLoader(val_ds, args.batch_size, shuffle=False, num_workers=4)
+
+ # --------------------------------------------------------
+ # Model
+ # --------------------------------------------------------
+ model = RETFoundSegmentation(
+ args.img_size,
+ args.patch_size,
+ num_classes=2,
+ drop_path=args.drop_path
+ ).to(device)
+
+ # --------------------------------------------------------
+ # Load Pretrained RETFound Encoder (optional)
+ # --------------------------------------------------------
+ if args.finetune:
+
+ if os.path.isfile(args.finetune):
+ ckpt_path = args.finetune
+ else:
+ ckpt_path = hf_hub_download(
+ repo_id=f"YukunZhou/{args.finetune}",
+ filename="pytorch_model.bin"
+ )
+
+ state = torch.load(ckpt_path, map_location="cpu", weights_only=False)
+ state = state["model"] if "model" in state else state
+
+ # Remove classification head
+ for k in ["head.weight", "head.bias"]:
+ if k in state:
+ del state[k]
+
+ interpolate_pos_embed(model.encoder, state)
+ model.encoder.load_state_dict(state, strict=False)
+
+ print("Pretrained RETFound weights loaded.")
+
+ # --------------------------------------------------------
+ # Loss & Optimizer
+ # --------------------------------------------------------
+ ce_weights = torch.tensor(
+ [float(x) for x in args.ce_weight.split(",")]
+ ).to(device)
+
+ ce_loss = nn.CrossEntropyLoss(weight=ce_weights)
+
+ def loss_fn(out, tgt):
+ return combined_loss_fn(out, tgt, ce_loss, args.dice_weight)
+
+ optimizer = optim.AdamW(model.parameters(), lr=args.lr)
+
+ print("\n[DEBUG] Starting training loop...\n")
+
+ # --------------------------------------------------------
+ # Training Loop (NO TEST DATA)
+ # --------------------------------------------------------
+ best = float("inf")
+
+ for e in range(args.epochs):
+
+ print(f"[DEBUG] Entered epoch {e+1}")
+
+ train_loss = train_segmentation(
+ model, train_loader, loss_fn, optimizer, device
+ )
+
+ val_loss, P, T = evaluate_segmentation(
+ model, val_loader, loss_fn, device
+ )
+
+ acc, dice, iou = compute_metrics(P, T)
+
+ print(
+ f"Epoch {e+1}: "
+ f"Train={train_loss:.4f} | "
+ f"Val={val_loss:.4f} | "
+ f"Dice={dice:.4f} | "
+ f"IoU={iou:.4f}"
+ )
+
+ # Save best model based on validation loss
+ if val_loss < best:
+ best = val_loss
+ torch.save(
+ model.state_dict(),
+ os.path.join(args.output_dir, "best.pth")
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/models_segmentation.py b/models_segmentation.py
new file mode 100644
index 00000000..9adef928
--- /dev/null
+++ b/models_segmentation.py
@@ -0,0 +1,138 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from models_vit import RETFound_mae
+
+
+# ============================================================
+# Decoder / Segmentation Head
+# ============================================================
+class SegmentationHead(nn.Module):
+ """
+ Lightweight decoder that converts ViT patch embeddings
+ into full-resolution segmentation map.
+
+ Steps:
+ - Reshape sequence → 2D feature map
+ - Upsample to original image size
+ - Apply small CNN to produce class logits
+ """
+
+ def __init__(self, hidden_dim, num_classes, img_size, patch_size):
+ super().__init__()
+
+ # Patch geometry from ViT
+ self.patch_size = patch_size
+ self.h = img_size // patch_size
+ self.w = img_size // patch_size
+
+ # Simple convolutional decoder
+ self.conv = nn.Sequential(
+ # Reduce channel dimension
+ nn.Conv2d(hidden_dim, hidden_dim // 2, 3, padding=1),
+ nn.ReLU(inplace=True),
+
+ # Final layer → number of classes
+ nn.Conv2d(hidden_dim // 2, num_classes, 1),
+ )
+
+ def forward(self, x):
+ """
+ Args:
+ x: ViT token embeddings [B, N, C]
+ (without CLS token)
+
+ Returns:
+ Segmentation logits [B, num_classes, H, W]
+ """
+
+ B, N, C = x.shape
+
+ # Reshape sequence back to 2D feature map
+ x = x.reshape(B, self.h, self.w, C).permute(0, 3, 1, 2)
+
+ # Upsample from patch grid → image resolution
+ x = F.interpolate(
+ x,
+ scale_factor=self.patch_size,
+ mode="bilinear",
+ align_corners=False
+ )
+
+ # Apply conv decoder to get class logits
+ return self.conv(x)
+
+
+# ============================================================
+# Full RETFound + Decoder Model
+# ============================================================
+class RETFoundSegmentation(nn.Module):
+ """
+ Segmentation model built on top of RETFound MAE encoder.
+
+ Architecture:
+ RETFound ViT Encoder → SegmentationHead Decoder
+ """
+
+ def __init__(
+ self,
+ img_size=512,
+ patch_size=16,
+ hidden_dim=1024,
+ num_classes=2,
+ drop_path=0.2
+ ):
+ super().__init__()
+
+ # ----------------------------------------------------
+ # Encoder: pretrained RETFound ViT (MAE)
+ # ----------------------------------------------------
+ self.encoder = RETFound_mae(
+ img_size=img_size,
+ num_classes=num_classes,
+ drop_path_rate=drop_path,
+ global_pool=False # keep token sequence
+ )
+
+ # ----------------------------------------------------
+ # Decoder head for pixel prediction
+ # ----------------------------------------------------
+ self.seg_head = SegmentationHead(
+ hidden_dim,
+ num_classes,
+ img_size,
+ patch_size
+ )
+
+ def forward(self, x):
+ """
+ Forward pass:
+ 1. Patch embedding
+ 2. Add CLS token
+ 3. Positional embedding
+ 4. Transformer blocks
+ 5. Decoder head
+ """
+
+ B = x.size(0)
+
+ # ----- Patch embedding -----
+ x = self.encoder.patch_embed(x)
+
+ # ----- Add CLS token -----
+ cls = self.encoder.cls_token.expand(B, -1, -1)
+ x = torch.cat((cls, x), dim=1)
+
+ # ----- Positional encoding -----
+ x = x + self.encoder.pos_embed
+ x = self.encoder.pos_drop(x)
+
+ # ----- Transformer encoder blocks -----
+ for blk in self.encoder.blocks:
+ x = blk(x)
+
+ # ----- Final normalization -----
+ x = self.encoder.norm(x)
+
+ # ----- Remove CLS token & decode -----
+ return self.seg_head(x[:, 1:])
diff --git a/models_vit.py b/models_vit.py
index 82e7fbdb..6640aa80 100644
--- a/models_vit.py
+++ b/models_vit.py
@@ -1,105 +1,105 @@
-
-from functools import partial
-
-import timm.models.vision_transformer
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from torch import Tensor
-from timm.models.layers import trunc_normal_
-
-class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
- """ Vision Transformer with support for global average pooling
- """
- def __init__(self, global_pool=False, **kwargs):
- super(VisionTransformer, self).__init__(**kwargs)
-
- self.global_pool = global_pool
- if self.global_pool:
- norm_layer = kwargs['norm_layer']
- embed_dim = kwargs['embed_dim']
- self.fc_norm = norm_layer(embed_dim)
-
- del self.norm # remove the original norm
-
- def forward_features(self, x):
- B = x.shape[0]
- x = self.patch_embed(x)
-
- cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
- x = torch.cat((cls_tokens, x), dim=1)
- x = x + self.pos_embed
- x = self.pos_drop(x)
-
- for blk in self.blocks:
- x = blk(x)
-
- if self.global_pool:
- x = x[:, 1:, :].mean(dim=1,keepdim=True) # global pool without cls token
- outcome = self.fc_norm(x)
- else:
- x = self.norm(x)
- outcome = x[:, 0]
-
- return outcome
-
-
-def RETFound_mae(**kwargs):
- model = VisionTransformer(
- patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
- norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
- return model
-
-
-
-def Dinov2(args, **kwargs):
-
- if args.model_arch == 'dinov2_vits14':
- arch = 'vit_small_patch14_dinov2.lvd142m'
- elif args.model_arch == 'dinov2_vitb14':
- arch = 'vit_base_patch14_dinov2.lvd142m'
- elif args.model_arch == 'dinov2_vitl14':
- arch = 'vit_large_patch14_dinov2.lvd142m'
- elif args.model_arch == 'dinov2_vitg14':
- arch = 'vit_giant_patch14_dinov2.lvd142m'
- else:
- raise ValueError(f"Unknown model_arch '{args.model_arch}'. "
- f"Expected one of: dinov2_vits14, dinov2_vitb14, dinov2_vitl14, dinov2_vitg14")
-
- model = timm.create_model(
- arch,
- pretrained=True,
- img_size=224,
- **kwargs
- )
- return model
-
-
-
-def RETFound_dinov2(args, **kwargs):
- model = timm.create_model(
- 'vit_large_patch14_dinov2.lvd142m',
- pretrained=True,
- img_size=224,
- **kwargs
- )
- return model
-
-
-def Dinov3(args, **kwargs):
- # Load ViT-L/16 backbone (hub model has `head = Identity` by default)
- model = torch.hub.load(
- repo_or_dir="facebookresearch/dinov3",
- model=args.model_arch,
- pretrained=False, # main() will load your checkpoint
- trust_repo=True,
- )
-
- # Figure out feature dimension for the probe
- feat_dim = getattr(model, "embed_dim", None) or getattr(model, "num_features", None)
- model.head = nn.Linear(feat_dim, args.nb_classes)
- trunc_normal_(model.head.weight, std=2e-5)
- if model.head.bias is not None:
- nn.init.zeros_(model.head.bias)
-
- return model
+
+from functools import partial
+
+import timm.models.vision_transformer
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+from timm.models.layers import trunc_normal_
+
+class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
+ """ Vision Transformer with support for global average pooling
+ """
+ def __init__(self, global_pool=False, **kwargs):
+ super(VisionTransformer, self).__init__(**kwargs)
+
+ self.global_pool = global_pool
+ if self.global_pool:
+ norm_layer = kwargs['norm_layer']
+ embed_dim = kwargs['embed_dim']
+ self.fc_norm = norm_layer(embed_dim)
+
+ del self.norm # remove the original norm
+
+ def forward_features(self, x):
+ B = x.shape[0]
+ x = self.patch_embed(x)
+
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ if self.global_pool:
+ x = x[:, 1:, :].mean(dim=1,keepdim=True) # global pool without cls token
+ outcome = self.fc_norm(x)
+ else:
+ x = self.norm(x)
+ outcome = x[:, 0]
+
+ return outcome
+
+
+def RETFound_mae(**kwargs):
+ model = VisionTransformer(
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+ return model
+
+
+
+def Dinov2(args, **kwargs):
+
+ if args.model_arch == 'dinov2_vits14':
+ arch = 'vit_small_patch14_dinov2.lvd142m'
+ elif args.model_arch == 'dinov2_vitb14':
+ arch = 'vit_base_patch14_dinov2.lvd142m'
+ elif args.model_arch == 'dinov2_vitl14':
+ arch = 'vit_large_patch14_dinov2.lvd142m'
+ elif args.model_arch == 'dinov2_vitg14':
+ arch = 'vit_giant_patch14_dinov2.lvd142m'
+ else:
+ raise ValueError(f"Unknown model_arch '{args.model_arch}'. "
+ f"Expected one of: dinov2_vits14, dinov2_vitb14, dinov2_vitl14, dinov2_vitg14")
+
+ model = timm.create_model(
+ arch,
+ pretrained=True,
+ img_size=224,
+ **kwargs
+ )
+ return model
+
+
+
+def RETFound_dinov2(args, **kwargs):
+ model = timm.create_model(
+ 'vit_large_patch14_dinov2.lvd142m',
+ pretrained=True,
+ img_size=224,
+ **kwargs
+ )
+ return model
+
+
+def Dinov3(args, **kwargs):
+ # Load ViT-L/16 backbone (hub model has `head = Identity` by default)
+ model = torch.hub.load(
+ repo_or_dir="facebookresearch/dinov3",
+ model=args.model_arch,
+ pretrained=False, # main() will load your checkpoint
+ trust_repo=True,
+ )
+
+ # Figure out feature dimension for the probe
+ feat_dim = getattr(model, "embed_dim", None) or getattr(model, "num_features", None)
+ model.head = nn.Linear(feat_dim, args.nb_classes)
+ trunc_normal_(model.head.weight, std=2e-5)
+ if model.head.bias is not None:
+ nn.init.zeros_(model.head.bias)
+
+ return model
diff --git a/requirements.txt b/requirements.txt
index 0c2b9399..6f012a00 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,11 +1,12 @@
-opencv-python~=4.9.0.80
-Pillow~=10.2.0
-pycm~=4.0
-scikit-learn~=1.4.2
-timm~=0.9.2
-
-numpy~=1.26.4
-matplotlib~=3.8.4
-scikit-multilearn~=0.2.0
-huggingface-hub~=0.23.4
-tensorboard~=2.17.0
\ No newline at end of file
+opencv-python~=4.9.0.80
+Pillow~=10.2.0
+pycm~=4.0
+scikit-learn~=1.4.2
+timm~=0.9.2
+
+numpy~=1.26.4
+matplotlib~=3.8.4
+scikit-multilearn~=0.2.0
+huggingface-hub~=0.23.4
+tensorboard~=2.17.0
+albumentations~=1.4.3
\ No newline at end of file
diff --git a/tree.txt b/tree.txt
new file mode 100644
index 00000000..ad8561c1
Binary files /dev/null and b/tree.txt differ
diff --git a/util/datasets.py b/util/datasets.py
index 20f65f20..7458331f 100644
--- a/util/datasets.py
+++ b/util/datasets.py
@@ -1,81 +1,81 @@
-import os
-import torch
-from torch.utils.data import Subset
-from torchvision import datasets, transforms
-from timm.data import create_transform
-from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
-
-def build_dataset(is_train, args):
- transform = build_transform(is_train, args)
- root = os.path.join(args.data_path, is_train)
- dataset = datasets.ImageFolder(root, transform=transform)
-
- if is_train == 'train':
- ratio = float(getattr(args, "dataratio", 1.0))
- seed = int(getattr(args, "seed", 0))
- stratified = bool(getattr(args, "stratified", False))
-
- if 0.0 < ratio < 1.0:
- if stratified:
- idx = _stratified_indices(dataset.targets, ratio, seed)
- else:
- # simple uniform subsample with torch.Generator for reproducibility
- g = torch.Generator().manual_seed(seed)
- n = len(dataset)
- k = max(1, int(n * ratio))
- idx = torch.randperm(n, generator=g)[:k].tolist()
- dataset = Subset(dataset, idx)
-
- return dataset
-
-def build_transform(is_train, args):
- mean = IMAGENET_DEFAULT_MEAN
- std = IMAGENET_DEFAULT_STD
-
- if is_train == 'train':
- return create_transform(
- input_size=args.input_size,
- is_training=True,
- color_jitter=args.color_jitter,
- auto_augment=args.aa,
- interpolation='bicubic',
- re_prob=args.reprob,
- re_mode=args.remode,
- re_count=args.recount,
- mean=mean,
- std=std,
- )
-
- # eval transform
- crop_pct = 224 / 256 if args.input_size <= 224 else 1.0
- size = int(args.input_size / crop_pct)
- t = [
- transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
- transforms.CenterCrop(args.input_size),
- transforms.ToTensor(),
- transforms.Normalize(mean, std),
- ]
- return transforms.Compose(t)
-
-# ---- helpers ----
-
-def _stratified_indices(targets, ratio: float, seed: int):
- """Maintain class proportions. Ensures at least 1 sample per class when possible."""
- t = torch.as_tensor(targets)
- classes = torch.unique(t)
- g = torch.Generator().manual_seed(seed)
-
- keep = []
- for c in classes.tolist():
- cls_idx = torch.nonzero(t == c, as_tuple=False).view(-1)
- if len(cls_idx) == 0:
- continue
- k = max(1, int(round(len(cls_idx) * ratio)))
- sel = cls_idx[torch.randperm(len(cls_idx), generator=g)[:k]]
- keep.extend(sel.tolist())
-
- # shuffle final indices (stable across seed)
- g2 = torch.Generator().manual_seed(seed + 1)
- keep = torch.tensor(keep)[torch.randperm(len(keep), generator=g2)].tolist()
- return keep
-
+import os
+import torch
+from torch.utils.data import Subset
+from torchvision import datasets, transforms
+from timm.data import create_transform
+from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+
+def build_dataset(is_train, args):
+ transform = build_transform(is_train, args)
+ root = os.path.join(args.data_path, is_train)
+ dataset = datasets.ImageFolder(root, transform=transform)
+
+ if is_train == 'train':
+ ratio = float(getattr(args, "dataratio", 1.0))
+ seed = int(getattr(args, "seed", 0))
+ stratified = bool(getattr(args, "stratified", False))
+
+ if 0.0 < ratio < 1.0:
+ if stratified:
+ idx = _stratified_indices(dataset.targets, ratio, seed)
+ else:
+ # simple uniform subsample with torch.Generator for reproducibility
+ g = torch.Generator().manual_seed(seed)
+ n = len(dataset)
+ k = max(1, int(n * ratio))
+ idx = torch.randperm(n, generator=g)[:k].tolist()
+ dataset = Subset(dataset, idx)
+
+ return dataset
+
+def build_transform(is_train, args):
+ mean = IMAGENET_DEFAULT_MEAN
+ std = IMAGENET_DEFAULT_STD
+
+ if is_train == 'train':
+ return create_transform(
+ input_size=args.input_size,
+ is_training=True,
+ color_jitter=args.color_jitter,
+ auto_augment=args.aa,
+ interpolation='bicubic',
+ re_prob=args.reprob,
+ re_mode=args.remode,
+ re_count=args.recount,
+ mean=mean,
+ std=std,
+ )
+
+ # eval transform
+ crop_pct = 224 / 256 if args.input_size <= 224 else 1.0
+ size = int(args.input_size / crop_pct)
+ t = [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
+ transforms.CenterCrop(args.input_size),
+ transforms.ToTensor(),
+ transforms.Normalize(mean, std),
+ ]
+ return transforms.Compose(t)
+
+# ---- helpers ----
+
+def _stratified_indices(targets, ratio: float, seed: int):
+ """Maintain class proportions. Ensures at least 1 sample per class when possible."""
+ t = torch.as_tensor(targets)
+ classes = torch.unique(t)
+ g = torch.Generator().manual_seed(seed)
+
+ keep = []
+ for c in classes.tolist():
+ cls_idx = torch.nonzero(t == c, as_tuple=False).view(-1)
+ if len(cls_idx) == 0:
+ continue
+ k = max(1, int(round(len(cls_idx) * ratio)))
+ sel = cls_idx[torch.randperm(len(cls_idx), generator=g)[:k]]
+ keep.extend(sel.tolist())
+
+ # shuffle final indices (stable across seed)
+ g2 = torch.Generator().manual_seed(seed + 1)
+ keep = torch.tensor(keep)[torch.randperm(len(keep), generator=g2)].tolist()
+ return keep
+
diff --git a/util/lr_decay.py b/util/lr_decay.py
index 652fcd87..c915491b 100644
--- a/util/lr_decay.py
+++ b/util/lr_decay.py
@@ -1,74 +1,74 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-# Partly revised by YZ @UCL&Moorfields
-# --------------------------------------------------------
-
-import json
-
-
-def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75):
- """
- Parameter groups for layer-wise lr decay
- Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
- """
- param_group_names = {}
- param_groups = {}
-
- if hasattr(model, 'blocks'):
- num_layers = len(model.blocks) + 1
- else:
- # use the number of layers in the ResNet model as a default value
- num_layers = len(model.layer1) + len(model.layer2) + len(model.layer3) + len(model.layer4) + 1
-
- layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
-
- for n, p in model.named_parameters():
- if not p.requires_grad:
- continue
-
- # no decay: all 1D parameters and model specific ones
- if p.ndim == 1 or n in no_weight_decay_list:
- g_decay = "no_decay"
- this_decay = 0.
- else:
- g_decay = "decay"
- this_decay = weight_decay
-
- layer_id = get_layer_id_for_vit(n, num_layers)
- group_name = "layer_%d_%s" % (layer_id, g_decay)
-
- if group_name not in param_group_names:
- this_scale = layer_scales[layer_id]
-
- param_group_names[group_name] = {
- "lr_scale": this_scale,
- "weight_decay": this_decay,
- "params": [],
- }
- param_groups[group_name] = {
- "lr_scale": this_scale,
- "weight_decay": this_decay,
- "params": [],
- }
-
- param_group_names[group_name]["params"].append(n)
- param_groups[group_name]["params"].append(p)
-
- # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
-
- return list(param_groups.values())
-
-
-def get_layer_id_for_vit(name, num_layers):
- """
- Assign a parameter with its layer id
- Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
- """
- if name in ['cls_token', 'pos_embed']:
- return 0
- elif name.startswith('patch_embed'):
- return 0
- elif name.startswith('blocks'):
- return int(name.split('.')[1]) + 1
- else:
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+# Partly revised by YZ @UCL&Moorfields
+# --------------------------------------------------------
+
+import json
+
+
+def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75):
+ """
+ Parameter groups for layer-wise lr decay
+ Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
+ """
+ param_group_names = {}
+ param_groups = {}
+
+ if hasattr(model, 'blocks'):
+ num_layers = len(model.blocks) + 1
+ else:
+ # use the number of layers in the ResNet model as a default value
+ num_layers = len(model.layer1) + len(model.layer2) + len(model.layer3) + len(model.layer4) + 1
+
+ layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
+
+ for n, p in model.named_parameters():
+ if not p.requires_grad:
+ continue
+
+ # no decay: all 1D parameters and model specific ones
+ if p.ndim == 1 or n in no_weight_decay_list:
+ g_decay = "no_decay"
+ this_decay = 0.
+ else:
+ g_decay = "decay"
+ this_decay = weight_decay
+
+ layer_id = get_layer_id_for_vit(n, num_layers)
+ group_name = "layer_%d_%s" % (layer_id, g_decay)
+
+ if group_name not in param_group_names:
+ this_scale = layer_scales[layer_id]
+
+ param_group_names[group_name] = {
+ "lr_scale": this_scale,
+ "weight_decay": this_decay,
+ "params": [],
+ }
+ param_groups[group_name] = {
+ "lr_scale": this_scale,
+ "weight_decay": this_decay,
+ "params": [],
+ }
+
+ param_group_names[group_name]["params"].append(n)
+ param_groups[group_name]["params"].append(p)
+
+ # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
+
+ return list(param_groups.values())
+
+
+def get_layer_id_for_vit(name, num_layers):
+ """
+ Assign a parameter with its layer id
+ Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
+ """
+ if name in ['cls_token', 'pos_embed']:
+ return 0
+ elif name.startswith('patch_embed'):
+ return 0
+ elif name.startswith('blocks'):
+ return int(name.split('.')[1]) + 1
+ else:
return num_layers
\ No newline at end of file
diff --git a/util/lr_sched.py b/util/lr_sched.py
index 178e1bfb..7a3a107c 100644
--- a/util/lr_sched.py
+++ b/util/lr_sched.py
@@ -1,20 +1,20 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-# Partly revised by YZ @UCL&Moorfields
-# --------------------------------------------------------
-
-import math
-
-def adjust_learning_rate(optimizer, epoch, args):
- """Decay the learning rate with half-cycle cosine after warmup"""
- if epoch < args.warmup_epochs:
- lr = args.lr * epoch / args.warmup_epochs
- else:
- lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
- (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
- for param_group in optimizer.param_groups:
- if "lr_scale" in param_group:
- param_group["lr"] = lr * param_group["lr_scale"]
- else:
- param_group["lr"] = lr
- return lr
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+# Partly revised by YZ @UCL&Moorfields
+# --------------------------------------------------------
+
+import math
+
+def adjust_learning_rate(optimizer, epoch, args):
+ """Decay the learning rate with half-cycle cosine after warmup"""
+ if epoch < args.warmup_epochs:
+ lr = args.lr * epoch / args.warmup_epochs
+ else:
+ lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
+ (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
+ for param_group in optimizer.param_groups:
+ if "lr_scale" in param_group:
+ param_group["lr"] = lr * param_group["lr_scale"]
+ else:
+ param_group["lr"] = lr
+ return lr
diff --git a/util/misc.py b/util/misc.py
index 47f7fde2..fd4ffb2a 100644
--- a/util/misc.py
+++ b/util/misc.py
@@ -1,369 +1,369 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-# Partly revised by YZ @UCL&Moorfields
-# --------------------------------------------------------
-
-import builtins
-import datetime
-import os
-import time
-from collections import defaultdict, deque
-from pathlib import Path
-
-import torch
-import torch.distributed as dist
-from math import inf
-
-
-class SmoothedValue(object):
- """Track a series of values and provide access to smoothed values over a
- window or the global series average.
- """
-
- def __init__(self, window_size=20, fmt=None):
- if fmt is None:
- fmt = "{median:.4f} ({global_avg:.4f})"
- self.deque = deque(maxlen=window_size)
- self.total = 0.0
- self.count = 0
- self.fmt = fmt
-
- def update(self, value, n=1):
- self.deque.append(value)
- self.count += n
- self.total += value * n
-
- def synchronize_between_processes(self):
- """
- Warning: does not synchronize the deque!
- """
- if not is_dist_avail_and_initialized():
- return
- t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
- dist.barrier()
- dist.all_reduce(t)
- t = t.tolist()
- self.count = int(t[0])
- self.total = t[1]
-
- @property
- def median(self):
- d = torch.tensor(list(self.deque))
- return d.median().item()
-
- @property
- def avg(self):
- d = torch.tensor(list(self.deque), dtype=torch.float32)
- return d.mean().item()
-
- @property
- def global_avg(self):
- return self.total / self.count
-
- @property
- def max(self):
- return max(self.deque)
-
- @property
- def value(self):
- return self.deque[-1]
-
- def __str__(self):
- return self.fmt.format(
- median=self.median,
- avg=self.avg,
- global_avg=self.global_avg,
- max=self.max,
- value=self.value)
-
-
-class MetricLogger(object):
- def __init__(self, delimiter="\t"):
- self.meters = defaultdict(SmoothedValue)
- self.delimiter = delimiter
-
- def update(self, **kwargs):
- for k, v in kwargs.items():
- if v is None:
- continue
- if isinstance(v, torch.Tensor):
- v = v.item()
- assert isinstance(v, (float, int))
- self.meters[k].update(v)
-
- def __getattr__(self, attr):
- if attr in self.meters:
- return self.meters[attr]
- if attr in self.__dict__:
- return self.__dict__[attr]
- raise AttributeError("'{}' object has no attribute '{}'".format(
- type(self).__name__, attr))
-
- def __str__(self):
- loss_str = []
- for name, meter in self.meters.items():
- loss_str.append(
- "{}: {}".format(name, str(meter))
- )
- return self.delimiter.join(loss_str)
-
- def synchronize_between_processes(self):
- for meter in self.meters.values():
- meter.synchronize_between_processes()
-
- def add_meter(self, name, meter):
- self.meters[name] = meter
-
- def log_every(self, iterable, print_freq, header=None):
- i = 0
- if not header:
- header = ''
- start_time = time.time()
- end = time.time()
- iter_time = SmoothedValue(fmt='{avg:.4f}')
- data_time = SmoothedValue(fmt='{avg:.4f}')
- space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
- log_msg = [
- header,
- '[{0' + space_fmt + '}/{1}]',
- 'eta: {eta}',
- '{meters}',
- 'time: {time}',
- 'data: {data}'
- ]
- if torch.cuda.is_available():
- log_msg.append('max mem: {memory:.0f}')
- log_msg = self.delimiter.join(log_msg)
- MB = 1024.0 * 1024.0
- for obj in iterable:
- data_time.update(time.time() - end)
- yield obj
- iter_time.update(time.time() - end)
- if i % print_freq == 0 or i == len(iterable) - 1:
- eta_seconds = iter_time.global_avg * (len(iterable) - i)
- eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
- if torch.cuda.is_available():
- print(log_msg.format(
- i, len(iterable), eta=eta_string,
- meters=str(self),
- time=str(iter_time), data=str(data_time),
- memory=torch.cuda.max_memory_allocated() / MB))
- else:
- print(log_msg.format(
- i, len(iterable), eta=eta_string,
- meters=str(self),
- time=str(iter_time), data=str(data_time)))
- i += 1
- end = time.time()
- total_time = time.time() - start_time
- total_time_str = str(datetime.timedelta(seconds=int(total_time)))
- print('{} Total time: {} ({:.4f} s / it)'.format(
- header, total_time_str, total_time / len(iterable)))
-
-
-def setup_for_distributed(is_master):
- """
- This function disables printing when not in master process
- """
- builtin_print = builtins.print
-
- def print(*args, **kwargs):
- force = kwargs.pop('force', False)
- force = force or (get_world_size() > 8)
- if is_master or force:
- now = datetime.datetime.now().time()
- builtin_print('[{}] '.format(now), end='') # print with time stamp
- builtin_print(*args, **kwargs)
-
- builtins.print = print
-
-
-def is_dist_avail_and_initialized():
- if not dist.is_available():
- return False
- if not dist.is_initialized():
- return False
- return True
-
-
-def get_world_size():
- if not is_dist_avail_and_initialized():
- return 1
- return dist.get_world_size()
-
-
-def get_rank():
- if not is_dist_avail_and_initialized():
- return 0
- return dist.get_rank()
-
-
-def is_main_process():
- return get_rank() == 0
-
-
-def save_on_master(*args, **kwargs):
- if is_main_process():
- torch.save(*args, **kwargs)
-
-
-def init_distributed_mode(args):
- if args.dist_on_itp:
- args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
- args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
- args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
- args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
- os.environ['LOCAL_RANK'] = str(args.gpu)
- os.environ['RANK'] = str(args.rank)
- os.environ['WORLD_SIZE'] = str(args.world_size)
- # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
- elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
- args.rank = int(os.environ["RANK"])
- args.world_size = int(os.environ['WORLD_SIZE'])
- args.gpu = int(os.environ['LOCAL_RANK'])
- elif 'SLURM_PROCID' in os.environ:
- args.rank = int(os.environ['SLURM_PROCID'])
- args.gpu = args.rank % torch.cuda.device_count()
- else:
- print('Not using distributed mode')
- setup_for_distributed(is_master=True) # hack
- args.distributed = False
- return
-
- args.distributed = True
-
- torch.cuda.set_device(args.gpu)
- args.dist_backend = 'nccl'
- print('| distributed init (rank {}): {}, gpu {}'.format(
- args.rank, args.dist_url, args.gpu), flush=True)
- torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
- world_size=args.world_size, rank=args.rank)
- torch.distributed.barrier()
- setup_for_distributed(args.rank == 0)
-
-
-class NativeScalerWithGradNormCount:
- state_dict_key = "amp_scaler"
-
- def __init__(self):
- self._scaler = torch.cuda.amp.GradScaler()
-
- def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
- self._scaler.scale(loss).backward(create_graph=create_graph)
- if update_grad:
- if clip_grad is not None:
- assert parameters is not None
- self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
- norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
- else:
- self._scaler.unscale_(optimizer)
- norm = get_grad_norm_(parameters)
- self._scaler.step(optimizer)
- self._scaler.update()
- else:
- norm = None
- return norm
-
- def state_dict(self):
- return self._scaler.state_dict()
-
- def load_state_dict(self, state_dict):
- self._scaler.load_state_dict(state_dict)
-
-
-def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
- if isinstance(parameters, torch.Tensor):
- parameters = [parameters]
- parameters = [p for p in parameters if p.grad is not None]
- norm_type = float(norm_type)
- if len(parameters) == 0:
- return torch.tensor(0.)
- device = parameters[0].grad.device
- if norm_type == inf:
- total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
- else:
- total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]),
- norm_type)
- return total_norm
-
-
-def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, mode):
- output_dir = Path(args.output_dir)
- epoch_name = str(epoch)
- os.makedirs(os.path.join(args.output_dir, args.task), exist_ok=True)
- if loss_scaler is not None:
- if mode == 'best':
- checkpoint_paths = [os.path.join(args.output_dir, args.task, 'checkpoint-best.pth')]
- else:
- checkpoint_paths = [os.path.join(args.output_dir, args.task, 'checkpoint-latest.pth')]
- for checkpoint_path in checkpoint_paths:
- if mode == 'best':
- to_save = {
- 'model': model_without_ddp.state_dict(),
- 'epoch': epoch,
- 'args': args, }
- else:
- if epoch == args.epochs - 1:
- to_save = {
- 'model': model_without_ddp.state_dict(),
- 'args': args, }
- else:
- to_save = {
- 'model': model_without_ddp.state_dict(),
- 'optimizer': optimizer.state_dict(),
- 'epoch': epoch,
- 'scaler': loss_scaler.state_dict(),
- 'args': args,
- }
-
- save_on_master(to_save, checkpoint_path)
- else:
- if mode == 'best':
- to_save = {
- 'model': model_without_ddp.state_dict(),
- 'epoch': epoch, }
- torch.save(to_save, os.path.join(args.output_dir, args.task, "checkpoint-best.pth"))
- else:
- if epoch == args.epochs - 1:
- to_save = {
- 'model': model_without_ddp.state_dict(), }
- else:
- to_save = {
- 'model': model_without_ddp.state_dict(),
- 'optimizer': optimizer.state_dict(),
- 'epoch': epoch,
- 'args': args,
- }
- torch.save(to_save, os.path.join(args.output_dir, args.task, "checkpoint-latest.pth"))
-
-
-def load_model(args, model_without_ddp, optimizer, loss_scaler):
- if args.resume:
- if args.resume.startswith('https'):
- checkpoint = torch.hub.load_state_dict_from_url(
- args.resume, map_location='cpu', check_hash=True)
- else:
- checkpoint = torch.load(args.resume, map_location='cpu')
- if 'model' in checkpoint:
- checkpoint_model = checkpoint['model']
- else:
- checkpoint_model = checkpoint
- model_without_ddp.load_state_dict(checkpoint_model, strict=False)
- print("Resume checkpoint %s" % args.resume)
- if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
- optimizer.load_state_dict(checkpoint['optimizer'])
- args.start_epoch = checkpoint['epoch'] + 1
- if 'scaler' in checkpoint:
- loss_scaler.load_state_dict(checkpoint['scaler'])
- print("With optim & sched!")
-
-
-def all_reduce_mean(x):
- world_size = get_world_size()
- if world_size > 1:
- x_reduce = torch.tensor(x).cuda()
- dist.all_reduce(x_reduce)
- x_reduce /= world_size
- return x_reduce.item()
- else:
- return x
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+# Partly revised by YZ @UCL&Moorfields
+# --------------------------------------------------------
+
+import builtins
+import datetime
+import os
+import time
+from collections import defaultdict, deque
+from pathlib import Path
+
+import torch
+import torch.distributed as dist
+from math import inf
+
+
+class SmoothedValue(object):
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+
+ def __init__(self, window_size=20, fmt=None):
+ if fmt is None:
+ fmt = "{median:.4f} ({global_avg:.4f})"
+ self.deque = deque(maxlen=window_size)
+ self.total = 0.0
+ self.count = 0
+ self.fmt = fmt
+
+ def update(self, value, n=1):
+ self.deque.append(value)
+ self.count += n
+ self.total += value * n
+
+ def synchronize_between_processes(self):
+ """
+ Warning: does not synchronize the deque!
+ """
+ if not is_dist_avail_and_initialized():
+ return
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
+ dist.barrier()
+ dist.all_reduce(t)
+ t = t.tolist()
+ self.count = int(t[0])
+ self.total = t[1]
+
+ @property
+ def median(self):
+ d = torch.tensor(list(self.deque))
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / self.count
+
+ @property
+ def max(self):
+ return max(self.deque)
+
+ @property
+ def value(self):
+ return self.deque[-1]
+
+ def __str__(self):
+ return self.fmt.format(
+ median=self.median,
+ avg=self.avg,
+ global_avg=self.global_avg,
+ max=self.max,
+ value=self.value)
+
+
+class MetricLogger(object):
+ def __init__(self, delimiter="\t"):
+ self.meters = defaultdict(SmoothedValue)
+ self.delimiter = delimiter
+
+ def update(self, **kwargs):
+ for k, v in kwargs.items():
+ if v is None:
+ continue
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.meters[k].update(v)
+
+ def __getattr__(self, attr):
+ if attr in self.meters:
+ return self.meters[attr]
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ raise AttributeError("'{}' object has no attribute '{}'".format(
+ type(self).__name__, attr))
+
+ def __str__(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ loss_str.append(
+ "{}: {}".format(name, str(meter))
+ )
+ return self.delimiter.join(loss_str)
+
+ def synchronize_between_processes(self):
+ for meter in self.meters.values():
+ meter.synchronize_between_processes()
+
+ def add_meter(self, name, meter):
+ self.meters[name] = meter
+
+ def log_every(self, iterable, print_freq, header=None):
+ i = 0
+ if not header:
+ header = ''
+ start_time = time.time()
+ end = time.time()
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
+ data_time = SmoothedValue(fmt='{avg:.4f}')
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
+ log_msg = [
+ header,
+ '[{0' + space_fmt + '}/{1}]',
+ 'eta: {eta}',
+ '{meters}',
+ 'time: {time}',
+ 'data: {data}'
+ ]
+ if torch.cuda.is_available():
+ log_msg.append('max mem: {memory:.0f}')
+ log_msg = self.delimiter.join(log_msg)
+ MB = 1024.0 * 1024.0
+ for obj in iterable:
+ data_time.update(time.time() - end)
+ yield obj
+ iter_time.update(time.time() - end)
+ if i % print_freq == 0 or i == len(iterable) - 1:
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ if torch.cuda.is_available():
+ print(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / MB))
+ else:
+ print(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time)))
+ i += 1
+ end = time.time()
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('{} Total time: {} ({:.4f} s / it)'.format(
+ header, total_time_str, total_time / len(iterable)))
+
+
+def setup_for_distributed(is_master):
+ """
+ This function disables printing when not in master process
+ """
+ builtin_print = builtins.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop('force', False)
+ force = force or (get_world_size() > 8)
+ if is_master or force:
+ now = datetime.datetime.now().time()
+ builtin_print('[{}] '.format(now), end='') # print with time stamp
+ builtin_print(*args, **kwargs)
+
+ builtins.print = print
+
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_world_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def is_main_process():
+ return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+ if is_main_process():
+ torch.save(*args, **kwargs)
+
+
+def init_distributed_mode(args):
+ if args.dist_on_itp:
+ args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
+ args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
+ args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
+ args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
+ os.environ['LOCAL_RANK'] = str(args.gpu)
+ os.environ['RANK'] = str(args.rank)
+ os.environ['WORLD_SIZE'] = str(args.world_size)
+ # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
+ elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+ args.rank = int(os.environ["RANK"])
+ args.world_size = int(os.environ['WORLD_SIZE'])
+ args.gpu = int(os.environ['LOCAL_RANK'])
+ elif 'SLURM_PROCID' in os.environ:
+ args.rank = int(os.environ['SLURM_PROCID'])
+ args.gpu = args.rank % torch.cuda.device_count()
+ else:
+ print('Not using distributed mode')
+ setup_for_distributed(is_master=True) # hack
+ args.distributed = False
+ return
+
+ args.distributed = True
+
+ torch.cuda.set_device(args.gpu)
+ args.dist_backend = 'nccl'
+ print('| distributed init (rank {}): {}, gpu {}'.format(
+ args.rank, args.dist_url, args.gpu), flush=True)
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
+ world_size=args.world_size, rank=args.rank)
+ torch.distributed.barrier()
+ setup_for_distributed(args.rank == 0)
+
+
+class NativeScalerWithGradNormCount:
+ state_dict_key = "amp_scaler"
+
+ def __init__(self):
+ self._scaler = torch.cuda.amp.GradScaler()
+
+ def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
+ self._scaler.scale(loss).backward(create_graph=create_graph)
+ if update_grad:
+ if clip_grad is not None:
+ assert parameters is not None
+ self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
+ norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
+ else:
+ self._scaler.unscale_(optimizer)
+ norm = get_grad_norm_(parameters)
+ self._scaler.step(optimizer)
+ self._scaler.update()
+ else:
+ norm = None
+ return norm
+
+ def state_dict(self):
+ return self._scaler.state_dict()
+
+ def load_state_dict(self, state_dict):
+ self._scaler.load_state_dict(state_dict)
+
+
+def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ parameters = [p for p in parameters if p.grad is not None]
+ norm_type = float(norm_type)
+ if len(parameters) == 0:
+ return torch.tensor(0.)
+ device = parameters[0].grad.device
+ if norm_type == inf:
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
+ else:
+ total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]),
+ norm_type)
+ return total_norm
+
+
+def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, mode):
+ output_dir = Path(args.output_dir)
+ epoch_name = str(epoch)
+ os.makedirs(os.path.join(args.output_dir, args.task), exist_ok=True)
+ if loss_scaler is not None:
+ if mode == 'best':
+ checkpoint_paths = [os.path.join(args.output_dir, args.task, 'checkpoint-best.pth')]
+ else:
+ checkpoint_paths = [os.path.join(args.output_dir, args.task, 'checkpoint-latest.pth')]
+ for checkpoint_path in checkpoint_paths:
+ if mode == 'best':
+ to_save = {
+ 'model': model_without_ddp.state_dict(),
+ 'epoch': epoch,
+ 'args': args, }
+ else:
+ if epoch == args.epochs - 1:
+ to_save = {
+ 'model': model_without_ddp.state_dict(),
+ 'args': args, }
+ else:
+ to_save = {
+ 'model': model_without_ddp.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'epoch': epoch,
+ 'scaler': loss_scaler.state_dict(),
+ 'args': args,
+ }
+
+ save_on_master(to_save, checkpoint_path)
+ else:
+ if mode == 'best':
+ to_save = {
+ 'model': model_without_ddp.state_dict(),
+ 'epoch': epoch, }
+ torch.save(to_save, os.path.join(args.output_dir, args.task, "checkpoint-best.pth"))
+ else:
+ if epoch == args.epochs - 1:
+ to_save = {
+ 'model': model_without_ddp.state_dict(), }
+ else:
+ to_save = {
+ 'model': model_without_ddp.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'epoch': epoch,
+ 'args': args,
+ }
+ torch.save(to_save, os.path.join(args.output_dir, args.task, "checkpoint-latest.pth"))
+
+
+def load_model(args, model_without_ddp, optimizer, loss_scaler):
+ if args.resume:
+ if args.resume.startswith('https'):
+ checkpoint = torch.hub.load_state_dict_from_url(
+ args.resume, map_location='cpu', check_hash=True)
+ else:
+ checkpoint = torch.load(args.resume, map_location='cpu')
+ if 'model' in checkpoint:
+ checkpoint_model = checkpoint['model']
+ else:
+ checkpoint_model = checkpoint
+ model_without_ddp.load_state_dict(checkpoint_model, strict=False)
+ print("Resume checkpoint %s" % args.resume)
+ if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ args.start_epoch = checkpoint['epoch'] + 1
+ if 'scaler' in checkpoint:
+ loss_scaler.load_state_dict(checkpoint['scaler'])
+ print("With optim & sched!")
+
+
+def all_reduce_mean(x):
+ world_size = get_world_size()
+ if world_size > 1:
+ x_reduce = torch.tensor(x).cuda()
+ dist.all_reduce(x_reduce)
+ x_reduce /= world_size
+ return x_reduce.item()
+ else:
+ return x
diff --git a/util/pos_embed.py b/util/pos_embed.py
index 4652ff22..11f7128b 100644
--- a/util/pos_embed.py
+++ b/util/pos_embed.py
@@ -1,92 +1,92 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-# Partly revised by YZ @UCL&Moorfields
-# --------------------------------------------------------
-
-import numpy as np
-
-import torch
-
-# --------------------------------------------------------
-# 2D sine-cosine position embedding
-# References:
-# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
-# MoCo v3: https://github.com/facebookresearch/moco-v3
-# --------------------------------------------------------
-def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
- """
- grid_size: int of the grid height and width
- return:
- pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
- """
- grid_h = np.arange(grid_size, dtype=np.float32)
- grid_w = np.arange(grid_size, dtype=np.float32)
- grid = np.meshgrid(grid_w, grid_h) # here w goes first
- grid = np.stack(grid, axis=0)
-
- grid = grid.reshape([2, 1, grid_size, grid_size])
- pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
- if cls_token:
- pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
- return pos_embed
-
-
-def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
- assert embed_dim % 2 == 0
-
- # use half of dimensions to encode grid_h
- emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
- emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
-
- emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
- return emb
-
-
-def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
- """
- embed_dim: output dimension for each position
- pos: a list of positions to be encoded: size (M,)
- out: (M, D)
- """
- assert embed_dim % 2 == 0
- omega = np.arange(embed_dim // 2, dtype=np.float)
- omega /= embed_dim / 2.
- omega = 1. / 10000**omega # (D/2,)
-
- pos = pos.reshape(-1) # (M,)
- out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
-
- emb_sin = np.sin(out) # (M, D/2)
- emb_cos = np.cos(out) # (M, D/2)
-
- emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
- return emb
-
-
-# --------------------------------------------------------
-# Interpolate position embeddings for high-resolution
-# References:
-# DeiT: https://github.com/facebookresearch/deit
-# --------------------------------------------------------
-def interpolate_pos_embed(model, checkpoint_model):
- if 'pos_embed' in checkpoint_model:
- pos_embed_checkpoint = checkpoint_model['pos_embed']
- embedding_size = pos_embed_checkpoint.shape[-1]
- num_patches = model.patch_embed.num_patches
- num_extra_tokens = model.pos_embed.shape[-2] - num_patches
- # height (== width) for the checkpoint position embedding
- orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
- # height (== width) for the new position embedding
- new_size = int(num_patches ** 0.5)
- # class_token and dist_token are kept unchanged
- if orig_size != new_size:
- print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
- extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
- # only the position tokens are interpolated
- pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
- pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
- pos_tokens = torch.nn.functional.interpolate(
- pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
- pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
- new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
- checkpoint_model['pos_embed'] = new_pos_embed
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+# Partly revised by YZ @UCL&Moorfields
+# --------------------------------------------------------
+
+import numpy as np
+
+import torch
+
+# --------------------------------------------------------
+# 2D sine-cosine position embedding
+# References:
+# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
+# MoCo v3: https://github.com/facebookresearch/moco-v3
+# --------------------------------------------------------
+def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ grid_h = np.arange(grid_size, dtype=np.float32)
+ grid_w = np.arange(grid_size, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size, grid_size])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token:
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float)
+ omega /= embed_dim / 2.
+ omega = 1. / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+# --------------------------------------------------------
+# Interpolate position embeddings for high-resolution
+# References:
+# DeiT: https://github.com/facebookresearch/deit
+# --------------------------------------------------------
+def interpolate_pos_embed(model, checkpoint_model):
+ if 'pos_embed' in checkpoint_model:
+ pos_embed_checkpoint = checkpoint_model['pos_embed']
+ embedding_size = pos_embed_checkpoint.shape[-1]
+ num_patches = model.patch_embed.num_patches
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
+ # height (== width) for the checkpoint position embedding
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+ # height (== width) for the new position embedding
+ new_size = int(num_patches ** 0.5)
+ # class_token and dist_token are kept unchanged
+ if orig_size != new_size:
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ checkpoint_model['pos_embed'] = new_pos_embed