• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Import a trackable object from a SavedModel."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22import os
23import sys
24
25from tensorflow.core.protobuf import graph_debug_info_pb2
26from tensorflow.python.distribute import distribute_utils
27from tensorflow.python.distribute import distribution_strategy_context as ds_context
28from tensorflow.python.distribute import values_util
29from tensorflow.python.eager import context
30from tensorflow.python.eager import def_function
31from tensorflow.python.eager import function
32from tensorflow.python.framework import constant_op
33from tensorflow.python.framework import dtypes
34from tensorflow.python.framework import errors
35from tensorflow.python.framework import ops
36from tensorflow.python.framework import tensor_util
37from tensorflow.python.ops import array_ops
38from tensorflow.python.ops import control_flow_ops
39from tensorflow.python.ops import handle_data_util
40from tensorflow.python.ops import lookup_ops
41from tensorflow.python.ops import resource_variable_ops
42from tensorflow.python.ops import variables
43from tensorflow.python.saved_model import function_deserialization
44from tensorflow.python.saved_model import load_options
45from tensorflow.python.saved_model import load_v1_in_v2
46from tensorflow.python.saved_model import loader_impl
47from tensorflow.python.saved_model import nested_structure_coder
48from tensorflow.python.saved_model import revived_types
49from tensorflow.python.saved_model import utils_impl as saved_model_utils
50from tensorflow.python.saved_model.pywrap_saved_model import metrics
51from tensorflow.python.training.saving import checkpoint_options
52from tensorflow.python.training.saving import saveable_object_util
53from tensorflow.python.training.tracking import base
54from tensorflow.python.training.tracking import data_structures
55from tensorflow.python.training.tracking import graph_view
56from tensorflow.python.training.tracking import tracking
57from tensorflow.python.training.tracking import util
58from tensorflow.python.util import nest
59from tensorflow.python.util.tf_export import tf_export
60
61# API label for SavedModel metrics.
62_LOAD_V2_LABEL = "load_v2"
63
64
65def _unused_handle():
66  """Returns a placeholder as a handle that is not supposed to be accessed."""
67  error_message = ("Trying to access a placeholder that is not supposed to be "
68                   "executed. This means you are executing a graph generated "
69                   "from the cross-replica context in an in-replica context.")
70
71  assert_op = control_flow_ops.Assert(
72      array_ops.placeholder_with_default(False, shape=()),
73      [error_message])
74
75  with ops.control_dependencies([assert_op]):
76    return array_ops.placeholder(dtype=dtypes.resource)
77
78
79class _WrapperFunction(function.ConcreteFunction):
80  """A class wraps a concrete function to handle different distributed contexts.
81
82  The reason for wrapping a concrete function is because the _captured_inputs
83  fields used for in-replica context and cross-replica context are different.
84  When `load()` is called from within a tf.distribute.strategy scope, the
85  captured inputs are distributed variables. When using these distributed
86  variables during calling the function, we need different approaches when it is
87  in-replica and when it is not in-replica. When it is in replica, naturally we
88  should use the corresponding component of the distributed variable; when it is
89  not in-replica, calling the function should mean that it is constructing a
90  graph that is not actually going to be used. A typical use case is when
91  constructing a functional model. In this case, return a placeholder with a
92  control dependency to ensure that is never accessed.
93  """
94
95  def __init__(self, concrete_function):
96    # Shallow copy the concrete_function
97    self.__dict__.update(vars(concrete_function))
98
99  def _call_flat(self, args, captured_inputs, cancellation_manager=None):
100
101    def get_handle(x):
102      return x.handle if distribute_utils.is_distributed_variable(x) else x
103
104    def get_unused_handle(x):
105      return _unused_handle() if distribute_utils.is_distributed_variable(x)   \
106          else x
107
108    if (ds_context.get_replica_context() is not None or
109        values_util.is_saving_non_distributed()):
110      # If we're in the replica context or are saving a non-distributed version
111      # of the model, we resolve the captured variables to the corresponding
112      # resource handle. In both situation we call var.handle, but it has
113      # different behavior. In the replica context, var.handle resolves the
114      # replica local variable handle if the variable is replicated. When saving
115      # a non-distributed version of the model, var.handle resolves to the
116      # primary variable handle, since we only save one copy of a replicated
117      # variable.
118      captured_inputs = list(map(get_handle, captured_inputs))
119    else:  # cross-replica context
120      captured_inputs = list(map(get_unused_handle, captured_inputs))
121    return super(_WrapperFunction, self)._call_flat(args, captured_inputs,
122                                                    cancellation_manager)
123
124
125class Loader(object):
126  """Helper class to load an object-based SavedModel."""
127
128  def __init__(self, object_graph_proto, saved_model_proto, export_dir,
129               ckpt_options, save_options, filters):
130    meta_graph = saved_model_proto.meta_graphs[0]
131    self._asset_file_def = meta_graph.asset_file_def
132    self._operation_attributes = {
133        node.name: node.attr for node in meta_graph.graph_def.node}
134    self._proto = object_graph_proto
135    self._export_dir = export_dir
136    self._concrete_functions = (
137        function_deserialization.load_function_def_library(
138            meta_graph.graph_def.library, wrapper_function=_WrapperFunction))
139    self._checkpoint_options = ckpt_options
140    self._save_options = save_options
141
142    # Stores user-defined node_filters argument.
143    self._node_filters = filters
144    # Stores map of string paths to integers.
145    self._node_path_to_id = self._convert_node_paths_to_ints()
146    self._loaded_nodes = {}
147    if isinstance(filters, dict):
148      # If node_filters is a dict, then the values may contain already created
149      # trackable objects. In this case, create a dictionary mapping node IDs to
150      # the already created nodes. This dict will be updated in
151      # `_retrieve_all_filtered_nodes` with tracked dependencies.
152      for node_path, node in filters.items():
153        if isinstance(node, tuple):
154          self._loaded_nodes[self._node_path_to_id[node_path]] = node
155        else:
156          self._loaded_nodes[self._node_path_to_id[node_path]] = (node, setattr)
157
158    # Get a list of all integer node ids to load, or None if all nodes should be
159    # loaded. This list includes ids of child nodes.
160    self._filtered_nodes = self._retrieve_all_filtered_nodes()
161
162    self._load_all()
163
164    if not save_options.experimental_skip_checkpoint:
165      self._restore_checkpoint()
166      for node in self._nodes:
167        if isinstance(node, tracking.CapturableResource):
168          init_op = node._initialize()  # pylint: disable=protected-access
169          if not context.executing_eagerly():
170            ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
171
172  def _convert_node_paths_to_ints(self):
173    """Maps all string node paths in node_filters to the int node ids."""
174    if self._node_filters is None:
175      return None
176    path_to_int = {}
177    for node_id in self._node_filters:
178      int_node_id = None
179      if isinstance(node_id, str):
180        node_path = node_id.split(".")
181        if node_path[0] != "root":
182          raise ValueError(
183              "When passing string identifiers to node_filters, the first name"
184              f" must be root. Received {node_path[0]}.")
185        int_node_id = 0
186        for n, name in enumerate(node_path[1:]):
187          int_node_id = self._find_node_child(
188              int_node_id, name, ".".join(node_path[:n+2]))
189        path_to_int[node_id] = int_node_id
190      else:
191        raise TypeError("Elements in node_filters must be strings.")
192    return path_to_int
193
194  def _retrieve_all_filtered_nodes(self):
195    """Traverses through the object graph to get the IDs of all nodes to load.
196
197    As a side-effect, if node_filters is a dictionary that contains already-
198    created objects, then the dependencies tracked by those objects will be
199    added to node_filters.
200
201    Returns:
202      List of all nodes to load, or None if all nodes should be loaded.
203
204    """
205    if self._node_filters is None:
206      return None  # All nodes should be loaded.
207
208    all_filtered_nodes = set()
209    nodes_to_visit = list(self._node_filters)
210
211    while nodes_to_visit:
212      node_path = nodes_to_visit.pop(0)
213      node_id = self._node_path_to_id[node_path]
214      if node_id in all_filtered_nodes:
215        continue
216      all_filtered_nodes.add(node_id)
217
218      node, setter = self._loaded_nodes.get(node_id, (None, None))
219      if node is not None:
220        if not isinstance(node, base.Trackable):
221          raise TypeError(
222              "Error when processing dictionary values passed to nodes_to_load."
223              f"Object at {node_path} is expected to be a checkpointable (i.e. "
224              "'trackable') TensorFlow object (e.g. tf.Variable, tf.Module or "
225              "Keras layer).")
226        node._maybe_initialize_trackable()  # pylint: disable=protected-access
227
228      for reference in self._proto.nodes[node_id].children:
229        child_object, _ = self._loaded_nodes.get(
230            reference.node_id, (None, None))
231
232        # See if node already tracks the child reference, in which case add the
233        # child to the loaded_nodes dict.
234        if child_object is None and node is not None:
235          child_object = node._lookup_dependency(reference.local_name)  # pylint: disable=protected-access
236          if isinstance(child_object, data_structures.TrackableDataStructure):
237            # Make setattr a noop to avoid overwriting already existing data
238            # structures.
239            setter = lambda *args: None
240
241            self._loaded_nodes[reference.node_id] = (child_object, setter)
242
243        child_path = "{}.{}".format(node_path, reference.local_name)
244        self._node_path_to_id[child_path] = reference.node_id
245        nodes_to_visit.append(child_path)
246
247    if 0 in all_filtered_nodes:
248      return None
249    return all_filtered_nodes
250
251  def _find_node_child(self, node_id, child_name, path):
252    for reference in self._proto.nodes[node_id].children:
253      if reference.local_name == child_name:
254        return reference.node_id
255    raise ValueError(f"Unable to find node {path}.")
256
257  def _load_all(self):
258    """Loads all nodes and functions from the SavedModel and their edges."""
259    self._load_nodes()
260    self._load_edges()
261    # TODO(b/124045874): There are limitations with functions whose captures
262    # trigger other functions to be executed. For now it is only guaranteed to
263    # work if the captures of a function only trigger functions without
264    # captures.
265    self._setup_functions_structures()
266    self._setup_functions_captures()
267
268    self._create_saveable_object_factories()
269
270  def _create_saveable_object_factories(self):
271    for node_id, proto in self._iter_all_nodes():
272      node = self.get(node_id)
273      node._self_saveable_object_factories = {}  # pylint: disable=protected-access
274      for name, saveable_object_proto in proto.saveable_objects.items():
275        node._self_saveable_object_factories[name] = (  # pylint: disable=protected-access
276            saveable_object_util.restored_saved_object_factory(
277                self.get(saveable_object_proto.save_function),
278                self.get(saveable_object_proto.restore_function)))
279
280  def _load_edges(self):
281    """Adds edges from objects to other objects and functions."""
282    for node_id, object_proto in self._iter_all_nodes():
283      self._add_object_graph_edges(object_proto, node_id)
284
285    # If root object isn't loaded, then create edges from the root for
286    # checkpoint compatibility.
287    if self._filtered_nodes is not None and 0 not in self._filtered_nodes:
288      root = self.get(0)
289      for node_path in self._node_filters:
290        loaded_node = self._nodes[self._node_path_to_id[node_path]]
291        path = node_path.split(".")
292        current_node = root
293        for name in path[1:-1]:
294          if not hasattr(current_node, name):
295            setattr(current_node, name, self._recreate_base_user_object()[0])
296          current_node = getattr(current_node, name)
297        if not hasattr(current_node, path[-1]):
298          setattr(current_node, path[-1], loaded_node)
299
300  def _add_object_graph_edges(self, proto, node_id):
301    """Adds edges from an object to its children."""
302    obj = self._nodes[node_id]
303    setter = self._node_setters[node_id]
304
305    for reference in proto.children:
306      setter(obj, reference.local_name, self._nodes[reference.node_id])
307      # Note: if an object has an attribute `__call__` add a class method
308      # that allows `obj()` syntax to work. This is done per-instance to
309      # allow `callable` to be used to find out if an object is callable.
310      if reference.local_name == "__call__" and not callable(obj):
311        setattr(type(obj), "__call__", _call_attribute)
312
313  def _setup_functions_structures(self):
314    """Setup structure for inputs and outputs of restored functions."""
315    coder = nested_structure_coder.StructureCoder()
316    for name, proto in sorted(self._proto.concrete_functions.items()):
317      concrete_function = self._concrete_functions[name]
318      # By setting the structured_outputs directly, we can rely on this
319      # function_lib.ConcreteFunction object to perform the output repacking
320      # logic. The only limitation of that logic is that it only works
321      # with output that is convertible to Tensors and the conversion
322      # always happens. For example tf.TensorShape([2, 3]) will be
323      # converted to Tensor representing [2, 3].
324      original_outputs = coder.decode_proto(proto.output_signature)
325      # The original_outputs here had Tensors converted to TensorSpecs, so
326      # the restored function's structured_outputs field will not be
327      # exactly the same. Fortunately the repacking logic cares only about
328      # the structure; and the unpacking logic cares only about structure
329      # and types.
330      concrete_function._func_graph.structured_outputs = original_outputs  # pylint: disable=protected-access
331      concrete_function._func_graph.structured_input_signature = (  # pylint: disable=protected-access
332          coder.decode_proto(proto.canonicalized_input_signature))
333      concrete_function._initialize_function_spec()  # pylint: disable=protected-access
334
335  def _setup_functions_captures(self):
336    """Setup captures and variables in restored functions."""
337    concrete_functions = sorted(self._proto.concrete_functions.items())
338    for name, proto in concrete_functions:
339      concrete_function = self._concrete_functions[name]
340      bound_inputs = [
341          self._get_tensor_from_node(node_id, name)
342          for node_id in proto.bound_inputs]
343      bound_variables = [
344          self._nodes[node_id]
345          for node_id in proto.bound_inputs
346          if self._proto.nodes[node_id].WhichOneof("kind") == "variable"
347      ]
348      # TODO(andresp): This is only injecting the captured inputs into the
349      # concrete function, note that we did not modify the FuncGraph
350      # itself.
351      concrete_function._captured_inputs = bound_inputs  # pylint: disable=protected-access
352      concrete_function._func_graph.variables = bound_variables  # pylint: disable=protected-access
353      if bound_inputs:
354        for bound_input, internal_capture in zip(
355            bound_inputs, concrete_function.inputs[-len(bound_inputs):]):
356          if distribute_utils.is_distributed_variable(bound_input):
357            concrete_function.graph.capture_distributed_variable(
358                bound_input, internal_capture)
359          else:
360            concrete_function.graph.replace_capture(bound_input,
361                                                    internal_capture)
362            if internal_capture.dtype == dtypes.resource:
363              if resource_variable_ops.is_resource_variable(bound_input):
364                try:
365                  handle = bound_input.handle
366                except ValueError:
367                  # For mirrored variables we'll copy handle data for components
368                  # as they get captured.
369                  pass
370                else:
371                  handle_data_util.copy_handle_data(handle, internal_capture)
372              else:
373                handle_data_util.copy_handle_data(bound_input, internal_capture)
374            # Setting "captures" first means "capture" won't create a new
375            # placeholder for this input.
376            concrete_function.graph.capture(bound_input)
377
378  def _get_tensor_from_node(self, node_id, fn_name):
379    """Resolves a node id into a tensor to be captured for a function."""
380    if self._node_filters is not None and self._nodes[node_id] is None:
381      raise ValueError(
382          f"Error when processing nodes_to_load. Function '{fn_name}' requires "
383          "inputs/variables that are not loaded when nodes_to_load="
384          f"{self._node_filters}.")
385
386    with ops.init_scope():
387      obj = self._nodes[node_id]
388      if distribute_utils.is_distributed_variable(obj):
389        return obj
390      elif resource_variable_ops.is_resource_variable(obj):
391        return obj.handle
392      elif isinstance(obj, tracking.Asset):
393        return obj.asset_path
394      elif tensor_util.is_tf_type(obj):
395        return obj
396      elif isinstance(obj, tracking.CapturableResource):
397        # Note: this executes restored functions in the CapturableResource.
398        return obj.resource_handle
399      raise ValueError(f"Cannot convert node {obj} to tensor.")
400
401  def _initialize_loaded_nodes(self):
402    nodes = {}
403    node_setters = {}
404    for node_id, (node, setter) in self._loaded_nodes.items():
405      nodes[node_id] = node
406      node_setters[node_id] = setter
407    return nodes, node_setters
408
409  def _iter_all_nodes(self):
410    if self._filtered_nodes is None:
411      return enumerate(self._proto.nodes)
412    else:
413      return [(node_id, self._proto.nodes[node_id])
414              for node_id in self._filtered_nodes]
415
416  def _load_nodes(self):
417    """Load all saved objects."""
418    # `nodes` maps from node ids to recreated objects
419    # `node_setters` maps from node ids to setter functions
420    # (same signature as setattr) for setting dependencies.
421    nodes, node_setters = self._initialize_loaded_nodes()
422
423    # Figure out which objects are slot variables. These objects are created
424    # with Optimizer.add_slot rather than _recreate_variable.
425    slot_variable_node_ids = set()
426
427    for _, proto in self._iter_all_nodes():
428      for slot_variable_proto in proto.slot_variables:
429        slot_variable_node_ids.add(slot_variable_proto.slot_variable_node_id)
430
431    # Re-create everything except slot variables.
432    for node_id, proto in self._iter_all_nodes():
433      if node_id in slot_variable_node_ids or nodes.get(node_id) is not None:
434        # Defer recreating slot variables so we can use the public Optimizer
435        # interface.
436        continue
437      node, setter = self._recreate(proto, node_id)
438      nodes[node_id] = node
439      node_setters[node_id] = setter
440
441    # Now that we have created the variables being optimized, we have enough
442    # information to re-create slot variables for them.
443    for node_id, proto in self._iter_all_nodes():
444      optimizer_object = nodes[node_id]
445      for slot_variable_proto in proto.slot_variables:
446        optimized_variable = nodes[
447            slot_variable_proto.original_variable_node_id]
448        slot_variable = optimizer_object.add_slot(
449            var=optimized_variable,
450            slot_name=slot_variable_proto.slot_name)
451        nodes[slot_variable_proto.slot_variable_node_id] = slot_variable
452        node_setters[slot_variable_proto.slot_variable_node_id] = setattr
453
454    # If root object is not loaded, add a dummy root object for checkpoint
455    # compatibility.
456    if 0 not in nodes:
457      nodes[0] = self._recreate_base_user_object()[0]
458
459    self._nodes = [nodes.get(node_id)
460                   for node_id in range(len(self._proto.nodes))]
461    self._node_setters = node_setters
462
463  def _restore_checkpoint(self):
464    """Load state from checkpoint into the deserialized objects."""
465    variables_path = saved_model_utils.get_variables_path(self._export_dir)
466    # TODO(andresp): Clean use of private methods of TrackableSaver.
467    # pylint: disable=protected-access
468    saver = util.TrackableSaver(graph_view.ObjectGraphView(self.get(0)))
469    with ops.device("CPU"):
470      saver._file_prefix_placeholder = constant_op.constant(variables_path)
471    if self._save_options.allow_partial_checkpoint:
472      load_status = saver.restore(variables_path,
473                                  self._checkpoint_options).expect_partial()
474      load_status.assert_nontrivial_match()
475    else:
476      load_status = saver.restore(variables_path, self._checkpoint_options)
477      load_status.assert_existing_objects_matched()
478    checkpoint = load_status._checkpoint
479
480    if not context.executing_eagerly():
481      # When running in eager mode, the `restore` call above has already run and
482      # restored the state of trackables, and calling `position.restore_ops()`
483      # would re-run the restore. In graph mode, that will return a cached list
484      # of ops that must run to restore the object on that position. We have to
485      # wire them in the initializers of the objects so that they get
486      # initialized properly when using common practices (e.g. the ones used by
487      # ManagedSession) without further user action.
488      for object_id, obj in dict(checkpoint.object_by_proto_id).items():
489        position = base.CheckpointPosition(checkpoint=checkpoint,
490                                           proto_id=object_id)
491        restore_ops = position.restore_ops()
492        if restore_ops:
493          if resource_variable_ops.is_resource_variable(obj):
494            if len(restore_ops) == 1:
495              obj._initializer_op = restore_ops[0]
496            else:
497              obj._initializer_op = control_flow_ops.group(*restore_ops)
498          elif isinstance(obj, lookup_ops.LookupInterface):
499            # We don't need to check for eager execution here, since this code
500            # path should only be taken if we are restoring in graph mode.
501            ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, restore_ops)
502          else:
503            raise NotImplementedError(
504                f"Unable to restore state of object {obj} from the checkpoint.")
505
506  def adjust_debug_info_func_names(self, debug_info):
507    """Rewrite func names in the debug info by using the concrete func names."""
508    output_debug_info = graph_debug_info_pb2.GraphDebugInfo()
509    output_debug_info.files[:] = debug_info.files
510    for key in debug_info.traces:
511      node, func = key.split("@")
512      new_func = ""
513      if func in self._concrete_functions:
514        new_func = self._concrete_functions[func].function_def.signature.name
515      output_debug_info.traces[node + "@" + new_func].CopyFrom(
516          debug_info.traces[key])
517    return output_debug_info
518
519  def get(self, node_id):
520    if isinstance(node_id, str):
521      node_id = self._node_path_to_id[node_id]
522    return self._nodes[node_id]
523
524  def _recreate(self, proto, node_id):
525    """Creates a Python object from a SavedObject protocol buffer."""
526    factory = {
527        "user_object": (
528            lambda: self._recreate_user_object(proto.user_object, node_id)),
529        "asset": lambda: self._recreate_asset(proto.asset),
530        "function": lambda: self._recreate_function(proto.function),
531        "bare_concrete_function": functools.partial(
532            self._recreate_bare_concrete_function,
533            proto.bare_concrete_function),
534        "variable": lambda: self._recreate_variable(proto.variable),
535        "constant": lambda: self._recreate_constant(proto.constant),
536        "resource": lambda: self._recreate_resource(proto.resource),
537        "captured_tensor": functools.partial(
538            self._get_tensor_from_fn, proto.captured_tensor),
539    }
540    kind = proto.WhichOneof("kind")
541    if kind not in factory:
542      raise ValueError(f"Unknown SavedObject type: {kind}. Expected one of "
543                       f"{list(factory.keys())}.")
544    return factory[kind]()
545
546  def _recreate_user_object(self, proto, node_id):
547    """Instantiates a SavedUserObject."""
548    looked_up = revived_types.deserialize(proto)
549    if looked_up is None:
550      return self._recreate_base_user_object(proto, node_id)
551    return looked_up
552
553  def _recreate_base_user_object(self, proto=None, node_id=None):
554    del proto, node_id
555    # Note: each user object has its own class. This allows making each one
556    # individually callable by adding a `__call__` method to the classes of
557    # the objects instances that have a `__call__` property.
558
559    class _UserObject(tracking.AutoTrackable):
560      pass
561
562    return _UserObject(), setattr
563
564  def _recreate_asset(self, proto):
565    filename = os.path.join(
566        saved_model_utils.get_assets_dir(self._export_dir),
567        self._asset_file_def[proto.asset_file_def_index].filename)
568    asset = tracking.Asset(filename)
569    if not context.executing_eagerly():
570      ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, asset.asset_path)
571    return asset, setattr
572
573  def _recreate_function(self, proto):
574    return function_deserialization.recreate_function(
575        proto, self._concrete_functions), setattr
576
577  def _recreate_bare_concrete_function(self, proto):
578    return function_deserialization.setup_bare_concrete_function(
579        proto, self._concrete_functions), setattr
580
581  def _recreate_variable(self, proto):
582    name = proto.name if proto.name else None
583    if name is not None:
584      dbg_name = name
585    else:
586      dbg_name = "<variable loaded from saved model>"
587    synchronization, aggregation, trainable = (
588        variables.validate_synchronization_aggregation_trainable(
589            proto.synchronization, proto.aggregation, proto.trainable,
590            name=dbg_name))
591
592    def uninitialized_variable_creator(next_creator, **kwargs):
593      """A variable creator that creates uninitialized variables."""
594      del next_creator
595      return resource_variable_ops.UninitializedVariable(**kwargs)
596
597    # Create a variable_creator_scope that creates uninitialized variables with
598    # a lower priority such that a potential distributed variable_creator_scope
599    # can take precedence.
600    with ops.get_default_graph()._variable_creator_scope(  # pylint: disable=protected-access
601        uninitialized_variable_creator,
602        priority=50):
603      return variables.Variable(
604          shape=proto.shape,
605          dtype=proto.dtype,
606          name=name,
607          trainable=trainable,
608          synchronization=synchronization,
609          aggregation=aggregation), setattr
610
611  def _recreate_constant(self, proto):
612    tensor_proto = self._operation_attributes[proto.operation]["value"].tensor
613    ndarray = tensor_util.MakeNdarray(tensor_proto)
614    if dtypes.as_dtype(tensor_proto.dtype) == dtypes.string:
615      with ops.device("CPU"):
616        imported_constant = constant_op.constant(ndarray)
617    else:
618      imported_constant = constant_op.constant(ndarray)
619    return imported_constant, setattr
620
621  def _get_tensor_from_fn(self, proto):
622    outer_graph = self._concrete_functions[proto.concrete_function].graph
623    captured_tensor = outer_graph.get_tensor_by_name(proto.name)
624    return captured_tensor, setattr
625
626  def _recreate_resource(self, proto):
627    return _RestoredResource(device=proto.device), _setattr_and_track
628
629
630# TODO(b/124205571,b/124092991): Solve destruction of resources.
631class _RestoredResource(tracking.TrackableResource):
632  """Restored SavedResource."""
633
634  def __init__(self, device=""):
635    super(_RestoredResource, self).__init__(device=device)
636
637  def _create_resource(self):
638    raise RuntimeError()
639
640  def _initialize(self):
641    raise RuntimeError()
642
643  # _list_functions_for_serialization expects Function objects, but unlike
644  # _create_resource and _initialize, _destroy_function didn't always exist in
645  # older TrackableResource implementations, so this default stub must be a
646  # Function.
647  @def_function.function
648  def _destroy_resource(self):
649    raise RuntimeError()
650
651  def _list_functions_for_serialization(self, unused_serialization_cache):
652    # Overwrite this method to avoid the implementation of
653    # base class to re-wrap the polymorphic functions into
654    # another layer of `tf.function`.
655    functions = {
656        "_create_resource": self._create_resource,
657        "_initialize": self._initialize,
658        "_destroy_resource": self._destroy_resource,
659    }
660    return functions
661
662
663def _call_attribute(instance, *args, **kwargs):
664  return instance.__call__(*args, **kwargs)
665
666
667def _setattr_and_track(obj, name, value):
668  """Sets new attribute and marks it as a dependency if Trackable."""
669  setattr(obj, name, value)
670  if isinstance(value, base.Trackable):
671    obj._track_trackable(value, name)  # pylint:disable=protected-access
672
673
674@tf_export("__internal__.saved_model.load_partial", v1=[])
675def load_partial(export_dir, filters, tags=None, options=None):
676  """Partially load a SavedModel (saved from V2).
677
678  Similar to `tf.saved_model.load`, but with an additional argument that
679  lets you specify which nodes to load.
680  `tf.saved_model.load_partial(export_dir, ["root"])` and
681  `tf.saved_model.load(export_dir)` are equivalent.
682
683  Note: This only works for SavedModels saved with TensorFlow V2 from
684  `tf.saved_model.save` or Keras. This will not load SavedModels save from
685  the Estimator API.
686
687  In Tensorflow V2, SavedModel stores the **object graph** of the saved object.
688  The graph contains nodes (`tf.Module`, `tf.Variable`, `tf.function`, Keras
689  layers, etc.) and edges that are the name of the attributes connecting the
690  objects.
691
692  *Example 1*
693
694  ```
695  model = tf.Module()
696  model.child_layer = tf.Module()
697  model.child_layer.v = tf.Variable(5.)
698  tf.saved_model.save(model, '/tmp/model')
699  loaded = tf.__internal__.saved_model.load_partial(
700  ...   '/tmp/model',
701  ...   ['root.child_layer', 'root.child_layer.v'])
702  loaded['root.child_layer'].v.numpy()
703  5.
704  loaded['root.child_layer'].v is loaded['root.child_layer.v']
705  True
706
707  *Example 2*
708  model = tf.Module()
709  model.child_layer = tf.Module()
710  model.child_layer.v = tf.Variable(5.)
711  >>>
712  tf.saved_model.save(model, '/tmp/model')
713  # Create a variable
714  new_variable = tf.Variable(0.)
715  loaded = tf.__internal__.saved_model.load_partial(
716  ...   '/tmp/model',
717  ...   {'root.child_layer': None, 'root.child_layer.v': new_variable})
718  loaded['root.child_layer'].v.numpy()
719  5.
720  new_variable.numpy()
721  5.
722  ```
723
724  **Loading under different distribution strategies**
725  You can load different parts of the model under different distribution
726  strategies. Note that this is very experimental so use with care.
727
728  ```
729  model = tf.Module()
730  model.layer_1 = tf.Module()
731  model.layer_1.v = tf.Variable(5.)
732  model.layer_2 = tf.Module()
733  model.layer_2.v = tf.Variable(7.)
734  tf.saved_model.save(model, '/tmp/model')
735  # Load with no strategy
736  loaded = tf.__internal__.saved_model.load_partial(
737  ...   '/tmp/model',
738  ...   ['root.layer_1'])
739  loaded['root.layer_1'].v
740  <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>
741  strategy = tf.distribute.MirroredStrategy()
742  with strategy.scope():
743  ...   loaded2 = tf.__internal__.saved_model.load_partial(
744  ...     '/tmp/model',
745  ...     ['root.layer_2'])
746  loaded2['root.layer_2'].v
747  MirroredVariable:{
748      0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=7.0>
749  }
750  ```
751
752  Args:
753    export_dir: The SavedModel directory to load from.
754    filters: A list or dictionary where each element or key is a string
755      path to nodes that should be loaded. Node paths consist of all the child
756      attribute names to reach that node in the form: `root.{attribute_name}`.
757      The loader will load all of the specified nodes and their recursive
758      descendants. When this option is defined, the loader will return a
759      dictionary mapping the node paths to the loaded objects.
760    tags: A tag or sequence of tags identifying the MetaGraph to load. Optional
761      if the SavedModel contains a single MetaGraph, as for those exported from
762      `tf.saved_model.save`.
763    options: `tf.saved_model.LoadOptions` object that specifies options for
764      loading.
765
766  Returns:
767    A dictionary mapping node paths from the filter to loaded objects.
768  """
769  return load_internal(export_dir, tags, options, filters=filters)
770
771
772@tf_export("saved_model.load", v1=["saved_model.load_v2"])
773def load(export_dir, tags=None, options=None):
774  """Load a SavedModel from `export_dir`.
775
776  Signatures associated with the SavedModel are available as functions:
777
778  ```python
779  imported = tf.saved_model.load(path)
780  f = imported.signatures["serving_default"]
781  print(f(x=tf.constant([[1.]])))
782  ```
783
784  Objects exported with `tf.saved_model.save` additionally have trackable
785  objects and functions assigned to attributes:
786
787  ```python
788  exported = tf.train.Checkpoint(v=tf.Variable(3.))
789  exported.f = tf.function(
790      lambda x: exported.v * x,
791      input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
792  tf.saved_model.save(exported, path)
793  imported = tf.saved_model.load(path)
794  assert 3. == imported.v.numpy()
795  assert 6. == imported.f(x=tf.constant(2.)).numpy()
796  ```
797
798  _Loading Keras models_
799
800  Keras models are trackable, so they can be saved to SavedModel. The object
801  returned by `tf.saved_model.load` is not a Keras object (i.e. doesn't have
802  `.fit`, `.predict`, etc. methods). A few attributes and functions are still
803  available: `.variables`, `.trainable_variables` and `.__call__`.
804
805  ```python
806  model = tf.keras.Model(...)
807  tf.saved_model.save(model, path)
808  imported = tf.saved_model.load(path)
809  outputs = imported(inputs)
810  ```
811
812  Use `tf.keras.models.load_model` to restore the Keras model.
813
814  _Importing SavedModels from TensorFlow 1.x_
815
816  SavedModels from `tf.estimator.Estimator` or 1.x SavedModel APIs have a flat
817  graph instead of `tf.function` objects. These SavedModels will be loaded with
818  the following attributes:
819
820  * `.signatures`: A dictionary mapping signature names to functions.
821  * `.prune(feeds, fetches) `: A method which allows you to extract
822    functions for new subgraphs. This is equivalent to importing the SavedModel
823    and naming feeds and fetches in a Session from TensorFlow 1.x.
824
825    ```python
826    imported = tf.saved_model.load(path_to_v1_saved_model)
827    pruned = imported.prune("x:0", "out:0")
828    pruned(tf.ones([]))
829    ```
830
831    See `tf.compat.v1.wrap_function` for details.
832  * `.variables`: A list of imported variables.
833  * `.graph`: The whole imported graph.
834  * `.restore(save_path)`: A function that restores variables from a checkpoint
835    saved from `tf.compat.v1.Saver`.
836
837  _Consuming SavedModels asynchronously_
838
839  When consuming SavedModels asynchronously (the producer is a separate
840  process), the SavedModel directory will appear before all files have been
841  written, and `tf.saved_model.load` will fail if pointed at an incomplete
842  SavedModel. Rather than checking for the directory, check for
843  "saved_model_dir/saved_model.pb". This file is written atomically as the last
844  `tf.saved_model.save` file operation.
845
846  Args:
847    export_dir: The SavedModel directory to load from.
848    tags: A tag or sequence of tags identifying the MetaGraph to load. Optional
849      if the SavedModel contains a single MetaGraph, as for those exported from
850      `tf.saved_model.save`.
851    options: `tf.saved_model.LoadOptions` object that specifies options for
852      loading.
853
854  Returns:
855    A trackable object with a `signatures` attribute mapping from signature
856    keys to functions. If the SavedModel was exported by `tf.saved_model.save`,
857    it also points to trackable objects, functions, debug info which it has been
858    saved.
859
860  Raises:
861    ValueError: If `tags` don't match a MetaGraph in the SavedModel.
862  """
863  result = load_internal(export_dir, tags, options)["root"]
864  return result
865
866
867def load_internal(export_dir, tags=None, options=None, loader_cls=Loader,
868                  filters=None):
869  """Loader implementation."""
870  options = options or load_options.LoadOptions()
871  if tags is not None and not isinstance(tags, set):
872    # Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered
873    # sequences for nest.flatten, so we put those through as-is.
874    tags = nest.flatten(tags)
875  saved_model_proto, debug_info = (
876      loader_impl.parse_saved_model_with_debug_info(export_dir))
877
878  if (len(saved_model_proto.meta_graphs) == 1 and
879      saved_model_proto.meta_graphs[0].HasField("object_graph_def")):
880    metrics.IncrementReadApi(_LOAD_V2_LABEL)
881    meta_graph_def = saved_model_proto.meta_graphs[0]
882    # tensor_content field contains raw bytes in litle endian format
883    # which causes problems when loaded on big-endian systems
884    # requiring byteswap
885    if sys.byteorder == "big":
886      saved_model_utils.swap_function_tensor_content(meta_graph_def, "little",
887                                                     "big")
888    if (tags is not None
889        and set(tags) != set(meta_graph_def.meta_info_def.tags)):
890      raise ValueError(
891          "Got an incompatible argument to `tags`: {tags}. The SavedModel at "
892          f"{export_dir} has one MetaGraph with tags "
893          f"{meta_graph_def.meta_info_def.tags}. You may omit the argument, "
894          "pass 'None', or pass matching tags.")
895    object_graph_proto = meta_graph_def.object_graph_def
896
897    ckpt_options = checkpoint_options.CheckpointOptions(
898        experimental_io_device=options.experimental_io_device)
899    with ops.init_scope():
900      try:
901        loader = loader_cls(object_graph_proto, saved_model_proto, export_dir,
902                            ckpt_options, options, filters)
903      except errors.NotFoundError as err:
904        raise FileNotFoundError(
905            str(err) + "\n You may be trying to load on a different device "
906            "from the computational device. Consider setting the "
907            "`experimental_io_device` option in `tf.saved_model.LoadOptions` "
908            "to the io_device such as '/job:localhost'.")
909      root = loader.get(0)
910      if isinstance(loader, Loader):
911        root.graph_debug_info = loader.adjust_debug_info_func_names(debug_info)
912    root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version
913    root.tensorflow_git_version = (
914        meta_graph_def.meta_info_def.tensorflow_git_version)
915    metrics.IncrementRead(write_version="2")
916  else:
917    if filters:
918      raise ValueError("SavedModels saved from Tensorflow 1.x or Estimator (any"
919                       " version) cannot be loaded with node filters.")
920    with ops.init_scope():
921      root = load_v1_in_v2.load(export_dir, tags)
922      root.graph_debug_info = debug_info
923
924  if filters:
925    return {node_id: loader.get(node_id) for node_id in filters}
926  else:
927    return {"root": root}
928