文档章节

TensorFlow学习系列(三):保存/恢复和混合多个模型

AllenOR灵感
 AllenOR灵感
发布于 2017/09/10 01:19
字数 1625
阅读 4
收藏 0
点赞 0
评论 0

这篇教程是翻译Morgan写的TensorFlow教程,作者已经授权翻译,这是原文


目录


TensorFlow学习系列(一):初识TensorFlow

TensorFlow学习系列(二):形状和动态维度

TensorFlow学习系列(三):保存/恢复和混合多个模型

TensorFlow学习系列(四):利用神经网络实现泛逼近器(universal approximator)

TensorFlow学习系列(五):如何使用队列和多线程优化输入管道


在学习这篇博客之前,我希望你已经掌握了Tensorflow基本的操作。如果没有,你可以阅读这篇入门文章

为什么要学习模型的保存和恢复呢?因为这对于避免数据的混乱无序是至关重要的,特别是在你代码中的不同图。

如何保存和加载模型

saver类

在不同的会话中,当需要将数据在硬盘上面进行保存时,那么我们就可以使用Saver这个类。这个Saver构造类允许你去控制3个目标:

  • 目标(The target):这个参数设置目标。在分布式架构的情况下,我们可以指定要计算哪个TF服务器或者“目标”。
  • 图(The graph):这个参数设置保存的图。保存你希望会话处理的图。对于初学者来说,这里有一件棘手的事情就是在Tensorflow中总是有一个默认的图,并且你所有的操作都是在这个图中首先进行。所有,你总是在“默认图范围”内。
  • 配置(The config):这个参数设置配置。你可以使用 ConfigProto 参数来进行配置Tensorflow。点击这里,查看更多信息。

Saver类可以处理你的图中元数据和变量数据的保存和恢复。而我们唯一需要做的是,告诉Saver类我们需要保存哪个图和哪些变量。

在默认情况下,Saver类能处理默认图中包含的所有变量。但是,你也可以去创建很多的Saver类,去保存你想要的任何子图。

import tensorflow as tf

# First, you design your mathematical operations
# We are the default graph scope

# Let's design a variable
v1 = tf.Variable(1. , name="v1")
v2 = tf.Variable(2. , name="v2")
# Let's design an operation
a = tf.add(v1, v2)

# Let's create a Saver object
# By default, the Saver handles every Variables related to the default graph
all_saver = tf.train.Saver() 
# But you can precise which vars you want to save under which name
v2_saver = tf.train.Saver({"v2": v2}) 

# By default the Session handles the default graph and all its included variables
with tf.Session() as sess:
  # Init v and v2   
  sess.run(tf.global_variables_initializer())
  # Now v1 holds the value 1.0 and v2 holds the value 2.0
  # We can now save all those values
  all_saver.save(sess, 'data.chkp')
  # or saves only v2
  v2_saver.save(sess, 'data-v2.chkp')

当你运行了上面的程序之后,如果你去看文件夹,那么你会发现文件夹中存在了七个文件(如下)。在接下来的博客中,我会详细解释这些文件的意义。目前你只需要知道,模型的权重是保存在 .chkp 文件中,模型的图是保存在 .chkp.meta 文件中。

├── checkpoint
├── data-v2.chkp.data-00000-of-00001
├── data-v2.chkp.index
├── data-v2.chkp.meta
├── data.chkp.data-00000-of-00001
├── data.chkp.index
├── data.chkp.meta

恢复操作和其它元数据

我想分享的最后一个信息是,Saver将保存与图有关联的任何元数据。这就意味着,当我们恢复一个模型的时候,我们还同时恢复了所有与图相关的变量、操作和集合。

当我们恢复一个元模型(restore a meta checkpoint)时,实际上我们执行的操作是将恢复的图载入到当前的默认图中。所有当你完成模型恢复之后,你可以在默认图中访问载入的任何内容,比如一个张量,一个操作或者集合。

import tensorflow as tf

# Let's laod a previous meta graph in the current graph in use: usually the default graph
# This actions returns a Saver
saver = tf.train.import_meta_graph('results/model.ckpt-1000.meta')

# We can now access the default graph where all our metadata has been loaded
graph = tf.get_default_graph()

# Finally we can retrieve tensors, operations, etc.
global_step_tensor = graph.get_tensor_by_name('loss/global_step:0')
train_op = graph.get_operation_by_name('loss/train_op')
hyperparameters = tf.get_collection('hyperparameters')

恢复权重

请记住,在实际的环境中,真实的权重只能存在于一个会话中。也就是说,restore 这个操作必须在一个会话中启动,然后将数据权重导入到图中。理解恢复操作的最好方法是将它简单的看做是一种数据初始化操作。

with tf.Session() as sess:
    # To initialize values with saved data
    saver.restore(sess, 'results/model.ckpt-1000-00000-of-00001')
    print(sess.run(global_step_tensor)) # returns 1000

在新图中导入预训练模型

