1# Copyright 2017 The Abseil Authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Contains base classes used to parse and convert arguments. 16 17Do NOT import this module directly. Import the flags package and use the 18aliases defined at the package level instead. 19""" 20 21import collections 22import csv 23import io 24import string 25 26from absl.flags import _helpers 27 28 29def _is_integer_type(instance): 30 """Returns True if instance is an integer, and not a bool.""" 31 return (isinstance(instance, int) and 32 not isinstance(instance, bool)) 33 34 35class _ArgumentParserCache(type): 36 """Metaclass used to cache and share argument parsers among flags.""" 37 38 _instances = {} 39 40 def __call__(cls, *args, **kwargs): 41 """Returns an instance of the argument parser cls. 42 43 This method overrides behavior of the __new__ methods in 44 all subclasses of ArgumentParser (inclusive). If an instance 45 for cls with the same set of arguments exists, this instance is 46 returned, otherwise a new instance is created. 47 48 If any keyword arguments are defined, or the values in args 49 are not hashable, this method always returns a new instance of 50 cls. 51 52 Args: 53 *args: Positional initializer arguments. 54 **kwargs: Initializer keyword arguments. 55 56 Returns: 57 An instance of cls, shared or new. 58 """ 59 if kwargs: 60 return type.__call__(cls, *args, **kwargs) 61 else: 62 instances = cls._instances 63 key = (cls,) + tuple(args) 64 try: 65 return instances[key] 66 except KeyError: 67 # No cache entry for key exists, create a new one. 68 return instances.setdefault(key, type.__call__(cls, *args)) 69 except TypeError: 70 # An object in args cannot be hashed, always return 71 # a new instance. 72 return type.__call__(cls, *args) 73 74 75# NOTE about Genericity and Metaclass of ArgumentParser. 76# (1) In the .py source (this file) 77# - is not declared as Generic 78# - has _ArgumentParserCache as a metaclass 79# (2) In the .pyi source (type stub) 80# - is declared as Generic 81# - doesn't have a metaclass 82# The reason we need this is due to Generic having a different metaclass 83# (for python versions <= 3.7) and a class can have only one metaclass. 84# 85# * Lack of metaclass in .pyi is not a deal breaker, since the metaclass 86# doesn't affect any type information. Also type checkers can check the type 87# parameters. 88# * However, not declaring ArgumentParser as Generic in the source affects 89# runtime annotation processing. In particular this means, subclasses should 90# inherit from `ArgumentParser` and not `ArgumentParser[SomeType]`. 91# The corresponding DEFINE_someType method (the public API) can be annotated 92# to return FlagHolder[SomeType]. 93class ArgumentParser(metaclass=_ArgumentParserCache): 94 """Base class used to parse and convert arguments. 95 96 The :meth:`parse` method checks to make sure that the string argument is a 97 legal value and convert it to a native type. If the value cannot be 98 converted, it should throw a ``ValueError`` exception with a human 99 readable explanation of why the value is illegal. 100 101 Subclasses should also define a syntactic_help string which may be 102 presented to the user to describe the form of the legal values. 103 104 Argument parser classes must be stateless, since instances are cached 105 and shared between flags. Initializer arguments are allowed, but all 106 member variables must be derived from initializer arguments only. 107 """ 108 109 syntactic_help = '' 110 111 def parse(self, argument): 112 """Parses the string argument and returns the native value. 113 114 By default it returns its argument unmodified. 115 116 Args: 117 argument: string argument passed in the commandline. 118 119 Raises: 120 ValueError: Raised when it fails to parse the argument. 121 TypeError: Raised when the argument has the wrong type. 122 123 Returns: 124 The parsed value in native type. 125 """ 126 if not isinstance(argument, str): 127 raise TypeError('flag value must be a string, found "{}"'.format( 128 type(argument))) 129 return argument 130 131 def flag_type(self): 132 """Returns a string representing the type of the flag.""" 133 return 'string' 134 135 def _custom_xml_dom_elements(self, doc): 136 """Returns a list of minidom.Element to add additional flag information. 137 138 Args: 139 doc: minidom.Document, the DOM document it should create nodes from. 140 """ 141 del doc # Unused. 142 return [] 143 144 145class ArgumentSerializer(object): 146 """Base class for generating string representations of a flag value.""" 147 148 def serialize(self, value): 149 """Returns a serialized string of the value.""" 150 return str(value) 151 152 153class NumericParser(ArgumentParser): 154 """Parser of numeric values. 155 156 Parsed value may be bounded to a given upper and lower bound. 157 """ 158 159 def is_outside_bounds(self, val): 160 """Returns whether the value is outside the bounds or not.""" 161 return ((self.lower_bound is not None and val < self.lower_bound) or 162 (self.upper_bound is not None and val > self.upper_bound)) 163 164 def parse(self, argument): 165 """See base class.""" 166 val = self.convert(argument) 167 if self.is_outside_bounds(val): 168 raise ValueError('%s is not %s' % (val, self.syntactic_help)) 169 return val 170 171 def _custom_xml_dom_elements(self, doc): 172 elements = [] 173 if self.lower_bound is not None: 174 elements.append(_helpers.create_xml_dom_element( 175 doc, 'lower_bound', self.lower_bound)) 176 if self.upper_bound is not None: 177 elements.append(_helpers.create_xml_dom_element( 178 doc, 'upper_bound', self.upper_bound)) 179 return elements 180 181 def convert(self, argument): 182 """Returns the correct numeric value of argument. 183 184 Subclass must implement this method, and raise TypeError if argument is not 185 string or has the right numeric type. 186 187 Args: 188 argument: string argument passed in the commandline, or the numeric type. 189 190 Raises: 191 TypeError: Raised when argument is not a string or the right numeric type. 192 ValueError: Raised when failed to convert argument to the numeric value. 193 """ 194 raise NotImplementedError 195 196 197class FloatParser(NumericParser): 198 """Parser of floating point values. 199 200 Parsed value may be bounded to a given upper and lower bound. 201 """ 202 number_article = 'a' 203 number_name = 'number' 204 syntactic_help = ' '.join((number_article, number_name)) 205 206 def __init__(self, lower_bound=None, upper_bound=None): 207 super(FloatParser, self).__init__() 208 self.lower_bound = lower_bound 209 self.upper_bound = upper_bound 210 sh = self.syntactic_help 211 if lower_bound is not None and upper_bound is not None: 212 sh = ('%s in the range [%s, %s]' % (sh, lower_bound, upper_bound)) 213 elif lower_bound == 0: 214 sh = 'a non-negative %s' % self.number_name 215 elif upper_bound == 0: 216 sh = 'a non-positive %s' % self.number_name 217 elif upper_bound is not None: 218 sh = '%s <= %s' % (self.number_name, upper_bound) 219 elif lower_bound is not None: 220 sh = '%s >= %s' % (self.number_name, lower_bound) 221 self.syntactic_help = sh 222 223 def convert(self, argument): 224 """Returns the float value of argument.""" 225 if (_is_integer_type(argument) or isinstance(argument, float) or 226 isinstance(argument, str)): 227 return float(argument) 228 else: 229 raise TypeError( 230 'Expect argument to be a string, int, or float, found {}'.format( 231 type(argument))) 232 233 def flag_type(self): 234 """See base class.""" 235 return 'float' 236 237 238class IntegerParser(NumericParser): 239 """Parser of an integer value. 240 241 Parsed value may be bounded to a given upper and lower bound. 242 """ 243 number_article = 'an' 244 number_name = 'integer' 245 syntactic_help = ' '.join((number_article, number_name)) 246 247 def __init__(self, lower_bound=None, upper_bound=None): 248 super(IntegerParser, self).__init__() 249 self.lower_bound = lower_bound 250 self.upper_bound = upper_bound 251 sh = self.syntactic_help 252 if lower_bound is not None and upper_bound is not None: 253 sh = ('%s in the range [%s, %s]' % (sh, lower_bound, upper_bound)) 254 elif lower_bound == 1: 255 sh = 'a positive %s' % self.number_name 256 elif upper_bound == -1: 257 sh = 'a negative %s' % self.number_name 258 elif lower_bound == 0: 259 sh = 'a non-negative %s' % self.number_name 260 elif upper_bound == 0: 261 sh = 'a non-positive %s' % self.number_name 262 elif upper_bound is not None: 263 sh = '%s <= %s' % (self.number_name, upper_bound) 264 elif lower_bound is not None: 265 sh = '%s >= %s' % (self.number_name, lower_bound) 266 self.syntactic_help = sh 267 268 def convert(self, argument): 269 """Returns the int value of argument.""" 270 if _is_integer_type(argument): 271 return argument 272 elif isinstance(argument, str): 273 base = 10 274 if len(argument) > 2 and argument[0] == '0': 275 if argument[1] == 'o': 276 base = 8 277 elif argument[1] == 'x': 278 base = 16 279 return int(argument, base) 280 else: 281 raise TypeError('Expect argument to be a string or int, found {}'.format( 282 type(argument))) 283 284 def flag_type(self): 285 """See base class.""" 286 return 'int' 287 288 289class BooleanParser(ArgumentParser): 290 """Parser of boolean values.""" 291 292 def parse(self, argument): 293 """See base class.""" 294 if isinstance(argument, str): 295 if argument.lower() in ('true', 't', '1'): 296 return True 297 elif argument.lower() in ('false', 'f', '0'): 298 return False 299 else: 300 raise ValueError('Non-boolean argument to boolean flag', argument) 301 elif isinstance(argument, int): 302 # Only allow bool or integer 0, 1. 303 # Note that float 1.0 == True, 0.0 == False. 304 bool_value = bool(argument) 305 if argument == bool_value: 306 return bool_value 307 else: 308 raise ValueError('Non-boolean argument to boolean flag', argument) 309 310 raise TypeError('Non-boolean argument to boolean flag', argument) 311 312 def flag_type(self): 313 """See base class.""" 314 return 'bool' 315 316 317class EnumParser(ArgumentParser): 318 """Parser of a string enum value (a string value from a given set).""" 319 320 def __init__(self, enum_values, case_sensitive=True): 321 """Initializes EnumParser. 322 323 Args: 324 enum_values: [str], a non-empty list of string values in the enum. 325 case_sensitive: bool, whether or not the enum is to be case-sensitive. 326 327 Raises: 328 ValueError: When enum_values is empty. 329 """ 330 if not enum_values: 331 raise ValueError( 332 'enum_values cannot be empty, found "{}"'.format(enum_values)) 333 super(EnumParser, self).__init__() 334 self.enum_values = enum_values 335 self.case_sensitive = case_sensitive 336 337 def parse(self, argument): 338 """Determines validity of argument and returns the correct element of enum. 339 340 Args: 341 argument: str, the supplied flag value. 342 343 Returns: 344 The first matching element from enum_values. 345 346 Raises: 347 ValueError: Raised when argument didn't match anything in enum. 348 """ 349 if self.case_sensitive: 350 if argument not in self.enum_values: 351 raise ValueError('value should be one of <%s>' % 352 '|'.join(self.enum_values)) 353 else: 354 return argument 355 else: 356 if argument.upper() not in [value.upper() for value in self.enum_values]: 357 raise ValueError('value should be one of <%s>' % 358 '|'.join(self.enum_values)) 359 else: 360 return [value for value in self.enum_values 361 if value.upper() == argument.upper()][0] 362 363 def flag_type(self): 364 """See base class.""" 365 return 'string enum' 366 367 368class EnumClassParser(ArgumentParser): 369 """Parser of an Enum class member.""" 370 371 def __init__(self, enum_class, case_sensitive=True): 372 """Initializes EnumParser. 373 374 Args: 375 enum_class: class, the Enum class with all possible flag values. 376 case_sensitive: bool, whether or not the enum is to be case-sensitive. If 377 False, all member names must be unique when case is ignored. 378 379 Raises: 380 TypeError: When enum_class is not a subclass of Enum. 381 ValueError: When enum_class is empty. 382 """ 383 # Users must have an Enum class defined before using EnumClass flag. 384 # Therefore this dependency is guaranteed. 385 import enum 386 387 if not issubclass(enum_class, enum.Enum): 388 raise TypeError('{} is not a subclass of Enum.'.format(enum_class)) 389 if not enum_class.__members__: 390 raise ValueError('enum_class cannot be empty, but "{}" is empty.' 391 .format(enum_class)) 392 if not case_sensitive: 393 members = collections.Counter( 394 name.lower() for name in enum_class.__members__) 395 duplicate_keys = { 396 member for member, count in members.items() if count > 1 397 } 398 if duplicate_keys: 399 raise ValueError( 400 'Duplicate enum values for {} using case_sensitive=False'.format( 401 duplicate_keys)) 402 403 super(EnumClassParser, self).__init__() 404 self.enum_class = enum_class 405 self._case_sensitive = case_sensitive 406 if case_sensitive: 407 self._member_names = tuple(enum_class.__members__) 408 else: 409 self._member_names = tuple( 410 name.lower() for name in enum_class.__members__) 411 412 @property 413 def member_names(self): 414 """The accepted enum names, in lowercase if not case sensitive.""" 415 return self._member_names 416 417 def parse(self, argument): 418 """Determines validity of argument and returns the correct element of enum. 419 420 Args: 421 argument: str or Enum class member, the supplied flag value. 422 423 Returns: 424 The first matching Enum class member in Enum class. 425 426 Raises: 427 ValueError: Raised when argument didn't match anything in enum. 428 """ 429 if isinstance(argument, self.enum_class): 430 return argument 431 elif not isinstance(argument, str): 432 raise ValueError( 433 '{} is not an enum member or a name of a member in {}'.format( 434 argument, self.enum_class)) 435 key = EnumParser( 436 self._member_names, case_sensitive=self._case_sensitive).parse(argument) 437 if self._case_sensitive: 438 return self.enum_class[key] 439 else: 440 # If EnumParser.parse() return a value, we're guaranteed to find it 441 # as a member of the class 442 return next(value for name, value in self.enum_class.__members__.items() 443 if name.lower() == key.lower()) 444 445 def flag_type(self): 446 """See base class.""" 447 return 'enum class' 448 449 450class ListSerializer(ArgumentSerializer): 451 452 def __init__(self, list_sep): 453 self.list_sep = list_sep 454 455 def serialize(self, value): 456 """See base class.""" 457 return self.list_sep.join([str(x) for x in value]) 458 459 460class EnumClassListSerializer(ListSerializer): 461 """A serializer for :class:`MultiEnumClass` flags. 462 463 This serializer simply joins the output of `EnumClassSerializer` using a 464 provided separator. 465 """ 466 467 def __init__(self, list_sep, **kwargs): 468 """Initializes EnumClassListSerializer. 469 470 Args: 471 list_sep: String to be used as a separator when serializing 472 **kwargs: Keyword arguments to the `EnumClassSerializer` used to serialize 473 individual values. 474 """ 475 super(EnumClassListSerializer, self).__init__(list_sep) 476 self._element_serializer = EnumClassSerializer(**kwargs) 477 478 def serialize(self, value): 479 """See base class.""" 480 if isinstance(value, list): 481 return self.list_sep.join( 482 self._element_serializer.serialize(x) for x in value) 483 else: 484 return self._element_serializer.serialize(value) 485 486 487class CsvListSerializer(ArgumentSerializer): 488 489 def __init__(self, list_sep): 490 self.list_sep = list_sep 491 492 def serialize(self, value): 493 """Serializes a list as a CSV string or unicode.""" 494 output = io.StringIO() 495 writer = csv.writer(output, delimiter=self.list_sep) 496 writer.writerow([str(x) for x in value]) 497 serialized_value = output.getvalue().strip() 498 499 # We need the returned value to be pure ascii or Unicodes so that 500 # when the xml help is generated they are usefully encodable. 501 return str(serialized_value) 502 503 504class EnumClassSerializer(ArgumentSerializer): 505 """Class for generating string representations of an enum class flag value.""" 506 507 def __init__(self, lowercase): 508 """Initializes EnumClassSerializer. 509 510 Args: 511 lowercase: If True, enum member names are lowercased during serialization. 512 """ 513 self._lowercase = lowercase 514 515 def serialize(self, value): 516 """Returns a serialized string of the Enum class value.""" 517 as_string = str(value.name) 518 return as_string.lower() if self._lowercase else as_string 519 520 521class BaseListParser(ArgumentParser): 522 """Base class for a parser of lists of strings. 523 524 To extend, inherit from this class; from the subclass ``__init__``, call:: 525 526 super().__init__(token, name) 527 528 where token is a character used to tokenize, and name is a description 529 of the separator. 530 """ 531 532 def __init__(self, token=None, name=None): 533 assert name 534 super(BaseListParser, self).__init__() 535 self._token = token 536 self._name = name 537 self.syntactic_help = 'a %s separated list' % self._name 538 539 def parse(self, argument): 540 """See base class.""" 541 if isinstance(argument, list): 542 return argument 543 elif not argument: 544 return [] 545 else: 546 return [s.strip() for s in argument.split(self._token)] 547 548 def flag_type(self): 549 """See base class.""" 550 return '%s separated list of strings' % self._name 551 552 553class ListParser(BaseListParser): 554 """Parser for a comma-separated list of strings.""" 555 556 def __init__(self): 557 super(ListParser, self).__init__(',', 'comma') 558 559 def parse(self, argument): 560 """Parses argument as comma-separated list of strings.""" 561 if isinstance(argument, list): 562 return argument 563 elif not argument: 564 return [] 565 else: 566 try: 567 return [s.strip() for s in list(csv.reader([argument], strict=True))[0]] 568 except csv.Error as e: 569 # Provide a helpful report for case like 570 # --listflag="$(printf 'hello,\nworld')" 571 # IOW, list flag values containing naked newlines. This error 572 # was previously "reported" by allowing csv.Error to 573 # propagate. 574 raise ValueError('Unable to parse the value %r as a %s: %s' 575 % (argument, self.flag_type(), e)) 576 577 def _custom_xml_dom_elements(self, doc): 578 elements = super(ListParser, self)._custom_xml_dom_elements(doc) 579 elements.append(_helpers.create_xml_dom_element( 580 doc, 'list_separator', repr(','))) 581 return elements 582 583 584class WhitespaceSeparatedListParser(BaseListParser): 585 """Parser for a whitespace-separated list of strings.""" 586 587 def __init__(self, comma_compat=False): 588 """Initializer. 589 590 Args: 591 comma_compat: bool, whether to support comma as an additional separator. 592 If False then only whitespace is supported. This is intended only for 593 backwards compatibility with flags that used to be comma-separated. 594 """ 595 self._comma_compat = comma_compat 596 name = 'whitespace or comma' if self._comma_compat else 'whitespace' 597 super(WhitespaceSeparatedListParser, self).__init__(None, name) 598 599 def parse(self, argument): 600 """Parses argument as whitespace-separated list of strings. 601 602 It also parses argument as comma-separated list of strings if requested. 603 604 Args: 605 argument: string argument passed in the commandline. 606 607 Returns: 608 [str], the parsed flag value. 609 """ 610 if isinstance(argument, list): 611 return argument 612 elif not argument: 613 return [] 614 else: 615 if self._comma_compat: 616 argument = argument.replace(',', ' ') 617 return argument.split() 618 619 def _custom_xml_dom_elements(self, doc): 620 elements = super(WhitespaceSeparatedListParser, self 621 )._custom_xml_dom_elements(doc) 622 separators = list(string.whitespace) 623 if self._comma_compat: 624 separators.append(',') 625 separators.sort() 626 for sep_char in separators: 627 elements.append(_helpers.create_xml_dom_element( 628 doc, 'list_separator', repr(sep_char))) 629 return elements 630