运行以下代码出现报错,但我的内存还有几百G,训练一个几百M的数据集完全够,难道和tensorflow版本是cpu有关系?
import glob
import tensorflow as tf
from tensorflow import keras
import numpy as np
from PIL import ImageDraw,Image
import os
import matplotlib.pyplot as plt
result_path = r"Mydataset_data\result"
if not os.path.exists(result_path):
os.makedirs(result_path)
net = keras.models.load_model('model.h5')
path = r"Mydataset_data\test"
img_path = glob.glob(path+'*png:')
for i, img in enumerate(img_path):
print(img)
label = img.split('.')[1: 5]
print(label)
print(type(label))
label = [int(i) for i in label]
x_img = tf.io.read_file(img)
x_img = tf.image.decode_jpeg(x_img,channels = 3)
x_img = (tf.cast(x_img, dtype=tf.float32) / 255. - 0.5) * 2
x_img_x = tf.expand_dims(x_img, axis=0)
out = net(x_img_x)
out = tf.squeeze(out, axis=0)
print(out)
out_put = out.numpy() * 300
print(out_put)
img_data = np.array(((x_img * 0.5 + 0.5) * 255), dtype=np.int8)
img = Image.fromarray(img_data, "RGB")
draw = ImageDraw.Draw(img)
draw.rectangle(label, outline = "red",width = 5)
draw.rectangle(out_put, outline="blue",width = 5)
img.show()
img.save(r"(0)/(1}-png".format(result_path, str(i)))
print(1)
if i == 500:
break
以下内容由CHATGPT及阿里嘎多学长共同生成、有用望采纳:
这个错误通常是因为模型训练时需要的内存超过了系统可用内存。有几种解决方法:
你可以尝试减少批次大小,并在每个训练步骤后清除未使用的变量。具体来说,可以在每个训练步骤后使用tf.keras.backend.clear_session()
方法清除未使用的变量,例如:
for i in range(num_steps):
# 训练一个批次的数据
loss = model.train_on_batch(x_batch, y_batch)
# 清除未使用的变量
tf.keras.backend.clear_session()
此外,你可以尝试使用更轻量级的模型或使用云计算平台来提供更多内存。
内存还有几百G,你是不是搞错了硬盘和内存?我服务器的内存也就是128GB
不知道你这个问题是否已经解决, 如果还没有解决的话: