文档章节

验证码cnn模型

pseudo
 pseudo
发布于 2017/06/25 09:26
字数 538
阅读 100
收藏 0
点赞 0
评论 2
"""基于切割的识别模型"""

import numpy as np
from PIL import Image
from keras import backend as K
from keras import layers
from keras.models import Sequential

from . import img_util

K.set_image_dim_ordering('tf')

img_width, img_height = 50, 70
chars = "23456789abcdefghiklmnpqrstuvwxyz"


class ZXRModel:
    def __init__(self):
        raise NotImplemented

    def inspect(self, file):
        from keras.utils import plot_model
        plot_model(self.model, to_file=file, show_shapes=True)

    def load_weights(self, weight_filepath):
        self.model.load_weights(weight_filepath)

    def predict(self, img: Image.Image):
        data = self.load_img(img)
        indices = self.model.predict_classes(np.asarray(data), batch_size=len(data), verbose=0)
        return [chars[i] for i in indices]

    def train(self, glob_img_path, save_weights_to=None):
        data, label = self._load_data(glob_img_path)
        self._train_with_data(data, label)
        if save_weights_to:
            self.model.save_weights(save_weights_to)

    def _load_data(self, glob_path):
        from glob import iglob
        from os import path

        img_files = iglob(glob_path)

        print('开始加载训练集:', glob_path)
        data, label = list(), list()
        for file in img_files:
            file_name = path.splitext(path.split(file)[-1])[0].split('.')[0].lower()
            try:
                t = self.load_img(Image.open(file))
                if len(file_name) == len(t):
                    data += t
            except Exception as ex:
                from sys import stderr
                print(ex.args[0], file=stderr)
                continue
            else:
                label += [chars.index(i) for i in file_name]
        return data, label

    def train_with_data(self, data, labels):
        from keras.utils import np_utils
        print('训练集: %d, label: %d' % (len(data), len(labels)))
        self.model.compile(loss='categorical_crossentropy', optimizer='adadelta', metrics=['accuracy'])
        self.model.fit(np.asarray(data),
                       np_utils.to_categorical(labels, len(chars)),
                       batch_size=500, epochs=150, shuffle=True,
                       verbose=1)

    def load_img(self, img):
        segs = img_util.split_zxr(img if img.mode == 'RGB' else img.convert('RGB'))
        return [255 - np.asarray(seg, dtype='int32') for seg in segs]


class SimpleModel(ZXRModel):
    def __init__(self):
        """
        vggnet简化版
        """
        _model = Sequential()
        _model.add(layers.InputLayer(input_shape=(img_height, img_width, 1)))
        for i in range(2):
            _model.add(layers.Convolution2D(9 * 3 ** i, (3, 3), border_mode='valid', activation="relu"))
            _model.add(layers.Convolution2D(9 * 3 ** i, (3, 3), border_mode='valid', activation="relu"))
            _model.add(layers.MaxPooling2D((2, 2)))

        _model.add(layers.Flatten())
        _model.add(layers.Dense(output_dim=128, activation='tanh'))
        _model.add(layers.Dense(len(chars), init='normal', activation='softmax'))
        self.model = _model


class VGGModel(ZXRModel):
    def __init__(self, training=False):
        """定制的vggnet模型"""
        drop_rate = 0.5 if training else 1.0
        _model = Sequential()
        _model.add(layers.InputLayer(input_shape=(img_height, img_width, 3)))
        for i in range(1, 3):
            _model.add(layers.Convolution2D(16 * 2 ** i, (3, 3), border_mode='valid', activation="relu"))
            _model.add(layers.Convolution2D(16 * 2 ** i, (3, 3), border_mode='valid', activation="relu"))
            _model.add(layers.MaxPooling2D((2, 2), strides=(2, 2)))

        _model.add(layers.Flatten())
        _model.add(layers.Dense(1024, activation='relu'))
        _model.add(layers.Dropout(drop_rate))
        _model.add(layers.Dense(1024, activation='relu'))
        _model.add(layers.Dropout(drop_rate))
        _model.add(layers.Dense(output_dim=256, activation='relu'))
        _model.add(layers.Dense(len(chars), init='normal', activation='softmax'))
        self.model = _model


