关于#pytorch#的问题:pytorch实现mnist手写数字识别

我正在用pytorch实现mnist手写数字识别,但是我的loss从一开始就不变,这是为什么? 以下是我的代码:


import gzip
import pickle
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.autograd import variable

#读取训练数据
f = gzip.open("./mnist.pkl.gz","rb")
train_data, val_data, test_data = pickle.load(f,encoding='latin1')
f.close()

# 将50000张训练图片分为250组,每组200张图片,图片大小
train_data_img = train_data[0].reshape(250,200,28,28)
train_data_ans = train_data[1].reshape(250,200)
#搭建网络
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.fc1 = nn.Linear(784, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, 10)

    def forward(self,x):
        x = F.leaky_relu(self.fc1(x))
        x = F.leaky_relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()
optimizer = optim.SGD(net.parameters(),lr=0.01,momentum=0.0)

#迭代
losses = []
loss = 1
for epoch in range(10):
    for data,ans in zip(train_data_img,train_data_ans):
        out = net(torch.tensor(data.reshape(200,784)))
        ans = F.one_hot(torch.tensor(ans))
        loss = F.mse_loss(out, ans)
        optimizer.zero_grad()
        loss.backword()
        optimizer.step()
    losses.append(loss.item())

print(losses)
xlabel = np.linspace(0,len(losses),len(losses))
plt.plot(xlabel,losses)
plt.show()

可以将每一次的w、b、loss、dw、db都打印出来,看是否随机梯度下降没起作用