文档章节

【AI实战】动手训练自己的目标检测模型(YOLO篇)

雪饼
 雪饼
发布于 2018/08/14 18:49
字数 2031
阅读 8720
收藏 10

在前面的文章中,已经介绍了基于SSD使用自己的数据训练目标检测模型(见文章:手把手教你训练自己的目标检测模型),本文将基于另一个目标检测模型YOLO,介绍如何使用自己的数据进行训练。

 
YOLO(You only look once)是目前流行的目标检测模型之一,目前最新已经发展到V3版本了,在业界的应用也很广泛。YOLO的基本原理是:首先对输入图像划分成7x7的网格,对每个网格预测2个边框,然后根据阈值去除可能性比较低的目标窗口,最后再使用边框合并的方式去除冗余窗口,得出检测结果,如下图:
 
YOLO的特点就是“快”,但由于YOLO对每个网格只预测一个物体,就容易造成漏检,对物体的尺度相对比较敏感,对于尺度变化较大的物体泛化能力较差。

本文的目标仍旧是在图像中识别检测出可爱的熊猫

基于YOLO使用自己的数据训练目标检测模型,训练过程跟前面文章所介绍的基于SSD训练模型一样,主要步骤如下:
 
1、安装标注工具
本案例采用的标注工具是labelImg,在前面的文章介绍训练SSD模型时有详细介绍了安装方法(见文章:手把手教你训练自己的目标检测模型),在此就不再赘述了。
成功安装后的labelImg标注工具,如下图:

2、标注数据
使用labelImg工具对熊猫照片进行画框标注,自动生成VOC_2007格式的xml文件,保存为训练数据集。操作方式跟前面的文章介绍训练SSD模型的标注方法一样(见文章:手把手教你训练自己的目标检测模型),在此就不再赘述了。

3、配置YOLO
(1)安装Keras
 
本案例选用YOLO的最新V3版本,基于Keras版本。Keras是一个高层神经网络API,以Tensorflow、Theano和CNTK作为后端。由于本案例的基础环境(见文章:AI基础环境搭建)已经安装了tensorflow,因此,Keras底层将会调用tensorflow跑模型。Keras安装方式如下:

# 切换虚拟环境
source activate tensorflow
# 安装keras-gpu版本
conda install keras-gpu
# 如果是安装 keras cpu版本,则执行以下指令
#conda install keras

keras版本的yolo3还依赖于PIL工具包,如果之前没安装的,也要在anaconda中安装

# 安装 PIL
conda install pillow

