在进行相关平台的练习过程中, 由于要自己导入数据集, 而导入方法在市面上五花八门, 各种库都可以应用, 在这个过程中我准备尝试 torchvision 的库 dataset
torchvision.datasets.ImageFolder
简单应用起来非常简单, 用 torchvision.datasets.ImageFolder 实现图片的导入, 在随后训练过程中用 Datalodar 处理后可按批次取出训练集
class ImageFolder(root, transform=None, target_transform=None, loader=default_loader, is_valid_file=None)
ImageFolder 有这么几个参数, 其中 root 指的是数据所在的文件夹, 其中该文件夹的存储方式应为
root/labels/xxx.jpg
即根据自身分类标签存储在对应标签名的文件夹内
ImageFolder 在读入的过程中会自行加好标签, 最后形成一对对的数据
另外比较常用的就是 transform, 表示对于传入图片的预处理, 如剪裁, 颜色选择等等
比如
- transform_t = transforms.Compose([
- transforms.Resize([64, 64]),
- transforms.Grayscale(num_output_channels=1),
- transforms.ToTensor()]
- )
具体参数可以上网查看
在之后用 DataLodar 处理后虽然的确有 Shuffle 的参数, 但是却只是在一个小批次内进行打乱, 原本是按照类别存储的, 这样的话会导致很严重的过拟合, 为了避免这个, 我决定常识改写一下 Dataset 的类 (主要是看起来 Dataset 看起来改写比较顺手...ImageFolder 还没有看源码并没要对此下手)
但是 Dataset 需要读入一个个的训练数据的位置, 怎么办呢? 我就先写了一个小脚本, 生成一个 txt 文件来存储所有数据的名称 (相对路径), 同时在这一步就进行打乱操作 [一眼看下去甚至会发现 init 的 classnum 参数完全没用上 (捂脸
- import os
- import numpy as np
- '''
- self.target 顺序存储数据集
- self.DataFile 存储根目录
- self.s 存储所有数据
- self.label 存储所有标签及其对应的值
- '''
- class create_list():
- def __init__(self,root,classnum=2):
- self.target=open("./Data.txt",'w')
- self.DataFile=root
- self.s=[]
- self.label={}
- self.datanum=0
- def create(self):
- files=os.listdir(self.DataFile)
- for labels in files:
- tempdata=os.listdir(self.DataFile+"/"+labels)
- self.label[labels]=len(self.label)
- for img in tempdata:
- self.datanum+=1
- self.target.write(self.DataFile+"/"+labels+"/"+img+""+labels+"\n")
- self.s.append([self.DataFile+"/"+labels+"/"+img,labels])
- def detail(self):
- #查看数据数量以及标签对应
- print(self.datanum)
- print(self.label)
- def get_all(self):
- #查看所有数据
- print(self.s)
- def get_root(self):
- #获得根目录
- return self.DataFile
- def shuffle(self):
- #获得打乱的存储 txt
- shuffle_file=open("./Shuffle_Data.txt",'w')
- temp=self.s
- np.random.shuffle(temp)
- for i in temp:
- shuffle_file.write(i[0]+""+str(i[1])+"\n")
- return self.DataFile+"/Shuffle_Data.txt"
- def label_id(self,label):
- #获得该标签对应的值
- return self.label[label]
数据集的存储方式上的要求跟之前的 ImageFolder 一样
最终会生成一个这样的 txt 文件
数据集来源于某 x 光胸片判断...
而 Shuffle 操作就是为了生成打乱后的 txt 文件, 我写的比较简单粗暴... 先将就看吧, 生成后大概就是这个样子
至少真正的做到打乱数据了
完成这个以后, 就可以用此来帮助 DataLodar 了
接下来的代码或许比较辣眼睛... 但是事实证明是有用的, 但是可能 Python 技巧不太熟练所以就会显得很生涩...
我重现的 Dataset 类:
- from PIL import Image
- import torch
- class cDataset(torch.utils.data.Dataset):
- def __init__(self, datatxt, root="", transform=None, target_transform=None, LabelDic=None):
- super(cDataset,self).__init__()
- files = open(root + "/" + datatxt, 'r')
- self.img=[]
- for i in files:
- i = i.rstrip()
- temp = i.split()
- if LabelDic!=None:
- self.img.append((temp[0],LabelDic[temp[1]]))
- else:
- self.img.append((temp[0],temp[0]))
- self.transform = transform
- self.target_transform = target_transform
- def __getitem__(self, index):
- files, label = self.img[index]
- img = Image.open(files).convert('RGB')
- if self.transform is not None:
- img = self.transform(img)
- return img,label
- def __len__(self):
- return len(self.img)
其实直接看就能大概看明白, 主要也就是要实现类里面的几个方法
- class cDataset(torch.utils.data.Dataset):
- def __init__():
- def __getitem__(self, index):
- def __len__(self):
其中 getitm 类似一次次的取出数据, len 就是返回数据集数目
其中 init 的参数我做了稍许调整, 由于我之前的 txt 内标签是字符串, 而为了能让对应生成的 tag 是所要求的, 可以传入一个字典, 如:
LabelDic={"NORMAL":0,"PNEUMONIA":1}
这样就可以在之后转化为数字的标签, onehot 或者怎么怎么样了,,,