• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The Abseil Authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for flagsaver."""
15
16from absl import flags
17from absl.testing import absltest
18from absl.testing import flagsaver
19
20flags.DEFINE_string('flagsaver_test_flag0', 'unchanged0', 'flag to test with')
21flags.DEFINE_string('flagsaver_test_flag1', 'unchanged1', 'flag to test with')
22
23flags.DEFINE_string('flagsaver_test_validated_flag', None, 'flag to test with')
24flags.register_validator('flagsaver_test_validated_flag', lambda x: not x)
25
26flags.DEFINE_string('flagsaver_test_validated_flag1', None, 'flag to test with')
27flags.DEFINE_string('flagsaver_test_validated_flag2', None, 'flag to test with')
28
29INT_FLAG = flags.DEFINE_integer(
30    'flagsaver_test_int_flag', default=1, help='help')
31STR_FLAG = flags.DEFINE_string(
32    'flagsaver_test_str_flag', default='str default', help='help')
33
34
35@flags.multi_flags_validator(
36    ('flagsaver_test_validated_flag1', 'flagsaver_test_validated_flag2'))
37def validate_test_flags(flag_dict):
38  return (flag_dict['flagsaver_test_validated_flag1'] ==
39          flag_dict['flagsaver_test_validated_flag2'])
40
41
42FLAGS = flags.FLAGS
43
44
45@flags.validator('flagsaver_test_flag0')
46def check_no_upper_case(value):
47  return value == value.lower()
48
49
50class _TestError(Exception):
51  """Exception class for use in these tests."""
52
53
54class FlagSaverTest(absltest.TestCase):
55
56  def test_context_manager_without_parameters(self):
57    with flagsaver.flagsaver():
58      FLAGS.flagsaver_test_flag0 = 'new value'
59    self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
60
61  def test_context_manager_with_overrides(self):
62    with flagsaver.flagsaver(flagsaver_test_flag0='new value'):
63      self.assertEqual('new value', FLAGS.flagsaver_test_flag0)
64      FLAGS.flagsaver_test_flag1 = 'another value'
65    self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
66    self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1)
67
68  def test_context_manager_with_flagholders(self):
69    with flagsaver.flagsaver((INT_FLAG, 3), (STR_FLAG, 'new value')):
70      self.assertEqual('new value', STR_FLAG.value)
71      self.assertEqual(3, INT_FLAG.value)
72      FLAGS.flagsaver_test_flag1 = 'another value'
73    self.assertEqual(INT_FLAG.value, INT_FLAG.default)
74    self.assertEqual(STR_FLAG.value, STR_FLAG.default)
75    self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1)
76
77  def test_context_manager_with_overrides_and_flagholders(self):
78    with flagsaver.flagsaver((INT_FLAG, 3), flagsaver_test_flag0='new value'):
79      self.assertEqual(STR_FLAG.default, STR_FLAG.value)
80      self.assertEqual(3, INT_FLAG.value)
81      FLAGS.flagsaver_test_flag0 = 'new value'
82    self.assertEqual(INT_FLAG.value, INT_FLAG.default)
83    self.assertEqual(STR_FLAG.value, STR_FLAG.default)
84    self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
85
86  def test_context_manager_with_cross_validated_overrides_set_together(self):
87    # When the flags are set in the same flagsaver call their validators will
88    # be triggered only once the setting is done.
89    with flagsaver.flagsaver(
90        flagsaver_test_validated_flag1='new_value',
91        flagsaver_test_validated_flag2='new_value'):
92      self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag1)
93      self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag2)
94
95    self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
96    self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
97
98  def test_context_manager_with_cross_validated_overrides_set_badly(self):
99
100    # Different values should violate the validator.
101    with self.assertRaisesRegex(flags.IllegalFlagValueError,
102                                'Flag validation failed'):
103      with flagsaver.flagsaver(
104          flagsaver_test_validated_flag1='new_value',
105          flagsaver_test_validated_flag2='other_value'):
106        pass
107
108    self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
109    self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
110
111  def test_context_manager_with_cross_validated_overrides_set_separately(self):
112
113    # Setting just one flag will trip the validator as well.
114    with self.assertRaisesRegex(flags.IllegalFlagValueError,
115                                'Flag validation failed'):
116      with flagsaver.flagsaver(flagsaver_test_validated_flag1='new_value'):
117        pass
118
119    self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
120    self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
121
122  def test_context_manager_with_exception(self):
123    with self.assertRaises(_TestError):
124      with flagsaver.flagsaver(flagsaver_test_flag0='new value'):
125        self.assertEqual('new value', FLAGS.flagsaver_test_flag0)
126        FLAGS.flagsaver_test_flag1 = 'another value'
127        raise _TestError('oops')
128    self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
129    self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1)
130
131  def test_context_manager_with_validation_exception(self):
132    with self.assertRaises(flags.IllegalFlagValueError):
133      with flagsaver.flagsaver(
134          flagsaver_test_flag0='new value',
135          flagsaver_test_validated_flag='new value'):
136        pass
137    self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
138    self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1)
139    self.assertIsNone(FLAGS.flagsaver_test_validated_flag)
140
141  def test_decorator_without_call(self):
142
143    @flagsaver.flagsaver
144    def mutate_flags(value):
145      """Test function that mutates a flag."""
146      # The undecorated method mutates --flagsaver_test_flag0 to the given value
147      # and then returns the value of that flag.  If the @flagsaver.flagsaver
148      # decorator works as designed, then this mutation will be reverted after
149      # this method returns.
150      FLAGS.flagsaver_test_flag0 = value
151      return FLAGS.flagsaver_test_flag0
152
153    # mutate_flags returns the flag value before it gets restored by
154    # the flagsaver decorator.  So we check that flag value was
155    # actually changed in the method's scope.
156    self.assertEqual('new value', mutate_flags('new value'))
157    # But... notice that the flag is now unchanged0.
158    self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
159
160  def test_decorator_without_parameters(self):
161
162    @flagsaver.flagsaver()
163    def mutate_flags(value):
164      FLAGS.flagsaver_test_flag0 = value
165      return FLAGS.flagsaver_test_flag0
166
167    self.assertEqual('new value', mutate_flags('new value'))
168    self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
169
170  def test_decorator_with_overrides(self):
171
172    @flagsaver.flagsaver(flagsaver_test_flag0='new value')
173    def mutate_flags():
174      """Test function expecting new value."""
175      # If the @flagsaver.decorator decorator works as designed,
176      # then the value of the flag should be changed in the scope of
177      # the method but the change will be reverted after this method
178      # returns.
179      return FLAGS.flagsaver_test_flag0
180
181    # mutate_flags returns the flag value before it gets restored by
182    # the flagsaver decorator.  So we check that flag value was
183    # actually changed in the method's scope.
184    self.assertEqual('new value', mutate_flags())
185    # But... notice that the flag is now unchanged0.
186    self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
187
188  def test_decorator_with_cross_validated_overrides_set_together(self):
189
190    # When the flags are set in the same flagsaver call their validators will
191    # be triggered only once the setting is done.
192    @flagsaver.flagsaver(
193        flagsaver_test_validated_flag1='new_value',
194        flagsaver_test_validated_flag2='new_value')
195    def mutate_flags_together():
196      return (FLAGS.flagsaver_test_validated_flag1,
197              FLAGS.flagsaver_test_validated_flag2)
198
199    self.assertEqual(('new_value', 'new_value'), mutate_flags_together())
200
201    # The flags have not changed outside the context of the function.
202    self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
203    self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
204
205  def test_decorator_with_cross_validated_overrides_set_badly(self):
206
207    # Different values should violate the validator.
208    @flagsaver.flagsaver(
209        flagsaver_test_validated_flag1='new_value',
210        flagsaver_test_validated_flag2='other_value')
211    def mutate_flags_together_badly():
212      return (FLAGS.flagsaver_test_validated_flag1,
213              FLAGS.flagsaver_test_validated_flag2)
214
215    with self.assertRaisesRegex(flags.IllegalFlagValueError,
216                                'Flag validation failed'):
217      mutate_flags_together_badly()
218
219    # The flags have not changed outside the context of the exception.
220    self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
221    self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
222
223  def test_decorator_with_cross_validated_overrides_set_separately(self):
224
225    # Setting the flags sequentially and not together will trip the validator,
226    # because it will be called at the end of each flagsaver call.
227    @flagsaver.flagsaver(flagsaver_test_validated_flag1='new_value')
228    @flagsaver.flagsaver(flagsaver_test_validated_flag2='new_value')
229    def mutate_flags_separately():
230      return (FLAGS.flagsaver_test_validated_flag1,
231              FLAGS.flagsaver_test_validated_flag2)
232
233    with self.assertRaisesRegex(flags.IllegalFlagValueError,
234                                'Flag validation failed'):
235      mutate_flags_separately()
236
237    # The flags have not changed outside the context of the exception.
238    self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
239    self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
240
241  def test_save_flag_value(self):
242    # First save the flag values.
243    saved_flag_values = flagsaver.save_flag_values()
244
245    # Now mutate the flag's value field and check that it changed.
246    FLAGS.flagsaver_test_flag0 = 'new value'
247    self.assertEqual('new value', FLAGS.flagsaver_test_flag0)
248
249    # Now restore the flag to its original value.
250    flagsaver.restore_flag_values(saved_flag_values)
251    self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
252
253  def test_save_flag_default(self):
254    # First save the flag.
255    saved_flag_values = flagsaver.save_flag_values()
256
257    # Now mutate the flag's default field and check that it changed.
258    FLAGS.set_default('flagsaver_test_flag0', 'new_default')
259    self.assertEqual('new_default', FLAGS['flagsaver_test_flag0'].default)
260
261    # Now restore the flag's default field.
262    flagsaver.restore_flag_values(saved_flag_values)
263    self.assertEqual('unchanged0', FLAGS['flagsaver_test_flag0'].default)
264
265  def test_restore_after_parse(self):
266    # First save the flag.
267    saved_flag_values = flagsaver.save_flag_values()
268
269    # Sanity check (would fail if called with --flagsaver_test_flag0).
270    self.assertEqual(0, FLAGS['flagsaver_test_flag0'].present)
271    # Now populate the flag and check that it changed.
272    FLAGS['flagsaver_test_flag0'].parse('new value')
273    self.assertEqual('new value', FLAGS['flagsaver_test_flag0'].value)
274    self.assertEqual(1, FLAGS['flagsaver_test_flag0'].present)
275
276    # Now restore the flag to its original value.
277    flagsaver.restore_flag_values(saved_flag_values)
278    self.assertEqual('unchanged0', FLAGS['flagsaver_test_flag0'].value)
279    self.assertEqual(0, FLAGS['flagsaver_test_flag0'].present)
280
281  def test_decorator_with_exception(self):
282
283    @flagsaver.flagsaver
284    def raise_exception():
285      FLAGS.flagsaver_test_flag0 = 'new value'
286      # Simulate a failed test.
287      raise _TestError('something happened')
288
289    self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
290    self.assertRaises(_TestError, raise_exception)
291    self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
292
293  def test_validator_list_is_restored(self):
294
295    self.assertLen(FLAGS['flagsaver_test_flag0'].validators, 1)
296    original_validators = list(FLAGS['flagsaver_test_flag0'].validators)
297
298    @flagsaver.flagsaver
299    def modify_validators():
300
301      def no_space(value):
302        return ' ' not in value
303
304      flags.register_validator('flagsaver_test_flag0', no_space)
305      self.assertLen(FLAGS['flagsaver_test_flag0'].validators, 2)
306
307    modify_validators()
308    self.assertEqual(original_validators,
309                     FLAGS['flagsaver_test_flag0'].validators)
310
311
312class FlagSaverDecoratorUsageTest(absltest.TestCase):
313
314  @flagsaver.flagsaver
315  def test_mutate1(self):
316    # Even though other test cases change the flag, it should be
317    # restored to 'unchanged0' if the flagsaver is working.
318    self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
319    FLAGS.flagsaver_test_flag0 = 'changed0'
320
321  @flagsaver.flagsaver
322  def test_mutate2(self):
323    # Even though other test cases change the flag, it should be
324    # restored to 'unchanged0' if the flagsaver is working.
325    self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
326    FLAGS.flagsaver_test_flag0 = 'changed0'
327
328  @flagsaver.flagsaver
329  def test_mutate3(self):
330    # Even though other test cases change the flag, it should be
331    # restored to 'unchanged0' if the flagsaver is working.
332    self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
333    FLAGS.flagsaver_test_flag0 = 'changed0'
334
335  @flagsaver.flagsaver
336  def test_mutate4(self):
337    # Even though other test cases change the flag, it should be
338    # restored to 'unchanged0' if the flagsaver is working.
339    self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
340    FLAGS.flagsaver_test_flag0 = 'changed0'
341
342
343class FlagSaverSetUpTearDownUsageTest(absltest.TestCase):
344
345  def setUp(self):
346    self.saved_flag_values = flagsaver.save_flag_values()
347
348  def tearDown(self):
349    flagsaver.restore_flag_values(self.saved_flag_values)
350
351  def test_mutate1(self):
352    # Even though other test cases change the flag, it should be
353    # restored to 'unchanged0' if the flagsaver is working.
354    self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
355    FLAGS.flagsaver_test_flag0 = 'changed0'
356
357  def test_mutate2(self):
358    # Even though other test cases change the flag, it should be
359    # restored to 'unchanged0' if the flagsaver is working.
360    self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
361    FLAGS.flagsaver_test_flag0 = 'changed0'
362
363  def test_mutate3(self):
364    # Even though other test cases change the flag, it should be
365    # restored to 'unchanged0' if the flagsaver is working.
366    self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
367    FLAGS.flagsaver_test_flag0 = 'changed0'
368
369  def test_mutate4(self):
370    # Even though other test cases change the flag, it should be
371    # restored to 'unchanged0' if the flagsaver is working.
372    self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
373    FLAGS.flagsaver_test_flag0 = 'changed0'
374
375
376class FlagSaverBadUsageTest(absltest.TestCase):
377  """Tests that certain kinds of improper usages raise errors."""
378
379  def test_flag_saver_on_class(self):
380    with self.assertRaises(TypeError):
381
382      # WRONG. Don't do this.
383      # Consider the correct usage example in FlagSaverSetUpTearDownUsageTest.
384      @flagsaver.flagsaver
385      class FooTest(absltest.TestCase):
386
387        def test_tautology(self):
388          pass
389
390      del FooTest
391
392  def test_flag_saver_call_on_class(self):
393    with self.assertRaises(TypeError):
394
395      # WRONG. Don't do this.
396      # Consider the correct usage example in FlagSaverSetUpTearDownUsageTest.
397      @flagsaver.flagsaver()
398      class FooTest(absltest.TestCase):
399
400        def test_tautology(self):
401          pass
402
403      del FooTest
404
405  def test_flag_saver_with_overrides_on_class(self):
406    with self.assertRaises(TypeError):
407
408      # WRONG. Don't do this.
409      # Consider the correct usage example in FlagSaverSetUpTearDownUsageTest.
410      @flagsaver.flagsaver(foo='bar')
411      class FooTest(absltest.TestCase):
412
413        def test_tautology(self):
414          pass
415
416      del FooTest
417
418  def test_multiple_positional_parameters(self):
419    with self.assertRaises(ValueError):
420      func_a = lambda: None
421      func_b = lambda: None
422      flagsaver.flagsaver(func_a, func_b)
423
424  def test_both_positional_and_keyword_parameters(self):
425    with self.assertRaises(ValueError):
426      func_a = lambda: None
427      flagsaver.flagsaver(func_a, flagsaver_test_flag0='new value')
428
429  def test_duplicate_holder_parameters(self):
430    with self.assertRaises(ValueError):
431      flagsaver.flagsaver((INT_FLAG, 45), (INT_FLAG, 45))
432
433  def test_duplicate_holder_and_kw_parameter(self):
434    with self.assertRaises(ValueError):
435      flagsaver.flagsaver((INT_FLAG, 45), **{INT_FLAG.name: 45})
436
437  def test_both_positional_and_holder_parameters(self):
438    with self.assertRaises(ValueError):
439      func_a = lambda: None
440      flagsaver.flagsaver(func_a, (INT_FLAG, 45))
441
442  def test_holder_parameters_wrong_shape(self):
443    with self.assertRaises(ValueError):
444      flagsaver.flagsaver(INT_FLAG)
445
446  def test_holder_parameters_tuple_too_long(self):
447    with self.assertRaises(ValueError):
448      # Even if it is a bool flag, it should be a tuple
449      flagsaver.flagsaver((INT_FLAG, 4, 5))
450
451  def test_holder_parameters_tuple_wrong_type(self):
452    with self.assertRaises(ValueError):
453      # Even if it is a bool flag, it should be a tuple
454      flagsaver.flagsaver((4, INT_FLAG))
455
456  def test_both_wrong_positional_parameters(self):
457    with self.assertRaises(ValueError):
458      func_a = lambda: None
459      flagsaver.flagsaver(func_a, STR_FLAG, '45')
460
461
462if __name__ == '__main__':
463  absltest.main()
464