1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Classes for different types of export output.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import abc 22 23import six 24 25 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import ops 29from tensorflow.python.framework import tensor_util 30from tensorflow.python.saved_model import signature_def_utils 31 32 33class ExportOutput(object): 34 """Represents an output of a model that can be served. 35 36 These typically correspond to model heads. 37 """ 38 39 __metaclass__ = abc.ABCMeta 40 41 _SEPARATOR_CHAR = '/' 42 43 @abc.abstractmethod 44 def as_signature_def(self, receiver_tensors): 45 """Generate a SignatureDef proto for inclusion in a MetaGraphDef. 46 47 The SignatureDef will specify outputs as described in this ExportOutput, 48 and will use the provided receiver_tensors as inputs. 49 50 Args: 51 receiver_tensors: a `Tensor`, or a dict of string to `Tensor`, specifying 52 input nodes that will be fed. 53 """ 54 pass 55 56 def _check_output_key(self, key, error_label): 57 # For multi-head models, the key can be a tuple. 58 if isinstance(key, tuple): 59 key = self._SEPARATOR_CHAR.join(key) 60 61 if not isinstance(key, six.string_types): 62 raise ValueError( 63 '{} output key must be a string; got {}.'.format(error_label, key)) 64 return key 65 66 def _wrap_and_check_outputs( 67 self, outputs, single_output_default_name, error_label=None): 68 """Wraps raw tensors as dicts and checks type. 69 70 Note that we create a new dict here so that we can overwrite the keys 71 if necessary. 72 73 Args: 74 outputs: A `Tensor` or a dict of string to `Tensor`. 75 single_output_default_name: A string key for use in the output dict 76 if the provided `outputs` is a raw tensor. 77 error_label: descriptive string for use in error messages. If none, 78 single_output_default_name will be used. 79 80 Returns: 81 A dict of tensors 82 83 Raises: 84 ValueError: if the outputs dict keys are not strings or tuples of strings 85 or the values are not Tensors. 86 """ 87 if not isinstance(outputs, dict): 88 outputs = {single_output_default_name: outputs} 89 90 output_dict = {} 91 for key, value in outputs.items(): 92 error_name = error_label or single_output_default_name 93 key = self._check_output_key(key, error_name) 94 if not isinstance(value, ops.Tensor): 95 raise ValueError( 96 '{} output value must be a Tensor; got {}.'.format( 97 error_name, value)) 98 99 output_dict[key] = value 100 return output_dict 101 102 103class ClassificationOutput(ExportOutput): 104 """Represents the output of a classification head. 105 106 Either classes or scores or both must be set. 107 108 The classes `Tensor` must provide string labels, not integer class IDs. 109 110 If only classes is set, it is interpreted as providing top-k results in 111 descending order. 112 113 If only scores is set, it is interpreted as providing a score for every class 114 in order of class ID. 115 116 If both classes and scores are set, they are interpreted as zipped, so each 117 score corresponds to the class at the same index. Clients should not depend 118 on the order of the entries. 119 """ 120 121 def __init__(self, scores=None, classes=None): 122 """Constructor for `ClassificationOutput`. 123 124 Args: 125 scores: A float `Tensor` giving scores (sometimes but not always 126 interpretable as probabilities) for each class. May be `None`, but 127 only if `classes` is set. Interpretation varies-- see class doc. 128 classes: A string `Tensor` giving predicted class labels. May be `None`, 129 but only if `scores` is set. Interpretation varies-- see class doc. 130 131 Raises: 132 ValueError: if neither classes nor scores is set, or one of them is not a 133 `Tensor` with the correct dtype. 134 """ 135 if (scores is not None 136 and not (isinstance(scores, ops.Tensor) 137 and scores.dtype.is_floating)): 138 raise ValueError('Classification scores must be a float32 Tensor; ' 139 'got {}'.format(scores)) 140 if (classes is not None 141 and not (isinstance(classes, ops.Tensor) 142 and dtypes.as_dtype(classes.dtype) == dtypes.string)): 143 raise ValueError('Classification classes must be a string Tensor; ' 144 'got {}'.format(classes)) 145 if scores is None and classes is None: 146 raise ValueError('At least one of scores and classes must be set.') 147 148 self._scores = scores 149 self._classes = classes 150 151 @property 152 def scores(self): 153 return self._scores 154 155 @property 156 def classes(self): 157 return self._classes 158 159 def as_signature_def(self, receiver_tensors): 160 if len(receiver_tensors) != 1: 161 raise ValueError('Classification input must be a single string Tensor; ' 162 'got {}'.format(receiver_tensors)) 163 (_, examples), = receiver_tensors.items() 164 if dtypes.as_dtype(examples.dtype) != dtypes.string: 165 raise ValueError('Classification input must be a single string Tensor; ' 166 'got {}'.format(receiver_tensors)) 167 return signature_def_utils.classification_signature_def( 168 examples, self.classes, self.scores) 169 170 171class RegressionOutput(ExportOutput): 172 """Represents the output of a regression head.""" 173 174 def __init__(self, value): 175 """Constructor for `RegressionOutput`. 176 177 Args: 178 value: a float `Tensor` giving the predicted values. Required. 179 180 Raises: 181 ValueError: if the value is not a `Tensor` with dtype tf.float32. 182 """ 183 if not (isinstance(value, ops.Tensor) and value.dtype.is_floating): 184 raise ValueError('Regression output value must be a float32 Tensor; ' 185 'got {}'.format(value)) 186 self._value = value 187 188 @property 189 def value(self): 190 return self._value 191 192 def as_signature_def(self, receiver_tensors): 193 if len(receiver_tensors) != 1: 194 raise ValueError('Regression input must be a single string Tensor; ' 195 'got {}'.format(receiver_tensors)) 196 (_, examples), = receiver_tensors.items() 197 if dtypes.as_dtype(examples.dtype) != dtypes.string: 198 raise ValueError('Regression input must be a single string Tensor; ' 199 'got {}'.format(receiver_tensors)) 200 return signature_def_utils.regression_signature_def(examples, self.value) 201 202 203class PredictOutput(ExportOutput): 204 """Represents the output of a generic prediction head. 205 206 A generic prediction need not be either a classification or a regression. 207 208 Named outputs must be provided as a dict from string to `Tensor`, 209 """ 210 _SINGLE_OUTPUT_DEFAULT_NAME = 'output' 211 212 def __init__(self, outputs): 213 """Constructor for PredictOutput. 214 215 Args: 216 outputs: A `Tensor` or a dict of string to `Tensor` representing the 217 predictions. 218 219 Raises: 220 ValueError: if the outputs is not dict, or any of its keys are not 221 strings, or any of its values are not `Tensor`s. 222 """ 223 224 self._outputs = self._wrap_and_check_outputs( 225 outputs, self._SINGLE_OUTPUT_DEFAULT_NAME, error_label='Prediction') 226 227 @property 228 def outputs(self): 229 return self._outputs 230 231 def as_signature_def(self, receiver_tensors): 232 return signature_def_utils.predict_signature_def(receiver_tensors, 233 self.outputs) 234 235 236class _SupervisedOutput(ExportOutput): 237 """Represents the output of a supervised training or eval process.""" 238 __metaclass__ = abc.ABCMeta 239 240 LOSS_NAME = 'loss' 241 PREDICTIONS_NAME = 'predictions' 242 METRICS_NAME = 'metrics' 243 244 METRIC_VALUE_SUFFIX = 'value' 245 METRIC_UPDATE_SUFFIX = 'update_op' 246 247 _loss = None 248 _predictions = None 249 _metrics = None 250 251 def __init__(self, loss=None, predictions=None, metrics=None): 252 """Constructor for SupervisedOutput (ie, Train or Eval output). 253 254 Args: 255 loss: dict of Tensors or single Tensor representing calculated loss. 256 predictions: dict of Tensors or single Tensor representing model 257 predictions. 258 metrics: Dict of metric results keyed by name. 259 The values of the dict can be one of the following: 260 (1) instance of `Metric` class. 261 (2) (metric_value, update_op) tuples, or a single tuple. 262 metric_value must be a Tensor, and update_op must be a Tensor or Op. 263 264 Raises: 265 ValueError: if any of the outputs' dict keys are not strings or tuples of 266 strings or the values are not Tensors (or Operations in the case of 267 update_op). 268 """ 269 270 if loss is not None: 271 loss_dict = self._wrap_and_check_outputs(loss, self.LOSS_NAME) 272 self._loss = self._prefix_output_keys(loss_dict, self.LOSS_NAME) 273 if predictions is not None: 274 pred_dict = self._wrap_and_check_outputs( 275 predictions, self.PREDICTIONS_NAME) 276 self._predictions = self._prefix_output_keys( 277 pred_dict, self.PREDICTIONS_NAME) 278 if metrics is not None: 279 self._metrics = self._wrap_and_check_metrics(metrics) 280 281 def _prefix_output_keys(self, output_dict, output_name): 282 """Prepend output_name to the output_dict keys if it doesn't exist. 283 284 This produces predictable prefixes for the pre-determined outputs 285 of SupervisedOutput. 286 287 Args: 288 output_dict: dict of string to Tensor, assumed valid. 289 output_name: prefix string to prepend to existing keys. 290 291 Returns: 292 dict with updated keys and existing values. 293 """ 294 295 new_outputs = {} 296 for key, val in output_dict.items(): 297 key = self._prefix_key(key, output_name) 298 new_outputs[key] = val 299 return new_outputs 300 301 def _prefix_key(self, key, output_name): 302 if key.find(output_name) != 0: 303 key = output_name + self._SEPARATOR_CHAR + key 304 return key 305 306 def _wrap_and_check_metrics(self, metrics): 307 """Handle the saving of metrics. 308 309 Metrics is either a tuple of (value, update_op), or a dict of such tuples. 310 Here, we separate out the tuples and create a dict with names to tensors. 311 312 Args: 313 metrics: Dict of metric results keyed by name. 314 The values of the dict can be one of the following: 315 (1) instance of `Metric` class. 316 (2) (metric_value, update_op) tuples, or a single tuple. 317 metric_value must be a Tensor, and update_op must be a Tensor or Op. 318 319 Returns: 320 dict of output_names to tensors 321 322 Raises: 323 ValueError: if the dict key is not a string, or the metric values or ops 324 are not tensors. 325 """ 326 if not isinstance(metrics, dict): 327 metrics = {self.METRICS_NAME: metrics} 328 329 outputs = {} 330 for key, value in metrics.items(): 331 if isinstance(value, tuple): 332 metric_val, metric_op = value 333 else: # value is a keras.Metrics object 334 metric_val = value.result() 335 assert len(value.updates) == 1 # We expect only one update op. 336 metric_op = value.updates[0] 337 key = self._check_output_key(key, self.METRICS_NAME) 338 key = self._prefix_key(key, self.METRICS_NAME) 339 340 val_name = key + self._SEPARATOR_CHAR + self.METRIC_VALUE_SUFFIX 341 op_name = key + self._SEPARATOR_CHAR + self.METRIC_UPDATE_SUFFIX 342 if not isinstance(metric_val, ops.Tensor): 343 raise ValueError( 344 '{} output value must be a Tensor; got {}.'.format( 345 key, metric_val)) 346 if not (tensor_util.is_tf_type(metric_op) or 347 isinstance(metric_op, ops.Operation)): 348 raise ValueError( 349 '{} update_op must be a Tensor or Operation; got {}.'.format( 350 key, metric_op)) 351 352 # We must wrap any ops (or variables) in a Tensor before export, as the 353 # SignatureDef proto expects tensors only. See b/109740581 354 metric_op_tensor = metric_op 355 if not isinstance(metric_op, ops.Tensor): 356 with ops.control_dependencies([metric_op]): 357 metric_op_tensor = constant_op.constant([], name='metric_op_wrapper') 358 359 outputs[val_name] = metric_val 360 outputs[op_name] = metric_op_tensor 361 362 return outputs 363 364 @property 365 def loss(self): 366 return self._loss 367 368 @property 369 def predictions(self): 370 return self._predictions 371 372 @property 373 def metrics(self): 374 return self._metrics 375 376 @abc.abstractmethod 377 def _get_signature_def_fn(self): 378 """Returns a function that produces a SignatureDef given desired outputs.""" 379 pass 380 381 def as_signature_def(self, receiver_tensors): 382 signature_def_fn = self._get_signature_def_fn() 383 return signature_def_fn( 384 receiver_tensors, self.loss, self.predictions, self.metrics) 385 386 387class TrainOutput(_SupervisedOutput): 388 """Represents the output of a supervised training process. 389 390 This class generates the appropriate signature def for exporting 391 training output by type-checking and wrapping loss, predictions, and metrics 392 values. 393 """ 394 395 def _get_signature_def_fn(self): 396 return signature_def_utils.supervised_train_signature_def 397 398 399class EvalOutput(_SupervisedOutput): 400 """Represents the output of a supervised eval process. 401 402 This class generates the appropriate signature def for exporting 403 eval output by type-checking and wrapping loss, predictions, and metrics 404 values. 405 """ 406 407 def _get_signature_def_fn(self): 408 return signature_def_utils.supervised_eval_signature_def 409