1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""SignatureDef utility functions implementation.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21 22from tensorflow.core.framework import types_pb2 23from tensorflow.core.protobuf import meta_graph_pb2 24from tensorflow.python.framework import errors 25from tensorflow.python.framework import ops 26from tensorflow.python.saved_model import signature_constants 27from tensorflow.python.saved_model import utils_impl as utils 28from tensorflow.python.util import deprecation 29from tensorflow.python.util.tf_export import tf_export 30 31 32@tf_export( 33 v1=[ 34 'saved_model.build_signature_def', 35 'saved_model.signature_def_utils.build_signature_def' 36 ]) 37@deprecation.deprecated_endpoints( 38 'saved_model.signature_def_utils.build_signature_def') 39def build_signature_def(inputs=None, outputs=None, method_name=None): 40 """Utility function to build a SignatureDef protocol buffer. 41 42 Args: 43 inputs: Inputs of the SignatureDef defined as a proto map of string to 44 tensor info. 45 outputs: Outputs of the SignatureDef defined as a proto map of string to 46 tensor info. 47 method_name: Method name of the SignatureDef as a string. 48 49 Returns: 50 A SignatureDef protocol buffer constructed based on the supplied arguments. 51 """ 52 signature_def = meta_graph_pb2.SignatureDef() 53 if inputs is not None: 54 for item in inputs: 55 signature_def.inputs[item].CopyFrom(inputs[item]) 56 if outputs is not None: 57 for item in outputs: 58 signature_def.outputs[item].CopyFrom(outputs[item]) 59 if method_name is not None: 60 signature_def.method_name = method_name 61 return signature_def 62 63 64@tf_export( 65 v1=[ 66 'saved_model.regression_signature_def', 67 'saved_model.signature_def_utils.regression_signature_def' 68 ]) 69@deprecation.deprecated_endpoints( 70 'saved_model.signature_def_utils.regression_signature_def') 71def regression_signature_def(examples, predictions): 72 """Creates regression signature from given examples and predictions. 73 74 This function produces signatures intended for use with the TensorFlow Serving 75 Regress API (tensorflow_serving/apis/prediction_service.proto), and so 76 constrains the input and output types to those allowed by TensorFlow Serving. 77 78 Args: 79 examples: A string `Tensor`, expected to accept serialized tf.Examples. 80 predictions: A float `Tensor`. 81 82 Returns: 83 A regression-flavored signature_def. 84 85 Raises: 86 ValueError: If examples is `None`. 87 """ 88 if examples is None: 89 raise ValueError('Regression examples cannot be None.') 90 if not isinstance(examples, ops.Tensor): 91 raise ValueError('Regression examples must be a string Tensor.') 92 if predictions is None: 93 raise ValueError('Regression predictions cannot be None.') 94 95 input_tensor_info = utils.build_tensor_info(examples) 96 if input_tensor_info.dtype != types_pb2.DT_STRING: 97 raise ValueError('Regression examples must be a string Tensor.') 98 signature_inputs = {signature_constants.REGRESS_INPUTS: input_tensor_info} 99 100 output_tensor_info = utils.build_tensor_info(predictions) 101 if output_tensor_info.dtype != types_pb2.DT_FLOAT: 102 raise ValueError('Regression output must be a float Tensor.') 103 signature_outputs = {signature_constants.REGRESS_OUTPUTS: output_tensor_info} 104 105 signature_def = build_signature_def( 106 signature_inputs, signature_outputs, 107 signature_constants.REGRESS_METHOD_NAME) 108 109 return signature_def 110 111 112@tf_export( 113 v1=[ 114 'saved_model.classification_signature_def', 115 'saved_model.signature_def_utils.classification_signature_def' 116 ]) 117@deprecation.deprecated_endpoints( 118 'saved_model.signature_def_utils.classification_signature_def') 119def classification_signature_def(examples, classes, scores): 120 """Creates classification signature from given examples and predictions. 121 122 This function produces signatures intended for use with the TensorFlow Serving 123 Classify API (tensorflow_serving/apis/prediction_service.proto), and so 124 constrains the input and output types to those allowed by TensorFlow Serving. 125 126 Args: 127 examples: A string `Tensor`, expected to accept serialized tf.Examples. 128 classes: A string `Tensor`. Note that the ClassificationResponse message 129 requires that class labels are strings, not integers or anything else. 130 scores: a float `Tensor`. 131 132 Returns: 133 A classification-flavored signature_def. 134 135 Raises: 136 ValueError: If examples is `None`. 137 """ 138 if examples is None: 139 raise ValueError('Classification examples cannot be None.') 140 if not isinstance(examples, ops.Tensor): 141 raise ValueError('Classification examples must be a string Tensor.') 142 if classes is None and scores is None: 143 raise ValueError('Classification classes and scores cannot both be None.') 144 145 input_tensor_info = utils.build_tensor_info(examples) 146 if input_tensor_info.dtype != types_pb2.DT_STRING: 147 raise ValueError('Classification examples must be a string Tensor.') 148 signature_inputs = {signature_constants.CLASSIFY_INPUTS: input_tensor_info} 149 150 signature_outputs = {} 151 if classes is not None: 152 classes_tensor_info = utils.build_tensor_info(classes) 153 if classes_tensor_info.dtype != types_pb2.DT_STRING: 154 raise ValueError('Classification classes must be a string Tensor.') 155 signature_outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES] = ( 156 classes_tensor_info) 157 if scores is not None: 158 scores_tensor_info = utils.build_tensor_info(scores) 159 if scores_tensor_info.dtype != types_pb2.DT_FLOAT: 160 raise ValueError('Classification scores must be a float Tensor.') 161 signature_outputs[signature_constants.CLASSIFY_OUTPUT_SCORES] = ( 162 scores_tensor_info) 163 164 signature_def = build_signature_def( 165 signature_inputs, signature_outputs, 166 signature_constants.CLASSIFY_METHOD_NAME) 167 168 return signature_def 169 170 171@tf_export( 172 v1=[ 173 'saved_model.predict_signature_def', 174 'saved_model.signature_def_utils.predict_signature_def' 175 ]) 176@deprecation.deprecated_endpoints( 177 'saved_model.signature_def_utils.predict_signature_def') 178def predict_signature_def(inputs, outputs): 179 """Creates prediction signature from given inputs and outputs. 180 181 This function produces signatures intended for use with the TensorFlow Serving 182 Predict API (tensorflow_serving/apis/prediction_service.proto). This API 183 imposes no constraints on the input and output types. 184 185 Args: 186 inputs: dict of string to `Tensor`. 187 outputs: dict of string to `Tensor`. 188 189 Returns: 190 A prediction-flavored signature_def. 191 192 Raises: 193 ValueError: If inputs or outputs is `None`. 194 """ 195 if inputs is None or not inputs: 196 raise ValueError('Prediction inputs cannot be None or empty.') 197 if outputs is None or not outputs: 198 raise ValueError('Prediction outputs cannot be None or empty.') 199 200 signature_inputs = {key: utils.build_tensor_info(tensor) 201 for key, tensor in inputs.items()} 202 signature_outputs = {key: utils.build_tensor_info(tensor) 203 for key, tensor in outputs.items()} 204 205 signature_def = build_signature_def( 206 signature_inputs, signature_outputs, 207 signature_constants.PREDICT_METHOD_NAME) 208 209 return signature_def 210 211 212def supervised_train_signature_def( 213 inputs, loss, predictions=None, metrics=None): 214 return _supervised_signature_def( 215 signature_constants.SUPERVISED_TRAIN_METHOD_NAME, inputs, loss=loss, 216 predictions=predictions, metrics=metrics) 217 218 219def supervised_eval_signature_def( 220 inputs, loss, predictions=None, metrics=None): 221 return _supervised_signature_def( 222 signature_constants.SUPERVISED_EVAL_METHOD_NAME, inputs, loss=loss, 223 predictions=predictions, metrics=metrics) 224 225 226def _supervised_signature_def( 227 method_name, inputs, loss=None, predictions=None, 228 metrics=None): 229 """Creates a signature for training and eval data. 230 231 This function produces signatures that describe the inputs and outputs 232 of a supervised process, such as training or evaluation, that 233 results in loss, metrics, and the like. Note that this function only requires 234 inputs to be not None. 235 236 Args: 237 method_name: Method name of the SignatureDef as a string. 238 inputs: dict of string to `Tensor`. 239 loss: dict of string to `Tensor` representing computed loss. 240 predictions: dict of string to `Tensor` representing the output predictions. 241 metrics: dict of string to `Tensor` representing metric ops. 242 243 Returns: 244 A train- or eval-flavored signature_def. 245 246 Raises: 247 ValueError: If inputs or outputs is `None`. 248 """ 249 if inputs is None or not inputs: 250 raise ValueError('{} inputs cannot be None or empty.'.format(method_name)) 251 252 signature_inputs = {key: utils.build_tensor_info(tensor) 253 for key, tensor in inputs.items()} 254 255 signature_outputs = {} 256 for output_set in (loss, predictions, metrics): 257 if output_set is not None: 258 sig_out = {key: utils.build_tensor_info(tensor) 259 for key, tensor in output_set.items()} 260 signature_outputs.update(sig_out) 261 262 signature_def = build_signature_def( 263 signature_inputs, signature_outputs, method_name) 264 265 return signature_def 266 267 268@tf_export( 269 v1=[ 270 'saved_model.is_valid_signature', 271 'saved_model.signature_def_utils.is_valid_signature' 272 ]) 273@deprecation.deprecated_endpoints( 274 'saved_model.signature_def_utils.is_valid_signature') 275def is_valid_signature(signature_def): 276 """Determine whether a SignatureDef can be served by TensorFlow Serving.""" 277 if signature_def is None: 278 return False 279 return (_is_valid_classification_signature(signature_def) or 280 _is_valid_regression_signature(signature_def) or 281 _is_valid_predict_signature(signature_def)) 282 283 284def _is_valid_predict_signature(signature_def): 285 """Determine whether the argument is a servable 'predict' SignatureDef.""" 286 if signature_def.method_name != signature_constants.PREDICT_METHOD_NAME: 287 return False 288 if not signature_def.inputs.keys(): 289 return False 290 if not signature_def.outputs.keys(): 291 return False 292 return True 293 294 295def _is_valid_regression_signature(signature_def): 296 """Determine whether the argument is a servable 'regress' SignatureDef.""" 297 if signature_def.method_name != signature_constants.REGRESS_METHOD_NAME: 298 return False 299 300 if (set(signature_def.inputs.keys()) 301 != set([signature_constants.REGRESS_INPUTS])): 302 return False 303 if (signature_def.inputs[signature_constants.REGRESS_INPUTS].dtype != 304 types_pb2.DT_STRING): 305 return False 306 307 if (set(signature_def.outputs.keys()) 308 != set([signature_constants.REGRESS_OUTPUTS])): 309 return False 310 if (signature_def.outputs[signature_constants.REGRESS_OUTPUTS].dtype != 311 types_pb2.DT_FLOAT): 312 return False 313 314 return True 315 316 317def _is_valid_classification_signature(signature_def): 318 """Determine whether the argument is a servable 'classify' SignatureDef.""" 319 if signature_def.method_name != signature_constants.CLASSIFY_METHOD_NAME: 320 return False 321 322 if (set(signature_def.inputs.keys()) 323 != set([signature_constants.CLASSIFY_INPUTS])): 324 return False 325 if (signature_def.inputs[signature_constants.CLASSIFY_INPUTS].dtype != 326 types_pb2.DT_STRING): 327 return False 328 329 allowed_outputs = set([signature_constants.CLASSIFY_OUTPUT_CLASSES, 330 signature_constants.CLASSIFY_OUTPUT_SCORES]) 331 332 if not signature_def.outputs.keys(): 333 return False 334 if set(signature_def.outputs.keys()) - allowed_outputs: 335 return False 336 if (signature_constants.CLASSIFY_OUTPUT_CLASSES in signature_def.outputs 337 and 338 signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES].dtype 339 != types_pb2.DT_STRING): 340 return False 341 if (signature_constants.CLASSIFY_OUTPUT_SCORES in signature_def.outputs 342 and 343 signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_SCORES].dtype != 344 types_pb2.DT_FLOAT): 345 return False 346 347 return True 348 349 350def op_signature_def(op, key): 351 """Creates a signature def with the output pointing to an op. 352 353 Note that op isn't strictly enforced to be an Op object, and may be a Tensor. 354 It is recommended to use the build_signature_def() function for Tensors. 355 356 Args: 357 op: An Op (or possibly Tensor). 358 key: Key to graph element in the SignatureDef outputs. 359 360 Returns: 361 A SignatureDef with a single output pointing to the op. 362 """ 363 # Use build_tensor_info_from_op, which creates a TensorInfo from the element's 364 # name. 365 return build_signature_def(outputs={key: utils.build_tensor_info_from_op(op)}) 366 367 368def load_op_from_signature_def(signature_def, key, import_scope=None): 369 """Load an Op from a SignatureDef created by op_signature_def(). 370 371 Args: 372 signature_def: a SignatureDef proto 373 key: string key to op in the SignatureDef outputs. 374 import_scope: Scope used to import the op 375 376 Returns: 377 Op (or possibly Tensor) in the graph with the same name as saved in the 378 SignatureDef. 379 380 Raises: 381 NotFoundError: If the op could not be found in the graph. 382 """ 383 tensor_info = signature_def.outputs[key] 384 try: 385 # The init and train ops are not strictly enforced to be operations, so 386 # retrieve any graph element (can be either op or tensor). 387 return utils.get_element_from_tensor_info( 388 tensor_info, import_scope=import_scope) 389 except KeyError: 390 raise errors.NotFoundError( 391 None, None, 392 'The {0} could not be found in the graph. Please make sure the ' 393 'SavedModel was created by the internal _SavedModelBuilder. If you ' 394 'are using the public API, please make sure the SignatureDef in the ' 395 'SavedModel does not contain the key "{0}".'.format(key)) 396