• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helpers for working with signatures in tf.saved_model.save."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22
23from tensorflow.python.eager import def_function
24from tensorflow.python.eager import function as defun
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_spec
27from tensorflow.python.saved_model import revived_types
28from tensorflow.python.saved_model import signature_constants
29from tensorflow.python.training.tracking import base
30from tensorflow.python.util import compat
31from tensorflow.python.util import nest
32
33
34DEFAULT_SIGNATURE_ATTR = "_default_save_signature"
35SIGNATURE_ATTRIBUTE_NAME = "signatures"
36
37
38def _get_signature(function):
39  if (isinstance(function, (defun.Function, def_function.Function)) and
40      function.input_signature is not None):
41    function = function.get_concrete_function()
42  if not isinstance(function, defun.ConcreteFunction):
43    return None
44  return function
45
46
47def _valid_signature(concrete_function):
48  """Returns whether concrete function can be converted to a signature."""
49  if not concrete_function.outputs:
50    # Functions without outputs don't make sense as signatures. We just don't
51    # have any way to run an Operation with no outputs as a SignatureDef in the
52    # 1.x style.
53    return False
54  try:
55    _normalize_outputs(concrete_function.structured_outputs, "unused", "unused")
56  except ValueError:
57    return False
58  return True
59
60
61def find_function_to_export(saveable_view):
62  """Function to export, None if no suitable function was found."""
63  # If the user did not specify signatures, check the root object for a function
64  # that can be made into a signature.
65  functions = saveable_view.list_functions(saveable_view.root)
66  signature = functions.get(DEFAULT_SIGNATURE_ATTR, None)
67  if signature is not None:
68    return signature
69
70  # TODO(andresp): Discuss removing this behaviour. It can lead to WTFs when a
71  # user decides to annotate more functions with tf.function and suddenly
72  # serving that model way later in the process stops working.
73  possible_signatures = []
74  for function in functions.values():
75    concrete = _get_signature(function)
76    if concrete is not None and _valid_signature(concrete):
77      possible_signatures.append(concrete)
78  if len(possible_signatures) == 1:
79    single_function = possible_signatures[0]
80    signature = _get_signature(single_function)
81    if signature and  _valid_signature(signature):
82      return signature
83  return None
84
85
86def canonicalize_signatures(signatures):
87  """Converts `signatures` into a dictionary of concrete functions."""
88  if signatures is None:
89    return {}
90  if not isinstance(signatures, collections.Mapping):
91    signatures = {
92        signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signatures}
93  concrete_signatures = {}
94  for signature_key, function in signatures.items():
95    signature_function = _get_signature(function)
96    if signature_function is None:
97      raise ValueError(
98          ("Expected a TensorFlow function to generate a signature for, but "
99           "got {}. Only `tf.functions` with an input signature or "
100           "concrete functions can be used as a signature.").format(function))
101
102    # Re-wrap the function so that it returns a dictionary of Tensors. This
103    # matches the format of 1.x-style signatures.
104    # pylint: disable=cell-var-from-loop
105    @def_function.function
106    def signature_wrapper(**kwargs):
107      structured_outputs = signature_function(**kwargs)
108      return _normalize_outputs(
109          structured_outputs, signature_function.name, signature_key)
110    # TODO(b/123902469): Use ConcreteFunction.structured_inputs once their names
111    # always match keyword arguments.
112    tensor_spec_signature = {}
113    for keyword, tensor in zip(
114        signature_function._arg_keywords,  # pylint: disable=protected-access
115        signature_function.inputs):
116      keyword = compat.as_str(keyword)
117      tensor_spec_signature[keyword] = tensor_spec.TensorSpec.from_tensor(
118          tensor, name=keyword)
119    final_concrete = signature_wrapper.get_concrete_function(
120        **tensor_spec_signature)
121    # pylint: disable=protected-access
122    if len(final_concrete._arg_keywords) == 1:
123      # If there is only one input to the signature, a very common case, then
124      # ordering is unambiguous and we can let people pass a positional
125      # argument. Since SignatureDefs are unordered (protobuf "map") multiple
126      # arguments means we need to be keyword-only.
127      final_concrete._num_positional_args = 1
128    else:
129      final_concrete._num_positional_args = 0
130    # pylint: enable=protected-access
131    concrete_signatures[signature_key] = final_concrete
132    # pylint: enable=cell-var-from-loop
133  return concrete_signatures
134
135
136def _is_flat(sequence):
137  sequence_flat = nest.flatten(sequence)
138  try:
139    nest.assert_same_structure(sequence_flat, sequence)
140    return True
141  except ValueError:
142    return False
143  except TypeError:
144    return False
145
146
147def _normalize_outputs(outputs, function_name, signature_key):
148  """Construct an output dictionary from unnormalized function outputs."""
149  if isinstance(outputs, collections.Mapping):
150    for key, value in outputs.items():
151      if not isinstance(value, ops.Tensor):
152        raise ValueError(
153            ("Got a dictionary containing non-Tensor value {} for key {} "
154             "in the output of the function {} used to generate a SavedModel "
155             "signature. Dictionaries outputs for functions used as signatures "
156             "should have one Tensor output per string key.")
157            .format(value, key, compat.as_str_any(function_name)))
158    return outputs
159  else:
160    original_outputs = outputs
161    if not isinstance(outputs, collections.Sequence):
162      outputs = [outputs]
163    if not _is_flat(outputs):
164      raise ValueError(
165          ("Got non-flat outputs '{}' from '{}' for SavedModel "
166           "signature '{}'. Signatures have one Tensor per output, so "
167           "to have predictable names Python functions used to generate "
168           "these signatures should avoid outputting Tensors in nested "
169           "structures.")
170          .format(original_outputs, function_name, signature_key))
171    return {("output_{}".format(output_index)): output
172            for output_index, output
173            in enumerate(outputs)}
174
175
176# _SignatureMap is immutable to ensure that users do not expect changes to be
177# reflected in the SavedModel. Using public APIs, tf.saved_model.load() is the
178# only way to create a _SignatureMap and there is no way to modify it. So we can
179# safely ignore/overwrite ".signatures" attributes attached to objects being
180# saved if they contain a _SignatureMap. A ".signatures" attribute containing
181# any other type (e.g. a regular dict) will raise an exception asking the user
182# to first "del obj.signatures" if they want it overwritten.
183class _SignatureMap(collections.Mapping, base.Trackable):
184  """A collection of SavedModel signatures."""
185
186  def __init__(self):
187    self._signatures = {}
188
189  def _add_signature(self, name, concrete_function):
190    """Adds a signature to the _SignatureMap."""
191    # Ideally this object would be immutable, but restore is streaming so we do
192    # need a private API for adding new signatures to an existing object.
193    self._signatures[name] = concrete_function
194
195  def __getitem__(self, key):
196    return self._signatures[key]
197
198  def __iter__(self):
199    return iter(self._signatures)
200
201  def __len__(self):
202    return len(self._signatures)
203
204  def __repr__(self):
205    return "_SignatureMap({})".format(self._signatures)
206
207  def _list_functions_for_serialization(self):
208    return {
209        key: value for key, value in self.items()
210        if isinstance(value, (def_function.Function, defun.ConcreteFunction))
211    }
212
213
214revived_types.register_revived_type(
215    "signature_map",
216    lambda obj: isinstance(obj, _SignatureMap),
217    versions=[revived_types.VersionedTypeRegistration(
218        # Standard dependencies are enough to reconstruct the trackable
219        # items in dictionaries, so we don't need to save any extra information.
220        object_factory=lambda proto: _SignatureMap(),
221        version=1,
222        min_producer_version=1,
223        min_consumer_version=1,
224        setter=_SignatureMap._add_signature  # pylint: disable=protected-access
225    )])
226
227
228def create_signature_map(signatures):
229  """Creates an object containing `signatures`."""
230  signature_map = _SignatureMap()
231  for name, func in signatures.items():
232    # This true of any signature that came from canonicalize_signatures. Here as
233    # a sanity check on saving; crashing on load (e.g. in _add_signature) would
234    # be more problematic in case future export changes violated these
235    # assertions.
236    assert isinstance(func, defun.ConcreteFunction)
237    assert isinstance(func.structured_outputs, collections.Mapping)
238    # pylint: disable=protected-access
239    if len(func._arg_keywords) == 1:
240      assert 1 == func._num_positional_args
241    else:
242      assert 0 == func._num_positional_args
243    signature_map._add_signature(name, func)
244    # pylint: enable=protected-access
245  return signature_map
246
247
248def validate_saveable_view(saveable_view):
249  """Performs signature-related sanity checks on `saveable_view`."""
250  for name, dep in saveable_view.list_dependencies(
251      saveable_view.root):
252    if name == SIGNATURE_ATTRIBUTE_NAME:
253      if not isinstance(dep, _SignatureMap):
254        raise ValueError(
255            ("Exporting an object {} which has an attribute named "
256             "'{signatures}'. This is a reserved attribute used to store "
257             "SavedModel signatures in objects which come from "
258             "`tf.saved_model.load`. Delete this attribute "
259             "(e.g. 'del obj.{signatures}') before saving if this shadowing is "
260             "acceptable.").format(
261                 saveable_view.root,
262                 signatures=SIGNATURE_ATTRIBUTE_NAME))
263      break
264