-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_vqvae.py
More file actions
60 lines (50 loc) · 1.71 KB
/
train_vqvae.py
File metadata and controls
60 lines (50 loc) · 1.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import pytorch_lightning as pl
# from argparse import ArgumentParser
from pytorch_lightning import Trainer
import pytorch_lightning.callbacks as plc
from pytorch_lightning.loggers import TensorBoardLogger
import torch
# from tqdm import tqdm
from models import MInterface
from dataset.get_datasets import build_dataloader
from Utils.base_utils import load_model_path_by_config, ConfigLoader
import logging
logger = logging.getLogger()
def load_callbacks(config):
callbacks = []
callbacks.append(plc.ModelCheckpoint(
monitor='fid',
filename='{epoch:02d}-{fid:.6f}',
save_top_k=1,
mode='min',
))
if config.lr_scheduler:
callbacks.append(plc.LearningRateMonitor(
logging_interval='step'))
return callbacks
def main():
data="stock"
config = ConfigLoader.load_vq_config(config=f"configs/train_vq_{data}.yaml")
logger = TensorBoardLogger(save_dir="log/", name=f"vq_{data}")
pl.seed_everything(config.seed)
load_path = load_model_path_by_config(config)
train_dataloader, val_dataloader=build_dataloader(config)
if load_path is None:
model = MInterface(**config)
else:
print('='*30 + f'{load_path}'+'='*30)
model = MInterface(**config)
checkpoint = torch.load(load_path, map_location=lambda storage, loc: storage)
model.load_state_dict(checkpoint['state_dict'], strict=True)
trainer = Trainer(
devices=config.devices,
max_epochs=5000,
log_every_n_steps=5,
callbacks=load_callbacks(config),
check_val_every_n_epoch=100,
logger=logger
)
trainer.fit(model, train_dataloader, train_dataloader)
if __name__ == '__main__':
main()