文档章节

学习笔记TF041:分布式并行

利炳根
 利炳根
发布于 2017/08/20 11:15
字数 2341
阅读 4
收藏 0
点赞 0
评论 0

TensorFlow分布式并行基于gRPC通信框架,一个master负责创建Session,多个worker负责执行计算图任务。

先创建TensorFlow Cluster对象,包含一组task(每个task一台独立机器),分布式执行TensorFlow计算图。一个Cluster切分多个job,一个job是一类特定任务(parameter server ps,worker),每个job可以包含多个task。每个task创建一个server,连接到Cluster,每个task执行在不同机器。也可以一台机器执行多个task(不同GPU)。tf.train.ClusterSpec初始化Cluster对象,初始化信息是Python dict,tf.train.ClusterSpec({"ps":["192.168.233.201:2222"],"worker":["192.168.233.202:2222","192.168.233.203:2222"]}),代表一个parameter server和两个worker,分别在三个不同机器上。每个task,定义自己身份,如server=tf.train.Server(cluster,job_name="ps",task_index=0),机器job定义ps第0台机器。程序中with tf.device("/job:worker/task:7"),限定Variable存放在哪个task或机器。

TensorFlow分布式模式:In-graph replication模型并行,模型计算图不同部分放在不同机器执行;Between-graph replication数据并行,每台机器相同计算图,计算不同batch数据。异步并行,每台机器独立计算梯度,计算完更新到parameter server,不等其他机器。同步并行,等所有机器都完成梯度计算,多个梯度合成统一更新模型参数。同步并行训练,loss下降速度更快,可达到最大精度更高,同步并行速度取决最慢机器,设备速度一致,效率较高。

TensorFlow实现包含1个paramter server、2个worker分布式并行训练程序,MNIST手写数据识别任务示例。写一个完整Python文件,在不同机器不同task执行。载入依赖库。

tf.app.flags定义标记,命令行执行TensorFlow设置参数。命令行指定参数被TensorFlow读取,转flags。设定数据储存目录data_dir默认/tmp/mnist-data,隐藏节点数默认100,训练最大步数train_steps默认1000000,batch_size默认100,学习速率默认0.01。

设定是否使用同步并行标记sync_replicas默认False,命令行执行时可设True开户同步并行。设定需要累计梯度个数更新模型值默认None,代表同步并行积累多少个batch梯度再进行参数更新,设None 为worker数量,所有worker完成一个batch训练后再更新模型参数。

定义ps地址,默认192.168.233.201:2222,根据集群实际情况配置。worker地址设置192.168.233.202:2222和192.168.233.203:2222.设置job_name和task_index FLAG。

flags.FLAGS直接命名FLAGS,简化使用。设置图片尺寸IMAGE_PIXELS 28。

编写程序主函数main,input_data.read_data_sets下载读取MNIST数据集,设置one_hot编码格式。检测命令行输入参数,确保job_name和task_index两个必备参数。显示job_name和task_index,ps、worker所有地址解析成列表ps_spec、worker_spec。

计算总共worker数量,tf.train.ClusterSpec生成TensorFlow Cluster对象,传入参数ps地址信息和worker地址信息。tf.train.Server创建当前机器server,连接Cluster。如当前节点是parameter server,不进行后续操作,server.join等待worker工作。

判断当前机器是否主节点,task_index是否0。定义当前机器worker_device,格式"job:worker/task:0/gpu:0"。多台机器,每台机器1块GPU,总共需要机器数量worker。如一台机器多GPU,一个task管理多个GPU或多个task分别管理。tf.train.replica_device_setter()设置worker资源,worker_device计算资源,ps_device存储模型参数资源。replica_device_setter将模型参数部署在独立ps服务器"/job:ps/cpu:0",训练操作部署在"/job:worker/task:0/gpu:0",本机GPU。创建记录全局训练步数变量global_step。

定义神经网络模型,tf.truncated_normal初始化权重,tf.zeros初始化偏置,创建输入 placeholder,tf.nn.xw_plus_b输入矩阵乘法、偏置操作,ReLU激活函数处理,得到第一个隐层输出hid。tf.nn.xw_plus_b、tf.nn.softmax对第一层输出hid处理,得到网络最终输出y。最后计算损失cross_entropy,定义优化器Adam。

