文档章节

【AI实战】快速掌握TensorFlow(四):损失函数

雪饼
 雪饼
发布于 2018/09/02 00:24
字数 2082
阅读 2274
收藏 10

在前面的文章中,我们已经学习了TensorFlow激励函数的操作使用方法(见文章:快速掌握TensorFlow(三)),今天我们将继续学习TensorFlow

本文主要是学习掌握TensorFlow的损失函数。

一、什么是损失函数

损失函数(loss function)是机器学习中非常重要的内容,它是度量模型输出值与目标值的差异,也就是作为评估模型效果的一种重要指标,损失函数越小,表明模型的鲁棒性就越好。

 

二、怎样使用损失函数

在TensorFlow中训练模型时,通过损失函数告诉TensorFlow预测结果相比目标结果是好还是坏。在多种情况下,我们会给出模型训练的样本数据和目标数据,损失函数即是比较预测值与给定的目标值之间的差异。

下面将介绍在TensorFlow中常用的损失函数。

 

1、回归模型的损失函数

首先讲解回归模型的损失函数,回归模型是预测连续因变量的。为方便介绍,先定义预测结果(-1至1的等差序列)、目标结果(目标值为0),代码如下:

import tensorflow as tf

sess=tf.Session()

y_pred=tf.linspace(-1., 1., 100)

y_target=tf.constant(0.)

注意,在实际训练模型时,预测结果是模型输出的结果值,目标结果是样本提供的。

 

(1)L1正则损失函数(即绝对值损失函数)

L1正则损失函数是对预测值与目标值的差值求绝对值,公式如下:

在TensorFlow中调用方式如下:

loss_l1_vals=tf.abs(y_pred-y_target)

loss_l1_out=sess.run(loss_l1_vals)

L1正则损失函数在目标值附近不平滑,会导致模型不能很好地收敛。

 

(2)L2正则损失函数(即欧拉损失函数)

L2正则损失函数是预测值与目标值差值的平方和,公式如下:

当对L2取平均值,就变成均方误差(MSE, mean squared error),公式如下:

在TensorFlow中调用方式如下:

# L2损失

loss_l2_vals=tf.square(y_pred - y_target)

loss_l2_out=sess.run(loss_l2_vals)

# 均方误差

loss_mse_vals= tf.reduce.mean(tf.square(y_pred - y_target))

loss_mse_out = sess.run(loss_mse_vals)

L2正则损失函数在目标值附近有很好的曲度,离目标越近收敛越慢,是非常有用的损失函数。

 

L1、L2正则损失函数如下图所示:

 

(3)Pseudo-Huber 损失函数

Huber损失函数经常用于回归问题,它是分段函数,公式如下:

从这个公式可以看出当残差(预测值与目标值的差值,即y-f(x) )很小的时候,损失函数为L2范数,残差大的时候,为L1范数的线性函数。

 

Peseudo-Huber损失函数是Huber损失函数的连续、平滑估计,在目标附近连续,公式如下:

该公式依赖于参数delta,delta越大,则两边的线性部分越陡峭。

 

在TensorFlow中的调用方式如下:

delta=tf.constant(0.25)

loss_huber_vals = tf.mul(tf.square(delta), tf.sqrt(1. + tf.square(y_target – y_pred)/delta)) – 1.)

loss_huber_out = sess.run(loss_huber_vals)

L1、L2、Huber损失函数的对比图如下,其中Huber的delta取0.25、5两个值:

2、分类模型的损失函数

分类损失函数主要用于评估预测分类结果,重新定义预测值(-3至5的等差序列)和目标值(目标值为1),如下:

y_pred=tf.linspace(-3., 5., 100)

y_target=tf.constant(1.)

y_targets=tf.fill([100, ], 1.)

(1)Hinge损失函数

Hinge损失常用于二分类问题,主要用来评估向量机算法,但有时也用来评估神经网络算法,公式如下:

在TensorFlow中的调用方式如下:

loss_hinge_vals = tf.maximum(0., 1. – tf.mul(y_target, y_pred))

loss_hinge_out = sess.run(loss_hinge_vals)

上面的代码中,目标值为1,当预测值离1越近,则损失函数越小,如下图:

(2)两类交叉熵(Cross-entropy)损失函数

交叉熵来自于信息论,是分类问题中使用广泛的损失函数。交叉熵刻画了两个概率分布之间的距离,当两个概率分布越接近时,它们的交叉熵也就越小,给定两个概率分布p和q,则距离如下:

对于两类问题,当一个概率p=y,则另一个概率q=1-y,因此代入化简后的公式如下:

在TensorFlow中的调用方式如下:

loss_ce_vals = tf.mul(y_target, tf.log(y_pred)) – tf.mul((1. – y_target), tf.log(1. – y_pred))

loss_ce_out = sess.run(loss_ce_vals)

Cross-entropy损失函数主要应用在二分类问题上,预测值为概率值,取值范围为[0,1],损失函数图如下:

(3)Sigmoid交叉熵损失函数

与上面的两类交叉熵类似,只是将预测值y_pred值通过sigmoid函数进行转换,再计算交叉熵损失。在TensorFlow中有内置了该函数,调用方式如下:

loss_sce_vals=tf.nn.sigmoid_cross_entropy_with_logits(y_pred, y_targets)

loss_sce_out=sess.run(loss_sce_vals)

由于sigmoid函数会将输入值变小很多,从而平滑了预测值,使得sigmoid交叉熵在预测值离目标值比较远时,其损失的增长没有那么的陡峭。与两类交叉熵的比较图如下:

