forked from Dawn5786/ComDPSE
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmultihead_attention.py
More file actions
109 lines (92 loc) · 3.97 KB
/
multihead_attention.py
File metadata and controls
109 lines (92 loc) · 3.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch.nn as nn
import torch
import math
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import pdb
class MultiheadAttention(nn.Module):
def __init__(self, feature_dim=512, n_head=8, key_feature_dim=64): #MultiheadAttention(feature_dim=d_model=512, n_head=1, key_feature_dim=128)
super(MultiheadAttention, self).__init__()
self.Nh = n_head
self.head = nn.ModuleList()
for N in range(self.Nh):
self.head.append(RelationUnit(feature_dim, key_feature_dim)) #(512, 128)
# self.out_conv = nn.Linear(n_head*key_feature_dim, feature_dim) # bias=False
def forward(self, query=None, key=None, value=None):
isFirst = True
for N in range(self.Nh):
if(isFirst):
concat = self.head[N](query, key, value)
isFirst = False
else:
concat = torch.cat((concat, self.head[N](query, key, value)), -1)
# output = self.out_conv(concat)
output = concat
return output
class RelationUnit(nn.Module):
def __init__(self, feature_dim=512, key_feature_dim=64): #(512, 128)
super(RelationUnit, self).__init__()
self.temp = 30
self.WK = nn.Linear(feature_dim, key_feature_dim) # bias=False
# self.WQ = nn.Linear(feature_dim, key_feature_dim)
self.WV = nn.Linear(feature_dim, feature_dim)
# Init weights
for m in self.WK.modules():
m.weight.data.normal_(0, math.sqrt(2. / m.out_features))
# m.weight.data.normal_(1. / m.in_features, math.sqrt(2. / m.out_features))
if m.bias is not None:
m.bias.data.zero_()
'''
for m in self.WQ.modules():
m.weight.data.normal_(0, math.sqrt(2. / m.out_features))
# m.weight.data.normal_(1. / m.in_features, math.sqrt(2. / m.out_features))
if m.bias is not None:
m.bias.data.zero_()
'''
for m in self.WV.modules():
m.weight.data.normal_(0, math.sqrt(2. / m.out_features))
# m.weight.data.normal_(1. / m.in_features, math.sqrt(2. / m.out_features))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, query=None, key=None, value=None):
w_k = self.WK(key)
w_k = F.normalize(w_k, p=2, dim=-1)
w_k = w_k.permute(1, 2, 0) # Batch, Dim, Len_1
w_q = self.WK(query)
w_q = F.normalize(w_q, p=2, dim=-1)
w_q = w_q.permute(1, 0, 2) # Batch, Len_2, Dim
with torch.no_grad():
dot_prod = torch.bmm(w_q, w_k) # Batch, Len_2, Len_1
# affinity = F.softmax(dot_prod*self.temp, dim=-1)
affinity = F.softmax(dot_prod*self.temp, dim=-1)
###################################
# aff_copy = dot_prod.cpu()
#
# plt.cla()
# # plt.imshow(aff_copy[0:3, :, :].permute(1,2,0))
# plt.imshow(1000*aff_copy[0, :, :])
# plt.axis('off')
# plt.axis('equal')
# bug_path = '/data1/lxt/2021projects/work-2021/Hyper-para/Fig7_draw'
# file_name = 'atten_score'
# fig_file = '{}/{}'.format(bug_path, file_name)
# fig_file = '{}.jpg'.format(fig_file)
# Scoremap = plt.gcf()
# Scoremap.savefig(fig_file, dpi=1200)
#####################################
w_v = value.permute(1, 0, 2) # Batch, Len_1, Dim
output = torch.bmm(affinity, w_v) # Batch, Len_2, Dim
output = output.permute(1, 0, 2)
######################################
# with torch.no_grad():
# output_view = output.cpu()
# plt.cla()
# plt.imshow(output_view[:, 2, :])
# plt.axis('off')
# plt.axis('equal')
# map_name = '/data1/lxt/2021projects/work-2021/Hyper-para/Fig7_draw/hotmap.jpg'
# Hotmap = plt.gcf()
# Hotmap.savefig(map_name, dpi=1200)
###################################
return output