• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 The Chromium Authors
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5from __future__ import annotations
6
7import collections
8import re
9from typing import (Any, Dict, Iterable, Iterator, List, Optional, Tuple, Type,
10                    TypeVar, Union)
11
12
13class FrozenFlagsError(RuntimeError):
14  pass
15
16
17FreezableT = TypeVar("FreezableT", bound="Freezable")
18
19
20class Freezable:
21
22  def __init__(self, *args, **kwargs) -> None:
23    self._frozen = False
24    super().__init__(*args, **kwargs)
25
26  def __hash__(self):
27    self.freeze()
28    return hash(str(self))
29
30  @property
31  def is_frozen(self) -> bool:
32    return self._frozen
33
34  def freeze(self: FreezableT) -> FreezableT:
35    self._frozen = True
36    return self
37
38  def assert_not_frozen(self, msg: Optional[str] = None) -> None:
39    if not self._frozen:
40      return
41    if not msg:
42      msg = f"Cannot modify frozen {type(self).__name__}"
43    raise FrozenFlagsError(msg)
44
45
46BasicFlagsT = TypeVar("BasicFlagsT", bound="BasicFlags")
47
48
49FlagsData = Optional[Union[Dict[str, str], "Flags",
50                           Iterable[Union[Tuple[str, Optional[str]], str]]]]
51
52
53class BasicFlags(Freezable, collections.UserDict):
54  """Basic implementation for command line flags (similar to Dic[str, str].
55
56  This class is mostly used to make sure command-line flags for browsers
57  don't end up having contradicting values.
58  """
59
60  _WHITE_SPACE_RE = re.compile(r"\s+")
61  _BASIC_FLAG_NAME_RE = re.compile(r"(--?)[^\s=-][^\s=]*")
62  # Handles space-separated flags: --foo="1" --bar  --baz='2'  --boo=3
63  _VALUE_PATTERN = (r"('(?P<value_single_quotes>[^']*)')|"
64                    r"(\"(?P<value_double_quotes>[^\"]*)\")|"
65                    r"(?P<value_no_quotes>[^'\" ]+)")
66  _END_OR_SEPARATOR_PATTERN = r"(\s*\s\s*|$)"
67  _PARSE_RE = re.compile(fr"(?P<name>{_BASIC_FLAG_NAME_RE.pattern})"
68                         fr"((?P<equal>=)({_VALUE_PATTERN})?)?"
69                         fr"{_END_OR_SEPARATOR_PATTERN}")
70
71  @classmethod
72  def split(cls, flag_str: str) -> Tuple[str, Optional[str]]:
73    if "=" in flag_str:
74      flag_name, flag_value = flag_str.split("=", maxsplit=1)
75      return (flag_name, flag_value)
76    return (flag_str, None)
77
78  @classmethod
79  def parse(cls: Type[BasicFlagsT], data: Any) -> BasicFlagsT:
80    if isinstance(data, cls):
81      return data
82    if isinstance(data, str):
83      return cls.parse_str(data)
84    return cls(data)
85
86  @classmethod
87  def parse_str(cls: Type[BasicFlagsT], raw_flags: str) -> BasicFlagsT:
88    return cls._parse_str(raw_flags)
89
90  @classmethod
91  def _parse_str(cls: Type[BasicFlagsT],
92                 raw_flags: str,
93                 msg: str = "flag") -> BasicFlagsT:
94    raw_flags = raw_flags.strip()
95    if not raw_flags:
96      return cls()
97    flag_parts: List[Tuple[str, Optional[str]]] = []
98    current_end: Optional[int] = None
99    for match in cls._PARSE_RE.finditer(raw_flags):
100      if current_end is None:
101        if match.start() != 0:
102          part = raw_flags[:match.start()]
103          raise ValueError(f"Invalid {msg} part at pos=0: {repr(part)}")
104      else:
105        if current_end != match.start():
106          raise ValueError(f"Invalid {msg}: could not consume all data")
107      current_end = match.end()
108
109      groups = match.groupdict()
110      maybe_flag_name: Optional[str] = groups.get("name")
111      if not maybe_flag_name:
112        raise ValueError(f"Invalid {msg}: {repr(raw_flags)}")
113      # Re-assign since pytype doesn't remove the Optional.
114      flag_name: str = maybe_flag_name
115      flag_value: Optional[str] = (
116          groups.get("value_single_quotes") or
117          groups.get("value_double_quotes") or groups.get("value_no_quotes"))
118      if groups.get("equal") and not flag_value:
119        raise ValueError(f"Invalid {msg}: missing value for {repr(flag_name)}")
120      assert flag_name
121      flag_parts.append((flag_name, flag_value))
122
123    if current_end != len(raw_flags):
124      part = raw_flags[current_end:]
125      raise ValueError(
126          f"Invalid {msg} part at pos={current_end or 0}: {repr(part)}")
127    return cls(flag_parts)
128
129  def __init__(self, initial_data: FlagsData = None) -> None:
130    super().__init__(initial_data)
131
132  def __setitem__(self, flag_name: str, flag_value: Optional[str]) -> None:
133    return self.set(flag_name, flag_value)
134
135  def set(self,
136          flag_name: str,
137          flag_value: Optional[str] = None,
138          override: bool = False) -> None:
139    self._set(flag_name, flag_value, override)
140
141  def _set(self,
142           flag_name: str,
143           flag_value: Optional[str] = None,
144           override: bool = False) -> None:
145    self.assert_not_frozen()
146    self._validate_flag_name(flag_name)
147    if flag_value:
148      self._validate_flag_value(flag_name, flag_value)
149    self._validate_override(flag_name, flag_value, override)
150    self.data[flag_name] = flag_value
151
152  def _validate_flag_name(self, flag_name: str) -> None:
153    if not flag_name:
154      raise ValueError("Cannot set empty flag")
155    if self._WHITE_SPACE_RE.search(flag_name):
156      raise ValueError(
157          f"Flag name cannot contain whitespaces: {repr(flag_name)}")
158    if "=" in flag_name:
159      raise ValueError(
160          f"Flag name contains '=': {repr(flag_name)}, please split")
161    if flag_name[0] != "-":
162      raise ValueError(
163          f"Flag name must begin with a '-', but got {repr(flag_name)}")
164    if not self._BASIC_FLAG_NAME_RE.fullmatch(flag_name):
165      raise ValueError(
166          f"Flag name contains invalid characters: {repr(flag_name)}")
167
168  def _validate_flag_value(self, flag_name: str, flag_value: str) -> None:
169    assert flag_value, "Got invalid empty flag_value."
170    if not isinstance(flag_value, str):
171      raise TypeError(
172          f"Expected None or string flag-value for flag {flag_name}, "
173          f"but got: {repr(flag_value)}")
174
175  def _validate_override(self, flag_name: str, flag_value: Optional[str],
176                         override: bool) -> None:
177    if override or flag_name not in self:
178      return
179    old_value = self[flag_name]
180    if flag_value != old_value:
181      raise ValueError(f"Flag {flag_name}={repr(flag_value)} was already set "
182                       f"with a different previous value: {repr(old_value)}")
183
184  # pylint: disable=arguments-differ
185  def update(self,
186             initial_data: FlagsData = None,
187             override: bool = False) -> None:
188    # pylint: disable=arguments-differ
189    if initial_data is None:
190      return
191    if isinstance(initial_data, (Flags, dict)):
192      for flag_name, flag_value in initial_data.items():
193        self.set(flag_name, flag_value, override)
194    else:
195      for flag_name_or_items in initial_data:
196        if isinstance(flag_name_or_items, str):
197          self.set(flag_name_or_items, None, override)
198        else:
199          flag_name, flag_value = flag_name_or_items
200          self.set(flag_name, flag_value, override)
201
202  def merge(self, other: FlagsData) -> None:
203    self.update(other)
204
205  def copy(self: BasicFlagsT) -> BasicFlagsT:
206    return self.__class__(self)
207
208  def merge_copy(self, other: FlagsData):
209    ret = self.copy()
210    ret.merge(other)
211    return ret
212
213  def _describe(self, flag_name: str) -> str:
214    value = self.get(flag_name)
215    if value is None:
216      return flag_name
217    return f"{flag_name}={value}"
218
219  def items(self) -> Iterable[Tuple[str, Optional[str]]]:
220    return self.data.items()
221
222  def to_dict(self) -> Dict[str, Optional[str]]:
223    return dict(self.items())
224
225  def __iter__(self) -> Iterator[str]:
226    for k, v in self.items():
227      if v is None:
228        yield k
229      else:
230        yield f"{k}={v}"
231
232  def __bool__(self) -> bool:
233    return bool(self.data)
234
235  def __repr__(self) -> str:
236    dict_repr = repr(self.to_dict())
237    return f"{type(self).__name__}({dict_repr})"
238
239  def __str__(self) -> str:
240    return " ".join(self)
241
242
243class Flags(BasicFlags):
244  """
245  Subclass with slightly stricter flag name checking.
246  Most command-line programs adhere to this.
247  """
248  _FLAG_NAME_RE = re.compile(r"(--?)[a-zA-Z0-9][a-zA-Z0-9_-]*")
249  _PARSE_RE = re.compile(fr"(?P<name>{_FLAG_NAME_RE.pattern})"
250                         fr"((?P<equal>=)({BasicFlags._VALUE_PATTERN})?)?"
251                         fr"{BasicFlags._END_OR_SEPARATOR_PATTERN}")
252
253  def _validate_flag_name(self, flag_name: str) -> None:
254    super()._validate_flag_name(flag_name)
255    if not self._FLAG_NAME_RE.fullmatch(flag_name):
256      raise ValueError(
257          f"Flag name contains invalid characters: {repr(flag_name)}")
258
259
260FlagsT = TypeVar("FlagsT", bound=Flags)
261