MNIST 简介
一个手写数字识别库, 世界上最权威的, 美国邮政系统开发的, 手写内容是 0-9 的内容, 手写内容采集于美国人口调查局的员工和高中生. 包括 6 万张训练图片和 1 万张测试图片构成的, 每张图片都是 28*28 大小, 而且都是黑白色构成.
MINIST 实验包含了四个文件, 其中 train-images-idx3-ubyte 是 60000 个图片样本, train-labels-idx1-ubyte 是这 60000 个图片对应的数字标签, t10k-images-idx3-ubyte 是用于测试的样本, t10k-labels-idx1-ubyte 是测试样本对应的数字标签.
我们以测试集中的一个图片为例来说明图片的存储形式:
MNIST 图片并不是传统意义上的 PNG 或者 jpg 格式的图片, 因为 PNG 或者 jpg 的图片格式, 会带有很多干扰信息 (如: 数据块, 图片头, 图片尾, 长度等等), 这些图片会被处理成很简易的数组, 图片长度为 28, 宽度也为 28, 总像素为 2828=784, 在 MNIST 存储的就是一个长度为 784 的数组, 数组中的每个值表示每个点的 RGB 值, 其中黑色用 0 表示, 白色用 255 表示. 我们可以将数组转成 2828 的二维数组, 如下图所示, 可以看出这是一个表示的是数字 5 的图片.
image.PNG
如果把像素写成图片, 图片是这样的:
image.PNG
通过 MNIST 训练模型
在 BP 神经网络中, 层数, 节点个数, 学习速率, 训练集, 训练次数, 都会影响到最终模型的泛化能力. 因此, 在设计模型时, 节点的个数, 学习速率的大小, 以及训练次数都是需要考虑的.
本实例中设置神经网络层数为 3 层, 其中输入特征为 784 个, 每层节点数分别为 300,100,10 个, 学习速率设置为 0.5, 迭代周期为 30, 批量设置 60 个. 通过训练该模型在 MNIST 测试集上的平均准确率为 96.68 % 左右.
- public static void main(String[] args) {
- // 三层网络, 各层节点数为 784*300*10 输入特征 784 个 隐藏层节点 300 个 输出层节点 10 个
- int[] nodeNum = {784, 300,100, 10};
- // 周期被定义为向前和向后传播中所有批次的单次训练迭代.
- int epoch = 30;
- // 每次批量的样本数
- int batchSize = 60;
- double learningRate=0.5;
- NetTrainAndTest.train(nodeNum, epoch, batchSize,learningRate);
- }
对模型进行序列化
为了 "一次训练, 多次使用", 我们对训练好的模型进行序列化存储, 后续即可通过反序列化的方式读取恢复模型.
- /**
- * 通过序列化方式存储模型
- *
- * @param fileName 模型存放的文件名
- */
- public static <T> void saveModel(String fileName, T obj) {
- try (BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(fileName));
- ObjectOutputStream oos = new ObjectOutputStream(bos)) {
- oos.writeObject(obj);
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- }
- /**
- * 恢复模型
- *
- * @param fileName 模型持久化的存放位置 文件名
- * <p>
- *//@SuppressWarnings("unchecked")
- */
- public static <T> T restoreModel(String fileName) {
- try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(fileName));
- ObjectInputStream ois = new ObjectInputStream(bis)) {
- return (T) ois.readObject();
- } catch (IOException | ClassNotFoundException e) {
- throw new RuntimeException(e);
- }
- }
来源: http://www.jianshu.com/p/195e916d45e0