pytorch图像分类

Given groups=1, weight of size [16, 1, 3, 3], expected input[1, 3, 512, 512] to have 1 channels, but got 3 channels instead
我的图片是rgb类型的,这个该怎么修改代码呀,盆友们

您可以尝试在代码中修改groups参数的值,将其从1改为3,来适配输入图片的RGB通道数,如下所示:

import torch.nn as nn

class YourModel(nn.Module):
    def __init__(self):
        super(YourModel, self).__init__()
        self.conv = nn.Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=3, bias=False)

    def forward(self, x):
        x = self.conv(x)
        return x

这里我将groups参数从1改为3,同时修改了Conv2d的输入通道数为3(代表rgb图片的3个通道)。如果您的图片不是RGB类型的,请将输入通道数改为对应的通道数。

改backbone或图片RGB转灰度