最近做 Kaggle 的图像分类比赛: RSNA Intracranial Hemorrhage Detection()以及阅读 Yolov3
源码的时候接触到深度学习训练时一个有趣的技巧, 那就是构造生成器 generator 并且用 keras 的 fit_generator 来批量生成数据, 释放内存, 该方法适合于大规模数据集的训练. 一个 DataGenerator 是 keras 的 Sequence 类的继承类, 一般要包含__len__,__getitem__, on_epoch_end 等方法, 例如下面的批量图片数据生成器:
- class DataGenerator(keras.utils.Sequence):
- def __init__(self, list_IDs, labels, batch_size=1, img_size=(512, 512),
- img_dir, *args, **kwargs):
- """
- self.list_IDs: 存放所有需要训练的图片文件名的列表.
- self.labels: 记录图片标注的分类信息的 pandas.DataFrame 数据类型, 已经预先给定.
- self.batch_size: 每次批量生成, 训练的样本大小.
- self.img_size: 训练的图片尺寸.
- self.img_dir: 图片在电脑中存放的路径.
- """
- self.list_IDs = list_IDs
- self.labels = labels
- self.batch_size = batch_size
- self.img_size = img_size
- self.img_dir = img_dir
- self.on_epoch_end()
- def __len__(self):
- """
- 返回生成器的长度, 也就是总共分批生成数据的次数.
- """
- return int(ceil(len(self.list_IDs) / self.batch_size))
- def __getitem__(self, index):
- """
- 该函数返回每次我们需要的经过处理的数据.
- """
- indices = self.indices[index*self.batch_size:(index+1)*self.batch_size]
- list_IDs_temp = [self.list_IDs[k] for k in indices]
- X, Y = self.__data_generation(list_IDs_temp)
- return X, Y
- def on_epoch_end(self):
- """
- 该函数将在训练时每一个 epoch 结束的时候自动执行, 在这里是随机打乱索引次序以方便下一 batch 运行.
- """
- self.indices = np.arange(len(self.list_IDs))
- np.random.shuffle(self.indices)
- def __data_generation(self, list_IDs_temp):
- """
- 给定文件名, 生成数据.
- """
- X = np.empty((self.batch_size, *self.img_size, 1))
- Y = np.empty((self.batch_size, 6), dtype=np.float32)
- for i, ID in enumerate(list_IDs_temp):
- X[i,] = mpimg.imread(self.img_dir+ID+".png")
- Y[i,] = self.labels.loc[ID].values
- return X, Y
有了这个生成器, 我们就可以用 fit_generator 方法进行训练, 格式套路如下:
- model.fit_generator(generator,
- steps_per_epoch=...,
- epochs=...,
- verbose=...,
- callbacks=...,
- validation_data=...,
- validation_steps=...,
- validation_freq=...,
- class_weight=None=...,
- max_queue_size=...
- workers=...,
- use_multiprocessing=...,
- )
除此以外我们还可以搞批量预测:
model.predict_generator()
来源: http://www.bubuko.com/infodetail-3224450.html