Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ __pycache__/
# Distribution / packaging
.Python
build/
checkpoints
Data
develop-eggs/
dist/
downloads/
Expand All @@ -18,6 +20,7 @@ lib/
lib64/
parts/
sdist/
segmentation_output
var/
wheels/
pip-wheel-metadata/
Expand Down Expand Up @@ -104,6 +107,7 @@ celerybeat.pid
# Environments
.env
.venv
venv3.9
env/
venv/
ENV/
Expand Down
296 changes: 148 additions & 148 deletions engine_finetune.py
Original file line number Diff line number Diff line change
@@ -1,148 +1,148 @@
import os
import csv
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from typing import Iterable, Optional
from timm.data import Mixup
from timm.utils import accuracy
from sklearn.metrics import (
accuracy_score, roc_auc_score, f1_score, average_precision_score,
hamming_loss, jaccard_score, recall_score, precision_score, cohen_kappa_score
)
from pycm import ConfusionMatrix
import util.misc as misc
import util.lr_sched as lr_sched
def train_one_epoch(
model: torch.nn.Module,
criterion: torch.nn.Module,
data_loader: Iterable,
optimizer: torch.optim.Optimizer,
device: torch.device,
epoch: int,
loss_scaler,
max_norm: float = 0,
mixup_fn: Optional[Mixup] = None,
log_writer=None,
args=None
):
"""Train the model for one epoch."""
model.train(True)
metric_logger = misc.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
print_freq, accum_iter = 20, args.accum_iter
optimizer.zero_grad()
if log_writer:
print(f'log_dir: {log_writer.log_dir}')
for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, f'Epoch: [{epoch}]')):
if data_iter_step % accum_iter == 0:
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
samples, targets = samples.to(device, non_blocking=True), targets.to(device, non_blocking=True)
if mixup_fn:
samples, targets = mixup_fn(samples, targets)
with torch.cuda.amp.autocast():
outputs = model(samples)
loss = criterion(outputs, targets)
loss_value = loss.item()
loss /= accum_iter
loss_scaler(loss, optimizer, clip_grad=max_norm, parameters=model.parameters(), create_graph=False,
update_grad=(data_iter_step + 1) % accum_iter == 0)
if (data_iter_step + 1) % accum_iter == 0:
optimizer.zero_grad()
torch.cuda.synchronize()
metric_logger.update(loss=loss_value)
min_lr = 10.
max_lr = 0.
for group in optimizer.param_groups:
min_lr = min(min_lr, group["lr"])
max_lr = max(max_lr, group["lr"])
metric_logger.update(lr=max_lr)
loss_value_reduce = misc.all_reduce_mean(loss_value)
if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
""" We use epoch_1000x as the x-axis in tensorboard.
This calibrates different curves when batch size changes.
"""
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
log_writer.add_scalar('loss/train', loss_value_reduce, epoch_1000x)
log_writer.add_scalar('lr', max_lr, epoch_1000x)
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
@torch.no_grad()
def evaluate(data_loader, model, device, args, epoch, mode, num_class, log_writer):
"""Evaluate the model."""
criterion = nn.CrossEntropyLoss()
metric_logger = misc.MetricLogger(delimiter=" ")
os.makedirs(os.path.join(args.output_dir, args.task), exist_ok=True)
model.eval()
true_onehot, pred_onehot, true_labels, pred_labels, pred_softmax = [], [], [], [], []
for batch in metric_logger.log_every(data_loader, 10, f'{mode}:'):
images, target = batch[0].to(device, non_blocking=True), batch[-1].to(device, non_blocking=True)
target_onehot = F.one_hot(target.to(torch.int64), num_classes=num_class)
with torch.cuda.amp.autocast():
output = model(images)
loss = criterion(output, target)
output_ = nn.Softmax(dim=1)(output)
output_label = output_.argmax(dim=1)
output_onehot = F.one_hot(output_label.to(torch.int64), num_classes=num_class)
metric_logger.update(loss=loss.item())
true_onehot.extend(target_onehot.cpu().numpy())
pred_onehot.extend(output_onehot.detach().cpu().numpy())
true_labels.extend(target.cpu().numpy())
pred_labels.extend(output_label.detach().cpu().numpy())
pred_softmax.extend(output_.detach().cpu().numpy())
accuracy = accuracy_score(true_labels, pred_labels)
hamming = hamming_loss(true_onehot, pred_onehot)
jaccard = jaccard_score(true_onehot, pred_onehot, average='macro')
average_precision = average_precision_score(true_onehot, pred_softmax, average='macro')
kappa = cohen_kappa_score(true_labels, pred_labels)
f1 = f1_score(true_onehot, pred_onehot, zero_division=0, average='macro')
roc_auc = roc_auc_score(true_onehot, pred_softmax, multi_class='ovr', average='macro')
precision = precision_score(true_onehot, pred_onehot, zero_division=0, average='macro')
recall = recall_score(true_onehot, pred_onehot, zero_division=0, average='macro')
score = (f1 + roc_auc + kappa) / 3
if log_writer:
for metric_name, value in zip(['accuracy', 'f1', 'roc_auc', 'hamming', 'jaccard', 'precision', 'recall', 'average_precision', 'kappa', 'score'],
[accuracy, f1, roc_auc, hamming, jaccard, precision, recall, average_precision, kappa, score]):
log_writer.add_scalar(f'perf/{metric_name}', value, epoch)
print(f'val loss: {metric_logger.meters["loss"].global_avg}')
print(f'Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}, ROC AUC: {roc_auc:.4f}, Hamming Loss: {hamming:.4f},\n'
f' Jaccard Score: {jaccard:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f},\n'
f' Average Precision: {average_precision:.4f}, Kappa: {kappa:.4f}, Score: {score:.4f}')
metric_logger.synchronize_between_processes()
results_path = os.path.join(args.output_dir, args.task, f'metrics_{mode}.csv')
file_exists = os.path.isfile(results_path)
with open(results_path, 'a', newline='', encoding='utf8') as cfa:
wf = csv.writer(cfa)
if not file_exists:
wf.writerow(['val_loss', 'accuracy', 'f1', 'roc_auc', 'hamming', 'jaccard', 'precision', 'recall', 'average_precision', 'kappa'])
wf.writerow([metric_logger.meters["loss"].global_avg, accuracy, f1, roc_auc, hamming, jaccard, precision, recall, average_precision, kappa])
if mode == 'test':
cm = ConfusionMatrix(actual_vector=true_labels, predict_vector=pred_labels)
cm.plot(cmap=plt.cm.Blues, number_label=True, normalized=True, plot_lib="matplotlib")
plt.savefig(os.path.join(args.output_dir, args.task, 'confusion_matrix_test.jpg'), dpi=600, bbox_inches='tight')
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, score
import os
import csv
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from typing import Iterable, Optional
from timm.data import Mixup
from timm.utils import accuracy
from sklearn.metrics import (
accuracy_score, roc_auc_score, f1_score, average_precision_score,
hamming_loss, jaccard_score, recall_score, precision_score, cohen_kappa_score
)
from pycm import ConfusionMatrix
import util.misc as misc
import util.lr_sched as lr_sched

