文档章节

CNN中文文本分类-基于TensorFlow实现

Gaussic
 Gaussic
发布于 2017/08/30 01:02
字数 1747
阅读 1156
收藏 3

代码地址:Github

转载请注明出处:Gaussic - 写干净的代码

基于CNN的文本分类问题已经有了一定的研究成果,CNN做句子分类的论文可以参看: Convolutional Neural Networks for Sentence Classification

在网上也有了一些开源的实现,例如比较著名的dennybritz大牛的博客Implementing a CNN for Text Classification in TensorFlow基于早期TensorFlow的一个实现版本。

如今,TensorFlow大版本已经升级到了1.3,对很多的网络层实现了更高层次的封装和实现,甚至还整合了如Keras这样优秀的一些高层次框架,使得其易用性大大提升。相比早起的底层代码,如今的实现更加简洁和优雅。

本章的目的是基于TensorFlow的API来重新实现一个在中文文本上的分类器。如果你觉得对你有些许帮助或者疑惑,欢迎star和交流。

数据集

本文采用了清华NLP组提供的THUCNews新闻文本分类数据集的一个子集(原始的数据集大约74万篇文档,训练起来需要花较长的时间)。数据集请自行到THUCTC:一个高效的中文文本分类工具包下载,请遵循数据提供方的开源协议。

本次训练使用了其中的10个分类,每个分类6500条,总共65000条新闻数据。

类别如下:

体育, 财经, 房产, 家居, 教育, 科技, 时尚, 时政, 游戏, 娱乐

数据集划分如下:

  • 训练集: 5000*10
  • 验证集: 500*10
  • 测试集: 1000*10

从原数据集生成子集的过程请参看helper下的两个脚本。其中,copy_data.sh用于从每个分类拷贝6500个文件,cnews_group.py用于将多个文件整合到一个文件中。执行该文件后,得到三个数据文件:

  • cnews.train.txt: 训练集(50000条)
  • cnews.val.txt: 验证集(5000条)
  • cnews.test.txt: 测试集(10000条)

预处理

data/cnews_loader.py为数据的预处理文件。

  • read_file():读取上一部分生成的数据文件,将内容和标签分开返回;
  • _build_vocab(): 构建词汇表,这里不需要对文档进行分词,单字的效果已经很好,这一函数会将词汇表存储下来,避免每一次重复处理;
  • _read_vocab(): 读取上一步存储的词汇表,转换为{词:id}表示;
  • _read_category(): 将分类目录固定,转换为{类别: id}表示;
  • _file_to_ids(): 基于上面定义的函数,将数据集从文字转换为id表示;
  • to_words(): 将一条由id表示的数据重新转换为文字;
  • preocess_file(): 一次性处理所有的数据并返回;
  • batch_iter(): 为神经网络的训练准备批次的数据。

经过数据预处理,数据的格式如下:

输入图片说明

配置项

可配置的参数如下所示,在model.py的上部。

class TCNNConfig(object):
    """配置参数"""

    # 模型参数
    embedding_dim = 64      # 词向量维度
    seq_length = 600        # 序列长度
    num_classes = 10        # 类别数
    num_filters = 256       # 卷积核数目
    kernel_size = 5         # 卷积核尺寸
    vocab_size = 5000       # 词汇表达小

    hidden_dim = 128        # 全链接层神经元

    dropout_keep_prob = 0.8 # dropout保留比例
    learning_rate = 1e-3    # 学习率

    batch_size = 128         # 每批训练大小
    num_epochs = 10          # 总迭代轮次

模型

原始的模型如下图所示:

输入图片说明

可看到它使用了多个不同宽度的卷积核然后将它们做了一个max over time pooling转换为一个长的特征向量,再使用softmax进行分类。

实验发现,简单的cnn也能达到较好的效果。

因此在这里使用的是简化版的结构,具体参看model.py

首先在初始化时,需要定义两个placeholder作为输入输出占位符。

def __init__(self, config):
      self.config = config

      self.input_x = tf.placeholder(tf.int32,
          [None, self.config.seq_length], name='input_x')
      self.input_y = tf.placeholder(tf.float32,
          [None, self.config.num_classes], name='input_y')

      self.cnn()

词嵌入将词的id映射为词向量表示,embedding层会在训练时更新。

def input_embedding(self):
    """词嵌入"""
    with tf.device('/cpu:0'):
        embedding = tf.get_variable('embedding',
            [self.config.vocab_size, self.config.embedding_dim])
        _inputs = tf.nn.embedding_lookup(embedding, self.input_x)
    return _inputs

