1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for SavedModel simple save functionality.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22 23from tensorflow.python.framework import ops 24from tensorflow.python.ops import variables 25from tensorflow.python.platform import test 26from tensorflow.python.saved_model import loader 27from tensorflow.python.saved_model import signature_constants 28from tensorflow.python.saved_model import simple_save 29from tensorflow.python.saved_model import tag_constants 30 31 32class SimpleSaveTest(test.TestCase): 33 34 def _init_and_validate_variable(self, variable_name, variable_value): 35 v = variables.Variable(variable_value, name=variable_name) 36 self.evaluate(variables.global_variables_initializer()) 37 self.assertEqual(variable_value, self.evaluate(v)) 38 return v 39 40 def _check_variable_info(self, actual_variable, expected_variable): 41 self.assertEqual(actual_variable.name, expected_variable.name) 42 self.assertEqual(actual_variable.dtype, expected_variable.dtype) 43 self.assertEqual(len(actual_variable.shape), len(expected_variable.shape)) 44 for i in range(len(actual_variable.shape)): 45 self.assertEqual(actual_variable.shape[i], expected_variable.shape[i]) 46 47 def _check_tensor_info(self, actual_tensor_info, expected_tensor): 48 self.assertEqual(actual_tensor_info.name, expected_tensor.name) 49 self.assertEqual(actual_tensor_info.dtype, expected_tensor.dtype) 50 self.assertEqual( 51 len(actual_tensor_info.tensor_shape.dim), len(expected_tensor.shape)) 52 for i in range(len(actual_tensor_info.tensor_shape.dim)): 53 self.assertEqual(actual_tensor_info.tensor_shape.dim[i].size, 54 expected_tensor.shape[i]) 55 56 def testSimpleSave(self): 57 """Test simple_save that uses the default parameters.""" 58 export_dir = os.path.join(test.get_temp_dir(), 59 "test_simple_save") 60 61 # Force the test to run in graph mode. 62 # This tests a deprecated v1 API that both requires a session and uses 63 # functionality that does not work with eager tensors (such as 64 # build_tensor_info as called by predict_signature_def). 65 with ops.Graph().as_default(): 66 # Initialize input and output variables and save a prediction graph using 67 # the default parameters. 68 with self.session(graph=ops.Graph()) as sess: 69 var_x = self._init_and_validate_variable("var_x", 1) 70 var_y = self._init_and_validate_variable("var_y", 2) 71 inputs = {"x": var_x} 72 outputs = {"y": var_y} 73 simple_save.simple_save(sess, export_dir, inputs, outputs) 74 75 # Restore the graph with a valid tag and check the global variables and 76 # signature def map. 77 with self.session(graph=ops.Graph()) as sess: 78 graph = loader.load(sess, [tag_constants.SERVING], export_dir) 79 collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 80 81 # Check value and metadata of the saved variables. 82 self.assertEqual(len(collection_vars), 2) 83 self.assertEqual(1, collection_vars[0].eval()) 84 self.assertEqual(2, collection_vars[1].eval()) 85 self._check_variable_info(collection_vars[0], var_x) 86 self._check_variable_info(collection_vars[1], var_y) 87 88 # Check that the appropriate signature_def_map is created with the 89 # default key and method name, and the specified inputs and outputs. 90 signature_def_map = graph.signature_def 91 self.assertEqual(1, len(signature_def_map)) 92 self.assertEqual(signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, 93 list(signature_def_map.keys())[0]) 94 95 signature_def = signature_def_map[ 96 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] 97 self.assertEqual(signature_constants.PREDICT_METHOD_NAME, 98 signature_def.method_name) 99 100 self.assertEqual(1, len(signature_def.inputs)) 101 self._check_tensor_info(signature_def.inputs["x"], var_x) 102 self.assertEqual(1, len(signature_def.outputs)) 103 self._check_tensor_info(signature_def.outputs["y"], var_y) 104 105 106if __name__ == "__main__": 107 test.main() 108