• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for SavedModel utils."""
16
17import os
18
19from tensorflow.python.framework import ops
20from tensorflow.python.framework import test_util
21from tensorflow.python.lib.io import file_io
22from tensorflow.python.ops import variables
23from tensorflow.python.platform import test
24from tensorflow.python.saved_model import builder as saved_model_builder
25from tensorflow.python.saved_model import tag_constants
26from tensorflow.python.tools import saved_model_utils
27
28
29def tearDownModule():
30  file_io.delete_recursively(test.get_temp_dir())
31
32
33class SavedModelUtilTest(test.TestCase):
34
35  def _init_and_validate_variable(self, sess, variable_name, variable_value):
36    v = variables.Variable(variable_value, name=variable_name)
37    sess.run(variables.global_variables_initializer())
38    self.assertEqual(variable_value, v.eval())
39
40  @test_util.deprecated_graph_mode_only
41  def testReadSavedModelValid(self):
42    saved_model_dir = os.path.join(test.get_temp_dir(), "valid_saved_model")
43    builder = saved_model_builder.SavedModelBuilder(saved_model_dir)
44    with self.session(graph=ops.Graph()) as sess:
45      self._init_and_validate_variable(sess, "v", 42)
46      builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING])
47    builder.save()
48
49    actual_saved_model_pb = saved_model_utils.read_saved_model(saved_model_dir)
50    self.assertEqual(len(actual_saved_model_pb.meta_graphs), 1)
51    self.assertEqual(
52        len(actual_saved_model_pb.meta_graphs[0].meta_info_def.tags), 1)
53    self.assertEqual(actual_saved_model_pb.meta_graphs[0].meta_info_def.tags[0],
54                     tag_constants.TRAINING)
55
56  def testReadSavedModelInvalid(self):
57    saved_model_dir = os.path.join(test.get_temp_dir(), "invalid_saved_model")
58    with self.assertRaisesRegex(
59        IOError, "SavedModel file does not exist at: %s" % saved_model_dir):
60      saved_model_utils.read_saved_model(saved_model_dir)
61
62  def testGetSavedModelTagSets(self):
63    saved_model_dir = os.path.join(test.get_temp_dir(), "test_tags")
64    builder = saved_model_builder.SavedModelBuilder(saved_model_dir)
65    # Force test to run in graph mode since SavedModelBuilder.save requires a
66    # session to work.
67    with ops.Graph().as_default():
68    # Graph with a single variable. SavedModel invoked to:
69    # - add with weights.
70    # - a single tag (from predefined constants).
71      with self.session(graph=ops.Graph()) as sess:
72        self._init_and_validate_variable(sess, "v", 42)
73        builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING])
74
75      # Graph that updates the single variable. SavedModel invoked to:
76      # - simply add the model (weights are not updated).
77      # - a single tag (from predefined constants).
78      with self.session(graph=ops.Graph()) as sess:
79        self._init_and_validate_variable(sess, "v", 43)
80        builder.add_meta_graph([tag_constants.SERVING])
81
82      # Graph that updates the single variable. SavedModel is invoked:
83      # - to add the model (weights are not updated).
84      # - multiple predefined tags.
85      with self.session(graph=ops.Graph()) as sess:
86        self._init_and_validate_variable(sess, "v", 44)
87        builder.add_meta_graph([tag_constants.SERVING, tag_constants.GPU])
88
89      # Graph that updates the single variable. SavedModel is invoked:
90      # - to add the model (weights are not updated).
91      # - multiple predefined tags for serving on TPU.
92      with self.session(graph=ops.Graph()) as sess:
93        self._init_and_validate_variable(sess, "v", 44)
94        builder.add_meta_graph([tag_constants.SERVING, tag_constants.TPU])
95
96      # Graph that updates the single variable. SavedModel is invoked:
97      # - to add the model (weights are not updated).
98      # - multiple custom tags.
99      with self.session(graph=ops.Graph()) as sess:
100        self._init_and_validate_variable(sess, "v", 45)
101        builder.add_meta_graph(["foo", "bar"])
102
103      # Save the SavedModel to disk.
104      builder.save()
105
106    actual_tags = saved_model_utils.get_saved_model_tag_sets(saved_model_dir)
107    expected_tags = [["train"], ["serve"], ["serve", "gpu"], ["serve", "tpu"],
108                     ["foo", "bar"]]
109    self.assertEqual(expected_tags, actual_tags)
110
111  def testGetMetaGraphInvalidTagSet(self):
112    saved_model_dir = os.path.join(test.get_temp_dir(), "test_invalid_tags")
113    builder = saved_model_builder.SavedModelBuilder(saved_model_dir)
114    # Force test to run in graph mode since SavedModelBuilder.save requires a
115    # session to work.
116    with ops.Graph().as_default() as g:
117      with self.session(graph=g) as sess:
118        self._init_and_validate_variable(sess, "v", 42)
119        builder.add_meta_graph_and_variables(sess, ["a", "b"])
120      builder.save()
121
122    # Sanity check
123    saved_model_utils.get_meta_graph_def(saved_model_dir, "a,b")
124
125    with self.assertRaisesRegex(RuntimeError, "associated with tag-set"):
126      saved_model_utils.get_meta_graph_def(saved_model_dir, "c,d")
127
128
129if __name__ == "__main__":
130  test.main()
131