判断是否设置同步训练模式sync_replicas,如果同步模型,先获取同步更新模型参数需要副本数replicas_to_aggregate;如果没有单独设置,worker数作默认值。tf.train.SyncReplicasOptimizer创建同步训练优化器,对原有优化器扩展。传入原有优化器及其他参数(replicas_to_aggregate、total_num_replicas、replica_id),原有优化器改造为同步分布式训练版本。用普通(异步)或同步优化器优化损失cross_entropy。

同步训练模式,主节点,opt.get_chief_queue_runner创建队列执行器,opt.get_init_tokens_op创建全局参数初始化器。

生成本地参数初始化操作init_op,创建临时训练目录,tf.train_Supervisor创建分布式训练监督器,传入参数is_chief、train_dir、init_op。Supervisor管理task参与到分布式训练。

设置Session参数,allow_soft_placement设True,代表操作在指定device不能执行时转到其他device执行。

如果主节点,显示初始化Session,其他节点显示等待主节点初始化操作。执行sv.prepate_or_wait_for_session()。

如果处于同步模型主节点,sv.start_queue_runners执行队列化执行器chief_queue_runner,执行全局参数初始化器init_tokens_op。

训练过程,记录worker执行训练启动时间,初始化本地训练步数local_step,进入训练循环。每步训练,从nnist.train.next_batch读取一个batch数据,生成feed_dict,调train_step执行训练。当全局训练步数达到预设最大值,停止训练。

训练结束,展示总训练时间,在验证数据上计算预测结果损失cross_entropy,展示。

在主程序执行tf.app.run()启动main函数,全部代码保存到distributed.py文件。3台不同机器分别执行distributed.py启动3个task,每次执行distributed.py,传入job_name、task_index指定worker身份。

分别在三台机器192.168.233.201、192.168.233.202、192.168.233.203执行python distributed.py。

同步模式,加上--sync_replicas=True。global_step,异步时,全局步数是所有worker训练步数和,同步时是多少轮并行训练。

#from __future__ import absolute_import
#from __future__ import division
#from __future__ import print_function
import math
#import sys
import tempfile
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
flags = tf.app.flags
flags.DEFINE_string("data_dir", "/tmp/mnist-data",
                    "Directory for storing mnist data")
#flags.DEFINE_boolean("download_only", False,
#                     "Only perform downloading of data; Do not proceed to "
#                     "session preparation, model definition or training")
flags.DEFINE_integer("task_index", None,
                     "Worker task index, should be >= 0. task_index=0 is "
                     "the master worker task the performs the variable "
                     "initialization ")
#flags.DEFINE_integer("num_gpus", 2,
#                     "Total number of gpus for each machine."
#                     "If you don't use GPU, please set it to '0'")
flags.DEFINE_integer("replicas_to_aggregate", None,
                     "Number of replicas to aggregate before parameter update"
                     "is applied (For sync_replicas mode only; default: "
                     "num_workers)")
flags.DEFINE_integer("hidden_units", 100,
                     "Number of units in the hidden layer of the NN")
flags.DEFINE_integer("train_steps", 1000000,
                     "Number of (global) training steps to perform")
flags.DEFINE_integer("batch_size", 100, "Training batch size")
flags.DEFINE_float("learning_rate", 0.01, "Learning rate")
flags.DEFINE_boolean("sync_replicas", False,
                     "Use the sync_replicas (synchronized replicas) mode, "
                     "wherein the parameter updates from workers are aggregated "
                     "before applied to avoid stale gradients")
#flags.DEFINE_boolean(
#    "existing_servers", False, "Whether servers already exists. If True, "
#    "will use the worker hosts via their GRPC URLs (one client process "
#    "per worker host). Otherwise, will create an in-process TensorFlow "
#    "server.")
flags.DEFINE_string("ps_hosts","192.168.233.201:2222",
                    "Comma-separated list of hostname:port pairs")
flags.DEFINE_string("worker_hosts", "192.168.233.202:2223,192.168.233.203:2224",
                    "Comma-separated list of hostname:port pairs")
flags.DEFINE_string("job_name", None,"job name: worker or ps")
FLAGS = flags.FLAGS
IMAGE_PIXELS = 28
def main(unused_argv):
  mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
