1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Sanity tests for tf.flags.""" 16import sys 17import unittest 18 19from absl import flags as absl_flags 20 21from tensorflow.python.platform import flags 22from tensorflow.python.platform import test 23 24 25flags.DEFINE_string( # pylint: disable=no-value-for-parameter 26 flag_name='old_string', 27 default_value='default', 28 docstring='docstring') 29flags.DEFINE_string( 30 name='new_string', default='default', help='docstring') 31flags.DEFINE_integer( # pylint: disable=no-value-for-parameter 32 flag_name='old_integer', 33 default_value=1, 34 docstring='docstring') 35flags.DEFINE_integer( 36 name='new_integer', default=1, help='docstring') 37flags.DEFINE_float( # pylint: disable=no-value-for-parameter 38 flag_name='old_float', 39 default_value=1.5, 40 docstring='docstring') 41flags.DEFINE_float( 42 name='new_float', default=1.5, help='docstring') 43flags.DEFINE_bool( # pylint: disable=no-value-for-parameter 44 flag_name='old_bool', 45 default_value=True, 46 docstring='docstring') 47flags.DEFINE_bool( 48 name='new_bool', default=True, help='docstring') 49flags.DEFINE_boolean( # pylint: disable=no-value-for-parameter 50 flag_name='old_boolean', 51 default_value=False, 52 docstring='docstring') 53flags.DEFINE_boolean( 54 name='new_boolean', default=False, help='docstring') 55 56 57class FlagsTest(unittest.TestCase): 58 59 def setUp(self): 60 self.original_flags = flags.FlagValues() 61 self.wrapped_flags = flags._FlagValuesWrapper(self.original_flags) 62 flags.DEFINE_string( 63 'test', 'default', 'test flag', flag_values=self.wrapped_flags) 64 65 def test_attribute_overrides(self): 66 # Test that methods defined in absl.flags.FlagValues are the same as the 67 # wrapped ones. 68 self.assertEqual(flags.FLAGS.is_parsed, absl_flags.FLAGS.is_parsed) 69 70 def test_getattr(self): 71 self.assertFalse(self.wrapped_flags.is_parsed()) 72 with test.mock.patch.object(sys, 'argv', new=['program', '--test=new']): 73 self.assertEqual('new', self.wrapped_flags.test) 74 self.assertTrue(self.wrapped_flags.is_parsed()) 75 76 def test_setattr(self): 77 self.assertEqual('default', self.wrapped_flags.test) 78 self.wrapped_flags.test = 'new' 79 self.assertEqual('new', self.wrapped_flags.test) 80 81 def test_delattr(self): 82 del self.wrapped_flags.test 83 self.assertNotIn('test', self.wrapped_flags) 84 with self.assertRaises(AttributeError): 85 _ = self.wrapped_flags.test 86 87 def test_dir(self): 88 self.assertEqual(['test'], dir(self.wrapped_flags)) 89 90 def test_getitem(self): 91 self.assertIs(self.original_flags['test'], self.wrapped_flags['test']) 92 93 def test_setitem(self): 94 flag = flags.Flag(flags.ArgumentParser(), flags.ArgumentSerializer(), 95 'fruit', 'apple', 'the fruit type') 96 self.wrapped_flags['fruit'] = flag 97 self.assertIs(self.original_flags['fruit'], self.wrapped_flags['fruit']) 98 self.assertEqual('apple', self.wrapped_flags.fruit) 99 100 def test_len(self): 101 self.assertEqual(1, len(self.wrapped_flags)) 102 103 def test_iter(self): 104 self.assertEqual(['test'], list(self.wrapped_flags)) 105 106 def test_str(self): 107 self.assertEqual(str(self.wrapped_flags), str(self.original_flags)) 108 109 def test_call(self): 110 self.wrapped_flags(['program', '--test=new']) 111 self.assertEqual('new', self.wrapped_flags.test) 112 113 def test_keyword_arguments(self): 114 test_cases = ( 115 ('old_string', 'default'), 116 ('new_string', 'default'), 117 ('old_integer', 1), 118 ('new_integer', 1), 119 ('old_float', 1.5), 120 ('new_float', 1.5), 121 ('old_bool', True), 122 ('new_bool', True), 123 ('old_boolean', False), 124 ('new_boolean', False), 125 ) 126 for flag_name, default_value in test_cases: 127 self.assertEqual(default_value, absl_flags.FLAGS[flag_name].default) 128 self.assertEqual('docstring', absl_flags.FLAGS[flag_name].help) 129 130 131if __name__ == '__main__': 132 unittest.main() 133