文档章节

MXNet Scala 学习笔记 二 ---- 创建新的 Operator

Ldpe2G
 Ldpe2G
发布于 2017/09/08 21:06
字数 1633
阅读 68
收藏 0
点赞 0
评论 0

MXNet Scala包中创建新的操作子

用现有操作子组合

    在MXNet中创建新的操作子有多种方式。第一种最简单的方法就是在前端(比如Python、Scala)
采用现有的操作子来组合,比如实现 Selu  激活函数。简单示例代码如下:

def selu(x: Symbol): Symbol = {
  val alpha = 1.6732632423543772848170429916717f
  val scale = 1.0507009873554804934193349852946f
  val condition = x >= 0f
  val y = Symbol.LeakyReLU()()(Map("data" -> x, "act_type" -> "elu", "slope" -> alpha))
  scale * Symbol.where()()(Map("condition" -> condition, "x" -> x, "y" -> y))
}

更详细的代码可参考:SelfNormNets。    或者实现L1损失函数,L1_Loss

def getAbsLoss(): Symbol = {
  val origin = Symbol.Variable("origin")
  val rec = Symbol.Variable("rec")
  val diff = origin - rec
  val abs = Symbol.abs()()(Map("data" -> diff))
  val mean = Symbol.mean()()(Map("data" -> abs))
  Symbol.MakeLoss()()(Map("data" -> mean))
}

    这种方式比较简单,而且如果你对现有的操作子基本熟悉的话,那么一般的需求基本都能满足。

CustomOp接口

   第二种是相对难度大一点的,比较接近第三种,有时候可能单纯的操作子组合满足不了需求,那就可以

采用继承CustomOp接口的方式,下面举Softmax的例子来解说:

class Softmax(_param: Map[String, String]) extends CustomOp {

  override def forward(sTrain: Boolean, req: Array[String],
    inData: Array[NDArray], outData: Array[NDArray], aux: Array[NDArray]): Unit = {
    val xShape = inData(0).shape
    val x = inData(0).toArray.grouped(xShape(1)).toArray
    val yArr = x.map { it =>
      val max = it.max
      val tmp = it.map(e => Math.exp(e.toDouble - max).toFloat)
      val sum = tmp.sum
      tmp.map(_ / sum)
    }.flatten
    val y = NDArray.empty(xShape, outData(0).context)
    y.set(yArr)
    this.assign(outData(0), req(0), y)
    y.dispose()
  }

  override def backward(req: Array[String], outGrad: Array[NDArray],
    inData: Array[NDArray], outData: Array[NDArray],
    inGrad: Array[NDArray], aux: Array[NDArray]): Unit = {
    val l = inData(1).toArray.map(_.toInt)
    val oShape = outData(0).shape
    val yArr = outData(0).toArray.grouped(oShape(1)).toArray
    l.indices.foreach { i =>
      yArr(i)(l(i)) -= 1.0f
    }
    val y = NDArray.empty(oShape, inGrad(0).context)
    y.set(yArr.flatten)
    this.assign(inGrad(0), req(0), y)
    y.dispose()
  }
}

    首先继承CustomOp抽象类,然后实现forward和backward函数,构造函数参数"_param"可以

当做是能够提取用户在构造Symbol时传入的参数具体例子可以参考CustomOpWithRtc

forward和backward的具体实现大家看源码就清楚了,变量名也很清晰,就是softmax的简化版。

需要注意的是,在算出结果之后,比如forward的y和backward的y之后,这时候需要调用内置的

assign函数把结果赋值给相应的outData或者inGrad。这里的req有几种"write"、"add"、"inplace"和

"null":

def assign(dst: NDArray, req: String, src: NDArray): Unit = req match {
  case "write" | "inplace" => dst.set(src)
  case "add" => dst += src
  case "null" => {}
}

在赋值完之后,因为y是临时申请的NDArray,所以在函数返回前需要调用dispose函数释放内存。

