def unet_train():
height = 512
width = 512
path = 'C:/Users/asus/Desktop/unet_datasets/'
input_name = os.listdir(path + 'train_image')
n = len(input_name)
print(n)
X_train, y_train = [], []
for i in range(1,n+1):
print("正在读取第%d张图片" % i)
img = cv2.imread(path + 'train_image/(%d)_png' % i)
label = cv2.imread(path + 'train_label/(%d)_png' % i)
X_train.append(img)
y_train.append(label)
X_train = np.array(X_train)
y_train = np.array(y_train)
def Conv2d_BN(x, nb_filter, kernel_size, strides=(1, 1), padding='same'):
x = layers.Conv2D(nb_filter, kernel_size, strides=strides, padding=padding)(x)
x = layers.BatchNormalization(axis=3)(x)
x = layers.LeakyReLU(alpha=0.1)(x)
return x
def Conv2dT_BN(x, filters, kernel_size, strides=(2, 2), padding='same'):
x = layers.Conv2DTranspose(filters, kernel_size, strides=strides, padding=padding)(x)
x = layers.BatchNormalization(axis=3)(x)
x = layers.LeakyReLU(alpha=0.1)(x)
return x
inpt = layers.Input(shape=(height, width, 3))
conv1 = Conv2d_BN(inpt, 8, (3, 3))
conv1 = Conv2d_BN(conv1, 8, (3, 3))
pool1 = layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(conv1)
conv2 = Conv2d_BN(pool1, 16, (3, 3))
conv2 = Conv2d_BN(conv2, 16, (3, 3))
pool2 = layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(conv2)
conv3 = Conv2d_BN(pool2, 32, (3, 3))
conv3 = Conv2d_BN(conv3, 32, (3, 3))
pool3 = layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(conv3)
conv4 = Conv2d_BN(pool3, 64, (3, 3))
conv4 = Conv2d_BN(conv4, 64, (3, 3))
pool4 = layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(conv4)
conv5 = Conv2d_BN(pool4, 128, (3, 3))
conv5 = layers.Dropout(0.5)(conv5)
conv5 = Conv2d_BN(conv5, 128, (3, 3))
conv5 = layers.Dropout(0.5)(conv5)
convt1 = Conv2dT_BN(conv5, 64, (3, 3))
concat1 = layers.concatenate([conv4, convt1], axis=3)
concat1 = layers.Dropout(0.5)(concat1)
conv6 = Conv2d_BN(concat1, 64, (3, 3))
conv6 = Conv2d_BN(conv6, 64, (3, 3))
convt2 = Conv2dT_BN(conv6, 32, (3, 3))
concat2 = layers.concatenate([conv3, convt2], axis=3)
concat2 = layers.Dropout(0.5)(concat2)
conv7 = Conv2d_BN(concat2, 32, (3, 3))
conv7 = Conv2d_BN(conv7, 32, (3, 3))
convt3 = Conv2dT_BN(conv7, 16, (3, 3))
concat3 = layers.concatenate([conv2, convt3], axis=3)
concat3 = layers.Dropout(0.5)(concat3)
conv8 = Conv2d_BN(concat3, 16, (3, 3))
conv8 = Conv2d_BN(conv8, 16, (3, 3))
convt4 = Conv2dT_BN(conv8, 8, (3, 3))
concat4 = layers.concatenate([conv1, convt4], axis=3)
concat4 = layers.Dropout(0.5)(concat4)
conv9 = Conv2d_BN(concat4, 8, (3, 3))
conv9 = Conv2d_BN(conv9, 8, (3, 3))
conv9 = layers.Dropout(0.5)(conv9)
outpt = layers.Conv2D(filters=3, kernel_size=(1, 1), strides=(1, 1), padding='same', activation='relu')(conv9)
model = models.Model(inpt, outpt)
model.compile(optimizer='adam',
loss='mean_squared_error',
metrics=['accuracy'])
model.summary()
print(np.max(X_train))
print(np.max(y_train))
print(X_train.shape)
model.fit(X_train, y_train, epochs=100, batch_size=15)#epochs和batch_size看个人情况调整,batch_size不要过大,否则内存容易溢出
#我11G显存也只能设置15-20左右,我训练最终loss降低至250左右,acc约95%左右
model.save('unet.h5')
print('unet.h5保存成功!!!')