-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathmain.py
More file actions
60 lines (50 loc) · 1.48 KB
/
main.py
File metadata and controls
60 lines (50 loc) · 1.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
from torch import nn
import os
from lib import model, trainer, dataset
print('Beginning')
parser = trainer.make_parser()
args = parser.parse_args()
print('args: ', args)
# Run Settings
save_loc = '/data/bball/save_loc'
name = 'tst'
log_dir = '/data/bball/logs'
if not os.path.exists(save_loc):
os.makedirs(save_loc)
seed = 1
batch_size = 32
epoch_limit = 20
criterion = nn.functional.smooth_l1_loss
gpu_device = 1
lr = 0.005
clip = 1
on_policy = False
double_q = False
num_workers = 1
reset_rate = 10000
validate_rate = 50
gamma = 1
torch.cuda.set_device(args.device)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
data = dataset.DummyBballDataset(data_size=1000, seed=seed)
model = model.NothingButNetDQN()
tr = trainer.DeepQTrainer(data=data,
batch_size=batch_size,
epoch_limit=epoch_limit,
criterion=criterion,
save_loc=save_loc,
name=name,
log_dir=log_dir,
gpu_device=gpu_device,
lr=lr,
clip=clip,
on_policy=on_policy,
double_q=double_q,
num_workers=num_workers,
reset_rate=reset_rate,
validate_rate=validate_rate)
print(name)
print('training for {} epochs'.format(epoch_limit))
tr.train(model=model, gamma=gamma)