pytorch框架使用load_state_dict(torch. Load(path)['model'])时加载时间很久,需要好几个小时,但其实模型文件只有十几mb
是不是参数比较多
"""
多项式回归( 本地模型的加载 )
"""
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
plt.switch_backend("TkAgg")
# 一. 我们保存的模型信息
# 将整个数学模型和参数进行封装
# 继承 nn.Module
class Model(nn.Module):
# 初始化
def __init__(self):
# 将父类初始化
super().__init__()
# 定义可训练参数并赋予初始值,并封装注册到nn.parameter.Parameter类中,
# Parameter封装好的参数可以通过 model.parameters() 调用
self.w = nn.parameter.Parameter(torch.randn([2, 1]))
self.b = nn.parameter.Parameter(torch.randn([1]))
def forward(self, x):
# 正向计算过程, 即我们定义的数学模型, y = x @ w +b
return x @ self.w + self.b
# 二. 加载模型
model = Model()
model.load_state_dict(torch.load("regression.pt"))
# 三. 绘制测试集和预测曲线
# 3.1 绘制测试集
# 测试集的特征, 以x=0为对称轴, 标准差为1的正态分布生成1000个点
# 注意: 要以 矩阵的方式生成点, 即 1000行 1列的矩阵
x = np.random.normal(0, 1, [1000, 1])
# 测试集标签, 大概遵从 y = x^2 + 2x + 1 的曲线分布
d = x ** 2 + 2 * x + 1 + np.random.normal(0, 0.3, [1000, 1])
# 绘制散点图
plt.scatter(x, d, c="k")
# 3.2 绘制预测曲线
# 生成-3 到 3之间 100个数, 并变为 100 x 1 的矩阵
x_plt = np.linspace(-3, 3, 100)[..., np.newaxis]
# 将 x_plt 转为 Tensor 类型
xp = torch.from_numpy(x_plt).float()
# 拼接特征 [x^2, x]
xp = torch.cat((xp ** 2, xp), dim=1)
# 代入模型
yp = model(xp)
# 由于 y_plt 为计算图的一部分, 所以进行一下截断 detach() , 只要 y 节点
y_plt = yp.detach().numpy()
# 绘制曲线图
plt.plot(x_plt, y_plt, c="r", lw=2)
plt.show()