def train_one_epoch(
model: torch.nn.Module,
criterion: torch.nn.Module,
data_loader: Iterable,
optimizer: torch.optim.Optimizer,
device: torch.device,
epoch: int,
loss_scaler,
max_norm: float = 0,
mixup_fn: Optional[Mixup] = None,
log_writer=None,
args=None
):
"""Train the model for one epoch."""
model.train(True)
metric_logger = misc.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
print_freq, accum_iter = 20, args.accum_iter
optimizer.zero_grad()

if log_writer:
print(f'log_dir: {log_writer.log_dir}')

for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, f'Epoch: [{epoch}]')):
if data_iter_step % accum_iter == 0:
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)

samples, targets = samples.to(device, non_blocking=True), targets.to(device, non_blocking=True)
if mixup_fn:
samples, targets = mixup_fn(samples, targets)

with torch.cuda.amp.autocast():
outputs = model(samples)
loss = criterion(outputs, targets)
loss_value = loss.item()
loss /= accum_iter

loss_scaler(loss, optimizer, clip_grad=max_norm, parameters=model.parameters(), create_graph=False,
update_grad=(data_iter_step + 1) % accum_iter == 0)
if (data_iter_step + 1) % accum_iter == 0:
optimizer.zero_grad()

torch.cuda.synchronize()
metric_logger.update(loss=loss_value)
min_lr = 10.
max_lr = 0.
for group in optimizer.param_groups:
min_lr = min(min_lr, group["lr"])
max_lr = max(max_lr, group["lr"])

