• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""MetaGraph and related functions."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import copy
22import os.path
23import re
24
25import six
26from google.protobuf.any_pb2 import Any
27from google.protobuf import text_format
28
29from tensorflow.core.framework import attr_value_pb2
30from tensorflow.core.framework import graph_pb2
31from tensorflow.core.framework import op_def_pb2
32from tensorflow.core.protobuf import graph_debug_info_pb2
33from tensorflow.core.protobuf import meta_graph_pb2
34from tensorflow.core.protobuf import saver_pb2
35from tensorflow.python import pywrap_tensorflow
36from tensorflow.python.eager import context
37from tensorflow.python.framework import error_interpolation
38from tensorflow.python.framework import graph_io
39from tensorflow.python.framework import importer
40from tensorflow.python.framework import op_def_registry
41from tensorflow.python.framework import ops
42from tensorflow.python.framework import versions
43from tensorflow.python.lib.io import file_io
44from tensorflow.python.platform import tf_logging as logging
45from tensorflow.python.util import compat
46
47
48# Prefix to be added to unbound input names so they are easily identifiable.
49_UNBOUND_INPUT_PREFIX = "$unbound_inputs_"
50
51# List of collections that didn't register proto functions, as a result in
52# a previously exported meta_graph the items are of a different data type.
53_COMPAT_COLLECTION_LIST = [ops.GraphKeys.LOCAL_VARIABLES,
54                           ops.GraphKeys.MODEL_VARIABLES]
55
56
57def _node_def(from_node_def, export_scope, unbound_inputs, clear_devices=False):
58  """Create a `NodeDef` proto with export_scope stripped.
59
60  Args:
61    from_node_def: A `node_def_pb2.NodeDef` protocol buffer.
62    export_scope: A `string` representing the name scope to remove.
63    unbound_inputs: An array of unbound input names if they exist.
64    clear_devices: Boolean which controls whether to clear device information
65      from node_def. Default false.
66
67  Returns:
68    A `node_def_pb2.NodeDef` protocol buffer.
69  """
70  node_def = copy.deepcopy(from_node_def)
71  for i, v in enumerate(node_def.input):
72    if (export_scope and
73        not node_def.input[i].lstrip("^").startswith(export_scope)):
74      # Adds "$unbound_inputs_" prefix to the unbound name so they are easily
75      # identifiable.
76      node_def.input[i] = re.sub(r"([\^]|^)(.*)",
77                                 r"\1" + _UNBOUND_INPUT_PREFIX + r"\2",
78                                 compat.as_str(v))
79      unbound_inputs.append(node_def.input[i])
80    else:
81      node_def.input[i] = ops.strip_name_scope(v, export_scope)
82  node_def.name = compat.as_bytes(
83      ops.strip_name_scope(from_node_def.name, export_scope))
84  for k, v in six.iteritems(from_node_def.attr):
85    if k == "_class":
86      new_s = [compat.as_bytes(
87          ops.strip_name_scope(s, export_scope)) for s in v.list.s
88               if not export_scope or
89               compat.as_str(s).split("@")[1].startswith(export_scope)]
90      node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(
91          list=attr_value_pb2.AttrValue.ListValue(s=new_s)))
92    elif node_def.op in ("Enter", "RefEnter") and k == "frame_name":
93      if not export_scope or compat.as_str(v.s).startswith(export_scope):
94        new_s = compat.as_bytes(ops.strip_name_scope(v.s, export_scope))
95      node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(s=new_s))
96    else:
97      node_def.attr[k].CopyFrom(v)
98
99  if clear_devices:
100    node_def.device = ""
101
102  return node_def
103
104
105def _read_file(filename):
106  """Reads a file containing `GraphDef` and returns the protocol buffer.
107
108  Args:
109    filename: `graph_def` filename including the path.
110
111  Returns:
112    A `GraphDef` protocol buffer.
113
114  Raises:
115    IOError: If the file doesn't exist, or cannot be successfully parsed.
116  """
117  graph_def = graph_pb2.GraphDef()
118  if not file_io.file_exists(filename):
119    raise IOError("File %s does not exist." % filename)
120  # First try to read it as a binary file.
121  file_content = file_io.FileIO(filename, "rb").read()
122  try:
123    graph_def.ParseFromString(file_content)
124    return graph_def
125  except Exception:  # pylint: disable=broad-except
126    pass
127
128  # Next try to read it as a text file.
129  try:
130    text_format.Merge(file_content, graph_def)
131  except text_format.ParseError as e:
132    raise IOError("Cannot parse file %s: %s." % (filename, str(e)))
133
134  return graph_def
135
136
137def ops_used_by_graph_def(graph_def):
138  """Collect the list of ops used by a graph.
139
140  Does not validate that the ops are all registered.
141
142  Args:
143    graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`.
144
145  Returns:
146    A list of strings, each naming an op used by the graph.
147  """
148  # Map function names to definitions
149  name_to_function = {}
150  for fun in graph_def.library.function:
151    name_to_function[fun.signature.name] = fun
152
153  # Collect the list of op names.  Since functions can reference functions, we
154  # need a recursive traversal.
155  used_ops = set()  # Includes both primitive ops and functions
156  functions_to_process = []  # A subset of used_ops
157
158  def mark_op_as_used(op):
159    if op not in used_ops and op in name_to_function:
160      functions_to_process.append(name_to_function[op])
161    used_ops.add(op)
162
163  for node in graph_def.node:
164    mark_op_as_used(node.op)
165  while functions_to_process:
166    fun = functions_to_process.pop()
167    for node in fun.node_def:
168      mark_op_as_used(node.op)
169
170  return [op for op in used_ops if op not in name_to_function]
171
172
173def stripped_op_list_for_graph(graph_def):
174  """Collect the stripped OpDefs for ops used by a graph.
175
176  This function computes the `stripped_op_list` field of `MetaGraphDef` and
177  similar protos.  The result can be communicated from the producer to the
178  consumer, which can then use the C++ function
179  `RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility.
180
181  Args:
182    graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`.
183
184  Returns:
185    An `OpList` of ops used by the graph.
186
187  Raises:
188    ValueError: If an unregistered op is used.
189  """
190  # This is the Python equivalent of StrippedOpListForGraph in C++.
191  # Unfortunately, since the Python op registry can differ from that in C++, we
192  # can't remove the duplication using swig (at least naively).
193  # TODO(irving): Support taking graphs directly.
194
195  used_ops = ops_used_by_graph_def(graph_def)
196
197  # Verify that all used ops are registered.
198  registered_ops = op_def_registry.get_registered_ops()
199  # These internal ops used by functions are not registered, so we need to
200  # whitelist them.  # TODO(irving): Do something better here.
201  op_whitelist = ("_Arg", "_Retval", "_ListToArray", "_ArrayToList")
202  for op in used_ops:
203    if op not in registered_ops and op not in op_whitelist:
204      raise ValueError("Op %s is used by the graph, but is not registered" % op)
205
206  # Build the stripped op list in sorted order
207  return op_def_pb2.OpList(op=[registered_ops[op] for op in sorted(used_ops)
208                               if op in registered_ops])
209
210
211def _get_kind_name(item):
212  """Returns the kind name in CollectionDef.
213
214  Args:
215    item: A data item.
216
217  Returns:
218    The string representation of the kind in CollectionDef.
219  """
220  if isinstance(item, (six.string_types, six.binary_type)):
221    kind = "bytes_list"
222  elif isinstance(item, six.integer_types):
223    kind = "int64_list"
224  elif isinstance(item, float):
225    kind = "float_list"
226  elif isinstance(item, Any):
227    kind = "any_list"
228  else:
229    kind = "node_list"
230  return kind
231
232
233SAVE_AND_RESTORE_OPS = ["SaveV2",
234                        "Save", "SaveSlice",
235                        "LegacySave", "LegacySaveSlice",
236                        "RestoreV2",
237                        "Restore", "RestoreSlice",
238                        "LegacyRestore", "LegacyRestoreSlice"]
239
240
241def _op_name(tensor_name):
242  """Extract the Op name from a Tensor name.
243
244  The Op name is everything before a colon, if present,
245  not including any ^ prefix denoting a control dependency.
246
247  Args:
248    tensor_name: the full name of a Tensor in the graph.
249  Returns:
250    The name of the Op of which the given Tensor is an output.
251  Raises:
252    ValueError: if tensor_name is None or empty.
253  """
254  if not tensor_name:
255    raise ValueError("Tensor name cannot be empty or None.")
256
257  # Control dependency inputs start with ^.
258  if tensor_name.startswith("^"):
259    tensor_name = tensor_name[1:]
260  if ":" in tensor_name:
261    op_name, _ = tensor_name.split(":")
262    return op_name
263  return tensor_name
264
265
266def _get_scope(node_name):
267  """Extract the scope name from a node name.
268
269  The scope name is everything before the final slash,
270  not including any ^ prefix denoting a control dependency.
271
272  Args:
273    node_name: the full name of an Op or a Tensor in the graph.
274  Returns:
275    The deepest named scope containing the node.
276  Raises:
277    ValueError: if tensor_name is None or empty
278  """
279  if not node_name:
280    raise ValueError("Node name cannot be empty or None.")
281
282  # Control dependency inputs start with ^.
283  if node_name.startswith("^"):
284    node_name = node_name[1:]
285  if "/" in node_name:
286    scope, _ = node_name.rsplit("/", 1)
287    return scope
288
289  return ""
290
291
292def _find_extraneous_saver_nodes(graph_def, saver_def):
293  """Identifies any nodes in the graph_def related to unused Savers.
294
295  This approach assumes that each Saver is cleanly isolated in its own name
296  scope, so we need only identify the scopes associated with extraneous Savers
297  and return all the nodes in those scopes.
298
299  Args:
300    graph_def: a GraphDef proto to evaluate.
301    saver_def: a SaverDef proto referencing Save/Restore ops to be retained.
302  Returns:
303    An iterable of node names that may be safely omitted.
304  """
305  # TODO(soergel): confirm that the assumption of scope isolation is valid.
306  # If not, we need to walk up the graph from any restore_all nodes, and walk
307  # down the graph from any Save/Restore nodes.  I drafted that approach too,
308  # but it seems unnecessarily complex given the name scope solution.
309
310  # load the graph DAG in minimal form, without initializing a full Graph object
311  nodes = {node_def.name:
312           (set([_op_name(x) for x in node_def.input]), node_def.op)
313           for node_def in graph_def.node}
314
315  retain_scope_save = None
316  retain_scope_restore = None
317  # It's possible to have no saver if the graph has no Variables
318  if saver_def is not None:
319    save_op_name = _op_name(saver_def.save_tensor_name)
320    restore_op_name = _op_name(saver_def.restore_op_name)
321
322    # The save and restore scopes should always be the same, but if they differ
323    # for some reason, we retain them both to be safe.
324    retain_scope_restore = _get_scope(restore_op_name) + "/"
325    retain_scope_save = _get_scope(save_op_name) + "/"
326
327  all_saver_node_names = set([name for name, (_, op) in nodes.items()
328                              if op in SAVE_AND_RESTORE_OPS])
329
330  all_saver_scopes = (set([_get_scope(x) for x in all_saver_node_names])
331                      - all_saver_node_names)
332  all_saver_scopes = set([x + "/" for x in all_saver_scopes])
333
334  extraneous_scopes = all_saver_scopes - set([retain_scope_save,
335                                              retain_scope_restore])
336
337  extraneous_node_names = set()
338  for name, _ in nodes.items():
339    for extraneous_scope in extraneous_scopes:
340      if name.startswith(extraneous_scope):
341        extraneous_node_names.add(name)
342        break
343
344  return extraneous_node_names
345
346
347def _should_include_node(node_or_node_name, export_scope, exclude_nodes):
348  """Returns `True` if a node should be included.
349
350  Args:
351    node_or_node_name: A node or `string` node name.
352    export_scope: `string`. Name scope under which to extract the subgraph. The
353      scope name will be stripped from the node definitions for easy import
354      later into new name scopes.
355    exclude_nodes: An iterable of nodes or `string` node names to omit from the
356      export, or None.  Note no sanity-checking is done, so this list must be
357      carefully constructed to avoid producing an invalid graph.
358
359  Returns:
360    `True` if the node should be included.
361  """
362  if not isinstance(node_or_node_name, six.string_types):
363    try:
364      node_name = node_or_node_name.name
365    except AttributeError:
366      # Keep the object that we don't know how to process.
367      return True
368  else:
369    node_name = node_or_node_name
370
371  if exclude_nodes and (node_or_node_name in exclude_nodes
372                        or node_name in exclude_nodes):
373    return False
374
375  return (node_name.startswith(_UNBOUND_INPUT_PREFIX) or
376          (not export_scope or node_name.startswith(export_scope)))
377
378
379def add_collection_def(meta_graph_def, key, graph=None,
380                       export_scope=None, exclude_nodes=None,
381                       override_contents=None):
382  """Adds a collection to MetaGraphDef protocol buffer.
383
384  Args:
385    meta_graph_def: MetaGraphDef protocol buffer.
386    key: One of the GraphKeys or user-defined string.
387    graph: The `Graph` from which to get collections.
388    export_scope: Optional `string`. Name scope to remove.
389    exclude_nodes: An iterable of nodes or `string` node names to omit from the
390      collection, or None.
391    override_contents: An iterable of values to place in the collection,
392      ignoring the current values (if set).
393  """
394  if graph and not isinstance(graph, ops.Graph):
395    raise TypeError("graph must be of type Graph, not %s", type(graph))
396
397  if not isinstance(key, six.string_types) and not isinstance(key, bytes):
398    logging.warning("Only collections with string type keys will be "
399                    "serialized. This key has %s", type(key))
400    return
401
402  # Sets graph to default graph if it's not passed in.
403  graph = graph or ops.get_default_graph()
404
405  if override_contents:
406    collection_list = override_contents
407  else:
408    collection_list = graph.get_collection(key)
409
410  # Remove nodes that should not be exported from the collection list.
411  collection_list = [x for x in collection_list if
412                     _should_include_node(x, export_scope, exclude_nodes)]
413  if not collection_list:
414    return
415
416  try:
417    col_def = meta_graph_def.collection_def[key]
418    to_proto = ops.get_to_proto_function(key)
419    proto_type = ops.get_collection_proto_type(key)
420    if to_proto:
421      kind = "bytes_list"
422      for x in collection_list:
423        # Additional type check to make sure the returned proto is indeed
424        # what we expect.
425        proto = to_proto(x, export_scope=export_scope)
426        if proto:
427          assert isinstance(proto, proto_type)
428          getattr(col_def, kind).value.append(proto.SerializeToString())
429    else:
430      kind = _get_kind_name(collection_list[0])
431      if kind == "node_list":
432        for x in collection_list:
433          if not export_scope or x.name.startswith(export_scope):
434            getattr(col_def, kind).value.append(
435                ops.strip_name_scope(x.name, export_scope))
436      elif kind == "bytes_list":
437        # NOTE(opensource): This force conversion is to work around the fact
438        # that Python3 distinguishes between bytes and strings.
439        getattr(col_def, kind).value.extend(
440            [compat.as_bytes(x) for x in collection_list])
441      else:
442        getattr(col_def, kind).value.extend([x for x in collection_list])
443  except Exception as e:  # pylint: disable=broad-except
444    logging.warning("Issue encountered when serializing %s.\n"
445                    "Type is unsupported, or the types of the items don't "
446                    "match field type in CollectionDef. Note this is a warning "
447                    "and probably safe to ignore.\n%s", key, str(e))
448    if key in meta_graph_def.collection_def:
449      del meta_graph_def.collection_def[key]
450    return
451
452
453def _is_default_attr_value(op_def, attr_name, attr_value):
454  """Checks if given attribute matches the default value in the op def."""
455  for attr_def in op_def.attr:
456    if attr_def.name == attr_name:
457      if not attr_def.HasField("default_value"):
458        return False
459      # pywrap_tensorflow.EqualAttrValueWrapper returns an empty string
460      # if both arguments represent an equivalent AttrValue instance.
461      return not pywrap_tensorflow.EqualAttrValueWrapper(
462          attr_value.SerializeToString(),
463          attr_def.default_value.SerializeToString())
464  return False
465
466
467def strip_graph_default_valued_attrs(meta_graph_def):
468  """Strips default valued attributes for node defs in given MetaGraphDef.
469
470  This method also sets `meta_info_def.stripped_default_attrs` in the given
471  `MetaGraphDef` proto to True.
472
473  Args:
474    meta_graph_def: `MetaGraphDef` protocol buffer
475
476  Returns:
477    None.
478  """
479  # Map function op names to their function definitions.
480  op_name_to_function = {}
481  for function_def in meta_graph_def.graph_def.library.function:
482    op_name_to_function[function_def.signature.name] = function_def
483
484  # Get all registered ops.
485  registered_ops = op_def_registry.get_registered_ops()
486
487  def _strip_node_default_valued_attrs(node_def):
488    """Removes default valued attributes from a single node def."""
489    if node_def.op in op_name_to_function or node_def.op not in registered_ops:
490      return
491    op_def = registered_ops[node_def.op]
492
493    attrs_to_strip = set()
494    for attr_name, attr_value in node_def.attr.items():
495      if _is_default_attr_value(op_def, attr_name, attr_value):
496        attrs_to_strip.add(attr_name)
497
498    for attr in attrs_to_strip:
499      del node_def.attr[attr]
500
501  # Process all NodeDef instances in graph_def.
502  for node_def in meta_graph_def.graph_def.node:
503    _strip_node_default_valued_attrs(node_def)
504
505  # Process all NodeDef instances in graph_def.library.function.
506  for function_def in meta_graph_def.graph_def.library.function:
507    for function_node_def in function_def.node_def:
508      _strip_node_default_valued_attrs(function_node_def)
509
510  # Tell consumers of this graph that default valued attrs have been stripped.
511  meta_graph_def.meta_info_def.stripped_default_attrs = True
512
513
514def create_graph_debug_info_def(operations):
515  """Construct and returns a `GraphDebugInfo` protocol buffer.
516
517  Args:
518    operations: An iterable of op.Operation objects having _traceback members.
519
520  Returns:
521    GraphDebugInfo protocol buffer.
522
523  Raises:
524    TypeError: If the arguments are not of the correct proto buffer type.
525  """
526  # Creates an empty GraphDebugInfoDef proto.
527  graph_debug_info_def = graph_debug_info_pb2.GraphDebugInfo()
528
529  # Gets the file names and line numbers for the exported node names. Also
530  # collects the unique file names.
531  all_file_names = set()
532  node_to_trace = {}
533  for op in operations:
534    # Gets the stack trace of the operation and then the file location.
535    node_name = op.name
536    node_to_trace[node_name] = error_interpolation.compute_useful_stack(op)
537    for trace in node_to_trace[node_name]:
538      all_file_names.add(trace[0])
539
540  # Sets the `files` field in the GraphDebugInfo proto
541  graph_debug_info_def.files.extend(all_file_names)
542
543  # Builds a mapping between file names and index of the `files` field, so we
544  # only store the indexes for the nodes in the GraphDebugInfo.
545  file_to_index = dict(
546      [(y, x) for x, y in enumerate(graph_debug_info_def.files)])
547
548  # Creates the FileLineCol proto for each node and sets the value in the
549  # GraphDebugInfo proto. We only store the file name index for each node to
550  # save the storage space.
551  for node_name, trace in node_to_trace.items():
552    trace_def = graph_debug_info_def.traces[node_name]
553    for file_name, line, func, code in trace:
554      file_index = file_to_index[file_name]
555      trace_def.file_line_cols.add(
556          file_index=file_index, line=line, func=func, code=code)
557
558  return graph_debug_info_def
559
560
561def create_meta_graph_def(meta_info_def=None,
562                          graph_def=None,
563                          saver_def=None,
564                          collection_list=None,
565                          graph=None,
566                          export_scope=None,
567                          exclude_nodes=None,
568                          clear_extraneous_savers=False,
569                          strip_default_attrs=False):
570  # pylint: disable=line-too-long
571  """Construct and returns a `MetaGraphDef` protocol buffer.
572
573  Args:
574    meta_info_def: `MetaInfoDef` protocol buffer.
575    graph_def: `GraphDef` protocol buffer.
576    saver_def: `SaverDef` protocol buffer.
577    collection_list: List of string keys to collect.
578    graph: The `Graph` to create `MetaGraphDef` out of.
579    export_scope: Optional `string`. Name scope to remove.
580    exclude_nodes: An iterable of nodes or `string` node names to omit from all
581      collection, or None.
582    clear_extraneous_savers: Remove any preexisting SaverDefs from the SAVERS
583        collection.  Note this method does not alter the graph, so any
584        extraneous Save/Restore ops should have been removed already, as needed.
585    strip_default_attrs: Boolean. If `True`, default-valued attributes will be
586        removed from the NodeDefs. For a detailed guide, see
587        [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
588
589  Returns:
590    MetaGraphDef protocol buffer.
591
592  Raises:
593    TypeError: If the arguments are not of the correct proto buffer type.
594  """
595  # pylint: enable=line-too-long
596  # Type check.
597  if graph and not isinstance(graph, ops.Graph):
598    raise TypeError("graph must be of type Graph, not %s", type(graph))
599  if meta_info_def and not isinstance(meta_info_def,
600                                      meta_graph_pb2.MetaGraphDef.MetaInfoDef):
601    raise TypeError("meta_info_def must be of type MetaInfoDef, not %s",
602                    type(meta_info_def))
603  if graph_def and not isinstance(graph_def, graph_pb2.GraphDef):
604    raise TypeError("graph_def must be of type GraphDef, not %s",
605                    type(graph_def))
606  if saver_def and not isinstance(saver_def, saver_pb2.SaverDef):
607    raise TypeError("saver_def must be of type SaverDef, not %s",
608                    type(saver_def))
609
610  # Sets graph to default graph if it's not passed in.
611  graph = graph or ops.get_default_graph()
612
613  # Creates a MetaGraphDef proto.
614  meta_graph_def = meta_graph_pb2.MetaGraphDef()
615  # Adds meta_info_def.
616  if not meta_info_def:
617    meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef()
618
619  # Set the tf version strings to the current tf build.
620  meta_info_def.tensorflow_version = versions.__version__
621  meta_info_def.tensorflow_git_version = versions.__git_version__
622  meta_graph_def.meta_info_def.MergeFrom(meta_info_def)
623
624  # Adds graph_def or the default.
625  if not graph_def:
626    meta_graph_def.graph_def.MergeFrom(graph.as_graph_def(add_shapes=True))
627  else:
628    meta_graph_def.graph_def.MergeFrom(graph_def)
629
630  # Fills in meta_info_def.stripped_op_list using the ops from graph_def.
631  # pylint: disable=g-explicit-length-test
632  if len(meta_graph_def.meta_info_def.stripped_op_list.op) == 0:
633    meta_graph_def.meta_info_def.stripped_op_list.MergeFrom(
634        stripped_op_list_for_graph(meta_graph_def.graph_def))
635  # pylint: enable=g-explicit-length-test
636
637  # Strip default valued attributes in graph_def.
638  if strip_default_attrs:
639    strip_graph_default_valued_attrs(meta_graph_def)
640
641  # Adds saver_def.
642  if saver_def:
643    meta_graph_def.saver_def.MergeFrom(saver_def)
644
645  # Adds collection_list.
646  if collection_list is not None:
647    clist = collection_list
648  else:
649    clist = graph.get_all_collection_keys()
650
651  for ctype in clist:
652    if clear_extraneous_savers and ctype == ops.GraphKeys.SAVERS:
653      # Avoid importing Saver here
654      from_proto = ops.get_from_proto_function(ctype)
655      add_collection_def(meta_graph_def, ctype,
656                         graph=graph,
657                         export_scope=export_scope,
658                         exclude_nodes=exclude_nodes,
659                         override_contents=[from_proto(saver_def)])
660    else:
661      add_collection_def(meta_graph_def, ctype,
662                         graph=graph,
663                         export_scope=export_scope,
664                         exclude_nodes=exclude_nodes)
665  return meta_graph_def
666
667
668def read_meta_graph_file(filename):
669  """Reads a file containing `MetaGraphDef` and returns the protocol buffer.
670
671  Args:
672    filename: `meta_graph_def` filename including the path.
673
674  Returns:
675    A `MetaGraphDef` protocol buffer.
676
677  Raises:
678    IOError: If the file doesn't exist, or cannot be successfully parsed.
679  """
680  meta_graph_def = meta_graph_pb2.MetaGraphDef()
681  if not file_io.file_exists(filename):
682    raise IOError("File %s does not exist." % filename)
683  # First try to read it as a binary file.
684  file_content = file_io.FileIO(filename, "rb").read()
685  try:
686    meta_graph_def.ParseFromString(file_content)
687    return meta_graph_def
688  except Exception:  # pylint: disable=broad-except
689    pass
690
691  # Next try to read it as a text file.
692  try:
693    text_format.Merge(file_content.decode("utf-8"), meta_graph_def)
694  except text_format.ParseError as e:
695    raise IOError("Cannot parse file %s: %s." % (filename, str(e)))
696
697  return meta_graph_def
698
699
700def import_scoped_meta_graph(meta_graph_or_file,
701                             clear_devices=False,
702                             graph=None,
703                             import_scope=None,
704                             input_map=None,
705                             unbound_inputs_col_name="unbound_inputs",
706                             restore_collections_predicate=(lambda key: True)):
707  """Recreates a `Graph` saved in a `MetaGraphDef` proto.
708
709  This function takes a `MetaGraphDef` protocol buffer as input. If
710  the argument is a file containing a `MetaGraphDef` protocol buffer ,
711  it constructs a protocol buffer from the file content. The function
712  then adds all the nodes from the `graph_def` field to the
713  current graph, recreates the desired collections, and returns a dictionary of
714  all the Variables imported into the name scope.
715
716  In combination with `export_scoped_meta_graph()`, this function can be used to
717
718  * Serialize a graph along with other Python objects such as `QueueRunner`,
719    `Variable` into a `MetaGraphDef`.
720
721  * Restart training from a saved graph and checkpoints.
722
723  * Run inference from a saved graph and checkpoints.
724
725  Args:
726    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
727      the path) containing a `MetaGraphDef`.
728    clear_devices: Boolean which controls whether to clear device information
729      from graph_def. Default false.
730    graph: The `Graph` to import into. If `None`, use the default graph.
731    import_scope: Optional `string`. Name scope into which to import the
732      subgraph. If `None`, the graph is imported to the root name scope.
733    input_map: A dictionary mapping input names (as strings) in `graph_def` to
734      `Tensor` objects. The values of the named input tensors in the imported
735      graph will be re-mapped to the respective `Tensor` values.
736    unbound_inputs_col_name: Collection name for looking up unbound inputs.
737    restore_collections_predicate: a predicate on collection names. A collection
738      named c (i.e whose key is c) will be restored iff
739      1) `restore_collections_predicate(c)` is True, and
740      2) `c != unbound_inputs_col_name`.
741
742  Returns:
743    A dictionary of all the `Variables` imported into the name scope.
744
745  Raises:
746    ValueError: If the graph_def contains unbound inputs.
747  """
748  return import_scoped_meta_graph_with_return_elements(
749      meta_graph_or_file, clear_devices, graph, import_scope, input_map,
750      unbound_inputs_col_name, restore_collections_predicate)[0]
751
752
753def import_scoped_meta_graph_with_return_elements(
754    meta_graph_or_file,
755    clear_devices=False,
756    graph=None,
757    import_scope=None,
758    input_map=None,
759    unbound_inputs_col_name="unbound_inputs",
760    restore_collections_predicate=(lambda key: True),
761    return_elements=None):
762  """Imports graph from `MetaGraphDef` and returns vars and return elements.
763
764  This function takes a `MetaGraphDef` protocol buffer as input. If
765  the argument is a file containing a `MetaGraphDef` protocol buffer ,
766  it constructs a protocol buffer from the file content. The function
767  then adds all the nodes from the `graph_def` field to the
768  current graph, recreates the desired collections, and returns a dictionary of
769  all the Variables imported into the name scope.
770
771  In combination with `export_scoped_meta_graph()`, this function can be used to
772
773  * Serialize a graph along with other Python objects such as `QueueRunner`,
774    `Variable` into a `MetaGraphDef`.
775
776  * Restart training from a saved graph and checkpoints.
777
778  * Run inference from a saved graph and checkpoints.
779
780  Args:
781    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
782      the path) containing a `MetaGraphDef`.
783    clear_devices: Boolean which controls whether to clear device information
784      from graph_def. Default false.
785    graph: The `Graph` to import into. If `None`, use the default graph.
786    import_scope: Optional `string`. Name scope into which to import the
787      subgraph. If `None`, the graph is imported to the root name scope.
788    input_map: A dictionary mapping input names (as strings) in `graph_def` to
789      `Tensor` objects. The values of the named input tensors in the imported
790      graph will be re-mapped to the respective `Tensor` values.
791    unbound_inputs_col_name: Collection name for looking up unbound inputs.
792    restore_collections_predicate: a predicate on collection names. A collection
793      named c (i.e whose key is c) will be restored iff
794      1) `restore_collections_predicate(c)` is True, and
795      2) `c != unbound_inputs_col_name`.
796    return_elements:  A list of strings containing operation names in the
797      `MetaGraphDef` that will be returned as `Operation` objects; and/or
798      tensor names in `MetaGraphDef` that will be returned as `Tensor` objects.
799
800  Returns:
801    A tuple of (
802      dictionary of all the `Variables` imported into the name scope,
803      list of `Operation` or `Tensor` objects from the `return_elements` list).
804
805  Raises:
806    ValueError: If the graph_def contains unbound inputs.
807
808  """
809  if context.executing_eagerly():
810    raise ValueError("Exporting/importing meta graphs is not supported when "
811                     "eager execution is enabled.")
812  if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
813    meta_graph_def = meta_graph_or_file
814  else:
815    meta_graph_def = read_meta_graph_file(meta_graph_or_file)
816
817  if unbound_inputs_col_name:
818    for key, col_def in meta_graph_def.collection_def.items():
819      if key == unbound_inputs_col_name:
820        kind = col_def.WhichOneof("kind")
821        field = getattr(col_def, kind)
822        if field.value and (
823            not input_map or
824            sorted([compat.as_str(v) for v in field.value]) !=
825            sorted(input_map)):
826          raise ValueError("Graph contains unbound inputs: %s. Must "
827                           "provide these inputs through input_map." %
828                           ",".join([compat.as_str(v) for v in field.value
829                                     if not input_map or v not in input_map]))
830        break
831
832  # Sets graph to default graph if it's not passed in.
833  graph = graph or ops.get_default_graph()
834
835  # Gathers the list of nodes we are interested in.
836  with graph.as_default():
837    producer_op_list = None
838    if meta_graph_def.meta_info_def.HasField("stripped_op_list"):
839      producer_op_list = meta_graph_def.meta_info_def.stripped_op_list
840    input_graph_def = meta_graph_def.graph_def
841    # Remove all the explicit device specifications for this node. This helps to
842    # make the graph more portable.
843    if clear_devices:
844      for node in input_graph_def.node:
845        node.device = ""
846
847    scope_to_prepend_to_names = graph.unique_name(
848        import_scope or "", mark_as_used=False)
849
850    imported_return_elements = importer.import_graph_def(
851        input_graph_def,
852        name=(import_scope or scope_to_prepend_to_names),
853        input_map=input_map,
854        producer_op_list=producer_op_list,
855        return_elements=return_elements)
856
857    # Restores all the other collections.
858    variable_objects = {}
859    for key, col_def in sorted(meta_graph_def.collection_def.items()):
860      # Don't add unbound_inputs to the new graph.
861      if key == unbound_inputs_col_name:
862        continue
863      if not restore_collections_predicate(key):
864        continue
865
866      kind = col_def.WhichOneof("kind")
867      if kind is None:
868        logging.error("Cannot identify data type for collection %s. Skipping.",
869                      key)
870        continue
871      from_proto = ops.get_from_proto_function(key)
872      if from_proto and kind == "bytes_list":
873        proto_type = ops.get_collection_proto_type(key)
874        if key in ops.GraphKeys._VARIABLE_COLLECTIONS:  # pylint: disable=protected-access
875          for value in col_def.bytes_list.value:
876            variable = variable_objects.get(value, None)
877            if variable is None:
878              proto = proto_type()
879              proto.ParseFromString(value)
880              variable = from_proto(
881                  proto, import_scope=scope_to_prepend_to_names)
882              variable_objects[value] = variable
883            graph.add_to_collection(key, variable)
884        else:
885          for value in col_def.bytes_list.value:
886            proto = proto_type()
887            proto.ParseFromString(value)
888            graph.add_to_collection(
889                key, from_proto(
890                    proto, import_scope=scope_to_prepend_to_names))
891      else:
892        field = getattr(col_def, kind)
893        if key in _COMPAT_COLLECTION_LIST:
894          logging.warning(
895              "The saved meta_graph is possibly from an older release:\n"
896              "'%s' collection should be of type 'byte_list', but instead "
897              "is of type '%s'.", key, kind)
898        if kind == "node_list":
899          for value in field.value:
900            col_op = graph.as_graph_element(
901                ops.prepend_name_scope(value, scope_to_prepend_to_names))
902            graph.add_to_collection(key, col_op)
903        elif kind == "int64_list":
904          # NOTE(opensource): This force conversion is to work around the fact
905          # that Python2 distinguishes between int and long, while Python3 has
906          # only int.
907          for value in field.value:
908            graph.add_to_collection(key, int(value))
909        else:
910          for value in field.value:
911            graph.add_to_collection(
912                key, ops.prepend_name_scope(value, scope_to_prepend_to_names))
913
914    var_list = {}
915    variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
916                                     scope=scope_to_prepend_to_names)
917    for v in variables:
918      var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v
919
920  return var_list, imported_return_elements
921
922
923def export_scoped_meta_graph(filename=None,
924                             graph_def=None,
925                             graph=None,
926                             export_scope=None,
927                             as_text=False,
928                             unbound_inputs_col_name="unbound_inputs",
929                             clear_devices=False,
930                             saver_def=None,
931                             clear_extraneous_savers=False,
932                             strip_default_attrs=False,
933                             save_debug_info=False,
934                             **kwargs):
935  """Returns `MetaGraphDef` proto. Optionally writes it to filename.
936
937  This function exports the graph, saver, and collection objects into
938  `MetaGraphDef` protocol buffer with the intention of it being imported
939  at a later time or location to restart training, run inference, or be
940  a subgraph.
941
942  Args:
943    filename: Optional filename including the path for writing the
944      generated `MetaGraphDef` protocol buffer.
945    graph_def: `GraphDef` protocol buffer.
946    graph: The `Graph` to export. If `None`, use the default graph.
947    export_scope: Optional `string`. Name scope under which to extract
948      the subgraph. The scope name will be stripped from the node definitions
949      for easy import later into new name scopes. If `None`, the whole graph
950      is exported.
951    as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto.
952    unbound_inputs_col_name: Optional `string`. If provided, a string collection
953      with the given name will be added to the returned `MetaGraphDef`,
954      containing the names of tensors that must be remapped when importing the
955      `MetaGraphDef`.
956    clear_devices: Boolean which controls whether to clear device information
957      before exporting the graph.
958    saver_def: `SaverDef` protocol buffer.
959    clear_extraneous_savers: Remove any Saver-related information from the
960        graph (both Save/Restore ops and SaverDefs) that are not associated
961        with the provided SaverDef.
962    strip_default_attrs: Set to true if default valued attributes must be
963      removed while exporting the GraphDef.
964    save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
965      which in the same directory of filename and with `_debug` added before the
966      file extension.
967    **kwargs: Optional keyed arguments, including meta_info_def and
968        collection_list.
969
970  Returns:
971    A `MetaGraphDef` proto and dictionary of `Variables` in the exported
972    name scope.
973
974  Raises:
975    ValueError: When the `GraphDef` is larger than 2GB.
976    ValueError: When executing in Eager mode and either `graph_def` or `graph`
977      is undefined.
978  """
979  if context.executing_eagerly() and not (graph_def is not None and
980                                          graph is not None):
981    raise ValueError("Exporting/importing meta graphs is not supported when "
982                     "Eager Execution is enabled.")
983  graph = graph or ops.get_default_graph()
984
985  exclude_nodes = None
986  unbound_inputs = []
987  if export_scope or clear_extraneous_savers or clear_devices:
988    if graph_def:
989      new_graph_def = graph_pb2.GraphDef()
990      new_graph_def.versions.CopyFrom(graph_def.versions)
991      new_graph_def.library.CopyFrom(graph_def.library)
992
993      if clear_extraneous_savers:
994        exclude_nodes = _find_extraneous_saver_nodes(graph_def, saver_def)
995
996      for node_def in graph_def.node:
997        if _should_include_node(node_def.name, export_scope, exclude_nodes):
998          new_node_def = _node_def(node_def, export_scope, unbound_inputs,
999                                   clear_devices=clear_devices)
1000          new_graph_def.node.extend([new_node_def])
1001      graph_def = new_graph_def
1002    else:
1003      # Only do this complicated work if we want to remove a name scope.
1004      graph_def = graph_pb2.GraphDef()
1005      # pylint: disable=protected-access
1006      graph_def.versions.CopyFrom(graph.graph_def_versions)
1007      bytesize = 0
1008
1009      if clear_extraneous_savers:
1010        exclude_nodes = _find_extraneous_saver_nodes(graph.as_graph_def(),
1011                                                     saver_def)
1012
1013      for key in sorted(graph._nodes_by_id):
1014        if _should_include_node(graph._nodes_by_id[key].name,
1015                                export_scope,
1016                                exclude_nodes):
1017          value = graph._nodes_by_id[key]
1018          # pylint: enable=protected-access
1019          node_def = _node_def(value.node_def, export_scope, unbound_inputs,
1020                               clear_devices=clear_devices)
1021          graph_def.node.extend([node_def])
1022          if value.outputs:
1023            assert "_output_shapes" not in graph_def.node[-1].attr
1024            graph_def.node[-1].attr["_output_shapes"].list.shape.extend([
1025                output.get_shape().as_proto() for output in value.outputs])
1026          bytesize += value.node_def.ByteSize()
1027          if bytesize >= (1 << 31) or bytesize < 0:
1028            raise ValueError("GraphDef cannot be larger than 2GB.")
1029
1030      graph._copy_functions_to_graph_def(graph_def, bytesize)  # pylint: disable=protected-access
1031
1032    # It's possible that not all the inputs are in the export_scope.
1033    # If we would like such information included in the exported meta_graph,
1034    # add them to a special unbound_inputs collection.
1035    if unbound_inputs_col_name:
1036      # Clears the unbound_inputs collections.
1037      graph.clear_collection(unbound_inputs_col_name)
1038      for k in unbound_inputs:
1039        graph.add_to_collection(unbound_inputs_col_name, k)
1040
1041  var_list = {}
1042  variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
1043                                   scope=export_scope)
1044  for v in variables:
1045    if _should_include_node(v, export_scope, exclude_nodes):
1046      var_list[ops.strip_name_scope(v.name, export_scope)] = v
1047
1048  scoped_meta_graph_def = create_meta_graph_def(
1049      graph_def=graph_def,
1050      graph=graph,
1051      export_scope=export_scope,
1052      exclude_nodes=exclude_nodes,
1053      clear_extraneous_savers=clear_extraneous_savers,
1054      saver_def=saver_def,
1055      strip_default_attrs=strip_default_attrs,
1056      **kwargs)
1057
1058  if filename:
1059    graph_io.write_graph(
1060        scoped_meta_graph_def,
1061        os.path.dirname(filename),
1062        os.path.basename(filename),
1063        as_text=as_text)
1064    if save_debug_info:
1065      name, _ = os.path.splitext(filename)
1066      debug_filename = "{name}{ext}".format(name=name, ext=".debug")
1067
1068      # Gets the operation from the graph by the name. Exludes variable nodes,
1069      # so only the nodes in the frozen models are included.
1070      ops_to_export = []
1071      for node in scoped_meta_graph_def.graph_def.node:
1072        scoped_op_name = ops.prepend_name_scope(node.name, export_scope)
1073        ops_to_export.append(graph.get_operation_by_name(scoped_op_name))
1074
1075      graph_debug_info = create_graph_debug_info_def(ops_to_export)
1076
1077      graph_io.write_graph(
1078          graph_debug_info,
1079          os.path.dirname(debug_filename),
1080          os.path.basename(debug_filename),
1081          as_text=as_text)
1082
1083  return scoped_meta_graph_def, var_list
1084
1085
1086def copy_scoped_meta_graph(from_scope, to_scope,
1087                           from_graph=None, to_graph=None):
1088  """Copies a sub-meta_graph from one scope to another.
1089
1090  Args:
1091    from_scope: `String` name scope containing the subgraph to be copied.
1092    to_scope: `String` name scope under which the copied subgraph will reside.
1093    from_graph: Optional `Graph` from which to copy the subgraph. If `None`, the
1094      default graph is use.
1095    to_graph: Optional `Graph` to which to copy the subgraph. If `None`, the
1096      default graph is used.
1097
1098  Returns:
1099    A dictionary of `Variables` that has been copied into `to_scope`.
1100
1101  Raises:
1102    ValueError: If `from_scope` and `to_scope` are the same while
1103      `from_graph` and `to_graph` are also the same.
1104  """
1105  from_graph = from_graph or ops.get_default_graph()
1106  to_graph = to_graph or ops.get_default_graph()
1107
1108  if from_graph == to_graph and from_scope == to_scope:
1109    raise ValueError("'from_scope' and 'to_scope' need to be different "
1110                     "when performing copy in the same graph.")
1111
1112  orig_meta_graph, var_list = export_scoped_meta_graph(
1113      export_scope=from_scope, graph=from_graph)
1114  var_list = import_scoped_meta_graph(orig_meta_graph,
1115                                      graph=to_graph,
1116                                      import_scope=to_scope)
1117  return var_list
1118