#  if FLAGS.download_only:
#    sys.exit(0)
  if FLAGS.job_name is None or FLAGS.job_name == "":
    raise ValueError("Must specify an explicit `job_name`")
  if FLAGS.task_index is None or FLAGS.task_index =="":
    raise ValueError("Must specify an explicit `task_index`")
  print("job name = %s" % FLAGS.job_name)
  print("task index = %d" % FLAGS.task_index)
  #Construct the cluster and start the server
  ps_spec = FLAGS.ps_hosts.split(",")
  worker_spec = FLAGS.worker_hosts.split(",")
  # Get the number of workers.
  num_workers = len(worker_spec)
  cluster = tf.train.ClusterSpec({
      "ps": ps_spec,
      "worker": worker_spec})
  #if not FLAGS.existing_servers:
    # Not using existing servers. Create an in-process server.
  server = tf.train.Server(
      cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
  if FLAGS.job_name == "ps":
    server.join()
  is_chief = (FLAGS.task_index == 0)

#  if FLAGS.num_gpus > 0:
#    if FLAGS.num_gpus < num_workers:
#      raise ValueError("number of gpus is less than number of workers")
#    # Avoid gpu allocation conflict: now allocate task_num -> #gpu 
#    # for each worker in the corresponding machine
#    gpu = (FLAGS.task_index % FLAGS.num_gpus)
#    worker_device = "/job:worker/task:%d/gpu:%d" % (FLAGS.task_index, gpu)
#  elif FLAGS.num_gpus == 0:
#    # Just allocate the CPU to worker server
#    cpu = 0
#    worker_device = "/job:worker/task:%d/cpu:%d" % (FLAGS.task_index, cpu)
#  # The device setter will automatically place Variables ops on separate
#  # parameter servers (ps). The non-Variable ops will be placed on the workers.
#  # The ps use CPU and workers use corresponding GPU

  worker_device = "/job:worker/task:%d/gpu:0" % FLAGS.task_index
  with tf.device(
      tf.train.replica_device_setter(
          worker_device=worker_device,
          ps_device="/job:ps/cpu:0",
          cluster=cluster)):
    global_step = tf.Variable(0, name="global_step", trainable=False)
    # Variables of the hidden layer
    hid_w = tf.Variable(
        tf.truncated_normal(
            [IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units],
            stddev=1.0 / IMAGE_PIXELS),
        name="hid_w")
    hid_b = tf.Variable(tf.zeros([FLAGS.hidden_units]), name="hid_b")
    # Variables of the softmax layer
    sm_w = tf.Variable(
        tf.truncated_normal(
            [FLAGS.hidden_units, 10],
            stddev=1.0 / math.sqrt(FLAGS.hidden_units)),
        name="sm_w")
    sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
    # Ops: located on the worker specified with FLAGS.task_index
    x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS])
    y_ = tf.placeholder(tf.float32, [None, 10])
    hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
    hid = tf.nn.relu(hid_lin)
    y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
    cross_entropy = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
    opt = tf.train.AdamOptimizer(FLAGS.learning_rate)
    if FLAGS.sync_replicas:
      if FLAGS.replicas_to_aggregate is None:
        replicas_to_aggregate = num_workers
      else:
        replicas_to_aggregate = FLAGS.replicas_to_aggregate
      opt = tf.train.SyncReplicasOptimizer(
          opt,
          replicas_to_aggregate=replicas_to_aggregate,
          total_num_replicas=num_workers,
          replica_id=FLAGS.task_index,
          name="mnist_sync_replicas")
    train_step = opt.minimize(cross_entropy, global_step=global_step)
    if FLAGS.sync_replicas and is_chief:
      # Initial token and chief queue runners required by the sync_replicas mode
      chief_queue_runner = opt.get_chief_queue_runner()
      init_tokens_op = opt.get_init_tokens_op()
    init_op = tf.global_variables_initializer()
    train_dir = tempfile.mkdtemp()
    sv = tf.train.Supervisor(
        is_chief=is_chief,
        logdir=train_dir,
        init_op=init_op,
        recovery_wait_secs=1,
        global_step=global_step)
    sess_config = tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=False,
        device_filters=["/job:ps", "/job:worker/task:%d" % FLAGS.task_index])
    # The chief worker (task_index==0) session will prepare the session,
    # while the remaining workers will wait for the preparation to complete.
    if is_chief:
      print("Worker %d: Initializing session..." % FLAGS.task_index)
    else:
      print("Worker %d: Waiting for session to be initialized..." %
        FLAGS.task_index)
