文档章节

MXNet Scala 学习笔记 二 ---- 创建新的 Operator

Ldpe2G
 Ldpe2G
发布于 2017/09/08 21:06
字数 1633
阅读 89
收藏 0

MXNet Scala包中创建新的操作子

用现有操作子组合

    在MXNet中创建新的操作子有多种方式。第一种最简单的方法就是在前端(比如Python、Scala)
采用现有的操作子来组合,比如实现 Selu  激活函数。简单示例代码如下:

def selu(x: Symbol): Symbol = {
  val alpha = 1.6732632423543772848170429916717f
  val scale = 1.0507009873554804934193349852946f
  val condition = x >= 0f
  val y = Symbol.LeakyReLU()()(Map("data" -> x, "act_type" -> "elu", "slope" -> alpha))
  scale * Symbol.where()()(Map("condition" -> condition, "x" -> x, "y" -> y))
}

更详细的代码可参考:SelfNormNets。    或者实现L1损失函数,L1_Loss

def getAbsLoss(): Symbol = {
  val origin = Symbol.Variable("origin")
  val rec = Symbol.Variable("rec")
  val diff = origin - rec
  val abs = Symbol.abs()()(Map("data" -> diff))
  val mean = Symbol.mean()()(Map("data" -> abs))
  Symbol.MakeLoss()()(Map("data" -> mean))
}

    这种方式比较简单,而且如果你对现有的操作子基本熟悉的话,那么一般的需求基本都能满足。

CustomOp接口

   第二种是相对难度大一点的,比较接近第三种,有时候可能单纯的操作子组合满足不了需求,那就可以

采用继承CustomOp接口的方式,下面举Softmax的例子来解说:

class Softmax(_param: Map[String, String]) extends CustomOp {

  override def forward(sTrain: Boolean, req: Array[String],
    inData: Array[NDArray], outData: Array[NDArray], aux: Array[NDArray]): Unit = {
    val xShape = inData(0).shape
    val x = inData(0).toArray.grouped(xShape(1)).toArray
    val yArr = x.map { it =>
      val max = it.max
      val tmp = it.map(e => Math.exp(e.toDouble - max).toFloat)
      val sum = tmp.sum
      tmp.map(_ / sum)
    }.flatten
    val y = NDArray.empty(xShape, outData(0).context)
    y.set(yArr)
    this.assign(outData(0), req(0), y)
    y.dispose()
  }

  override def backward(req: Array[String], outGrad: Array[NDArray],
    inData: Array[NDArray], outData: Array[NDArray],
    inGrad: Array[NDArray], aux: Array[NDArray]): Unit = {
    val l = inData(1).toArray.map(_.toInt)
    val oShape = outData(0).shape
    val yArr = outData(0).toArray.grouped(oShape(1)).toArray
    l.indices.foreach { i =>
      yArr(i)(l(i)) -= 1.0f
    }
    val y = NDArray.empty(oShape, inGrad(0).context)
    y.set(yArr.flatten)
    this.assign(inGrad(0), req(0), y)
    y.dispose()
  }
}

    首先继承CustomOp抽象类,然后实现forward和backward函数,构造函数参数"_param"可以

当做是能够提取用户在构造Symbol时传入的参数具体例子可以参考CustomOpWithRtc

forward和backward的具体实现大家看源码就清楚了,变量名也很清晰,就是softmax的简化版。

需要注意的是,在算出结果之后,比如forward的y和backward的y之后,这时候需要调用内置的

assign函数把结果赋值给相应的outData或者inGrad。这里的req有几种"write"、"add"、"inplace"和

"null":

def assign(dst: NDArray, req: String, src: NDArray): Unit = req match {
  case "write" | "inplace" => dst.set(src)
  case "add" => dst += src
  case "null" => {}
}

在赋值完之后,因为y是临时申请的NDArray,所以在函数返回前需要调用dispose函数释放内存。

这是在使用Scala包的时候需要注意的地方。而inData数组里面的NDArray对应数据和标签的顺序

