文档章节

【AI实战】手把手教你训练自己的目标检测模型(SSD篇)

雪饼
 雪饼
发布于 2018/08/14 00:32
字数 2765
阅读 10963
收藏 17

目标检测是AI的一项重要应用,通过目标检测模型能在图像中把人、动物、汽车、飞机等目标物体检测出来,甚至还能将物体的轮廓描绘出来,就像下面这张图,是不是很酷炫呢,嘿嘿

在动手训练自己的目标检测模型之前,建议先了解一下目标检测模型的原理(见文章:大话目标检测经典模型RCNN、Fast RCNN、Faster RCNN以及Mark R-CNN),这样才会更加清楚模型的训练过程。

本文将在我们前面搭建好的AI实战基础环境上(见文章:AI基础环境搭建),基于SSD算法,介绍如何使用自己的数据训练目标检测模型。SSD,全称Single Shot MultiBox Detector(单镜头多盒检测器),是Wei Liu在ECCV 2016上提出的一种目标检测算法,是目前流行的主要检测框架之一。

本案例要做的识别便是在图像中识别出熊猫,可爱吧,呵呵
 
下面按照以下过程介绍如何使用自己的数据训练目标检测模型:
 
1、安装标注工具
要使用自己的数据来训练模型,首先得先作数据标注,也就是先要告诉机器图像里面有什么物体、物体在位置在哪里,有了这些信息后才能来训练模型。
(1)标注数据文件
目前流行的数据标注文件格式主要有VOC_2007、VOC_2012,该文本格式来源于Pascal VOC标准数据集,这是衡量图像分类识别能力的重要基准之一。本文采用VOC_2007数据格式文件,以xml格式存储,如下:

其中重要的信息有:
filename:图片的文件名
name:标注的物体名称
xmin、ymin、xmax、ymax:物体位置的左上角、右下角坐标