这是在使用Scala包的时候需要注意的地方。而inData数组里面的NDArray对应数据和标签的顺序

是接下来要说的。

    实现好CustomOp之后,需要再继承CustomOpProp抽象类,主要是定义自定义操作子的一些

比如输入输出的格式和相关信息等等。

class SoftmaxProp(needTopGrad: Boolean = false)
 extends CustomOpProp(needTopGrad) {

  override def listArguments(): Array[String] = Array("data", "label")

  override def listOutputs(): Array[String] = Array("output")

  override def inferShape(inShape: Array[Shape]):
    (Array[Shape], Array[Shape], Array[Shape]) = {
    val dataShape = inShape(0)
    val labelShape = Shape(dataShape(0))
    val outputShape = dataShape
    (Array(dataShape, labelShape), Array(outputShape), null)
  }

  override def inferType(inType: Array[DType]):
    (Array[DType], Array[DType], Array[DType]) = {
    (inType, inType.take(1), null)
  }

  override def createOperator(ctx: String, inShapes: Array[Array[Int]],
    inDtypes: Array[Int]): CustomOp = new Softmax(this.kwargs)
}

Operator.register("softmax", new SoftmaxProp)


//定义网络构造
val data = Symbol.Variable("data")
val label = Symbol.Variable("label")
val fc1 = Symbol.FullyConnected("fc1")()(Map("data" -> data, "num_hidden" -> 128))
val act1 = Symbol.Activation("relu1")()(Map("data" -> fc1, "act_type" -> "relu"))
val fc2 = Symbol.FullyConnected("fc2")()(Map("data" -> act1, "num_hidden" -> 64))
val act2 = Symbol.Activation("relu2")()(Map("data" -> fc2, "act_type" -> "relu"))
val fc3 = Symbol.FullyConnected("fc3")()(Map("data" -> act2, "num_hidden" -> 10))
val mlp = Symbol.Custom("softmax")()(Map("data" -> fc3, "label" -> label, "op_type" -> "softmax"))

    needTopGrad参数表示在backward的时候是否需要来自顶层的梯度,因为softmax是损失层,

一般放在网络最后,所以是不需要顶层的梯度,所以这里默认为false。然后listArguments函数

和listOutputs函数是定义该层操作子的输入与输出。这里listArugments的顺序与forward中的inData

顺序对应,同时listArugments还可以定义该层的输入还可以定义该层的参数,比如卷积层需要权值,

也需要在这里定义。inferShape就是根据输入的形状来推导网络的输出与参数形状,这些需要自己实现。

inferType的实现是可选的,支持多种数据类型DType。最后再实现createOperator函数,kwargs成员

变量存储了用户在构造Symbol时传入的参数,比如上面代码定义网络构造的代码,最后通过调用

Symbol.Custom函数然后根据你注册的操作子的名称就是"op_type"参数找到你自己实现的操作子。

用户还可以传任意的自定义参数,string->string,具体例子可以参考CenterLossCustomOpWithRtc

CustomOp使用注意事项

    在使用CustomOp创建操作子的时候需要注意的是,因为这种方式实现的操作子不是用已有的

操作子组合,而是用前端自己实现的,所以在保存训练模型的时候,尽管保存的模型定义的json文件

中会包含你的操作子,但是这个保存的模型直接给其他的用户是用不了的,因为在载入的时候会

报错,找不到你自己定义操作子,你需要把源码也给其他人,这是需要注意的地方。

    而对Scala包的CustomOp内部实现感兴趣的读者可以去看看源码:CustomOp实现源码

JNI部分源码。对于实现CustomOp这个功能,可以说是我参与MXNet项目以来提交的所有的pr中

难度最大之一了,因为对JNI不是很熟,还记得当时是一边google一边debug弄了差不多两周才

搞好这个功能。不过这个过程中也算是学到了不少东西。

