tensorflow跑模型报错 'BatchDataset' object has no attribute 'ndim'

问题遇到的现象和发生背景

使用tensorflow跑模型时,出现错误AttributeError: 'BatchDataset' object has no attribute 'ndim'

问题相关代码,请勿粘贴截图

history=model.fit(train_dataset,batch_size=batch_size, epochs=10,validation_data=val_dataset)
运行结果及报错内容

ttributeError                            Traceback (most recent call last)
 in 
----> 1 history=model.fit(train_dataset,batch_size=batch_size, epochs=10,validation_data=val_dataset)

~\Anaconda3\envs\zktest\lib\site-packages\tensorflow\python\keras\_impl\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
   1141         sample_weight=sample_weight,
   1142         class_weight=class_weight,
-> 1143         batch_size=batch_size)
   1144     # Prepare validation data.
   1145     if validation_data:

~\Anaconda3\envs\zktest\lib\site-packages\tensorflow\python\keras\_impl\keras\engine\training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, batch_size)
    763         feed_input_shapes,
    764         check_batch_axis=False,  # Don't enforce the batch size.
--> 765         exception_prefix='input')
    766 
    767     if y is not None:

~\Anaconda3\envs\zktest\lib\site-packages\tensorflow\python\keras\_impl\keras\engine\training_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
    148     data = data.values if data.__class__.__name__ == 'DataFrame' else data
    149     data = [data]
--> 150   data = [standardize_single_array(x) for x in data]
    151 
    152   if len(data) != len(names):

~\Anaconda3\envs\zktest\lib\site-packages\tensorflow\python\keras\_impl\keras\engine\training_utils.py in (.0)
    148     data = data.values if data.__class__.__name__ == 'DataFrame' else data
    149     data = [data]
--> 150   data = [standardize_single_array(x) for x in data]
    151 
    152   if len(data) != len(names):

~\Anaconda3\envs\zktest\lib\site-packages\tensorflow\python\keras\_impl\keras\engine\training_utils.py in standardize_single_array(x)
     86   elif tensor_util.is_tensor(x):
     87     return x
---> 88   elif x.ndim == 1:
     89     x = np.expand_dims(x, 1)
     90   return x

AttributeError: 'BatchDataset' object has no attribute 'ndim'

A

我的解答思路和尝试过的方法

我感觉是版本问题,但是找不到哪里改

试试把这个放在代码开头:
os.environ['TF_KERAS'] = '1' # 必须使用tf.keras,注意,这里要放在引用keras之前
参考链接:
https://blog.csdn.net/weixin_37251044/article/details/124730342

这是相关代码


import time
import sys
import os
import tensorflow as tf
os.environ['TF_KERAS'] = '1' 
from tensorflow import keras 
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import os
from pyrsgis import raster
%matplotlib inline
sys.path.append("..")

def read_images(images_dir,gt_images_dir,is_train=True):#将影像和标注读入内存
    if(is_train):
        images_dir_list = os.listdir(images_dir)
        images_dir_list.sort()
        gt_images_dir_list = os.listdir(gt_images_dir)
        gt_images_dir_list.sort()
        #shape=(h,w,c)
        images, gt_images = [None] * len(images_dir_list),[None] * len(images_dir_list)
        for i,filename in zip(range(0,len(images_dir_list)),images_dir_list):
            ds,images_temp = raster.read(images_dir+filename,bands='all')#读取tiff
            #print(images_temp.shape)
            images_temp = tf.cast(images_temp,dtype=tf.float32)
#             print(images_temp.shape)
            images[i] = tf.transpose(images_temp,[1,2,0])
            print(images[i].shape)
        for i,filename in zip(range(0,len(gt_images_dir_list)),gt_images_dir_list):
            ds,gt_images_temp = raster.read(gt_images_dir+filename,bands='all')#读取tiff
            gt_images_temp = tf.cast(gt_images_temp,dtype=tf.float32)
#             print(gt_images_temp.shape)
            gt_images[i] = tf.transpose(gt_images_temp,[1,2,0])
            print(gt_images[i].shape)
        return images,gt_images
    else:
        images_dir_list = os.listdir(images_dir)
        images_dir_list.sort()

        images = [None] * len(images_dir_list)
        for i,filename in zip(range(0,len(images_dir_list)),images_dir_list):
            images_tmp = tf.io.read_file('%s/%s'%(images_dir,filename))
            images[i] = tf.image.decode_png(images_tmp)
        return images