(2)下载yolo3源代码
在keras-yolo3的github上下载源代码(https://github.com/qqwweee/keras-yolo3),使用git进行clone或者直接下载成zip压缩文件。

(3)导入PyCharm
打开PyCharm,新建项目,将keras-yolo3的源代码导入到PyCharm中

4、下载预训练模型
YOLO官网上提供了YOLOv3模型训练好的权重文件,把它下载保存到电脑上。下载地址为https://pjreddie.com/media/files/yolov3.weights

5、训练模型
接下来到了关键的步骤:训练模型。在训练模型之前,还有几项准备工作要做。
(1)转换标注数据文件
YOLO采用的标注数据文件,每一行由文件所在路径、标注框的位置(左上角、右下角)、类别ID组成,格式为:image_file_path x_min,y_min,x_max,y_max,class_id
例子如下:
 
这种文件格式跟前面制作好的VOC_2007标注文件的格式不一样,Keras-yolo3里面提供了voc格式转yolo格式的转换脚本 voc_annotation.py
在转换格式之前,先打开voc_annotation.py文件,修改里面的classes的值。例如本案例在voc_2007中标注的熊猫的物体命名为panda,因此voc_annotation.py修改为:

import xml.etree.ElementTree as ET
from os import getcwd

sets=[('2007', 'train'), ('2007', 'val'), ('2007', 'test')]

classes = ["panda"]

新建文件夹VOCdevkit/VOC2007,将熊猫的标注数据文件夹Annotations、ImageSets、JPEGImages放到文件夹VOCdevkit/VOC2007里面,然后执行转换脚本,代码如下:

mkdir VOCdevkit
mkdir VOCdevkit/VOC2007
mv Annotations VOCdevkit/VOC2007
mv ImageSets VOCdevkit/VOC2007
mv JPEGImages VOCdevkit/VOC2007

source activate tensorflow
python voc_annotation.py

转换后,将会自动生成yolo格式的文件,包括训练集、测试集、验证集。

(2)创建类别文件
在PyCharm导入的keras-yolo3源代码中,在model_data目录里面新建一个类别文件my_class.txt,将标注物体的类别写到里面,每行一个类别,如下:

(3)转换权重文件
将前面下载的yolo权重文件yolov3.weights转换成适合Keras的模型文件,转换代码如下:

source activate tensorflow
python convert.py -w yolov3.cfg yolov3.weights model_data/yolo_weights.h5

(4)修改训练文件的路径配置
修改train.py里面的相关路径配置,主要有:annotation_path、classes_path、weights_path

其中,train.py里面的batch_size默认是32(第57行),指每次处理时批量处理的数量,数值越大对机器的性能要求越高,因此可根据电脑的实际情况进行调高或调低

(5)训练模型
经过以上的配置后,终于全部都准备好了,执行train.py就可以开始进行训练。

训练后的模型,默认保存路径为logs/000/trained_weights_final.h5,可以根据需要进行修改,位于train.py的第85行,可修改模型保存的名称。

6、使用模型
完成模型的训练之后,调用yolo.py即可使用我们训练好的模型。
首先,修改yolo.py里面的模型路径、类别文件路径,如下:

class YOLO(object):
    _defaults = {
        "model_path": 'logs/000/trained_weights_final.h5',
        "anchors_path": 'model_data/yolo_anchors.txt',
        "classes_path": 'model_data/my_classes.txt',
        "score" : 0.3,
        "iou" : 0.45,
        "model_image_size" : (416, 416),
        "gpu_num" : 1,
    }

通过调用 YOLO类就能使用YOLO模型,为方便测试,在yolo.py最后增加以下代码,只要修改图像路径后,就能使用自己的yolo模型了

if __name__ == '__main__':
    yolo=YOLO()
    path = '/data/work/tensorflow/data/panda_test/1.jpg'
    try:
        image = Image.open(path)
    except:
        print('Open Error! Try again!')
    else:
        r_image, _ = yolo.detect_image(image)
        r_image.show()

    yolo.close_session()

执行后,可爱的熊猫就被乖乖圈出来了,呵呵

通过以上步骤,我们又学习了基于YOLO来训练自己的目标检测模型,这样在应用中可以结合实际需求,使用SSD、YOLO训练自己的数据,并从中选择出效果更好的目标检测模型。

 

关注本人公众号“大数据与人工智能Lab”(BigdataAILab),然后回复“代码”关键字可获取 完整源代码

 

推荐相关阅读

 

© 著作权归作者所有

雪饼

雪饼

粉丝 408
博文 61
码字总数 134328
作品 0
广州
私信 提问
加载中

评论(25)

石头来了
你好, 博主,能够共享下 你的数据集, 近千张的数据集制作要很长时间吧
s
shihun
请问怎么改成调用摄像头进行实时物体检测?
雪饼
雪饼 博主

引用来自“这个昵称不太火”的评论

您好,最后测试的时候path为某一张图片的路径,可以最后报错TypeError: 'JpegImageFile' object is not iterable
,path指向的那张图片只有一个熊猫,却框出来好多个框,这是怎么回事呢?

引用来自“这个昵称不太火”的评论

把r_image, _ = yolo.detect_image(image)改成r_image= yolo.detect_image(image)就好了,detect_image()函数返回一个对象
点赞
这个昵称不太火
这个昵称不太火

引用来自“这个昵称不太火”的评论

您好,最后测试的时候path为某一张图片的路径,可以最后报错TypeError: 'JpegImageFile' object is not iterable
,path指向的那张图片只有一个熊猫,却框出来好多个框,这是怎么回事呢?
把r_image, _ = yolo.detect_image(image)改成r_image= yolo.detect_image(image)就好了,detect_image()函数返回一个对象
这个昵称不太火
这个昵称不太火
您好,最后测试的时候path为某一张图片的路径,可以最后报错TypeError: 'JpegImageFile' object is not iterable
,path指向的那张图片只有一个熊猫,却框出来好多个框,这是怎么回事呢?
qizidog
qizidog

引用来自“wanguzgen”的评论

为什么我按照你的操作老是不行

引用来自“雪饼”的评论

哪里有出现问题

引用来自“御剑飞星”的评论

tensorflow.python.framework.errors_impl.ResourceExhaustedError: OOM when allocating tensor with shape[32,52,52,256] and type float on /job:localhost/replica:0/task:0/device:CPU:0 by allocator cpu
   [[Node: leaky_re_lu_14/LeakyRelu = Maximum[T=DT_FLOAT, _class=["loc:@training_1/Adam/gradients/batch_normalization_14/cond/Merge_grad/cond_grad"], _device="/job:localhost/replica:0/task:0/device:CPU:0"](leaky_re_lu_14/LeakyRelu/mul, batch_normalization_14/cond/Merge)]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.

引用来自“雪饼”的评论

你电脑内存有多大,是不是不够

引用来自“qizidog”的评论

你好,雪饼,我在复现你的yolo教程(https://my.oschina.net/u/876354/blog/1927881)时遇到了和御剑飞星相同的问题。能否留下联系方式,比如qq,希望能向你单独请教一下。报错发生在train.py执行过程中,应该不是内存的问题,我用gpu和非gpu版的tf跑下来都一样。教程中有些文件的目录有点小问题,我已经根据实际情况修正过了,迫切地希望和你取得联系!谢谢。

引用来自“雪饼”的评论

你好,可尝试将train.py里面的batch_size改小,例如改为16或者8,这样会占用内存小一些,再重新训练模型
好的,非常感谢,确实是这样的,还是自己的gpu太菜了的锅
wait1ess
wait1ess
请问博主 如果在Google colab上进行 路径含有空格会报错 这时候怎么处理 就是读取lines的时候
雪饼
雪饼 博主

引用来自“逝去的记忆”的评论

请问训练图片的大小有没有要求

引用来自“逝去的记忆”的评论

我的抱错信息:
File "/home/rainbomsea/github/keras-yolo3/train.py", line 191, in
_main()
File "/home/rainbomsea/github/keras-yolo3/train.py", line 84, in _main
callbacks=[logging, checkpoint, reduce_lr, early_stopping])
File "/home/rainbomsea/github/keras-yolo3/venv/lib/python3.6/site-packages/keras/legacy/interfaces.py", line 91, in wrapper
return func(*args, **kwargs)
File "/home/rainbomsea/github/keras-yolo3/venv/lib/python3.6/site-packages/keras/engine/training.py", line 1418, in fit_generator
initial_epoch=initial_epoch)
File "/home/rainbomsea/github/keras-yolo3/venv/lib/python3.6/site-packages/keras/engine/training_generator.py", line 217, in fit_generator
class_weight=class_weight)
File "/home/rainbomsea/github/keras-yolo3/venv/lib/python3.6/site-packages/keras/engine/training.py", line 1217, in train_on_batch
outputs = self.train_function(ins)
File "/home/rainbomsea/github/keras-yolo3/venv/lib/python3.6/site-packages/keras/backend/tensorf

