1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15r"""Converts checkpoint variables into Const ops in a standalone GraphDef file. 16 17This script is designed to take a GraphDef proto, a SaverDef proto, and a set of 18variable values stored in a checkpoint file, and output a GraphDef with all of 19the variable ops converted into const ops containing the values of the 20variables. 21 22It's useful to do this when we need to load a single file in C++, especially in 23environments like mobile or embedded where we may not have access to the 24RestoreTensor ops and file loading calls that they rely on. 25 26An example of command-line usage is: 27bazel build tensorflow/python/tools:freeze_graph && \ 28bazel-bin/tensorflow/python/tools/freeze_graph \ 29--input_graph=some_graph_def.pb \ 30--input_checkpoint=model.ckpt-8361242 \ 31--output_graph=/tmp/frozen_graph.pb --output_node_names=softmax 32 33You can also look at freeze_graph_test.py for an example of how to use it. 34 35""" 36from __future__ import absolute_import 37from __future__ import division 38from __future__ import print_function 39 40import argparse 41import re 42import sys 43 44from google.protobuf import text_format 45 46from tensorflow.core.framework import graph_pb2 47from tensorflow.core.protobuf import saver_pb2 48from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef 49from tensorflow.python.client import session 50from tensorflow.python.framework import graph_util 51from tensorflow.python.framework import importer 52from tensorflow.python.platform import app 53from tensorflow.python.platform import gfile 54from tensorflow.python.saved_model import loader 55from tensorflow.python.saved_model import tag_constants 56from tensorflow.python.tools import saved_model_utils 57from tensorflow.python.training import checkpoint_management 58from tensorflow.python.training import py_checkpoint_reader 59from tensorflow.python.training import saver as saver_lib 60 61 62def _has_no_variables(sess): 63 """Determines if the graph has any variables. 64 65 Args: 66 sess: TensorFlow Session. 67 68 Returns: 69 Bool. 70 """ 71 for op in sess.graph.get_operations(): 72 if op.type.startswith("Variable") or op.type.endswith("VariableOp"): 73 return False 74 return True 75 76 77def freeze_graph_with_def_protos(input_graph_def, 78 input_saver_def, 79 input_checkpoint, 80 output_node_names, 81 restore_op_name, 82 filename_tensor_name, 83 output_graph, 84 clear_devices, 85 initializer_nodes, 86 variable_names_whitelist="", 87 variable_names_denylist="", 88 input_meta_graph_def=None, 89 input_saved_model_dir=None, 90 saved_model_tags=None, 91 checkpoint_version=saver_pb2.SaverDef.V2): 92 """Converts all variables in a graph and checkpoint into constants. 93 94 Args: 95 input_graph_def: A `GraphDef`. 96 input_saver_def: A `SaverDef` (optional). 97 input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking 98 priority. Typically the result of `Saver.save()` or that of 99 `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or 100 V1/V2. 101 output_node_names: The name(s) of the output nodes, comma separated. 102 restore_op_name: Unused. 103 filename_tensor_name: Unused. 104 output_graph: String where to write the frozen `GraphDef`. 105 clear_devices: A Bool whether to remove device specifications. 106 initializer_nodes: Comma separated string of initializer nodes to run before 107 freezing. 108 variable_names_whitelist: The set of variable names to convert (optional, by 109 default, all variables are converted). 110 variable_names_denylist: The set of variable names to omit converting 111 to constants (optional). 112 input_meta_graph_def: A `MetaGraphDef` (optional), 113 input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file 114 and variables (optional). 115 saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to 116 load, in string format (optional). 117 checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1 118 or saver_pb2.SaverDef.V2) 119 120 Returns: 121 Location of the output_graph_def. 122 """ 123 del restore_op_name, filename_tensor_name # Unused by updated loading code. 124 125 # 'input_checkpoint' may be a prefix if we're using Saver V2 format 126 if (not input_saved_model_dir and 127 not checkpoint_management.checkpoint_exists(input_checkpoint)): 128 raise ValueError("Input checkpoint '" + input_checkpoint + 129 "' doesn't exist!") 130 131 if not output_node_names: 132 raise ValueError( 133 "You need to supply the name of a node to --output_node_names.") 134 135 # Remove all the explicit device specifications for this node. This helps to 136 # make the graph more portable. 137 if clear_devices: 138 if input_meta_graph_def: 139 for node in input_meta_graph_def.graph_def.node: 140 node.device = "" 141 elif input_graph_def: 142 for node in input_graph_def.node: 143 node.device = "" 144 145 if input_graph_def: 146 _ = importer.import_graph_def(input_graph_def, name="") 147 with session.Session() as sess: 148 if input_saver_def: 149 saver = saver_lib.Saver( 150 saver_def=input_saver_def, write_version=checkpoint_version) 151 saver.restore(sess, input_checkpoint) 152 elif input_meta_graph_def: 153 restorer = saver_lib.import_meta_graph( 154 input_meta_graph_def, clear_devices=True) 155 restorer.restore(sess, input_checkpoint) 156 if initializer_nodes: 157 sess.run(initializer_nodes.replace(" ", "").split(",")) 158 elif input_saved_model_dir: 159 if saved_model_tags is None: 160 saved_model_tags = [] 161 loader.load(sess, saved_model_tags, input_saved_model_dir) 162 else: 163 var_list = {} 164 reader = py_checkpoint_reader.NewCheckpointReader(input_checkpoint) 165 var_to_shape_map = reader.get_variable_to_shape_map() 166 167 # List of all partition variables. Because the condition is heuristic 168 # based, the list could include false positives. 169 all_partition_variable_names = [ 170 tensor.name.split(":")[0] 171 for op in sess.graph.get_operations() 172 for tensor in op.values() 173 if re.search(r"/part_\d+/", tensor.name) 174 ] 175 has_partition_var = False 176 177 for key in var_to_shape_map: 178 try: 179 tensor = sess.graph.get_tensor_by_name(key + ":0") 180 if any(key in name for name in all_partition_variable_names): 181 has_partition_var = True 182 except KeyError: 183 # This tensor doesn't exist in the graph (for example it's 184 # 'global_step' or a similar housekeeping element) so skip it. 185 continue 186 var_list[key] = tensor 187 188 try: 189 saver = saver_lib.Saver( 190 var_list=var_list, write_version=checkpoint_version) 191 except TypeError as e: 192 # `var_list` is required to be a map of variable names to Variable 193 # tensors. Partition variables are Identity tensors that cannot be 194 # handled by Saver. 195 if has_partition_var: 196 raise ValueError( 197 "Models containing partition variables cannot be converted " 198 "from checkpoint files. Please pass in a SavedModel using " 199 "the flag --input_saved_model_dir.") 200 # Models that have been frozen previously do not contain Variables. 201 elif _has_no_variables(sess): 202 raise ValueError( 203 "No variables were found in this model. It is likely the model " 204 "was frozen previously. You cannot freeze a graph twice.") 205 return 0 206 else: 207 raise e 208 209 saver.restore(sess, input_checkpoint) 210 if initializer_nodes: 211 sess.run(initializer_nodes.replace(" ", "").split(",")) 212 213 variable_names_whitelist = ( 214 variable_names_whitelist.replace(" ", "").split(",") 215 if variable_names_whitelist else None) 216 variable_names_denylist = ( 217 variable_names_denylist.replace(" ", "").split(",") 218 if variable_names_denylist else None) 219 220 if input_meta_graph_def: 221 output_graph_def = graph_util.convert_variables_to_constants( 222 sess, 223 input_meta_graph_def.graph_def, 224 output_node_names.replace(" ", "").split(","), 225 variable_names_whitelist=variable_names_whitelist, 226 variable_names_blacklist=variable_names_denylist) 227 else: 228 output_graph_def = graph_util.convert_variables_to_constants( 229 sess, 230 input_graph_def, 231 output_node_names.replace(" ", "").split(","), 232 variable_names_whitelist=variable_names_whitelist, 233 variable_names_blacklist=variable_names_denylist) 234 235 # Write GraphDef to file if output path has been given. 236 if output_graph: 237 with gfile.GFile(output_graph, "wb") as f: 238 f.write(output_graph_def.SerializeToString()) 239 240 return output_graph_def 241 242 243def _parse_input_graph_proto(input_graph, input_binary): 244 """Parses input tensorflow graph into GraphDef proto.""" 245 if not gfile.Exists(input_graph): 246 raise IOError("Input graph file '" + input_graph + "' does not exist!") 247 input_graph_def = graph_pb2.GraphDef() 248 mode = "rb" if input_binary else "r" 249 with gfile.GFile(input_graph, mode) as f: 250 if input_binary: 251 input_graph_def.ParseFromString(f.read()) 252 else: 253 text_format.Merge(f.read(), input_graph_def) 254 return input_graph_def 255 256 257def _parse_input_meta_graph_proto(input_graph, input_binary): 258 """Parses input tensorflow graph into MetaGraphDef proto.""" 259 if not gfile.Exists(input_graph): 260 raise IOError("Input meta graph file '" + input_graph + "' does not exist!") 261 input_meta_graph_def = MetaGraphDef() 262 mode = "rb" if input_binary else "r" 263 with gfile.GFile(input_graph, mode) as f: 264 if input_binary: 265 input_meta_graph_def.ParseFromString(f.read()) 266 else: 267 text_format.Merge(f.read(), input_meta_graph_def) 268 print("Loaded meta graph file '" + input_graph) 269 return input_meta_graph_def 270 271 272def _parse_input_saver_proto(input_saver, input_binary): 273 """Parses input tensorflow Saver into SaverDef proto.""" 274 if not gfile.Exists(input_saver): 275 raise IOError("Input saver file '" + input_saver + "' does not exist!") 276 mode = "rb" if input_binary else "r" 277 with gfile.GFile(input_saver, mode) as f: 278 saver_def = saver_pb2.SaverDef() 279 if input_binary: 280 saver_def.ParseFromString(f.read()) 281 else: 282 text_format.Merge(f.read(), saver_def) 283 return saver_def 284 285 286def freeze_graph(input_graph, 287 input_saver, 288 input_binary, 289 input_checkpoint, 290 output_node_names, 291 restore_op_name, 292 filename_tensor_name, 293 output_graph, 294 clear_devices, 295 initializer_nodes, 296 variable_names_whitelist="", 297 variable_names_denylist="", 298 input_meta_graph=None, 299 input_saved_model_dir=None, 300 saved_model_tags=tag_constants.SERVING, 301 checkpoint_version=saver_pb2.SaverDef.V2): 302 """Converts all variables in a graph and checkpoint into constants. 303 304 Args: 305 input_graph: A `GraphDef` file to load. 306 input_saver: A TensorFlow Saver file. 307 input_binary: A Bool. True means input_graph is .pb, False indicates .pbtxt. 308 input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking 309 priority. Typically the result of `Saver.save()` or that of 310 `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or 311 V1/V2. 312 output_node_names: The name(s) of the output nodes, comma separated. 313 restore_op_name: Unused. 314 filename_tensor_name: Unused. 315 output_graph: String where to write the frozen `GraphDef`. 316 clear_devices: A Bool whether to remove device specifications. 317 initializer_nodes: Comma separated list of initializer nodes to run before 318 freezing. 319 variable_names_whitelist: The set of variable names to convert (optional, by 320 default, all variables are converted), 321 variable_names_denylist: The set of variable names to omit converting 322 to constants (optional). 323 input_meta_graph: A `MetaGraphDef` file to load (optional). 324 input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file and 325 variables (optional). 326 saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to 327 load, in string format. 328 checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1 329 or saver_pb2.SaverDef.V2). 330 Returns: 331 String that is the location of frozen GraphDef. 332 """ 333 input_graph_def = None 334 if input_saved_model_dir: 335 input_graph_def = saved_model_utils.get_meta_graph_def( 336 input_saved_model_dir, saved_model_tags).graph_def 337 elif input_graph: 338 input_graph_def = _parse_input_graph_proto(input_graph, input_binary) 339 input_meta_graph_def = None 340 if input_meta_graph: 341 input_meta_graph_def = _parse_input_meta_graph_proto( 342 input_meta_graph, input_binary) 343 input_saver_def = None 344 if input_saver: 345 input_saver_def = _parse_input_saver_proto(input_saver, input_binary) 346 return freeze_graph_with_def_protos( 347 input_graph_def, 348 input_saver_def, 349 input_checkpoint, 350 output_node_names, 351 restore_op_name, 352 filename_tensor_name, 353 output_graph, 354 clear_devices, 355 initializer_nodes, 356 variable_names_whitelist, 357 variable_names_denylist, 358 input_meta_graph_def, 359 input_saved_model_dir, 360 [tag for tag in saved_model_tags.replace(" ", "").split(",") if tag], 361 checkpoint_version=checkpoint_version) 362 363 364def main(unused_args, flags): 365 if flags.checkpoint_version == 1: 366 checkpoint_version = saver_pb2.SaverDef.V1 367 elif flags.checkpoint_version == 2: 368 checkpoint_version = saver_pb2.SaverDef.V2 369 else: 370 raise ValueError("Invalid checkpoint version (must be '1' or '2'): %d" % 371 flags.checkpoint_version) 372 freeze_graph(flags.input_graph, flags.input_saver, flags.input_binary, 373 flags.input_checkpoint, flags.output_node_names, 374 flags.restore_op_name, flags.filename_tensor_name, 375 flags.output_graph, flags.clear_devices, flags.initializer_nodes, 376 flags.variable_names_whitelist, flags.variable_names_denylist, 377 flags.input_meta_graph, flags.input_saved_model_dir, 378 flags.saved_model_tags, checkpoint_version) 379 380 381def run_main(): 382 """Main function of freeze_graph.""" 383 parser = argparse.ArgumentParser() 384 parser.register("type", "bool", lambda v: v.lower() == "true") 385 parser.add_argument( 386 "--input_graph", 387 type=str, 388 default="", 389 help="TensorFlow \'GraphDef\' file to load.") 390 parser.add_argument( 391 "--input_saver", 392 type=str, 393 default="", 394 help="TensorFlow saver file to load.") 395 parser.add_argument( 396 "--input_checkpoint", 397 type=str, 398 default="", 399 help="TensorFlow variables file to load.") 400 parser.add_argument( 401 "--checkpoint_version", 402 type=int, 403 default=2, 404 help="Tensorflow variable file format") 405 parser.add_argument( 406 "--output_graph", 407 type=str, 408 default="", 409 help="Output \'GraphDef\' file name.") 410 parser.add_argument( 411 "--input_binary", 412 nargs="?", 413 const=True, 414 type="bool", 415 default=False, 416 help="Whether the input files are in binary format.") 417 parser.add_argument( 418 "--output_node_names", 419 type=str, 420 default="", 421 help="The name of the output nodes, comma separated.") 422 parser.add_argument( 423 "--restore_op_name", 424 type=str, 425 default="save/restore_all", 426 help="""\ 427 The name of the master restore operator. Deprecated, unused by updated \ 428 loading code. 429 """) 430 parser.add_argument( 431 "--filename_tensor_name", 432 type=str, 433 default="save/Const:0", 434 help="""\ 435 The name of the tensor holding the save path. Deprecated, unused by \ 436 updated loading code. 437 """) 438 parser.add_argument( 439 "--clear_devices", 440 nargs="?", 441 const=True, 442 type="bool", 443 default=True, 444 help="Whether to remove device specifications.") 445 parser.add_argument( 446 "--initializer_nodes", 447 type=str, 448 default="", 449 help="Comma separated list of initializer nodes to run before freezing.") 450 parser.add_argument( 451 "--variable_names_whitelist", 452 type=str, 453 default="", 454 help="""\ 455 Comma separated list of variables to convert to constants. If specified, \ 456 only those variables will be converted to constants.\ 457 """) 458 parser.add_argument( 459 "--variable_names_denylist", 460 type=str, 461 default="", 462 help="""\ 463 Comma separated list of variables to skip converting to constants.\ 464 """) 465 parser.add_argument( 466 "--input_meta_graph", 467 type=str, 468 default="", 469 help="TensorFlow \'MetaGraphDef\' file to load.") 470 parser.add_argument( 471 "--input_saved_model_dir", 472 type=str, 473 default="", 474 help="Path to the dir with TensorFlow \'SavedModel\' file and variables.") 475 parser.add_argument( 476 "--saved_model_tags", 477 type=str, 478 default="serve", 479 help="""\ 480 Group of tag(s) of the MetaGraphDef to load, in string format,\ 481 separated by \',\'. For tag-set contains multiple tags, all tags \ 482 must be passed in.\ 483 """) 484 flags, unparsed = parser.parse_known_args() 485 486 my_main = lambda unused_args: main(unused_args, flags) 487 app.run(main=my_main, argv=[sys.argv[0]] + unparsed) 488 489 490if __name__ == "__main__": 491 run_main() 492