• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15r"""Converts checkpoint variables into Const ops in a standalone GraphDef file.
16
17This script is designed to take a GraphDef proto, a SaverDef proto, and a set of
18variable values stored in a checkpoint file, and output a GraphDef with all of
19the variable ops converted into const ops containing the values of the
20variables.
21
22It's useful to do this when we need to load a single file in C++, especially in
23environments like mobile or embedded where we may not have access to the
24RestoreTensor ops and file loading calls that they rely on.
25
26An example of command-line usage is:
27bazel build tensorflow/python/tools:freeze_graph && \
28bazel-bin/tensorflow/python/tools/freeze_graph \
29--input_graph=some_graph_def.pb \
30--input_checkpoint=model.ckpt-8361242 \
31--output_graph=/tmp/frozen_graph.pb --output_node_names=softmax
32
33You can also look at freeze_graph_test.py for an example of how to use it.
34
35"""
36from __future__ import absolute_import
37from __future__ import division
38from __future__ import print_function
39
40import argparse
41import re
42import sys
43
44from google.protobuf import text_format
45
46from tensorflow.core.framework import graph_pb2
47from tensorflow.core.protobuf import saver_pb2
48from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef
49from tensorflow.python.client import session
50from tensorflow.python.framework import graph_util
51from tensorflow.python.framework import importer
52from tensorflow.python.platform import app
53from tensorflow.python.platform import gfile
54from tensorflow.python.saved_model import loader
55from tensorflow.python.saved_model import tag_constants
56from tensorflow.python.tools import saved_model_utils
57from tensorflow.python.training import checkpoint_management
58from tensorflow.python.training import py_checkpoint_reader
59from tensorflow.python.training import saver as saver_lib
60
61
62def _has_no_variables(sess):
63  """Determines if the graph has any variables.
64
65  Args:
66    sess: TensorFlow Session.
67
68  Returns:
69    Bool.
70  """
71  for op in sess.graph.get_operations():
72    if op.type.startswith("Variable") or op.type.endswith("VariableOp"):
73      return False
74  return True
75
76
77def freeze_graph_with_def_protos(input_graph_def,
78                                 input_saver_def,
79                                 input_checkpoint,
80                                 output_node_names,
81                                 restore_op_name,
82                                 filename_tensor_name,
83                                 output_graph,
84                                 clear_devices,
85                                 initializer_nodes,
86                                 variable_names_whitelist="",
87                                 variable_names_denylist="",
88                                 input_meta_graph_def=None,
89                                 input_saved_model_dir=None,
90                                 saved_model_tags=None,
91                                 checkpoint_version=saver_pb2.SaverDef.V2):
92  """Converts all variables in a graph and checkpoint into constants.
93
94  Args:
95    input_graph_def: A `GraphDef`.
96    input_saver_def: A `SaverDef` (optional).
97    input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking
98      priority.  Typically the result of `Saver.save()` or that of
99      `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
100      V1/V2.
101    output_node_names: The name(s) of the output nodes, comma separated.
102    restore_op_name: Unused.
103    filename_tensor_name: Unused.
104    output_graph: String where to write the frozen `GraphDef`.
105    clear_devices: A Bool whether to remove device specifications.
106    initializer_nodes: Comma separated string of initializer nodes to run before
107                       freezing.
108    variable_names_whitelist: The set of variable names to convert (optional, by
109                              default, all variables are converted).
110    variable_names_denylist: The set of variable names to omit converting
111                              to constants (optional).
112    input_meta_graph_def: A `MetaGraphDef` (optional),
113    input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file
114                           and variables (optional).
115    saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to
116                      load, in string format (optional).
117    checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1
118                        or saver_pb2.SaverDef.V2)
119
120  Returns:
121    Location of the output_graph_def.
122  """
123  del restore_op_name, filename_tensor_name  # Unused by updated loading code.
124
125  # 'input_checkpoint' may be a prefix if we're using Saver V2 format
126  if (not input_saved_model_dir and
127      not checkpoint_management.checkpoint_exists(input_checkpoint)):
128    raise ValueError("Input checkpoint '" + input_checkpoint +
129                     "' doesn't exist!")
130
131  if not output_node_names:
132    raise ValueError(
133        "You need to supply the name of a node to --output_node_names.")
134
135  # Remove all the explicit device specifications for this node. This helps to
136  # make the graph more portable.
137  if clear_devices:
138    if input_meta_graph_def:
139      for node in input_meta_graph_def.graph_def.node:
140        node.device = ""
141    elif input_graph_def:
142      for node in input_graph_def.node:
143        node.device = ""
144
145  if input_graph_def:
146    _ = importer.import_graph_def(input_graph_def, name="")
147  with session.Session() as sess:
148    if input_saver_def:
149      saver = saver_lib.Saver(
150          saver_def=input_saver_def, write_version=checkpoint_version)
151      saver.restore(sess, input_checkpoint)
152    elif input_meta_graph_def:
153      restorer = saver_lib.import_meta_graph(
154          input_meta_graph_def, clear_devices=True)
155      restorer.restore(sess, input_checkpoint)
156      if initializer_nodes:
157        sess.run(initializer_nodes.replace(" ", "").split(","))
158    elif input_saved_model_dir:
159      if saved_model_tags is None:
160        saved_model_tags = []
161      loader.load(sess, saved_model_tags, input_saved_model_dir)
162    else:
163      var_list = {}
164      reader = py_checkpoint_reader.NewCheckpointReader(input_checkpoint)
165      var_to_shape_map = reader.get_variable_to_shape_map()
166
167      # List of all partition variables. Because the condition is heuristic
168      # based, the list could include false positives.
169      all_partition_variable_names = [
170          tensor.name.split(":")[0]
171          for op in sess.graph.get_operations()
172          for tensor in op.values()
173          if re.search(r"/part_\d+/", tensor.name)
174      ]
175      has_partition_var = False
176
177      for key in var_to_shape_map:
178        try:
179          tensor = sess.graph.get_tensor_by_name(key + ":0")
180          if any(key in name for name in all_partition_variable_names):
181            has_partition_var = True
182        except KeyError:
183          # This tensor doesn't exist in the graph (for example it's
184          # 'global_step' or a similar housekeeping element) so skip it.
185          continue
186        var_list[key] = tensor
187
188      try:
189        saver = saver_lib.Saver(
190            var_list=var_list, write_version=checkpoint_version)
191      except TypeError as e:
192        # `var_list` is required to be a map of variable names to Variable
193        # tensors. Partition variables are Identity tensors that cannot be
194        # handled by Saver.
195        if has_partition_var:
196          raise ValueError(
197              "Models containing partition variables cannot be converted "
198              "from checkpoint files. Please pass in a SavedModel using "
199              "the flag --input_saved_model_dir.")
200        # Models that have been frozen previously do not contain Variables.
201        elif _has_no_variables(sess):
202          raise ValueError(
203              "No variables were found in this model. It is likely the model "
204              "was frozen previously. You cannot freeze a graph twice.")
205          return 0
206        else:
207          raise e
208
209      saver.restore(sess, input_checkpoint)
210      if initializer_nodes:
211        sess.run(initializer_nodes.replace(" ", "").split(","))
212
213    variable_names_whitelist = (
214        variable_names_whitelist.replace(" ", "").split(",")
215        if variable_names_whitelist else None)
216    variable_names_denylist = (
217        variable_names_denylist.replace(" ", "").split(",")
218        if variable_names_denylist else None)
219
220    if input_meta_graph_def:
221      output_graph_def = graph_util.convert_variables_to_constants(
222          sess,
223          input_meta_graph_def.graph_def,
224          output_node_names.replace(" ", "").split(","),
225          variable_names_whitelist=variable_names_whitelist,
226          variable_names_blacklist=variable_names_denylist)
227    else:
228      output_graph_def = graph_util.convert_variables_to_constants(
229          sess,
230          input_graph_def,
231          output_node_names.replace(" ", "").split(","),
232          variable_names_whitelist=variable_names_whitelist,
233          variable_names_blacklist=variable_names_denylist)
234
235  # Write GraphDef to file if output path has been given.
236  if output_graph:
237    with gfile.GFile(output_graph, "wb") as f:
238      f.write(output_graph_def.SerializeToString())
239
240  return output_graph_def
241
242
243def _parse_input_graph_proto(input_graph, input_binary):
244  """Parses input tensorflow graph into GraphDef proto."""
245  if not gfile.Exists(input_graph):
246    raise IOError("Input graph file '" + input_graph + "' does not exist!")
247  input_graph_def = graph_pb2.GraphDef()
248  mode = "rb" if input_binary else "r"
249  with gfile.GFile(input_graph, mode) as f:
250    if input_binary:
251      input_graph_def.ParseFromString(f.read())
252    else:
253      text_format.Merge(f.read(), input_graph_def)
254  return input_graph_def
255
256
257def _parse_input_meta_graph_proto(input_graph, input_binary):
258  """Parses input tensorflow graph into MetaGraphDef proto."""
259  if not gfile.Exists(input_graph):
260    raise IOError("Input meta graph file '" + input_graph + "' does not exist!")
261  input_meta_graph_def = MetaGraphDef()
262  mode = "rb" if input_binary else "r"
263  with gfile.GFile(input_graph, mode) as f:
264    if input_binary:
265      input_meta_graph_def.ParseFromString(f.read())
266    else:
267      text_format.Merge(f.read(), input_meta_graph_def)
268  print("Loaded meta graph file '" + input_graph)
269  return input_meta_graph_def
270
271
272def _parse_input_saver_proto(input_saver, input_binary):
273  """Parses input tensorflow Saver into SaverDef proto."""
274  if not gfile.Exists(input_saver):
275    raise IOError("Input saver file '" + input_saver + "' does not exist!")
276  mode = "rb" if input_binary else "r"
277  with gfile.GFile(input_saver, mode) as f:
278    saver_def = saver_pb2.SaverDef()
279    if input_binary:
280      saver_def.ParseFromString(f.read())
281    else:
282      text_format.Merge(f.read(), saver_def)
283  return saver_def
284
285
286def freeze_graph(input_graph,
287                 input_saver,
288                 input_binary,
289                 input_checkpoint,
290                 output_node_names,
291                 restore_op_name,
292                 filename_tensor_name,
293                 output_graph,
294                 clear_devices,
295                 initializer_nodes,
296                 variable_names_whitelist="",
297                 variable_names_denylist="",
298                 input_meta_graph=None,
299                 input_saved_model_dir=None,
300                 saved_model_tags=tag_constants.SERVING,
301                 checkpoint_version=saver_pb2.SaverDef.V2):
302  """Converts all variables in a graph and checkpoint into constants.
303
304  Args:
305    input_graph: A `GraphDef` file to load.
306    input_saver: A TensorFlow Saver file.
307    input_binary: A Bool. True means input_graph is .pb, False indicates .pbtxt.
308    input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking
309      priority.  Typically the result of `Saver.save()` or that of
310      `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
311      V1/V2.
312    output_node_names: The name(s) of the output nodes, comma separated.
313    restore_op_name: Unused.
314    filename_tensor_name: Unused.
315    output_graph: String where to write the frozen `GraphDef`.
316    clear_devices: A Bool whether to remove device specifications.
317    initializer_nodes: Comma separated list of initializer nodes to run before
318                       freezing.
319    variable_names_whitelist: The set of variable names to convert (optional, by
320                              default, all variables are converted),
321    variable_names_denylist: The set of variable names to omit converting
322                              to constants (optional).
323    input_meta_graph: A `MetaGraphDef` file to load (optional).
324    input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file and
325                           variables (optional).
326    saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to
327                      load, in string format.
328    checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1
329                        or saver_pb2.SaverDef.V2).
330  Returns:
331    String that is the location of frozen GraphDef.
332  """
333  input_graph_def = None
334  if input_saved_model_dir:
335    input_graph_def = saved_model_utils.get_meta_graph_def(
336        input_saved_model_dir, saved_model_tags).graph_def
337  elif input_graph:
338    input_graph_def = _parse_input_graph_proto(input_graph, input_binary)
339  input_meta_graph_def = None
340  if input_meta_graph:
341    input_meta_graph_def = _parse_input_meta_graph_proto(
342        input_meta_graph, input_binary)
343  input_saver_def = None
344  if input_saver:
345    input_saver_def = _parse_input_saver_proto(input_saver, input_binary)
346  return freeze_graph_with_def_protos(
347      input_graph_def,
348      input_saver_def,
349      input_checkpoint,
350      output_node_names,
351      restore_op_name,
352      filename_tensor_name,
353      output_graph,
354      clear_devices,
355      initializer_nodes,
356      variable_names_whitelist,
357      variable_names_denylist,
358      input_meta_graph_def,
359      input_saved_model_dir,
360      [tag for tag in saved_model_tags.replace(" ", "").split(",") if tag],
361      checkpoint_version=checkpoint_version)
362
363
364def main(unused_args, flags):
365  if flags.checkpoint_version == 1:
366    checkpoint_version = saver_pb2.SaverDef.V1
367  elif flags.checkpoint_version == 2:
368    checkpoint_version = saver_pb2.SaverDef.V2
369  else:
370    raise ValueError("Invalid checkpoint version (must be '1' or '2'): %d" %
371                     flags.checkpoint_version)
372  freeze_graph(flags.input_graph, flags.input_saver, flags.input_binary,
373               flags.input_checkpoint, flags.output_node_names,
374               flags.restore_op_name, flags.filename_tensor_name,
375               flags.output_graph, flags.clear_devices, flags.initializer_nodes,
376               flags.variable_names_whitelist, flags.variable_names_denylist,
377               flags.input_meta_graph, flags.input_saved_model_dir,
378               flags.saved_model_tags, checkpoint_version)
379
380
381def run_main():
382  """Main function of freeze_graph."""
383  parser = argparse.ArgumentParser()
384  parser.register("type", "bool", lambda v: v.lower() == "true")
385  parser.add_argument(
386      "--input_graph",
387      type=str,
388      default="",
389      help="TensorFlow \'GraphDef\' file to load.")
390  parser.add_argument(
391      "--input_saver",
392      type=str,
393      default="",
394      help="TensorFlow saver file to load.")
395  parser.add_argument(
396      "--input_checkpoint",
397      type=str,
398      default="",
399      help="TensorFlow variables file to load.")
400  parser.add_argument(
401      "--checkpoint_version",
402      type=int,
403      default=2,
404      help="Tensorflow variable file format")
405  parser.add_argument(
406      "--output_graph",
407      type=str,
408      default="",
409      help="Output \'GraphDef\' file name.")
410  parser.add_argument(
411      "--input_binary",
412      nargs="?",
413      const=True,
414      type="bool",
415      default=False,
416      help="Whether the input files are in binary format.")
417  parser.add_argument(
418      "--output_node_names",
419      type=str,
420      default="",
421      help="The name of the output nodes, comma separated.")
422  parser.add_argument(
423      "--restore_op_name",
424      type=str,
425      default="save/restore_all",
426      help="""\
427      The name of the master restore operator. Deprecated, unused by updated \
428      loading code.
429      """)
430  parser.add_argument(
431      "--filename_tensor_name",
432      type=str,
433      default="save/Const:0",
434      help="""\
435      The name of the tensor holding the save path. Deprecated, unused by \
436      updated loading code.
437      """)
438  parser.add_argument(
439      "--clear_devices",
440      nargs="?",
441      const=True,
442      type="bool",
443      default=True,
444      help="Whether to remove device specifications.")
445  parser.add_argument(
446      "--initializer_nodes",
447      type=str,
448      default="",
449      help="Comma separated list of initializer nodes to run before freezing.")
450  parser.add_argument(
451      "--variable_names_whitelist",
452      type=str,
453      default="",
454      help="""\
455      Comma separated list of variables to convert to constants. If specified, \
456      only those variables will be converted to constants.\
457      """)
458  parser.add_argument(
459      "--variable_names_denylist",
460      type=str,
461      default="",
462      help="""\
463      Comma separated list of variables to skip converting to constants.\
464      """)
465  parser.add_argument(
466      "--input_meta_graph",
467      type=str,
468      default="",
469      help="TensorFlow \'MetaGraphDef\' file to load.")
470  parser.add_argument(
471      "--input_saved_model_dir",
472      type=str,
473      default="",
474      help="Path to the dir with TensorFlow \'SavedModel\' file and variables.")
475  parser.add_argument(
476      "--saved_model_tags",
477      type=str,
478      default="serve",
479      help="""\
480      Group of tag(s) of the MetaGraphDef to load, in string format,\
481      separated by \',\'. For tag-set contains multiple tags, all tags \
482      must be passed in.\
483      """)
484  flags, unparsed = parser.parse_known_args()
485
486  my_main = lambda unused_args: main(unused_args, flags)
487  app.run(main=my_main, argv=[sys.argv[0]] + unparsed)
488
489
490if __name__ == "__main__":
491  run_main()
492