• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the 'License');
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an 'AS IS' BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for export."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.core.framework import tensor_shape_pb2
22from tensorflow.core.framework import types_pb2
23from tensorflow.core.protobuf import meta_graph_pb2
24from tensorflow.python.eager import context
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import sparse_tensor
29from tensorflow.python.keras import metrics as metrics_module
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import control_flow_ops
32from tensorflow.python.platform import test
33from tensorflow.python.saved_model import signature_constants
34from tensorflow.python.saved_model.model_utils import export_output as export_output_lib
35
36
37class ExportOutputTest(test.TestCase):
38
39  def test_regress_value_must_be_float(self):
40    with context.graph_mode():
41      value = array_ops.placeholder(dtypes.string, 1, name='output-tensor-1')
42      with self.assertRaisesRegexp(
43          ValueError, 'Regression output value must be a float32 Tensor'):
44        export_output_lib.RegressionOutput(value)
45
46  def test_classify_classes_must_be_strings(self):
47    with context.graph_mode():
48      classes = array_ops.placeholder(dtypes.float32, 1, name='output-tensor-1')
49      with self.assertRaisesRegexp(
50          ValueError, 'Classification classes must be a string Tensor'):
51        export_output_lib.ClassificationOutput(classes=classes)
52
53  def test_classify_scores_must_be_float(self):
54    with context.graph_mode():
55      scores = array_ops.placeholder(dtypes.string, 1, name='output-tensor-1')
56      with self.assertRaisesRegexp(
57          ValueError, 'Classification scores must be a float32 Tensor'):
58        export_output_lib.ClassificationOutput(scores=scores)
59
60  def test_classify_requires_classes_or_scores(self):
61    with self.assertRaisesRegexp(
62        ValueError, 'At least one of scores and classes must be set.'):
63      export_output_lib.ClassificationOutput()
64
65  def test_build_standardized_signature_def_regression(self):
66    with context.graph_mode():
67      input_tensors = {
68          'input-1':
69              array_ops.placeholder(
70                  dtypes.string, 1, name='input-tensor-1')
71      }
72      value = array_ops.placeholder(dtypes.float32, 1, name='output-tensor-1')
73
74      export_output = export_output_lib.RegressionOutput(value)
75      actual_signature_def = export_output.as_signature_def(input_tensors)
76
77      expected_signature_def = meta_graph_pb2.SignatureDef()
78      shape = tensor_shape_pb2.TensorShapeProto(
79          dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])
80      dtype_float = types_pb2.DataType.Value('DT_FLOAT')
81      dtype_string = types_pb2.DataType.Value('DT_STRING')
82      expected_signature_def.inputs[
83          signature_constants.REGRESS_INPUTS].CopyFrom(
84              meta_graph_pb2.TensorInfo(name='input-tensor-1:0',
85                                        dtype=dtype_string,
86                                        tensor_shape=shape))
87      expected_signature_def.outputs[
88          signature_constants.REGRESS_OUTPUTS].CopyFrom(
89              meta_graph_pb2.TensorInfo(name='output-tensor-1:0',
90                                        dtype=dtype_float,
91                                        tensor_shape=shape))
92
93      expected_signature_def.method_name = (
94          signature_constants.REGRESS_METHOD_NAME)
95      self.assertEqual(actual_signature_def, expected_signature_def)
96
97  def test_build_standardized_signature_def_classify_classes_only(self):
98    """Tests classification with one output tensor."""
99    with context.graph_mode():
100      input_tensors = {
101          'input-1':
102              array_ops.placeholder(
103                  dtypes.string, 1, name='input-tensor-1')
104      }
105      classes = array_ops.placeholder(dtypes.string, 1, name='output-tensor-1')
106
107      export_output = export_output_lib.ClassificationOutput(classes=classes)
108      actual_signature_def = export_output.as_signature_def(input_tensors)
109
110      expected_signature_def = meta_graph_pb2.SignatureDef()
111      shape = tensor_shape_pb2.TensorShapeProto(
112          dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])
113      dtype_string = types_pb2.DataType.Value('DT_STRING')
114      expected_signature_def.inputs[
115          signature_constants.CLASSIFY_INPUTS].CopyFrom(
116              meta_graph_pb2.TensorInfo(name='input-tensor-1:0',
117                                        dtype=dtype_string,
118                                        tensor_shape=shape))
119      expected_signature_def.outputs[
120          signature_constants.CLASSIFY_OUTPUT_CLASSES].CopyFrom(
121              meta_graph_pb2.TensorInfo(name='output-tensor-1:0',
122                                        dtype=dtype_string,
123                                        tensor_shape=shape))
124
125      expected_signature_def.method_name = (
126          signature_constants.CLASSIFY_METHOD_NAME)
127      self.assertEqual(actual_signature_def, expected_signature_def)
128
129  def test_build_standardized_signature_def_classify_both(self):
130    """Tests multiple output tensors that include classes and scores."""
131    with context.graph_mode():
132      input_tensors = {
133          'input-1':
134              array_ops.placeholder(
135                  dtypes.string, 1, name='input-tensor-1')
136      }
137      classes = array_ops.placeholder(dtypes.string, 1,
138                                      name='output-tensor-classes')
139      scores = array_ops.placeholder(dtypes.float32, 1,
140                                     name='output-tensor-scores')
141
142      export_output = export_output_lib.ClassificationOutput(
143          scores=scores, classes=classes)
144      actual_signature_def = export_output.as_signature_def(input_tensors)
145
146      expected_signature_def = meta_graph_pb2.SignatureDef()
147      shape = tensor_shape_pb2.TensorShapeProto(
148          dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])
149      dtype_float = types_pb2.DataType.Value('DT_FLOAT')
150      dtype_string = types_pb2.DataType.Value('DT_STRING')
151      expected_signature_def.inputs[
152          signature_constants.CLASSIFY_INPUTS].CopyFrom(
153              meta_graph_pb2.TensorInfo(name='input-tensor-1:0',
154                                        dtype=dtype_string,
155                                        tensor_shape=shape))
156      expected_signature_def.outputs[
157          signature_constants.CLASSIFY_OUTPUT_CLASSES].CopyFrom(
158              meta_graph_pb2.TensorInfo(name='output-tensor-classes:0',
159                                        dtype=dtype_string,
160                                        tensor_shape=shape))
161      expected_signature_def.outputs[
162          signature_constants.CLASSIFY_OUTPUT_SCORES].CopyFrom(
163              meta_graph_pb2.TensorInfo(name='output-tensor-scores:0',
164                                        dtype=dtype_float,
165                                        tensor_shape=shape))
166
167      expected_signature_def.method_name = (
168          signature_constants.CLASSIFY_METHOD_NAME)
169      self.assertEqual(actual_signature_def, expected_signature_def)
170
171  def test_build_standardized_signature_def_classify_scores_only(self):
172    """Tests classification without classes tensor."""
173    with context.graph_mode():
174      input_tensors = {
175          'input-1':
176              array_ops.placeholder(
177                  dtypes.string, 1, name='input-tensor-1')
178      }
179
180      scores = array_ops.placeholder(dtypes.float32, 1,
181                                     name='output-tensor-scores')
182
183      export_output = export_output_lib.ClassificationOutput(
184          scores=scores)
185      actual_signature_def = export_output.as_signature_def(input_tensors)
186
187      expected_signature_def = meta_graph_pb2.SignatureDef()
188      shape = tensor_shape_pb2.TensorShapeProto(
189          dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])
190      dtype_float = types_pb2.DataType.Value('DT_FLOAT')
191      dtype_string = types_pb2.DataType.Value('DT_STRING')
192      expected_signature_def.inputs[
193          signature_constants.CLASSIFY_INPUTS].CopyFrom(
194              meta_graph_pb2.TensorInfo(name='input-tensor-1:0',
195                                        dtype=dtype_string,
196                                        tensor_shape=shape))
197      expected_signature_def.outputs[
198          signature_constants.CLASSIFY_OUTPUT_SCORES].CopyFrom(
199              meta_graph_pb2.TensorInfo(name='output-tensor-scores:0',
200                                        dtype=dtype_float,
201                                        tensor_shape=shape))
202
203      expected_signature_def.method_name = (
204          signature_constants.CLASSIFY_METHOD_NAME)
205      self.assertEqual(actual_signature_def, expected_signature_def)
206
207  def test_predict_outputs_valid(self):
208    """Tests that no errors are raised when provided outputs are valid."""
209    outputs = {
210        'output0': constant_op.constant([0]),
211        u'output1': constant_op.constant(['foo']),
212    }
213    export_output_lib.PredictOutput(outputs)
214
215    # Single Tensor is OK too
216    export_output_lib.PredictOutput(constant_op.constant([0]))
217
218  def test_predict_outputs_invalid(self):
219    with self.assertRaisesRegexp(
220        ValueError,
221        'Prediction output key must be a string'):
222      export_output_lib.PredictOutput({1: constant_op.constant([0])})
223
224    with self.assertRaisesRegexp(
225        ValueError,
226        'Prediction output value must be a Tensor'):
227      export_output_lib.PredictOutput({
228          'prediction1': sparse_tensor.SparseTensor(
229              indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
230      })
231
232
233class MockSupervisedOutput(export_output_lib._SupervisedOutput):
234  """So that we can test the abstract class methods directly."""
235
236  def _get_signature_def_fn(self):
237    pass
238
239
240class SupervisedOutputTest(test.TestCase):
241
242  def test_supervised_outputs_valid(self):
243    """Tests that no errors are raised when provided outputs are valid."""
244    with context.graph_mode():
245      loss = {'my_loss': constant_op.constant([0])}
246      predictions = {u'output1': constant_op.constant(['foo'])}
247      metric_obj = metrics_module.Mean()
248      metric_obj.update_state(constant_op.constant([0]))
249      metrics = {
250          'metrics': metric_obj,
251          'metrics2': (constant_op.constant([0]), constant_op.constant([10]))
252      }
253
254      outputter = MockSupervisedOutput(loss, predictions, metrics)
255      self.assertEqual(outputter.loss['loss/my_loss'], loss['my_loss'])
256      self.assertEqual(
257          outputter.predictions['predictions/output1'], predictions['output1'])
258      self.assertEqual(outputter.metrics['metrics/update_op'].name,
259                       'metric_op_wrapper:0')
260      self.assertEqual(
261          outputter.metrics['metrics2/update_op'], metrics['metrics2'][1])
262
263      # Single Tensor is OK too
264      outputter = MockSupervisedOutput(
265          loss['my_loss'], predictions['output1'], metrics['metrics'])
266      self.assertEqual(outputter.loss, {'loss': loss['my_loss']})
267      self.assertEqual(
268          outputter.predictions, {'predictions': predictions['output1']})
269      self.assertEqual(outputter.metrics['metrics/update_op'].name,
270                       'metric_op_wrapper_1:0')
271
272  def test_supervised_outputs_none(self):
273    outputter = MockSupervisedOutput(
274        constant_op.constant([0]), None, None)
275    self.assertEqual(len(outputter.loss), 1)
276    self.assertEqual(outputter.predictions, None)
277    self.assertEqual(outputter.metrics, None)
278
279  def test_supervised_outputs_invalid(self):
280    with self.assertRaisesRegexp(ValueError, 'predictions output value must'):
281      MockSupervisedOutput(constant_op.constant([0]), [3], None)
282    with self.assertRaisesRegexp(ValueError, 'loss output value must'):
283      MockSupervisedOutput('str', None, None)
284    with self.assertRaisesRegexp(ValueError, 'metrics output value must'):
285      MockSupervisedOutput(None, None, (15.3, 4))
286    with self.assertRaisesRegexp(ValueError, 'loss output key must'):
287      MockSupervisedOutput({25: 'Tensor'}, None, None)
288
289  def test_supervised_outputs_tuples(self):
290    """Tests that no errors are raised when provided outputs are valid."""
291    with context.graph_mode():
292      loss = {('my', 'loss'): constant_op.constant([0])}
293      predictions = {(u'output1', '2'): constant_op.constant(['foo'])}
294      metric_obj = metrics_module.Mean()
295      metric_obj.update_state(constant_op.constant([0]))
296      metrics = {
297          ('metrics', '1'):
298              metric_obj,
299          ('metrics', '2'): (constant_op.constant([0]),
300                             constant_op.constant([10]))
301      }
302
303      outputter = MockSupervisedOutput(loss, predictions, metrics)
304      self.assertEqual(set(outputter.loss.keys()), set(['loss/my/loss']))
305      self.assertEqual(set(outputter.predictions.keys()),
306                       set(['predictions/output1/2']))
307      self.assertEqual(
308          set(outputter.metrics.keys()),
309          set([
310              'metrics/1/value', 'metrics/1/update_op', 'metrics/2/value',
311              'metrics/2/update_op'
312          ]))
313
314  def test_supervised_outputs_no_prepend(self):
315    """Tests that no errors are raised when provided outputs are valid."""
316    with context.graph_mode():
317      loss = {'loss': constant_op.constant([0])}
318      predictions = {u'predictions': constant_op.constant(['foo'])}
319      metric_obj = metrics_module.Mean()
320      metric_obj.update_state(constant_op.constant([0]))
321      metrics = {
322          'metrics_1': metric_obj,
323          'metrics_2': (constant_op.constant([0]), constant_op.constant([10]))
324      }
325
326      outputter = MockSupervisedOutput(loss, predictions, metrics)
327      self.assertEqual(set(outputter.loss.keys()), set(['loss']))
328      self.assertEqual(set(outputter.predictions.keys()), set(['predictions']))
329      self.assertEqual(
330          set(outputter.metrics.keys()),
331          set([
332              'metrics_1/value', 'metrics_1/update_op', 'metrics_2/update_op',
333              'metrics_2/value'
334          ]))
335
336  def test_train_signature_def(self):
337    with context.graph_mode():
338      loss = {'my_loss': constant_op.constant([0])}
339      predictions = {u'output1': constant_op.constant(['foo'])}
340      metric_obj = metrics_module.Mean()
341      metric_obj.update_state(constant_op.constant([0]))
342      metrics = {
343          'metrics_1': metric_obj,
344          'metrics_2': (constant_op.constant([0]), constant_op.constant([10]))
345      }
346
347      outputter = export_output_lib.TrainOutput(loss, predictions, metrics)
348
349      receiver = {u'features': constant_op.constant(100, shape=(100, 2)),
350                  'labels': constant_op.constant(100, shape=(100, 1))}
351      sig_def = outputter.as_signature_def(receiver)
352
353      self.assertTrue('loss/my_loss' in sig_def.outputs)
354      self.assertTrue('metrics_1/value' in sig_def.outputs)
355      self.assertTrue('metrics_2/value' in sig_def.outputs)
356      self.assertTrue('predictions/output1' in sig_def.outputs)
357      self.assertTrue('features' in sig_def.inputs)
358
359  def test_eval_signature_def(self):
360    with context.graph_mode():
361      loss = {'my_loss': constant_op.constant([0])}
362      predictions = {u'output1': constant_op.constant(['foo'])}
363
364      outputter = export_output_lib.EvalOutput(loss, predictions, None)
365
366      receiver = {u'features': constant_op.constant(100, shape=(100, 2)),
367                  'labels': constant_op.constant(100, shape=(100, 1))}
368      sig_def = outputter.as_signature_def(receiver)
369
370      self.assertTrue('loss/my_loss' in sig_def.outputs)
371      self.assertFalse('metrics/value' in sig_def.outputs)
372      self.assertTrue('predictions/output1' in sig_def.outputs)
373      self.assertTrue('features' in sig_def.inputs)
374
375  def test_metric_op_is_tensor(self):
376    """Tests that ops.Operation is wrapped by a tensor for metric_ops."""
377    with context.graph_mode():
378      loss = {'my_loss': constant_op.constant([0])}
379      predictions = {u'output1': constant_op.constant(['foo'])}
380      metric_obj = metrics_module.Mean()
381      metric_obj.update_state(constant_op.constant([0]))
382      metrics = {
383          'metrics_1': metric_obj,
384          'metrics_2': (constant_op.constant([0]), control_flow_ops.no_op())
385      }
386
387      outputter = MockSupervisedOutput(loss, predictions, metrics)
388
389      self.assertTrue(outputter.metrics['metrics_1/update_op'].name.startswith(
390          'metric_op_wrapper'))
391      self.assertTrue(
392          isinstance(outputter.metrics['metrics_1/update_op'], ops.Tensor))
393      self.assertTrue(
394          isinstance(outputter.metrics['metrics_1/value'], ops.Tensor))
395
396      self.assertEqual(outputter.metrics['metrics_2/value'],
397                       metrics['metrics_2'][0])
398      self.assertTrue(outputter.metrics['metrics_2/update_op'].name.startswith(
399          'metric_op_wrapper'))
400      self.assertTrue(
401          isinstance(outputter.metrics['metrics_2/update_op'], ops.Tensor))
402
403
404if __name__ == '__main__':
405  test.main()
406