SRCNN,训练的效果很差, 应该怎么改

我复现SRCNN的时候,不知道为什么model跑个200个epochs就没办法优化了,导致效果像抛硬币,有的图片提升了一丢丢,有的完全是变得更差了.
我用的训练图像是原作者提供的。

import numpy as np
from matplotlib import pyplot as plt
import sys
import keras
import cv2
import numpy
from keras.models import Sequential
from keras.layers import Conv2D
from keras.optimizers import Adam
from skimage.measure import compare_ssim as ssim
import cv2
import math
import random
from keras.callbacks import ModelCheckpoint
import os
from tensorflow.keras.callbacks import ModelCheckpoint


def psnr(target, ref):
    target_data = target.astype(float)
    ref_data = ref.astype(float)
    diff = ref_data - target_data
    diff = diff.flatten('C')
    rmse = math.sqrt(np.mean(diff ** 2.))

    return 20 * math.log10(255. / rmse)


def mse(target, ref):
    err = np.sum((target.astype('float') - ref.astype('float')) ** 2)
    err /= float(target.shape[0] * target.shape[1])
    return err


def compare_images(target, ref):
    scores = []
    scores.append(psnr(target, ref))
    scores.append(mse(target, ref))
    scores.append(ssim(target, ref, multichannel =True))
    
    return scores

def modcrop(img, scale):
    tmpsz = img.shape
    sz = tmpsz[0:2]
    
    # np.mod 是sz%scale
    sz = sz - np.mod(sz, scale)
    
    img = img[0:sz[0], 0:sz[1]]
    return img


def shave(image, border):
    "把周围去掉"
    img = image[border: -border, border: -border]
    return img


path = './Train'
deg=[]
ref=[]
# deg=np.array()
# ref = np.array()
count=0
for file in os.listdir('./Train'):
    if file != ".DS_Store":
        ref_e = cv2.imread(path+'/'+file)
        ref_e = cv2.cvtColor(ref_e, cv2.COLOR_BGR2YCrCb)

        ref_e=ref_e[:,:,1]
        ref_e=modcrop(ref_e,3)
#         size = ref_e.shape[0], ref_e.shape[1]
#         print(size)
        
        
        
        h = ref_e.shape[0]
        w = ref_e.shape[1]
        
        new_height = h // 2
        new_width = w // 2
        deg_e=cv2.resize(cv2.resize(ref_e,(new_width,new_height)),(w,h))
        
        temp1=np.zeros((32,32,1))
        temp2=np.zeros((20,20,1))
        
        for x in range(0,ref_e.shape[0]-33,14):
            for y in range(0,ref_e.shape[1]-33,14):
                temp1[:,:,0] = deg_e[x:x + 33 - 1, y:y + 33 - 1].astype(float) / 255;
                temp2[:,:,0] = ref_e[x + 6 : x + 6 + 21 - 1, y + 6 : y + 6 + 21 - 1].astype(float) / 255;
                deg.append(temp1)
                ref.append(temp2)
                
ref = np.array(ref)
deg = np.array(deg)


def model():

    SRCNN = Sequential()
    SRCNN.add(Conv2D(filters=64, kernel_size = (9, 9), activation='relu', padding='valid', use_bias=True, input_shape=(32, 32, 1)))
    SRCNN.add(Conv2D(filters=32, kernel_size = (3, 3), activation='relu', padding='same', use_bias=True))
    SRCNN.add(Conv2D(filters=1, kernel_size = (5, 5), padding='valid', use_bias=True))
    SRCNN.compile(optimizer='adam', loss='mean_squared_error', metrics=['mean_squared_error'])
    
    return SRCNN


checkpoint = ModelCheckpoint("SRCNN_check_1.h5", monitor='val_loss', verbose=1, save_best_only=True, save_weights_only=False, mode='min')
callbacks_list = [checkpoint]
history = srcnn.fit(deg,ref, validation_split=0.33, epochs=200, batch_size=128, verbose=1, callbacks=callbacks_list)


def predict(image_path):
    
    # load the degraded and reference images
    path, file = os.path.split(image_path)
    degraded = cv2.imread(image_path)
    ref = cv2.imread('Train/{}'.format(file))
    
    # preprocess the image with modcrop
    ref = modcrop(ref, 3)
    degraded = modcrop(degraded, 3)
    
    temp = cv2.cvtColor(degraded, cv2.COLOR_BGR2YCrCb)
    
    # create image slice and normalize  
    Y = numpy.zeros((1, temp.shape[0], temp.shape[1], 1), dtype=float)
    Y[0, :, :, 0] = temp[:, :, 0].astype(float) / 255
    
    srcnn.load_weights('SRCNN_check_1.h5')
    # perform super-resolution with srcnn
    pre = srcnn.predict(Y, batch_size=1)
    
    # post-process output
    pre *= 255
    pre[pre[:] > 255] = 255
    pre[pre[:] < 0] = 0
    pre = pre.astype(np.uint8)
    
    # copy Y channel back to image and convert to BGR
    temp = shave(temp, 6)
    temp[:, :, 0] = pre[0, :, :, 0]
    output = cv2.cvtColor(temp, cv2.COLOR_YCrCb2BGR)
    
    # remove border from reference and degraged image
    ref = shave(ref.astype(np.uint8), 6)
    degraded = shave(degraded.astype(np.uint8), 6)
    
    # image quality calculations
    scores = []
    scores.append(compare_images(degraded, ref))
    scores.append(compare_images(output, ref))
    
    # return images and scores
    return ref, degraded, output, scores


ref, degraded, output, scores= predict('train_lr/butterfly_GT.bmp')
print(degraded.shape)
print(output.shape)
print('Degraded Image: \nPSNR: {}\nMSE: {}\nSSIM: {}\n'.format(scores[0][0], scores[0][1], scores[0][2]))
print('Reconstructed Image: \nPSNR: {}\nMSE: {}\nSSIM: {}\n'.format(scores[1][0], scores[1][1], scores[1][2]))

fig, axs = plt.subplots(1, 3, figsize=(20, 8))
axs[0].imshow(cv2.cvtColor(degraded, cv2.COLOR_BGR2RGB))
axs[0].set_title('Degraded')
axs[1].imshow(cv2.cvtColor(output, cv2.COLOR_BGR2RGB))
axs[1].set_title('SRCNN')
axs[2].imshow(cv2.cvtColor(ref, cv2.COLOR_BGR2RGB))
axs[2].set_title('orginal')
# plt.savefig("test2.jpg")
# remove the x and y ticks
for ax in axs:
    ax.set_xticks([])
    ax.set_yticks([])

这种问题如果原作者的代码没有问题的话,你的设置也没有问题的话,那么就是你的数据量的问题了,数据量是否足够大,分布是否合理?
还有一种你要看下原作者的效果能达到什么样子,有些时候就是网络的瓶颈在那里的,一旦是这种情况那么久需要你自己根据具体的问题修改一些参数了,这种最难了。