class AlexNet(ZXRModel):
    def __init__(self, training=False):
        drop_rate = 0.5 if training else 1.0
        _model = Sequential()
        _model.add(layers.InputLayer((img_height, img_width, 3)))
        _model.add(layers.Convolution2D(16, (5, 5), strides=(2, 2), activation='relu'))
        _model.add(layers.MaxPool2D((2, 2), strides=(2, 2)))
        _model.add(layers.Convolution2D(32, (3, 3), strides=(1, 1), activation='relu'))
        _model.add(layers.MaxPool2D((2, 2), strides=(2, 2)))

        # for _ in range(3):
        _model.add(layers.Convolution2D(96, (2, 2), strides=(1, 1), activation='relu'))
        # _model.add(layers.Convolution2D(64, (3, 3), strides=(1, 1), activation='relu'))

        _model.add(layers.Flatten())
        _model.add(layers.Dense(512, activation='relu'))
        _model.add(layers.Dropout(drop_rate))
        _model.add(layers.Dense(512, activation='relu'))
        _model.add(layers.Dropout(drop_rate))
        _model.add(layers.Dense(output_dim=256, activation='relu'))
        _model.add(layers.Dense(len(chars), init='normal', activation='softmax'))
        self.model = _model

© 著作权归作者所有

共有 人打赏支持
pseudo

pseudo

粉丝 76
博文 37
码字总数 35469
作品 3
朝阳
程序员
加载中

评论(2)

zyj789
zyj789
麻烦问下:您代码中的from . import img_util 引入的 img_util 是你自己写的吗?如果是的话,能否把代码贴下!谢谢!!!
zyj789
zyj789
麻烦问下:from . import img_util 引入的 img_util 是你自己写的吗?如果是的话,能否把代码贴下!谢谢!!!
tensorflow 实现端到端的OCR:二代身份证号识别

最近在研究OCR识别相关的东西,最终目标是能识别身份证上的所有中文汉字+数字,不过本文先设定一个小目标,先识别定长为18的身份证号,当然本文的思路也是可以复用来识别定长的验证码识别的。...

某杰
2017/08/08
0
0
裤҉裆҉里҉的҉霸҉气҉/verification-decoder

四位验证码CNN识别 1.参考 [1] 街道多位数字CNN识别,神经网络架构参考 [2] 关于CNN的详细解释,深度学习入门必备 [3] 验证码生成参考类 2.支持 [1] Python3.6.1 or >=3.5 [2] TensorFlow 1....

裤҉裆҉里҉的҉霸҉气҉
2017/11/28
0
0
CNN破解简单验证码(Tensorflow实现)

使用CNN破解一下自己生成的图片验证码,因为电脑性能不行,只破解四位的数字验证码,代码实现中可以对符号、字符和数字混合的验证码进行破解,原理相同,有高性能GPU的童鞋可以试试玩玩。CNN...

cskywit
02/01
0
0
ADDA模型实现

去年有看了几篇domain adaptation相关的论文,这里想实现一篇最简单好用的模型,Adversarial Discriminative Domain Adaptation,作者提出了针对adversrarial adaptation 的一个通用框架,并...

Slyne_D
01/21
0
0
机器学习-4:DeepLN之CNN解析

开篇废话: 很感谢谭哥的开篇废话这四个字,让我把一些废话说出来了,是时候还给谭哥了。因为废话太多会让人感觉,没有能力净废话。 今天我开始从头学习CNN,上一篇MachineLN之深度学习入门坑...

MachineLP
01/10
0
0
Github近期最有趣的10款机器学习开源项目

-01- Face Recognition #世界上最简单的人脸识别库 本项目号称世界上最简单的人脸识别库,可使用 Python 和命令行进行调用。该库使用 dlib 顶尖的深度学习人脸识别技术构建,在户外脸部检测数...

技术小能手
01/03
0
2
【AI 工程师】掌握这10个项目,秒杀90%面试者!

2017年人工智能给了我们太多的惊喜和变化,从今年开始,国际巨头们纷纷开始大踏步地战略转向——从移动优先转向AI优先:3月份的微软、4月份的Facebook、5月份的Google、6月份的苹果……乃至前...

