Skip to content

RuntimeError: The size of tensor a (11) must match the size of tensor b (10) at non-singleton dimension 3 #3

@QingdaChen

Description

@QingdaChen

This is my code :

class Accurate_Modle((torch.nn.Module)):
def init(self,data_kind,padding):
super(Accurate_Modle, self).init()
self.con6 = torch.nn.Conv2d(2*112,384, kernel_size=1, stride=1, padding=0)
self.con7 = torch.nn.Conv2d(384, 384, kernel_size=1, stride=1, padding=0)
self.con8 = torch.nn.Conv2d(384, 384, kernel_size=1, stride=1, padding=0)
self.con9 = torch.nn.Conv2d(384, 1, kernel_size=1, stride=1, padding=0)
if data_kind=='mb':

        # torch.nn.init.constant_(self.con5.bias, 0)
        self.conv1 = torch.nn.Sequential(OctConv2d('first', in_channels=1, out_channels=112, kernel_size=3),
                                   OctReLU(),
                                   OctConv2d('regular', in_channels=112, out_channels=112, kernel_size=3),
                                   OctReLU(),
                                   OctConv2d('regular', in_channels=112, out_channels=112, kernel_size=3),
                                   OctReLU(),
                                   OctConv2d('regular', in_channels=112, out_channels=112, kernel_size=3),
                                   OctReLU(),
                                   OctConv2d('last', in_channels=112, out_channels=112, kernel_size=3),
                                   nn.ReLU(),
                                         )
        self.full = torch.nn.Sequential(
                                      self.con6,
                                      torch.nn.ReLU(),
                                      self.con7,
                                      torch.nn.ReLU(),
                                      self.con8,
                                      torch.nn.ReLU(),
                                      self.con9,
                                      torch.nn.Sigmoid()
                                )
    # elif data_kind=='kitt':
    #    #...



def forward(self,x0,x1,flag):
   if flag=='train':
       y0 = self.conv1(x0)  # left_patch
       y1 = self.conv1(x1)
       y3 = torch.cat((y0, y1), 1)
       # print(y3.shape)
       y = self.full(y3)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions