forked from Dawn5786/ComDPSE
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmultihead_attention_complex.py
More file actions
84 lines (68 loc) · 3.08 KB
/
multihead_attention_complex.py
File metadata and controls
84 lines (68 loc) · 3.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch.nn as nn
import torch
import math
import torch.nn.functional as F
import numpy as np
import pdb
# class MultiheadAttention(nn.Module):
class ComMultiheadAttention(nn.Module):
def __init__(self, feature_dim=512, n_head=8, key_feature_dim=64): #MultiheadAttention(feature_dim=d_model=512, n_head=1, key_feature_dim=128)
# super(MultiheadAttention, self).__init__()
super(ComMultiheadAttention, self).__init__()
self.Nh = n_head
self.head = nn.ModuleList()
for N in range(self.Nh):
# self.head.append(RelationUnit(feature_dim, key_feature_dim)) #(512, 128)
self.head.append(ComRelationUnit(feature_dim, key_feature_dim)) #(512, 128)
# self.out_conv = nn.Linear(n_head*key_feature_dim, feature_dim) # bias=False
def forward(self, query=None, key=None, value=None):
isFirst = True
for N in range(self.Nh):
if(isFirst):
concat = self.head[N](query, key, value)
isFirst = False
else:
concat = torch.cat((concat, self.head[N](query, key, value)), -1)
# output = self.out_conv(concat)
output = concat
return output
# class RelationUnit(nn.Module):
class ComRelationUnit(nn.Module):
def __init__(self, feature_dim=512, key_feature_dim=64): #(512, 128)
# super(RelationUnit, self).__init__()
super(ComRelationUnit, self).__init__()
self.temp = 30
self.WK = nn.Linear(feature_dim, key_feature_dim) # bias=False
# self.WQ = nn.Linear(feature_dim, key_feature_dim)
self.WV = nn.Linear(feature_dim, feature_dim)
# Init weights
for m in self.WK.modules():
m.weight.data.normal_(0, math.sqrt(2. / m.out_features))
# m.weight.data.normal_(1. / m.in_features, math.sqrt(2. / m.out_features))
if m.bias is not None:
m.bias.data.zero_()
'''
for m in self.WQ.modules():
m.weight.data.normal_(0, math.sqrt(2. / m.out_features))
# m.weight.data.normal_(1. / m.in_features, math.sqrt(2. / m.out_features))
if m.bias is not None:
m.bias.data.zero_()
'''
for m in self.WV.modules():
m.weight.data.normal_(0, math.sqrt(2. / m.out_features))
# m.weight.data.normal_(1. / m.in_features, math.sqrt(2. / m.out_features))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, query=None, key=None, value=None):
w_k = self.WK(key)
w_k = F.normalize(w_k, p=2, dim=-1)
w_k = w_k.permute(1,2,0) # Batch, Dim, Len_1
w_q = self.WK(query)
w_q = F.normalize(w_q, p=2, dim=-1)
w_q = w_q.permute(1,0,2) # Batch, Len_2, Dim
dot_prod = torch.bmm(w_q, w_k) # Batch, Len_2, Len_1
affinity = F.softmax(dot_prod*self.temp, dim=-1)
w_v = value.permute(1,0,2) # Batch, Len_1, Dim
output = torch.bmm(affinity, w_v) # Batch, Len_2, Dim
output = output.permute(1,0,2)
return output