metric_logger.update(lr=max_lr)

loss_value_reduce = misc.all_reduce_mean(loss_value)
if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
""" We use epoch_1000x as the x-axis in tensorboard.
This calibrates different curves when batch size changes.
"""
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
log_writer.add_scalar('loss/train', loss_value_reduce, epoch_1000x)
log_writer.add_scalar('lr', max_lr, epoch_1000x)

metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

@torch.no_grad()
def evaluate(data_loader, model, device, args, epoch, mode, num_class, log_writer):
"""Evaluate the model."""
criterion = nn.CrossEntropyLoss()
metric_logger = misc.MetricLogger(delimiter=" ")
os.makedirs(os.path.join(args.output_dir, args.task), exist_ok=True)

model.eval()
true_onehot, pred_onehot, true_labels, pred_labels, pred_softmax = [], [], [], [], []

for batch in metric_logger.log_every(data_loader, 10, f'{mode}:'):
images, target = batch[0].to(device, non_blocking=True), batch[-1].to(device, non_blocking=True)
target_onehot = F.one_hot(target.to(torch.int64), num_classes=num_class)

with torch.cuda.amp.autocast():
output = model(images)
loss = criterion(output, target)
output_ = nn.Softmax(dim=1)(output)
output_label = output_.argmax(dim=1)
output_onehot = F.one_hot(output_label.to(torch.int64), num_classes=num_class)

metric_logger.update(loss=loss.item())
true_onehot.extend(target_onehot.cpu().numpy())
pred_onehot.extend(output_onehot.detach().cpu().numpy())
true_labels.extend(target.cpu().numpy())
pred_labels.extend(output_label.detach().cpu().numpy())
pred_softmax.extend(output_.detach().cpu().numpy())

accuracy = accuracy_score(true_labels, pred_labels)
hamming = hamming_loss(true_onehot, pred_onehot)
jaccard = jaccard_score(true_onehot, pred_onehot, average='macro')
average_precision = average_precision_score(true_onehot, pred_softmax, average='macro')
kappa = cohen_kappa_score(true_labels, pred_labels)
f1 = f1_score(true_onehot, pred_onehot, zero_division=0, average='macro')
roc_auc = roc_auc_score(true_onehot, pred_softmax, multi_class='ovr', average='macro')
precision = precision_score(true_onehot, pred_onehot, zero_division=0, average='macro')
recall = recall_score(true_onehot, pred_onehot, zero_division=0, average='macro')

score = (f1 + roc_auc + kappa) / 3
if log_writer:
for metric_name, value in zip(['accuracy', 'f1', 'roc_auc', 'hamming', 'jaccard', 'precision', 'recall', 'average_precision', 'kappa', 'score'],
[accuracy, f1, roc_auc, hamming, jaccard, precision, recall, average_precision, kappa, score]):
log_writer.add_scalar(f'perf/{metric_name}', value, epoch)

print(f'val loss: {metric_logger.meters["loss"].global_avg}')
print(f'Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}, ROC AUC: {roc_auc:.4f}, Hamming Loss: {hamming:.4f},\n'
f' Jaccard Score: {jaccard:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f},\n'
f' Average Precision: {average_precision:.4f}, Kappa: {kappa:.4f}, Score: {score:.4f}')

metric_logger.synchronize_between_processes()

results_path = os.path.join(args.output_dir, args.task, f'metrics_{mode}.csv')
file_exists = os.path.isfile(results_path)
with open(results_path, 'a', newline='', encoding='utf8') as cfa:
wf = csv.writer(cfa)
if not file_exists:
wf.writerow(['val_loss', 'accuracy', 'f1', 'roc_auc', 'hamming', 'jaccard', 'precision', 'recall', 'average_precision', 'kappa'])
wf.writerow([metric_logger.meters["loss"].global_avg, accuracy, f1, roc_auc, hamming, jaccard, precision, recall, average_precision, kappa])

if mode == 'test':
cm = ConfusionMatrix(actual_vector=true_labels, predict_vector=pred_labels)
cm.plot(cmap=plt.cm.Blues, number_label=True, normalized=True, plot_lib="matplotlib")
plt.savefig(os.path.join(args.output_dir, args.task, 'confusion_matrix_test.jpg'), dpi=600, bbox_inches='tight')

return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, score
Loading