我在unet 內訓練模型,底下是在data.py traingenerator yield遇到的錯誤訊息 ,主要是yield回傳給model.compile時發生錯誤
data.py traingenerator
img_batch = np.concatenate(images, axis=3) #(1, 512, 512, 8) (512x512 image,5*1channel 1*3chainnel)
mask_batch = np.concatenate(masks, axis=3) #(1, 512, 512, 6) (6 classes)
sample_weights = np.ones(9540)
y_int = np.argmax(mask_batch, axis=2)
y_binary = to_categorical(y_int) # Convert integer labels to binary vectors
yield (img_batch, y_binary, mask_batch)
error message
發生例外狀況: ValueError
Found a sample_weight array with shape (1, 512, 512, 6). In order to use timestep-wise sample weights, you should specify sample_weight_mode="temporal" in compile(); founssd "None" instead. If you just mean to use sample-wise weights, make sure your sample_weight array is 1D.
File "C:\Labbb\unet_mao0\main.py", line 90, in <module>
train_history=model.fit_generator(myGene,steps_per_epoch=200,epochs=20,callbacks=[model_checkpoint])
ValueError: Found a sample_weight array with shape (1, 512, 512, 6). In order to use timestep-wise sample weights, you should specify sample_weight_mode="temporal" in compile(); founssd "None" instead. If you just mean to use sample-wise weights, make sure your sample_weight array is 1D.
yield (img_batch, y_binary, mask_batch) img_batch(四維) shape 必須等於 mask_batch mask_batch又必須等於model.compile的sample_weight 的格式也就是一維/二維
該怎麼解決這個問題
这个问题看起来是由于 model.compile
函数中的 sample_weight
参数的格式不正确导致的。根据错误消息,“Found a sample_weight array with shape (1, 512, 512, 6)”表明 sample_weight
是一个四维数组。但是,在样本权重中,每个样本应只有一个权重值(即对应于该样本的标签),因此,sample_weight
应该是一个形如 (n_samples,),(n_samples, 1) 或 (n_batches, n_samples_per_batch) 的数组。
我建议你将 sample_weight
参数设置为一个形状为 (1, 1, 1, n_samples) 的四维数组,其中 n_samples
等于你的 img_batch
的样本数,然后再将其传递给 model.compile
函数。例如:
sample_weights = np.ones((1, 1, 1, img_batch.shape[0]))
model.compile(loss='binary_crossentropy', optimizer='adam', sample_weight_mode='temporal')
在这里,sample_weights
是一个形状为 (1, 1, 1, n_samples) 的数组,每个元素的值都为 1。在对模型进行编译时,确保设置了 sample_weight_mode='temporal'
,以使模型知道你要使用 timestep-wise sample weights。如果你只想使用 sample-wise weights,可以将 sample_weights
改为一个形如 (n_samples,) 的一维数组,或者在 model.fit_generator
中使用 sample_weight
参数。