-
Notifications
You must be signed in to change notification settings - Fork 14
Expand file tree
/
Copy pathrun_vr.py
More file actions
97 lines (86 loc) · 3.55 KB
/
run_vr.py
File metadata and controls
97 lines (86 loc) · 3.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import copy
import torch
import random
import numpy as np
import argparse
from dataloader import KUAIRECDataLoader
from model import WideAndDeep
from utils import eval_mae, eval_xauc, eval_kl
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_name', default='kuairec')
parser.add_argument('--dataset_path', default='./dataset/')
parser.add_argument('--device', default='cuda:0')
parser.add_argument('--bsz', type=int, default=2048)
parser.add_argument('--log_interval', type=int, default=10)
parser.add_argument('--epoch', type=int, default=10)
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument('--weight_decay', type=float, default=1e-6)
parser.add_argument('--seed', type=int, default=42)
args = parser.parse_args()
return args
def get_loaders(name, dataset_path, device, bsz):
path = os.path.join(dataset_path, name, "{}_data.pkl".format(name))
if name == 'kuairec':
dataloaders = KUAIRECDataLoader(name, path, device, bsz=bsz)
else:
raise ValueError('unkown dataset name: {}'.format(name))
return dataloaders
def mae_rescale_to_second(dataset, mae):
if dataset == 'kuairec':
return mae * 999639 / 1000
elif dataset == 'wechat':
return mae * 20840
elif dataset == 'cikm16':
return (mae *(6000-31) + 31) / 1000
else:
raise ValueError('unkown dataset name: {}'.format(dataset))
def test(args, model, dataloaders):
model.eval()
labels, scores, predicts = list(), list(), list()
with torch.no_grad():
for _, (features, label) in enumerate(dataloaders['test']):
y = model(features) /20
labels.extend(label.tolist())
scores.extend(y.tolist())
labels, scores = np.array(labels), np.array(scores)
mae, xauc, kl = eval_mae(labels, scores), eval_xauc(labels, scores), eval_kl(labels, scores)
mae = mae_rescale_to_second(args.dataset_name, mae)
print("test result | MAE: {:.7f} | XAUC: {:.7f} | KL: {:.7f}".format(mae, xauc, kl))
if __name__ == '__main__':
args = get_args()
if args.seed > -1:
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
res = {}
torch.cuda.empty_cache()
device = torch.device(args.device)
dataloaders = get_loaders(args.dataset_name, args.dataset_path, device, args.bsz)
model = WideAndDeep(dataloaders.description, embed_dim=16, mlp_dims=(512, 256, 128, 64, 32), dropout=0.0)
model = model.to(device)
# train
dataloader_train = dataloaders['train']
model.train()
# criterion = torch.nn.BCELoss()
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
for epoch_i in range(1, args.epoch + 1):
model.train()
epoch_loss = 0.0
total_loss = 0
total_iters = len(dataloader_train)
for i, (features, label) in enumerate(dataloader_train):
y = model(features)
loss = criterion(y,20 * label.float())
model.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
total_loss += loss.item()
if (i + 1) % 10 == 0:
print(" Iter {}/{} loss: {:.7f}".format(i + 1, total_iters + 1, total_loss/args.log_interval), end='\r')
total_loss = 0
print("Epoch {}/{} average Loss: {:.7f}".format(epoch_i, args.epoch, epoch_loss/total_iters))
test(args, model, dataloaders)