文档章节

TensorFlow中四种 Cross Entropy 算法实现和应用

AllenOR灵感
 AllenOR灵感
发布于 2017/09/10 01:19
字数 2127
阅读 1
收藏 0
点赞 0
评论 0

本文转载自CSDN,这是原文

交叉熵介绍

交叉熵(Cross Entropy)是Loss函数的一种(也称为损失函数或代价函数),用于描述模型预测值与真实值的差距大小,常见的 Loss 函数就是均方平方差(Mean Squared Error),定义如下。


​平方差很好理解,预测值与真实值直接相减,为了避免得到负数取绝对值或者平方,再做平均就是均方平方差。注意这里预测值需要经过sigmoid激活函数,得到取值范围在0到1之间的预测值。

平方差可以表达预测值与真实值的差异,但在分类问题种效果并不如交叉熵好,原因可以参考这篇博文 。交叉熵的定义如下,截图来自这个网站


上面的文章也介绍了交叉熵可以作为Loss函数的原因,首先是交叉熵得到的值一定是正数,其次是预测结果越准确值越小,注意这里用于计算的“a”也是经过sigmoid激活的,取值范围在0到1。如果label是1,预测值也是1的话,前面一项y * ln(a)就是1 * ln(1)等于0,后一项(1 - y) * ln(1 - a)也就是0 * ln(0)等于0,Loss函数为0,反之Loss函数为无限大非常符合我们对Loss函数的定义。

这里多次强调sigmoid激活函数,是因为在多目标或者多分类的问题下有些函数是不可用的,而TensorFlow本身也提供了多种交叉熵算法的实现。

TensorFlow的交叉熵函数

TensorFlow针对分类问题,实现了四个交叉熵函数,分别是:


tf.nn.sigmoid_cross_entropy_with_logits
tf.nn.softmax_cross_entropy_with_logits
tf.nn.sparse_softmax_cross_entropy_with_logits
tf.nn.weighted_cross_entropy_with_logits


详细内容可以参考这个API文档

sigmoid_cross_entropy_with_logits 详解

我们先看sigmoid_cross_entropy_with_logits,为什么呢,因为它的实现和前面的交叉熵算法定义是一样的,也是TensorFlow最早实现的交叉熵算法。这个函数的输入是logits和targets,logits就是神经网络模型中的 W * X矩阵,注意不需要经过sigmoid,而targets的shape和logits相同,就是正确的label值,例如这个模型一次要判断100张图是否包含10种动物,这两个输入的shape都是[100, 10]。注释中还提到这10个分类之间是独立的、不要求是互斥,这种问题我们成为多目标,例如判断图片中是否包含10种动物,label值可以包含多个1或0个1,还有一种问题是多分类问题,例如我们对年龄特征分为5段,只允许5个值有且只有1个值为1,这种问题可以直接用这个函数吗?答案是不可以,我们先来看看sigmoid_cross_entropy_with_logits的代码实现吧。


可以看到这就是标准的Cross Entropy算法实现,对W * X得到的值进行sigmoid激活,保证取值在0到1之间,然后放在交叉熵的函数中计算Loss。对于二分类问题这样做没问题,但对于前面提到的多分类,例如年轻取值范围在0~4,目标值也在0~4,这里如果经过sigmoid后预测值就限制在0到1之间,而且公式中的1 - z就会出现负数,仔细想一下0到4之间还不存在线性关系,如果直接把label值带入计算肯定会有非常大的误差。因此对于多分类问题是不能直接代入的,那其实我们可以灵活变通,把5个年龄段的预测用onehot encoding变成5维的label,训练时当做5个不同的目标来训练即可,但不保证只有一个为1,对于这类问题TensorFlow又提供了基于Softmax的交叉熵函数。

softmax_cross_entropy_with_logits 详解

Softmax本身的算法很简单,就是把所有值用e的n次方计算出来,求和后算每个值占的比率,保证总和为1,一般我们可以认为Softmax出来的就是confidence也就是概率,算法实现如下。


​softmax_cross_entropy_with_logits和sigmoid_cross_entropy_with_logits很不一样,输入是类似的logits和lables的shape一样,但这里要求分类的结果是互斥的,保证只有一个字段有值,例如CIFAR-10中图片只能分一类而不像前面判断是否包含多类动物。想一下问什么会有这样的限制?在函数头的注释中我们看到,这个函数传入的logits是unscaled的,既不做sigmoid也不做softmax,因为函数实现会在内部更高效得使用softmax,对于任意的输入经过softmax都会变成和为1的概率预测值,这个值就可以代入变形的Cross Entroy算法- y * ln(a) - (1 - y) * ln(1 - a)算法中,得到有意义的Loss值了。如果是多目标问题,经过softmax就不会得到多个和为1的概率,而且label有多个1也无法计算交叉熵,因此这个函数只适合单目标的二分类或者多分类问题,TensorFlow函数定义如下。


