PyTorch加载模型错误

我把在PyTorch框架下用torch.save()训练模型保存下来,但是我加载不了我保存下来的模型。
import torch
model=torch.load("xxx.pt)总是出现没有“model”的错误

img

  • 帮你找了个相似的问题, 你可以看下: https://ask.csdn.net/questions/7596107
  • 我还给你找了一篇非常好的博客,你可以看看是否有帮助,链接:PyTorch模型保存torch.save()与加载torch.load()
  • 你还可以看下pytorch参考手册中的 pytorch torch.nn到底是什么?
  • 除此之外, 这篇博客: Pytorch——保存训练好的模型参数中的 2.torch.save(保存模型) 部分也许能够解决你的问题, 你可以仔细阅读以下内容或跳转源博客中阅读:
  • 首先,先搭建一个神经网络

    import torch
    from torch import nn
    import matplotlib.pyplot as plt
    torch.manual_seed(11)    # 使每次得到的随机数是固定的。但是如果不加上torch.manual_seed这个函数调用的话,打印出来的随机数每次都不一样
    
    
    x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # [100] -> [100,1]
    y = x.pow(2) + 0.5*torch.rand(x.size())  # y的形状与x一样
    
    
    def make_and_save_model():
        
        network = torch.nn.Sequential(
            torch.nn.Linear(1, 8),
            torch.nn.ReLU(),
            torch.nn.Linear(8, 1)
        )
        optimizer = torch.optim.SGD(network.parameters(), lr=0.3)   #优化器
        criterion = torch.nn.MSELoss()     #损失函数
    
        # 训练
        for i in range(200):
            prediction = network(x)      #数据放入模型后得到预测值
            loss = criterion(prediction, y)    #计算预测值与真实值之间的误差
            optimizer.zero_grad()       #清空梯度
            loss.backward()           #误差反向传播
            optimizer.step()          #更新参数
        torch.save(network, 'network.pth')  # 保存整个网络
        torch.save(network.state_dict(), 'network_params.pth')   # 只保存网络中的参数
        
        plt.figure(1, figsize = (10,3))
        plt.subplot(131)
        plt.title('network')
        plt.scatter(x.data.numpy(), y.data.numpy())
        plt.plot(x.data.numpy(), prediction.data.numpy(), 'yo' , lw = 5)
        plt.pause(1)
    
  • 以下回答由chatgpt基于相关博客总结生成:

    针对该问题,可能有以下几个原因导致无法成功加载模型: 1. 加载模型的路径错误; 2. 保存模型时出现了错误,导致模型文件已损坏; 3. 加载模型的代码存在问题,导致无法成功加载。

    针对这些可能的原因,可以采取以下方法逐一排查问题:

    1. 确认加载模型的路径是否正确,可以使用os.path.exists()方法检查文件是否存在;
    2. 确认保存模型时没有出现错误,可以查看保存模型时的输出信息或者捕获异常信息;
    3. 检查加载模型的代码是否正确,包括模型文件格式是否正确,是否使用了正确的模型加载方法等。

    具体的代码示例如下:

    import torch
    
    # 加载模型的路径
    MODEL_DIR = './model.pt'
    
    def load_model():
        if not os.path.exists(MODEL_DIR):
            print('模型文件不存在!')
            return None
        # 加载模型文件
        try:
            model = torch.load(MODEL_DIR)
        except Exception as e:
            print(f'加载模型失败,错误信息:{e}')
            return None
        return model
    
    # 在主函数中调用加载模型的方法
    if __name__ == '__main__':
        # 加载模型
        model = load_model()
        if model is None:
            print('无法加载模型!')
        else:
            # 对加载的模型进行测试等操作
            pass
    

    针对第三种可能的问题,如果确信模型文件格式正确且加载代码没有误,可以尝试使用其他的模型加载方法,例如按照state_dict方式加载:

    import torch
    
    # 加载模型的路径
    MODEL_DIR = './model.pth'
    
    def load_model():
        if not os.path.exists(MODEL_DIR):
            print('模型文件不存在!')
            return None
        # 加载模型文件
        try:
            # 使用state_dict方式加载
            model = torch.nn.Sequential(
                torch.nn.Linear(1, 8),
                torch.nn.ReLU(),
                torch.nn.Linear(8, 1)
            )
            checkpoint = torch.load(MODEL_DIR)
            model.load_state_dict(checkpoint)
        except Exception as e:
            print(f'加载模型失败,错误信息:{e}')
            return None
        return model
    
    # 在主函数中调用加载模型的方法
    if __name__ == '__main__':
        # 加载模型
        model = load_model()
        if model is None:
            print('无法加载模型!')
        else:
            # 对加载的模型进行测试等操作
            pass