1,MNIST 数据集简介
首先通过下面两行代码获取到 TensorFlow 内置的 MNIST 数据集:
- from tensorflow.examples.tutorials.mnist import input_data
- mnist = input_data.read_data_sets('./data/mnist', one_hot=True)
MNIST 数据集共有 55000(mnist.train.num_examples)张用于训练的数据, 对应的有 55000 个标签; 共有 10000(mnist.test.num_examples)张用于测试的图片的数据, 同样的有 10000 个标签与之对应. 为了方便访问, 这些图片或标签的数据都是被格式化了的.
MNIST 数据集的训练数据集 (mnist.train.images) 是一个 55000 * 784 的矩阵, 矩阵的每一行代表一张图片 (28 * 28 * 1) 的数据, 图片的数据范围是 [0, 1], 代表像素点灰度归一化后的值.
训练集的标签 (mnist.train.labels) 是一个 55000 * 10 的矩阵, 每一行的 10 个数字分别代表对应的图片属于数字 0 到 9 的概率, 范围是 0 或 1. 一个标签行只有一个是 1, 表示该图片的正确数字是对应的下标值, 其余是 0.
测试集与训练集的类似, 只是数据量不同.
- import numpy as np
- import matplotlib.pyplot as plot
- from tensorflow.examples.tutorials.mnist import input_data
- mnist = input_data.read_data_sets('./data/mnist', one_hot=True)
- trainImages = mnist.train.images
- trainLabels = mnist.train.labels
- plot.figure(1, figsize=(4, 3))
- for i in range(6):
- curImage = np.reshape(trainImages[i, :], (28, 28))
- curLabel = np.argmax(trainLabels[i, :])
- ax = plot.subplot(int(str(23) + str(i+1)))
- plot.imshow(curImage, cmap=plot.get_cmap('gray'))
- plot.axis('off')
- ax.set_title(curLabel)
- plot.suptitle('MNIST')
- plot.show()
- def train(trainCycle=50000, debug=False):
- inputSize = 784
- outputSize = 10
- batchSize = 64
- inputs = tf.placeholder(tf.float32, shape=[None, inputSize])
- # x * w = [64, 784] * [784, 10]
- weights = tf.Variable(tf.random_normal([784, 10], 0, 0.1))
- bias = tf.Variable(tf.random_normal([outputSize], 0, 0.1))
- outputs = tf.add(tf.matmul(inputs, weights), bias)
- outputs = tf.nn.softmax(outputs)
- labels = tf.placeholder(tf.float32, shape=[None, outputSize])
- loss = tf.reduce_mean(tf.square(outputs - labels))
- optimizer = tf.train.GradientDescentOptimizer(0.1)
- trainer = optimizer.minimize(loss)
- sess = tf.Session()
- sess.run(tf.global_variables_initializer())
- for i in range(trainCycle):
- batch = mnist.train.next_batch(batchSize)
- sess.run([trainer, loss], feed_dict={inputs: batch[0], labels: batch[1]})
- if debug and i % 1000 == 0:
- corrected = tf.equal(tf.argmax(labels, 1), tf.argmax(outputs, 1))
- accuracy = tf.reduce_mean(tf.cast(corrected, tf.float32))
- accuracyValue = sess.run(accuracy, feed_dict={inputs: batch[0], labels: batch[1]})
- print(i, 'train set accuracy:', accuracyValue)
- # 测试
- corrected = tf.equal(tf.argmax(labels, 1), tf.argmax(outputs, 1))
- accuracy = tf.reduce_mean(tf.cast(corrected, tf.float32))
- accuracyValue = sess.run(accuracy, feed_dict={inputs: mnist.test.images, labels: mnist.test.labels})
- print("accuracy on test set:", accuracyValue)
- sess.close()
来源: https://www.cnblogs.com/laishenghao/p/9576806.html