1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""MetaGraph and related functions.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import copy 22import os.path 23import re 24 25import six 26from google.protobuf.any_pb2 import Any 27from google.protobuf import text_format 28 29from tensorflow.core.framework import attr_value_pb2 30from tensorflow.core.framework import graph_pb2 31from tensorflow.core.framework import op_def_pb2 32from tensorflow.core.protobuf import graph_debug_info_pb2 33from tensorflow.core.protobuf import meta_graph_pb2 34from tensorflow.core.protobuf import saver_pb2 35from tensorflow.python import pywrap_tensorflow 36from tensorflow.python.eager import context 37from tensorflow.python.framework import error_interpolation 38from tensorflow.python.framework import graph_io 39from tensorflow.python.framework import importer 40from tensorflow.python.framework import op_def_registry 41from tensorflow.python.framework import ops 42from tensorflow.python.framework import versions 43from tensorflow.python.lib.io import file_io 44from tensorflow.python.platform import tf_logging as logging 45from tensorflow.python.util import compat 46 47 48# Prefix to be added to unbound input names so they are easily identifiable. 49_UNBOUND_INPUT_PREFIX = "$unbound_inputs_" 50 51# List of collections that didn't register proto functions, as a result in 52# a previously exported meta_graph the items are of a different data type. 53_COMPAT_COLLECTION_LIST = [ops.GraphKeys.LOCAL_VARIABLES, 54 ops.GraphKeys.MODEL_VARIABLES] 55 56 57def _node_def(from_node_def, export_scope, unbound_inputs, clear_devices=False): 58 """Create a `NodeDef` proto with export_scope stripped. 59 60 Args: 61 from_node_def: A `node_def_pb2.NodeDef` protocol buffer. 62 export_scope: A `string` representing the name scope to remove. 63 unbound_inputs: An array of unbound input names if they exist. 64 clear_devices: Boolean which controls whether to clear device information 65 from node_def. Default false. 66 67 Returns: 68 A `node_def_pb2.NodeDef` protocol buffer. 69 """ 70 node_def = copy.deepcopy(from_node_def) 71 for i, v in enumerate(node_def.input): 72 if (export_scope and 73 not node_def.input[i].lstrip("^").startswith(export_scope)): 74 # Adds "$unbound_inputs_" prefix to the unbound name so they are easily 75 # identifiable. 76 node_def.input[i] = re.sub(r"([\^]|^)(.*)", 77 r"\1" + _UNBOUND_INPUT_PREFIX + r"\2", 78 compat.as_str(v)) 79 unbound_inputs.append(node_def.input[i]) 80 else: 81 node_def.input[i] = ops.strip_name_scope(v, export_scope) 82 node_def.name = compat.as_bytes( 83 ops.strip_name_scope(from_node_def.name, export_scope)) 84 for k, v in six.iteritems(from_node_def.attr): 85 if k == "_class": 86 new_s = [compat.as_bytes( 87 ops.strip_name_scope(s, export_scope)) for s in v.list.s 88 if not export_scope or 89 compat.as_str(s).split("@")[1].startswith(export_scope)] 90 node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue( 91 list=attr_value_pb2.AttrValue.ListValue(s=new_s))) 92 elif node_def.op in ("Enter", "RefEnter") and k == "frame_name": 93 if not export_scope or compat.as_str(v.s).startswith(export_scope): 94 new_s = compat.as_bytes(ops.strip_name_scope(v.s, export_scope)) 95 node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(s=new_s)) 96 else: 97 node_def.attr[k].CopyFrom(v) 98 99 if clear_devices: 100 node_def.device = "" 101 102 return node_def 103 104 105def _read_file(filename): 106 """Reads a file containing `GraphDef` and returns the protocol buffer. 107 108 Args: 109 filename: `graph_def` filename including the path. 110 111 Returns: 112 A `GraphDef` protocol buffer. 113 114 Raises: 115 IOError: If the file doesn't exist, or cannot be successfully parsed. 116 """ 117 graph_def = graph_pb2.GraphDef() 118 if not file_io.file_exists(filename): 119 raise IOError("File %s does not exist." % filename) 120 # First try to read it as a binary file. 121 file_content = file_io.FileIO(filename, "rb").read() 122 try: 123 graph_def.ParseFromString(file_content) 124 return graph_def 125 except Exception: # pylint: disable=broad-except 126 pass 127 128 # Next try to read it as a text file. 129 try: 130 text_format.Merge(file_content, graph_def) 131 except text_format.ParseError as e: 132 raise IOError("Cannot parse file %s: %s." % (filename, str(e))) 133 134 return graph_def 135 136 137def ops_used_by_graph_def(graph_def): 138 """Collect the list of ops used by a graph. 139 140 Does not validate that the ops are all registered. 141 142 Args: 143 graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`. 144 145 Returns: 146 A list of strings, each naming an op used by the graph. 147 """ 148 # Map function names to definitions 149 name_to_function = {} 150 for fun in graph_def.library.function: 151 name_to_function[fun.signature.name] = fun 152 153 # Collect the list of op names. Since functions can reference functions, we 154 # need a recursive traversal. 155 used_ops = set() # Includes both primitive ops and functions 156 functions_to_process = [] # A subset of used_ops 157 158 def mark_op_as_used(op): 159 if op not in used_ops and op in name_to_function: 160 functions_to_process.append(name_to_function[op]) 161 used_ops.add(op) 162 163 for node in graph_def.node: 164 mark_op_as_used(node.op) 165 while functions_to_process: 166 fun = functions_to_process.pop() 167 for node in fun.node_def: 168 mark_op_as_used(node.op) 169 170 return [op for op in used_ops if op not in name_to_function] 171 172 173def stripped_op_list_for_graph(graph_def): 174 """Collect the stripped OpDefs for ops used by a graph. 175 176 This function computes the `stripped_op_list` field of `MetaGraphDef` and 177 similar protos. The result can be communicated from the producer to the 178 consumer, which can then use the C++ function 179 `RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility. 180 181 Args: 182 graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`. 183 184 Returns: 185 An `OpList` of ops used by the graph. 186 187 Raises: 188 ValueError: If an unregistered op is used. 189 """ 190 # This is the Python equivalent of StrippedOpListForGraph in C++. 191 # Unfortunately, since the Python op registry can differ from that in C++, we 192 # can't remove the duplication using swig (at least naively). 193 # TODO(irving): Support taking graphs directly. 194 195 used_ops = ops_used_by_graph_def(graph_def) 196 197 # Verify that all used ops are registered. 198 registered_ops = op_def_registry.get_registered_ops() 199 # These internal ops used by functions are not registered, so we need to 200 # whitelist them. # TODO(irving): Do something better here. 201 op_whitelist = ("_Arg", "_Retval", "_ListToArray", "_ArrayToList") 202 for op in used_ops: 203 if op not in registered_ops and op not in op_whitelist: 204 raise ValueError("Op %s is used by the graph, but is not registered" % op) 205 206 # Build the stripped op list in sorted order 207 return op_def_pb2.OpList(op=[registered_ops[op] for op in sorted(used_ops) 208 if op in registered_ops]) 209 210 211def _get_kind_name(item): 212 """Returns the kind name in CollectionDef. 213 214 Args: 215 item: A data item. 216 217 Returns: 218 The string representation of the kind in CollectionDef. 219 """ 220 if isinstance(item, (six.string_types, six.binary_type)): 221 kind = "bytes_list" 222 elif isinstance(item, six.integer_types): 223 kind = "int64_list" 224 elif isinstance(item, float): 225 kind = "float_list" 226 elif isinstance(item, Any): 227 kind = "any_list" 228 else: 229 kind = "node_list" 230 return kind 231 232 233SAVE_AND_RESTORE_OPS = ["SaveV2", 234 "Save", "SaveSlice", 235 "LegacySave", "LegacySaveSlice", 236 "RestoreV2", 237 "Restore", "RestoreSlice", 238 "LegacyRestore", "LegacyRestoreSlice"] 239 240 241def _op_name(tensor_name): 242 """Extract the Op name from a Tensor name. 243 244 The Op name is everything before a colon, if present, 245 not including any ^ prefix denoting a control dependency. 246 247 Args: 248 tensor_name: the full name of a Tensor in the graph. 249 Returns: 250 The name of the Op of which the given Tensor is an output. 251 Raises: 252 ValueError: if tensor_name is None or empty. 253 """ 254 if not tensor_name: 255 raise ValueError("Tensor name cannot be empty or None.") 256 257 # Control dependency inputs start with ^. 258 if tensor_name.startswith("^"): 259 tensor_name = tensor_name[1:] 260 if ":" in tensor_name: 261 op_name, _ = tensor_name.split(":") 262 return op_name 263 return tensor_name 264 265 266def _get_scope(node_name): 267 """Extract the scope name from a node name. 268 269 The scope name is everything before the final slash, 270 not including any ^ prefix denoting a control dependency. 271 272 Args: 273 node_name: the full name of an Op or a Tensor in the graph. 274 Returns: 275 The deepest named scope containing the node. 276 Raises: 277 ValueError: if tensor_name is None or empty 278 """ 279 if not node_name: 280 raise ValueError("Node name cannot be empty or None.") 281 282 # Control dependency inputs start with ^. 283 if node_name.startswith("^"): 284 node_name = node_name[1:] 285 if "/" in node_name: 286 scope, _ = node_name.rsplit("/", 1) 287 return scope 288 289 return "" 290 291 292def _find_extraneous_saver_nodes(graph_def, saver_def): 293 """Identifies any nodes in the graph_def related to unused Savers. 294 295 This approach assumes that each Saver is cleanly isolated in its own name 296 scope, so we need only identify the scopes associated with extraneous Savers 297 and return all the nodes in those scopes. 298 299 Args: 300 graph_def: a GraphDef proto to evaluate. 301 saver_def: a SaverDef proto referencing Save/Restore ops to be retained. 302 Returns: 303 An iterable of node names that may be safely omitted. 304 """ 305 # TODO(soergel): confirm that the assumption of scope isolation is valid. 306 # If not, we need to walk up the graph from any restore_all nodes, and walk 307 # down the graph from any Save/Restore nodes. I drafted that approach too, 308 # but it seems unnecessarily complex given the name scope solution. 309 310 # load the graph DAG in minimal form, without initializing a full Graph object 311 nodes = {node_def.name: 312 (set([_op_name(x) for x in node_def.input]), node_def.op) 313 for node_def in graph_def.node} 314 315 retain_scope_save = None 316 retain_scope_restore = None 317 # It's possible to have no saver if the graph has no Variables 318 if saver_def is not None: 319 save_op_name = _op_name(saver_def.save_tensor_name) 320 restore_op_name = _op_name(saver_def.restore_op_name) 321 322 # The save and restore scopes should always be the same, but if they differ 323 # for some reason, we retain them both to be safe. 324 retain_scope_restore = _get_scope(restore_op_name) + "/" 325 retain_scope_save = _get_scope(save_op_name) + "/" 326 327 all_saver_node_names = set([name for name, (_, op) in nodes.items() 328 if op in SAVE_AND_RESTORE_OPS]) 329 330 all_saver_scopes = (set([_get_scope(x) for x in all_saver_node_names]) 331 - all_saver_node_names) 332 all_saver_scopes = set([x + "/" for x in all_saver_scopes]) 333 334 extraneous_scopes = all_saver_scopes - set([retain_scope_save, 335 retain_scope_restore]) 336 337 extraneous_node_names = set() 338 for name, _ in nodes.items(): 339 for extraneous_scope in extraneous_scopes: 340 if name.startswith(extraneous_scope): 341 extraneous_node_names.add(name) 342 break 343 344 return extraneous_node_names 345 346 347def _should_include_node(node_or_node_name, export_scope, exclude_nodes): 348 """Returns `True` if a node should be included. 349 350 Args: 351 node_or_node_name: A node or `string` node name. 352 export_scope: `string`. Name scope under which to extract the subgraph. The 353 scope name will be stripped from the node definitions for easy import 354 later into new name scopes. 355 exclude_nodes: An iterable of nodes or `string` node names to omit from the 356 export, or None. Note no sanity-checking is done, so this list must be 357 carefully constructed to avoid producing an invalid graph. 358 359 Returns: 360 `True` if the node should be included. 361 """ 362 if not isinstance(node_or_node_name, six.string_types): 363 try: 364 node_name = node_or_node_name.name 365 except AttributeError: 366 # Keep the object that we don't know how to process. 367 return True 368 else: 369 node_name = node_or_node_name 370 371 if exclude_nodes and (node_or_node_name in exclude_nodes 372 or node_name in exclude_nodes): 373 return False 374 375 return (node_name.startswith(_UNBOUND_INPUT_PREFIX) or 376 (not export_scope or node_name.startswith(export_scope))) 377 378 379def add_collection_def(meta_graph_def, key, graph=None, 380 export_scope=None, exclude_nodes=None, 381 override_contents=None): 382 """Adds a collection to MetaGraphDef protocol buffer. 383 384 Args: 385 meta_graph_def: MetaGraphDef protocol buffer. 386 key: One of the GraphKeys or user-defined string. 387 graph: The `Graph` from which to get collections. 388 export_scope: Optional `string`. Name scope to remove. 389 exclude_nodes: An iterable of nodes or `string` node names to omit from the 390 collection, or None. 391 override_contents: An iterable of values to place in the collection, 392 ignoring the current values (if set). 393 """ 394 if graph and not isinstance(graph, ops.Graph): 395 raise TypeError("graph must be of type Graph, not %s", type(graph)) 396 397 if not isinstance(key, six.string_types) and not isinstance(key, bytes): 398 logging.warning("Only collections with string type keys will be " 399 "serialized. This key has %s", type(key)) 400 return 401 402 # Sets graph to default graph if it's not passed in. 403 graph = graph or ops.get_default_graph() 404 405 if override_contents: 406 collection_list = override_contents 407 else: 408 collection_list = graph.get_collection(key) 409 410 # Remove nodes that should not be exported from the collection list. 411 collection_list = [x for x in collection_list if 412 _should_include_node(x, export_scope, exclude_nodes)] 413 if not collection_list: 414 return 415 416 try: 417 col_def = meta_graph_def.collection_def[key] 418 to_proto = ops.get_to_proto_function(key) 419 proto_type = ops.get_collection_proto_type(key) 420 if to_proto: 421 kind = "bytes_list" 422 for x in collection_list: 423 # Additional type check to make sure the returned proto is indeed 424 # what we expect. 425 proto = to_proto(x, export_scope=export_scope) 426 if proto: 427 assert isinstance(proto, proto_type) 428 getattr(col_def, kind).value.append(proto.SerializeToString()) 429 else: 430 kind = _get_kind_name(collection_list[0]) 431 if kind == "node_list": 432 for x in collection_list: 433 if not export_scope or x.name.startswith(export_scope): 434 getattr(col_def, kind).value.append( 435 ops.strip_name_scope(x.name, export_scope)) 436 elif kind == "bytes_list": 437 # NOTE(opensource): This force conversion is to work around the fact 438 # that Python3 distinguishes between bytes and strings. 439 getattr(col_def, kind).value.extend( 440 [compat.as_bytes(x) for x in collection_list]) 441 else: 442 getattr(col_def, kind).value.extend([x for x in collection_list]) 443 except Exception as e: # pylint: disable=broad-except 444 logging.warning("Issue encountered when serializing %s.\n" 445 "Type is unsupported, or the types of the items don't " 446 "match field type in CollectionDef. Note this is a warning " 447 "and probably safe to ignore.\n%s", key, str(e)) 448 if key in meta_graph_def.collection_def: 449 del meta_graph_def.collection_def[key] 450 return 451 452 453def _is_default_attr_value(op_def, attr_name, attr_value): 454 """Checks if given attribute matches the default value in the op def.""" 455 for attr_def in op_def.attr: 456 if attr_def.name == attr_name: 457 if not attr_def.HasField("default_value"): 458 return False 459 # pywrap_tensorflow.EqualAttrValueWrapper returns an empty string 460 # if both arguments represent an equivalent AttrValue instance. 461 return not pywrap_tensorflow.EqualAttrValueWrapper( 462 attr_value.SerializeToString(), 463 attr_def.default_value.SerializeToString()) 464 return False 465 466 467def strip_graph_default_valued_attrs(meta_graph_def): 468 """Strips default valued attributes for node defs in given MetaGraphDef. 469 470 This method also sets `meta_info_def.stripped_default_attrs` in the given 471 `MetaGraphDef` proto to True. 472 473 Args: 474 meta_graph_def: `MetaGraphDef` protocol buffer 475 476 Returns: 477 None. 478 """ 479 # Map function op names to their function definitions. 480 op_name_to_function = {} 481 for function_def in meta_graph_def.graph_def.library.function: 482 op_name_to_function[function_def.signature.name] = function_def 483 484 # Get all registered ops. 485 registered_ops = op_def_registry.get_registered_ops() 486 487 def _strip_node_default_valued_attrs(node_def): 488 """Removes default valued attributes from a single node def.""" 489 if node_def.op in op_name_to_function or node_def.op not in registered_ops: 490 return 491 op_def = registered_ops[node_def.op] 492 493 attrs_to_strip = set() 494 for attr_name, attr_value in node_def.attr.items(): 495 if _is_default_attr_value(op_def, attr_name, attr_value): 496 attrs_to_strip.add(attr_name) 497 498 for attr in attrs_to_strip: 499 del node_def.attr[attr] 500 501 # Process all NodeDef instances in graph_def. 502 for node_def in meta_graph_def.graph_def.node: 503 _strip_node_default_valued_attrs(node_def) 504 505 # Process all NodeDef instances in graph_def.library.function. 506 for function_def in meta_graph_def.graph_def.library.function: 507 for function_node_def in function_def.node_def: 508 _strip_node_default_valued_attrs(function_node_def) 509 510 # Tell consumers of this graph that default valued attrs have been stripped. 511 meta_graph_def.meta_info_def.stripped_default_attrs = True 512 513 514def create_graph_debug_info_def(operations): 515 """Construct and returns a `GraphDebugInfo` protocol buffer. 516 517 Args: 518 operations: An iterable of op.Operation objects having _traceback members. 519 520 Returns: 521 GraphDebugInfo protocol buffer. 522 523 Raises: 524 TypeError: If the arguments are not of the correct proto buffer type. 525 """ 526 # Creates an empty GraphDebugInfoDef proto. 527 graph_debug_info_def = graph_debug_info_pb2.GraphDebugInfo() 528 529 # Gets the file names and line numbers for the exported node names. Also 530 # collects the unique file names. 531 all_file_names = set() 532 node_to_trace = {} 533 for op in operations: 534 # Gets the stack trace of the operation and then the file location. 535 node_name = op.name 536 node_to_trace[node_name] = error_interpolation.compute_useful_stack(op) 537 for trace in node_to_trace[node_name]: 538 all_file_names.add(trace[0]) 539 540 # Sets the `files` field in the GraphDebugInfo proto 541 graph_debug_info_def.files.extend(all_file_names) 542 543 # Builds a mapping between file names and index of the `files` field, so we 544 # only store the indexes for the nodes in the GraphDebugInfo. 545 file_to_index = dict( 546 [(y, x) for x, y in enumerate(graph_debug_info_def.files)]) 547 548 # Creates the FileLineCol proto for each node and sets the value in the 549 # GraphDebugInfo proto. We only store the file name index for each node to 550 # save the storage space. 551 for node_name, trace in node_to_trace.items(): 552 trace_def = graph_debug_info_def.traces[node_name] 553 for file_name, line, func, code in trace: 554 file_index = file_to_index[file_name] 555 trace_def.file_line_cols.add( 556 file_index=file_index, line=line, func=func, code=code) 557 558 return graph_debug_info_def 559 560 561def create_meta_graph_def(meta_info_def=None, 562 graph_def=None, 563 saver_def=None, 564 collection_list=None, 565 graph=None, 566 export_scope=None, 567 exclude_nodes=None, 568 clear_extraneous_savers=False, 569 strip_default_attrs=False): 570 # pylint: disable=line-too-long 571 """Construct and returns a `MetaGraphDef` protocol buffer. 572 573 Args: 574 meta_info_def: `MetaInfoDef` protocol buffer. 575 graph_def: `GraphDef` protocol buffer. 576 saver_def: `SaverDef` protocol buffer. 577 collection_list: List of string keys to collect. 578 graph: The `Graph` to create `MetaGraphDef` out of. 579 export_scope: Optional `string`. Name scope to remove. 580 exclude_nodes: An iterable of nodes or `string` node names to omit from all 581 collection, or None. 582 clear_extraneous_savers: Remove any preexisting SaverDefs from the SAVERS 583 collection. Note this method does not alter the graph, so any 584 extraneous Save/Restore ops should have been removed already, as needed. 585 strip_default_attrs: Boolean. If `True`, default-valued attributes will be 586 removed from the NodeDefs. For a detailed guide, see 587 [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). 588 589 Returns: 590 MetaGraphDef protocol buffer. 591 592 Raises: 593 TypeError: If the arguments are not of the correct proto buffer type. 594 """ 595 # pylint: enable=line-too-long 596 # Type check. 597 if graph and not isinstance(graph, ops.Graph): 598 raise TypeError("graph must be of type Graph, not %s", type(graph)) 599 if meta_info_def and not isinstance(meta_info_def, 600 meta_graph_pb2.MetaGraphDef.MetaInfoDef): 601 raise TypeError("meta_info_def must be of type MetaInfoDef, not %s", 602 type(meta_info_def)) 603 if graph_def and not isinstance(graph_def, graph_pb2.GraphDef): 604 raise TypeError("graph_def must be of type GraphDef, not %s", 605 type(graph_def)) 606 if saver_def and not isinstance(saver_def, saver_pb2.SaverDef): 607 raise TypeError("saver_def must be of type SaverDef, not %s", 608 type(saver_def)) 609 610 # Sets graph to default graph if it's not passed in. 611 graph = graph or ops.get_default_graph() 612 613 # Creates a MetaGraphDef proto. 614 meta_graph_def = meta_graph_pb2.MetaGraphDef() 615 # Adds meta_info_def. 616 if not meta_info_def: 617 meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef() 618 619 # Set the tf version strings to the current tf build. 620 meta_info_def.tensorflow_version = versions.__version__ 621 meta_info_def.tensorflow_git_version = versions.__git_version__ 622 meta_graph_def.meta_info_def.MergeFrom(meta_info_def) 623 624 # Adds graph_def or the default. 625 if not graph_def: 626 meta_graph_def.graph_def.MergeFrom(graph.as_graph_def(add_shapes=True)) 627 else: 628 meta_graph_def.graph_def.MergeFrom(graph_def) 629 630 # Fills in meta_info_def.stripped_op_list using the ops from graph_def. 631 # pylint: disable=g-explicit-length-test 632 if len(meta_graph_def.meta_info_def.stripped_op_list.op) == 0: 633 meta_graph_def.meta_info_def.stripped_op_list.MergeFrom( 634 stripped_op_list_for_graph(meta_graph_def.graph_def)) 635 # pylint: enable=g-explicit-length-test 636 637 # Strip default valued attributes in graph_def. 638 if strip_default_attrs: 639 strip_graph_default_valued_attrs(meta_graph_def) 640 641 # Adds saver_def. 642 if saver_def: 643 meta_graph_def.saver_def.MergeFrom(saver_def) 644 645 # Adds collection_list. 646 if collection_list is not None: 647 clist = collection_list 648 else: 649 clist = graph.get_all_collection_keys() 650 651 for ctype in clist: 652 if clear_extraneous_savers and ctype == ops.GraphKeys.SAVERS: 653 # Avoid importing Saver here 654 from_proto = ops.get_from_proto_function(ctype) 655 add_collection_def(meta_graph_def, ctype, 656 graph=graph, 657 export_scope=export_scope, 658 exclude_nodes=exclude_nodes, 659 override_contents=[from_proto(saver_def)]) 660 else: 661 add_collection_def(meta_graph_def, ctype, 662 graph=graph, 663 export_scope=export_scope, 664 exclude_nodes=exclude_nodes) 665 return meta_graph_def 666 667 668def read_meta_graph_file(filename): 669 """Reads a file containing `MetaGraphDef` and returns the protocol buffer. 670 671 Args: 672 filename: `meta_graph_def` filename including the path. 673 674 Returns: 675 A `MetaGraphDef` protocol buffer. 676 677 Raises: 678 IOError: If the file doesn't exist, or cannot be successfully parsed. 679 """ 680 meta_graph_def = meta_graph_pb2.MetaGraphDef() 681 if not file_io.file_exists(filename): 682 raise IOError("File %s does not exist." % filename) 683 # First try to read it as a binary file. 684 file_content = file_io.FileIO(filename, "rb").read() 685 try: 686 meta_graph_def.ParseFromString(file_content) 687 return meta_graph_def 688 except Exception: # pylint: disable=broad-except 689 pass 690 691 # Next try to read it as a text file. 692 try: 693 text_format.Merge(file_content.decode("utf-8"), meta_graph_def) 694 except text_format.ParseError as e: 695 raise IOError("Cannot parse file %s: %s." % (filename, str(e))) 696 697 return meta_graph_def 698 699 700def import_scoped_meta_graph(meta_graph_or_file, 701 clear_devices=False, 702 graph=None, 703 import_scope=None, 704 input_map=None, 705 unbound_inputs_col_name="unbound_inputs", 706 restore_collections_predicate=(lambda key: True)): 707 """Recreates a `Graph` saved in a `MetaGraphDef` proto. 708 709 This function takes a `MetaGraphDef` protocol buffer as input. If 710 the argument is a file containing a `MetaGraphDef` protocol buffer , 711 it constructs a protocol buffer from the file content. The function 712 then adds all the nodes from the `graph_def` field to the 713 current graph, recreates the desired collections, and returns a dictionary of 714 all the Variables imported into the name scope. 715 716 In combination with `export_scoped_meta_graph()`, this function can be used to 717 718 * Serialize a graph along with other Python objects such as `QueueRunner`, 719 `Variable` into a `MetaGraphDef`. 720 721 * Restart training from a saved graph and checkpoints. 722 723 * Run inference from a saved graph and checkpoints. 724 725 Args: 726 meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including 727 the path) containing a `MetaGraphDef`. 728 clear_devices: Boolean which controls whether to clear device information 729 from graph_def. Default false. 730 graph: The `Graph` to import into. If `None`, use the default graph. 731 import_scope: Optional `string`. Name scope into which to import the 732 subgraph. If `None`, the graph is imported to the root name scope. 733 input_map: A dictionary mapping input names (as strings) in `graph_def` to 734 `Tensor` objects. The values of the named input tensors in the imported 735 graph will be re-mapped to the respective `Tensor` values. 736 unbound_inputs_col_name: Collection name for looking up unbound inputs. 737 restore_collections_predicate: a predicate on collection names. A collection 738 named c (i.e whose key is c) will be restored iff 739 1) `restore_collections_predicate(c)` is True, and 740 2) `c != unbound_inputs_col_name`. 741 742 Returns: 743 A dictionary of all the `Variables` imported into the name scope. 744 745 Raises: 746 ValueError: If the graph_def contains unbound inputs. 747 """ 748 return import_scoped_meta_graph_with_return_elements( 749 meta_graph_or_file, clear_devices, graph, import_scope, input_map, 750 unbound_inputs_col_name, restore_collections_predicate)[0] 751 752 753def import_scoped_meta_graph_with_return_elements( 754 meta_graph_or_file, 755 clear_devices=False, 756 graph=None, 757 import_scope=None, 758 input_map=None, 759 unbound_inputs_col_name="unbound_inputs", 760 restore_collections_predicate=(lambda key: True), 761 return_elements=None): 762 """Imports graph from `MetaGraphDef` and returns vars and return elements. 763 764 This function takes a `MetaGraphDef` protocol buffer as input. If 765 the argument is a file containing a `MetaGraphDef` protocol buffer , 766 it constructs a protocol buffer from the file content. The function 767 then adds all the nodes from the `graph_def` field to the 768 current graph, recreates the desired collections, and returns a dictionary of 769 all the Variables imported into the name scope. 770 771 In combination with `export_scoped_meta_graph()`, this function can be used to 772 773 * Serialize a graph along with other Python objects such as `QueueRunner`, 774 `Variable` into a `MetaGraphDef`. 775 776 * Restart training from a saved graph and checkpoints. 777 778 * Run inference from a saved graph and checkpoints. 779 780 Args: 781 meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including 782 the path) containing a `MetaGraphDef`. 783 clear_devices: Boolean which controls whether to clear device information 784 from graph_def. Default false. 785 graph: The `Graph` to import into. If `None`, use the default graph. 786 import_scope: Optional `string`. Name scope into which to import the 787 subgraph. If `None`, the graph is imported to the root name scope. 788 input_map: A dictionary mapping input names (as strings) in `graph_def` to 789 `Tensor` objects. The values of the named input tensors in the imported 790 graph will be re-mapped to the respective `Tensor` values. 791 unbound_inputs_col_name: Collection name for looking up unbound inputs. 792 restore_collections_predicate: a predicate on collection names. A collection 793 named c (i.e whose key is c) will be restored iff 794 1) `restore_collections_predicate(c)` is True, and 795 2) `c != unbound_inputs_col_name`. 796 return_elements: A list of strings containing operation names in the 797 `MetaGraphDef` that will be returned as `Operation` objects; and/or 798 tensor names in `MetaGraphDef` that will be returned as `Tensor` objects. 799 800 Returns: 801 A tuple of ( 802 dictionary of all the `Variables` imported into the name scope, 803 list of `Operation` or `Tensor` objects from the `return_elements` list). 804 805 Raises: 806 ValueError: If the graph_def contains unbound inputs. 807 808 """ 809 if context.executing_eagerly(): 810 raise ValueError("Exporting/importing meta graphs is not supported when " 811 "eager execution is enabled.") 812 if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): 813 meta_graph_def = meta_graph_or_file 814 else: 815 meta_graph_def = read_meta_graph_file(meta_graph_or_file) 816 817 if unbound_inputs_col_name: 818 for key, col_def in meta_graph_def.collection_def.items(): 819 if key == unbound_inputs_col_name: 820 kind = col_def.WhichOneof("kind") 821 field = getattr(col_def, kind) 822 if field.value and ( 823 not input_map or 824 sorted([compat.as_str(v) for v in field.value]) != 825 sorted(input_map)): 826 raise ValueError("Graph contains unbound inputs: %s. Must " 827 "provide these inputs through input_map." % 828 ",".join([compat.as_str(v) for v in field.value 829 if not input_map or v not in input_map])) 830 break 831 832 # Sets graph to default graph if it's not passed in. 833 graph = graph or ops.get_default_graph() 834 835 # Gathers the list of nodes we are interested in. 836 with graph.as_default(): 837 producer_op_list = None 838 if meta_graph_def.meta_info_def.HasField("stripped_op_list"): 839 producer_op_list = meta_graph_def.meta_info_def.stripped_op_list 840 input_graph_def = meta_graph_def.graph_def 841 # Remove all the explicit device specifications for this node. This helps to 842 # make the graph more portable. 843 if clear_devices: 844 for node in input_graph_def.node: 845 node.device = "" 846 847 scope_to_prepend_to_names = graph.unique_name( 848 import_scope or "", mark_as_used=False) 849 850 imported_return_elements = importer.import_graph_def( 851 input_graph_def, 852 name=(import_scope or scope_to_prepend_to_names), 853 input_map=input_map, 854 producer_op_list=producer_op_list, 855 return_elements=return_elements) 856 857 # Restores all the other collections. 858 variable_objects = {} 859 for key, col_def in sorted(meta_graph_def.collection_def.items()): 860 # Don't add unbound_inputs to the new graph. 861 if key == unbound_inputs_col_name: 862 continue 863 if not restore_collections_predicate(key): 864 continue 865 866 kind = col_def.WhichOneof("kind") 867 if kind is None: 868 logging.error("Cannot identify data type for collection %s. Skipping.", 869 key) 870 continue 871 from_proto = ops.get_from_proto_function(key) 872 if from_proto and kind == "bytes_list": 873 proto_type = ops.get_collection_proto_type(key) 874 if key in ops.GraphKeys._VARIABLE_COLLECTIONS: # pylint: disable=protected-access 875 for value in col_def.bytes_list.value: 876 variable = variable_objects.get(value, None) 877 if variable is None: 878 proto = proto_type() 879 proto.ParseFromString(value) 880 variable = from_proto( 881 proto, import_scope=scope_to_prepend_to_names) 882 variable_objects[value] = variable 883 graph.add_to_collection(key, variable) 884 else: 885 for value in col_def.bytes_list.value: 886 proto = proto_type() 887 proto.ParseFromString(value) 888 graph.add_to_collection( 889 key, from_proto( 890 proto, import_scope=scope_to_prepend_to_names)) 891 else: 892 field = getattr(col_def, kind) 893 if key in _COMPAT_COLLECTION_LIST: 894 logging.warning( 895 "The saved meta_graph is possibly from an older release:\n" 896 "'%s' collection should be of type 'byte_list', but instead " 897 "is of type '%s'.", key, kind) 898 if kind == "node_list": 899 for value in field.value: 900 col_op = graph.as_graph_element( 901 ops.prepend_name_scope(value, scope_to_prepend_to_names)) 902 graph.add_to_collection(key, col_op) 903 elif kind == "int64_list": 904 # NOTE(opensource): This force conversion is to work around the fact 905 # that Python2 distinguishes between int and long, while Python3 has 906 # only int. 907 for value in field.value: 908 graph.add_to_collection(key, int(value)) 909 else: 910 for value in field.value: 911 graph.add_to_collection( 912 key, ops.prepend_name_scope(value, scope_to_prepend_to_names)) 913 914 var_list = {} 915 variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, 916 scope=scope_to_prepend_to_names) 917 for v in variables: 918 var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v 919 920 return var_list, imported_return_elements 921 922 923def export_scoped_meta_graph(filename=None, 924 graph_def=None, 925 graph=None, 926 export_scope=None, 927 as_text=False, 928 unbound_inputs_col_name="unbound_inputs", 929 clear_devices=False, 930 saver_def=None, 931 clear_extraneous_savers=False, 932 strip_default_attrs=False, 933 save_debug_info=False, 934 **kwargs): 935 """Returns `MetaGraphDef` proto. Optionally writes it to filename. 936 937 This function exports the graph, saver, and collection objects into 938 `MetaGraphDef` protocol buffer with the intention of it being imported 939 at a later time or location to restart training, run inference, or be 940 a subgraph. 941 942 Args: 943 filename: Optional filename including the path for writing the 944 generated `MetaGraphDef` protocol buffer. 945 graph_def: `GraphDef` protocol buffer. 946 graph: The `Graph` to export. If `None`, use the default graph. 947 export_scope: Optional `string`. Name scope under which to extract 948 the subgraph. The scope name will be stripped from the node definitions 949 for easy import later into new name scopes. If `None`, the whole graph 950 is exported. 951 as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto. 952 unbound_inputs_col_name: Optional `string`. If provided, a string collection 953 with the given name will be added to the returned `MetaGraphDef`, 954 containing the names of tensors that must be remapped when importing the 955 `MetaGraphDef`. 956 clear_devices: Boolean which controls whether to clear device information 957 before exporting the graph. 958 saver_def: `SaverDef` protocol buffer. 959 clear_extraneous_savers: Remove any Saver-related information from the 960 graph (both Save/Restore ops and SaverDefs) that are not associated 961 with the provided SaverDef. 962 strip_default_attrs: Set to true if default valued attributes must be 963 removed while exporting the GraphDef. 964 save_debug_info: If `True`, save the GraphDebugInfo to a separate file, 965 which in the same directory of filename and with `_debug` added before the 966 file extension. 967 **kwargs: Optional keyed arguments, including meta_info_def and 968 collection_list. 969 970 Returns: 971 A `MetaGraphDef` proto and dictionary of `Variables` in the exported 972 name scope. 973 974 Raises: 975 ValueError: When the `GraphDef` is larger than 2GB. 976 ValueError: When executing in Eager mode and either `graph_def` or `graph` 977 is undefined. 978 """ 979 if context.executing_eagerly() and not (graph_def is not None and 980 graph is not None): 981 raise ValueError("Exporting/importing meta graphs is not supported when " 982 "Eager Execution is enabled.") 983 graph = graph or ops.get_default_graph() 984 985 exclude_nodes = None 986 unbound_inputs = [] 987 if export_scope or clear_extraneous_savers or clear_devices: 988 if graph_def: 989 new_graph_def = graph_pb2.GraphDef() 990 new_graph_def.versions.CopyFrom(graph_def.versions) 991 new_graph_def.library.CopyFrom(graph_def.library) 992 993 if clear_extraneous_savers: 994 exclude_nodes = _find_extraneous_saver_nodes(graph_def, saver_def) 995 996 for node_def in graph_def.node: 997 if _should_include_node(node_def.name, export_scope, exclude_nodes): 998 new_node_def = _node_def(node_def, export_scope, unbound_inputs, 999 clear_devices=clear_devices) 1000 new_graph_def.node.extend([new_node_def]) 1001 graph_def = new_graph_def 1002 else: 1003 # Only do this complicated work if we want to remove a name scope. 1004 graph_def = graph_pb2.GraphDef() 1005 # pylint: disable=protected-access 1006 graph_def.versions.CopyFrom(graph.graph_def_versions) 1007 bytesize = 0 1008 1009 if clear_extraneous_savers: 1010 exclude_nodes = _find_extraneous_saver_nodes(graph.as_graph_def(), 1011 saver_def) 1012 1013 for key in sorted(graph._nodes_by_id): 1014 if _should_include_node(graph._nodes_by_id[key].name, 1015 export_scope, 1016 exclude_nodes): 1017 value = graph._nodes_by_id[key] 1018 # pylint: enable=protected-access 1019 node_def = _node_def(value.node_def, export_scope, unbound_inputs, 1020 clear_devices=clear_devices) 1021 graph_def.node.extend([node_def]) 1022 if value.outputs: 1023 assert "_output_shapes" not in graph_def.node[-1].attr 1024 graph_def.node[-1].attr["_output_shapes"].list.shape.extend([ 1025 output.get_shape().as_proto() for output in value.outputs]) 1026 bytesize += value.node_def.ByteSize() 1027 if bytesize >= (1 << 31) or bytesize < 0: 1028 raise ValueError("GraphDef cannot be larger than 2GB.") 1029 1030 graph._copy_functions_to_graph_def(graph_def, bytesize) # pylint: disable=protected-access 1031 1032 # It's possible that not all the inputs are in the export_scope. 1033 # If we would like such information included in the exported meta_graph, 1034 # add them to a special unbound_inputs collection. 1035 if unbound_inputs_col_name: 1036 # Clears the unbound_inputs collections. 1037 graph.clear_collection(unbound_inputs_col_name) 1038 for k in unbound_inputs: 1039 graph.add_to_collection(unbound_inputs_col_name, k) 1040 1041 var_list = {} 1042 variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, 1043 scope=export_scope) 1044 for v in variables: 1045 if _should_include_node(v, export_scope, exclude_nodes): 1046 var_list[ops.strip_name_scope(v.name, export_scope)] = v 1047 1048 scoped_meta_graph_def = create_meta_graph_def( 1049 graph_def=graph_def, 1050 graph=graph, 1051 export_scope=export_scope, 1052 exclude_nodes=exclude_nodes, 1053 clear_extraneous_savers=clear_extraneous_savers, 1054 saver_def=saver_def, 1055 strip_default_attrs=strip_default_attrs, 1056 **kwargs) 1057 1058 if filename: 1059 graph_io.write_graph( 1060 scoped_meta_graph_def, 1061 os.path.dirname(filename), 1062 os.path.basename(filename), 1063 as_text=as_text) 1064 if save_debug_info: 1065 name, _ = os.path.splitext(filename) 1066 debug_filename = "{name}{ext}".format(name=name, ext=".debug") 1067 1068 # Gets the operation from the graph by the name. Exludes variable nodes, 1069 # so only the nodes in the frozen models are included. 1070 ops_to_export = [] 1071 for node in scoped_meta_graph_def.graph_def.node: 1072 scoped_op_name = ops.prepend_name_scope(node.name, export_scope) 1073 ops_to_export.append(graph.get_operation_by_name(scoped_op_name)) 1074 1075 graph_debug_info = create_graph_debug_info_def(ops_to_export) 1076 1077 graph_io.write_graph( 1078 graph_debug_info, 1079 os.path.dirname(debug_filename), 1080 os.path.basename(debug_filename), 1081 as_text=as_text) 1082 1083 return scoped_meta_graph_def, var_list 1084 1085 1086def copy_scoped_meta_graph(from_scope, to_scope, 1087 from_graph=None, to_graph=None): 1088 """Copies a sub-meta_graph from one scope to another. 1089 1090 Args: 1091 from_scope: `String` name scope containing the subgraph to be copied. 1092 to_scope: `String` name scope under which the copied subgraph will reside. 1093 from_graph: Optional `Graph` from which to copy the subgraph. If `None`, the 1094 default graph is use. 1095 to_graph: Optional `Graph` to which to copy the subgraph. If `None`, the 1096 default graph is used. 1097 1098 Returns: 1099 A dictionary of `Variables` that has been copied into `to_scope`. 1100 1101 Raises: 1102 ValueError: If `from_scope` and `to_scope` are the same while 1103 `from_graph` and `to_graph` are also the same. 1104 """ 1105 from_graph = from_graph or ops.get_default_graph() 1106 to_graph = to_graph or ops.get_default_graph() 1107 1108 if from_graph == to_graph and from_scope == to_scope: 1109 raise ValueError("'from_scope' and 'to_scope' need to be different " 1110 "when performing copy in the same graph.") 1111 1112 orig_meta_graph, var_list = export_scoped_meta_graph( 1113 export_scope=from_scope, graph=from_graph) 1114 var_list = import_scoped_meta_graph(orig_meta_graph, 1115 graph=to_graph, 1116 import_scope=to_scope) 1117 return var_list 1118