load_state_dict加载时间很久

pytorch框架使用load_state_dict(torch. Load(path)['model'])时加载时间很久,需要好几个小时,但其实模型文件只有十几mb

是不是参数比较多

  • 请看👉 :torch.load_state_dict()用法
  • 你还可以看下pytorch参考手册中的 pytorch load_state_dict_from_url() (in module torch.hub)
  • 除此之外, 这篇博客: 人工智能-深度学习-Pytorch与神经网络中的 model.load_state_dict(torch.load(“regression.pt”)) 方法可以对模型进行加载复原 部分也许能够解决你的问题, 你可以仔细阅读以下内容或跳转源博客中阅读:
  • """
    多项式回归( 本地模型的加载 )
    """
    import torch
    import torch.nn as nn
    import numpy as np
    import matplotlib.pyplot as plt
    
    plt.switch_backend("TkAgg")
    
    
    # 一. 我们保存的模型信息
    
    # 将整个数学模型和参数进行封装
    # 继承 nn.Module
    class Model(nn.Module):
        # 初始化
        def __init__(self):
            # 将父类初始化
            super().__init__()
            # 定义可训练参数并赋予初始值,并封装注册到nn.parameter.Parameter类中,
            # Parameter封装好的参数可以通过 model.parameters() 调用
            self.w = nn.parameter.Parameter(torch.randn([2, 1]))
            self.b = nn.parameter.Parameter(torch.randn([1]))
    
        def forward(self, x):
            # 正向计算过程, 即我们定义的数学模型,  y = x @ w +b
            return x @ self.w + self.b
    
    
    # 二. 加载模型
    
    model = Model()
    model.load_state_dict(torch.load("regression.pt"))
    
    
    # 三. 绘制测试集和预测曲线
    
    # 3.1 绘制测试集
    
    # 测试集的特征, 以x=0为对称轴, 标准差为1的正态分布生成1000个点
    # 注意: 要以 矩阵的方式生成点, 即 1000行 1列的矩阵
    x = np.random.normal(0, 1, [1000, 1])
    
    # 测试集标签, 大概遵从 y = x^2 + 2x + 1 的曲线分布
    d = x ** 2 + 2 * x + 1 + np.random.normal(0, 0.3, [1000, 1])
    # 绘制散点图
    plt.scatter(x, d, c="k")
    
    
    # 3.2 绘制预测曲线
    
    # 生成-3 到 3之间 100个数, 并变为 100 x 1 的矩阵
    x_plt = np.linspace(-3, 3, 100)[..., np.newaxis]
    # 将 x_plt 转为 Tensor 类型
    xp = torch.from_numpy(x_plt).float()
    # 拼接特征 [x^2, x]
    xp = torch.cat((xp ** 2, xp), dim=1)
    
    # 代入模型
    yp = model(xp)
    
    # 由于 y_plt 为计算图的一部分, 所以进行一下截断 detach() , 只要 y 节点
    y_plt = yp.detach().numpy()
    
    # 绘制曲线图
    plt.plot(x_plt, y_plt, c="r", lw=2)
    
    plt.show()
    

    在这里插入图片描述