关于keras使用时的一些问题。

目前在研究dann模型,但是遇到一部分大问题。

代码来源:
【深度域自适应】二、利用DANN实现MNIST和MNIST-M数据集迁移训练_戴璞微的学习之路-CSDN博客_mnist-m

def build_feature_extractor():
    """
    这是特征提取子网络的构建函数
    :param image_input: 图像输入张量
    :param name: 输出特征名称
    :return:
    """
    model = tf.keras.Sequential([Conv1D(filters=24, kernel_size=25, strides=1),
                                 # tf.keras.layers.BatchNormalization(),
                                 Activation('relu'),
                                 MaxPool1D(pool_size=2, strides=2),
                                 Flatten(),
                                 ])

    return model

def build_image_classify_extractor():
    """
    这是搭建图像分类器模型的函数
    :param image_classify_feature: 图像分类特征张量
    :return:
    """
    model = tf.keras.Sequential([Dense(100),
                                 tf.keras.layers.BatchNormalization(),
                                 Activation('relu'),
                                 tf.keras.layers.Dropout(0.5),
                                 Dense(100,activation='relu'),
                                 tf.keras.layers.Dropout(0.5),
                                 Dense(8,activation='softmax',name="image_cls_pred"),
    ])

    return model

image_input = Input(shape=self.cfg.text_input_shape)
embedding_layers = Embedding(output_dim=self.cfg.text_output_feature,  # 输出向量维度
                                          input_dim=n_symbols,  # 输入向量维度
                                          mask_zero=True,  # 使我们填补的0值在后续训练中不产生影响(屏蔽0值)
                                          weights=[embedding_weights],  # 对数据加权
                                          input_length=self.cfg.max_len)  # 每个特征的长度

# 域分类器与图像分类器的共享特征
feature_encoder = build_feature_extractor()
# 获取图像分类结果和域分类结果张量
image_cls_encoder = build_image_classify_extractor()

image_cls_feature = feature_encoder(embedding_layers(batch_mnist_image_data))
image_cls_pred = image_cls_encoder(image_cls_feature, training=True)

问题是这段中training参数的作用是什么和下文中training是一样么,为什么该部分改成函数式结果不一样呢。

model = Dense(100, activation='relu')(model, training=True)

呃,直接转到函数定义看参数是干啥的就可以了啊

呃,直接转到函数定义看参数是干啥的就可以了啊