使用数据生成器讀取.npy檔出現錯誤 in Keras

這是我第一次創建自定義數據生成器。我對自定義數據生成器還不是很熟悉

為什麼 np.load 無法加載 .npy 文件?我在創建 npy 文件時遇到錯誤嗎?

data/syu/npy 中的所有 image_folders 和 label_folders,我嘗試使用與 ImageDataGenerator 相同的文件夾結構。

我的圖像是 5 * 1 通道和 1 * 3 通道,我concatenate到 8 通道,但單通道影像不能concatenate,所以我將它們擴展到 3 通道並concatenate成 18 通道 .npy 文件,label也執行相同的操作。

TypeError
expected str, bytes or os.PathLike object, not numpy.int32
  File "C:\Labbb\testing\unet_mao0\data.py", line 219, in get_data
    image = np.load(image_path)
  File "C:\Labbb\testing\unet_mao0\data.py", line 205, in __getitem__
    return self.get_data(batch)
  File "C:\Labbb\testing\unet_mao0\main.py", line 105, in <module>
    train_history=model.fit_generator(data_generator, steps_per_epoch=200,epochs=30,callbacks=[model_checkpoint])
TypeError: expected str, bytes or os.PathLike object, not numpy.int32
main.py
params = {'dim': (512,512), 'batch_size': 1, 'n_classes': 7, 'n_channels': 18, 'shuffle': True}
image_folders = ['image0', 'image1', 'image2', 'image3', 'image4', 'image6', 'image7']
label_folders = ['label0', 'label1', 'label2', 'label3', 'label4', 'label6', 'label7']  
data_generator = CustomDataGenerator(image_folders, label_folders, **params)
model = unet()
model_checkpoint = ModelCheckpoint('weight/syu/0123467_Custom_test1.hdf5', monitor='loss',verbose=1, save_best_only=True)
class CustomDataGenerator(Sequence):
    def __init__(self, image_folders, label_folders,dim=(512,512),  batch_size=1,n_classes=7,n_channels=18,shuffle=True):
        self.image_folders = image_folders
        self.label_folders = label_folders
        self.dim = dim
        self.batch_size = batch_size
        self.n_classes = n_classes
        self.n_channels = n_channels
        self.shuffle = shuffle
        self.image_paths = []
        self.label_paths = []

        for folder in self.image_folders:
            image_folder_path = os.path.join('data/syu/npy', folder)
            image_files = os.listdir(image_folder_path)
            for file_name in image_files:
                self.image_paths.append(os.path.join(image_folder_path, file_name))

        for folder in self.label_folders:
            label_folder_path = os.path.join('data/syu/npy', folder)
            label_files = os.listdir(label_folder_path)
            for file_name in label_files:
                self.label_paths.append(os.path.join(label_folder_path, file_name))

        self.on_epoch_end()
    def __len__(self):
        return int(np.ceil(len(self.image_paths) / self.batch_size))

    def __getitem__(self, index):
        batch_image_paths = self.image_paths[index * self.batch_size: (index + 1) * self.batch_size]
        batch_label_paths = self.label_paths[index * self.batch_size: (index + 1) * self.batch_size]

        
        batch = zip(batch_image_paths, batch_label_paths)
        
        return self.get_data(batch)
    
    def on_epoch_end(self):
        self.image_paths = np.arange(len(self.image_paths))
        self.label_paths = np.arange(len(self.label_paths))
        if self.shuffle == True:
            np.random.shuffle(self.image_paths)
            np.random.shuffle(self.label_paths)

    def get_data(self, batch):
        X = np.empty((self.batch_size, *self.dim, self.n_channels))
        y = np.empty((self.batch_size, *self.dim, self.n_classes))

        for i, (image_path, label_path) in enumerate(batch):
            image = np.load(image_path)
            label = np.load(label_path)

            X[i,] = image
            y[i,] = label

        return X, y

根据错误信息TypeError: expected str, bytes or os.PathLike object, not numpy.int32,问题出在np.load()这里,你传给它的image_path和label_path应该是一个字符串类型的文件路径,但现在是一个numpy整数数组。
主要原因是在on_epoch_end()方法里,你把image_pathslabel_paths重新赋值为:
python
self.image_paths = np.arange(len(self.image_paths))
self.label_paths = np.arange(len(self.label_paths))
这将它们变成了整数数组。
你应该删除或者注释掉这两行代码,保持image_pathslabel_paths为字符串类型的文件路径数组,不要重新赋值成整数数组。
另外一些建议:

  1. shuffle参数最好在外部传入,不要硬编码在__init__
  2. on_epoch_end方法中不需要重新对image_pathslabel_paths赋值,直接用np.random.shuffle对它们原地shuffle即可。
  3. 可以考虑用yield代替__getitem__和get_data,简化代码逻辑
  4. 建议参考Keras官方的imagenDataGenerator实现
    这样应该就可以解决这个问题了。