作者你好,我想改动一个新的结构,是在SE的地方改动的,有点困惑,mxnet 的symbol,不能直接得到bchw的值,
pytorch 的SGE,一个实现架构语句, 对应你提供的模型SE代码位置修改的话,symbol每一层bn3 后边的bchw,我直接得不到,我要mxnet,实现这句话,b, c, h, w = x.size(), x = x.reshape(b * self.groups, -1, h, w) 我对mxnet 不是那么熟悉,不知道作者你有没有好的方式实现这句reshape
我在修改的地方
bn3 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn3')
#if use_se:
if usr_sge:
得到 bn3的 bchw
然后reshape
下面是对应pytorch 实现
class SpatialGroupEnhance(nn.Module): # 3 2 1 hw is half, 311 is same size
def init(self, groups = 64):
super(SpatialGroupEnhance, self).init()
self.groups = groups
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.weight = Parameter(torch.zeros(1, groups, 1, 1))
self.bias = Parameter(torch.ones(1, groups, 1, 1))
self.sig = nn.Sigmoid()
def forward(self, x): # (b, c, h, w)
b, c, h, w = x.size()
x = x.view(b * self.groups, -1, h, w) ##reshape
xn = x * self.avg_pool(x) # x * global pooling(h,w change 1)
xn = xn.sum(dim=1, keepdim=True) #(b,1,h,w)
t = xn.view(b * self.groups, -1)
t = t - t.mean(dim=1, keepdim=True)
std = t.std(dim=1, keepdim=True) + 1e-5
t = t / std # normalize -mean/std
t = t.view(b, self.groups, h, w)
t = t * self.weight + self.bias
t = t.view(b * self.groups, 1, h, w)
x = x * self.sig(t) #in order to sigmod facter,this is group factor (0-1)
x = x.view(b, c, h, w) #get to varying degrees of importance,Restoration dimension
return x
作者你好,我想改动一个新的结构,是在SE的地方改动的,有点困惑,mxnet 的symbol,不能直接得到bchw的值,
pytorch 的SGE,一个实现架构语句, 对应你提供的模型SE代码位置修改的话,symbol每一层bn3 后边的bchw,我直接得不到,我要mxnet,实现这句话,b, c, h, w = x.size(), x = x.reshape(b * self.groups, -1, h, w) 我对mxnet 不是那么熟悉,不知道作者你有没有好的方式实现这句reshape
我在修改的地方
bn3 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn3')
#if use_se:
if usr_sge:
得到 bn3的 bchw
然后reshape
下面是对应pytorch 实现
class SpatialGroupEnhance(nn.Module): # 3 2 1 hw is half, 311 is same size
def init(self, groups = 64):
super(SpatialGroupEnhance, self).init()
self.groups = groups
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.weight = Parameter(torch.zeros(1, groups, 1, 1))
self.bias = Parameter(torch.ones(1, groups, 1, 1))
self.sig = nn.Sigmoid()
def forward(self, x): # (b, c, h, w)
b, c, h, w = x.size()
x = x.view(b * self.groups, -1, h, w) ##reshape
xn = x * self.avg_pool(x) # x * global pooling(h,w change 1)
xn = xn.sum(dim=1, keepdim=True) #(b,1,h,w)
t = xn.view(b * self.groups, -1)
t = t - t.mean(dim=1, keepdim=True)
std = t.std(dim=1, keepdim=True) + 1e-5
t = t / std # normalize -mean/std
t = t.view(b, self.groups, h, w)
t = t * self.weight + self.bias
t = t.view(b * self.groups, 1, h, w)
x = x * self.sig(t) #in order to sigmod facter,this is group factor (0-1)
x = x.view(b, c, h, w) #get to varying degrees of importance,Restoration dimension
return x