引用来自“RainbomSea”的评论

我的bitch_size已经设为8了, 还是抱错, 应该不是内存问题, 代码的input_shape 是 (416, 416) 不过我的训练图片大小是800 x 800 的 , 我不知到是不是这个问题, 还是代码已经处理类图片大小的问题
你的内存有多大,有用GPU吗,显存多大。batch_size改4,再试试,在内存很有限的情况下,有时也会设置为4
雪饼
雪饼 博主

引用来自“逝去的记忆”的评论

请问训练图片的大小有没有要求

引用来自“RainbomSea”的评论

我的抱错信息:
File "/home/rainbomsea/github/keras-yolo3/train.py", line 191, in
_main()
File "/home/rainbomsea/github/keras-yolo3/train.py", line 84, in _main
callbacks=[logging, checkpoint, reduce_lr, early_stopping])
File "/home/rainbomsea/github/keras-yolo3/venv/lib/python3.6/site-packages/keras/legacy/interfaces.py", line 91, in wrapper
return func(*args, **kwargs)
File "/home/rainbomsea/github/keras-yolo3/venv/lib/python3.6/site-packages/keras/engine/training.py", line 1418, in fit_generator
initial_epoch=initial_epoch)
File "/home/rainbomsea/github/keras-yolo3/venv/lib/python3.6/site-packages/keras/engine/training_generator.py", line 217, in fit_generator
class_weight=class_weight)
File "/home/rainbomsea/github/keras-yolo3/venv/lib/python3.6/site-packages/keras/engine/training.py", line 1217, in train_on_batch
outputs = self.train_function(ins)
File "/home/rainbomsea/github/keras-yolo3/venv/lib/python3.6/site-packages/keras/backend/tensorf
能否截报错的关键信息来看一下。现在附上来的只是一些路径,看不出提示什么
雪饼
雪饼 博主

