1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Classes for different types of export output.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import abc 22 23import six 24 25 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import ops 29from tensorflow.python.saved_model import signature_def_utils 30 31 32class ExportOutput(object): 33 """Represents an output of a model that can be served. 34 35 These typically correspond to model heads. 36 """ 37 38 __metaclass__ = abc.ABCMeta 39 40 _SEPARATOR_CHAR = '/' 41 42 @abc.abstractmethod 43 def as_signature_def(self, receiver_tensors): 44 """Generate a SignatureDef proto for inclusion in a MetaGraphDef. 45 46 The SignatureDef will specify outputs as described in this ExportOutput, 47 and will use the provided receiver_tensors as inputs. 48 49 Args: 50 receiver_tensors: a `Tensor`, or a dict of string to `Tensor`, specifying 51 input nodes that will be fed. 52 """ 53 pass 54 55 def _check_output_key(self, key, error_label): 56 # For multi-head models, the key can be a tuple. 57 if isinstance(key, tuple): 58 key = self._SEPARATOR_CHAR.join(key) 59 60 if not isinstance(key, six.string_types): 61 raise ValueError( 62 '{} output key must be a string; got {}.'.format(error_label, key)) 63 return key 64 65 def _wrap_and_check_outputs( 66 self, outputs, single_output_default_name, error_label=None): 67 """Wraps raw tensors as dicts and checks type. 68 69 Note that we create a new dict here so that we can overwrite the keys 70 if necessary. 71 72 Args: 73 outputs: A `Tensor` or a dict of string to `Tensor`. 74 single_output_default_name: A string key for use in the output dict 75 if the provided `outputs` is a raw tensor. 76 error_label: descriptive string for use in error messages. If none, 77 single_output_default_name will be used. 78 79 Returns: 80 A dict of tensors 81 82 Raises: 83 ValueError: if the outputs dict keys are not strings or tuples of strings 84 or the values are not Tensors. 85 """ 86 if not isinstance(outputs, dict): 87 outputs = {single_output_default_name: outputs} 88 89 output_dict = {} 90 for key, value in outputs.items(): 91 error_name = error_label or single_output_default_name 92 key = self._check_output_key(key, error_name) 93 if not isinstance(value, ops.Tensor): 94 raise ValueError( 95 '{} output value must be a Tensor; got {}.'.format( 96 error_name, value)) 97 98 output_dict[key] = value 99 return output_dict 100 101 102class ClassificationOutput(ExportOutput): 103 """Represents the output of a classification head. 104 105 Either classes or scores or both must be set. 106 107 The classes `Tensor` must provide string labels, not integer class IDs. 108 109 If only classes is set, it is interpreted as providing top-k results in 110 descending order. 111 112 If only scores is set, it is interpreted as providing a score for every class 113 in order of class ID. 114 115 If both classes and scores are set, they are interpreted as zipped, so each 116 score corresponds to the class at the same index. Clients should not depend 117 on the order of the entries. 118 """ 119 120 def __init__(self, scores=None, classes=None): 121 """Constructor for `ClassificationOutput`. 122 123 Args: 124 scores: A float `Tensor` giving scores (sometimes but not always 125 interpretable as probabilities) for each class. May be `None`, but 126 only if `classes` is set. Interpretation varies-- see class doc. 127 classes: A string `Tensor` giving predicted class labels. May be `None`, 128 but only if `scores` is set. Interpretation varies-- see class doc. 129 130 Raises: 131 ValueError: if neither classes nor scores is set, or one of them is not a 132 `Tensor` with the correct dtype. 133 """ 134 if (scores is not None 135 and not (isinstance(scores, ops.Tensor) 136 and scores.dtype.is_floating)): 137 raise ValueError('Classification scores must be a float32 Tensor; ' 138 'got {}'.format(scores)) 139 if (classes is not None 140 and not (isinstance(classes, ops.Tensor) 141 and dtypes.as_dtype(classes.dtype) == dtypes.string)): 142 raise ValueError('Classification classes must be a string Tensor; ' 143 'got {}'.format(classes)) 144 if scores is None and classes is None: 145 raise ValueError('At least one of scores and classes must be set.') 146 147 self._scores = scores 148 self._classes = classes 149 150 @property 151 def scores(self): 152 return self._scores 153 154 @property 155 def classes(self): 156 return self._classes 157 158 def as_signature_def(self, receiver_tensors): 159 if len(receiver_tensors) != 1: 160 raise ValueError('Classification input must be a single string Tensor; ' 161 'got {}'.format(receiver_tensors)) 162 (_, examples), = receiver_tensors.items() 163 if dtypes.as_dtype(examples.dtype) != dtypes.string: 164 raise ValueError('Classification input must be a single string Tensor; ' 165 'got {}'.format(receiver_tensors)) 166 return signature_def_utils.classification_signature_def( 167 examples, self.classes, self.scores) 168 169 170class RegressionOutput(ExportOutput): 171 """Represents the output of a regression head.""" 172 173 def __init__(self, value): 174 """Constructor for `RegressionOutput`. 175 176 Args: 177 value: a float `Tensor` giving the predicted values. Required. 178 179 Raises: 180 ValueError: if the value is not a `Tensor` with dtype tf.float32. 181 """ 182 if not (isinstance(value, ops.Tensor) and value.dtype.is_floating): 183 raise ValueError('Regression output value must be a float32 Tensor; ' 184 'got {}'.format(value)) 185 self._value = value 186 187 @property 188 def value(self): 189 return self._value 190 191 def as_signature_def(self, receiver_tensors): 192 if len(receiver_tensors) != 1: 193 raise ValueError('Regression input must be a single string Tensor; ' 194 'got {}'.format(receiver_tensors)) 195 (_, examples), = receiver_tensors.items() 196 if dtypes.as_dtype(examples.dtype) != dtypes.string: 197 raise ValueError('Regression input must be a single string Tensor; ' 198 'got {}'.format(receiver_tensors)) 199 return signature_def_utils.regression_signature_def(examples, self.value) 200 201 202class PredictOutput(ExportOutput): 203 """Represents the output of a generic prediction head. 204 205 A generic prediction need not be either a classification or a regression. 206 207 Named outputs must be provided as a dict from string to `Tensor`, 208 """ 209 _SINGLE_OUTPUT_DEFAULT_NAME = 'output' 210 211 def __init__(self, outputs): 212 """Constructor for PredictOutput. 213 214 Args: 215 outputs: A `Tensor` or a dict of string to `Tensor` representing the 216 predictions. 217 218 Raises: 219 ValueError: if the outputs is not dict, or any of its keys are not 220 strings, or any of its values are not `Tensor`s. 221 """ 222 223 self._outputs = self._wrap_and_check_outputs( 224 outputs, self._SINGLE_OUTPUT_DEFAULT_NAME, error_label='Prediction') 225 226 @property 227 def outputs(self): 228 return self._outputs 229 230 def as_signature_def(self, receiver_tensors): 231 return signature_def_utils.predict_signature_def(receiver_tensors, 232 self.outputs) 233 234 235class _SupervisedOutput(ExportOutput): 236 """Represents the output of a supervised training or eval process.""" 237 __metaclass__ = abc.ABCMeta 238 239 LOSS_NAME = 'loss' 240 PREDICTIONS_NAME = 'predictions' 241 METRICS_NAME = 'metrics' 242 243 METRIC_VALUE_SUFFIX = 'value' 244 METRIC_UPDATE_SUFFIX = 'update_op' 245 246 _loss = None 247 _predictions = None 248 _metrics = None 249 250 def __init__(self, loss=None, predictions=None, metrics=None): 251 """Constructor for SupervisedOutput (ie, Train or Eval output). 252 253 Args: 254 loss: dict of Tensors or single Tensor representing calculated loss. 255 predictions: dict of Tensors or single Tensor representing model 256 predictions. 257 metrics: Dict of metric results keyed by name. 258 The values of the dict can be one of the following: 259 (1) instance of `Metric` class. 260 (2) (metric_value, update_op) tuples, or a single tuple. 261 metric_value must be a Tensor, and update_op must be a Tensor or Op. 262 263 Raises: 264 ValueError: if any of the outputs' dict keys are not strings or tuples of 265 strings or the values are not Tensors (or Operations in the case of 266 update_op). 267 """ 268 269 if loss is not None: 270 loss_dict = self._wrap_and_check_outputs(loss, self.LOSS_NAME) 271 self._loss = self._prefix_output_keys(loss_dict, self.LOSS_NAME) 272 if predictions is not None: 273 pred_dict = self._wrap_and_check_outputs( 274 predictions, self.PREDICTIONS_NAME) 275 self._predictions = self._prefix_output_keys( 276 pred_dict, self.PREDICTIONS_NAME) 277 if metrics is not None: 278 self._metrics = self._wrap_and_check_metrics(metrics) 279 280 def _prefix_output_keys(self, output_dict, output_name): 281 """Prepend output_name to the output_dict keys if it doesn't exist. 282 283 This produces predictable prefixes for the pre-determined outputs 284 of SupervisedOutput. 285 286 Args: 287 output_dict: dict of string to Tensor, assumed valid. 288 output_name: prefix string to prepend to existing keys. 289 290 Returns: 291 dict with updated keys and existing values. 292 """ 293 294 new_outputs = {} 295 for key, val in output_dict.items(): 296 key = self._prefix_key(key, output_name) 297 new_outputs[key] = val 298 return new_outputs 299 300 def _prefix_key(self, key, output_name): 301 if key.find(output_name) != 0: 302 key = output_name + self._SEPARATOR_CHAR + key 303 return key 304 305 def _wrap_and_check_metrics(self, metrics): 306 """Handle the saving of metrics. 307 308 Metrics is either a tuple of (value, update_op), or a dict of such tuples. 309 Here, we separate out the tuples and create a dict with names to tensors. 310 311 Args: 312 metrics: Dict of metric results keyed by name. 313 The values of the dict can be one of the following: 314 (1) instance of `Metric` class. 315 (2) (metric_value, update_op) tuples, or a single tuple. 316 metric_value must be a Tensor, and update_op must be a Tensor or Op. 317 318 Returns: 319 dict of output_names to tensors 320 321 Raises: 322 ValueError: if the dict key is not a string, or the metric values or ops 323 are not tensors. 324 """ 325 if not isinstance(metrics, dict): 326 metrics = {self.METRICS_NAME: metrics} 327 328 outputs = {} 329 for key, value in metrics.items(): 330 if isinstance(value, tuple): 331 metric_val, metric_op = value 332 else: # value is a keras.Metrics object 333 metric_val = value.result() 334 assert len(value.updates) == 1 # We expect only one update op. 335 metric_op = value.updates[0] 336 key = self._check_output_key(key, self.METRICS_NAME) 337 key = self._prefix_key(key, self.METRICS_NAME) 338 339 val_name = key + self._SEPARATOR_CHAR + self.METRIC_VALUE_SUFFIX 340 op_name = key + self._SEPARATOR_CHAR + self.METRIC_UPDATE_SUFFIX 341 if not isinstance(metric_val, ops.Tensor): 342 raise ValueError( 343 '{} output value must be a Tensor; got {}.'.format( 344 key, metric_val)) 345 if (not isinstance(metric_op, ops.Tensor) and 346 not isinstance(metric_op, ops.Operation)): 347 raise ValueError( 348 '{} update_op must be a Tensor or Operation; got {}.'.format( 349 key, metric_op)) 350 351 # We must wrap any ops in a Tensor before export, as the SignatureDef 352 # proto expects tensors only. See b/109740581 353 metric_op_tensor = metric_op 354 if isinstance(metric_op, ops.Operation): 355 with ops.control_dependencies([metric_op]): 356 metric_op_tensor = constant_op.constant([], name='metric_op_wrapper') 357 358 outputs[val_name] = metric_val 359 outputs[op_name] = metric_op_tensor 360 361 return outputs 362 363 @property 364 def loss(self): 365 return self._loss 366 367 @property 368 def predictions(self): 369 return self._predictions 370 371 @property 372 def metrics(self): 373 return self._metrics 374 375 @abc.abstractmethod 376 def _get_signature_def_fn(self): 377 """Returns a function that produces a SignatureDef given desired outputs.""" 378 pass 379 380 def as_signature_def(self, receiver_tensors): 381 signature_def_fn = self._get_signature_def_fn() 382 return signature_def_fn( 383 receiver_tensors, self.loss, self.predictions, self.metrics) 384 385 386class TrainOutput(_SupervisedOutput): 387 """Represents the output of a supervised training process. 388 389 This class generates the appropriate signature def for exporting 390 training output by type-checking and wrapping loss, predictions, and metrics 391 values. 392 """ 393 394 def _get_signature_def_fn(self): 395 return signature_def_utils.supervised_train_signature_def 396 397 398class EvalOutput(_SupervisedOutput): 399 """Represents the output of a supervised eval process. 400 401 This class generates the appropriate signature def for exporting 402 eval output by type-checking and wrapping loss, predictions, and metrics 403 values. 404 """ 405 406 def _get_signature_def_fn(self): 407 return signature_def_utils.supervised_eval_signature_def 408