• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Import a trackable object from a SavedModel."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22import os
23
24from tensorflow.core.protobuf import graph_debug_info_pb2
25from tensorflow.python.distribute import distribution_strategy_context as ds_context
26from tensorflow.python.distribute import values as ds_values
27from tensorflow.python.eager import context
28from tensorflow.python.eager import function
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import tensor_util
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import control_flow_ops
35from tensorflow.python.ops import custom_gradient
36from tensorflow.python.ops import resource_variable_ops
37from tensorflow.python.ops import variables
38from tensorflow.python.saved_model import function_deserialization
39from tensorflow.python.saved_model import load_v1_in_v2
40from tensorflow.python.saved_model import loader_impl
41from tensorflow.python.saved_model import nested_structure_coder
42from tensorflow.python.saved_model import revived_types
43from tensorflow.python.saved_model import utils_impl as saved_model_utils
44from tensorflow.python.training.tracking import base
45from tensorflow.python.training.tracking import graph_view
46from tensorflow.python.training.tracking import tracking
47from tensorflow.python.training.tracking import util
48from tensorflow.python.util import nest
49from tensorflow.python.util.tf_export import tf_export
50
51
52def _unused_handle():
53  """Returns a placeholder as a handle that is not supposed to be accessed."""
54  error_message = ("Trying to access a placeholder that is not supposed to be "
55                   "executed. This means you are executing a graph generated "
56                   "from the cross-replica context in an in-replica context.")
57
58  assert_op = control_flow_ops.Assert(
59      array_ops.placeholder_with_default(False, shape=()),
60      [error_message])
61
62  with ops.control_dependencies([assert_op]):
63    return array_ops.placeholder(dtype=dtypes.resource)
64
65
66class _WrapperFunction(function.ConcreteFunction):
67  """A class wraps a concrete function to handle different distributed contexts.
68
69  The reason for wrapping a concrete function is because the _captured_inputs
70  fields used for in-replica context and cross-replica context are different.
71  When `load()` is called from within a tf.distribute.strategy scope, the
72  captured inputs are distributed variables. When using these distributed
73  variables during calling the function, we need different approaches when it is
74  in-replica and when it is not in-replica. When it is in replica, naturally we
75  should use the corresponding component of the distributed variable; when it is
76  not in-replica, calling the function should mean that it is constructing a
77  graph that is not actually going to be used. A typical use case is when
78  constructing a functional model. In this case, return a placeholder with a
79  control dependency to ensure that is never accessed.
80  """
81
82  def __init__(self, concrete_function):
83    # Shallow copy the concrete_function
84    self.__dict__.update(vars(concrete_function))
85
86  def _call_flat(self, args, captured_inputs, cancellation_manager=None):
87
88    def get_in_replica_handle(x):
89      return x.handle if ds_values.is_distributed_variable(x) else x
90
91    def get_cross_replica_handle(x):
92      return _unused_handle() if ds_values.is_distributed_variable(x) else x
93
94    if ds_context.get_replica_context() is not None:  # in-replica context
95      captured_inputs = list(map(get_in_replica_handle, captured_inputs))
96    else:  # cross-replica context
97      captured_inputs = list(
98          map(get_cross_replica_handle, captured_inputs))
99    return super(_WrapperFunction, self)._call_flat(args, captured_inputs,
100                                                    cancellation_manager)
101
102
103class Loader(object):
104  """Helper class to load an object-based SavedModel."""
105
106  def __init__(self, object_graph_proto, saved_model_proto, export_dir):
107    meta_graph = saved_model_proto.meta_graphs[0]
108    self._asset_file_def = meta_graph.asset_file_def
109    self._operation_attributes = {
110        node.name: node.attr for node in meta_graph.graph_def.node}
111    self._proto = object_graph_proto
112    self._export_dir = export_dir
113    self._concrete_functions = (
114        function_deserialization.load_function_def_library(
115            meta_graph.graph_def.library))
116
117    for name, concrete_function in self._concrete_functions.items():
118      # Wrap all the concrete function so that they are capable of dealing with
119      # both in replica and cross replica cases.
120      self._concrete_functions[name] = _WrapperFunction(concrete_function)
121
122    self._load_all()
123    self._restore_checkpoint()
124
125    for node in self._nodes:
126      if isinstance(node, tracking.CapturableResource):
127        init_op = node._initialize()  # pylint: disable=protected-access
128        if not context.executing_eagerly():
129          ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
130
131  def _load_all(self):
132    """Loads all nodes and functions from the SavedModel and their edges."""
133    self._load_nodes()
134    self._load_edges()
135    # TODO(b/124045874): There are limitations with functions whose captures
136    # trigger other functions to be executed. For now it is only guaranteed to
137    # work if the captures of a function only trigger functions without
138    # captures.
139    self._setup_functions_structures()
140    self._setup_functions_captures()
141
142  def _load_edges(self):
143    """Adds edges from objects to other objects and functions."""
144    for node_id, object_proto in enumerate(self._proto.nodes):
145      self._add_object_graph_edges(object_proto, node_id)
146
147  def _add_object_graph_edges(self, proto, node_id):
148    """Adds edges from an object to its children."""
149    obj = self._nodes[node_id]
150    setter = self._node_setters[node_id]
151
152    for reference in proto.children:
153      setter(obj, reference.local_name, self._nodes[reference.node_id])
154      # Note: if an object has an attribute `__call__` add a class method
155      # that allows `obj()` syntax to work. This is done per-instance to
156      # allow `callable` to be used to find out if an object is callable.
157      if reference.local_name == "__call__" and not callable(obj):
158        setattr(type(obj), "__call__", _call_attribute)
159
160  def _setup_functions_structures(self):
161    """Setup structure for inputs and outputs of restored functions."""
162    coder = nested_structure_coder.StructureCoder()
163    for name, proto in sorted(self._proto.concrete_functions.items()):
164      concrete_function = self._concrete_functions[name]
165      # By setting the structured_outputs directly, we can rely on this
166      # function_lib.ConcreteFunction object to perform the output repacking
167      # logic. The only limitation of that logic is that it only works
168      # with output that is convertible to Tensors and the conversion
169      # always happens. For example tf.TensorShape([2, 3]) will be
170      # converted to Tensor representing [2, 3].
171      original_outputs = coder.decode_proto(proto.output_signature)
172      # The original_outputs here had Tensors converted to TensorSpecs, so
173      # the restored function's structured_outputs field will not be
174      # exactly the same. Fortunately the repacking logic cares only about
175      # the structure.
176      # TODO(vbardiovsky): Should we just replicate the structures, with
177      # Nones instead of real objects?
178      concrete_function._func_graph.structured_outputs = original_outputs  # pylint: disable=protected-access
179      concrete_function._func_graph.structured_input_signature = (  # pylint: disable=protected-access
180          coder.decode_proto(proto.canonicalized_input_signature))
181
182  def _setup_functions_captures(self):
183    """Setup captures and variables in restored functions."""
184    concrete_functions = sorted(self._proto.concrete_functions.items())
185    for name, proto in concrete_functions:
186      concrete_function = self._concrete_functions[name]
187      bound_inputs = [
188          self._get_tensor_from_node(node_id)
189          for node_id in proto.bound_inputs]
190      bound_variables = [
191          self._nodes[node_id]
192          for node_id in proto.bound_inputs
193          if self._proto.nodes[node_id].WhichOneof("kind") == "variable"
194      ]
195      # TODO(andresp): This is only injecting the captured inputs into the
196      # concrete function, note that we did not modify the FuncGraph
197      # itself.
198      concrete_function._captured_inputs = bound_inputs  # pylint: disable=protected-access
199      concrete_function._func_graph.variables = bound_variables  # pylint: disable=protected-access
200      if bound_inputs:
201        for bound_input, internal_capture in zip(
202            bound_inputs, concrete_function.inputs[-len(bound_inputs):]):
203          if ds_values.is_distributed_variable(bound_input):
204            concrete_function.graph.capture_distributed_variable(
205                bound_input, internal_capture)
206          else:
207            concrete_function.graph.replace_capture(bound_input,
208                                                    internal_capture)
209            if internal_capture.dtype == dtypes.resource:
210              if resource_variable_ops.is_resource_variable(bound_input):
211                try:
212                  handle = bound_input.handle
213                except ValueError:
214                  # For mirrored variables we'll copy handle data for components
215                  # as they get captured.
216                  pass
217                else:
218                  custom_gradient.copy_handle_data(handle, internal_capture)
219              else:
220                custom_gradient.copy_handle_data(bound_input, internal_capture)
221            # Setting "captures" first means "capture" won't create a new
222            # placeholder for this input.
223            concrete_function.graph.capture(bound_input)
224
225  def _get_tensor_from_node(self, node_id):
226    """Resolves a node id into a tensor to be captured for a function."""
227    with ops.init_scope():
228      obj = self._nodes[node_id]
229      if ds_values.is_distributed_variable(obj):
230        return obj
231      elif resource_variable_ops.is_resource_variable(obj):
232        return obj.handle
233      elif isinstance(obj, tracking.Asset):
234        return obj.asset_path
235      elif tensor_util.is_tensor(obj):
236        return obj
237      elif isinstance(obj, tracking.CapturableResource):
238        # Note: this executes restored functions in the CapturableResource.
239        return obj.resource_handle
240      raise ValueError("Can't convert node %s to tensor" % (type(obj)))
241
242  def _load_nodes(self):
243    """Load all saved objects."""
244    # Maps from node ids to recreated objects
245    nodes = {}
246    # Maps from node ids to setter functions (same signature as setattr) for
247    # setting dependencies.
248    node_setters = {}
249
250    # Figure out which objects are slot variables. These objects are created
251    # with Optimizer.add_slot rather than _recreate_variable.
252    slot_variable_node_ids = set()
253    for proto in self._proto.nodes:
254      for slot_variable_proto in proto.slot_variables:
255        slot_variable_node_ids.add(slot_variable_proto.slot_variable_node_id)
256
257    # Re-create everything except slot variables.
258    for node_id, proto in enumerate(self._proto.nodes):
259      if node_id in slot_variable_node_ids:
260        # Defer recreating slot variables so we can use the public Optimizer
261        # interface.
262        continue
263      node, setter = self._recreate(proto, node_id)
264      nodes[node_id] = node
265      node_setters[node_id] = setter
266
267    # Now that we have created the variables being optimized, we have enough
268    # information to re-create slot variables for them.
269    for node_id, proto in enumerate(self._proto.nodes):
270      optimizer_object = nodes[node_id]
271      for slot_variable_proto in proto.slot_variables:
272        optimized_variable = nodes[
273            slot_variable_proto.original_variable_node_id]
274        slot_variable = optimizer_object.add_slot(
275            var=optimized_variable,
276            slot_name=slot_variable_proto.slot_name)
277        nodes[slot_variable_proto.slot_variable_node_id] = slot_variable
278        node_setters[slot_variable_proto.slot_variable_node_id] = setattr
279
280    self._nodes = [nodes[node_id] for node_id in range(len(self._proto.nodes))]
281    self._node_setters = node_setters
282
283  @property
284  def _expect_partial_checkpoint(self):
285    """Whether to expect that some objects aren't loaded.
286
287    This should be set to True in subclasses of the Loader class which generate
288    a trackable object with an object graph that is different from the graph
289    in the SavedModel. Setting this property to True suppresses the warnings
290    that are printed out when there are unused parts of the checkpoint or
291    object.
292
293    Returns:
294      boolean
295    """
296    return False
297
298  def _restore_checkpoint(self):
299    """Load state from checkpoint into the deserialized objects."""
300    variables_path = saved_model_utils.get_variables_path(self._export_dir)
301    # TODO(andresp): Clean use of private methods of TrackableSaver.
302    # pylint: disable=protected-access
303    saver = util.TrackableSaver(graph_view.ObjectGraphView(self.get(0)))
304    with ops.device("CPU"):
305      saver._file_prefix_placeholder = constant_op.constant(variables_path)
306    if self._expect_partial_checkpoint:
307      load_status = saver.restore(variables_path).expect_partial()
308    else:
309      load_status = saver.restore(variables_path)
310    load_status.assert_existing_objects_matched()
311    checkpoint = load_status._checkpoint
312
313    # When running in eager mode, the `restore` call above has already run and
314    # restored the state of trackables, call `position.restore_ops()` will
315    # return an empty list as there is nothing left to do. In graph mode, that
316    # will return the list of ops that must run to restore the object on that
317    # position. We have to wire them in the initializers of the objects so that
318    # they get initialized properly when using common practices (e.g. the ones
319    # used by ManagedSession) without further user action.
320    for object_id, obj in dict(checkpoint.object_by_proto_id).items():
321      position = base.CheckpointPosition(checkpoint=checkpoint,
322                                         proto_id=object_id)
323      restore_ops = position.restore_ops()
324      if restore_ops:
325        if resource_variable_ops.is_resource_variable(obj):
326          obj._initializer_op = restore_ops
327        else:
328          raise NotImplementedError(
329              ("Missing functionality to restore state of object "
330               "%r from the checkpoint." % obj))
331
332  def adjust_debug_info_func_names(self, debug_info):
333    """Rewrite func names in the debug info by using the concrete func names."""
334    output_debug_info = graph_debug_info_pb2.GraphDebugInfo()
335    output_debug_info.files[:] = debug_info.files
336    for key in debug_info.traces:
337      node, func = key.split("@")
338      new_func = ""
339      if func in self._concrete_functions:
340        new_func = self._concrete_functions[func].function_def.signature.name
341      output_debug_info.traces[node + "@" + new_func].CopyFrom(
342          debug_info.traces[key])
343    return output_debug_info
344
345  def get(self, node_id):
346    return self._nodes[node_id]
347
348  def _recreate(self, proto, node_id):
349    """Creates a Python object from a SavedObject protocol buffer."""
350    factory = {
351        "user_object": (
352            lambda: self._recreate_user_object(proto.user_object, node_id)),
353        "asset": lambda: self._recreate_asset(proto.asset),
354        "function": lambda: self._recreate_function(proto.function),
355        "bare_concrete_function": functools.partial(
356            self._recreate_bare_concrete_function,
357            proto.bare_concrete_function),
358        "variable": lambda: self._recreate_variable(proto.variable),
359        "constant": lambda: self._recreate_constant(proto.constant),
360        "resource": lambda: self._recreate_resource(proto.resource),
361    }
362    kind = proto.WhichOneof("kind")
363    if kind not in factory:
364      raise ValueError("Unknown SavedObject type: %r" % kind)
365    return factory[kind]()
366
367  def _recreate_user_object(self, proto, node_id):
368    """Instantiates a SavedUserObject."""
369    looked_up = revived_types.deserialize(proto)
370    if looked_up is None:
371      return self._recreate_base_user_object(proto, node_id)
372    return looked_up
373
374  def _recreate_base_user_object(self, proto, node_id):
375    del proto, node_id
376    # Note: each user object has its own class. This allows making each one
377    # individually callable by adding a `__call__` method to the classes of
378    # the objects instances that have a `__call__` property.
379
380    class _UserObject(tracking.AutoTrackable):
381      pass
382
383    return _UserObject(), setattr
384
385  def _recreate_asset(self, proto):
386    filename = os.path.join(
387        saved_model_utils.get_assets_dir(self._export_dir),
388        self._asset_file_def[proto.asset_file_def_index].filename)
389    return tracking.Asset(filename), setattr
390
391  def _recreate_function(self, proto):
392    return function_deserialization.recreate_function(
393        proto, self._concrete_functions), setattr
394
395  def _recreate_bare_concrete_function(self, proto):
396    return function_deserialization.setup_bare_concrete_function(
397        proto, self._concrete_functions), setattr
398
399  def _recreate_variable(self, proto):
400    name = proto.name if proto.name else None
401    if name is not None:
402      dbg_name = name
403    else:
404      dbg_name = "<variable loaded from saved model>"
405    synchronization, aggregation, trainable = (
406        variables.validate_synchronization_aggregation_trainable(
407            proto.synchronization, proto.aggregation, proto.trainable,
408            name=dbg_name))
409
410    def uninitialized_variable_creator(next_creator, **kwargs):
411      """A variable creator that creates uninitialized variables."""
412      del next_creator
413      return resource_variable_ops.UninitializedVariable(**kwargs)
414
415    # Create a variable_creator_scope that creates uninitialized variables with
416    # a lower priority such that a potential distributed variable_creator_scope
417    # can take precedence.
418    with ops.get_default_graph()._variable_creator_scope(  # pylint: disable=protected-access
419        uninitialized_variable_creator,
420        priority=50):
421      return variables.Variable(
422          shape=proto.shape,
423          dtype=proto.dtype,
424          name=name,
425          trainable=trainable,
426          synchronization=synchronization,
427          aggregation=aggregation), setattr
428
429  def _recreate_constant(self, proto):
430    tensor_proto = self._operation_attributes[proto.operation]["value"].tensor
431    ndarray = tensor_util.MakeNdarray(tensor_proto)
432    if dtypes.as_dtype(tensor_proto.dtype) == dtypes.string:
433      with ops.device("CPU"):
434        imported_constant = constant_op.constant(ndarray)
435    else:
436      imported_constant = constant_op.constant(ndarray)
437    return imported_constant, setattr
438
439  def _recreate_resource(self, proto):
440    return _RestoredResource(device=proto.device), setattr
441
442
443# TODO(b/124205571,b/124092991): Solve destruction of resources.
444class _RestoredResource(tracking.TrackableResource):
445  """Restored SavedResource."""
446
447  def __init__(self, device=""):
448    super(_RestoredResource, self).__init__(device=device)
449    self._destroy_resource_fn = None
450
451  def _create_resource(self):
452    raise RuntimeError()
453
454  def _initialize(self):
455    raise RuntimeError()
456
457  @property
458  def _destroy_resource(self):
459    return self._destroy_resource_fn
460
461  @_destroy_resource.setter
462  def _destroy_resource(self, destroy_resource_fn):
463    self._resource_deleter = tracking.CapturableResourceDeleter(
464        destroy_resource_fn)
465    self._destroy_resource_fn = destroy_resource_fn
466
467  def _list_functions_for_serialization(self, unused_serialization_cache):
468    # Overwrite this method to avoid the implementation of
469    # base class to re-wrap the polymorphic functions into
470    # another layer of `tf.function`.
471    functions = {
472        "_create_resource": self._create_resource,
473        "_initialize": self._initialize,
474    }
475    if self._destroy_resource:
476      functions.update(_destroy_resource=self._destroy_resource)
477    return functions
478
479
480def _call_attribute(instance, *args, **kwargs):
481  return instance.__call__(*args, **kwargs)
482
483
484@tf_export("saved_model.load", v1=["saved_model.load_v2"])
485def load(export_dir, tags=None):
486  """Load a SavedModel from `export_dir`.
487
488  Signatures associated with the SavedModel are available as functions:
489
490  ```python
491  imported = tf.saved_model.load(path)
492  f = imported.signatures["serving_default"]
493  print(f(x=tf.constant([[1.]])))
494  ```
495
496  Objects exported with `tf.saved_model.save` additionally have trackable
497  objects and functions assigned to attributes:
498
499  ```python
500  exported = tf.train.Checkpoint(v=tf.Variable(3.))
501  exported.f = tf.function(
502      lambda x: exported.v * x,
503      input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
504  tf.saved_model.save(exported, path)
505  imported = tf.saved_model.load(path)
506  assert 3. == imported.v.numpy()
507  assert 6. == imported.f(x=tf.constant(2.)).numpy()
508  ```
509
510  _Loading Keras models_
511
512  Keras models are trackable, so they can be saved to SavedModel. The object
513  returned by `tf.saved_model.load` is not a Keras object (i.e. doesn't have
514  `.fit`, `.predict`, etc. methods). A few attributes and functions are still
515  available: `.variables`, `.trainable_variables` and `.__call__`.
516
517  ```python
518  model = tf.keras.Model(...)
519  tf.saved_model.save(model, path)
520  imported = tf.saved_model.load(path)
521  outputs = imported(inputs)
522  ```
523
524  Use `tf.keras.models.load_model` to restore the Keras model.
525
526  _Importing SavedModels from TensorFlow 1.x_
527
528  SavedModels from `tf.estimator.Estimator` or 1.x SavedModel APIs have a flat
529  graph instead of `tf.function` objects. These SavedModels will be loaded with
530  the following attributes:
531
532  * `.signatures`: A dictionary mapping signature names to functions.
533  * `.prune(feeds, fetches) `: A method which allows you to extract
534    functions for new subgraphs. This is equivalent to importing the SavedModel
535    and naming feeds and fetches in a Session from TensorFlow 1.x.
536
537    ```python
538    imported = tf.saved_model.load(path_to_v1_saved_model)
539    pruned = imported.prune("x:0", "out:0")
540    pruned(tf.ones([]))
541    ```
542
543    See `tf.compat.v1.wrap_function` for details.
544  * `.variables`: A list of imported variables.
545  * `.graph`: The whole imported graph.
546  * `.restore(save_path)`: A function that restores variables from a checkpoint
547    saved from `tf.compat.v1.Saver`.
548
549  _Consuming SavedModels asynchronously_
550
551  When consuming SavedModels asynchronously (the producer is a separate
552  process), the SavedModel directory will appear before all files have been
553  written, and `tf.saved_model.load` will fail if pointed at an incomplete
554  SavedModel. Rather than checking for the directory, check for
555  "saved_model_dir/saved_model.pb". This file is written atomically as the last
556  `tf.saved_model.save` file operation.
557
558  Args:
559    export_dir: The SavedModel directory to load from.
560    tags: A tag or sequence of tags identifying the MetaGraph to load. Optional
561      if the SavedModel contains a single MetaGraph, as for those exported from
562      `tf.saved_model.save`.
563
564  Returns:
565    A trackable object with a `signatures` attribute mapping from signature
566    keys to functions. If the SavedModel was exported by `tf.saved_model.load`,
567    it also points to trackable objects, functions, debug info which it has been
568    saved.
569
570  Raises:
571    ValueError: If `tags` don't match a MetaGraph in the SavedModel.
572  """
573  return load_internal(export_dir, tags)
574
575
576def load_internal(export_dir, tags=None, loader_cls=Loader):
577  """Loader implementation."""
578  if tags is not None and not isinstance(tags, set):
579    # Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered
580    # sequences for nest.flatten, so we put those through as-is.
581    tags = nest.flatten(tags)
582  saved_model_proto, debug_info = (
583      loader_impl.parse_saved_model_with_debug_info(export_dir))
584
585  if (len(saved_model_proto.meta_graphs) == 1 and
586      saved_model_proto.meta_graphs[0].HasField("object_graph_def")):
587    meta_graph_def = saved_model_proto.meta_graphs[0]
588    if (tags is not None
589        and set(tags) != set(meta_graph_def.meta_info_def.tags)):
590      raise ValueError(
591          ("The SavedModel at {} has one MetaGraph with tags {}, but got an "
592           "incompatible argument tags={} to tf.saved_model.load. You may omit "
593           "it, pass 'None', or pass matching tags.")
594          .format(export_dir, meta_graph_def.meta_info_def.tags, tags))
595    object_graph_proto = meta_graph_def.object_graph_def
596    with ops.init_scope():
597      loader = loader_cls(object_graph_proto,
598                          saved_model_proto,
599                          export_dir)
600      root = loader.get(0)
601      if isinstance(loader, Loader):
602        root.graph_debug_info = loader.adjust_debug_info_func_names(debug_info)
603    root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version
604    root.tensorflow_git_version = (
605        meta_graph_def.meta_info_def.tensorflow_git_version)
606  else:
607    with ops.init_scope():
608      root = load_v1_in_v2.load(export_dir, tags)
609      root.graph_debug_info = debug_info
610  return root
611