至此,你应该已经明白了如何去保存和恢复一个模型。然而,我们还可以使用一些技巧去帮助你更快的保存和恢复一个模型。比如:

  • 一个图的输出能成为另一个图的输入吗?

答案是确定的。但是目前我的做法是先将第一个图进行保存,然后在另一个图中进行恢复。但是这种方案感觉很笨重,我不知道是否有更好的方法。

但是这种方法确实能工作,除非你想要去重新训练第一个图。在这种情况下,你需要将输入的梯度重新输入到第一张图中的特定的训练步骤中。我想你已经被这种复杂的方案给逼疯了把。:-)

  • 我可以在一个图中混合不同的图吗?

答案当然是肯定的,但是你必须非常小心命名空间。这种方法有一点好处是,简化了一切。比如,你可以预加载一个VGG-19模型。然后访问图中的任何节点,并执行你自己的后续操作,从而训练一整个完整的模型。

如果你只想微调你自己的节点,那么你可以在你想要的地方中断梯度。

import tensorflow as tf

# Load the VGG-16 model in the default graph
vgg_saver = tf.train.import_meta_graph(dir + '/vgg/results/vgg-16.meta')
# Access the graph
vgg_graph = tf.get_default_graph()

# Retrieve VGG inputs
self.x_plh = vgg_graph.get_tensor_by_name('input:0')

# Choose which node you want to connect your own graph
output_conv =vgg_graph.get_tensor_by_name('conv1_2:0')
# output_conv =vgg_graph.get_tensor_by_name('conv2_2:0')
# output_conv =vgg_graph.get_tensor_by_name('conv3_3:0')
# output_conv =vgg_graph.get_tensor_by_name('conv4_3:0')
# output_conv =vgg_graph.get_tensor_by_name('conv5_3:0')

# Stop the gradient for fine-tuning
output_conv_sg = tf.stop_gradient(output_conv) # It's an identity function

# Build further operations
output_conv_shape = output_conv_sg.get_shape().as_list()
W1 = tf.get_variable('W1', shape=[1, 1, output_conv_shape[3], 32], initializer=tf.random_normal_initializer(stddev=1e-1))
b1 = tf.get_variable('b1', shape=[32], initializer=tf.constant_initializer(0.1))
z1 = tf.nn.conv2d(output_conv_sg, W1, strides=[1, 1, 1, 1], padding='SAME') + b1
a = tf.nn.relu(z1)

References:

http://stackoverflow.com/questions/38947658/tensorflow-saving-into-loading-a-graph-from-a-file

http://stackoverflow.com/questions/34343259/is-there-an-example-on-how-to-generate-protobuf-files-holding-trained-tensorflow?rq=1

http://stackoverflow.com/questions/39468640/tensorflow-freeze-graph-py-the-name-save-const0-refers-to-a-tensor-which-doe?rq=1

http://stackoverflow.com/questions/33759623/tensorflow-how-to-restore-a-previously-saved-model-python

http://stackoverflow.com/questions/34500052/tensorflow-saving-and-restoring-session?noredirect=1&lq=1

http://stackoverflow.com/questions/35687678/using-a-pre-trained-word-embedding-word2vec-or-glove-in-tensorflow

https://github.com/jtoy/awesome-tensorflow

本文转载自:http://www.jianshu.com/p/8487db911d9a

共有 人打赏支持
AllenOR灵感
粉丝 10
博文 2139
码字总数 82983
作品 0
程序员
史上最全TensorFlow学习资源汇总

来源 悦动智能(公众号ID:aibbtcom) 本篇文章将为大家总结TensorFlow纯干货学习资源,非常适合新手学习,建议大家收藏。 ▌一 、TensorFlow教程资源 1)适合初学者的TensorFlow教程和代码示...

悦动智能
04/12
0
0
13- 深度学习之神经网络核心原理与算法-TensorFlow介绍与框架挑选

TensorFlow以及TensorFlow的应用 支持深度学习的框架。torch caffe TensorFlow 简介 使用图(Graph)来表示计算任务 图中的节点被称为op(operation) 一个op获取0个或多个tensor,执行计算,产生...

天涯明月笙
06/01
0
0
资源 | 概率编程工具:TensorFlow Probability官方简介

  选自Medium   作者:Josh Dillon、Mike Shwe、Dustin Tran   机器之心编译   参与:白妤昕、李泽南      在 2018 年 TensorFlow 开发者峰会上,谷歌发布了 TensorFlow Probabi...

机器之心
04/22
0
0
【干货】史上最全的Tensorflow学习资源汇总,速藏!

