1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Helpers for working with signatures in tf.saved_model.save.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22 23from tensorflow.python.eager import def_function 24from tensorflow.python.eager import function as defun 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_spec 27from tensorflow.python.saved_model import revived_types 28from tensorflow.python.saved_model import signature_constants 29from tensorflow.python.training.tracking import base 30from tensorflow.python.util import compat 31from tensorflow.python.util import nest 32 33 34DEFAULT_SIGNATURE_ATTR = "_default_save_signature" 35SIGNATURE_ATTRIBUTE_NAME = "signatures" 36 37 38def _get_signature(function): 39 if (isinstance(function, (defun.Function, def_function.Function)) and 40 function.input_signature is not None): 41 function = function.get_concrete_function() 42 if not isinstance(function, defun.ConcreteFunction): 43 return None 44 return function 45 46 47def _valid_signature(concrete_function): 48 """Returns whether concrete function can be converted to a signature.""" 49 if not concrete_function.outputs: 50 # Functions without outputs don't make sense as signatures. We just don't 51 # have any way to run an Operation with no outputs as a SignatureDef in the 52 # 1.x style. 53 return False 54 try: 55 _normalize_outputs(concrete_function.structured_outputs, "unused", "unused") 56 except ValueError: 57 return False 58 return True 59 60 61def find_function_to_export(saveable_view): 62 """Function to export, None if no suitable function was found.""" 63 # If the user did not specify signatures, check the root object for a function 64 # that can be made into a signature. 65 functions = saveable_view.list_functions(saveable_view.root) 66 signature = functions.get(DEFAULT_SIGNATURE_ATTR, None) 67 if signature is not None: 68 return signature 69 70 # TODO(andresp): Discuss removing this behaviour. It can lead to WTFs when a 71 # user decides to annotate more functions with tf.function and suddenly 72 # serving that model way later in the process stops working. 73 possible_signatures = [] 74 for function in functions.values(): 75 concrete = _get_signature(function) 76 if concrete is not None and _valid_signature(concrete): 77 possible_signatures.append(concrete) 78 if len(possible_signatures) == 1: 79 single_function = possible_signatures[0] 80 signature = _get_signature(single_function) 81 if signature and _valid_signature(signature): 82 return signature 83 return None 84 85 86def canonicalize_signatures(signatures): 87 """Converts `signatures` into a dictionary of concrete functions.""" 88 if signatures is None: 89 return {} 90 if not isinstance(signatures, collections.Mapping): 91 signatures = { 92 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signatures} 93 concrete_signatures = {} 94 for signature_key, function in signatures.items(): 95 signature_function = _get_signature(function) 96 if signature_function is None: 97 raise ValueError( 98 ("Expected a TensorFlow function to generate a signature for, but " 99 "got {}. Only `tf.functions` with an input signature or " 100 "concrete functions can be used as a signature.").format(function)) 101 102 # Re-wrap the function so that it returns a dictionary of Tensors. This 103 # matches the format of 1.x-style signatures. 104 # pylint: disable=cell-var-from-loop 105 @def_function.function 106 def signature_wrapper(**kwargs): 107 structured_outputs = signature_function(**kwargs) 108 return _normalize_outputs( 109 structured_outputs, signature_function.name, signature_key) 110 # TODO(b/123902469): Use ConcreteFunction.structured_inputs once their names 111 # always match keyword arguments. 112 tensor_spec_signature = {} 113 for keyword, tensor in zip( 114 signature_function._arg_keywords, # pylint: disable=protected-access 115 signature_function.inputs): 116 keyword = compat.as_str(keyword) 117 tensor_spec_signature[keyword] = tensor_spec.TensorSpec.from_tensor( 118 tensor, name=keyword) 119 final_concrete = signature_wrapper.get_concrete_function( 120 **tensor_spec_signature) 121 # pylint: disable=protected-access 122 if len(final_concrete._arg_keywords) == 1: 123 # If there is only one input to the signature, a very common case, then 124 # ordering is unambiguous and we can let people pass a positional 125 # argument. Since SignatureDefs are unordered (protobuf "map") multiple 126 # arguments means we need to be keyword-only. 127 final_concrete._num_positional_args = 1 128 else: 129 final_concrete._num_positional_args = 0 130 # pylint: enable=protected-access 131 concrete_signatures[signature_key] = final_concrete 132 # pylint: enable=cell-var-from-loop 133 return concrete_signatures 134 135 136def _is_flat(sequence): 137 sequence_flat = nest.flatten(sequence) 138 try: 139 nest.assert_same_structure(sequence_flat, sequence) 140 return True 141 except ValueError: 142 return False 143 except TypeError: 144 return False 145 146 147def _normalize_outputs(outputs, function_name, signature_key): 148 """Construct an output dictionary from unnormalized function outputs.""" 149 if isinstance(outputs, collections.Mapping): 150 for key, value in outputs.items(): 151 if not isinstance(value, ops.Tensor): 152 raise ValueError( 153 ("Got a dictionary containing non-Tensor value {} for key {} " 154 "in the output of the function {} used to generate a SavedModel " 155 "signature. Dictionaries outputs for functions used as signatures " 156 "should have one Tensor output per string key.") 157 .format(value, key, compat.as_str_any(function_name))) 158 return outputs 159 else: 160 original_outputs = outputs 161 if not isinstance(outputs, collections.Sequence): 162 outputs = [outputs] 163 if not _is_flat(outputs): 164 raise ValueError( 165 ("Got non-flat outputs '{}' from '{}' for SavedModel " 166 "signature '{}'. Signatures have one Tensor per output, so " 167 "to have predictable names Python functions used to generate " 168 "these signatures should avoid outputting Tensors in nested " 169 "structures.") 170 .format(original_outputs, function_name, signature_key)) 171 return {("output_{}".format(output_index)): output 172 for output_index, output 173 in enumerate(outputs)} 174 175 176# _SignatureMap is immutable to ensure that users do not expect changes to be 177# reflected in the SavedModel. Using public APIs, tf.saved_model.load() is the 178# only way to create a _SignatureMap and there is no way to modify it. So we can 179# safely ignore/overwrite ".signatures" attributes attached to objects being 180# saved if they contain a _SignatureMap. A ".signatures" attribute containing 181# any other type (e.g. a regular dict) will raise an exception asking the user 182# to first "del obj.signatures" if they want it overwritten. 183class _SignatureMap(collections.Mapping, base.Trackable): 184 """A collection of SavedModel signatures.""" 185 186 def __init__(self): 187 self._signatures = {} 188 189 def _add_signature(self, name, concrete_function): 190 """Adds a signature to the _SignatureMap.""" 191 # Ideally this object would be immutable, but restore is streaming so we do 192 # need a private API for adding new signatures to an existing object. 193 self._signatures[name] = concrete_function 194 195 def __getitem__(self, key): 196 return self._signatures[key] 197 198 def __iter__(self): 199 return iter(self._signatures) 200 201 def __len__(self): 202 return len(self._signatures) 203 204 def __repr__(self): 205 return "_SignatureMap({})".format(self._signatures) 206 207 def _list_functions_for_serialization(self): 208 return { 209 key: value for key, value in self.items() 210 if isinstance(value, (def_function.Function, defun.ConcreteFunction)) 211 } 212 213 214revived_types.register_revived_type( 215 "signature_map", 216 lambda obj: isinstance(obj, _SignatureMap), 217 versions=[revived_types.VersionedTypeRegistration( 218 # Standard dependencies are enough to reconstruct the trackable 219 # items in dictionaries, so we don't need to save any extra information. 220 object_factory=lambda proto: _SignatureMap(), 221 version=1, 222 min_producer_version=1, 223 min_consumer_version=1, 224 setter=_SignatureMap._add_signature # pylint: disable=protected-access 225 )]) 226 227 228def create_signature_map(signatures): 229 """Creates an object containing `signatures`.""" 230 signature_map = _SignatureMap() 231 for name, func in signatures.items(): 232 # This true of any signature that came from canonicalize_signatures. Here as 233 # a sanity check on saving; crashing on load (e.g. in _add_signature) would 234 # be more problematic in case future export changes violated these 235 # assertions. 236 assert isinstance(func, defun.ConcreteFunction) 237 assert isinstance(func.structured_outputs, collections.Mapping) 238 # pylint: disable=protected-access 239 if len(func._arg_keywords) == 1: 240 assert 1 == func._num_positional_args 241 else: 242 assert 0 == func._num_positional_args 243 signature_map._add_signature(name, func) 244 # pylint: enable=protected-access 245 return signature_map 246 247 248def validate_saveable_view(saveable_view): 249 """Performs signature-related sanity checks on `saveable_view`.""" 250 for name, dep in saveable_view.list_dependencies( 251 saveable_view.root): 252 if name == SIGNATURE_ATTRIBUTE_NAME: 253 if not isinstance(dep, _SignatureMap): 254 raise ValueError( 255 ("Exporting an object {} which has an attribute named " 256 "'{signatures}'. This is a reserved attribute used to store " 257 "SavedModel signatures in objects which come from " 258 "`tf.saved_model.load`. Delete this attribute " 259 "(e.g. 'del obj.{signatures}') before saving if this shadowing is " 260 "acceptable.").format( 261 saveable_view.root, 262 signatures=SIGNATURE_ATTRIBUTE_NAME)) 263 break 264