np.concatenate合併維度問題,in 1&3通道影像輸入的Unet

我在做一個預測未來七天青花菜生長的unet model
我試著去concatenate 影像,因為我的輸入是八通道(五張單通道,一張三通道[index 5])

appear this error message

        img_batch = np.concatenate(images, axis=3)________line172
        mask_batch = np.concatenate(masks, axis=2)________line173
        y_int = np.argmax(mask_batch, axis=2)  
        y_binary = to_categorical(y_int)  
        yield (img_batch, y_binary ,mask_batch )
error message
發生例外狀況: ValueError
all the input array dimensions for the concatenation axis must match exactly, but along dimension 3, the array at index 0 has size 1 and the array at index 5 has size 3
  File "C:\Labbb\unet_mao0\data.py", line 173, in trainGenerator
    mask_batch = np.concatenate(masks, axis=2)
  File "C:\Labbb\unet_mao0\main.py", line 88, in <module>
    train_history=model.fit_generator(myGene,steps_per_epoch=200,epochs=20,callbacks=[model_checkpoint])
ValueError: all the input array dimensions for the concatenation axis must match exactly, but along dimension 3, the array at index 0 has size 1 and the array at index 5 has size 3

錯誤訊息顯示mask image的index 5是三維的導致無法合併,
我認為是因為我所輸入的影像index5 是三通道rgb的影像所導致的
但是我確定mask影像全都是單通道影像

不知道你这个问题是否已经解决, 如果还没有解决的话:
  • 以下回答来自chatgpt:

    可以先使用np.split函数将输入数据拆分成各个通道,然后将单通道影像和三通道影像分开处理。对于单通道影像,可以直接使用np.concatenate进行合并。对于三通道影像,可以使用np.dstack先将三个通道的数据合并为一个三维数组,然后再使用np.concatenate进行合并。

    具体的代码示例如下:

    import numpy as np
    
    # 假设输入数据为8通道,shape为(batch_size, height, width, channels)
    inputs = ...
    
    # 将数据拆分成各个通道,注意index从0开始
    single_channel_images = [inputs[..., i] for i in range(5)]
    rgb_image = inputs[..., 5]
    
    # 对于单通道影像,直接使用np.concatenate进行合并
    merged_single_channel = np.concatenate(single_channel_images, axis=-1)
    
    # 对于三通道影像,先使用np.dstack将三个通道合并为一个三维数组,再使用np.concatenate进行合并
    merged_rgb = np.concatenate([np.dstack(rgb_image[..., i:i+3]) for i in range(0, 3, 3)], axis=-1)
    # 上面的代码等价于merged_rgb = np.dstack(rgb_image)
    
    # 将合并后的影像再次合并成一个输入数组,注意顺序
    merged_inputs = np.concatenate([merged_single_channel, merged_rgb], axis=-1)
    

    这样就可以成功地将八通道输入的Unet中的各个通道合并起来了。


如果你已经解决了该问题, 非常希望你能够分享一下解决方案, 写成博客, 将相关链接放在评论区, 以帮助更多的人 ^-^