1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Classes and methods related to model_fn (deprecated). 17 18This module and all its submodules are deprecated. See 19[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 20for migration instructions. 21""" 22 23from __future__ import absolute_import 24from __future__ import division 25from __future__ import print_function 26 27import collections 28 29import six 30 31from tensorflow.contrib.framework import get_graph_from_inputs 32from tensorflow.contrib.learn.python.learn.estimators import constants 33from tensorflow.contrib.learn.python.learn.estimators import metric_key 34from tensorflow.contrib.learn.python.learn.estimators import prediction_key 35from tensorflow.python.estimator import model_fn as core_model_fn_lib 36from tensorflow.python.estimator.export import export_output as core_export_lib 37from tensorflow.python.framework import dtypes 38from tensorflow.python.framework import ops 39from tensorflow.python.framework import sparse_tensor 40from tensorflow.python.framework import tensor_shape 41from tensorflow.python.ops import array_ops 42from tensorflow.python.platform import tf_logging as logging 43from tensorflow.python.saved_model import signature_constants 44from tensorflow.python.training import session_run_hook 45from tensorflow.python.util.deprecation import deprecated 46 47 48class ModeKeys(object): 49 """Standard names for model modes (deprecated). 50 51 THIS CLASS IS DEPRECATED. 52 53 The following standard keys are defined: 54 55 * `TRAIN`: training mode. 56 * `EVAL`: evaluation mode. 57 * `INFER`: inference mode. 58 """ 59 60 TRAIN = 'train' 61 EVAL = 'eval' 62 INFER = 'infer' 63 64 @classmethod 65 def validate(cls, key): 66 if key not in (cls.TRAIN, cls.EVAL, cls.INFER): 67 raise ValueError('Invalid mode %s.' % key) 68 69 70class ModelFnOps( 71 collections.namedtuple('ModelFnOps', [ 72 'predictions', 'loss', 'train_op', 'eval_metric_ops', 73 'output_alternatives', 'training_chief_hooks', 'training_hooks', 74 'scaffold', 'mode' 75 ])): 76 """Ops returned from a model_fn. 77 78 THIS CLASS IS DEPRECATED. See 79 [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 80 for general migration instructions. 81 """ 82 83 @deprecated(None, 'When switching to tf.estimator.Estimator, use ' 84 'tf.estimator.EstimatorSpec. You can use the `estimator_spec`' 85 ' method to create an equivalent one.') 86 def __new__(cls, 87 mode, 88 predictions=None, 89 loss=None, 90 train_op=None, 91 eval_metric_ops=None, 92 output_alternatives=None, 93 training_chief_hooks=None, 94 training_hooks=None, 95 scaffold=None): 96 """Creates a validated `ModelFnOps` instance. 97 98 For a multi-headed model, the predictions dict here will contain the outputs 99 of all of the heads. However: at serving time, requests will be made 100 specifically for one or more heads, and the RPCs used for these requests may 101 differ by problem type (i.e., regression, classification, other). The 102 purpose of the output_alternatives dict is to aid in exporting a SavedModel 103 from which such head-specific queries can be served. These 104 output_alternatives will be combined with input_alternatives (see 105 `saved_model_export_utils`) to produce a set of `SignatureDef`s specifying 106 the valid requests that can be served from this model. 107 108 For a single-headed model, it is still adviseable to provide 109 output_alternatives with a single entry, because this is how the problem 110 type is communicated for export and serving. If output_alternatives is not 111 given, the resulting SavedModel will support only one head of unspecified 112 type. 113 114 Args: 115 mode: One of `ModeKeys`. Specifies if this training, evaluation or 116 prediction. 117 predictions: Predictions `Tensor` or dict of `Tensor`. 118 loss: Training loss `Tensor`. 119 train_op: Op for the training step. 120 eval_metric_ops: Dict of metric results keyed by name. The values of the 121 dict are the results of calling a metric function, such as `Tensor`. 122 output_alternatives: a dict of 123 `{submodel_name: (problem_type, {tensor_name: Tensor})}`, where 124 `submodel_name` is a submodel identifier that should be consistent 125 across the pipeline (here likely taken from the name of each `Head`, 126 for models that use them), `problem_type` is a `ProblemType`, 127 `tensor_name` is a symbolic name for an output Tensor possibly but not 128 necessarily taken from `PredictionKey`, and `Tensor` is the 129 corresponding output Tensor itself. 130 training_chief_hooks: A list of `SessionRunHook` objects that will be 131 run on the chief worker during training. 132 training_hooks: A list of `SessionRunHook` objects that will be run on 133 all workers during training. 134 scaffold: A `tf.train.Scaffold` object that can be used to set 135 initialization, saver, and more to be used in training. 136 137 Returns: 138 A validated `ModelFnOps` object. 139 140 Raises: 141 ValueError: If validation fails. 142 """ 143 ModeKeys.validate(mode) 144 145 # Assert all ops are from the same graph. 146 get_graph_from_inputs((predictions, loss, train_op)) 147 148 # Validate train_op. 149 if train_op is None: 150 if mode == ModeKeys.TRAIN: 151 raise ValueError('Missing train_op.') 152 elif not isinstance(train_op, ops.Operation): 153 # TODO(ptucker): Should this be allowed? Consider raising error. 154 train_op = ops.convert_to_tensor(train_op).op 155 156 # Validate loss. 157 if loss is None: 158 if mode in (ModeKeys.TRAIN, ModeKeys.EVAL): 159 raise ValueError('Missing loss.') 160 else: 161 loss = ops.convert_to_tensor(loss) 162 loss_shape = loss.get_shape() 163 if loss_shape.num_elements() not in (None, 1): 164 raise ValueError('Loss must be scalar: %s.' % loss) 165 if not loss_shape.is_compatible_with(tensor_shape.scalar()): 166 loss = array_ops.reshape(loss, []) 167 168 # Validate predictions. 169 if predictions is None: 170 if mode == ModeKeys.INFER or mode == ModeKeys.EVAL: 171 raise ValueError('Missing predictions.') 172 else: 173 if isinstance(predictions, dict): 174 predictions = { 175 k: sparse_tensor.convert_to_tensor_or_sparse_tensor(v) 176 for k, v in six.iteritems(predictions) 177 } 178 else: 179 predictions = sparse_tensor.convert_to_tensor_or_sparse_tensor( 180 predictions) 181 182 # Validate eval_metric_ops 183 if eval_metric_ops is None: 184 eval_metric_ops = {} 185 else: 186 if not isinstance(eval_metric_ops, dict): 187 raise ValueError('eval_metric_ops must be a dict.') 188 189 # Validate hooks 190 if training_chief_hooks is None: 191 training_chief_hooks = [] 192 if training_hooks is None: 193 training_hooks = [] 194 for hook in training_hooks + training_chief_hooks: 195 if not isinstance(hook, session_run_hook.SessionRunHook): 196 raise TypeError('All hooks returned from model_fn must be ' 197 'SessionRunHook instances, got instance of %s: %s' % 198 (type(hook), hook)) 199 200 return super(ModelFnOps, cls).__new__( 201 cls, 202 predictions=predictions, 203 loss=loss, 204 train_op=train_op, 205 eval_metric_ops=eval_metric_ops, 206 output_alternatives=output_alternatives, 207 training_chief_hooks=training_chief_hooks, 208 training_hooks=training_hooks, 209 scaffold=scaffold, 210 mode=mode) 211 212 def estimator_spec(self, default_serving_output_alternative_key=None): 213 """Creates an equivalent `EstimatorSpec`. 214 215 Args: 216 default_serving_output_alternative_key: Required for multiple heads. If 217 you have multiple entries in `output_alternatives` dict (comparable to 218 multiple heads), `EstimatorSpec` requires a default head that will be 219 used if a Servo request does not explicitly mention which head to infer 220 on. Pass the key of the output alternative here that you want to 221 designate as default. A separate ExportOutpout for this default head 222 will be added to the export_outputs dict with the special key 223 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, unless there is 224 already an enry in output_alternatives with this special key. 225 226 Returns: 227 Instance of `EstimatorSpec` that is equivalent to this `ModelFnOps` 228 229 Raises: 230 ValueError: If problem type is unknown. 231 """ 232 def _scores(output_tensors): 233 scores = output_tensors.get(prediction_key.PredictionKey.SCORES) 234 if scores is None: 235 scores = output_tensors.get(prediction_key.PredictionKey.PROBABILITIES) 236 return scores 237 238 def _classes(output_tensors): # pylint: disable=missing-docstring 239 classes = output_tensors.get(prediction_key.PredictionKey.CLASSES) 240 if classes is None: 241 logging.warning( 242 'classes is None, Servo inference will not have class ids.') 243 return None 244 elif classes.dtype != dtypes.string: 245 # Servo classification can only serve string classes 246 logging.warning( 247 'classes is not string, Servo inference will not have class ids.') 248 return None 249 250 return classes 251 252 def _export_output(problem_type, predictions): # pylint: disable=missing-docstring 253 if problem_type == constants.ProblemType.LINEAR_REGRESSION: 254 return core_export_lib.RegressionOutput(_scores(predictions)) 255 256 if (problem_type == constants.ProblemType.CLASSIFICATION or 257 problem_type == constants.ProblemType.LOGISTIC_REGRESSION): 258 return core_export_lib.ClassificationOutput( 259 scores=_scores(predictions), classes=_classes(predictions)) 260 261 if problem_type == constants.ProblemType.UNSPECIFIED: 262 return core_export_lib.PredictOutput(predictions) 263 264 raise ValueError('Unknown problem_type=%s' % problem_type) 265 266 # Converts output_alternatives 267 export_outputs_dict = None 268 if self.output_alternatives: 269 output_alternatives = self.output_alternatives 270 # Adds default output_alternative if needed. 271 if (len(output_alternatives) > 1 and 272 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY not in 273 output_alternatives): 274 output_alternatives = output_alternatives.copy() 275 output_alternatives[ 276 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = ( 277 output_alternatives[default_serving_output_alternative_key]) 278 export_outputs_dict = {key: _export_output(*val) for key, val in 279 output_alternatives.items()} 280 281 def _get_eval_metric_ops(): 282 """Returns self.eval_metric_ops without loss metric.""" 283 result = {} 284 for key, value in six.iteritems(self.eval_metric_ops): 285 if key != metric_key.MetricKey.LOSS: 286 result[key] = value 287 return result 288 289 # Convert the contrib mode enum to the core mode enum. 290 # Note: mode already validated in __new__(). 291 if self.mode == ModeKeys.TRAIN: 292 core_mode = core_model_fn_lib.ModeKeys.TRAIN 293 elif self.mode == ModeKeys.EVAL: 294 core_mode = core_model_fn_lib.ModeKeys.EVAL 295 elif self.mode == ModeKeys.INFER: 296 core_mode = core_model_fn_lib.ModeKeys.PREDICT 297 298 return core_model_fn_lib.EstimatorSpec( 299 mode=core_mode, 300 predictions=self.predictions, 301 loss=self.loss, 302 train_op=self.train_op, 303 eval_metric_ops=_get_eval_metric_ops(), 304 export_outputs=export_outputs_dict, 305 training_chief_hooks=self.training_chief_hooks, 306 training_hooks=self.training_hooks, 307 scaffold=self.scaffold) 308