是接下来要说的。

    实现好CustomOp之后,需要再继承CustomOpProp抽象类,主要是定义自定义操作子的一些

比如输入输出的格式和相关信息等等。

class SoftmaxProp(needTopGrad: Boolean = false)
 extends CustomOpProp(needTopGrad) {

  override def listArguments(): Array[String] = Array("data", "label")

  override def listOutputs(): Array[String] = Array("output")

  override def inferShape(inShape: Array[Shape]):
    (Array[Shape], Array[Shape], Array[Shape]) = {
    val dataShape = inShape(0)
    val labelShape = Shape(dataShape(0))
    val outputShape = dataShape
    (Array(dataShape, labelShape), Array(outputShape), null)
  }

  override def inferType(inType: Array[DType]):
    (Array[DType], Array[DType], Array[DType]) = {
    (inType, inType.take(1), null)
  }

  override def createOperator(ctx: String, inShapes: Array[Array[Int]],
    inDtypes: Array[Int]): CustomOp = new Softmax(this.kwargs)
}

Operator.register("softmax", new SoftmaxProp)


//定义网络构造
val data = Symbol.Variable("data")
val label = Symbol.Variable("label")
val fc1 = Symbol.FullyConnected("fc1")()(Map("data" -> data, "num_hidden" -> 128))
val act1 = Symbol.Activation("relu1")()(Map("data" -> fc1, "act_type" -> "relu"))
val fc2 = Symbol.FullyConnected("fc2")()(Map("data" -> act1, "num_hidden" -> 64))
val act2 = Symbol.Activation("relu2")()(Map("data" -> fc2, "act_type" -> "relu"))
val fc3 = Symbol.FullyConnected("fc3")()(Map("data" -> act2, "num_hidden" -> 10))
val mlp = Symbol.Custom("softmax")()(Map("data" -> fc3, "label" -> label, "op_type" -> "softmax"))

    needTopGrad参数表示在backward的时候是否需要来自顶层的梯度,因为softmax是损失层,

一般放在网络最后,所以是不需要顶层的梯度,所以这里默认为false。然后listArguments函数

和listOutputs函数是定义该层操作子的输入与输出。这里listArugments的顺序与forward中的inData

顺序对应,同时listArugments还可以定义该层的输入还可以定义该层的参数,比如卷积层需要权值,

也需要在这里定义。inferShape就是根据输入的形状来推导网络的输出与参数形状,这些需要自己实现。

inferType的实现是可选的,支持多种数据类型DType。最后再实现createOperator函数,kwargs成员

变量存储了用户在构造Symbol时传入的参数,比如上面代码定义网络构造的代码,最后通过调用

Symbol.Custom函数然后根据你注册的操作子的名称就是"op_type"参数找到你自己实现的操作子。

用户还可以传任意的自定义参数,string->string,具体例子可以参考CenterLossCustomOpWithRtc

CustomOp使用注意事项

    在使用CustomOp创建操作子的时候需要注意的是,因为这种方式实现的操作子不是用已有的

操作子组合,而是用前端自己实现的,所以在保存训练模型的时候,尽管保存的模型定义的json文件

中会包含你的操作子,但是这个保存的模型直接给其他的用户是用不了的,因为在载入的时候会

报错,找不到你自己定义操作子,你需要把源码也给其他人,这是需要注意的地方。

    而对Scala包的CustomOp内部实现感兴趣的读者可以去看看源码:CustomOp实现源码

JNI部分源码。对于实现CustomOp这个功能,可以说是我参与MXNet项目以来提交的所有的pr中

难度最大之一了,因为对JNI不是很熟,还记得当时是一边google一边debug弄了差不多两周才

搞好这个功能。不过这个过程中也算是学到了不少东西。

C++

    最后一种就是采用C++来实现了,也是难度最大的,其实CustomOp是这种方式的简化版,

CustomOp其实在后端也对应也有C++的实现,相当于在后端有一个CustomOp来调用

前端定义的CustomOp,具体源码:

https://github.com/apache/incubator-mxnet/tree/master/src/operator/custom