一 、Tensorflow教程资源: 1)适合初学者的Tensorflow教程和代码示例:(https://github.com/aymericdamien/TensorFlow-Examples)该教程不光提供了一些经典的数据集,更是从实现最简单的“Hel...

技术小能手
04/16
0
0
一步步上手TensorFlow——基础知识

之前发过了几篇关于机器学习的帖子,使用的框架多为TensorFlow。 TensorFlow 是一个用于人工智能的开源神器。作为常用的机器学习框架,可被用于语音识别或图像识别等多项机器学习和深度学习领...

BlackBlog__
05/14
0
0
《Scikit-Learn与TensorFlow机器学习实用指南》第9章 启动并运行TensorFlow

第9章 启动并运行TensorFlow 来源:ApacheCN《Sklearn 与 TensorFlow 机器学习实用指南》翻译项目 译者:@akonwang @WilsonQu 校对:@Lisanaaa @飞龙 TensorFlow 是一款用于数值计算的强大的...

apachecn_飞龙
04/23
0
0
入门 | TensorFlow的动态图工具Eager怎么用?这是一篇极简教程

  选自Github   作者:Madalina Buzau   机器之心编译   参与:王淑婷、泽南      去年 11 月,Google Brain 团队发布了 Eager Execution,一个由运行定义的新接口,为 TensorFl...

机器之心
06/14
0
0
扣丁学堂浅谈将TensorFlow的模型网络导出为单个文件的方法

今天给大家分享的是将TensorFlow的模型网络导出为单个文件的方法,喜欢Python开发的小伙伴和扣丁学堂Python在线学习小编一块来看一下吧。 有时候,我们需要将TensorFlow的模型导出为单个文件...

扣丁学堂
06/01
0
0
在Tensorflow Serving上部署基于LSTM的文本分类模型

一些重要的概念 Servables Servables 是客户端请求执行计算的基础对象,大小和粒度是灵活的。 Servables 不会管理自己的运行周期。 典型的Servables包括: Servable Versions Tensorflow ser...

liyonghong
02/02
0
0
如何部署tensorflow训练的模型

最近深度学习算法被广泛研究和应用,而tensorflow则是被应用最为广泛的工具。tensorflow训练的模型被应用在线上时,主要有3种方式(本文主要讨论java方向的应用): 1:java代码重写预测代码(...

lirainbow0
05/29
0
0

没有更多内容

加载失败,请刷新页面

加载更多

下一页

32.filter表案例 nat表应用 (iptables)

10.15 iptables filter表案例 10.16/10.17/10.18 iptables nat表应用 10.15 iptables filter表案例: ~1. 写一个具体的iptables小案例,需求是把80端口、22端口、21 端口放行。但是,22端口我...

王鑫linux
今天
0
0
shell中的函数&shell中的数组&告警系统需求分析

20.16/20.17 shell中的函数 20.18 shell中的数组 20.19 告警系统需求分析

影夜Linux
今天
0
0
Linux网络基础、Linux防火墙

Linux网络基础 ip addr 命令 :查看网口信息 ifconfig命令:查看网口信息,要比ip addr更明了一些 centos 7默认没安装ifconfig命令,可以使用yum install -y net-tools命令来安装。 ifconfig...

李超小牛子
今天
1
0
[机器学习]回归--Decision Tree Regression

CART决策树又称分类回归树,当数据集的因变量为连续性数值时,该树算法就是一个回归树,可以用叶节点观察的均值作为预测值;当数据集的因变量为离散型数值时,该树算法就是一个分类树,可以很...

wangxuwei
昨天
1
0
Redis做分布式无锁CAS的问题

因为Redis本身是单线程的,具备原子性,所以可以用来做分布式无锁的操作,但会有一点小问题。 public interface OrderService { public String getOrderNo();} public class OrderRe...

算法之名
昨天
9
0
143. Reorder List - LeetCode

Question 143. Reorder List Solution 题目大意:给一个链表,将这个列表分成前后两部分,后半部分反转,再将这两分链表的节点交替连接成一个新的链表 思路 :先将链表分成前后两部分,将后部...

yysue
昨天
1
0
数据结构与算法1

第一个代码,描述一个被称为BankAccount的类,该类模拟了银行中的账户操作。程序建立了一个开户金额,显示金额,存款,取款并显示余额。 主要的知识点联系为类的含义,构造函数,公有和私有。...

沉迷于编程的小菜菜
昨天
1
0
从为什么别的队伍总比你的快说起

在机场候检排队的时候,大多数情况下,别的队伍都要比自己所在的队伍快,并常常懊悔当初怎么没去那个队。 其实,最快的队伍只能有一个,而排队之前并不知道那个队快。所以,如果有六个队伍你...

我是菜鸟我骄傲
昨天
1
0
分布式事务常见的解决方案

随着互联网的发展,越来越多的多服务相互之间的调用,这时候就产生了一个问题,在单项目情况下很容易实现的事务控制(通过数据库的acid控制),变得不那么容易。 这时候就产生了多种方案: ...

小海bug
昨天
3
0
python从零学——scrapy初体验

python从零学——scrapy初体验 近日因为一些事情,需要从网上爬取一些东西,故而想通过使用爬虫来顺便学习下强大的python。现将一些学习中遇到的问题记录下来,以便日后查询 1. 开发环境的准...

咾咔叽
昨天
1
0

没有更多内容

加载失败,请刷新页面

加载更多

下一页

返回顶部
顶部