- /// <summary>
- /// 处理数据集
- /// </summary>
- /// <param name="trainDir"> 数据集所在文件夹 </param>
- /// <param name="numClasses"></param>
- /// <param name="validationSize"> 拿出多少做验证?</param>
- public void ReadDataSets(string trainDir, int numClasses = 10, int validationSize = 5000)
- {
- const string SourceUrl = "http://yann.lecun.com/exdb/mnist/";
- const string TrainImagesName = "train-images-idx3-ubyte.gz";
- const string TrainLabelsName = "train-labels-idx1-ubyte.gz";
- const string TestImagesName = "t10k-images-idx3-ubyte.gz";
- const string TestLabelsName = "t10k-labels-idx1-ubyte.gz";
- // 获得训练数据, 然后处理训练数据和测试数据
- TrainImages = ExtractImages(Helper.MaybeDownload(SourceUrl, trainDir, TrainImagesName), TrainImagesName);
- TestImages = ExtractImages(Helper.MaybeDownload(SourceUrl, trainDir, TestImagesName), TestImagesName);
- TrainLabels = ExtractLabels(Helper.MaybeDownload(SourceUrl, trainDir, TrainLabelsName), TrainLabelsName);
- TestLabels = ExtractLabels(Helper.MaybeDownload(SourceUrl, trainDir, TestLabelsName), TestLabelsName);
- // 拿出前面的一部分做验证
- ValidationImages = Pick(TrainImages, 0, validationSize);
- ValidationLabels = Pick(TrainLabels, 0, validationSize);
- // 拿出剩下的做训练 (输入 0 意味着拿剩下所有的)
- TrainImages = Pick(TrainImages, validationSize, 0);
- TrainLabels = Pick(TrainLabels, validationSize, 0);
- // 将数字标签转换为二维数组
- // 例如, 标签 3 = [0,0,0,1,0,0,0,0,0,0]
- // 标签 0 = [1,0,0,0,0,0,0,0,0,0]
- if (numClasses != -1)
- {
- OneHotTrainLabels = OneHot(TrainLabels, numClasses);
- OneHotValidationLabels = OneHot(ValidationLabels, numClasses);
- OneHotTestLabels = OneHot(TestLabels, numClasses);
- }
- }
- /// <summary>
- /// 获得 source 集合中的一部分, 从 first 开始, 到 last 结束
- /// </summary>
- /// <typeparam name="T"></typeparam>
- /// <param name="source"></param>
- /// <param name="first"></param>
- /// <param name="last"></param>
- /// <returns></returns>
- T[] Pick<T>(T[] source, int first, int last)
- {
- if (last == 0)
- {
- last = source.Length;
- }
- var count = last - first;
- var ret = source.Skip(first).Take(count).ToArray();
- return ret;
- }
- public static Mnist Load()
- {
- var x = new Mnist();
- x.ReadDataSets(@"D:\ 人工智能 \ C# 代码 \ MNISTTensorFlowSharp\MNISTTensorFlowSharp\data");
- return x;
- }
来源: https://www.cnblogs.com/haoyifei/p/9028235.html