我现在有一个用Unet已经训练成功的网络,用于迁移学习。但是他的网络只训练了9个类别,我想读取他的权重并扩充到14个类别再进行训练(因为这14个类别里有9个是和他一样的),请问应该怎么实现呢?
代码如下
#这是原本9个类别的模型
used_model = "smUnet"
MULTI_GPU = False
BACKBONE = 'vgg19'
ENCODER_WEIGHTS = 'imagenet'
ACTIVATION = 'softmax'
ENCODER_FREEZE = False
model = sm.Unet(classes=9, activation=ACTIVATION,input_shape=(None, None, 4),encoder_freeze=ENCODER_FREEZE, backbone_name=BACKBONE, encoder_weights=None, decoder_block_type = 'transpose')
我想实现的是获取他的weights 并按照14个类别进行训练。 因为这14个类别里有9个类别是和他训练的一样