import os
import warnings
warnings.filterwarnings("ignore")
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers
from tensorflow.keras import Model
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler
from tensorflow.keras.callbacks import ReduceLROnPlateau
#%% 加载预训练模型
from tensorflow.keras.applications.resnet import ResNet50
from tensorflow.keras.applications.resnet import ResNet101
from tensorflow.keras.applications.inception_v3 import InceptionV3
# ResNet101网络
# 默认输入图片大小为224*224*3,使用input_shape参数重新设置
# include_top:设置是否包含最后用于分类的全连接层
# weights:设置使用哪种数据集训练出来的参数
pre_trained_model = ResNet101(input_shape = (75, 75, 3), # 输入大小
include_top = False, # 不要最后的全连接层
weights = 'imagenet')
#%% 构建网络
# 加载进来的自带训练好的参数的网络不需要再训练
for layer in pre_trained_model.layers:
layer.trainable = False
# 拉平 为全连接层准备
x = layers.Flatten()(pre_trained_model.output)
# 加入全连接层,这个需要重头训练的
x = layers.Dense(1024, activation='relu')(x)
x = layers.Dropout(0.2)(x)
# 输出层,二分类只需要一个神经元
x = layers.Dense(1, activation='sigmoid')(x)
# 构建模型序列
model = Model(pre_trained_model.input, x)
# 定义算法
model.compile(optimizer = Adam(lr=0.001),
loss = 'binary_crossentropy',
metrics = ['acc'])
#%% 构建图片生成器
base_dir = './data/cats_and_dogs'
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')
train_cats_dir = os.path.join(train_dir, 'cats')
train_dogs_dir = os.path.join(train_dir, 'dogs')
validation_cats_dir = os.path.join(validation_dir, 'cats')
validation_dogs_dir = os.path.join(validation_dir, 'dogs')
train_datagen = ImageDataGenerator(rescale = 1./255.,
rotation_range = 40,
width_shift_range = 0.2,
height_shift_range = 0.2,
shear_range = 0.2,
zoom_range = 0.2,
horizontal_flip = True)
test_datagen = ImageDataGenerator( rescale = 1.0/255. )
# 以文件夹路径为参数,不断产生batch数据,是经过数据提升/归一化以后
# directory:目标文件夹,该文件夹下放置的是以标签为命名的子文件夹,有几个子文件夹就认为有几个类,
# 子文件夹中任何jpg,png,bnp的图片都会被生成器使用
# batch_size:整数tuple,默认为(256,256),图像将被reisze成该尺寸
# color_mode:颜色模式,为'grayscale','rgb'之一,默认为'rgb'
# classes:可选参数,为子文件夹列表如['dogs','cats'],默认为None,若未提供则自动推断
# class_mode:该参数决定了返回的标签数组的形式,默认"categorical":返回2D的one-hot标签,
# "binary":返回1D的二值标签
train_generator = train_datagen.flow_from_directory(train_dir,
batch_size = 20,
class_mode = 'binary',
target_size = (75, 75))
validation_generator = test_datagen.flow_from_directory( validation_dir,
batch_size = 20,
class_mode = 'binary',
target_size = (75, 75))
#%% 构建回调
# Prepare model model saving directory.
save_dir = os.path.join(os.getcwd(), 'saved_models')
model_name = 'garbage_model.h5'
if not os.path.isdir(save_dir):
os.makedirs(save_dir)
filepath = os.path.join(save_dir, model_name)
def lr_schedule(epoch):
lr = 1e-3
if epoch > 180:
lr *= 0.5e-3
elif epoch > 160:
lr *= 1e-3
elif epoch > 120:
lr *= 1e-2
elif epoch > 80:
lr *= 1e-1
print('Learning rate: ', lr)
return lr
class myCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs={}):
if(logs.get('acc')>0.95):
print("\nReached 95% accuracy so cancelling training!")
self.model.stop_training = True
#该回调函数在每个epoch后保存模型到filepath
checkpoint = ModelCheckpoint(filepath=filepath,
monitor='val_acc',
verbose=1,
save_best_only=True)
#学习率调度函数
#该函数以epoch号为参数(从0起的整数),返回一个新学习率(浮点数)
lr_scheduler = LearningRateScheduler(lr_schedule)
lr_reducer = ReduceLROnPlateau(factor=np.sqrt(0.1),
cooldown=0,
patience=5,
min_lr=0.5e-6)
callbacks = [checkpoint, lr_reducer, lr_scheduler, myCallback()]
#%% 训练
history = model.fit_generator(
train_generator,
validation_data = validation_generator,
steps_per_epoch = 100,
epochs = 100,
validation_steps = 50,
verbose = 1,
callbacks=callbacks)
#%% 保存模型
save_dir = os.path.join(os.getcwd(), 'saved_models')
if not os.path.isdir(save_dir):
os.makedirs(save_dir)
model_name = 'last_model.h5'
filepath = os.path.join(save_dir, model_name)
print('saveing model')
model.save(filepath)
print('save model finished!')
#%% 绘图
import matplotlib.pyplot as plt
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(len(acc))
plt.plot(epochs, acc, 'b', label='Training accuracy')
plt.plot(epochs, val_acc, 'r', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.legend()
plt.figure()
plt.plot(epochs, loss, 'b', label='Training Loss')
plt.plot(epochs, val_loss, 'r', label='Validation Loss')
plt.title('Training and validation loss')
plt.legend()
plt.show()
评论0