系列目录:
机器学习系列 - 第 0 篇 - 开发工具与 tensorflow 环境搭建
机器学习系列 - 第 1 篇 - 感知机识别手写数字(mnist 例子分析)
机器学习系列 - 第 2 篇 - CNN 识别手写数字(mnist 例子分析)
第 0 篇已经搭好了开发环境, 本文详细介绍用感知机识别手写数字 (mnist 例子) 的过程, 希望依照本文的步骤, 每个人能清楚理解 mnist 例子, 动手实践. 继续阅读之前, 最好你已经了解以下知识点(不然你迟早会回来的):
向量和矩阵基础运算规则 https://wenku.baidu.com/view/67a500c50029bd64783e2cdb.html
tensorflow 基础(session,graph,tensor 等) http://www.tensorfly.cn/tfdoc/get_started/basic_usage.html
感知机 https://www.zybuluo.com/hanbingtao/note/433855
交叉熵 https://www.zhihu.com/question/41252833
最小梯度下降算法 https://www.zybuluo.com/hanbingtao/note/448086
一 图片与数据分析
mnist 例子会自动下载下列数据
train-*.gz 是用来训练的数据, t10k-*gz 是用来测试的数据, 这些数据都不是原始的图片, 而是经过处理变成了二级制的文件, 这里重点分析一下这个数据格式.
1 image 文件
手写数字的图片是 28*28 的灰度图片, 图片中每个像素点的值范围是 0-255(黑色是 0, 白色是 255), 图片文件是按照这样格式写的:
魔法值(32 位)+ 图片数量(32 位)+ 图片宽(32 位)+ 图片长(32 位)+ 所有图数据
(1) 魔法值: 文件标识, train-images-idx3-ubyte 文件的 magic 值是 2051
(2) 所有图数据: 单张图数据 28*28=784 个 uint8, 所以所有图 N, 就是 N*784 个 uint8
用 16 进制查看 image 文件, 结果如下图所示:
2 label 文件
label 文件记录的是与 image 顺序一一对应的图片实际值, 范围是 0-9. 文件的格式是:
魔法值(32 位)+ 标签数量(32 位)+ 所有标签数据
(1)train-labels-idx1-ubyte 文件的 magic 值是 2049
(2) 所有标签数据: 每个标签是一个 uint8, 所以所有标签 就是 N 个 uint8
用 16 进制查看 label 文件, 结果如下图所示:
从上可以看到, 图片数据从 ea60 之后开始, 前 4 张图片分别是 5 0 4 1, 我把它们还原成图片:
还原图片的代码如下:
通过对文件数据的分析与还原, 我们能清楚知道数据格式是怎样的, 帮组我们在编码过程中处理数据时不会产生困惑, 同时在文章后面会用到自己制作手写字体来验证模型的准确性.
二 mnist 分析
1. 文件 import
mnist_reader 是我写的, 用来读取自己制作的手写字体(在第三部分有给出代码):
2. 训练过程
(1) 读取数据
one_hot 参数为 True 的意思是把图片 28*28 的二维数组处理为一维数组[784], 这样处理之后, 所有像素点都是一个特征输入, 最终变成只是分析像素点对预测结果的影响, 丢失了图片的结构信息. 因为这只是练习, 所以知道问题所在就好, 暂时不讨论该方法的优劣.
(2)模型构建
weight 为什么是一个 [784*10] 的数组, 因为图片数据是一个 784 的数组, 而每一张图片可能是 0-9 这十个数字中的一个, 每一张图片预测的结果有 10 种可能, 对应这张图片是 0,1, ...9 的的概率值. 同理, 偏置 b 也与实际输出的维数相同.
x 变量是在分批训练过程中用来存放一批图片, 所以它没有指定行数量, 让系统自动推导.
y_变量是在分批训练过程中, 用来存放一批标签数据, 它与 x 一一对应
y 变量是我们的模型函数 f(x)=wx+b 的结果 经过 softmax 函数处理的输出
(3)初始化 session
tensorflow 的变量需要先初始化才能用.
(4)模型训练
每次 100 张图片为一批, 循环 1000 次训练. 每一批的数量大小怎么定没有规定, 太小会导致训练出来的模型预测结果不理想, 太大会需要训练很久. 需要根据自己的样本数量来定批大小和循环次数. 主要考虑的如何合理设定这两个值的大小, 在有限的样本数量时, 尽量避免梯度下降算法 可能找到局部最小值 而非全局最小值的问题.
(5)模型评估
模型评估的规则是取预测值概率最大的对应数字与实际值相比较, 然后统计预测正常的占比, 为了方便理解, 假设训练 9 个样本, 下列代码的输出结果:
correct_prediction=tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) ->
- [ True True True False False True False False False]
- tf.reduce_mean(tf.cast(correct_prediction, "float")) -> 0.44444445
根据测试样本, 样本的准确率可以达到 91% 左右:
- training prediction: 0.9138
- training done
然后再通过不断调整 批量大小和循环次数, 发现最高的准确率基本都在 91%-92% 之间, 所以基本可以认定该模型的极限准确率在 92% 左右.
(6)模型保存
一个模型保存了如下 4 个文件:
checkpoint: 记录当前目录下有哪些模型
mnist.data-00000-of-00001: 模型的所有参数值 (如 w 和 b)
mnist.index:
mnist.meta: 模型的图结构信息
三. 模型使用
1. 制作自己的手写数字图片
根据 第一 部分最后 输出 mnist 样子的图片, 发现了它是黑色底, 即大部分像素点都是 0, 组成数字的部分是白色, 像素值都基本在 128-255, 所以参照它来做成的图片, 才能用训练好的模型来识别数字. 此前我自己学习过程, 按 tensorflow 中文社区 http://www.tensorfly.cn/tfdoc/tutorials/mnist_download.html 里面的例子是白底黑字, 制作手写字根本识别不了.
正确的只需按照下面两步就可以制作跟 mnist 例子一样的图片:
最后保存为 png 图片就可以了.
以下是我制作的图片, 0-9 数字各 9 张, 名字规则是 数字 + 序号, 即名字的第一字符标识图片里面写的是什么数字, 好处是不用写 label 文件, 在读取图片的时候可以自动解析出 label:
2. 加载图片数据
加载图片数据的重点是理解 mnist 里面的数据时怎样的格式, 这个我不去解释, 大家自己调试一下, 看内存数据. 读取图片的代码如下:
3. 识别
(1)使用训练好的模型
(2)可以不指定 specify, 全部测试, 我这个分开为每个数字单独测试, 看看各自的准确率
(3)结果
(4)结果分析
看到以上预测结果, 惊不惊喜, 意不意外? 跟训练评估时的 91% 准确率相差甚远. 是什么原因导致的呢? 回头去看我上面制作的手写图片, 你会发现跟它例子自带的数据有非常大的区别, 我制作的时候, 故意把数字写的大小不一, 位置不同.
再看 "mnist 分析" 部分第 (1) 点 "读取数据" 里面说到的, 把图片数据转为一个 [784] 数组, 丢失图片原有的结构信息, 单纯分析各像素点的值对图片的影响, 当在 28*28 大小的区域里面, 书写的数字大小和位置有很大差异的时候, 肯定是识别不出来.
四. 总结
至此, 我们完成了整个感知机识别手写数字的训练和预测过程, 在这过程中, 对于一个新手, 训练和测试的代码并不难, 难点是要理解它背后实现的基础, 就是文章一开始我列出的知识点. 如果你已经完全懂了本文中每一行代码的意思, 那么恭喜你, 机器学习的 Hello World 已经完成了.
来源: https://juejin.im/post/5aed31b3f265da0b9a69d7ec