使用tensorflow跑模型时,出现错误AttributeError: 'BatchDataset' object has no attribute 'ndim'
history=model.fit(train_dataset,batch_size=batch_size, epochs=10,validation_data=val_dataset)
ttributeError Traceback (most recent call last)
in
----> 1 history=model.fit(train_dataset,batch_size=batch_size, epochs=10,validation_data=val_dataset)
~\Anaconda3\envs\zktest\lib\site-packages\tensorflow\python\keras\_impl\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
1141 sample_weight=sample_weight,
1142 class_weight=class_weight,
-> 1143 batch_size=batch_size)
1144 # Prepare validation data.
1145 if validation_data:
~\Anaconda3\envs\zktest\lib\site-packages\tensorflow\python\keras\_impl\keras\engine\training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, batch_size)
763 feed_input_shapes,
764 check_batch_axis=False, # Don't enforce the batch size.
--> 765 exception_prefix='input')
766
767 if y is not None:
~\Anaconda3\envs\zktest\lib\site-packages\tensorflow\python\keras\_impl\keras\engine\training_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
148 data = data.values if data.__class__.__name__ == 'DataFrame' else data
149 data = [data]
--> 150 data = [standardize_single_array(x) for x in data]
151
152 if len(data) != len(names):
~\Anaconda3\envs\zktest\lib\site-packages\tensorflow\python\keras\_impl\keras\engine\training_utils.py in (.0)
148 data = data.values if data.__class__.__name__ == 'DataFrame' else data
149 data = [data]
--> 150 data = [standardize_single_array(x) for x in data]
151
152 if len(data) != len(names):
~\Anaconda3\envs\zktest\lib\site-packages\tensorflow\python\keras\_impl\keras\engine\training_utils.py in standardize_single_array(x)
86 elif tensor_util.is_tensor(x):
87 return x
---> 88 elif x.ndim == 1:
89 x = np.expand_dims(x, 1)
90 return x
AttributeError: 'BatchDataset' object has no attribute 'ndim'
A
我感觉是版本问题,但是找不到哪里改
试试把这个放在代码开头:
os.environ['TF_KERAS'] = '1' # 必须使用tf.keras,注意,这里要放在引用keras之前
参考链接:
https://blog.csdn.net/weixin_37251044/article/details/124730342
这是相关代码
import time
import sys
import os
import tensorflow as tf
os.environ['TF_KERAS'] = '1'
from tensorflow import keras
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import os
from pyrsgis import raster
%matplotlib inline
sys.path.append("..")
def read_images(images_dir,gt_images_dir,is_train=True):#将影像和标注读入内存
if(is_train):
images_dir_list = os.listdir(images_dir)
images_dir_list.sort()
gt_images_dir_list = os.listdir(gt_images_dir)
gt_images_dir_list.sort()
#shape=(h,w,c)
images, gt_images = [None] * len(images_dir_list),[None] * len(images_dir_list)
for i,filename in zip(range(0,len(images_dir_list)),images_dir_list):
ds,images_temp = raster.read(images_dir+filename,bands='all')#读取tiff
#print(images_temp.shape)
images_temp = tf.cast(images_temp,dtype=tf.float32)
# print(images_temp.shape)
images[i] = tf.transpose(images_temp,[1,2,0])
print(images[i].shape)
for i,filename in zip(range(0,len(gt_images_dir_list)),gt_images_dir_list):
ds,gt_images_temp = raster.read(gt_images_dir+filename,bands='all')#读取tiff
gt_images_temp = tf.cast(gt_images_temp,dtype=tf.float32)
# print(gt_images_temp.shape)
gt_images[i] = tf.transpose(gt_images_temp,[1,2,0])
print(gt_images[i].shape)
return images,gt_images
else:
images_dir_list = os.listdir(images_dir)
images_dir_list.sort()
images = [None] * len(images_dir_list)
for i,filename in zip(range(0,len(images_dir_list)),images_dir_list):
images_tmp = tf.io.read_file('%s/%s'%(images_dir,filename))
images[i] = tf.image.decode_png(images_tmp)
return images
def read_images(images_dir,gt_images_dir,is_train=True):#将影像和标注读入内存
if(is_train):
images_dir_list = os.listdir(images_dir)
images_dir_list.sort()
gt_images_dir_list = os.listdir(gt_images_dir)
gt_images_dir_list.sort()
#shape=(h,w,c)
images, gt_images = [None] * len(images_dir_list),[None] * len(images_dir_list)
for i,filename in zip(range(0,len(images_dir_list)),images_dir_list):
ds,images_temp = raster.read(images_dir+filename,bands='all')#读取tiff
#print(images_temp.shape)
images_temp = tf.cast(images_temp,dtype=tf.float32)
# print(images_temp.shape)
images[i] = tf.transpose(images_temp,[1,2,0])
print(images[i].shape)
for i,filename in zip(range(0,len(gt_images_dir_list)),gt_images_dir_list):
ds,gt_images_temp = raster.read(gt_images_dir+filename,bands='all')#读取tiff
gt_images_temp = tf.cast(gt_images_temp,dtype=tf.float32)
# print(gt_images_temp.shape)
gt_images[i] = tf.transpose(gt_images_temp,[1,2,0])
print(gt_images[i].shape)
return images,gt_images
else:
images_dir_list = os.listdir(images_dir)
images_dir_list.sort()
images = [None] * len(images_dir_list)
for i,filename in zip(range(0,len(images_dir_list)),images_dir_list):
images_tmp = tf.io.read_file('%s/%s'%(images_dir,filename))
images[i] = tf.image.decode_png(images_tmp)
return images
# #定义颜色和类别常量
COLORMAP = [[255, 255, 255], [1, 1, 1],[0, 0, 0]]
CALSSES = ['fcs','cs','bj']
#将颜色与标签对应起来
colormap2label = np.zeros(256 * 3,dtype=np.uint8) #定义一个一维的全是0,256的三次方,类型是uint8的数组,为什么这样定义呢?是因为之前定义的colormap2label数组最大是256三次方。现在根据自己的定义,colormap2label大小为256*3
# print(print(colormap2label.shape))
for i,colormap in enumerate(COLORMAP):
colormap2label[colormap[0] + colormap[1]+ colormap[2]] = i#将公式的值作为索引,赋值为i
colormap2label = tf.convert_to_tensor(colormap2label)
# print('索引类别',colormap2label[765])
#说明目前问题出在了colormap2label与idx的值不对应(已解决))
def gt_image_indices(colormap,colormap2label):
colormap=tf.cast(colormap,dtype=tf.int32) #强制转化整形
print(colormap)
idx = tf.add(colormap[:, :, 0],colormap[:, :, 1])
idx = tf.add(idx,colormap[:, :, 2])
#idx = tf.add(idx, colormap[:, :, 2])
print(idx)
idx=tf.expand_dims(idx, -1)
print(idx.shape)
return tf.gather_nd(colormap2label, idx)
tr_images_dir = r"D:/zch/data running/data1_256/1/"
tr_labels_dir = r"D:/zch/data running/data1_256/2/"
train_dataset = get_dataset(tr_images_dir,tr_labels_dir,colormap2label,is_train=True)
v_images_dir = "D:/zch/data running/data1_256/1/"
v_gt_images_dir = "D:/zch/data running/data1_256/2/"
val_dataset = get_dataset(v_images_dir,v_gt_images_dir,colormap2label,True)
# print(val_dataset)
batch_size = 8
train_dataset =train_dataset.shuffle(buffer_size=19000).batch(batch_size)
val_dataset = val_dataset.shuffle(buffer_size=10000).batch(batch_size)
from builders.model_builder import builder
model,base_model = builder(3,input_size=(256,256),model='UNet', base_model=None)
model.compile(optimizer = tf.keras.optimizers.Adam(lr=1e-5),
loss=keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy'])
history=model.fit(train_dataset,batch_size=batch_size, epochs=10,validation_data=val_dataset)