-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmodel.py
More file actions
88 lines (74 loc) · 2.86 KB
/
model.py
File metadata and controls
88 lines (74 loc) · 2.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
import copy
class SwAVModel(nn.Module):
def __init__(self, backbone, sem_seg_head, projector, prototypes):
super(SwAVModel, self).__init__()
self.backbone = backbone
self.sem_seg_head = sem_seg_head
self.projector = projector
self.prototypes = prototypes
def forward(self, x, predict = False):
features = self.backbone(x)
representation = self.sem_seg_head(features)
projection = self.projector(representation)
if predict:
projection = F.normalize(projection, dim=1, p=2)
similarity = self.prototypes(projection)
return representation, projection, similarity
else:
return representation, projection
class MLP(nn.Module):
def __init__(self, in_channels, hidden_channels = None, out_channels = None, layer_num = 2, get_intermediate = False):
super(MLP, self).__init__()
self.in_channels = in_channels
if hidden_channels is None:
self.hidden_channels = self.in_channels
else:
self.hidden_channels = hidden_channels
if out_channels is None:
self.out_channels = self.hidden_channels
else:
self.out_channels = out_channels
self.layer_num = layer_num
self.get_intermediate = get_intermediate
layers_list = []
in_dim = self.in_channels
for i in range(self.layer_num):
layers_list.append(("layer"+str(i), MLP._block(in_dim, self.hidden_channels)))
in_dim = self.hidden_channels
self.layers = nn.Sequential(OrderedDict(layers_list))
if not self.get_intermediate:
self.convs = nn.Conv2d(in_channels=self.hidden_channels,
out_channels=self.out_channels,
kernel_size=1,
padding=0,
bias=False
)
def forward(self, representation):
if self.get_intermediate:
return self.layers(representation)
else:
return self.convs(self.layers(representation))
@staticmethod
def _block(in_channels, features):
return nn.Sequential(
OrderedDict(
[
(
"conv",
nn.Conv2d(
in_channels=in_channels,
out_channels=features,
kernel_size=1,
padding=0,
bias=True,
),
),
("norm", nn.BatchNorm2d(num_features=features)),
("relu", nn.ReLU(inplace=True))
]
)
)