dev_csdn
2017/12/14
0
0
教程 如何使用Keras集成多个卷积网络并实现共同预测

     在统计学和机器学习领域,集成方法(ensemble method)使用多种学习算法以获得更好的预测性能(相比单独使用其中任何一种算法)。和统计力学中的统计集成(通常是无穷集合)不同,一...

深度学习
2017/12/15
0
0
详解卷积神经网络(CNN)在语音识别中的应用

欢迎大家前往腾讯云社区,获取更多腾讯海量技术实践干货哦~ 作者:侯艺馨 前言 总结目前语音识别的发展现状,dnn、rnn/lstm和cnn算是语音识别中几个比较主流的方向。2012年,微软邓力和俞栋老...

腾讯云社区
2017/12/01
0
0
详解卷积神经网络(CNN)在语音识别中的应用

欢迎大家前往腾讯云社区,获取更多腾讯海量技术实践干货哦~ 作者:侯艺馨 前言 总结目前语音识别的发展现状,dnn、rnn/lstm和cnn算是语音识别中几个比较主流的方向。2012年,微软邓力和俞栋老...

腾讯云社区
2017/12/01
0
0

没有更多内容

加载失败,请刷新页面

加载更多

下一页

spring boot中swagger2使用

1.pom.xml中添加 <dependency> <groupId>io.springfox</groupId> <artifactId>springfox-swagger2</artifactId> <version>2.9.2</version>......

说回答
6分钟前
0
0
tomcat虚拟路径的几种配置方法

tomcat虚拟路径的几种配置方法 一般我们都是直接引用webapps下面的web项目,如果我们要部署一个在其它地方的WEB项目,这就要在TOMCAT中设置虚拟路径了,Tomcat的加载web顺序是先加载 $Tomcat_ho...

Helios51
18分钟前
1
0
Mac 安装jupyter notebook的过程

MAC台式机 python:mac下自带Python 2.7.10 1.先升级了pip安装工具:sudo python -m pip install --upgrade --force pip 2.安装setuptools 工具:sudo pip install setuptools==33.1.1 3.安装......

火力全開
24分钟前
0
0
导航守卫解释与例子

“导航”表示路由正在发生改变。 正如其名,vue-router 提供的导航守卫主要用来通过跳转或取消的方式守卫导航。有多种机会植入路由导航过程中:全局的, 单个路由独享的, 或者组件级的。 记住...

tianyawhl
24分钟前
0
0
Java日志框架-logback配置文件多环境日志配置(开发、测试、生产)(原始解决方法)

说明:这种方式应该算是最通用的,原理是通过判断标签实现。 <!-- if-then form --> <if condition="some conditional expression"> <then> ... </then> </if> ......

浮躁的码农
38分钟前
1
0
FTP传输时的两种登录方式和区别

登录方式 匿名登录 用户名为: anonymous。 密码为:任何合法 email 地址。 授权登录 用户名为:用户在远程系统中的用户帐号。 密码为:用户在远程系统中的用户密码。 区别 匿名登录 只能访问...

寰宇01
39分钟前
0
0
plsql developer 配置监听(不安装oracle客户端)

plsql developer 配置监听(不安装oracle客户端)

微小宝
46分钟前
1
0
数据库(分库分表)中间件对比

本人的宗旨就是,能copy的,绝对不手写。 分区:对业务透明,分区只不过把存放数据的文件分成了许多小块,例如mysql中的一张表对应三个文件.MYD,MYI,frm。 根据一定的规则把数据文件(MYD)和索...

奔跑吧代码
50分钟前
2
0
Netty与Reactor模式详解

在学习Reactor模式之前,我们需要对“I/O的四种模型”以及“什么是I/O多路复用”进行简单的介绍,因为Reactor是一个使用了同步非阻塞的I/O多路复用机制的模式。 I/O的四种模型 I/0 操作 主要...

hutaishi
56分钟前
1
0
【2018.07.16学习笔记】【linux高级知识 20.16-20.19】

20.16/20.17 shell中的函数 20.18 shell中的数组 20.19 告警系统需求分析

lgsxp
今天
0
0

没有更多内容

加载失败,请刷新页面

加载更多

下一页

返回顶部
顶部