引用来自“RainbomSea”的评论

请问训练图片的大小有没有要求
在模型里面会自动对图片的大小进行调整,对输入没有要求
【AI实战】快速掌握TensorFlow(二):计算图、会话

在前面的文章中,我们已经完成了AI基础环境的搭建(见文章:Ubuntu + Anaconda + TensorFlow + GPU + PyCharm搭建AI基础环境),以及初步了解了TensorFlow的特点和基本操作(见文章:快速掌握...

雪饼
2018/08/20
1K
1
【AI实战】手把手教你训练自己的目标检测模型(SSD篇)

目标检测是AI的一项重要应用,通过目标检测模型能在图像中把人、动物、汽车、飞机等目标物体检测出来,甚至还能将物体的轮廓描绘出来,就像下面这张图,是不是很酷炫呢,嘿嘿 在动手训练自己...

雪饼
2018/08/14
10.8K
25
【AI实战】快速掌握TensorFlow(三):激励函数

到现在我们已经了解了TensorFlow的特点和基本操作(见文章:快速掌握TensorFlow(一)),以及TensorFlow计算图、会话的操作(见文章:快速掌握TensorFlow(二)),接下来我们将继续学习掌握...

雪饼
2018/08/30
1K
0
【AI实战】训练第一个AI模型:MNIST手写数字识别模型

在上篇文章中,我们已经把AI的基础环境搭建好了(见文章:Ubuntu + conda + tensorflow + GPU + pycharm搭建AI基础环境),接下来将基于tensorflow训练第一个AI模型:MNIST手写数字识别模型。...

雪饼
2018/08/11
3.6K
0
【AI实战】手把手教你文字识别(检测篇二:AdvancedEAST、PixelLink方法)

自然场景下的文字检测是深度学习的重要应用,在之前的文章中已经介绍过了在简单场景、复杂场景下的文字检测方法,包括MSER+NMS、CTPN、SegLink、EAST等方法,详见文章: 【AI实战】手把手教你...

雪饼
06/24
3K
8

没有更多内容

加载失败,请刷新页面

加载更多

OSChina 周一乱弹 —— 年迈渔夫遭黑帮袭抢

Osc乱弹歌单(2019)请戳(这里) 【今日歌曲】 @tom_tdhzz :#今日歌曲推荐# 分享Elvis Presley的单曲《White Christmas》: 《White Christmas》- Elvis Presley 手机党少年们想听歌,请使劲...

小小编辑
56分钟前
222
11
CentOS7.6中安装使用fcitx框架

内容目录 一、为什么要使用fcitx?二、安装fcitx框架三、安装搜狗输入法 一、为什么要使用fcitx? Gnome3桌面自带的输入法框架为ibus,而在使用ibus时会时不时出现卡顿无法输入的现象。 搜狗和...

技术训练营
昨天
5
0
《Designing.Data-Intensive.Applications》笔记 四

第九章 一致性与共识 分布式系统最重要的的抽象之一是共识(consensus):让所有的节点对某件事达成一致。 最终一致性(eventual consistency)只提供较弱的保证,需要探索更高的一致性保证(stro...

丰田破产标志
昨天
8
0
docker 使用mysql

1, 进入容器 比如 myslq1 里面进行操作 docker exec -it mysql1 /bin/bash 2. 退出 容器 交互: exit 3. mysql 启动在容器里面,并且 可以本地连接mysql docker run --name mysql1 --env MY...

之渊
昨天
10
0
python数据结构

1、字符串及其方法(案例来自Python-100-Days) def main(): str1 = 'hello, world!' # 通过len函数计算字符串的长度 print(len(str1)) # 13 # 获得字符串首字母大写的...

huijue
昨天
6
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部