1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests the graph freezing tool.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22import re 23 24from absl.testing import parameterized 25 26from tensorflow.core.example import example_pb2 27from tensorflow.core.framework import graph_pb2 28from tensorflow.core.protobuf import saver_pb2 29from tensorflow.python.client import session 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import graph_io 32from tensorflow.python.framework import importer 33from tensorflow.python.framework import ops 34from tensorflow.python.framework import test_util 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import math_ops 37from tensorflow.python.ops import nn 38from tensorflow.python.ops import parsing_ops 39from tensorflow.python.ops import partitioned_variables 40from tensorflow.python.ops import variable_scope 41from tensorflow.python.ops import variables 42from tensorflow.python.platform import test 43from tensorflow.python.saved_model import builder as saved_model_builder 44from tensorflow.python.saved_model import signature_constants 45from tensorflow.python.saved_model import signature_def_utils 46from tensorflow.python.saved_model import tag_constants 47from tensorflow.python.tools import freeze_graph 48from tensorflow.python.training import saver as saver_lib 49 50 51class FreezeGraphTest(test_util.TensorFlowTestCase, parameterized.TestCase): 52 53 def _testFreezeGraph(self, saver_write_version): 54 55 checkpoint_prefix = os.path.join(self.get_temp_dir(), "saved_checkpoint") 56 checkpoint_state_name = "checkpoint_state" 57 input_graph_name = "input_graph.pb" 58 output_graph_name = "output_graph.pb" 59 60 # We'll create an input graph that has a single variable containing 1.0, 61 # and that then multiplies it by 2. 62 with ops.Graph().as_default(): 63 variable_node = variables.VariableV1(1.0, name="variable_node") 64 output_node = math_ops.multiply(variable_node, 2.0, name="output_node") 65 sess = session.Session() 66 init = variables.global_variables_initializer() 67 sess.run(init) 68 output = sess.run(output_node) 69 self.assertNear(2.0, output, 0.00001) 70 saver = saver_lib.Saver(write_version=saver_write_version) 71 checkpoint_path = saver.save( 72 sess, 73 checkpoint_prefix, 74 global_step=0, 75 latest_filename=checkpoint_state_name) 76 graph_io.write_graph(sess.graph, self.get_temp_dir(), input_graph_name) 77 78 # We save out the graph to disk, and then call the const conversion 79 # routine. 80 input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name) 81 input_saver_def_path = "" 82 input_binary = False 83 output_node_names = "output_node" 84 restore_op_name = "save/restore_all" 85 filename_tensor_name = "save/Const:0" 86 output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name) 87 clear_devices = False 88 89 freeze_graph.freeze_graph( 90 input_graph_path, 91 input_saver_def_path, 92 input_binary, 93 checkpoint_path, 94 output_node_names, 95 restore_op_name, 96 filename_tensor_name, 97 output_graph_path, 98 clear_devices, 99 "", 100 "", 101 "", 102 checkpoint_version=saver_write_version) 103 104 # Now we make sure the variable is now a constant, and that the graph still 105 # produces the expected result. 106 with ops.Graph().as_default(): 107 output_graph_def = graph_pb2.GraphDef() 108 with open(output_graph_path, "rb") as f: 109 output_graph_def.ParseFromString(f.read()) 110 _ = importer.import_graph_def(output_graph_def, name="") 111 112 self.assertEqual(4, len(output_graph_def.node)) 113 for node in output_graph_def.node: 114 self.assertNotEqual("VariableV2", node.op) 115 self.assertNotEqual("Variable", node.op) 116 117 with session.Session() as sess: 118 output_node = sess.graph.get_tensor_by_name("output_node:0") 119 output = sess.run(output_node) 120 self.assertNear(2.0, output, 0.00001) 121 122 def _createTFExampleString(self, feature_name, feature_value): 123 """Create a serialized tensorflow example.""" 124 example = example_pb2.Example() 125 example.features.feature[feature_name].float_list.value.extend([ 126 feature_value]) 127 return example.SerializeToString() 128 129 def _writeDummySavedModel(self, path, feature_name, tags): 130 """Writes a classifier with two input features to the given path.""" 131 with ops.Graph().as_default(): 132 examples = array_ops.placeholder(dtypes.string, name="input_node") 133 feature_configs = { 134 feature_name: parsing_ops.FixedLenFeature(shape=[], 135 dtype=dtypes.float32), 136 } 137 features = parsing_ops.parse_example(examples, feature_configs) 138 feature = features[feature_name] 139 140 variable_node = variables.VariableV1(1.0, name="variable_node") 141 scores = math_ops.multiply(variable_node, feature, name="output_node") 142 class_feature = array_ops.fill(array_ops.shape(feature), 143 "class_%s" % feature_name) 144 classes = array_ops.transpose(class_feature) 145 146 with session.Session() as sess: 147 sess.run(variables.global_variables_initializer()) 148 signature = ( 149 signature_def_utils.classification_signature_def( 150 examples=examples, 151 classes=classes, 152 scores=scores,)) 153 builder = saved_model_builder.SavedModelBuilder(path) 154 builder.add_meta_graph_and_variables( 155 sess, 156 tags, 157 signature_def_map={ 158 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: 159 signature, 160 }, 161 ) 162 builder.save(as_text=True) 163 164 @test_util.run_v1_only("b/120545219") 165 def testFreezeGraphV1(self): 166 self._testFreezeGraph(saver_pb2.SaverDef.V1) 167 168 @test_util.run_v1_only("b/120545219") 169 def testFreezeGraphV2(self): 170 self._testFreezeGraph(saver_pb2.SaverDef.V2) 171 172 def testFreezeMetaGraph(self): 173 tmp_dir = self.get_temp_dir() 174 checkpoint_prefix = os.path.join(tmp_dir, "meta_graph_checkpoint") 175 checkpoint_state_name = "checkpoint_state" 176 output_graph_filename = os.path.join(tmp_dir, "output_graph.pb") 177 178 with ops.Graph().as_default(): 179 variable_node = variables.VariableV1(1.0, name="variable_node") 180 output_node = math_ops.multiply(variable_node, 2.0, name="output_node") 181 sess = session.Session() 182 init = variables.global_variables_initializer() 183 sess.run(init) 184 output = sess.run(output_node) 185 self.assertNear(2.0, output, 0.00001) 186 saver = saver_lib.Saver() 187 checkpoint_path = saver.save( 188 sess, 189 checkpoint_prefix, 190 global_step=0, 191 latest_filename=checkpoint_state_name) 192 193 input_saver_def_path = "" 194 input_binary = True 195 output_node_names = "output_node" 196 restore_op_name = "save/restore_all" 197 filename_tensor_name = "save/Const:0" 198 clear_devices = False 199 input_meta_graph = checkpoint_path + ".meta" 200 201 freeze_graph.freeze_graph( 202 "", input_saver_def_path, input_binary, checkpoint_path, 203 output_node_names, restore_op_name, filename_tensor_name, 204 output_graph_filename, clear_devices, "", "", "", input_meta_graph) 205 206 # Now we make sure the variable is now a constant, and that the graph still 207 # produces the expected result. 208 with ops.Graph().as_default(): 209 output_graph_def = graph_pb2.GraphDef() 210 with open(output_graph_filename, "rb") as f: 211 output_graph_def.ParseFromString(f.read()) 212 _ = importer.import_graph_def(output_graph_def, name="") 213 214 self.assertEqual(4, len(output_graph_def.node)) 215 for node in output_graph_def.node: 216 self.assertNotEqual("VariableV2", node.op) 217 self.assertNotEqual("Variable", node.op) 218 219 with session.Session() as sess: 220 output_node = sess.graph.get_tensor_by_name("output_node:0") 221 output = sess.run(output_node) 222 self.assertNear(2.0, output, 0.00001) 223 224 @parameterized.named_parameters( 225 ("empty_tags_set", "", []), 226 ("default_tags_set", tag_constants.SERVING, [tag_constants.SERVING])) 227 def testFreezeSavedModel(self, tags_string, tags_list): 228 tmp_dir = self.get_temp_dir() 229 saved_model_dir = os.path.join(tmp_dir, "saved_model_dir") 230 feature_name = "feature" 231 self._writeDummySavedModel(saved_model_dir, feature_name, tags_list) 232 output_graph_filename = os.path.join(tmp_dir, "output_graph.pb") 233 234 input_saved_model_dir = saved_model_dir 235 output_node_names = "output_node" 236 input_binary = False 237 input_saver_def_path = False 238 restore_op_name = None 239 filename_tensor_name = None 240 clear_devices = False 241 input_meta_graph = False 242 checkpoint_path = None 243 input_graph_filename = None 244 saved_model_tags = tags_string 245 246 freeze_graph.freeze_graph(input_graph_filename, input_saver_def_path, 247 input_binary, checkpoint_path, output_node_names, 248 restore_op_name, filename_tensor_name, 249 output_graph_filename, clear_devices, "", "", "", 250 input_meta_graph, input_saved_model_dir, 251 saved_model_tags) 252 253 # Now we make sure the variable is now a constant, and that the graph still 254 # produces the expected result. 255 with ops.Graph().as_default(): 256 output_graph_def = graph_pb2.GraphDef() 257 with open(output_graph_filename, "rb") as f: 258 output_graph_def.ParseFromString(f.read()) 259 _ = importer.import_graph_def(output_graph_def, name="") 260 261 if any(u"ParseExampleV2" in node.name for node in output_graph_def.node): 262 expected_node_count = 10 263 else: 264 expected_node_count = 8 265 self.assertEqual(expected_node_count, len(output_graph_def.node)) 266 for node in output_graph_def.node: 267 self.assertNotEqual("VariableV2", node.op) 268 self.assertNotEqual("Variable", node.op) 269 270 feature_value = 2.0 271 example = self._createTFExampleString(feature_name, feature_value) 272 with session.Session() as sess: 273 input_node = sess.graph.get_tensor_by_name("input_node:0") 274 output_node = sess.graph.get_tensor_by_name("output_node:0") 275 output = sess.run(output_node, feed_dict={input_node: [example]}) 276 self.assertNear(feature_value, output, 0.00001) 277 278 def testSinglePartitionedVariable(self): 279 """Ensures partitioned variables fail cleanly with freeze graph.""" 280 checkpoint_prefix = os.path.join(self.get_temp_dir(), "saved_checkpoint") 281 checkpoint_state_name = "checkpoint_state" 282 input_graph_name = "input_graph.pb" 283 output_graph_name = "output_graph.pb" 284 285 # Create a graph with partition variables. When weights are partitioned into 286 # a single partition, the weights variable is followed by a identity -> 287 # identity (an additional identity node). 288 partitioner = partitioned_variables.fixed_size_partitioner(1) 289 with ops.Graph().as_default(): 290 with variable_scope.variable_scope("part", partitioner=partitioner): 291 batch_size, height, width, depth = 5, 128, 128, 3 292 input1 = array_ops.zeros( 293 (batch_size, height, width, depth), name="input1") 294 input2 = array_ops.zeros( 295 (batch_size, height, width, depth), name="input2") 296 297 num_nodes = depth 298 filter1 = variable_scope.get_variable("filter", [num_nodes, num_nodes]) 299 filter2 = array_ops.reshape(filter1, [1, 1, num_nodes, num_nodes]) 300 conv = nn.conv2d( 301 input=input1, filter=filter2, strides=[1, 1, 1, 1], padding="SAME") 302 node = math_ops.add(conv, input2, name="test/add") 303 node = nn.relu6(node, name="test/relu6") 304 305 # Save graph and checkpoints. 306 sess = session.Session() 307 sess.run(variables.global_variables_initializer()) 308 309 saver = saver_lib.Saver() 310 checkpoint_path = saver.save( 311 sess, 312 checkpoint_prefix, 313 global_step=0, 314 latest_filename=checkpoint_state_name) 315 graph_io.write_graph(sess.graph, self.get_temp_dir(), input_graph_name) 316 317 # Ensure this graph has partition variables. 318 self.assertTrue([ 319 tensor.name.split(":")[0] 320 for op in sess.graph.get_operations() 321 for tensor in op.values() 322 if re.search(r"/part_\d+/", tensor.name) 323 ]) 324 325 # Test freezing graph doesn't make it crash. 326 output_node_names = "save/restore_all" 327 output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name) 328 329 with self.assertRaises(ValueError): 330 freeze_graph.freeze_graph_with_def_protos( 331 input_graph_def=sess.graph_def, 332 input_saver_def=None, 333 input_checkpoint=checkpoint_path, 334 output_node_names=output_node_names, 335 restore_op_name="save/restore_all", # default value 336 filename_tensor_name="save/Const:0", # default value 337 output_graph=output_graph_path, 338 clear_devices=False, 339 initializer_nodes="") 340 341 342if __name__ == "__main__": 343 test.main() 344