1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""SignatureDef utility functions implementation.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21 22from tensorflow.core.framework import types_pb2 23from tensorflow.core.protobuf import meta_graph_pb2 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_shape 27from tensorflow.python.saved_model import signature_constants 28from tensorflow.python.saved_model import utils 29from tensorflow.python.util.tf_export import tf_export 30 31 32@tf_export('saved_model.signature_def_utils.build_signature_def') 33def build_signature_def(inputs=None, outputs=None, method_name=None): 34 """Utility function to build a SignatureDef protocol buffer. 35 36 Args: 37 inputs: Inputs of the SignatureDef defined as a proto map of string to 38 tensor info. 39 outputs: Outputs of the SignatureDef defined as a proto map of string to 40 tensor info. 41 method_name: Method name of the SignatureDef as a string. 42 43 Returns: 44 A SignatureDef protocol buffer constructed based on the supplied arguments. 45 """ 46 signature_def = meta_graph_pb2.SignatureDef() 47 if inputs is not None: 48 for item in inputs: 49 signature_def.inputs[item].CopyFrom(inputs[item]) 50 if outputs is not None: 51 for item in outputs: 52 signature_def.outputs[item].CopyFrom(outputs[item]) 53 if method_name is not None: 54 signature_def.method_name = method_name 55 return signature_def 56 57 58@tf_export('saved_model.signature_def_utils.regression_signature_def') 59def regression_signature_def(examples, predictions): 60 """Creates regression signature from given examples and predictions. 61 62 This function produces signatures intended for use with the TensorFlow Serving 63 Regress API (tensorflow_serving/apis/prediction_service.proto), and so 64 constrains the input and output types to those allowed by TensorFlow Serving. 65 66 Args: 67 examples: A string `Tensor`, expected to accept serialized tf.Examples. 68 predictions: A float `Tensor`. 69 70 Returns: 71 A regression-flavored signature_def. 72 73 Raises: 74 ValueError: If examples is `None`. 75 """ 76 if examples is None: 77 raise ValueError('Regression examples cannot be None.') 78 if not isinstance(examples, ops.Tensor): 79 raise ValueError('Regression examples must be a string Tensor.') 80 if predictions is None: 81 raise ValueError('Regression predictions cannot be None.') 82 83 input_tensor_info = utils.build_tensor_info(examples) 84 if input_tensor_info.dtype != types_pb2.DT_STRING: 85 raise ValueError('Regression examples must be a string Tensor.') 86 signature_inputs = {signature_constants.REGRESS_INPUTS: input_tensor_info} 87 88 output_tensor_info = utils.build_tensor_info(predictions) 89 if output_tensor_info.dtype != types_pb2.DT_FLOAT: 90 raise ValueError('Regression output must be a float Tensor.') 91 signature_outputs = {signature_constants.REGRESS_OUTPUTS: output_tensor_info} 92 93 signature_def = build_signature_def( 94 signature_inputs, signature_outputs, 95 signature_constants.REGRESS_METHOD_NAME) 96 97 return signature_def 98 99 100@tf_export('saved_model.signature_def_utils.classification_signature_def') 101def classification_signature_def(examples, classes, scores): 102 """Creates classification signature from given examples and predictions. 103 104 This function produces signatures intended for use with the TensorFlow Serving 105 Classify API (tensorflow_serving/apis/prediction_service.proto), and so 106 constrains the input and output types to those allowed by TensorFlow Serving. 107 108 Args: 109 examples: A string `Tensor`, expected to accept serialized tf.Examples. 110 classes: A string `Tensor`. Note that the ClassificationResponse message 111 requires that class labels are strings, not integers or anything else. 112 scores: a float `Tensor`. 113 114 Returns: 115 A classification-flavored signature_def. 116 117 Raises: 118 ValueError: If examples is `None`. 119 """ 120 if examples is None: 121 raise ValueError('Classification examples cannot be None.') 122 if not isinstance(examples, ops.Tensor): 123 raise ValueError('Classification examples must be a string Tensor.') 124 if classes is None and scores is None: 125 raise ValueError('Classification classes and scores cannot both be None.') 126 127 input_tensor_info = utils.build_tensor_info(examples) 128 if input_tensor_info.dtype != types_pb2.DT_STRING: 129 raise ValueError('Classification examples must be a string Tensor.') 130 signature_inputs = {signature_constants.CLASSIFY_INPUTS: input_tensor_info} 131 132 signature_outputs = {} 133 if classes is not None: 134 classes_tensor_info = utils.build_tensor_info(classes) 135 if classes_tensor_info.dtype != types_pb2.DT_STRING: 136 raise ValueError('Classification classes must be a string Tensor.') 137 signature_outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES] = ( 138 classes_tensor_info) 139 if scores is not None: 140 scores_tensor_info = utils.build_tensor_info(scores) 141 if scores_tensor_info.dtype != types_pb2.DT_FLOAT: 142 raise ValueError('Classification scores must be a float Tensor.') 143 signature_outputs[signature_constants.CLASSIFY_OUTPUT_SCORES] = ( 144 scores_tensor_info) 145 146 signature_def = build_signature_def( 147 signature_inputs, signature_outputs, 148 signature_constants.CLASSIFY_METHOD_NAME) 149 150 return signature_def 151 152 153@tf_export('saved_model.signature_def_utils.predict_signature_def') 154def predict_signature_def(inputs, outputs): 155 """Creates prediction signature from given inputs and outputs. 156 157 This function produces signatures intended for use with the TensorFlow Serving 158 Predict API (tensorflow_serving/apis/prediction_service.proto). This API 159 imposes no constraints on the input and output types. 160 161 Args: 162 inputs: dict of string to `Tensor`. 163 outputs: dict of string to `Tensor`. 164 165 Returns: 166 A prediction-flavored signature_def. 167 168 Raises: 169 ValueError: If inputs or outputs is `None`. 170 """ 171 if inputs is None or not inputs: 172 raise ValueError('Prediction inputs cannot be None or empty.') 173 if outputs is None or not outputs: 174 raise ValueError('Prediction outputs cannot be None or empty.') 175 176 signature_inputs = {key: utils.build_tensor_info(tensor) 177 for key, tensor in inputs.items()} 178 signature_outputs = {key: utils.build_tensor_info(tensor) 179 for key, tensor in outputs.items()} 180 181 signature_def = build_signature_def( 182 signature_inputs, signature_outputs, 183 signature_constants.PREDICT_METHOD_NAME) 184 185 return signature_def 186 187 188@tf_export('saved_model.signature_def_utils.is_valid_signature') 189def is_valid_signature(signature_def): 190 """Determine whether a SignatureDef can be served by TensorFlow Serving.""" 191 if signature_def is None: 192 return False 193 return (_is_valid_classification_signature(signature_def) or 194 _is_valid_regression_signature(signature_def) or 195 _is_valid_predict_signature(signature_def)) 196 197 198def _is_valid_predict_signature(signature_def): 199 """Determine whether the argument is a servable 'predict' SignatureDef.""" 200 if signature_def.method_name != signature_constants.PREDICT_METHOD_NAME: 201 return False 202 if not signature_def.inputs.keys(): 203 return False 204 if not signature_def.outputs.keys(): 205 return False 206 return True 207 208 209def _is_valid_regression_signature(signature_def): 210 """Determine whether the argument is a servable 'regress' SignatureDef.""" 211 if signature_def.method_name != signature_constants.REGRESS_METHOD_NAME: 212 return False 213 214 if (set(signature_def.inputs.keys()) 215 != set([signature_constants.REGRESS_INPUTS])): 216 return False 217 if (signature_def.inputs[signature_constants.REGRESS_INPUTS].dtype != 218 types_pb2.DT_STRING): 219 return False 220 221 if (set(signature_def.outputs.keys()) 222 != set([signature_constants.REGRESS_OUTPUTS])): 223 return False 224 if (signature_def.outputs[signature_constants.REGRESS_OUTPUTS].dtype != 225 types_pb2.DT_FLOAT): 226 return False 227 228 return True 229 230 231def _is_valid_classification_signature(signature_def): 232 """Determine whether the argument is a servable 'classify' SignatureDef.""" 233 if signature_def.method_name != signature_constants.CLASSIFY_METHOD_NAME: 234 return False 235 236 if (set(signature_def.inputs.keys()) 237 != set([signature_constants.CLASSIFY_INPUTS])): 238 return False 239 if (signature_def.inputs[signature_constants.CLASSIFY_INPUTS].dtype != 240 types_pb2.DT_STRING): 241 return False 242 243 allowed_outputs = set([signature_constants.CLASSIFY_OUTPUT_CLASSES, 244 signature_constants.CLASSIFY_OUTPUT_SCORES]) 245 246 if not signature_def.outputs.keys(): 247 return False 248 if set(signature_def.outputs.keys()) - allowed_outputs: 249 return False 250 if (signature_constants.CLASSIFY_OUTPUT_CLASSES in signature_def.outputs 251 and 252 signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES].dtype 253 != types_pb2.DT_STRING): 254 return False 255 if (signature_constants.CLASSIFY_OUTPUT_SCORES in signature_def.outputs 256 and 257 signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_SCORES].dtype != 258 types_pb2.DT_FLOAT): 259 return False 260 261 return True 262 263 264def _get_shapes_from_tensor_info_dict(tensor_info_dict): 265 """Returns a map of keys to TensorShape objects. 266 267 Args: 268 tensor_info_dict: map with TensorInfo proto as values. 269 270 Returns: 271 Map with corresponding TensorShape objects as values. 272 """ 273 return { 274 key: tensor_shape.TensorShape(tensor_info.tensor_shape) 275 for key, tensor_info in tensor_info_dict.items() 276 } 277 278 279def _get_types_from_tensor_info_dict(tensor_info_dict): 280 """Returns a map of keys to DType objects. 281 282 Args: 283 tensor_info_dict: map with TensorInfo proto as values. 284 285 Returns: 286 Map with corresponding DType objects as values. 287 """ 288 return { 289 key: dtypes.DType(tensor_info.dtype) 290 for key, tensor_info in tensor_info_dict.items() 291 } 292 293 294def get_signature_def_input_shapes(signature): 295 """Returns map of parameter names to their shapes. 296 297 Args: 298 signature: SignatureDef proto. 299 300 Returns: 301 Map from string to TensorShape objects. 302 """ 303 return _get_shapes_from_tensor_info_dict(signature.inputs) 304 305 306def get_signature_def_input_types(signature): 307 """Returns map of output names to their types. 308 309 Args: 310 signature: SignatureDef proto. 311 312 Returns: 313 Map from string to DType objects. 314 """ 315 return _get_types_from_tensor_info_dict(signature.inputs) 316 317 318def get_signature_def_output_shapes(signature): 319 """Returns map of output names to their shapes. 320 321 Args: 322 signature: SignatureDef proto. 323 324 Returns: 325 Map from string to TensorShape objects. 326 """ 327 return _get_shapes_from_tensor_info_dict(signature.outputs) 328 329 330def get_signature_def_output_types(signature): 331 """Returns map of output names to their types. 332 333 Args: 334 signature: SignatureDef proto. 335 336 Returns: 337 Map from string to DType objects. 338 """ 339 return _get_types_from_tensor_info_dict(signature.outputs) 340