1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests the graph freezing tool.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22import re 23 24from tensorflow.core.example import example_pb2 25from tensorflow.core.framework import graph_pb2 26from tensorflow.core.protobuf import saver_pb2 27from tensorflow.python.client import session 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import graph_io 30from tensorflow.python.framework import importer 31from tensorflow.python.framework import ops 32from tensorflow.python.framework import test_util 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import math_ops 35from tensorflow.python.ops import nn 36from tensorflow.python.ops import parsing_ops 37from tensorflow.python.ops import partitioned_variables 38from tensorflow.python.ops import variable_scope 39from tensorflow.python.ops import variables 40from tensorflow.python.platform import test 41from tensorflow.python.saved_model import builder as saved_model_builder 42from tensorflow.python.saved_model import signature_constants 43from tensorflow.python.saved_model import signature_def_utils 44from tensorflow.python.saved_model import tag_constants 45from tensorflow.python.tools import freeze_graph 46from tensorflow.python.training import saver as saver_lib 47 48 49class FreezeGraphTest(test_util.TensorFlowTestCase): 50 51 def _testFreezeGraph(self, saver_write_version): 52 53 checkpoint_prefix = os.path.join(self.get_temp_dir(), "saved_checkpoint") 54 checkpoint_state_name = "checkpoint_state" 55 input_graph_name = "input_graph.pb" 56 output_graph_name = "output_graph.pb" 57 58 # We'll create an input graph that has a single variable containing 1.0, 59 # and that then multiplies it by 2. 60 with ops.Graph().as_default(): 61 variable_node = variables.VariableV1(1.0, name="variable_node") 62 output_node = math_ops.multiply(variable_node, 2.0, name="output_node") 63 sess = session.Session() 64 init = variables.global_variables_initializer() 65 sess.run(init) 66 output = sess.run(output_node) 67 self.assertNear(2.0, output, 0.00001) 68 saver = saver_lib.Saver(write_version=saver_write_version) 69 checkpoint_path = saver.save( 70 sess, 71 checkpoint_prefix, 72 global_step=0, 73 latest_filename=checkpoint_state_name) 74 graph_io.write_graph(sess.graph, self.get_temp_dir(), input_graph_name) 75 76 # We save out the graph to disk, and then call the const conversion 77 # routine. 78 input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name) 79 input_saver_def_path = "" 80 input_binary = False 81 output_node_names = "output_node" 82 restore_op_name = "save/restore_all" 83 filename_tensor_name = "save/Const:0" 84 output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name) 85 clear_devices = False 86 87 freeze_graph.freeze_graph( 88 input_graph_path, 89 input_saver_def_path, 90 input_binary, 91 checkpoint_path, 92 output_node_names, 93 restore_op_name, 94 filename_tensor_name, 95 output_graph_path, 96 clear_devices, 97 "", 98 "", 99 "", 100 checkpoint_version=saver_write_version) 101 102 # Now we make sure the variable is now a constant, and that the graph still 103 # produces the expected result. 104 with ops.Graph().as_default(): 105 output_graph_def = graph_pb2.GraphDef() 106 with open(output_graph_path, "rb") as f: 107 output_graph_def.ParseFromString(f.read()) 108 _ = importer.import_graph_def(output_graph_def, name="") 109 110 self.assertEqual(4, len(output_graph_def.node)) 111 for node in output_graph_def.node: 112 self.assertNotEqual("VariableV2", node.op) 113 self.assertNotEqual("Variable", node.op) 114 115 with session.Session() as sess: 116 output_node = sess.graph.get_tensor_by_name("output_node:0") 117 output = sess.run(output_node) 118 self.assertNear(2.0, output, 0.00001) 119 120 def _createTFExampleString(self, feature_name, feature_value): 121 """Create a serialized tensorflow example.""" 122 example = example_pb2.Example() 123 example.features.feature[feature_name].float_list.value.extend([ 124 feature_value]) 125 return example.SerializeToString() 126 127 def _writeDummySavedModel(self, path, feature_name): 128 """Writes a classifier with two input features to the given path.""" 129 with ops.Graph().as_default(): 130 examples = array_ops.placeholder(dtypes.string, name="input_node") 131 feature_configs = { 132 feature_name: parsing_ops.FixedLenFeature(shape=[], 133 dtype=dtypes.float32), 134 } 135 features = parsing_ops.parse_example(examples, feature_configs) 136 feature = features[feature_name] 137 138 variable_node = variables.VariableV1(1.0, name="variable_node") 139 scores = math_ops.multiply(variable_node, feature, name="output_node") 140 class_feature = array_ops.fill(array_ops.shape(feature), 141 "class_%s" % feature_name) 142 classes = array_ops.transpose(class_feature) 143 144 with session.Session() as sess: 145 sess.run(variables.global_variables_initializer()) 146 signature = ( 147 signature_def_utils.classification_signature_def( 148 examples=examples, 149 classes=classes, 150 scores=scores,)) 151 builder = saved_model_builder.SavedModelBuilder(path) 152 builder.add_meta_graph_and_variables( 153 sess, 154 [tag_constants.SERVING], 155 signature_def_map={ 156 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: 157 signature, 158 },) 159 builder.save(as_text=True) 160 161 @test_util.run_v1_only("b/120545219") 162 def testFreezeGraphV1(self): 163 self._testFreezeGraph(saver_pb2.SaverDef.V1) 164 165 @test_util.run_v1_only("b/120545219") 166 def testFreezeGraphV2(self): 167 self._testFreezeGraph(saver_pb2.SaverDef.V2) 168 169 def testFreezeMetaGraph(self): 170 tmp_dir = self.get_temp_dir() 171 checkpoint_prefix = os.path.join(tmp_dir, "meta_graph_checkpoint") 172 checkpoint_state_name = "checkpoint_state" 173 output_graph_filename = os.path.join(tmp_dir, "output_graph.pb") 174 175 with ops.Graph().as_default(): 176 variable_node = variables.VariableV1(1.0, name="variable_node") 177 output_node = math_ops.multiply(variable_node, 2.0, name="output_node") 178 sess = session.Session() 179 init = variables.global_variables_initializer() 180 sess.run(init) 181 output = sess.run(output_node) 182 self.assertNear(2.0, output, 0.00001) 183 saver = saver_lib.Saver() 184 checkpoint_path = saver.save( 185 sess, 186 checkpoint_prefix, 187 global_step=0, 188 latest_filename=checkpoint_state_name) 189 190 input_saver_def_path = "" 191 input_binary = True 192 output_node_names = "output_node" 193 restore_op_name = "save/restore_all" 194 filename_tensor_name = "save/Const:0" 195 clear_devices = False 196 input_meta_graph = checkpoint_path + ".meta" 197 198 freeze_graph.freeze_graph( 199 "", input_saver_def_path, input_binary, checkpoint_path, 200 output_node_names, restore_op_name, filename_tensor_name, 201 output_graph_filename, clear_devices, "", "", "", input_meta_graph) 202 203 # Now we make sure the variable is now a constant, and that the graph still 204 # produces the expected result. 205 with ops.Graph().as_default(): 206 output_graph_def = graph_pb2.GraphDef() 207 with open(output_graph_filename, "rb") as f: 208 output_graph_def.ParseFromString(f.read()) 209 _ = importer.import_graph_def(output_graph_def, name="") 210 211 self.assertEqual(4, len(output_graph_def.node)) 212 for node in output_graph_def.node: 213 self.assertNotEqual("VariableV2", node.op) 214 self.assertNotEqual("Variable", node.op) 215 216 with session.Session() as sess: 217 output_node = sess.graph.get_tensor_by_name("output_node:0") 218 output = sess.run(output_node) 219 self.assertNear(2.0, output, 0.00001) 220 221 def testFreezeSavedModel(self): 222 tmp_dir = self.get_temp_dir() 223 saved_model_dir = os.path.join(tmp_dir, "saved_model_dir") 224 feature_name = "feature" 225 self._writeDummySavedModel(saved_model_dir, feature_name) 226 output_graph_filename = os.path.join(tmp_dir, "output_graph.pb") 227 228 input_saved_model_dir = saved_model_dir 229 output_node_names = "output_node" 230 input_binary = False 231 input_saver_def_path = False 232 restore_op_name = None 233 filename_tensor_name = None 234 clear_devices = False 235 input_meta_graph = False 236 checkpoint_path = None 237 input_graph_filename = None 238 saved_model_tags = tag_constants.SERVING 239 240 freeze_graph.freeze_graph(input_graph_filename, input_saver_def_path, 241 input_binary, checkpoint_path, output_node_names, 242 restore_op_name, filename_tensor_name, 243 output_graph_filename, clear_devices, "", "", "", 244 input_meta_graph, input_saved_model_dir, 245 saved_model_tags) 246 247 # Now we make sure the variable is now a constant, and that the graph still 248 # produces the expected result. 249 with ops.Graph().as_default(): 250 output_graph_def = graph_pb2.GraphDef() 251 with open(output_graph_filename, "rb") as f: 252 output_graph_def.ParseFromString(f.read()) 253 _ = importer.import_graph_def(output_graph_def, name="") 254 255 self.assertEqual(8, len(output_graph_def.node)) 256 for node in output_graph_def.node: 257 self.assertNotEqual("VariableV2", node.op) 258 self.assertNotEqual("Variable", node.op) 259 260 feature_value = 2.0 261 example = self._createTFExampleString(feature_name, feature_value) 262 with session.Session() as sess: 263 input_node = sess.graph.get_tensor_by_name("input_node:0") 264 output_node = sess.graph.get_tensor_by_name("output_node:0") 265 output = sess.run(output_node, feed_dict={input_node: [example]}) 266 self.assertNear(feature_value, output, 0.00001) 267 268 def testSinglePartitionedVariable(self): 269 """Ensures partitioned variables fail cleanly with freeze graph.""" 270 checkpoint_prefix = os.path.join(self.get_temp_dir(), "saved_checkpoint") 271 checkpoint_state_name = "checkpoint_state" 272 input_graph_name = "input_graph.pb" 273 output_graph_name = "output_graph.pb" 274 275 # Create a graph with partition variables. When weights are partitioned into 276 # a single partition, the weights variable is followed by a identity -> 277 # identity (an additional identity node). 278 partitioner = partitioned_variables.fixed_size_partitioner(1) 279 with ops.Graph().as_default(): 280 with variable_scope.variable_scope("part", partitioner=partitioner): 281 batch_size, height, width, depth = 5, 128, 128, 3 282 input1 = array_ops.zeros( 283 (batch_size, height, width, depth), name="input1") 284 input2 = array_ops.zeros( 285 (batch_size, height, width, depth), name="input2") 286 287 num_nodes = depth 288 filter1 = variable_scope.get_variable("filter", [num_nodes, num_nodes]) 289 filter2 = array_ops.reshape(filter1, [1, 1, num_nodes, num_nodes]) 290 conv = nn.conv2d( 291 input=input1, filter=filter2, strides=[1, 1, 1, 1], padding="SAME") 292 node = math_ops.add(conv, input2, name="test/add") 293 node = nn.relu6(node, name="test/relu6") 294 295 # Save graph and checkpoints. 296 sess = session.Session() 297 sess.run(variables.global_variables_initializer()) 298 299 saver = saver_lib.Saver() 300 checkpoint_path = saver.save( 301 sess, 302 checkpoint_prefix, 303 global_step=0, 304 latest_filename=checkpoint_state_name) 305 graph_io.write_graph(sess.graph, self.get_temp_dir(), input_graph_name) 306 307 # Ensure this graph has partition variables. 308 self.assertTrue([ 309 tensor.name.split(":")[0] 310 for op in sess.graph.get_operations() 311 for tensor in op.values() 312 if re.search(r"/part_\d+/", tensor.name) 313 ]) 314 315 # Test freezing graph doesn't make it crash. 316 output_node_names = "save/restore_all" 317 output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name) 318 319 return_value = freeze_graph.freeze_graph_with_def_protos( 320 input_graph_def=sess.graph_def, 321 input_saver_def=None, 322 input_checkpoint=checkpoint_path, 323 output_node_names=output_node_names, 324 restore_op_name="save/restore_all", # default value 325 filename_tensor_name="save/Const:0", # default value 326 output_graph=output_graph_path, 327 clear_devices=False, 328 initializer_nodes="") 329 self.assertTrue(return_value, -1) 330 331 332if __name__ == "__main__": 333 test.main() 334