我在学习时间卷积网络(Temporal convolutional network ),跟着别人的代码学习,在看到有一行代码的后面又加了一个‘(x)’,这个x是上一步的变量x,我想知道这是什么意思,代码怎么运行的?在残差函数中也有这个现象。
from keras.models import Model
from keras.layers import add,Input,Conv1D,Activation,Flatten,Dense
#载入数据
def read_data(path):
mnist=input_data.read_data_sets(path,one_hot=True)
train_x,train_y=mnist.train.images.reshape(-1,28,28),mnist.train.labels,
valid_x,valid_y=mnist.validation.images.reshape(-1,28,28),mnist.validation.labels,
test_x,test_y=mnist.test.images.reshape(-1,28,28),mnist.test.labels
return train_x,train_y,valid_x,valid_y,test_x,test_y
#残差块
def ResBlock(x,filters,kernel_size,dilation_rate):
r=Conv1D(filters,kernel_size,padding='same',dilation_rate=dilation_rate,activation='relu')(x) #第一卷积
r=Conv1D(filters,kernel_size,padding='same',dilation_rate=dilation_rate)(r) #第二卷积
if x.shape[-1]==filters:
shortcut=x
else:
shortcut=Conv1D(filters,kernel_size,padding='same')(x) #shortcut(捷径)
o=add([r,shortcut])
o=Activation('relu')(o) #激活函数
return o
#序列模型
def TCN(train_x,train_y,valid_x,valid_y,test_x,test_y):
inputs=Input(shape=(28,28))
x=ResBlock(inputs,filters=32,kernel_size=3,dilation_rate=1)
x=ResBlock(x,filters=32,kernel_size=3,dilation_rate=2)
x=ResBlock(x,filters=16,kernel_size=3,dilation_rate=4)
x=Flatten()(x)
x=Dense(10,activation='softmax')(x)
model=Model(input=inputs,output=x)
#查看网络结构
model.summary()
#编译模型
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
#训练模型
model.fit(train_x,train_y,batch_size=500,nb_epoch=30,verbose=2,validation_data=(valid_x,valid_y))
#评估模型
pre=model.evaluate(test_x,test_y,batch_size=500,verbose=2)
print('test_loss:',pre[0],'- test_acc:',pre[1])
train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data')
TCN(train_x,train_y,valid_x,valid_y,test_x,test_y)
后缀一个(x)是什么意思呢
没百度出来
知道后缀一个(x)的含义
a=b()(x)
这看起来很怪吗
如果你知道函数b的返回值是一个函数,像这样
def b():
def c():
...
return c
还怪吗
b(),其实就是c
b()(x)其实就是c(x)
最终其实就是a=c(x)
这就好比多维list中每一项也是一个list,所以你可以list[0][0][0]这样不断的往下写,直到下面的元素不再是个可迭代对象为止
函数的返回值如果还是函数,那么你也可以这样往下嵌套的写调用