上次使用 Google ML Engine 跑了一下 TensorFlow Object Detection API 中的 Quick Start(http://www.cnblogs.com/take-fetter/p/8384564.html), 但是遇到了很多错误, 索性放弃了
这两天索性从自己的数据集开始制作手掌识别器先放运行结果吧
所有代码文件可在 https://github.com/takefetter/hand-detection 查看
使用前所需要的准备: 1.clone tensorflow models(site:https://github.com/tensorflow/models)
2. 在 model/research 目录下运行 setup.py 安装 object detection API
3. 其余必要条件: 安装 tensorflow(版本需大于等于 1.4),opencv-python 等必须的 package
4. 安装 Google Cloud SDK, 激活免费试用 300 美金 (需要一张信用卡来验证) 和在命令行中使用 gcloud init 设置等
准备数据集
(关于手的图片的 dataset 仍旧使用的 dlib 训练 (site:http://www.cnblogs.com/take-fetter/p/8321158.html) 中的 Hand Images Databases - https://www.mutah.edu.jo/biometrix/hand-images-databases.html 提供的数据集, 只不过这次使用了 WEHI 系列的图片(MOHI 的图片我也试过, 导入后会导致 standard-gpu 版的训练无法进行(内存不足)), 作为示例目前我只使用了 1-50 人的共计 250 张图片)
tensorflow 训练的数据集需为 TFRecord 格式, 我们需要对训练数据进行标注, 但是我并没有找到直接可以标注生成的工具, 索性有工具可以生成 Pascal VOC 格式的 xml 文件 https://github.com/tzutalin/labelImg, 推荐将图片文件放于 research/images 中, 保存 xml 文件夹位于 research/images/xmls 中
根据你要训练的数据集, 创建. pbtxt 文件
转换为 tfrecord 格式
完成图片标注后在 xmls 文件夹中运行 xml_to_csv.py 即可生成 csv 文件, 再通过 create_hand_tfrecord.py 即可将图片转换为 hand.record 文件
需要注意的是, 如果你需要训练的数据集和我这里的不一样的话, create_hand_tfrecord.py 的 todo 部分需要与你的. pbtxt 文件内的内容一致
(方法参考至 https://github.com/datitran/raccoon_dataset 使用本作者的文件还可以完成划分测试集和分析数据等功能, 当然我这里并没有使用)
下载预训练模型
重新开始一个模型的训练时间是很长的时间, 而 tensorflow model zoo 为我们提供好了预训练的模型(site:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md#coco-trained-models-coco-models), 选择并下载一个 我选择的是
速度最快的 ssd_mobilenet_v1, 下载后解压可找到 3 个含有 ckpt 的文件, 如图
之后还需下载并配置 model 对应的 config 文件 (https://github.com/tensorflow/models/tree/master/research/object_detection/samples/configs) 并修改文件中的内容
需要修改的地方有:
num_classes: 改为 pbtxt 中类的数目
PATH_TO_BE_CONFIGURED 的部分改为相应的目录
num_steps 定义了学习的上限 默认是 200000 可自己更改, 训练过程中也可以随时停止
上传文件并在 Google Cloud Platform 中训练
1. 上传 3 个 ckpt 文件以及 config 文件和. record 文件
到 google cloud 控制台 - 存储目录下, 创建存储分区(这里使用 takefetter_hand_detector), 并新建 data 文件夹, 拖拽上传到该目录中 生成后的目录如下
- + takefetter_hand_detector/
- + data/
- - faster_rcnn_resnet101_pets.config
- - model.ckpt.index
- - model.ckpt.meta
- - model.ckpt.data-00000-of-00001
- - pet_label_map.pbtxt
- - pet_train.record
- - pet_val.record
2. 打包 tf slim 和 object detection
在 research 目录下运行
- python setup.py sdist
- (cd slim && python setup.py sdist)
3. 创建机器学习任务
在 research 目录下运行此命令 开始训练
- gcloud ml-engine jobs submit training `whoami`_object_detection_`date +%s` \
- --runtime-version 1.4 \
- --job-dir=gs://takefetter_hand_detector/train \
- --packages dist/object_detection-0.1.tar.gz,slim/dist/slim-0.1.tar.gz \
- --module-name object_detection.train \
- --region us-central1 \
- --config object_detection/samples/cloud/cloud.yml \
- -- \
- --train_dir=gs://takefetter_hand_detector/train \
- --pipeline_config_path=gs://takefetter_hand_detector/data/ssd_inception_v2_coco.config
需要注意的地方有
windows 下需要放在同一行运行 并删除 \
cloud.yml 文件中的内容可以自行更改, 我这里的设置为
- trainingInput:
- runtimeVersion: "1.4"
- scaleTier: CUSTOM
- masterType: standard_gpu
- workerCount: 2
- workerType: standard_gpu
- parameterServerCount: 2
- parameterServerType: standard
在提交任务后在 机器学习引擎 - 作业中即可看到具体情况, 每运行几千次后在 takefetter_hand_detector/train 中存储对应 cheakpoint 的文件 如图
之后下载需要的 cheak 的 3 个文件 复制到 research 目录下(这里用 30045 的 3 个文件作为示例), 并将 research/object_detectIon 目录下的 export_inference_graph.py 复制到 research 目录下 运行例如
- python object_detection/export_inference_graph.py \
- --input_type image_tensor \
- --pipeline_config_path object_detection/samples/configs/ssd_mobilenet_v1_hand.config \
- --trained_checkpoint_prefix model.ckpt-30045 \
- --output_directory exported_graphs
在运行完成后 research 目录中会生成文件夹 exported_graphs_30045 包含的文件如图所示
拷贝 frozen_inference_graph.pb 和 pbtxt 文件到 test/hand_inference_graph 文件夹, 并运行 hand_detector.py 即可得到如文章开头的结果
后记:
1. 如果需要视频实时的 hand tracking, 可使用 https://github.com/victordibia/handtracking 在我的渣本上 FPS 太低了......
2. 我目前使用的数据集还是较小训练次数也比较少, 很容易出现一些误识别的情况, 之后还会加大数据集和训练次数
3. 换用其他 model 应该也会显著改善识别精确度
感谢:
- https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/running_pets.md
- https://github.com/victordibia/handtracking
- https://pythonprogramming.net/testing-custom-object-detector-tensorflow-object-detection-api-tutorial/?completed=/training-custom-objects-tensorflow-object-detection-api-tutorial/
- https://github.com/datitran/raccoon_dataset
- https://www.mutah.edu.jo/biometrix/hand-images-databases.html
来源: https://www.cnblogs.com/take-fetter/p/8438747.html