Spark 机器学习实践 :Iris数据集的分类

原创
2016/04/12 12:10
阅读数 4.3K

今天试用了一下Spark的机器学习,体验如下:

第一步,导入数据

我们使用Iris数据集,做一个分类,首先要把csv文件导入。这里用到了spark的csv包,不明白为什么这么常见的功能不是内置的,还需要额外加载。

--packages com.databricks:spark-csv_2.11:1.4.0

from pyspark.sql import SQLContext
sqlContext = SQLContext(sc)
df = sqlContext.read.format('com.databricks.spark.csv')
    .options(header='true', inferschema='true')
    .load('iris.csv')
# Displays the content of the DataFrame to stdout
df.show()

结果如下:

+-----+------------+-----------+------------+-----------+-------+
|rowid|Sepal.Length|Sepal.Width|Petal.Length|Petal.Width|Species|
+-----+------------+-----------+------------+-----------+-------+
|    1|         5.1|        3.5|         1.4|        0.2| setosa|
|    2|         4.9|        3.0|         1.4|        0.2| setosa|
|    3|         4.7|        3.2|         1.3|        0.2| setosa|
|    4|         4.6|        3.1|         1.5|        0.2| setosa|
|    5|         5.0|        3.6|         1.4|        0.2| setosa|
|    6|         5.4|        3.9|         1.7|        0.4| setosa|
|    7|         4.6|        3.4|         1.4|        0.3| setosa|
|    8|         5.0|        3.4|         1.5|        0.2| setosa|
|    9|         4.4|        2.9|         1.4|        0.2| setosa|
|   10|         4.9|        3.1|         1.5|        0.1| setosa|
|   11|         5.4|        3.7|         1.5|        0.2| setosa|
|   12|         4.8|        3.4|         1.6|        0.2| setosa|
|   13|         4.8|        3.0|         1.4|        0.1| setosa|
|   14|         4.3|        3.0|         1.1|        0.1| setosa|
|   15|         5.8|        4.0|         1.2|        0.2| setosa|
|   16|         5.7|        4.4|         1.5|        0.4| setosa|
|   17|         5.4|        3.9|         1.3|        0.4| setosa|
|   18|         5.1|        3.5|         1.4|        0.3| setosa|
|   19|         5.7|        3.8|         1.7|        0.3| setosa|
|   20|         5.1|        3.8|         1.5|        0.3| setosa|
+-----+------------+-----------+------------+-----------+-------+
only showing top 20 rows

第二步,提取特征

Spark要求把分类的标签(label)转换成数值进行计算,这一点没有scklearn方便。Spark提供了StringIndexer功能,可以把字符串转换为索引值。

from pyspark.ml.feature import StringIndexer

indexer = StringIndexer(inputCol="Species", outputCol="categoryIndex")
indexed = indexer.fit(df).transform(df)
indexed.show()

提取后,categoryIndex这一列里面就是Species的索引值。

+-----+------------+-----------+------------+-----------+-------+-------------+
|rowid|Sepal.Length|Sepal.Width|Petal.Length|Petal.Width|Species|categoryIndex|
+-----+------------+-----------+------------+-----------+-------+-------------+
|    1|         5.1|        3.5|         1.4|        0.2| setosa|          2.0|
|    2|         4.9|        3.0|         1.4|        0.2| setosa|          2.0|
|    3|         4.7|        3.2|         1.3|        0.2| setosa|          2.0|
|    4|         4.6|        3.1|         1.5|        0.2| setosa|          2.0|
|    5|         5.0|        3.6|         1.4|        0.2| setosa|          2.0|
|    6|         5.4|        3.9|         1.7|        0.4| setosa|          2.0|
|    7|         4.6|        3.4|         1.4|        0.3| setosa|          2.0|
|    8|         5.0|        3.4|         1.5|        0.2| setosa|          2.0|
|    9|         4.4|        2.9|         1.4|        0.2| setosa|          2.0|
|   10|         4.9|        3.1|         1.5|        0.1| setosa|          2.0|
|   11|         5.4|        3.7|         1.5|        0.2| setosa|          2.0|
|   12|         4.8|        3.4|         1.6|        0.2| setosa|          2.0|
|   13|         4.8|        3.0|         1.4|        0.1| setosa|          2.0|
|   14|         4.3|        3.0|         1.1|        0.1| setosa|          2.0|
|   15|         5.8|        4.0|         1.2|        0.2| setosa|          2.0|
|   16|         5.7|        4.4|         1.5|        0.4| setosa|          2.0|
|   17|         5.4|        3.9|         1.3|        0.4| setosa|          2.0|
|   18|         5.1|        3.5|         1.4|        0.3| setosa|          2.0|
|   19|         5.7|        3.8|         1.7|        0.3| setosa|          2.0|
|   20|         5.1|        3.8|         1.5|        0.3| setosa|          2.0|
+-----+------------+-----------+------------+-----------+-------+-------------+
only showing top 20 rows