C++

    最后一种就是采用C++来实现了,也是难度最大的,其实CustomOp是这种方式的简化版,

CustomOp其实在后端也对应也有C++的实现,相当于在后端有一个CustomOp来调用

前端定义的CustomOp,具体源码:

https://github.com/apache/incubator-mxnet/tree/master/src/operator/custom

而Scala包Symbol类的操作子定义是采用macro的方式生成的,自动与C++这边同步,所以

只要你按照文档在C++端定义好新的操作子,那么在Scala包这边就能用。

怎么用C++自动以新的操作子这里就不详细展开了具体可以参考文档源码

 

    

 

 

© 著作权归作者所有

共有 人打赏支持
Ldpe2G
粉丝 17
博文 18
码字总数 29621
作品 0
广州
程序员
MXNet/Gluon 中网络和参数的存取方式

Gluon是MXNet的高层封装,网络设计简单易用,与Keras类似。随着深度学习技术的普及,类似于Gluon这种,高层封装的深度学习框架,被越来越多的开发者接受和使用。 在开发深度学习算法时,必然...

SpikeKing ⋅ 05/29 ⋅ 0

业界 | MXNet开放支持Keras,高效实现CNN与RNN的分布式训练

  选自AWS Machine Learning Blog   作者:Lai Wei、Kalyanee Chendke、Aaron Markham、Sandeep Krishnamurthy   机器之心编译   参与:路、王淑婷      今日 AWS 发布博客宣布 ...

机器之心 ⋅ 05/22 ⋅ 0

MXNet 宣布支持 Keras 2,可更加方便快捷地实现 CNN 及 RNN 分布式训练

雷锋网(公众号:雷锋网) AI 研习社按,近期,AWS 表示 MXNet 支持 Keras 2,开发者可以使用 Keras-MXNet 更加方便快捷地实现 CNN 及 RNN 分布式训练。AI 研习社将 AWS 官方博文编译如下。 Ke...

孔令双 ⋅ 05/23 ⋅ 0

windows下编译mxnet并使用C++接口开发

大多数情况下,mxnet都使用python接口进行机器学习程序的编写,方便快捷,但是有的时候,需要把机器学习训练和识别的程序部署到生产版的程序中去,比如游戏或者云服务,此时采用C++等高级语言...

u012234115 ⋅ 05/29 ⋅ 0

资源 | DMLC团队发布GluonCV和GluonNLP:两种简单易用的DL工具箱

  选自Gluon   机器之心编译   参与:思源、李亚洲      近日,DMLC 发布了简单易用的深度学习工具箱 GluonCV 和 GluonNLP,它们分别为计算机视觉和自然语言处理提供了顶级的算法实...

机器之心 ⋅ 04/26 ⋅ 0

资源 | 从VGG到ResNet,你想要的MXNet预训练模型轻松学

  选自AWS Blog   作者:Julien Simon   机器之心编译   参与:Pedro、路      本文介绍了如何利用 Apache MXNet 预训练出的多个模型。每个模型在特定图像上的表现略有不同,训练...

机器之心 ⋅ 05/20 ⋅ 0

亚马逊AWS首席科学家:从图像理解到语音识别,我们是如何研究和量化机器学习的

亚马逊一直致力于寻求机器学习多域模型的解决方案,以及多领域的应用如何能够在云上进行计算。 在今日于北京召开的《麻省理工科技评论》新兴科技峰会EmTech China上,作为亚马逊旗下最赚钱云...

行者武松 ⋅ 04/12 ⋅ 0

GPU云服务器深度学习性能模型初探

1 背景   得益于GPU强大的计算能力,深度学习近年来在图像处理、语音识别、自然语言处理等领域取得了重大突破,GPU服务器几乎成了深度学习加速的标配。   阿里云GPU云服务器在公有云上提...

撷峰 ⋅ 04/23 ⋅ 0

