Hugging Face载入模型双卡训练报错

模型载入代码如下:

device = torch.device("cuda:0")
model6 = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-tiny",
                                                       ignore_mismatched_sizes=True)  # /upernet-swin-large,upernet-convnext-tiny
model6 = nn.DataParallel(model6, device_ids = [0, 1])
model6 = model6.to(device)
x2 = torch.randn(4, 3, 256,256).to(device)
print(model6(x2).shape)

报错:

RuntimeError: Expected tensor for argument #1 'input' to have the same device as
tensor for argument #2 'weight'; but device 1 does not equal 0 (while checking 
arguments for cudnn_convolution)

请问如何解决?

模型的输入数据和模型权重必须在同一设备上。

代码里模型移动到了GPU(device 0)上,执行 model6(x2) 时,输入数据 x2 却没有明确地分配到某一个GPU上。

试试使用 nn.DataParallel 的 forward() 方法来解决这个问题。将 model6(x2) 更改为 model6.module.forward(x2),看看能否解决问题。这样会确保你的输入数据被分配到所有的GPU上。

device = torch.device("cuda:0")
model6 = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-tiny",
                                                       ignore_mismatched_sizes=True)  # /upernet-swin-large,upernet-convnext-tiny
model6 = nn.DataParallel(model6, device_ids=[0, 1])
model6 = model6.to(device)
x2 = torch.randn(4, 3, 256, 256).to(device)
print(model6.module.forward(x2).shape)
不知道你这个问题是否已经解决, 如果还没有解决的话:
  • 看下这篇博客,也许你就懂了,链接:Hugging Face 预训练模型的下载及使用
  • 除此之外, 这篇博客: Hugging face预训练模型下载和使用中的 预训练模型的使用(API) 部分也许能够解决你的问题, 你可以仔细阅读以下内容或跳转源博客中阅读:

    transformers库最关键的三个类

    • model
    • tokenizer
    • configuration

    其他类如GPT2LMHeadModel都是对应类的子类而已,根据模型的特点进行改进。

    1. 使用方式一:指定模型名字

    from transformers import AutoTokenizer, AutoModel
    tokenizer = AutoTokenizer.from_pretrained(“nghuyong/ernie-1.0”)
    model = AutoModel.from_pretrained(“nghuyong/ernie-1.0”)
    这种方式不需要下载预训练模型,函数调用过程中如果发现没有这个模型就会自动下载

    1. 使用方式二:指定路径

    from transformers import AutoTokenizer, AutoModel
    tokenizer = AutoTokenizer.from_pretrained(“/home/models/huggingface/gpt2”)
    model = AutoModel.from_pretrained(“/home/models/huggingface/gpt2”)
    这种方式需要先下载好预训练模型的文件


如果你已经解决了该问题, 非常希望你能够分享一下解决方案, 写成博客, 将相关链接放在评论区, 以帮助更多的人 ^-^