1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for ExportStrategy.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.contrib.learn.python.learn import export_strategy 22from tensorflow.python.platform import test 23 24 25class ExportStrategyTest(test.TestCase): 26 27 def test_no_optional_args_export(self): 28 model_path = '/path/to/model' 29 def _export_fn(estimator, export_path): 30 self.assertTupleEqual((estimator, export_path), (None, None)) 31 return model_path 32 33 strategy = export_strategy.ExportStrategy('foo', _export_fn) 34 self.assertTupleEqual(strategy, ('foo', _export_fn, None)) 35 self.assertIs(strategy.export(None, None), model_path) 36 37 def test_checkpoint_export(self): 38 ckpt_model_path = '/path/to/checkpoint_model' 39 def _ckpt_export_fn(estimator, export_path, checkpoint_path): 40 self.assertTupleEqual((estimator, export_path), (None, None)) 41 self.assertEqual(checkpoint_path, 'checkpoint') 42 return ckpt_model_path 43 44 strategy = export_strategy.ExportStrategy('foo', _ckpt_export_fn) 45 self.assertTupleEqual(strategy, ('foo', _ckpt_export_fn, None)) 46 self.assertIs(strategy.export(None, None, 'checkpoint'), ckpt_model_path) 47 48 def test_checkpoint_eval_export(self): 49 ckpt_eval_model_path = '/path/to/checkpoint_eval_model' 50 def _ckpt_eval_export_fn(estimator, export_path, checkpoint_path, 51 eval_result): 52 self.assertTupleEqual((estimator, export_path), (None, None)) 53 self.assertEqual(checkpoint_path, 'checkpoint') 54 self.assertEqual(eval_result, 'eval') 55 return ckpt_eval_model_path 56 57 strategy = export_strategy.ExportStrategy('foo', _ckpt_eval_export_fn) 58 self.assertTupleEqual(strategy, ('foo', _ckpt_eval_export_fn, None)) 59 self.assertIs(strategy.export(None, None, 'checkpoint', 'eval'), 60 ckpt_eval_model_path) 61 62 def test_eval_only_export(self): 63 def _eval_export_fn(estimator, export_path, eval_result): 64 del estimator, export_path, eval_result 65 66 strategy = export_strategy.ExportStrategy('foo', _eval_export_fn) 67 self.assertTupleEqual(strategy, ('foo', _eval_export_fn, None)) 68 with self.assertRaisesRegexp(ValueError, 'An export_fn accepting ' 69 'eval_result must also accept ' 70 'checkpoint_path'): 71 strategy.export(None, None, eval_result='eval') 72 73 def test_strip_default_attr_export(self): 74 strip_default_attrs_model_path = '/path/to/strip_default_attrs_model' 75 def _strip_default_attrs_export_fn(estimator, export_path, 76 strip_default_attrs): 77 self.assertTupleEqual((estimator, export_path), (None, None)) 78 self.assertTrue(strip_default_attrs) 79 return strip_default_attrs_model_path 80 81 strategy = export_strategy.ExportStrategy('foo', 82 _strip_default_attrs_export_fn, 83 True) 84 self.assertTupleEqual(strategy, 85 ('foo', _strip_default_attrs_export_fn, True)) 86 self.assertIs(strategy.export(None, None), strip_default_attrs_model_path) 87 88if __name__ == '__main__': 89 test.main() 90