在应用CNN进行张量融合时报错:
RuntimeError: mat1 and mat2 shapes cannot be multiplied (10x18240 and 10x18240)
模型代码如下
class MyCNN(nn.Module):
def __init__(self):
super(MyCNN, self).__init__()
self.conv1 = nn.Conv2d(6, 16, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.fc1 = nn.Linear(10,18240)
self.fc2 = nn.Linear(18240,1)
def forward(self, xp, xc):
batch_size = xp.size(0)
x = torch.cat((xp.unsqueeze(1), xc), dim=1) #(10,6,20,456)
print(x.shape)
x = F.relu(self.conv1(x))
print(x.shape)
x = F.max_pool2d(x, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2)
x = x.view(batch_size, -1)
print(x.shape)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x.squeeze()
为什么会出现这样的问题?要怎么修改?
将nn.Linear(10, 18240)更改为nn.Linear(1610228, 18240)看看呢
提示是维度不一致
原因: 网络参数设置不合适!!
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
import torch
class Lyy(nn.Module):
def __init__(self):
super(Lyy, self).__init__()
self.conv1 = Conv2d(3, 32, 5, padding=2)
self.maxpool1 = MaxPool2d(2)
self.conv2 = Conv2d(32, 32, 5, padding=2)
self.maxpool2 = MaxPool2d(2)
self.conv3 = Conv2d(32, 64, 5, padding=2)
self.maxpool3 = MaxPool2d(2)
self.flatten = Flatten()
self.linear1 = Linear(1024, 64)
self.linear2 = Linear(64, 10)
def forward(self, x):
x = self.conv1(x)
x = self.maxpool1(x)
x = self.conv2(x)
x = self.maxpool2(x)
x = self.conv3(x)
x = self.maxpool3(x)
x = self.flatten(x)
x = self.linear1(x)
x = self.linear2(x)
return x
lyy = Lyy()
input = torch.ones((64,3,32,32))
print(input.shape)
output = lyy(input)
print(output.shape)
上述代码的,等价替换版本
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
import torch
class Lyy(nn.Module):
def __init__(self):
super(Lyy, self).__init__()
self.model1 = Sequential(
Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10)
)
def forward(self, x):
x = self.model1(x)
return x
lyy = Lyy()
input = torch.ones((64,3,32,32))
print(input.shape)
output = lyy(input)
print(output.shape)
根据报错信息,mat1和mat2的大小分别为10x18240和10x18240,在矩阵相乘时不符合尺寸要求。可能原因是在进行张量融合时,数据维度和模型参数设置不符合。为了解决这个问题,需要对代码进行调整,确保数据和参数的维度匹配。
具体的操作步骤如下:
1.首先检查输入数据的维度是否和模型参数的要求一致,比如输入的数据维度是否为4维,卷积层的输入通道数是否和输入数据通道数一致等。
2.如果前面的检查都没有问题,那么需要检查模型参数的设置是否符合实际应用需要。比如卷积核的大小、通道数、池化层的大小等是否设置得当。
3.如果数据和参数都没有问题,那么建议将数据和参数的维度打印出来,检查究竟是哪一步出现了维度不匹配的问题。
4.最后,如果还不能解决问题,建议通过查询文献或者请教更有经验的同行来寻求帮助。