(2)安装标注工具
如果要标注的图像有很多,那就需要一张一张手动去计算位置信息,制作xml文件,这样的效率就太低了。
所幸,有一位大神开源了一个数据标注工具labelImg,可以通过可视化的操作界面进行画框标注,就能自动生成VOC格式的xml文件了。该工具是基于Python语言编写的,这样就支持在Windows、Linux的跨平台运行,实在是良心之作啊。安装方式如下:
a. 下载源代码
通过访问labelImg的github页面(https://github.com/tzutalin/labelImg),下载源代码。可通过git进行clone,也可以直接下载成zip压缩格式的文件。
 
在本案例中直接下载成zip文件。
b.安装编译
解压labelImg的zip文件,得到LabelImg-master文件夹。
labelImg的界面是使用PyQt编写的,由于我们搭建的基础环境使用了最新版本的anaconda已经自带了PyQt5,在python3的环境下,只需再安装lxml即可,进入LabelImg-master目录进行编译,代码如下:

#激活虚拟环境
source activate tensorflow
#在python3环境中安装PyQt5(anaconda已自带),如果是在python2环境下,则要安装PyQt4,PyQt4的安装方式如下
#conda install -c anaconda pyqt=4.11.4
#安装xml
conda install xml
#编译
make qt5py3
#打开标注工具
python3 labelImg.py

成功打开labelImg标注工具的界面如下:

2、标注数据
成功安装了标注工具后,现在就来开始标注数据了。
(1)创建文件夹
按照VOC数据集的要求,创建以下文件夹
Annotations:用于存放标注后的xml文件
ImageSets/Main:用于存放训练集、测试集、验收集的文件列表
JPEGImages:用于存放原始图像

(2)标注数据
将熊猫图片集放在JPEGImages文件夹里面(熊猫的美照请找度娘要哦~),注意图片的格式必须是jpg格式的。
打开labelImg标注工具,然后点击左侧的工具栏“Open Dir”按钮,选择刚才放熊猫的JPEGImages文件夹。这时,主界面将会自动加载第一张熊猫照片。

点击左侧工具栏的“Create RectBox”按钮,然后在主界面上点击拉个矩形框,将熊猫圈出来。圈定后,将会弹出一个对话框,用于输入标注物体的名称,输入panda作为熊猫的名称。

然后点击左侧工具栏的“Save”按钮,选择刚才创建的Annotations作为保存目录,系统将自动生成voc_2007格式的xml文件保存起来。这样就完成了一张熊猫照片的物体标注了。

接下来点击左侧工具栏的“Next Image”进入下一张图像,按照以上步骤,画框、输入名称、保存,如此反复,直到把所有照片都标注好,保存起来。

(3)划分训练集、测试集、验证集
完成所有熊猫照片的标注后,还要将数据集划分下训练集、测试集和验证集。
在github上下载一个自动划分的脚本(https://github.com/EddyGao/make_VOC2007/blob/master/make_main_txt.py)
然后执行以下代码

python make_main_txt.py

将会按照脚本里面设置的比例,自动拆分训练集、测试集和验证集,将相应的文件名列表保存在里面。

3、配置SSD
(1)下载SSD代码
由于本案例是基于tensorflow的,因此,在github上下载一个基于tensorflow的SSD,地址是 https://github.com/balancap/SSD-Tensorflow
 
以zip文件的方式下载下来,然后解压,得到SSD-Tensorflow-master文件夹
(2)转换文件格式
将voc_2007格式的文件转换为tfrecord格式,tfrecord数据文件tensorflow中的一种将图像数据和标签统一存储的二进制文件,能更加快速地在tensorflow中复制、移动、读取和存储等。
SSD-Tensorflow-master提供了转换格式的脚本,转换代码如下:

DATASET_DIR=./panda_voc2007/
OUTPUT_DIR=./panda_tfrecord/
python SSD-Tensorflow-master/tf_convert_data.py --dataset_name=pascalvoc --dataset_dir=${DATASET_DIR} --output_name=voc_2007_train --output_dir=${OUTPUT_DIR}

(3)修改物体类别
由于是我们自定义的物体,因此,要修改SSD-Tensorflow-master中关于物体类别的定义,打开SSD-Tensorflow-master/datasets/pascalvoc_common.py文件,进行修改,将VOC_LABELS中的其它无关类别全部删掉,增加panda的名称、ID、类别,如下:

VOC_LABELS = {
    'none': (0, 'Background'),
'panda': (1, 'Animal'),
}

4、下载预训练模型
SSD-Tensorflow提供了预训练好的模型,基于VGG模型(要了解VGG模型详情,请阅读文章:大话经典CNN经典模型VGG),如下表:
 
但这些预训练的模型文件都是存储在drive.google.com上,因此,无法直接下载。只能通过“你懂的”方式进行下载,在这里下载SSD-300 VGG-based预训练模型,得到文件:VGG_VOC0712_SSD_300x300_ft_iter_120000.ckpt.zip,然后进行解压

5、训练模型
终于把标注文件、SSD模型都准备好了,现在准备开始来训练了。
在训练模型之前,有个参数要修改下,打开SSD-Tensorflow-master/train_ssd_network.py找到里面的DATA_FORMAT参数项,如果是使用cpu训练则值为NHWC,如果是使用gpu训练则值为NCHW,如下:

DATA_FORMAT = 'NCHW'  # gpu
# DATA_FORMAT = 'NHWC'    # cpu

现在终于可以开始来训练了,打开终端,切换conda虚拟环境

source activate tensorflow

然后执行以下命令,开始训练

# 使用预训练好的 vgg_ssd_300 模型 
DATASET_DIR=./ panda_tfrecord
TRAIN_DIR=./panda_model
CHECKPOINT_PATH=./model_pre_train/VGG_VOC0712_SSD_300x300_ft_iter_120000.ckpt/VGG_VOC0712_SSD_300x300_ft_iter_120000.ckpt
python3 SSD-Tensorflow-master/train_ssd_network.py \
    --train_dir=${TRAIN_DIR} \
    --dataset_dir=${DATASET_DIR} \
    --dataset_name=pascalvoc_2007 \
    --dataset_split_name=train \
    --model_name=ssd_300_vgg \
    --checkpoint_path=${CHECKPOINT_PATH} \
    --save_summaries_secs=60 \
    --save_interval_secs=600 \
    --weight_decay=0.0005 \
    --optimizer=adam \
    --learning_rate=0.0001 \
    --batch_size=16

其中,根据自己电脑的性能情况,设置batch_size的值,值越大表示批量处理的数量越大,对机器性能的要求越高。如果电脑性能普通的,则可以设置为8,甚至4,土豪请忽略。
学习率learning_rate也可以根据实际情况调整,学习率越小则越精确,训练的时间也越长,学习率越大则可缩短训练时间,但就会降低精准度。

在这里使用预训练好的模型,SSD将会锁定VGG模型的一些参数进行训练,这样能在较短的时间内完成训练。

6、使用模型
SSD模型训练好了,现在要来使用了,使用的方式也很简单。
SSD-Tensorflow-master自带了一个notebooks脚本,可通过jupyter直接使用模型。
先安装jupyter,安装方式如下:

conda install jupyter

然后启动jupyter-notebook,代码如下:

jupyter-notebook SSD-Tensorflow-master/notebooks/ssd_notebook.ipynb

启动后在SSD 300 Model的代码块设置模型的路径和名称

然后在最后的代码块中,设置要测试的图像路径path

然后点击菜单“Cell”,点击子菜单“Run All”,便能按顺序全部执行代码,并显示出结果出来

执行后,可爱的熊猫就被圈出来了

经过以上步骤,我们便使用了自己的数据完成了目标检测模型的训练。只要以后还有物体检测的需求,然后找相关的图片集进行标注,标注后进行模型训练,就能完成一个定制化的目标检测模型了,非常方便,希望本案例对大家能有所帮助。

 

关注本人公众号“大数据与人工智能Lab”(BigdataAILab),然后回复“代码”关键字可获取 完整源代码。

 

推荐相关阅读

 

© 著作权归作者所有

雪饼

雪饼

粉丝 410
博文 61
码字总数 134328
作品 0
广州
私信 提问
加载中

评论(25)

这个昵称不太火
这个昵称不太火
您好,训练完测试的时候一张图片胡乱框出了许多框是什么原因呢?
雪饼
雪饼 博主

引用来自“vitalik”的评论

是不是跟检测的目标大小有关系?我需要检测的目标大小时 60*60 px 大小,这个大小是不是太小了,使用这个预训练模型无法检测到?

引用来自“雪饼”的评论

应该不是大小问题。你的图片在matlab转换后,大小有变化吗,因为原先标注后的内容是带有位置信息的,另外,训练的次数、时长足够吗,如果训练不够,也会导致无法识别

引用来自“vitalik”的评论

分辨率没有变化,体积变小了。我的训练集比较小,只有40张图纸,检测的目标也不复杂,是图纸上的标注符号,在每张图纸上的形态基本一致,通常是有旋转,因为感觉检测的目标比较简单,条件限制使用的是cpu训练,所以训练的步数只有1500步,请指导下问题可能出在哪里?
图片太少
vitalik
vitalik

引用来自“vitalik”的评论

是不是跟检测的目标大小有关系?我需要检测的目标大小时 60*60 px 大小,这个大小是不是太小了,使用这个预训练模型无法检测到?

引用来自“雪饼”的评论

应该不是大小问题。你的图片在matlab转换后,大小有变化吗,因为原先标注后的内容是带有位置信息的,另外,训练的次数、时长足够吗,如果训练不够,也会导致无法识别
分辨率没有变化,体积变小了。我的训练集比较小,只有40张图纸,检测的目标也不复杂,是图纸上的标注符号,在每张图纸上的形态基本一致,通常是有旋转,因为感觉检测的目标比较简单,条件限制使用的是cpu训练,所以训练的步数只有1500步,请指导下问题可能出在哪里?
雪饼
雪饼 博主

引用来自“vitalik”的评论

是不是跟检测的目标大小有关系?我需要检测的目标大小时 60*60 px 大小,这个大小是不是太小了,使用这个预训练模型无法检测到?
应该不是大小问题。你的图片在matlab转换后,大小有变化吗,因为原先标注后的内容是带有位置信息的,另外,训练的次数、时长足够吗,如果训练不够,也会导致无法识别
vitalik
vitalik
是不是跟检测的目标大小有关系?我需要检测的目标大小时 60*60 px 大小,这个大小是不是太小了,使用这个预训练模型无法检测到?
vitalik
vitalik
请问博主,我之前因为黑白图像训练出现的问题已经解决了,我现在使用matlab先把黑白单通道图像转换成三通道,重新训练,新的问题是训练的模型使用ssd_notebook.ipynb测试时无法识别目标,无框无数,请指导问题原因,感谢。
雪饼
雪饼 博主

引用来自“wait1ess”的评论

请问博主,yolo_video.py应该怎么用 如果我想简单点直接使用yolo3的权重不训练的话
官网上有介绍,python yolo_video.py [video_path] [output_path (optional)]
wait1ess
wait1ess
请问博主,yolo_video.py应该怎么用 如果我想简单点直接使用yolo3的权重不训练的话
雪饼
雪饼 博主

引用来自“vitalik”的评论

博主你好,我在最后进行使用模型时出现这个错误,请问怎么解决
<img alt="" height="658" src="https://oscimg.oschina.net/oscnet/93e043e3f43d9e90893a69345d556312842.jpg" width="659" />
~\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
1126 'which has shape %r' %
1127 (np_val.shape, subfeed_t.name,
-> 1128 str(subfeed_t.get_shape())))
1129 if not self.graph.is_feedable(subfeed_t):
1130 raise ValueError('Tensor %s may not be fed.' % subfeed_t)

ValueError: Cannot feed value of shape (1654, 2344) for Tensor 'Placeholder_3:0', which has shape '(?, ?, 3)'

引用来自“雪饼”的评论

你的图片是彩色的吗,不要转灰度,直接输入RGB的彩色图片

引用来自“vitalik”的评论

谢谢指导问题原因,我找到解决方法了🙏
👍
vitalik
vitalik

引用来自“vitalik”的评论

博主你好,我在最后进行使用模型时出现这个错误,请问怎么解决
<img alt="" height="658" src="https://oscimg.oschina.net/oscnet/93e043e3f43d9e90893a69345d556312842.jpg" width="659" />
~\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
1126 'which has shape %r' %
1127 (np_val.shape, subfeed_t.name,
-> 1128 str(subfeed_t.get_shape())))
1129 if not self.graph.is_feedable(subfeed_t):
1130 raise ValueError('Tensor %s may not be fed.' % subfeed_t)

