加载已有的checkpoint文件时报错
主要的问题代码如下:
loader = tf.train.Saver(var_list=restore_var)
load(loader, sess, args.model_weights)
全部代码如下:
from __future__ import print_function
import argparse
import os
from PIL import Image
import tensorflow._api.v2.compat.v1 as tf
import numpy as np
from deeplab_resnet import DeepLabResNetModel, decode_labels
IMG_MEAN = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32)
NUM_CLASSES = 21
SAVE_DIR = './output/'
IMAGE_PATH = 'dataset/JPEGImages/2007_000039.jpeg'
RESTORE_FROM = './ini_model/model.ckpt-40000'
def get_arguments():
"""解析控制台参数.
Returns:参数列表
"""
parser = argparse.ArgumentParser(description="DeepLab Network Inference.")
parser.add_argument("--img-path", type=str, default=IMAGE_PATH,
help="Path to the RGB image file.")
parser.add_argument("--model-weights", type=str, default=RESTORE_FROM,
help="Path to the file with model weights.")
parser.add_argument("--num-classes", type=int, default=NUM_CLASSES,
help="Number of classes to predict (including background).")
parser.add_argument("--save-dir", type=str, default=SAVE_DIR,
help="Where to save predicted mask.")
return parser.parse_args()
def load(saver, sess, ckpt_path):
'''加载已训练的权重参数.
Args:
saver: TensorFlow Saver 存储器对象.
sess: TensorFlow session.
ckpt_path: checkpoint权重参数文件路径.
'''
saver.restore(sess, ckpt_path)
print("Restored model parameters from {}".format(ckpt_path))
def main():
"""主函数:模型构建和evaluate."""
args = get_arguments()
print(1)
tf.compat.v1.disable_eager_execution()
# 读取图片.
img = tf.image.decode_jpeg(tf.read_file(args.img_path), channels=3)
#image_raw_data = tf.gfile.FastGFile('/home/penglu/Desktop/11.jpg').read()
#img = tf.image.decode_jpeg('./dataset/JPEGImages/2007_000039.jpeg', channels=3)
# 通道转换: RGB --> BGR.
img_r, img_g, img_b = tf.split(axis=2, num_or_size_splits=3, value=img)
img = tf.cast(tf.concat(axis=2, values=[img_b, img_g, img_r]), dtype=tf.float32)
# 减去像素均值.
img -= IMG_MEAN
# 构建DeepLab-ResNet-101网络.
net = DeepLabResNetModel({'data': tf.expand_dims(img, dim=0)}, is_training=False, num_classes=args.num_classes)
# 设定要预加载的网络权重参数
restore_var = tf.global_variables()
# 执行预测.
raw_output = net.layers['fc1_voc12']
raw_output_up = tf.image.resize_bilinear(raw_output, tf.shape(img)[0:2,])
raw_output_up = tf.argmax(raw_output_up, dimension=3)
pred = tf.expand_dims(raw_output_up, dim=3)
# 建立tf session
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
# 执行权重变量初始化
init = tf.global_variables_initializer()
#无法执行sess.run()的原因是tensorflow版本不同导致的,tensorflow版本2.0无法兼容版本1.0.解决办法:tf.compat.v1.disable_eager_execution()
print(2)
sess.run(init)
print(3)
# 加载已有的checkpoint文件
loader = tf.train.Saver(var_list=restore_var)
load(loader, sess, args.model_weights)
print(5)
# 执行推断.
preds = sess.run(pred)
print(6)
msk = decode_labels(preds, num_classes=args.num_classes)
im = Image.fromarray(msk[0])
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
im.save(args.save_dir + 'mask.png')
print('The output file has been saved to {}'.format(args.save_dir + 'mask.png'))
if __name__ == '__main__':
main()
主要报错内容是::tensorflow.python.framework.errors_impl.NotFoundError: Key bn2a_branch1/batch_normalization/beta not found in checkpoint
[[{{node save/RestoreV2}}]]
全部报错内容:
C:\Users\dou\.conda\envs\tensorflow_38\python.exe E:/学习/图像分割/深度学习图像处理目标检测图像分割/代码/ch11/DeepLab_TF/inference.py
1
WARNING:tensorflow:From C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\util\dispatch.py:1082: calling expand_dims (from tensorflow.python.ops.array_ops) with dim is deprecated and will be removed in a future version.
Instructions for updating:
Use the `axis` argument instead
num_classes: <class 'int'>
E:\学习\图像分割\深度学习图像处理目标检测图像分割\代码\ch11\DeepLab_TF\deeplab_resnet\network.py:314: UserWarning: `tf.layers.batch_normalization` is deprecated and will be removed in a future version. Please use `tf.keras.layers.BatchNormalization` instead. In particular, `tf.control_dependencies(tf.GraphKeys.UPDATE_OPS)` should not be used (consult the `tf.keras.layers.BatchNormalization` documentation).
output = tf.layers.batch_normalization(
WARNING:tensorflow:From C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\util\dispatch.py:1082: calling argmax (from tensorflow.python.ops.math_ops) with dimension is deprecated and will be removed in a future version.
Instructions for updating:
Use the `axis` argument instead
2022-07-05 10:26:34.611191: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX AVX2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-07-05 10:26:35.299105: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'cudnn64_8.dll'; dlerror: cudnn64_8.dll not found
2022-07-05 10:26:35.299650: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1850] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
2
2022-07-05 10:26:35.867355: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:354] MLIR V1 optimization pass is not enabled
3
4
2022-07-05 10:26:39.039911: W tensorflow/core/framework/op_kernel.cc:1745] OP_REQUIRES failed at save_restore_v2_ops.cc:228 : NOT_FOUND: Key bn2a_branch1/batch_normalization/beta not found in checkpoint
Traceback (most recent call last):
File "C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\client\session.py", line 1377, in _do_call
return fn(*args)
File "C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\client\session.py", line 1360, in _run_fn
return self._call_tf_sessionrun(options, feed_dict, fetch_list,
File "C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\client\session.py", line 1453, in _call_tf_sessionrun
return tf_session.TF_SessionRun_wrapper(self._session, options, feed_dict,
tensorflow.python.framework.errors_impl.NotFoundError: Key bn2a_branch1/batch_normalization/beta not found in checkpoint
[[{{node save/RestoreV2}}]]
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\training\saver.py", line 1417, in restore
sess.run(self.saver_def.restore_op_name,
File "C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\client\session.py", line 967, in run
result = self._run(None, fetches, feed_dict, options_ptr,
File "C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\client\session.py", line 1190, in _run
results = self._do_run(handle, final_targets, final_fetches,
File "C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\client\session.py", line 1370, in _do_run
return self._do_call(_run_fn, feeds, fetches, targets, options,
File "C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\client\session.py", line 1396, in _do_call
raise type(e)(node_def, op, message) # pylint: disable=no-value-for-parameter
tensorflow.python.framework.errors_impl.NotFoundError: Graph execution error:
Detected at node 'save/RestoreV2' defined at (most recent call last):
File "E:/学习/图像分割/深度学习图像处理目标检测图像分割/代码/ch11/DeepLab_TF/inference.py", line 100, in <module>
main()
File "E:/学习/图像分割/深度学习图像处理目标检测图像分割/代码/ch11/DeepLab_TF/inference.py", line 83, in main
loader = tf.train.Saver(var_list=restore_var)
Node: 'save/RestoreV2'
Key bn2a_branch1/batch_normalization/beta not found in checkpoint
[[{{node save/RestoreV2}}]]
Original stack trace for 'save/RestoreV2':
File "E:/学习/图像分割/深度学习图像处理目标检测图像分割/代码/ch11/DeepLab_TF/inference.py", line 100, in <module>
main()
File "E:/学习/图像分割/深度学习图像处理目标检测图像分割/代码/ch11/DeepLab_TF/inference.py", line 83, in main
loader = tf.train.Saver(var_list=restore_var)
File "C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\training\saver.py", line 933, in __init__
self.build()
File "C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\training\saver.py", line 945, in build
self._build(self._filename, build_save=True, build_restore=True)
File "C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\training\saver.py", line 973, in _build
self.saver_def = self._builder._build_internal( # pylint: disable=protected-access
File "C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\training\saver.py", line 543, in _build_internal
restore_op = self._AddRestoreOps(filename_tensor, saveables,
File "C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\training\saver.py", line 363, in _AddRestoreOps
all_tensors = self.bulk_restore(filename_tensor, saveables, preferred_shard,
File "C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\training\saver.py", line 611, in bulk_restore
return io_ops.restore_v2(filename_tensor, names, slices, dtypes)
File "C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\ops\gen_io_ops.py", line 1500, in restore_v2
_, _, _op, _outputs = _op_def_library._apply_op_helper(
File "C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 797, in _apply_op_helper
op = g._create_op_internal(op_type_name, inputs, dtypes=None,
File "C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\framework\ops.py", line 3754, in _create_op_internal
ret = Operation(
File "C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\framework\ops.py", line 2133, in __init__
self._traceback = tf_stack.extract_stack_for_node(self._c_op)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\training\py_checkpoint_reader.py", line 66, in get_tensor
return CheckpointReader.CheckpointReader_GetTensor(
RuntimeError: Key _CHECKPOINTABLE_OBJECT_GRAPH not found in checkpoint
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\training\saver.py", line 1428, in restore
names_to_keys = object_graph_key_mapping(save_path)
File "C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\training\saver.py", line 1749, in object_graph_key_mapping
object_graph_string = reader.get_tensor(trackable.OBJECT_GRAPH_PROTO_KEY)
File "C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\training\py_checkpoint_reader.py", line 71, in get_tensor
error_translator(e)
File "C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\training\py_checkpoint_reader.py", line 31, in error_translator
raise errors_impl.NotFoundError(None, None, error_message)
tensorflow.python.framework.errors_impl.NotFoundError: Key _CHECKPOINTABLE_OBJECT_GRAPH not found in checkpoint
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "E:/学习/图像分割/深度学习图像处理目标检测图像分割/代码/ch11/DeepLab_TF/inference.py", line 100, in <module>
main()
File "E:/学习/图像分割/深度学习图像处理目标检测图像分割/代码/ch11/DeepLab_TF/inference.py", line 85, in main
load(loader, sess, args.model_weights)
File "E:/学习/图像分割/深度学习图像处理目标检测图像分割/代码/ch11/DeepLab_TF/inference.py", line 43, in load
saver.restore(sess, ckpt_path)
File "C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\training\saver.py", line 1433, in restore
raise _wrap_restore_error_with_msg(
tensorflow.python.framework.errors_impl.NotFoundError: Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:
Graph execution error:
Detected at node 'save/RestoreV2' defined at (most recent call last):
File "E:/学习/图像分割/深度学习图像处理目标检测图像分割/代码/ch11/DeepLab_TF/inference.py", line 100, in <module>
main()
File "E:/学习/图像分割/深度学习图像处理目标检测图像分割/代码/ch11/DeepLab_TF/inference.py", line 83, in main
loader = tf.train.Saver(var_list=restore_var)
Node: 'save/RestoreV2'
Key bn2a_branch1/batch_normalization/beta not found in checkpoint
[[{{node save/RestoreV2}}]]
Original stack trace for 'save/RestoreV2':
File "E:/学习/图像分割/深度学习图像处理目标检测图像分割/代码/ch11/DeepLab_TF/inference.py", line 100, in <module>
main()
File "E:/学习/图像分割/深度学习图像处理目标检测图像分割/代码/ch11/DeepLab_TF/inference.py", line 83, in main
loader = tf.train.Saver(var_list=restore_var)
File "C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\training\saver.py", line 933, in __init__
self.build()
File "C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\training\saver.py", line 945, in build
self._build(self._filename, build_save=True, build_restore=True)
File "C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\training\saver.py", line 973, in _build
self.saver_def = self._builder._build_internal( # pylint: disable=protected-access
File "C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\training\saver.py", line 543, in _build_internal
restore_op = self._AddRestoreOps(filename_tensor, saveables,
File "C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\training\saver.py", line 363, in _AddRestoreOps
all_tensors = self.bulk_restore(filename_tensor, saveables, preferred_shard,
File "C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\training\saver.py", line 611, in bulk_restore
return io_ops.restore_v2(filename_tensor, names, slices, dtypes)
File "C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\ops\gen_io_ops.py", line 1500, in restore_v2
_, _, _op, _outputs = _op_def_library._apply_op_helper(
File "C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 797, in _apply_op_helper
op = g._create_op_internal(op_type_name, inputs, dtypes=None,
File "C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\framework\ops.py", line 3754, in _create_op_internal
ret = Operation(
File "C:\Users\dou\.conda\envs\tensorflow_38\lib\site-packages\tensorflow\python\framework\ops.py", line 2133, in __init__
self._traceback = tf_stack.extract_stack_for_node(self._c_op)
Process finished with exit code 1
找不到 bn2a_branch1/batch_normalization/beta 原因:
可能 你checkpoint之后,代码修改了导致的;