寻找会深度学习,图像生成方面的大神(DCGAN)

https://github.com/eriklindernoren/Keras-GAN/blob/master/dcgan/dcgan.py

我想用github上的DCGAN的opensource来训练我自己的dataset,请问如何导入自己的dataset

我把你的链接中的代码下载下来了,运行了一遍,是可以用的代码。

dcgan.py在第109行用到了mnist.load_data()这个函数,读取的是自带的mnist.npz数据集。我看到mnist.load_data()函数的原文是这样的:

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""MNIST handwritten digits dataset.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from tensorflow.python.keras.utils.data_utils import get_file
from tensorflow.python.util.tf_export import keras_export


@keras_export('keras.datasets.mnist.load_data')
def load_data(path='mnist.npz'):
  """Loads the [MNIST dataset](http://yann.lecun.com/exdb/mnist/).

  This is a dataset of 60,000 28x28 grayscale images of the 10 digits,
  along with a test set of 10,000 images.
  More info can be found at the
  [MNIST homepage](http://yann.lecun.com/exdb/mnist/).


  Arguments:
      path: path where to cache the dataset locally
          (relative to `~/.keras/datasets`).

  Returns:
      Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.

      **x_train, x_test**: uint8 arrays of grayscale image data with shapes
        (num_samples, 28, 28).

      **y_train, y_test**: uint8 arrays of digit labels (integers in range 0-9)
        with shapes (num_samples,).

  License:
      Yann LeCun and Corinna Cortes hold the copyright of MNIST dataset,
      which is a derivative work from original NIST datasets.
      MNIST dataset is made available under the terms of the
      [Creative Commons Attribution-Share Alike 3.0 license.](
      https://creativecommons.org/licenses/by-sa/3.0/)
  """
  origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
  path = get_file(
      path,
      origin=origin_folder + 'mnist.npz',
      file_hash=
      '731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1')
  with np.load(path, allow_pickle=True) as f:
    x_train, y_train = f['x_train'], f['y_train']
    x_test, y_test = f['x_test'], f['y_test']

    return (x_train, y_train), (x_test, y_test)

可以根据该函数仿写一个读取数据的函数。经过查验,mnist.npz里面的样本是28*28的,需要缩放到28*28的样本。

最后的函数是这样的:

import numpy as np
import cv2
import cv
import os
import random

def get_image(image_index, path=r'C:\Coding\Python\CSDN\Image\bibimbap', img_predix="hed"):
	# 扩充前导0
	'''
	image_index 是数字,从0到999
	path是数据集的绝对路径。也可以换成相对路径。
	img_predix是数据集的前缀。
	'''
	image_index = "%04d" % image_index
	image_path = os.path.join(path, img_predix+image_index+'.png')
	# 转为灰度图
	img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
	img = Reduce(img)
	return img

def Reduce(image):
    shrink = cv2.resize(image, (28,28), interpolation=cv2.INTER_AREA)  
    return shrink

def load_data_mydatabase(path=r'C:\Coding\Python\CSDN\Image\bibimbap'):
	# 由于mnist.npz中的数据集是单色值,28*28像素的数据。
	# 因此需要选择预训练之后的bibimbap下head0000.png作为训练集和标签集合
	# 由于只有1000张图,可以采用前900张作为训练集,最后100张作为数据集。
	x_train = []
	x_test = []
	y_train = []
	y_test = []
	# 这个烤冷面我不太清楚你希望做成哪些类别,所以这里随机生成十个类别。
	for train_index in range(900):
		x_train.append(get_image(train_index))
		y_train.append(int(random.uniform(0,10)))
	x_train = np.array(x_train)
	for test_index in range(900, 1000):
		x_test.append(get_image(test_index))
		y_test.append(int(random.uniform(0,10)))
	x_test = np.array(x_test)
	y_test = np.array(y_test)
	return (x_train,y_train), (x_test,y_test)

因为我记得你之前是要做冷面的数据集,我还下载下来了一份。但是不清楚你这是要几分类。所以我随机生成了一个十分类的标签值,题主根据自己需要生成新的标签值。

使用的时候,将这片代码放到原代码中。并将dcgan.py中第109行换成

        (X_train, _), (_, _) = load_data_mydatabase()

即可。
 

感谢你的解答,我按照你的建议报了这个错误,你可以给看一下是什么原因吗

先用for循环依次读取你自己的图片,比如用opencv的,然后reshape这些图片成统一的shape,然后用numpy创建一个shape为 (你图片数量,high,length, channel=3)的数组,把刚才读的图片依次赋值进去就行了,我看你这个github还继续除 127.5 再减 1,你可以跟他一样,这样数据部分就准备好了。

比如这样,我先创建个空数组用来装图片

x_train = np.zeros((20000, 224, 224, 3), dtype = np.uint8)

然后用for循环依次将读后处理后的图像放入数组,就ok啦~

for i in tqdm(range(int(image_num / 2))):
    x_train[i] = cv2.resize(cv2.imread('train/cat.%d.jpg' % i), (224, 224))