tensorflow2.0如何实现参数共享

您好,我想请教一下如何在tf2.0中实现参数共享。
我正在尝试实现【Unsupervised Visual Representation Learning by Context Prediction】这篇论文中的模型结构如下图所示:

这个结构的前几层是由相同的两部分组成的,然后在fc7融合为一层,但是这个模型要求fc7层以前并行的两部分的参数是共享的。我不知道该如何在tf2.0里面实现参数共享。要在tf2.0中实现这样的结构该如何做呢?

我尝试的代码如下:

def Alex_net(x):
    conv1    = tf.keras.layers.Conv2D(96,(11,11),activation='relu',strides=(4,4))(x)
    maxpool1 = tf.keras.layers.MaxPooling2D((3,3),strides=(2,2))(conv1)
    conv2    = tf.keras.layers.Conv2D(256,(5,5),activation='relu',padding='same')(maxpool1)
    maxpool2 = tf.keras.layers.MaxPooling2D((3,3),strides=(2,2))(conv2)
    conv3    = tf.keras.layers.Conv2D(384,(3,3),activation='relu',padding='same')(maxpool2)
    conv4    = tf.keras.layers.Conv2D(384,(3,3),activation='relu',padding='same')(conv3)
    conv5    = tf.keras.layers.Conv2D(256,(3,3),activation='relu',padding='same')(conv4)
    maxpool5 = tf.keras.layers.MaxPooling2D((3,3),strides=(2,2))(conv5)
    fc6      = tf.keras.layers.Dense(4096,activation='relu')(maxpool5)
    print(fc6)
    return fc6

def Concat_net(x1,x2):
    input_1 = Alex_net(x1)
    input_2 = Alex_net(x2)   
    concat  = tf.keras.layers.concatenate([input_1,input_2])
    fc7     = tf.keras.layers.Dense(4096,activation='relu')(concat)
    fc8     = tf.keras.layers.Dense(4096,activation='relu')(fc7)
    fc9     = tf.keras.layers.Dense(8,activation='softmax')(fc8)
    C       = fc9
    return C

def final_net(width,height,depth):
    inputshape=(height,width,depth)
    inputs_1  = tf.keras.layers.Input(shape=inputshape)
    inputs_2  = tf.keras.layers.Input(shape=inputshape)
    outputs   = Concat_net(inputs_1, inputs_2)
    model     = tf.keras.Model([inputs_1,inputs_2],outputs,name='concat_NET')
    return model
F=final_net(96,96,3)
F.summary()

但是这样summary打印出来的参数是独立的,并不是共享的,模型只是重复利用了结构。summary的结果如下:

图片说明

要在tf2.0中实现这样的结构该如何做呢?

import tensorflow as tf
from tensorflow import keras

def Concat_net(x1,x2,model):
    input_1 = model.predict(x1)
    input_2 = model.predict(x2)   
    concat  = tf.keras.layers.concatenate([input_1,input_2])
    fc7     = tf.keras.layers.Dense(4096,activation='relu')(concat)
    fc8     = tf.keras.layers.Dense(4096,activation='relu')(fc7)
    fc9     = tf.keras.layers.Dense(8,activation='softmax')(fc8)
    C       = fc9
    return C

def final_net(inputshape,model):
    inputs_1  = tf.keras.layers.Input(shape=inputshape)
    inputs_2  = tf.keras.layers.Input(shape=inputshape)
    outputs   = Concat_net(inputs_1, inputs_2, model)
    model     = tf.keras.Model([inputs_1,inputs_2],outputs,name='concat_NET')
    return model

Alex_net = keras.Sequential([
    keras.layers.Conv2D(96,(11,11),activation='relu',strides=(4,4)),
    keras.layers.MaxPooling2D((3,3),strides=(2,2)),
    keras.layers.Conv2D(256,(5,5),activation='relu',padding='same'),
    keras.layers.MaxPooling2D((3,3),strides=(2,2)),
    keras.layers.Conv2D(384,(3,3),activation='relu',padding='same'),
    keras.layers.Conv2D(384,(3,3),activation='relu',padding='same'),
    keras.layers.Conv2D(256,(3,3),activation='relu',padding='same'),
    keras.layers.MaxPooling2D((3,3),strides=(2,2)),
    keras.layers.Dense(4096,activation='relu')
])
inputshape=(96,96,3)
F=final_net(inputshape,Alex_net)
F.summary()

https://www.cnblogs.com/baby-lily/p/10934131.html

path1跑到fc6处,
path2跑到fc6处,
然后将两次fc6的结果整合后输入tc7