这个是CsiNet的网络结构:
自己写的:
class RefineNet(nn.Module):
def __init__(self):
super(RefineNet, self).__init__()
## input size: 2*32*32
self.conv1 = nn.Sequential(
nn.Conv2d(2, 8, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(8),
nn.LeakyReLU(negative_slope=0.3, inplace=True)
)
self.conv2 = nn.Sequential(
nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(16),
nn.LeakyReLU(negative_slope=0.3, inplace=True)
)
self.conv3 = nn.Sequential(
nn.Conv2d(16, 2, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(2)
)
self.r1 = nn.LeakyReLU(negative_slope=0.3, inplace=True)
self.identity = nn.Identity()
# 前向传播函数
def forward(self, x):
## input size: 2*32*32
residual = self.identity(x)
x = self.conv1(x) # 8*32*32
x = self.conv2(x) # 16*32*32
x = self.conv3(x) # 2*32*32
x += residual
x = self.r1(x)
return x
class CsiNet(nn.Module):
def __init__(self):
super(CsiNet, self).__init__()
## input size: 2*32*32
self.conv1 = nn.Sequential(
nn.Conv2d(2, 2, kernel_size=3, stride=1, padding=1), # 2*32*32
nn.BatchNorm2d(2),
nn.LeakyReLU(negative_slope=0.3, inplace=True)
)
self.conv2 = nn.Sequential(
nn.Conv2d(2, 2, kernel_size=3, stride=1, padding=1), # 2*32*32
nn.BatchNorm2d(2),
nn.Sigmoid()
)
self.layer1 = RefineNet()
self.layer2 = RefineNet()
self.fc1 = nn.Linear(img_total, dim_out)
self.fc2 = nn.Linear(dim_out, img_total)
# 前向传播函数
def forward(self, x):
## input size: 2*32*32
x = self.conv1(x) ## 2*32*32
x = x.view(batch_size, img_total) ## 2048
x = self.fc1(x)
x = self.fc2(x)
x = x.view(batch_size, 2, 32, 32)
x = self.layer1(x)
x = self.layer2(x)
x = self.conv2(x)
return x
网上正确的代码(实在不知道哪里不一样):
class ConvBN(nn.Sequential): # 包含卷积;批次归一化;激活函数
def __init__(self, in_planes, out_planes, kernel_size, stride=1, groups=1):
if not isinstance(kernel_size, int):
padding = [(i - 1) // 2 for i in kernel_size]
else:
padding = (kernel_size - 1) // 2 # padding的设置是为了让输出的特征图的大小保持一致
super(ConvBN, self).__init__(OrderedDict([
('conv', nn.Conv2d(in_planes, out_planes, kernel_size, stride,
padding=padding, groups=groups, bias=False)), # 为什么bia设置为FALSE呢?
('bn', nn.BatchNorm2d(out_planes)),
('relu',nn.LeakyReLU(negative_slope=0.3, inplace=True))
]))
class ResBlock(nn.Module):
def __init__(self):
super(ResBlock, self).__init__()
self.direct_path = nn.Sequential(OrderedDict([
("conv_1", ConvBN(2, 8, kernel_size=3)),
("conv_2", ConvBN(8, 16, kernel_size=3)),
("conv_3", nn.Conv2d(16, 2, kernel_size=3, stride=1, padding=1)),
("bn", nn.BatchNorm2d(2))
]))
self.identity = nn.Identity()
self.relu = nn.LeakyReLU(negative_slope=0.3, inplace=True)
def forward(self, x):
identity = self.identity(x)
out = self.direct_path(x)
out = self.relu(out + identity)
return out
class CsiNet(nn.Module):
def __init__(self,reduction=16):
super(CsiNet, self).__init__()
total_size, in_channel, w, h = 2048, 2, 32, 32
dim_out = total_size // reduction
self.encoder_convbn = ConvBN(in_channel, 2, kernel_size=3)
self.encoder_fc = nn.Linear(total_size, dim_out)
self.decoder_fc = nn.Linear(dim_out, total_size)
self.decoder_RefineNet1 = ResBlock()
self.decoder_RefineNet2 = ResBlock()
self.decoder_conv = nn.Conv2d(2, 2, kernel_size=3, stride=1, padding=1)
self.decoder_bn = nn.BatchNorm2d(2)
self.decoder_sigmoid = nn.Sigmoid()
def forward(self, x):
n,c, h, w = x.detach().size()
x = self.encoder_convbn(x)
x = x.view(n,-1) # 平坦化,reshape
x = self.encoder_fc(x)
# 此时x为编码后的输出,需要将x回传给发送端
x = self.decoder_fc(x)
x = x.view(n, c, h, w)
x = self.decoder_RefineNet1(x)
x = self.decoder_RefineNet2(x)
x = self.decoder_conv(x)
x = self.decoder_bn(x)
x = self.decoder_sigmoid(x)
return x
定义里的 img_total
和 dim_out
是什么,外部变量吗?是不是这个的问题?