1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Helpers for working with signatures in tf.saved_model.save.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from absl import logging 22 23from tensorflow.python.eager import def_function 24from tensorflow.python.eager import function as defun 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_spec 27from tensorflow.python.ops import resource_variable_ops 28from tensorflow.python.saved_model import function_serialization 29from tensorflow.python.saved_model import revived_types 30from tensorflow.python.saved_model import signature_constants 31from tensorflow.python.training.tracking import base 32from tensorflow.python.util import compat 33from tensorflow.python.util import nest 34from tensorflow.python.util.compat import collections_abc 35 36 37DEFAULT_SIGNATURE_ATTR = "_default_save_signature" 38SIGNATURE_ATTRIBUTE_NAME = "signatures" 39# Max number of warnings to show if signature contains normalized input names. 40_NUM_DISPLAY_NORMALIZED_SIGNATURES = 5 41 42 43def _get_signature(function): 44 if (isinstance(function, (defun.Function, def_function.Function)) and 45 function.input_signature is not None): 46 function = function._get_concrete_function_garbage_collected() # pylint: disable=protected-access 47 if not isinstance(function, defun.ConcreteFunction): 48 return None 49 return function 50 51 52def _valid_signature(concrete_function): 53 """Returns whether concrete function can be converted to a signature.""" 54 if not concrete_function.outputs: 55 # Functions without outputs don't make sense as signatures. We just don't 56 # have any way to run an Operation with no outputs as a SignatureDef in the 57 # 1.x style. 58 return False 59 try: 60 _validate_inputs(concrete_function) 61 _normalize_outputs(concrete_function.structured_outputs, "unused", "unused") 62 except ValueError: 63 return False 64 return True 65 66 67def _validate_inputs(concrete_function): 68 """Raises error if input type is tf.Variable.""" 69 if any(isinstance(inp, resource_variable_ops.VariableSpec) 70 for inp in nest.flatten( 71 concrete_function.structured_input_signature)): 72 raise ValueError(("Functions that expect tf.Variable inputs cannot be " 73 "exported as signatures.")) 74 75 76def _get_signature_name_changes(concrete_function): 77 """Checks for user-specified signature input names that are normalized.""" 78 # Map of {user-given name: normalized name} if the names are un-identical. 79 name_changes = {} 80 for signature_input_name, graph_input in zip( 81 concrete_function.function_def.signature.input_arg, 82 concrete_function.graph.inputs): 83 try: 84 user_specified_name = compat.as_str( 85 graph_input.op.get_attr("_user_specified_name")) 86 if signature_input_name.name != user_specified_name: 87 name_changes[user_specified_name] = signature_input_name.name 88 except ValueError: 89 # Signature input does not have a user-specified name. 90 pass 91 return name_changes 92 93 94def find_function_to_export(saveable_view): 95 """Function to export, None if no suitable function was found.""" 96 # If the user did not specify signatures, check the root object for a function 97 # that can be made into a signature. 98 functions = saveable_view.list_functions(saveable_view.root) 99 signature = functions.get(DEFAULT_SIGNATURE_ATTR, None) 100 if signature is not None: 101 return signature 102 103 # TODO(andresp): Discuss removing this behaviour. It can lead to WTFs when a 104 # user decides to annotate more functions with tf.function and suddenly 105 # serving that model way later in the process stops working. 106 possible_signatures = [] 107 for function in functions.values(): 108 concrete = _get_signature(function) 109 if concrete is not None and _valid_signature(concrete): 110 possible_signatures.append(concrete) 111 if len(possible_signatures) == 1: 112 single_function = possible_signatures[0] 113 signature = _get_signature(single_function) 114 if signature and _valid_signature(signature): 115 return signature 116 return None 117 118 119def canonicalize_signatures(signatures): 120 """Converts `signatures` into a dictionary of concrete functions.""" 121 if signatures is None: 122 return {}, {} 123 if not isinstance(signatures, collections_abc.Mapping): 124 signatures = { 125 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signatures} 126 num_normalized_signatures_counter = 0 127 concrete_signatures = {} 128 wrapped_functions = {} 129 for signature_key, function in signatures.items(): 130 original_function = signature_function = _get_signature(function) 131 if signature_function is None: 132 raise ValueError( 133 ("Expected a TensorFlow function to generate a signature for, but " 134 "got {}. Only `tf.functions` with an input signature or " 135 "concrete functions can be used as a signature.").format(function)) 136 137 wrapped_functions[original_function] = signature_function = ( 138 wrapped_functions.get(original_function) or 139 function_serialization.wrap_cached_variables(original_function)) 140 _validate_inputs(signature_function) 141 if num_normalized_signatures_counter < _NUM_DISPLAY_NORMALIZED_SIGNATURES: 142 signature_name_changes = _get_signature_name_changes(signature_function) 143 if signature_name_changes: 144 num_normalized_signatures_counter += 1 145 logging.warning( 146 "Function `%s` contains input name(s) %s with unsupported " 147 "characters which will be renamed to %s in the SavedModel.", 148 compat.as_str(signature_function.graph.name), 149 ", ".join(signature_name_changes.keys()), 150 ", ".join(signature_name_changes.values())) 151 # Re-wrap the function so that it returns a dictionary of Tensors. This 152 # matches the format of 1.x-style signatures. 153 # pylint: disable=cell-var-from-loop 154 @def_function.function 155 def signature_wrapper(**kwargs): 156 structured_outputs = signature_function(**kwargs) 157 return _normalize_outputs( 158 structured_outputs, signature_function.name, signature_key) 159 tensor_spec_signature = {} 160 if signature_function.structured_input_signature is not None: 161 # The structured input signature may contain other non-tensor arguments. 162 inputs = filter( 163 lambda x: isinstance(x, tensor_spec.TensorSpec), 164 nest.flatten(signature_function.structured_input_signature, 165 expand_composites=True)) 166 else: 167 # Structured input signature isn't always defined for some functions. 168 inputs = signature_function.inputs 169 170 for keyword, inp in zip( 171 signature_function._arg_keywords, # pylint: disable=protected-access 172 inputs): 173 keyword = compat.as_str(keyword) 174 if isinstance(inp, tensor_spec.TensorSpec): 175 spec = tensor_spec.TensorSpec(inp.shape, inp.dtype, name=keyword) 176 else: 177 spec = tensor_spec.TensorSpec.from_tensor(inp, name=keyword) 178 tensor_spec_signature[keyword] = spec 179 final_concrete = signature_wrapper._get_concrete_function_garbage_collected( # pylint: disable=protected-access 180 **tensor_spec_signature) 181 # pylint: disable=protected-access 182 if len(final_concrete._arg_keywords) == 1: 183 # If there is only one input to the signature, a very common case, then 184 # ordering is unambiguous and we can let people pass a positional 185 # argument. Since SignatureDefs are unordered (protobuf "map") multiple 186 # arguments means we need to be keyword-only. 187 final_concrete._num_positional_args = 1 188 else: 189 final_concrete._num_positional_args = 0 190 # pylint: enable=protected-access 191 concrete_signatures[signature_key] = final_concrete 192 # pylint: enable=cell-var-from-loop 193 return concrete_signatures, wrapped_functions 194 195 196def _normalize_outputs(outputs, function_name, signature_key): 197 """Construct an output dictionary from unnormalized function outputs.""" 198 # Convert `outputs` to a dictionary (if it's not one already). 199 if not isinstance(outputs, collections_abc.Mapping): 200 if not isinstance(outputs, collections_abc.Sequence): 201 outputs = [outputs] 202 outputs = {("output_{}".format(output_index)): output 203 for output_index, output 204 in enumerate(outputs)} 205 206 # Check that the keys of `outputs` are strings and the values are Tensors. 207 for key, value in outputs.items(): 208 if not isinstance(key, compat.bytes_or_text_types): 209 raise ValueError( 210 ("Got a dictionary with a non-string key {!r} in the output of the " 211 "function {} used to generate the SavedModel signature {!r}.") 212 .format(key, compat.as_str_any(function_name), signature_key)) 213 if not isinstance(value, ops.Tensor): 214 raise ValueError( 215 ("Got a non-Tensor value {!r} for key {!r} in the output of the " 216 "function {} used to generate the SavedModel signature {!r}. " 217 "Outputs for functions used as signatures must be a single Tensor, " 218 "a sequence of Tensors, or a dictionary from string to Tensor.") 219 .format(value, key, compat.as_str_any(function_name), signature_key)) 220 221 return outputs 222 223 224# _SignatureMap is immutable to ensure that users do not expect changes to be 225# reflected in the SavedModel. Using public APIs, tf.saved_model.load() is the 226# only way to create a _SignatureMap and there is no way to modify it. So we can 227# safely ignore/overwrite ".signatures" attributes attached to objects being 228# saved if they contain a _SignatureMap. A ".signatures" attribute containing 229# any other type (e.g. a regular dict) will raise an exception asking the user 230# to first "del obj.signatures" if they want it overwritten. 231class _SignatureMap(collections_abc.Mapping, base.Trackable): 232 """A collection of SavedModel signatures.""" 233 234 def __init__(self): 235 self._signatures = {} 236 237 def _add_signature(self, name, concrete_function): 238 """Adds a signature to the _SignatureMap.""" 239 # Ideally this object would be immutable, but restore is streaming so we do 240 # need a private API for adding new signatures to an existing object. 241 self._signatures[name] = concrete_function 242 243 def __getitem__(self, key): 244 return self._signatures[key] 245 246 def __iter__(self): 247 return iter(self._signatures) 248 249 def __len__(self): 250 return len(self._signatures) 251 252 def __repr__(self): 253 return "_SignatureMap({})".format(self._signatures) 254 255 def _list_functions_for_serialization(self, unused_serialization_cache): 256 return { 257 key: value for key, value in self.items() 258 if isinstance(value, (def_function.Function, defun.ConcreteFunction)) 259 } 260 261 262revived_types.register_revived_type( 263 "signature_map", 264 lambda obj: isinstance(obj, _SignatureMap), 265 versions=[revived_types.VersionedTypeRegistration( 266 # Standard dependencies are enough to reconstruct the trackable 267 # items in dictionaries, so we don't need to save any extra information. 268 object_factory=lambda proto: _SignatureMap(), 269 version=1, 270 min_producer_version=1, 271 min_consumer_version=1, 272 setter=_SignatureMap._add_signature # pylint: disable=protected-access 273 )]) 274 275 276def create_signature_map(signatures): 277 """Creates an object containing `signatures`.""" 278 signature_map = _SignatureMap() 279 for name, func in signatures.items(): 280 # This true of any signature that came from canonicalize_signatures. Here as 281 # a sanity check on saving; crashing on load (e.g. in _add_signature) would 282 # be more problematic in case future export changes violated these 283 # assertions. 284 assert isinstance(func, defun.ConcreteFunction) 285 assert isinstance(func.structured_outputs, collections_abc.Mapping) 286 # pylint: disable=protected-access 287 if len(func._arg_keywords) == 1: 288 assert 1 == func._num_positional_args 289 else: 290 assert 0 == func._num_positional_args 291 signature_map._add_signature(name, func) 292 # pylint: enable=protected-access 293 return signature_map 294 295 296def validate_saveable_view(saveable_view): 297 """Performs signature-related sanity checks on `saveable_view`.""" 298 for name, dep in saveable_view.list_dependencies( 299 saveable_view.root): 300 if name == SIGNATURE_ATTRIBUTE_NAME: 301 if not isinstance(dep, _SignatureMap): 302 raise ValueError( 303 ("Exporting an object {} which has an attribute named " 304 "'{signatures}'. This is a reserved attribute used to store " 305 "SavedModel signatures in objects which come from " 306 "`tf.saved_model.load`. Delete this attribute " 307 "(e.g. 'del obj.{signatures}') before saving if this shadowing is " 308 "acceptable.").format( 309 saveable_view.root, 310 signatures=SIGNATURE_ATTRIBUTE_NAME)) 311 break 312