1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the 'License'); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an 'AS IS' BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for export.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.core.framework import tensor_shape_pb2 22from tensorflow.core.framework import types_pb2 23from tensorflow.core.protobuf import meta_graph_pb2 24from tensorflow.python.eager import context 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import sparse_tensor 29from tensorflow.python.keras import metrics as metrics_module 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import control_flow_ops 32from tensorflow.python.platform import test 33from tensorflow.python.saved_model import signature_constants 34from tensorflow.python.saved_model.model_utils import export_output as export_output_lib 35 36 37class ExportOutputTest(test.TestCase): 38 39 def test_regress_value_must_be_float(self): 40 with context.graph_mode(): 41 value = array_ops.placeholder(dtypes.string, 1, name='output-tensor-1') 42 with self.assertRaisesRegexp( 43 ValueError, 'Regression output value must be a float32 Tensor'): 44 export_output_lib.RegressionOutput(value) 45 46 def test_classify_classes_must_be_strings(self): 47 with context.graph_mode(): 48 classes = array_ops.placeholder(dtypes.float32, 1, name='output-tensor-1') 49 with self.assertRaisesRegexp( 50 ValueError, 'Classification classes must be a string Tensor'): 51 export_output_lib.ClassificationOutput(classes=classes) 52 53 def test_classify_scores_must_be_float(self): 54 with context.graph_mode(): 55 scores = array_ops.placeholder(dtypes.string, 1, name='output-tensor-1') 56 with self.assertRaisesRegexp( 57 ValueError, 'Classification scores must be a float32 Tensor'): 58 export_output_lib.ClassificationOutput(scores=scores) 59 60 def test_classify_requires_classes_or_scores(self): 61 with self.assertRaisesRegexp( 62 ValueError, 'At least one of scores and classes must be set.'): 63 export_output_lib.ClassificationOutput() 64 65 def test_build_standardized_signature_def_regression(self): 66 with context.graph_mode(): 67 input_tensors = { 68 'input-1': 69 array_ops.placeholder( 70 dtypes.string, 1, name='input-tensor-1') 71 } 72 value = array_ops.placeholder(dtypes.float32, 1, name='output-tensor-1') 73 74 export_output = export_output_lib.RegressionOutput(value) 75 actual_signature_def = export_output.as_signature_def(input_tensors) 76 77 expected_signature_def = meta_graph_pb2.SignatureDef() 78 shape = tensor_shape_pb2.TensorShapeProto( 79 dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)]) 80 dtype_float = types_pb2.DataType.Value('DT_FLOAT') 81 dtype_string = types_pb2.DataType.Value('DT_STRING') 82 expected_signature_def.inputs[ 83 signature_constants.REGRESS_INPUTS].CopyFrom( 84 meta_graph_pb2.TensorInfo(name='input-tensor-1:0', 85 dtype=dtype_string, 86 tensor_shape=shape)) 87 expected_signature_def.outputs[ 88 signature_constants.REGRESS_OUTPUTS].CopyFrom( 89 meta_graph_pb2.TensorInfo(name='output-tensor-1:0', 90 dtype=dtype_float, 91 tensor_shape=shape)) 92 93 expected_signature_def.method_name = ( 94 signature_constants.REGRESS_METHOD_NAME) 95 self.assertEqual(actual_signature_def, expected_signature_def) 96 97 def test_build_standardized_signature_def_classify_classes_only(self): 98 """Tests classification with one output tensor.""" 99 with context.graph_mode(): 100 input_tensors = { 101 'input-1': 102 array_ops.placeholder( 103 dtypes.string, 1, name='input-tensor-1') 104 } 105 classes = array_ops.placeholder(dtypes.string, 1, name='output-tensor-1') 106 107 export_output = export_output_lib.ClassificationOutput(classes=classes) 108 actual_signature_def = export_output.as_signature_def(input_tensors) 109 110 expected_signature_def = meta_graph_pb2.SignatureDef() 111 shape = tensor_shape_pb2.TensorShapeProto( 112 dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)]) 113 dtype_string = types_pb2.DataType.Value('DT_STRING') 114 expected_signature_def.inputs[ 115 signature_constants.CLASSIFY_INPUTS].CopyFrom( 116 meta_graph_pb2.TensorInfo(name='input-tensor-1:0', 117 dtype=dtype_string, 118 tensor_shape=shape)) 119 expected_signature_def.outputs[ 120 signature_constants.CLASSIFY_OUTPUT_CLASSES].CopyFrom( 121 meta_graph_pb2.TensorInfo(name='output-tensor-1:0', 122 dtype=dtype_string, 123 tensor_shape=shape)) 124 125 expected_signature_def.method_name = ( 126 signature_constants.CLASSIFY_METHOD_NAME) 127 self.assertEqual(actual_signature_def, expected_signature_def) 128 129 def test_build_standardized_signature_def_classify_both(self): 130 """Tests multiple output tensors that include classes and scores.""" 131 with context.graph_mode(): 132 input_tensors = { 133 'input-1': 134 array_ops.placeholder( 135 dtypes.string, 1, name='input-tensor-1') 136 } 137 classes = array_ops.placeholder(dtypes.string, 1, 138 name='output-tensor-classes') 139 scores = array_ops.placeholder(dtypes.float32, 1, 140 name='output-tensor-scores') 141 142 export_output = export_output_lib.ClassificationOutput( 143 scores=scores, classes=classes) 144 actual_signature_def = export_output.as_signature_def(input_tensors) 145 146 expected_signature_def = meta_graph_pb2.SignatureDef() 147 shape = tensor_shape_pb2.TensorShapeProto( 148 dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)]) 149 dtype_float = types_pb2.DataType.Value('DT_FLOAT') 150 dtype_string = types_pb2.DataType.Value('DT_STRING') 151 expected_signature_def.inputs[ 152 signature_constants.CLASSIFY_INPUTS].CopyFrom( 153 meta_graph_pb2.TensorInfo(name='input-tensor-1:0', 154 dtype=dtype_string, 155 tensor_shape=shape)) 156 expected_signature_def.outputs[ 157 signature_constants.CLASSIFY_OUTPUT_CLASSES].CopyFrom( 158 meta_graph_pb2.TensorInfo(name='output-tensor-classes:0', 159 dtype=dtype_string, 160 tensor_shape=shape)) 161 expected_signature_def.outputs[ 162 signature_constants.CLASSIFY_OUTPUT_SCORES].CopyFrom( 163 meta_graph_pb2.TensorInfo(name='output-tensor-scores:0', 164 dtype=dtype_float, 165 tensor_shape=shape)) 166 167 expected_signature_def.method_name = ( 168 signature_constants.CLASSIFY_METHOD_NAME) 169 self.assertEqual(actual_signature_def, expected_signature_def) 170 171 def test_build_standardized_signature_def_classify_scores_only(self): 172 """Tests classification without classes tensor.""" 173 with context.graph_mode(): 174 input_tensors = { 175 'input-1': 176 array_ops.placeholder( 177 dtypes.string, 1, name='input-tensor-1') 178 } 179 180 scores = array_ops.placeholder(dtypes.float32, 1, 181 name='output-tensor-scores') 182 183 export_output = export_output_lib.ClassificationOutput( 184 scores=scores) 185 actual_signature_def = export_output.as_signature_def(input_tensors) 186 187 expected_signature_def = meta_graph_pb2.SignatureDef() 188 shape = tensor_shape_pb2.TensorShapeProto( 189 dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)]) 190 dtype_float = types_pb2.DataType.Value('DT_FLOAT') 191 dtype_string = types_pb2.DataType.Value('DT_STRING') 192 expected_signature_def.inputs[ 193 signature_constants.CLASSIFY_INPUTS].CopyFrom( 194 meta_graph_pb2.TensorInfo(name='input-tensor-1:0', 195 dtype=dtype_string, 196 tensor_shape=shape)) 197 expected_signature_def.outputs[ 198 signature_constants.CLASSIFY_OUTPUT_SCORES].CopyFrom( 199 meta_graph_pb2.TensorInfo(name='output-tensor-scores:0', 200 dtype=dtype_float, 201 tensor_shape=shape)) 202 203 expected_signature_def.method_name = ( 204 signature_constants.CLASSIFY_METHOD_NAME) 205 self.assertEqual(actual_signature_def, expected_signature_def) 206 207 def test_predict_outputs_valid(self): 208 """Tests that no errors are raised when provided outputs are valid.""" 209 outputs = { 210 'output0': constant_op.constant([0]), 211 u'output1': constant_op.constant(['foo']), 212 } 213 export_output_lib.PredictOutput(outputs) 214 215 # Single Tensor is OK too 216 export_output_lib.PredictOutput(constant_op.constant([0])) 217 218 def test_predict_outputs_invalid(self): 219 with self.assertRaisesRegexp( 220 ValueError, 221 'Prediction output key must be a string'): 222 export_output_lib.PredictOutput({1: constant_op.constant([0])}) 223 224 with self.assertRaisesRegexp( 225 ValueError, 226 'Prediction output value must be a Tensor'): 227 export_output_lib.PredictOutput({ 228 'prediction1': sparse_tensor.SparseTensor( 229 indices=[[0, 0]], values=[1], dense_shape=[1, 1]), 230 }) 231 232 233class MockSupervisedOutput(export_output_lib._SupervisedOutput): 234 """So that we can test the abstract class methods directly.""" 235 236 def _get_signature_def_fn(self): 237 pass 238 239 240class SupervisedOutputTest(test.TestCase): 241 242 def test_supervised_outputs_valid(self): 243 """Tests that no errors are raised when provided outputs are valid.""" 244 with context.graph_mode(): 245 loss = {'my_loss': constant_op.constant([0])} 246 predictions = {u'output1': constant_op.constant(['foo'])} 247 metric_obj = metrics_module.Mean() 248 metric_obj.update_state(constant_op.constant([0])) 249 metrics = { 250 'metrics': metric_obj, 251 'metrics2': (constant_op.constant([0]), constant_op.constant([10])) 252 } 253 254 outputter = MockSupervisedOutput(loss, predictions, metrics) 255 self.assertEqual(outputter.loss['loss/my_loss'], loss['my_loss']) 256 self.assertEqual( 257 outputter.predictions['predictions/output1'], predictions['output1']) 258 self.assertEqual(outputter.metrics['metrics/update_op'].name, 259 'metric_op_wrapper:0') 260 self.assertEqual( 261 outputter.metrics['metrics2/update_op'], metrics['metrics2'][1]) 262 263 # Single Tensor is OK too 264 outputter = MockSupervisedOutput( 265 loss['my_loss'], predictions['output1'], metrics['metrics']) 266 self.assertEqual(outputter.loss, {'loss': loss['my_loss']}) 267 self.assertEqual( 268 outputter.predictions, {'predictions': predictions['output1']}) 269 self.assertEqual(outputter.metrics['metrics/update_op'].name, 270 'metric_op_wrapper_1:0') 271 272 def test_supervised_outputs_none(self): 273 outputter = MockSupervisedOutput( 274 constant_op.constant([0]), None, None) 275 self.assertEqual(len(outputter.loss), 1) 276 self.assertEqual(outputter.predictions, None) 277 self.assertEqual(outputter.metrics, None) 278 279 def test_supervised_outputs_invalid(self): 280 with self.assertRaisesRegexp(ValueError, 'predictions output value must'): 281 MockSupervisedOutput(constant_op.constant([0]), [3], None) 282 with self.assertRaisesRegexp(ValueError, 'loss output value must'): 283 MockSupervisedOutput('str', None, None) 284 with self.assertRaisesRegexp(ValueError, 'metrics output value must'): 285 MockSupervisedOutput(None, None, (15.3, 4)) 286 with self.assertRaisesRegexp(ValueError, 'loss output key must'): 287 MockSupervisedOutput({25: 'Tensor'}, None, None) 288 289 def test_supervised_outputs_tuples(self): 290 """Tests that no errors are raised when provided outputs are valid.""" 291 with context.graph_mode(): 292 loss = {('my', 'loss'): constant_op.constant([0])} 293 predictions = {(u'output1', '2'): constant_op.constant(['foo'])} 294 metric_obj = metrics_module.Mean() 295 metric_obj.update_state(constant_op.constant([0])) 296 metrics = { 297 ('metrics', '1'): 298 metric_obj, 299 ('metrics', '2'): (constant_op.constant([0]), 300 constant_op.constant([10])) 301 } 302 303 outputter = MockSupervisedOutput(loss, predictions, metrics) 304 self.assertEqual(set(outputter.loss.keys()), set(['loss/my/loss'])) 305 self.assertEqual(set(outputter.predictions.keys()), 306 set(['predictions/output1/2'])) 307 self.assertEqual( 308 set(outputter.metrics.keys()), 309 set([ 310 'metrics/1/value', 'metrics/1/update_op', 'metrics/2/value', 311 'metrics/2/update_op' 312 ])) 313 314 def test_supervised_outputs_no_prepend(self): 315 """Tests that no errors are raised when provided outputs are valid.""" 316 with context.graph_mode(): 317 loss = {'loss': constant_op.constant([0])} 318 predictions = {u'predictions': constant_op.constant(['foo'])} 319 metric_obj = metrics_module.Mean() 320 metric_obj.update_state(constant_op.constant([0])) 321 metrics = { 322 'metrics_1': metric_obj, 323 'metrics_2': (constant_op.constant([0]), constant_op.constant([10])) 324 } 325 326 outputter = MockSupervisedOutput(loss, predictions, metrics) 327 self.assertEqual(set(outputter.loss.keys()), set(['loss'])) 328 self.assertEqual(set(outputter.predictions.keys()), set(['predictions'])) 329 self.assertEqual( 330 set(outputter.metrics.keys()), 331 set([ 332 'metrics_1/value', 'metrics_1/update_op', 'metrics_2/update_op', 333 'metrics_2/value' 334 ])) 335 336 def test_train_signature_def(self): 337 with context.graph_mode(): 338 loss = {'my_loss': constant_op.constant([0])} 339 predictions = {u'output1': constant_op.constant(['foo'])} 340 metric_obj = metrics_module.Mean() 341 metric_obj.update_state(constant_op.constant([0])) 342 metrics = { 343 'metrics_1': metric_obj, 344 'metrics_2': (constant_op.constant([0]), constant_op.constant([10])) 345 } 346 347 outputter = export_output_lib.TrainOutput(loss, predictions, metrics) 348 349 receiver = {u'features': constant_op.constant(100, shape=(100, 2)), 350 'labels': constant_op.constant(100, shape=(100, 1))} 351 sig_def = outputter.as_signature_def(receiver) 352 353 self.assertTrue('loss/my_loss' in sig_def.outputs) 354 self.assertTrue('metrics_1/value' in sig_def.outputs) 355 self.assertTrue('metrics_2/value' in sig_def.outputs) 356 self.assertTrue('predictions/output1' in sig_def.outputs) 357 self.assertTrue('features' in sig_def.inputs) 358 359 def test_eval_signature_def(self): 360 with context.graph_mode(): 361 loss = {'my_loss': constant_op.constant([0])} 362 predictions = {u'output1': constant_op.constant(['foo'])} 363 364 outputter = export_output_lib.EvalOutput(loss, predictions, None) 365 366 receiver = {u'features': constant_op.constant(100, shape=(100, 2)), 367 'labels': constant_op.constant(100, shape=(100, 1))} 368 sig_def = outputter.as_signature_def(receiver) 369 370 self.assertTrue('loss/my_loss' in sig_def.outputs) 371 self.assertFalse('metrics/value' in sig_def.outputs) 372 self.assertTrue('predictions/output1' in sig_def.outputs) 373 self.assertTrue('features' in sig_def.inputs) 374 375 def test_metric_op_is_tensor(self): 376 """Tests that ops.Operation is wrapped by a tensor for metric_ops.""" 377 with context.graph_mode(): 378 loss = {'my_loss': constant_op.constant([0])} 379 predictions = {u'output1': constant_op.constant(['foo'])} 380 metric_obj = metrics_module.Mean() 381 metric_obj.update_state(constant_op.constant([0])) 382 metrics = { 383 'metrics_1': metric_obj, 384 'metrics_2': (constant_op.constant([0]), control_flow_ops.no_op()) 385 } 386 387 outputter = MockSupervisedOutput(loss, predictions, metrics) 388 389 self.assertTrue(outputter.metrics['metrics_1/update_op'].name.startswith( 390 'metric_op_wrapper')) 391 self.assertTrue( 392 isinstance(outputter.metrics['metrics_1/update_op'], ops.Tensor)) 393 self.assertTrue( 394 isinstance(outputter.metrics['metrics_1/value'], ops.Tensor)) 395 396 self.assertEqual(outputter.metrics['metrics_2/value'], 397 metrics['metrics_2'][0]) 398 self.assertTrue(outputter.metrics['metrics_2/update_op'].name.startswith( 399 'metric_op_wrapper')) 400 self.assertTrue( 401 isinstance(outputter.metrics['metrics_2/update_op'], ops.Tensor)) 402 403 404if __name__ == '__main__': 405 test.main() 406