• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Exports a SavedModel from a Trackable Python object."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import os
23
24from tensorflow.core.framework import versions_pb2
25from tensorflow.core.protobuf import meta_graph_pb2
26from tensorflow.core.protobuf import saved_model_pb2
27from tensorflow.core.protobuf import saved_object_graph_pb2
28from tensorflow.python.eager import context
29from tensorflow.python.eager import def_function
30from tensorflow.python.eager import function as defun
31from tensorflow.python.framework import constant_op
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import meta_graph
34from tensorflow.python.framework import ops
35from tensorflow.python.lib.io import file_io
36from tensorflow.python.ops import array_ops
37from tensorflow.python.ops import control_flow_ops
38from tensorflow.python.ops import resource_variable_ops
39from tensorflow.python.saved_model import builder_impl
40from tensorflow.python.saved_model import constants
41from tensorflow.python.saved_model import function_serialization
42from tensorflow.python.saved_model import nested_structure_coder
43from tensorflow.python.saved_model import revived_types
44from tensorflow.python.saved_model import signature_constants
45from tensorflow.python.saved_model import signature_def_utils
46from tensorflow.python.saved_model import signature_serialization
47from tensorflow.python.saved_model import tag_constants
48from tensorflow.python.saved_model import utils_impl
49from tensorflow.python.training.saving import functional_saver
50from tensorflow.python.training.tracking import base
51from tensorflow.python.training.tracking import graph_view
52from tensorflow.python.training.tracking import object_identity
53from tensorflow.python.training.tracking import tracking
54from tensorflow.python.training.tracking import util
55from tensorflow.python.util import compat
56from tensorflow.python.util.tf_export import tf_export
57
58_UNCOPIABLE_DTYPES = frozenset((dtypes.resource, dtypes.variant))
59
60
61# A container for an EagerTensor constant which has been copied to the exported
62# Graph.
63_CapturedConstant = collections.namedtuple(
64    "_CapturedConstant", ["eager_tensor", "graph_tensor"])
65
66
67class _AugmentedGraphView(graph_view.ObjectGraphView):
68  """An extendable graph which also tracks functions attached to objects.
69
70  Extensions through `add_object` appear in the object graph and any checkpoints
71  generated from it, even if they are not dependencies of the node they were
72  attached to in the saving program. For example a `.signatures` attribute is
73  added to exported SavedModel root objects without modifying the root object
74  itself.
75
76  Also tracks functions attached to objects in the graph, through the caching
77  `list_functions` method. Enumerating functions only through this method
78  ensures that we get a consistent view of functions, even if object attributes
79  create new functions every time they are accessed.
80  """
81
82  def __init__(self, root):
83    super(_AugmentedGraphView, self).__init__(root)
84    # Object -> (name -> dep)
85    self._extra_dependencies = object_identity.ObjectIdentityDictionary()
86    self._functions = object_identity.ObjectIdentityDictionary()
87
88  def add_object(self, parent_node, name_in_parent, subgraph_root):
89    """Attach an object to `parent_node`, overriding any existing dependency."""
90    self._extra_dependencies.setdefault(
91        parent_node, {})[name_in_parent] = subgraph_root
92
93  def list_dependencies(self, obj):
94    """Overrides a parent method to include `add_object` objects."""
95    extra_dependencies = self._extra_dependencies.get(obj, {})
96    used_names = set()
97    for name, dep in super(_AugmentedGraphView, self).list_dependencies(obj):
98      used_names.add(name)
99      if name in extra_dependencies:
100        yield base.TrackableReference(name, extra_dependencies[name])
101      else:
102        yield base.TrackableReference(name, dep)
103    for name, dep in extra_dependencies.items():
104      if name in used_names:
105        continue
106      yield base.TrackableReference(name, dep)
107
108  def list_functions(self, obj):
109    obj_functions = self._functions.get(obj, None)
110    if obj_functions is None:
111      obj_functions = obj._list_functions_for_serialization()  # pylint: disable=protected-access
112      self._functions[obj] = obj_functions
113    return obj_functions
114
115
116class _SaveableView(object):
117  """Provides a frozen view over a trackable root.
118
119  This class helps creating a single stable view over an object to save. The
120  saving code should access properties and functions via this class and not via
121  the original object as there are cases where an object construct their
122  trackable attributes and functions dynamically per call and will yield
123  different objects if invoked more than once.
124
125  Changes to the graph, for example adding objects, must happen in
126  `checkpoint_view` (an `_AugmentedGraphView`) before the `_SaveableView` is
127  constructed. Changes after the `_SaveableView` has been constructed will be
128  ignored.
129  """
130
131  def __init__(self, checkpoint_view):
132    self.checkpoint_view = checkpoint_view
133    trackable_objects, node_ids, slot_variables = (
134        self.checkpoint_view.objects_ids_and_slot_variables())
135    self.nodes = trackable_objects
136    self.node_ids = node_ids
137    self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary()
138    self.slot_variables = slot_variables
139    self.concrete_functions = []
140
141    # Also add `Function`s as nodes.
142    nodes_without_functions = list(self.nodes)
143    seen_function_names = set()
144    for node in nodes_without_functions:
145      for function in checkpoint_view.list_functions(node).values():
146        if function not in self.node_ids:
147          self.node_ids[function] = len(self.nodes)
148          self.nodes.append(function)
149        if isinstance(function, def_function.Function):
150          # Force listing the concrete functions for the side effects:
151          #  - populate the cache for functions that have an input_signature
152          #  and have not been called.
153          #  - force side effects of creation of concrete functions, e.g. create
154          #  variables on first run.
155          concrete_functions = (
156              function._list_all_concrete_functions_for_serialization())  # pylint: disable=protected-access
157        else:
158          concrete_functions = [function]
159        for concrete_function in concrete_functions:
160          if concrete_function.name not in seen_function_names:
161            seen_function_names.add(concrete_function.name)
162            self.concrete_functions.append(concrete_function)
163
164  @property
165  def root(self):
166    return self.nodes[0]
167
168  def fill_object_graph_proto(self, proto):
169    """Populate the nodes, children and slot_variables of a SavedObjectGraph."""
170    for node_id, node in enumerate(self.nodes):
171      assert self.node_ids[node] == node_id
172      object_proto = proto.nodes.add()
173      object_proto.slot_variables.extend(self.slot_variables.get(node, ()))
174      if isinstance(node, (def_function.Function, defun.ConcreteFunction,
175                           _CapturedConstant)):
176        continue
177      for child in self.checkpoint_view.list_dependencies(node):
178        child_proto = object_proto.children.add()
179        child_proto.node_id = self.node_ids[child.ref]
180        child_proto.local_name = child.name
181      for local_name, ref_function in (
182          self.checkpoint_view.list_functions(node).items()):
183        child_proto = object_proto.children.add()
184        child_proto.node_id = self.node_ids[ref_function]
185        child_proto.local_name = local_name
186
187  def map_resources(self):
188    """Makes new resource handle ops corresponding to existing resource tensors.
189
190    Creates resource handle ops in the current default graph, whereas
191    `accessible_objects` will be from an eager context. Resource mapping adds
192    resource handle ops to the main GraphDef of a SavedModel, which allows the
193    C++ loader API to interact with variables.
194
195    Returns:
196      A tuple of (object_map, resource_map, asset_info):
197        object_map: A dictionary mapping from object in `accessible_objects` to
198          replacement objects created to hold the new resource tensors.
199        resource_map: A dictionary mapping from resource tensors extracted from
200          `accessible_objects` to newly created resource tensors.
201        asset_info: An _AssetInfo tuple describing external assets referenced
202          from accessible_objects.
203    """
204    # Only makes sense when adding to the export Graph
205    assert not context.executing_eagerly()
206    # TODO(allenl): Handle MirroredVariables and other types of variables which
207    # may need special casing.
208    object_map = object_identity.ObjectIdentityDictionary()
209    resource_map = {}
210    asset_info = _AssetInfo(
211        asset_defs=[],
212        asset_initializers_by_resource={},
213        asset_filename_map={},
214        asset_index={})
215    for node_id, obj in enumerate(self.nodes):
216      if isinstance(obj, tracking.TrackableResource):
217        new_resource = obj._create_resource()  # pylint: disable=protected-access
218        resource_map[obj.resource_handle] = new_resource
219        self.captured_tensor_node_ids[obj.resource_handle] = node_id
220      elif resource_variable_ops.is_resource_variable(obj):
221        new_variable = resource_variable_ops.copy_to_graph_uninitialized(obj)
222        object_map[obj] = new_variable
223        resource_map[obj.handle] = new_variable.handle
224        self.captured_tensor_node_ids[obj.handle] = node_id
225      elif isinstance(obj, tracking.TrackableAsset):
226        _process_asset(obj, asset_info, resource_map)
227        self.captured_tensor_node_ids[obj.asset_path] = node_id
228
229    for concrete_function in self.concrete_functions:
230      for capture in concrete_function.captured_inputs:
231        if (isinstance(capture, ops.EagerTensor)
232            and capture.dtype not in _UNCOPIABLE_DTYPES
233            and capture not in self.captured_tensor_node_ids):
234          copied_tensor = constant_op.constant(capture.numpy())
235          node_id = len(self.nodes)
236          node = _CapturedConstant(
237              eager_tensor=capture, graph_tensor=copied_tensor)
238          self.nodes.append(node)
239          self.node_ids[capture] = node_id
240          self.node_ids[node] = node_id
241          self.captured_tensor_node_ids[capture] = node_id
242          resource_map[capture] = copied_tensor
243
244    return object_map, resource_map, asset_info
245
246
247def _tensor_dict_to_tensorinfo(tensor_dict):
248  return {key: utils_impl.build_tensor_info_internal(value)
249          for key, value in tensor_dict.items()}
250
251
252def _map_captures_to_created_tensors(
253    original_captures, resource_map):
254  """Maps eager tensors captured by a function to Graph resources for export.
255
256  Args:
257    original_captures: A dictionary mapping from tensors captured by the
258      function to interior placeholders for those tensors (inside the function
259      body).
260    resource_map: A dictionary mapping from resource tensors owned by the eager
261      context to resource tensors in the exported graph.
262
263  Returns:
264    A list of stand-in tensors which belong to the exported graph, corresponding
265    to the function's captures.
266
267  Raises:
268    AssertionError: If the function references a resource which is not part of
269      `resource_map`.
270  """
271  export_captures = []
272  for exterior, interior in original_captures.items():
273    mapped_resource = resource_map.get(exterior, None)
274    if mapped_resource is None:
275      raise AssertionError(
276          ("Tried to export a function which references untracked object {}."
277           "TensorFlow objects (e.g. tf.Variable) captured by functions must "
278           "be tracked by assigning them to an attribute of a tracked object "
279           "or assigned to an attribute of the main object directly.")
280          .format(interior))
281    export_captures.append(mapped_resource)
282  return export_captures
283
284
285def _map_function_arguments_to_created_inputs(
286    function_arguments, signature_key, function_name):
287  """Creates exterior placeholders in the exported graph for function arguments.
288
289  Functions have two types of inputs: tensors captured from the outside (eager)
290  context, and arguments to the function which we expect to receive from the
291  user at each call. `_map_captures_to_created_tensors` replaces
292  captured tensors with stand-ins (typically these are resource dtype tensors
293  associated with variables). `_map_function_inputs_to_created_inputs` runs over
294  every argument, creating a new placeholder for each which will belong to the
295  exported graph rather than the function body.
296
297  Args:
298    function_arguments: A list of argument placeholders in the function body.
299    signature_key: The name of the signature being exported, for error messages.
300    function_name: The name of the function, for error messages.
301
302  Returns:
303    A tuple of (mapped_inputs, exterior_placeholders)
304      mapped_inputs: A list with entries corresponding to `function_arguments`
305        containing all of the inputs of the function gathered from the exported
306        graph (both captured resources and arguments).
307      exterior_argument_placeholders: A dictionary mapping from argument names
308        to placeholders in the exported graph, containing the explicit arguments
309        to the function which a user is expected to provide.
310
311  Raises:
312    ValueError: If argument names are not unique.
313  """
314  # `exterior_argument_placeholders` holds placeholders which are outside the
315  # function body, directly contained in a MetaGraph of the SavedModel. The
316  # function body itself contains nearly identical placeholders used when
317  # running the function, but these exterior placeholders allow Session-based
318  # APIs to call the function using feeds and fetches which name Tensors in the
319  # MetaGraph.
320  exterior_argument_placeholders = {}
321  mapped_inputs = []
322  for placeholder in function_arguments:
323    # `export_captures` contains an exhaustive set of captures, so if we don't
324    # find the input there then we now know we have an argument.
325    user_input_name = compat.as_str_any(
326        placeholder.op.get_attr("_user_specified_name"))
327    # If the internal placeholders for a function have names which were
328    # uniquified by TensorFlow, then a single user-specified argument name
329    # must refer to multiple Tensors. The resulting signatures would be
330    # confusing to call. Instead, we throw an exception telling the user to
331    # specify explicit names.
332    if user_input_name != placeholder.op.name:
333      # This should be unreachable, since concrete functions may not be
334      # generated with non-unique argument names.
335      raise ValueError(
336          ("Got non-flat/non-unique argument names for SavedModel "
337           "signature '{}': more than one argument to '{}' was named '{}'. "
338           "Signatures have one Tensor per named input, so to have "
339           "predictable names Python functions used to generate these "
340           "signatures should avoid *args and Tensors in nested "
341           "structures unless unique names are specified for each. Use "
342           "tf.TensorSpec(..., name=...) to provide a name for a Tensor "
343           "input.")
344          .format(signature_key, compat.as_str_any(function_name),
345                  user_input_name))
346    arg_placeholder = array_ops.placeholder(
347        shape=placeholder.shape,
348        dtype=placeholder.dtype,
349        name="{}_{}".format(signature_key, user_input_name))
350    exterior_argument_placeholders[user_input_name] = arg_placeholder
351    mapped_inputs.append(arg_placeholder)
352  return mapped_inputs, exterior_argument_placeholders
353
354
355def _call_function_with_mapped_captures(function, args, resource_map):
356  """Calls `function` in the exported graph, using mapped resource captures."""
357  export_captures = _map_captures_to_created_tensors(
358      function.graph.captures, resource_map)
359  mapped_inputs = args + export_captures
360  # Calls the function quite directly, since we have new captured resource
361  # tensors we need to feed in which weren't part of the original function
362  # definition.
363  # pylint: disable=protected-access
364  outputs = function._build_call_outputs(
365      function._inference_function.call(context.context(), mapped_inputs))
366  return outputs
367
368
369def _generate_signatures(signature_functions, resource_map):
370  """Validates and calls `signature_functions` in the default graph.
371
372  Args:
373    signature_functions: A dictionary mapping string keys to concrete TensorFlow
374      functions (e.g. from `signature_serialization.canonicalize_signatures`)
375      which will be used to generate SignatureDefs.
376    resource_map: A dictionary mapping from resource tensors in the eager
377      context to resource tensors in the Graph being exported. This dictionary
378      is used to re-bind resources captured by functions to tensors which will
379      exist in the SavedModel.
380
381  Returns:
382    Each function in the `signature_functions` dictionary is called with
383    placeholder Tensors, generating a function call operation and output
384    Tensors. The placeholder Tensors, the function call operation, and the
385    output Tensors from the function call are part of the default Graph.
386
387    This function then returns a dictionary with the same structure as
388    `signature_functions`, with the concrete functions replaced by SignatureDefs
389    implicitly containing information about how to call each function from a
390    TensorFlow 1.x Session / the C++ Loader API. These SignatureDefs reference
391    the generated placeholders and Tensor outputs by name.
392
393    The caller is expected to include the default Graph set while calling this
394    function as a MetaGraph in a SavedModel, including the returned
395    SignatureDefs as part of that MetaGraph.
396  """
397  signatures = {}
398  for signature_key, function in sorted(signature_functions.items()):
399    if function.graph.captures:
400      argument_inputs = function.graph.inputs[:-len(function.graph.captures)]
401    else:
402      argument_inputs = function.graph.inputs
403    mapped_inputs, exterior_argument_placeholders = (
404        _map_function_arguments_to_created_inputs(
405            argument_inputs, signature_key, function.name))
406    outputs = _call_function_with_mapped_captures(
407        function, mapped_inputs, resource_map)
408    signatures[signature_key] = signature_def_utils.build_signature_def(
409        _tensor_dict_to_tensorinfo(exterior_argument_placeholders),
410        _tensor_dict_to_tensorinfo(outputs),
411        method_name=signature_constants.PREDICT_METHOD_NAME)
412  return signatures
413
414
415def _trace_resource_initializers(accessible_objects):
416  """Create concrete functions from `TrackableResource` objects."""
417  resource_initializers = []
418
419  def _wrap_initializer(obj):
420    obj._initialize()  # pylint: disable=protected-access
421    return constant_op.constant(1.)  # Dummy control output
422
423  def _wrap_obj_initializer(obj):
424    return lambda: _wrap_initializer(obj)
425
426  for obj in accessible_objects:
427    if isinstance(obj, tracking.TrackableResource):
428      resource_initializers.append(def_function.function(
429          _wrap_obj_initializer(obj),
430          # All inputs are captures.
431          input_signature=[]).get_concrete_function())
432  return resource_initializers
433
434
435_AssetInfo = collections.namedtuple(
436    "_AssetInfo", [
437        # List of AssetFileDef protocol buffers
438        "asset_defs",
439        # Map from asset variable resource Tensors to their init ops
440        "asset_initializers_by_resource",
441        # Map from base asset filenames to full paths
442        "asset_filename_map",
443        # Map from TrackableAsset to index of corresponding AssetFileDef
444        "asset_index"])
445
446
447def _process_asset(trackable_asset, asset_info, resource_map):
448  """Add `trackable_asset` to `asset_info` and `resource_map`."""
449  original_variable = trackable_asset.asset_path
450  with context.eager_mode():
451    original_path = original_variable.numpy()
452  path = builder_impl.get_asset_filename_to_add(
453      asset_filepath=original_path,
454      asset_filename_map=asset_info.asset_filename_map)
455  # TODO(andresp): Instead of mapping 1-1 between trackable asset
456  # and asset in the graph def consider deduping the assets that
457  # point to the same file.
458  asset_path_initializer = array_ops.placeholder(
459      shape=original_variable.shape,
460      dtype=dtypes.string,
461      name="asset_path_initializer")
462  asset_variable = resource_variable_ops.ResourceVariable(
463      asset_path_initializer)
464  asset_info.asset_filename_map[path] = original_path
465  asset_def = meta_graph_pb2.AssetFileDef()
466  asset_def.filename = path
467  asset_def.tensor_info.name = asset_path_initializer.name
468  asset_info.asset_defs.append(asset_def)
469  asset_info.asset_initializers_by_resource[original_variable] = (
470      asset_variable.initializer)
471  asset_info.asset_index[trackable_asset] = len(asset_info.asset_defs) - 1
472  resource_map[original_variable] = asset_variable
473
474
475def _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions):
476  """Generates a MetaGraph which calls `signature_functions`.
477
478  Args:
479    meta_graph_def: The MetaGraphDef proto to fill.
480    saveable_view: The _SaveableView being exported.
481    signature_functions: A dictionary mapping signature keys to concrete
482      functions containing signatures to add to the MetaGraph.
483
484  Returns:
485    An _AssetInfo, which contains information to help creating the SavedModel.
486  """
487  # List objects from the eager context to make sure Optimizers give us the
488  # right Graph-dependent variables.
489  accessible_objects = saveable_view.nodes
490  resource_initializer_functions = _trace_resource_initializers(
491      accessible_objects)
492  exported_graph = ops.Graph()
493  resource_initializer_ops = []
494  with exported_graph.as_default():
495    object_map, resource_map, asset_info = saveable_view.map_resources()
496    for resource_initializer_function in resource_initializer_functions:
497      asset_dependencies = []
498      for capture in resource_initializer_function.graph.external_captures:
499        asset_initializer = asset_info.asset_initializers_by_resource.get(
500            capture, None)
501        if asset_initializer is not None:
502          asset_dependencies.append(asset_initializer)
503      with ops.control_dependencies(asset_dependencies):
504        resource_initializer_ops.append(
505            _call_function_with_mapped_captures(
506                resource_initializer_function, [], resource_map))
507    resource_initializer_ops.extend(
508        asset_info.asset_initializers_by_resource.values())
509    with ops.control_dependencies(resource_initializer_ops):
510      init_op = control_flow_ops.no_op()
511    # Add the same op to the main_op collection and to the init_op
512    # signature. The collection is for compatibility with older loader APIs;
513    # only one will be executed.
514    meta_graph_def.collection_def[constants.MAIN_OP_KEY].node_list.value.append(
515        init_op.name)
516    meta_graph_def.signature_def[constants.INIT_OP_SIGNATURE_KEY].CopyFrom(
517        signature_def_utils.op_signature_def(
518            init_op, constants.INIT_OP_SIGNATURE_KEY))
519
520  # Saving an object-based checkpoint again gathers variables. We need to do the
521  # gathering from the eager context so Optimizers save the right set of
522  # variables, but want any operations associated with the save/restore to be in
523  # the exported graph (thus the `to_graph` argument).
524  saver = functional_saver.Saver(
525      saveable_view.checkpoint_view.frozen_saveable_objects(
526          object_map=object_map, to_graph=exported_graph))
527
528  with exported_graph.as_default():
529    signatures = _generate_signatures(signature_functions, resource_map)
530    for concrete_function in saveable_view.concrete_functions:
531      concrete_function.add_to_graph()
532    saver_def = saver.to_proto()
533    meta_graph_def.saver_def.CopyFrom(saver_def)
534  graph_def = exported_graph.as_graph_def(add_shapes=True)
535
536  meta_graph_def.graph_def.CopyFrom(graph_def)
537  meta_graph_def.meta_info_def.tags.append(tag_constants.SERVING)
538  meta_graph_def.asset_file_def.extend(asset_info.asset_defs)
539  for signature_key, signature in signatures.items():
540    meta_graph_def.signature_def[signature_key].CopyFrom(signature)
541  meta_graph.strip_graph_default_valued_attrs(meta_graph_def)
542  return asset_info, exported_graph
543
544
545def _serialize_object_graph(saveable_view, asset_file_def_index):
546  """Save a SavedObjectGraph proto for `root`."""
547  # SavedObjectGraph is similar to the TrackableObjectGraph proto in the
548  # checkpoint. It will eventually go into the SavedModel.
549  proto = saved_object_graph_pb2.SavedObjectGraph()
550  saveable_view.fill_object_graph_proto(proto)
551
552  coder = nested_structure_coder.StructureCoder()
553  for concrete_function in saveable_view.concrete_functions:
554    serialized = function_serialization.serialize_concrete_function(
555        concrete_function, saveable_view.captured_tensor_node_ids, coder)
556    if serialized is not None:
557      proto.concrete_functions[concrete_function.name].CopyFrom(
558          serialized)
559
560  for obj, obj_proto in zip(saveable_view.nodes, proto.nodes):
561    _write_object_proto(obj, obj_proto, asset_file_def_index)
562  return proto
563
564
565def _write_object_proto(obj, proto, asset_file_def_index):
566  """Saves an object into SavedObject proto."""
567  if isinstance(obj, tracking.TrackableAsset):
568    proto.asset.SetInParent()
569    proto.asset.asset_file_def_index = asset_file_def_index[obj]
570  elif resource_variable_ops.is_resource_variable(obj):
571    proto.variable.SetInParent()
572    proto.variable.trainable = obj.trainable
573    proto.variable.dtype = obj.dtype.as_datatype_enum
574    proto.variable.shape.CopyFrom(obj.shape.as_proto())
575  elif isinstance(obj, def_function.Function):
576    proto.function.CopyFrom(
577        function_serialization.serialize_function(obj))
578  elif isinstance(obj, defun.ConcreteFunction):
579    proto.bare_concrete_function.CopyFrom(
580        function_serialization.serialize_bare_concrete_function(obj))
581  elif isinstance(obj, _CapturedConstant):
582    proto.constant.operation = obj.graph_tensor.op.name
583  elif isinstance(obj, tracking.TrackableResource):
584    proto.resource.SetInParent()
585  else:
586    registered_type_proto = revived_types.serialize(obj)
587    if registered_type_proto is None:
588      # Fallback for types with no matching registration
589      registered_type_proto = saved_object_graph_pb2.SavedUserObject(
590          identifier="_generic_user_object",
591          version=versions_pb2.VersionDef(
592              producer=1, min_consumer=1, bad_consumers=[]))
593    proto.user_object.CopyFrom(registered_type_proto)
594
595
596@tf_export("saved_model.save",
597           v1=["saved_model.save", "saved_model.experimental.save"])
598def save(obj, export_dir, signatures=None):
599  # pylint: disable=line-too-long
600  """Exports the Trackable object `obj` to [SavedModel format](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md).
601
602  Example usage:
603
604  ```python
605  class Adder(tf.train.Checkpoint):
606
607    @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
608    def add(self, x):
609      return x + x + 1.
610
611  to_export = Adder()
612  tf.saved_model.save(to_export, '/tmp/adder')
613  ```
614
615  The resulting SavedModel is then servable with an input named "x", its value
616  having any shape and dtype float32.
617
618  The optional `signatures` argument controls which methods in `obj` will be
619  available to programs which consume `SavedModel`s, for example serving
620  APIs. Python functions may be decorated with
621  `@tf.function(input_signature=...)` and passed as signatures directly, or
622  lazily with a call to `get_concrete_function` on the method decorated with
623  `@tf.function`.
624
625  If the `signatures` argument is omitted, `obj` will be searched for
626  `@tf.function`-decorated methods. If exactly one `@tf.function` is found, that
627  method will be used as the default signature for the SavedModel. This behavior
628  is expected to change in the future, when a corresponding
629  `tf.saved_model.load` symbol is added. At that point signatures will be
630  completely optional, and any `@tf.function` attached to `obj` or its
631  dependencies will be exported for use with `load`.
632
633  When invoking a signature in an exported SavedModel, `Tensor` arguments are
634  identified by name. These names will come from the Python function's argument
635  names by default. They may be overridden by specifying a `name=...` argument
636  in the corresponding `tf.TensorSpec` object. Explicit naming is required if
637  multiple `Tensor`s are passed through a single argument to the Python
638  function.
639
640  The outputs of functions used as `signatures` must either be flat lists, in
641  which case outputs will be numbered, or a dictionary mapping string keys to
642  `Tensor`, in which case the keys will be used to name outputs.
643
644  Signatures are available in objects returned by `tf.saved_model.load` as a
645  `.signatures` attribute. This is a reserved attribute: `tf.saved_model.save`
646  on an object with a custom `.signatures` attribute will raise an exception.
647
648  Since `tf.keras.Model` objects are also Trackable, this function can be
649  used to export Keras models. For example, exporting with a signature
650  specified:
651
652  ```python
653  class Model(tf.keras.Model):
654
655    @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)])
656    def serve(self, serialized):
657      ...
658
659  m = Model()
660  tf.saved_model.save(m, '/tmp/saved_model/')
661  ```
662
663  Exporting from a function without a fixed signature:
664
665  ```python
666  class Model(tf.keras.Model):
667
668    @tf.function
669    def call(self, x):
670      ...
671
672  m = Model()
673  tf.saved_model.save(
674      m, '/tmp/saved_model/',
675      signatures=m.call.get_concrete_function(
676          tf.TensorSpec(shape=[None, 3], dtype=tf.float32, name="inp")))
677  ```
678
679  `tf.keras.Model` instances constructed from inputs and outputs already have a
680  signature and so do not require a `@tf.function` decorator or a `signatures`
681  argument. If neither are specified, the model's forward pass is exported.
682
683  ```python
684  x = input_layer.Input((4,), name="x")
685  y = core.Dense(5, name="out")(x)
686  model = training.Model(x, y)
687  tf.saved_model.save(model, '/tmp/saved_model/')
688  # The exported SavedModel takes "x" with shape [None, 4] and returns "out"
689  # with shape [None, 5]
690  ```
691
692  Variables must be tracked by assigning them to an attribute of a tracked
693  object or to an attribute of `obj` directly. TensorFlow objects (e.g. layers
694  from `tf.keras.layers`, optimizers from `tf.train`) track their variables
695  automatically. This is the same tracking scheme that `tf.train.Checkpoint`
696  uses, and an exported `Checkpoint` object may be restored as a training
697  checkpoint by pointing `tf.train.Checkpoint.restore` to the SavedModel's
698  "variables/" subdirectory. Currently variables are the only stateful objects
699  supported by `tf.saved_model.save`, but others (e.g. tables) will be supported
700  in the future.
701
702  `tf.function` does not hard-code device annotations from outside the function
703  body, instead using the calling context's device. This means for example that
704  exporting a model which runs on a GPU and serving it on a CPU will generally
705  work, with some exceptions. `tf.device` annotations inside the body of the
706  function will be hard-coded in the exported model; this type of annotation is
707  discouraged. Device-specific operations, e.g. with "cuDNN" in the name or with
708  device-specific layouts, may cause issues. Currently a `DistributionStrategy`
709  is another exception: active distribution strategies will cause device
710  placements to be hard-coded in a function. Exporting a single-device
711  computation and importing under a `DistributionStrategy` is not currently
712  supported, but may be in the future.
713
714  SavedModels exported with `tf.saved_model.save` [strip default-valued
715  attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes)
716  automatically, which removes one source of incompatibilities when the consumer
717  of a SavedModel is running an older TensorFlow version than the
718  producer. There are however other sources of incompatibilities which are not
719  handled automatically, such as when the exported model contains operations
720  which the consumer does not have definitions for.
721
722  The current implementation of `tf.saved_model.save` targets serving use-cases,
723  but omits information which will be necessary for the planned future
724  implementation of `tf.saved_model.load`. Exported models using the current
725  `save` implementation, and other existing SavedModels, will not be compatible
726  with `tf.saved_model.load` when it is implemented. Further, `save` will in the
727  future attempt to export `@tf.function`-decorated methods which it does not
728  currently inspect, so some objects which are exportable today will raise
729  exceptions on export in the future (e.g. due to complex/non-serializable
730  default arguments). Such backwards-incompatible API changes are expected only
731  prior to the TensorFlow 2.0 release.
732
733  Args:
734    obj: A trackable object to export.
735    export_dir: A directory in which to write the SavedModel.
736    signatures: Optional, either a `tf.function` with an input signature
737      specified or the result of `f.get_concrete_function` on a
738      `@tf.function`-decorated function `f`, in which case `f` will be used to
739      generate a signature for the SavedModel under the default serving
740      signature key. `signatures` may also be a dictionary, in which case it
741      maps from signature keys to either `tf.function` instances with input
742      signatures or concrete functions. The keys of such a dictionary may be
743      arbitrary strings, but will typically be from the
744      `tf.saved_model.signature_constants` module.
745
746  Raises:
747    ValueError: If `obj` is not trackable.
748
749  @compatibility(eager)
750  Not supported when graph building. From TensorFlow 1.x,
751  `tf.enable_eager_execution()` must run first. May not be called from within a
752  function body.
753  @end_compatibility
754  """
755  if not context.executing_eagerly():
756    with ops.init_scope():
757      if context.executing_eagerly():
758        raise AssertionError(
759            "tf.saved_model.save is not supported inside a traced "
760            "@tf.function. Move the call to the outer eagerly-executed "
761            "context.")
762      else:
763        raise AssertionError(
764            "tf.saved_model.save is not supported when graph building. "
765            "tf.enable_eager_execution() must run first when calling it from "
766            "TensorFlow 1.x.")
767  # pylint: enable=line-too-long
768  if not isinstance(obj, base.Trackable):
769    raise ValueError(
770        "Expected a Trackable object for export, got {}.".format(obj))
771
772  checkpoint_graph_view = _AugmentedGraphView(obj)
773  if signatures is None:
774    signatures = signature_serialization.find_function_to_export(
775        checkpoint_graph_view)
776
777  signatures = signature_serialization.canonicalize_signatures(signatures)
778  signature_serialization.validate_saveable_view(checkpoint_graph_view)
779  signature_map = signature_serialization.create_signature_map(signatures)
780  checkpoint_graph_view.add_object(
781      parent_node=checkpoint_graph_view.root,
782      name_in_parent=signature_serialization.SIGNATURE_ATTRIBUTE_NAME,
783      subgraph_root=signature_map)
784
785  # Use _SaveableView to provide a frozen listing of properties and functions.
786  # Note we run this twice since, while constructing the view the first time
787  # there can be side effects of creating variables.
788  _ = _SaveableView(checkpoint_graph_view)
789  saveable_view = _SaveableView(checkpoint_graph_view)
790
791  # TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x
792  # compatible (no sessions) and share it with this export API rather than
793  # making a SavedModel proto and writing it directly.
794  saved_model = saved_model_pb2.SavedModel()
795  meta_graph_def = saved_model.meta_graphs.add()
796  object_saver = util.TrackableSaver(checkpoint_graph_view)
797  asset_info, exported_graph = _fill_meta_graph_def(
798      meta_graph_def, saveable_view, signatures)
799  saved_model.saved_model_schema_version = (
800      constants.SAVED_MODEL_SCHEMA_VERSION)
801  # So far we've just been generating protocol buffers with no I/O. Now we write
802  # the checkpoint, copy assets into the assets directory, and write out the
803  # SavedModel proto itself.
804  utils_impl.get_or_create_variables_dir(export_dir)
805  object_saver.save(utils_impl.get_variables_path(export_dir))
806  builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map,
807                                              export_dir)
808  path = os.path.join(
809      compat.as_bytes(export_dir),
810      compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
811  object_graph_proto = _serialize_object_graph(
812      saveable_view, asset_info.asset_index)
813  meta_graph_def.object_graph_def.CopyFrom(object_graph_proto)
814  file_io.write_string_to_file(path, saved_model.SerializeToString())
815  # Clean reference cycles so repeated export()s don't make work for the garbage
816  # collector. Before this point we need to keep references to captured
817  # constants in the saved graph.
818  ops.dismantle_graph(exported_graph)
819