再补充一点,对于多分类问题,例如我们的年龄分为5类,并且人工编码为0、1、2、3、4,因为输出值是5维的特征,因此我们需要人工做onehot encoding分别编码为00001、00010、00100、01000、10000,才可以作为这个函数的输入。理论上我们不做onehot encoding也可以,做成和为1的概率分布也可以,但需要保证是和为1,和不为1的实际含义不明确,TensorFlow的C++代码实现计划检查这些参数,可以提前提醒用户避免误用。

sparse_softmax_cross_entropy_with_logits 详解

sparse_softmax_cross_entropy_with_logits是softmax_cross_entropy_with_logits的易用版本,除了输入参数不同,作用和算法实现都是一样的。前面提到softmax_cross_entropy_with_logits的输入必须是类似onehot encoding的多维特征,但CIFAR-10、ImageNet和大部分分类场景都只有一个分类目标,label值都是从0编码的整数,每次转成onehot encoding比较麻烦,有没有更好的方法呢?答案就是用sparse_softmax_cross_entropy_with_logits,它的第一个参数logits和前面一样,shape是[batch_size, num_classes],而第二个参数labels以前也必须是[batch_size, num_classes]否则无法做Cross Entropy,这个函数改为限制更强的[batch_size],而值必须是从0开始编码的int32或int64,而且值范围是[0, num_class),如果我们从1开始编码或者步长大于1,会导致某些label值超过这个范围,代码会直接报错退出。这也很好理解,TensorFlow通过这样的限制才能知道用户传入的3、6或者9对应是哪个class,最后可以在内部高效实现类似的onehot encoding,这只是简化用户的输入而已,如果用户已经做了onehot encoding那可以直接使用不带“sparse”的softmax_cross_entropy_with_logits函数。

weighted_sigmoid_cross_entropy_with_logits详解

weighted_sigmoid_cross_entropy_with_logits是sigmoid_cross_entropy_with_logits的拓展版,输入参数和实现和后者差不多,可以多支持一个pos_weight参数,目的是可以增加或者减小正样本在算Cross Entropy时的Loss。实现原理很简单,在传统基于sigmoid的交叉熵算法上,正样本算出的值乘以某个系数接口,算法实现如下。


总结

这就是TensorFlow目前提供的有关Cross Entropy的函数实现,用户需要理解多目标和多分类的场景,根据业务需求(分类目标是否独立和互斥)来选择基于sigmoid或者softmax的实现,如果使用sigmoid目前还支持加权的实现,如果使用softmax我们可以自己做onehot coding或者使用更易用的sparse_softmax_cross_entropy_with_logits函数。

TensorFlow提供的Cross Entropy函数基本cover了多目标和多分类的问题,但如果同时是多目标多分类的场景,肯定是无法使用softmax_cross_entropy_with_logits,如果使用sigmoid_cross_entropy_with_logits我们就把多分类的特征都认为是独立的特征,而实际上他们有且只有一个为1的非独立特征,计算Loss时不如Softmax有效。这里可以预测下,未来TensorFlow社区将会实现更多的op解决类似的问题,我们也期待更多人参与TensorFlow贡献算法和代码 :)

本文转载自:http://www.jianshu.com/p/122b2113dae2

共有 人打赏支持
AllenOR灵感
粉丝 10
博文 2634
码字总数 82983
作品 0
程序员
TensorFlow——训练神经网络模型

TensorFlow训练神经网络模型的步骤: (1)定义神经网络的结构和向前传播的输出结果 (2)定义损失函数以及选择反向传播优化的算法 (3)生成会话(tf.Session),并且在训练数据上反复运行反...

飞天小橘子
04/23
0
0
TensorFlow基本原理,入门教程网址

TensorFlow 进阶 Python代码的目的是用来 构建这个可以在外部运行的计算图,以及 安排计算图的哪一部分应该被运行。 http://tensorfly.cn/ github 地址 : https://github.com/tensorflow/te...

寒月谷
06/05
0
0
实战|TensorFlow 实践之手写体数字识别!

本文的主要目的是教会大家运用google开源的深度学习框架tensorflow来实现手写体数字识别,给出两种模型,一种是利用机器学习中的softmax regression作分类器,另一种将是搭建一个深度神经网络...

j2iayu7y
04/17
0
0
tensorflow学习笔记(第一天)-MNIST机器学习入门

MNIST机器学习入门 这个是关于tensorflow的中文文档:http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/mnist_beginners.html MNIST是一个入门级的计算机视觉数据集,这个就相当...

a870542373
04/13
0
0
关于tensorflow使用的一些简单的问题

笔者在前两年为了验证一些神经网络问题曾经在Ubuntu上安装了Tensorflow,这个好像运行并没有什么太大的问题,但近期又在Windows(Win10)下安装使用最新的Tensorflow,不知为什么总是存在一些...

