我把在PyTorch框架下用torch.save()训练模型保存下来,但是我加载不了我保存下来的模型。
import torch
model=torch.load("xxx.pt)总是出现没有“model”的错误
首先,先搭建一个神经网络
import torch
from torch import nn
import matplotlib.pyplot as plt
torch.manual_seed(11) # 使每次得到的随机数是固定的。但是如果不加上torch.manual_seed这个函数调用的话,打印出来的随机数每次都不一样
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # [100] -> [100,1]
y = x.pow(2) + 0.5*torch.rand(x.size()) # y的形状与x一样
def make_and_save_model():
network = torch.nn.Sequential(
torch.nn.Linear(1, 8),
torch.nn.ReLU(),
torch.nn.Linear(8, 1)
)
optimizer = torch.optim.SGD(network.parameters(), lr=0.3) #优化器
criterion = torch.nn.MSELoss() #损失函数
# 训练
for i in range(200):
prediction = network(x) #数据放入模型后得到预测值
loss = criterion(prediction, y) #计算预测值与真实值之间的误差
optimizer.zero_grad() #清空梯度
loss.backward() #误差反向传播
optimizer.step() #更新参数
torch.save(network, 'network.pth') # 保存整个网络
torch.save(network.state_dict(), 'network_params.pth') # 只保存网络中的参数
plt.figure(1, figsize = (10,3))
plt.subplot(131)
plt.title('network')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'yo' , lw = 5)
plt.pause(1)
针对该问题,可能有以下几个原因导致无法成功加载模型: 1. 加载模型的路径错误; 2. 保存模型时出现了错误,导致模型文件已损坏; 3. 加载模型的代码存在问题,导致无法成功加载。
针对这些可能的原因,可以采取以下方法逐一排查问题:
具体的代码示例如下:
import torch
# 加载模型的路径
MODEL_DIR = './model.pt'
def load_model():
if not os.path.exists(MODEL_DIR):
print('模型文件不存在!')
return None
# 加载模型文件
try:
model = torch.load(MODEL_DIR)
except Exception as e:
print(f'加载模型失败,错误信息:{e}')
return None
return model
# 在主函数中调用加载模型的方法
if __name__ == '__main__':
# 加载模型
model = load_model()
if model is None:
print('无法加载模型!')
else:
# 对加载的模型进行测试等操作
pass
针对第三种可能的问题,如果确信模型文件格式正确且加载代码没有误,可以尝试使用其他的模型加载方法,例如按照state_dict方式加载:
import torch
# 加载模型的路径
MODEL_DIR = './model.pth'
def load_model():
if not os.path.exists(MODEL_DIR):
print('模型文件不存在!')
return None
# 加载模型文件
try:
# 使用state_dict方式加载
model = torch.nn.Sequential(
torch.nn.Linear(1, 8),
torch.nn.ReLU(),
torch.nn.Linear(8, 1)
)
checkpoint = torch.load(MODEL_DIR)
model.load_state_dict(checkpoint)
except Exception as e:
print(f'加载模型失败,错误信息:{e}')
return None
return model
# 在主函数中调用加载模型的方法
if __name__ == '__main__':
# 加载模型
model = load_model()
if model is None:
print('无法加载模型!')
else:
# 对加载的模型进行测试等操作
pass