• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests the graph freezing tool."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22import re
23
24from absl.testing import parameterized
25
26from tensorflow.core.example import example_pb2
27from tensorflow.core.framework import graph_pb2
28from tensorflow.core.protobuf import saver_pb2
29from tensorflow.python.client import session
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import graph_io
32from tensorflow.python.framework import importer
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import test_util
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import math_ops
37from tensorflow.python.ops import nn
38from tensorflow.python.ops import parsing_ops
39from tensorflow.python.ops import partitioned_variables
40from tensorflow.python.ops import variable_scope
41from tensorflow.python.ops import variables
42from tensorflow.python.platform import test
43from tensorflow.python.saved_model import builder as saved_model_builder
44from tensorflow.python.saved_model import signature_constants
45from tensorflow.python.saved_model import signature_def_utils
46from tensorflow.python.saved_model import tag_constants
47from tensorflow.python.tools import freeze_graph
48from tensorflow.python.training import saver as saver_lib
49
50
51class FreezeGraphTest(test_util.TensorFlowTestCase, parameterized.TestCase):
52
53  def _testFreezeGraph(self, saver_write_version):
54
55    checkpoint_prefix = os.path.join(self.get_temp_dir(), "saved_checkpoint")
56    checkpoint_state_name = "checkpoint_state"
57    input_graph_name = "input_graph.pb"
58    output_graph_name = "output_graph.pb"
59
60    # We'll create an input graph that has a single variable containing 1.0,
61    # and that then multiplies it by 2.
62    with ops.Graph().as_default():
63      variable_node = variables.VariableV1(1.0, name="variable_node")
64      output_node = math_ops.multiply(variable_node, 2.0, name="output_node")
65      sess = session.Session()
66      init = variables.global_variables_initializer()
67      sess.run(init)
68      output = sess.run(output_node)
69      self.assertNear(2.0, output, 0.00001)
70      saver = saver_lib.Saver(write_version=saver_write_version)
71      checkpoint_path = saver.save(
72          sess,
73          checkpoint_prefix,
74          global_step=0,
75          latest_filename=checkpoint_state_name)
76      graph_io.write_graph(sess.graph, self.get_temp_dir(), input_graph_name)
77
78    # We save out the graph to disk, and then call the const conversion
79    # routine.
80    input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name)
81    input_saver_def_path = ""
82    input_binary = False
83    output_node_names = "output_node"
84    restore_op_name = "save/restore_all"
85    filename_tensor_name = "save/Const:0"
86    output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name)
87    clear_devices = False
88
89    freeze_graph.freeze_graph(
90        input_graph_path,
91        input_saver_def_path,
92        input_binary,
93        checkpoint_path,
94        output_node_names,
95        restore_op_name,
96        filename_tensor_name,
97        output_graph_path,
98        clear_devices,
99        "",
100        "",
101        "",
102        checkpoint_version=saver_write_version)
103
104    # Now we make sure the variable is now a constant, and that the graph still
105    # produces the expected result.
106    with ops.Graph().as_default():
107      output_graph_def = graph_pb2.GraphDef()
108      with open(output_graph_path, "rb") as f:
109        output_graph_def.ParseFromString(f.read())
110        _ = importer.import_graph_def(output_graph_def, name="")
111
112      self.assertEqual(4, len(output_graph_def.node))
113      for node in output_graph_def.node:
114        self.assertNotEqual("VariableV2", node.op)
115        self.assertNotEqual("Variable", node.op)
116
117      with session.Session() as sess:
118        output_node = sess.graph.get_tensor_by_name("output_node:0")
119        output = sess.run(output_node)
120        self.assertNear(2.0, output, 0.00001)
121
122  def _createTFExampleString(self, feature_name, feature_value):
123    """Create a serialized tensorflow example."""
124    example = example_pb2.Example()
125    example.features.feature[feature_name].float_list.value.extend([
126        feature_value])
127    return example.SerializeToString()
128
129  def _writeDummySavedModel(self, path, feature_name, tags):
130    """Writes a classifier with two input features to the given path."""
131    with ops.Graph().as_default():
132      examples = array_ops.placeholder(dtypes.string, name="input_node")
133      feature_configs = {
134          feature_name: parsing_ops.FixedLenFeature(shape=[],
135                                                    dtype=dtypes.float32),
136      }
137      features = parsing_ops.parse_example(examples, feature_configs)
138      feature = features[feature_name]
139
140      variable_node = variables.VariableV1(1.0, name="variable_node")
141      scores = math_ops.multiply(variable_node, feature, name="output_node")
142      class_feature = array_ops.fill(array_ops.shape(feature),
143                                     "class_%s" % feature_name)
144      classes = array_ops.transpose(class_feature)
145
146      with session.Session() as sess:
147        sess.run(variables.global_variables_initializer())
148        signature = (
149            signature_def_utils.classification_signature_def(
150                examples=examples,
151                classes=classes,
152                scores=scores,))
153        builder = saved_model_builder.SavedModelBuilder(path)
154        builder.add_meta_graph_and_variables(
155            sess,
156            tags,
157            signature_def_map={
158                signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
159                    signature,
160            },
161        )
162        builder.save(as_text=True)
163
164  @test_util.run_v1_only("b/120545219")
165  def testFreezeGraphV1(self):
166    self._testFreezeGraph(saver_pb2.SaverDef.V1)
167
168  @test_util.run_v1_only("b/120545219")
169  def testFreezeGraphV2(self):
170    self._testFreezeGraph(saver_pb2.SaverDef.V2)
171
172  def testFreezeMetaGraph(self):
173    tmp_dir = self.get_temp_dir()
174    checkpoint_prefix = os.path.join(tmp_dir, "meta_graph_checkpoint")
175    checkpoint_state_name = "checkpoint_state"
176    output_graph_filename = os.path.join(tmp_dir, "output_graph.pb")
177
178    with ops.Graph().as_default():
179      variable_node = variables.VariableV1(1.0, name="variable_node")
180      output_node = math_ops.multiply(variable_node, 2.0, name="output_node")
181      sess = session.Session()
182      init = variables.global_variables_initializer()
183      sess.run(init)
184      output = sess.run(output_node)
185      self.assertNear(2.0, output, 0.00001)
186      saver = saver_lib.Saver()
187      checkpoint_path = saver.save(
188          sess,
189          checkpoint_prefix,
190          global_step=0,
191          latest_filename=checkpoint_state_name)
192
193    input_saver_def_path = ""
194    input_binary = True
195    output_node_names = "output_node"
196    restore_op_name = "save/restore_all"
197    filename_tensor_name = "save/Const:0"
198    clear_devices = False
199    input_meta_graph = checkpoint_path + ".meta"
200
201    freeze_graph.freeze_graph(
202        "", input_saver_def_path, input_binary, checkpoint_path,
203        output_node_names, restore_op_name, filename_tensor_name,
204        output_graph_filename, clear_devices, "", "", "", input_meta_graph)
205
206    # Now we make sure the variable is now a constant, and that the graph still
207    # produces the expected result.
208    with ops.Graph().as_default():
209      output_graph_def = graph_pb2.GraphDef()
210      with open(output_graph_filename, "rb") as f:
211        output_graph_def.ParseFromString(f.read())
212        _ = importer.import_graph_def(output_graph_def, name="")
213
214      self.assertEqual(4, len(output_graph_def.node))
215      for node in output_graph_def.node:
216        self.assertNotEqual("VariableV2", node.op)
217        self.assertNotEqual("Variable", node.op)
218
219      with session.Session() as sess:
220        output_node = sess.graph.get_tensor_by_name("output_node:0")
221        output = sess.run(output_node)
222        self.assertNear(2.0, output, 0.00001)
223
224  @parameterized.named_parameters(
225      ("empty_tags_set", "", []),
226      ("default_tags_set", tag_constants.SERVING, [tag_constants.SERVING]))
227  def testFreezeSavedModel(self, tags_string, tags_list):
228    tmp_dir = self.get_temp_dir()
229    saved_model_dir = os.path.join(tmp_dir, "saved_model_dir")
230    feature_name = "feature"
231    self._writeDummySavedModel(saved_model_dir, feature_name, tags_list)
232    output_graph_filename = os.path.join(tmp_dir, "output_graph.pb")
233
234    input_saved_model_dir = saved_model_dir
235    output_node_names = "output_node"
236    input_binary = False
237    input_saver_def_path = False
238    restore_op_name = None
239    filename_tensor_name = None
240    clear_devices = False
241    input_meta_graph = False
242    checkpoint_path = None
243    input_graph_filename = None
244    saved_model_tags = tags_string
245
246    freeze_graph.freeze_graph(input_graph_filename, input_saver_def_path,
247                              input_binary, checkpoint_path, output_node_names,
248                              restore_op_name, filename_tensor_name,
249                              output_graph_filename, clear_devices, "", "", "",
250                              input_meta_graph, input_saved_model_dir,
251                              saved_model_tags)
252
253    # Now we make sure the variable is now a constant, and that the graph still
254    # produces the expected result.
255    with ops.Graph().as_default():
256      output_graph_def = graph_pb2.GraphDef()
257      with open(output_graph_filename, "rb") as f:
258        output_graph_def.ParseFromString(f.read())
259        _ = importer.import_graph_def(output_graph_def, name="")
260
261      if any(u"ParseExampleV2" in node.name for node in output_graph_def.node):
262        expected_node_count = 10
263      else:
264        expected_node_count = 8
265      self.assertEqual(expected_node_count, len(output_graph_def.node))
266      for node in output_graph_def.node:
267        self.assertNotEqual("VariableV2", node.op)
268        self.assertNotEqual("Variable", node.op)
269
270      feature_value = 2.0
271      example = self._createTFExampleString(feature_name, feature_value)
272      with session.Session() as sess:
273        input_node = sess.graph.get_tensor_by_name("input_node:0")
274        output_node = sess.graph.get_tensor_by_name("output_node:0")
275        output = sess.run(output_node, feed_dict={input_node: [example]})
276        self.assertNear(feature_value, output, 0.00001)
277
278  def testSinglePartitionedVariable(self):
279    """Ensures partitioned variables fail cleanly with freeze graph."""
280    checkpoint_prefix = os.path.join(self.get_temp_dir(), "saved_checkpoint")
281    checkpoint_state_name = "checkpoint_state"
282    input_graph_name = "input_graph.pb"
283    output_graph_name = "output_graph.pb"
284
285    # Create a graph with partition variables. When weights are partitioned into
286    # a single partition, the weights variable is followed by a identity ->
287    # identity (an additional identity node).
288    partitioner = partitioned_variables.fixed_size_partitioner(1)
289    with ops.Graph().as_default():
290      with variable_scope.variable_scope("part", partitioner=partitioner):
291        batch_size, height, width, depth = 5, 128, 128, 3
292        input1 = array_ops.zeros(
293            (batch_size, height, width, depth), name="input1")
294        input2 = array_ops.zeros(
295            (batch_size, height, width, depth), name="input2")
296
297        num_nodes = depth
298        filter1 = variable_scope.get_variable("filter", [num_nodes, num_nodes])
299        filter2 = array_ops.reshape(filter1, [1, 1, num_nodes, num_nodes])
300        conv = nn.conv2d(
301            input=input1, filter=filter2, strides=[1, 1, 1, 1], padding="SAME")
302        node = math_ops.add(conv, input2, name="test/add")
303        node = nn.relu6(node, name="test/relu6")
304
305      # Save graph and checkpoints.
306      sess = session.Session()
307      sess.run(variables.global_variables_initializer())
308
309      saver = saver_lib.Saver()
310      checkpoint_path = saver.save(
311          sess,
312          checkpoint_prefix,
313          global_step=0,
314          latest_filename=checkpoint_state_name)
315      graph_io.write_graph(sess.graph, self.get_temp_dir(), input_graph_name)
316
317      # Ensure this graph has partition variables.
318      self.assertTrue([
319          tensor.name.split(":")[0]
320          for op in sess.graph.get_operations()
321          for tensor in op.values()
322          if re.search(r"/part_\d+/", tensor.name)
323      ])
324
325    # Test freezing graph doesn't make it crash.
326    output_node_names = "save/restore_all"
327    output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name)
328
329    with self.assertRaises(ValueError):
330      freeze_graph.freeze_graph_with_def_protos(
331          input_graph_def=sess.graph_def,
332          input_saver_def=None,
333          input_checkpoint=checkpoint_path,
334          output_node_names=output_node_names,
335          restore_op_name="save/restore_all",  # default value
336          filename_tensor_name="save/Const:0",  # default value
337          output_graph=output_graph_path,
338          clear_devices=False,
339          initializer_nodes="")
340
341
342if __name__ == "__main__":
343  test.main()
344