在训练网络时,报错如下:
代码如下:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from datetime import datetime
import time
import random
import tensorflow as tf
import numpy as np
# import QualityNet
from QualityNet import QualityNet
SEED = 66478 # Set to None for random seed.
NUM_EPOCHS = 15
DECAY_EPOCH = 5
LEARNING_RATE = 0.01
LEARNING_RATE_DECAY = 0.99#0.1
MAX_STEPS = 50
BATCH_SIZE = 100
PATCH_SIZE = (32, 32)
NUM_CHANNELS = 1
MODEL_DIR = "tf_models"
if not os.path.isdir(MODEL_DIR):
os.mkdir(MODEL_DIR)
def run():
tf_file_path = r"H:/RCNN3/data/*.tfrecords"
files = tf.io.match_filenames_once(tf_file_path)
#创建文件队列shuffle参数设置为True,打乱文件
filename_queue = tf.train.string_input_producer(files, shuffle=True)
#实例化TFRecordReader类,准备读取TFRecord文件
reader = tf.TFRecordReader()
_,serialized_example = reader.read(filename_queue)
#解析读取的样例,
features = tf.parse_single_example(
serialized_example,
features={
"image_raw":tf.FixedLenFeature([], tf.string),
"image_ref":tf.FixedLenFeature([], tf.string),
"label":tf.FixedLenFeature([], tf.int64),
"ref_idx":tf.FixedLenFeature([], tf.int64)
})
images_SR = tf.decode_raw(features["image_raw"], tf.uint8)
images_SR = tf.reshape(images_SR, [32, 32, 1], )
images_LR = tf.decode_raw(features["image_ref"], tf.uint8)
images_LR = tf.reshape(images_LR, [32, 32, 1])
labels = tf.cast(features["label"], tf.int32)
ref_idx = tf.cast(features["ref_idx"], tf.int32)
#设置每个batch中样例的个数
batch_size = 100#150*1620=24300张图像,训练集为194400张图像
#用于组合成batch的队列中最多可以缓存的样例的个数
capacity = 1944 + 3 * batch_size
#使用batch()函数将样例组合成batch
#函数原型batch(tensors,batch_size,num_threads,capacity,enqueue_many,shapes,
# dynamic_pad,allow_smaller_final_batch,shared_name,name)
image_SR_batch, image_LR_batch, label_batch = tf.train.batch([images_SR, images_LR, labels],
batch_size=batch_size,num_threads=1,capacity=capacity,)
tf.reset_default_graph()
with tf.Session() as sess:
global_step = tf.train.get_or_create_global_step ()
batch = tf.Variable(0, dtype=tf.int32)
epoch_counter = tf.Variable (0, dtype=tf.int32)
x = tf.placeholder(tf.float32, [batch_size, 32, 32, 1], name="input")
y_ = tf.placeholder(tf.float32, [batch_size], name="output")
net = QualityNet(x, NUM_CHANNELS, SEED=SEED)
net.build_graph()
y = net.forward(x, train=True)
training_step = tf.Variable(0, trainable=False)
averages_class = tf.train.ExponentialMovingAverage(0.99,training_step)
#定义一个更新变量滑动平均值的操作需要向滑动平均类的apply()函数提供一个参数列表
#train_variables()函数返回集合图上Graph.TRAINABLE_VARIABLES中的元素,
#这个集合的元素就是所有没有指定trainable_variables=False的参数
loss = tf.divide(tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(y, y_)))), BATCH_SIZE)
var_list = net.parameters
epoch_inc_op = tf.assign(epoch_counter, epoch_counter + 1)
learning_rate = tf.train.exponential_decay(LEARNING_RATE, training_step, epoch_counter,
LEARNING_RATE_DECAY, staircase=True)
optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.9, use_nesterov=True)
gvs = optimizer.compute_gradients(loss, var_list)
capped_gvs = [(tf.clip_by_norm(gv[0], 1), gv[1]) for gv in gvs]
train_op = optimizer.apply_gradients(capped_gvs)
saver = tf.train.Saver(var_list, max_to_keep=1)
tf.global_variables_initializer().run()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
#一般在这个循环内开始训练,这里设定了训练轮数为3
#在每一轮训练的过程中都会执行一个组合样例为batch的操作并打印出来
for epoch in range(MAX_STEPS):
epoch_loss = 0
for train_data_SR, train_data_LR, train_labels in zip(image_SR_batch, image_LR_batch, label_batch):
feed_dict = {x: train_data_SR,
#x2: train_data_LR,
y_: train_labels}
loss_val = sess.run([train_op, loss], feed_dict=feed_dict)
epoch_loss += loss_val[-1]
#xs1, xs2, ys = sess.run([image_SR_batch, image_LR_batch, label_batch])
# print(xs1, xs2, ys)
sess.run(epoch_inc_op)
print("Epoch loss:", epoch_loss)
saver.save(sess, os.path.join (MODEL_DIR, 'model_' + str(epoch + 1)))
coord.request_stop()
coord.join(threads)
if __name__ == '__main__':
run()