forked from kuc2477/pytorch-vae
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
98 lines (85 loc) · 3.35 KB
/
train.py
File metadata and controls
98 lines (85 loc) · 3.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from torch import optim
from torch.autograd import Variable
from tqdm import tqdm
import utils
import visual
def train_model(model, dataset, epochs=10,
batch_size=32, sample_size=32,
lr=3e-04, weight_decay=1e-5,
loss_log_interval=30,
image_log_interval=300,
checkpoint_dir='./checkpoints',
resume=False,
cuda=False):
# prepare optimizer and model
model.train()
optimizer = optim.Adam(
model.parameters(), lr=lr,
weight_decay=weight_decay,
)
if resume:
epoch_start = utils.load_checkpoint(model, checkpoint_dir)
else:
epoch_start = 1
for epoch in range(epoch_start, epochs+1):
data_loader = utils.get_data_loader(dataset, batch_size, cuda=cuda)
data_stream = tqdm(enumerate(data_loader, 1))
for batch_index, (x, _) in data_stream:
# where are we?
iteration = (epoch-1)*(len(dataset)//batch_size) + batch_index
# prepare data on gpu if needed
x = Variable(x).cuda() if cuda else Variable(x)
# flush gradients and run the model forward
optimizer.zero_grad()
(mean, logvar), x_reconstructed = model(x)
reconstruction_loss = model.reconstruction_loss(x_reconstructed, x)
kl_divergence_loss = model.kl_divergence_loss(mean, logvar)
total_loss = reconstruction_loss + kl_divergence_loss
# backprop gradients from the loss
total_loss.backward()
optimizer.step()
# update progress
data_stream.set_description((
'epoch: {epoch} | '
'iteration: {iteration} | '
'progress: [{trained}/{total}] ({progress:.0f}%) | '
'loss => '
'total: {total_loss:.4f} / '
're: {reconstruction_loss:.3f} / '
'kl: {kl_divergence_loss:.3f}'
).format(
epoch=epoch,
iteration=iteration,
trained=batch_index * len(x),
total=len(data_loader.dataset),
progress=(100. * batch_index / len(data_loader)),
total_loss=total_loss.data[0],
reconstruction_loss=reconstruction_loss.data[0],
kl_divergence_loss=kl_divergence_loss.data[0],
))
if iteration % loss_log_interval == 0:
losses = [
reconstruction_loss.data[0],
kl_divergence_loss.data[0],
total_loss.data[0]
]
names = ['reconstruction', 'kl divergence', 'total']
visual.visualize_scalars(
losses, names, 'loss',
iteration, env=model.name)
if iteration % image_log_interval == 0:
images = model.sample(sample_size)
visual.visualize_images(
images, 'generated samples',
env=model.name
)
# notify that we've reached to a new checkpoint.
print()
print()
print('#############')
print('# checkpoint!')
print('#############')
print()
# save the checkpoint.
utils.save_checkpoint(model, checkpoint_dir, epoch)
print()