def read_images(images_dir,gt_images_dir,is_train=True):#将影像和标注读入内存
    if(is_train):
        images_dir_list = os.listdir(images_dir)
        images_dir_list.sort()
        gt_images_dir_list = os.listdir(gt_images_dir)
        gt_images_dir_list.sort()
        #shape=(h,w,c)
        images, gt_images = [None] * len(images_dir_list),[None] * len(images_dir_list)
        for i,filename in zip(range(0,len(images_dir_list)),images_dir_list):
            ds,images_temp = raster.read(images_dir+filename,bands='all')#读取tiff
            #print(images_temp.shape)
            images_temp = tf.cast(images_temp,dtype=tf.float32)
#             print(images_temp.shape)
            images[i] = tf.transpose(images_temp,[1,2,0])
            print(images[i].shape)
        for i,filename in zip(range(0,len(gt_images_dir_list)),gt_images_dir_list):
            ds,gt_images_temp = raster.read(gt_images_dir+filename,bands='all')#读取tiff
            gt_images_temp = tf.cast(gt_images_temp,dtype=tf.float32)
#             print(gt_images_temp.shape)
            gt_images[i] = tf.transpose(gt_images_temp,[1,2,0])
            print(gt_images[i].shape)
        return images,gt_images
    else:
        images_dir_list = os.listdir(images_dir)
        images_dir_list.sort()

        images = [None] * len(images_dir_list)
        for i,filename in zip(range(0,len(images_dir_list)),images_dir_list):
            images_tmp = tf.io.read_file('%s/%s'%(images_dir,filename))
            images[i] = tf.image.decode_png(images_tmp)
        return images


# #定义颜色和类别常量
COLORMAP = [[255, 255, 255], [1, 1, 1],[0, 0, 0]]
CALSSES = ['fcs','cs','bj']

#将颜色与标签对应起来
colormap2label = np.zeros(256 * 3,dtype=np.uint8) #定义一个一维的全是0,256的三次方,类型是uint8的数组,为什么这样定义呢?是因为之前定义的colormap2label数组最大是256三次方。现在根据自己的定义,colormap2label大小为256*3
# print(print(colormap2label.shape))
for i,colormap in enumerate(COLORMAP):
    colormap2label[colormap[0] + colormap[1]+ colormap[2]] = i#将公式的值作为索引,赋值为i
colormap2label = tf.convert_to_tensor(colormap2label)
# print('索引类别',colormap2label[765])

#说明目前问题出在了colormap2label与idx的值不对应(已解决))
def gt_image_indices(colormap,colormap2label):
    colormap=tf.cast(colormap,dtype=tf.int32)   #强制转化整形
    print(colormap)
    idx = tf.add(colormap[:, :, 0],colormap[:, :, 1])
    idx = tf.add(idx,colormap[:, :, 2])
    #idx = tf.add(idx, colormap[:, :, 2])
    print(idx)
    idx=tf.expand_dims(idx, -1)
    print(idx.shape)
    return tf.gather_nd(colormap2label, idx)


tr_images_dir = r"D:/zch/data running/data1_256/1/"
tr_labels_dir = r"D:/zch/data running/data1_256/2/"
train_dataset = get_dataset(tr_images_dir,tr_labels_dir,colormap2label,is_train=True)

v_images_dir = "D:/zch/data running/data1_256/1/"                                                                                                       
v_gt_images_dir = "D:/zch/data running/data1_256/2/"


val_dataset = get_dataset(v_images_dir,v_gt_images_dir,colormap2label,True)
# print(val_dataset)
batch_size = 8
train_dataset =train_dataset.shuffle(buffer_size=19000).batch(batch_size)
val_dataset = val_dataset.shuffle(buffer_size=10000).batch(batch_size)

from builders.model_builder import builder
model,base_model = builder(3,input_size=(256,256),model='UNet', base_model=None)
model.compile(optimizer = tf.keras.optimizers.Adam(lr=1e-5),
                     loss=keras.losses.sparse_categorical_crossentropy,
                      metrics=['accuracy'])


history=model.fit(train_dataset,batch_size=batch_size, epochs=10,validation_data=val_dataset)