ValueError: Cannot feed value of shape (1654, 2344) for Tensor 'Placeholder_3:0', which has shape '(?, ?, 3)'

引用来自“雪饼”的评论

你的图片是彩色的吗,不要转灰度,直接输入RGB的彩色图片
谢谢指导问题原因,我找到解决方法了🙏
【AI实战】手把手教你文字识别(检测篇二:AdvancedEAST、PixelLink方法)

自然场景下的文字检测是深度学习的重要应用,在之前的文章中已经介绍过了在简单场景、复杂场景下的文字检测方法,包括MSER+NMS、CTPN、SegLink、EAST等方法,详见文章: 【AI实战】手把手教你...

雪饼
06/24
3.1K
8
【AI实战】手把手教你文字识别(识别篇:LSTM+CTC, CRNN, chineseocr方法)

文字识别是AI的一个重要应用场景,文字识别过程一般由图像输入、预处理、文本检测、文本识别、结果输出等环节组成。 其中,文本检测、文本识别是最核心的环节。文本检测方面,在前面的文章中...

雪饼
07/07
5.4K
8
【图解AI:动图】各种类型的卷积,你认全了吗?

卷积(convolution)是深度学习中非常有用的计算操作,主要用于提取图像的特征。在近几年来深度学习快速发展的过程中,卷积从标准卷积演变出了反卷积、可分离卷积、分组卷积等各种类型,以适...

