1# Copyright 2017 The Abseil Authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for flagsaver.""" 15 16from absl import flags 17from absl.testing import absltest 18from absl.testing import flagsaver 19 20flags.DEFINE_string('flagsaver_test_flag0', 'unchanged0', 'flag to test with') 21flags.DEFINE_string('flagsaver_test_flag1', 'unchanged1', 'flag to test with') 22 23flags.DEFINE_string('flagsaver_test_validated_flag', None, 'flag to test with') 24flags.register_validator('flagsaver_test_validated_flag', lambda x: not x) 25 26flags.DEFINE_string('flagsaver_test_validated_flag1', None, 'flag to test with') 27flags.DEFINE_string('flagsaver_test_validated_flag2', None, 'flag to test with') 28 29INT_FLAG = flags.DEFINE_integer( 30 'flagsaver_test_int_flag', default=1, help='help') 31STR_FLAG = flags.DEFINE_string( 32 'flagsaver_test_str_flag', default='str default', help='help') 33 34 35@flags.multi_flags_validator( 36 ('flagsaver_test_validated_flag1', 'flagsaver_test_validated_flag2')) 37def validate_test_flags(flag_dict): 38 return (flag_dict['flagsaver_test_validated_flag1'] == 39 flag_dict['flagsaver_test_validated_flag2']) 40 41 42FLAGS = flags.FLAGS 43 44 45@flags.validator('flagsaver_test_flag0') 46def check_no_upper_case(value): 47 return value == value.lower() 48 49 50class _TestError(Exception): 51 """Exception class for use in these tests.""" 52 53 54class FlagSaverTest(absltest.TestCase): 55 56 def test_context_manager_without_parameters(self): 57 with flagsaver.flagsaver(): 58 FLAGS.flagsaver_test_flag0 = 'new value' 59 self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) 60 61 def test_context_manager_with_overrides(self): 62 with flagsaver.flagsaver(flagsaver_test_flag0='new value'): 63 self.assertEqual('new value', FLAGS.flagsaver_test_flag0) 64 FLAGS.flagsaver_test_flag1 = 'another value' 65 self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) 66 self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1) 67 68 def test_context_manager_with_flagholders(self): 69 with flagsaver.flagsaver((INT_FLAG, 3), (STR_FLAG, 'new value')): 70 self.assertEqual('new value', STR_FLAG.value) 71 self.assertEqual(3, INT_FLAG.value) 72 FLAGS.flagsaver_test_flag1 = 'another value' 73 self.assertEqual(INT_FLAG.value, INT_FLAG.default) 74 self.assertEqual(STR_FLAG.value, STR_FLAG.default) 75 self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1) 76 77 def test_context_manager_with_overrides_and_flagholders(self): 78 with flagsaver.flagsaver((INT_FLAG, 3), flagsaver_test_flag0='new value'): 79 self.assertEqual(STR_FLAG.default, STR_FLAG.value) 80 self.assertEqual(3, INT_FLAG.value) 81 FLAGS.flagsaver_test_flag0 = 'new value' 82 self.assertEqual(INT_FLAG.value, INT_FLAG.default) 83 self.assertEqual(STR_FLAG.value, STR_FLAG.default) 84 self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) 85 86 def test_context_manager_with_cross_validated_overrides_set_together(self): 87 # When the flags are set in the same flagsaver call their validators will 88 # be triggered only once the setting is done. 89 with flagsaver.flagsaver( 90 flagsaver_test_validated_flag1='new_value', 91 flagsaver_test_validated_flag2='new_value'): 92 self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag1) 93 self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag2) 94 95 self.assertIsNone(FLAGS.flagsaver_test_validated_flag1) 96 self.assertIsNone(FLAGS.flagsaver_test_validated_flag2) 97 98 def test_context_manager_with_cross_validated_overrides_set_badly(self): 99 100 # Different values should violate the validator. 101 with self.assertRaisesRegex(flags.IllegalFlagValueError, 102 'Flag validation failed'): 103 with flagsaver.flagsaver( 104 flagsaver_test_validated_flag1='new_value', 105 flagsaver_test_validated_flag2='other_value'): 106 pass 107 108 self.assertIsNone(FLAGS.flagsaver_test_validated_flag1) 109 self.assertIsNone(FLAGS.flagsaver_test_validated_flag2) 110 111 def test_context_manager_with_cross_validated_overrides_set_separately(self): 112 113 # Setting just one flag will trip the validator as well. 114 with self.assertRaisesRegex(flags.IllegalFlagValueError, 115 'Flag validation failed'): 116 with flagsaver.flagsaver(flagsaver_test_validated_flag1='new_value'): 117 pass 118 119 self.assertIsNone(FLAGS.flagsaver_test_validated_flag1) 120 self.assertIsNone(FLAGS.flagsaver_test_validated_flag2) 121 122 def test_context_manager_with_exception(self): 123 with self.assertRaises(_TestError): 124 with flagsaver.flagsaver(flagsaver_test_flag0='new value'): 125 self.assertEqual('new value', FLAGS.flagsaver_test_flag0) 126 FLAGS.flagsaver_test_flag1 = 'another value' 127 raise _TestError('oops') 128 self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) 129 self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1) 130 131 def test_context_manager_with_validation_exception(self): 132 with self.assertRaises(flags.IllegalFlagValueError): 133 with flagsaver.flagsaver( 134 flagsaver_test_flag0='new value', 135 flagsaver_test_validated_flag='new value'): 136 pass 137 self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) 138 self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1) 139 self.assertIsNone(FLAGS.flagsaver_test_validated_flag) 140 141 def test_decorator_without_call(self): 142 143 @flagsaver.flagsaver 144 def mutate_flags(value): 145 """Test function that mutates a flag.""" 146 # The undecorated method mutates --flagsaver_test_flag0 to the given value 147 # and then returns the value of that flag. If the @flagsaver.flagsaver 148 # decorator works as designed, then this mutation will be reverted after 149 # this method returns. 150 FLAGS.flagsaver_test_flag0 = value 151 return FLAGS.flagsaver_test_flag0 152 153 # mutate_flags returns the flag value before it gets restored by 154 # the flagsaver decorator. So we check that flag value was 155 # actually changed in the method's scope. 156 self.assertEqual('new value', mutate_flags('new value')) 157 # But... notice that the flag is now unchanged0. 158 self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) 159 160 def test_decorator_without_parameters(self): 161 162 @flagsaver.flagsaver() 163 def mutate_flags(value): 164 FLAGS.flagsaver_test_flag0 = value 165 return FLAGS.flagsaver_test_flag0 166 167 self.assertEqual('new value', mutate_flags('new value')) 168 self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) 169 170 def test_decorator_with_overrides(self): 171 172 @flagsaver.flagsaver(flagsaver_test_flag0='new value') 173 def mutate_flags(): 174 """Test function expecting new value.""" 175 # If the @flagsaver.decorator decorator works as designed, 176 # then the value of the flag should be changed in the scope of 177 # the method but the change will be reverted after this method 178 # returns. 179 return FLAGS.flagsaver_test_flag0 180 181 # mutate_flags returns the flag value before it gets restored by 182 # the flagsaver decorator. So we check that flag value was 183 # actually changed in the method's scope. 184 self.assertEqual('new value', mutate_flags()) 185 # But... notice that the flag is now unchanged0. 186 self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) 187 188 def test_decorator_with_cross_validated_overrides_set_together(self): 189 190 # When the flags are set in the same flagsaver call their validators will 191 # be triggered only once the setting is done. 192 @flagsaver.flagsaver( 193 flagsaver_test_validated_flag1='new_value', 194 flagsaver_test_validated_flag2='new_value') 195 def mutate_flags_together(): 196 return (FLAGS.flagsaver_test_validated_flag1, 197 FLAGS.flagsaver_test_validated_flag2) 198 199 self.assertEqual(('new_value', 'new_value'), mutate_flags_together()) 200 201 # The flags have not changed outside the context of the function. 202 self.assertIsNone(FLAGS.flagsaver_test_validated_flag1) 203 self.assertIsNone(FLAGS.flagsaver_test_validated_flag2) 204 205 def test_decorator_with_cross_validated_overrides_set_badly(self): 206 207 # Different values should violate the validator. 208 @flagsaver.flagsaver( 209 flagsaver_test_validated_flag1='new_value', 210 flagsaver_test_validated_flag2='other_value') 211 def mutate_flags_together_badly(): 212 return (FLAGS.flagsaver_test_validated_flag1, 213 FLAGS.flagsaver_test_validated_flag2) 214 215 with self.assertRaisesRegex(flags.IllegalFlagValueError, 216 'Flag validation failed'): 217 mutate_flags_together_badly() 218 219 # The flags have not changed outside the context of the exception. 220 self.assertIsNone(FLAGS.flagsaver_test_validated_flag1) 221 self.assertIsNone(FLAGS.flagsaver_test_validated_flag2) 222 223 def test_decorator_with_cross_validated_overrides_set_separately(self): 224 225 # Setting the flags sequentially and not together will trip the validator, 226 # because it will be called at the end of each flagsaver call. 227 @flagsaver.flagsaver(flagsaver_test_validated_flag1='new_value') 228 @flagsaver.flagsaver(flagsaver_test_validated_flag2='new_value') 229 def mutate_flags_separately(): 230 return (FLAGS.flagsaver_test_validated_flag1, 231 FLAGS.flagsaver_test_validated_flag2) 232 233 with self.assertRaisesRegex(flags.IllegalFlagValueError, 234 'Flag validation failed'): 235 mutate_flags_separately() 236 237 # The flags have not changed outside the context of the exception. 238 self.assertIsNone(FLAGS.flagsaver_test_validated_flag1) 239 self.assertIsNone(FLAGS.flagsaver_test_validated_flag2) 240 241 def test_save_flag_value(self): 242 # First save the flag values. 243 saved_flag_values = flagsaver.save_flag_values() 244 245 # Now mutate the flag's value field and check that it changed. 246 FLAGS.flagsaver_test_flag0 = 'new value' 247 self.assertEqual('new value', FLAGS.flagsaver_test_flag0) 248 249 # Now restore the flag to its original value. 250 flagsaver.restore_flag_values(saved_flag_values) 251 self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) 252 253 def test_save_flag_default(self): 254 # First save the flag. 255 saved_flag_values = flagsaver.save_flag_values() 256 257 # Now mutate the flag's default field and check that it changed. 258 FLAGS.set_default('flagsaver_test_flag0', 'new_default') 259 self.assertEqual('new_default', FLAGS['flagsaver_test_flag0'].default) 260 261 # Now restore the flag's default field. 262 flagsaver.restore_flag_values(saved_flag_values) 263 self.assertEqual('unchanged0', FLAGS['flagsaver_test_flag0'].default) 264 265 def test_restore_after_parse(self): 266 # First save the flag. 267 saved_flag_values = flagsaver.save_flag_values() 268 269 # Sanity check (would fail if called with --flagsaver_test_flag0). 270 self.assertEqual(0, FLAGS['flagsaver_test_flag0'].present) 271 # Now populate the flag and check that it changed. 272 FLAGS['flagsaver_test_flag0'].parse('new value') 273 self.assertEqual('new value', FLAGS['flagsaver_test_flag0'].value) 274 self.assertEqual(1, FLAGS['flagsaver_test_flag0'].present) 275 276 # Now restore the flag to its original value. 277 flagsaver.restore_flag_values(saved_flag_values) 278 self.assertEqual('unchanged0', FLAGS['flagsaver_test_flag0'].value) 279 self.assertEqual(0, FLAGS['flagsaver_test_flag0'].present) 280 281 def test_decorator_with_exception(self): 282 283 @flagsaver.flagsaver 284 def raise_exception(): 285 FLAGS.flagsaver_test_flag0 = 'new value' 286 # Simulate a failed test. 287 raise _TestError('something happened') 288 289 self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) 290 self.assertRaises(_TestError, raise_exception) 291 self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) 292 293 def test_validator_list_is_restored(self): 294 295 self.assertLen(FLAGS['flagsaver_test_flag0'].validators, 1) 296 original_validators = list(FLAGS['flagsaver_test_flag0'].validators) 297 298 @flagsaver.flagsaver 299 def modify_validators(): 300 301 def no_space(value): 302 return ' ' not in value 303 304 flags.register_validator('flagsaver_test_flag0', no_space) 305 self.assertLen(FLAGS['flagsaver_test_flag0'].validators, 2) 306 307 modify_validators() 308 self.assertEqual(original_validators, 309 FLAGS['flagsaver_test_flag0'].validators) 310 311 312class FlagSaverDecoratorUsageTest(absltest.TestCase): 313 314 @flagsaver.flagsaver 315 def test_mutate1(self): 316 # Even though other test cases change the flag, it should be 317 # restored to 'unchanged0' if the flagsaver is working. 318 self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) 319 FLAGS.flagsaver_test_flag0 = 'changed0' 320 321 @flagsaver.flagsaver 322 def test_mutate2(self): 323 # Even though other test cases change the flag, it should be 324 # restored to 'unchanged0' if the flagsaver is working. 325 self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) 326 FLAGS.flagsaver_test_flag0 = 'changed0' 327 328 @flagsaver.flagsaver 329 def test_mutate3(self): 330 # Even though other test cases change the flag, it should be 331 # restored to 'unchanged0' if the flagsaver is working. 332 self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) 333 FLAGS.flagsaver_test_flag0 = 'changed0' 334 335 @flagsaver.flagsaver 336 def test_mutate4(self): 337 # Even though other test cases change the flag, it should be 338 # restored to 'unchanged0' if the flagsaver is working. 339 self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) 340 FLAGS.flagsaver_test_flag0 = 'changed0' 341 342 343class FlagSaverSetUpTearDownUsageTest(absltest.TestCase): 344 345 def setUp(self): 346 self.saved_flag_values = flagsaver.save_flag_values() 347 348 def tearDown(self): 349 flagsaver.restore_flag_values(self.saved_flag_values) 350 351 def test_mutate1(self): 352 # Even though other test cases change the flag, it should be 353 # restored to 'unchanged0' if the flagsaver is working. 354 self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) 355 FLAGS.flagsaver_test_flag0 = 'changed0' 356 357 def test_mutate2(self): 358 # Even though other test cases change the flag, it should be 359 # restored to 'unchanged0' if the flagsaver is working. 360 self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) 361 FLAGS.flagsaver_test_flag0 = 'changed0' 362 363 def test_mutate3(self): 364 # Even though other test cases change the flag, it should be 365 # restored to 'unchanged0' if the flagsaver is working. 366 self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) 367 FLAGS.flagsaver_test_flag0 = 'changed0' 368 369 def test_mutate4(self): 370 # Even though other test cases change the flag, it should be 371 # restored to 'unchanged0' if the flagsaver is working. 372 self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0) 373 FLAGS.flagsaver_test_flag0 = 'changed0' 374 375 376class FlagSaverBadUsageTest(absltest.TestCase): 377 """Tests that certain kinds of improper usages raise errors.""" 378 379 def test_flag_saver_on_class(self): 380 with self.assertRaises(TypeError): 381 382 # WRONG. Don't do this. 383 # Consider the correct usage example in FlagSaverSetUpTearDownUsageTest. 384 @flagsaver.flagsaver 385 class FooTest(absltest.TestCase): 386 387 def test_tautology(self): 388 pass 389 390 del FooTest 391 392 def test_flag_saver_call_on_class(self): 393 with self.assertRaises(TypeError): 394 395 # WRONG. Don't do this. 396 # Consider the correct usage example in FlagSaverSetUpTearDownUsageTest. 397 @flagsaver.flagsaver() 398 class FooTest(absltest.TestCase): 399 400 def test_tautology(self): 401 pass 402 403 del FooTest 404 405 def test_flag_saver_with_overrides_on_class(self): 406 with self.assertRaises(TypeError): 407 408 # WRONG. Don't do this. 409 # Consider the correct usage example in FlagSaverSetUpTearDownUsageTest. 410 @flagsaver.flagsaver(foo='bar') 411 class FooTest(absltest.TestCase): 412 413 def test_tautology(self): 414 pass 415 416 del FooTest 417 418 def test_multiple_positional_parameters(self): 419 with self.assertRaises(ValueError): 420 func_a = lambda: None 421 func_b = lambda: None 422 flagsaver.flagsaver(func_a, func_b) 423 424 def test_both_positional_and_keyword_parameters(self): 425 with self.assertRaises(ValueError): 426 func_a = lambda: None 427 flagsaver.flagsaver(func_a, flagsaver_test_flag0='new value') 428 429 def test_duplicate_holder_parameters(self): 430 with self.assertRaises(ValueError): 431 flagsaver.flagsaver((INT_FLAG, 45), (INT_FLAG, 45)) 432 433 def test_duplicate_holder_and_kw_parameter(self): 434 with self.assertRaises(ValueError): 435 flagsaver.flagsaver((INT_FLAG, 45), **{INT_FLAG.name: 45}) 436 437 def test_both_positional_and_holder_parameters(self): 438 with self.assertRaises(ValueError): 439 func_a = lambda: None 440 flagsaver.flagsaver(func_a, (INT_FLAG, 45)) 441 442 def test_holder_parameters_wrong_shape(self): 443 with self.assertRaises(ValueError): 444 flagsaver.flagsaver(INT_FLAG) 445 446 def test_holder_parameters_tuple_too_long(self): 447 with self.assertRaises(ValueError): 448 # Even if it is a bool flag, it should be a tuple 449 flagsaver.flagsaver((INT_FLAG, 4, 5)) 450 451 def test_holder_parameters_tuple_wrong_type(self): 452 with self.assertRaises(ValueError): 453 # Even if it is a bool flag, it should be a tuple 454 flagsaver.flagsaver((4, INT_FLAG)) 455 456 def test_both_wrong_positional_parameters(self): 457 with self.assertRaises(ValueError): 458 func_a = lambda: None 459 flagsaver.flagsaver(func_a, STR_FLAG, '45') 460 461 462if __name__ == '__main__': 463 absltest.main() 464