pytorch搭建cnn-lstm网络的shape修改问题



class cnn_lstm(nn.Module):
    def __init__(self,window_size,feature_number):
        super(cnn_lstm, self).__init__()
        self.window_size=window_size
        self.feature_number=feature_number
        self.conv1 = Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.relu1 = ReLU()
        self.maxpooling1 = MaxPool2d(3, stride=1,padding=1)
        self.dropout1 = Dropout(0.3)
        self.lstm1 = nn.LSTM(input_size=64 * feature_number, hidden_size=128, num_layers=1, batch_first=True)
        self.lstm2 = nn.LSTM(input_size=128, hidden_size=64, num_layers=1, batch_first=True)
        self.fc = nn.Linear(in_features=64, out_features=32)
        self.relu2 = nn.ReLU()
        self.head = nn.Linear(in_features=32, out_features=1)

    def forward(self, x):

            x = x.reshape([x.shape[0], 1, self.window_size, self.feature_number])
            # x = x.transpose(-1, -2)
            x = self.conv1(x)
            x = self.relu1(x)
            x = self.maxpooling1(x)
            x = self.dropout1(x)

            x = x.reshape([x.shape[0], self.window_size, -1])
            # x = x.transpose(-1, -2)  #
            x, (h, c) = self.lstm1(x)
            x, (h, c) = self.lstm2(x)
            x = x[:, -1, :]  # 最后一个LSTM只要窗口中最后一个特征的输出
            x = self.fc(x)
            x = self.relu2(x)
            x = self.head(x)

            return x

# 实例化模型
model=cnn_lstm(window_size=3,feature_number=9)

这是我借鉴pytorch写的cnn-lstm架构,在处理数据时报RuntimeError: The size of tensor a (7963) must match the size of tensor b (3) at non-singleton dimension 1的错误。
数据的shape是:
x_train.shape = torch.Size([7963, 3, 9])
y_train.shape = torch.Size([7963, 3, 1])
x_test.shape = torch.Size([1991, 3, 9])
y_test.shape = torch.Size([1991, 3, 1])
请问需要哪里进行修改啊

不知道你这个问题是否已经解决, 如果还没有解决的话:
  • 以下回答来自chatgpt:

    首先需要理解错误信息的含义,该错误表示a(维度为7963)和b(维度为3)在非单一维度1上的大小必须匹配。因此我们需要修改数据shape以匹配网络架构。

    假设当前的CNN-LSTM网络架构已经确定,需要修改数据shape以满足网络的输入要求。具体来说,我们需要根据当前输入数据的shape确定哪些部分需要进行修改。如果我们能够通过数据切片来满足网络输入,则我们可以使用PyTorch的view函数来调整数据,否则我们需要使用其他方式来修改数据shape。

    以下是一些常见的修改数据shape的方式:

    1.使用numpy的reshape函数,它可以在不改变数据本身的前提下修改数据的shape。

    2.使用PyTorch的view函数,在不改变存储结构的前提下修改数据的shape。

    3.使用PyTorch的transpose函数,可以交换维度的位置。

    具体来说,对于本问题,可以通过对数据进行切片、使用PyTorch的view函数和transpose函数来修改数据shape,使其与指定的CNN-LSTM网络架构匹配。

    假设当前数据的shape为(batch_size, seq_len, feature_dim),其中batch_size为样本数,seq_len为序列长度,feature_dim为特征数。我们需要将其修改为(batch_size, channel, seq_len, width, height),其中channel为通道数,可以设置为1,seq_len为序列长度,可以设置为1,width和height为图像的宽和高。

    具体的代码如下:

    # 将数据切片并使用view函数修改数据shape
    data = data.view(batch_size, 1, seq_len, width, height)
    
    # 使用transpose函数交换维度的位置
    data = data.transpose(1, 2)
    
    # 将数据转换为PyTorch张量
    data = torch.Tensor(data)
    
    # 将数据放入GPU
    data = data.to(device)
    

    这样就可以将数据shape调整为(batch_size, channel, seq_len, width, height),使其与CNN-LSTM网络架构匹配,从而避免了RuntimeError。


如果你已经解决了该问题, 非常希望你能够分享一下解决方案, 写成博客, 将相关链接放在评论区, 以帮助更多的人 ^-^