关于如何将数据集封装为 Bunch 可参考 关于 AI 专属数据库的定制的改进 https://www.jianshu.com/p/29066e70ea5e .
PyTables http://www.pytables.org/ 是 Python 与 HDF5 数据库 / 文件标准的结合 http://www.hdfgroup.org/ . 它专门为优化 I/O 操作的性能, 最大限度地利用可用硬件而设计, 并且它还支持压缩功能.
下面的代码均是在 Jupyter NoteBook 下完成的:
- import sys
- sys.path.append('E:/xinlib')
- from base.filez import DataBunch
- import tables as tb
- import numpy as np
- def bunch2hdf5(root):
- '''
- 这里我仅仅封装了 Cifar10,Cifar100,MNIST,Fashion MNIST 数据集,
- 使用者还可以自己追加数据集.
- '''
- db = DataBunch(root)
- filters = tb.Filters(complevel=7, shuffle=False)
- # 这里我采用了压缩表, 因而保存为 `.h5c` 但也可以保存为 `.h5`
- with tb.open_file(f'{root}X.h5c', 'w', filters=filters, title='Xinet\'s dataset') as h5:
- for name in db.keys():
- h5.create_group('/', name, title=f'{db[name].url}')
- if name != 'cifar100':
- h5.create_array(h5.root[name], 'trainX', db[name].trainX, title='训练数据')
- h5.create_array(h5.root[name], 'trainY', db[name].trainY, title='训练标签')
- h5.create_array(h5.root[name], 'testX', db[name].testX, title='测试数据')
- h5.create_array(h5.root[name], 'testY', db[name].testY, title='测试标签')
- else:
- h5.create_array(h5.root[name], 'trainX', db[name].trainX, title='训练数据')
- h5.create_array(h5.root[name], 'testX', db[name].testX, title='测试数据')
- h5.create_array(h5.root[name], 'train_coarse_labels', db[name].train_coarse_labels, title='超类训练标签')
- h5.create_array(h5.root[name], 'test_coarse_labels', db[name].test_coarse_labels, title='超类测试标签')
- h5.create_array(h5.root[name], 'train_fine_labels', db[name].train_fine_labels, title='子类训练标签')
- h5.create_array(h5.root[name], 'test_fine_labels', db[name].test_fine_labels, title='子类测试标签')
- for k in ['cifar10', 'cifar100']:
- for name in db[k].meta.keys():
- name = name.decode()
- if name.endswith('names'):
- label_names = np.asanyarray([label_name.decode() for label_name in db[k].meta[name.encode()]])
- h5.create_array(h5.root[k], name, label_names, title='标签名称')
完成 Bunch 到 HDF5 的转换
- root = 'E:/Data/Zip/'
- bunch2hdf5(root)
- h5c = tb.open_file('E:/Data/Zip/X.h5c')
- h5c
- File(filename=E:/Data/Zip/X.h5c, title="Xinet's dataset", mode='r', root_uep='/', filters=Filters(complevel=7, complib='zlib', shuffle=False, bitshuffle=False, fletcher32=False, least_significant_digit=None))
- / (RootGroup) "Xinet's dataset"/cifar10 (Group)'https://www.cs.toronto.edu/~kriz/cifar.html'/cifar10/label_names (Array(10,))'标签名称' atom := StringAtom(itemsize=10, shape=(), dflt=b'')
- maindim := 0
- flavor := 'numpy'
- byteorder := 'irrelevant'
- chunkshape := None
- /cifar10/testX (Array(10000, 32, 32, 3)) '测试数据'
- atom := UInt8Atom(shape=(), dflt=0)
- maindim := 0
- flavor := 'numpy'
- byteorder := 'irrelevant'
- chunkshape := None
- /cifar10/testY (Array(10000,)) '测试标签'
- atom := Int32Atom(shape=(), dflt=0)
- maindim := 0
- flavor := 'numpy'
- byteorder := 'little'
- chunkshape := None
- /cifar10/trainX (Array(50000, 32, 32, 3)) '训练数据'
- atom := UInt8Atom(shape=(), dflt=0)
- maindim := 0
- flavor := 'numpy'
- byteorder := 'irrelevant'
- chunkshape := None
- /cifar10/trainY (Array(50000,)) '训练标签'
- atom := Int32Atom(shape=(), dflt=0)
- maindim := 0
- flavor := 'numpy'
- byteorder := 'little'
- chunkshape := None
- /cifar100 (Group) 'https://www.cs.toronto.edu/~kriz/cifar.html'
- /cifar100/coarse_label_names (Array(20,)) '标签名称'
- atom := StringAtom(itemsize=30, shape=(), dflt=b'')
- maindim := 0
- flavor := 'numpy'
- byteorder := 'irrelevant'
- chunkshape := None
- /cifar100/fine_label_names (Array(100,)) '标签名称'
- atom := StringAtom(itemsize=13, shape=(), dflt=b'')
- maindim := 0
- flavor := 'numpy'
- byteorder := 'irrelevant'
- chunkshape := None
- /cifar100/testX (Array(10000, 32, 32, 3)) '测试数据'
- atom := UInt8Atom(shape=(), dflt=0)
- maindim := 0
- flavor := 'numpy'
- byteorder := 'irrelevant'
- chunkshape := None
- /cifar100/test_coarse_labels (Array(10000,)) '超类测试标签'
- atom := Int32Atom(shape=(), dflt=0)
- maindim := 0
- flavor := 'numpy'
- byteorder := 'little'
- chunkshape := None
- /cifar100/test_fine_labels (Array(10000,)) '子类测试标签'
- atom := Int32Atom(shape=(), dflt=0)
- maindim := 0
- flavor := 'numpy'
- byteorder := 'little'
- chunkshape := None
- /cifar100/trainX (Array(50000, 32, 32, 3)) '训练数据'
- atom := UInt8Atom(shape=(), dflt=0)
- maindim := 0
- flavor := 'numpy'
- byteorder := 'irrelevant'
- chunkshape := None
- /cifar100/train_coarse_labels (Array(50000,)) '超类训练标签'
- atom := Int32Atom(shape=(), dflt=0)
- maindim := 0
- flavor := 'numpy'
- byteorder := 'little'
- chunkshape := None
- /cifar100/train_fine_labels (Array(50000,)) '子类训练标签'
- atom := Int32Atom(shape=(), dflt=0)
- maindim := 0
- flavor := 'numpy'
- byteorder := 'little'
- chunkshape := None
- /fashion_mnist (Group) 'https://github.com/zalandoresearch/fashion-mnist'
- /fashion_mnist/testX (Array(10000, 28, 28, 1)) '测试数据'
- atom := UInt8Atom(shape=(), dflt=0)
- maindim := 0
- flavor := 'numpy'
- byteorder := 'irrelevant'
- chunkshape := None
- /fashion_mnist/testY (Array(10000,)) '测试标签'
- atom := Int32Atom(shape=(), dflt=0)
- maindim := 0
- flavor := 'numpy'
- byteorder := 'little'
- chunkshape := None
- /fashion_mnist/trainX (Array(60000, 28, 28, 1)) '训练数据'
- atom := UInt8Atom(shape=(), dflt=0)
- maindim := 0
- flavor := 'numpy'
- byteorder := 'irrelevant'
- chunkshape := None
- /fashion_mnist/trainY (Array(60000,)) '训练标签'
- atom := Int32Atom(shape=(), dflt=0)
- maindim := 0
- flavor := 'numpy'
- byteorder := 'little'
- chunkshape := None
- /mnist (Group) 'http://yann.lecun.com/exdb/mnist'
- /mnist/testX (Array(10000, 28, 28, 1)) '测试数据'
- atom := UInt8Atom(shape=(), dflt=0)
- maindim := 0
- flavor := 'numpy'
- byteorder := 'irrelevant'
- chunkshape := None
- /mnist/testY (Array(10000,)) '测试标签'
- atom := Int32Atom(shape=(), dflt=0)
- maindim := 0
- flavor := 'numpy'
- byteorder := 'little'
- chunkshape := None
- /mnist/trainX (Array(60000, 28, 28, 1)) '训练数据'
- atom := UInt8Atom(shape=(), dflt=0)
- maindim := 0
- flavor := 'numpy'
- byteorder := 'irrelevant'
- chunkshape := None
- /mnist/trainY (Array(60000,)) '训练标签'
- atom := Int32Atom(shape=(), dflt=0)
- maindim := 0
- flavor := 'numpy'
- byteorder := 'little'
- chunkshape := None
从上面的结构可看出我将 Cifar10,Cifar100,MNIST,Fashion MNIST 进行了封装, 并且还附带了它们各种的数据集信息. 比如标签名, 数字特征 (以数组的形式进行封装) 等.
- %%time
- arr = h5c.root.cifar100.trainX.read() # 读取数据十分快速
- Wall time: 125 ms
- arr.shape
- (50000, 32, 32, 3)
- h5c.root
- / (RootGroup) "Xinet's dataset" children := ['cifar10'(Group),'cifar100'(Group),'fashion_mnist'(Group),'mnist' (Group)]
X.h5c 使用说明
下面我们以 Cifar100 为例来展示我们自创的数据集 X.h5c(我将其上传到了百度云盘链接: https://pan.baidu.com/s/1nzaicwHmFZH9Xgf2foSw6Q 密码: bl2e可以下载直接使用; 亦可你自己生成, 不过我推荐自己生成, 可以对数据集加深理解)
- cifar100 = h5c.root.cifar100
- cifar100
- /cifar100 (Group) 'https://www.cs.toronto.edu/~kriz/cifar.html'
- children := ['coarse_label_names' (Array), 'fine_label_names' (Array), 'testX' (Array), 'test_coarse_labels' (Array), 'test_fine_labels' (Array), 'trainX' (Array), 'train_coarse_labels' (Array), 'train_fine_labels' (Array)]
- 'coarse_label_names'
指的是粗粒度或超类标签名,'fine_label_names' 则是细粒度标签名.
可以使用 read() 方法直接获取信息, 也可以使用索引的方式获取.
- coarse_label_names = cifar100.coarse_label_names[:]
- # 或者
- coarse_label_names = cifar100.coarse_label_names.read()
- coarse_label_names.astype('str')
- array(['aquatic_mammals', 'fish', 'flowers', 'food_containers',
- 'fruit_and_vegetables', 'household_electrical_devices',
- 'household_furniture', 'insects', 'large_carnivores',
- 'large_man-made_outdoor_things', 'large_natural_outdoor_scenes',
- 'large_omnivores_and_herbivores', 'medium_mammals',
- 'non-insect_invertebrates', 'people', 'reptiles', 'small_mammals',
- 'trees', 'vehicles_1', 'vehicles_2'], dtype='<U30')
- fine_label_names = cifar100.fine_label_names[:].astype('str')
- fine_label_names
- array(['apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee',
- 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus',
- 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle',
- 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch',
- 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant',
- 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house',
- 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
- 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain',
- 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter',
- 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate',
- 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road',
- 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk',
- 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar',
- 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone',
- 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip',
- 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman',
- 'worm'], dtype='<U13')
'testX' 与'trainX' 分别代表数据的测试数据和训练数据, 而其他的节点所代表的含义也是类似的.
- trainX = cifar100.trainX
- train_coarse_labels = cifar100.train_coarse_labels
- array([11, 15, 4, ..., 8, 7, 1])
- train_data = trainX[:]
- print(train_data[0].shape)
- print(train_data.dtype)
- (32, 32, 3) uint8
- for x in cifar100.trainX:
- y = x * 2
- break
- print(y.shape)
- (32, 32, 3)
- h5c.get_node(h5c.root.cifar100, 'trainX')
- /cifar100/trainX (Array(50000, 32, 32, 3)) '训练数据'
- atom := UInt8Atom(shape=(), dflt=0)
- maindim := 0
- flavor := 'numpy'
- byteorder := 'irrelevant'
- chunkshape := None
- trainX = cifar100.trainX
- train_coarse_labels = cifar100.train_coarse_labels
- def data_iter(X, Y, batch_size):
- n = X.nrows
- idx = np.arange(n)
- if X.name.startswith('train'):
- np.random.shuffle(idx)
- for i in range(0, n ,batch_size):
- k = idx[i: min(n, i + batch_size)].tolist()
- yield np.take(X, k, 0), np.take(Y, k, 0)
- for x, y in data_iter(trainX, train_coarse_labels, 8):
- print(x.shape, y)
- break
- (8, 32, 32, 3) [ 7 7 0 15 4 8 8 3]
来源: https://www.cnblogs.com/q735613050/p/9244223.html