我复现SRCNN的时候,不知道为什么model跑个200个epochs就没办法优化了,导致效果像抛硬币,有的图片提升了一丢丢,有的完全是变得更差了.
我用的训练图像是原作者提供的。
import numpy as np
from matplotlib import pyplot as plt
import sys
import keras
import cv2
import numpy
from keras.models import Sequential
from keras.layers import Conv2D
from keras.optimizers import Adam
from skimage.measure import compare_ssim as ssim
import cv2
import math
import random
from keras.callbacks import ModelCheckpoint
import os
from tensorflow.keras.callbacks import ModelCheckpoint
def psnr(target, ref):
target_data = target.astype(float)
ref_data = ref.astype(float)
diff = ref_data - target_data
diff = diff.flatten('C')
rmse = math.sqrt(np.mean(diff ** 2.))
return 20 * math.log10(255. / rmse)
def mse(target, ref):
err = np.sum((target.astype('float') - ref.astype('float')) ** 2)
err /= float(target.shape[0] * target.shape[1])
return err
def compare_images(target, ref):
scores = []
scores.append(psnr(target, ref))
scores.append(mse(target, ref))
scores.append(ssim(target, ref, multichannel =True))
return scores
def modcrop(img, scale):
tmpsz = img.shape
sz = tmpsz[0:2]
# np.mod 是sz%scale
sz = sz - np.mod(sz, scale)
img = img[0:sz[0], 0:sz[1]]
return img
def shave(image, border):
"把周围去掉"
img = image[border: -border, border: -border]
return img
path = './Train'
deg=[]
ref=[]
# deg=np.array()
# ref = np.array()
count=0
for file in os.listdir('./Train'):
if file != ".DS_Store":
ref_e = cv2.imread(path+'/'+file)
ref_e = cv2.cvtColor(ref_e, cv2.COLOR_BGR2YCrCb)
ref_e=ref_e[:,:,1]
ref_e=modcrop(ref_e,3)
# size = ref_e.shape[0], ref_e.shape[1]
# print(size)
h = ref_e.shape[0]
w = ref_e.shape[1]
new_height = h // 2
new_width = w // 2
deg_e=cv2.resize(cv2.resize(ref_e,(new_width,new_height)),(w,h))
temp1=np.zeros((32,32,1))
temp2=np.zeros((20,20,1))
for x in range(0,ref_e.shape[0]-33,14):
for y in range(0,ref_e.shape[1]-33,14):
temp1[:,:,0] = deg_e[x:x + 33 - 1, y:y + 33 - 1].astype(float) / 255;
temp2[:,:,0] = ref_e[x + 6 : x + 6 + 21 - 1, y + 6 : y + 6 + 21 - 1].astype(float) / 255;
deg.append(temp1)
ref.append(temp2)
ref = np.array(ref)
deg = np.array(deg)
def model():
SRCNN = Sequential()
SRCNN.add(Conv2D(filters=64, kernel_size = (9, 9), activation='relu', padding='valid', use_bias=True, input_shape=(32, 32, 1)))
SRCNN.add(Conv2D(filters=32, kernel_size = (3, 3), activation='relu', padding='same', use_bias=True))
SRCNN.add(Conv2D(filters=1, kernel_size = (5, 5), padding='valid', use_bias=True))
SRCNN.compile(optimizer='adam', loss='mean_squared_error', metrics=['mean_squared_error'])
return SRCNN
checkpoint = ModelCheckpoint("SRCNN_check_1.h5", monitor='val_loss', verbose=1, save_best_only=True, save_weights_only=False, mode='min')
callbacks_list = [checkpoint]
history = srcnn.fit(deg,ref, validation_split=0.33, epochs=200, batch_size=128, verbose=1, callbacks=callbacks_list)
def predict(image_path):
# load the degraded and reference images
path, file = os.path.split(image_path)
degraded = cv2.imread(image_path)
ref = cv2.imread('Train/{}'.format(file))
# preprocess the image with modcrop
ref = modcrop(ref, 3)
degraded = modcrop(degraded, 3)
temp = cv2.cvtColor(degraded, cv2.COLOR_BGR2YCrCb)
# create image slice and normalize
Y = numpy.zeros((1, temp.shape[0], temp.shape[1], 1), dtype=float)
Y[0, :, :, 0] = temp[:, :, 0].astype(float) / 255
srcnn.load_weights('SRCNN_check_1.h5')
# perform super-resolution with srcnn
pre = srcnn.predict(Y, batch_size=1)
# post-process output
pre *= 255
pre[pre[:] > 255] = 255
pre[pre[:] < 0] = 0
pre = pre.astype(np.uint8)
# copy Y channel back to image and convert to BGR
temp = shave(temp, 6)
temp[:, :, 0] = pre[0, :, :, 0]
output = cv2.cvtColor(temp, cv2.COLOR_YCrCb2BGR)
# remove border from reference and degraged image
ref = shave(ref.astype(np.uint8), 6)
degraded = shave(degraded.astype(np.uint8), 6)
# image quality calculations
scores = []
scores.append(compare_images(degraded, ref))
scores.append(compare_images(output, ref))
# return images and scores
return ref, degraded, output, scores
ref, degraded, output, scores= predict('train_lr/butterfly_GT.bmp')
print(degraded.shape)
print(output.shape)
print('Degraded Image: \nPSNR: {}\nMSE: {}\nSSIM: {}\n'.format(scores[0][0], scores[0][1], scores[0][2]))
print('Reconstructed Image: \nPSNR: {}\nMSE: {}\nSSIM: {}\n'.format(scores[1][0], scores[1][1], scores[1][2]))
fig, axs = plt.subplots(1, 3, figsize=(20, 8))
axs[0].imshow(cv2.cvtColor(degraded, cv2.COLOR_BGR2RGB))
axs[0].set_title('Degraded')
axs[1].imshow(cv2.cvtColor(output, cv2.COLOR_BGR2RGB))
axs[1].set_title('SRCNN')
axs[2].imshow(cv2.cvtColor(ref, cv2.COLOR_BGR2RGB))
axs[2].set_title('orginal')
# plt.savefig("test2.jpg")
# remove the x and y ticks
for ax in axs:
ax.set_xticks([])
ax.set_yticks([])
这种问题如果原作者的代码没有问题的话,你的设置也没有问题的话,那么就是你的数据量的问题了,数据量是否足够大,分布是否合理?
还有一种你要看下原作者的效果能达到什么样子,有些时候就是网络的瓶颈在那里的,一旦是这种情况那么久需要你自己根据具体的问题修改一些参数了,这种最难了。