基于python的mnist条用时出现远端服务器无响应

在学习神经网络是时,读入mnist数据集时总是出现错误,测试了各种·方法都不能有效的解决

代码如下

import sys,os
sys.path.append(os.pardir)
from dataset.mnist import load_mnist
( x_train, t_train), (x_test, t_test) = load_mnist(normalize=False, flatten=True)
print(x_train.shape)
print(t_train.shape)
print(x_test.shape)
print(t_test.shape)
import numpy as np
from PIL import Image
def img_show(img):
    pil_img=Image.fromarray(np.uint8(img))
    pil_img.show()
( x_train, t_train), (x_test, t_test) = load_mnist(normalize=False, flatten=True)
img=x_train[0]
lable=t_train[0]
print(lable)

print(img.shape)   #(784,)
img=img.reshape(28,28)  #把图片的形状变为原来的尺寸
print(img.shape)     #28*28
img_show()

错误是这样 

Downloading t10k-labels-idx1-ubyte.gz ... 
Traceback (most recent call last):
  File "E:\pycharm\ch3\minst_show.py", line 4, in <module>
    ( x_train, t_train), (x_test, t_test) = load_mnist(normalize=False, flatten=True)
  File "E:\pycharm\dataset\mnist.py", line 106, in load_mnist
    init_mnist()
  File "E:\pycharm\dataset\mnist.py", line 75, in init_mnist
    download_mnist()
  File "E:\pycharm\dataset\mnist.py", line 42, in download_mnist
    _download(v)
  File "E:\pycharm\dataset\mnist.py", line 37, in _download
    urllib.request.urlretrieve(url_base + file_name, file_path)
  File "C:\Python\Python39\lib\urllib\request.py", line 239, in urlretrieve
    with contextlib.closing(urlopen(url, data)) as fp:
  File "C:\Python\Python39\lib\urllib\request.py", line 214, in urlopen
    return opener.open(url, data, timeout)
  File "C:\Python\Python39\lib\urllib\request.py", line 523, in open
    response = meth(req, response)
  File "C:\Python\Python39\lib\urllib\request.py", line 632, in http_response
    response = self.parent.error(
  File "C:\Python\Python39\lib\urllib\request.py", line 561, in error
    return self._call_chain(*args)
  File "C:\Python\Python39\lib\urllib\request.py", line 494, in _call_chain
    result = func(*args)
  File "C:\Python\Python39\lib\urllib\request.py", line 641, in http_error_default
    raise HTTPError(req.full_url, code, msg, hdrs, fp)
urllib.error.HTTPError: HTTP Error 503: Service Unavailable

 

HTTP Error 503: Service Unavailable,访问的地址错了。

这是因为在http://yann.lecun.com/exdb/mnist/的最后一个文件暂时不提供下载服务,尝试从其他地方下载到本地调用,也可用替代方法:

from torchvision.datasets import MNIST

MNIST(".", download=True)
from torchvision import datasets
datasets.MNIST.resources = [
            ('https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz', 'f68b3c2dcbeaaa9fbdd348bbdeb94873'),
            ('https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz', 'd53e105ee54ea40749a09fcbcd1e9432'),
            ('https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz', '9fb629c4189551a2d022fa330f9573f3'),
            ('https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz', 'ec29112dd5afa0611ce80d1b7f02629c')
        ]

示例,从keras 可正常导入:

from keras.datasets import mnist
from matplotlib import pyplot

#loading
(train_X, train_y), (test_X, test_y) = mnist.load_data()

#shape of dataset
print('X_train: ' + str(train_X.shape))
print('Y_train: ' + str(train_y.shape))
print('X_test:  ' + str(test_X.shape))
print('Y_test:  ' + str(test_y.shape))

#plotting
for i in range(9):
    pyplot.subplot(330 + 1 + i)
    pyplot.imshow(train_X[i], cmap=pyplot.get_cmap('gray'))
pyplot.show()

 

# coding: utf-8
try:
    import urllib.request
except ImportError:
    raise ImportError('You should use Python 3.x')
import os.path
import gzip
import pickle
import os
import numpy as np


url_base = 'http://yann.lecun.com/exdb/mnist/'
key_file = {
    'train_img':'train-images-idx3-ubyte.gz',
    'train_label':'train-labels-idx1-ubyte.gz',
    'test_img':'t10k-images-idx3-ubyte.gz',
    'test_label':'t10k-labels-idx1-ubyte.gz'
}

dataset_dir = os.path.dirname(os.path.abspath(__file__))
save_file = dataset_dir + "/mnist.pkl"

train_num = 60000
test_num = 10000
img_dim = (1, 28, 28)
img_size = 784


def _download(file_name):
    file_path = dataset_dir + "/" + file_name
    
    if os.path.exists(file_path):
        return

    print("Downloading " + file_name + " ... ")
    urllib.request.urlretrieve(url_base + file_name, file_path)
    print("Done")
    
def download_mnist():
    for v in key_file.values():
       _download(v)
        
def _load_label(file_name):
    file_path = dataset_dir + "/" + file_name
    
    print("Converting " + file_name + " to NumPy Array ...")
    with gzip.open(file_path, 'rb') as f:
            labels = np.frombuffer(f.read(), np.uint8, offset=8)
    print("Done")
    
    return labels

def _load_img(file_name):
    file_path = dataset_dir + "/" + file_name
    
    print("Converting " + file_name + " to NumPy Array ...")    
    with gzip.open(file_path, 'rb') as f:
            data = np.frombuffer(f.read(), np.uint8, offset=16)
    data = data.reshape(-1, img_size)
    print("Done")
    
    return data
    
def _convert_numpy():
    dataset = {}
    dataset['train_img'] =  _load_img(key_file['train_img'])
    dataset['train_label'] = _load_label(key_file['train_label'])    
    dataset['test_img'] = _load_img(key_file['test_img'])
    dataset['test_label'] = _load_label(key_file['test_label'])
    
    return dataset

def init_mnist():
    download_mnist()
    dataset = _convert_numpy()
    print("Creating pickle file ...")
    with open(save_file, 'wb') as f:
        pickle.dump(dataset, f, -1)
    print("Done!")

def _change_one_hot_label(X):
    T = np.zeros((X.size, 10))
    for idx, row in enumerate(T):
        row[X[idx]] = 1
        
    return T
    

def load_mnist(normalize=True, flatten=True, one_hot_label=False):
    """读入MNIST数据集
    
    Parameters
    ----------
    normalize : 将图像的像素值正规化为0.0~1.0
    one_hot_label : 
        one_hot_label为True的情况下,标签作为one-hot数组返回
        one-hot数组是指[0,0,1,0,0,0,0,0,0,0]这样的数组
    flatten : 是否将图像展开为一维数组
    
    Returns
    -------
    (训练图像, 训练标签), (测试图像, 测试标签)
    """
    if not os.path.exists(save_file):
        init_mnist()
        
    with open(save_file, 'rb') as f:
        dataset = pickle.load(f)
    
    if normalize:
        for key in ('train_img', 'test_img'):
            dataset[key] = dataset[key].astype(np.float32)
            dataset[key] /= 255.0
            
    if one_hot_label:
        dataset['train_label'] = _change_one_hot_label(dataset['train_label'])
        dataset['test_label'] = _change_one_hot_label(dataset['test_label'])
    
    if not flatten:
         for key in ('train_img', 'test_img'):
            dataset[key] = dataset[key].reshape(-1, 1, 28, 28)

    return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label']) 


if __name__ == '__main__':
    init_mnist()

这个是mnist代码也是书上给的源代码

 

 

您好,我是有问必答小助手,你的问题已经有小伙伴为您解答了问题,您看下是否解决了您的问题,可以追评进行沟通哦~

如果有您比较满意的答案 / 帮您提供解决思路的答案,可以点击【采纳】按钮,给回答的小伙伴一些鼓励哦~~

ps:问答VIP仅需29元,即可享受5次/月 有问必答服务,了解详情>>>https://vip.csdn.net/askvip?utm_source=1146287632

非常感谢您使用有问必答服务,为了后续更快速的帮您解决问题,现诚邀您参与有问必答体验反馈。您的建议将会运用到我们的产品优化中,希望能得到您的支持与协助!

速戳参与调研>>>https://t.csdnimg.cn/Kf0y