• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The Abseil Authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Testing that flags validators framework does work.
16
17This file tests that each flag validator called when it should be, and that
18failed validator will throw an exception, etc.
19"""
20
21import warnings
22
23from absl.flags import _defines
24from absl.flags import _exceptions
25from absl.flags import _flagvalues
26from absl.flags import _validators
27from absl.testing import absltest
28
29
30class SingleFlagValidatorTest(absltest.TestCase):
31  """Testing _validators.register_validator() method."""
32
33  def setUp(self):
34    super(SingleFlagValidatorTest, self).setUp()
35    self.flag_values = _flagvalues.FlagValues()
36    self.call_args = []
37
38  def test_success(self):
39    def checker(x):
40      self.call_args.append(x)
41      return True
42    _defines.DEFINE_integer(
43        'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
44    _validators.register_validator(
45        'test_flag',
46        checker,
47        message='Errors happen',
48        flag_values=self.flag_values)
49
50    argv = ('./program',)
51    self.flag_values(argv)
52    self.assertIsNone(self.flag_values.test_flag)
53    self.flag_values.test_flag = 2
54    self.assertEqual(2, self.flag_values.test_flag)
55    self.assertEqual([None, 2], self.call_args)
56
57  def test_success_holder(self):
58    def checker(x):
59      self.call_args.append(x)
60      return True
61
62    flag_holder = _defines.DEFINE_integer(
63        'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
64    _validators.register_validator(
65        flag_holder,
66        checker,
67        message='Errors happen',
68        flag_values=self.flag_values)
69
70    argv = ('./program',)
71    self.flag_values(argv)
72    self.assertIsNone(self.flag_values.test_flag)
73    self.flag_values.test_flag = 2
74    self.assertEqual(2, self.flag_values.test_flag)
75    self.assertEqual([None, 2], self.call_args)
76
77  def test_success_holder_infer_flagvalues(self):
78    def checker(x):
79      self.call_args.append(x)
80      return True
81
82    flag_holder = _defines.DEFINE_integer(
83        'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
84    _validators.register_validator(
85        flag_holder,
86        checker,
87        message='Errors happen')
88
89    argv = ('./program',)
90    self.flag_values(argv)
91    self.assertIsNone(self.flag_values.test_flag)
92    self.flag_values.test_flag = 2
93    self.assertEqual(2, self.flag_values.test_flag)
94    self.assertEqual([None, 2], self.call_args)
95
96  def test_default_value_not_used_success(self):
97    def checker(x):
98      self.call_args.append(x)
99      return True
100    _defines.DEFINE_integer(
101        'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
102    _validators.register_validator(
103        'test_flag',
104        checker,
105        message='Errors happen',
106        flag_values=self.flag_values)
107
108    argv = ('./program', '--test_flag=1')
109    self.flag_values(argv)
110    self.assertEqual(1, self.flag_values.test_flag)
111    self.assertEqual([1], self.call_args)
112
113  def test_validator_not_called_when_other_flag_is_changed(self):
114    def checker(x):
115      self.call_args.append(x)
116      return True
117    _defines.DEFINE_integer(
118        'test_flag', 1, 'Usual integer flag', flag_values=self.flag_values)
119    _defines.DEFINE_integer(
120        'other_flag', 2, 'Other integer flag', flag_values=self.flag_values)
121    _validators.register_validator(
122        'test_flag',
123        checker,
124        message='Errors happen',
125        flag_values=self.flag_values)
126
127    argv = ('./program',)
128    self.flag_values(argv)
129    self.assertEqual(1, self.flag_values.test_flag)
130    self.flag_values.other_flag = 3
131    self.assertEqual([1], self.call_args)
132
133  def test_exception_raised_if_checker_fails(self):
134    def checker(x):
135      self.call_args.append(x)
136      return x == 1
137    _defines.DEFINE_integer(
138        'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
139    _validators.register_validator(
140        'test_flag',
141        checker,
142        message='Errors happen',
143        flag_values=self.flag_values)
144
145    argv = ('./program', '--test_flag=1')
146    self.flag_values(argv)
147    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
148      self.flag_values.test_flag = 2
149    self.assertEqual('flag --test_flag=2: Errors happen', str(cm.exception))
150    self.assertEqual([1, 2], self.call_args)
151
152  def test_exception_raised_if_checker_raises_exception(self):
153    def checker(x):
154      self.call_args.append(x)
155      if x == 1:
156        return True
157      raise _exceptions.ValidationError('Specific message')
158
159    _defines.DEFINE_integer(
160        'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
161    _validators.register_validator(
162        'test_flag',
163        checker,
164        message='Errors happen',
165        flag_values=self.flag_values)
166
167    argv = ('./program', '--test_flag=1')
168    self.flag_values(argv)
169    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
170      self.flag_values.test_flag = 2
171    self.assertEqual('flag --test_flag=2: Specific message', str(cm.exception))
172    self.assertEqual([1, 2], self.call_args)
173
174  def test_error_message_when_checker_returns_false_on_start(self):
175    def checker(x):
176      self.call_args.append(x)
177      return False
178    _defines.DEFINE_integer(
179        'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
180    _validators.register_validator(
181        'test_flag',
182        checker,
183        message='Errors happen',
184        flag_values=self.flag_values)
185
186    argv = ('./program', '--test_flag=1')
187    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
188      self.flag_values(argv)
189    self.assertEqual('flag --test_flag=1: Errors happen', str(cm.exception))
190    self.assertEqual([1], self.call_args)
191
192  def test_error_message_when_checker_raises_exception_on_start(self):
193    def checker(x):
194      self.call_args.append(x)
195      raise _exceptions.ValidationError('Specific message')
196
197    _defines.DEFINE_integer(
198        'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
199    _validators.register_validator(
200        'test_flag',
201        checker,
202        message='Errors happen',
203        flag_values=self.flag_values)
204
205    argv = ('./program', '--test_flag=1')
206    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
207      self.flag_values(argv)
208    self.assertEqual('flag --test_flag=1: Specific message', str(cm.exception))
209    self.assertEqual([1], self.call_args)
210
211  def test_validators_checked_in_order(self):
212
213    def required(x):
214      self.calls.append('required')
215      return x is not None
216
217    def even(x):
218      self.calls.append('even')
219      return x % 2 == 0
220
221    self.calls = []
222    self._define_flag_and_validators(required, even)
223    self.assertEqual(['required', 'even'], self.calls)
224
225    self.calls = []
226    self._define_flag_and_validators(even, required)
227    self.assertEqual(['even', 'required'], self.calls)
228
229  def _define_flag_and_validators(self, first_validator, second_validator):
230    local_flags = _flagvalues.FlagValues()
231    _defines.DEFINE_integer(
232        'test_flag', 2, 'test flag', flag_values=local_flags)
233    _validators.register_validator(
234        'test_flag', first_validator, message='', flag_values=local_flags)
235    _validators.register_validator(
236        'test_flag', second_validator, message='', flag_values=local_flags)
237    argv = ('./program',)
238    local_flags(argv)
239
240  def test_validator_as_decorator(self):
241    _defines.DEFINE_integer(
242        'test_flag', None, 'Simple integer flag', flag_values=self.flag_values)
243
244    @_validators.validator('test_flag', flag_values=self.flag_values)
245    def checker(x):
246      self.call_args.append(x)
247      return True
248
249    argv = ('./program',)
250    self.flag_values(argv)
251    self.assertIsNone(self.flag_values.test_flag)
252    self.flag_values.test_flag = 2
253    self.assertEqual(2, self.flag_values.test_flag)
254    self.assertEqual([None, 2], self.call_args)
255    # Check that 'Checker' is still a function and has not been replaced.
256    self.assertTrue(checker(3))
257    self.assertEqual([None, 2, 3], self.call_args)
258
259  def test_mismatching_flagvalues(self):
260
261    def checker(x):
262      self.call_args.append(x)
263      return True
264
265    flag_holder = _defines.DEFINE_integer(
266        'test_flag',
267        None,
268        'Usual integer flag',
269        flag_values=_flagvalues.FlagValues())
270    expected = (
271        'flag_values must not be customized when operating on a FlagHolder')
272    with self.assertRaisesWithLiteralMatch(ValueError, expected):
273      _validators.register_validator(
274          flag_holder,
275          checker,
276          message='Errors happen',
277          flag_values=self.flag_values)
278
279
280class MultiFlagsValidatorTest(absltest.TestCase):
281  """Test flags multi-flag validators."""
282
283  def setUp(self):
284    super(MultiFlagsValidatorTest, self).setUp()
285    self.flag_values = _flagvalues.FlagValues()
286    self.call_args = []
287    self.foo_holder = _defines.DEFINE_integer(
288        'foo', 1, 'Usual integer flag', flag_values=self.flag_values)
289    self.bar_holder = _defines.DEFINE_integer(
290        'bar', 2, 'Usual integer flag', flag_values=self.flag_values)
291
292  def test_success(self):
293    def checker(flags_dict):
294      self.call_args.append(flags_dict)
295      return True
296    _validators.register_multi_flags_validator(
297        ['foo', 'bar'], checker, flag_values=self.flag_values)
298
299    argv = ('./program', '--bar=2')
300    self.flag_values(argv)
301    self.assertEqual(1, self.flag_values.foo)
302    self.assertEqual(2, self.flag_values.bar)
303    self.assertEqual([{'foo': 1, 'bar': 2}], self.call_args)
304    self.flag_values.foo = 3
305    self.assertEqual(3, self.flag_values.foo)
306    self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 3, 'bar': 2}],
307                     self.call_args)
308
309  def test_success_holder(self):
310
311    def checker(flags_dict):
312      self.call_args.append(flags_dict)
313      return True
314
315    _validators.register_multi_flags_validator(
316        [self.foo_holder, self.bar_holder],
317        checker,
318        flag_values=self.flag_values)
319
320    argv = ('./program', '--bar=2')
321    self.flag_values(argv)
322    self.assertEqual(1, self.flag_values.foo)
323    self.assertEqual(2, self.flag_values.bar)
324    self.assertEqual([{'foo': 1, 'bar': 2}], self.call_args)
325    self.flag_values.foo = 3
326    self.assertEqual(3, self.flag_values.foo)
327    self.assertEqual([{
328        'foo': 1,
329        'bar': 2
330    }, {
331        'foo': 3,
332        'bar': 2
333    }], self.call_args)
334
335  def test_success_holder_infer_flagvalues(self):
336    def checker(flags_dict):
337      self.call_args.append(flags_dict)
338      return True
339
340    _validators.register_multi_flags_validator(
341        [self.foo_holder, self.bar_holder], checker)
342
343    argv = ('./program', '--bar=2')
344    self.flag_values(argv)
345    self.assertEqual(1, self.flag_values.foo)
346    self.assertEqual(2, self.flag_values.bar)
347    self.assertEqual([{'foo': 1, 'bar': 2}], self.call_args)
348    self.flag_values.foo = 3
349    self.assertEqual(3, self.flag_values.foo)
350    self.assertEqual([{
351        'foo': 1,
352        'bar': 2
353    }, {
354        'foo': 3,
355        'bar': 2
356    }], self.call_args)
357
358  def test_validator_not_called_when_other_flag_is_changed(self):
359    def checker(flags_dict):
360      self.call_args.append(flags_dict)
361      return True
362    _defines.DEFINE_integer(
363        'other_flag', 3, 'Other integer flag', flag_values=self.flag_values)
364    _validators.register_multi_flags_validator(
365        ['foo', 'bar'], checker, flag_values=self.flag_values)
366
367    argv = ('./program',)
368    self.flag_values(argv)
369    self.flag_values.other_flag = 3
370    self.assertEqual([{'foo': 1, 'bar': 2}], self.call_args)
371
372  def test_exception_raised_if_checker_fails(self):
373    def checker(flags_dict):
374      self.call_args.append(flags_dict)
375      values = flags_dict.values()
376      # Make sure all the flags have different values.
377      return len(set(values)) == len(values)
378    _validators.register_multi_flags_validator(
379        ['foo', 'bar'],
380        checker,
381        message='Errors happen',
382        flag_values=self.flag_values)
383
384    argv = ('./program',)
385    self.flag_values(argv)
386    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
387      self.flag_values.bar = 1
388    self.assertEqual('flags foo=1, bar=1: Errors happen', str(cm.exception))
389    self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}],
390                     self.call_args)
391
392  def test_exception_raised_if_checker_raises_exception(self):
393    def checker(flags_dict):
394      self.call_args.append(flags_dict)
395      values = flags_dict.values()
396      # Make sure all the flags have different values.
397      if len(set(values)) != len(values):
398        raise _exceptions.ValidationError('Specific message')
399      return True
400
401    _validators.register_multi_flags_validator(
402        ['foo', 'bar'],
403        checker,
404        message='Errors happen',
405        flag_values=self.flag_values)
406
407    argv = ('./program',)
408    self.flag_values(argv)
409    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
410      self.flag_values.bar = 1
411    self.assertEqual('flags foo=1, bar=1: Specific message', str(cm.exception))
412    self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}],
413                     self.call_args)
414
415  def test_decorator(self):
416    @_validators.multi_flags_validator(
417        ['foo', 'bar'], message='Errors happen', flag_values=self.flag_values)
418    def checker(flags_dict):  # pylint: disable=unused-variable
419      self.call_args.append(flags_dict)
420      values = flags_dict.values()
421      # Make sure all the flags have different values.
422      return len(set(values)) == len(values)
423
424    argv = ('./program',)
425    self.flag_values(argv)
426    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
427      self.flag_values.bar = 1
428    self.assertEqual('flags foo=1, bar=1: Errors happen', str(cm.exception))
429    self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}],
430                     self.call_args)
431
432  def test_mismatching_flagvalues(self):
433
434    def checker(flags_dict):
435      self.call_args.append(flags_dict)
436      values = flags_dict.values()
437      # Make sure all the flags have different values.
438      return len(set(values)) == len(values)
439
440    other_holder = _defines.DEFINE_integer(
441        'other_flag',
442        3,
443        'Other integer flag',
444        flag_values=_flagvalues.FlagValues())
445    expected = (
446        'multiple FlagValues instances used in invocation. '
447        'FlagHolders must be registered to the same FlagValues instance as '
448        'do flag names, if provided.')
449    with self.assertRaisesWithLiteralMatch(ValueError, expected):
450      _validators.register_multi_flags_validator(
451          [self.foo_holder, self.bar_holder, other_holder],
452          checker,
453          message='Errors happen',
454          flag_values=self.flag_values)
455
456
457class MarkFlagsAsMutualExclusiveTest(absltest.TestCase):
458
459  def setUp(self):
460    super(MarkFlagsAsMutualExclusiveTest, self).setUp()
461    self.flag_values = _flagvalues.FlagValues()
462
463    self.flag_one_holder = _defines.DEFINE_string(
464        'flag_one', None, 'flag one', flag_values=self.flag_values)
465    self.flag_two_holder = _defines.DEFINE_string(
466        'flag_two', None, 'flag two', flag_values=self.flag_values)
467    _defines.DEFINE_string(
468        'flag_three', None, 'flag three', flag_values=self.flag_values)
469    _defines.DEFINE_integer(
470        'int_flag_one', None, 'int flag one', flag_values=self.flag_values)
471    _defines.DEFINE_integer(
472        'int_flag_two', None, 'int flag two', flag_values=self.flag_values)
473    _defines.DEFINE_multi_string(
474        'multi_flag_one', None, 'multi flag one', flag_values=self.flag_values)
475    _defines.DEFINE_multi_string(
476        'multi_flag_two', None, 'multi flag two', flag_values=self.flag_values)
477    _defines.DEFINE_boolean(
478        'flag_not_none', False, 'false default', flag_values=self.flag_values)
479
480  def _mark_flags_as_mutually_exclusive(self, flag_names, required):
481    _validators.mark_flags_as_mutual_exclusive(
482        flag_names, required=required, flag_values=self.flag_values)
483
484  def test_no_flags_present(self):
485    self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], False)
486    argv = ('./program',)
487
488    self.flag_values(argv)
489    self.assertIsNone(self.flag_values.flag_one)
490    self.assertIsNone(self.flag_values.flag_two)
491
492  def test_no_flags_present_holder(self):
493    self._mark_flags_as_mutually_exclusive(
494        [self.flag_one_holder, self.flag_two_holder], False)
495    argv = ('./program',)
496
497    self.flag_values(argv)
498    self.assertIsNone(self.flag_values.flag_one)
499    self.assertIsNone(self.flag_values.flag_two)
500
501  def test_no_flags_present_mixed(self):
502    self._mark_flags_as_mutually_exclusive([self.flag_one_holder, 'flag_two'],
503                                           False)
504    argv = ('./program',)
505
506    self.flag_values(argv)
507    self.assertIsNone(self.flag_values.flag_one)
508    self.assertIsNone(self.flag_values.flag_two)
509
510  def test_no_flags_present_required(self):
511    self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], True)
512    argv = ('./program',)
513    expected = (
514        'flags flag_one=None, flag_two=None: '
515        'Exactly one of (flag_one, flag_two) must have a value other than '
516        'None.')
517
518    self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
519                                      expected, self.flag_values, argv)
520
521  def test_one_flag_present(self):
522    self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], False)
523    self.flag_values(('./program', '--flag_one=1'))
524    self.assertEqual('1', self.flag_values.flag_one)
525
526  def test_one_flag_present_required(self):
527    self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], True)
528    self.flag_values(('./program', '--flag_two=2'))
529    self.assertEqual('2', self.flag_values.flag_two)
530
531  def test_one_flag_zero_required(self):
532    self._mark_flags_as_mutually_exclusive(
533        ['int_flag_one', 'int_flag_two'], True)
534    self.flag_values(('./program', '--int_flag_one=0'))
535    self.assertEqual(0, self.flag_values.int_flag_one)
536
537  def test_mutual_exclusion_with_extra_flags(self):
538    self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], True)
539    argv = ('./program', '--flag_two=2', '--flag_three=3')
540
541    self.flag_values(argv)
542    self.assertEqual('2', self.flag_values.flag_two)
543    self.assertEqual('3', self.flag_values.flag_three)
544
545  def test_mutual_exclusion_with_zero(self):
546    self._mark_flags_as_mutually_exclusive(
547        ['int_flag_one', 'int_flag_two'], False)
548    argv = ('./program', '--int_flag_one=0', '--int_flag_two=0')
549    expected = (
550        'flags int_flag_one=0, int_flag_two=0: '
551        'At most one of (int_flag_one, int_flag_two) must have a value other '
552        'than None.')
553
554    self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
555                                      expected, self.flag_values, argv)
556
557  def test_multiple_flags_present(self):
558    self._mark_flags_as_mutually_exclusive(
559        ['flag_one', 'flag_two', 'flag_three'], False)
560    argv = ('./program', '--flag_one=1', '--flag_two=2', '--flag_three=3')
561    expected = (
562        'flags flag_one=1, flag_two=2, flag_three=3: '
563        'At most one of (flag_one, flag_two, flag_three) must have a value '
564        'other than None.')
565
566    self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
567                                      expected, self.flag_values, argv)
568
569  def test_multiple_flags_present_required(self):
570    self._mark_flags_as_mutually_exclusive(
571        ['flag_one', 'flag_two', 'flag_three'], True)
572    argv = ('./program', '--flag_one=1', '--flag_two=2', '--flag_three=3')
573    expected = (
574        'flags flag_one=1, flag_two=2, flag_three=3: '
575        'Exactly one of (flag_one, flag_two, flag_three) must have a value '
576        'other than None.')
577
578    self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
579                                      expected, self.flag_values, argv)
580
581  def test_no_multiflags_present(self):
582    self._mark_flags_as_mutually_exclusive(
583        ['multi_flag_one', 'multi_flag_two'], False)
584    argv = ('./program',)
585    self.flag_values(argv)
586    self.assertIsNone(self.flag_values.multi_flag_one)
587    self.assertIsNone(self.flag_values.multi_flag_two)
588
589  def test_no_multistring_flags_present_required(self):
590    self._mark_flags_as_mutually_exclusive(
591        ['multi_flag_one', 'multi_flag_two'], True)
592    argv = ('./program',)
593    expected = (
594        'flags multi_flag_one=None, multi_flag_two=None: '
595        'Exactly one of (multi_flag_one, multi_flag_two) must have a value '
596        'other than None.')
597
598    self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
599                                      expected, self.flag_values, argv)
600
601  def test_one_multiflag_present(self):
602    self._mark_flags_as_mutually_exclusive(
603        ['multi_flag_one', 'multi_flag_two'], True)
604    self.flag_values(('./program', '--multi_flag_one=1'))
605    self.assertEqual(['1'], self.flag_values.multi_flag_one)
606
607  def test_one_multiflag_present_repeated(self):
608    self._mark_flags_as_mutually_exclusive(
609        ['multi_flag_one', 'multi_flag_two'], True)
610    self.flag_values(('./program', '--multi_flag_one=1', '--multi_flag_one=1b'))
611    self.assertEqual(['1', '1b'], self.flag_values.multi_flag_one)
612
613  def test_multiple_multiflags_present(self):
614    self._mark_flags_as_mutually_exclusive(
615        ['multi_flag_one', 'multi_flag_two'], False)
616    argv = ('./program', '--multi_flag_one=1', '--multi_flag_two=2')
617    expected = (
618        "flags multi_flag_one=['1'], multi_flag_two=['2']: "
619        'At most one of (multi_flag_one, multi_flag_two) must have a value '
620        'other than None.')
621
622    self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
623                                      expected, self.flag_values, argv)
624
625  def test_multiple_multiflags_present_required(self):
626    self._mark_flags_as_mutually_exclusive(
627        ['multi_flag_one', 'multi_flag_two'], True)
628    argv = ('./program', '--multi_flag_one=1', '--multi_flag_two=2')
629    expected = (
630        "flags multi_flag_one=['1'], multi_flag_two=['2']: "
631        'Exactly one of (multi_flag_one, multi_flag_two) must have a value '
632        'other than None.')
633
634    self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
635                                      expected, self.flag_values, argv)
636
637  def test_flag_default_not_none_warning(self):
638    with warnings.catch_warnings(record=True) as caught_warnings:
639      warnings.simplefilter('always')
640      self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_not_none'],
641                                             False)
642    self.assertLen(caught_warnings, 1)
643    self.assertIn('--flag_not_none has a non-None default value',
644                  str(caught_warnings[0].message))
645
646  def test_multiple_flagvalues(self):
647    other_holder = _defines.DEFINE_boolean(
648        'other_flagvalues',
649        False,
650        'other ',
651        flag_values=_flagvalues.FlagValues())
652    expected = (
653        'multiple FlagValues instances used in invocation. '
654        'FlagHolders must be registered to the same FlagValues instance as '
655        'do flag names, if provided.')
656    with self.assertRaisesWithLiteralMatch(ValueError, expected):
657      self._mark_flags_as_mutually_exclusive(
658          [self.flag_one_holder, other_holder], False)
659
660
661class MarkBoolFlagsAsMutualExclusiveTest(absltest.TestCase):
662
663  def setUp(self):
664    super(MarkBoolFlagsAsMutualExclusiveTest, self).setUp()
665    self.flag_values = _flagvalues.FlagValues()
666
667    self.false_1_holder = _defines.DEFINE_boolean(
668        'false_1', False, 'default false 1', flag_values=self.flag_values)
669    self.false_2_holder = _defines.DEFINE_boolean(
670        'false_2', False, 'default false 2', flag_values=self.flag_values)
671    self.true_1_holder = _defines.DEFINE_boolean(
672        'true_1', True, 'default true 1', flag_values=self.flag_values)
673    self.non_bool_holder = _defines.DEFINE_integer(
674        'non_bool', None, 'non bool', flag_values=self.flag_values)
675
676  def _mark_bool_flags_as_mutually_exclusive(self, flag_names, required):
677    _validators.mark_bool_flags_as_mutual_exclusive(
678        flag_names, required=required, flag_values=self.flag_values)
679
680  def test_no_flags_present(self):
681    self._mark_bool_flags_as_mutually_exclusive(['false_1', 'false_2'], False)
682    self.flag_values(('./program',))
683    self.assertEqual(False, self.flag_values.false_1)
684    self.assertEqual(False, self.flag_values.false_2)
685
686  def test_no_flags_present_holder(self):
687    self._mark_bool_flags_as_mutually_exclusive(
688        [self.false_1_holder, self.false_2_holder], False)
689    self.flag_values(('./program',))
690    self.assertEqual(False, self.flag_values.false_1)
691    self.assertEqual(False, self.flag_values.false_2)
692
693  def test_no_flags_present_mixed(self):
694    self._mark_bool_flags_as_mutually_exclusive(
695        [self.false_1_holder, 'false_2'], False)
696    self.flag_values(('./program',))
697    self.assertEqual(False, self.flag_values.false_1)
698    self.assertEqual(False, self.flag_values.false_2)
699
700  def test_no_flags_present_required(self):
701    self._mark_bool_flags_as_mutually_exclusive(['false_1', 'false_2'], True)
702    argv = ('./program',)
703    expected = (
704        'flags false_1=False, false_2=False: '
705        'Exactly one of (false_1, false_2) must be True.')
706
707    self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
708                                      expected, self.flag_values, argv)
709
710  def test_no_flags_present_with_default_true_required(self):
711    self._mark_bool_flags_as_mutually_exclusive(['false_1', 'true_1'], True)
712    self.flag_values(('./program',))
713    self.assertEqual(False, self.flag_values.false_1)
714    self.assertEqual(True, self.flag_values.true_1)
715
716  def test_two_flags_true(self):
717    self._mark_bool_flags_as_mutually_exclusive(['false_1', 'false_2'], False)
718    argv = ('./program', '--false_1', '--false_2')
719    expected = (
720        'flags false_1=True, false_2=True: At most one of (false_1, '
721        'false_2) must be True.')
722
723    self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
724                                      expected, self.flag_values, argv)
725
726  def test_non_bool_flag(self):
727    expected = ('Flag --non_bool is not Boolean, which is required for flags '
728                'used in mark_bool_flags_as_mutual_exclusive.')
729    with self.assertRaisesWithLiteralMatch(_exceptions.ValidationError,
730                                           expected):
731      self._mark_bool_flags_as_mutually_exclusive(['false_1', 'non_bool'],
732                                                  False)
733
734  def test_multiple_flagvalues(self):
735    other_bool_holder = _defines.DEFINE_boolean(
736        'other_bool', False, 'other bool', flag_values=_flagvalues.FlagValues())
737    expected = (
738        'multiple FlagValues instances used in invocation. '
739        'FlagHolders must be registered to the same FlagValues instance as '
740        'do flag names, if provided.')
741    with self.assertRaisesWithLiteralMatch(ValueError, expected):
742      self._mark_bool_flags_as_mutually_exclusive(
743          [self.false_1_holder, other_bool_holder], False)
744
745
746class MarkFlagAsRequiredTest(absltest.TestCase):
747
748  def setUp(self):
749    super(MarkFlagAsRequiredTest, self).setUp()
750    self.flag_values = _flagvalues.FlagValues()
751
752  def test_success(self):
753    _defines.DEFINE_string(
754        'string_flag', None, 'string flag', flag_values=self.flag_values)
755    _validators.mark_flag_as_required(
756        'string_flag', flag_values=self.flag_values)
757    argv = ('./program', '--string_flag=value')
758    self.flag_values(argv)
759    self.assertEqual('value', self.flag_values.string_flag)
760
761  def test_success_holder(self):
762    holder = _defines.DEFINE_string(
763        'string_flag', None, 'string flag', flag_values=self.flag_values)
764    _validators.mark_flag_as_required(holder, flag_values=self.flag_values)
765    argv = ('./program', '--string_flag=value')
766    self.flag_values(argv)
767    self.assertEqual('value', self.flag_values.string_flag)
768
769  def test_success_holder_infer_flagvalues(self):
770    holder = _defines.DEFINE_string(
771        'string_flag', None, 'string flag', flag_values=self.flag_values)
772    _validators.mark_flag_as_required(holder)
773    argv = ('./program', '--string_flag=value')
774    self.flag_values(argv)
775    self.assertEqual('value', self.flag_values.string_flag)
776
777  def test_catch_none_as_default(self):
778    _defines.DEFINE_string(
779        'string_flag', None, 'string flag', flag_values=self.flag_values)
780    _validators.mark_flag_as_required(
781        'string_flag', flag_values=self.flag_values)
782    argv = ('./program',)
783    expected = (
784        r'flag --string_flag=None: Flag --string_flag must have a value other '
785        r'than None\.')
786    with self.assertRaisesRegex(_exceptions.IllegalFlagValueError, expected):
787      self.flag_values(argv)
788
789  def test_catch_setting_none_after_program_start(self):
790    _defines.DEFINE_string(
791        'string_flag', 'value', 'string flag', flag_values=self.flag_values)
792    _validators.mark_flag_as_required(
793        'string_flag', flag_values=self.flag_values)
794    argv = ('./program',)
795    self.flag_values(argv)
796    self.assertEqual('value', self.flag_values.string_flag)
797    expected = ('flag --string_flag=None: Flag --string_flag must have a value '
798                'other than None.')
799    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
800      self.flag_values.string_flag = None
801    self.assertEqual(expected, str(cm.exception))
802
803  def test_flag_default_not_none_warning(self):
804    _defines.DEFINE_string(
805        'flag_not_none', '', 'empty default', flag_values=self.flag_values)
806    with warnings.catch_warnings(record=True) as caught_warnings:
807      warnings.simplefilter('always')
808      _validators.mark_flag_as_required(
809          'flag_not_none', flag_values=self.flag_values)
810
811    self.assertLen(caught_warnings, 1)
812    self.assertIn('--flag_not_none has a non-None default value',
813                  str(caught_warnings[0].message))
814
815  def test_mismatching_flagvalues(self):
816    flag_holder = _defines.DEFINE_string(
817        'string_flag',
818        'value',
819        'string flag',
820        flag_values=_flagvalues.FlagValues())
821    expected = (
822        'flag_values must not be customized when operating on a FlagHolder')
823    with self.assertRaisesWithLiteralMatch(ValueError, expected):
824      _validators.mark_flag_as_required(
825          flag_holder, flag_values=self.flag_values)
826
827
828class MarkFlagsAsRequiredTest(absltest.TestCase):
829
830  def setUp(self):
831    super(MarkFlagsAsRequiredTest, self).setUp()
832    self.flag_values = _flagvalues.FlagValues()
833
834  def test_success(self):
835    _defines.DEFINE_string(
836        'string_flag_1', None, 'string flag 1', flag_values=self.flag_values)
837    _defines.DEFINE_string(
838        'string_flag_2', None, 'string flag 2', flag_values=self.flag_values)
839    flag_names = ['string_flag_1', 'string_flag_2']
840    _validators.mark_flags_as_required(flag_names, flag_values=self.flag_values)
841    argv = ('./program', '--string_flag_1=value_1', '--string_flag_2=value_2')
842    self.flag_values(argv)
843    self.assertEqual('value_1', self.flag_values.string_flag_1)
844    self.assertEqual('value_2', self.flag_values.string_flag_2)
845
846  def test_success_holders(self):
847    flag_1_holder = _defines.DEFINE_string(
848        'string_flag_1', None, 'string flag 1', flag_values=self.flag_values)
849    flag_2_holder = _defines.DEFINE_string(
850        'string_flag_2', None, 'string flag 2', flag_values=self.flag_values)
851    _validators.mark_flags_as_required([flag_1_holder, flag_2_holder],
852                                       flag_values=self.flag_values)
853    argv = ('./program', '--string_flag_1=value_1', '--string_flag_2=value_2')
854    self.flag_values(argv)
855    self.assertEqual('value_1', self.flag_values.string_flag_1)
856    self.assertEqual('value_2', self.flag_values.string_flag_2)
857
858  def test_catch_none_as_default(self):
859    _defines.DEFINE_string(
860        'string_flag_1', None, 'string flag 1', flag_values=self.flag_values)
861    _defines.DEFINE_string(
862        'string_flag_2', None, 'string flag 2', flag_values=self.flag_values)
863    _validators.mark_flags_as_required(
864        ['string_flag_1', 'string_flag_2'], flag_values=self.flag_values)
865    argv = ('./program', '--string_flag_1=value_1')
866    expected = (
867        r'flag --string_flag_2=None: Flag --string_flag_2 must have a value '
868        r'other than None\.')
869    with self.assertRaisesRegex(_exceptions.IllegalFlagValueError, expected):
870      self.flag_values(argv)
871
872  def test_catch_setting_none_after_program_start(self):
873    _defines.DEFINE_string(
874        'string_flag_1',
875        'value_1',
876        'string flag 1',
877        flag_values=self.flag_values)
878    _defines.DEFINE_string(
879        'string_flag_2',
880        'value_2',
881        'string flag 2',
882        flag_values=self.flag_values)
883    _validators.mark_flags_as_required(
884        ['string_flag_1', 'string_flag_2'], flag_values=self.flag_values)
885    argv = ('./program', '--string_flag_1=value_1')
886    self.flag_values(argv)
887    self.assertEqual('value_1', self.flag_values.string_flag_1)
888    expected = (
889        'flag --string_flag_1=None: Flag --string_flag_1 must have a value '
890        'other than None.')
891    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
892      self.flag_values.string_flag_1 = None
893    self.assertEqual(expected, str(cm.exception))
894
895  def test_catch_multiple_flags_as_none_at_program_start(self):
896    _defines.DEFINE_float(
897        'float_flag_1',
898        None,
899        'string flag 1',
900        flag_values=self.flag_values)
901    _defines.DEFINE_float(
902        'float_flag_2',
903        None,
904        'string flag 2',
905        flag_values=self.flag_values)
906    _validators.mark_flags_as_required(
907        ['float_flag_1', 'float_flag_2'], flag_values=self.flag_values)
908    argv = ('./program', '')
909    expected = (
910        'flag --float_flag_1=None: Flag --float_flag_1 must have a value '
911        'other than None.\n'
912        'flag --float_flag_2=None: Flag --float_flag_2 must have a value '
913        'other than None.')
914    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
915      self.flag_values(argv)
916    self.assertEqual(expected, str(cm.exception))
917
918  def test_fail_fast_single_flag_and_skip_remaining_validators(self):
919    def raise_unexpected_error(x):
920      del x
921      raise _exceptions.ValidationError('Should not be raised.')
922    _defines.DEFINE_float(
923        'flag_1', None, 'flag 1', flag_values=self.flag_values)
924    _defines.DEFINE_float(
925        'flag_2', 4.2, 'flag 2', flag_values=self.flag_values)
926    _validators.mark_flag_as_required('flag_1', flag_values=self.flag_values)
927    _validators.register_validator(
928        'flag_1', raise_unexpected_error, flag_values=self.flag_values)
929    _validators.register_multi_flags_validator(['flag_2', 'flag_1'],
930                                               raise_unexpected_error,
931                                               flag_values=self.flag_values)
932    argv = ('./program', '')
933    expected = (
934        'flag --flag_1=None: Flag --flag_1 must have a value other than None.')
935    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
936      self.flag_values(argv)
937    self.assertEqual(expected, str(cm.exception))
938
939  def test_fail_fast_multi_flag_and_skip_remaining_validators(self):
940    def raise_expected_error(x):
941      del x
942      raise _exceptions.ValidationError('Expected error.')
943    def raise_unexpected_error(x):
944      del x
945      raise _exceptions.ValidationError('Got unexpected error.')
946    _defines.DEFINE_float(
947        'flag_1', 5.1, 'flag 1', flag_values=self.flag_values)
948    _defines.DEFINE_float(
949        'flag_2', 10.0, 'flag 2', flag_values=self.flag_values)
950    _validators.register_multi_flags_validator(['flag_1', 'flag_2'],
951                                               raise_expected_error,
952                                               flag_values=self.flag_values)
953    _validators.register_multi_flags_validator(['flag_2', 'flag_1'],
954                                               raise_unexpected_error,
955                                               flag_values=self.flag_values)
956    _validators.register_validator(
957        'flag_1', raise_unexpected_error, flag_values=self.flag_values)
958    _validators.register_validator(
959        'flag_2', raise_unexpected_error, flag_values=self.flag_values)
960    argv = ('./program', '')
961    expected = ('flags flag_1=5.1, flag_2=10.0: Expected error.')
962    with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
963      self.flag_values(argv)
964    self.assertEqual(expected, str(cm.exception))
965
966
967if __name__ == '__main__':
968  absltest.main()
969