• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The Abseil Authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Contains base classes used to parse and convert arguments.
16
17Do NOT import this module directly. Import the flags package and use the
18aliases defined at the package level instead.
19"""
20
21import collections
22import csv
23import io
24import string
25
26from absl.flags import _helpers
27
28
29def _is_integer_type(instance):
30  """Returns True if instance is an integer, and not a bool."""
31  return (isinstance(instance, int) and
32          not isinstance(instance, bool))
33
34
35class _ArgumentParserCache(type):
36  """Metaclass used to cache and share argument parsers among flags."""
37
38  _instances = {}
39
40  def __call__(cls, *args, **kwargs):
41    """Returns an instance of the argument parser cls.
42
43    This method overrides behavior of the __new__ methods in
44    all subclasses of ArgumentParser (inclusive). If an instance
45    for cls with the same set of arguments exists, this instance is
46    returned, otherwise a new instance is created.
47
48    If any keyword arguments are defined, or the values in args
49    are not hashable, this method always returns a new instance of
50    cls.
51
52    Args:
53      *args: Positional initializer arguments.
54      **kwargs: Initializer keyword arguments.
55
56    Returns:
57      An instance of cls, shared or new.
58    """
59    if kwargs:
60      return type.__call__(cls, *args, **kwargs)
61    else:
62      instances = cls._instances
63      key = (cls,) + tuple(args)
64      try:
65        return instances[key]
66      except KeyError:
67        # No cache entry for key exists, create a new one.
68        return instances.setdefault(key, type.__call__(cls, *args))
69      except TypeError:
70        # An object in args cannot be hashed, always return
71        # a new instance.
72        return type.__call__(cls, *args)
73
74
75# NOTE about Genericity and Metaclass of ArgumentParser.
76# (1) In the .py source (this file)
77#     - is not declared as Generic
78#     - has _ArgumentParserCache as a metaclass
79# (2) In the .pyi source (type stub)
80#     - is declared as Generic
81#     - doesn't have a metaclass
82# The reason we need this is due to Generic having a different metaclass
83# (for python versions <= 3.7) and a class can have only one metaclass.
84#
85# * Lack of metaclass in .pyi is not a deal breaker, since the metaclass
86#   doesn't affect any type information. Also type checkers can check the type
87#   parameters.
88# * However, not declaring ArgumentParser as Generic in the source affects
89#   runtime annotation processing. In particular this means, subclasses should
90#   inherit from `ArgumentParser` and not `ArgumentParser[SomeType]`.
91#   The corresponding DEFINE_someType method (the public API) can be annotated
92#   to return FlagHolder[SomeType].
93class ArgumentParser(metaclass=_ArgumentParserCache):
94  """Base class used to parse and convert arguments.
95
96  The :meth:`parse` method checks to make sure that the string argument is a
97  legal value and convert it to a native type.  If the value cannot be
98  converted, it should throw a ``ValueError`` exception with a human
99  readable explanation of why the value is illegal.
100
101  Subclasses should also define a syntactic_help string which may be
102  presented to the user to describe the form of the legal values.
103
104  Argument parser classes must be stateless, since instances are cached
105  and shared between flags. Initializer arguments are allowed, but all
106  member variables must be derived from initializer arguments only.
107  """
108
109  syntactic_help = ''
110
111  def parse(self, argument):
112    """Parses the string argument and returns the native value.
113
114    By default it returns its argument unmodified.
115
116    Args:
117      argument: string argument passed in the commandline.
118
119    Raises:
120      ValueError: Raised when it fails to parse the argument.
121      TypeError: Raised when the argument has the wrong type.
122
123    Returns:
124      The parsed value in native type.
125    """
126    if not isinstance(argument, str):
127      raise TypeError('flag value must be a string, found "{}"'.format(
128          type(argument)))
129    return argument
130
131  def flag_type(self):
132    """Returns a string representing the type of the flag."""
133    return 'string'
134
135  def _custom_xml_dom_elements(self, doc):
136    """Returns a list of minidom.Element to add additional flag information.
137
138    Args:
139      doc: minidom.Document, the DOM document it should create nodes from.
140    """
141    del doc  # Unused.
142    return []
143
144
145class ArgumentSerializer(object):
146  """Base class for generating string representations of a flag value."""
147
148  def serialize(self, value):
149    """Returns a serialized string of the value."""
150    return str(value)
151
152
153class NumericParser(ArgumentParser):
154  """Parser of numeric values.
155
156  Parsed value may be bounded to a given upper and lower bound.
157  """
158
159  def is_outside_bounds(self, val):
160    """Returns whether the value is outside the bounds or not."""
161    return ((self.lower_bound is not None and val < self.lower_bound) or
162            (self.upper_bound is not None and val > self.upper_bound))
163
164  def parse(self, argument):
165    """See base class."""
166    val = self.convert(argument)
167    if self.is_outside_bounds(val):
168      raise ValueError('%s is not %s' % (val, self.syntactic_help))
169    return val
170
171  def _custom_xml_dom_elements(self, doc):
172    elements = []
173    if self.lower_bound is not None:
174      elements.append(_helpers.create_xml_dom_element(
175          doc, 'lower_bound', self.lower_bound))
176    if self.upper_bound is not None:
177      elements.append(_helpers.create_xml_dom_element(
178          doc, 'upper_bound', self.upper_bound))
179    return elements
180
181  def convert(self, argument):
182    """Returns the correct numeric value of argument.
183
184    Subclass must implement this method, and raise TypeError if argument is not
185    string or has the right numeric type.
186
187    Args:
188      argument: string argument passed in the commandline, or the numeric type.
189
190    Raises:
191      TypeError: Raised when argument is not a string or the right numeric type.
192      ValueError: Raised when failed to convert argument to the numeric value.
193    """
194    raise NotImplementedError
195
196
197class FloatParser(NumericParser):
198  """Parser of floating point values.
199
200  Parsed value may be bounded to a given upper and lower bound.
201  """
202  number_article = 'a'
203  number_name = 'number'
204  syntactic_help = ' '.join((number_article, number_name))
205
206  def __init__(self, lower_bound=None, upper_bound=None):
207    super(FloatParser, self).__init__()
208    self.lower_bound = lower_bound
209    self.upper_bound = upper_bound
210    sh = self.syntactic_help
211    if lower_bound is not None and upper_bound is not None:
212      sh = ('%s in the range [%s, %s]' % (sh, lower_bound, upper_bound))
213    elif lower_bound == 0:
214      sh = 'a non-negative %s' % self.number_name
215    elif upper_bound == 0:
216      sh = 'a non-positive %s' % self.number_name
217    elif upper_bound is not None:
218      sh = '%s <= %s' % (self.number_name, upper_bound)
219    elif lower_bound is not None:
220      sh = '%s >= %s' % (self.number_name, lower_bound)
221    self.syntactic_help = sh
222
223  def convert(self, argument):
224    """Returns the float value of argument."""
225    if (_is_integer_type(argument) or isinstance(argument, float) or
226        isinstance(argument, str)):
227      return float(argument)
228    else:
229      raise TypeError(
230          'Expect argument to be a string, int, or float, found {}'.format(
231              type(argument)))
232
233  def flag_type(self):
234    """See base class."""
235    return 'float'
236
237
238class IntegerParser(NumericParser):
239  """Parser of an integer value.
240
241  Parsed value may be bounded to a given upper and lower bound.
242  """
243  number_article = 'an'
244  number_name = 'integer'
245  syntactic_help = ' '.join((number_article, number_name))
246
247  def __init__(self, lower_bound=None, upper_bound=None):
248    super(IntegerParser, self).__init__()
249    self.lower_bound = lower_bound
250    self.upper_bound = upper_bound
251    sh = self.syntactic_help
252    if lower_bound is not None and upper_bound is not None:
253      sh = ('%s in the range [%s, %s]' % (sh, lower_bound, upper_bound))
254    elif lower_bound == 1:
255      sh = 'a positive %s' % self.number_name
256    elif upper_bound == -1:
257      sh = 'a negative %s' % self.number_name
258    elif lower_bound == 0:
259      sh = 'a non-negative %s' % self.number_name
260    elif upper_bound == 0:
261      sh = 'a non-positive %s' % self.number_name
262    elif upper_bound is not None:
263      sh = '%s <= %s' % (self.number_name, upper_bound)
264    elif lower_bound is not None:
265      sh = '%s >= %s' % (self.number_name, lower_bound)
266    self.syntactic_help = sh
267
268  def convert(self, argument):
269    """Returns the int value of argument."""
270    if _is_integer_type(argument):
271      return argument
272    elif isinstance(argument, str):
273      base = 10
274      if len(argument) > 2 and argument[0] == '0':
275        if argument[1] == 'o':
276          base = 8
277        elif argument[1] == 'x':
278          base = 16
279      return int(argument, base)
280    else:
281      raise TypeError('Expect argument to be a string or int, found {}'.format(
282          type(argument)))
283
284  def flag_type(self):
285    """See base class."""
286    return 'int'
287
288
289class BooleanParser(ArgumentParser):
290  """Parser of boolean values."""
291
292  def parse(self, argument):
293    """See base class."""
294    if isinstance(argument, str):
295      if argument.lower() in ('true', 't', '1'):
296        return True
297      elif argument.lower() in ('false', 'f', '0'):
298        return False
299      else:
300        raise ValueError('Non-boolean argument to boolean flag', argument)
301    elif isinstance(argument, int):
302      # Only allow bool or integer 0, 1.
303      # Note that float 1.0 == True, 0.0 == False.
304      bool_value = bool(argument)
305      if argument == bool_value:
306        return bool_value
307      else:
308        raise ValueError('Non-boolean argument to boolean flag', argument)
309
310    raise TypeError('Non-boolean argument to boolean flag', argument)
311
312  def flag_type(self):
313    """See base class."""
314    return 'bool'
315
316
317class EnumParser(ArgumentParser):
318  """Parser of a string enum value (a string value from a given set)."""
319
320  def __init__(self, enum_values, case_sensitive=True):
321    """Initializes EnumParser.
322
323    Args:
324      enum_values: [str], a non-empty list of string values in the enum.
325      case_sensitive: bool, whether or not the enum is to be case-sensitive.
326
327    Raises:
328      ValueError: When enum_values is empty.
329    """
330    if not enum_values:
331      raise ValueError(
332          'enum_values cannot be empty, found "{}"'.format(enum_values))
333    super(EnumParser, self).__init__()
334    self.enum_values = enum_values
335    self.case_sensitive = case_sensitive
336
337  def parse(self, argument):
338    """Determines validity of argument and returns the correct element of enum.
339
340    Args:
341      argument: str, the supplied flag value.
342
343    Returns:
344      The first matching element from enum_values.
345
346    Raises:
347      ValueError: Raised when argument didn't match anything in enum.
348    """
349    if self.case_sensitive:
350      if argument not in self.enum_values:
351        raise ValueError('value should be one of <%s>' %
352                         '|'.join(self.enum_values))
353      else:
354        return argument
355    else:
356      if argument.upper() not in [value.upper() for value in self.enum_values]:
357        raise ValueError('value should be one of <%s>' %
358                         '|'.join(self.enum_values))
359      else:
360        return [value for value in self.enum_values
361                if value.upper() == argument.upper()][0]
362
363  def flag_type(self):
364    """See base class."""
365    return 'string enum'
366
367
368class EnumClassParser(ArgumentParser):
369  """Parser of an Enum class member."""
370
371  def __init__(self, enum_class, case_sensitive=True):
372    """Initializes EnumParser.
373
374    Args:
375      enum_class: class, the Enum class with all possible flag values.
376      case_sensitive: bool, whether or not the enum is to be case-sensitive. If
377        False, all member names must be unique when case is ignored.
378
379    Raises:
380      TypeError: When enum_class is not a subclass of Enum.
381      ValueError: When enum_class is empty.
382    """
383    # Users must have an Enum class defined before using EnumClass flag.
384    # Therefore this dependency is guaranteed.
385    import enum
386
387    if not issubclass(enum_class, enum.Enum):
388      raise TypeError('{} is not a subclass of Enum.'.format(enum_class))
389    if not enum_class.__members__:
390      raise ValueError('enum_class cannot be empty, but "{}" is empty.'
391                       .format(enum_class))
392    if not case_sensitive:
393      members = collections.Counter(
394          name.lower() for name in enum_class.__members__)
395      duplicate_keys = {
396          member for member, count in members.items() if count > 1
397      }
398      if duplicate_keys:
399        raise ValueError(
400            'Duplicate enum values for {} using case_sensitive=False'.format(
401                duplicate_keys))
402
403    super(EnumClassParser, self).__init__()
404    self.enum_class = enum_class
405    self._case_sensitive = case_sensitive
406    if case_sensitive:
407      self._member_names = tuple(enum_class.__members__)
408    else:
409      self._member_names = tuple(
410          name.lower() for name in enum_class.__members__)
411
412  @property
413  def member_names(self):
414    """The accepted enum names, in lowercase if not case sensitive."""
415    return self._member_names
416
417  def parse(self, argument):
418    """Determines validity of argument and returns the correct element of enum.
419
420    Args:
421      argument: str or Enum class member, the supplied flag value.
422
423    Returns:
424      The first matching Enum class member in Enum class.
425
426    Raises:
427      ValueError: Raised when argument didn't match anything in enum.
428    """
429    if isinstance(argument, self.enum_class):
430      return argument
431    elif not isinstance(argument, str):
432      raise ValueError(
433          '{} is not an enum member or a name of a member in {}'.format(
434              argument, self.enum_class))
435    key = EnumParser(
436        self._member_names, case_sensitive=self._case_sensitive).parse(argument)
437    if self._case_sensitive:
438      return self.enum_class[key]
439    else:
440      # If EnumParser.parse() return a value, we're guaranteed to find it
441      # as a member of the class
442      return next(value for name, value in self.enum_class.__members__.items()
443                  if name.lower() == key.lower())
444
445  def flag_type(self):
446    """See base class."""
447    return 'enum class'
448
449
450class ListSerializer(ArgumentSerializer):
451
452  def __init__(self, list_sep):
453    self.list_sep = list_sep
454
455  def serialize(self, value):
456    """See base class."""
457    return self.list_sep.join([str(x) for x in value])
458
459
460class EnumClassListSerializer(ListSerializer):
461  """A serializer for :class:`MultiEnumClass` flags.
462
463  This serializer simply joins the output of `EnumClassSerializer` using a
464  provided separator.
465  """
466
467  def __init__(self, list_sep, **kwargs):
468    """Initializes EnumClassListSerializer.
469
470    Args:
471      list_sep: String to be used as a separator when serializing
472      **kwargs: Keyword arguments to the `EnumClassSerializer` used to serialize
473        individual values.
474    """
475    super(EnumClassListSerializer, self).__init__(list_sep)
476    self._element_serializer = EnumClassSerializer(**kwargs)
477
478  def serialize(self, value):
479    """See base class."""
480    if isinstance(value, list):
481      return self.list_sep.join(
482          self._element_serializer.serialize(x) for x in value)
483    else:
484      return self._element_serializer.serialize(value)
485
486
487class CsvListSerializer(ArgumentSerializer):
488
489  def __init__(self, list_sep):
490    self.list_sep = list_sep
491
492  def serialize(self, value):
493    """Serializes a list as a CSV string or unicode."""
494    output = io.StringIO()
495    writer = csv.writer(output, delimiter=self.list_sep)
496    writer.writerow([str(x) for x in value])
497    serialized_value = output.getvalue().strip()
498
499    # We need the returned value to be pure ascii or Unicodes so that
500    # when the xml help is generated they are usefully encodable.
501    return str(serialized_value)
502
503
504class EnumClassSerializer(ArgumentSerializer):
505  """Class for generating string representations of an enum class flag value."""
506
507  def __init__(self, lowercase):
508    """Initializes EnumClassSerializer.
509
510    Args:
511      lowercase: If True, enum member names are lowercased during serialization.
512    """
513    self._lowercase = lowercase
514
515  def serialize(self, value):
516    """Returns a serialized string of the Enum class value."""
517    as_string = str(value.name)
518    return as_string.lower() if self._lowercase else as_string
519
520
521class BaseListParser(ArgumentParser):
522  """Base class for a parser of lists of strings.
523
524  To extend, inherit from this class; from the subclass ``__init__``, call::
525
526      super().__init__(token, name)
527
528  where token is a character used to tokenize, and name is a description
529  of the separator.
530  """
531
532  def __init__(self, token=None, name=None):
533    assert name
534    super(BaseListParser, self).__init__()
535    self._token = token
536    self._name = name
537    self.syntactic_help = 'a %s separated list' % self._name
538
539  def parse(self, argument):
540    """See base class."""
541    if isinstance(argument, list):
542      return argument
543    elif not argument:
544      return []
545    else:
546      return [s.strip() for s in argument.split(self._token)]
547
548  def flag_type(self):
549    """See base class."""
550    return '%s separated list of strings' % self._name
551
552
553class ListParser(BaseListParser):
554  """Parser for a comma-separated list of strings."""
555
556  def __init__(self):
557    super(ListParser, self).__init__(',', 'comma')
558
559  def parse(self, argument):
560    """Parses argument as comma-separated list of strings."""
561    if isinstance(argument, list):
562      return argument
563    elif not argument:
564      return []
565    else:
566      try:
567        return [s.strip() for s in list(csv.reader([argument], strict=True))[0]]
568      except csv.Error as e:
569        # Provide a helpful report for case like
570        #   --listflag="$(printf 'hello,\nworld')"
571        # IOW, list flag values containing naked newlines.  This error
572        # was previously "reported" by allowing csv.Error to
573        # propagate.
574        raise ValueError('Unable to parse the value %r as a %s: %s'
575                         % (argument, self.flag_type(), e))
576
577  def _custom_xml_dom_elements(self, doc):
578    elements = super(ListParser, self)._custom_xml_dom_elements(doc)
579    elements.append(_helpers.create_xml_dom_element(
580        doc, 'list_separator', repr(',')))
581    return elements
582
583
584class WhitespaceSeparatedListParser(BaseListParser):
585  """Parser for a whitespace-separated list of strings."""
586
587  def __init__(self, comma_compat=False):
588    """Initializer.
589
590    Args:
591      comma_compat: bool, whether to support comma as an additional separator.
592          If False then only whitespace is supported.  This is intended only for
593          backwards compatibility with flags that used to be comma-separated.
594    """
595    self._comma_compat = comma_compat
596    name = 'whitespace or comma' if self._comma_compat else 'whitespace'
597    super(WhitespaceSeparatedListParser, self).__init__(None, name)
598
599  def parse(self, argument):
600    """Parses argument as whitespace-separated list of strings.
601
602    It also parses argument as comma-separated list of strings if requested.
603
604    Args:
605      argument: string argument passed in the commandline.
606
607    Returns:
608      [str], the parsed flag value.
609    """
610    if isinstance(argument, list):
611      return argument
612    elif not argument:
613      return []
614    else:
615      if self._comma_compat:
616        argument = argument.replace(',', ' ')
617      return argument.split()
618
619  def _custom_xml_dom_elements(self, doc):
620    elements = super(WhitespaceSeparatedListParser, self
621                    )._custom_xml_dom_elements(doc)
622    separators = list(string.whitespace)
623    if self._comma_compat:
624      separators.append(',')
625    separators.sort()
626    for sep_char in separators:
627      elements.append(_helpers.create_xml_dom_element(
628          doc, 'list_separator', repr(sep_char)))
629    return elements
630