CV 深度学习工具包 GluonCV 开源,助你实现重要论文复现

雷锋网(公众号:雷锋网) AI 研习社按,日前,MXNet 作者李沐在 Apache MXNet 中文公众号发文,宣布开源计算机视觉深度学习工具包 GluonCV,在文中,他详细介绍了做 GluonCV 工具包的原因,以...

思颖 ⋅ 04/28 ⋅ 0

Pycharm远程调试之Docker debug

关于连接Linux Docker 我们以前使用的是Docker Toolbox,在配置的时候pycharm也是自动填充的是关于Docker Toolbox的信息,看来是对 Docker Toolbox的支持是比较好的。 我们需要了解以下几件事...

JungleKing ⋅ 06/13 ⋅ 0

没有更多内容

加载失败,请刷新页面

加载更多

下一页

spring Email

使用spring发Email其实就是使用spring自己封装携带的一个javamail.JavaMailSenderImpl类而已。这个类可以当一个普通的java对象来使用,也可以通过把它配置变成spring Bean的方式然后注入使用...

BobwithB ⋅ 19分钟前 ⋅ 0

spark 整理的一些知识

Spark 知识点 请描述spark RDD原理与特征? RDD全称是resilient distributed dataset(具有弹性的分布式数据集)。一个RDD仅仅是一个分布式的元素集合。在Spark中,所有工作都表示为创建新的...

tuoleisi77 ⋅ 23分钟前 ⋅ 0

思考

时间一天天过感觉自己有在成长吗?最怕的是时光匆匆而过,自己没有收获!下面总结下最近自己的思考。 认识自己 认识另一个自己,人们常说要虚心听取别人意见和建议。然而人往往是很难做到的,...

hello_hp ⋅ 23分钟前 ⋅ 0

IT行业的变革就像世界杯德国对战墨西哥一样难以预测[图]

最近在观看世界杯,尤其是昨天的比赛,上一届卫冕冠军德国队居然0:1告负墨西哥,这创造了历史,首先是墨西哥从来没赢过德国队,其次是德国队36年来首站没输过,再差也是打平,而这次,德国队...

原创小博客 ⋅ 42分钟前 ⋅ 0

解决CentOS6、7,/etc/sysconfig/下没有iptables的问题

一、Centos 6版本解决办法: 1.任意运行一条iptables防火墙规则配置命令: iptables -P OUTPUT ACCEPT 2.对iptables服务进行保存: service iptables save 3.重启iptables服务: service ...

寰宇01 ⋅ 52分钟前 ⋅ 2

数据库备份和恢复

备份:mysqldump -u root -p 数据库>磁盘路径 恢复:mysql -u root -p 数据库<sql脚本的磁盘路径

anlve ⋅ 今天 ⋅ 0

发生了什么?Linus 又发怒了?

在一个 Linux 内核 4.18-rc1 的 Pull Request 中,开发者 Andy Shevchenko 表示其在对设备属性框架进行更新时,移除了 union 别名,这引发了 Linus 的暴怒。 这一次 Linus Torvalds 发怒的原...

问题终结者 ⋅ 今天 ⋅ 0

在树莓派上搭建一个maven仓库

在树莓派上搭建一个maven仓库 20180618 lambo init 项目说明 家里有台树莓派性能太慢。想搭建一个maven私服, 使用nexus或者 jfrog-artifactory 运行的够呛。怎么办呢,手写一个吧.所在这个...

林小宝 ⋅ 今天 ⋅ 0

Spring发展历程总结

转自与 https://www.cnblogs.com/RunForLove/p/4641672.html 目前很多公司的架构,从Struts2迁移到了SpringMVC。你有想过为什么不使用Servlet+JSP来构建Java web项目,而是采用SpringMVC呢?...

onedotdot ⋅ 今天 ⋅ 0

没有更多内容

加载失败,请刷新页面

加载更多

下一页

返回顶部
顶部