文档章节

mxnet常用的数据评估指标

q
 qinhui99
发布于 2017/06/21 17:56
字数 557
阅读 263
收藏 0

mxnet最近更新很多文档,其中就包括了常用的数据评估指标。相关文档参考:http://mxnet.io/api/python/metric.html#overview

仔细看了一下,大部分都是分类用的评估指标,线性回归的很少。我猜可能是因为mxnet做线性回归不太行的缘故。下面把常用的指标列出来,以备查看:

#encoding=utf-8
'''
测试用的校验指标. 包括准确率,F1指标、topk和Perplexity
http://mxnet.io/api/python/metric.html#overview
'''
import mxnet as mx
import numpy as np
predicts = [mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])]
labels   = [mx.nd.array([0, 1, 1])]

#1、最常用的准确率Accuracy
eval_metrics_1 = mx.metric.Accuracy()
#2、分类的综合评估指标F1. This F1 score only supports binary classification
eval_metrics_2 = mx.metric.F1()
eval_metrics = mx.metric.CompositeEvalMetric()
for child_metric in [eval_metrics_1, eval_metrics_2]:
    eval_metrics.add(child_metric)

eval_metrics.update(labels = labels, preds = predicts)
print eval_metrics.get() #(['accuracy', 'f1'], [0.6666666666666666, 0.8])

#3、平均绝对误差MAE
predicts = [mx.nd.array(np.array([3, -0.5, 2, 7]).reshape(4,1))]
labels = [mx.nd.array(np.array([2.5, 0.0, 2, 8]).reshape(4,1))]
mean_absolute_error = mx.metric.MAE()
mean_absolute_error.update(labels = labels, preds = predicts)
print mean_absolute_error.get() #('mae', 0.5)

#4、均方差MSE
predicts = [mx.nd.array(np.array([3, -0.5, 2, 7]).reshape(4,1))]
labels = [mx.nd.array(np.array([2.5, 0.0, 2, 8]).reshape(4,1))]
mean_squared_error = mx.metric.MSE()
mean_squared_error.update(labels = labels, preds = predicts)
print mean_squared_error.get()#('mse', 0.375)

#5、标准差RMSE
predicts = [mx.nd.array(np.array([3, -0.5, 2, 7]).reshape(4,1))]
labels = [mx.nd.array(np.array([2.5, 0.0, 2, 8]).reshape(4,1))]
root_mean_squared_error = mx.metric.RMSE()
root_mean_squared_error.update(labels = labels, preds = predicts)
print root_mean_squared_error.get()#('rmse', 0.61237245798110962)

#6、交叉熵CrossEntropy
predicts = [mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])]
labels   = [mx.nd.array([0, 1, 1])]
ce = mx.metric.CrossEntropy()
ce.update(labels, predicts)#('cross-entropy', 0.57159948348999023)
print ce.get()

#7、前K项指标,top k指标。k越大,值越大,因为包含的可能性更高。
np.random.seed(999)
top_k = 3 #前3项指标
labels = [mx.nd.array([2, 6, 9, 2, 3, 4, 7, 8, 9, 6])]
predicts = [mx.nd.array(np.random.rand(10, 10))]
acc = mx.metric.TopKAccuracy(top_k=top_k)
acc.update(labels, predicts)
print acc.get() # ('top_k_accuracy_3', 0.3)

top_k = 5 #前5项指标
acc = mx.metric.TopKAccuracy(top_k=top_k)
acc.update(labels, predicts)
print acc.get() # ('top_k_accuracy_5', 0.6)

'''
8、Perplexity指标。
Perplexity is a measurement of how well a probability distribution or model predicts a sample.
A low perplexity indicates the model is good at predicting the sample.
简单来说,perplexity就是对于语言模型所估计的一句话出现的概率.Perplexity其实表示的是average branch factor,
大概可以翻译为平均分支系数。即平均来说,我们预测下一个词时有多少种选择。
摘录自:http://blog.csdn.net/luo123n/article/details/48902815
'''
predicts = [mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])]
labels   = [mx.nd.array([0, 1, 1])]
perp = mx.metric.Perplexity(ignore_label=None)
perp.update(labels, predicts)
print perp.get() #('Perplexity', 1.7710976285155853)

© 著作权归作者所有

q
粉丝 66
博文 73
码字总数 34091
作品 0
深圳
程序员
私信 提问
业界 | MXNet开放支持Keras,高效实现CNN与RNN的分布式训练

  选自AWS Machine Learning Blog   作者:Lai Wei、Kalyanee Chendke、Aaron Markham、Sandeep Krishnamurthy   机器之心编译   参与:路、王淑婷      今日 AWS 发布博客宣布 ...

机器之心
2018/05/22
0
0
MXNet 宣布支持 Keras 2,可更加方便快捷地实现 CNN 及 RNN 分布式训练

雷锋网(公众号:雷锋网) AI 研习社按,近期,AWS 表示 MXNet 支持 Keras 2,开发者可以使用 Keras-MXNet 更加方便快捷地实现 CNN 及 RNN 分布式训练。AI 研习社将 AWS 官方博文编译如下。 Ke...

孔令双
2018/05/23
0
0
MXNet/Gluon 中网络和参数的存取方式

Gluon是MXNet的高层封装,网络设计简单易用,与Keras类似。随着深度学习技术的普及,类似于Gluon这种,高层封装的深度学习框架,被越来越多的开发者接受和使用。 在开发深度学习算法时,必然...

SpikeKing
2018/05/29
0
0
windows下编译mxnet并使用C++训练模型

版权声明:原创文章如需转载,请在左侧博主描述栏目扫码联系我并取得授权,谢谢 https://blog.csdn.net/u012234115/article/details/80503086 大多数情况下,mxnet都使用python接口进行机器学...

踏莎行hyx
2018/05/29
0
0
云上深度学习实践(二)-云上MXNet实践

目录 云上深度学习实践(一)-GPU云服务器TensorFlow单机多卡训练性能实践 云上深度学习实践(二)-云上MXNet实践 1 MXNet 简介 1.1 MXNet特点 MXNet是一个全功能,灵活可编程和高扩展性的深...

撷峰
2018/07/13
0
0

没有更多内容

加载失败,请刷新页面

加载更多

总结

一、设计模式 简单工厂:一个简单而且比较杂的工厂,可以创建任何对象给你 复杂工厂:先创建一种基础类型的工厂接口,然后各自集成实现这个接口,但是每个工厂都是这个基础类的扩展分类,spr...

BobwithB
26分钟前
2
0
java内存模型

前言 Java作为一种面向对象的,跨平台语言,其对象、内存等一直是比较难的知识点。而且很多概念的名称看起来又那么相似,很多人会傻傻分不清楚。比如本文我们要讨论的JVM内存结构、Java内存模...

ls_cherish
30分钟前
2
0
友元函数强制转换

友元函数强制转换 p522

天王盖地虎626
昨天
5
0
js中实现页面跳转(返回前一页、后一页)

本文转载于:专业的前端网站➸js中实现页面跳转(返回前一页、后一页) 一:JS 重载页面,本地刷新,返回上一页 复制代码代码如下: <a href="javascript:history.go(-1)">返回上一页</a> <a h...

前端老手
昨天
5
0
JAVA 利用时间戳来判断TOKEN是否过期

import java.time.Instant;import java.time.LocalDateTime;import java.time.ZoneId;import java.time.ZoneOffset;import java.time.format.DateTimeFormatter;/** * @descri......

huangkejie
昨天
4
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部