• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests of utilities supporting export to SavedModel."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import os
21import tempfile
22import time
23
24from tensorflow.contrib.layers.python.layers import feature_column as fc
25from tensorflow.contrib.learn.python.learn import export_strategy as export_strategy_lib
26from tensorflow.contrib.learn.python.learn.estimators import constants
27from tensorflow.contrib.learn.python.learn.estimators import estimator
28from tensorflow.contrib.learn.python.learn.estimators import model_fn
29from tensorflow.contrib.learn.python.learn.utils import input_fn_utils
30from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
31from tensorflow.core.framework import tensor_shape_pb2
32from tensorflow.core.framework import types_pb2
33from tensorflow.core.protobuf import meta_graph_pb2
34from tensorflow.python.estimator import estimator as core_estimator
35from tensorflow.python.framework import constant_op
36from tensorflow.python.framework import dtypes
37from tensorflow.python.ops import array_ops
38from tensorflow.python.platform import gfile
39from tensorflow.python.platform import test
40from tensorflow.python.saved_model import signature_constants
41from tensorflow.python.saved_model import signature_def_utils
42from tensorflow.python.util import compat
43
44
45class TestEstimator(estimator.Estimator):
46
47  def __init__(self, *args, **kwargs):
48    super(TestEstimator, self).__init__(*args, **kwargs)
49    self.last_exported_checkpoint = ""
50    self.last_exported_dir = ""
51
52  # @Override
53  def export_savedmodel(self,
54                        export_dir,
55                        serving_input_fn,
56                        default_output_alternative_key=None,
57                        assets_extra=None,
58                        as_text=False,
59                        checkpoint_path=None,
60                        strip_default_attrs=False):
61
62    if not os.path.exists(export_dir):
63      os.makedirs(export_dir)
64
65    open(os.path.join(export_dir, "placeholder.txt"), "a").close()
66
67    self.last_exported_checkpoint = checkpoint_path
68    self.last_exported_dir = export_dir
69
70    return export_dir
71
72
73class SavedModelExportUtilsTest(test.TestCase):
74
75  def test_build_standardized_signature_def_regression(self):
76    input_tensors = {
77        "input-1":
78            array_ops.placeholder(dtypes.string, 1, name="input-tensor-1")
79    }
80    output_tensors = {
81        "output-1":
82            array_ops.placeholder(dtypes.float32, 1, name="output-tensor-1")
83    }
84    problem_type = constants.ProblemType.LINEAR_REGRESSION
85    actual_signature_def = (
86        saved_model_export_utils.build_standardized_signature_def(
87            input_tensors, output_tensors, problem_type))
88    expected_signature_def = meta_graph_pb2.SignatureDef()
89    shape = tensor_shape_pb2.TensorShapeProto(
90        dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])
91    dtype_float = types_pb2.DataType.Value("DT_FLOAT")
92    dtype_string = types_pb2.DataType.Value("DT_STRING")
93    expected_signature_def.inputs[signature_constants.REGRESS_INPUTS].CopyFrom(
94        meta_graph_pb2.TensorInfo(
95            name="input-tensor-1:0", dtype=dtype_string, tensor_shape=shape))
96    expected_signature_def.outputs[
97        signature_constants.REGRESS_OUTPUTS].CopyFrom(
98            meta_graph_pb2.TensorInfo(
99                name="output-tensor-1:0", dtype=dtype_float,
100                tensor_shape=shape))
101
102    expected_signature_def.method_name = signature_constants.REGRESS_METHOD_NAME
103    self.assertEqual(actual_signature_def, expected_signature_def)
104
105  def test_build_standardized_signature_def_classification(self):
106    """Tests classification with one output tensor."""
107    input_tensors = {
108        "input-1":
109            array_ops.placeholder(dtypes.string, 1, name="input-tensor-1")
110    }
111    output_tensors = {
112        "output-1":
113            array_ops.placeholder(dtypes.string, 1, name="output-tensor-1")
114    }
115    problem_type = constants.ProblemType.CLASSIFICATION
116    actual_signature_def = (
117        saved_model_export_utils.build_standardized_signature_def(
118            input_tensors, output_tensors, problem_type))
119    expected_signature_def = meta_graph_pb2.SignatureDef()
120    shape = tensor_shape_pb2.TensorShapeProto(
121        dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])
122    dtype_string = types_pb2.DataType.Value("DT_STRING")
123    expected_signature_def.inputs[signature_constants.CLASSIFY_INPUTS].CopyFrom(
124        meta_graph_pb2.TensorInfo(
125            name="input-tensor-1:0", dtype=dtype_string, tensor_shape=shape))
126    expected_signature_def.outputs[
127        signature_constants.CLASSIFY_OUTPUT_CLASSES].CopyFrom(
128            meta_graph_pb2.TensorInfo(
129                name="output-tensor-1:0",
130                dtype=dtype_string,
131                tensor_shape=shape))
132
133    expected_signature_def.method_name = (
134        signature_constants.CLASSIFY_METHOD_NAME)
135    self.assertEqual(actual_signature_def, expected_signature_def)
136
137  def test_build_standardized_signature_def_classification2(self):
138    """Tests multiple output tensors that include classes and probabilities."""
139    input_tensors = {
140        "input-1":
141            array_ops.placeholder(dtypes.string, 1, name="input-tensor-1")
142    }
143    output_tensors = {
144        "classes":
145            array_ops.placeholder(
146                dtypes.string, 1, name="output-tensor-classes"),
147        # Will be used for CLASSIFY_OUTPUT_SCORES.
148        "probabilities":
149            array_ops.placeholder(
150                dtypes.float32, 1, name="output-tensor-proba"),
151        "logits":
152            array_ops.placeholder(
153                dtypes.float32, 1, name="output-tensor-logits-unused"),
154    }
155    problem_type = constants.ProblemType.CLASSIFICATION
156    actual_signature_def = (
157        saved_model_export_utils.build_standardized_signature_def(
158            input_tensors, output_tensors, problem_type))
159    expected_signature_def = meta_graph_pb2.SignatureDef()
160    shape = tensor_shape_pb2.TensorShapeProto(
161        dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])
162    dtype_float = types_pb2.DataType.Value("DT_FLOAT")
163    dtype_string = types_pb2.DataType.Value("DT_STRING")
164    expected_signature_def.inputs[signature_constants.CLASSIFY_INPUTS].CopyFrom(
165        meta_graph_pb2.TensorInfo(
166            name="input-tensor-1:0", dtype=dtype_string, tensor_shape=shape))
167    expected_signature_def.outputs[
168        signature_constants.CLASSIFY_OUTPUT_CLASSES].CopyFrom(
169            meta_graph_pb2.TensorInfo(
170                name="output-tensor-classes:0",
171                dtype=dtype_string,
172                tensor_shape=shape))
173    expected_signature_def.outputs[
174        signature_constants.CLASSIFY_OUTPUT_SCORES].CopyFrom(
175            meta_graph_pb2.TensorInfo(
176                name="output-tensor-proba:0",
177                dtype=dtype_float,
178                tensor_shape=shape))
179
180    expected_signature_def.method_name = (
181        signature_constants.CLASSIFY_METHOD_NAME)
182    self.assertEqual(actual_signature_def, expected_signature_def)
183
184  def test_build_standardized_signature_def_classification3(self):
185    """Tests multiple output tensors that include classes and scores."""
186    input_tensors = {
187        "input-1":
188            array_ops.placeholder(dtypes.string, 1, name="input-tensor-1")
189    }
190    output_tensors = {
191        "classes":
192            array_ops.placeholder(
193                dtypes.string, 1, name="output-tensor-classes"),
194        "scores":
195            array_ops.placeholder(
196                dtypes.float32, 1, name="output-tensor-scores"),
197        "logits":
198            array_ops.placeholder(
199                dtypes.float32, 1, name="output-tensor-logits-unused"),
200    }
201    problem_type = constants.ProblemType.CLASSIFICATION
202    actual_signature_def = (
203        saved_model_export_utils.build_standardized_signature_def(
204            input_tensors, output_tensors, problem_type))
205    expected_signature_def = meta_graph_pb2.SignatureDef()
206    shape = tensor_shape_pb2.TensorShapeProto(
207        dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])
208    dtype_float = types_pb2.DataType.Value("DT_FLOAT")
209    dtype_string = types_pb2.DataType.Value("DT_STRING")
210    expected_signature_def.inputs[signature_constants.CLASSIFY_INPUTS].CopyFrom(
211        meta_graph_pb2.TensorInfo(
212            name="input-tensor-1:0", dtype=dtype_string, tensor_shape=shape))
213    expected_signature_def.outputs[
214        signature_constants.CLASSIFY_OUTPUT_CLASSES].CopyFrom(
215            meta_graph_pb2.TensorInfo(
216                name="output-tensor-classes:0",
217                dtype=dtype_string,
218                tensor_shape=shape))
219    expected_signature_def.outputs[
220        signature_constants.CLASSIFY_OUTPUT_SCORES].CopyFrom(
221            meta_graph_pb2.TensorInfo(
222                name="output-tensor-scores:0",
223                dtype=dtype_float,
224                tensor_shape=shape))
225
226    expected_signature_def.method_name = (
227        signature_constants.CLASSIFY_METHOD_NAME)
228    self.assertEqual(actual_signature_def, expected_signature_def)
229
230  def test_build_standardized_signature_def_classification4(self):
231    """Tests classification without classes tensor."""
232    input_tensors = {
233        "input-1":
234            array_ops.placeholder(dtypes.string, 1, name="input-tensor-1")
235    }
236    output_tensors = {
237        "probabilities":
238            array_ops.placeholder(
239                dtypes.float32, 1, name="output-tensor-proba"),
240        "logits":
241            array_ops.placeholder(
242                dtypes.float32, 1, name="output-tensor-logits-unused"),
243    }
244    problem_type = constants.ProblemType.CLASSIFICATION
245    actual_signature_def = (
246        saved_model_export_utils.build_standardized_signature_def(
247            input_tensors, output_tensors, problem_type))
248    expected_signature_def = meta_graph_pb2.SignatureDef()
249    shape = tensor_shape_pb2.TensorShapeProto(
250        dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])
251    dtype_float = types_pb2.DataType.Value("DT_FLOAT")
252    dtype_string = types_pb2.DataType.Value("DT_STRING")
253    expected_signature_def.inputs[signature_constants.CLASSIFY_INPUTS].CopyFrom(
254        meta_graph_pb2.TensorInfo(
255            name="input-tensor-1:0", dtype=dtype_string, tensor_shape=shape))
256    expected_signature_def.outputs[
257        signature_constants.CLASSIFY_OUTPUT_SCORES].CopyFrom(
258            meta_graph_pb2.TensorInfo(
259                name="output-tensor-proba:0",
260                dtype=dtype_float,
261                tensor_shape=shape))
262
263    expected_signature_def.method_name = (
264        signature_constants.CLASSIFY_METHOD_NAME)
265    self.assertEqual(actual_signature_def, expected_signature_def)
266
267  def test_build_standardized_signature_def_classification5(self):
268    """Tests multiple output tensors that include integer classes and scores.
269
270    Integer classes are dropped out, because Servo classification can only serve
271    string classes. So, only scores are present in the signature.
272    """
273    input_tensors = {
274        "input-1":
275            array_ops.placeholder(dtypes.string, 1, name="input-tensor-1")
276    }
277    output_tensors = {
278        "classes":
279            array_ops.placeholder(
280                dtypes.int64, 1, name="output-tensor-classes"),
281        "scores":
282            array_ops.placeholder(
283                dtypes.float32, 1, name="output-tensor-scores"),
284        "logits":
285            array_ops.placeholder(
286                dtypes.float32, 1, name="output-tensor-logits-unused"),
287    }
288    problem_type = constants.ProblemType.CLASSIFICATION
289    actual_signature_def = (
290        saved_model_export_utils.build_standardized_signature_def(
291            input_tensors, output_tensors, problem_type))
292    expected_signature_def = meta_graph_pb2.SignatureDef()
293    shape = tensor_shape_pb2.TensorShapeProto(
294        dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])
295    dtype_float = types_pb2.DataType.Value("DT_FLOAT")
296    dtype_string = types_pb2.DataType.Value("DT_STRING")
297    expected_signature_def.inputs[signature_constants.CLASSIFY_INPUTS].CopyFrom(
298        meta_graph_pb2.TensorInfo(
299            name="input-tensor-1:0", dtype=dtype_string, tensor_shape=shape))
300    expected_signature_def.outputs[
301        signature_constants.CLASSIFY_OUTPUT_SCORES].CopyFrom(
302            meta_graph_pb2.TensorInfo(
303                name="output-tensor-scores:0",
304                dtype=dtype_float,
305                tensor_shape=shape))
306
307    expected_signature_def.method_name = (
308        signature_constants.CLASSIFY_METHOD_NAME)
309    self.assertEqual(actual_signature_def, expected_signature_def)
310
311  def test_build_standardized_signature_def_classification6(self):
312    """Tests multiple output tensors that with integer classes and no scores.
313
314    Servo classification cannot serve integer classes, but no scores are
315    available. So, we fall back to predict signature.
316    """
317    input_tensors = {
318        "input-1":
319            array_ops.placeholder(dtypes.string, 1, name="input-tensor-1")
320    }
321    output_tensors = {
322        "classes":
323            array_ops.placeholder(
324                dtypes.int64, 1, name="output-tensor-classes"),
325        "logits":
326            array_ops.placeholder(
327                dtypes.float32, 1, name="output-tensor-logits"),
328    }
329    problem_type = constants.ProblemType.CLASSIFICATION
330    actual_signature_def = (
331        saved_model_export_utils.build_standardized_signature_def(
332            input_tensors, output_tensors, problem_type))
333    expected_signature_def = meta_graph_pb2.SignatureDef()
334    shape = tensor_shape_pb2.TensorShapeProto(
335        dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])
336    dtype_int64 = types_pb2.DataType.Value("DT_INT64")
337    dtype_float = types_pb2.DataType.Value("DT_FLOAT")
338    dtype_string = types_pb2.DataType.Value("DT_STRING")
339    expected_signature_def.inputs["input-1"].CopyFrom(
340        meta_graph_pb2.TensorInfo(
341            name="input-tensor-1:0", dtype=dtype_string, tensor_shape=shape))
342    expected_signature_def.outputs["classes"].CopyFrom(
343        meta_graph_pb2.TensorInfo(
344            name="output-tensor-classes:0",
345            dtype=dtype_int64,
346            tensor_shape=shape))
347    expected_signature_def.outputs["logits"].CopyFrom(
348        meta_graph_pb2.TensorInfo(
349            name="output-tensor-logits:0",
350            dtype=dtype_float,
351            tensor_shape=shape))
352
353    expected_signature_def.method_name = (
354        signature_constants.PREDICT_METHOD_NAME)
355    self.assertEqual(actual_signature_def, expected_signature_def)
356
357  def test_get_input_alternatives(self):
358    input_ops = input_fn_utils.InputFnOps("bogus features dict", None,
359                                          "bogus default input dict")
360
361    input_alternatives, _ = saved_model_export_utils.get_input_alternatives(
362        input_ops)
363    self.assertEqual(input_alternatives[
364        saved_model_export_utils.DEFAULT_INPUT_ALTERNATIVE_KEY],
365                     "bogus default input dict")
366    # self.assertEqual(input_alternatives[
367    #     saved_model_export_utils.FEATURES_INPUT_ALTERNATIVE_KEY],
368    #                  "bogus features dict")
369
370  def test_get_output_alternatives_explicit_default(self):
371    provided_output_alternatives = {
372        "head-1": (constants.ProblemType.LINEAR_REGRESSION,
373                   "bogus output dict"),
374        "head-2": (constants.ProblemType.CLASSIFICATION, "bogus output dict 2"),
375        "head-3": (constants.ProblemType.UNSPECIFIED, "bogus output dict 3"),
376    }
377    model_fn_ops = model_fn.ModelFnOps(
378        model_fn.ModeKeys.INFER,
379        predictions={"some_output": "bogus_tensor"},
380        output_alternatives=provided_output_alternatives)
381
382    output_alternatives, _ = saved_model_export_utils.get_output_alternatives(
383        model_fn_ops, "head-1")
384
385    self.assertEqual(provided_output_alternatives, output_alternatives)
386
387  def test_get_output_alternatives_wrong_default(self):
388    provided_output_alternatives = {
389        "head-1": (constants.ProblemType.LINEAR_REGRESSION,
390                   "bogus output dict"),
391        "head-2": (constants.ProblemType.CLASSIFICATION, "bogus output dict 2"),
392        "head-3": (constants.ProblemType.UNSPECIFIED, "bogus output dict 3"),
393    }
394    model_fn_ops = model_fn.ModelFnOps(
395        model_fn.ModeKeys.INFER,
396        predictions={"some_output": "bogus_tensor"},
397        output_alternatives=provided_output_alternatives)
398
399    with self.assertRaises(ValueError) as e:
400      saved_model_export_utils.get_output_alternatives(model_fn_ops, "WRONG")
401
402    self.assertEqual("Requested default_output_alternative: WRONG, but "
403                     "available output_alternatives are: ['head-1', 'head-2', "
404                     "'head-3']", str(e.exception))
405
406  def test_get_output_alternatives_single_no_default(self):
407    prediction_tensor = constant_op.constant(["bogus"])
408    provided_output_alternatives = {
409        "head-1": (constants.ProblemType.LINEAR_REGRESSION, {
410            "output": prediction_tensor
411        }),
412    }
413    model_fn_ops = model_fn.ModelFnOps(
414        model_fn.ModeKeys.INFER,
415        predictions=prediction_tensor,
416        output_alternatives=provided_output_alternatives)
417
418    output_alternatives, _ = saved_model_export_utils.get_output_alternatives(
419        model_fn_ops)
420
421    self.assertEqual({
422        "head-1": (constants.ProblemType.LINEAR_REGRESSION, {
423            "output": prediction_tensor
424        })
425    }, output_alternatives)
426
427  def test_get_output_alternatives_multi_no_default(self):
428    provided_output_alternatives = {
429        "head-1": (constants.ProblemType.LINEAR_REGRESSION,
430                   "bogus output dict"),
431        "head-2": (constants.ProblemType.CLASSIFICATION, "bogus output dict 2"),
432        "head-3": (constants.ProblemType.UNSPECIFIED, "bogus output dict 3"),
433    }
434    model_fn_ops = model_fn.ModelFnOps(
435        model_fn.ModeKeys.INFER,
436        predictions={"some_output": "bogus_tensor"},
437        output_alternatives=provided_output_alternatives)
438
439    with self.assertRaises(ValueError) as e:
440      saved_model_export_utils.get_output_alternatives(model_fn_ops)
441
442    self.assertEqual("Please specify a default_output_alternative.  Available "
443                     "output_alternatives are: ['head-1', 'head-2', 'head-3']",
444                     str(e.exception))
445
446  def test_get_output_alternatives_none_provided(self):
447    prediction_tensor = constant_op.constant(["bogus"])
448    model_fn_ops = model_fn.ModelFnOps(
449        model_fn.ModeKeys.INFER,
450        predictions={"some_output": prediction_tensor},
451        output_alternatives=None)
452
453    output_alternatives, _ = saved_model_export_utils.get_output_alternatives(
454        model_fn_ops)
455
456    self.assertEqual({
457        "default_output_alternative": (constants.ProblemType.UNSPECIFIED, {
458            "some_output": prediction_tensor
459        })
460    }, output_alternatives)
461
462  def test_get_output_alternatives_empty_provided_with_default(self):
463    prediction_tensor = constant_op.constant(["bogus"])
464    model_fn_ops = model_fn.ModelFnOps(
465        model_fn.ModeKeys.INFER,
466        predictions={"some_output": prediction_tensor},
467        output_alternatives={})
468
469    with self.assertRaises(ValueError) as e:
470      saved_model_export_utils.get_output_alternatives(model_fn_ops, "WRONG")
471
472    self.assertEqual("Requested default_output_alternative: WRONG, but "
473                     "available output_alternatives are: []", str(e.exception))
474
475  def test_get_output_alternatives_empty_provided_no_default(self):
476    prediction_tensor = constant_op.constant(["bogus"])
477    model_fn_ops = model_fn.ModelFnOps(
478        model_fn.ModeKeys.INFER,
479        predictions={"some_output": prediction_tensor},
480        output_alternatives={})
481
482    output_alternatives, _ = saved_model_export_utils.get_output_alternatives(
483        model_fn_ops)
484
485    self.assertEqual({
486        "default_output_alternative": (constants.ProblemType.UNSPECIFIED, {
487            "some_output": prediction_tensor
488        })
489    }, output_alternatives)
490
491  def test_get_output_alternatives_implicit_single(self):
492    prediction_tensor = constant_op.constant(["bogus"])
493    model_fn_ops = model_fn.ModelFnOps(
494        model_fn.ModeKeys.INFER,
495        predictions=prediction_tensor,
496        output_alternatives=None)
497
498    output_alternatives, _ = saved_model_export_utils.get_output_alternatives(
499        model_fn_ops)
500    self.assertEqual({
501        "default_output_alternative": (constants.ProblemType.UNSPECIFIED, {
502            "output": prediction_tensor
503        })
504    }, output_alternatives)
505
506  def test_build_all_signature_defs(self):
507    input_features = constant_op.constant(["10"])
508    input_example = constant_op.constant(["input string"])
509    input_ops = input_fn_utils.InputFnOps({
510        "features": input_features
511    }, None, {
512        "default input": input_example
513    })
514    input_alternatives, _ = (
515        saved_model_export_utils.get_input_alternatives(input_ops))
516    output_1 = constant_op.constant([1.0])
517    output_2 = constant_op.constant(["2"])
518    output_3 = constant_op.constant(["3"])
519    provided_output_alternatives = {
520        "head-1": (constants.ProblemType.LINEAR_REGRESSION, {
521            "some_output_1": output_1
522        }),
523        "head-2": (constants.ProblemType.CLASSIFICATION, {
524            "some_output_2": output_2
525        }),
526        "head-3": (constants.ProblemType.UNSPECIFIED, {
527            "some_output_3": output_3
528        }),
529    }
530    model_fn_ops = model_fn.ModelFnOps(
531        model_fn.ModeKeys.INFER,
532        predictions={"some_output": constant_op.constant(["4"])},
533        output_alternatives=provided_output_alternatives)
534    output_alternatives, _ = (
535        saved_model_export_utils.get_output_alternatives(
536            model_fn_ops, "head-1"))
537
538    signature_defs = saved_model_export_utils.build_all_signature_defs(
539        input_alternatives, output_alternatives, "head-1")
540
541    expected_signature_defs = {
542        "serving_default":
543            signature_def_utils.regression_signature_def(
544                input_example, output_1),
545        "default_input_alternative:head-1":
546            signature_def_utils.regression_signature_def(
547                input_example, output_1),
548        "default_input_alternative:head-2":
549            signature_def_utils.classification_signature_def(
550                input_example, output_2, None),
551        "default_input_alternative:head-3":
552            signature_def_utils.predict_signature_def({
553                "default input": input_example
554            }, {
555                "some_output_3": output_3
556            }),
557        # "features_input_alternative:head-1":
558        #     signature_def_utils.regression_signature_def(input_features,
559        #                                                  output_1),
560        # "features_input_alternative:head-2":
561        #     signature_def_utils.classification_signature_def(input_features,
562        #                                                      output_2, None),
563        # "features_input_alternative:head-3":
564        #     signature_def_utils.predict_signature_def({
565        #         "input": input_features
566        #     }, {"output": output_3}),
567    }
568
569    self.assertDictEqual(expected_signature_defs, signature_defs)
570
571  def test_build_all_signature_defs_legacy_input_fn_not_supported(self):
572    """Tests that legacy input_fn returning (features, labels) raises error.
573
574    serving_input_fn must return InputFnOps including a default input
575    alternative.
576    """
577    input_features = constant_op.constant(["10"])
578    input_ops = ({"features": input_features}, None)
579    input_alternatives, _ = (
580        saved_model_export_utils.get_input_alternatives(input_ops))
581    output_1 = constant_op.constant(["1"])
582    output_2 = constant_op.constant(["2"])
583    output_3 = constant_op.constant(["3"])
584    provided_output_alternatives = {
585        "head-1": (constants.ProblemType.LINEAR_REGRESSION, {
586            "some_output_1": output_1
587        }),
588        "head-2": (constants.ProblemType.CLASSIFICATION, {
589            "some_output_2": output_2
590        }),
591        "head-3": (constants.ProblemType.UNSPECIFIED, {
592            "some_output_3": output_3
593        }),
594    }
595    model_fn_ops = model_fn.ModelFnOps(
596        model_fn.ModeKeys.INFER,
597        predictions={"some_output": constant_op.constant(["4"])},
598        output_alternatives=provided_output_alternatives)
599    output_alternatives, _ = (
600        saved_model_export_utils.get_output_alternatives(
601            model_fn_ops, "head-1"))
602
603    with self.assertRaisesRegexp(
604        ValueError, "A default input_alternative must be provided"):
605      saved_model_export_utils.build_all_signature_defs(
606          input_alternatives, output_alternatives, "head-1")
607
608  def test_get_timestamped_export_dir(self):
609    export_dir_base = tempfile.mkdtemp() + "export/"
610    export_dir_1 = saved_model_export_utils.get_timestamped_export_dir(
611        export_dir_base)
612    time.sleep(2)
613    export_dir_2 = saved_model_export_utils.get_timestamped_export_dir(
614        export_dir_base)
615    time.sleep(2)
616    export_dir_3 = saved_model_export_utils.get_timestamped_export_dir(
617        export_dir_base)
618
619    # Export directories should be named using a timestamp that is seconds
620    # since epoch.  Such a timestamp is 10 digits long.
621    time_1 = os.path.basename(export_dir_1)
622    self.assertEqual(10, len(time_1))
623    time_2 = os.path.basename(export_dir_2)
624    self.assertEqual(10, len(time_2))
625    time_3 = os.path.basename(export_dir_3)
626    self.assertEqual(10, len(time_3))
627
628    self.assertTrue(int(time_1) < int(time_2))
629    self.assertTrue(int(time_2) < int(time_3))
630
631  def test_garbage_collect_exports(self):
632    export_dir_base = tempfile.mkdtemp() + "export/"
633    gfile.MkDir(export_dir_base)
634    export_dir_1 = _create_test_export_dir(export_dir_base)
635    export_dir_2 = _create_test_export_dir(export_dir_base)
636    export_dir_3 = _create_test_export_dir(export_dir_base)
637    export_dir_4 = _create_test_export_dir(export_dir_base)
638
639    self.assertTrue(gfile.Exists(export_dir_1))
640    self.assertTrue(gfile.Exists(export_dir_2))
641    self.assertTrue(gfile.Exists(export_dir_3))
642    self.assertTrue(gfile.Exists(export_dir_4))
643
644    # Garbage collect all but the most recent 2 exports,
645    # where recency is determined based on the timestamp directory names.
646    saved_model_export_utils.garbage_collect_exports(export_dir_base, 2)
647
648    self.assertFalse(gfile.Exists(export_dir_1))
649    self.assertFalse(gfile.Exists(export_dir_2))
650    self.assertTrue(gfile.Exists(export_dir_3))
651    self.assertTrue(gfile.Exists(export_dir_4))
652
653  def test_get_most_recent_export(self):
654    export_dir_base = tempfile.mkdtemp() + "export/"
655    gfile.MkDir(export_dir_base)
656    _create_test_export_dir(export_dir_base)
657    _create_test_export_dir(export_dir_base)
658    _create_test_export_dir(export_dir_base)
659    export_dir_4 = _create_test_export_dir(export_dir_base)
660
661    (most_recent_export_dir, most_recent_export_version) = (
662        saved_model_export_utils.get_most_recent_export(export_dir_base))
663
664    self.assertEqual(
665        compat.as_bytes(export_dir_4), compat.as_bytes(most_recent_export_dir))
666    self.assertEqual(
667        compat.as_bytes(export_dir_4),
668        os.path.join(
669            compat.as_bytes(export_dir_base),
670            compat.as_bytes(str(most_recent_export_version))))
671
672  def test_make_export_strategy(self):
673    """Only tests that an ExportStrategy instance is created."""
674
675    def _serving_input_fn():
676      return array_ops.constant([1]), None
677
678    export_strategy = saved_model_export_utils.make_export_strategy(
679        serving_input_fn=_serving_input_fn,
680        default_output_alternative_key="default",
681        assets_extra={"from/path": "to/path"},
682        as_text=False,
683        exports_to_keep=5)
684    self.assertTrue(
685        isinstance(export_strategy, export_strategy_lib.ExportStrategy))
686
687  def test_make_parsing_export_strategy(self):
688    """Only tests that an ExportStrategy instance is created."""
689    sparse_col = fc.sparse_column_with_hash_bucket(
690        "sparse_column", hash_bucket_size=100)
691    embedding_col = fc.embedding_column(
692        fc.sparse_column_with_hash_bucket(
693            "sparse_column_for_embedding", hash_bucket_size=10),
694        dimension=4)
695    real_valued_col1 = fc.real_valued_column("real_valued_column1")
696    bucketized_col1 = fc.bucketized_column(
697        fc.real_valued_column("real_valued_column_for_bucketization1"), [0, 4])
698    feature_columns = [
699        sparse_col, embedding_col, real_valued_col1, bucketized_col1
700    ]
701
702    export_strategy = saved_model_export_utils.make_parsing_export_strategy(
703        feature_columns=feature_columns)
704    self.assertTrue(
705        isinstance(export_strategy, export_strategy_lib.ExportStrategy))
706
707  def test_make_best_model_export_strategy(self):
708    export_dir_base = tempfile.mkdtemp() + "export/"
709    gfile.MkDir(export_dir_base)
710
711    test_estimator = TestEstimator()
712    export_strategy = saved_model_export_utils.make_best_model_export_strategy(
713        serving_input_fn=None, exports_to_keep=3, compare_fn=None)
714
715    self.assertNotEqual("",
716                        export_strategy.export(test_estimator, export_dir_base,
717                                               "fake_ckpt_0", {
718                                                   "loss": 100
719                                               }))
720    self.assertNotEqual("", test_estimator.last_exported_dir)
721    self.assertNotEqual("", test_estimator.last_exported_checkpoint)
722
723    self.assertEqual("",
724                     export_strategy.export(test_estimator, export_dir_base,
725                                            "fake_ckpt_1", {
726                                                "loss": 101
727                                            }))
728    self.assertEqual(test_estimator.last_exported_dir,
729                     os.path.join(export_dir_base, "fake_ckpt_0"))
730
731    self.assertNotEqual("",
732                        export_strategy.export(test_estimator, export_dir_base,
733                                               "fake_ckpt_2", {
734                                                   "loss": 10
735                                               }))
736    self.assertEqual(test_estimator.last_exported_dir,
737                     os.path.join(export_dir_base, "fake_ckpt_2"))
738
739    self.assertEqual("",
740                     export_strategy.export(test_estimator, export_dir_base,
741                                            "fake_ckpt_3", {
742                                                "loss": 20
743                                            }))
744    self.assertEqual(test_estimator.last_exported_dir,
745                     os.path.join(export_dir_base, "fake_ckpt_2"))
746
747  def test_make_best_model_export_strategy_with_preemption(self):
748    model_dir = self.get_temp_dir()
749    eval_dir_base = os.path.join(model_dir, "eval_continuous")
750    core_estimator._write_dict_to_summary(eval_dir_base, {"loss": 50}, 1)
751    core_estimator._write_dict_to_summary(eval_dir_base, {"loss": 60}, 2)
752
753    test_estimator = TestEstimator()
754    export_strategy = saved_model_export_utils.make_best_model_export_strategy(
755        serving_input_fn=None,
756        exports_to_keep=3,
757        model_dir=model_dir,
758        event_file_pattern="eval_continuous/*.tfevents.*",
759        compare_fn=None)
760
761    export_dir_base = os.path.join(self.get_temp_dir(), "export")
762    self.assertEqual("",
763                     export_strategy.export(test_estimator, export_dir_base,
764                                            "fake_ckpt_0", {
765                                                "loss": 100
766                                            }))
767    self.assertEqual("", test_estimator.last_exported_dir)
768    self.assertEqual("", test_estimator.last_exported_checkpoint)
769
770    self.assertNotEqual("",
771                        export_strategy.export(test_estimator, export_dir_base,
772                                               "fake_ckpt_2", {
773                                                   "loss": 10
774                                               }))
775    self.assertEqual(test_estimator.last_exported_dir,
776                     os.path.join(export_dir_base, "fake_ckpt_2"))
777
778    self.assertEqual("",
779                     export_strategy.export(test_estimator, export_dir_base,
780                                            "fake_ckpt_3", {
781                                                "loss": 20
782                                            }))
783    self.assertEqual(test_estimator.last_exported_dir,
784                     os.path.join(export_dir_base, "fake_ckpt_2"))
785
786  def test_make_best_model_export_strategy_exceptions(self):
787    export_dir_base = tempfile.mkdtemp() + "export/"
788
789    test_estimator = TestEstimator()
790    export_strategy = saved_model_export_utils.make_best_model_export_strategy(
791        serving_input_fn=None, exports_to_keep=3, compare_fn=None)
792
793    with self.assertRaises(ValueError):
794      export_strategy.export(test_estimator, export_dir_base, "", {"loss": 200})
795
796    with self.assertRaises(ValueError):
797      export_strategy.export(test_estimator, export_dir_base, "fake_ckpt_1",
798                             None)
799
800  def test_extend_export_strategy(self):
801
802    def _base_export_fn(unused_estimator,
803                        export_dir_base,
804                        unused_checkpoint_path=None):
805      base_path = os.path.join(export_dir_base, "e1")
806      gfile.MkDir(base_path)
807      return base_path
808
809    def _post_export_fn(orig_path, new_path):
810      assert orig_path.endswith("/e1")
811      post_export_path = os.path.join(new_path, "rewrite")
812      gfile.MkDir(post_export_path)
813      return post_export_path
814
815    base_export_strategy = export_strategy_lib.ExportStrategy(
816        "Servo", _base_export_fn)
817
818    final_export_strategy = saved_model_export_utils.extend_export_strategy(
819        base_export_strategy, _post_export_fn, "Servo2")
820    self.assertEqual(final_export_strategy.name, "Servo2")
821
822    test_estimator = TestEstimator()
823    tmpdir = tempfile.mkdtemp()
824    export_model_dir = os.path.join(tmpdir, "model")
825    checkpoint_path = os.path.join(tmpdir, "checkpoint")
826    final_path = final_export_strategy.export(test_estimator, export_model_dir,
827                                              checkpoint_path)
828    self.assertEqual(os.path.join(export_model_dir, "rewrite"), final_path)
829
830  def test_extend_export_strategy_same_name(self):
831
832    def _base_export_fn(unused_estimator,
833                        export_dir_base,
834                        unused_checkpoint_path=None):
835      base_path = os.path.join(export_dir_base, "e1")
836      gfile.MkDir(base_path)
837      return base_path
838
839    def _post_export_fn(orig_path, new_path):
840      assert orig_path.endswith("/e1")
841      post_export_path = os.path.join(new_path, "rewrite")
842      gfile.MkDir(post_export_path)
843      return post_export_path
844
845    base_export_strategy = export_strategy_lib.ExportStrategy(
846        "Servo", _base_export_fn)
847
848    final_export_strategy = saved_model_export_utils.extend_export_strategy(
849        base_export_strategy, _post_export_fn)
850    self.assertEqual(final_export_strategy.name, "Servo")
851
852    test_estimator = TestEstimator()
853    tmpdir = tempfile.mkdtemp()
854    export_model_dir = os.path.join(tmpdir, "model")
855    checkpoint_path = os.path.join(tmpdir, "checkpoint")
856    final_path = final_export_strategy.export(test_estimator, export_model_dir,
857                                              checkpoint_path)
858    self.assertEqual(os.path.join(export_model_dir, "rewrite"), final_path)
859
860  def test_extend_export_strategy_raises_error(self):
861
862    def _base_export_fn(unused_estimator,
863                        export_dir_base,
864                        unused_checkpoint_path=None):
865      base_path = os.path.join(export_dir_base, "e1")
866      gfile.MkDir(base_path)
867      return base_path
868
869    def _post_export_fn(unused_orig_path, unused_new_path):
870      return tempfile.mkdtemp()
871
872    base_export_strategy = export_strategy_lib.ExportStrategy(
873        "Servo", _base_export_fn)
874
875    final_export_strategy = saved_model_export_utils.extend_export_strategy(
876        base_export_strategy, _post_export_fn)
877
878    test_estimator = TestEstimator()
879    tmpdir = tempfile.mkdtemp()
880    with self.assertRaises(ValueError) as ve:
881      final_export_strategy.export(test_estimator, tmpdir,
882                                   os.path.join(tmpdir, "checkpoint"))
883
884    self.assertTrue(
885        "post_export_fn must return a sub-directory" in str(ve.exception))
886
887
888def _create_test_export_dir(export_dir_base):
889  export_dir = saved_model_export_utils.get_timestamped_export_dir(
890      export_dir_base)
891  gfile.MkDir(export_dir)
892  time.sleep(2)
893  return export_dir
894
895
896if __name__ == "__main__":
897  test.main()
898