```c
def model_load(IMG_SHAPE=(224, 224, 3), class_num=10):
# 搭建模型
model = tf.keras.models.Sequential([
# 对模型做归一化的处理,将0-255之间的数字统一处理到0到1之间
tf.keras.layers.experimental.preprocessing.Rescaling(1. / 255, input_shape=IMG_SHAPE),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Conv2D(128, (3, 3), activation='elu',padding='SAME'),
tf.keras.layers.MaxPooling2D((3, 3),padding='SAME',strides = 2),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Flatten(),
# The same 128 dense layers, and 10 output layers as in the pre-convolution example:
tf.keras.layers.Dense(128, activation='elu'),
tf.keras.layers.Dense(64, activation='elu'),
tf.keras.layers.Dense(32, activation='elu'),
tf.keras.layers.Dense(16, activation='elu'),
# 通过softmax函数将模型输出为类名长度的神经元上,激活函数采用softmax对应概率值
tf.keras.layers.Dense(class_num, activation='softmax')
])
model.summary()
# 指明模型的训练参数,优化器为sgd优化器,损失函数为交叉熵损失函数
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss='categorical_crossentropy',
metrics=['accuracy'])
怎么把上面的改成class model():的
这不是直接开启套娃模式,外面包个class就行了?
class Model:
def __init__(self,IMG_SHAPE=(224,224,3),class_num=10):
#super.__init__() #如果需要继承父类某些成员需要加上
self.IMG_SHAPE=IMG_SHAPE
self.class_num=class_num
def model_load(self):
# 搭建模型
model = tf.keras.models.Sequential([
# 对模型做归一化的处理,将0-255之间的数字统一处理到0到1之间
tf.keras.layers.experimental.preprocessing.Rescaling(1. / 255, input_shape=self.IMG_SHAPE),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Conv2D(128, (3, 3), activation='elu', padding='SAME'),
tf.keras.layers.MaxPooling2D((3, 3), padding='SAME', strides=2),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Flatten(),
# The same 128 dense layers, and 10 output layers as in the pre-convolution example:
tf.keras.layers.Dense(128, activation='elu'),
tf.keras.layers.Dense(64, activation='elu'),
tf.keras.layers.Dense(32, activation='elu'),
tf.keras.layers.Dense(16, activation='elu'),
# 通过softmax函数将模型输出为类名长度的神经元上,激活函数采用softmax对应概率值
tf.keras.layers.Dense(self.class_num, activation='softmax')
])
model.summary()
# 指明模型的训练参数,优化器为sgd优化器,损失函数为交叉熵损失函数
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss='categorical_crossentropy',
metrics=['accuracy'])
return model
是 python ? 把这个函数封装成类?