文档章节

Spark机器学习工具链-MLflow使用教程

openthings
 openthings
发布于 2018/06/07 09:42
字数 1466
阅读 2155
收藏 1

Spark机器学习工具链-MLflow使用教程

参考:

什么是我们构建的?

在本教程中,我们将演示一个案例,展示数据科学家使用MLFlow端到端地构建一个线性回归模型。如何使用MLflow打包代码,其中代码训练该模型以一种可重用和重复生产的模型格式保存。最后,使用MLflow创建简单的 HTTP server,可以用来进行预测。

我们使用一个数据集来预测酒类质量,基于酒的量化指标如“fixed acidity”, “pH”, “residual sugar”, 等等。数据集来自于 UCI’s machine learning repository[Ref]

你首先需要?

本教程中,我们使用MLflow, conda, 和位于example/tutorial的示范代码,在 MLflow repository。下载相关代码,如下:

git clone https://github.com/databricks/mlflow

训练模型

要做的第一件事是训练一个线性回归模型,有两个hyperparameters: alpha 和 l1_ratio。

使用的代码位于 example/tutorial/train.py,如下

# Read the wine-quality csv file (make sure you're running this from the root of MLflow!)
wine_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "wine-quality.csv")
data = pd.read_csv(wine_path)

# Split the data into training and test sets. (0.75, 0.25) split.
train, test = train_test_split(data)

# The predicted column is "quality" which is a scalar from [3, 9]
train_x = train.drop(["quality"], axis=1)
test_x = test.drop(["quality"], axis=1)
train_y = train[["quality"]]
test_y = test[["quality"]]

alpha = float(sys.argv[1]) if len(sys.argv) > 1 else 0.5
l1_ratio = float(sys.argv[2]) if len(sys.argv) > 2 else 0.5

with mlflow.start_run():
    lr = ElasticNet(alpha=alpha, l1_ratio=l1_ratio, random_state=42)
    lr.fit(train_x, train_y)

    predicted_qualities = lr.predict(test_x)

    (rmse, mae, r2) = eval_metrics(test_y, predicted_qualities)

    print("Elasticnet model (alpha=%f, l1_ratio=%f):" % (alpha, l1_ratio))
    print("  RMSE: %s" % rmse)
    print("  MAE: %s" % mae)
    print("  R2: %s" % r2)

    mlflow.log_param("alpha", alpha)
    mlflow.log_param("l1_ratio", l1_ratio)
    mlflow.log_metric("rmse", rmse)
    mlflow.log_metric("r2", r2)
    mlflow.log_metric("mae", mae)

    mlflow.sklearn.log_model(lr, "model")

在这里,我们使用pandas、numpy和 sklearn APIs 创建简单的机器学习模型。除此之外,我们使用 MLflow tracking APIs记录每一次训练的信息,如 hyperparameters alpha 和 l1_ratio 用于训练的度量,如 root mean square error,用于评估该模型。另外,我们序列化该模型model,以MLflow可以部署的格式保存。

运行代码:

python example/tutorial/train.py

试验其他的 alpha 和 l1_ratio,通过将其作为参数传入train.py,如下:

python example/tutorial/train.py <alpha> <l1_ratio>

运行后,MLflow 记录了相关信息,在目录 mlruns中。

比较模型

下一步,我们使用 MLflow UI 来比较刚才产生的模型。运行mlflow ui在同样的工作目录(包含 mlruns),在浏览器打开 http://localhost:5000

此页面中,可以看到所产生的度量指标,如下:

_images/tutorial-compare.png

从此页面可以看到,较低的 alpha 更适合我们的模型。我们可以使用搜索快速过滤出模型。例如,查询 metrics.rmse < 0.8 将返回所有 root mean squared error 小于 0.8的。更复杂的操作,可以下载 CSV的表格,并使用喜欢的软件来分析。

打包训练代码

现在,我们有了编写好的训练代码,希望将其打包从而让其他的数据科学家可以容易地重用这个模型,或者将其放到远程服务器运行。为了打包,我们使用 MLflow Projects conventions指定代码的依赖和入口点。在 example/tutorial/MLproject 文件中,我们指定project的依赖在 Conda environment file ,名为 conda.yaml, 我们的这个项目有一个入口点,接受两个参数:alpha 和 l1_ratio。如下:

# example/tutorial/MLproject

name: tutorial

conda_env: conda.yaml

entry_points:
  main:
    parameters:
      alpha: float
      l1_ratio: {type: float, default: 0.1}
    command: "python train.py {alpha} {l1_ratio}"
# example/tutorial/conda.yaml

name: tutorial
channels:
  - defaults
dependencies:
  - numpy=1.14.3
  - pandas=0.22.0
  - scikit-learn=0.19.1
  - pip:
    - mlflow

为了运行该项目,简单地调用 mlflow run example/tutorial -P alpha=0.42。运行命令后, MLflow将在新的conda环境中运行训练代码,并且使用在 conda.yaml中指定的依赖软件和模块。

Projects can also be run directly from Github if the repository has a MLproject file in the root. We’ve duplicated this tutorial to the https://github.com/databricks/mlflow-example repository which can be run with mlflow run git@github.com:databricks/mlflow-example.git -P alpha=0.42.

服务模型

现在,我们将 MLproject打包并且识别出最好的model,是时候使用 MLflow Models来部署这个模型了。一个MLflow Model是机器学习模型封装的标准格式,可以用于后续一系列的处理工具。例如,通过real-time serving提供 REST API 或在Spark上的批处理智能推理。

