http://www.cnblogs.com/zhiyishou/p/5651321.html
具体代码见
这是我对 cup, glasses 训练的识别
faster-rcnn 在 fast-rcnn 的基础上加了 rpn 来将整个训练都置于 GPU 内,以用来提高效率,这里我们将使用 ImageNet 的数据集来在 faster-rcnn 上来训练自己的分类器.从上可下载到很多类别的 Image 与 bounding box annotation 来进行训练(每一个类别下的 annotation 都少于等于 image 的个数,所以我们从 annotation 来建立索引).
在
lib/dataset/factory.py
中提供了 coco 与 voc 的数据集获取方法,而我们要做的就是在这里加上我们自己的 ImageNet 获取方法,我们先来建立 ImageNet 数据获取主文件.coco 与 pascal_voc 的获取都是继承于父类 imdb,所以我们可根据 pascal_voc 的获取方法来做模板修改完成我们的 ImageNet 类.
创建 ImageNet 类
由于在 faster-rcnn 里使用 rpn 来代替了 selective_search,所以我们可以在使用时直接略过有关 selective_search 的方法,根据 pascal_voc 类做模板,我们需要留下的方法有:
__init__ //初始化
image_path_at //根据数据集列表的index来取图片绝对地址
image_path_from_index //配合上面
_load_image_set_index //获取数据集列表
_gt_roidb //获取ground-truth数据
rpn_roidb //获取region proposal数据
_load_rpn_roidb //根据gt_roidb生成rpn_roidb数据并合成
_load_psacal_annotation //加载annotation文件并对bounding box进行数据整理
__init__: def __init__(self, image_set) : imdb.__init__(self, 'imagenet') self._image_set = image_set self._data_path = os.path.join(cfg.DATA_DIR, "imagenet")#类别与对应的wnid,
可以修改成自己要训练的类别self._class_wnids = {
'cup': 'n03147509',
'glasses': 'n04272054'
}#类别,
修改类别时同时要修改这里self._classes = ('__background__', self._class_wnids['cup'], self._class_wnids['glasses']) self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes)))#bounding box annotation文件的目录self._xml_path = os.path.join(self._data_path, "Annotations") self._image_ext = '.JPEG'#我们使用xml文件名来做数据集的索引#the xml file name and each one corresponding to image file name self._image_index = self._load_xml_filenames() self._salt = str(uuid.uuid4()) self._comp_id = 'comp4'self.config = {
'cleanup': True,
'use_salt': True,
'use_diff': False,
'matlab_eval': False,
'rpn_file': None,
'min_size': 2
}
assert os.path.exists(self._data_path),
\'Path does not exist: {}'.format(self._data_path) image_path_at def image_path_at(self, i) : #使用index来从xml_filenames取到filename,
生成绝对路径
return self.image_path_from_image_filename(self._image_index[i])
image_path_from_image_filename(类似 pascal_voc 中的 image_path_from_index)
def image_path_from_image_filename(self, image_filename) : image_path = os.path.join(self._data_path, 'Images', image_filename + self._image_ext) assert os.path.exists(image_path),
\'Path does not exist: {}'.format(image_path) return image_path
_load_xml_filenames(类似 pascal_voc 中的_load_image_set_index)
def _load_xml_filenames(self) : #从Annotations文件夹中拿取到bounding box annotation文件名#用来做数据集的索引xml_folder_path = os.path.join(self._data_path, "Annotations") assert os.path.exists(xml_folder_path),
\'Path does not exist: {}'.format(xml_folder_path) for dirpath,
dirnames,
filenames in os.walk(xml_folder_path) : xml_filenames = [xml_filename.split(".")[0]
for xml_filename in filenames]
return xml_filenames
gt_roidb
def gt_roidb(self) : #Ground - Truth数据缓存cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl') if os.path.exists(cache_file) : with open(cache_file, 'rb') as fid: roidb = cPickle.load(fid) print '{} gt roidb loaded from {}'.format(self.name, cache_file) return roidb#从xml中获取Ground - Truth数据gt_roidb = [self._load_imagenet_annotation(xml_filename) for xml_filename in self._image_index] with open(cache_file, 'wb') as fid: cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL) print 'wrote gt roidb to {}'.format(cache_file) return gt_roidb
rpn_roidb
def rpn_roidb(self) : #根据gt_roidb生成rpn_roidb,并进行合并gt_roidb = self.gt_roidb() rpn_roidb = self._load_rpn_roidb(gt_roidb) roidb = imdb.merge_roidbs(gt_roidb, rpn_roidb) return roidb
_load_rpn_roidb
def _load_rpn_roidb(self, gt_roidb) : filename = self.config['rpn_file'] print 'loading {}'.format(filename) assert os.path.exists(filename),
\'rpn data not found at: {}'.format(filename) with open(filename, 'rb') as f: box_list = cPickle.load(f) return self.create_roidb_from_box_list(box_list, gt_roidb)
_load_imagenet_annotation(类似于 pascal_voc 中的_load_pascal_annotation)
def _load_imagenet_annotation(self, xml_filename) : #从annotation的xml文件中拿取bounding box数据filepath = os.path.join(self._data_path, 'Annotations', xml_filename + '.xml')#这里使用了ap,是我写的一个annotation parser,在后面贴出代码#它会返回这个xml文件的wnid,
图像文件名,以及里面包含的注解物体wnid,
image_name,
objects = ap.parse(filepath) num_objs = len(objects) boxes = np.zeros((num_objs, 4), dtype = np.uint16) gt_classes = np.zeros((num_objs), dtype = np.int32) overlaps = np.zeros((num_objs, self.num_classes), dtype = np.float32) seg_areas = np.zeros((num_objs), dtype = np.float32)#Load object bounding boxes into a data frame.
for ix,
obj in enumerate(objects) : box = obj["box"] x1 = box['xmin'] y1 = box['ymin'] x2 = box['xmax'] y2 = box['ymax']#如果这个bounding box并不是我们想要学习的类别,那则跳过#go next
if the wnid not exist in declared classes
try: cls = self._class_to_ind[obj["wnid"]] except KeyError: print "wnid %s isn't show in given" % obj["wnid"]
continue boxes[ix, :] = [x1, y1, x2, y2] gt_classes[ix] = cls overlaps[ix, cls] = 1.0 seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1) overlaps = scipy.sparse.csr_matrix(overlaps) return {
'boxes': boxes,
'gt_classes': gt_classes,
'gt_overlaps': overlaps,
'flipped': False,
'seg_areas': seg_areas
}
annotation_parser.py 文件
import os import xml.dom.minidom def getText(node) : return node.firstChild.nodeValue def getWnid(node) : return getText(node.getElementsByTagName("name")[0]) def getImageName(node) : return getText(node.getElementsByTagName("filename")[0]) def getObjects(node) : objects = []
for obj in node.getElementsByTagName("object") : objects.append({
"wnid": getText(obj.getElementsByTagName("name")[0]),
"box": {
"xmin": int(getText(obj.getElementsByTagName("xmin")[0])),
"ymin": int(getText(obj.getElementsByTagName("ymin")[0])),
"xmax": int(getText(obj.getElementsByTagName("xmax")[0])),
"ymax": int(getText(obj.getElementsByTagName("ymax")[0])),
}
}) return objects def parse(filepath) : dom = xml.dom.minidom.parse(filepath) root = dom.documentElement image_name = getImageName(root) wnid = getWnid(root) objects = getObjects(root) return wnid,
image_name,
objects
则对数据结构的要求是:
| ---data | ---imagenet | ---Annotations | ---n03147509 | ---n03147509_ * .xml | ---... | ---n04272054 | ---n04272054_ * .xml | ---... | ---Images | ---n03147508_ * .JPEG | ---... | ---n04272054_ * .JPEG | ---...
同时我在 github 上也提供了 draw 方法,可以用来将 bounding box 画于 Image 文件上,用来甄别该 annotation 的正确性
训练
这样,我们的 ImageNet 类则是生成好了,下面我们则可以训练我们的数据,但是在开始之前,还有一件事情,那就是修改 prototxt 中的与类别数目有关的值,我将
models/pascal_voc
拷贝到了
models/imagenet
进行修改,比如我想要训练 ZF,如果使用的是 train_faster_rcnn_alt_opt.py,则需要修改
models/imagenet/ZF/faster_rcnn_alt_opt/
下的所有 pt 文件里的内容,用如下的法则去替换:
//num为类别的个数
input - data - >num_classes = num class_score - >num_output = num bbox_pred - >num_output = num * 4
我这里使用 train_faster_rcnn_alt_opt.py 进行的训练,这样的话则需要把添加的
models/imagenet
作为可选项
//pt_type 则是添加的选择项,默认使用psacal_voc的models
. / tools / train_faster_rcnn_alt_opt.py--gpu 0\--net_name ZF\--weights data / imagenet_models / ZF.v2.caffemodel[optional]\--imdb imagenet\--cfg experiments / cfgs / faster_rcnn_alt_opt.yml\--pt_type imagenet
识别
这里我们则需要使用刚训练出来的模型进行识别
#就像demo.py一样,但是使用训练的models,我创建了tools / classify.py来单独识别prototxt = os.path.join(cfg.ROOT_DIR, 'models/imagenet', NETS[args.demo_net][0], 'faster_rcnn_alt_opt', 'faster_rcnn_test.pt') caffemodel = os.path.join(cfg.ROOT_DIR, 'output/faster_rcnn_alt_opt/imagenet/' + NETS[args.demo_net][0] + '_faster_rcnn_final.caffemodel')
同样,在识别前我们要对识别方法里的 Classes 进行修改,修改成你自己训练的类别后
执行
. / tools / classify.py--net zf
则可对
data/demo
下的图片文件使用训练的 zf 网络进行识别
Have fun
来源: