• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the 'License');
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an 'AS IS' BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for export."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.core.framework import tensor_shape_pb2
22from tensorflow.core.framework import types_pb2
23from tensorflow.core.protobuf import meta_graph_pb2
24from tensorflow.python.eager import context
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import sparse_tensor
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import control_flow_ops
31from tensorflow.python.ops import metrics as metrics_module
32from tensorflow.python.ops import variables
33from tensorflow.python.platform import test
34from tensorflow.python.saved_model import signature_constants
35from tensorflow.python.saved_model.model_utils import export_output as export_output_lib
36
37
38class ExportOutputTest(test.TestCase):
39
40  def test_regress_value_must_be_float(self):
41    with context.graph_mode():
42      value = array_ops.placeholder(dtypes.string, 1, name='output-tensor-1')
43      with self.assertRaisesRegex(
44          ValueError, 'Regression output value must be a float32 Tensor'):
45        export_output_lib.RegressionOutput(value)
46
47  def test_classify_classes_must_be_strings(self):
48    with context.graph_mode():
49      classes = array_ops.placeholder(dtypes.float32, 1, name='output-tensor-1')
50      with self.assertRaisesRegex(
51          ValueError, 'Classification classes must be a string Tensor'):
52        export_output_lib.ClassificationOutput(classes=classes)
53
54  def test_classify_scores_must_be_float(self):
55    with context.graph_mode():
56      scores = array_ops.placeholder(dtypes.string, 1, name='output-tensor-1')
57      with self.assertRaisesRegex(
58          ValueError, 'Classification scores must be a float32 Tensor'):
59        export_output_lib.ClassificationOutput(scores=scores)
60
61  def test_classify_requires_classes_or_scores(self):
62    with self.assertRaisesRegex(
63        ValueError,
64        'Cannot create a ClassificationOutput with empty arguments'):
65      export_output_lib.ClassificationOutput()
66
67  def test_build_standardized_signature_def_regression(self):
68    with context.graph_mode():
69      input_tensors = {
70          'input-1':
71              array_ops.placeholder(
72                  dtypes.string, 1, name='input-tensor-1')
73      }
74      value = array_ops.placeholder(dtypes.float32, 1, name='output-tensor-1')
75
76      export_output = export_output_lib.RegressionOutput(value)
77      actual_signature_def = export_output.as_signature_def(input_tensors)
78
79      expected_signature_def = meta_graph_pb2.SignatureDef()
80      shape = tensor_shape_pb2.TensorShapeProto(
81          dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])
82      dtype_float = types_pb2.DataType.Value('DT_FLOAT')
83      dtype_string = types_pb2.DataType.Value('DT_STRING')
84      expected_signature_def.inputs[
85          signature_constants.REGRESS_INPUTS].CopyFrom(
86              meta_graph_pb2.TensorInfo(name='input-tensor-1:0',
87                                        dtype=dtype_string,
88                                        tensor_shape=shape))
89      expected_signature_def.outputs[
90          signature_constants.REGRESS_OUTPUTS].CopyFrom(
91              meta_graph_pb2.TensorInfo(name='output-tensor-1:0',
92                                        dtype=dtype_float,
93                                        tensor_shape=shape))
94
95      expected_signature_def.method_name = (
96          signature_constants.REGRESS_METHOD_NAME)
97      self.assertEqual(actual_signature_def, expected_signature_def)
98
99  def test_build_standardized_signature_def_classify_classes_only(self):
100    """Tests classification with one output tensor."""
101    with context.graph_mode():
102      input_tensors = {
103          'input-1':
104              array_ops.placeholder(
105                  dtypes.string, 1, name='input-tensor-1')
106      }
107      classes = array_ops.placeholder(dtypes.string, 1, name='output-tensor-1')
108
109      export_output = export_output_lib.ClassificationOutput(classes=classes)
110      actual_signature_def = export_output.as_signature_def(input_tensors)
111
112      expected_signature_def = meta_graph_pb2.SignatureDef()
113      shape = tensor_shape_pb2.TensorShapeProto(
114          dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])
115      dtype_string = types_pb2.DataType.Value('DT_STRING')
116      expected_signature_def.inputs[
117          signature_constants.CLASSIFY_INPUTS].CopyFrom(
118              meta_graph_pb2.TensorInfo(name='input-tensor-1:0',
119                                        dtype=dtype_string,
120                                        tensor_shape=shape))
121      expected_signature_def.outputs[
122          signature_constants.CLASSIFY_OUTPUT_CLASSES].CopyFrom(
123              meta_graph_pb2.TensorInfo(name='output-tensor-1:0',
124                                        dtype=dtype_string,
125                                        tensor_shape=shape))
126
127      expected_signature_def.method_name = (
128          signature_constants.CLASSIFY_METHOD_NAME)
129      self.assertEqual(actual_signature_def, expected_signature_def)
130
131  def test_build_standardized_signature_def_classify_both(self):
132    """Tests multiple output tensors that include classes and scores."""
133    with context.graph_mode():
134      input_tensors = {
135          'input-1':
136              array_ops.placeholder(
137                  dtypes.string, 1, name='input-tensor-1')
138      }
139      classes = array_ops.placeholder(dtypes.string, 1,
140                                      name='output-tensor-classes')
141      scores = array_ops.placeholder(dtypes.float32, 1,
142                                     name='output-tensor-scores')
143
144      export_output = export_output_lib.ClassificationOutput(
145          scores=scores, classes=classes)
146      actual_signature_def = export_output.as_signature_def(input_tensors)
147
148      expected_signature_def = meta_graph_pb2.SignatureDef()
149      shape = tensor_shape_pb2.TensorShapeProto(
150          dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])
151      dtype_float = types_pb2.DataType.Value('DT_FLOAT')
152      dtype_string = types_pb2.DataType.Value('DT_STRING')
153      expected_signature_def.inputs[
154          signature_constants.CLASSIFY_INPUTS].CopyFrom(
155              meta_graph_pb2.TensorInfo(name='input-tensor-1:0',
156                                        dtype=dtype_string,
157                                        tensor_shape=shape))
158      expected_signature_def.outputs[
159          signature_constants.CLASSIFY_OUTPUT_CLASSES].CopyFrom(
160              meta_graph_pb2.TensorInfo(name='output-tensor-classes:0',
161                                        dtype=dtype_string,
162                                        tensor_shape=shape))
163      expected_signature_def.outputs[
164          signature_constants.CLASSIFY_OUTPUT_SCORES].CopyFrom(
165              meta_graph_pb2.TensorInfo(name='output-tensor-scores:0',
166                                        dtype=dtype_float,
167                                        tensor_shape=shape))
168
169      expected_signature_def.method_name = (
170          signature_constants.CLASSIFY_METHOD_NAME)
171      self.assertEqual(actual_signature_def, expected_signature_def)
172
173  def test_build_standardized_signature_def_classify_scores_only(self):
174    """Tests classification without classes tensor."""
175    with context.graph_mode():
176      input_tensors = {
177          'input-1':
178              array_ops.placeholder(
179                  dtypes.string, 1, name='input-tensor-1')
180      }
181
182      scores = array_ops.placeholder(dtypes.float32, 1,
183                                     name='output-tensor-scores')
184
185      export_output = export_output_lib.ClassificationOutput(
186          scores=scores)
187      actual_signature_def = export_output.as_signature_def(input_tensors)
188
189      expected_signature_def = meta_graph_pb2.SignatureDef()
190      shape = tensor_shape_pb2.TensorShapeProto(
191          dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])
192      dtype_float = types_pb2.DataType.Value('DT_FLOAT')
193      dtype_string = types_pb2.DataType.Value('DT_STRING')
194      expected_signature_def.inputs[
195          signature_constants.CLASSIFY_INPUTS].CopyFrom(
196              meta_graph_pb2.TensorInfo(name='input-tensor-1:0',
197                                        dtype=dtype_string,
198                                        tensor_shape=shape))
199      expected_signature_def.outputs[
200          signature_constants.CLASSIFY_OUTPUT_SCORES].CopyFrom(
201              meta_graph_pb2.TensorInfo(name='output-tensor-scores:0',
202                                        dtype=dtype_float,
203                                        tensor_shape=shape))
204
205      expected_signature_def.method_name = (
206          signature_constants.CLASSIFY_METHOD_NAME)
207      self.assertEqual(actual_signature_def, expected_signature_def)
208
209  def test_predict_outputs_valid(self):
210    """Tests that no errors are raised when provided outputs are valid."""
211    outputs = {
212        'output0': constant_op.constant([0]),
213        u'output1': constant_op.constant(['foo']),
214    }
215    export_output_lib.PredictOutput(outputs)
216
217    # Single Tensor is OK too
218    export_output_lib.PredictOutput(constant_op.constant([0]))
219
220  def test_predict_outputs_invalid(self):
221    with self.assertRaisesRegex(ValueError,
222                                'Prediction output key must be a string'):
223      export_output_lib.PredictOutput({1: constant_op.constant([0])})
224
225    with self.assertRaisesRegex(ValueError,
226                                'Prediction output value must be a Tensor'):
227      export_output_lib.PredictOutput({
228          'prediction1': sparse_tensor.SparseTensor(
229              indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
230      })
231
232
233class MockSupervisedOutput(export_output_lib._SupervisedOutput):
234  """So that we can test the abstract class methods directly."""
235
236  def _get_signature_def_fn(self):
237    pass
238
239
240class SupervisedOutputTest(test.TestCase):
241
242  def test_supervised_outputs_valid(self):
243    """Tests that no errors are raised when provided outputs are valid."""
244    with context.graph_mode():
245      loss = {'my_loss': constant_op.constant([0])}
246      predictions = {u'output1': constant_op.constant(['foo'])}
247      mean, update_op = metrics_module.mean_tensor(constant_op.constant([0]))
248      metrics = {
249          'metrics': (mean, update_op),
250          'metrics2': (constant_op.constant([0]), constant_op.constant([10]))
251      }
252
253      outputter = MockSupervisedOutput(loss, predictions, metrics)
254      self.assertEqual(outputter.loss['loss/my_loss'], loss['my_loss'])
255      self.assertEqual(
256          outputter.predictions['predictions/output1'], predictions['output1'])
257      self.assertEqual(outputter.metrics['metrics/update_op'].name,
258                       'mean/update_op:0')
259      self.assertEqual(
260          outputter.metrics['metrics2/update_op'], metrics['metrics2'][1])
261
262      # Single Tensor is OK too
263      outputter = MockSupervisedOutput(
264          loss['my_loss'], predictions['output1'], metrics['metrics'])
265      self.assertEqual(outputter.loss, {'loss': loss['my_loss']})
266      self.assertEqual(
267          outputter.predictions, {'predictions': predictions['output1']})
268      self.assertEqual(outputter.metrics['metrics/update_op'].name,
269                       'mean/update_op:0')
270
271  def test_supervised_outputs_none(self):
272    outputter = MockSupervisedOutput(
273        constant_op.constant([0]), None, None)
274    self.assertLen(outputter.loss, 1)
275    self.assertIsNone(outputter.predictions)
276    self.assertIsNone(outputter.metrics)
277
278  def test_supervised_outputs_invalid(self):
279    with self.assertRaisesRegex(ValueError, 'predictions output value must'):
280      MockSupervisedOutput(constant_op.constant([0]), [3], None)
281    with self.assertRaisesRegex(ValueError, 'loss output value must'):
282      MockSupervisedOutput('str', None, None)
283    with self.assertRaisesRegex(ValueError, 'metrics output value must'):
284      MockSupervisedOutput(None, None, (15.3, 4))
285    with self.assertRaisesRegex(ValueError, 'loss output key must'):
286      MockSupervisedOutput({25: 'Tensor'}, None, None)
287
288  def test_supervised_outputs_tuples(self):
289    """Tests that no errors are raised when provided outputs are valid."""
290    with context.graph_mode():
291      loss = {('my', 'loss'): constant_op.constant([0])}
292      predictions = {(u'output1', '2'): constant_op.constant(['foo'])}
293      mean, update_op = metrics_module.mean_tensor(constant_op.constant([0]))
294      metrics = {
295          ('metrics', '1'): (mean, update_op),
296          ('metrics', '2'): (constant_op.constant([0]),
297                             constant_op.constant([10]))
298      }
299
300      outputter = MockSupervisedOutput(loss, predictions, metrics)
301      self.assertEqual(set(outputter.loss.keys()), set(['loss/my/loss']))
302      self.assertEqual(set(outputter.predictions.keys()),
303                       set(['predictions/output1/2']))
304      self.assertEqual(
305          set(outputter.metrics.keys()),
306          set([
307              'metrics/1/value', 'metrics/1/update_op', 'metrics/2/value',
308              'metrics/2/update_op'
309          ]))
310
311  def test_supervised_outputs_no_prepend(self):
312    """Tests that no errors are raised when provided outputs are valid."""
313    with context.graph_mode():
314      loss = {'loss': constant_op.constant([0])}
315      predictions = {u'predictions': constant_op.constant(['foo'])}
316      mean, update_op = metrics_module.mean_tensor(constant_op.constant([0]))
317      metrics = {
318          'metrics_1': (mean, update_op),
319          'metrics_2': (constant_op.constant([0]), constant_op.constant([10]))
320      }
321
322      outputter = MockSupervisedOutput(loss, predictions, metrics)
323      self.assertEqual(set(outputter.loss.keys()), set(['loss']))
324      self.assertEqual(set(outputter.predictions.keys()), set(['predictions']))
325      self.assertEqual(
326          set(outputter.metrics.keys()),
327          set([
328              'metrics_1/value', 'metrics_1/update_op', 'metrics_2/update_op',
329              'metrics_2/value'
330          ]))
331
332  def test_train_signature_def(self):
333    with context.graph_mode():
334      loss = {'my_loss': constant_op.constant([0])}
335      predictions = {u'output1': constant_op.constant(['foo'])}
336      mean, update_op = metrics_module.mean_tensor(constant_op.constant([0]))
337      metrics = {
338          'metrics_1': (mean, update_op),
339          'metrics_2': (constant_op.constant([0]), constant_op.constant([10]))
340      }
341
342      outputter = export_output_lib.TrainOutput(loss, predictions, metrics)
343
344      receiver = {u'features': constant_op.constant(100, shape=(100, 2)),
345                  'labels': constant_op.constant(100, shape=(100, 1))}
346      sig_def = outputter.as_signature_def(receiver)
347
348      self.assertIn('loss/my_loss', sig_def.outputs)
349      self.assertIn('metrics_1/value', sig_def.outputs)
350      self.assertIn('metrics_2/value', sig_def.outputs)
351      self.assertIn('predictions/output1', sig_def.outputs)
352      self.assertIn('features', sig_def.inputs)
353
354  def test_eval_signature_def(self):
355    with context.graph_mode():
356      loss = {'my_loss': constant_op.constant([0])}
357      predictions = {u'output1': constant_op.constant(['foo'])}
358
359      outputter = export_output_lib.EvalOutput(loss, predictions, None)
360
361      receiver = {u'features': constant_op.constant(100, shape=(100, 2)),
362                  'labels': constant_op.constant(100, shape=(100, 1))}
363      sig_def = outputter.as_signature_def(receiver)
364
365      self.assertIn('loss/my_loss', sig_def.outputs)
366      self.assertNotIn('metrics/value', sig_def.outputs)
367      self.assertIn('predictions/output1', sig_def.outputs)
368      self.assertIn('features', sig_def.inputs)
369
370  def test_metric_op_is_tensor(self):
371    """Tests that ops.Operation is wrapped by a tensor for metric_ops."""
372    with context.graph_mode():
373      loss = {'my_loss': constant_op.constant([0])}
374      predictions = {u'output1': constant_op.constant(['foo'])}
375      mean, update_op = metrics_module.mean_tensor(constant_op.constant([0]))
376      metrics = {
377          'metrics_1': (mean, update_op),
378          'metrics_2': (constant_op.constant([0]), control_flow_ops.no_op()),
379          # Keras metric's update_state() could return a Variable, rather than
380          # an Operation or Tensor.
381          'keras_1': (constant_op.constant([0.5]),
382                      variables.Variable(1.0, name='AssignAddVariableOp_3'))
383      }
384
385      outputter = MockSupervisedOutput(loss, predictions, metrics)
386      # If we get there, it means constructor succeeded; which is sufficient
387      # for testing the constructor.
388
389      self.assertTrue(outputter.metrics['metrics_1/update_op'].name.startswith(
390          'mean/update_op'))
391      self.assertIsInstance(
392          outputter.metrics['metrics_1/update_op'], ops.Tensor)
393      self.assertIsInstance(outputter.metrics['metrics_1/value'], ops.Tensor)
394
395      self.assertEqual(outputter.metrics['metrics_2/value'],
396                       metrics['metrics_2'][0])
397      self.assertTrue(outputter.metrics['metrics_2/update_op'].name.startswith(
398          'metric_op_wrapper'))
399      self.assertIsInstance(
400          outputter.metrics['metrics_2/update_op'], ops.Tensor)
401
402
403if __name__ == '__main__':
404  test.main()
405