#    if FLAGS.existing_servers:
#      server_grpc_url = "grpc://" + worker_spec[FLAGS.task_index]
#      print("Using existing server at: %s" % server_grpc_url)
#
#      sess = sv.prepare_or_wait_for_session(server_grpc_url, config=sess_config)
#    else:
    sess = sv.prepare_or_wait_for_session(server.target,
                                        config=sess_config)
    print("Worker %d: Session initialization complete." % FLAGS.task_index)
    if FLAGS.sync_replicas and is_chief:
      # Chief worker will start the chief queue runner and call the init op
      print("Starting chief queue runner and running init_tokens_op")
      sv.start_queue_runners(sess, [chief_queue_runner])
      sess.run(init_tokens_op)
    # Perform training
    time_begin = time.time()
    print("Training begins @ %f" % time_begin)
    local_step = 0
    while True:
      # Training feed
      batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)
      train_feed = {x: batch_xs, y_: batch_ys}
      _, step = sess.run([train_step, global_step], feed_dict=train_feed)
      local_step += 1
      now = time.time()
      print("%f: Worker %d: training step %d done (global step: %d)" %
            (now, FLAGS.task_index, local_step, step))
      if step >= FLAGS.train_steps:
        break
    time_end = time.time()
    print("Training ends @ %f" % time_end)
    training_time = time_end - time_begin
    print("Training elapsed time: %f s" % training_time)
    # Validation feed
    val_feed = {x: mnist.validation.images, y_: mnist.validation.labels}
    val_xent = sess.run(cross_entropy, feed_dict=val_feed)
    print("After %d training step(s), validation cross entropy = %g" %
          (FLAGS.train_steps, val_xent))
if __name__ == "__main__":
  tf.app.run()

参考资料: 《TensorFlow实战》

欢迎付费咨询(150元每小时),我的微信:qingxingfengzi

© 著作权归作者所有

共有 人打赏支持
利炳根
粉丝 11
博文 60
码字总数 136346
作品 0
深圳
《BIG DATA大数据日知录 架构和算法》读书笔记

1.数据分片和路由 Hash Hash H(Key) = hash(key) mod K 虚拟桶(Virtual Buckets) 先hash到桶,在Hash,多加一层Hash便于扩展 一致性Hash 分布式Hash表(DHT),P2P对等网络,构成环,节点加...

selfless ⋅ 2016/06/18 ⋅ 4

Mahout安装与配置笔记

一、硬件环境 操作系统:Linux ubuntu-13.04-desktop-i386 jdk安装版本:jdk-7u51-linux-i586 Hadoop版本:Hadoop-1.1.1(一个Namenode,三个Datanode部署) 二、安装步骤 在Mahout安装之前读...

kartik ⋅ 2014/06/01 ⋅ 0

Golang学习笔记目录

Golang 介绍 Go语言是谷歌2009发布的第二款开源编程语言。 Go语言专门针对多处理器系统应用程序的编程进行了优化,使用Go编译的程序可以媲美C或C++代码的速度,而且更加安全、支持并行进程。...

ChainZhang ⋅ 2017/12/26 ⋅ 0

【Java并发性和多线程】Java并发性和多线程介绍

本文为转载学习 原文链接:http://tutorials.jenkov.com/java-concurrency/index.html 译文链接:http://ifeve.com/java-concurrency-thread/ 在过去单CPU时代,单任务在一个时间点只能执行单...

heroShane ⋅ 2014/01/28 ⋅ 0

01_线程基础(一)之笔记

1.1 并发编程的学习的目的 + 我们为什么要去学习并发编程? + 第一点,这对面试非常重要,是企业面试程序员的标准,是考察要素: 1. 考察我公司技术你是否熟悉50%以上,或我们公司有特殊的技...

圣洁之子 ⋅ 02/08 ⋅ 0

如何学习Liunx和个人学习大纲

提醒不甘平凡的我 2009大学刚刚毕业(不是正规大学,没来北京之前,在家里不是打架就是闲逛,家里人怕早晚出事,索性花钱上了这个“大学”)找了一家IT培训机构。当时的培训机构出名两家:“...

jcpokai521 ⋅ 2017/02/28 ⋅ 0

[Erlang 0060] Joe Armstrong 论文《面向软件错误构建可靠的分布式系统》笔记

周末读了两篇论文"On Designing and Deploying Internet-Scale Services"和Joe Armstrong的论文"面对软件错误构建可靠的分布式系统",这两篇论文实战内容相当多,整理笔记于此,备忘. On Design...

唐玄奘 ⋅ 2017/12/03 ⋅ 0

分布式计算平台开源了!

哪儿可以获取代码? https://github.com/chaoyongzhou/ebgn 如何编译? 执行make bgn即可,请常见main函数修改以提供不同的功能 平台能干什么? 如果你想做分布式计算、并行计算,可以选择该...

