1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""MetaGraph and related functions.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import copy 22from distutils import version as distutils_version # pylint: disable=g-bad-import-order 23import os.path 24import re 25 26import six 27from google.protobuf.any_pb2 import Any 28from google.protobuf import text_format 29 30from tensorflow.core.framework import attr_value_pb2 31from tensorflow.core.framework import graph_pb2 32from tensorflow.core.framework import op_def_pb2 33from tensorflow.core.protobuf import meta_graph_pb2 34from tensorflow.core.protobuf import saver_pb2 35from tensorflow.python.client import pywrap_tf_session as c_api 36from tensorflow.python.eager import context 37from tensorflow.python.framework import error_interpolation 38from tensorflow.python.framework import graph_io 39from tensorflow.python.framework import importer 40from tensorflow.python.framework import op_def_registry 41from tensorflow.python.framework import ops 42from tensorflow.python.framework import versions 43from tensorflow.python.lib.io import file_io 44from tensorflow.python.platform import tf_logging as logging 45from tensorflow.python.util import compat 46 47 48# Prefix to be added to unbound input names so they are easily identifiable. 49_UNBOUND_INPUT_PREFIX = "$unbound_inputs_" 50 51# List of collections that didn't register proto functions, as a result in 52# a previously exported meta_graph the items are of a different data type. 53_COMPAT_COLLECTION_LIST = [ops.GraphKeys.LOCAL_VARIABLES, 54 ops.GraphKeys.MODEL_VARIABLES, 55 ops.GraphKeys.METRIC_VARIABLES] 56 57 58def _node_def(from_node_def, export_scope, unbound_inputs, clear_devices=False): 59 """Create a `NodeDef` proto with export_scope stripped. 60 61 Args: 62 from_node_def: A `node_def_pb2.NodeDef` protocol buffer. 63 export_scope: A `string` representing the name scope to remove. 64 unbound_inputs: An array of unbound input names if they exist. 65 clear_devices: Boolean which controls whether to clear device information 66 from node_def. Default false. 67 68 Returns: 69 A `node_def_pb2.NodeDef` protocol buffer. 70 """ 71 node_def = copy.deepcopy(from_node_def) 72 for i, v in enumerate(node_def.input): 73 if (export_scope and 74 not node_def.input[i].lstrip("^").startswith(export_scope)): 75 # Adds "$unbound_inputs_" prefix to the unbound name so they are easily 76 # identifiable. 77 node_def.input[i] = re.sub(r"([\^]|^)(.*)", 78 r"\1" + _UNBOUND_INPUT_PREFIX + r"\2", 79 compat.as_str(v)) 80 unbound_inputs.append(node_def.input[i]) 81 else: 82 node_def.input[i] = ops.strip_name_scope(v, export_scope) 83 node_def.name = compat.as_bytes( 84 ops.strip_name_scope(from_node_def.name, export_scope)) 85 for k, v in six.iteritems(from_node_def.attr): 86 if k == "_class": 87 new_s = [compat.as_bytes( 88 ops.strip_name_scope(s, export_scope)) for s in v.list.s 89 if not export_scope or 90 compat.as_str(s).split("@")[1].startswith(export_scope)] 91 node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue( 92 list=attr_value_pb2.AttrValue.ListValue(s=new_s))) 93 elif node_def.op in ("Enter", "RefEnter") and k == "frame_name": 94 if not export_scope or compat.as_str(v.s).startswith(export_scope): 95 new_s = compat.as_bytes(ops.strip_name_scope(v.s, export_scope)) 96 node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(s=new_s)) 97 else: 98 node_def.attr[k].CopyFrom(v) 99 100 if clear_devices: 101 node_def.device = "" 102 103 return node_def 104 105 106def _read_file(filename): 107 """Reads a file containing `GraphDef` and returns the protocol buffer. 108 109 Args: 110 filename: `graph_def` filename including the path. 111 112 Returns: 113 A `GraphDef` protocol buffer. 114 115 Raises: 116 IOError: If the file doesn't exist, or cannot be successfully parsed. 117 """ 118 graph_def = graph_pb2.GraphDef() 119 if not file_io.file_exists(filename): 120 raise IOError("File %s does not exist." % filename) 121 # First try to read it as a binary file. 122 with file_io.FileIO(filename, "rb") as f: 123 file_content = f.read() 124 try: 125 graph_def.ParseFromString(file_content) 126 return graph_def 127 except Exception: # pylint: disable=broad-except 128 pass 129 130 # Next try to read it as a text file. 131 try: 132 text_format.Merge(file_content, graph_def) 133 except text_format.ParseError as e: 134 raise IOError("Cannot parse file %s: %s." % (filename, str(e))) 135 136 return graph_def 137 138 139def ops_used_by_graph_def(graph_def): 140 """Collect the list of ops used by a graph. 141 142 Does not validate that the ops are all registered. 143 144 Args: 145 graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`. 146 147 Returns: 148 A list of strings, each naming an op used by the graph. 149 """ 150 # Map function names to definitions 151 name_to_function = {} 152 for fun in graph_def.library.function: 153 name_to_function[fun.signature.name] = fun 154 155 # Collect the list of op names. Since functions can reference functions, we 156 # need a recursive traversal. 157 used_ops = set() # Includes both primitive ops and functions 158 functions_to_process = [] # A subset of used_ops 159 160 def mark_op_as_used(op): 161 if op not in used_ops and op in name_to_function: 162 functions_to_process.append(name_to_function[op]) 163 used_ops.add(op) 164 165 def process_node(node): 166 mark_op_as_used(node.op) 167 if node.op in ["PartitionedCall", "StatefulPartitionedCall"]: 168 mark_op_as_used(node.attr["f"].func.name) 169 170 for node in graph_def.node: 171 process_node(node) 172 while functions_to_process: 173 fun = functions_to_process.pop() 174 for node in fun.node_def: 175 process_node(node) 176 177 return [op for op in used_ops if op not in name_to_function] 178 179 180def stripped_op_list_for_graph(graph_def): 181 """Collect the stripped OpDefs for ops used by a graph. 182 183 This function computes the `stripped_op_list` field of `MetaGraphDef` and 184 similar protos. The result can be communicated from the producer to the 185 consumer, which can then use the C++ function 186 `RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility. 187 188 Args: 189 graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`. 190 191 Returns: 192 An `OpList` of ops used by the graph. 193 """ 194 # This is similar to StrippedOpListForGraph in C++, but unlike its 195 # C++ counterpart, this version does not require all ops to be registered. 196 # This is done to support Prelu fusion in tfjs. 197 used_ops = ops_used_by_graph_def(graph_def) 198 op_defs = [] 199 for op in sorted(used_ops): 200 op_def = op_def_registry.get(op) 201 if op_def is not None: 202 op_defs.append(op_def) 203 return op_def_pb2.OpList(op=op_defs) 204 205 206def _get_kind_name(item): 207 """Returns the kind name in CollectionDef. 208 209 Args: 210 item: A data item. 211 212 Returns: 213 The string representation of the kind in CollectionDef. 214 """ 215 if isinstance(item, (six.string_types, six.binary_type)): 216 kind = "bytes_list" 217 elif isinstance(item, six.integer_types): 218 kind = "int64_list" 219 elif isinstance(item, float): 220 kind = "float_list" 221 elif isinstance(item, Any): 222 kind = "any_list" 223 else: 224 kind = "node_list" 225 return kind 226 227 228SAVE_AND_RESTORE_OPS = ["SaveV2", 229 "Save", "SaveSlice", 230 "LegacySave", "LegacySaveSlice", 231 "RestoreV2", 232 "Restore", "RestoreSlice", 233 "LegacyRestore", "LegacyRestoreSlice"] 234 235 236def _op_name(tensor_name): 237 """Extract the Op name from a Tensor name. 238 239 The Op name is everything before a colon, if present, 240 not including any ^ prefix denoting a control dependency. 241 242 Args: 243 tensor_name: the full name of a Tensor in the graph. 244 Returns: 245 The name of the Op of which the given Tensor is an output. 246 Raises: 247 ValueError: if tensor_name is None or empty. 248 """ 249 if not tensor_name: 250 raise ValueError("Tensor name cannot be empty or None.") 251 252 # Control dependency inputs start with ^. 253 if tensor_name.startswith("^"): 254 tensor_name = tensor_name[1:] 255 if ":" in tensor_name: 256 op_name, _ = tensor_name.split(":") 257 return op_name 258 return tensor_name 259 260 261def _get_scope(node_name): 262 """Extract the scope name from a node name. 263 264 The scope name is everything before the final slash, 265 not including any ^ prefix denoting a control dependency. 266 267 Args: 268 node_name: the full name of an Op or a Tensor in the graph. 269 Returns: 270 The deepest named scope containing the node. 271 Raises: 272 ValueError: if tensor_name is None or empty 273 """ 274 if not node_name: 275 raise ValueError("Node name cannot be empty or None.") 276 277 # Control dependency inputs start with ^. 278 if node_name.startswith("^"): 279 node_name = node_name[1:] 280 if "/" in node_name: 281 scope, _ = node_name.rsplit("/", 1) 282 return scope 283 284 return "" 285 286 287def _find_extraneous_saver_nodes(graph_def, saver_def): 288 """Identifies any nodes in the graph_def related to unused Savers. 289 290 This approach assumes that each Saver is cleanly isolated in its own name 291 scope, so we need only identify the scopes associated with extraneous Savers 292 and return all the nodes in those scopes. 293 294 Args: 295 graph_def: a GraphDef proto to evaluate. 296 saver_def: a SaverDef proto referencing Save/Restore ops to be retained. 297 Returns: 298 An iterable of node names that may be safely omitted. 299 """ 300 # TODO(soergel): confirm that the assumption of scope isolation is valid. 301 # If not, we need to walk up the graph from any restore_all nodes, and walk 302 # down the graph from any Save/Restore nodes. I drafted that approach too, 303 # but it seems unnecessarily complex given the name scope solution. 304 305 # load the graph DAG in minimal form, without initializing a full Graph object 306 nodes = { 307 node_def.name: (set(_op_name(x) for x in node_def.input), node_def.op) 308 for node_def in graph_def.node 309 } 310 311 retain_scope_save = None 312 retain_scope_restore = None 313 # It's possible to have no saver if the graph has no Variables 314 if saver_def is not None: 315 save_op_name = _op_name(saver_def.save_tensor_name) 316 restore_op_name = _op_name(saver_def.restore_op_name) 317 318 # The save and restore scopes should always be the same, but if they differ 319 # for some reason, we retain them both to be safe. 320 retain_scope_restore = _get_scope(restore_op_name) + "/" 321 retain_scope_save = _get_scope(save_op_name) + "/" 322 323 all_saver_node_names = set( 324 name for name, (_, op) in nodes.items() if op in SAVE_AND_RESTORE_OPS) 325 326 all_saver_scopes = ( 327 set(_get_scope(x) for x in all_saver_node_names) - all_saver_node_names) 328 all_saver_scopes = set(x + "/" for x in all_saver_scopes) 329 330 extraneous_scopes = all_saver_scopes - set([retain_scope_save, 331 retain_scope_restore]) 332 333 extraneous_node_names = set() 334 for name, _ in nodes.items(): 335 for extraneous_scope in extraneous_scopes: 336 if name.startswith(extraneous_scope): 337 extraneous_node_names.add(name) 338 break 339 340 return extraneous_node_names 341 342 343def _should_include_node(node_or_node_name, export_scope, exclude_nodes): 344 """Returns `True` if a node should be included. 345 346 Args: 347 node_or_node_name: A node or `string` node name. 348 export_scope: `string`. Name scope under which to extract the subgraph. The 349 scope name will be stripped from the node definitions for easy import 350 later into new name scopes. 351 exclude_nodes: An iterable of nodes or `string` node names to omit from the 352 export, or None. Note no sanity-checking is done, so this list must be 353 carefully constructed to avoid producing an invalid graph. 354 355 Returns: 356 `True` if the node should be included. 357 """ 358 if not isinstance(node_or_node_name, six.string_types): 359 try: 360 node_name = node_or_node_name.name 361 except AttributeError: 362 # Keep the object that we don't know how to process. 363 return True 364 else: 365 node_name = node_or_node_name 366 367 if exclude_nodes and (node_or_node_name in exclude_nodes 368 or node_name in exclude_nodes): 369 return False 370 371 return (node_name.startswith(_UNBOUND_INPUT_PREFIX) or 372 (not export_scope or node_name.startswith(export_scope))) 373 374 375def add_collection_def(meta_graph_def, key, graph=None, 376 export_scope=None, exclude_nodes=None, 377 override_contents=None): 378 """Adds a collection to MetaGraphDef protocol buffer. 379 380 Args: 381 meta_graph_def: MetaGraphDef protocol buffer. 382 key: One of the GraphKeys or user-defined string. 383 graph: The `Graph` from which to get collections. 384 export_scope: Optional `string`. Name scope to remove. 385 exclude_nodes: An iterable of nodes or `string` node names to omit from the 386 collection, or None. 387 override_contents: An iterable of values to place in the collection, 388 ignoring the current values (if set). 389 """ 390 if graph and not isinstance(graph, ops.Graph): 391 raise TypeError("graph must be of type Graph, not %s", type(graph)) 392 393 if not isinstance(key, six.string_types) and not isinstance(key, bytes): 394 logging.warning("Only collections with string type keys will be " 395 "serialized. This key has %s", type(key)) 396 return 397 398 # Sets graph to default graph if it's not passed in. 399 graph = graph or ops.get_default_graph() 400 401 if override_contents: 402 collection_list = override_contents 403 else: 404 collection_list = graph.get_collection(key) 405 406 # Remove nodes that should not be exported from the collection list. 407 collection_list = [x for x in collection_list if 408 _should_include_node(x, export_scope, exclude_nodes)] 409 if not collection_list: 410 return 411 412 try: 413 col_def = meta_graph_def.collection_def[key] 414 to_proto = ops.get_to_proto_function(key) 415 proto_type = ops.get_collection_proto_type(key) 416 if to_proto: 417 kind = "bytes_list" 418 for x in collection_list: 419 # Additional type check to make sure the returned proto is indeed 420 # what we expect. 421 proto = to_proto(x, export_scope=export_scope) 422 if proto: 423 assert isinstance(proto, proto_type) 424 getattr(col_def, kind).value.append(proto.SerializeToString()) 425 else: 426 kind = _get_kind_name(collection_list[0]) 427 if kind == "node_list": 428 for x in collection_list: 429 if not export_scope or x.name.startswith(export_scope): 430 getattr(col_def, kind).value.append( 431 ops.strip_name_scope(x.name, export_scope)) 432 elif kind == "bytes_list": 433 # NOTE(opensource): This force conversion is to work around the fact 434 # that Python3 distinguishes between bytes and strings. 435 getattr(col_def, kind).value.extend( 436 [compat.as_bytes(x) for x in collection_list]) 437 else: 438 getattr(col_def, kind).value.extend([x for x in collection_list]) 439 except Exception as e: # pylint: disable=broad-except 440 logging.warning("Issue encountered when serializing %s.\n" 441 "Type is unsupported, or the types of the items don't " 442 "match field type in CollectionDef. Note this is a warning " 443 "and probably safe to ignore.\n%s", key, str(e)) 444 if key in meta_graph_def.collection_def: 445 del meta_graph_def.collection_def[key] 446 return 447 448 449def _is_default_attr_value(op_def, attr_name, attr_value): 450 """Checks if given attribute matches the default value in the op def.""" 451 for attr_def in op_def.attr: 452 if attr_def.name == attr_name: 453 if not attr_def.HasField("default_value"): 454 return False 455 # c_api.EqualAttrValueWrapper returns an empty string 456 # if both arguments represent an equivalent AttrValue instance. 457 return not c_api.EqualAttrValueWrapper( 458 attr_value.SerializeToString(), 459 attr_def.default_value.SerializeToString()) 460 return False 461 462 463def strip_graph_default_valued_attrs(meta_graph_def): 464 """Strips default valued attributes for node defs in given MetaGraphDef. 465 466 This method also sets `meta_info_def.stripped_default_attrs` in the given 467 `MetaGraphDef` proto to True. 468 469 Args: 470 meta_graph_def: `MetaGraphDef` protocol buffer 471 472 Returns: 473 None. 474 """ 475 # Map function op names to their function definitions. 476 op_name_to_function = {} 477 for function_def in meta_graph_def.graph_def.library.function: 478 op_name_to_function[function_def.signature.name] = function_def 479 480 def _strip_node_default_valued_attrs(node_def): 481 """Removes default valued attributes from a single node def.""" 482 if node_def.op in op_name_to_function: 483 return 484 485 op_def = op_def_registry.get(node_def.op) 486 if op_def is None: 487 return 488 489 attrs_to_strip = set() 490 for attr_name, attr_value in node_def.attr.items(): 491 if _is_default_attr_value(op_def, attr_name, attr_value): 492 attrs_to_strip.add(attr_name) 493 494 for attr in attrs_to_strip: 495 del node_def.attr[attr] 496 497 # Process all NodeDef instances in graph_def. 498 for node_def in meta_graph_def.graph_def.node: 499 _strip_node_default_valued_attrs(node_def) 500 501 # Process all NodeDef instances in graph_def.library.function. 502 for function_def in meta_graph_def.graph_def.library.function: 503 for function_node_def in function_def.node_def: 504 _strip_node_default_valued_attrs(function_node_def) 505 506 # Tell consumers of this graph that default valued attrs have been stripped. 507 meta_graph_def.meta_info_def.stripped_default_attrs = True 508 509 510def create_meta_graph_def(meta_info_def=None, 511 graph_def=None, 512 saver_def=None, 513 collection_list=None, 514 graph=None, 515 export_scope=None, 516 exclude_nodes=None, 517 clear_extraneous_savers=False, 518 strip_default_attrs=False): 519 # pylint: disable=line-too-long 520 """Construct and returns a `MetaGraphDef` protocol buffer. 521 522 Args: 523 meta_info_def: `MetaInfoDef` protocol buffer. 524 graph_def: `GraphDef` protocol buffer. 525 saver_def: `SaverDef` protocol buffer. 526 collection_list: List of string keys to collect. 527 graph: The `Graph` to create `MetaGraphDef` out of. 528 export_scope: Optional `string`. Name scope to remove. 529 exclude_nodes: An iterable of nodes or `string` node names to omit from all 530 collection, or None. 531 clear_extraneous_savers: Remove any preexisting SaverDefs from the SAVERS 532 collection. Note this method does not alter the graph, so any 533 extraneous Save/Restore ops should have been removed already, as needed. 534 strip_default_attrs: Boolean. If `True`, default-valued attributes will be 535 removed from the NodeDefs. For a detailed guide, see 536 [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). 537 538 Returns: 539 MetaGraphDef protocol buffer. 540 541 Raises: 542 TypeError: If the arguments are not of the correct proto buffer type. 543 """ 544 # pylint: enable=line-too-long 545 # Type check. 546 if graph and not isinstance(graph, ops.Graph): 547 raise TypeError("graph must be of type Graph, not %s", type(graph)) 548 if meta_info_def and not isinstance(meta_info_def, 549 meta_graph_pb2.MetaGraphDef.MetaInfoDef): 550 raise TypeError("meta_info_def must be of type MetaInfoDef, not %s", 551 type(meta_info_def)) 552 if graph_def and not isinstance(graph_def, graph_pb2.GraphDef): 553 raise TypeError("graph_def must be of type GraphDef, not %s", 554 type(graph_def)) 555 if saver_def and not isinstance(saver_def, saver_pb2.SaverDef): 556 raise TypeError("saver_def must be of type SaverDef, not %s", 557 type(saver_def)) 558 559 # Sets graph to default graph if it's not passed in. 560 graph = graph or ops.get_default_graph() 561 562 # Creates a MetaGraphDef proto. 563 meta_graph_def = meta_graph_pb2.MetaGraphDef() 564 # Adds meta_info_def. 565 if not meta_info_def: 566 meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef() 567 568 # Set the tf version strings to the current tf build. 569 meta_info_def.tensorflow_version = versions.__version__ 570 meta_info_def.tensorflow_git_version = versions.__git_version__ 571 meta_graph_def.meta_info_def.MergeFrom(meta_info_def) 572 573 # Adds graph_def or the default. 574 if not graph_def: 575 meta_graph_def.graph_def.MergeFrom(graph.as_graph_def(add_shapes=True)) 576 else: 577 meta_graph_def.graph_def.MergeFrom(graph_def) 578 579 # Fills in meta_info_def.stripped_op_list using the ops from graph_def. 580 # pylint: disable=g-explicit-length-test 581 if len(meta_graph_def.meta_info_def.stripped_op_list.op) == 0: 582 meta_graph_def.meta_info_def.stripped_op_list.MergeFrom( 583 stripped_op_list_for_graph(meta_graph_def.graph_def)) 584 # pylint: enable=g-explicit-length-test 585 586 # Strip default valued attributes in graph_def. 587 if strip_default_attrs: 588 strip_graph_default_valued_attrs(meta_graph_def) 589 590 # Adds saver_def. 591 if saver_def: 592 meta_graph_def.saver_def.MergeFrom(saver_def) 593 594 # Adds collection_list. 595 if collection_list is not None: 596 clist = collection_list 597 else: 598 clist = graph.get_all_collection_keys() 599 600 for ctype in clist: 601 if clear_extraneous_savers and ctype == ops.GraphKeys.SAVERS: 602 # Avoid importing Saver here 603 from_proto = ops.get_from_proto_function(ctype) 604 add_collection_def(meta_graph_def, ctype, 605 graph=graph, 606 export_scope=export_scope, 607 exclude_nodes=exclude_nodes, 608 override_contents=[from_proto(saver_def)]) 609 else: 610 add_collection_def(meta_graph_def, ctype, 611 graph=graph, 612 export_scope=export_scope, 613 exclude_nodes=exclude_nodes) 614 return meta_graph_def 615 616 617def read_meta_graph_file(filename): 618 """Reads a file containing `MetaGraphDef` and returns the protocol buffer. 619 620 Args: 621 filename: `meta_graph_def` filename including the path. 622 623 Returns: 624 A `MetaGraphDef` protocol buffer. 625 626 Raises: 627 IOError: If the file doesn't exist, or cannot be successfully parsed. 628 """ 629 meta_graph_def = meta_graph_pb2.MetaGraphDef() 630 if not file_io.file_exists(filename): 631 raise IOError("File %s does not exist." % filename) 632 # First try to read it as a binary file. 633 with file_io.FileIO(filename, "rb") as f: 634 file_content = f.read() 635 try: 636 meta_graph_def.ParseFromString(file_content) 637 return meta_graph_def 638 except Exception: # pylint: disable=broad-except 639 pass 640 641 # Next try to read it as a text file. 642 try: 643 text_format.Merge(file_content.decode("utf-8"), meta_graph_def) 644 except text_format.ParseError as e: 645 raise IOError("Cannot parse file %s: %s." % (filename, str(e))) 646 647 return meta_graph_def 648 649 650def import_scoped_meta_graph(meta_graph_or_file, 651 clear_devices=False, 652 graph=None, 653 import_scope=None, 654 input_map=None, 655 unbound_inputs_col_name="unbound_inputs", 656 restore_collections_predicate=(lambda key: True)): 657 """Recreates a `Graph` saved in a `MetaGraphDef` proto. 658 659 This function takes a `MetaGraphDef` protocol buffer as input. If 660 the argument is a file containing a `MetaGraphDef` protocol buffer , 661 it constructs a protocol buffer from the file content. The function 662 then adds all the nodes from the `graph_def` field to the 663 current graph, recreates the desired collections, and returns a dictionary of 664 all the Variables imported into the name scope. 665 666 In combination with `export_scoped_meta_graph()`, this function can be used to 667 668 * Serialize a graph along with other Python objects such as `QueueRunner`, 669 `Variable` into a `MetaGraphDef`. 670 671 * Restart training from a saved graph and checkpoints. 672 673 * Run inference from a saved graph and checkpoints. 674 675 Args: 676 meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including 677 the path) containing a `MetaGraphDef`. 678 clear_devices: Boolean which controls whether to clear device information 679 from graph_def. Default false. 680 graph: The `Graph` to import into. If `None`, use the default graph. 681 import_scope: Optional `string`. Name scope into which to import the 682 subgraph. If `None`, the graph is imported to the root name scope. 683 input_map: A dictionary mapping input names (as strings) in `graph_def` to 684 `Tensor` objects. The values of the named input tensors in the imported 685 graph will be re-mapped to the respective `Tensor` values. 686 unbound_inputs_col_name: Collection name for looking up unbound inputs. 687 restore_collections_predicate: a predicate on collection names. A collection 688 named c (i.e whose key is c) will be restored iff 689 1) `restore_collections_predicate(c)` is True, and 690 2) `c != unbound_inputs_col_name`. 691 692 Returns: 693 A dictionary of all the `Variables` imported into the name scope. 694 695 Raises: 696 ValueError: If the graph_def contains unbound inputs. 697 """ 698 return import_scoped_meta_graph_with_return_elements( 699 meta_graph_or_file, clear_devices, graph, import_scope, input_map, 700 unbound_inputs_col_name, restore_collections_predicate)[0] 701 702 703def import_scoped_meta_graph_with_return_elements( 704 meta_graph_or_file, 705 clear_devices=False, 706 graph=None, 707 import_scope=None, 708 input_map=None, 709 unbound_inputs_col_name="unbound_inputs", 710 restore_collections_predicate=(lambda key: True), 711 return_elements=None): 712 """Imports graph from `MetaGraphDef` and returns vars and return elements. 713 714 This function takes a `MetaGraphDef` protocol buffer as input. If 715 the argument is a file containing a `MetaGraphDef` protocol buffer , 716 it constructs a protocol buffer from the file content. The function 717 then adds all the nodes from the `graph_def` field to the 718 current graph, recreates the desired collections, and returns a dictionary of 719 all the Variables imported into the name scope. 720 721 In combination with `export_scoped_meta_graph()`, this function can be used to 722 723 * Serialize a graph along with other Python objects such as `QueueRunner`, 724 `Variable` into a `MetaGraphDef`. 725 726 * Restart training from a saved graph and checkpoints. 727 728 * Run inference from a saved graph and checkpoints. 729 730 Args: 731 meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including 732 the path) containing a `MetaGraphDef`. 733 clear_devices: Boolean which controls whether to clear device information 734 from graph_def. Default false. 735 graph: The `Graph` to import into. If `None`, use the default graph. 736 import_scope: Optional `string`. Name scope into which to import the 737 subgraph. If `None`, the graph is imported to the root name scope. 738 input_map: A dictionary mapping input names (as strings) in `graph_def` to 739 `Tensor` objects. The values of the named input tensors in the imported 740 graph will be re-mapped to the respective `Tensor` values. 741 unbound_inputs_col_name: Collection name for looking up unbound inputs. 742 restore_collections_predicate: a predicate on collection names. A collection 743 named c (i.e whose key is c) will be restored iff 744 1) `restore_collections_predicate(c)` is True, and 745 2) `c != unbound_inputs_col_name`. 746 return_elements: A list of strings containing operation names in the 747 `MetaGraphDef` that will be returned as `Operation` objects; and/or 748 tensor names in `MetaGraphDef` that will be returned as `Tensor` objects. 749 750 Returns: 751 A tuple of ( 752 dictionary of all the `Variables` imported into the name scope, 753 list of `Operation` or `Tensor` objects from the `return_elements` list). 754 755 Raises: 756 ValueError: If the graph_def contains unbound inputs. 757 758 """ 759 if context.executing_eagerly(): 760 raise ValueError("Exporting/importing meta graphs is not supported when " 761 "eager execution is enabled.") 762 if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): 763 meta_graph_def = meta_graph_or_file 764 else: 765 meta_graph_def = read_meta_graph_file(meta_graph_or_file) 766 767 if unbound_inputs_col_name: 768 for key, col_def in meta_graph_def.collection_def.items(): 769 if key == unbound_inputs_col_name: 770 kind = col_def.WhichOneof("kind") 771 field = getattr(col_def, kind) 772 if field.value and ( 773 not input_map or 774 sorted([compat.as_str(v) for v in field.value]) != 775 sorted(input_map)): 776 raise ValueError("Graph contains unbound inputs: %s. Must " 777 "provide these inputs through input_map." % ",".join( 778 compat.as_str(v) 779 for v in field.value 780 if not input_map or v not in input_map)) 781 break 782 783 # Sets graph to default graph if it's not passed in. 784 graph = graph or ops.get_default_graph() 785 786 # Gathers the list of nodes we are interested in. 787 with graph.as_default(): 788 producer_op_list = None 789 if meta_graph_def.meta_info_def.HasField("stripped_op_list"): 790 producer_op_list = meta_graph_def.meta_info_def.stripped_op_list 791 input_graph_def = meta_graph_def.graph_def 792 # Remove all the explicit device specifications for this node. This helps to 793 # make the graph more portable. 794 if clear_devices: 795 for node in input_graph_def.node: 796 node.device = "" 797 798 scope_to_prepend_to_names = graph.unique_name( 799 import_scope or "", mark_as_used=False) 800 801 imported_return_elements = importer.import_graph_def( 802 input_graph_def, 803 name=(import_scope or scope_to_prepend_to_names), 804 input_map=input_map, 805 producer_op_list=producer_op_list, 806 return_elements=return_elements) 807 808 # TensorFlow versions before 1.9 (not inclusive) exported SavedModels 809 # without a VariableDef.trainable field set. 810 tf_version = meta_graph_def.meta_info_def.tensorflow_version 811 if not tf_version: 812 variables_have_trainable = True 813 else: 814 variables_have_trainable = ( 815 distutils_version.LooseVersion(tf_version) 816 >= distutils_version.LooseVersion("1.9")) 817 818 # Sort collections so we see TRAINABLE_VARIABLES first and can default these 819 # variables to trainable if the value is not set in their VariableDef. 820 sorted_collections = [] 821 if ops.GraphKeys.TRAINABLE_VARIABLES in meta_graph_def.collection_def: 822 sorted_collections.append( 823 (ops.GraphKeys.TRAINABLE_VARIABLES, 824 meta_graph_def.collection_def[ops.GraphKeys.TRAINABLE_VARIABLES])) 825 for key, value in sorted(meta_graph_def.collection_def.items()): 826 if key != ops.GraphKeys.TRAINABLE_VARIABLES: 827 sorted_collections.append((key, value)) 828 829 # Restores all the other collections. 830 variable_objects = {} 831 for key, col_def in sorted_collections: 832 # Don't add unbound_inputs to the new graph. 833 if key == unbound_inputs_col_name: 834 continue 835 if not restore_collections_predicate(key): 836 continue 837 838 kind = col_def.WhichOneof("kind") 839 if kind is None: 840 logging.error("Cannot identify data type for collection %s. Skipping.", 841 key) 842 continue 843 from_proto = ops.get_from_proto_function(key) 844 845 # Temporary change to allow the TFMA evaluator to read metric variables 846 # saved as a bytes list. 847 # TODO(kathywu): Remove this hack once cl/248406059 has been submitted. 848 if key == ops.GraphKeys.METRIC_VARIABLES: 849 # Metric variables will use the same proto functions as GLOBAL_VARIABLES 850 from_proto = ops.get_from_proto_function(ops.GraphKeys.GLOBAL_VARIABLES) 851 if from_proto and kind == "bytes_list": 852 proto_type = ops.get_collection_proto_type(key) 853 if key in ops.GraphKeys._VARIABLE_COLLECTIONS: # pylint: disable=protected-access 854 for value in col_def.bytes_list.value: 855 variable = variable_objects.get(value, None) 856 if variable is None: 857 proto = proto_type() 858 proto.ParseFromString(value) 859 if not variables_have_trainable: 860 # If the VariableDef proto does not contain a "trainable" 861 # property because it was exported before that property was 862 # added, we default it to whether the variable is in the 863 # TRAINABLE_VARIABLES collection. We've sorted 864 # TRAINABLE_VARIABLES to be first, so trainable variables will 865 # be created from that collection. 866 proto.trainable = (key == ops.GraphKeys.TRAINABLE_VARIABLES) 867 variable = from_proto( 868 proto, import_scope=scope_to_prepend_to_names) 869 variable_objects[value] = variable 870 graph.add_to_collection(key, variable) 871 else: 872 for value in col_def.bytes_list.value: 873 proto = proto_type() 874 proto.ParseFromString(value) 875 graph.add_to_collection( 876 key, from_proto( 877 proto, import_scope=scope_to_prepend_to_names)) 878 else: 879 field = getattr(col_def, kind) 880 if key in _COMPAT_COLLECTION_LIST: 881 logging.warning( 882 "The saved meta_graph is possibly from an older release:\n" 883 "'%s' collection should be of type 'byte_list', but instead " 884 "is of type '%s'.", key, kind) 885 if kind == "node_list": 886 for value in field.value: 887 col_op = graph.as_graph_element( 888 ops.prepend_name_scope(value, scope_to_prepend_to_names)) 889 graph.add_to_collection(key, col_op) 890 elif kind == "int64_list": 891 # NOTE(opensource): This force conversion is to work around the fact 892 # that Python2 distinguishes between int and long, while Python3 has 893 # only int. 894 for value in field.value: 895 graph.add_to_collection(key, int(value)) 896 else: 897 for value in field.value: 898 graph.add_to_collection( 899 key, ops.prepend_name_scope(value, scope_to_prepend_to_names)) 900 901 var_list = {} 902 variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, 903 scope=scope_to_prepend_to_names) 904 for v in variables: 905 var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v 906 907 return var_list, imported_return_elements 908 909 910def export_scoped_meta_graph(filename=None, 911 graph_def=None, 912 graph=None, 913 export_scope=None, 914 as_text=False, 915 unbound_inputs_col_name="unbound_inputs", 916 clear_devices=False, 917 saver_def=None, 918 clear_extraneous_savers=False, 919 strip_default_attrs=False, 920 save_debug_info=False, 921 **kwargs): 922 """Returns `MetaGraphDef` proto. Optionally writes it to filename. 923 924 This function exports the graph, saver, and collection objects into 925 `MetaGraphDef` protocol buffer with the intention of it being imported 926 at a later time or location to restart training, run inference, or be 927 a subgraph. 928 929 Args: 930 filename: Optional filename including the path for writing the 931 generated `MetaGraphDef` protocol buffer. 932 graph_def: `GraphDef` protocol buffer. 933 graph: The `Graph` to export. If `None`, use the default graph. 934 export_scope: Optional `string`. Name scope under which to extract 935 the subgraph. The scope name will be stripped from the node definitions 936 for easy import later into new name scopes. If `None`, the whole graph 937 is exported. 938 as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto. 939 unbound_inputs_col_name: Optional `string`. If provided, a string collection 940 with the given name will be added to the returned `MetaGraphDef`, 941 containing the names of tensors that must be remapped when importing the 942 `MetaGraphDef`. 943 clear_devices: Boolean which controls whether to clear device information 944 before exporting the graph. 945 saver_def: `SaverDef` protocol buffer. 946 clear_extraneous_savers: Remove any Saver-related information from the 947 graph (both Save/Restore ops and SaverDefs) that are not associated 948 with the provided SaverDef. 949 strip_default_attrs: Set to true if default valued attributes must be 950 removed while exporting the GraphDef. 951 save_debug_info: If `True`, save the GraphDebugInfo to a separate file, 952 which in the same directory of filename and with `_debug` added before the 953 file extension. 954 **kwargs: Optional keyed arguments, including meta_info_def and 955 collection_list. 956 957 Returns: 958 A `MetaGraphDef` proto and dictionary of `Variables` in the exported 959 name scope. 960 961 Raises: 962 ValueError: When the `GraphDef` is larger than 2GB. 963 ValueError: When executing in Eager mode and either `graph_def` or `graph` 964 is undefined. 965 """ 966 if context.executing_eagerly() and not (graph_def is not None and 967 graph is not None): 968 raise ValueError("Exporting/importing meta graphs is not supported when " 969 "Eager Execution is enabled.") 970 graph = graph or ops.get_default_graph() 971 972 exclude_nodes = None 973 unbound_inputs = [] 974 if export_scope or clear_extraneous_savers or clear_devices: 975 if graph_def: 976 new_graph_def = graph_pb2.GraphDef() 977 new_graph_def.versions.CopyFrom(graph_def.versions) 978 new_graph_def.library.CopyFrom(graph_def.library) 979 980 if clear_extraneous_savers: 981 exclude_nodes = _find_extraneous_saver_nodes(graph_def, saver_def) 982 983 for node_def in graph_def.node: 984 if _should_include_node(node_def.name, export_scope, exclude_nodes): 985 new_node_def = _node_def(node_def, export_scope, unbound_inputs, 986 clear_devices=clear_devices) 987 new_graph_def.node.extend([new_node_def]) 988 graph_def = new_graph_def 989 else: 990 # Only do this complicated work if we want to remove a name scope. 991 graph_def = graph_pb2.GraphDef() 992 # pylint: disable=protected-access 993 graph_def.versions.CopyFrom(graph.graph_def_versions) 994 bytesize = 0 995 996 if clear_extraneous_savers: 997 exclude_nodes = _find_extraneous_saver_nodes(graph.as_graph_def(), 998 saver_def) 999 1000 for key in sorted(graph._nodes_by_id): 1001 if _should_include_node(graph._nodes_by_id[key].name, 1002 export_scope, 1003 exclude_nodes): 1004 value = graph._nodes_by_id[key] 1005 # pylint: enable=protected-access 1006 node_def = _node_def(value.node_def, export_scope, unbound_inputs, 1007 clear_devices=clear_devices) 1008 graph_def.node.extend([node_def]) 1009 if value.outputs: 1010 assert "_output_shapes" not in graph_def.node[-1].attr 1011 graph_def.node[-1].attr["_output_shapes"].list.shape.extend([ 1012 output.get_shape().as_proto() for output in value.outputs]) 1013 bytesize += value.node_def.ByteSize() 1014 if bytesize >= (1 << 31) or bytesize < 0: 1015 raise ValueError("GraphDef cannot be larger than 2GB.") 1016 1017 graph._copy_functions_to_graph_def(graph_def, bytesize) # pylint: disable=protected-access 1018 1019 # It's possible that not all the inputs are in the export_scope. 1020 # If we would like such information included in the exported meta_graph, 1021 # add them to a special unbound_inputs collection. 1022 if unbound_inputs_col_name: 1023 # Clears the unbound_inputs collections. 1024 graph.clear_collection(unbound_inputs_col_name) 1025 for k in unbound_inputs: 1026 graph.add_to_collection(unbound_inputs_col_name, k) 1027 1028 var_list = {} 1029 variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, 1030 scope=export_scope) 1031 for v in variables: 1032 if _should_include_node(v, export_scope, exclude_nodes): 1033 var_list[ops.strip_name_scope(v.name, export_scope)] = v 1034 1035 scoped_meta_graph_def = create_meta_graph_def( 1036 graph_def=graph_def, 1037 graph=graph, 1038 export_scope=export_scope, 1039 exclude_nodes=exclude_nodes, 1040 clear_extraneous_savers=clear_extraneous_savers, 1041 saver_def=saver_def, 1042 strip_default_attrs=strip_default_attrs, 1043 **kwargs) 1044 1045 if filename: 1046 graph_io.write_graph( 1047 scoped_meta_graph_def, 1048 os.path.dirname(filename), 1049 os.path.basename(filename), 1050 as_text=as_text) 1051 if save_debug_info: 1052 name, _ = os.path.splitext(filename) 1053 debug_filename = "{name}{ext}".format(name=name, ext=".debug") 1054 1055 # Gets the operation from the graph by the name. Excludes variable nodes, 1056 # so only the nodes in the frozen models are included. 1057 # TODO(liufengdb): fix this for functions. 1058 ops_to_export = [] 1059 for node in scoped_meta_graph_def.graph_def.node: 1060 scoped_op_name = ops.prepend_name_scope(node.name, export_scope) 1061 ops_to_export.append(("", graph.get_operation_by_name(scoped_op_name))) 1062 1063 graph_debug_info = error_interpolation.create_graph_debug_info_def( 1064 ops_to_export) 1065 1066 graph_io.write_graph( 1067 graph_debug_info, 1068 os.path.dirname(debug_filename), 1069 os.path.basename(debug_filename), 1070 as_text=as_text) 1071 1072 return scoped_meta_graph_def, var_list 1073 1074 1075def copy_scoped_meta_graph(from_scope, to_scope, 1076 from_graph=None, to_graph=None): 1077 """Copies a sub-meta_graph from one scope to another. 1078 1079 Args: 1080 from_scope: `String` name scope containing the subgraph to be copied. 1081 to_scope: `String` name scope under which the copied subgraph will reside. 1082 from_graph: Optional `Graph` from which to copy the subgraph. If `None`, the 1083 default graph is use. 1084 to_graph: Optional `Graph` to which to copy the subgraph. If `None`, the 1085 default graph is used. 1086 1087 Returns: 1088 A dictionary of `Variables` that has been copied into `to_scope`. 1089 1090 Raises: 1091 ValueError: If `from_scope` and `to_scope` are the same while 1092 `from_graph` and `to_graph` are also the same. 1093 """ 1094 from_graph = from_graph or ops.get_default_graph() 1095 to_graph = to_graph or ops.get_default_graph() 1096 1097 if from_graph == to_graph and from_scope == to_scope: 1098 raise ValueError("'from_scope' and 'to_scope' need to be different " 1099 "when performing copy in the same graph.") 1100 1101 orig_meta_graph, var_list = export_scoped_meta_graph( 1102 export_scope=from_scope, graph=from_graph) 1103 var_list = import_scoped_meta_graph(orig_meta_graph, 1104 graph=to_graph, 1105 import_scope=to_scope) 1106 return var_list 1107