• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for ExportStrategy."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.contrib.learn.python.learn import export_strategy
22from tensorflow.python.platform import test
23
24
25class ExportStrategyTest(test.TestCase):
26
27  def test_no_optional_args_export(self):
28    model_path = '/path/to/model'
29    def _export_fn(estimator, export_path):
30      self.assertTupleEqual((estimator, export_path), (None, None))
31      return model_path
32
33    strategy = export_strategy.ExportStrategy('foo', _export_fn)
34    self.assertTupleEqual(strategy, ('foo', _export_fn, None))
35    self.assertIs(strategy.export(None, None), model_path)
36
37  def test_checkpoint_export(self):
38    ckpt_model_path = '/path/to/checkpoint_model'
39    def _ckpt_export_fn(estimator, export_path, checkpoint_path):
40      self.assertTupleEqual((estimator, export_path), (None, None))
41      self.assertEqual(checkpoint_path, 'checkpoint')
42      return ckpt_model_path
43
44    strategy = export_strategy.ExportStrategy('foo', _ckpt_export_fn)
45    self.assertTupleEqual(strategy, ('foo', _ckpt_export_fn, None))
46    self.assertIs(strategy.export(None, None, 'checkpoint'), ckpt_model_path)
47
48  def test_checkpoint_eval_export(self):
49    ckpt_eval_model_path = '/path/to/checkpoint_eval_model'
50    def _ckpt_eval_export_fn(estimator, export_path, checkpoint_path,
51                             eval_result):
52      self.assertTupleEqual((estimator, export_path), (None, None))
53      self.assertEqual(checkpoint_path, 'checkpoint')
54      self.assertEqual(eval_result, 'eval')
55      return ckpt_eval_model_path
56
57    strategy = export_strategy.ExportStrategy('foo', _ckpt_eval_export_fn)
58    self.assertTupleEqual(strategy, ('foo', _ckpt_eval_export_fn, None))
59    self.assertIs(strategy.export(None, None, 'checkpoint', 'eval'),
60                  ckpt_eval_model_path)
61
62  def test_eval_only_export(self):
63    def _eval_export_fn(estimator, export_path, eval_result):
64      del estimator, export_path, eval_result
65
66    strategy = export_strategy.ExportStrategy('foo', _eval_export_fn)
67    self.assertTupleEqual(strategy, ('foo', _eval_export_fn, None))
68    with self.assertRaisesRegexp(ValueError, 'An export_fn accepting '
69                                 'eval_result must also accept '
70                                 'checkpoint_path'):
71      strategy.export(None, None, eval_result='eval')
72
73  def test_strip_default_attr_export(self):
74    strip_default_attrs_model_path = '/path/to/strip_default_attrs_model'
75    def _strip_default_attrs_export_fn(estimator, export_path,
76                                       strip_default_attrs):
77      self.assertTupleEqual((estimator, export_path), (None, None))
78      self.assertTrue(strip_default_attrs)
79      return strip_default_attrs_model_path
80
81    strategy = export_strategy.ExportStrategy('foo',
82                                              _strip_default_attrs_export_fn,
83                                              True)
84    self.assertTupleEqual(strategy,
85                          ('foo', _strip_default_attrs_export_fn, True))
86    self.assertIs(strategy.export(None, None), strip_default_attrs_model_path)
87
88if __name__ == '__main__':
89  test.main()
90