• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""MetaGraph and related functions."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import copy
22from distutils import version as distutils_version  # pylint: disable=g-bad-import-order
23import os.path
24import re
25
26import six
27from google.protobuf.any_pb2 import Any
28from google.protobuf import text_format
29
30from tensorflow.core.framework import attr_value_pb2
31from tensorflow.core.framework import graph_pb2
32from tensorflow.core.framework import op_def_pb2
33from tensorflow.core.protobuf import meta_graph_pb2
34from tensorflow.core.protobuf import saver_pb2
35from tensorflow.python.client import pywrap_tf_session as c_api
36from tensorflow.python.eager import context
37from tensorflow.python.framework import error_interpolation
38from tensorflow.python.framework import graph_io
39from tensorflow.python.framework import importer
40from tensorflow.python.framework import op_def_registry
41from tensorflow.python.framework import ops
42from tensorflow.python.framework import versions
43from tensorflow.python.lib.io import file_io
44from tensorflow.python.platform import tf_logging as logging
45from tensorflow.python.util import compat
46
47
48# Prefix to be added to unbound input names so they are easily identifiable.
49_UNBOUND_INPUT_PREFIX = "$unbound_inputs_"
50
51# List of collections that didn't register proto functions, as a result in
52# a previously exported meta_graph the items are of a different data type.
53_COMPAT_COLLECTION_LIST = [ops.GraphKeys.LOCAL_VARIABLES,
54                           ops.GraphKeys.MODEL_VARIABLES,
55                           ops.GraphKeys.METRIC_VARIABLES]
56
57
58def _node_def(from_node_def, export_scope, unbound_inputs, clear_devices=False):
59  """Create a `NodeDef` proto with export_scope stripped.
60
61  Args:
62    from_node_def: A `node_def_pb2.NodeDef` protocol buffer.
63    export_scope: A `string` representing the name scope to remove.
64    unbound_inputs: An array of unbound input names if they exist.
65    clear_devices: Boolean which controls whether to clear device information
66      from node_def. Default false.
67
68  Returns:
69    A `node_def_pb2.NodeDef` protocol buffer.
70  """
71  node_def = copy.deepcopy(from_node_def)
72  for i, v in enumerate(node_def.input):
73    if (export_scope and
74        not node_def.input[i].lstrip("^").startswith(export_scope)):
75      # Adds "$unbound_inputs_" prefix to the unbound name so they are easily
76      # identifiable.
77      node_def.input[i] = re.sub(r"([\^]|^)(.*)",
78                                 r"\1" + _UNBOUND_INPUT_PREFIX + r"\2",
79                                 compat.as_str(v))
80      unbound_inputs.append(node_def.input[i])
81    else:
82      node_def.input[i] = ops.strip_name_scope(v, export_scope)
83  node_def.name = compat.as_bytes(
84      ops.strip_name_scope(from_node_def.name, export_scope))
85  for k, v in six.iteritems(from_node_def.attr):
86    if k == "_class":
87      new_s = [compat.as_bytes(
88          ops.strip_name_scope(s, export_scope)) for s in v.list.s
89               if not export_scope or
90               compat.as_str(s).split("@")[1].startswith(export_scope)]
91      node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(
92          list=attr_value_pb2.AttrValue.ListValue(s=new_s)))
93    elif node_def.op in ("Enter", "RefEnter") and k == "frame_name":
94      if not export_scope or compat.as_str(v.s).startswith(export_scope):
95        new_s = compat.as_bytes(ops.strip_name_scope(v.s, export_scope))
96      node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(s=new_s))
97    else:
98      node_def.attr[k].CopyFrom(v)
99
100  if clear_devices:
101    node_def.device = ""
102
103  return node_def
104
105
106def _read_file(filename):
107  """Reads a file containing `GraphDef` and returns the protocol buffer.
108
109  Args:
110    filename: `graph_def` filename including the path.
111
112  Returns:
113    A `GraphDef` protocol buffer.
114
115  Raises:
116    IOError: If the file doesn't exist, or cannot be successfully parsed.
117  """
118  graph_def = graph_pb2.GraphDef()
119  if not file_io.file_exists(filename):
120    raise IOError("File %s does not exist." % filename)
121  # First try to read it as a binary file.
122  with file_io.FileIO(filename, "rb") as f:
123    file_content = f.read()
124  try:
125    graph_def.ParseFromString(file_content)
126    return graph_def
127  except Exception:  # pylint: disable=broad-except
128    pass
129
130  # Next try to read it as a text file.
131  try:
132    text_format.Merge(file_content, graph_def)
133  except text_format.ParseError as e:
134    raise IOError("Cannot parse file %s: %s." % (filename, str(e)))
135
136  return graph_def
137
138
139def ops_used_by_graph_def(graph_def):
140  """Collect the list of ops used by a graph.
141
142  Does not validate that the ops are all registered.
143
144  Args:
145    graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`.
146
147  Returns:
148    A list of strings, each naming an op used by the graph.
149  """
150  # Map function names to definitions
151  name_to_function = {}
152  for fun in graph_def.library.function:
153    name_to_function[fun.signature.name] = fun
154
155  # Collect the list of op names.  Since functions can reference functions, we
156  # need a recursive traversal.
157  used_ops = set()  # Includes both primitive ops and functions
158  functions_to_process = []  # A subset of used_ops
159
160  def mark_op_as_used(op):
161    if op not in used_ops and op in name_to_function:
162      functions_to_process.append(name_to_function[op])
163    used_ops.add(op)
164
165  def process_node(node):
166    mark_op_as_used(node.op)
167    if node.op in ["PartitionedCall", "StatefulPartitionedCall"]:
168      mark_op_as_used(node.attr["f"].func.name)
169
170  for node in graph_def.node:
171    process_node(node)
172  while functions_to_process:
173    fun = functions_to_process.pop()
174    for node in fun.node_def:
175      process_node(node)
176
177  return [op for op in used_ops if op not in name_to_function]
178
179
180def stripped_op_list_for_graph(graph_def):
181  """Collect the stripped OpDefs for ops used by a graph.
182
183  This function computes the `stripped_op_list` field of `MetaGraphDef` and
184  similar protos.  The result can be communicated from the producer to the
185  consumer, which can then use the C++ function
186  `RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility.
187
188  Args:
189    graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`.
190
191  Returns:
192    An `OpList` of ops used by the graph.
193  """
194  # This is similar to StrippedOpListForGraph in C++, but unlike its
195  # C++ counterpart, this version does not require all ops to be registered.
196  # This is done to support Prelu fusion in tfjs.
197  used_ops = ops_used_by_graph_def(graph_def)
198  op_defs = []
199  for op in sorted(used_ops):
200    op_def = op_def_registry.get(op)
201    if op_def is not None:
202      op_defs.append(op_def)
203  return op_def_pb2.OpList(op=op_defs)
204
205
206def _get_kind_name(item):
207  """Returns the kind name in CollectionDef.
208
209  Args:
210    item: A data item.
211
212  Returns:
213    The string representation of the kind in CollectionDef.
214  """
215  if isinstance(item, (six.string_types, six.binary_type)):
216    kind = "bytes_list"
217  elif isinstance(item, six.integer_types):
218    kind = "int64_list"
219  elif isinstance(item, float):
220    kind = "float_list"
221  elif isinstance(item, Any):
222    kind = "any_list"
223  else:
224    kind = "node_list"
225  return kind
226
227
228SAVE_AND_RESTORE_OPS = ["SaveV2",
229                        "Save", "SaveSlice",
230                        "LegacySave", "LegacySaveSlice",
231                        "RestoreV2",
232                        "Restore", "RestoreSlice",
233                        "LegacyRestore", "LegacyRestoreSlice"]
234
235
236def _op_name(tensor_name):
237  """Extract the Op name from a Tensor name.
238
239  The Op name is everything before a colon, if present,
240  not including any ^ prefix denoting a control dependency.
241
242  Args:
243    tensor_name: the full name of a Tensor in the graph.
244  Returns:
245    The name of the Op of which the given Tensor is an output.
246  Raises:
247    ValueError: if tensor_name is None or empty.
248  """
249  if not tensor_name:
250    raise ValueError("Tensor name cannot be empty or None.")
251
252  # Control dependency inputs start with ^.
253  if tensor_name.startswith("^"):
254    tensor_name = tensor_name[1:]
255  if ":" in tensor_name:
256    op_name, _ = tensor_name.split(":")
257    return op_name
258  return tensor_name
259
260
261def _get_scope(node_name):
262  """Extract the scope name from a node name.
263
264  The scope name is everything before the final slash,
265  not including any ^ prefix denoting a control dependency.
266
267  Args:
268    node_name: the full name of an Op or a Tensor in the graph.
269  Returns:
270    The deepest named scope containing the node.
271  Raises:
272    ValueError: if tensor_name is None or empty
273  """
274  if not node_name:
275    raise ValueError("Node name cannot be empty or None.")
276
277  # Control dependency inputs start with ^.
278  if node_name.startswith("^"):
279    node_name = node_name[1:]
280  if "/" in node_name:
281    scope, _ = node_name.rsplit("/", 1)
282    return scope
283
284  return ""
285
286
287def _find_extraneous_saver_nodes(graph_def, saver_def):
288  """Identifies any nodes in the graph_def related to unused Savers.
289
290  This approach assumes that each Saver is cleanly isolated in its own name
291  scope, so we need only identify the scopes associated with extraneous Savers
292  and return all the nodes in those scopes.
293
294  Args:
295    graph_def: a GraphDef proto to evaluate.
296    saver_def: a SaverDef proto referencing Save/Restore ops to be retained.
297  Returns:
298    An iterable of node names that may be safely omitted.
299  """
300  # TODO(soergel): confirm that the assumption of scope isolation is valid.
301  # If not, we need to walk up the graph from any restore_all nodes, and walk
302  # down the graph from any Save/Restore nodes.  I drafted that approach too,
303  # but it seems unnecessarily complex given the name scope solution.
304
305  # load the graph DAG in minimal form, without initializing a full Graph object
306  nodes = {
307      node_def.name: (set(_op_name(x) for x in node_def.input), node_def.op)
308      for node_def in graph_def.node
309  }
310
311  retain_scope_save = None
312  retain_scope_restore = None
313  # It's possible to have no saver if the graph has no Variables
314  if saver_def is not None:
315    save_op_name = _op_name(saver_def.save_tensor_name)
316    restore_op_name = _op_name(saver_def.restore_op_name)
317
318    # The save and restore scopes should always be the same, but if they differ
319    # for some reason, we retain them both to be safe.
320    retain_scope_restore = _get_scope(restore_op_name) + "/"
321    retain_scope_save = _get_scope(save_op_name) + "/"
322
323  all_saver_node_names = set(
324      name for name, (_, op) in nodes.items() if op in SAVE_AND_RESTORE_OPS)
325
326  all_saver_scopes = (
327      set(_get_scope(x) for x in all_saver_node_names) - all_saver_node_names)
328  all_saver_scopes = set(x + "/" for x in all_saver_scopes)
329
330  extraneous_scopes = all_saver_scopes - set([retain_scope_save,
331                                              retain_scope_restore])
332
333  extraneous_node_names = set()
334  for name, _ in nodes.items():
335    for extraneous_scope in extraneous_scopes:
336      if name.startswith(extraneous_scope):
337        extraneous_node_names.add(name)
338        break
339
340  return extraneous_node_names
341
342
343def _should_include_node(node_or_node_name, export_scope, exclude_nodes):
344  """Returns `True` if a node should be included.
345
346  Args:
347    node_or_node_name: A node or `string` node name.
348    export_scope: `string`. Name scope under which to extract the subgraph. The
349      scope name will be stripped from the node definitions for easy import
350      later into new name scopes.
351    exclude_nodes: An iterable of nodes or `string` node names to omit from the
352      export, or None.  Note no sanity-checking is done, so this list must be
353      carefully constructed to avoid producing an invalid graph.
354
355  Returns:
356    `True` if the node should be included.
357  """
358  if not isinstance(node_or_node_name, six.string_types):
359    try:
360      node_name = node_or_node_name.name
361    except AttributeError:
362      # Keep the object that we don't know how to process.
363      return True
364  else:
365    node_name = node_or_node_name
366
367  if exclude_nodes and (node_or_node_name in exclude_nodes
368                        or node_name in exclude_nodes):
369    return False
370
371  return (node_name.startswith(_UNBOUND_INPUT_PREFIX) or
372          (not export_scope or node_name.startswith(export_scope)))
373
374
375def add_collection_def(meta_graph_def, key, graph=None,
376                       export_scope=None, exclude_nodes=None,
377                       override_contents=None):
378  """Adds a collection to MetaGraphDef protocol buffer.
379
380  Args:
381    meta_graph_def: MetaGraphDef protocol buffer.
382    key: One of the GraphKeys or user-defined string.
383    graph: The `Graph` from which to get collections.
384    export_scope: Optional `string`. Name scope to remove.
385    exclude_nodes: An iterable of nodes or `string` node names to omit from the
386      collection, or None.
387    override_contents: An iterable of values to place in the collection,
388      ignoring the current values (if set).
389  """
390  if graph and not isinstance(graph, ops.Graph):
391    raise TypeError("graph must be of type Graph, not %s", type(graph))
392
393  if not isinstance(key, six.string_types) and not isinstance(key, bytes):
394    logging.warning("Only collections with string type keys will be "
395                    "serialized. This key has %s", type(key))
396    return
397
398  # Sets graph to default graph if it's not passed in.
399  graph = graph or ops.get_default_graph()
400
401  if override_contents:
402    collection_list = override_contents
403  else:
404    collection_list = graph.get_collection(key)
405
406  # Remove nodes that should not be exported from the collection list.
407  collection_list = [x for x in collection_list if
408                     _should_include_node(x, export_scope, exclude_nodes)]
409  if not collection_list:
410    return
411
412  try:
413    col_def = meta_graph_def.collection_def[key]
414    to_proto = ops.get_to_proto_function(key)
415    proto_type = ops.get_collection_proto_type(key)
416    if to_proto:
417      kind = "bytes_list"
418      for x in collection_list:
419        # Additional type check to make sure the returned proto is indeed
420        # what we expect.
421        proto = to_proto(x, export_scope=export_scope)
422        if proto:
423          assert isinstance(proto, proto_type)
424          getattr(col_def, kind).value.append(proto.SerializeToString())
425    else:
426      kind = _get_kind_name(collection_list[0])
427      if kind == "node_list":
428        for x in collection_list:
429          if not export_scope or x.name.startswith(export_scope):
430            getattr(col_def, kind).value.append(
431                ops.strip_name_scope(x.name, export_scope))
432      elif kind == "bytes_list":
433        # NOTE(opensource): This force conversion is to work around the fact
434        # that Python3 distinguishes between bytes and strings.
435        getattr(col_def, kind).value.extend(
436            [compat.as_bytes(x) for x in collection_list])
437      else:
438        getattr(col_def, kind).value.extend([x for x in collection_list])
439  except Exception as e:  # pylint: disable=broad-except
440    logging.warning("Issue encountered when serializing %s.\n"
441                    "Type is unsupported, or the types of the items don't "
442                    "match field type in CollectionDef. Note this is a warning "
443                    "and probably safe to ignore.\n%s", key, str(e))
444    if key in meta_graph_def.collection_def:
445      del meta_graph_def.collection_def[key]
446    return
447
448
449def _is_default_attr_value(op_def, attr_name, attr_value):
450  """Checks if given attribute matches the default value in the op def."""
451  for attr_def in op_def.attr:
452    if attr_def.name == attr_name:
453      if not attr_def.HasField("default_value"):
454        return False
455      # c_api.EqualAttrValueWrapper returns an empty string
456      # if both arguments represent an equivalent AttrValue instance.
457      return not c_api.EqualAttrValueWrapper(
458          attr_value.SerializeToString(),
459          attr_def.default_value.SerializeToString())
460  return False
461
462
463def strip_graph_default_valued_attrs(meta_graph_def):
464  """Strips default valued attributes for node defs in given MetaGraphDef.
465
466  This method also sets `meta_info_def.stripped_default_attrs` in the given
467  `MetaGraphDef` proto to True.
468
469  Args:
470    meta_graph_def: `MetaGraphDef` protocol buffer
471
472  Returns:
473    None.
474  """
475  # Map function op names to their function definitions.
476  op_name_to_function = {}
477  for function_def in meta_graph_def.graph_def.library.function:
478    op_name_to_function[function_def.signature.name] = function_def
479
480  def _strip_node_default_valued_attrs(node_def):
481    """Removes default valued attributes from a single node def."""
482    if node_def.op in op_name_to_function:
483      return
484
485    op_def = op_def_registry.get(node_def.op)
486    if op_def is None:
487      return
488
489    attrs_to_strip = set()
490    for attr_name, attr_value in node_def.attr.items():
491      if _is_default_attr_value(op_def, attr_name, attr_value):
492        attrs_to_strip.add(attr_name)
493
494    for attr in attrs_to_strip:
495      del node_def.attr[attr]
496
497  # Process all NodeDef instances in graph_def.
498  for node_def in meta_graph_def.graph_def.node:
499    _strip_node_default_valued_attrs(node_def)
500
501  # Process all NodeDef instances in graph_def.library.function.
502  for function_def in meta_graph_def.graph_def.library.function:
503    for function_node_def in function_def.node_def:
504      _strip_node_default_valued_attrs(function_node_def)
505
506  # Tell consumers of this graph that default valued attrs have been stripped.
507  meta_graph_def.meta_info_def.stripped_default_attrs = True
508
509
510def create_meta_graph_def(meta_info_def=None,
511                          graph_def=None,
512                          saver_def=None,
513                          collection_list=None,
514                          graph=None,
515                          export_scope=None,
516                          exclude_nodes=None,
517                          clear_extraneous_savers=False,
518                          strip_default_attrs=False):
519  # pylint: disable=line-too-long
520  """Construct and returns a `MetaGraphDef` protocol buffer.
521
522  Args:
523    meta_info_def: `MetaInfoDef` protocol buffer.
524    graph_def: `GraphDef` protocol buffer.
525    saver_def: `SaverDef` protocol buffer.
526    collection_list: List of string keys to collect.
527    graph: The `Graph` to create `MetaGraphDef` out of.
528    export_scope: Optional `string`. Name scope to remove.
529    exclude_nodes: An iterable of nodes or `string` node names to omit from all
530      collection, or None.
531    clear_extraneous_savers: Remove any preexisting SaverDefs from the SAVERS
532        collection.  Note this method does not alter the graph, so any
533        extraneous Save/Restore ops should have been removed already, as needed.
534    strip_default_attrs: Boolean. If `True`, default-valued attributes will be
535        removed from the NodeDefs. For a detailed guide, see
536        [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
537
538  Returns:
539    MetaGraphDef protocol buffer.
540
541  Raises:
542    TypeError: If the arguments are not of the correct proto buffer type.
543  """
544  # pylint: enable=line-too-long
545  # Type check.
546  if graph and not isinstance(graph, ops.Graph):
547    raise TypeError("graph must be of type Graph, not %s", type(graph))
548  if meta_info_def and not isinstance(meta_info_def,
549                                      meta_graph_pb2.MetaGraphDef.MetaInfoDef):
550    raise TypeError("meta_info_def must be of type MetaInfoDef, not %s",
551                    type(meta_info_def))
552  if graph_def and not isinstance(graph_def, graph_pb2.GraphDef):
553    raise TypeError("graph_def must be of type GraphDef, not %s",
554                    type(graph_def))
555  if saver_def and not isinstance(saver_def, saver_pb2.SaverDef):
556    raise TypeError("saver_def must be of type SaverDef, not %s",
557                    type(saver_def))
558
559  # Sets graph to default graph if it's not passed in.
560  graph = graph or ops.get_default_graph()
561
562  # Creates a MetaGraphDef proto.
563  meta_graph_def = meta_graph_pb2.MetaGraphDef()
564  # Adds meta_info_def.
565  if not meta_info_def:
566    meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef()
567
568  # Set the tf version strings to the current tf build.
569  meta_info_def.tensorflow_version = versions.__version__
570  meta_info_def.tensorflow_git_version = versions.__git_version__
571  meta_graph_def.meta_info_def.MergeFrom(meta_info_def)
572
573  # Adds graph_def or the default.
574  if not graph_def:
575    meta_graph_def.graph_def.MergeFrom(graph.as_graph_def(add_shapes=True))
576  else:
577    meta_graph_def.graph_def.MergeFrom(graph_def)
578
579  # Fills in meta_info_def.stripped_op_list using the ops from graph_def.
580  # pylint: disable=g-explicit-length-test
581  if len(meta_graph_def.meta_info_def.stripped_op_list.op) == 0:
582    meta_graph_def.meta_info_def.stripped_op_list.MergeFrom(
583        stripped_op_list_for_graph(meta_graph_def.graph_def))
584  # pylint: enable=g-explicit-length-test
585
586  # Strip default valued attributes in graph_def.
587  if strip_default_attrs:
588    strip_graph_default_valued_attrs(meta_graph_def)
589
590  # Adds saver_def.
591  if saver_def:
592    meta_graph_def.saver_def.MergeFrom(saver_def)
593
594  # Adds collection_list.
595  if collection_list is not None:
596    clist = collection_list
597  else:
598    clist = graph.get_all_collection_keys()
599
600  for ctype in clist:
601    if clear_extraneous_savers and ctype == ops.GraphKeys.SAVERS:
602      # Avoid importing Saver here
603      from_proto = ops.get_from_proto_function(ctype)
604      add_collection_def(meta_graph_def, ctype,
605                         graph=graph,
606                         export_scope=export_scope,
607                         exclude_nodes=exclude_nodes,
608                         override_contents=[from_proto(saver_def)])
609    else:
610      add_collection_def(meta_graph_def, ctype,
611                         graph=graph,
612                         export_scope=export_scope,
613                         exclude_nodes=exclude_nodes)
614  return meta_graph_def
615
616
617def read_meta_graph_file(filename):
618  """Reads a file containing `MetaGraphDef` and returns the protocol buffer.
619
620  Args:
621    filename: `meta_graph_def` filename including the path.
622
623  Returns:
624    A `MetaGraphDef` protocol buffer.
625
626  Raises:
627    IOError: If the file doesn't exist, or cannot be successfully parsed.
628  """
629  meta_graph_def = meta_graph_pb2.MetaGraphDef()
630  if not file_io.file_exists(filename):
631    raise IOError("File %s does not exist." % filename)
632  # First try to read it as a binary file.
633  with file_io.FileIO(filename, "rb") as f:
634    file_content = f.read()
635  try:
636    meta_graph_def.ParseFromString(file_content)
637    return meta_graph_def
638  except Exception:  # pylint: disable=broad-except
639    pass
640
641  # Next try to read it as a text file.
642  try:
643    text_format.Merge(file_content.decode("utf-8"), meta_graph_def)
644  except text_format.ParseError as e:
645    raise IOError("Cannot parse file %s: %s." % (filename, str(e)))
646
647  return meta_graph_def
648
649
650def import_scoped_meta_graph(meta_graph_or_file,
651                             clear_devices=False,
652                             graph=None,
653                             import_scope=None,
654                             input_map=None,
655                             unbound_inputs_col_name="unbound_inputs",
656                             restore_collections_predicate=(lambda key: True)):
657  """Recreates a `Graph` saved in a `MetaGraphDef` proto.
658
659  This function takes a `MetaGraphDef` protocol buffer as input. If
660  the argument is a file containing a `MetaGraphDef` protocol buffer ,
661  it constructs a protocol buffer from the file content. The function
662  then adds all the nodes from the `graph_def` field to the
663  current graph, recreates the desired collections, and returns a dictionary of
664  all the Variables imported into the name scope.
665
666  In combination with `export_scoped_meta_graph()`, this function can be used to
667
668  * Serialize a graph along with other Python objects such as `QueueRunner`,
669    `Variable` into a `MetaGraphDef`.
670
671  * Restart training from a saved graph and checkpoints.
672
673  * Run inference from a saved graph and checkpoints.
674
675  Args:
676    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
677      the path) containing a `MetaGraphDef`.
678    clear_devices: Boolean which controls whether to clear device information
679      from graph_def. Default false.
680    graph: The `Graph` to import into. If `None`, use the default graph.
681    import_scope: Optional `string`. Name scope into which to import the
682      subgraph. If `None`, the graph is imported to the root name scope.
683    input_map: A dictionary mapping input names (as strings) in `graph_def` to
684      `Tensor` objects. The values of the named input tensors in the imported
685      graph will be re-mapped to the respective `Tensor` values.
686    unbound_inputs_col_name: Collection name for looking up unbound inputs.
687    restore_collections_predicate: a predicate on collection names. A collection
688      named c (i.e whose key is c) will be restored iff
689      1) `restore_collections_predicate(c)` is True, and
690      2) `c != unbound_inputs_col_name`.
691
692  Returns:
693    A dictionary of all the `Variables` imported into the name scope.
694
695  Raises:
696    ValueError: If the graph_def contains unbound inputs.
697  """
698  return import_scoped_meta_graph_with_return_elements(
699      meta_graph_or_file, clear_devices, graph, import_scope, input_map,
700      unbound_inputs_col_name, restore_collections_predicate)[0]
701
702
703def import_scoped_meta_graph_with_return_elements(
704    meta_graph_or_file,
705    clear_devices=False,
706    graph=None,
707    import_scope=None,
708    input_map=None,
709    unbound_inputs_col_name="unbound_inputs",
710    restore_collections_predicate=(lambda key: True),
711    return_elements=None):
712  """Imports graph from `MetaGraphDef` and returns vars and return elements.
713
714  This function takes a `MetaGraphDef` protocol buffer as input. If
715  the argument is a file containing a `MetaGraphDef` protocol buffer ,
716  it constructs a protocol buffer from the file content. The function
717  then adds all the nodes from the `graph_def` field to the
718  current graph, recreates the desired collections, and returns a dictionary of
719  all the Variables imported into the name scope.
720
721  In combination with `export_scoped_meta_graph()`, this function can be used to
722
723  * Serialize a graph along with other Python objects such as `QueueRunner`,
724    `Variable` into a `MetaGraphDef`.
725
726  * Restart training from a saved graph and checkpoints.
727
728  * Run inference from a saved graph and checkpoints.
729
730  Args:
731    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
732      the path) containing a `MetaGraphDef`.
733    clear_devices: Boolean which controls whether to clear device information
734      from graph_def. Default false.
735    graph: The `Graph` to import into. If `None`, use the default graph.
736    import_scope: Optional `string`. Name scope into which to import the
737      subgraph. If `None`, the graph is imported to the root name scope.
738    input_map: A dictionary mapping input names (as strings) in `graph_def` to
739      `Tensor` objects. The values of the named input tensors in the imported
740      graph will be re-mapped to the respective `Tensor` values.
741    unbound_inputs_col_name: Collection name for looking up unbound inputs.
742    restore_collections_predicate: a predicate on collection names. A collection
743      named c (i.e whose key is c) will be restored iff
744      1) `restore_collections_predicate(c)` is True, and
745      2) `c != unbound_inputs_col_name`.
746    return_elements:  A list of strings containing operation names in the
747      `MetaGraphDef` that will be returned as `Operation` objects; and/or
748      tensor names in `MetaGraphDef` that will be returned as `Tensor` objects.
749
750  Returns:
751    A tuple of (
752      dictionary of all the `Variables` imported into the name scope,
753      list of `Operation` or `Tensor` objects from the `return_elements` list).
754
755  Raises:
756    ValueError: If the graph_def contains unbound inputs.
757
758  """
759  if context.executing_eagerly():
760    raise ValueError("Exporting/importing meta graphs is not supported when "
761                     "eager execution is enabled.")
762  if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
763    meta_graph_def = meta_graph_or_file
764  else:
765    meta_graph_def = read_meta_graph_file(meta_graph_or_file)
766
767  if unbound_inputs_col_name:
768    for key, col_def in meta_graph_def.collection_def.items():
769      if key == unbound_inputs_col_name:
770        kind = col_def.WhichOneof("kind")
771        field = getattr(col_def, kind)
772        if field.value and (
773            not input_map or
774            sorted([compat.as_str(v) for v in field.value]) !=
775            sorted(input_map)):
776          raise ValueError("Graph contains unbound inputs: %s. Must "
777                           "provide these inputs through input_map." % ",".join(
778                               compat.as_str(v)
779                               for v in field.value
780                               if not input_map or v not in input_map))
781        break
782
783  # Sets graph to default graph if it's not passed in.
784  graph = graph or ops.get_default_graph()
785
786  # Gathers the list of nodes we are interested in.
787  with graph.as_default():
788    producer_op_list = None
789    if meta_graph_def.meta_info_def.HasField("stripped_op_list"):
790      producer_op_list = meta_graph_def.meta_info_def.stripped_op_list
791    input_graph_def = meta_graph_def.graph_def
792    # Remove all the explicit device specifications for this node. This helps to
793    # make the graph more portable.
794    if clear_devices:
795      for node in input_graph_def.node:
796        node.device = ""
797
798    scope_to_prepend_to_names = graph.unique_name(
799        import_scope or "", mark_as_used=False)
800
801    imported_return_elements = importer.import_graph_def(
802        input_graph_def,
803        name=(import_scope or scope_to_prepend_to_names),
804        input_map=input_map,
805        producer_op_list=producer_op_list,
806        return_elements=return_elements)
807
808    # TensorFlow versions before 1.9 (not inclusive) exported SavedModels
809    # without a VariableDef.trainable field set.
810    tf_version = meta_graph_def.meta_info_def.tensorflow_version
811    if not tf_version:
812      variables_have_trainable = True
813    else:
814      variables_have_trainable = (
815          distutils_version.LooseVersion(tf_version)
816          >= distutils_version.LooseVersion("1.9"))
817
818    # Sort collections so we see TRAINABLE_VARIABLES first and can default these
819    # variables to trainable if the value is not set in their VariableDef.
820    sorted_collections = []
821    if ops.GraphKeys.TRAINABLE_VARIABLES in meta_graph_def.collection_def:
822      sorted_collections.append(
823          (ops.GraphKeys.TRAINABLE_VARIABLES,
824           meta_graph_def.collection_def[ops.GraphKeys.TRAINABLE_VARIABLES]))
825    for key, value in sorted(meta_graph_def.collection_def.items()):
826      if key != ops.GraphKeys.TRAINABLE_VARIABLES:
827        sorted_collections.append((key, value))
828
829    # Restores all the other collections.
830    variable_objects = {}
831    for key, col_def in sorted_collections:
832      # Don't add unbound_inputs to the new graph.
833      if key == unbound_inputs_col_name:
834        continue
835      if not restore_collections_predicate(key):
836        continue
837
838      kind = col_def.WhichOneof("kind")
839      if kind is None:
840        logging.error("Cannot identify data type for collection %s. Skipping.",
841                      key)
842        continue
843      from_proto = ops.get_from_proto_function(key)
844
845      # Temporary change to allow the TFMA evaluator to read metric variables
846      # saved as a bytes list.
847      # TODO(kathywu): Remove this hack once cl/248406059 has been submitted.
848      if key == ops.GraphKeys.METRIC_VARIABLES:
849        # Metric variables will use the same proto functions as GLOBAL_VARIABLES
850        from_proto = ops.get_from_proto_function(ops.GraphKeys.GLOBAL_VARIABLES)
851      if from_proto and kind == "bytes_list":
852        proto_type = ops.get_collection_proto_type(key)
853        if key in ops.GraphKeys._VARIABLE_COLLECTIONS:  # pylint: disable=protected-access
854          for value in col_def.bytes_list.value:
855            variable = variable_objects.get(value, None)
856            if variable is None:
857              proto = proto_type()
858              proto.ParseFromString(value)
859              if not variables_have_trainable:
860                # If the VariableDef proto does not contain a "trainable"
861                # property because it was exported before that property was
862                # added, we default it to whether the variable is in the
863                # TRAINABLE_VARIABLES collection. We've sorted
864                # TRAINABLE_VARIABLES to be first, so trainable variables will
865                # be created from that collection.
866                proto.trainable = (key == ops.GraphKeys.TRAINABLE_VARIABLES)
867              variable = from_proto(
868                  proto, import_scope=scope_to_prepend_to_names)
869              variable_objects[value] = variable
870            graph.add_to_collection(key, variable)
871        else:
872          for value in col_def.bytes_list.value:
873            proto = proto_type()
874            proto.ParseFromString(value)
875            graph.add_to_collection(
876                key, from_proto(
877                    proto, import_scope=scope_to_prepend_to_names))
878      else:
879        field = getattr(col_def, kind)
880        if key in _COMPAT_COLLECTION_LIST:
881          logging.warning(
882              "The saved meta_graph is possibly from an older release:\n"
883              "'%s' collection should be of type 'byte_list', but instead "
884              "is of type '%s'.", key, kind)
885        if kind == "node_list":
886          for value in field.value:
887            col_op = graph.as_graph_element(
888                ops.prepend_name_scope(value, scope_to_prepend_to_names))
889            graph.add_to_collection(key, col_op)
890        elif kind == "int64_list":
891          # NOTE(opensource): This force conversion is to work around the fact
892          # that Python2 distinguishes between int and long, while Python3 has
893          # only int.
894          for value in field.value:
895            graph.add_to_collection(key, int(value))
896        else:
897          for value in field.value:
898            graph.add_to_collection(
899                key, ops.prepend_name_scope(value, scope_to_prepend_to_names))
900
901    var_list = {}
902    variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
903                                     scope=scope_to_prepend_to_names)
904    for v in variables:
905      var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v
906
907  return var_list, imported_return_elements
908
909
910def export_scoped_meta_graph(filename=None,
911                             graph_def=None,
912                             graph=None,
913                             export_scope=None,
914                             as_text=False,
915                             unbound_inputs_col_name="unbound_inputs",
916                             clear_devices=False,
917                             saver_def=None,
918                             clear_extraneous_savers=False,
919                             strip_default_attrs=False,
920                             save_debug_info=False,
921                             **kwargs):
922  """Returns `MetaGraphDef` proto. Optionally writes it to filename.
923
924  This function exports the graph, saver, and collection objects into
925  `MetaGraphDef` protocol buffer with the intention of it being imported
926  at a later time or location to restart training, run inference, or be
927  a subgraph.
928
929  Args:
930    filename: Optional filename including the path for writing the
931      generated `MetaGraphDef` protocol buffer.
932    graph_def: `GraphDef` protocol buffer.
933    graph: The `Graph` to export. If `None`, use the default graph.
934    export_scope: Optional `string`. Name scope under which to extract
935      the subgraph. The scope name will be stripped from the node definitions
936      for easy import later into new name scopes. If `None`, the whole graph
937      is exported.
938    as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto.
939    unbound_inputs_col_name: Optional `string`. If provided, a string collection
940      with the given name will be added to the returned `MetaGraphDef`,
941      containing the names of tensors that must be remapped when importing the
942      `MetaGraphDef`.
943    clear_devices: Boolean which controls whether to clear device information
944      before exporting the graph.
945    saver_def: `SaverDef` protocol buffer.
946    clear_extraneous_savers: Remove any Saver-related information from the
947        graph (both Save/Restore ops and SaverDefs) that are not associated
948        with the provided SaverDef.
949    strip_default_attrs: Set to true if default valued attributes must be
950      removed while exporting the GraphDef.
951    save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
952      which in the same directory of filename and with `_debug` added before the
953      file extension.
954    **kwargs: Optional keyed arguments, including meta_info_def and
955        collection_list.
956
957  Returns:
958    A `MetaGraphDef` proto and dictionary of `Variables` in the exported
959    name scope.
960
961  Raises:
962    ValueError: When the `GraphDef` is larger than 2GB.
963    ValueError: When executing in Eager mode and either `graph_def` or `graph`
964      is undefined.
965  """
966  if context.executing_eagerly() and not (graph_def is not None and
967                                          graph is not None):
968    raise ValueError("Exporting/importing meta graphs is not supported when "
969                     "Eager Execution is enabled.")
970  graph = graph or ops.get_default_graph()
971
972  exclude_nodes = None
973  unbound_inputs = []
974  if export_scope or clear_extraneous_savers or clear_devices:
975    if graph_def:
976      new_graph_def = graph_pb2.GraphDef()
977      new_graph_def.versions.CopyFrom(graph_def.versions)
978      new_graph_def.library.CopyFrom(graph_def.library)
979
980      if clear_extraneous_savers:
981        exclude_nodes = _find_extraneous_saver_nodes(graph_def, saver_def)
982
983      for node_def in graph_def.node:
984        if _should_include_node(node_def.name, export_scope, exclude_nodes):
985          new_node_def = _node_def(node_def, export_scope, unbound_inputs,
986                                   clear_devices=clear_devices)
987          new_graph_def.node.extend([new_node_def])
988      graph_def = new_graph_def
989    else:
990      # Only do this complicated work if we want to remove a name scope.
991      graph_def = graph_pb2.GraphDef()
992      # pylint: disable=protected-access
993      graph_def.versions.CopyFrom(graph.graph_def_versions)
994      bytesize = 0
995
996      if clear_extraneous_savers:
997        exclude_nodes = _find_extraneous_saver_nodes(graph.as_graph_def(),
998                                                     saver_def)
999
1000      for key in sorted(graph._nodes_by_id):
1001        if _should_include_node(graph._nodes_by_id[key].name,
1002                                export_scope,
1003                                exclude_nodes):
1004          value = graph._nodes_by_id[key]
1005          # pylint: enable=protected-access
1006          node_def = _node_def(value.node_def, export_scope, unbound_inputs,
1007                               clear_devices=clear_devices)
1008          graph_def.node.extend([node_def])
1009          if value.outputs:
1010            assert "_output_shapes" not in graph_def.node[-1].attr
1011            graph_def.node[-1].attr["_output_shapes"].list.shape.extend([
1012                output.get_shape().as_proto() for output in value.outputs])
1013          bytesize += value.node_def.ByteSize()
1014          if bytesize >= (1 << 31) or bytesize < 0:
1015            raise ValueError("GraphDef cannot be larger than 2GB.")
1016
1017      graph._copy_functions_to_graph_def(graph_def, bytesize)  # pylint: disable=protected-access
1018
1019    # It's possible that not all the inputs are in the export_scope.
1020    # If we would like such information included in the exported meta_graph,
1021    # add them to a special unbound_inputs collection.
1022    if unbound_inputs_col_name:
1023      # Clears the unbound_inputs collections.
1024      graph.clear_collection(unbound_inputs_col_name)
1025      for k in unbound_inputs:
1026        graph.add_to_collection(unbound_inputs_col_name, k)
1027
1028  var_list = {}
1029  variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
1030                                   scope=export_scope)
1031  for v in variables:
1032    if _should_include_node(v, export_scope, exclude_nodes):
1033      var_list[ops.strip_name_scope(v.name, export_scope)] = v
1034
1035  scoped_meta_graph_def = create_meta_graph_def(
1036      graph_def=graph_def,
1037      graph=graph,
1038      export_scope=export_scope,
1039      exclude_nodes=exclude_nodes,
1040      clear_extraneous_savers=clear_extraneous_savers,
1041      saver_def=saver_def,
1042      strip_default_attrs=strip_default_attrs,
1043      **kwargs)
1044
1045  if filename:
1046    graph_io.write_graph(
1047        scoped_meta_graph_def,
1048        os.path.dirname(filename),
1049        os.path.basename(filename),
1050        as_text=as_text)
1051    if save_debug_info:
1052      name, _ = os.path.splitext(filename)
1053      debug_filename = "{name}{ext}".format(name=name, ext=".debug")
1054
1055      # Gets the operation from the graph by the name. Excludes variable nodes,
1056      # so only the nodes in the frozen models are included.
1057      # TODO(liufengdb): fix this for functions.
1058      ops_to_export = []
1059      for node in scoped_meta_graph_def.graph_def.node:
1060        scoped_op_name = ops.prepend_name_scope(node.name, export_scope)
1061        ops_to_export.append(("", graph.get_operation_by_name(scoped_op_name)))
1062
1063      graph_debug_info = error_interpolation.create_graph_debug_info_def(
1064          ops_to_export)
1065
1066      graph_io.write_graph(
1067          graph_debug_info,
1068          os.path.dirname(debug_filename),
1069          os.path.basename(debug_filename),
1070          as_text=as_text)
1071
1072  return scoped_meta_graph_def, var_list
1073
1074
1075def copy_scoped_meta_graph(from_scope, to_scope,
1076                           from_graph=None, to_graph=None):
1077  """Copies a sub-meta_graph from one scope to another.
1078
1079  Args:
1080    from_scope: `String` name scope containing the subgraph to be copied.
1081    to_scope: `String` name scope under which the copied subgraph will reside.
1082    from_graph: Optional `Graph` from which to copy the subgraph. If `None`, the
1083      default graph is use.
1084    to_graph: Optional `Graph` to which to copy the subgraph. If `None`, the
1085      default graph is used.
1086
1087  Returns:
1088    A dictionary of `Variables` that has been copied into `to_scope`.
1089
1090  Raises:
1091    ValueError: If `from_scope` and `to_scope` are the same while
1092      `from_graph` and `to_graph` are also the same.
1093  """
1094  from_graph = from_graph or ops.get_default_graph()
1095  to_graph = to_graph or ops.get_default_graph()
1096
1097  if from_graph == to_graph and from_scope == to_scope:
1098    raise ValueError("'from_scope' and 'to_scope' need to be different "
1099                     "when performing copy in the same graph.")
1100
1101  orig_meta_graph, var_list = export_scoped_meta_graph(
1102      export_scope=from_scope, graph=from_graph)
1103  var_list = import_scoped_meta_graph(orig_meta_graph,
1104                                      graph=to_graph,
1105                                      import_scope=to_scope)
1106  return var_list
1107