• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The Abseil Authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Testing that flags validators framework does work.
16
17This file tests that each flag validator called when it should be, and that
18failed validator will throw an exception, etc.
19"""
20
21import warnings
22
23
24from absl.flags import _defines
25from absl.flags import _exceptions
26from absl.flags import _flagvalues
27from absl.flags import _validators
28from absl.testing import absltest
29
30
31class SingleFlagValidatorTest(absltest.TestCase):
32  """Testing _validators.register_validator() method."""
33
34  def setUp(self):
35    super(SingleFlagValidatorTest, self).setUp()
36    self.flag_values = _flagvalues.FlagValues()
37    self.call_args = []
38
39  def test_success(self):
40    def checker(x):
41      self.call_args.append(x)
42      return True
43    _defines.DEFINE_integer(
44        'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
45    _validators.register_validator(
46        'test_flag',
47        checker,
48        message='Errors happen',
49        flag_values=self.flag_values)
50
51    argv = ('./program',)
52    self.flag_values(argv)
53    self.assertIsNone(self.flag_values.test_flag)
54    self.flag_values.test_flag = 2
55    self.assertEqual(2, self.flag_values.test_flag)
56    self.assertEqual([None, 2], self.call_args)
57
58  def test_success_holder(self):
59    def checker(x):
60      self.call_args.append(x)
61      return True
62
63    flag_holder = _defines.DEFINE_integer(
64        'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
65    _validators.register_validator(
66        flag_holder,
67        checker,
68        message='Errors happen',
69        flag_values=self.flag_values)
70
71    argv = ('./program',)
72    self.flag_values(argv)
73    self.assertIsNone(self.flag_values.test_flag)
74    self.flag_values.test_flag = 2
75    self.assertEqual(2, self.flag_values.test_flag)
76    self.assertEqual([None, 2], self.call_args)
77
78  def test_success_holder_infer_flagvalues(self):
79    def checker(x):
80      self.call_args.append(x)
81      return True
82
83    flag_holder = _defines.DEFINE_integer(
84        'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
85    _validators.register_validator(
86        flag_holder,
87        checker,
88        message='Errors happen')
89
90    argv = ('./program',)
91    self.flag_values(argv)
92    self.assertIsNone(self.flag_values.test_flag)
93    self.flag_values.test_flag = 2
94    self.assertEqual(2, self.flag_values.test_flag)
95    self.assertEqual([None, 2], self.call_args)
96
97  def test_default_value_not_used_success(self):
98    def checker(x):
99      self.call_args.append(x)
100      return True
101    _defines.DEFINE_integer(
102        'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
103    _validators.register_validator(
104        'test_flag',
105        checker,
106        message='Errors happen',
107        flag_values=self.flag_values)
108
109    argv = ('./program', '--test_flag=1')
110    self.flag_values(argv)
111    self.assertEqual(1, self.flag_values.test_flag)
112    self.assertEqual([1], self.call_args)
113
114  def test_validator_not_called_when_other_flag_is_changed(self):
115    def checker(x):
116      self.call_args.append(x)
117      return True
118    _defines.DEFINE_integer(
119        'test_flag', 1, 'Usual integer flag', flag_values=self.flag_values)
120    _defines.DEFINE_integer(
121        'other_flag', 2, 'Other integer flag', flag_values=self.flag_values)
122    _validators.register_validator(
123        'test_flag',
124        checker,
125        message='Errors happen',
126        flag_values=self.flag_values)
127
128    argv = ('./program',)
129    self.flag_values(argv)
130    self.assertEqual(1, self.flag_values.test_flag)
131    self.flag_values.other_flag = 3
132    self.assertEqual([1], self.call_args)
133
134  def test_exception_raised_if_checker_fails(self):
135    def checker(x):
136      self.call_args.append(x)
137      return x == 1
138    _defines.DEFINE_integer(
139        'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
140    _validators.register_validator(
141        'test_flag',
142        checker,
143        message='Errors happen',
144        flag_values=self.flag_values)
145
146    argv = ('./program', '--test_flag=1')
147    self.flag_values(argv)
148    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
149      self.flag_values.test_flag = 2
150    self.assertEqual('flag --test_flag=2: Errors happen', str(cm.exception))
151    self.assertEqual([1, 2], self.call_args)
152
153  def test_exception_raised_if_checker_raises_exception(self):
154    def checker(x):
155      self.call_args.append(x)
156      if x == 1:
157        return True
158      raise _exceptions.ValidationError('Specific message')
159
160    _defines.DEFINE_integer(
161        'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
162    _validators.register_validator(
163        'test_flag',
164        checker,
165        message='Errors happen',
166        flag_values=self.flag_values)
167
168    argv = ('./program', '--test_flag=1')
169    self.flag_values(argv)
170    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
171      self.flag_values.test_flag = 2
172    self.assertEqual('flag --test_flag=2: Specific message', str(cm.exception))
173    self.assertEqual([1, 2], self.call_args)
174
175  def test_error_message_when_checker_returns_false_on_start(self):
176    def checker(x):
177      self.call_args.append(x)
178      return False
179    _defines.DEFINE_integer(
180        'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
181    _validators.register_validator(
182        'test_flag',
183        checker,
184        message='Errors happen',
185        flag_values=self.flag_values)
186
187    argv = ('./program', '--test_flag=1')
188    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
189      self.flag_values(argv)
190    self.assertEqual('flag --test_flag=1: Errors happen', str(cm.exception))
191    self.assertEqual([1], self.call_args)
192
193  def test_error_message_when_checker_raises_exception_on_start(self):
194    def checker(x):
195      self.call_args.append(x)
196      raise _exceptions.ValidationError('Specific message')
197
198    _defines.DEFINE_integer(
199        'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
200    _validators.register_validator(
201        'test_flag',
202        checker,
203        message='Errors happen',
204        flag_values=self.flag_values)
205
206    argv = ('./program', '--test_flag=1')
207    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
208      self.flag_values(argv)
209    self.assertEqual('flag --test_flag=1: Specific message', str(cm.exception))
210    self.assertEqual([1], self.call_args)
211
212  def test_validators_checked_in_order(self):
213
214    def required(x):
215      self.calls.append('required')
216      return x is not None
217
218    def even(x):
219      self.calls.append('even')
220      return x % 2 == 0
221
222    self.calls = []
223    self._define_flag_and_validators(required, even)
224    self.assertEqual(['required', 'even'], self.calls)
225
226    self.calls = []
227    self._define_flag_and_validators(even, required)
228    self.assertEqual(['even', 'required'], self.calls)
229
230  def _define_flag_and_validators(self, first_validator, second_validator):
231    local_flags = _flagvalues.FlagValues()
232    _defines.DEFINE_integer(
233        'test_flag', 2, 'test flag', flag_values=local_flags)
234    _validators.register_validator(
235        'test_flag', first_validator, message='', flag_values=local_flags)
236    _validators.register_validator(
237        'test_flag', second_validator, message='', flag_values=local_flags)
238    argv = ('./program',)
239    local_flags(argv)
240
241  def test_validator_as_decorator(self):
242    _defines.DEFINE_integer(
243        'test_flag', None, 'Simple integer flag', flag_values=self.flag_values)
244
245    @_validators.validator('test_flag', flag_values=self.flag_values)
246    def checker(x):
247      self.call_args.append(x)
248      return True
249
250    argv = ('./program',)
251    self.flag_values(argv)
252    self.assertIsNone(self.flag_values.test_flag)
253    self.flag_values.test_flag = 2
254    self.assertEqual(2, self.flag_values.test_flag)
255    self.assertEqual([None, 2], self.call_args)
256    # Check that 'Checker' is still a function and has not been replaced.
257    self.assertTrue(checker(3))
258    self.assertEqual([None, 2, 3], self.call_args)
259
260  def test_mismatching_flagvalues(self):
261
262    def checker(x):
263      self.call_args.append(x)
264      return True
265
266    flag_holder = _defines.DEFINE_integer(
267        'test_flag',
268        None,
269        'Usual integer flag',
270        flag_values=_flagvalues.FlagValues())
271    expected = (
272        'flag_values must not be customized when operating on a FlagHolder')
273    with self.assertRaisesWithLiteralMatch(ValueError, expected):
274      _validators.register_validator(
275          flag_holder,
276          checker,
277          message='Errors happen',
278          flag_values=self.flag_values)
279
280
281class MultiFlagsValidatorTest(absltest.TestCase):
282  """Test flags multi-flag validators."""
283
284  def setUp(self):
285    super(MultiFlagsValidatorTest, self).setUp()
286    self.flag_values = _flagvalues.FlagValues()
287    self.call_args = []
288    self.foo_holder = _defines.DEFINE_integer(
289        'foo', 1, 'Usual integer flag', flag_values=self.flag_values)
290    self.bar_holder = _defines.DEFINE_integer(
291        'bar', 2, 'Usual integer flag', flag_values=self.flag_values)
292
293  def test_success(self):
294    def checker(flags_dict):
295      self.call_args.append(flags_dict)
296      return True
297    _validators.register_multi_flags_validator(
298        ['foo', 'bar'], checker, flag_values=self.flag_values)
299
300    argv = ('./program', '--bar=2')
301    self.flag_values(argv)
302    self.assertEqual(1, self.flag_values.foo)
303    self.assertEqual(2, self.flag_values.bar)
304    self.assertEqual([{'foo': 1, 'bar': 2}], self.call_args)
305    self.flag_values.foo = 3
306    self.assertEqual(3, self.flag_values.foo)
307    self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 3, 'bar': 2}],
308                     self.call_args)
309
310  def test_success_holder(self):
311
312    def checker(flags_dict):
313      self.call_args.append(flags_dict)
314      return True
315
316    _validators.register_multi_flags_validator(
317        [self.foo_holder, self.bar_holder],
318        checker,
319        flag_values=self.flag_values)
320
321    argv = ('./program', '--bar=2')
322    self.flag_values(argv)
323    self.assertEqual(1, self.flag_values.foo)
324    self.assertEqual(2, self.flag_values.bar)
325    self.assertEqual([{'foo': 1, 'bar': 2}], self.call_args)
326    self.flag_values.foo = 3
327    self.assertEqual(3, self.flag_values.foo)
328    self.assertEqual([{
329        'foo': 1,
330        'bar': 2
331    }, {
332        'foo': 3,
333        'bar': 2
334    }], self.call_args)
335
336  def test_success_holder_infer_flagvalues(self):
337    def checker(flags_dict):
338      self.call_args.append(flags_dict)
339      return True
340
341    _validators.register_multi_flags_validator(
342        [self.foo_holder, self.bar_holder], checker)
343
344    argv = ('./program', '--bar=2')
345    self.flag_values(argv)
346    self.assertEqual(1, self.flag_values.foo)
347    self.assertEqual(2, self.flag_values.bar)
348    self.assertEqual([{'foo': 1, 'bar': 2}], self.call_args)
349    self.flag_values.foo = 3
350    self.assertEqual(3, self.flag_values.foo)
351    self.assertEqual([{
352        'foo': 1,
353        'bar': 2
354    }, {
355        'foo': 3,
356        'bar': 2
357    }], self.call_args)
358
359  def test_validator_not_called_when_other_flag_is_changed(self):
360    def checker(flags_dict):
361      self.call_args.append(flags_dict)
362      return True
363    _defines.DEFINE_integer(
364        'other_flag', 3, 'Other integer flag', flag_values=self.flag_values)
365    _validators.register_multi_flags_validator(
366        ['foo', 'bar'], checker, flag_values=self.flag_values)
367
368    argv = ('./program',)
369    self.flag_values(argv)
370    self.flag_values.other_flag = 3
371    self.assertEqual([{'foo': 1, 'bar': 2}], self.call_args)
372
373  def test_exception_raised_if_checker_fails(self):
374    def checker(flags_dict):
375      self.call_args.append(flags_dict)
376      values = flags_dict.values()
377      # Make sure all the flags have different values.
378      return len(set(values)) == len(values)
379    _validators.register_multi_flags_validator(
380        ['foo', 'bar'],
381        checker,
382        message='Errors happen',
383        flag_values=self.flag_values)
384
385    argv = ('./program',)
386    self.flag_values(argv)
387    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
388      self.flag_values.bar = 1
389    self.assertEqual('flags foo=1, bar=1: Errors happen', str(cm.exception))
390    self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}],
391                     self.call_args)
392
393  def test_exception_raised_if_checker_raises_exception(self):
394    def checker(flags_dict):
395      self.call_args.append(flags_dict)
396      values = flags_dict.values()
397      # Make sure all the flags have different values.
398      if len(set(values)) != len(values):
399        raise _exceptions.ValidationError('Specific message')
400      return True
401
402    _validators.register_multi_flags_validator(
403        ['foo', 'bar'],
404        checker,
405        message='Errors happen',
406        flag_values=self.flag_values)
407
408    argv = ('./program',)
409    self.flag_values(argv)
410    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
411      self.flag_values.bar = 1
412    self.assertEqual('flags foo=1, bar=1: Specific message', str(cm.exception))
413    self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}],
414                     self.call_args)
415
416  def test_decorator(self):
417    @_validators.multi_flags_validator(
418        ['foo', 'bar'], message='Errors happen', flag_values=self.flag_values)
419    def checker(flags_dict):  # pylint: disable=unused-variable
420      self.call_args.append(flags_dict)
421      values = flags_dict.values()
422      # Make sure all the flags have different values.
423      return len(set(values)) == len(values)
424
425    argv = ('./program',)
426    self.flag_values(argv)
427    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
428      self.flag_values.bar = 1
429    self.assertEqual('flags foo=1, bar=1: Errors happen', str(cm.exception))
430    self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}],
431                     self.call_args)
432
433  def test_mismatching_flagvalues(self):
434
435    def checker(flags_dict):
436      self.call_args.append(flags_dict)
437      values = flags_dict.values()
438      # Make sure all the flags have different values.
439      return len(set(values)) == len(values)
440
441    other_holder = _defines.DEFINE_integer(
442        'other_flag',
443        3,
444        'Other integer flag',
445        flag_values=_flagvalues.FlagValues())
446    expected = (
447        'multiple FlagValues instances used in invocation. '
448        'FlagHolders must be registered to the same FlagValues instance as '
449        'do flag names, if provided.')
450    with self.assertRaisesWithLiteralMatch(ValueError, expected):
451      _validators.register_multi_flags_validator(
452          [self.foo_holder, self.bar_holder, other_holder],
453          checker,
454          message='Errors happen',
455          flag_values=self.flag_values)
456
457
458class MarkFlagsAsMutualExclusiveTest(absltest.TestCase):
459
460  def setUp(self):
461    super(MarkFlagsAsMutualExclusiveTest, self).setUp()
462    self.flag_values = _flagvalues.FlagValues()
463
464    self.flag_one_holder = _defines.DEFINE_string(
465        'flag_one', None, 'flag one', flag_values=self.flag_values)
466    self.flag_two_holder = _defines.DEFINE_string(
467        'flag_two', None, 'flag two', flag_values=self.flag_values)
468    _defines.DEFINE_string(
469        'flag_three', None, 'flag three', flag_values=self.flag_values)
470    _defines.DEFINE_integer(
471        'int_flag_one', None, 'int flag one', flag_values=self.flag_values)
472    _defines.DEFINE_integer(
473        'int_flag_two', None, 'int flag two', flag_values=self.flag_values)
474    _defines.DEFINE_multi_string(
475        'multi_flag_one', None, 'multi flag one', flag_values=self.flag_values)
476    _defines.DEFINE_multi_string(
477        'multi_flag_two', None, 'multi flag two', flag_values=self.flag_values)
478    _defines.DEFINE_boolean(
479        'flag_not_none', False, 'false default', flag_values=self.flag_values)
480
481  def _mark_flags_as_mutually_exclusive(self, flag_names, required):
482    _validators.mark_flags_as_mutual_exclusive(
483        flag_names, required=required, flag_values=self.flag_values)
484
485  def test_no_flags_present(self):
486    self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], False)
487    argv = ('./program',)
488
489    self.flag_values(argv)
490    self.assertIsNone(self.flag_values.flag_one)
491    self.assertIsNone(self.flag_values.flag_two)
492
493  def test_no_flags_present_holder(self):
494    self._mark_flags_as_mutually_exclusive(
495        [self.flag_one_holder, self.flag_two_holder], False)
496    argv = ('./program',)
497
498    self.flag_values(argv)
499    self.assertIsNone(self.flag_values.flag_one)
500    self.assertIsNone(self.flag_values.flag_two)
501
502  def test_no_flags_present_mixed(self):
503    self._mark_flags_as_mutually_exclusive([self.flag_one_holder, 'flag_two'],
504                                           False)
505    argv = ('./program',)
506
507    self.flag_values(argv)
508    self.assertIsNone(self.flag_values.flag_one)
509    self.assertIsNone(self.flag_values.flag_two)
510
511  def test_no_flags_present_required(self):
512    self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], True)
513    argv = ('./program',)
514    expected = (
515        'flags flag_one=None, flag_two=None: '
516        'Exactly one of (flag_one, flag_two) must have a value other than '
517        'None.')
518
519    self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
520                                      expected, self.flag_values, argv)
521
522  def test_one_flag_present(self):
523    self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], False)
524    self.flag_values(('./program', '--flag_one=1'))
525    self.assertEqual('1', self.flag_values.flag_one)
526
527  def test_one_flag_present_required(self):
528    self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], True)
529    self.flag_values(('./program', '--flag_two=2'))
530    self.assertEqual('2', self.flag_values.flag_two)
531
532  def test_one_flag_zero_required(self):
533    self._mark_flags_as_mutually_exclusive(
534        ['int_flag_one', 'int_flag_two'], True)
535    self.flag_values(('./program', '--int_flag_one=0'))
536    self.assertEqual(0, self.flag_values.int_flag_one)
537
538  def test_mutual_exclusion_with_extra_flags(self):
539    self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], True)
540    argv = ('./program', '--flag_two=2', '--flag_three=3')
541
542    self.flag_values(argv)
543    self.assertEqual('2', self.flag_values.flag_two)
544    self.assertEqual('3', self.flag_values.flag_three)
545
546  def test_mutual_exclusion_with_zero(self):
547    self._mark_flags_as_mutually_exclusive(
548        ['int_flag_one', 'int_flag_two'], False)
549    argv = ('./program', '--int_flag_one=0', '--int_flag_two=0')
550    expected = (
551        'flags int_flag_one=0, int_flag_two=0: '
552        'At most one of (int_flag_one, int_flag_two) must have a value other '
553        'than None.')
554
555    self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
556                                      expected, self.flag_values, argv)
557
558  def test_multiple_flags_present(self):
559    self._mark_flags_as_mutually_exclusive(
560        ['flag_one', 'flag_two', 'flag_three'], False)
561    argv = ('./program', '--flag_one=1', '--flag_two=2', '--flag_three=3')
562    expected = (
563        'flags flag_one=1, flag_two=2, flag_three=3: '
564        'At most one of (flag_one, flag_two, flag_three) must have a value '
565        'other than None.')
566
567    self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
568                                      expected, self.flag_values, argv)
569
570  def test_multiple_flags_present_required(self):
571    self._mark_flags_as_mutually_exclusive(
572        ['flag_one', 'flag_two', 'flag_three'], True)
573    argv = ('./program', '--flag_one=1', '--flag_two=2', '--flag_three=3')
574    expected = (
575        'flags flag_one=1, flag_two=2, flag_three=3: '
576        'Exactly one of (flag_one, flag_two, flag_three) must have a value '
577        'other than None.')
578
579    self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
580                                      expected, self.flag_values, argv)
581
582  def test_no_multiflags_present(self):
583    self._mark_flags_as_mutually_exclusive(
584        ['multi_flag_one', 'multi_flag_two'], False)
585    argv = ('./program',)
586    self.flag_values(argv)
587    self.assertIsNone(self.flag_values.multi_flag_one)
588    self.assertIsNone(self.flag_values.multi_flag_two)
589
590  def test_no_multistring_flags_present_required(self):
591    self._mark_flags_as_mutually_exclusive(
592        ['multi_flag_one', 'multi_flag_two'], True)
593    argv = ('./program',)
594    expected = (
595        'flags multi_flag_one=None, multi_flag_two=None: '
596        'Exactly one of (multi_flag_one, multi_flag_two) must have a value '
597        'other than None.')
598
599    self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
600                                      expected, self.flag_values, argv)
601
602  def test_one_multiflag_present(self):
603    self._mark_flags_as_mutually_exclusive(
604        ['multi_flag_one', 'multi_flag_two'], True)
605    self.flag_values(('./program', '--multi_flag_one=1'))
606    self.assertEqual(['1'], self.flag_values.multi_flag_one)
607
608  def test_one_multiflag_present_repeated(self):
609    self._mark_flags_as_mutually_exclusive(
610        ['multi_flag_one', 'multi_flag_two'], True)
611    self.flag_values(('./program', '--multi_flag_one=1', '--multi_flag_one=1b'))
612    self.assertEqual(['1', '1b'], self.flag_values.multi_flag_one)
613
614  def test_multiple_multiflags_present(self):
615    self._mark_flags_as_mutually_exclusive(
616        ['multi_flag_one', 'multi_flag_two'], False)
617    argv = ('./program', '--multi_flag_one=1', '--multi_flag_two=2')
618    expected = (
619        "flags multi_flag_one=['1'], multi_flag_two=['2']: "
620        'At most one of (multi_flag_one, multi_flag_two) must have a value '
621        'other than None.')
622
623    self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
624                                      expected, self.flag_values, argv)
625
626  def test_multiple_multiflags_present_required(self):
627    self._mark_flags_as_mutually_exclusive(
628        ['multi_flag_one', 'multi_flag_two'], True)
629    argv = ('./program', '--multi_flag_one=1', '--multi_flag_two=2')
630    expected = (
631        "flags multi_flag_one=['1'], multi_flag_two=['2']: "
632        'Exactly one of (multi_flag_one, multi_flag_two) must have a value '
633        'other than None.')
634
635    self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
636                                      expected, self.flag_values, argv)
637
638  def test_flag_default_not_none_warning(self):
639    with warnings.catch_warnings(record=True) as caught_warnings:
640      warnings.simplefilter('always')
641      self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_not_none'],
642                                             False)
643    self.assertLen(caught_warnings, 1)
644    self.assertIn('--flag_not_none has a non-None default value',
645                  str(caught_warnings[0].message))
646
647  def test_multiple_flagvalues(self):
648    other_holder = _defines.DEFINE_boolean(
649        'other_flagvalues',
650        False,
651        'other ',
652        flag_values=_flagvalues.FlagValues())
653    expected = (
654        'multiple FlagValues instances used in invocation. '
655        'FlagHolders must be registered to the same FlagValues instance as '
656        'do flag names, if provided.')
657    with self.assertRaisesWithLiteralMatch(ValueError, expected):
658      self._mark_flags_as_mutually_exclusive(
659          [self.flag_one_holder, other_holder], False)
660
661
662class MarkBoolFlagsAsMutualExclusiveTest(absltest.TestCase):
663
664  def setUp(self):
665    super(MarkBoolFlagsAsMutualExclusiveTest, self).setUp()
666    self.flag_values = _flagvalues.FlagValues()
667
668    self.false_1_holder = _defines.DEFINE_boolean(
669        'false_1', False, 'default false 1', flag_values=self.flag_values)
670    self.false_2_holder = _defines.DEFINE_boolean(
671        'false_2', False, 'default false 2', flag_values=self.flag_values)
672    self.true_1_holder = _defines.DEFINE_boolean(
673        'true_1', True, 'default true 1', flag_values=self.flag_values)
674    self.non_bool_holder = _defines.DEFINE_integer(
675        'non_bool', None, 'non bool', flag_values=self.flag_values)
676
677  def _mark_bool_flags_as_mutually_exclusive(self, flag_names, required):
678    _validators.mark_bool_flags_as_mutual_exclusive(
679        flag_names, required=required, flag_values=self.flag_values)
680
681  def test_no_flags_present(self):
682    self._mark_bool_flags_as_mutually_exclusive(['false_1', 'false_2'], False)
683    self.flag_values(('./program',))
684    self.assertEqual(False, self.flag_values.false_1)
685    self.assertEqual(False, self.flag_values.false_2)
686
687  def test_no_flags_present_holder(self):
688    self._mark_bool_flags_as_mutually_exclusive(
689        [self.false_1_holder, self.false_2_holder], False)
690    self.flag_values(('./program',))
691    self.assertEqual(False, self.flag_values.false_1)
692    self.assertEqual(False, self.flag_values.false_2)
693
694  def test_no_flags_present_mixed(self):
695    self._mark_bool_flags_as_mutually_exclusive(
696        [self.false_1_holder, 'false_2'], False)
697    self.flag_values(('./program',))
698    self.assertEqual(False, self.flag_values.false_1)
699    self.assertEqual(False, self.flag_values.false_2)
700
701  def test_no_flags_present_required(self):
702    self._mark_bool_flags_as_mutually_exclusive(['false_1', 'false_2'], True)
703    argv = ('./program',)
704    expected = (
705        'flags false_1=False, false_2=False: '
706        'Exactly one of (false_1, false_2) must be True.')
707
708    self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
709                                      expected, self.flag_values, argv)
710
711  def test_no_flags_present_with_default_true_required(self):
712    self._mark_bool_flags_as_mutually_exclusive(['false_1', 'true_1'], True)
713    self.flag_values(('./program',))
714    self.assertEqual(False, self.flag_values.false_1)
715    self.assertEqual(True, self.flag_values.true_1)
716
717  def test_two_flags_true(self):
718    self._mark_bool_flags_as_mutually_exclusive(['false_1', 'false_2'], False)
719    argv = ('./program', '--false_1', '--false_2')
720    expected = (
721        'flags false_1=True, false_2=True: At most one of (false_1, '
722        'false_2) must be True.')
723
724    self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
725                                      expected, self.flag_values, argv)
726
727  def test_non_bool_flag(self):
728    expected = ('Flag --non_bool is not Boolean, which is required for flags '
729                'used in mark_bool_flags_as_mutual_exclusive.')
730    with self.assertRaisesWithLiteralMatch(_exceptions.ValidationError,
731                                           expected):
732      self._mark_bool_flags_as_mutually_exclusive(['false_1', 'non_bool'],
733                                                  False)
734
735  def test_multiple_flagvalues(self):
736    other_bool_holder = _defines.DEFINE_boolean(
737        'other_bool', False, 'other bool', flag_values=_flagvalues.FlagValues())
738    expected = (
739        'multiple FlagValues instances used in invocation. '
740        'FlagHolders must be registered to the same FlagValues instance as '
741        'do flag names, if provided.')
742    with self.assertRaisesWithLiteralMatch(ValueError, expected):
743      self._mark_bool_flags_as_mutually_exclusive(
744          [self.false_1_holder, other_bool_holder], False)
745
746
747class MarkFlagAsRequiredTest(absltest.TestCase):
748
749  def setUp(self):
750    super(MarkFlagAsRequiredTest, self).setUp()
751    self.flag_values = _flagvalues.FlagValues()
752
753  def test_success(self):
754    _defines.DEFINE_string(
755        'string_flag', None, 'string flag', flag_values=self.flag_values)
756    _validators.mark_flag_as_required(
757        'string_flag', flag_values=self.flag_values)
758    argv = ('./program', '--string_flag=value')
759    self.flag_values(argv)
760    self.assertEqual('value', self.flag_values.string_flag)
761
762  def test_success_holder(self):
763    holder = _defines.DEFINE_string(
764        'string_flag', None, 'string flag', flag_values=self.flag_values)
765    _validators.mark_flag_as_required(holder, flag_values=self.flag_values)
766    argv = ('./program', '--string_flag=value')
767    self.flag_values(argv)
768    self.assertEqual('value', self.flag_values.string_flag)
769
770  def test_success_holder_infer_flagvalues(self):
771    holder = _defines.DEFINE_string(
772        'string_flag', None, 'string flag', flag_values=self.flag_values)
773    _validators.mark_flag_as_required(holder)
774    argv = ('./program', '--string_flag=value')
775    self.flag_values(argv)
776    self.assertEqual('value', self.flag_values.string_flag)
777
778  def test_catch_none_as_default(self):
779    _defines.DEFINE_string(
780        'string_flag', None, 'string flag', flag_values=self.flag_values)
781    _validators.mark_flag_as_required(
782        'string_flag', flag_values=self.flag_values)
783    argv = ('./program',)
784    expected = (
785        r'flag --string_flag=None: Flag --string_flag must have a value other '
786        r'than None\.')
787    with self.assertRaisesRegex(_exceptions.IllegalFlagValueError, expected):
788      self.flag_values(argv)
789
790  def test_catch_setting_none_after_program_start(self):
791    _defines.DEFINE_string(
792        'string_flag', 'value', 'string flag', flag_values=self.flag_values)
793    _validators.mark_flag_as_required(
794        'string_flag', flag_values=self.flag_values)
795    argv = ('./program',)
796    self.flag_values(argv)
797    self.assertEqual('value', self.flag_values.string_flag)
798    expected = ('flag --string_flag=None: Flag --string_flag must have a value '
799                'other than None.')
800    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
801      self.flag_values.string_flag = None
802    self.assertEqual(expected, str(cm.exception))
803
804  def test_flag_default_not_none_warning(self):
805    _defines.DEFINE_string(
806        'flag_not_none', '', 'empty default', flag_values=self.flag_values)
807    with warnings.catch_warnings(record=True) as caught_warnings:
808      warnings.simplefilter('always')
809      _validators.mark_flag_as_required(
810          'flag_not_none', flag_values=self.flag_values)
811
812    self.assertLen(caught_warnings, 1)
813    self.assertIn('--flag_not_none has a non-None default value',
814                  str(caught_warnings[0].message))
815
816  def test_mismatching_flagvalues(self):
817    flag_holder = _defines.DEFINE_string(
818        'string_flag',
819        'value',
820        'string flag',
821        flag_values=_flagvalues.FlagValues())
822    expected = (
823        'flag_values must not be customized when operating on a FlagHolder')
824    with self.assertRaisesWithLiteralMatch(ValueError, expected):
825      _validators.mark_flag_as_required(
826          flag_holder, flag_values=self.flag_values)
827
828
829class MarkFlagsAsRequiredTest(absltest.TestCase):
830
831  def setUp(self):
832    super(MarkFlagsAsRequiredTest, self).setUp()
833    self.flag_values = _flagvalues.FlagValues()
834
835  def test_success(self):
836    _defines.DEFINE_string(
837        'string_flag_1', None, 'string flag 1', flag_values=self.flag_values)
838    _defines.DEFINE_string(
839        'string_flag_2', None, 'string flag 2', flag_values=self.flag_values)
840    flag_names = ['string_flag_1', 'string_flag_2']
841    _validators.mark_flags_as_required(flag_names, flag_values=self.flag_values)
842    argv = ('./program', '--string_flag_1=value_1', '--string_flag_2=value_2')
843    self.flag_values(argv)
844    self.assertEqual('value_1', self.flag_values.string_flag_1)
845    self.assertEqual('value_2', self.flag_values.string_flag_2)
846
847  def test_success_holders(self):
848    flag_1_holder = _defines.DEFINE_string(
849        'string_flag_1', None, 'string flag 1', flag_values=self.flag_values)
850    flag_2_holder = _defines.DEFINE_string(
851        'string_flag_2', None, 'string flag 2', flag_values=self.flag_values)
852    _validators.mark_flags_as_required([flag_1_holder, flag_2_holder],
853                                       flag_values=self.flag_values)
854    argv = ('./program', '--string_flag_1=value_1', '--string_flag_2=value_2')
855    self.flag_values(argv)
856    self.assertEqual('value_1', self.flag_values.string_flag_1)
857    self.assertEqual('value_2', self.flag_values.string_flag_2)
858
859  def test_catch_none_as_default(self):
860    _defines.DEFINE_string(
861        'string_flag_1', None, 'string flag 1', flag_values=self.flag_values)
862    _defines.DEFINE_string(
863        'string_flag_2', None, 'string flag 2', flag_values=self.flag_values)
864    _validators.mark_flags_as_required(
865        ['string_flag_1', 'string_flag_2'], flag_values=self.flag_values)
866    argv = ('./program', '--string_flag_1=value_1')
867    expected = (
868        r'flag --string_flag_2=None: Flag --string_flag_2 must have a value '
869        r'other than None\.')
870    with self.assertRaisesRegex(_exceptions.IllegalFlagValueError, expected):
871      self.flag_values(argv)
872
873  def test_catch_setting_none_after_program_start(self):
874    _defines.DEFINE_string(
875        'string_flag_1',
876        'value_1',
877        'string flag 1',
878        flag_values=self.flag_values)
879    _defines.DEFINE_string(
880        'string_flag_2',
881        'value_2',
882        'string flag 2',
883        flag_values=self.flag_values)
884    _validators.mark_flags_as_required(
885        ['string_flag_1', 'string_flag_2'], flag_values=self.flag_values)
886    argv = ('./program', '--string_flag_1=value_1')
887    self.flag_values(argv)
888    self.assertEqual('value_1', self.flag_values.string_flag_1)
889    expected = (
890        'flag --string_flag_1=None: Flag --string_flag_1 must have a value '
891        'other than None.')
892    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
893      self.flag_values.string_flag_1 = None
894    self.assertEqual(expected, str(cm.exception))
895
896  def test_catch_multiple_flags_as_none_at_program_start(self):
897    _defines.DEFINE_float(
898        'float_flag_1',
899        None,
900        'string flag 1',
901        flag_values=self.flag_values)
902    _defines.DEFINE_float(
903        'float_flag_2',
904        None,
905        'string flag 2',
906        flag_values=self.flag_values)
907    _validators.mark_flags_as_required(
908        ['float_flag_1', 'float_flag_2'], flag_values=self.flag_values)
909    argv = ('./program', '')
910    expected = (
911        'flag --float_flag_1=None: Flag --float_flag_1 must have a value '
912        'other than None.\n'
913        'flag --float_flag_2=None: Flag --float_flag_2 must have a value '
914        'other than None.')
915    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
916      self.flag_values(argv)
917    self.assertEqual(expected, str(cm.exception))
918
919  def test_fail_fast_single_flag_and_skip_remaining_validators(self):
920    def raise_unexpected_error(x):
921      del x
922      raise _exceptions.ValidationError('Should not be raised.')
923    _defines.DEFINE_float(
924        'flag_1', None, 'flag 1', flag_values=self.flag_values)
925    _defines.DEFINE_float(
926        'flag_2', 4.2, 'flag 2', flag_values=self.flag_values)
927    _validators.mark_flag_as_required('flag_1', flag_values=self.flag_values)
928    _validators.register_validator(
929        'flag_1', raise_unexpected_error, flag_values=self.flag_values)
930    _validators.register_multi_flags_validator(['flag_2', 'flag_1'],
931                                               raise_unexpected_error,
932                                               flag_values=self.flag_values)
933    argv = ('./program', '')
934    expected = (
935        'flag --flag_1=None: Flag --flag_1 must have a value other than None.')
936    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
937      self.flag_values(argv)
938    self.assertEqual(expected, str(cm.exception))
939
940  def test_fail_fast_multi_flag_and_skip_remaining_validators(self):
941    def raise_expected_error(x):
942      del x
943      raise _exceptions.ValidationError('Expected error.')
944    def raise_unexpected_error(x):
945      del x
946      raise _exceptions.ValidationError('Got unexpected error.')
947    _defines.DEFINE_float(
948        'flag_1', 5.1, 'flag 1', flag_values=self.flag_values)
949    _defines.DEFINE_float(
950        'flag_2', 10.0, 'flag 2', flag_values=self.flag_values)
951    _validators.register_multi_flags_validator(['flag_1', 'flag_2'],
952                                               raise_expected_error,
953                                               flag_values=self.flag_values)
954    _validators.register_multi_flags_validator(['flag_2', 'flag_1'],
955                                               raise_unexpected_error,
956                                               flag_values=self.flag_values)
957    _validators.register_validator(
958        'flag_1', raise_unexpected_error, flag_values=self.flag_values)
959    _validators.register_validator(
960        'flag_2', raise_unexpected_error, flag_values=self.flag_values)
961    argv = ('./program', '')
962    expected = ('flags flag_1=5.1, flag_2=10.0: Expected error.')
963    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
964      self.flag_values(argv)
965    self.assertEqual(expected, str(cm.exception))
966
967
968if __name__ == '__main__':
969  absltest.main()
970