# tensorflow将训练好的模型freeze,即将权重固化到图里面,并使用该模型进行预测

2017/12/05 18:58

ML主要分为训练和预测两个阶段,此教程就是将训练好的模型freeze并保存下来.freeze的含义就是将该模型的图结构和该模型的权重固化到一起了.也即加载freeze的模型之后,立刻能够使用了。

#-*- coding:utf-8 -*-
import tensorflow as tf
import numpy as np

with tf.variable_scope('Placeholder'):
inputs_placeholder = tf.placeholder(tf.float32, name='inputs_placeholder', shape=[None, 10])
labels_placeholder = tf.placeholder(tf.float32, name='labels_placeholder', shape=[None, 1])

with tf.variable_scope('NN'):
W1 = tf.get_variable('W1', shape=[10, 1], initializer=tf.random_normal_initializer(stddev=1e-1))
b1 = tf.get_variable('b1', shape=[1], initializer=tf.constant_initializer(0.1))
W2 = tf.get_variable('W2', shape=[10, 1], initializer=tf.random_normal_initializer(stddev=1e-1))
b2 = tf.get_variable('b2', shape=[1], initializer=tf.constant_initializer(0.1))

a = tf.nn.relu(tf.matmul(inputs_placeholder, W1) + b1)
a2 = tf.nn.relu(tf.matmul(inputs_placeholder, W2) + b2)

with tf.variable_scope('Loss'):
loss = tf.reduce_sum(tf.square(y - labels_placeholder) / 2)

with tf.variable_scope('Accuracy'):
predictions = tf.greater(y, 0.5, name="predictions")
correct_predictions = tf.equal(predictions, tf.cast(labels_placeholder, tf.bool), name="correct_predictions")
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

# generate_data
inputs = np.random.choice(10, size=[10000, 10])
labels = (np.sum(inputs, axis=1) > 45).reshape(-1, 1).astype(np.float32)
print('inputs.shape:', inputs.shape)
print('labels.shape:', labels.shape)

test_inputs = np.random.choice(10, size=[100, 10])
test_labels = (np.sum(test_inputs, axis=1) > 45).reshape(-1, 1).astype(np.float32)
print('test_inputs.shape:', test_inputs.shape)
print('test_labels.shape:', test_labels.shape)

batch_size = 32
epochs = 10

