1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Import a trackable object from a SavedModel.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22import os 23 24from tensorflow.core.protobuf import graph_debug_info_pb2 25from tensorflow.python.distribute import distribution_strategy_context as ds_context 26from tensorflow.python.distribute import values as ds_values 27from tensorflow.python.eager import context 28from tensorflow.python.eager import function 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import ops 32from tensorflow.python.framework import tensor_util 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import control_flow_ops 35from tensorflow.python.ops import custom_gradient 36from tensorflow.python.ops import resource_variable_ops 37from tensorflow.python.ops import variables 38from tensorflow.python.saved_model import function_deserialization 39from tensorflow.python.saved_model import load_v1_in_v2 40from tensorflow.python.saved_model import loader_impl 41from tensorflow.python.saved_model import nested_structure_coder 42from tensorflow.python.saved_model import revived_types 43from tensorflow.python.saved_model import utils_impl as saved_model_utils 44from tensorflow.python.training.tracking import base 45from tensorflow.python.training.tracking import graph_view 46from tensorflow.python.training.tracking import tracking 47from tensorflow.python.training.tracking import util 48from tensorflow.python.util import nest 49from tensorflow.python.util.tf_export import tf_export 50 51 52def _unused_handle(): 53 """Returns a placeholder as a handle that is not supposed to be accessed.""" 54 error_message = ("Trying to access a placeholder that is not supposed to be " 55 "executed. This means you are executing a graph generated " 56 "from the cross-replica context in an in-replica context.") 57 58 assert_op = control_flow_ops.Assert( 59 array_ops.placeholder_with_default(False, shape=()), 60 [error_message]) 61 62 with ops.control_dependencies([assert_op]): 63 return array_ops.placeholder(dtype=dtypes.resource) 64 65 66class _WrapperFunction(function.ConcreteFunction): 67 """A class wraps a concrete function to handle different distributed contexts. 68 69 The reason for wrapping a concrete function is because the _captured_inputs 70 fields used for in-replica context and cross-replica context are different. 71 When `load()` is called from within a tf.distribute.strategy scope, the 72 captured inputs are distributed variables. When using these distributed 73 variables during calling the function, we need different approaches when it is 74 in-replica and when it is not in-replica. When it is in replica, naturally we 75 should use the corresponding component of the distributed variable; when it is 76 not in-replica, calling the function should mean that it is constructing a 77 graph that is not actually going to be used. A typical use case is when 78 constructing a functional model. In this case, return a placeholder with a 79 control dependency to ensure that is never accessed. 80 """ 81 82 def __init__(self, concrete_function): 83 # Shallow copy the concrete_function 84 self.__dict__.update(vars(concrete_function)) 85 86 def _call_flat(self, args, captured_inputs, cancellation_manager=None): 87 88 def get_in_replica_handle(x): 89 return x.handle if ds_values.is_distributed_variable(x) else x 90 91 def get_cross_replica_handle(x): 92 return _unused_handle() if ds_values.is_distributed_variable(x) else x 93 94 if ds_context.get_replica_context() is not None: # in-replica context 95 captured_inputs = list(map(get_in_replica_handle, captured_inputs)) 96 else: # cross-replica context 97 captured_inputs = list( 98 map(get_cross_replica_handle, captured_inputs)) 99 return super(_WrapperFunction, self)._call_flat(args, captured_inputs, 100 cancellation_manager) 101 102 103class Loader(object): 104 """Helper class to load an object-based SavedModel.""" 105 106 def __init__(self, object_graph_proto, saved_model_proto, export_dir): 107 meta_graph = saved_model_proto.meta_graphs[0] 108 self._asset_file_def = meta_graph.asset_file_def 109 self._operation_attributes = { 110 node.name: node.attr for node in meta_graph.graph_def.node} 111 self._proto = object_graph_proto 112 self._export_dir = export_dir 113 self._concrete_functions = ( 114 function_deserialization.load_function_def_library( 115 meta_graph.graph_def.library)) 116 117 for name, concrete_function in self._concrete_functions.items(): 118 # Wrap all the concrete function so that they are capable of dealing with 119 # both in replica and cross replica cases. 120 self._concrete_functions[name] = _WrapperFunction(concrete_function) 121 122 self._load_all() 123 self._restore_checkpoint() 124 125 for node in self._nodes: 126 if isinstance(node, tracking.CapturableResource): 127 init_op = node._initialize() # pylint: disable=protected-access 128 if not context.executing_eagerly(): 129 ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) 130 131 def _load_all(self): 132 """Loads all nodes and functions from the SavedModel and their edges.""" 133 self._load_nodes() 134 self._load_edges() 135 # TODO(b/124045874): There are limitations with functions whose captures 136 # trigger other functions to be executed. For now it is only guaranteed to 137 # work if the captures of a function only trigger functions without 138 # captures. 139 self._setup_functions_structures() 140 self._setup_functions_captures() 141 142 def _load_edges(self): 143 """Adds edges from objects to other objects and functions.""" 144 for node_id, object_proto in enumerate(self._proto.nodes): 145 self._add_object_graph_edges(object_proto, node_id) 146 147 def _add_object_graph_edges(self, proto, node_id): 148 """Adds edges from an object to its children.""" 149 obj = self._nodes[node_id] 150 setter = self._node_setters[node_id] 151 152 for reference in proto.children: 153 setter(obj, reference.local_name, self._nodes[reference.node_id]) 154 # Note: if an object has an attribute `__call__` add a class method 155 # that allows `obj()` syntax to work. This is done per-instance to 156 # allow `callable` to be used to find out if an object is callable. 157 if reference.local_name == "__call__" and not callable(obj): 158 setattr(type(obj), "__call__", _call_attribute) 159 160 def _setup_functions_structures(self): 161 """Setup structure for inputs and outputs of restored functions.""" 162 coder = nested_structure_coder.StructureCoder() 163 for name, proto in sorted(self._proto.concrete_functions.items()): 164 concrete_function = self._concrete_functions[name] 165 # By setting the structured_outputs directly, we can rely on this 166 # function_lib.ConcreteFunction object to perform the output repacking 167 # logic. The only limitation of that logic is that it only works 168 # with output that is convertible to Tensors and the conversion 169 # always happens. For example tf.TensorShape([2, 3]) will be 170 # converted to Tensor representing [2, 3]. 171 original_outputs = coder.decode_proto(proto.output_signature) 172 # The original_outputs here had Tensors converted to TensorSpecs, so 173 # the restored function's structured_outputs field will not be 174 # exactly the same. Fortunately the repacking logic cares only about 175 # the structure. 176 # TODO(vbardiovsky): Should we just replicate the structures, with 177 # Nones instead of real objects? 178 concrete_function._func_graph.structured_outputs = original_outputs # pylint: disable=protected-access 179 concrete_function._func_graph.structured_input_signature = ( # pylint: disable=protected-access 180 coder.decode_proto(proto.canonicalized_input_signature)) 181 182 def _setup_functions_captures(self): 183 """Setup captures and variables in restored functions.""" 184 concrete_functions = sorted(self._proto.concrete_functions.items()) 185 for name, proto in concrete_functions: 186 concrete_function = self._concrete_functions[name] 187 bound_inputs = [ 188 self._get_tensor_from_node(node_id) 189 for node_id in proto.bound_inputs] 190 bound_variables = [ 191 self._nodes[node_id] 192 for node_id in proto.bound_inputs 193 if self._proto.nodes[node_id].WhichOneof("kind") == "variable" 194 ] 195 # TODO(andresp): This is only injecting the captured inputs into the 196 # concrete function, note that we did not modify the FuncGraph 197 # itself. 198 concrete_function._captured_inputs = bound_inputs # pylint: disable=protected-access 199 concrete_function._func_graph.variables = bound_variables # pylint: disable=protected-access 200 if bound_inputs: 201 for bound_input, internal_capture in zip( 202 bound_inputs, concrete_function.inputs[-len(bound_inputs):]): 203 if ds_values.is_distributed_variable(bound_input): 204 concrete_function.graph.capture_distributed_variable( 205 bound_input, internal_capture) 206 else: 207 concrete_function.graph.replace_capture(bound_input, 208 internal_capture) 209 if internal_capture.dtype == dtypes.resource: 210 if resource_variable_ops.is_resource_variable(bound_input): 211 try: 212 handle = bound_input.handle 213 except ValueError: 214 # For mirrored variables we'll copy handle data for components 215 # as they get captured. 216 pass 217 else: 218 custom_gradient.copy_handle_data(handle, internal_capture) 219 else: 220 custom_gradient.copy_handle_data(bound_input, internal_capture) 221 # Setting "captures" first means "capture" won't create a new 222 # placeholder for this input. 223 concrete_function.graph.capture(bound_input) 224 225 def _get_tensor_from_node(self, node_id): 226 """Resolves a node id into a tensor to be captured for a function.""" 227 with ops.init_scope(): 228 obj = self._nodes[node_id] 229 if ds_values.is_distributed_variable(obj): 230 return obj 231 elif resource_variable_ops.is_resource_variable(obj): 232 return obj.handle 233 elif isinstance(obj, tracking.Asset): 234 return obj.asset_path 235 elif tensor_util.is_tensor(obj): 236 return obj 237 elif isinstance(obj, tracking.CapturableResource): 238 # Note: this executes restored functions in the CapturableResource. 239 return obj.resource_handle 240 raise ValueError("Can't convert node %s to tensor" % (type(obj))) 241 242 def _load_nodes(self): 243 """Load all saved objects.""" 244 # Maps from node ids to recreated objects 245 nodes = {} 246 # Maps from node ids to setter functions (same signature as setattr) for 247 # setting dependencies. 248 node_setters = {} 249 250 # Figure out which objects are slot variables. These objects are created 251 # with Optimizer.add_slot rather than _recreate_variable. 252 slot_variable_node_ids = set() 253 for proto in self._proto.nodes: 254 for slot_variable_proto in proto.slot_variables: 255 slot_variable_node_ids.add(slot_variable_proto.slot_variable_node_id) 256 257 # Re-create everything except slot variables. 258 for node_id, proto in enumerate(self._proto.nodes): 259 if node_id in slot_variable_node_ids: 260 # Defer recreating slot variables so we can use the public Optimizer 261 # interface. 262 continue 263 node, setter = self._recreate(proto, node_id) 264 nodes[node_id] = node 265 node_setters[node_id] = setter 266 267 # Now that we have created the variables being optimized, we have enough 268 # information to re-create slot variables for them. 269 for node_id, proto in enumerate(self._proto.nodes): 270 optimizer_object = nodes[node_id] 271 for slot_variable_proto in proto.slot_variables: 272 optimized_variable = nodes[ 273 slot_variable_proto.original_variable_node_id] 274 slot_variable = optimizer_object.add_slot( 275 var=optimized_variable, 276 slot_name=slot_variable_proto.slot_name) 277 nodes[slot_variable_proto.slot_variable_node_id] = slot_variable 278 node_setters[slot_variable_proto.slot_variable_node_id] = setattr 279 280 self._nodes = [nodes[node_id] for node_id in range(len(self._proto.nodes))] 281 self._node_setters = node_setters 282 283 @property 284 def _expect_partial_checkpoint(self): 285 """Whether to expect that some objects aren't loaded. 286 287 This should be set to True in subclasses of the Loader class which generate 288 a trackable object with an object graph that is different from the graph 289 in the SavedModel. Setting this property to True suppresses the warnings 290 that are printed out when there are unused parts of the checkpoint or 291 object. 292 293 Returns: 294 boolean 295 """ 296 return False 297 298 def _restore_checkpoint(self): 299 """Load state from checkpoint into the deserialized objects.""" 300 variables_path = saved_model_utils.get_variables_path(self._export_dir) 301 # TODO(andresp): Clean use of private methods of TrackableSaver. 302 # pylint: disable=protected-access 303 saver = util.TrackableSaver(graph_view.ObjectGraphView(self.get(0))) 304 with ops.device("CPU"): 305 saver._file_prefix_placeholder = constant_op.constant(variables_path) 306 if self._expect_partial_checkpoint: 307 load_status = saver.restore(variables_path).expect_partial() 308 else: 309 load_status = saver.restore(variables_path) 310 load_status.assert_existing_objects_matched() 311 checkpoint = load_status._checkpoint 312 313 # When running in eager mode, the `restore` call above has already run and 314 # restored the state of trackables, call `position.restore_ops()` will 315 # return an empty list as there is nothing left to do. In graph mode, that 316 # will return the list of ops that must run to restore the object on that 317 # position. We have to wire them in the initializers of the objects so that 318 # they get initialized properly when using common practices (e.g. the ones 319 # used by ManagedSession) without further user action. 320 for object_id, obj in dict(checkpoint.object_by_proto_id).items(): 321 position = base.CheckpointPosition(checkpoint=checkpoint, 322 proto_id=object_id) 323 restore_ops = position.restore_ops() 324 if restore_ops: 325 if resource_variable_ops.is_resource_variable(obj): 326 obj._initializer_op = restore_ops 327 else: 328 raise NotImplementedError( 329 ("Missing functionality to restore state of object " 330 "%r from the checkpoint." % obj)) 331 332 def adjust_debug_info_func_names(self, debug_info): 333 """Rewrite func names in the debug info by using the concrete func names.""" 334 output_debug_info = graph_debug_info_pb2.GraphDebugInfo() 335 output_debug_info.files[:] = debug_info.files 336 for key in debug_info.traces: 337 node, func = key.split("@") 338 new_func = "" 339 if func in self._concrete_functions: 340 new_func = self._concrete_functions[func].function_def.signature.name 341 output_debug_info.traces[node + "@" + new_func].CopyFrom( 342 debug_info.traces[key]) 343 return output_debug_info 344 345 def get(self, node_id): 346 return self._nodes[node_id] 347 348 def _recreate(self, proto, node_id): 349 """Creates a Python object from a SavedObject protocol buffer.""" 350 factory = { 351 "user_object": ( 352 lambda: self._recreate_user_object(proto.user_object, node_id)), 353 "asset": lambda: self._recreate_asset(proto.asset), 354 "function": lambda: self._recreate_function(proto.function), 355 "bare_concrete_function": functools.partial( 356 self._recreate_bare_concrete_function, 357 proto.bare_concrete_function), 358 "variable": lambda: self._recreate_variable(proto.variable), 359 "constant": lambda: self._recreate_constant(proto.constant), 360 "resource": lambda: self._recreate_resource(proto.resource), 361 } 362 kind = proto.WhichOneof("kind") 363 if kind not in factory: 364 raise ValueError("Unknown SavedObject type: %r" % kind) 365 return factory[kind]() 366 367 def _recreate_user_object(self, proto, node_id): 368 """Instantiates a SavedUserObject.""" 369 looked_up = revived_types.deserialize(proto) 370 if looked_up is None: 371 return self._recreate_base_user_object(proto, node_id) 372 return looked_up 373 374 def _recreate_base_user_object(self, proto, node_id): 375 del proto, node_id 376 # Note: each user object has its own class. This allows making each one 377 # individually callable by adding a `__call__` method to the classes of 378 # the objects instances that have a `__call__` property. 379 380 class _UserObject(tracking.AutoTrackable): 381 pass 382 383 return _UserObject(), setattr 384 385 def _recreate_asset(self, proto): 386 filename = os.path.join( 387 saved_model_utils.get_assets_dir(self._export_dir), 388 self._asset_file_def[proto.asset_file_def_index].filename) 389 return tracking.Asset(filename), setattr 390 391 def _recreate_function(self, proto): 392 return function_deserialization.recreate_function( 393 proto, self._concrete_functions), setattr 394 395 def _recreate_bare_concrete_function(self, proto): 396 return function_deserialization.setup_bare_concrete_function( 397 proto, self._concrete_functions), setattr 398 399 def _recreate_variable(self, proto): 400 name = proto.name if proto.name else None 401 if name is not None: 402 dbg_name = name 403 else: 404 dbg_name = "<variable loaded from saved model>" 405 synchronization, aggregation, trainable = ( 406 variables.validate_synchronization_aggregation_trainable( 407 proto.synchronization, proto.aggregation, proto.trainable, 408 name=dbg_name)) 409 410 def uninitialized_variable_creator(next_creator, **kwargs): 411 """A variable creator that creates uninitialized variables.""" 412 del next_creator 413 return resource_variable_ops.UninitializedVariable(**kwargs) 414 415 # Create a variable_creator_scope that creates uninitialized variables with 416 # a lower priority such that a potential distributed variable_creator_scope 417 # can take precedence. 418 with ops.get_default_graph()._variable_creator_scope( # pylint: disable=protected-access 419 uninitialized_variable_creator, 420 priority=50): 421 return variables.Variable( 422 shape=proto.shape, 423 dtype=proto.dtype, 424 name=name, 425 trainable=trainable, 426 synchronization=synchronization, 427 aggregation=aggregation), setattr 428 429 def _recreate_constant(self, proto): 430 tensor_proto = self._operation_attributes[proto.operation]["value"].tensor 431 ndarray = tensor_util.MakeNdarray(tensor_proto) 432 if dtypes.as_dtype(tensor_proto.dtype) == dtypes.string: 433 with ops.device("CPU"): 434 imported_constant = constant_op.constant(ndarray) 435 else: 436 imported_constant = constant_op.constant(ndarray) 437 return imported_constant, setattr 438 439 def _recreate_resource(self, proto): 440 return _RestoredResource(device=proto.device), setattr 441 442 443# TODO(b/124205571,b/124092991): Solve destruction of resources. 444class _RestoredResource(tracking.TrackableResource): 445 """Restored SavedResource.""" 446 447 def __init__(self, device=""): 448 super(_RestoredResource, self).__init__(device=device) 449 self._destroy_resource_fn = None 450 451 def _create_resource(self): 452 raise RuntimeError() 453 454 def _initialize(self): 455 raise RuntimeError() 456 457 @property 458 def _destroy_resource(self): 459 return self._destroy_resource_fn 460 461 @_destroy_resource.setter 462 def _destroy_resource(self, destroy_resource_fn): 463 self._resource_deleter = tracking.CapturableResourceDeleter( 464 destroy_resource_fn) 465 self._destroy_resource_fn = destroy_resource_fn 466 467 def _list_functions_for_serialization(self, unused_serialization_cache): 468 # Overwrite this method to avoid the implementation of 469 # base class to re-wrap the polymorphic functions into 470 # another layer of `tf.function`. 471 functions = { 472 "_create_resource": self._create_resource, 473 "_initialize": self._initialize, 474 } 475 if self._destroy_resource: 476 functions.update(_destroy_resource=self._destroy_resource) 477 return functions 478 479 480def _call_attribute(instance, *args, **kwargs): 481 return instance.__call__(*args, **kwargs) 482 483 484@tf_export("saved_model.load", v1=["saved_model.load_v2"]) 485def load(export_dir, tags=None): 486 """Load a SavedModel from `export_dir`. 487 488 Signatures associated with the SavedModel are available as functions: 489 490 ```python 491 imported = tf.saved_model.load(path) 492 f = imported.signatures["serving_default"] 493 print(f(x=tf.constant([[1.]]))) 494 ``` 495 496 Objects exported with `tf.saved_model.save` additionally have trackable 497 objects and functions assigned to attributes: 498 499 ```python 500 exported = tf.train.Checkpoint(v=tf.Variable(3.)) 501 exported.f = tf.function( 502 lambda x: exported.v * x, 503 input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)]) 504 tf.saved_model.save(exported, path) 505 imported = tf.saved_model.load(path) 506 assert 3. == imported.v.numpy() 507 assert 6. == imported.f(x=tf.constant(2.)).numpy() 508 ``` 509 510 _Loading Keras models_ 511 512 Keras models are trackable, so they can be saved to SavedModel. The object 513 returned by `tf.saved_model.load` is not a Keras object (i.e. doesn't have 514 `.fit`, `.predict`, etc. methods). A few attributes and functions are still 515 available: `.variables`, `.trainable_variables` and `.__call__`. 516 517 ```python 518 model = tf.keras.Model(...) 519 tf.saved_model.save(model, path) 520 imported = tf.saved_model.load(path) 521 outputs = imported(inputs) 522 ``` 523 524 Use `tf.keras.models.load_model` to restore the Keras model. 525 526 _Importing SavedModels from TensorFlow 1.x_ 527 528 SavedModels from `tf.estimator.Estimator` or 1.x SavedModel APIs have a flat 529 graph instead of `tf.function` objects. These SavedModels will be loaded with 530 the following attributes: 531 532 * `.signatures`: A dictionary mapping signature names to functions. 533 * `.prune(feeds, fetches) `: A method which allows you to extract 534 functions for new subgraphs. This is equivalent to importing the SavedModel 535 and naming feeds and fetches in a Session from TensorFlow 1.x. 536 537 ```python 538 imported = tf.saved_model.load(path_to_v1_saved_model) 539 pruned = imported.prune("x:0", "out:0") 540 pruned(tf.ones([])) 541 ``` 542 543 See `tf.compat.v1.wrap_function` for details. 544 * `.variables`: A list of imported variables. 545 * `.graph`: The whole imported graph. 546 * `.restore(save_path)`: A function that restores variables from a checkpoint 547 saved from `tf.compat.v1.Saver`. 548 549 _Consuming SavedModels asynchronously_ 550 551 When consuming SavedModels asynchronously (the producer is a separate 552 process), the SavedModel directory will appear before all files have been 553 written, and `tf.saved_model.load` will fail if pointed at an incomplete 554 SavedModel. Rather than checking for the directory, check for 555 "saved_model_dir/saved_model.pb". This file is written atomically as the last 556 `tf.saved_model.save` file operation. 557 558 Args: 559 export_dir: The SavedModel directory to load from. 560 tags: A tag or sequence of tags identifying the MetaGraph to load. Optional 561 if the SavedModel contains a single MetaGraph, as for those exported from 562 `tf.saved_model.save`. 563 564 Returns: 565 A trackable object with a `signatures` attribute mapping from signature 566 keys to functions. If the SavedModel was exported by `tf.saved_model.load`, 567 it also points to trackable objects, functions, debug info which it has been 568 saved. 569 570 Raises: 571 ValueError: If `tags` don't match a MetaGraph in the SavedModel. 572 """ 573 return load_internal(export_dir, tags) 574 575 576def load_internal(export_dir, tags=None, loader_cls=Loader): 577 """Loader implementation.""" 578 if tags is not None and not isinstance(tags, set): 579 # Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered 580 # sequences for nest.flatten, so we put those through as-is. 581 tags = nest.flatten(tags) 582 saved_model_proto, debug_info = ( 583 loader_impl.parse_saved_model_with_debug_info(export_dir)) 584 585 if (len(saved_model_proto.meta_graphs) == 1 and 586 saved_model_proto.meta_graphs[0].HasField("object_graph_def")): 587 meta_graph_def = saved_model_proto.meta_graphs[0] 588 if (tags is not None 589 and set(tags) != set(meta_graph_def.meta_info_def.tags)): 590 raise ValueError( 591 ("The SavedModel at {} has one MetaGraph with tags {}, but got an " 592 "incompatible argument tags={} to tf.saved_model.load. You may omit " 593 "it, pass 'None', or pass matching tags.") 594 .format(export_dir, meta_graph_def.meta_info_def.tags, tags)) 595 object_graph_proto = meta_graph_def.object_graph_def 596 with ops.init_scope(): 597 loader = loader_cls(object_graph_proto, 598 saved_model_proto, 599 export_dir) 600 root = loader.get(0) 601 if isinstance(loader, Loader): 602 root.graph_debug_info = loader.adjust_debug_info_func_names(debug_info) 603 root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version 604 root.tensorflow_git_version = ( 605 meta_graph_def.meta_info_def.tensorflow_git_version) 606 else: 607 with ops.init_scope(): 608 root = load_v1_in_v2.load(export_dir, tags) 609 root.graph_debug_info = debug_info 610 return root 611