chenhu73
06/28
0
0
通过Python来学习人工智能!事半功倍!TensorFlow之入门篇!

     MNIST数据集介绍   MNIST是一个入门级的计算机视觉数据集,它包含各种手写数字图片:      MNIST数据集是含标注信息的,以上图片分别代表5, 0, 4和1。   MNIST数据集的官网...

中国机器人
06/21
0
0
[2018-07-08] tensorflow 创建线性回归(1)

OverView: 今天突然想起以前写过一个用BP算法的iris分类器, 加上最近面试把线性规划的思想和实现又看了一遍. (1) 数据集介绍 (2) tensorflow 实现分类器 (3) tensorflow实现模型评估 (1) 数据...

斐波那契的数字
07/08
0
0
Keras 深度学习框架介绍----一起来慢慢走进deep learning

Introduce Keras是一个高级API,用Python编写,能够在TensorFlow、Theano或CNTK上运行。Keras提供了一个简单和模块化的API来创建和训练神经网络,隐藏了大部分复杂的细节。 How to install k...

qq_15642411
04/20
0
0
TensorFlow Tutorial-1

1、Why TensorFlow? 网上有关介绍太多了,我就不多说了,这里主要注重使用。 2、Programing model 2.1、Big Idea: 将数值的计算转化为图(computational graph),任何tensorflow的计算都是...

戬杨Jason
2017/08/05
0
0
TensorFlow官方文档学习|TensorFlow运作方式入门

认识MINIST数据集 from tensorflow.examples.tutorials.mnist import input_datamnist = inputdata.readdatasets("MNISTdata/", one_hot=True) print (mnist.train.images.shape)print (mnis......

darlingwood2013
2017/03/12
0
0

没有更多内容

加载失败,请刷新页面

加载更多

下一页

vue-router懒加载

1. vue-router懒加载定义 当路由被访问的时候才加载对应组件 2. vue-router懒加载作用 当构建的项目比较大的时候,懒加载可以分割代码块,提高页面的初始加载效率。 ###3. vue-router懒加载实...

不负好时光
6分钟前
0
0
庆祝法国队夺冠:用Python放一场烟花秀

天天敲代码的朋友,有没有想过代码也可以变得很酷炫又浪漫?今天就教大家用Python模拟出绽放的烟花庆祝昨晚法国队夺冠,工作之余也可以随时让程序为自己放一场烟花秀。 这个有趣的小项目并不...

猫咪编程
8分钟前
0
0
SpringBoot | 第七章:过滤器、监听器、拦截器

前言 在实际开发过程中,经常会碰见一些比如系统启动初始化信息、统计在线人数、在线用户数、过滤敏高词汇、访问权限控制(URL级别)等业务需求。这些对于业务来说一般上是无关的,业务方是无需...

oKong
22分钟前
4
0
存储结构分四类:顺序存储、链接存储、索引存储 和 散列存储

存储结构分四类:顺序存储、链接存储、索引存储 和 散列存储 存储结构分四类:顺序存储、链接存储、索引存储 和 散列存储。 顺序结构和链接结构适用在内存结构中。 顺序表每个单元都是按物理...

DannyCoder
32分钟前
0
0
Firefox 61已经为Ubuntu 提供支持

最新和最好的Mozilla Firefox 61 “Quantum”网络浏览器已经为Ubuntu Linux操作系统的用户提供了支持,现在可以通过官方软件库进行更新。 Mozilla于2018年6月26日发布了Firefox 61版本,该版...

六库科技
59分钟前
0
0
Win10升级后执行系统封装(Sysprep)报错

开始封装 一年多以前开始给公司封装Win10系统,便于统一给公司电脑初始化携带各种软件的系统,致力于装完既可以开发的状态。那时候最新的版本是Win10 1703版本,自然就以他为母盘,然后结合V...

lyunweb
今天
39
0
php 性能优化

#什么情况下会遇到性能问题 PHP 语法使用的不恰当

to_be_better
今天
0
0
Jenkins 构建触发器操作详解

前言 跑自动化用例每次用手工点击jenkins出发自动化用例太麻烦了,我们希望能每天固定时间跑,这样就不用管了,坐等收测试报告结果就行。 一、定时构建语法 * * * * * (五颗星,中间用空格隔...

覃光林
今天
0
0
IDEA配置技巧

超详细设置Idea类注释模板和方法注释模板 idea去掉注解param下划线 JetBrains全系列破解

AK灬
今天
0
0
rsync通过服务同步/Linux系统日志/screen工具

rsync通过服务同步 分为服务端(机器A) 和客户端(机器B) 机器A操作编辑/etc/rsyncd.conf配置文件 [root@yolks1 ~]# vim /etc/rsyncd.conf 文件中添加以下配置 port=873 ...

Hi_Yolks
今天
0
0

没有更多内容

加载失败,请刷新页面

加载更多

下一页

返回顶部
顶部