torch 自己写的CsiNet找不出哪里错了,后面是网上的正确的CsiNet

这个是CsiNet的网络结构:

img

自己写的:

class RefineNet(nn.Module):
    def __init__(self):
        super(RefineNet, self).__init__()
        ## input size: 2*32*32
        self.conv1 = nn.Sequential(
                        nn.Conv2d(2, 8, kernel_size=3, stride=1, padding=1),
                        nn.BatchNorm2d(8),
                        nn.LeakyReLU(negative_slope=0.3, inplace=True)
                        )
        self.conv2 = nn.Sequential(
                        nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1),
                        nn.BatchNorm2d(16),
                        nn.LeakyReLU(negative_slope=0.3, inplace=True)
                        )
        self.conv3 = nn.Sequential(
                        nn.Conv2d(16, 2, kernel_size=3, stride=1, padding=1),
                        nn.BatchNorm2d(2)
                        )
        self.r1 = nn.LeakyReLU(negative_slope=0.3, inplace=True)
        self.identity = nn.Identity()
    # 前向传播函数
    def forward(self, x):
        ## input size: 2*32*32
        residual = self.identity(x)
        x = self.conv1(x)  # 8*32*32
        x = self.conv2(x)  # 16*32*32
        x = self.conv3(x)  # 2*32*32
        x += residual
        x = self.r1(x)
        return x

class CsiNet(nn.Module):
    def __init__(self):
        super(CsiNet, self).__init__()
        ## input size: 2*32*32
        self.conv1 = nn.Sequential(
                        nn.Conv2d(2, 2, kernel_size=3, stride=1, padding=1),  # 2*32*32
                        nn.BatchNorm2d(2),
                        nn.LeakyReLU(negative_slope=0.3, inplace=True)
                        )
        self.conv2 = nn.Sequential(
                        nn.Conv2d(2, 2, kernel_size=3, stride=1, padding=1),  # 2*32*32
                        nn.BatchNorm2d(2),
                        nn.Sigmoid()
                        )
        self.layer1 = RefineNet()

        self.layer2 = RefineNet()
        
        self.fc1 = nn.Linear(img_total, dim_out)

        self.fc2 = nn.Linear(dim_out, img_total)

    # 前向传播函数
    def forward(self, x):
        ## input size: 2*32*32
        x = self.conv1(x)  ## 2*32*32
        x = x.view(batch_size, img_total)  ## 2048
        x = self.fc1(x)
        x = self.fc2(x)
        x = x.view(batch_size, 2, 32, 32)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.conv2(x)
        return x

网上正确的代码(实在不知道哪里不一样):

class ConvBN(nn.Sequential): # 包含卷积;批次归一化;激活函数
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, groups=1):
        if not isinstance(kernel_size, int):
            padding = [(i - 1) // 2 for i in kernel_size]
        else:
            padding = (kernel_size - 1) // 2 # padding的设置是为了让输出的特征图的大小保持一致
        super(ConvBN, self).__init__(OrderedDict([
            ('conv', nn.Conv2d(in_planes, out_planes, kernel_size, stride,
                               padding=padding, groups=groups, bias=False)), # 为什么bia设置为FALSE呢?
            ('bn', nn.BatchNorm2d(out_planes)), 
            ('relu',nn.LeakyReLU(negative_slope=0.3, inplace=True))
        ]))


class ResBlock(nn.Module):
    def __init__(self):
        super(ResBlock, self).__init__()
        self.direct_path = nn.Sequential(OrderedDict([
            ("conv_1", ConvBN(2, 8, kernel_size=3)),
            ("conv_2", ConvBN(8, 16, kernel_size=3)),
            ("conv_3", nn.Conv2d(16, 2, kernel_size=3, stride=1, padding=1)),
            ("bn", nn.BatchNorm2d(2))
        ]))
        self.identity = nn.Identity()
        self.relu = nn.LeakyReLU(negative_slope=0.3, inplace=True)
    def forward(self, x):
        identity = self.identity(x)
        out = self.direct_path(x)
        out = self.relu(out + identity)
        
        return out

class CsiNet(nn.Module):
    def __init__(self,reduction=16):
        super(CsiNet, self).__init__()
        total_size, in_channel, w, h = 2048, 2, 32, 32
        dim_out = total_size // reduction
        
        self.encoder_convbn = ConvBN(in_channel, 2, kernel_size=3)
        self.encoder_fc = nn.Linear(total_size, dim_out)

        self.decoder_fc = nn.Linear(dim_out, total_size)
        self.decoder_RefineNet1 = ResBlock()
        self.decoder_RefineNet2 = ResBlock()
        self.decoder_conv = nn.Conv2d(2, 2, kernel_size=3, stride=1, padding=1)
        self.decoder_bn = nn.BatchNorm2d(2)
        self.decoder_sigmoid = nn.Sigmoid()

    def forward(self, x):
        n,c, h, w = x.detach().size()
        x = self.encoder_convbn(x)
        x = x.view(n,-1) # 平坦化,reshape
        x = self.encoder_fc(x)
        # 此时x为编码后的输出,需要将x回传给发送端

        x = self.decoder_fc(x)
        x = x.view(n, c, h, w)
        x = self.decoder_RefineNet1(x)
        x = self.decoder_RefineNet2(x)
        x = self.decoder_conv(x)
        x = self.decoder_bn(x)
        x = self.decoder_sigmoid(x)

        return x

定义里的 img_totaldim_out 是什么,外部变量吗?是不是这个的问题?