-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
executable file
·111 lines (77 loc) · 3.85 KB
/
train.py
File metadata and controls
executable file
·111 lines (77 loc) · 3.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from erlich import BaseTrainer, AverageEstimator, Erlich
from erlich.figure import Figure
import neural
from dataset import ConfidenceDataset
class Trainer(BaseTrainer):
def __init__(self, cfg, model_parts, advanced):
super().__init__(cfg, model_parts, advanced)
self.model = model_parts["model"]
self.model.train()
self.size = cfg.get("size", 128+64)
def get_dataloader(self, batch_size):
train_dataset = ConfidenceDataset(["train_data.tar"], size=self.size)
return DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=False, drop_last=True)
def get_validation_dataloader(self, validation_batch_size):
val_dataset = ConfidenceDataset([f"val_data.tar"], size=self.size, validation=True)
return DataLoader(val_dataset, batch_size=validation_batch_size, shuffle=False, num_workers=4, pin_memory=False, drop_last=False)
def get_train_metrics(self):
return {
"loss": AverageEstimator("Loss"),
"supervision": AverageEstimator("Supervision"),
"physics": AverageEstimator("Physics"),
"smoothness": AverageEstimator("Smoothness"),
"monotonicity": AverageEstimator("Monotonicity"),
}
def before_train_epoch(self, epoch: int):
self.model.train()
def before_validation(self, epoch: int, train_batch: int):
self.model.eval()
def validation_step(self, batch, batch_idx):
loss, supervision, physics, smoothness = self.step(batch, batch_idx, "val.png")
return {"loss": loss, "supervision": supervision, "physics": physics, "smoothness": smoothness}
def step(self, batch, batch_idx, img_name="out.png"):
x, good, bad = batch
y = self.model(x)
alpha = 5 - 1
log_sig_pos = F.logsigmoid(y)
log_sig_neg = F.logsigmoid(-y)
# Create a binary mask that is 1 on unannotated pixels and 0 elsewhere
empty = -(good + bad - 1)
# Supervision coming from the annotations
supervision = (-alpha*log_sig_pos * good - alpha*log_sig_neg * bad - 0.1*(log_sig_pos + log_sig_neg) * empty).mean()
y_sigmoid = torch.sigmoid(y)
# Physics prior: y should be monotonic, with a slope of at least 0.3/256
y_diff = log_sig_pos[:, :, 1:] - log_sig_pos[:, :, :-1]
physics = torch.mean(torch.relu(y_diff + 0.3 / 256)) # + 0.01 * torch.mean((y/6)**2) # - x_diff))
# Smoothness prior: y should be horizontally smooth
smoothness = torch.mean((y_sigmoid[:, :, :, 1:] - y_sigmoid[:, :, :, :-1])**2)
if batch_idx % 100 == 0:
bs = x.size(0)
figure = Figure()
for i in range(min(bs, 8)):
figure.add_image("Frame", x[i, 0])
figure.add_image("Labels", good[i, 0] - bad[i, 0], vmin=-1, vmax=1)
figure.add_image("Prediction", torch.sigmoid(y[i, 0]), vmin=0, vmax=1)
figure.save(img_name, 300)
loss = supervision + 75 * physics + 0.25 * smoothness
return loss, supervision, physics, smoothness
def train_step(self, batch, batch_idx, train_metrics):
loss, supervision, physics, smoothness = self.step(batch, batch_idx)
train_metrics["loss"].update(loss)
train_metrics["supervision"].update(supervision)
train_metrics["physics"].update(physics)
train_metrics["smoothness"].update(smoothness)
return loss
def main():
np.random.seed(701)
torch.cuda.manual_seed_all(701)
erlich = Erlich("erlich-models")
cfg = erlich.config_from_cli()
erlich.train(Trainer, cfg, devices=[0], logger_min_wait=5, files_to_save=["train.py", "dataset.py", "neural.py"])
if __name__ == "__main__":
main()