-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathAttr.py
More file actions
48 lines (35 loc) · 1.32 KB
/
Attr.py
File metadata and controls
48 lines (35 loc) · 1.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
import torch.nn as nn
import torch.nn.functional as F
import utils
import numpy as np
from torch.autograd import Variable
class Attr(nn.Module):
embed_dims = [('driverID', 13000, 8), ('weekID', 7, 3), ('timeID', 96, 8)]
def __init__(self):
super(Attr, self).__init__()
# whether to add the two ends of the path into Attribute Component
self.build()
def build(self):
for name, dim_in, dim_out in Attr.embed_dims:
self.add_module(name + '_em', nn.Embedding(dim_in, dim_out))
# for module in self.modules():
# if type(module) is not nn.Embedding:
# continue
# nn.init.uniform_(module.state_dict()['weight'], a=-1, b=1)
def out_size(self):
sz = 0
for name, dim_in, dim_out in Attr.embed_dims:
sz += dim_out
return sz + 2
def forward(self, attr):
em_list = []
for name, dim_in, dim_out in Attr.embed_dims:
embed = getattr(self, name + '_em')
attr_t = attr[name].view(-1, 1)
attr_t = torch.squeeze(embed(attr_t))
em_list.append(attr_t)
dist = attr['dist']
em_list.append(dist.view(-1, 1))
em_list.append(attr['dateID'].float().view(-1, 1))
return torch.cat(em_list, dim = 1)