最近完成了一个以图搜图的项目, 项目总共用时三个多月. 记录一下项目中用到机器学习的地方, 以及各种踩过的坑. 总的来说, 项目分为一下几个部分:
一, 训练目标函数
1, 设定基础模型
2, 添加新层
3, 冻结 base 层
4, 编译模型
5, 训练
6, 保存模型
二, 特征提取
三, 创建索引
四, 构建服务
1,flask 开发
2,Gunicorn 异步, 增加服务稳健性
3,Supervisor 部署监控服务
五, 总结
一, 训练目标函数
项目是在预训练模型 vgg16 的基础上进行微调(fine_tune), 并将特征的维度从原先的 2048 维降为 1024 维度.
模型的微调又分为以下几个步骤:
1, 设定基础模型
本次采用预训练的 VGG16 基础模型, 利用其 bottleneck 特征
- #设定基础模型
- base_model =VGG16(weights='./model/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5',include_top=False)
- #指定权重路径
- # include_top= False 不加载三层全连接层
2, 添加新层
将自己要目标图片, 简单分类, 统计类别(在训练模型时需要指定类别)
- # 添加新层
- def add_new_last_layer(base_model, nb_classes):
- '''
- 添加最后的层
- :param base_model: 预训练模型
- :param nb_classes: 分类数量
- :return: 新的 model
- '''
- x = base_model.output
- x = GlobalAveragePooling2D()(x)
- x = Dense(128, activation='relu')(x) #输出的特征维度 88
- predictions = Dense(nb_classes, activation='softmax')(x)
- model = Model(input=base_model.input, output=predictions)
- return model
3, 冻结 base 层
以前的参数可以使用预训练好的参数, 不需要重新训练, 所以需要冻结, 不让其改变.
- def freeze_base_layer(model, base_model):
- for layer in base_model.layers:
- layer.trainable = False
4, 编译模型
- model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics= ['accuracy'])
- # optimizer: 优化器
- # loss: 损失函数, 多类的对数损失需要将分类标签转换为 (将标签转化为形如(nb_samples, nb_classes) 的二值序列)
- # metrics: 列表, 包含评估模型在训练和测试时的网络性能的指标准备训练数据.
5, 训练
- # 数据准备
- IM_WIDTH, IM_HEIGHT = 224,224
- train_dir = './refine_img_data/train'
- val_dir = './refine_img_data/test'
- nb_classes = 5
- np_epoch = 3
- batch_size = 16
- nb_train_samples = get_nb_files(train_dir)
- nb_classes = len(glob.glob(train_dir + '/*'))
- nb_val_samples = get_nb_files(val_dir)
- # 根据现有数据, 设置新数据生成参数
- train_datagen = ImageDataGenerator(
- preprocessing_function=preprocess_input,
- rotation_range=30,
- width_shift_range=0.2,
- height_shift_range=0.2,
- shear_range=0.2,
- zoom_range=0.2,
- horizontal_flip=True
- )
- test_datagen = ImageDataGenerator(
- preprocessing_function=preprocess_input,
- rotation_range=30,
- width_shift_range=0.2,
- height_shift_range=0.2,
- shear_range=0.2,
- zoom_range=0.2,
- horizontal_flip=True
- )
- # 从文件夹获取数据
- train_generator = train_datagen.flow_from_directory(
- train_dir,
- target_size=(IM_WIDTH, IM_HEIGHT),
- batch_size=batch_size,
- class_mode='categorical'
- )
- validation_generator = test_datagen.flow_from_directory(
- val_dir,
- target_size=(IM_WIDTH, IM_HEIGHT),
- batch_size=batch_size,
- class_mode='categorical'
- )
- # 训练
- history_t1 = model.fit_generator(
- train_generator,
- epochs=1,
- steps_per_epoch=10,
- validation_data=validation_generator,
- validation_steps=10,
- class_weight='auto'
- )
6, 保存模型
将模型保存到指定路径一般保存为. h5 格式
model.save('/model/test_model.h5')
二, 特征提取
加载我们训练好的模型, 根据需要, 取指定层的特征.
- # 可用 model.summary() 查看模型结构
- # 根据模型提取图片特征
- target_size = (224,224)
- def my_feature(mod, path):
- img = image.load_img(path,target_size=target_size)
- img = image.img_to_array(img)
- img = np.expand_dims(img, axis=0)
- img = preprocess_input(img)
- return mod.predict(img)
- # 创建模型, 获取指定层特征
- model_path = './model/my_model.h5'
- base_model = load_model(model_path)
- model = Model(inputs=base_model.input, outputs=base_model.get_layer('dense_1').output)
- # 提取特征
- img_path = './my_img/bus.jpg'
- feat = my_feature(model,img_path) # shape 为 (1,128)
- print(feat)
- print(feat.shape)
- # 注意, 当需要提取的图片特征数量较大, 比如千万以上, 需要的时间是比较长的, 这时我们可以采用多核与批处理来进行 (python 由于 GIL 的问题对多线程不友好).
- def pre_processs_image(path):
- if path is not None and os.path.exists(path) and len(path)> 10:
- try:
- img = cv2.imread(path, cv2.IMREAD_COLOR)
- img = cv2.resize(img, (224, 224))
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
- img = img.transpose(2, 0, 1)
- return [material_id,img, flag]
- except Exception as err:
- traceback.print_exc()
- return None
- else:
- logging.error('could not find path:' + path)
- return None
- #cpu 部分, 调用多核处理函数, 指定核数为 20
- with ProcessPoolExecutor(max_workers=20) as executor:
- feat_paras = list(executor.map(pre_processs_image,, material_batch))
- # GPU 部分采用批处理
- # TODO
三, 创建索引
此处我们使用 Facebook 开源的近邻索引框架 faiss .
- # create index
- d = 128
- nlist = 100 # 切分数量
- nprobe = 8 # 每次查找分片数量
- quantizer_img = faiss.IndexFlatL2(d) #根据欧式距离创建索引
- image_index = None
- model_index = None
- if image_feat_array is not None and len(img_feat_list)> 100:
- image_index = faiss.IndexIVFFlat(quantizer_img, d, nlist, faiss.METRIC_L2)
- image_index.train(image_feat_array)
- image_index.add_with_ids(image_feat_array,image_id_array)
- image_index.nprobe = nprobe
- image_index.dont_dealloc_me = quantizer_img
- # 保存当前索引到指定路径
- faiss.write_index(img_index,path)
- # 测试当前索引
- temp_feat = img_feat_list[1]
- res_2 = image_index.search(temp_feat, k=5)
- logging.info('image search result is:' + str(res_2))
四, 构建服务
1,flask 开发
参考文档 http://docs.jinkan.org/docs/flask/quickstart.html#a-minimal-application
2,Gunicorn 异步, 增加服务稳健性
基础语法:
- Gunicorn -w process_num -b ip:port -k 'gevent' fileName:app
- # 注意: 此处不选择 - k 'gevent'则为同步运行
同步部署:
gunicorn -b 0.0.0.0:9090 my_service:app
异步部署:
gunicorn -b 0.0.0.0:9090 -k gevent my_service:app
用了 Gunicorn 来部署应用后, 对比 flask , qps 提升了一倍. 原 flask 框架中由于我的接口中 request 了其他的接口, 线程在此处会阻塞, 导致程序非常容易假死. 改用后, 稳定又了极大的提升.
3,Supervisor 部署监控服务
可参考以下文档 https://www.cnblogs.com/gjack/p/8076419.html
五, 总结
项目到这个地方, 基本的服务框架已经有了. 许多地方只说了大体思路, 但是结构是完整. 文中的许多用了许多方法工具, 如 gunicorn 的异步等, 但是原理却不甚了解, 还需要花功夫去学习. 由于上线压力大, 时间紧, 许多地方来不及仔细琢磨, 肯定有不少纰漏, 后面再查漏补缺吧.
来源: https://www.cnblogs.com/yaolin1228/p/9557588.html