雪饼
06/20
349
0
【AI实战】手把手教你深度学习文字识别(文字检测篇:基于MSER, CTPN, SegLink, EAST等方法)

文字检测是文字识别过程中的一个非常重要的环节,文字检测的主要目标是将图片中的文字区域位置检测出来,以便于进行后面的文字识别,只有找到了文本所在区域,才能对其内容进行识别。 文字检...

雪饼
05/27
3.5K
6
【AI实战】训练第一个AI模型:MNIST手写数字识别模型

在上篇文章中,我们已经把AI的基础环境搭建好了(见文章:Ubuntu + conda + tensorflow + GPU + pycharm搭建AI基础环境),接下来将基于tensorflow训练第一个AI模型:MNIST手写数字识别模型。...

雪饼
2018/08/11
3.7K
0

没有更多内容

加载失败,请刷新页面

加载更多

哪些情况下适合使用云服务器?

我们一直在说云服务器价格适中,具备弹性扩展机制,适合部署中小规模的网站或应用。那么云服务器到底适用于哪些情况呢?如果您需要经常原始计算能力,那么使用独立服务器就能满足需求,因为他...

云漫网络Ruan
今天
5
0
Java 中的 String 有没有长度限制

转载: https://juejin.im/post/5d53653f5188257315539f9a String是Java中很重要的一个数据类型,除了基本数据类型以外,String是被使用的最广泛的了,但是,关于String,其实还是有很多东西...

低至一折起
今天
17
0
OpenStack 简介和几种安装方式总结

OpenStack :是一个由NASA和Rackspace合作研发并发起的,以Apache许可证授权的自由软件和开放源代码项目。项目目标是提供实施简单、可大规模扩展、丰富、标准统一的云计算管理平台。OpenSta...

小海bug
昨天
11
0
DDD(五)

1、引言 之前学习了解了DDD中实体这一概念,那么接下来需要了解的就是值对象、唯一标识。值对象,值就是数字1、2、3,字符串“1”,“2”,“3”,值时对象的特征,对象是一个事物的具体描述...

MrYuZixian
昨天
9
0
解决Mac下VSCode打开zsh乱码

1.乱码问题 iTerm2终端使用Zsh,并且配置Zsh主题,该主题主题需要安装字体来支持箭头效果,在iTerm2中设置这个字体,但是VSCode里这个箭头还是显示乱码。 iTerm2展示如下: VSCode展示如下: 2...

HelloDeveloper
昨天
9
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部