(4)加权交叉熵损失函数

加权交叉熵损失函数是Sigmoid交叉熵损失函数的加权,是对正目标的加权。假定权重为0.5,在TensorFlow中的调用方式如下:

weight = tf.constant(0.5)

loss_wce_vals = tf.nn.weighted_cross_entropy_with_logits(y)vals, y_targets, weight)

loss_wce_out = sess.run(loss_wce_vals)

(5)Softmax交叉熵损失函数

Softmax交叉熵损失函数是作用于非归一化的输出结果,只针对单个目标分类计算损失。

通过softmax函数将输出结果转化成概率分布,从而便于输入到交叉熵里面进行计算(交叉熵要求输入为概率),softmax定义如下:

结合前面的交叉熵定义公式,则Softmax交叉熵损失函数公式如下:

在TensorFlow中调用方式如下:

y_pred=tf.constant([[1., -3., 10.]]

y_target=tf.constant([[0.1, 0.02, 0.88]])

loss_sce_vals=tf.nn.softmax_cross_entropy_with_logits(y_pred, y_target)

loss_sce_out=sess.run(loss_sce_vals)

用于回归相关的损失函数,对比图如下:

3、总结

下面对各种损失函数进行一个总结,如下表所示:

在实际使用中,对于回归问题经常会使用MSE均方误差(L2取平均)计算损失,对于分类问题经常会使用Sigmoid交叉熵损失函数。

大家在使用时,还要根据实际的场景、具体的模型,选择使用的损失函数,希望本文对你有帮助。

接下来的“快速掌握TensorFlow”系列文章,还会有更多讲解TensorFlow的精彩内容,敬请期待。

 

欢迎关注本人的微信公众号“大数据与人工智能Lab”(BigdataAILab),获取更多信息

 

推荐相关阅读

 

关注本人公众号“大数据与人工智能Lab”(BigdataAILab),获取更多信息

© 著作权归作者所有

雪饼

雪饼

粉丝 412
博文 61
码字总数 134328
作品 0
广州
私信 提问
【AI实战】快速掌握TensorFlow(三):激励函数

到现在我们已经了解了TensorFlow的特点和基本操作(见文章:快速掌握TensorFlow(一)),以及TensorFlow计算图、会话的操作(见文章:快速掌握TensorFlow(二)),接下来我们将继续学习掌握...

雪饼
2018/08/30
1K
0
【AI实战】快速掌握TensorFlow(二):计算图、会话

在前面的文章中,我们已经完成了AI基础环境的搭建(见文章:Ubuntu + Anaconda + TensorFlow + GPU + PyCharm搭建AI基础环境),以及初步了解了TensorFlow的特点和基本操作(见文章:快速掌握...

雪饼
2018/08/20
1K
1
【AI实战】快速掌握Tensorflow(一):基本操作

Tensorflow是Google开源的深度学习框架,来自于Google Brain研究项目,在Google第一代分布式机器学习框架DistBelief的基础上发展起来。Tensorflow于2015年11月在GitHub上开源,在2016年4月补...

雪饼
2018/08/18
2.3K
0
【AI实战】训练第一个AI模型:MNIST手写数字识别模型

在上篇文章中,我们已经把AI的基础环境搭建好了(见文章:Ubuntu + conda + tensorflow + GPU + pycharm搭建AI基础环境),接下来将基于tensorflow训练第一个AI模型:MNIST手写数字识别模型。...

雪饼
2018/08/11
3.7K
0
【AI实战】手把手教你训练自己的目标检测模型(SSD篇)

目标检测是AI的一项重要应用,通过目标检测模型能在图像中把人、动物、汽车、飞机等目标物体检测出来,甚至还能将物体的轮廓描绘出来,就像下面这张图,是不是很酷炫呢,嘿嘿 在动手训练自己...

雪饼
2018/08/14
11.1K
25

没有更多内容

加载失败,请刷新页面

加载更多

排序––快速排序(二)

根据排序––快速排序(一)的描述,现准备写一个快速排序的主体框架: 1、首先需要设置一个枢轴元素即setPivot(int i); 2、然后需要与枢轴元素进行比较即int comparePivot(int j); 3、最后...

FAT_mt
今天
4
0
mysql概览

学习知识,首先要有一个总体的认识。以下为mysql概览 1-架构图 2-Detail csdn |简书 | 头条 | SegmentFault 思否 | 掘金 | 开源中国 |

程序员深夜写bug
今天
10
0
golang微服务框架go-micro 入门笔记2.2 micro工具之微应用利器micro web

micro web micro 功能非常强大,本文将详细阐述micro web 命令行的功能 阅读本文前你可能需要进行如下知识储备 golang分布式微服务框架go-micro 入门笔记1:搭建go-micro环境, golang微服务框架...

非正式解决方案
今天
6
0
前端——使用base64编码在页面嵌入图片

因为页面中插入一个图片都要写明图片的路径——相对路径或者绝对路径。而除了具体的网站图片的图片地址,如果是在自己电脑文件夹里的图片,当我们的HTML文件在别人电脑上打开的时候图片则由于...

被毒打的程序猿
今天
9
0
Flutter 系列之Dart语言概述

Dart语言与其他语言究竟有什么不同呢?在已有的编程语言经验的基础上,我们该如何快速上手呢?本篇文章从编程语言中最重要的组成部分,也就是基础语法与类型变量出发,一起来学习Dart吧 一、...

過愙
今天
6
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部