1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""MetaGraph and related functions.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import copy 22from distutils import version as distutils_version # pylint: disable=g-bad-import-order 23import os.path 24import re 25 26import six 27from google.protobuf.any_pb2 import Any 28from google.protobuf import text_format 29 30from tensorflow.core.framework import attr_value_pb2 31from tensorflow.core.framework import graph_pb2 32from tensorflow.core.framework import op_def_pb2 33from tensorflow.core.protobuf import meta_graph_pb2 34from tensorflow.core.protobuf import saver_pb2 35from tensorflow.python.client import pywrap_tf_session as c_api 36from tensorflow.python.eager import context 37from tensorflow.python.framework import error_interpolation 38from tensorflow.python.framework import graph_io 39from tensorflow.python.framework import importer 40from tensorflow.python.framework import op_def_registry 41from tensorflow.python.framework import ops 42from tensorflow.python.framework import versions 43from tensorflow.python.lib.io import file_io 44from tensorflow.python.platform import tf_logging as logging 45from tensorflow.python.util import compat 46 47 48# Prefix to be added to unbound input names so they are easily identifiable. 49_UNBOUND_INPUT_PREFIX = "$unbound_inputs_" 50 51# List of collections that didn't register proto functions, as a result in 52# a previously exported meta_graph the items are of a different data type. 53_COMPAT_COLLECTION_LIST = [ops.GraphKeys.LOCAL_VARIABLES, 54 ops.GraphKeys.MODEL_VARIABLES, 55 ops.GraphKeys.METRIC_VARIABLES] 56 57 58def _node_def(from_node_def, export_scope, unbound_inputs, clear_devices=False): 59 """Create a `NodeDef` proto with export_scope stripped. 60 61 Args: 62 from_node_def: A `node_def_pb2.NodeDef` protocol buffer. 63 export_scope: A `string` representing the name scope to remove. 64 unbound_inputs: An array of unbound input names if they exist. 65 clear_devices: Boolean which controls whether to clear device information 66 from node_def. Default false. 67 68 Returns: 69 A `node_def_pb2.NodeDef` protocol buffer. 70 """ 71 node_def = copy.deepcopy(from_node_def) 72 for i, v in enumerate(node_def.input): 73 if (export_scope and 74 not node_def.input[i].lstrip("^").startswith(export_scope)): 75 # Adds "$unbound_inputs_" prefix to the unbound name so they are easily 76 # identifiable. 77 node_def.input[i] = re.sub(r"([\^]|^)(.*)", 78 r"\1" + _UNBOUND_INPUT_PREFIX + r"\2", 79 compat.as_str(v)) 80 unbound_inputs.append(node_def.input[i]) 81 else: 82 node_def.input[i] = ops.strip_name_scope(v, export_scope) 83 node_def.name = compat.as_bytes( 84 ops.strip_name_scope(from_node_def.name, export_scope)) 85 for k, v in six.iteritems(from_node_def.attr): 86 if k == "_class": 87 new_s = [compat.as_bytes( 88 ops.strip_name_scope(s, export_scope)) for s in v.list.s 89 if not export_scope or 90 compat.as_str(s).split("@")[1].startswith(export_scope)] 91 node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue( 92 list=attr_value_pb2.AttrValue.ListValue(s=new_s))) 93 elif node_def.op in ("Enter", "RefEnter") and k == "frame_name": 94 if not export_scope or compat.as_str(v.s).startswith(export_scope): 95 new_s = compat.as_bytes(ops.strip_name_scope(v.s, export_scope)) 96 node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(s=new_s)) 97 else: 98 node_def.attr[k].CopyFrom(v) 99 100 if clear_devices: 101 node_def.device = "" 102 103 return node_def 104 105 106def _read_file(filename): 107 """Reads a file containing `GraphDef` and returns the protocol buffer. 108 109 Args: 110 filename: `graph_def` filename including the path. 111 112 Returns: 113 A `GraphDef` protocol buffer. 114 115 Raises: 116 IOError: If the file doesn't exist, or cannot be successfully parsed. 117 """ 118 graph_def = graph_pb2.GraphDef() 119 if not file_io.file_exists(filename): 120 raise IOError("File %s does not exist." % filename) 121 # First try to read it as a binary file. 122 file_content = file_io.FileIO(filename, "rb").read() 123 try: 124 graph_def.ParseFromString(file_content) 125 return graph_def 126 except Exception: # pylint: disable=broad-except 127 pass 128 129 # Next try to read it as a text file. 130 try: 131 text_format.Merge(file_content, graph_def) 132 except text_format.ParseError as e: 133 raise IOError("Cannot parse file %s: %s." % (filename, str(e))) 134 135 return graph_def 136 137 138def ops_used_by_graph_def(graph_def): 139 """Collect the list of ops used by a graph. 140 141 Does not validate that the ops are all registered. 142 143 Args: 144 graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`. 145 146 Returns: 147 A list of strings, each naming an op used by the graph. 148 """ 149 # Map function names to definitions 150 name_to_function = {} 151 for fun in graph_def.library.function: 152 name_to_function[fun.signature.name] = fun 153 154 # Collect the list of op names. Since functions can reference functions, we 155 # need a recursive traversal. 156 used_ops = set() # Includes both primitive ops and functions 157 functions_to_process = [] # A subset of used_ops 158 159 def mark_op_as_used(op): 160 if op not in used_ops and op in name_to_function: 161 functions_to_process.append(name_to_function[op]) 162 used_ops.add(op) 163 164 for node in graph_def.node: 165 mark_op_as_used(node.op) 166 while functions_to_process: 167 fun = functions_to_process.pop() 168 for node in fun.node_def: 169 mark_op_as_used(node.op) 170 171 return [op for op in used_ops if op not in name_to_function] 172 173 174def stripped_op_list_for_graph(graph_def): 175 """Collect the stripped OpDefs for ops used by a graph. 176 177 This function computes the `stripped_op_list` field of `MetaGraphDef` and 178 similar protos. The result can be communicated from the producer to the 179 consumer, which can then use the C++ function 180 `RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility. 181 182 Args: 183 graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`. 184 185 Returns: 186 An `OpList` of ops used by the graph. 187 """ 188 # This is similar to StrippedOpListForGraph in C++, but unlike its 189 # C++ counterpart, this version does not require all ops to be registered. 190 # This is done to support Prelu fusion in tfjs. 191 used_ops = ops_used_by_graph_def(graph_def) 192 op_defs = [] 193 for op in sorted(used_ops): 194 op_def = op_def_registry.get(op) 195 if op_def is not None: 196 op_defs.append(op_def) 197 return op_def_pb2.OpList(op=op_defs) 198 199 200def _get_kind_name(item): 201 """Returns the kind name in CollectionDef. 202 203 Args: 204 item: A data item. 205 206 Returns: 207 The string representation of the kind in CollectionDef. 208 """ 209 if isinstance(item, (six.string_types, six.binary_type)): 210 kind = "bytes_list" 211 elif isinstance(item, six.integer_types): 212 kind = "int64_list" 213 elif isinstance(item, float): 214 kind = "float_list" 215 elif isinstance(item, Any): 216 kind = "any_list" 217 else: 218 kind = "node_list" 219 return kind 220 221 222SAVE_AND_RESTORE_OPS = ["SaveV2", 223 "Save", "SaveSlice", 224 "LegacySave", "LegacySaveSlice", 225 "RestoreV2", 226 "Restore", "RestoreSlice", 227 "LegacyRestore", "LegacyRestoreSlice"] 228 229 230def _op_name(tensor_name): 231 """Extract the Op name from a Tensor name. 232 233 The Op name is everything before a colon, if present, 234 not including any ^ prefix denoting a control dependency. 235 236 Args: 237 tensor_name: the full name of a Tensor in the graph. 238 Returns: 239 The name of the Op of which the given Tensor is an output. 240 Raises: 241 ValueError: if tensor_name is None or empty. 242 """ 243 if not tensor_name: 244 raise ValueError("Tensor name cannot be empty or None.") 245 246 # Control dependency inputs start with ^. 247 if tensor_name.startswith("^"): 248 tensor_name = tensor_name[1:] 249 if ":" in tensor_name: 250 op_name, _ = tensor_name.split(":") 251 return op_name 252 return tensor_name 253 254 255def _get_scope(node_name): 256 """Extract the scope name from a node name. 257 258 The scope name is everything before the final slash, 259 not including any ^ prefix denoting a control dependency. 260 261 Args: 262 node_name: the full name of an Op or a Tensor in the graph. 263 Returns: 264 The deepest named scope containing the node. 265 Raises: 266 ValueError: if tensor_name is None or empty 267 """ 268 if not node_name: 269 raise ValueError("Node name cannot be empty or None.") 270 271 # Control dependency inputs start with ^. 272 if node_name.startswith("^"): 273 node_name = node_name[1:] 274 if "/" in node_name: 275 scope, _ = node_name.rsplit("/", 1) 276 return scope 277 278 return "" 279 280 281def _find_extraneous_saver_nodes(graph_def, saver_def): 282 """Identifies any nodes in the graph_def related to unused Savers. 283 284 This approach assumes that each Saver is cleanly isolated in its own name 285 scope, so we need only identify the scopes associated with extraneous Savers 286 and return all the nodes in those scopes. 287 288 Args: 289 graph_def: a GraphDef proto to evaluate. 290 saver_def: a SaverDef proto referencing Save/Restore ops to be retained. 291 Returns: 292 An iterable of node names that may be safely omitted. 293 """ 294 # TODO(soergel): confirm that the assumption of scope isolation is valid. 295 # If not, we need to walk up the graph from any restore_all nodes, and walk 296 # down the graph from any Save/Restore nodes. I drafted that approach too, 297 # but it seems unnecessarily complex given the name scope solution. 298 299 # load the graph DAG in minimal form, without initializing a full Graph object 300 nodes = { 301 node_def.name: (set(_op_name(x) for x in node_def.input), node_def.op) 302 for node_def in graph_def.node 303 } 304 305 retain_scope_save = None 306 retain_scope_restore = None 307 # It's possible to have no saver if the graph has no Variables 308 if saver_def is not None: 309 save_op_name = _op_name(saver_def.save_tensor_name) 310 restore_op_name = _op_name(saver_def.restore_op_name) 311 312 # The save and restore scopes should always be the same, but if they differ 313 # for some reason, we retain them both to be safe. 314 retain_scope_restore = _get_scope(restore_op_name) + "/" 315 retain_scope_save = _get_scope(save_op_name) + "/" 316 317 all_saver_node_names = set( 318 name for name, (_, op) in nodes.items() if op in SAVE_AND_RESTORE_OPS) 319 320 all_saver_scopes = ( 321 set(_get_scope(x) for x in all_saver_node_names) - all_saver_node_names) 322 all_saver_scopes = set(x + "/" for x in all_saver_scopes) 323 324 extraneous_scopes = all_saver_scopes - set([retain_scope_save, 325 retain_scope_restore]) 326 327 extraneous_node_names = set() 328 for name, _ in nodes.items(): 329 for extraneous_scope in extraneous_scopes: 330 if name.startswith(extraneous_scope): 331 extraneous_node_names.add(name) 332 break 333 334 return extraneous_node_names 335 336 337def _should_include_node(node_or_node_name, export_scope, exclude_nodes): 338 """Returns `True` if a node should be included. 339 340 Args: 341 node_or_node_name: A node or `string` node name. 342 export_scope: `string`. Name scope under which to extract the subgraph. The 343 scope name will be stripped from the node definitions for easy import 344 later into new name scopes. 345 exclude_nodes: An iterable of nodes or `string` node names to omit from the 346 export, or None. Note no sanity-checking is done, so this list must be 347 carefully constructed to avoid producing an invalid graph. 348 349 Returns: 350 `True` if the node should be included. 351 """ 352 if not isinstance(node_or_node_name, six.string_types): 353 try: 354 node_name = node_or_node_name.name 355 except AttributeError: 356 # Keep the object that we don't know how to process. 357 return True 358 else: 359 node_name = node_or_node_name 360 361 if exclude_nodes and (node_or_node_name in exclude_nodes 362 or node_name in exclude_nodes): 363 return False 364 365 return (node_name.startswith(_UNBOUND_INPUT_PREFIX) or 366 (not export_scope or node_name.startswith(export_scope))) 367 368 369def add_collection_def(meta_graph_def, key, graph=None, 370 export_scope=None, exclude_nodes=None, 371 override_contents=None): 372 """Adds a collection to MetaGraphDef protocol buffer. 373 374 Args: 375 meta_graph_def: MetaGraphDef protocol buffer. 376 key: One of the GraphKeys or user-defined string. 377 graph: The `Graph` from which to get collections. 378 export_scope: Optional `string`. Name scope to remove. 379 exclude_nodes: An iterable of nodes or `string` node names to omit from the 380 collection, or None. 381 override_contents: An iterable of values to place in the collection, 382 ignoring the current values (if set). 383 """ 384 if graph and not isinstance(graph, ops.Graph): 385 raise TypeError("graph must be of type Graph, not %s", type(graph)) 386 387 if not isinstance(key, six.string_types) and not isinstance(key, bytes): 388 logging.warning("Only collections with string type keys will be " 389 "serialized. This key has %s", type(key)) 390 return 391 392 # Sets graph to default graph if it's not passed in. 393 graph = graph or ops.get_default_graph() 394 395 if override_contents: 396 collection_list = override_contents 397 else: 398 collection_list = graph.get_collection(key) 399 400 # Remove nodes that should not be exported from the collection list. 401 collection_list = [x for x in collection_list if 402 _should_include_node(x, export_scope, exclude_nodes)] 403 if not collection_list: 404 return 405 406 try: 407 col_def = meta_graph_def.collection_def[key] 408 to_proto = ops.get_to_proto_function(key) 409 proto_type = ops.get_collection_proto_type(key) 410 if to_proto: 411 kind = "bytes_list" 412 for x in collection_list: 413 # Additional type check to make sure the returned proto is indeed 414 # what we expect. 415 proto = to_proto(x, export_scope=export_scope) 416 if proto: 417 assert isinstance(proto, proto_type) 418 getattr(col_def, kind).value.append(proto.SerializeToString()) 419 else: 420 kind = _get_kind_name(collection_list[0]) 421 if kind == "node_list": 422 for x in collection_list: 423 if not export_scope or x.name.startswith(export_scope): 424 getattr(col_def, kind).value.append( 425 ops.strip_name_scope(x.name, export_scope)) 426 elif kind == "bytes_list": 427 # NOTE(opensource): This force conversion is to work around the fact 428 # that Python3 distinguishes between bytes and strings. 429 getattr(col_def, kind).value.extend( 430 [compat.as_bytes(x) for x in collection_list]) 431 else: 432 getattr(col_def, kind).value.extend([x for x in collection_list]) 433 except Exception as e: # pylint: disable=broad-except 434 logging.warning("Issue encountered when serializing %s.\n" 435 "Type is unsupported, or the types of the items don't " 436 "match field type in CollectionDef. Note this is a warning " 437 "and probably safe to ignore.\n%s", key, str(e)) 438 if key in meta_graph_def.collection_def: 439 del meta_graph_def.collection_def[key] 440 return 441 442 443def _is_default_attr_value(op_def, attr_name, attr_value): 444 """Checks if given attribute matches the default value in the op def.""" 445 for attr_def in op_def.attr: 446 if attr_def.name == attr_name: 447 if not attr_def.HasField("default_value"): 448 return False 449 # c_api.EqualAttrValueWrapper returns an empty string 450 # if both arguments represent an equivalent AttrValue instance. 451 return not c_api.EqualAttrValueWrapper( 452 attr_value.SerializeToString(), 453 attr_def.default_value.SerializeToString()) 454 return False 455 456 457def strip_graph_default_valued_attrs(meta_graph_def): 458 """Strips default valued attributes for node defs in given MetaGraphDef. 459 460 This method also sets `meta_info_def.stripped_default_attrs` in the given 461 `MetaGraphDef` proto to True. 462 463 Args: 464 meta_graph_def: `MetaGraphDef` protocol buffer 465 466 Returns: 467 None. 468 """ 469 # Map function op names to their function definitions. 470 op_name_to_function = {} 471 for function_def in meta_graph_def.graph_def.library.function: 472 op_name_to_function[function_def.signature.name] = function_def 473 474 def _strip_node_default_valued_attrs(node_def): 475 """Removes default valued attributes from a single node def.""" 476 if node_def.op in op_name_to_function: 477 return 478 479 op_def = op_def_registry.get(node_def.op) 480 if op_def is None: 481 return 482 483 attrs_to_strip = set() 484 for attr_name, attr_value in node_def.attr.items(): 485 if _is_default_attr_value(op_def, attr_name, attr_value): 486 attrs_to_strip.add(attr_name) 487 488 for attr in attrs_to_strip: 489 del node_def.attr[attr] 490 491 # Process all NodeDef instances in graph_def. 492 for node_def in meta_graph_def.graph_def.node: 493 _strip_node_default_valued_attrs(node_def) 494 495 # Process all NodeDef instances in graph_def.library.function. 496 for function_def in meta_graph_def.graph_def.library.function: 497 for function_node_def in function_def.node_def: 498 _strip_node_default_valued_attrs(function_node_def) 499 500 # Tell consumers of this graph that default valued attrs have been stripped. 501 meta_graph_def.meta_info_def.stripped_default_attrs = True 502 503 504def create_meta_graph_def(meta_info_def=None, 505 graph_def=None, 506 saver_def=None, 507 collection_list=None, 508 graph=None, 509 export_scope=None, 510 exclude_nodes=None, 511 clear_extraneous_savers=False, 512 strip_default_attrs=False): 513 # pylint: disable=line-too-long 514 """Construct and returns a `MetaGraphDef` protocol buffer. 515 516 Args: 517 meta_info_def: `MetaInfoDef` protocol buffer. 518 graph_def: `GraphDef` protocol buffer. 519 saver_def: `SaverDef` protocol buffer. 520 collection_list: List of string keys to collect. 521 graph: The `Graph` to create `MetaGraphDef` out of. 522 export_scope: Optional `string`. Name scope to remove. 523 exclude_nodes: An iterable of nodes or `string` node names to omit from all 524 collection, or None. 525 clear_extraneous_savers: Remove any preexisting SaverDefs from the SAVERS 526 collection. Note this method does not alter the graph, so any 527 extraneous Save/Restore ops should have been removed already, as needed. 528 strip_default_attrs: Boolean. If `True`, default-valued attributes will be 529 removed from the NodeDefs. For a detailed guide, see 530 [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). 531 532 Returns: 533 MetaGraphDef protocol buffer. 534 535 Raises: 536 TypeError: If the arguments are not of the correct proto buffer type. 537 """ 538 # pylint: enable=line-too-long 539 # Type check. 540 if graph and not isinstance(graph, ops.Graph): 541 raise TypeError("graph must be of type Graph, not %s", type(graph)) 542 if meta_info_def and not isinstance(meta_info_def, 543 meta_graph_pb2.MetaGraphDef.MetaInfoDef): 544 raise TypeError("meta_info_def must be of type MetaInfoDef, not %s", 545 type(meta_info_def)) 546 if graph_def and not isinstance(graph_def, graph_pb2.GraphDef): 547 raise TypeError("graph_def must be of type GraphDef, not %s", 548 type(graph_def)) 549 if saver_def and not isinstance(saver_def, saver_pb2.SaverDef): 550 raise TypeError("saver_def must be of type SaverDef, not %s", 551 type(saver_def)) 552 553 # Sets graph to default graph if it's not passed in. 554 graph = graph or ops.get_default_graph() 555 556 # Creates a MetaGraphDef proto. 557 meta_graph_def = meta_graph_pb2.MetaGraphDef() 558 # Adds meta_info_def. 559 if not meta_info_def: 560 meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef() 561 562 # Set the tf version strings to the current tf build. 563 meta_info_def.tensorflow_version = versions.__version__ 564 meta_info_def.tensorflow_git_version = versions.__git_version__ 565 meta_graph_def.meta_info_def.MergeFrom(meta_info_def) 566 567 # Adds graph_def or the default. 568 if not graph_def: 569 meta_graph_def.graph_def.MergeFrom(graph.as_graph_def(add_shapes=True)) 570 else: 571 meta_graph_def.graph_def.MergeFrom(graph_def) 572 573 # Fills in meta_info_def.stripped_op_list using the ops from graph_def. 574 # pylint: disable=g-explicit-length-test 575 if len(meta_graph_def.meta_info_def.stripped_op_list.op) == 0: 576 meta_graph_def.meta_info_def.stripped_op_list.MergeFrom( 577 stripped_op_list_for_graph(meta_graph_def.graph_def)) 578 # pylint: enable=g-explicit-length-test 579 580 # Strip default valued attributes in graph_def. 581 if strip_default_attrs: 582 strip_graph_default_valued_attrs(meta_graph_def) 583 584 # Adds saver_def. 585 if saver_def: 586 meta_graph_def.saver_def.MergeFrom(saver_def) 587 588 # Adds collection_list. 589 if collection_list is not None: 590 clist = collection_list 591 else: 592 clist = graph.get_all_collection_keys() 593 594 for ctype in clist: 595 if clear_extraneous_savers and ctype == ops.GraphKeys.SAVERS: 596 # Avoid importing Saver here 597 from_proto = ops.get_from_proto_function(ctype) 598 add_collection_def(meta_graph_def, ctype, 599 graph=graph, 600 export_scope=export_scope, 601 exclude_nodes=exclude_nodes, 602 override_contents=[from_proto(saver_def)]) 603 else: 604 add_collection_def(meta_graph_def, ctype, 605 graph=graph, 606 export_scope=export_scope, 607 exclude_nodes=exclude_nodes) 608 return meta_graph_def 609 610 611def read_meta_graph_file(filename): 612 """Reads a file containing `MetaGraphDef` and returns the protocol buffer. 613 614 Args: 615 filename: `meta_graph_def` filename including the path. 616 617 Returns: 618 A `MetaGraphDef` protocol buffer. 619 620 Raises: 621 IOError: If the file doesn't exist, or cannot be successfully parsed. 622 """ 623 meta_graph_def = meta_graph_pb2.MetaGraphDef() 624 if not file_io.file_exists(filename): 625 raise IOError("File %s does not exist." % filename) 626 # First try to read it as a binary file. 627 file_content = file_io.FileIO(filename, "rb").read() 628 try: 629 meta_graph_def.ParseFromString(file_content) 630 return meta_graph_def 631 except Exception: # pylint: disable=broad-except 632 pass 633 634 # Next try to read it as a text file. 635 try: 636 text_format.Merge(file_content.decode("utf-8"), meta_graph_def) 637 except text_format.ParseError as e: 638 raise IOError("Cannot parse file %s: %s." % (filename, str(e))) 639 640 return meta_graph_def 641 642 643def import_scoped_meta_graph(meta_graph_or_file, 644 clear_devices=False, 645 graph=None, 646 import_scope=None, 647 input_map=None, 648 unbound_inputs_col_name="unbound_inputs", 649 restore_collections_predicate=(lambda key: True)): 650 """Recreates a `Graph` saved in a `MetaGraphDef` proto. 651 652 This function takes a `MetaGraphDef` protocol buffer as input. If 653 the argument is a file containing a `MetaGraphDef` protocol buffer , 654 it constructs a protocol buffer from the file content. The function 655 then adds all the nodes from the `graph_def` field to the 656 current graph, recreates the desired collections, and returns a dictionary of 657 all the Variables imported into the name scope. 658 659 In combination with `export_scoped_meta_graph()`, this function can be used to 660 661 * Serialize a graph along with other Python objects such as `QueueRunner`, 662 `Variable` into a `MetaGraphDef`. 663 664 * Restart training from a saved graph and checkpoints. 665 666 * Run inference from a saved graph and checkpoints. 667 668 Args: 669 meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including 670 the path) containing a `MetaGraphDef`. 671 clear_devices: Boolean which controls whether to clear device information 672 from graph_def. Default false. 673 graph: The `Graph` to import into. If `None`, use the default graph. 674 import_scope: Optional `string`. Name scope into which to import the 675 subgraph. If `None`, the graph is imported to the root name scope. 676 input_map: A dictionary mapping input names (as strings) in `graph_def` to 677 `Tensor` objects. The values of the named input tensors in the imported 678 graph will be re-mapped to the respective `Tensor` values. 679 unbound_inputs_col_name: Collection name for looking up unbound inputs. 680 restore_collections_predicate: a predicate on collection names. A collection 681 named c (i.e whose key is c) will be restored iff 682 1) `restore_collections_predicate(c)` is True, and 683 2) `c != unbound_inputs_col_name`. 684 685 Returns: 686 A dictionary of all the `Variables` imported into the name scope. 687 688 Raises: 689 ValueError: If the graph_def contains unbound inputs. 690 """ 691 return import_scoped_meta_graph_with_return_elements( 692 meta_graph_or_file, clear_devices, graph, import_scope, input_map, 693 unbound_inputs_col_name, restore_collections_predicate)[0] 694 695 696def import_scoped_meta_graph_with_return_elements( 697 meta_graph_or_file, 698 clear_devices=False, 699 graph=None, 700 import_scope=None, 701 input_map=None, 702 unbound_inputs_col_name="unbound_inputs", 703 restore_collections_predicate=(lambda key: True), 704 return_elements=None): 705 """Imports graph from `MetaGraphDef` and returns vars and return elements. 706 707 This function takes a `MetaGraphDef` protocol buffer as input. If 708 the argument is a file containing a `MetaGraphDef` protocol buffer , 709 it constructs a protocol buffer from the file content. The function 710 then adds all the nodes from the `graph_def` field to the 711 current graph, recreates the desired collections, and returns a dictionary of 712 all the Variables imported into the name scope. 713 714 In combination with `export_scoped_meta_graph()`, this function can be used to 715 716 * Serialize a graph along with other Python objects such as `QueueRunner`, 717 `Variable` into a `MetaGraphDef`. 718 719 * Restart training from a saved graph and checkpoints. 720 721 * Run inference from a saved graph and checkpoints. 722 723 Args: 724 meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including 725 the path) containing a `MetaGraphDef`. 726 clear_devices: Boolean which controls whether to clear device information 727 from graph_def. Default false. 728 graph: The `Graph` to import into. If `None`, use the default graph. 729 import_scope: Optional `string`. Name scope into which to import the 730 subgraph. If `None`, the graph is imported to the root name scope. 731 input_map: A dictionary mapping input names (as strings) in `graph_def` to 732 `Tensor` objects. The values of the named input tensors in the imported 733 graph will be re-mapped to the respective `Tensor` values. 734 unbound_inputs_col_name: Collection name for looking up unbound inputs. 735 restore_collections_predicate: a predicate on collection names. A collection 736 named c (i.e whose key is c) will be restored iff 737 1) `restore_collections_predicate(c)` is True, and 738 2) `c != unbound_inputs_col_name`. 739 return_elements: A list of strings containing operation names in the 740 `MetaGraphDef` that will be returned as `Operation` objects; and/or 741 tensor names in `MetaGraphDef` that will be returned as `Tensor` objects. 742 743 Returns: 744 A tuple of ( 745 dictionary of all the `Variables` imported into the name scope, 746 list of `Operation` or `Tensor` objects from the `return_elements` list). 747 748 Raises: 749 ValueError: If the graph_def contains unbound inputs. 750 751 """ 752 if context.executing_eagerly(): 753 raise ValueError("Exporting/importing meta graphs is not supported when " 754 "eager execution is enabled.") 755 if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): 756 meta_graph_def = meta_graph_or_file 757 else: 758 meta_graph_def = read_meta_graph_file(meta_graph_or_file) 759 760 if unbound_inputs_col_name: 761 for key, col_def in meta_graph_def.collection_def.items(): 762 if key == unbound_inputs_col_name: 763 kind = col_def.WhichOneof("kind") 764 field = getattr(col_def, kind) 765 if field.value and ( 766 not input_map or 767 sorted([compat.as_str(v) for v in field.value]) != 768 sorted(input_map)): 769 raise ValueError("Graph contains unbound inputs: %s. Must " 770 "provide these inputs through input_map." % ",".join( 771 compat.as_str(v) 772 for v in field.value 773 if not input_map or v not in input_map)) 774 break 775 776 # Sets graph to default graph if it's not passed in. 777 graph = graph or ops.get_default_graph() 778 779 # Gathers the list of nodes we are interested in. 780 with graph.as_default(): 781 producer_op_list = None 782 if meta_graph_def.meta_info_def.HasField("stripped_op_list"): 783 producer_op_list = meta_graph_def.meta_info_def.stripped_op_list 784 input_graph_def = meta_graph_def.graph_def 785 # Remove all the explicit device specifications for this node. This helps to 786 # make the graph more portable. 787 if clear_devices: 788 for node in input_graph_def.node: 789 node.device = "" 790 791 scope_to_prepend_to_names = graph.unique_name( 792 import_scope or "", mark_as_used=False) 793 794 imported_return_elements = importer.import_graph_def( 795 input_graph_def, 796 name=(import_scope or scope_to_prepend_to_names), 797 input_map=input_map, 798 producer_op_list=producer_op_list, 799 return_elements=return_elements) 800 801 # TensorFlow versions before 1.9 (not inclusive) exported SavedModels 802 # without a VariableDef.trainable field set. 803 tf_version = meta_graph_def.meta_info_def.tensorflow_version 804 if not tf_version: 805 variables_have_trainable = True 806 else: 807 variables_have_trainable = ( 808 distutils_version.LooseVersion(tf_version) 809 >= distutils_version.LooseVersion("1.9")) 810 811 # Sort collections so we see TRAINABLE_VARIABLES first and can default these 812 # variables to trainable if the value is not set in their VariableDef. 813 sorted_collections = [] 814 if ops.GraphKeys.TRAINABLE_VARIABLES in meta_graph_def.collection_def: 815 sorted_collections.append( 816 (ops.GraphKeys.TRAINABLE_VARIABLES, 817 meta_graph_def.collection_def[ops.GraphKeys.TRAINABLE_VARIABLES])) 818 for key, value in sorted(meta_graph_def.collection_def.items()): 819 if key != ops.GraphKeys.TRAINABLE_VARIABLES: 820 sorted_collections.append((key, value)) 821 822 # Restores all the other collections. 823 variable_objects = {} 824 for key, col_def in sorted_collections: 825 # Don't add unbound_inputs to the new graph. 826 if key == unbound_inputs_col_name: 827 continue 828 if not restore_collections_predicate(key): 829 continue 830 831 kind = col_def.WhichOneof("kind") 832 if kind is None: 833 logging.error("Cannot identify data type for collection %s. Skipping.", 834 key) 835 continue 836 from_proto = ops.get_from_proto_function(key) 837 838 # Temporary change to allow the TFMA evaluator to read metric variables 839 # saved as a bytes list. 840 # TODO(kathywu): Remove this hack once cl/248406059 has been submitted. 841 if key == ops.GraphKeys.METRIC_VARIABLES: 842 # Metric variables will use the same proto functions as GLOBAL_VARIABLES 843 from_proto = ops.get_from_proto_function(ops.GraphKeys.GLOBAL_VARIABLES) 844 if from_proto and kind == "bytes_list": 845 proto_type = ops.get_collection_proto_type(key) 846 if key in ops.GraphKeys._VARIABLE_COLLECTIONS: # pylint: disable=protected-access 847 for value in col_def.bytes_list.value: 848 variable = variable_objects.get(value, None) 849 if variable is None: 850 proto = proto_type() 851 proto.ParseFromString(value) 852 if not variables_have_trainable: 853 # If the VariableDef proto does not contain a "trainable" 854 # property because it was exported before that property was 855 # added, we default it to whether the variable is in the 856 # TRAINABLE_VARIABLES collection. We've sorted 857 # TRAINABLE_VARIABLES to be first, so trainable variables will 858 # be created from that collection. 859 proto.trainable = (key == ops.GraphKeys.TRAINABLE_VARIABLES) 860 variable = from_proto( 861 proto, import_scope=scope_to_prepend_to_names) 862 variable_objects[value] = variable 863 graph.add_to_collection(key, variable) 864 else: 865 for value in col_def.bytes_list.value: 866 proto = proto_type() 867 proto.ParseFromString(value) 868 graph.add_to_collection( 869 key, from_proto( 870 proto, import_scope=scope_to_prepend_to_names)) 871 else: 872 field = getattr(col_def, kind) 873 if key in _COMPAT_COLLECTION_LIST: 874 logging.warning( 875 "The saved meta_graph is possibly from an older release:\n" 876 "'%s' collection should be of type 'byte_list', but instead " 877 "is of type '%s'.", key, kind) 878 if kind == "node_list": 879 for value in field.value: 880 col_op = graph.as_graph_element( 881 ops.prepend_name_scope(value, scope_to_prepend_to_names)) 882 graph.add_to_collection(key, col_op) 883 elif kind == "int64_list": 884 # NOTE(opensource): This force conversion is to work around the fact 885 # that Python2 distinguishes between int and long, while Python3 has 886 # only int. 887 for value in field.value: 888 graph.add_to_collection(key, int(value)) 889 else: 890 for value in field.value: 891 graph.add_to_collection( 892 key, ops.prepend_name_scope(value, scope_to_prepend_to_names)) 893 894 var_list = {} 895 variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, 896 scope=scope_to_prepend_to_names) 897 for v in variables: 898 var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v 899 900 return var_list, imported_return_elements 901 902 903def export_scoped_meta_graph(filename=None, 904 graph_def=None, 905 graph=None, 906 export_scope=None, 907 as_text=False, 908 unbound_inputs_col_name="unbound_inputs", 909 clear_devices=False, 910 saver_def=None, 911 clear_extraneous_savers=False, 912 strip_default_attrs=False, 913 save_debug_info=False, 914 **kwargs): 915 """Returns `MetaGraphDef` proto. Optionally writes it to filename. 916 917 This function exports the graph, saver, and collection objects into 918 `MetaGraphDef` protocol buffer with the intention of it being imported 919 at a later time or location to restart training, run inference, or be 920 a subgraph. 921 922 Args: 923 filename: Optional filename including the path for writing the 924 generated `MetaGraphDef` protocol buffer. 925 graph_def: `GraphDef` protocol buffer. 926 graph: The `Graph` to export. If `None`, use the default graph. 927 export_scope: Optional `string`. Name scope under which to extract 928 the subgraph. The scope name will be stripped from the node definitions 929 for easy import later into new name scopes. If `None`, the whole graph 930 is exported. 931 as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto. 932 unbound_inputs_col_name: Optional `string`. If provided, a string collection 933 with the given name will be added to the returned `MetaGraphDef`, 934 containing the names of tensors that must be remapped when importing the 935 `MetaGraphDef`. 936 clear_devices: Boolean which controls whether to clear device information 937 before exporting the graph. 938 saver_def: `SaverDef` protocol buffer. 939 clear_extraneous_savers: Remove any Saver-related information from the 940 graph (both Save/Restore ops and SaverDefs) that are not associated 941 with the provided SaverDef. 942 strip_default_attrs: Set to true if default valued attributes must be 943 removed while exporting the GraphDef. 944 save_debug_info: If `True`, save the GraphDebugInfo to a separate file, 945 which in the same directory of filename and with `_debug` added before the 946 file extension. 947 **kwargs: Optional keyed arguments, including meta_info_def and 948 collection_list. 949 950 Returns: 951 A `MetaGraphDef` proto and dictionary of `Variables` in the exported 952 name scope. 953 954 Raises: 955 ValueError: When the `GraphDef` is larger than 2GB. 956 ValueError: When executing in Eager mode and either `graph_def` or `graph` 957 is undefined. 958 """ 959 if context.executing_eagerly() and not (graph_def is not None and 960 graph is not None): 961 raise ValueError("Exporting/importing meta graphs is not supported when " 962 "Eager Execution is enabled.") 963 graph = graph or ops.get_default_graph() 964 965 exclude_nodes = None 966 unbound_inputs = [] 967 if export_scope or clear_extraneous_savers or clear_devices: 968 if graph_def: 969 new_graph_def = graph_pb2.GraphDef() 970 new_graph_def.versions.CopyFrom(graph_def.versions) 971 new_graph_def.library.CopyFrom(graph_def.library) 972 973 if clear_extraneous_savers: 974 exclude_nodes = _find_extraneous_saver_nodes(graph_def, saver_def) 975 976 for node_def in graph_def.node: 977 if _should_include_node(node_def.name, export_scope, exclude_nodes): 978 new_node_def = _node_def(node_def, export_scope, unbound_inputs, 979 clear_devices=clear_devices) 980 new_graph_def.node.extend([new_node_def]) 981 graph_def = new_graph_def 982 else: 983 # Only do this complicated work if we want to remove a name scope. 984 graph_def = graph_pb2.GraphDef() 985 # pylint: disable=protected-access 986 graph_def.versions.CopyFrom(graph.graph_def_versions) 987 bytesize = 0 988 989 if clear_extraneous_savers: 990 exclude_nodes = _find_extraneous_saver_nodes(graph.as_graph_def(), 991 saver_def) 992 993 for key in sorted(graph._nodes_by_id): 994 if _should_include_node(graph._nodes_by_id[key].name, 995 export_scope, 996 exclude_nodes): 997 value = graph._nodes_by_id[key] 998 # pylint: enable=protected-access 999 node_def = _node_def(value.node_def, export_scope, unbound_inputs, 1000 clear_devices=clear_devices) 1001 graph_def.node.extend([node_def]) 1002 if value.outputs: 1003 assert "_output_shapes" not in graph_def.node[-1].attr 1004 graph_def.node[-1].attr["_output_shapes"].list.shape.extend([ 1005 output.get_shape().as_proto() for output in value.outputs]) 1006 bytesize += value.node_def.ByteSize() 1007 if bytesize >= (1 << 31) or bytesize < 0: 1008 raise ValueError("GraphDef cannot be larger than 2GB.") 1009 1010 graph._copy_functions_to_graph_def(graph_def, bytesize) # pylint: disable=protected-access 1011 1012 # It's possible that not all the inputs are in the export_scope. 1013 # If we would like such information included in the exported meta_graph, 1014 # add them to a special unbound_inputs collection. 1015 if unbound_inputs_col_name: 1016 # Clears the unbound_inputs collections. 1017 graph.clear_collection(unbound_inputs_col_name) 1018 for k in unbound_inputs: 1019 graph.add_to_collection(unbound_inputs_col_name, k) 1020 1021 var_list = {} 1022 variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, 1023 scope=export_scope) 1024 for v in variables: 1025 if _should_include_node(v, export_scope, exclude_nodes): 1026 var_list[ops.strip_name_scope(v.name, export_scope)] = v 1027 1028 scoped_meta_graph_def = create_meta_graph_def( 1029 graph_def=graph_def, 1030 graph=graph, 1031 export_scope=export_scope, 1032 exclude_nodes=exclude_nodes, 1033 clear_extraneous_savers=clear_extraneous_savers, 1034 saver_def=saver_def, 1035 strip_default_attrs=strip_default_attrs, 1036 **kwargs) 1037 1038 if filename: 1039 graph_io.write_graph( 1040 scoped_meta_graph_def, 1041 os.path.dirname(filename), 1042 os.path.basename(filename), 1043 as_text=as_text) 1044 if save_debug_info: 1045 name, _ = os.path.splitext(filename) 1046 debug_filename = "{name}{ext}".format(name=name, ext=".debug") 1047 1048 # Gets the operation from the graph by the name. Exludes variable nodes, 1049 # so only the nodes in the frozen models are included. 1050 # TODO(liufengdb): fix this for functions. 1051 ops_to_export = [] 1052 for node in scoped_meta_graph_def.graph_def.node: 1053 scoped_op_name = ops.prepend_name_scope(node.name, export_scope) 1054 ops_to_export.append(("", graph.get_operation_by_name(scoped_op_name))) 1055 1056 graph_debug_info = error_interpolation.create_graph_debug_info_def( 1057 ops_to_export) 1058 1059 graph_io.write_graph( 1060 graph_debug_info, 1061 os.path.dirname(debug_filename), 1062 os.path.basename(debug_filename), 1063 as_text=as_text) 1064 1065 return scoped_meta_graph_def, var_list 1066 1067 1068def copy_scoped_meta_graph(from_scope, to_scope, 1069 from_graph=None, to_graph=None): 1070 """Copies a sub-meta_graph from one scope to another. 1071 1072 Args: 1073 from_scope: `String` name scope containing the subgraph to be copied. 1074 to_scope: `String` name scope under which the copied subgraph will reside. 1075 from_graph: Optional `Graph` from which to copy the subgraph. If `None`, the 1076 default graph is use. 1077 to_graph: Optional `Graph` to which to copy the subgraph. If `None`, the 1078 default graph is used. 1079 1080 Returns: 1081 A dictionary of `Variables` that has been copied into `to_scope`. 1082 1083 Raises: 1084 ValueError: If `from_scope` and `to_scope` are the same while 1085 `from_graph` and `to_graph` are also the same. 1086 """ 1087 from_graph = from_graph or ops.get_default_graph() 1088 to_graph = to_graph or ops.get_default_graph() 1089 1090 if from_graph == to_graph and from_scope == to_scope: 1091 raise ValueError("'from_scope' and 'to_scope' need to be different " 1092 "when performing copy in the same graph.") 1093 1094 orig_meta_graph, var_list = export_scoped_meta_graph( 1095 export_scope=from_scope, graph=from_graph) 1096 var_list = import_scoped_meta_graph(orig_meta_graph, 1097 graph=to_graph, 1098 import_scope=to_scope) 1099 return var_list 1100