1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""ModeKey Tests.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.platform import test 22from tensorflow.python.saved_model.model_utils import mode_keys 23 24 25class ModeKeyMapTest(test.TestCase): 26 27 def test_map(self): 28 mode_map = mode_keys.ModeKeyMap(**{ 29 mode_keys.KerasModeKeys.PREDICT: 3, 30 mode_keys.KerasModeKeys.TEST: 1 31 }) 32 33 # Test dictionary __getitem__ 34 self.assertEqual(3, mode_map[mode_keys.KerasModeKeys.PREDICT]) 35 self.assertEqual(3, mode_map[mode_keys.EstimatorModeKeys.PREDICT]) 36 self.assertEqual(1, mode_map[mode_keys.KerasModeKeys.TEST]) 37 self.assertEqual(1, mode_map[mode_keys.EstimatorModeKeys.EVAL]) 38 with self.assertRaises(KeyError): 39 _ = mode_map[mode_keys.KerasModeKeys.TRAIN] 40 with self.assertRaises(KeyError): 41 _ = mode_map[mode_keys.EstimatorModeKeys.TRAIN] 42 with self.assertRaisesRegexp(ValueError, 'Invalid mode'): 43 _ = mode_map['serve'] 44 45 # Test common dictionary methods 46 self.assertLen(mode_map, 2) 47 self.assertEqual({1, 3}, set(mode_map.values())) 48 self.assertEqual( 49 {mode_keys.KerasModeKeys.TEST, mode_keys.KerasModeKeys.PREDICT}, 50 set(mode_map.keys())) 51 52 # Map is immutable 53 with self.assertRaises(TypeError): 54 mode_map[mode_keys.KerasModeKeys.TEST] = 1 55 56 def test_invalid_init(self): 57 with self.assertRaisesRegexp(ValueError, 'Multiple keys/values found'): 58 _ = mode_keys.ModeKeyMap(**{ 59 mode_keys.KerasModeKeys.PREDICT: 3, 60 mode_keys.EstimatorModeKeys.PREDICT: 1 61 }) 62 63 64if __name__ == '__main__': 65 test.main() 66