在我们的训练代码中,训练出线性回归模型后,我们启动 MLflow 中的一个函数,保存模型为运行部件。

mlflow.sklearn.log_model(lr, "model")

为了浏览这个 artifact,我们再次使用UI。点击页面中的列表,如下。

_images/tutorial-artifact.png

在下面,我们看到对 mlflow.sklearn.log_model 的调用产生了两个文件,在/Users/mlflow/mlflow-prototype/mlruns/0/7c1a0d5c42844dcdb8f5191146925174/artifacts/model。第一个 MLmodel 是元数据文件,告诉MLflow如何载入模型。第二个文件 model.pkl 是我们训练的线性回归模型的序列化。

在这个例子中,我们演示使用 MLmodel 格式通过MLflow部署一个本地的REST server,用于进行预测。

部署上服务器,运行:

mlflow sklearn serve /Users/mlflow/mlflow-prototype/mlruns/0/7c1a0d5c42844dcdb8f5191146925174/artifacts/model -p 1234

注意:

该版本Python必须与运行mlflow sklearn的一致。否则,可能会报错: UnicodeDecodeError: 'ascii' codec can't decode byte 0x9f in position 1: ordinal not in range(128) or raise ValueError, "unsupported pickle protocol: %d".

预测服务调用,运行:

curl -X POST -H "Content-Type:application/json" --data '[{"fixed acidity": 6.2, "volatile acidity": 0.66, "citric acid": 0.48, "residual sugar": 1.2, "chlorides": 0.029, "free sulfur dioxide": 29, "total sulfur dioxide": 75, "density": 0.98, "pH": 3.33, "sulphates": 0.39, "alcohol": 12.8}]' http://127.0.0.1:1234/invocations

# RESPONSE
# {"predictions": [6.379428821398614]}

更多资源

感谢完成这个教程,更多的参考 MLflow TrackingMLflow ProjectsMLflow Models等内容。

© 著作权归作者所有

openthings
粉丝 324
博文 1140
码字总数 689435
作品 1
东城
架构师
私信 提问
AirFlow/NiFi/MLFlow/KubeFlow进展

大数据分析中,进行流程化的批处理是必不可少的。传统的大数据处理大部分是基于关系数据库系统,难以实现大规模扩展;主流的基于Hadoop/Spark体系总体性能较强,但使用复杂、扩展能力弱。大数...

openthings
06/21
437
0
钉钉群直播【MLFlow和spark在机器学习方面的进展、Project Hydrogen和spark在深度学习方面的进展 】

直播主题: 【MLFlow和spark在机器学习方面的进展、Project Hydrogen和spark在深度学习方面的进展 】 时间: 6月19日 19:30-20:30 分享嘉宾: 江宇,阿里云EMR技术专家。从事Hadoop内核开发...

EMR
06/17
0
0
机器学习管理平台 MLFlow

最近工作很忙,博客一直都没有更新。抽时间给大家介绍一下Databrick开源的机器学习管理平台-MLFlow。 谈起Databrick,相信即使是不熟悉机器学习和大数据的工程湿们也都有所了解,它由Spark的...

naughty
2018/07/21
2.7K
1
【短文】Spark危机与机遇杂谈

MLFlow 昨天发了一篇文章Spark团队新作MLFlow 解决了什么问题 描述了我对MLFlow的一些看法,现在想来,Spark团队是非常聪明的,AI同学都有自己的社区,自己的生态,Spark则是在工程研发群体具...

祝威廉
2018/06/07
0
0
Apache Spark 技术团队开源机器学习平台 MLflow

近日,来自 Databricks 的 Matei Zaharia 宣布推出开源机器学习平台 MLflow 。Matei Zaharia 是 Apache Spark 和 Apache Mesos 的核心作者,也是 Databrick 的首席技术专家。Databrick 是由 ...

王练
2018/06/08
2.5K
0

没有更多内容

加载失败,请刷新页面

加载更多

MongoDB系列-在复制集(replication)以及分片(Shard)中创建索引

关注我,可以获取最新知识、经典面试题以及微服务技术分享   在使用MongoDB时,在创建索引会涉及到在复制集(replication)以及分片(Shard)中创建,为了最大限度地减少构建索引的影响,在副本...

ccww_
32分钟前
31
0
SAP HANA数据库multi container模式JDBC链接connection refused

报错如下信息 com.sap.db.jdbc.exceptions.JDBCDriverException: SAP DBTech JDBC: Cannot connect to jdbc:sap://xxx.xxx.xxx.xxx:30015 [Cannot connect to host xxx.xxx.xxx.xxx:30015 [C......

flash胜龙
57分钟前
53
0
c++ 虚基类

c++ 虚基类 p556

天王盖地虎626
今天
102
0
k8s删除Terminating状态的命名空间

背景: 我们都知道在k8s中namespace有两种常见的状态,即Active和Terminating状态,其中后者一般会比较少见,只有当对应的命名空间下还存在运行的资源,但是该命名空间被删除时才会出现所谓的...

Andy-xu
今天
106
0
seata源码阅读笔记

seata源码阅读笔记 本文没有seata的使用方法,怎么使用seata可以参考官方示例,详细的很。 本文基于v0.8.0版本,本文没贴代码。 seata中的三个重要部分: TC:事务协调器,维护全局事务和分支...

东都大狼狗
今天
62
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部