转载请注明作者:
Google Machine Learning Recipes 7
-
Github 工程地址
欢迎 Star,也欢迎到
- mnist = learn.datasets.load_dataset('mnist')
恩,就是这么简单,一行代码下载解压 mnist 数据,每个 img 已经灰度化成长 784 的数组,每个 label 已经 one-hot 成长度 10 的数组
在我的看 One-hot 是什么东西
- data = mnist.train.images labels = np.asarray(mnist.train.labels, dtype = np.int32) test_data = mnist.test.images test_labels = np.asarray(mnist.test.labels, dtype = np.int32) max_examples = 10000 data = data[: max_examples] labels = labels[: max_examples]
- def display(i) : img = test_data[i] plt.title('Example %d. Label: %d' % (i, test_labels[i])) plt.imshow(img.reshape((28, 28)), cmap = plt.cm.gray_r) plt.show()
用 matplotlib 展示灰度图
- feature_columns = learn.infer_real_valued_columns_from_input(data)
- classifier = learn.LinearClassifier(feature_columns = feature_columns, n_classes = 10) classifier.fit(data, labels, batch_size = 100, steps = 1000)
注意要制定 n_classes 为 labels 的数量
- result = classifier.evaluate(test_data, test_labels) print result["accuracy"]
速度非常快,而且准确率达到 91.4%
可以只预测某张图,并查看预测是否跟实际图形一致
- #here 's one it gets right
- print ("Predicted %d, Label: %d" % (classifier.predict(test_data[0]), test_labels[0]))
- display(0)
- # and one it gets wrong
- print ("Predicted %d, Label: %d" % (classifier.predict(test_data[8]), test_labels[8]))
- display(8)'
- weights = classifier.weights_ a.imshow(weights.T[i].reshape(28, 28), cmap = plt.cm.seismic)
来源: http://www.cnblogs.com/hellocwh/p/5783249.html