batches = []
print("%d items in batch of %d gives us %d full batches and %d batches of %d items" % (
len(inputs),
batch_size,
len(inputs) // batch_size,
batch_size - len(inputs) // batch_size,
len(inputs) - (len(inputs) // batch_size) * 32)
)
for i in range(len(inputs) // batch_size):
batch = [ inputs[batch_size*i:batch_size*i+batch_size], labels[batch_size*i:batch_size*i+batch_size] ]
batches.append(list(batch))
if (i + 1) * batch_size < len(inputs):
batch = [ inputs[batch_size*(i + 1):],labels[batch_size*(i + 1):] ]
batches.append(list(batch))
print("Number of batches: %d" % len(batches))
print("Size of full batch: %d" % len(batches[0]))
print("Size if final batch: %d" % len(batches[-1]))

global_count = 0

with tf.Session() as sess:
#sv = tf.train.Supervisor()
#with sv.managed_session() as sess:
sess.run(tf.initialize_all_variables())
for i in range(epochs):
for batch in batches:
# print(batch[0].shape, batch[1].shape)
train_loss , _= sess.run([loss, train_op], feed_dict={
inputs_placeholder: batch[0],
labels_placeholder: batch[1]
})
# print('train_loss: %d' % train_loss)

if global_count % 100 == 0:
acc = sess.run(accuracy, feed_dict={
inputs_placeholder: test_inputs,
labels_placeholder: test_labels
})
print('accuracy: %f' % acc)
global_count += 1

acc = sess.run(accuracy, feed_dict={
inputs_placeholder: test_inputs,
labels_placeholder: test_labels
})
print("final accuracy: %f" % acc)
#在session当中就要将模型进行保存
saver = tf.train.Saver()
last_chkp = saver.save(sess, 'results/graph.chkp')
#sv.saver.save(sess, 'results/graph.chkp')

for op in tf.get_default_graph().get_operations():
print(op.name)


.data:存放的是权重参数

1、恢复我们保存的图

2、开启一个Session，然后载入该图要求的权重

4、将处理好的模型序列化之后保存

#-*- coding:utf-8 -*-
import os, argparse
import tensorflow as tf
from tensorflow.python.framework import graph_util

dir = os.path.dirname(os.path.realpath(__file__))

def freeze_graph(model_folder):
# We retrieve our checkpoint fullpath
checkpoint = tf.train.get_checkpoint_state(model_folder)
input_checkpoint = checkpoint.model_checkpoint_path

# We precise the file fullname of our freezed graph
absolute_model_folder = "/".join(input_checkpoint.split('/')[:-1])
output_graph = absolute_model_folder + "/frozen_model.pb"

# Before exporting our graph, we need to precise what is our output node
# this variables is plural, because you can have multiple output nodes
#freeze之前必须明确哪个是输出结点,也就是我们要得到推论结果的结点
#输出结点可以看我们模型的定义
#只有定义了输出结点,freeze才会把得到输出结点所必要的结点都保存下来,或者哪些结点可以丢弃
#所以,output_node_names必须根据不同的网络进行修改
output_node_names = "Accuracy/predictions"

# We clear the devices, to allow TensorFlow to control on the loading where it wants operations to be calculated
clear_devices = True

# We import the meta graph and retrive a Saver
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)

# We retrieve the protobuf graph definition
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()

#We start a session and restore the graph weights
#这边已经将训练好的参数加载进来,也即最后保存的模型是有图,并且图里面已经有参数了,所以才叫做是frozen
#相当于将参数已经固化在了图当中
with tf.Session() as sess:
saver.restore(sess, input_checkpoint)

# We use a built-in TF helper to export variables to constant
output_graph_def = graph_util.convert_variables_to_constants(
sess,
input_graph_def,
output_node_names.split(",") # We split on comma for convenience
)

# Finally we serialize and dump the output graph to the filesystem
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--model_folder", type=str, help="Model folder to export")
args = parser.parse_args()

freeze_graph(args.model_folder)


#-*- coding:utf-8 -*-
import argparse
import tensorflow as tf

# We parse the graph_def file
with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
graph_def = tf.GraphDef()

# We load the graph_def in the default graph
with tf.Graph().as_default() as graph:
tf.import_graph_def(
graph_def,
input_map=None,
return_elements=None,
name="prefix",
op_dict=None,
producer_op_list=None
)
return graph

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--frozen_model_filename", default="results/frozen_model.pb", type=str, help="Frozen model file to import")
args = parser.parse_args()
#加载已经将参数固化后的图

# We can list operations
#op.values() gives you a list of tensors it produces
#op.name gives you the name
#输入,输出结点也是operation,所以,我们可以得到operation的名字
for op in graph.get_operations():
print(op.name,op.values())
# prefix/Placeholder/inputs_placeholder
# ...
# prefix/Accuracy/predictions
#操作有:prefix/Placeholder/inputs_placeholder
#操作有:prefix/Accuracy/predictions
#为了预测,我们需要找到我们需要feed的tensor,那么就需要该tensor的名字
#注意prefix/Placeholder/inputs_placeholder仅仅是操作的名字,prefix/Placeholder/inputs_placeholder:0才是tensor的名字
x = graph.get_tensor_by_name('prefix/Placeholder/inputs_placeholder:0')
y = graph.get_tensor_by_name('prefix/Accuracy/predictions:0')

with tf.Session(graph=graph) as sess:
y_out = sess.run(y, feed_dict={
x: [[3, 5, 7, 4, 5, 1, 1, 1, 1, 1]] # < 45
})
print(y_out) # [[ 0.]] Yay!
print ("finish")


1、在预测的过程中,当把freeze后的模型加载进来后,我们只需要定义好输入的tensor和目标tensor即可

2、在这里要注意一下tensor_name和ops_name,

x = graph.get_tensor_by_name('prefix/Placeholder/inputs_placeholder:0')一定要使用tensor的名字

3、要获取图中ops的名字和对应的tensor的名字,可用如下的代码

    # We can list operations
#op.values() gives you a list of tensors it produces
#op.name gives you the name
#输入,输出结点也是operation,所以,我们可以得到operation的名字
for op in graph.get_operations():
print(op.name,op.values())

=============================================================================================================================

#-*- coding:utf-8 -*-
import tensorflow as tf
import numpy as np

with tf.variable_scope('Placeholder'):
inputs_placeholder = tf.placeholder(tf.float32, name='inputs_placeholder', shape=[None, 10])
labels_placeholder = tf.placeholder(tf.float32, name='labels_placeholder', shape=[None, 1])

with tf.variable_scope('NN'):
W1 = tf.get_variable('W1', shape=[10, 1], initializer=tf.random_normal_initializer(stddev=1e-1))
b1 = tf.get_variable('b1', shape=[1], initializer=tf.constant_initializer(0.1))
W2 = tf.get_variable('W2', shape=[10, 1], initializer=tf.random_normal_initializer(stddev=1e-1))
b2 = tf.get_variable('b2', shape=[1], initializer=tf.constant_initializer(0.1))

a = tf.nn.relu(tf.matmul(inputs_placeholder, W1) + b1)
a2 = tf.nn.relu(tf.matmul(inputs_placeholder, W2) + b2)

with tf.variable_scope('Loss'):
loss = tf.reduce_sum(tf.square(y - labels_placeholder) / 2)

with tf.variable_scope('Accuracy'):
predictions = tf.greater(y, 0.5, name="predictions")
correct_predictions = tf.equal(predictions, tf.cast(labels_placeholder, tf.bool), name="correct_predictions")
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

# generate_data
inputs = np.random.choice(10, size=[10000, 10])
labels = (np.sum(inputs, axis=1) > 45).reshape(-1, 1).astype(np.float32)
print('inputs.shape:', inputs.shape)
print('labels.shape:', labels.shape)

test_inputs = np.random.choice(10, size=[100, 10])
test_labels = (np.sum(test_inputs, axis=1) > 45).reshape(-1, 1).astype(np.float32)
print('test_inputs.shape:', test_inputs.shape)
print('test_labels.shape:', test_labels.shape)

batch_size = 32
epochs = 10

batches = []
print("%d items in batch of %d gives us %d full batches and %d batches of %d items" % (
len(inputs),
batch_size,
len(inputs) // batch_size,
batch_size - len(inputs) // batch_size,
len(inputs) - (len(inputs) // batch_size) * 32)
)
for i in range(len(inputs) // batch_size):
batch = [ inputs[batch_size*i:batch_size*i+batch_size], labels[batch_size*i:batch_size*i+batch_size] ]
batches.append(list(batch))
if (i + 1) * batch_size < len(inputs):
batch = [ inputs[batch_size*(i + 1):],labels[batch_size*(i + 1):] ]
batches.append(list(batch))
print("Number of batches: %d" % len(batches))
print("Size of full batch: %d" % len(batches[0]))
print("Size if final batch: %d" % len(batches[-1]))

global_count = 0

#with tf.Session() as sess:
sv = tf.train.Supervisor()
with sv.managed_session() as sess:
#sess.run(tf.initialize_all_variables())
for i in range(epochs):
for batch in batches:
# print(batch[0].shape, batch[1].shape)
train_loss , _= sess.run([loss, train_op], feed_dict={
inputs_placeholder: batch[0],
labels_placeholder: batch[1]
})
# print('train_loss: %d' % train_loss)

if global_count % 100 == 0:
acc = sess.run(accuracy, feed_dict={
inputs_placeholder: test_inputs,
labels_placeholder: test_labels
})
print('accuracy: %f' % acc)
global_count += 1

acc = sess.run(accuracy, feed_dict={
inputs_placeholder: test_inputs,
labels_placeholder: test_labels
})
print("final accuracy: %f" % acc)
#在session当中就要将模型进行保存
#saver = tf.train.Saver()
#last_chkp = saver.save(sess, 'results/graph.chkp')
sv.saver.save(sess, 'results/graph.chkp')

for op in tf.get_default_graph().get_operations():
print(op.name)

0 评论
0 收藏
0