而Scala包Symbol类的操作子定义是采用macro的方式生成的,自动与C++这边同步,所以

只要你按照文档在C++端定义好新的操作子,那么在Scala包这边就能用。

怎么用C++自动以新的操作子这里就不详细展开了具体可以参考文档源码

 

    

 

 

© 著作权归作者所有

共有 人打赏支持
Ldpe2G
粉丝 17
博文 18
码字总数 29621
作品 0
广州
程序员
运用 MXNet Scala API 接口进行图像分类

雷锋网(公众号:雷锋网)按:本文为雷锋字幕组编译的技术博客,原标题 Image Classification with MXNet Scala Inference API,作者为 Qing Lan, Roshani Nagmote 翻译 | 朱茵 整理 | 凡江 随...

雷锋字幕组
07/20
0
0
MXNet/Gluon 中网络和参数的存取方式

Gluon是MXNet的高层封装,网络设计简单易用,与Keras类似。随着深度学习技术的普及,类似于Gluon这种,高层封装的深度学习框架,被越来越多的开发者接受和使用。 在开发深度学习算法时,必然...

SpikeKing
05/29
0
0
云上深度学习实践(二)-云上MXNet实践

目录 云上深度学习实践(一)-GPU云服务器TensorFlow单机多卡训练性能实践 云上深度学习实践(二)-云上MXNet实践 1 MXNet 简介 1.1 MXNet特点 MXNet是一个全功能,灵活可编程和高扩展性的深...

撷峰
07/13
0
0
业界 | MXNet开放支持Keras,高效实现CNN与RNN的分布式训练

  选自AWS Machine Learning Blog   作者:Lai Wei、Kalyanee Chendke、Aaron Markham、Sandeep Krishnamurthy   机器之心编译   参与:路、王淑婷      今日 AWS 发布博客宣布 ...

机器之心
05/22
0
0
MXNet 宣布支持 Keras 2,可更加方便快捷地实现 CNN 及 RNN 分布式训练

雷锋网(公众号:雷锋网) AI 研习社按,近期,AWS 表示 MXNet 支持 Keras 2,开发者可以使用 Keras-MXNet 更加方便快捷地实现 CNN 及 RNN 分布式训练。AI 研习社将 AWS 官方博文编译如下。 Ke...

孔令双
05/23
0
0

没有更多内容

加载失败,请刷新页面

加载更多

下一页

python标准输入输出

input() 读取键盘输入 input() 函数从标准输入读入一行文本,默认的标准输入是键盘。 input 可以接收一个Python表达式作为输入,并将运算结果返回。 print()和format()输出 format()输出...

colinux
25分钟前
0
0
Python 核心编程 (全)

浅拷贝和深拷贝 1.浅拷贝:是对于一个对象的顶层拷贝,通俗的理解是:拷贝了引用,并没有拷贝内容。相当于把变量里面指向的一个地址给了另一个变量就是浅拷贝,而没有创建一个新的对象,如a...

代码打碟手
37分钟前
0
0
PHP 对象比数组省内存?错!数组比对象省内存?错!

刚刚一个群里有人引出了 PHP 数组和对象占用内存谁多谁少的问题。我想起之前我好像也测试过这个问题,和群里人说的对象比数组节省内存的结论相反,我得出的是数组比对象节省内存。 但今天,我...

宇润
54分钟前
1
0
memcached命令行及其用法

21.5 memcached命令行 创建数据 yum install -y telnet 利用telnet命令连接memcached数据库 telnet 127.0.0.1 11211 #写入数据 set key2 0 30 212STORED 这个是错误的示范,因为0 30 已经...

lyy549745
54分钟前
0
0
Maven私服

Maven私服 一、简介 当多人项目开发的时候,尤其聚合项目开发,项目和项目之间需要有依赖关系,通过maven私服,可以保存互相依赖的jar包,这样的话就可把多个项目整合到一起。 如下图: Inst...

星汉
57分钟前
0
0

没有更多内容

加载失败,请刷新页面

加载更多

下一页

返回顶部
顶部