• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helpers for working with signatures in tf.saved_model.save."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from absl import logging
22
23from tensorflow.python.eager import def_function
24from tensorflow.python.eager import function as defun
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_spec
27from tensorflow.python.ops import resource_variable_ops
28from tensorflow.python.saved_model import function_serialization
29from tensorflow.python.saved_model import revived_types
30from tensorflow.python.saved_model import signature_constants
31from tensorflow.python.training.tracking import base
32from tensorflow.python.util import compat
33from tensorflow.python.util import nest
34from tensorflow.python.util.compat import collections_abc
35
36
37DEFAULT_SIGNATURE_ATTR = "_default_save_signature"
38SIGNATURE_ATTRIBUTE_NAME = "signatures"
39# Max number of warnings to show if signature contains normalized input names.
40_NUM_DISPLAY_NORMALIZED_SIGNATURES = 5
41
42
43def _get_signature(function):
44  if (isinstance(function, (defun.Function, def_function.Function)) and
45      function.input_signature is not None):
46    function = function._get_concrete_function_garbage_collected()  # pylint: disable=protected-access
47  if not isinstance(function, defun.ConcreteFunction):
48    return None
49  return function
50
51
52def _valid_signature(concrete_function):
53  """Returns whether concrete function can be converted to a signature."""
54  if not concrete_function.outputs:
55    # Functions without outputs don't make sense as signatures. We just don't
56    # have any way to run an Operation with no outputs as a SignatureDef in the
57    # 1.x style.
58    return False
59  try:
60    _validate_inputs(concrete_function)
61    _normalize_outputs(concrete_function.structured_outputs, "unused", "unused")
62  except ValueError:
63    return False
64  return True
65
66
67def _validate_inputs(concrete_function):
68  """Raises error if input type is tf.Variable."""
69  if any(isinstance(inp, resource_variable_ops.VariableSpec)
70         for inp in nest.flatten(
71             concrete_function.structured_input_signature)):
72    raise ValueError(("Functions that expect tf.Variable inputs cannot be "
73                      "exported as signatures."))
74
75
76def _get_signature_name_changes(concrete_function):
77  """Checks for user-specified signature input names that are normalized."""
78  # Map of {user-given name: normalized name} if the names are un-identical.
79  name_changes = {}
80  for signature_input_name, graph_input in zip(
81      concrete_function.function_def.signature.input_arg,
82      concrete_function.graph.inputs):
83    try:
84      user_specified_name = compat.as_str(
85          graph_input.op.get_attr("_user_specified_name"))
86      if signature_input_name.name != user_specified_name:
87        name_changes[user_specified_name] = signature_input_name.name
88    except ValueError:
89      # Signature input does not have a user-specified name.
90      pass
91  return name_changes
92
93
94def find_function_to_export(saveable_view):
95  """Function to export, None if no suitable function was found."""
96  # If the user did not specify signatures, check the root object for a function
97  # that can be made into a signature.
98  functions = saveable_view.list_functions(saveable_view.root)
99  signature = functions.get(DEFAULT_SIGNATURE_ATTR, None)
100  if signature is not None:
101    return signature
102
103  # TODO(andresp): Discuss removing this behaviour. It can lead to WTFs when a
104  # user decides to annotate more functions with tf.function and suddenly
105  # serving that model way later in the process stops working.
106  possible_signatures = []
107  for function in functions.values():
108    concrete = _get_signature(function)
109    if concrete is not None and _valid_signature(concrete):
110      possible_signatures.append(concrete)
111  if len(possible_signatures) == 1:
112    single_function = possible_signatures[0]
113    signature = _get_signature(single_function)
114    if signature and  _valid_signature(signature):
115      return signature
116  return None
117
118
119def canonicalize_signatures(signatures):
120  """Converts `signatures` into a dictionary of concrete functions."""
121  if signatures is None:
122    return {}, {}
123  if not isinstance(signatures, collections_abc.Mapping):
124    signatures = {
125        signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signatures}
126  num_normalized_signatures_counter = 0
127  concrete_signatures = {}
128  wrapped_functions = {}
129  for signature_key, function in signatures.items():
130    original_function = signature_function = _get_signature(function)
131    if signature_function is None:
132      raise ValueError(
133          ("Expected a TensorFlow function to generate a signature for, but "
134           "got {}. Only `tf.functions` with an input signature or "
135           "concrete functions can be used as a signature.").format(function))
136
137    wrapped_functions[original_function] = signature_function = (
138        wrapped_functions.get(original_function) or
139        function_serialization.wrap_cached_variables(original_function))
140    _validate_inputs(signature_function)
141    if num_normalized_signatures_counter < _NUM_DISPLAY_NORMALIZED_SIGNATURES:
142      signature_name_changes = _get_signature_name_changes(signature_function)
143      if signature_name_changes:
144        num_normalized_signatures_counter += 1
145        logging.warning(
146            "Function `%s` contains input name(s) %s with unsupported "
147            "characters which will be renamed to %s in the SavedModel.",
148            compat.as_str(signature_function.graph.name),
149            ", ".join(signature_name_changes.keys()),
150            ", ".join(signature_name_changes.values()))
151    # Re-wrap the function so that it returns a dictionary of Tensors. This
152    # matches the format of 1.x-style signatures.
153    # pylint: disable=cell-var-from-loop
154    @def_function.function
155    def signature_wrapper(**kwargs):
156      structured_outputs = signature_function(**kwargs)
157      return _normalize_outputs(
158          structured_outputs, signature_function.name, signature_key)
159    tensor_spec_signature = {}
160    if signature_function.structured_input_signature is not None:
161      # The structured input signature may contain other non-tensor arguments.
162      inputs = filter(
163          lambda x: isinstance(x, tensor_spec.TensorSpec),
164          nest.flatten(signature_function.structured_input_signature,
165                       expand_composites=True))
166    else:
167      # Structured input signature isn't always defined for some functions.
168      inputs = signature_function.inputs
169
170    for keyword, inp in zip(
171        signature_function._arg_keywords,  # pylint: disable=protected-access
172        inputs):
173      keyword = compat.as_str(keyword)
174      if isinstance(inp, tensor_spec.TensorSpec):
175        spec = tensor_spec.TensorSpec(inp.shape, inp.dtype, name=keyword)
176      else:
177        spec = tensor_spec.TensorSpec.from_tensor(inp, name=keyword)
178      tensor_spec_signature[keyword] = spec
179    final_concrete = signature_wrapper._get_concrete_function_garbage_collected(  # pylint: disable=protected-access
180        **tensor_spec_signature)
181    # pylint: disable=protected-access
182    if len(final_concrete._arg_keywords) == 1:
183      # If there is only one input to the signature, a very common case, then
184      # ordering is unambiguous and we can let people pass a positional
185      # argument. Since SignatureDefs are unordered (protobuf "map") multiple
186      # arguments means we need to be keyword-only.
187      final_concrete._num_positional_args = 1
188    else:
189      final_concrete._num_positional_args = 0
190    # pylint: enable=protected-access
191    concrete_signatures[signature_key] = final_concrete
192    # pylint: enable=cell-var-from-loop
193  return concrete_signatures, wrapped_functions
194
195
196def _normalize_outputs(outputs, function_name, signature_key):
197  """Construct an output dictionary from unnormalized function outputs."""
198  # Convert `outputs` to a dictionary (if it's not one already).
199  if not isinstance(outputs, collections_abc.Mapping):
200    if not isinstance(outputs, collections_abc.Sequence):
201      outputs = [outputs]
202    outputs = {("output_{}".format(output_index)): output
203               for output_index, output
204               in enumerate(outputs)}
205
206  # Check that the keys of `outputs` are strings and the values are Tensors.
207  for key, value in outputs.items():
208    if not isinstance(key, compat.bytes_or_text_types):
209      raise ValueError(
210          ("Got a dictionary with a non-string key {!r} in the output of the "
211           "function {} used to generate the SavedModel signature {!r}.")
212          .format(key, compat.as_str_any(function_name), signature_key))
213    if not isinstance(value, ops.Tensor):
214      raise ValueError(
215          ("Got a non-Tensor value {!r} for key {!r} in the output of the "
216           "function {} used to generate the SavedModel signature {!r}. "
217           "Outputs for functions used as signatures must be a single Tensor, "
218           "a sequence of Tensors, or a dictionary from string to Tensor.")
219          .format(value, key, compat.as_str_any(function_name), signature_key))
220
221  return outputs
222
223
224# _SignatureMap is immutable to ensure that users do not expect changes to be
225# reflected in the SavedModel. Using public APIs, tf.saved_model.load() is the
226# only way to create a _SignatureMap and there is no way to modify it. So we can
227# safely ignore/overwrite ".signatures" attributes attached to objects being
228# saved if they contain a _SignatureMap. A ".signatures" attribute containing
229# any other type (e.g. a regular dict) will raise an exception asking the user
230# to first "del obj.signatures" if they want it overwritten.
231class _SignatureMap(collections_abc.Mapping, base.Trackable):
232  """A collection of SavedModel signatures."""
233
234  def __init__(self):
235    self._signatures = {}
236
237  def _add_signature(self, name, concrete_function):
238    """Adds a signature to the _SignatureMap."""
239    # Ideally this object would be immutable, but restore is streaming so we do
240    # need a private API for adding new signatures to an existing object.
241    self._signatures[name] = concrete_function
242
243  def __getitem__(self, key):
244    return self._signatures[key]
245
246  def __iter__(self):
247    return iter(self._signatures)
248
249  def __len__(self):
250    return len(self._signatures)
251
252  def __repr__(self):
253    return "_SignatureMap({})".format(self._signatures)
254
255  def _list_functions_for_serialization(self, unused_serialization_cache):
256    return {
257        key: value for key, value in self.items()
258        if isinstance(value, (def_function.Function, defun.ConcreteFunction))
259    }
260
261
262revived_types.register_revived_type(
263    "signature_map",
264    lambda obj: isinstance(obj, _SignatureMap),
265    versions=[revived_types.VersionedTypeRegistration(
266        # Standard dependencies are enough to reconstruct the trackable
267        # items in dictionaries, so we don't need to save any extra information.
268        object_factory=lambda proto: _SignatureMap(),
269        version=1,
270        min_producer_version=1,
271        min_consumer_version=1,
272        setter=_SignatureMap._add_signature  # pylint: disable=protected-access
273    )])
274
275
276def create_signature_map(signatures):
277  """Creates an object containing `signatures`."""
278  signature_map = _SignatureMap()
279  for name, func in signatures.items():
280    # This true of any signature that came from canonicalize_signatures. Here as
281    # a sanity check on saving; crashing on load (e.g. in _add_signature) would
282    # be more problematic in case future export changes violated these
283    # assertions.
284    assert isinstance(func, defun.ConcreteFunction)
285    assert isinstance(func.structured_outputs, collections_abc.Mapping)
286    # pylint: disable=protected-access
287    if len(func._arg_keywords) == 1:
288      assert 1 == func._num_positional_args
289    else:
290      assert 0 == func._num_positional_args
291    signature_map._add_signature(name, func)
292    # pylint: enable=protected-access
293  return signature_map
294
295
296def validate_saveable_view(saveable_view):
297  """Performs signature-related sanity checks on `saveable_view`."""
298  for name, dep in saveable_view.list_dependencies(
299      saveable_view.root):
300    if name == SIGNATURE_ATTRIBUTE_NAME:
301      if not isinstance(dep, _SignatureMap):
302        raise ValueError(
303            ("Exporting an object {} which has an attribute named "
304             "'{signatures}'. This is a reserved attribute used to store "
305             "SavedModel signatures in objects which come from "
306             "`tf.saved_model.load`. Delete this attribute "
307             "(e.g. 'del obj.{signatures}') before saving if this shadowing is "
308             "acceptable.").format(
309                 saveable_view.root,
310                 signatures=SIGNATURE_ATTRIBUTE_NAME))
311      break
312