将模型的输出信息:loss,train accuracy,test accuracy等,用matplotlib绘制成图像两幅图像
示例代码
import torch
from torch import nn
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
# Define a MLP
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.layers = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 512),
nn.ReLU(),
nn.Linear(512, 10)
)
def forward(self, x):
return self.layers(x)
# Load data
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
trainset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
testset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)
model = MLP()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# Training
epochs = 10
train_losses, test_losses = [], []
for e in range(epochs):
running_loss = 0
for images, labels in trainloader:
optimizer.zero_grad()
output = model(images)
loss = criterion(output, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
else:
test_loss = 0
accuracy = 0
# Turn off gradients for validation
with torch.no_grad():
for images, labels in testloader:
output = model(images)
test_loss += criterion(output, labels)
ps = torch.exp(output)
top_p, top_class = ps.topk(1, dim=1)
equals = top_class == labels.view(*top_class.shape)
accuracy += torch.mean(equals.type(torch.FloatTensor))
train_losses.append(running_loss/len(trainloader))
test_losses.append(test_loss/len(testloader))
print("Epoch: {}/{}.. ".format(e+1, epochs),
"Training Loss: {:.3f}.. ".format(train_losses[-1]),
"Test Loss: {:.3f}.. ".format(test_losses[-1]),
"Test Accuracy: {:.3f}".format(accuracy/len(testloader)))
# Plotting the losses
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training loss')
plt.plot(test_losses, label='Validation loss')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(train_losses, label='Training loss')
plt.legend()
plt.show()
如果有帮助,点击一下采纳该答案~谢谢