1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""SignatureDef utility functions implementation.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21 22from tensorflow.core.framework import types_pb2 23from tensorflow.core.protobuf import meta_graph_pb2 24from tensorflow.python.framework import errors 25from tensorflow.python.framework import ops 26from tensorflow.python.saved_model import signature_constants 27from tensorflow.python.saved_model import utils_impl as utils 28from tensorflow.python.util import deprecation 29from tensorflow.python.util.tf_export import tf_export 30 31 32@tf_export( 33 v1=[ 34 'saved_model.build_signature_def', 35 'saved_model.signature_def_utils.build_signature_def' 36 ]) 37@deprecation.deprecated_endpoints( 38 'saved_model.signature_def_utils.build_signature_def') 39def build_signature_def(inputs=None, outputs=None, method_name=None): 40 """Utility function to build a SignatureDef protocol buffer. 41 42 Args: 43 inputs: Inputs of the SignatureDef defined as a proto map of string to 44 tensor info. 45 outputs: Outputs of the SignatureDef defined as a proto map of string to 46 tensor info. 47 method_name: Method name of the SignatureDef as a string. 48 49 Returns: 50 A SignatureDef protocol buffer constructed based on the supplied arguments. 51 """ 52 signature_def = meta_graph_pb2.SignatureDef() 53 if inputs is not None: 54 for item in inputs: 55 signature_def.inputs[item].CopyFrom(inputs[item]) 56 if outputs is not None: 57 for item in outputs: 58 signature_def.outputs[item].CopyFrom(outputs[item]) 59 if method_name is not None: 60 signature_def.method_name = method_name 61 return signature_def 62 63 64@tf_export( 65 v1=[ 66 'saved_model.regression_signature_def', 67 'saved_model.signature_def_utils.regression_signature_def' 68 ]) 69@deprecation.deprecated_endpoints( 70 'saved_model.signature_def_utils.regression_signature_def') 71def regression_signature_def(examples, predictions): 72 """Creates regression signature from given examples and predictions. 73 74 This function produces signatures intended for use with the TensorFlow Serving 75 Regress API (tensorflow_serving/apis/prediction_service.proto), and so 76 constrains the input and output types to those allowed by TensorFlow Serving. 77 78 Args: 79 examples: A string `Tensor`, expected to accept serialized tf.Examples. 80 predictions: A float `Tensor`. 81 82 Returns: 83 A regression-flavored signature_def. 84 85 Raises: 86 ValueError: If examples is `None`. 87 """ 88 if examples is None: 89 raise ValueError('Regression examples cannot be None.') 90 if not isinstance(examples, ops.Tensor): 91 raise ValueError('Regression examples must be a string Tensor.') 92 if predictions is None: 93 raise ValueError('Regression predictions cannot be None.') 94 95 input_tensor_info = utils.build_tensor_info(examples) 96 if input_tensor_info.dtype != types_pb2.DT_STRING: 97 raise ValueError('Regression examples must be a string Tensor.') 98 signature_inputs = {signature_constants.REGRESS_INPUTS: input_tensor_info} 99 100 output_tensor_info = utils.build_tensor_info(predictions) 101 if output_tensor_info.dtype != types_pb2.DT_FLOAT: 102 raise ValueError('Regression output must be a float Tensor.') 103 signature_outputs = {signature_constants.REGRESS_OUTPUTS: output_tensor_info} 104 105 signature_def = build_signature_def( 106 signature_inputs, signature_outputs, 107 signature_constants.REGRESS_METHOD_NAME) 108 109 return signature_def 110 111 112@tf_export( 113 v1=[ 114 'saved_model.classification_signature_def', 115 'saved_model.signature_def_utils.classification_signature_def' 116 ]) 117@deprecation.deprecated_endpoints( 118 'saved_model.signature_def_utils.classification_signature_def') 119def classification_signature_def(examples, classes, scores): 120 """Creates classification signature from given examples and predictions. 121 122 This function produces signatures intended for use with the TensorFlow Serving 123 Classify API (tensorflow_serving/apis/prediction_service.proto), and so 124 constrains the input and output types to those allowed by TensorFlow Serving. 125 126 Args: 127 examples: A string `Tensor`, expected to accept serialized tf.Examples. 128 classes: A string `Tensor`. Note that the ClassificationResponse message 129 requires that class labels are strings, not integers or anything else. 130 scores: a float `Tensor`. 131 132 Returns: 133 A classification-flavored signature_def. 134 135 Raises: 136 ValueError: If examples is `None`. 137 """ 138 if examples is None: 139 raise ValueError('Classification examples cannot be None.') 140 if not isinstance(examples, ops.Tensor): 141 raise ValueError('Classification examples must be a string Tensor.') 142 if classes is None and scores is None: 143 raise ValueError('Classification classes and scores cannot both be None.') 144 145 input_tensor_info = utils.build_tensor_info(examples) 146 if input_tensor_info.dtype != types_pb2.DT_STRING: 147 raise ValueError('Classification examples must be a string Tensor.') 148 signature_inputs = {signature_constants.CLASSIFY_INPUTS: input_tensor_info} 149 150 signature_outputs = {} 151 if classes is not None: 152 classes_tensor_info = utils.build_tensor_info(classes) 153 if classes_tensor_info.dtype != types_pb2.DT_STRING: 154 raise ValueError('Classification classes must be a string Tensor.') 155 signature_outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES] = ( 156 classes_tensor_info) 157 if scores is not None: 158 scores_tensor_info = utils.build_tensor_info(scores) 159 if scores_tensor_info.dtype != types_pb2.DT_FLOAT: 160 raise ValueError('Classification scores must be a float Tensor.') 161 signature_outputs[signature_constants.CLASSIFY_OUTPUT_SCORES] = ( 162 scores_tensor_info) 163 164 signature_def = build_signature_def( 165 signature_inputs, signature_outputs, 166 signature_constants.CLASSIFY_METHOD_NAME) 167 168 return signature_def 169 170 171@tf_export( 172 v1=[ 173 'saved_model.predict_signature_def', 174 'saved_model.signature_def_utils.predict_signature_def' 175 ]) 176@deprecation.deprecated_endpoints( 177 'saved_model.signature_def_utils.predict_signature_def') 178def predict_signature_def(inputs, outputs): 179 """Creates prediction signature from given inputs and outputs. 180 181 This function produces signatures intended for use with the TensorFlow Serving 182 Predict API (tensorflow_serving/apis/prediction_service.proto). This API 183 imposes no constraints on the input and output types. 184 185 Args: 186 inputs: dict of string to `Tensor`. 187 outputs: dict of string to `Tensor`. 188 189 Returns: 190 A prediction-flavored signature_def. 191 192 Raises: 193 ValueError: If inputs or outputs is `None`. 194 """ 195 if inputs is None or not inputs: 196 raise ValueError('Prediction inputs cannot be None or empty.') 197 if outputs is None or not outputs: 198 raise ValueError('Prediction outputs cannot be None or empty.') 199 200 signature_inputs = {key: utils.build_tensor_info(tensor) 201 for key, tensor in inputs.items()} 202 signature_outputs = {key: utils.build_tensor_info(tensor) 203 for key, tensor in outputs.items()} 204 205 signature_def = build_signature_def( 206 signature_inputs, signature_outputs, 207 signature_constants.PREDICT_METHOD_NAME) 208 209 return signature_def 210 211 212# LINT.IfChange 213def supervised_train_signature_def( 214 inputs, loss, predictions=None, metrics=None): 215 return _supervised_signature_def( 216 signature_constants.SUPERVISED_TRAIN_METHOD_NAME, inputs, loss=loss, 217 predictions=predictions, metrics=metrics) 218 219 220def supervised_eval_signature_def( 221 inputs, loss, predictions=None, metrics=None): 222 return _supervised_signature_def( 223 signature_constants.SUPERVISED_EVAL_METHOD_NAME, inputs, loss=loss, 224 predictions=predictions, metrics=metrics) 225 226 227def _supervised_signature_def( 228 method_name, inputs, loss=None, predictions=None, 229 metrics=None): 230 """Creates a signature for training and eval data. 231 232 This function produces signatures that describe the inputs and outputs 233 of a supervised process, such as training or evaluation, that 234 results in loss, metrics, and the like. Note that this function only requires 235 inputs to be not None. 236 237 Args: 238 method_name: Method name of the SignatureDef as a string. 239 inputs: dict of string to `Tensor`. 240 loss: dict of string to `Tensor` representing computed loss. 241 predictions: dict of string to `Tensor` representing the output predictions. 242 metrics: dict of string to `Tensor` representing metric ops. 243 244 Returns: 245 A train- or eval-flavored signature_def. 246 247 Raises: 248 ValueError: If inputs or outputs is `None`. 249 """ 250 if inputs is None or not inputs: 251 raise ValueError('{} inputs cannot be None or empty.'.format(method_name)) 252 253 signature_inputs = {key: utils.build_tensor_info(tensor) 254 for key, tensor in inputs.items()} 255 256 signature_outputs = {} 257 for output_set in (loss, predictions, metrics): 258 if output_set is not None: 259 sig_out = {key: utils.build_tensor_info(tensor) 260 for key, tensor in output_set.items()} 261 signature_outputs.update(sig_out) 262 263 signature_def = build_signature_def( 264 signature_inputs, signature_outputs, method_name) 265 266 return signature_def 267# LINT.ThenChange(//tensorflow/python/keras/saving/utils_v1/signature_def_utils.py) 268 269 270@tf_export( 271 v1=[ 272 'saved_model.is_valid_signature', 273 'saved_model.signature_def_utils.is_valid_signature' 274 ]) 275@deprecation.deprecated_endpoints( 276 'saved_model.signature_def_utils.is_valid_signature') 277def is_valid_signature(signature_def): 278 """Determine whether a SignatureDef can be served by TensorFlow Serving.""" 279 if signature_def is None: 280 return False 281 return (_is_valid_classification_signature(signature_def) or 282 _is_valid_regression_signature(signature_def) or 283 _is_valid_predict_signature(signature_def)) 284 285 286def _is_valid_predict_signature(signature_def): 287 """Determine whether the argument is a servable 'predict' SignatureDef.""" 288 if signature_def.method_name != signature_constants.PREDICT_METHOD_NAME: 289 return False 290 if not signature_def.inputs.keys(): 291 return False 292 if not signature_def.outputs.keys(): 293 return False 294 return True 295 296 297def _is_valid_regression_signature(signature_def): 298 """Determine whether the argument is a servable 'regress' SignatureDef.""" 299 if signature_def.method_name != signature_constants.REGRESS_METHOD_NAME: 300 return False 301 302 if (set(signature_def.inputs.keys()) 303 != set([signature_constants.REGRESS_INPUTS])): 304 return False 305 if (signature_def.inputs[signature_constants.REGRESS_INPUTS].dtype != 306 types_pb2.DT_STRING): 307 return False 308 309 if (set(signature_def.outputs.keys()) 310 != set([signature_constants.REGRESS_OUTPUTS])): 311 return False 312 if (signature_def.outputs[signature_constants.REGRESS_OUTPUTS].dtype != 313 types_pb2.DT_FLOAT): 314 return False 315 316 return True 317 318 319def _is_valid_classification_signature(signature_def): 320 """Determine whether the argument is a servable 'classify' SignatureDef.""" 321 if signature_def.method_name != signature_constants.CLASSIFY_METHOD_NAME: 322 return False 323 324 if (set(signature_def.inputs.keys()) 325 != set([signature_constants.CLASSIFY_INPUTS])): 326 return False 327 if (signature_def.inputs[signature_constants.CLASSIFY_INPUTS].dtype != 328 types_pb2.DT_STRING): 329 return False 330 331 allowed_outputs = set([signature_constants.CLASSIFY_OUTPUT_CLASSES, 332 signature_constants.CLASSIFY_OUTPUT_SCORES]) 333 334 if not signature_def.outputs.keys(): 335 return False 336 if set(signature_def.outputs.keys()) - allowed_outputs: 337 return False 338 if (signature_constants.CLASSIFY_OUTPUT_CLASSES in signature_def.outputs 339 and 340 signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES].dtype 341 != types_pb2.DT_STRING): 342 return False 343 if (signature_constants.CLASSIFY_OUTPUT_SCORES in signature_def.outputs 344 and 345 signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_SCORES].dtype != 346 types_pb2.DT_FLOAT): 347 return False 348 349 return True 350 351 352def op_signature_def(op, key): 353 """Creates a signature def with the output pointing to an op. 354 355 Note that op isn't strictly enforced to be an Op object, and may be a Tensor. 356 It is recommended to use the build_signature_def() function for Tensors. 357 358 Args: 359 op: An Op (or possibly Tensor). 360 key: Key to graph element in the SignatureDef outputs. 361 362 Returns: 363 A SignatureDef with a single output pointing to the op. 364 """ 365 # Use build_tensor_info_from_op, which creates a TensorInfo from the element's 366 # name. 367 return build_signature_def(outputs={key: utils.build_tensor_info_from_op(op)}) 368 369 370def load_op_from_signature_def(signature_def, key, import_scope=None): 371 """Load an Op from a SignatureDef created by op_signature_def(). 372 373 Args: 374 signature_def: a SignatureDef proto 375 key: string key to op in the SignatureDef outputs. 376 import_scope: Scope used to import the op 377 378 Returns: 379 Op (or possibly Tensor) in the graph with the same name as saved in the 380 SignatureDef. 381 382 Raises: 383 NotFoundError: If the op could not be found in the graph. 384 """ 385 tensor_info = signature_def.outputs[key] 386 try: 387 # The init and train ops are not strictly enforced to be operations, so 388 # retrieve any graph element (can be either op or tensor). 389 return utils.get_element_from_tensor_info( 390 tensor_info, import_scope=import_scope) 391 except KeyError: 392 raise errors.NotFoundError( 393 None, None, 394 'The {0} could not be found in the graph. Please make sure the ' 395 'SavedModel was created by the internal _SavedModelBuilder. If you ' 396 'are using the public API, please make sure the SignatureDef in the ' 397 'SavedModel does not contain the key "{0}".'.format(key)) 398