sample_weight 陣列維度錯誤訊息 跟 traingenerator yield In Unet

在unet訓練時出現以下錯誤 [(512,3072) 是 concatenate 6x512*512 image]
main.py

發生例外狀況: ValueError
Found a sample_weight array with shape (1, 512, 3072, 1). In order to use timestep-wise sample weighting, you should pass a 2D sample_weight array.
  File "C:\Labbb\unet_mao0\main.py", line 88, in <module>
    train_history=model.fit_generator(myGene,steps_per_epoch=200,epochs=20,callbacks=[model_checkpoint])
ValueError: Found a sample_weight array with shape (1, 512, 3072, 1). In order to use timestep-wise sample weighting, you should pass a 2D sample_weight array.

sample_weight_mode='temporal' 因為錯誤訊息顯示ample_weight array with shape (1, 512, 3072, 1),我想保留二三維度的值

model.py

    model.compile(sample_weight_mode='temporal',optimizer=Adam(lr=1e-4), loss='categorical_crossentropy', metrics=['accuracy'])  
but I dont know where can I adjust this array's shape

data.py

img_batch = np.concatenate(images, axis=3)
        mask_batch = np.concatenate(masks, axis=2)
        y_int = np.argmax(mask_batch, axis=2)  
        y_binary = to_categorical(y_int)  
        yield (img_batch, y_binary ,mask_batch )

我猜是因為yield 會還傳到main.py跟model.py 但是我不知道要怎麼去調整,我調整過,他的確會影響回傳的陣列但是每次都會出現錯誤

来自GPT的回答:


根据错误信息,问题出在了sample_weight数组的维度上。目前,sample_weight的维度是(1, 512, 3072, 1),而模型要求是一个二维的样本权重数组。

首先,需要明确一下你的训练数据的样本权重是如何定义和生成的。如果你没有明确指定样本权重,可以尝试将model.fit_generator中的sample_weight_mode参数设置为None,即移除样本权重的使用。

如果你确实需要使用样本权重,可以尝试调整数据生成器data.py中的代码,使其生成正确维度的样本权重数组。

以下是一种可能的修改方式:

python
Copy code
def data_gen(img_folder, mask_folder, batch_size):
    # 生成器代码
    while True:
        # 生成批量数据
        ...
        # 生成对应的样本权重
        sample_weight_batch = generate_sample_weight(img_batch, mask_batch)
        
        yield (img_batch, y_binary, mask_batch, sample_weight_batch)
其中,generate_sample_weight函数是根据你的需求自定义的生成样本权重的函数,它应该生成一个与img_batch和mask_batch具有相同维度的二维数组,用于指定每个样本的权重。

同时,确保在model.compile中移除sample_weight_mode参数,因为这个参数用于指定在时间步级别上对样本权重进行加权,而不是我们当前的情况。

当你进行这些修改后,应该能够正确传递样本权重到模型中进行训练,而不再出现维度错误的问题。记得根据你的需求自定义样本权重的生成函数以及相应的逻辑。