添加新類別時concatenate shape被改變

我正在訓練一個多類 Unet,但是當我在 Unet 中添加新類別時,connect shape發生了變化

myGene = trainGenerator(1, folders, ['image1','image2','image3','image4','image6','image7'], ['label1','label2','label3','label4','label6','label7'],data_gen_args,save_to_dir = None)


def trainGenerator
    img_batch = np.concatenate(images, axis=3) #(1, 512, 512, 8)
    mask_batch = np.concatenate(masks, axis=3) #(1, 512, 512, 6)

我的數據是 8channel(5x1channel, 1x3channel, mask_channel=6)

img_batch=(1, 512, 512, 8)
mask_batch=(1, 512, 512, 6)
但是當我添加一個新類別(image0,label0)

image_batch 會變成 (1, 512, 512, 9)

mask_batch=(1, 512, 512, 7)
我檢查了所有數據,通道數是正確的

该回答引用NewBing,仅供参考,不保证完全正确

根据您的描述,您正在使用Unet进行多类分割。当您添加新类别时,您的图像和掩码的通道数发生了变化。这是因为Unet的结构需要输入的图像和掩码具有相同的通道数。因此,当您添加新类别时,您需要相应地更改代码以使其与新通道数兼容。¹


源: 与必应的对话, 2023/6/21
(1) Image segmentation | TensorFlow Core. https://www.tensorflow.org/tutorials/images/segmentation.
(2) Connected-UNets: a deep learning architecture for breast mass .... https://www.nature.com/articles/s41523-021-00358-x.
(3) U-Net Architecture - Object Detection | Coursera. https://www.coursera.org/lecture/convolutional-neural-networks/u-net-architecture-GIIWY.

不知道你这个问题是否已经解决, 如果还没有解决的话:
  • 以下回答来自chatgpt:

    问题标题:Unet的Concatenate Shape改变问题

    问题描述:我正在使用8个通道进行训练,其中包括5个1通道、1个3通道、1个掩码通道和1个新添加的类别(image0, label0)。在添加新类别后,我发现Unet模型中的Concatenate输入形状不再与之前一致,image_batch的维度由(1, 512, 512, 8)变为(1, 512, 512, 9)。我已经检查了所有的数据,并确认每个通道的数量都是正确的,但我仍然很困惑为什么会发生这种情况。请问可能由什么原因导致这种形状变化?请提供具体的解决方案。

    解决方案:根据您提供的信息,这种形状变化的原因可能是新添加的类别(image0, label0)并没有正确处理,导致Concatenate操作时输入形状与预期不符。以下是一些可能的解决方案:

    1.检查新添加的类别:请确保新添加的类别的数据格式与其他通道相同,并且正确地加载到了训练数据集中。您可以使用以下代码检查类别的数据格式:

    import matplotlib.pyplot as plt
    import numpy as np
    
    #假设新添加的类别为label0
    label0 = np.load('path/to/label0.npy')
    print(label0.shape)
    plt.imshow(label0)
    

    2.更新Unet网络结构:您可以尝试更新Unet网络结构,以确保它能够正确地处理新添加的类别(image0, label0)。以下是一些Python代码片段,您可以用它们来更新Unet网络结构:

    import tensorflow as tf
    from tensorflow.keras.layers import concatenate, Conv2DTranspose, Input, MaxPooling2D
    
    def create_unet(input_shape, num_classes):
        # 定义Unet网络结构
        inputs = Input(input_shape)
    
        conv1 = Conv2D(64, 3, activation='relu', padding='same')(inputs)
        conv1 = Conv2D(64, 3, activation='relu', padding='same')(conv1)
        pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    
        conv2 = Conv2D(128, 3, activation='relu', padding='same')(pool1)
        conv2 = Conv2D(128, 3, activation='relu', padding='same')(conv2)
        pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    
        conv3 = Conv2D(256, 3, activation='relu', padding='same')(pool2)
        conv3 = Conv2D(256, 3, activation='relu', padding='same')(conv3)
        pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    
        conv4 = Conv2D(512, 3, activation='relu', padding='same')(pool3)
        conv4 = Conv2D(512, 3, activation='relu', padding='same')(conv4)
        drop4 = tf.nn.dropout(conv4, rate=0.5)
        pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
    
        conv5 = Conv2D(1024, 3, activation='relu', padding='same')(pool4)
        conv5 = Conv2D(1024, 3, activation='relu', padding='same')(conv5)
        drop5 = tf.nn.dropout(conv5, rate=0.5)
    
        up6 = concatenate([Conv2DTranspose(512, 2, strides=(2, 2), padding='same')(drop5), conv4], axis=3)
        conv6 = Conv2D(512, 3, activation='relu', padding='same')(up6)
        conv6 = Conv2D(512, 3, activation='relu', padding='same')(conv6)
    
        up7 = concatenate([Conv2DTranspose(256, 2, strides=(2, 2), padding='same')(conv6), conv3], axis=3)
        conv7 = Conv2D(256, 3, activation='relu', padding='same')(up7)
        conv7 = Conv2D(256, 3, activation='relu', padding='same')(conv7)
    
        up8 = concatenate([Conv2DTranspose(128, 2, strides=(2, 2), padding='same')(conv7), conv2], axis=3)
        conv8 = Conv2D(128, 3, activation='relu', padding='same')(up8)
        conv8 = Conv2D(128, 3, activation='relu', padding='same')(conv8)
    
        up9 = concatenate([Conv2DTranspose(64, 2, strides=(2, 2), padding='same')(conv8), conv1], axis=3)
        conv9 = Conv2D(64, 3, activation='relu', padding='same')(up9)
        conv9 = Conv2D(64, 3, activation='relu', padding='same')(conv9)
    
        conv10 = Conv2D(num_classes, 1, activation='sigmoid')(conv9)
    
        unet = tf.keras.Model(inputs=inputs, outputs=conv10)
    
        return unet
    

    3.检查训练代码:如果您使用了自己编写的训练代码,则需要检查您的代码是否正确处理了新类别的数据形状。例如,在使用ImageDataGenerator生成器时,您需要更新数据生成器的通道数,以确保所有通道都正确处理。以下是一些更改ImageDataGenerator生成器通道数的代码片段:

    from tensorflow.keras.preprocessing.image import ImageDataGenerator
    
    # 假设新添加的类别为label0
    train_datagen = ImageDataGenerator(
        rescale=1./255,
        rotation_range=15,
        width_shift_range=0.1,
        height_shift_range=0.1,
        shear_range=0.1,
        zoom_range=0.2,
        horizontal_flip=True,
        vertical_flip=True,
        fill_mode='reflect',
    )
    
    train_generator = train_datagen.flow_from_directory(
        # 设置目标数据集的路径
        directory='/path/to/training/directory',
        # 设置目标图像尺寸
        target_size=(512, 512),
        # 更新标签数量
        class_mode='categorical',
        # 更新通道数
        color_mode='rgb',
        # batch size
        batch_size=16,
    )
    

    总结:Unet的Concatenate Shape改变可能是由于新添加的类别(image0, label0)没有正确处理导致的。您可以检查并更新数据格式、Unet网络结构和训练代码,以解决这个问题。


如果你已经解决了该问题, 非常希望你能够分享一下解决方案, 写成博客, 将相关链接放在评论区, 以帮助更多的人 ^-^