## mxnet常用的数据评估指标 原

q
qinhui99

mxnet最近更新很多文档，其中就包括了常用的数据评估指标。相关文档参考：http://mxnet.io/api/python/metric.html#overview

```#encoding=utf-8
'''

http://mxnet.io/api/python/metric.html#overview
'''
import mxnet as mx
import numpy as np
predicts = [mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])]
labels   = [mx.nd.array([0, 1, 1])]

#1、最常用的准确率Accuracy
eval_metrics_1 = mx.metric.Accuracy()
#2、分类的综合评估指标F1. This F1 score only supports binary classification
eval_metrics_2 = mx.metric.F1()
eval_metrics = mx.metric.CompositeEvalMetric()
for child_metric in [eval_metrics_1, eval_metrics_2]:

eval_metrics.update(labels = labels, preds = predicts)
print eval_metrics.get() #(['accuracy', 'f1'], [0.6666666666666666, 0.8])

#3、平均绝对误差MAE
predicts = [mx.nd.array(np.array([3, -0.5, 2, 7]).reshape(4,1))]
labels = [mx.nd.array(np.array([2.5, 0.0, 2, 8]).reshape(4,1))]
mean_absolute_error = mx.metric.MAE()
mean_absolute_error.update(labels = labels, preds = predicts)
print mean_absolute_error.get() #('mae', 0.5)

#4、均方差MSE
predicts = [mx.nd.array(np.array([3, -0.5, 2, 7]).reshape(4,1))]
labels = [mx.nd.array(np.array([2.5, 0.0, 2, 8]).reshape(4,1))]
mean_squared_error = mx.metric.MSE()
mean_squared_error.update(labels = labels, preds = predicts)
print mean_squared_error.get()#('mse', 0.375)

#5、标准差RMSE
predicts = [mx.nd.array(np.array([3, -0.5, 2, 7]).reshape(4,1))]
labels = [mx.nd.array(np.array([2.5, 0.0, 2, 8]).reshape(4,1))]
root_mean_squared_error = mx.metric.RMSE()
root_mean_squared_error.update(labels = labels, preds = predicts)
print root_mean_squared_error.get()#('rmse', 0.61237245798110962)

#6、交叉熵CrossEntropy
predicts = [mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])]
labels   = [mx.nd.array([0, 1, 1])]
ce = mx.metric.CrossEntropy()
ce.update(labels, predicts)#('cross-entropy', 0.57159948348999023)
print ce.get()

#7、前K项指标，top k指标。k越大，值越大,因为包含的可能性更高。
np.random.seed(999)
top_k = 3 #前3项指标
labels = [mx.nd.array([2, 6, 9, 2, 3, 4, 7, 8, 9, 6])]
predicts = [mx.nd.array(np.random.rand(10, 10))]
acc = mx.metric.TopKAccuracy(top_k=top_k)
acc.update(labels, predicts)
print acc.get() # ('top_k_accuracy_3', 0.3)

top_k = 5 #前5项指标
acc = mx.metric.TopKAccuracy(top_k=top_k)
acc.update(labels, predicts)
print acc.get() # ('top_k_accuracy_5', 0.6)

'''
8、Perplexity指标。
Perplexity is a measurement of how well a probability distribution or model predicts a sample.
A low perplexity indicates the model is good at predicting the sample.

'''
predicts = [mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])]
labels   = [mx.nd.array([0, 1, 1])]
perp = mx.metric.Perplexity(ignore_label=None)
perp.update(labels, predicts)
print perp.get() #('Perplexity', 1.7710976285155853)```

q

### qinhui99

选自AWS Machine Learning Blog 　　作者：Lai Wei、Kalyanee Chendke、Aaron Markham、Sandeep Krishnamurthy 　　机器之心编译 　　参与：路、王淑婷 　　 　　今日 AWS 发布博客宣布 ...

2018/05/22
0
0
MXNet 宣布支持 Keras 2，可更加方便快捷地实现 CNN 及 RNN 分布式训练

2018/05/23
0
0
MXNet/Gluon 中网络和参数的存取方式

Gluon是MXNet的高层封装，网络设计简单易用，与Keras类似。随着深度学习技术的普及，类似于Gluon这种，高层封装的深度学习框架，被越来越多的开发者接受和使用。 在开发深度学习算法时，必然...

SpikeKing
2018/05/29
0
0
windows下编译mxnet并使用C++训练模型

2018/05/29
0
0

2018/07/13
0
0

BobwithB
26分钟前
2
0
java内存模型

ls_cherish
30分钟前
2
0

5
0
js中实现页面跳转（返回前一页、后一页）

5
0
JAVA 利用时间戳来判断TOKEN是否过期

import java.time.Instant;import java.time.LocalDateTime;import java.time.ZoneId;import java.time.ZoneOffset;import java.time.format.DateTimeFormatter;/** * @descri......

huangkejie

4
0