第三步,模型训练和验证:

from pyspark.sql import Row
from pyspark.mllib.linalg import Vectors
from pyspark.ml.classification import NaiveBayes

# Load and parse the data
def parseRow(row):
    return Row(label=row["categoryIndex"],
               features=Vectors.dense([row["Sepal.Length"],
                   row["Sepal.Width"],
                   row["Petal.Length"],
                   row["Petal.Width"]]))

## Must convert to dataframe after mapping
parsedData = indexed.map(parseRow).toDF()
nb = NaiveBayes(smoothing=1.0, modelType="multinomial")
model = nb.fit(parsedData)

predict_data = model.transform(parsedData)
traing_err = predict_data.filter(predict_data['label'] != predict_data['prediction']).count() 
total = predict_data.count()
print traing_err, total, float(traing_err)/total

结果如下,在150个样本的训练集上,有7个预测错误:

7 150 0.0466666666667

这里要注意几点:

  • Spark有两组机器学习的接口pyspark.ml和pyspark.mllib, 前一个是1.3推出的,比较新,功能也更丰富,后一个是0.9版本推出的,功能少一些。这两组API是不兼容的,你可以选一组来使用。

  • 新的接口要求数据集的类型是dataframe

这里把试用mllib的样例也放出来,大家可以比较一下。

from pyspark.mllib.classification import NaiveBayes
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.linalg import Vectors

# Load and parse the data
def parseRow(row):
    return LabeledPoint(row["categoryIndex"],
               Vectors.dense([row["Sepal.Length"],row["Sepal.Width"],row["Petal.Length"],row["Petal.Width"]]))

## Must convert to dataframe after mapping
parsedData = indexed.map(parseRow)
nb = NaiveBayes()
model = nb.train(parsedData)

labelsAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features)))
trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedData.count())
print("Training Error = " + str(trainErr))

结果和使用ml的是一样的:

Training Error = 0.0466666666667

然后我试用了SVM,代码如下:

from pyspark.mllib.classification import SVMWithSGD
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.linalg import Vectors

# Load and parse the data
def parseRow(row):
    return LabeledPoint(row["categoryIndex"],
               Vectors.dense([row["Sepal.Length"],row["Sepal.Width"],row["Petal.Length"],row["Petal.Width"]]))

## Must convert to dataframe after mapping
parsedData = indexed.map(parseRow)
nb = SVMWithSGD()
model = nb.train(parsedData, iterations=10)

labelsAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features)))
trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedData.count())
print("Training Error = " + str(trainErr))

结果出错了:

Py4JJavaError: An error occurred while calling o3397.trainSVMModelWithSGD.
: org.apache.spark.SparkException: Input validation failed.
	at org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm.run(GeneralizedLinearAlgorithm.scala:251)
	at org.apache.spark.mllib.api.python.PythonMLLibAPI.trainRegressionModel(PythonMLLibAPI.scala:94)
	at org.apache.spark.mllib.api.python.PythonMLLibAPI.trainSVMModelWithSGD(PythonMLLibAPI.scala:233)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:497)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:231)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:381)
	at py4j.Gateway.invoke(Gateway.java:259)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:133)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:209)
	at java.lang.Thread.run(Thread.java:745)

从这个结果完全看不出任何端倪。后来查找到了后台日志发现了原因:

ERROR DataValidators: Classification labels should be 0 or 1. Found 50 invalid labels

原来Spark实现的SVM方法只能支持二分类,不支持大于二的分类。这个有点坑呀,scklearn好像是支持的。虽然SVM理论是基于二元分类的,但是有办法扩展。

最后分享一个我在提取分类索引的时候的一个坑,因为觉得字符串映射为数值本身逻辑比较简单,我就自己实现了一个,然后去做map。

## ??? does not work
labels = dict()
def get_label(s):
    if labels.get(s) is None:
        print s
        l = len(labels)
        labels[s] = l   
    return labels.get(s)

# Load and parse the data
def parsePoint(row):
    return LabeledPoint(get_label(row["Species"]), [row["Sepal.Length"],row["Sepal.Width"]])

parsedData = df.map(parsePoint)

然而这样做是错的,因为传入map的labels是immutable的,在map方法中是无法修改labels的值的。这样才能保证在分布式运行是的无状态和并行。大家以后用的时候要小心。


展开阅读全文
加载中

作者的其它热门文章

打赏
1
1 收藏
分享
打赏
0 评论
1 收藏
1
分享
返回顶部
顶部