• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""MetaGraph and related functions."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import copy
22from distutils import version as distutils_version  # pylint: disable=g-bad-import-order
23import os.path
24import re
25
26import six
27from google.protobuf.any_pb2 import Any
28from google.protobuf import text_format
29
30from tensorflow.core.framework import attr_value_pb2
31from tensorflow.core.framework import graph_pb2
32from tensorflow.core.framework import op_def_pb2
33from tensorflow.core.protobuf import meta_graph_pb2
34from tensorflow.core.protobuf import saver_pb2
35from tensorflow.python.client import pywrap_tf_session as c_api
36from tensorflow.python.eager import context
37from tensorflow.python.framework import error_interpolation
38from tensorflow.python.framework import graph_io
39from tensorflow.python.framework import importer
40from tensorflow.python.framework import op_def_registry
41from tensorflow.python.framework import ops
42from tensorflow.python.framework import versions
43from tensorflow.python.lib.io import file_io
44from tensorflow.python.platform import tf_logging as logging
45from tensorflow.python.util import compat
46
47
48# Prefix to be added to unbound input names so they are easily identifiable.
49_UNBOUND_INPUT_PREFIX = "$unbound_inputs_"
50
51# List of collections that didn't register proto functions, as a result in
52# a previously exported meta_graph the items are of a different data type.
53_COMPAT_COLLECTION_LIST = [ops.GraphKeys.LOCAL_VARIABLES,
54                           ops.GraphKeys.MODEL_VARIABLES,
55                           ops.GraphKeys.METRIC_VARIABLES]
56
57
58def _node_def(from_node_def, export_scope, unbound_inputs, clear_devices=False):
59  """Create a `NodeDef` proto with export_scope stripped.
60
61  Args:
62    from_node_def: A `node_def_pb2.NodeDef` protocol buffer.
63    export_scope: A `string` representing the name scope to remove.
64    unbound_inputs: An array of unbound input names if they exist.
65    clear_devices: Boolean which controls whether to clear device information
66      from node_def. Default false.
67
68  Returns:
69    A `node_def_pb2.NodeDef` protocol buffer.
70  """
71  node_def = copy.deepcopy(from_node_def)
72  for i, v in enumerate(node_def.input):
73    if (export_scope and
74        not node_def.input[i].lstrip("^").startswith(export_scope)):
75      # Adds "$unbound_inputs_" prefix to the unbound name so they are easily
76      # identifiable.
77      node_def.input[i] = re.sub(r"([\^]|^)(.*)",
78                                 r"\1" + _UNBOUND_INPUT_PREFIX + r"\2",
79                                 compat.as_str(v))
80      unbound_inputs.append(node_def.input[i])
81    else:
82      node_def.input[i] = ops.strip_name_scope(v, export_scope)
83  node_def.name = compat.as_bytes(
84      ops.strip_name_scope(from_node_def.name, export_scope))
85  for k, v in six.iteritems(from_node_def.attr):
86    if k == "_class":
87      new_s = [compat.as_bytes(
88          ops.strip_name_scope(s, export_scope)) for s in v.list.s
89               if not export_scope or
90               compat.as_str(s).split("@")[1].startswith(export_scope)]
91      node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(
92          list=attr_value_pb2.AttrValue.ListValue(s=new_s)))
93    elif node_def.op in ("Enter", "RefEnter") and k == "frame_name":
94      if not export_scope or compat.as_str(v.s).startswith(export_scope):
95        new_s = compat.as_bytes(ops.strip_name_scope(v.s, export_scope))
96      node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(s=new_s))
97    else:
98      node_def.attr[k].CopyFrom(v)
99
100  if clear_devices:
101    node_def.device = ""
102
103  return node_def
104
105
106def _read_file(filename):
107  """Reads a file containing `GraphDef` and returns the protocol buffer.
108
109  Args:
110    filename: `graph_def` filename including the path.
111
112  Returns:
113    A `GraphDef` protocol buffer.
114
115  Raises:
116    IOError: If the file doesn't exist, or cannot be successfully parsed.
117  """
118  graph_def = graph_pb2.GraphDef()
119  if not file_io.file_exists(filename):
120    raise IOError("File %s does not exist." % filename)
121  # First try to read it as a binary file.
122  file_content = file_io.FileIO(filename, "rb").read()
123  try:
124    graph_def.ParseFromString(file_content)
125    return graph_def
126  except Exception:  # pylint: disable=broad-except
127    pass
128
129  # Next try to read it as a text file.
130  try:
131    text_format.Merge(file_content, graph_def)
132  except text_format.ParseError as e:
133    raise IOError("Cannot parse file %s: %s." % (filename, str(e)))
134
135  return graph_def
136
137
138def ops_used_by_graph_def(graph_def):
139  """Collect the list of ops used by a graph.
140
141  Does not validate that the ops are all registered.
142
143  Args:
144    graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`.
145
146  Returns:
147    A list of strings, each naming an op used by the graph.
148  """
149  # Map function names to definitions
150  name_to_function = {}
151  for fun in graph_def.library.function:
152    name_to_function[fun.signature.name] = fun
153
154  # Collect the list of op names.  Since functions can reference functions, we
155  # need a recursive traversal.
156  used_ops = set()  # Includes both primitive ops and functions
157  functions_to_process = []  # A subset of used_ops
158
159  def mark_op_as_used(op):
160    if op not in used_ops and op in name_to_function:
161      functions_to_process.append(name_to_function[op])
162    used_ops.add(op)
163
164  for node in graph_def.node:
165    mark_op_as_used(node.op)
166  while functions_to_process:
167    fun = functions_to_process.pop()
168    for node in fun.node_def:
169      mark_op_as_used(node.op)
170
171  return [op for op in used_ops if op not in name_to_function]
172
173
174def stripped_op_list_for_graph(graph_def):
175  """Collect the stripped OpDefs for ops used by a graph.
176
177  This function computes the `stripped_op_list` field of `MetaGraphDef` and
178  similar protos.  The result can be communicated from the producer to the
179  consumer, which can then use the C++ function
180  `RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility.
181
182  Args:
183    graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`.
184
185  Returns:
186    An `OpList` of ops used by the graph.
187  """
188  # This is similar to StrippedOpListForGraph in C++, but unlike its
189  # C++ counterpart, this version does not require all ops to be registered.
190  # This is done to support Prelu fusion in tfjs.
191  used_ops = ops_used_by_graph_def(graph_def)
192  op_defs = []
193  for op in sorted(used_ops):
194    op_def = op_def_registry.get(op)
195    if op_def is not None:
196      op_defs.append(op_def)
197  return op_def_pb2.OpList(op=op_defs)
198
199
200def _get_kind_name(item):
201  """Returns the kind name in CollectionDef.
202
203  Args:
204    item: A data item.
205
206  Returns:
207    The string representation of the kind in CollectionDef.
208  """
209  if isinstance(item, (six.string_types, six.binary_type)):
210    kind = "bytes_list"
211  elif isinstance(item, six.integer_types):
212    kind = "int64_list"
213  elif isinstance(item, float):
214    kind = "float_list"
215  elif isinstance(item, Any):
216    kind = "any_list"
217  else:
218    kind = "node_list"
219  return kind
220
221
222SAVE_AND_RESTORE_OPS = ["SaveV2",
223                        "Save", "SaveSlice",
224                        "LegacySave", "LegacySaveSlice",
225                        "RestoreV2",
226                        "Restore", "RestoreSlice",
227                        "LegacyRestore", "LegacyRestoreSlice"]
228
229
230def _op_name(tensor_name):
231  """Extract the Op name from a Tensor name.
232
233  The Op name is everything before a colon, if present,
234  not including any ^ prefix denoting a control dependency.
235
236  Args:
237    tensor_name: the full name of a Tensor in the graph.
238  Returns:
239    The name of the Op of which the given Tensor is an output.
240  Raises:
241    ValueError: if tensor_name is None or empty.
242  """
243  if not tensor_name:
244    raise ValueError("Tensor name cannot be empty or None.")
245
246  # Control dependency inputs start with ^.
247  if tensor_name.startswith("^"):
248    tensor_name = tensor_name[1:]
249  if ":" in tensor_name:
250    op_name, _ = tensor_name.split(":")
251    return op_name
252  return tensor_name
253
254
255def _get_scope(node_name):
256  """Extract the scope name from a node name.
257
258  The scope name is everything before the final slash,
259  not including any ^ prefix denoting a control dependency.
260
261  Args:
262    node_name: the full name of an Op or a Tensor in the graph.
263  Returns:
264    The deepest named scope containing the node.
265  Raises:
266    ValueError: if tensor_name is None or empty
267  """
268  if not node_name:
269    raise ValueError("Node name cannot be empty or None.")
270
271  # Control dependency inputs start with ^.
272  if node_name.startswith("^"):
273    node_name = node_name[1:]
274  if "/" in node_name:
275    scope, _ = node_name.rsplit("/", 1)
276    return scope
277
278  return ""
279
280
281def _find_extraneous_saver_nodes(graph_def, saver_def):
282  """Identifies any nodes in the graph_def related to unused Savers.
283
284  This approach assumes that each Saver is cleanly isolated in its own name
285  scope, so we need only identify the scopes associated with extraneous Savers
286  and return all the nodes in those scopes.
287
288  Args:
289    graph_def: a GraphDef proto to evaluate.
290    saver_def: a SaverDef proto referencing Save/Restore ops to be retained.
291  Returns:
292    An iterable of node names that may be safely omitted.
293  """
294  # TODO(soergel): confirm that the assumption of scope isolation is valid.
295  # If not, we need to walk up the graph from any restore_all nodes, and walk
296  # down the graph from any Save/Restore nodes.  I drafted that approach too,
297  # but it seems unnecessarily complex given the name scope solution.
298
299  # load the graph DAG in minimal form, without initializing a full Graph object
300  nodes = {
301      node_def.name: (set(_op_name(x) for x in node_def.input), node_def.op)
302      for node_def in graph_def.node
303  }
304
305  retain_scope_save = None
306  retain_scope_restore = None
307  # It's possible to have no saver if the graph has no Variables
308  if saver_def is not None:
309    save_op_name = _op_name(saver_def.save_tensor_name)
310    restore_op_name = _op_name(saver_def.restore_op_name)
311
312    # The save and restore scopes should always be the same, but if they differ
313    # for some reason, we retain them both to be safe.
314    retain_scope_restore = _get_scope(restore_op_name) + "/"
315    retain_scope_save = _get_scope(save_op_name) + "/"
316
317  all_saver_node_names = set(
318      name for name, (_, op) in nodes.items() if op in SAVE_AND_RESTORE_OPS)
319
320  all_saver_scopes = (
321      set(_get_scope(x) for x in all_saver_node_names) - all_saver_node_names)
322  all_saver_scopes = set(x + "/" for x in all_saver_scopes)
323
324  extraneous_scopes = all_saver_scopes - set([retain_scope_save,
325                                              retain_scope_restore])
326
327  extraneous_node_names = set()
328  for name, _ in nodes.items():
329    for extraneous_scope in extraneous_scopes:
330      if name.startswith(extraneous_scope):
331        extraneous_node_names.add(name)
332        break
333
334  return extraneous_node_names
335
336
337def _should_include_node(node_or_node_name, export_scope, exclude_nodes):
338  """Returns `True` if a node should be included.
339
340  Args:
341    node_or_node_name: A node or `string` node name.
342    export_scope: `string`. Name scope under which to extract the subgraph. The
343      scope name will be stripped from the node definitions for easy import
344      later into new name scopes.
345    exclude_nodes: An iterable of nodes or `string` node names to omit from the
346      export, or None.  Note no sanity-checking is done, so this list must be
347      carefully constructed to avoid producing an invalid graph.
348
349  Returns:
350    `True` if the node should be included.
351  """
352  if not isinstance(node_or_node_name, six.string_types):
353    try:
354      node_name = node_or_node_name.name
355    except AttributeError:
356      # Keep the object that we don't know how to process.
357      return True
358  else:
359    node_name = node_or_node_name
360
361  if exclude_nodes and (node_or_node_name in exclude_nodes
362                        or node_name in exclude_nodes):
363    return False
364
365  return (node_name.startswith(_UNBOUND_INPUT_PREFIX) or
366          (not export_scope or node_name.startswith(export_scope)))
367
368
369def add_collection_def(meta_graph_def, key, graph=None,
370                       export_scope=None, exclude_nodes=None,
371                       override_contents=None):
372  """Adds a collection to MetaGraphDef protocol buffer.
373
374  Args:
375    meta_graph_def: MetaGraphDef protocol buffer.
376    key: One of the GraphKeys or user-defined string.
377    graph: The `Graph` from which to get collections.
378    export_scope: Optional `string`. Name scope to remove.
379    exclude_nodes: An iterable of nodes or `string` node names to omit from the
380      collection, or None.
381    override_contents: An iterable of values to place in the collection,
382      ignoring the current values (if set).
383  """
384  if graph and not isinstance(graph, ops.Graph):
385    raise TypeError("graph must be of type Graph, not %s", type(graph))
386
387  if not isinstance(key, six.string_types) and not isinstance(key, bytes):
388    logging.warning("Only collections with string type keys will be "
389                    "serialized. This key has %s", type(key))
390    return
391
392  # Sets graph to default graph if it's not passed in.
393  graph = graph or ops.get_default_graph()
394
395  if override_contents:
396    collection_list = override_contents
397  else:
398    collection_list = graph.get_collection(key)
399
400  # Remove nodes that should not be exported from the collection list.
401  collection_list = [x for x in collection_list if
402                     _should_include_node(x, export_scope, exclude_nodes)]
403  if not collection_list:
404    return
405
406  try:
407    col_def = meta_graph_def.collection_def[key]
408    to_proto = ops.get_to_proto_function(key)
409    proto_type = ops.get_collection_proto_type(key)
410    if to_proto:
411      kind = "bytes_list"
412      for x in collection_list:
413        # Additional type check to make sure the returned proto is indeed
414        # what we expect.
415        proto = to_proto(x, export_scope=export_scope)
416        if proto:
417          assert isinstance(proto, proto_type)
418          getattr(col_def, kind).value.append(proto.SerializeToString())
419    else:
420      kind = _get_kind_name(collection_list[0])
421      if kind == "node_list":
422        for x in collection_list:
423          if not export_scope or x.name.startswith(export_scope):
424            getattr(col_def, kind).value.append(
425                ops.strip_name_scope(x.name, export_scope))
426      elif kind == "bytes_list":
427        # NOTE(opensource): This force conversion is to work around the fact
428        # that Python3 distinguishes between bytes and strings.
429        getattr(col_def, kind).value.extend(
430            [compat.as_bytes(x) for x in collection_list])
431      else:
432        getattr(col_def, kind).value.extend([x for x in collection_list])
433  except Exception as e:  # pylint: disable=broad-except
434    logging.warning("Issue encountered when serializing %s.\n"
435                    "Type is unsupported, or the types of the items don't "
436                    "match field type in CollectionDef. Note this is a warning "
437                    "and probably safe to ignore.\n%s", key, str(e))
438    if key in meta_graph_def.collection_def:
439      del meta_graph_def.collection_def[key]
440    return
441
442
443def _is_default_attr_value(op_def, attr_name, attr_value):
444  """Checks if given attribute matches the default value in the op def."""
445  for attr_def in op_def.attr:
446    if attr_def.name == attr_name:
447      if not attr_def.HasField("default_value"):
448        return False
449      # c_api.EqualAttrValueWrapper returns an empty string
450      # if both arguments represent an equivalent AttrValue instance.
451      return not c_api.EqualAttrValueWrapper(
452          attr_value.SerializeToString(),
453          attr_def.default_value.SerializeToString())
454  return False
455
456
457def strip_graph_default_valued_attrs(meta_graph_def):
458  """Strips default valued attributes for node defs in given MetaGraphDef.
459
460  This method also sets `meta_info_def.stripped_default_attrs` in the given
461  `MetaGraphDef` proto to True.
462
463  Args:
464    meta_graph_def: `MetaGraphDef` protocol buffer
465
466  Returns:
467    None.
468  """
469  # Map function op names to their function definitions.
470  op_name_to_function = {}
471  for function_def in meta_graph_def.graph_def.library.function:
472    op_name_to_function[function_def.signature.name] = function_def
473
474  def _strip_node_default_valued_attrs(node_def):
475    """Removes default valued attributes from a single node def."""
476    if node_def.op in op_name_to_function:
477      return
478
479    op_def = op_def_registry.get(node_def.op)
480    if op_def is None:
481      return
482
483    attrs_to_strip = set()
484    for attr_name, attr_value in node_def.attr.items():
485      if _is_default_attr_value(op_def, attr_name, attr_value):
486        attrs_to_strip.add(attr_name)
487
488    for attr in attrs_to_strip:
489      del node_def.attr[attr]
490
491  # Process all NodeDef instances in graph_def.
492  for node_def in meta_graph_def.graph_def.node:
493    _strip_node_default_valued_attrs(node_def)
494
495  # Process all NodeDef instances in graph_def.library.function.
496  for function_def in meta_graph_def.graph_def.library.function:
497    for function_node_def in function_def.node_def:
498      _strip_node_default_valued_attrs(function_node_def)
499
500  # Tell consumers of this graph that default valued attrs have been stripped.
501  meta_graph_def.meta_info_def.stripped_default_attrs = True
502
503
504def create_meta_graph_def(meta_info_def=None,
505                          graph_def=None,
506                          saver_def=None,
507                          collection_list=None,
508                          graph=None,
509                          export_scope=None,
510                          exclude_nodes=None,
511                          clear_extraneous_savers=False,
512                          strip_default_attrs=False):
513  # pylint: disable=line-too-long
514  """Construct and returns a `MetaGraphDef` protocol buffer.
515
516  Args:
517    meta_info_def: `MetaInfoDef` protocol buffer.
518    graph_def: `GraphDef` protocol buffer.
519    saver_def: `SaverDef` protocol buffer.
520    collection_list: List of string keys to collect.
521    graph: The `Graph` to create `MetaGraphDef` out of.
522    export_scope: Optional `string`. Name scope to remove.
523    exclude_nodes: An iterable of nodes or `string` node names to omit from all
524      collection, or None.
525    clear_extraneous_savers: Remove any preexisting SaverDefs from the SAVERS
526        collection.  Note this method does not alter the graph, so any
527        extraneous Save/Restore ops should have been removed already, as needed.
528    strip_default_attrs: Boolean. If `True`, default-valued attributes will be
529        removed from the NodeDefs. For a detailed guide, see
530        [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
531
532  Returns:
533    MetaGraphDef protocol buffer.
534
535  Raises:
536    TypeError: If the arguments are not of the correct proto buffer type.
537  """
538  # pylint: enable=line-too-long
539  # Type check.
540  if graph and not isinstance(graph, ops.Graph):
541    raise TypeError("graph must be of type Graph, not %s", type(graph))
542  if meta_info_def and not isinstance(meta_info_def,
543                                      meta_graph_pb2.MetaGraphDef.MetaInfoDef):
544    raise TypeError("meta_info_def must be of type MetaInfoDef, not %s",
545                    type(meta_info_def))
546  if graph_def and not isinstance(graph_def, graph_pb2.GraphDef):
547    raise TypeError("graph_def must be of type GraphDef, not %s",
548                    type(graph_def))
549  if saver_def and not isinstance(saver_def, saver_pb2.SaverDef):
550    raise TypeError("saver_def must be of type SaverDef, not %s",
551                    type(saver_def))
552
553  # Sets graph to default graph if it's not passed in.
554  graph = graph or ops.get_default_graph()
555
556  # Creates a MetaGraphDef proto.
557  meta_graph_def = meta_graph_pb2.MetaGraphDef()
558  # Adds meta_info_def.
559  if not meta_info_def:
560    meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef()
561
562  # Set the tf version strings to the current tf build.
563  meta_info_def.tensorflow_version = versions.__version__
564  meta_info_def.tensorflow_git_version = versions.__git_version__
565  meta_graph_def.meta_info_def.MergeFrom(meta_info_def)
566
567  # Adds graph_def or the default.
568  if not graph_def:
569    meta_graph_def.graph_def.MergeFrom(graph.as_graph_def(add_shapes=True))
570  else:
571    meta_graph_def.graph_def.MergeFrom(graph_def)
572
573  # Fills in meta_info_def.stripped_op_list using the ops from graph_def.
574  # pylint: disable=g-explicit-length-test
575  if len(meta_graph_def.meta_info_def.stripped_op_list.op) == 0:
576    meta_graph_def.meta_info_def.stripped_op_list.MergeFrom(
577        stripped_op_list_for_graph(meta_graph_def.graph_def))
578  # pylint: enable=g-explicit-length-test
579
580  # Strip default valued attributes in graph_def.
581  if strip_default_attrs:
582    strip_graph_default_valued_attrs(meta_graph_def)
583
584  # Adds saver_def.
585  if saver_def:
586    meta_graph_def.saver_def.MergeFrom(saver_def)
587
588  # Adds collection_list.
589  if collection_list is not None:
590    clist = collection_list
591  else:
592    clist = graph.get_all_collection_keys()
593
594  for ctype in clist:
595    if clear_extraneous_savers and ctype == ops.GraphKeys.SAVERS:
596      # Avoid importing Saver here
597      from_proto = ops.get_from_proto_function(ctype)
598      add_collection_def(meta_graph_def, ctype,
599                         graph=graph,
600                         export_scope=export_scope,
601                         exclude_nodes=exclude_nodes,
602                         override_contents=[from_proto(saver_def)])
603    else:
604      add_collection_def(meta_graph_def, ctype,
605                         graph=graph,
606                         export_scope=export_scope,
607                         exclude_nodes=exclude_nodes)
608  return meta_graph_def
609
610
611def read_meta_graph_file(filename):
612  """Reads a file containing `MetaGraphDef` and returns the protocol buffer.
613
614  Args:
615    filename: `meta_graph_def` filename including the path.
616
617  Returns:
618    A `MetaGraphDef` protocol buffer.
619
620  Raises:
621    IOError: If the file doesn't exist, or cannot be successfully parsed.
622  """
623  meta_graph_def = meta_graph_pb2.MetaGraphDef()
624  if not file_io.file_exists(filename):
625    raise IOError("File %s does not exist." % filename)
626  # First try to read it as a binary file.
627  file_content = file_io.FileIO(filename, "rb").read()
628  try:
629    meta_graph_def.ParseFromString(file_content)
630    return meta_graph_def
631  except Exception:  # pylint: disable=broad-except
632    pass
633
634  # Next try to read it as a text file.
635  try:
636    text_format.Merge(file_content.decode("utf-8"), meta_graph_def)
637  except text_format.ParseError as e:
638    raise IOError("Cannot parse file %s: %s." % (filename, str(e)))
639
640  return meta_graph_def
641
642
643def import_scoped_meta_graph(meta_graph_or_file,
644                             clear_devices=False,
645                             graph=None,
646                             import_scope=None,
647                             input_map=None,
648                             unbound_inputs_col_name="unbound_inputs",
649                             restore_collections_predicate=(lambda key: True)):
650  """Recreates a `Graph` saved in a `MetaGraphDef` proto.
651
652  This function takes a `MetaGraphDef` protocol buffer as input. If
653  the argument is a file containing a `MetaGraphDef` protocol buffer ,
654  it constructs a protocol buffer from the file content. The function
655  then adds all the nodes from the `graph_def` field to the
656  current graph, recreates the desired collections, and returns a dictionary of
657  all the Variables imported into the name scope.
658
659  In combination with `export_scoped_meta_graph()`, this function can be used to
660
661  * Serialize a graph along with other Python objects such as `QueueRunner`,
662    `Variable` into a `MetaGraphDef`.
663
664  * Restart training from a saved graph and checkpoints.
665
666  * Run inference from a saved graph and checkpoints.
667
668  Args:
669    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
670      the path) containing a `MetaGraphDef`.
671    clear_devices: Boolean which controls whether to clear device information
672      from graph_def. Default false.
673    graph: The `Graph` to import into. If `None`, use the default graph.
674    import_scope: Optional `string`. Name scope into which to import the
675      subgraph. If `None`, the graph is imported to the root name scope.
676    input_map: A dictionary mapping input names (as strings) in `graph_def` to
677      `Tensor` objects. The values of the named input tensors in the imported
678      graph will be re-mapped to the respective `Tensor` values.
679    unbound_inputs_col_name: Collection name for looking up unbound inputs.
680    restore_collections_predicate: a predicate on collection names. A collection
681      named c (i.e whose key is c) will be restored iff
682      1) `restore_collections_predicate(c)` is True, and
683      2) `c != unbound_inputs_col_name`.
684
685  Returns:
686    A dictionary of all the `Variables` imported into the name scope.
687
688  Raises:
689    ValueError: If the graph_def contains unbound inputs.
690  """
691  return import_scoped_meta_graph_with_return_elements(
692      meta_graph_or_file, clear_devices, graph, import_scope, input_map,
693      unbound_inputs_col_name, restore_collections_predicate)[0]
694
695
696def import_scoped_meta_graph_with_return_elements(
697    meta_graph_or_file,
698    clear_devices=False,
699    graph=None,
700    import_scope=None,
701    input_map=None,
702    unbound_inputs_col_name="unbound_inputs",
703    restore_collections_predicate=(lambda key: True),
704    return_elements=None):
705  """Imports graph from `MetaGraphDef` and returns vars and return elements.
706
707  This function takes a `MetaGraphDef` protocol buffer as input. If
708  the argument is a file containing a `MetaGraphDef` protocol buffer ,
709  it constructs a protocol buffer from the file content. The function
710  then adds all the nodes from the `graph_def` field to the
711  current graph, recreates the desired collections, and returns a dictionary of
712  all the Variables imported into the name scope.
713
714  In combination with `export_scoped_meta_graph()`, this function can be used to
715
716  * Serialize a graph along with other Python objects such as `QueueRunner`,
717    `Variable` into a `MetaGraphDef`.
718
719  * Restart training from a saved graph and checkpoints.
720
721  * Run inference from a saved graph and checkpoints.
722
723  Args:
724    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
725      the path) containing a `MetaGraphDef`.
726    clear_devices: Boolean which controls whether to clear device information
727      from graph_def. Default false.
728    graph: The `Graph` to import into. If `None`, use the default graph.
729    import_scope: Optional `string`. Name scope into which to import the
730      subgraph. If `None`, the graph is imported to the root name scope.
731    input_map: A dictionary mapping input names (as strings) in `graph_def` to
732      `Tensor` objects. The values of the named input tensors in the imported
733      graph will be re-mapped to the respective `Tensor` values.
734    unbound_inputs_col_name: Collection name for looking up unbound inputs.
735    restore_collections_predicate: a predicate on collection names. A collection
736      named c (i.e whose key is c) will be restored iff
737      1) `restore_collections_predicate(c)` is True, and
738      2) `c != unbound_inputs_col_name`.
739    return_elements:  A list of strings containing operation names in the
740      `MetaGraphDef` that will be returned as `Operation` objects; and/or
741      tensor names in `MetaGraphDef` that will be returned as `Tensor` objects.
742
743  Returns:
744    A tuple of (
745      dictionary of all the `Variables` imported into the name scope,
746      list of `Operation` or `Tensor` objects from the `return_elements` list).
747
748  Raises:
749    ValueError: If the graph_def contains unbound inputs.
750
751  """
752  if context.executing_eagerly():
753    raise ValueError("Exporting/importing meta graphs is not supported when "
754                     "eager execution is enabled.")
755  if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
756    meta_graph_def = meta_graph_or_file
757  else:
758    meta_graph_def = read_meta_graph_file(meta_graph_or_file)
759
760  if unbound_inputs_col_name:
761    for key, col_def in meta_graph_def.collection_def.items():
762      if key == unbound_inputs_col_name:
763        kind = col_def.WhichOneof("kind")
764        field = getattr(col_def, kind)
765        if field.value and (
766            not input_map or
767            sorted([compat.as_str(v) for v in field.value]) !=
768            sorted(input_map)):
769          raise ValueError("Graph contains unbound inputs: %s. Must "
770                           "provide these inputs through input_map." % ",".join(
771                               compat.as_str(v)
772                               for v in field.value
773                               if not input_map or v not in input_map))
774        break
775
776  # Sets graph to default graph if it's not passed in.
777  graph = graph or ops.get_default_graph()
778
779  # Gathers the list of nodes we are interested in.
780  with graph.as_default():
781    producer_op_list = None
782    if meta_graph_def.meta_info_def.HasField("stripped_op_list"):
783      producer_op_list = meta_graph_def.meta_info_def.stripped_op_list
784    input_graph_def = meta_graph_def.graph_def
785    # Remove all the explicit device specifications for this node. This helps to
786    # make the graph more portable.
787    if clear_devices:
788      for node in input_graph_def.node:
789        node.device = ""
790
791    scope_to_prepend_to_names = graph.unique_name(
792        import_scope or "", mark_as_used=False)
793
794    imported_return_elements = importer.import_graph_def(
795        input_graph_def,
796        name=(import_scope or scope_to_prepend_to_names),
797        input_map=input_map,
798        producer_op_list=producer_op_list,
799        return_elements=return_elements)
800
801    # TensorFlow versions before 1.9 (not inclusive) exported SavedModels
802    # without a VariableDef.trainable field set.
803    tf_version = meta_graph_def.meta_info_def.tensorflow_version
804    if not tf_version:
805      variables_have_trainable = True
806    else:
807      variables_have_trainable = (
808          distutils_version.LooseVersion(tf_version)
809          >= distutils_version.LooseVersion("1.9"))
810
811    # Sort collections so we see TRAINABLE_VARIABLES first and can default these
812    # variables to trainable if the value is not set in their VariableDef.
813    sorted_collections = []
814    if ops.GraphKeys.TRAINABLE_VARIABLES in meta_graph_def.collection_def:
815      sorted_collections.append(
816          (ops.GraphKeys.TRAINABLE_VARIABLES,
817           meta_graph_def.collection_def[ops.GraphKeys.TRAINABLE_VARIABLES]))
818    for key, value in sorted(meta_graph_def.collection_def.items()):
819      if key != ops.GraphKeys.TRAINABLE_VARIABLES:
820        sorted_collections.append((key, value))
821
822    # Restores all the other collections.
823    variable_objects = {}
824    for key, col_def in sorted_collections:
825      # Don't add unbound_inputs to the new graph.
826      if key == unbound_inputs_col_name:
827        continue
828      if not restore_collections_predicate(key):
829        continue
830
831      kind = col_def.WhichOneof("kind")
832      if kind is None:
833        logging.error("Cannot identify data type for collection %s. Skipping.",
834                      key)
835        continue
836      from_proto = ops.get_from_proto_function(key)
837
838      # Temporary change to allow the TFMA evaluator to read metric variables
839      # saved as a bytes list.
840      # TODO(kathywu): Remove this hack once cl/248406059 has been submitted.
841      if key == ops.GraphKeys.METRIC_VARIABLES:
842        # Metric variables will use the same proto functions as GLOBAL_VARIABLES
843        from_proto = ops.get_from_proto_function(ops.GraphKeys.GLOBAL_VARIABLES)
844      if from_proto and kind == "bytes_list":
845        proto_type = ops.get_collection_proto_type(key)
846        if key in ops.GraphKeys._VARIABLE_COLLECTIONS:  # pylint: disable=protected-access
847          for value in col_def.bytes_list.value:
848            variable = variable_objects.get(value, None)
849            if variable is None:
850              proto = proto_type()
851              proto.ParseFromString(value)
852              if not variables_have_trainable:
853                # If the VariableDef proto does not contain a "trainable"
854                # property because it was exported before that property was
855                # added, we default it to whether the variable is in the
856                # TRAINABLE_VARIABLES collection. We've sorted
857                # TRAINABLE_VARIABLES to be first, so trainable variables will
858                # be created from that collection.
859                proto.trainable = (key == ops.GraphKeys.TRAINABLE_VARIABLES)
860              variable = from_proto(
861                  proto, import_scope=scope_to_prepend_to_names)
862              variable_objects[value] = variable
863            graph.add_to_collection(key, variable)
864        else:
865          for value in col_def.bytes_list.value:
866            proto = proto_type()
867            proto.ParseFromString(value)
868            graph.add_to_collection(
869                key, from_proto(
870                    proto, import_scope=scope_to_prepend_to_names))
871      else:
872        field = getattr(col_def, kind)
873        if key in _COMPAT_COLLECTION_LIST:
874          logging.warning(
875              "The saved meta_graph is possibly from an older release:\n"
876              "'%s' collection should be of type 'byte_list', but instead "
877              "is of type '%s'.", key, kind)
878        if kind == "node_list":
879          for value in field.value:
880            col_op = graph.as_graph_element(
881                ops.prepend_name_scope(value, scope_to_prepend_to_names))
882            graph.add_to_collection(key, col_op)
883        elif kind == "int64_list":
884          # NOTE(opensource): This force conversion is to work around the fact
885          # that Python2 distinguishes between int and long, while Python3 has
886          # only int.
887          for value in field.value:
888            graph.add_to_collection(key, int(value))
889        else:
890          for value in field.value:
891            graph.add_to_collection(
892                key, ops.prepend_name_scope(value, scope_to_prepend_to_names))
893
894    var_list = {}
895    variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
896                                     scope=scope_to_prepend_to_names)
897    for v in variables:
898      var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v
899
900  return var_list, imported_return_elements
901
902
903def export_scoped_meta_graph(filename=None,
904                             graph_def=None,
905                             graph=None,
906                             export_scope=None,
907                             as_text=False,
908                             unbound_inputs_col_name="unbound_inputs",
909                             clear_devices=False,
910                             saver_def=None,
911                             clear_extraneous_savers=False,
912                             strip_default_attrs=False,
913                             save_debug_info=False,
914                             **kwargs):
915  """Returns `MetaGraphDef` proto. Optionally writes it to filename.
916
917  This function exports the graph, saver, and collection objects into
918  `MetaGraphDef` protocol buffer with the intention of it being imported
919  at a later time or location to restart training, run inference, or be
920  a subgraph.
921
922  Args:
923    filename: Optional filename including the path for writing the
924      generated `MetaGraphDef` protocol buffer.
925    graph_def: `GraphDef` protocol buffer.
926    graph: The `Graph` to export. If `None`, use the default graph.
927    export_scope: Optional `string`. Name scope under which to extract
928      the subgraph. The scope name will be stripped from the node definitions
929      for easy import later into new name scopes. If `None`, the whole graph
930      is exported.
931    as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto.
932    unbound_inputs_col_name: Optional `string`. If provided, a string collection
933      with the given name will be added to the returned `MetaGraphDef`,
934      containing the names of tensors that must be remapped when importing the
935      `MetaGraphDef`.
936    clear_devices: Boolean which controls whether to clear device information
937      before exporting the graph.
938    saver_def: `SaverDef` protocol buffer.
939    clear_extraneous_savers: Remove any Saver-related information from the
940        graph (both Save/Restore ops and SaverDefs) that are not associated
941        with the provided SaverDef.
942    strip_default_attrs: Set to true if default valued attributes must be
943      removed while exporting the GraphDef.
944    save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
945      which in the same directory of filename and with `_debug` added before the
946      file extension.
947    **kwargs: Optional keyed arguments, including meta_info_def and
948        collection_list.
949
950  Returns:
951    A `MetaGraphDef` proto and dictionary of `Variables` in the exported
952    name scope.
953
954  Raises:
955    ValueError: When the `GraphDef` is larger than 2GB.
956    ValueError: When executing in Eager mode and either `graph_def` or `graph`
957      is undefined.
958  """
959  if context.executing_eagerly() and not (graph_def is not None and
960                                          graph is not None):
961    raise ValueError("Exporting/importing meta graphs is not supported when "
962                     "Eager Execution is enabled.")
963  graph = graph or ops.get_default_graph()
964
965  exclude_nodes = None
966  unbound_inputs = []
967  if export_scope or clear_extraneous_savers or clear_devices:
968    if graph_def:
969      new_graph_def = graph_pb2.GraphDef()
970      new_graph_def.versions.CopyFrom(graph_def.versions)
971      new_graph_def.library.CopyFrom(graph_def.library)
972
973      if clear_extraneous_savers:
974        exclude_nodes = _find_extraneous_saver_nodes(graph_def, saver_def)
975
976      for node_def in graph_def.node:
977        if _should_include_node(node_def.name, export_scope, exclude_nodes):
978          new_node_def = _node_def(node_def, export_scope, unbound_inputs,
979                                   clear_devices=clear_devices)
980          new_graph_def.node.extend([new_node_def])
981      graph_def = new_graph_def
982    else:
983      # Only do this complicated work if we want to remove a name scope.
984      graph_def = graph_pb2.GraphDef()
985      # pylint: disable=protected-access
986      graph_def.versions.CopyFrom(graph.graph_def_versions)
987      bytesize = 0
988
989      if clear_extraneous_savers:
990        exclude_nodes = _find_extraneous_saver_nodes(graph.as_graph_def(),
991                                                     saver_def)
992
993      for key in sorted(graph._nodes_by_id):
994        if _should_include_node(graph._nodes_by_id[key].name,
995                                export_scope,
996                                exclude_nodes):
997          value = graph._nodes_by_id[key]
998          # pylint: enable=protected-access
999          node_def = _node_def(value.node_def, export_scope, unbound_inputs,
1000                               clear_devices=clear_devices)
1001          graph_def.node.extend([node_def])
1002          if value.outputs:
1003            assert "_output_shapes" not in graph_def.node[-1].attr
1004            graph_def.node[-1].attr["_output_shapes"].list.shape.extend([
1005                output.get_shape().as_proto() for output in value.outputs])
1006          bytesize += value.node_def.ByteSize()
1007          if bytesize >= (1 << 31) or bytesize < 0:
1008            raise ValueError("GraphDef cannot be larger than 2GB.")
1009
1010      graph._copy_functions_to_graph_def(graph_def, bytesize)  # pylint: disable=protected-access
1011
1012    # It's possible that not all the inputs are in the export_scope.
1013    # If we would like such information included in the exported meta_graph,
1014    # add them to a special unbound_inputs collection.
1015    if unbound_inputs_col_name:
1016      # Clears the unbound_inputs collections.
1017      graph.clear_collection(unbound_inputs_col_name)
1018      for k in unbound_inputs:
1019        graph.add_to_collection(unbound_inputs_col_name, k)
1020
1021  var_list = {}
1022  variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
1023                                   scope=export_scope)
1024  for v in variables:
1025    if _should_include_node(v, export_scope, exclude_nodes):
1026      var_list[ops.strip_name_scope(v.name, export_scope)] = v
1027
1028  scoped_meta_graph_def = create_meta_graph_def(
1029      graph_def=graph_def,
1030      graph=graph,
1031      export_scope=export_scope,
1032      exclude_nodes=exclude_nodes,
1033      clear_extraneous_savers=clear_extraneous_savers,
1034      saver_def=saver_def,
1035      strip_default_attrs=strip_default_attrs,
1036      **kwargs)
1037
1038  if filename:
1039    graph_io.write_graph(
1040        scoped_meta_graph_def,
1041        os.path.dirname(filename),
1042        os.path.basename(filename),
1043        as_text=as_text)
1044    if save_debug_info:
1045      name, _ = os.path.splitext(filename)
1046      debug_filename = "{name}{ext}".format(name=name, ext=".debug")
1047
1048      # Gets the operation from the graph by the name. Exludes variable nodes,
1049      # so only the nodes in the frozen models are included.
1050      # TODO(liufengdb): fix this for functions.
1051      ops_to_export = []
1052      for node in scoped_meta_graph_def.graph_def.node:
1053        scoped_op_name = ops.prepend_name_scope(node.name, export_scope)
1054        ops_to_export.append(("", graph.get_operation_by_name(scoped_op_name)))
1055
1056      graph_debug_info = error_interpolation.create_graph_debug_info_def(
1057          ops_to_export)
1058
1059      graph_io.write_graph(
1060          graph_debug_info,
1061          os.path.dirname(debug_filename),
1062          os.path.basename(debug_filename),
1063          as_text=as_text)
1064
1065  return scoped_meta_graph_def, var_list
1066
1067
1068def copy_scoped_meta_graph(from_scope, to_scope,
1069                           from_graph=None, to_graph=None):
1070  """Copies a sub-meta_graph from one scope to another.
1071
1072  Args:
1073    from_scope: `String` name scope containing the subgraph to be copied.
1074    to_scope: `String` name scope under which the copied subgraph will reside.
1075    from_graph: Optional `Graph` from which to copy the subgraph. If `None`, the
1076      default graph is use.
1077    to_graph: Optional `Graph` to which to copy the subgraph. If `None`, the
1078      default graph is used.
1079
1080  Returns:
1081    A dictionary of `Variables` that has been copied into `to_scope`.
1082
1083  Raises:
1084    ValueError: If `from_scope` and `to_scope` are the same while
1085      `from_graph` and `to_graph` are also the same.
1086  """
1087  from_graph = from_graph or ops.get_default_graph()
1088  to_graph = to_graph or ops.get_default_graph()
1089
1090  if from_graph == to_graph and from_scope == to_scope:
1091    raise ValueError("'from_scope' and 'to_scope' need to be different "
1092                     "when performing copy in the same graph.")
1093
1094  orig_meta_graph, var_list = export_scoped_meta_graph(
1095      export_scope=from_scope, graph=from_graph)
1096  var_list = import_scoped_meta_graph(orig_meta_graph,
1097                                      graph=to_graph,
1098                                      import_scope=to_scope)
1099  return var_list
1100