cnn模型中,首先定义一个一维卷积层,再使用tf.reduce_max实现global max pooling。再接两个dense层分别做映射和分类。使用交叉熵损失函数,Adam优化器,并且计算准确率。这里有许多参数可调,大部分可以通过调整TCNNConfig类即可。

def cnn(self):
      """cnnc模型"""
      embedding_inputs = self.input_embedding()

      with tf.name_scope("cnn"):
          # cnn 与全局最大池化
          conv = tf.layers.conv1d(embedding_inputs,
              self.config.num_filters,
              self.config.kernel_size, name='conv')

          # global max pooling
          gmp = tf.reduce_max(conv, reduction_indices=[1], name='gmp')

      with tf.name_scope("score"):
          # 全连接层,后面接dropout以及relu激活
          fc = tf.layers.dense(gmp, self.config.hidden_dim, name='fc1')
          fc = tf.contrib.layers.dropout(fc,
              self.config.dropout_keep_prob)
          fc = tf.nn.relu(fc)

          # 分类器
          self.logits = tf.layers.dense(fc, self.config.num_classes,
              name='fc2')
          self.pred_y = tf.nn.softmax(self.logits)

      with tf.name_scope("loss"):
          # 损失函数,交叉熵
          cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
              logits=self.logits, labels=self.input_y)
          self.loss = tf.reduce_mean(cross_entropy)

      with tf.name_scope("optimize"):
          # 优化器
          optimizer = tf.train.AdamOptimizer(
              learning_rate=self.config.learning_rate)
          self.optim = optimizer.minimize(self.loss)

      with tf.name_scope("accuracy"):
          # 准确率
          correct_pred = tf.equal(tf.argmax(self.input_y, 1),
              tf.argmax(self.pred_y, 1))
          self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

训练与验证

这一部分详见代码,已经做了许多的注释,浅显易懂,具体不在此叙述。

在设定迭代轮次为5的时候,测试集达到了95.72%的准确率,可见效果还是很理想的。

Loading data...
Time usage: 0:00:16
Constructing Model...
Training and evaluating...
Iter:      1, Train Loss:    2.3, Train Acc:  10.94%, Val Loss:    2.3, Val Acc:  10.06%, Time: 0:00:01
Iter:    201, Train Loss:   0.37, Train Acc:  87.50%, Val Loss:   0.58, Val Acc:  81.70%, Time: 0:00:06
Iter:    401, Train Loss:   0.22, Train Acc:  90.62%, Val Loss:   0.34, Val Acc:  91.16%, Time: 0:00:11
Iter:    601, Train Loss:   0.17, Train Acc:  95.31%, Val Loss:   0.28, Val Acc:  92.16%, Time: 0:00:16
Iter:    801, Train Loss:   0.18, Train Acc:  95.31%, Val Loss:   0.25, Val Acc:  93.12%, Time: 0:00:21
Iter:   1001, Train Loss:   0.12, Train Acc:  95.31%, Val Loss:   0.28, Val Acc:  91.52%, Time: 0:00:26
Iter:   1201, Train Loss:  0.085, Train Acc:  96.88%, Val Loss:   0.24, Val Acc:  92.92%, Time: 0:00:31
Iter:   1401, Train Loss:  0.098, Train Acc:  95.31%, Val Loss:   0.22, Val Acc:  93.40%, Time: 0:00:36
Iter:   1601, Train Loss:  0.042, Train Acc:  98.44%, Val Loss:   0.19, Val Acc:  94.70%, Time: 0:00:41
Iter:   1801, Train Loss:  0.035, Train Acc: 100.00%, Val Loss:   0.19, Val Acc:  94.88%, Time: 0:00:46
Iter:   2001, Train Loss:  0.011, Train Acc: 100.00%, Val Loss:    0.2, Val Acc:  94.38%, Time: 0:00:51
Iter:   2201, Train Loss:    0.1, Train Acc:  96.88%, Val Loss:    0.2, Val Acc:  94.60%, Time: 0:00:56
Iter:   2401, Train Loss:  0.015, Train Acc: 100.00%, Val Loss:   0.19, Val Acc:  94.88%, Time: 0:01:01
Iter:   2601, Train Loss:  0.017, Train Acc: 100.00%, Val Loss:   0.21, Val Acc:  94.24%, Time: 0:01:06
Iter:   2801, Train Loss: 0.0014, Train Acc: 100.00%, Val Loss:   0.17, Val Acc:  95.36%, Time: 0:01:11
Iter:   3001, Train Loss:  0.074, Train Acc:  98.44%, Val Loss:   0.23, Val Acc:  94.02%, Time: 0:01:17
Iter:   3201, Train Loss:  0.033, Train Acc:  98.44%, Val Loss:   0.15, Val Acc:  96.26%, Time: 0:01:22
Iter:   3401, Train Loss: 0.0087, Train Acc: 100.00%, Val Loss:    0.2, Val Acc:  94.38%, Time: 0:01:27
Iter:   3601, Train Loss: 0.0088, Train Acc: 100.00%, Val Loss:   0.23, Val Acc:  94.16%, Time: 0:01:32
Iter:   3801, Train Loss:  0.015, Train Acc: 100.00%, Val Loss:   0.26, Val Acc:  93.08%, Time: 0:01:37
Test Loss:   0.17, Test Acc:  95.72%

