哪位可以分享一下keras vgg16 fine-tune程序代码 程序能够正常运行,qq:1246365615
这是jupyter notebook文件:
{
"cells": [
{
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"from sklearn.model_selection import train_test_split\n",
"from keras.applications.vgg16 import VGG16\n",
"from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau\n",
"from keras.layers import Input, Dense, Dropout, Flatten\n",
"from keras.layers import Conv2D, MaxPooling2D\n",
"from keras.models import Sequential, Model\n",
"from keras.optimizers import Adam\n",
"from keras.preprocessing.image import ImageDataGenerator\n",
"np.random.seed(7)"
],
"outputs": [],
"metadata": {
"_cell_guid": "ad5a3ddc-02c5-4699-9e83-c58a09b9af25",
"_uuid": "fbcb0242449f7516052a04145a2119656907ea87"
},
"cell_type": "code",
"execution_count": 1
},
{
"source": [
"def make_df(path, mode):\n",
" \"\"\"\n",
" params\n",
" --------\n",
" path(str): path to json\n",
" mode(str): \"train\" or \"test\"\n",
"\n",
" outputs\n",
" --------\n",
" X(np.array): list of images shape=(None, 75, 75, 3)\n",
" Y(np.array): list of labels shape=(None,)\n",
" df(pd.DataFrame): data frame from json\n",
" \"\"\"\n",
" df = pd.read_json(path)\n",
" df.inc_angle = df.inc_angle.replace('na', 0)\n",
" X = _get_scaled_imgs(df)\n",
" if mode == \"test\":\n",
" return X, df\n",
"\n",
" Y = np.array(df['is_iceberg'])\n",
"\n",
" idx_tr = np.where(df.inc_angle > 0)\n",
"\n",
" X = X[idx_tr[0]]\n",
" Y = Y[idx_tr[0], ...]\n",
"\n",
" return X, Y\n",
"\n",
"\n",
"def _get_scaled_imgs(df):\n",
" imgs = []\n",
"\n",
" for i, row in df.iterrows():\n",
" band_1 = np.array(row['band_1']).reshape(75, 75)\n",
" band_2 = np.array(row['band_2']).reshape(75, 75)\n",
" band_3 = band_1 + band_2\n",
"\n",
" a = (band_1 - band_1.mean()) / (band_1.max() - band_1.min())\n",
" b = (band_2 - band_2.mean()) / (band_2.max() - band_2.min())\n",
" c = (band_3 - band_3.mean()) / (band_3.max() - band_3.min())\n",
"\n",
" imgs.append(np.dstack((a, b, c)))\n",
"\n",
" return np.array(imgs)"
],
"outputs": [],
"metadata": {
"collapsed": true
},
"cell_type": "code",
"execution_count": 2
},
{
"source": [
"def SmallCNN():\n",
" model = Sequential()\n",
"\n",
" model.add(Conv2D(64, kernel_size=(3, 3), activation='relu',\n",
" input_shape=(75, 75, 3)))\n",
" model.add(MaxPooling2D(pool_size=(3, 3), strides=(2, 2)))\n",
" model.add(Dropout(0.2))\n",
"\n",
" model.add(Conv2D(128, kernel_size=(3, 3), activation='relu'))\n",
" model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))\n",
" model.add(Dropout(0.2))\n",
"\n",
" model.add(Conv2D(128, kernel_size=(3, 3), activation='relu'))\n",
" model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))\n",
" model.add(Dropout(0.3))\n",
"\n",
" model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))\n",
" model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))\n",
" model.add(Dropout(0.3))\n",
"\n",
" model.add(Flatten())\n",
" model.add(Dense(512, activation='relu'))\n",
" model.add(Dropout(0.2))\n",
"\n",
" model.add(Dense(256, activation='relu'))\n",
" model.add(Dropout(0.2))\n",
"\n",
" model.add(Dense(1, activation=\"sigmoid\"))\n",
"\n",
" return model"
],
"outputs": [],
"metadata": {
"collapsed": true
},
"cell_type": "code",
"execution_count": 3
},
{
"source": [
"def Vgg16():\n",
" input_tensor = Input(shape=(75, 75, 3))\n",
" vgg16 = VGG16(include_top=False, weights='imagenet',\n",
" input_tensor=input_tensor)\n",
"\n",
" top_model = Sequential()\n",
" top_model.add(Flatten(input_shape=vgg16.output_shape[1:]))\n",
" top_model.add(Dense(512, activation='relu'))\n",
" top_model.add(Dropout(0.5))\n",
" top_model.add(Dense(256, activation='relu'))\n",
" top_model.add(Dropout(0.5))\n",
" top_model.add(Dense(1, activation='sigmoid'))\n",
"\n",
" model = Model(input=vgg16.input, output=top_model(vgg16.output))\n",
" for layer in model.layers[:13]:\n",
" layer.trainable = False\n",
"\n",
" return model"
],
"outputs": [],
"metadata": {
"collapsed": true
},
"cell_type": "code",
"execution_count": 4
},
{
"source": [
"if __name__ == \"__main__\":\n",
" x, y = make_df(\"../input/train.json\", \"train\")\n",
" xtr, xval, ytr, yval = train_test_split(x, y, test_size=0.25,\n",
" random_state=7)\n",
" model = SmallCNN()\n",
" #model = Vgg16()\n",
" optimizer = Adam(lr=0.001, decay=0.0)\n",
" model.compile(loss='binary_crossentropy', optimizer=optimizer,\n",
" metrics=['accuracy'])\n",
"\n",
" earlyStopping = EarlyStopping(monitor='val_loss', patience=20, verbose=0,\n",
" mode='min')\n",
" ckpt = ModelCheckpoint('.model.hdf5', save_best_only=True,\n",
" monitor='val_loss', mode='min')\n",
" reduce_lr_loss = ReduceLROnPlateau(monitor='val_loss', factor=0.1,\n",
" patience=7, verbose=1, epsilon=1e-4,\n",
" mode='min')\n",
"\n",
" gen = ImageDataGenerator(horizontal_flip=True,\n",
" vertical_flip=True,\n",
" width_shift_range=0,\n",
" height_shift_range=0,\n",
" channel_shift_range=0,\n",
" zoom_range=0.2,\n",
" rotation_range=10)\n",
" gen.fit(xtr)\n",
" model.fit_generator(gen.flow(xtr, ytr, batch_size=32),\n",
" steps_per_epoch=len(xtr), epochs=1,\n",
" callbacks=[earlyStopping, ckpt, reduce_lr_loss],\n",
" validation_data=(xval, yval))\n",
"\n",
" model.load_weights(filepath='.model.hdf5')\n",
" score = model.evaluate(xtr, ytr, verbose=1)\n",
" print('Train score:', score[0], 'Train accuracy:', score[1])\n",
"\n",
" xtest, df_test = make_df(\"../input/test.json\", \"test\")\n",
" pred_test = model.predict(xtest)\n",
" pred_test = pred_test.reshape((pred_test.shape[0]))\n",
" submission = pd.DataFrame({'id': df_test[\"id\"], 'is_iceberg': pred_test})\n",
" submission.to_csv('submission.csv', index=False)"
],
"outputs": [],
"metadata": {},
"cell_type": "code",
"execution_count": 5
}
],
"nbformat": 4,
"nbformat_minor": 1,
"metadata": {
"language_info": {
"name": "python",
"version": "3.6.3",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py",
"mimetype": "text/x-python"
},
"kernelspec": {
"name": "python3",
"language": "python",
"display_name": "Python 3"
}
}
}
来自:https://www.kaggle.com/takuok/keras-smallcnn-and-vgg16-fine-tuning/code
另外的资料:http://marubon-ds.blogspot.com/2017/09/vgg16-fine-tuning-model.html (需要科学上网)
https://flyyufelix.github.io/2016/10/03/fine-tuning-in-keras-part1.html