• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests the graph freezing tool."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22import re
23
24from tensorflow.core.example import example_pb2
25from tensorflow.core.framework import graph_pb2
26from tensorflow.core.protobuf import saver_pb2
27from tensorflow.python.client import session
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import graph_io
30from tensorflow.python.framework import importer
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import test_util
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import math_ops
35from tensorflow.python.ops import nn
36from tensorflow.python.ops import parsing_ops
37from tensorflow.python.ops import partitioned_variables
38from tensorflow.python.ops import variable_scope
39from tensorflow.python.ops import variables
40from tensorflow.python.platform import test
41from tensorflow.python.saved_model import builder as saved_model_builder
42from tensorflow.python.saved_model import signature_constants
43from tensorflow.python.saved_model import signature_def_utils
44from tensorflow.python.saved_model import tag_constants
45from tensorflow.python.tools import freeze_graph
46from tensorflow.python.training import saver as saver_lib
47
48
49class FreezeGraphTest(test_util.TensorFlowTestCase):
50
51  def _testFreezeGraph(self, saver_write_version):
52
53    checkpoint_prefix = os.path.join(self.get_temp_dir(), "saved_checkpoint")
54    checkpoint_state_name = "checkpoint_state"
55    input_graph_name = "input_graph.pb"
56    output_graph_name = "output_graph.pb"
57
58    # We'll create an input graph that has a single variable containing 1.0,
59    # and that then multiplies it by 2.
60    with ops.Graph().as_default():
61      variable_node = variables.VariableV1(1.0, name="variable_node")
62      output_node = math_ops.multiply(variable_node, 2.0, name="output_node")
63      sess = session.Session()
64      init = variables.global_variables_initializer()
65      sess.run(init)
66      output = sess.run(output_node)
67      self.assertNear(2.0, output, 0.00001)
68      saver = saver_lib.Saver(write_version=saver_write_version)
69      checkpoint_path = saver.save(
70          sess,
71          checkpoint_prefix,
72          global_step=0,
73          latest_filename=checkpoint_state_name)
74      graph_io.write_graph(sess.graph, self.get_temp_dir(), input_graph_name)
75
76    # We save out the graph to disk, and then call the const conversion
77    # routine.
78    input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name)
79    input_saver_def_path = ""
80    input_binary = False
81    output_node_names = "output_node"
82    restore_op_name = "save/restore_all"
83    filename_tensor_name = "save/Const:0"
84    output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name)
85    clear_devices = False
86
87    freeze_graph.freeze_graph(
88        input_graph_path,
89        input_saver_def_path,
90        input_binary,
91        checkpoint_path,
92        output_node_names,
93        restore_op_name,
94        filename_tensor_name,
95        output_graph_path,
96        clear_devices,
97        "",
98        "",
99        "",
100        checkpoint_version=saver_write_version)
101
102    # Now we make sure the variable is now a constant, and that the graph still
103    # produces the expected result.
104    with ops.Graph().as_default():
105      output_graph_def = graph_pb2.GraphDef()
106      with open(output_graph_path, "rb") as f:
107        output_graph_def.ParseFromString(f.read())
108        _ = importer.import_graph_def(output_graph_def, name="")
109
110      self.assertEqual(4, len(output_graph_def.node))
111      for node in output_graph_def.node:
112        self.assertNotEqual("VariableV2", node.op)
113        self.assertNotEqual("Variable", node.op)
114
115      with session.Session() as sess:
116        output_node = sess.graph.get_tensor_by_name("output_node:0")
117        output = sess.run(output_node)
118        self.assertNear(2.0, output, 0.00001)
119
120  def _createTFExampleString(self, feature_name, feature_value):
121    """Create a serialized tensorflow example."""
122    example = example_pb2.Example()
123    example.features.feature[feature_name].float_list.value.extend([
124        feature_value])
125    return example.SerializeToString()
126
127  def _writeDummySavedModel(self, path, feature_name):
128    """Writes a classifier with two input features to the given path."""
129    with ops.Graph().as_default():
130      examples = array_ops.placeholder(dtypes.string, name="input_node")
131      feature_configs = {
132          feature_name: parsing_ops.FixedLenFeature(shape=[],
133                                                    dtype=dtypes.float32),
134      }
135      features = parsing_ops.parse_example(examples, feature_configs)
136      feature = features[feature_name]
137
138      variable_node = variables.VariableV1(1.0, name="variable_node")
139      scores = math_ops.multiply(variable_node, feature, name="output_node")
140      class_feature = array_ops.fill(array_ops.shape(feature),
141                                     "class_%s" % feature_name)
142      classes = array_ops.transpose(class_feature)
143
144      with session.Session() as sess:
145        sess.run(variables.global_variables_initializer())
146        signature = (
147            signature_def_utils.classification_signature_def(
148                examples=examples,
149                classes=classes,
150                scores=scores,))
151        builder = saved_model_builder.SavedModelBuilder(path)
152        builder.add_meta_graph_and_variables(
153            sess,
154            [tag_constants.SERVING],
155            signature_def_map={
156                signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
157                    signature,
158            },)
159        builder.save(as_text=True)
160
161  @test_util.run_v1_only("b/120545219")
162  def testFreezeGraphV1(self):
163    self._testFreezeGraph(saver_pb2.SaverDef.V1)
164
165  @test_util.run_v1_only("b/120545219")
166  def testFreezeGraphV2(self):
167    self._testFreezeGraph(saver_pb2.SaverDef.V2)
168
169  def testFreezeMetaGraph(self):
170    tmp_dir = self.get_temp_dir()
171    checkpoint_prefix = os.path.join(tmp_dir, "meta_graph_checkpoint")
172    checkpoint_state_name = "checkpoint_state"
173    output_graph_filename = os.path.join(tmp_dir, "output_graph.pb")
174
175    with ops.Graph().as_default():
176      variable_node = variables.VariableV1(1.0, name="variable_node")
177      output_node = math_ops.multiply(variable_node, 2.0, name="output_node")
178      sess = session.Session()
179      init = variables.global_variables_initializer()
180      sess.run(init)
181      output = sess.run(output_node)
182      self.assertNear(2.0, output, 0.00001)
183      saver = saver_lib.Saver()
184      checkpoint_path = saver.save(
185          sess,
186          checkpoint_prefix,
187          global_step=0,
188          latest_filename=checkpoint_state_name)
189
190    input_saver_def_path = ""
191    input_binary = True
192    output_node_names = "output_node"
193    restore_op_name = "save/restore_all"
194    filename_tensor_name = "save/Const:0"
195    clear_devices = False
196    input_meta_graph = checkpoint_path + ".meta"
197
198    freeze_graph.freeze_graph(
199        "", input_saver_def_path, input_binary, checkpoint_path,
200        output_node_names, restore_op_name, filename_tensor_name,
201        output_graph_filename, clear_devices, "", "", "", input_meta_graph)
202
203    # Now we make sure the variable is now a constant, and that the graph still
204    # produces the expected result.
205    with ops.Graph().as_default():
206      output_graph_def = graph_pb2.GraphDef()
207      with open(output_graph_filename, "rb") as f:
208        output_graph_def.ParseFromString(f.read())
209        _ = importer.import_graph_def(output_graph_def, name="")
210
211      self.assertEqual(4, len(output_graph_def.node))
212      for node in output_graph_def.node:
213        self.assertNotEqual("VariableV2", node.op)
214        self.assertNotEqual("Variable", node.op)
215
216      with session.Session() as sess:
217        output_node = sess.graph.get_tensor_by_name("output_node:0")
218        output = sess.run(output_node)
219        self.assertNear(2.0, output, 0.00001)
220
221  def testFreezeSavedModel(self):
222    tmp_dir = self.get_temp_dir()
223    saved_model_dir = os.path.join(tmp_dir, "saved_model_dir")
224    feature_name = "feature"
225    self._writeDummySavedModel(saved_model_dir, feature_name)
226    output_graph_filename = os.path.join(tmp_dir, "output_graph.pb")
227
228    input_saved_model_dir = saved_model_dir
229    output_node_names = "output_node"
230    input_binary = False
231    input_saver_def_path = False
232    restore_op_name = None
233    filename_tensor_name = None
234    clear_devices = False
235    input_meta_graph = False
236    checkpoint_path = None
237    input_graph_filename = None
238    saved_model_tags = tag_constants.SERVING
239
240    freeze_graph.freeze_graph(input_graph_filename, input_saver_def_path,
241                              input_binary, checkpoint_path, output_node_names,
242                              restore_op_name, filename_tensor_name,
243                              output_graph_filename, clear_devices, "", "", "",
244                              input_meta_graph, input_saved_model_dir,
245                              saved_model_tags)
246
247    # Now we make sure the variable is now a constant, and that the graph still
248    # produces the expected result.
249    with ops.Graph().as_default():
250      output_graph_def = graph_pb2.GraphDef()
251      with open(output_graph_filename, "rb") as f:
252        output_graph_def.ParseFromString(f.read())
253        _ = importer.import_graph_def(output_graph_def, name="")
254
255      self.assertEqual(8, len(output_graph_def.node))
256      for node in output_graph_def.node:
257        self.assertNotEqual("VariableV2", node.op)
258        self.assertNotEqual("Variable", node.op)
259
260      feature_value = 2.0
261      example = self._createTFExampleString(feature_name, feature_value)
262      with session.Session() as sess:
263        input_node = sess.graph.get_tensor_by_name("input_node:0")
264        output_node = sess.graph.get_tensor_by_name("output_node:0")
265        output = sess.run(output_node, feed_dict={input_node: [example]})
266        self.assertNear(feature_value, output, 0.00001)
267
268  def testSinglePartitionedVariable(self):
269    """Ensures partitioned variables fail cleanly with freeze graph."""
270    checkpoint_prefix = os.path.join(self.get_temp_dir(), "saved_checkpoint")
271    checkpoint_state_name = "checkpoint_state"
272    input_graph_name = "input_graph.pb"
273    output_graph_name = "output_graph.pb"
274
275    # Create a graph with partition variables. When weights are partitioned into
276    # a single partition, the weights variable is followed by a identity ->
277    # identity (an additional identity node).
278    partitioner = partitioned_variables.fixed_size_partitioner(1)
279    with ops.Graph().as_default():
280      with variable_scope.variable_scope("part", partitioner=partitioner):
281        batch_size, height, width, depth = 5, 128, 128, 3
282        input1 = array_ops.zeros(
283            (batch_size, height, width, depth), name="input1")
284        input2 = array_ops.zeros(
285            (batch_size, height, width, depth), name="input2")
286
287        num_nodes = depth
288        filter1 = variable_scope.get_variable("filter", [num_nodes, num_nodes])
289        filter2 = array_ops.reshape(filter1, [1, 1, num_nodes, num_nodes])
290        conv = nn.conv2d(
291            input=input1, filter=filter2, strides=[1, 1, 1, 1], padding="SAME")
292        node = math_ops.add(conv, input2, name="test/add")
293        node = nn.relu6(node, name="test/relu6")
294
295      # Save graph and checkpoints.
296      sess = session.Session()
297      sess.run(variables.global_variables_initializer())
298
299      saver = saver_lib.Saver()
300      checkpoint_path = saver.save(
301          sess,
302          checkpoint_prefix,
303          global_step=0,
304          latest_filename=checkpoint_state_name)
305      graph_io.write_graph(sess.graph, self.get_temp_dir(), input_graph_name)
306
307      # Ensure this graph has partition variables.
308      self.assertTrue([
309          tensor.name.split(":")[0]
310          for op in sess.graph.get_operations()
311          for tensor in op.values()
312          if re.search(r"/part_\d+/", tensor.name)
313      ])
314
315    # Test freezing graph doesn't make it crash.
316    output_node_names = "save/restore_all"
317    output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name)
318
319    return_value = freeze_graph.freeze_graph_with_def_protos(
320        input_graph_def=sess.graph_def,
321        input_saver_def=None,
322        input_checkpoint=checkpoint_path,
323        output_node_names=output_node_names,
324        restore_op_name="save/restore_all",  # default value
325        filename_tensor_name="save/Const:0",  # default value
326        output_graph=output_graph_path,
327        clear_devices=False,
328        initializer_nodes="")
329    self.assertTrue(return_value, -1)
330
331
332if __name__ == "__main__":
333  test.main()
334