Hansoul ⋅ 2013/09/18 ⋅ 2

Git 学习笔记 (4):分支管理

4.1 多分支存在的意义 Git 不仅拥有分布式这样优秀的版本控制系统运作方式,而且具有强大的分支管理能力。多分支的存在,为项目的多人协作开发提供了便利,而强大的分支管理能力又进一步增强...

sailboat征帆 ⋅ 04/29 ⋅ 0

机器学习(三)--- scala学习笔记

Scala是一门多范式的编程语言,一种类似Java的编程语言,设计初衷是实现可伸缩的语言、并集成面向对象编程和函数式编程的各种特性。 Spark是UC Berkeley AMP lab所开源的类Hadoop MapReduce的...

云栖希望。 ⋅ 2017/12/04 ⋅ 0

没有更多内容

加载失败,请刷新页面

加载更多

下一页

SpringCloud 微服务 (六) 服务通信 RestTemplate

壹 通信的方式主要有两种,Http 和 RPC SpringCloud使用的是Http方式通信, Dubbo的通信方式是RPC 记录学习SpringCloud的restful方式: RestTemplate (本篇)、Feign 贰 RestTemplate 类似 Http...

___大侠 ⋅ 7分钟前 ⋅ 0

React创建组件的三种方式

1.无状态函数式组建 无状态函数式组件,也就是你无法使用State,也无法使用组件的生命周期方法,这就决定了函数组件都是展示性组件,接收Props,渲染DOM,而不关注其他逻辑。 无状态函数式组...

kimyeongnam ⋅ 13分钟前 ⋅ 0

react 判断实例类型

今天在写组件的时候想通过判断内部子元素不同而在父元素上应用不同的class,于是首先要解决的就是如何判断子元素的类型。 这里附上一个讲的很全面的文章: https://www.cnblogs.com/onepixel...

球球 ⋅ 20分钟前 ⋅ 0

Centos7备份数据到百度网盘

一、关于 有时候我们需要进行数据备份,如果能自动将数据备份到百度网盘,那将会非常方便。百度网盘有较大的存储空间,而且不怕数据丢失,安全可靠。下面简单的总结一下如何使用 bypy 实现百...

zctzl ⋅ 34分钟前 ⋅ 0

开启远程SSH

SSH默认没有开启账号密码登陆,需要再配置表中修改: vim /etc/ssh/sshd_configPermitRootLogin yes #是否可以使用root账户登陆PasswordAuthentication yes #是都开启密码登陆ser...

Kefy ⋅ 37分钟前 ⋅ 0

Zookeeper3.4.11+Hadoop2.7.6+Hbase2.0.0搭建分布式集群

有段时间没更新博客了,趁着最近有点时间,来完成之前关于集群部署方面的知识。今天主要讲一讲Zookeeper+Hadoop+Hbase分布式集群的搭建,在我前几篇的集群搭建的博客中已经分别讲过了Zookeep...

海岸线的曙光 ⋅ 44分钟前 ⋅ 0

js保留两位小数方法总结

本文是小编针对js保留两位小数这个大家经常遇到的经典问题整理了在各种情况下的函数写法以及遇到问题的分析,以下是全部内容: 一、我们首先从经典的“四舍五入”算法讲起 1、四舍五入的情况...

孟飞阳 ⋅ 今天 ⋅ 0

python log

python log 处理方式 log_demo.py: 日志代码。 #! /usr/bin/env python# -*- coding: utf-8 -*-# __author__ = "Q1mi""""logging配置"""import osimport logging.config# 定义三种......

inidcard ⋅ 今天 ⋅ 0

mysql 中的信息数据库以及 shell 查询 sql

Information_schema 是 MySQL 自带的信息数据库,里面的“表”保存着服务器当前的实时信息。它提供了访问数据库元数据的方式。 什么是元数据呢?元数据是关于数据的数据,如数据库名或表名,...

blackfoxya ⋅ 今天 ⋅ 0

maven配置阿里云镜像享受飞的感觉

1.在maven目录下的conf/setting.xml中找到mirrors添加如下内容,对所有使用改maven打包的项目生效。 <mirror> <id>alimaven</id> <name>aliyun maven</name> <url>http://maven.al......

kalnkaya ⋅ 今天 ⋅ 0

没有更多内容

加载失败,请刷新页面

加载更多

下一页

返回顶部
顶部