画出部分MNIST数据集图像时内核崩溃

画出部分MNIST数据集图像时内核崩溃
用代码块功能插入代码,请勿粘贴截图
import torch
import numpy as np
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
import torch.nn.functional as F
batch_size = 64
learning_rate = 0.01
momentum = 0.5
EPOCH = 10
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(root='D:\code\PyTorch\mnist', train=True, download=False, transform=transform)  
test_dataset = datasets.MNIST(root='D:\code\PyTorch\mnist', train=False, download=False, transform=transform) 
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
fig = plt.figure()
for i in range(12):
    plt.subplot(3, 4, i+1)
    plt.tight_layout()
    plt.imshow(train_dataset.train_data[i], cmap='gray', interpolation='none')
    plt.title("Labels: {}".format(train_dataset.train_labels[i]))
    plt.xticks([])
    plt.yticks([])
plt.show()

img

内核直接崩溃了,代码用的是https://blog.csdn.net/qq_45588019/article/details/120935828?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522166461547916782427490256%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&request_id=166461547916782427490256&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~top_positive~default-1-120935828-null-null.142^v51^control,201^v3^control_1&utm_term=mnist%E6%89%8B%E5%86%99%E6%95%B0%E5%AD%97%E8%AF%86%E5%88%ABpytorch&spm=1018.2226.3001.4187