1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""MetaGraph and related functions.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import copy 22from distutils import version as distutils_version # pylint: disable=g-bad-import-order 23import os.path 24import re 25 26import six 27from google.protobuf.any_pb2 import Any 28from google.protobuf import text_format 29 30from tensorflow.core.framework import attr_value_pb2 31from tensorflow.core.framework import graph_pb2 32from tensorflow.core.framework import op_def_pb2 33from tensorflow.core.protobuf import meta_graph_pb2 34from tensorflow.core.protobuf import saver_pb2 35from tensorflow.python.client import pywrap_tf_session as c_api 36from tensorflow.python.eager import context 37from tensorflow.python.framework import error_interpolation 38from tensorflow.python.framework import graph_io 39from tensorflow.python.framework import importer 40from tensorflow.python.framework import op_def_registry 41from tensorflow.python.framework import ops 42from tensorflow.python.framework import versions 43from tensorflow.python.lib.io import file_io 44from tensorflow.python.platform import tf_logging as logging 45from tensorflow.python.util import compat 46 47 48# Prefix to be added to unbound input names so they are easily identifiable. 49_UNBOUND_INPUT_PREFIX = "$unbound_inputs_" 50 51# List of collections that didn't register proto functions, as a result in 52# a previously exported meta_graph the items are of a different data type. 53_COMPAT_COLLECTION_LIST = [ops.GraphKeys.LOCAL_VARIABLES, 54 ops.GraphKeys.MODEL_VARIABLES, 55 ops.GraphKeys.METRIC_VARIABLES] 56 57 58def _node_def(from_node_def, export_scope, unbound_inputs, clear_devices=False): 59 """Create a `NodeDef` proto with export_scope stripped. 60 61 Args: 62 from_node_def: A `node_def_pb2.NodeDef` protocol buffer. 63 export_scope: A `string` representing the name scope to remove. 64 unbound_inputs: An array of unbound input names if they exist. 65 clear_devices: Boolean which controls whether to clear device information 66 from node_def. Default false. 67 68 Returns: 69 A `node_def_pb2.NodeDef` protocol buffer. 70 """ 71 node_def = copy.deepcopy(from_node_def) 72 for i, v in enumerate(node_def.input): 73 if (export_scope and 74 not node_def.input[i].lstrip("^").startswith(export_scope)): 75 # Adds "$unbound_inputs_" prefix to the unbound name so they are easily 76 # identifiable. 77 node_def.input[i] = re.sub(r"([\^]|^)(.*)", 78 r"\1" + _UNBOUND_INPUT_PREFIX + r"\2", 79 compat.as_str(v)) 80 unbound_inputs.append(node_def.input[i]) 81 else: 82 node_def.input[i] = ops.strip_name_scope(v, export_scope) 83 node_def.name = compat.as_bytes( 84 ops.strip_name_scope(from_node_def.name, export_scope)) 85 for k, v in six.iteritems(from_node_def.attr): 86 if k == "_class": 87 new_s = [compat.as_bytes( 88 ops.strip_name_scope(s, export_scope)) for s in v.list.s 89 if not export_scope or 90 compat.as_str(s).split("@")[1].startswith(export_scope)] 91 node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue( 92 list=attr_value_pb2.AttrValue.ListValue(s=new_s))) 93 elif node_def.op in ("Enter", "RefEnter") and k == "frame_name": 94 if not export_scope or compat.as_str(v.s).startswith(export_scope): 95 new_s = compat.as_bytes(ops.strip_name_scope(v.s, export_scope)) 96 node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(s=new_s)) 97 else: 98 node_def.attr[k].CopyFrom(v) 99 100 if clear_devices: 101 node_def.device = "" 102 103 return node_def 104 105 106def _read_file(filename): 107 """Reads a file containing `GraphDef` and returns the protocol buffer. 108 109 Args: 110 filename: `graph_def` filename including the path. 111 112 Returns: 113 A `GraphDef` protocol buffer. 114 115 Raises: 116 IOError: If the file doesn't exist, or cannot be successfully parsed. 117 """ 118 graph_def = graph_pb2.GraphDef() 119 if not file_io.file_exists(filename): 120 raise IOError("File %s does not exist." % filename) 121 # First try to read it as a binary file. 122 file_content = file_io.FileIO(filename, "rb").read() 123 try: 124 graph_def.ParseFromString(file_content) 125 return graph_def 126 except Exception: # pylint: disable=broad-except 127 pass 128 129 # Next try to read it as a text file. 130 try: 131 text_format.Merge(file_content, graph_def) 132 except text_format.ParseError as e: 133 raise IOError("Cannot parse file %s: %s." % (filename, str(e))) 134 135 return graph_def 136 137 138def ops_used_by_graph_def(graph_def): 139 """Collect the list of ops used by a graph. 140 141 Does not validate that the ops are all registered. 142 143 Args: 144 graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`. 145 146 Returns: 147 A list of strings, each naming an op used by the graph. 148 """ 149 # Map function names to definitions 150 name_to_function = {} 151 for fun in graph_def.library.function: 152 name_to_function[fun.signature.name] = fun 153 154 # Collect the list of op names. Since functions can reference functions, we 155 # need a recursive traversal. 156 used_ops = set() # Includes both primitive ops and functions 157 functions_to_process = [] # A subset of used_ops 158 159 def mark_op_as_used(op): 160 if op not in used_ops and op in name_to_function: 161 functions_to_process.append(name_to_function[op]) 162 used_ops.add(op) 163 164 def process_node(node): 165 mark_op_as_used(node.op) 166 if node.op in ["PartitionedCall", "StatefulPartitionedCall"]: 167 mark_op_as_used(node.attr["f"].func.name) 168 169 for node in graph_def.node: 170 process_node(node) 171 while functions_to_process: 172 fun = functions_to_process.pop() 173 for node in fun.node_def: 174 process_node(node) 175 176 return [op for op in used_ops if op not in name_to_function] 177 178 179def stripped_op_list_for_graph(graph_def): 180 """Collect the stripped OpDefs for ops used by a graph. 181 182 This function computes the `stripped_op_list` field of `MetaGraphDef` and 183 similar protos. The result can be communicated from the producer to the 184 consumer, which can then use the C++ function 185 `RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility. 186 187 Args: 188 graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`. 189 190 Returns: 191 An `OpList` of ops used by the graph. 192 """ 193 # This is similar to StrippedOpListForGraph in C++, but unlike its 194 # C++ counterpart, this version does not require all ops to be registered. 195 # This is done to support Prelu fusion in tfjs. 196 used_ops = ops_used_by_graph_def(graph_def) 197 op_defs = [] 198 for op in sorted(used_ops): 199 op_def = op_def_registry.get(op) 200 if op_def is not None: 201 op_defs.append(op_def) 202 return op_def_pb2.OpList(op=op_defs) 203 204 205def _get_kind_name(item): 206 """Returns the kind name in CollectionDef. 207 208 Args: 209 item: A data item. 210 211 Returns: 212 The string representation of the kind in CollectionDef. 213 """ 214 if isinstance(item, (six.string_types, six.binary_type)): 215 kind = "bytes_list" 216 elif isinstance(item, six.integer_types): 217 kind = "int64_list" 218 elif isinstance(item, float): 219 kind = "float_list" 220 elif isinstance(item, Any): 221 kind = "any_list" 222 else: 223 kind = "node_list" 224 return kind 225 226 227SAVE_AND_RESTORE_OPS = ["SaveV2", 228 "Save", "SaveSlice", 229 "LegacySave", "LegacySaveSlice", 230 "RestoreV2", 231 "Restore", "RestoreSlice", 232 "LegacyRestore", "LegacyRestoreSlice"] 233 234 235def _op_name(tensor_name): 236 """Extract the Op name from a Tensor name. 237 238 The Op name is everything before a colon, if present, 239 not including any ^ prefix denoting a control dependency. 240 241 Args: 242 tensor_name: the full name of a Tensor in the graph. 243 Returns: 244 The name of the Op of which the given Tensor is an output. 245 Raises: 246 ValueError: if tensor_name is None or empty. 247 """ 248 if not tensor_name: 249 raise ValueError("Tensor name cannot be empty or None.") 250 251 # Control dependency inputs start with ^. 252 if tensor_name.startswith("^"): 253 tensor_name = tensor_name[1:] 254 if ":" in tensor_name: 255 op_name, _ = tensor_name.split(":") 256 return op_name 257 return tensor_name 258 259 260def _get_scope(node_name): 261 """Extract the scope name from a node name. 262 263 The scope name is everything before the final slash, 264 not including any ^ prefix denoting a control dependency. 265 266 Args: 267 node_name: the full name of an Op or a Tensor in the graph. 268 Returns: 269 The deepest named scope containing the node. 270 Raises: 271 ValueError: if tensor_name is None or empty 272 """ 273 if not node_name: 274 raise ValueError("Node name cannot be empty or None.") 275 276 # Control dependency inputs start with ^. 277 if node_name.startswith("^"): 278 node_name = node_name[1:] 279 if "/" in node_name: 280 scope, _ = node_name.rsplit("/", 1) 281 return scope 282 283 return "" 284 285 286def _find_extraneous_saver_nodes(graph_def, saver_def): 287 """Identifies any nodes in the graph_def related to unused Savers. 288 289 This approach assumes that each Saver is cleanly isolated in its own name 290 scope, so we need only identify the scopes associated with extraneous Savers 291 and return all the nodes in those scopes. 292 293 Args: 294 graph_def: a GraphDef proto to evaluate. 295 saver_def: a SaverDef proto referencing Save/Restore ops to be retained. 296 Returns: 297 An iterable of node names that may be safely omitted. 298 """ 299 # TODO(soergel): confirm that the assumption of scope isolation is valid. 300 # If not, we need to walk up the graph from any restore_all nodes, and walk 301 # down the graph from any Save/Restore nodes. I drafted that approach too, 302 # but it seems unnecessarily complex given the name scope solution. 303 304 # load the graph DAG in minimal form, without initializing a full Graph object 305 nodes = { 306 node_def.name: (set(_op_name(x) for x in node_def.input), node_def.op) 307 for node_def in graph_def.node 308 } 309 310 retain_scope_save = None 311 retain_scope_restore = None 312 # It's possible to have no saver if the graph has no Variables 313 if saver_def is not None: 314 save_op_name = _op_name(saver_def.save_tensor_name) 315 restore_op_name = _op_name(saver_def.restore_op_name) 316 317 # The save and restore scopes should always be the same, but if they differ 318 # for some reason, we retain them both to be safe. 319 retain_scope_restore = _get_scope(restore_op_name) + "/" 320 retain_scope_save = _get_scope(save_op_name) + "/" 321 322 all_saver_node_names = set( 323 name for name, (_, op) in nodes.items() if op in SAVE_AND_RESTORE_OPS) 324 325 all_saver_scopes = ( 326 set(_get_scope(x) for x in all_saver_node_names) - all_saver_node_names) 327 all_saver_scopes = set(x + "/" for x in all_saver_scopes) 328 329 extraneous_scopes = all_saver_scopes - set([retain_scope_save, 330 retain_scope_restore]) 331 332 extraneous_node_names = set() 333 for name, _ in nodes.items(): 334 for extraneous_scope in extraneous_scopes: 335 if name.startswith(extraneous_scope): 336 extraneous_node_names.add(name) 337 break 338 339 return extraneous_node_names 340 341 342def _should_include_node(node_or_node_name, export_scope, exclude_nodes): 343 """Returns `True` if a node should be included. 344 345 Args: 346 node_or_node_name: A node or `string` node name. 347 export_scope: `string`. Name scope under which to extract the subgraph. The 348 scope name will be stripped from the node definitions for easy import 349 later into new name scopes. 350 exclude_nodes: An iterable of nodes or `string` node names to omit from the 351 export, or None. Note no sanity-checking is done, so this list must be 352 carefully constructed to avoid producing an invalid graph. 353 354 Returns: 355 `True` if the node should be included. 356 """ 357 if not isinstance(node_or_node_name, six.string_types): 358 try: 359 node_name = node_or_node_name.name 360 except AttributeError: 361 # Keep the object that we don't know how to process. 362 return True 363 else: 364 node_name = node_or_node_name 365 366 if exclude_nodes and (node_or_node_name in exclude_nodes 367 or node_name in exclude_nodes): 368 return False 369 370 return (node_name.startswith(_UNBOUND_INPUT_PREFIX) or 371 (not export_scope or node_name.startswith(export_scope))) 372 373 374def add_collection_def(meta_graph_def, key, graph=None, 375 export_scope=None, exclude_nodes=None, 376 override_contents=None): 377 """Adds a collection to MetaGraphDef protocol buffer. 378 379 Args: 380 meta_graph_def: MetaGraphDef protocol buffer. 381 key: One of the GraphKeys or user-defined string. 382 graph: The `Graph` from which to get collections. 383 export_scope: Optional `string`. Name scope to remove. 384 exclude_nodes: An iterable of nodes or `string` node names to omit from the 385 collection, or None. 386 override_contents: An iterable of values to place in the collection, 387 ignoring the current values (if set). 388 """ 389 if graph and not isinstance(graph, ops.Graph): 390 raise TypeError("graph must be of type Graph, not %s", type(graph)) 391 392 if not isinstance(key, six.string_types) and not isinstance(key, bytes): 393 logging.warning("Only collections with string type keys will be " 394 "serialized. This key has %s", type(key)) 395 return 396 397 # Sets graph to default graph if it's not passed in. 398 graph = graph or ops.get_default_graph() 399 400 if override_contents: 401 collection_list = override_contents 402 else: 403 collection_list = graph.get_collection(key) 404 405 # Remove nodes that should not be exported from the collection list. 406 collection_list = [x for x in collection_list if 407 _should_include_node(x, export_scope, exclude_nodes)] 408 if not collection_list: 409 return 410 411 try: 412 col_def = meta_graph_def.collection_def[key] 413 to_proto = ops.get_to_proto_function(key) 414 proto_type = ops.get_collection_proto_type(key) 415 if to_proto: 416 kind = "bytes_list" 417 for x in collection_list: 418 # Additional type check to make sure the returned proto is indeed 419 # what we expect. 420 proto = to_proto(x, export_scope=export_scope) 421 if proto: 422 assert isinstance(proto, proto_type) 423 getattr(col_def, kind).value.append(proto.SerializeToString()) 424 else: 425 kind = _get_kind_name(collection_list[0]) 426 if kind == "node_list": 427 for x in collection_list: 428 if not export_scope or x.name.startswith(export_scope): 429 getattr(col_def, kind).value.append( 430 ops.strip_name_scope(x.name, export_scope)) 431 elif kind == "bytes_list": 432 # NOTE(opensource): This force conversion is to work around the fact 433 # that Python3 distinguishes between bytes and strings. 434 getattr(col_def, kind).value.extend( 435 [compat.as_bytes(x) for x in collection_list]) 436 else: 437 getattr(col_def, kind).value.extend([x for x in collection_list]) 438 except Exception as e: # pylint: disable=broad-except 439 logging.warning("Issue encountered when serializing %s.\n" 440 "Type is unsupported, or the types of the items don't " 441 "match field type in CollectionDef. Note this is a warning " 442 "and probably safe to ignore.\n%s", key, str(e)) 443 if key in meta_graph_def.collection_def: 444 del meta_graph_def.collection_def[key] 445 return 446 447 448def _is_default_attr_value(op_def, attr_name, attr_value): 449 """Checks if given attribute matches the default value in the op def.""" 450 for attr_def in op_def.attr: 451 if attr_def.name == attr_name: 452 if not attr_def.HasField("default_value"): 453 return False 454 # c_api.EqualAttrValueWrapper returns an empty string 455 # if both arguments represent an equivalent AttrValue instance. 456 return not c_api.EqualAttrValueWrapper( 457 attr_value.SerializeToString(), 458 attr_def.default_value.SerializeToString()) 459 return False 460 461 462def strip_graph_default_valued_attrs(meta_graph_def): 463 """Strips default valued attributes for node defs in given MetaGraphDef. 464 465 This method also sets `meta_info_def.stripped_default_attrs` in the given 466 `MetaGraphDef` proto to True. 467 468 Args: 469 meta_graph_def: `MetaGraphDef` protocol buffer 470 471 Returns: 472 None. 473 """ 474 # Map function op names to their function definitions. 475 op_name_to_function = {} 476 for function_def in meta_graph_def.graph_def.library.function: 477 op_name_to_function[function_def.signature.name] = function_def 478 479 def _strip_node_default_valued_attrs(node_def): 480 """Removes default valued attributes from a single node def.""" 481 if node_def.op in op_name_to_function: 482 return 483 484 op_def = op_def_registry.get(node_def.op) 485 if op_def is None: 486 return 487 488 attrs_to_strip = set() 489 for attr_name, attr_value in node_def.attr.items(): 490 if _is_default_attr_value(op_def, attr_name, attr_value): 491 attrs_to_strip.add(attr_name) 492 493 for attr in attrs_to_strip: 494 del node_def.attr[attr] 495 496 # Process all NodeDef instances in graph_def. 497 for node_def in meta_graph_def.graph_def.node: 498 _strip_node_default_valued_attrs(node_def) 499 500 # Process all NodeDef instances in graph_def.library.function. 501 for function_def in meta_graph_def.graph_def.library.function: 502 for function_node_def in function_def.node_def: 503 _strip_node_default_valued_attrs(function_node_def) 504 505 # Tell consumers of this graph that default valued attrs have been stripped. 506 meta_graph_def.meta_info_def.stripped_default_attrs = True 507 508 509def create_meta_graph_def(meta_info_def=None, 510 graph_def=None, 511 saver_def=None, 512 collection_list=None, 513 graph=None, 514 export_scope=None, 515 exclude_nodes=None, 516 clear_extraneous_savers=False, 517 strip_default_attrs=False): 518 # pylint: disable=line-too-long 519 """Construct and returns a `MetaGraphDef` protocol buffer. 520 521 Args: 522 meta_info_def: `MetaInfoDef` protocol buffer. 523 graph_def: `GraphDef` protocol buffer. 524 saver_def: `SaverDef` protocol buffer. 525 collection_list: List of string keys to collect. 526 graph: The `Graph` to create `MetaGraphDef` out of. 527 export_scope: Optional `string`. Name scope to remove. 528 exclude_nodes: An iterable of nodes or `string` node names to omit from all 529 collection, or None. 530 clear_extraneous_savers: Remove any preexisting SaverDefs from the SAVERS 531 collection. Note this method does not alter the graph, so any 532 extraneous Save/Restore ops should have been removed already, as needed. 533 strip_default_attrs: Boolean. If `True`, default-valued attributes will be 534 removed from the NodeDefs. For a detailed guide, see 535 [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). 536 537 Returns: 538 MetaGraphDef protocol buffer. 539 540 Raises: 541 TypeError: If the arguments are not of the correct proto buffer type. 542 """ 543 # pylint: enable=line-too-long 544 # Type check. 545 if graph and not isinstance(graph, ops.Graph): 546 raise TypeError("graph must be of type Graph, not %s", type(graph)) 547 if meta_info_def and not isinstance(meta_info_def, 548 meta_graph_pb2.MetaGraphDef.MetaInfoDef): 549 raise TypeError("meta_info_def must be of type MetaInfoDef, not %s", 550 type(meta_info_def)) 551 if graph_def and not isinstance(graph_def, graph_pb2.GraphDef): 552 raise TypeError("graph_def must be of type GraphDef, not %s", 553 type(graph_def)) 554 if saver_def and not isinstance(saver_def, saver_pb2.SaverDef): 555 raise TypeError("saver_def must be of type SaverDef, not %s", 556 type(saver_def)) 557 558 # Sets graph to default graph if it's not passed in. 559 graph = graph or ops.get_default_graph() 560 561 # Creates a MetaGraphDef proto. 562 meta_graph_def = meta_graph_pb2.MetaGraphDef() 563 # Adds meta_info_def. 564 if not meta_info_def: 565 meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef() 566 567 # Set the tf version strings to the current tf build. 568 meta_info_def.tensorflow_version = versions.__version__ 569 meta_info_def.tensorflow_git_version = versions.__git_version__ 570 meta_graph_def.meta_info_def.MergeFrom(meta_info_def) 571 572 # Adds graph_def or the default. 573 if not graph_def: 574 meta_graph_def.graph_def.MergeFrom(graph.as_graph_def(add_shapes=True)) 575 else: 576 meta_graph_def.graph_def.MergeFrom(graph_def) 577 578 # Fills in meta_info_def.stripped_op_list using the ops from graph_def. 579 # pylint: disable=g-explicit-length-test 580 if len(meta_graph_def.meta_info_def.stripped_op_list.op) == 0: 581 meta_graph_def.meta_info_def.stripped_op_list.MergeFrom( 582 stripped_op_list_for_graph(meta_graph_def.graph_def)) 583 # pylint: enable=g-explicit-length-test 584 585 # Strip default valued attributes in graph_def. 586 if strip_default_attrs: 587 strip_graph_default_valued_attrs(meta_graph_def) 588 589 # Adds saver_def. 590 if saver_def: 591 meta_graph_def.saver_def.MergeFrom(saver_def) 592 593 # Adds collection_list. 594 if collection_list is not None: 595 clist = collection_list 596 else: 597 clist = graph.get_all_collection_keys() 598 599 for ctype in clist: 600 if clear_extraneous_savers and ctype == ops.GraphKeys.SAVERS: 601 # Avoid importing Saver here 602 from_proto = ops.get_from_proto_function(ctype) 603 add_collection_def(meta_graph_def, ctype, 604 graph=graph, 605 export_scope=export_scope, 606 exclude_nodes=exclude_nodes, 607 override_contents=[from_proto(saver_def)]) 608 else: 609 add_collection_def(meta_graph_def, ctype, 610 graph=graph, 611 export_scope=export_scope, 612 exclude_nodes=exclude_nodes) 613 return meta_graph_def 614 615 616def read_meta_graph_file(filename): 617 """Reads a file containing `MetaGraphDef` and returns the protocol buffer. 618 619 Args: 620 filename: `meta_graph_def` filename including the path. 621 622 Returns: 623 A `MetaGraphDef` protocol buffer. 624 625 Raises: 626 IOError: If the file doesn't exist, or cannot be successfully parsed. 627 """ 628 meta_graph_def = meta_graph_pb2.MetaGraphDef() 629 if not file_io.file_exists(filename): 630 raise IOError("File %s does not exist." % filename) 631 # First try to read it as a binary file. 632 file_content = file_io.FileIO(filename, "rb").read() 633 try: 634 meta_graph_def.ParseFromString(file_content) 635 return meta_graph_def 636 except Exception: # pylint: disable=broad-except 637 pass 638 639 # Next try to read it as a text file. 640 try: 641 text_format.Merge(file_content.decode("utf-8"), meta_graph_def) 642 except text_format.ParseError as e: 643 raise IOError("Cannot parse file %s: %s." % (filename, str(e))) 644 645 return meta_graph_def 646 647 648def import_scoped_meta_graph(meta_graph_or_file, 649 clear_devices=False, 650 graph=None, 651 import_scope=None, 652 input_map=None, 653 unbound_inputs_col_name="unbound_inputs", 654 restore_collections_predicate=(lambda key: True)): 655 """Recreates a `Graph` saved in a `MetaGraphDef` proto. 656 657 This function takes a `MetaGraphDef` protocol buffer as input. If 658 the argument is a file containing a `MetaGraphDef` protocol buffer , 659 it constructs a protocol buffer from the file content. The function 660 then adds all the nodes from the `graph_def` field to the 661 current graph, recreates the desired collections, and returns a dictionary of 662 all the Variables imported into the name scope. 663 664 In combination with `export_scoped_meta_graph()`, this function can be used to 665 666 * Serialize a graph along with other Python objects such as `QueueRunner`, 667 `Variable` into a `MetaGraphDef`. 668 669 * Restart training from a saved graph and checkpoints. 670 671 * Run inference from a saved graph and checkpoints. 672 673 Args: 674 meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including 675 the path) containing a `MetaGraphDef`. 676 clear_devices: Boolean which controls whether to clear device information 677 from graph_def. Default false. 678 graph: The `Graph` to import into. If `None`, use the default graph. 679 import_scope: Optional `string`. Name scope into which to import the 680 subgraph. If `None`, the graph is imported to the root name scope. 681 input_map: A dictionary mapping input names (as strings) in `graph_def` to 682 `Tensor` objects. The values of the named input tensors in the imported 683 graph will be re-mapped to the respective `Tensor` values. 684 unbound_inputs_col_name: Collection name for looking up unbound inputs. 685 restore_collections_predicate: a predicate on collection names. A collection 686 named c (i.e whose key is c) will be restored iff 687 1) `restore_collections_predicate(c)` is True, and 688 2) `c != unbound_inputs_col_name`. 689 690 Returns: 691 A dictionary of all the `Variables` imported into the name scope. 692 693 Raises: 694 ValueError: If the graph_def contains unbound inputs. 695 """ 696 return import_scoped_meta_graph_with_return_elements( 697 meta_graph_or_file, clear_devices, graph, import_scope, input_map, 698 unbound_inputs_col_name, restore_collections_predicate)[0] 699 700 701def import_scoped_meta_graph_with_return_elements( 702 meta_graph_or_file, 703 clear_devices=False, 704 graph=None, 705 import_scope=None, 706 input_map=None, 707 unbound_inputs_col_name="unbound_inputs", 708 restore_collections_predicate=(lambda key: True), 709 return_elements=None): 710 """Imports graph from `MetaGraphDef` and returns vars and return elements. 711 712 This function takes a `MetaGraphDef` protocol buffer as input. If 713 the argument is a file containing a `MetaGraphDef` protocol buffer , 714 it constructs a protocol buffer from the file content. The function 715 then adds all the nodes from the `graph_def` field to the 716 current graph, recreates the desired collections, and returns a dictionary of 717 all the Variables imported into the name scope. 718 719 In combination with `export_scoped_meta_graph()`, this function can be used to 720 721 * Serialize a graph along with other Python objects such as `QueueRunner`, 722 `Variable` into a `MetaGraphDef`. 723 724 * Restart training from a saved graph and checkpoints. 725 726 * Run inference from a saved graph and checkpoints. 727 728 Args: 729 meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including 730 the path) containing a `MetaGraphDef`. 731 clear_devices: Boolean which controls whether to clear device information 732 from graph_def. Default false. 733 graph: The `Graph` to import into. If `None`, use the default graph. 734 import_scope: Optional `string`. Name scope into which to import the 735 subgraph. If `None`, the graph is imported to the root name scope. 736 input_map: A dictionary mapping input names (as strings) in `graph_def` to 737 `Tensor` objects. The values of the named input tensors in the imported 738 graph will be re-mapped to the respective `Tensor` values. 739 unbound_inputs_col_name: Collection name for looking up unbound inputs. 740 restore_collections_predicate: a predicate on collection names. A collection 741 named c (i.e whose key is c) will be restored iff 742 1) `restore_collections_predicate(c)` is True, and 743 2) `c != unbound_inputs_col_name`. 744 return_elements: A list of strings containing operation names in the 745 `MetaGraphDef` that will be returned as `Operation` objects; and/or 746 tensor names in `MetaGraphDef` that will be returned as `Tensor` objects. 747 748 Returns: 749 A tuple of ( 750 dictionary of all the `Variables` imported into the name scope, 751 list of `Operation` or `Tensor` objects from the `return_elements` list). 752 753 Raises: 754 ValueError: If the graph_def contains unbound inputs. 755 756 """ 757 if context.executing_eagerly(): 758 raise ValueError("Exporting/importing meta graphs is not supported when " 759 "eager execution is enabled.") 760 if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): 761 meta_graph_def = meta_graph_or_file 762 else: 763 meta_graph_def = read_meta_graph_file(meta_graph_or_file) 764 765 if unbound_inputs_col_name: 766 for key, col_def in meta_graph_def.collection_def.items(): 767 if key == unbound_inputs_col_name: 768 kind = col_def.WhichOneof("kind") 769 field = getattr(col_def, kind) 770 if field.value and ( 771 not input_map or 772 sorted([compat.as_str(v) for v in field.value]) != 773 sorted(input_map)): 774 raise ValueError("Graph contains unbound inputs: %s. Must " 775 "provide these inputs through input_map." % ",".join( 776 compat.as_str(v) 777 for v in field.value 778 if not input_map or v not in input_map)) 779 break 780 781 # Sets graph to default graph if it's not passed in. 782 graph = graph or ops.get_default_graph() 783 784 # Gathers the list of nodes we are interested in. 785 with graph.as_default(): 786 producer_op_list = None 787 if meta_graph_def.meta_info_def.HasField("stripped_op_list"): 788 producer_op_list = meta_graph_def.meta_info_def.stripped_op_list 789 input_graph_def = meta_graph_def.graph_def 790 # Remove all the explicit device specifications for this node. This helps to 791 # make the graph more portable. 792 if clear_devices: 793 for node in input_graph_def.node: 794 node.device = "" 795 796 scope_to_prepend_to_names = graph.unique_name( 797 import_scope or "", mark_as_used=False) 798 799 imported_return_elements = importer.import_graph_def( 800 input_graph_def, 801 name=(import_scope or scope_to_prepend_to_names), 802 input_map=input_map, 803 producer_op_list=producer_op_list, 804 return_elements=return_elements) 805 806 # TensorFlow versions before 1.9 (not inclusive) exported SavedModels 807 # without a VariableDef.trainable field set. 808 tf_version = meta_graph_def.meta_info_def.tensorflow_version 809 if not tf_version: 810 variables_have_trainable = True 811 else: 812 variables_have_trainable = ( 813 distutils_version.LooseVersion(tf_version) 814 >= distutils_version.LooseVersion("1.9")) 815 816 # Sort collections so we see TRAINABLE_VARIABLES first and can default these 817 # variables to trainable if the value is not set in their VariableDef. 818 sorted_collections = [] 819 if ops.GraphKeys.TRAINABLE_VARIABLES in meta_graph_def.collection_def: 820 sorted_collections.append( 821 (ops.GraphKeys.TRAINABLE_VARIABLES, 822 meta_graph_def.collection_def[ops.GraphKeys.TRAINABLE_VARIABLES])) 823 for key, value in sorted(meta_graph_def.collection_def.items()): 824 if key != ops.GraphKeys.TRAINABLE_VARIABLES: 825 sorted_collections.append((key, value)) 826 827 # Restores all the other collections. 828 variable_objects = {} 829 for key, col_def in sorted_collections: 830 # Don't add unbound_inputs to the new graph. 831 if key == unbound_inputs_col_name: 832 continue 833 if not restore_collections_predicate(key): 834 continue 835 836 kind = col_def.WhichOneof("kind") 837 if kind is None: 838 logging.error("Cannot identify data type for collection %s. Skipping.", 839 key) 840 continue 841 from_proto = ops.get_from_proto_function(key) 842 843 # Temporary change to allow the TFMA evaluator to read metric variables 844 # saved as a bytes list. 845 # TODO(kathywu): Remove this hack once cl/248406059 has been submitted. 846 if key == ops.GraphKeys.METRIC_VARIABLES: 847 # Metric variables will use the same proto functions as GLOBAL_VARIABLES 848 from_proto = ops.get_from_proto_function(ops.GraphKeys.GLOBAL_VARIABLES) 849 if from_proto and kind == "bytes_list": 850 proto_type = ops.get_collection_proto_type(key) 851 if key in ops.GraphKeys._VARIABLE_COLLECTIONS: # pylint: disable=protected-access 852 for value in col_def.bytes_list.value: 853 variable = variable_objects.get(value, None) 854 if variable is None: 855 proto = proto_type() 856 proto.ParseFromString(value) 857 if not variables_have_trainable: 858 # If the VariableDef proto does not contain a "trainable" 859 # property because it was exported before that property was 860 # added, we default it to whether the variable is in the 861 # TRAINABLE_VARIABLES collection. We've sorted 862 # TRAINABLE_VARIABLES to be first, so trainable variables will 863 # be created from that collection. 864 proto.trainable = (key == ops.GraphKeys.TRAINABLE_VARIABLES) 865 variable = from_proto( 866 proto, import_scope=scope_to_prepend_to_names) 867 variable_objects[value] = variable 868 graph.add_to_collection(key, variable) 869 else: 870 for value in col_def.bytes_list.value: 871 proto = proto_type() 872 proto.ParseFromString(value) 873 graph.add_to_collection( 874 key, from_proto( 875 proto, import_scope=scope_to_prepend_to_names)) 876 else: 877 field = getattr(col_def, kind) 878 if key in _COMPAT_COLLECTION_LIST: 879 logging.warning( 880 "The saved meta_graph is possibly from an older release:\n" 881 "'%s' collection should be of type 'byte_list', but instead " 882 "is of type '%s'.", key, kind) 883 if kind == "node_list": 884 for value in field.value: 885 col_op = graph.as_graph_element( 886 ops.prepend_name_scope(value, scope_to_prepend_to_names)) 887 graph.add_to_collection(key, col_op) 888 elif kind == "int64_list": 889 # NOTE(opensource): This force conversion is to work around the fact 890 # that Python2 distinguishes between int and long, while Python3 has 891 # only int. 892 for value in field.value: 893 graph.add_to_collection(key, int(value)) 894 else: 895 for value in field.value: 896 graph.add_to_collection( 897 key, ops.prepend_name_scope(value, scope_to_prepend_to_names)) 898 899 var_list = {} 900 variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, 901 scope=scope_to_prepend_to_names) 902 for v in variables: 903 var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v 904 905 return var_list, imported_return_elements 906 907 908def export_scoped_meta_graph(filename=None, 909 graph_def=None, 910 graph=None, 911 export_scope=None, 912 as_text=False, 913 unbound_inputs_col_name="unbound_inputs", 914 clear_devices=False, 915 saver_def=None, 916 clear_extraneous_savers=False, 917 strip_default_attrs=False, 918 save_debug_info=False, 919 **kwargs): 920 """Returns `MetaGraphDef` proto. Optionally writes it to filename. 921 922 This function exports the graph, saver, and collection objects into 923 `MetaGraphDef` protocol buffer with the intention of it being imported 924 at a later time or location to restart training, run inference, or be 925 a subgraph. 926 927 Args: 928 filename: Optional filename including the path for writing the 929 generated `MetaGraphDef` protocol buffer. 930 graph_def: `GraphDef` protocol buffer. 931 graph: The `Graph` to export. If `None`, use the default graph. 932 export_scope: Optional `string`. Name scope under which to extract 933 the subgraph. The scope name will be stripped from the node definitions 934 for easy import later into new name scopes. If `None`, the whole graph 935 is exported. 936 as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto. 937 unbound_inputs_col_name: Optional `string`. If provided, a string collection 938 with the given name will be added to the returned `MetaGraphDef`, 939 containing the names of tensors that must be remapped when importing the 940 `MetaGraphDef`. 941 clear_devices: Boolean which controls whether to clear device information 942 before exporting the graph. 943 saver_def: `SaverDef` protocol buffer. 944 clear_extraneous_savers: Remove any Saver-related information from the 945 graph (both Save/Restore ops and SaverDefs) that are not associated 946 with the provided SaverDef. 947 strip_default_attrs: Set to true if default valued attributes must be 948 removed while exporting the GraphDef. 949 save_debug_info: If `True`, save the GraphDebugInfo to a separate file, 950 which in the same directory of filename and with `_debug` added before the 951 file extension. 952 **kwargs: Optional keyed arguments, including meta_info_def and 953 collection_list. 954 955 Returns: 956 A `MetaGraphDef` proto and dictionary of `Variables` in the exported 957 name scope. 958 959 Raises: 960 ValueError: When the `GraphDef` is larger than 2GB. 961 ValueError: When executing in Eager mode and either `graph_def` or `graph` 962 is undefined. 963 """ 964 if context.executing_eagerly() and not (graph_def is not None and 965 graph is not None): 966 raise ValueError("Exporting/importing meta graphs is not supported when " 967 "Eager Execution is enabled.") 968 graph = graph or ops.get_default_graph() 969 970 exclude_nodes = None 971 unbound_inputs = [] 972 if export_scope or clear_extraneous_savers or clear_devices: 973 if graph_def: 974 new_graph_def = graph_pb2.GraphDef() 975 new_graph_def.versions.CopyFrom(graph_def.versions) 976 new_graph_def.library.CopyFrom(graph_def.library) 977 978 if clear_extraneous_savers: 979 exclude_nodes = _find_extraneous_saver_nodes(graph_def, saver_def) 980 981 for node_def in graph_def.node: 982 if _should_include_node(node_def.name, export_scope, exclude_nodes): 983 new_node_def = _node_def(node_def, export_scope, unbound_inputs, 984 clear_devices=clear_devices) 985 new_graph_def.node.extend([new_node_def]) 986 graph_def = new_graph_def 987 else: 988 # Only do this complicated work if we want to remove a name scope. 989 graph_def = graph_pb2.GraphDef() 990 # pylint: disable=protected-access 991 graph_def.versions.CopyFrom(graph.graph_def_versions) 992 bytesize = 0 993 994 if clear_extraneous_savers: 995 exclude_nodes = _find_extraneous_saver_nodes(graph.as_graph_def(), 996 saver_def) 997 998 for key in sorted(graph._nodes_by_id): 999 if _should_include_node(graph._nodes_by_id[key].name, 1000 export_scope, 1001 exclude_nodes): 1002 value = graph._nodes_by_id[key] 1003 # pylint: enable=protected-access 1004 node_def = _node_def(value.node_def, export_scope, unbound_inputs, 1005 clear_devices=clear_devices) 1006 graph_def.node.extend([node_def]) 1007 if value.outputs: 1008 assert "_output_shapes" not in graph_def.node[-1].attr 1009 graph_def.node[-1].attr["_output_shapes"].list.shape.extend([ 1010 output.get_shape().as_proto() for output in value.outputs]) 1011 bytesize += value.node_def.ByteSize() 1012 if bytesize >= (1 << 31) or bytesize < 0: 1013 raise ValueError("GraphDef cannot be larger than 2GB.") 1014 1015 graph._copy_functions_to_graph_def(graph_def, bytesize) # pylint: disable=protected-access 1016 1017 # It's possible that not all the inputs are in the export_scope. 1018 # If we would like such information included in the exported meta_graph, 1019 # add them to a special unbound_inputs collection. 1020 if unbound_inputs_col_name: 1021 # Clears the unbound_inputs collections. 1022 graph.clear_collection(unbound_inputs_col_name) 1023 for k in unbound_inputs: 1024 graph.add_to_collection(unbound_inputs_col_name, k) 1025 1026 var_list = {} 1027 variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, 1028 scope=export_scope) 1029 for v in variables: 1030 if _should_include_node(v, export_scope, exclude_nodes): 1031 var_list[ops.strip_name_scope(v.name, export_scope)] = v 1032 1033 scoped_meta_graph_def = create_meta_graph_def( 1034 graph_def=graph_def, 1035 graph=graph, 1036 export_scope=export_scope, 1037 exclude_nodes=exclude_nodes, 1038 clear_extraneous_savers=clear_extraneous_savers, 1039 saver_def=saver_def, 1040 strip_default_attrs=strip_default_attrs, 1041 **kwargs) 1042 1043 if filename: 1044 graph_io.write_graph( 1045 scoped_meta_graph_def, 1046 os.path.dirname(filename), 1047 os.path.basename(filename), 1048 as_text=as_text) 1049 if save_debug_info: 1050 name, _ = os.path.splitext(filename) 1051 debug_filename = "{name}{ext}".format(name=name, ext=".debug") 1052 1053 # Gets the operation from the graph by the name. Excludes variable nodes, 1054 # so only the nodes in the frozen models are included. 1055 # TODO(liufengdb): fix this for functions. 1056 ops_to_export = [] 1057 for node in scoped_meta_graph_def.graph_def.node: 1058 scoped_op_name = ops.prepend_name_scope(node.name, export_scope) 1059 ops_to_export.append(("", graph.get_operation_by_name(scoped_op_name))) 1060 1061 graph_debug_info = error_interpolation.create_graph_debug_info_def( 1062 ops_to_export) 1063 1064 graph_io.write_graph( 1065 graph_debug_info, 1066 os.path.dirname(debug_filename), 1067 os.path.basename(debug_filename), 1068 as_text=as_text) 1069 1070 return scoped_meta_graph_def, var_list 1071 1072 1073def copy_scoped_meta_graph(from_scope, to_scope, 1074 from_graph=None, to_graph=None): 1075 """Copies a sub-meta_graph from one scope to another. 1076 1077 Args: 1078 from_scope: `String` name scope containing the subgraph to be copied. 1079 to_scope: `String` name scope under which the copied subgraph will reside. 1080 from_graph: Optional `Graph` from which to copy the subgraph. If `None`, the 1081 default graph is use. 1082 to_graph: Optional `Graph` to which to copy the subgraph. If `None`, the 1083 default graph is used. 1084 1085 Returns: 1086 A dictionary of `Variables` that has been copied into `to_scope`. 1087 1088 Raises: 1089 ValueError: If `from_scope` and `to_scope` are the same while 1090 `from_graph` and `to_graph` are also the same. 1091 """ 1092 from_graph = from_graph or ops.get_default_graph() 1093 to_graph = to_graph or ops.get_default_graph() 1094 1095 if from_graph == to_graph and from_scope == to_scope: 1096 raise ValueError("'from_scope' and 'to_scope' need to be different " 1097 "when performing copy in the same graph.") 1098 1099 orig_meta_graph, var_list = export_scoped_meta_graph( 1100 export_scope=from_scope, graph=from_graph) 1101 var_list = import_scoped_meta_graph(orig_meta_graph, 1102 graph=to_graph, 1103 import_scope=to_scope) 1104 return var_list 1105