现在有一个tensorflow模型已经训练完成存储在本地,我想在A线程中使用load-model加载这个模型,然后将这个模型传递给B线程,在B线程调用这个模型进行预测。但现在运行的时候报错
说是由于在加载模型时没有正确设置默认图形所导致的。我应该怎么做才能够在B线程运行这个模型。
我不能在B线程加载模型,因为在B线程中加载这个模型的话,加载模型会比较慢,而且B线程需要不断循环使用。所以不能再B线程在加载模型。只能在其他线程预加载完模型,传递给他。
问题答案由GPT-4Plus模型(语音)与博主程序员之梦@共同编写:
你遇到的问题是因为TensorFlow的计算是基于图(graph)的,而每个线程有自己的默认图。所以在一个线程中加载模型,然后传递到另一个线程使用会出现错误。
解决这个问题有两种方法:
python
import tensorflow as tf
# 定义默认图
graph = tf.compat.v1.get_default_graph()
# 在线程A中加载模型
with graph.as_default():
model = tf.keras.models.load_model('model.h5')
# 在线程B中使用model计算,指定使用线程A的图
with tf.compat.v1.Session(graph=graph) as sess:
output = model.predict(input)
python
# 在线程A中将Keras模型保存为SavedModel格式
model.save('saved_model')
# 在线程B中加载SavedModel并计算
model = tf.saved_model.load('saved_model')
output = model.predict(input)
SavedModel是TensorFlow的序列化格式,将模型结构和权重等全部封装在其中,各个线程加载后可以直接使用,无需指定图计算。
所以,推荐的方法是将Keras模型导出为TensorFlow SavedModel格式,然后在各个线程中加载SavedModel独立计算。
针对该问题,可以通过以下步骤解决:
在A线程中加载tensorflow模型,并设置默认图形,然后将模型以变量的形式保存起来。
在B线程中加载A线程中保存的模型变量,这样就不需要频繁地加载模型以影响性能。
在B线程中调用模型进行预测时,需要创建一个新的计算图,并将保存的模型变量加载到新的计算图中,同时设置默认图形,以便使用默认图形进行运算。
以下是具体的代码实现:
A线程中加载tensorflow模型:
import tensorflow as tf
# 加载模型
sess = tf.Session()
graph = tf.get_default_graph()
with graph.as_default():
saver = tf.train.import_meta_graph('model.ckpt.meta')
saver.restore(sess, 'model.ckpt')
# 获取模型中的变量
var1 = graph.get_tensor_by_name('var1:0')
var2 = graph.get_tensor_by_name('var2:0')
# 保存模型变量
model_vars = {'var1': var1, 'var2': var2}
B线程中加载A线程中保存的模型变量并调用模型进行预测:
import tensorflow as tf
# 加载A线程中保存的模型变量,并创建一个新的计算图
new_graph = tf.Graph()
with new_graph.as_default():
# 加载模型变量
var1 = tf.Variable(model_vars['var1'], name='var1')
var2 = tf.Variable(model_vars['var2'], name='var2')
# 进行预测
input_data = tf.placeholder(tf.float32, shape=(None, 100))
output = tf.add(tf.matmul(input_data, var1), var2)
prediction = tf.nn.softmax(output)
# 设置默认图形
sess = tf.Session()
with sess.as_default():
init = tf.global_variables_initializer()
sess.run(init)
tf.get_default_graph().finalize()
# 使用预测模型进行预测
prediction_val = sess.run(prediction, feed_dict={input_data: X_test})
需要注意的是,在B线程中调用模型进行预测时,需要重新创建一个新的计算图,并将保存的模型变量加载到新的计算图中,以避免与A线程中默认图形产生冲突。同时,在使用预测模型进行预测时,也需要设置默认图形,以便使用默认图形进行运算。
import tensorflow as tf
# 加载模型
sess = tf.Session()
graph = tf.get_default_graph()
with graph.as_default():
saver = tf.train.import_meta_graph('model.ckpt.meta')
saver.restore(sess, 'model.ckpt')
# 获取模型中的变量
var1 = graph.get_tensor_by_name('var1:0')
var2 = graph.get_tensor_by_name('var2:0')
# 保存模型变量
model_vars = {'var1': var1, 'var2': var2}