Skip to content

Update fusion.py#27

Open
littleSpongebob wants to merge 1 commit into
YimianDai:masterfrom
littleSpongebob:patch-1
Open

Update fusion.py#27
littleSpongebob wants to merge 1 commit into
YimianDai:masterfrom
littleSpongebob:patch-1

Conversation

@littleSpongebob

Copy link
Copy Markdown

第二次全局注意力模块没有用到

第二次全局注意力模块没有用到
@YimianDai

Copy link
Copy Markdown
Owner

用到了吧,代码如下

class ResGlobLocaforGlobLocaChaFuse(HybridBlock):
    def __init__(self, channels=64, r=4):
        super(ResGlobLocaforGlobLocaChaFuse, self).__init__()
        inter_channels = int(channels // r)

        with self.name_scope():

            self.local_att = nn.HybridSequential(prefix='local_att')
            self.local_att.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0))
            self.local_att.add(nn.BatchNorm())
            self.local_att.add(nn.Activation('relu'))
            self.local_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0))
            self.local_att.add(nn.BatchNorm())

            self.global_att = nn.HybridSequential(prefix='global_att')
            self.global_att.add(nn.GlobalAvgPool2D())
            self.global_att.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0))
            self.global_att.add(nn.BatchNorm())
            self.global_att.add(nn.Activation('relu'))
            self.global_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0))
            self.global_att.add(nn.BatchNorm())

            self.local_att2 = nn.HybridSequential(prefix='local_att2')
            self.local_att2.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0))
            self.local_att2.add(nn.BatchNorm())
            self.local_att2.add(nn.Activation('relu'))
            self.local_att2.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0))
            self.local_att2.add(nn.BatchNorm())

            self.global_att2 = nn.HybridSequential(prefix='global_att2')
            self.global_att2.add(nn.GlobalAvgPool2D())
            self.global_att2.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0))
            self.global_att2.add(nn.BatchNorm())
            self.global_att2.add(nn.Activation('relu'))
            self.global_att2.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0))
            self.global_att2.add(nn.BatchNorm())

            self.sig1 = nn.Activation('sigmoid')
            self.sig2 = nn.Activation('sigmoid')

    def hybrid_forward(self, F, x, residual):

        xa = x + residual
        xl = self.local_att(xa)
        xg = self.global_att(xa)
        xlg = F.broadcast_add(xl, xg)
        wei = self.sig1(xlg)

        xi = F.broadcast_mul(x, wei) + F.broadcast_mul(residual, 1-wei)
        xl2 = self.local_att2(xi)
        xg2 = self.global_att2(xi)
        xlg2 = F.broadcast_add(xl2, xg2)
        wei2 = self.sig2(xlg2)
        xo = F.broadcast_mul(x, wei2) + F.broadcast_mul(residual, 1-wei2)

        return xo

@littleSpongebob

Copy link
Copy Markdown
Author

好吧,我看的是aff_pytorch这个文件夹里的fusion.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants