1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""ModelFnOps tests.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import six 22 23from tensorflow.contrib.learn.python.learn.estimators import constants 24from tensorflow.contrib.learn.python.learn.estimators import model_fn 25from tensorflow.python.client import session 26from tensorflow.python.estimator.export import export_output as core_export_lib 27from tensorflow.python.framework import constant_op 28from tensorflow.python.ops import control_flow_ops 29from tensorflow.python.platform import test 30from tensorflow.python.saved_model import signature_constants 31from tensorflow.python.training import basic_session_run_hooks 32from tensorflow.python.training import monitored_session 33 34 35class ModelFnopsTest(test.TestCase): 36 """Multi-output tests.""" 37 38 def create_predictions(self): 39 probabilities = constant_op.constant([1., 1., 1.]) 40 scores = constant_op.constant([1., 2., 3.]) 41 classes = constant_op.constant([b"0", b"1", b"2"]) 42 return { 43 "probabilities": probabilities, 44 "scores": scores, 45 "classes": classes} 46 47 def create_model_fn_ops(self, predictions, output_alternatives, 48 mode=model_fn.ModeKeys.INFER): 49 50 return model_fn.ModelFnOps( 51 model_fn.ModeKeys.INFER, 52 predictions=predictions, 53 loss=constant_op.constant([1]), 54 train_op=control_flow_ops.no_op(), 55 eval_metric_ops={ 56 "metric_key": (constant_op.constant(1.), control_flow_ops.no_op()), 57 "loss": (constant_op.constant(1.), control_flow_ops.no_op()), 58 }, 59 training_chief_hooks=[basic_session_run_hooks.StepCounterHook()], 60 training_hooks=[basic_session_run_hooks.StepCounterHook()], 61 output_alternatives=output_alternatives, 62 scaffold=monitored_session.Scaffold()) 63 64 def assertEquals_except_export_and_eval_loss( 65 self, model_fn_ops, estimator_spec): 66 expected_eval_metric_ops = {} 67 for key, value in six.iteritems(model_fn_ops.eval_metric_ops): 68 if key != "loss": 69 expected_eval_metric_ops[key] = value 70 self.assertEqual(model_fn_ops.predictions, estimator_spec.predictions) 71 self.assertEqual(model_fn_ops.loss, estimator_spec.loss) 72 self.assertEqual(model_fn_ops.train_op, estimator_spec.train_op) 73 self.assertEqual(expected_eval_metric_ops, 74 estimator_spec.eval_metric_ops) 75 self.assertAllEqual(model_fn_ops.training_chief_hooks, 76 estimator_spec.training_chief_hooks) 77 self.assertAllEqual(model_fn_ops.training_hooks, 78 estimator_spec.training_hooks) 79 self.assertEqual(model_fn_ops.scaffold, estimator_spec.scaffold) 80 81 def testEstimatorSpec_except_export(self): 82 predictions = self.create_predictions() 83 model_fn_ops = self.create_model_fn_ops( 84 predictions, None, mode=model_fn.ModeKeys.INFER) 85 86 estimator_spec = model_fn_ops.estimator_spec() 87 self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec) 88 89 def testEstimatorSpec_export_regression_with_scores(self): 90 predictions = self.create_predictions() 91 output_alternatives = {"regression_head": ( 92 constants.ProblemType.LINEAR_REGRESSION, predictions)} 93 model_fn_ops = self.create_model_fn_ops( 94 predictions, output_alternatives, mode=model_fn.ModeKeys.INFER) 95 96 estimator_spec = model_fn_ops.estimator_spec() 97 self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec) 98 99 with session.Session(): 100 regression_output = estimator_spec.export_outputs["regression_head"] 101 self.assertTrue(isinstance( 102 regression_output, core_export_lib.RegressionOutput)) 103 self.assertAllEqual(predictions["scores"].eval(), 104 regression_output.value.eval()) 105 106 def testEstimatorSpec_export_regression_with_probabilities(self): 107 predictions = self.create_predictions() 108 output_alternatives_predictions = predictions.copy() 109 del output_alternatives_predictions["scores"] 110 output_alternatives = {"regression_head": ( 111 constants.ProblemType.LINEAR_REGRESSION, 112 output_alternatives_predictions)} 113 model_fn_ops = self.create_model_fn_ops( 114 predictions, output_alternatives, mode=model_fn.ModeKeys.INFER) 115 116 estimator_spec = model_fn_ops.estimator_spec() 117 self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec) 118 119 with session.Session(): 120 regression_output = estimator_spec.export_outputs["regression_head"] 121 self.assertTrue(isinstance( 122 regression_output, core_export_lib.RegressionOutput)) 123 self.assertAllEqual(predictions["probabilities"].eval(), 124 regression_output.value.eval()) 125 126 def testEstimatorSpec_export_classification(self): 127 predictions = self.create_predictions() 128 output_alternatives = {"classification_head": ( 129 constants.ProblemType.CLASSIFICATION, predictions)} 130 model_fn_ops = self.create_model_fn_ops( 131 predictions, output_alternatives, mode=model_fn.ModeKeys.INFER) 132 133 estimator_spec = model_fn_ops.estimator_spec() 134 self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec) 135 136 with session.Session(): 137 classification_output = estimator_spec.export_outputs[ 138 "classification_head"] 139 self.assertTrue(isinstance(classification_output, 140 core_export_lib.ClassificationOutput)) 141 self.assertAllEqual(predictions["scores"].eval(), 142 classification_output.scores.eval()) 143 self.assertAllEqual(predictions["classes"].eval(), 144 classification_output.classes.eval()) 145 146 def testEstimatorSpec_export_classification_with_missing_scores(self): 147 predictions = self.create_predictions() 148 output_alternatives_predictions = predictions.copy() 149 del output_alternatives_predictions["scores"] 150 output_alternatives = {"classification_head": ( 151 constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)} 152 model_fn_ops = self.create_model_fn_ops( 153 predictions, output_alternatives, mode=model_fn.ModeKeys.INFER) 154 155 estimator_spec = model_fn_ops.estimator_spec() 156 self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec) 157 158 with session.Session(): 159 classification_output = estimator_spec.export_outputs[ 160 "classification_head"] 161 self.assertTrue(isinstance(classification_output, 162 core_export_lib.ClassificationOutput)) 163 self.assertAllEqual(predictions["probabilities"].eval(), 164 classification_output.scores.eval()) 165 self.assertAllEqual(predictions["classes"].eval(), 166 classification_output.classes.eval()) 167 168 def testEstimatorSpec_export_classification_with_missing_scores_proba(self): 169 predictions = self.create_predictions() 170 output_alternatives_predictions = predictions.copy() 171 del output_alternatives_predictions["scores"] 172 del output_alternatives_predictions["probabilities"] 173 output_alternatives = {"classification_head": ( 174 constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)} 175 model_fn_ops = self.create_model_fn_ops( 176 predictions, output_alternatives, mode=model_fn.ModeKeys.INFER) 177 178 estimator_spec = model_fn_ops.estimator_spec() 179 self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec) 180 181 with session.Session(): 182 classification_output = estimator_spec.export_outputs[ 183 "classification_head"] 184 self.assertTrue(isinstance(classification_output, 185 core_export_lib.ClassificationOutput)) 186 self.assertIsNone(classification_output.scores) 187 self.assertAllEqual(predictions["classes"].eval(), 188 classification_output.classes.eval()) 189 190 def testEstimatorSpec_export_classification_with_missing_classes(self): 191 predictions = self.create_predictions() 192 output_alternatives_predictions = predictions.copy() 193 del output_alternatives_predictions["classes"] 194 output_alternatives = {"classification_head": ( 195 constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)} 196 model_fn_ops = self.create_model_fn_ops( 197 predictions, output_alternatives, mode=model_fn.ModeKeys.INFER) 198 199 estimator_spec = model_fn_ops.estimator_spec() 200 self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec) 201 202 with session.Session(): 203 classification_output = estimator_spec.export_outputs[ 204 "classification_head"] 205 self.assertTrue(isinstance(classification_output, 206 core_export_lib.ClassificationOutput)) 207 self.assertAllEqual(predictions["scores"].eval(), 208 classification_output.scores.eval()) 209 self.assertIsNone(classification_output.classes) 210 211 def testEstimatorSpec_export_classification_with_nonstring_classes(self): 212 predictions = self.create_predictions() 213 output_alternatives_predictions = predictions.copy() 214 output_alternatives_predictions["classes"] = constant_op.constant( 215 [1, 2, 3]) 216 output_alternatives = {"classification_head": ( 217 constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)} 218 model_fn_ops = self.create_model_fn_ops( 219 predictions, output_alternatives, mode=model_fn.ModeKeys.INFER) 220 221 estimator_spec = model_fn_ops.estimator_spec() 222 self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec) 223 224 with session.Session(): 225 classification_output = estimator_spec.export_outputs[ 226 "classification_head"] 227 self.assertTrue(isinstance(classification_output, 228 core_export_lib.ClassificationOutput)) 229 self.assertAllEqual(predictions["scores"].eval(), 230 classification_output.scores.eval()) 231 self.assertIsNone(classification_output.classes) 232 233 def testEstimatorSpec_export_logistic(self): 234 predictions = self.create_predictions() 235 output_alternatives = {"logistic_head": ( 236 constants.ProblemType.LOGISTIC_REGRESSION, predictions)} 237 model_fn_ops = self.create_model_fn_ops( 238 predictions, output_alternatives, mode=model_fn.ModeKeys.INFER) 239 240 estimator_spec = model_fn_ops.estimator_spec() 241 self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec) 242 243 with session.Session(): 244 logistic_output = estimator_spec.export_outputs["logistic_head"] 245 self.assertTrue(isinstance(logistic_output, 246 core_export_lib.ClassificationOutput)) 247 self.assertAllEqual(predictions["scores"].eval(), 248 logistic_output.scores.eval()) 249 self.assertAllEqual(predictions["classes"].eval(), 250 logistic_output.classes.eval()) 251 252 def testEstimatorSpec_export_unspecified(self): 253 predictions = self.create_predictions() 254 output_alternatives = {"unspecified_head": ( 255 constants.ProblemType.UNSPECIFIED, predictions)} 256 257 model_fn_ops = self.create_model_fn_ops( 258 predictions, output_alternatives, mode=model_fn.ModeKeys.INFER) 259 260 estimator_spec = model_fn_ops.estimator_spec() 261 self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec) 262 263 with session.Session(): 264 unspecified_output = estimator_spec.export_outputs["unspecified_head"] 265 self.assertTrue(isinstance(unspecified_output, 266 core_export_lib.PredictOutput)) 267 self.assertEqual(predictions, unspecified_output.outputs) 268 269 def testEstimatorSpec_export_multihead(self): 270 predictions = self.create_predictions() 271 output_alternatives = { 272 "regression_head": ( 273 constants.ProblemType.LINEAR_REGRESSION, predictions), 274 "classification_head": ( 275 constants.ProblemType.CLASSIFICATION, predictions)} 276 model_fn_ops = self.create_model_fn_ops( 277 predictions, output_alternatives, mode=model_fn.ModeKeys.INFER) 278 279 estimator_spec = model_fn_ops.estimator_spec("regression_head") 280 self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec) 281 282 with session.Session(): 283 regression_output = estimator_spec.export_outputs["regression_head"] 284 self.assertTrue(isinstance( 285 regression_output, core_export_lib.RegressionOutput)) 286 self.assertAllEqual(predictions["scores"].eval(), 287 regression_output.value.eval()) 288 289 default_output = estimator_spec.export_outputs[ 290 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] 291 self.assertTrue(isinstance(default_output, 292 core_export_lib.RegressionOutput)) 293 self.assertAllEqual(predictions["scores"].eval(), 294 default_output.value.eval()) 295 296if __name__ == "__main__": 297 test.main() 298