• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""ModelFnOps tests."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import six
22
23from tensorflow.contrib.learn.python.learn.estimators import constants
24from tensorflow.contrib.learn.python.learn.estimators import model_fn
25from tensorflow.python.client import session
26from tensorflow.python.estimator.export import export_output as core_export_lib
27from tensorflow.python.framework import constant_op
28from tensorflow.python.ops import control_flow_ops
29from tensorflow.python.platform import test
30from tensorflow.python.saved_model import signature_constants
31from tensorflow.python.training import basic_session_run_hooks
32from tensorflow.python.training import monitored_session
33
34
35class ModelFnopsTest(test.TestCase):
36  """Multi-output tests."""
37
38  def create_predictions(self):
39    probabilities = constant_op.constant([1., 1., 1.])
40    scores = constant_op.constant([1., 2., 3.])
41    classes = constant_op.constant([b"0", b"1", b"2"])
42    return {
43        "probabilities": probabilities,
44        "scores": scores,
45        "classes": classes}
46
47  def create_model_fn_ops(self, predictions, output_alternatives,
48                          mode=model_fn.ModeKeys.INFER):
49
50    return model_fn.ModelFnOps(
51        model_fn.ModeKeys.INFER,
52        predictions=predictions,
53        loss=constant_op.constant([1]),
54        train_op=control_flow_ops.no_op(),
55        eval_metric_ops={
56            "metric_key": (constant_op.constant(1.), control_flow_ops.no_op()),
57            "loss": (constant_op.constant(1.), control_flow_ops.no_op()),
58        },
59        training_chief_hooks=[basic_session_run_hooks.StepCounterHook()],
60        training_hooks=[basic_session_run_hooks.StepCounterHook()],
61        output_alternatives=output_alternatives,
62        scaffold=monitored_session.Scaffold())
63
64  def assertEquals_except_export_and_eval_loss(
65      self, model_fn_ops, estimator_spec):
66    expected_eval_metric_ops = {}
67    for key, value in six.iteritems(model_fn_ops.eval_metric_ops):
68      if key != "loss":
69        expected_eval_metric_ops[key] = value
70    self.assertEqual(model_fn_ops.predictions, estimator_spec.predictions)
71    self.assertEqual(model_fn_ops.loss, estimator_spec.loss)
72    self.assertEqual(model_fn_ops.train_op, estimator_spec.train_op)
73    self.assertEqual(expected_eval_metric_ops,
74                     estimator_spec.eval_metric_ops)
75    self.assertAllEqual(model_fn_ops.training_chief_hooks,
76                        estimator_spec.training_chief_hooks)
77    self.assertAllEqual(model_fn_ops.training_hooks,
78                        estimator_spec.training_hooks)
79    self.assertEqual(model_fn_ops.scaffold, estimator_spec.scaffold)
80
81  def testEstimatorSpec_except_export(self):
82    predictions = self.create_predictions()
83    model_fn_ops = self.create_model_fn_ops(
84        predictions, None, mode=model_fn.ModeKeys.INFER)
85
86    estimator_spec = model_fn_ops.estimator_spec()
87    self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
88
89  def testEstimatorSpec_export_regression_with_scores(self):
90    predictions = self.create_predictions()
91    output_alternatives = {"regression_head": (
92        constants.ProblemType.LINEAR_REGRESSION, predictions)}
93    model_fn_ops = self.create_model_fn_ops(
94        predictions, output_alternatives, mode=model_fn.ModeKeys.INFER)
95
96    estimator_spec = model_fn_ops.estimator_spec()
97    self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
98
99    with session.Session():
100      regression_output = estimator_spec.export_outputs["regression_head"]
101      self.assertTrue(isinstance(
102          regression_output, core_export_lib.RegressionOutput))
103      self.assertAllEqual(predictions["scores"].eval(),
104                          regression_output.value.eval())
105
106  def testEstimatorSpec_export_regression_with_probabilities(self):
107    predictions = self.create_predictions()
108    output_alternatives_predictions = predictions.copy()
109    del output_alternatives_predictions["scores"]
110    output_alternatives = {"regression_head": (
111        constants.ProblemType.LINEAR_REGRESSION,
112        output_alternatives_predictions)}
113    model_fn_ops = self.create_model_fn_ops(
114        predictions, output_alternatives, mode=model_fn.ModeKeys.INFER)
115
116    estimator_spec = model_fn_ops.estimator_spec()
117    self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
118
119    with session.Session():
120      regression_output = estimator_spec.export_outputs["regression_head"]
121      self.assertTrue(isinstance(
122          regression_output, core_export_lib.RegressionOutput))
123      self.assertAllEqual(predictions["probabilities"].eval(),
124                          regression_output.value.eval())
125
126  def testEstimatorSpec_export_classification(self):
127    predictions = self.create_predictions()
128    output_alternatives = {"classification_head": (
129        constants.ProblemType.CLASSIFICATION, predictions)}
130    model_fn_ops = self.create_model_fn_ops(
131        predictions, output_alternatives, mode=model_fn.ModeKeys.INFER)
132
133    estimator_spec = model_fn_ops.estimator_spec()
134    self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
135
136    with session.Session():
137      classification_output = estimator_spec.export_outputs[
138          "classification_head"]
139      self.assertTrue(isinstance(classification_output,
140                                 core_export_lib.ClassificationOutput))
141      self.assertAllEqual(predictions["scores"].eval(),
142                          classification_output.scores.eval())
143      self.assertAllEqual(predictions["classes"].eval(),
144                          classification_output.classes.eval())
145
146  def testEstimatorSpec_export_classification_with_missing_scores(self):
147    predictions = self.create_predictions()
148    output_alternatives_predictions = predictions.copy()
149    del output_alternatives_predictions["scores"]
150    output_alternatives = {"classification_head": (
151        constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)}
152    model_fn_ops = self.create_model_fn_ops(
153        predictions, output_alternatives, mode=model_fn.ModeKeys.INFER)
154
155    estimator_spec = model_fn_ops.estimator_spec()
156    self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
157
158    with session.Session():
159      classification_output = estimator_spec.export_outputs[
160          "classification_head"]
161      self.assertTrue(isinstance(classification_output,
162                                 core_export_lib.ClassificationOutput))
163      self.assertAllEqual(predictions["probabilities"].eval(),
164                          classification_output.scores.eval())
165      self.assertAllEqual(predictions["classes"].eval(),
166                          classification_output.classes.eval())
167
168  def testEstimatorSpec_export_classification_with_missing_scores_proba(self):
169    predictions = self.create_predictions()
170    output_alternatives_predictions = predictions.copy()
171    del output_alternatives_predictions["scores"]
172    del output_alternatives_predictions["probabilities"]
173    output_alternatives = {"classification_head": (
174        constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)}
175    model_fn_ops = self.create_model_fn_ops(
176        predictions, output_alternatives, mode=model_fn.ModeKeys.INFER)
177
178    estimator_spec = model_fn_ops.estimator_spec()
179    self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
180
181    with session.Session():
182      classification_output = estimator_spec.export_outputs[
183          "classification_head"]
184      self.assertTrue(isinstance(classification_output,
185                                 core_export_lib.ClassificationOutput))
186      self.assertIsNone(classification_output.scores)
187      self.assertAllEqual(predictions["classes"].eval(),
188                          classification_output.classes.eval())
189
190  def testEstimatorSpec_export_classification_with_missing_classes(self):
191    predictions = self.create_predictions()
192    output_alternatives_predictions = predictions.copy()
193    del output_alternatives_predictions["classes"]
194    output_alternatives = {"classification_head": (
195        constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)}
196    model_fn_ops = self.create_model_fn_ops(
197        predictions, output_alternatives, mode=model_fn.ModeKeys.INFER)
198
199    estimator_spec = model_fn_ops.estimator_spec()
200    self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
201
202    with session.Session():
203      classification_output = estimator_spec.export_outputs[
204          "classification_head"]
205      self.assertTrue(isinstance(classification_output,
206                                 core_export_lib.ClassificationOutput))
207      self.assertAllEqual(predictions["scores"].eval(),
208                          classification_output.scores.eval())
209      self.assertIsNone(classification_output.classes)
210
211  def testEstimatorSpec_export_classification_with_nonstring_classes(self):
212    predictions = self.create_predictions()
213    output_alternatives_predictions = predictions.copy()
214    output_alternatives_predictions["classes"] = constant_op.constant(
215        [1, 2, 3])
216    output_alternatives = {"classification_head": (
217        constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)}
218    model_fn_ops = self.create_model_fn_ops(
219        predictions, output_alternatives, mode=model_fn.ModeKeys.INFER)
220
221    estimator_spec = model_fn_ops.estimator_spec()
222    self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
223
224    with session.Session():
225      classification_output = estimator_spec.export_outputs[
226          "classification_head"]
227      self.assertTrue(isinstance(classification_output,
228                                 core_export_lib.ClassificationOutput))
229      self.assertAllEqual(predictions["scores"].eval(),
230                          classification_output.scores.eval())
231      self.assertIsNone(classification_output.classes)
232
233  def testEstimatorSpec_export_logistic(self):
234    predictions = self.create_predictions()
235    output_alternatives = {"logistic_head": (
236        constants.ProblemType.LOGISTIC_REGRESSION, predictions)}
237    model_fn_ops = self.create_model_fn_ops(
238        predictions, output_alternatives, mode=model_fn.ModeKeys.INFER)
239
240    estimator_spec = model_fn_ops.estimator_spec()
241    self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
242
243    with session.Session():
244      logistic_output = estimator_spec.export_outputs["logistic_head"]
245      self.assertTrue(isinstance(logistic_output,
246                                 core_export_lib.ClassificationOutput))
247      self.assertAllEqual(predictions["scores"].eval(),
248                          logistic_output.scores.eval())
249      self.assertAllEqual(predictions["classes"].eval(),
250                          logistic_output.classes.eval())
251
252  def testEstimatorSpec_export_unspecified(self):
253    predictions = self.create_predictions()
254    output_alternatives = {"unspecified_head": (
255        constants.ProblemType.UNSPECIFIED, predictions)}
256
257    model_fn_ops = self.create_model_fn_ops(
258        predictions, output_alternatives, mode=model_fn.ModeKeys.INFER)
259
260    estimator_spec = model_fn_ops.estimator_spec()
261    self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
262
263    with session.Session():
264      unspecified_output = estimator_spec.export_outputs["unspecified_head"]
265      self.assertTrue(isinstance(unspecified_output,
266                                 core_export_lib.PredictOutput))
267      self.assertEqual(predictions, unspecified_output.outputs)
268
269  def testEstimatorSpec_export_multihead(self):
270    predictions = self.create_predictions()
271    output_alternatives = {
272        "regression_head": (
273            constants.ProblemType.LINEAR_REGRESSION, predictions),
274        "classification_head": (
275            constants.ProblemType.CLASSIFICATION, predictions)}
276    model_fn_ops = self.create_model_fn_ops(
277        predictions, output_alternatives, mode=model_fn.ModeKeys.INFER)
278
279    estimator_spec = model_fn_ops.estimator_spec("regression_head")
280    self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
281
282    with session.Session():
283      regression_output = estimator_spec.export_outputs["regression_head"]
284      self.assertTrue(isinstance(
285          regression_output, core_export_lib.RegressionOutput))
286      self.assertAllEqual(predictions["scores"].eval(),
287                          regression_output.value.eval())
288
289      default_output = estimator_spec.export_outputs[
290          signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
291      self.assertTrue(isinstance(default_output,
292                                 core_export_lib.RegressionOutput))
293      self.assertAllEqual(predictions["scores"].eval(),
294                          default_output.value.eval())
295
296if __name__ == "__main__":
297  test.main()
298