我的博客即将搬运同步至腾讯云+社区,邀请大家一同入驻:https://cloud.tencent.com/developer/support-plan?invite_code=1nbzb7lxkx56n

© 著作权归作者所有

共有 人打赏支持
Gaussic
粉丝 400
博文 28
码字总数 66971
作品 0
宝山
私信 提问
中文文本分类对比(经典方法和CNN)

背景介绍 笔者实验室项目正好需要用到文本分类,作为NLP领域最经典的场景之一,文本分类积累了大量的技术实现方法,如果将是否使用深度学习技术作为标准来衡量,实现方法大致可以分成两类: ...

bupt_周小瑜
2017/12/31
0
0
tensorflow 实现端到端的OCR:二代身份证号识别

最近在研究OCR识别相关的东西,最终目标是能识别身份证上的所有中文汉字+数字,不过本文先设定一个小目标,先识别定长为18的身份证号,当然本文的思路也是可以复用来识别定长的验证码识别的。...

某杰
2017/08/08
0
0
史上最全TensorFlow学习资源汇总

来源 悦动智能(公众号ID:aibbtcom) 本篇文章将为大家总结TensorFlow纯干货学习资源,非常适合新手学习,建议大家收藏。 ▌一 、TensorFlow教程资源 1)适合初学者的TensorFlow教程和代码示...

悦动智能
04/12
0
0
开源 FAQ 问答系统 - AnyQ

AnyQ(ANswer Your Questions) AnyQ(ANswer Your Questions) 开源项目主要包含面向FAQ集合的问答系统框架、文本语义匹配工具SimNet。 问答系统框架采用了配置化、插件化的设计,各功能均通过插...

zenggang1988
07/12
0
0
基于tensorflow+CNN的新闻文本分类

2018年10月4日笔记 tensorflow是谷歌google的深度学习框架,tensor中文叫做张量,flow叫做流。 CNN是convolutional neural network的简称,中文叫做卷积神经网络。 文本分类是NLP(自然语言处...

潇洒坤
10/04
0
0

没有更多内容

加载失败,请刷新页面

加载更多

数据进一步优化篇:千万级数据下的Mysql优化

前言 平时在写一些小web系统时,我们总会对mysql不以为然。然而真正的系统易用应该讲数据量展望拓展到千万级别来考虑。因此,今天下午实在是无聊的慌,自己随手搭建一个千万级的数据库,然后...

hansonwong
8分钟前
0
0
【亲测】centos 7 下安装cuDNN

【亲测】centos 7 下安装cuDNN cudnn: https://developer.nvidia.com/compute/machine-learning/cudnn/secure/v7.4.1.5/prod/10.0_20181108/cudnn-10.0-linux-x64-v7.4.1.5.tgz cudnn code ......

Goopand
18分钟前
0
0
说一说$emit和$on

一、$emit 1、this $emit('自定义事件名',要传送的数据); 2、触发当前实例上的事件,要传递的数据会传给监听器; 二、$on 1、VM.$on('事件名',callback) --------------------callback回调...

文文1
19分钟前
0
0
画出wav文件声音数据的波形曲线

wav文件的格式都有介绍 另外:wav总播放时间长度:如何得到WAV文件播放的总时间? 1、直接读取wav文件头信息,从文件起始地址偏移28个字节长度为4个字节保存的是每秒钟播放的字节数,从文件起...

whoisliang
35分钟前
1
0
0030-如何在CDH中安装Kudu&Spark2&Kafka

1.概述 在CDH的默认安装包中,是不包含Kafka,Kudu和Spark2的,需要单独下载特定的Parcel包才能安装相应服务。本文档主要描述在离线环境下,在CentOS6.5操作系统上基于CDH5.12.1集群,使用C...

Hadoop实操
35分钟前
0
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部