用pytorch实现mnist手写数字识别
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets,transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
train_dataset=datasets.MNIST(root='./num/',
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset=datasets.MNIST(root='./num/',
train=False,
transform=transforms.ToTensor(),
download=False)
batch_size=64
train_loader=torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader=torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=True)
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1=nn.Sequential(nn.Conv2d(1,6,3,1,2),nn.ReLU(),
nn.MaxPool2d(2,2))
self.conv2=nn.Sequential(nn.Conv2d(6,16,5),nn.ReLU(),
nn.MaxPool2d(2,2))
self.fc1=nn.Sequential(nn.Linear(16*5*5,120),
nn.BatchNorm1d(120),nn.ReLU())
self.fc2=nn.Sequential(
nn.Linear(120,84),
nn.BatchNorm1d(84),
nn.ReLU(),
nn.Linear(84,10))
def forward(self,x):
x=self.conv1(x)
x=self.conv2(x)
x=x.view(x.size()[0],-1)
x=self.fc1(x)
x=self.fc2(x)
return x
device=torch.device('cuda'if torch.cuda.is_available() else 'cpu')
batch_size=64
LR=0.001
net=LeNet().to(device)
criterion=nn.CrossEntropyLoss()
optimizer=optim.Adam(
net.parameters(),
lr=LR,)
epoch=1
if __name__ == '__main__':
for epoch in range(epoch):
sum_loss=0.0
for i,data in enumerate(train_loader):
inputs,labels=data
inputs,labels=Variable(inputs).cuda(),Variable(labels).cuda()
optimizer.zero_grad()
outputs =net(inputs)
loss=criterion(outputs,labels)
loss.backward()
optimizer.step()
sum_loss+=loss.item()
if i % 100 == 99:
print('[%d,%d] loss:%.03f'%
(epoch + 1,i+1,sum_loss/100))
sum_loss=0.0
correct=0
total=0
for data_test in test_loader:
images,labels=data_test
images,labels=Variable(images).cuda(),Variable(labels).cuda()
output_test=net(images)
_,predicted=torch.max(output_test,1)
total+=labels.size(0)
correct+=(predicted==labels).sum()
print("correct1:",correct)
print("test acc:{0}".format(correct.item()/
len(test_dataset)))
Traceback (most recent call last):
File "C:\Users\12137\PycharmProjects\pythonProject2\main.py", line 65, in <module>
outputs =net(inputs)
File "D:\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "D:\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 201, in _forward_unimplemented
raise NotImplementedError
NotImplementedError
def forward没有和__init对齐
改后为
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1=nn.Sequential(nn.Conv2d(1,6,3,1,2),nn.ReLU(),
nn.MaxPool2d(2,2))
self.conv2=nn.Sequential(nn.Conv2d(6,16,5),nn.ReLU(),
nn.MaxPool2d(2,2))
self.fc1=nn.Sequential(nn.Linear(16*5*5,120),
nn.BatchNorm1d(120),nn.ReLU())
self.fc2=nn.Sequential(
nn.Linear(120,84),
nn.BatchNorm1d(84),
nn.ReLU(),
nn.Linear(84,10))
def forward(self,x):
x=self.conv1(x)
x=self.conv2(x)
x=x.view(x.size()[0],-1)
x=self.fc1(x)
x=self.fc2(x)
return x
你好,我是有问必答小助手,非常抱歉,本次您提出的有问必答问题,技术专家团超时未为您做出解答
本次提问扣除的有问必答次数,将会以问答VIP体验卡(1次有问必答机会、商城购买实体图书享受95折优惠)的形式为您补发到账户。
因为有问必答VIP体验卡有效期仅有1天,您在需要使用的时候【私信】联系我,我会为您补发。