• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for export utils."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22import tempfile
23import time
24
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import test_util
29from tensorflow.python.ops import array_ops
30from tensorflow.python.platform import test
31from tensorflow.python.saved_model import signature_constants
32from tensorflow.python.saved_model import signature_def_utils
33from tensorflow.python.saved_model.model_utils import export_output
34from tensorflow.python.saved_model.model_utils import export_utils
35from tensorflow.python.saved_model.model_utils.mode_keys import KerasModeKeys
36
37
38class ExportTest(test_util.TensorFlowTestCase):
39
40  def test_build_all_signature_defs_without_receiver_alternatives(self):
41    # Force the test to run in graph mode.
42    # This tests a deprecated v1 API that depends on graph-only functions such
43    # as build_tensor_info.
44    with ops.Graph().as_default():
45      receiver_tensor = array_ops.placeholder(dtypes.string)
46      output_1 = constant_op.constant([1.])
47      output_2 = constant_op.constant(["2"])
48      output_3 = constant_op.constant(["3"])
49      export_outputs = {
50          signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
51              export_output.RegressionOutput(value=output_1),
52          "head-2":
53              export_output.ClassificationOutput(classes=output_2),
54          "head-3":
55              export_output.PredictOutput(outputs={"some_output_3": output_3}),
56      }
57
58      signature_defs = export_utils.build_all_signature_defs(
59          receiver_tensor, export_outputs)
60
61      expected_signature_defs = {
62          "serving_default":
63              signature_def_utils.regression_signature_def(
64                  receiver_tensor, output_1),
65          "head-2":
66              signature_def_utils.classification_signature_def(
67                  receiver_tensor, output_2, None),
68          "head-3":
69              signature_def_utils.predict_signature_def(
70                  {"input": receiver_tensor}, {"some_output_3": output_3})
71      }
72
73      self.assertDictEqual(expected_signature_defs, signature_defs)
74
75  def test_build_all_signature_defs_with_dict_alternatives(self):
76    # Force the test to run in graph mode.
77    # This tests a deprecated v1 API that depends on graph-only functions such
78    # as build_tensor_info.
79    with ops.Graph().as_default():
80      receiver_tensor = array_ops.placeholder(dtypes.string)
81      receiver_tensors_alternative_1 = {
82          "foo": array_ops.placeholder(dtypes.int64),
83          "bar": array_ops.sparse_placeholder(dtypes.float32)
84      }
85      receiver_tensors_alternatives = {"other": receiver_tensors_alternative_1}
86      output_1 = constant_op.constant([1.])
87      output_2 = constant_op.constant(["2"])
88      output_3 = constant_op.constant(["3"])
89      export_outputs = {
90          signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
91              export_output.RegressionOutput(value=output_1),
92          "head-2":
93              export_output.ClassificationOutput(classes=output_2),
94          "head-3":
95              export_output.PredictOutput(outputs={"some_output_3": output_3}),
96      }
97
98      signature_defs = export_utils.build_all_signature_defs(
99          receiver_tensor, export_outputs, receiver_tensors_alternatives)
100
101      expected_signature_defs = {
102          "serving_default":
103              signature_def_utils.regression_signature_def(
104                  receiver_tensor, output_1),
105          "head-2":
106              signature_def_utils.classification_signature_def(
107                  receiver_tensor, output_2, None),
108          "head-3":
109              signature_def_utils.predict_signature_def(
110                  {"input": receiver_tensor}, {"some_output_3": output_3}),
111          "other:head-3":
112              signature_def_utils.predict_signature_def(
113                  receiver_tensors_alternative_1, {"some_output_3": output_3})
114
115          # Note that the alternatives 'other:serving_default' and
116          # 'other:head-2' are invalid, because regression and classification
117          # signatures must take a single string input.  Here we verify that
118          # these invalid signatures are not included in the export_utils.
119      }
120
121      self.assertDictEqual(expected_signature_defs, signature_defs)
122
123  def test_build_all_signature_defs_with_single_alternatives(self):
124    # Force the test to run in graph mode.
125    # This tests a deprecated v1 API that depends on graph-only functions such
126    # as build_tensor_info.
127    with ops.Graph().as_default():
128      receiver_tensor = array_ops.placeholder(dtypes.string)
129      receiver_tensors_alternative_1 = array_ops.placeholder(dtypes.int64)
130      receiver_tensors_alternative_2 = array_ops.sparse_placeholder(
131          dtypes.float32)
132      # Note we are passing single Tensors as values of
133      # receiver_tensors_alternatives, where normally that is a dict.
134      # In this case a dict will be created using the default receiver tensor
135      # name "input".
136      receiver_tensors_alternatives = {
137          "other1": receiver_tensors_alternative_1,
138          "other2": receiver_tensors_alternative_2
139      }
140      output_1 = constant_op.constant([1.])
141      output_2 = constant_op.constant(["2"])
142      output_3 = constant_op.constant(["3"])
143      export_outputs = {
144          signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
145              export_output.RegressionOutput(value=output_1),
146          "head-2":
147              export_output.ClassificationOutput(classes=output_2),
148          "head-3":
149              export_output.PredictOutput(outputs={"some_output_3": output_3}),
150      }
151
152      signature_defs = export_utils.build_all_signature_defs(
153          receiver_tensor, export_outputs, receiver_tensors_alternatives)
154
155      expected_signature_defs = {
156          "serving_default":
157              signature_def_utils.regression_signature_def(
158                  receiver_tensor, output_1),
159          "head-2":
160              signature_def_utils.classification_signature_def(
161                  receiver_tensor, output_2, None),
162          "head-3":
163              signature_def_utils.predict_signature_def(
164                  {"input": receiver_tensor}, {"some_output_3": output_3}),
165          "other1:head-3":
166              signature_def_utils.predict_signature_def(
167                  {"input": receiver_tensors_alternative_1},
168                  {"some_output_3": output_3}),
169          "other2:head-3":
170              signature_def_utils.predict_signature_def(
171                  {"input": receiver_tensors_alternative_2},
172                  {"some_output_3": output_3})
173
174          # Note that the alternatives 'other:serving_default' and
175          # 'other:head-2' are invalid, because regression and classification
176          # signatures must take a single string input.  Here we verify that
177          # these invalid signatures are not included in the export_utils.
178      }
179
180      self.assertDictEqual(expected_signature_defs, signature_defs)
181
182  def test_build_all_signature_defs_export_outputs_required(self):
183    receiver_tensor = constant_op.constant(["11"])
184
185    with self.assertRaises(ValueError) as e:
186      export_utils.build_all_signature_defs(receiver_tensor, None)
187
188    self.assertTrue(
189        str(e.exception).startswith("`export_outputs` must be a dict"))
190
191  def test_get_timestamped_export_dir(self):
192    export_dir_base = tempfile.mkdtemp() + "export/"
193    export_dir_1 = export_utils.get_timestamped_export_dir(
194        export_dir_base)
195    time.sleep(2)
196    export_dir_2 = export_utils.get_timestamped_export_dir(
197        export_dir_base)
198    time.sleep(2)
199    export_dir_3 = export_utils.get_timestamped_export_dir(
200        export_dir_base)
201
202    # Export directories should be named using a timestamp that is seconds
203    # since epoch.  Such a timestamp is 10 digits long.
204    time_1 = os.path.basename(export_dir_1)
205    self.assertEqual(10, len(time_1))
206    time_2 = os.path.basename(export_dir_2)
207    self.assertEqual(10, len(time_2))
208    time_3 = os.path.basename(export_dir_3)
209    self.assertEqual(10, len(time_3))
210
211    self.assertLess(int(time_1), int(time_2))
212    self.assertLess(int(time_2), int(time_3))
213
214  def test_get_temp_export_dir(self):
215    export_dir = os.path.join("tmp", "export", "1576013284")
216    tmp_export_dir = export_utils.get_temp_export_dir(export_dir)
217    self.assertEqual(tmp_export_dir,
218                     os.path.join(b"tmp", b"export", b"temp-1576013284"))
219
220    export_dir = os.path.join(b"tmp", b"export", b"1576013284")
221    tmp_export_dir = export_utils.get_temp_export_dir(export_dir)
222    self.assertEqual(tmp_export_dir,
223                     os.path.join(b"tmp", b"export", b"temp-1576013284"))
224
225  def test_build_all_signature_defs_serving_only(self):
226    # Force the test to run in graph mode.
227    # This tests a deprecated v1 API that depends on graph-only functions such
228    # as build_tensor_info.
229    with ops.Graph().as_default():
230      receiver_tensor = {"input": array_ops.placeholder(dtypes.string)}
231      output_1 = constant_op.constant([1.])
232      export_outputs = {
233          signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
234              export_output.PredictOutput(outputs=output_1),
235          "train":
236              export_output.TrainOutput(loss=output_1),
237      }
238
239      signature_defs = export_utils.build_all_signature_defs(
240          receiver_tensor, export_outputs)
241
242      expected_signature_defs = {
243          "serving_default":
244              signature_def_utils.predict_signature_def(receiver_tensor,
245                                                        {"output": output_1})
246      }
247
248      self.assertDictEqual(expected_signature_defs, signature_defs)
249
250      signature_defs = export_utils.build_all_signature_defs(
251          receiver_tensor, export_outputs, serving_only=False)
252
253      expected_signature_defs.update({
254          "train":
255              signature_def_utils.supervised_train_signature_def(
256                  receiver_tensor, loss={"loss": output_1})
257      })
258
259      self.assertDictEqual(expected_signature_defs, signature_defs)
260
261  def test_export_outputs_for_mode(self):
262    predictions = {"predictions": constant_op.constant([1.])}
263    loss = {"loss": constant_op.constant([2.])}
264    metrics = {
265        "metrics": (constant_op.constant([3.]), constant_op.constant([4.]))}
266    expected_metrics = {
267        "metrics/value": metrics["metrics"][0],
268        "metrics/update_op": metrics["metrics"][1]
269    }
270
271    def _build_export_output(mode):
272      return export_utils.export_outputs_for_mode(
273          mode, None, predictions, loss, metrics)
274
275    ret = _build_export_output(KerasModeKeys.TRAIN)
276    self.assertIn(signature_constants.DEFAULT_TRAIN_SIGNATURE_DEF_KEY, ret)
277    export_out = ret[signature_constants.DEFAULT_TRAIN_SIGNATURE_DEF_KEY]
278    self.assertIsInstance(export_out, export_output.TrainOutput)
279    self.assertEqual(export_out.predictions, predictions)
280    self.assertEqual(export_out.loss, loss)
281    self.assertEqual(export_out.metrics, expected_metrics)
282
283    ret = _build_export_output(KerasModeKeys.TEST)
284    self.assertIn(signature_constants.DEFAULT_EVAL_SIGNATURE_DEF_KEY, ret)
285    export_out = ret[signature_constants.DEFAULT_EVAL_SIGNATURE_DEF_KEY]
286    self.assertIsInstance(export_out, export_output.EvalOutput)
287    self.assertEqual(export_out.predictions, predictions)
288    self.assertEqual(export_out.loss, loss)
289    self.assertEqual(export_out.metrics, expected_metrics)
290
291    ret = _build_export_output(KerasModeKeys.PREDICT)
292    self.assertIn(signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, ret)
293    export_out = ret[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
294    self.assertIsInstance(export_out, export_output.PredictOutput)
295    self.assertEqual(export_out.outputs, predictions)
296
297    classes = constant_op.constant(["class5"])
298    ret = export_utils.export_outputs_for_mode(
299        KerasModeKeys.PREDICT,
300        {"classify": export_output.ClassificationOutput(
301            classes=classes)})
302    self.assertIn("classify", ret)
303    export_out = ret["classify"]
304    self.assertIsInstance(export_out, export_output.ClassificationOutput)
305    self.assertEqual(export_out.classes, classes)
306
307
308if __name__ == "__main__":
309  test.main()
310