1# Copyright 2024 The Chromium Authors 2# Use of this source code is governed by a BSD-style license that can be 3# found in the LICENSE file. 4 5from __future__ import annotations 6 7import collections 8import re 9from typing import (Any, Dict, Iterable, Iterator, List, Optional, Tuple, Type, 10 TypeVar, Union) 11 12 13class FrozenFlagsError(RuntimeError): 14 pass 15 16 17FreezableT = TypeVar("FreezableT", bound="Freezable") 18 19 20class Freezable: 21 22 def __init__(self, *args, **kwargs) -> None: 23 self._frozen = False 24 super().__init__(*args, **kwargs) 25 26 def __hash__(self): 27 self.freeze() 28 return hash(str(self)) 29 30 @property 31 def is_frozen(self) -> bool: 32 return self._frozen 33 34 def freeze(self: FreezableT) -> FreezableT: 35 self._frozen = True 36 return self 37 38 def assert_not_frozen(self, msg: Optional[str] = None) -> None: 39 if not self._frozen: 40 return 41 if not msg: 42 msg = f"Cannot modify frozen {type(self).__name__}" 43 raise FrozenFlagsError(msg) 44 45 46BasicFlagsT = TypeVar("BasicFlagsT", bound="BasicFlags") 47 48 49FlagsData = Optional[Union[Dict[str, str], "Flags", 50 Iterable[Union[Tuple[str, Optional[str]], str]]]] 51 52 53class BasicFlags(Freezable, collections.UserDict): 54 """Basic implementation for command line flags (similar to Dic[str, str]. 55 56 This class is mostly used to make sure command-line flags for browsers 57 don't end up having contradicting values. 58 """ 59 60 _WHITE_SPACE_RE = re.compile(r"\s+") 61 _BASIC_FLAG_NAME_RE = re.compile(r"(--?)[^\s=-][^\s=]*") 62 # Handles space-separated flags: --foo="1" --bar --baz='2' --boo=3 63 _VALUE_PATTERN = (r"('(?P<value_single_quotes>[^']*)')|" 64 r"(\"(?P<value_double_quotes>[^\"]*)\")|" 65 r"(?P<value_no_quotes>[^'\" ]+)") 66 _END_OR_SEPARATOR_PATTERN = r"(\s*\s\s*|$)" 67 _PARSE_RE = re.compile(fr"(?P<name>{_BASIC_FLAG_NAME_RE.pattern})" 68 fr"((?P<equal>=)({_VALUE_PATTERN})?)?" 69 fr"{_END_OR_SEPARATOR_PATTERN}") 70 71 @classmethod 72 def split(cls, flag_str: str) -> Tuple[str, Optional[str]]: 73 if "=" in flag_str: 74 flag_name, flag_value = flag_str.split("=", maxsplit=1) 75 return (flag_name, flag_value) 76 return (flag_str, None) 77 78 @classmethod 79 def parse(cls: Type[BasicFlagsT], data: Any) -> BasicFlagsT: 80 if isinstance(data, cls): 81 return data 82 if isinstance(data, str): 83 return cls.parse_str(data) 84 return cls(data) 85 86 @classmethod 87 def parse_str(cls: Type[BasicFlagsT], raw_flags: str) -> BasicFlagsT: 88 return cls._parse_str(raw_flags) 89 90 @classmethod 91 def _parse_str(cls: Type[BasicFlagsT], 92 raw_flags: str, 93 msg: str = "flag") -> BasicFlagsT: 94 raw_flags = raw_flags.strip() 95 if not raw_flags: 96 return cls() 97 flag_parts: List[Tuple[str, Optional[str]]] = [] 98 current_end: Optional[int] = None 99 for match in cls._PARSE_RE.finditer(raw_flags): 100 if current_end is None: 101 if match.start() != 0: 102 part = raw_flags[:match.start()] 103 raise ValueError(f"Invalid {msg} part at pos=0: {repr(part)}") 104 else: 105 if current_end != match.start(): 106 raise ValueError(f"Invalid {msg}: could not consume all data") 107 current_end = match.end() 108 109 groups = match.groupdict() 110 maybe_flag_name: Optional[str] = groups.get("name") 111 if not maybe_flag_name: 112 raise ValueError(f"Invalid {msg}: {repr(raw_flags)}") 113 # Re-assign since pytype doesn't remove the Optional. 114 flag_name: str = maybe_flag_name 115 flag_value: Optional[str] = ( 116 groups.get("value_single_quotes") or 117 groups.get("value_double_quotes") or groups.get("value_no_quotes")) 118 if groups.get("equal") and not flag_value: 119 raise ValueError(f"Invalid {msg}: missing value for {repr(flag_name)}") 120 assert flag_name 121 flag_parts.append((flag_name, flag_value)) 122 123 if current_end != len(raw_flags): 124 part = raw_flags[current_end:] 125 raise ValueError( 126 f"Invalid {msg} part at pos={current_end or 0}: {repr(part)}") 127 return cls(flag_parts) 128 129 def __init__(self, initial_data: FlagsData = None) -> None: 130 super().__init__(initial_data) 131 132 def __setitem__(self, flag_name: str, flag_value: Optional[str]) -> None: 133 return self.set(flag_name, flag_value) 134 135 def set(self, 136 flag_name: str, 137 flag_value: Optional[str] = None, 138 override: bool = False) -> None: 139 self._set(flag_name, flag_value, override) 140 141 def _set(self, 142 flag_name: str, 143 flag_value: Optional[str] = None, 144 override: bool = False) -> None: 145 self.assert_not_frozen() 146 self._validate_flag_name(flag_name) 147 if flag_value: 148 self._validate_flag_value(flag_name, flag_value) 149 self._validate_override(flag_name, flag_value, override) 150 self.data[flag_name] = flag_value 151 152 def _validate_flag_name(self, flag_name: str) -> None: 153 if not flag_name: 154 raise ValueError("Cannot set empty flag") 155 if self._WHITE_SPACE_RE.search(flag_name): 156 raise ValueError( 157 f"Flag name cannot contain whitespaces: {repr(flag_name)}") 158 if "=" in flag_name: 159 raise ValueError( 160 f"Flag name contains '=': {repr(flag_name)}, please split") 161 if flag_name[0] != "-": 162 raise ValueError( 163 f"Flag name must begin with a '-', but got {repr(flag_name)}") 164 if not self._BASIC_FLAG_NAME_RE.fullmatch(flag_name): 165 raise ValueError( 166 f"Flag name contains invalid characters: {repr(flag_name)}") 167 168 def _validate_flag_value(self, flag_name: str, flag_value: str) -> None: 169 assert flag_value, "Got invalid empty flag_value." 170 if not isinstance(flag_value, str): 171 raise TypeError( 172 f"Expected None or string flag-value for flag {flag_name}, " 173 f"but got: {repr(flag_value)}") 174 175 def _validate_override(self, flag_name: str, flag_value: Optional[str], 176 override: bool) -> None: 177 if override or flag_name not in self: 178 return 179 old_value = self[flag_name] 180 if flag_value != old_value: 181 raise ValueError(f"Flag {flag_name}={repr(flag_value)} was already set " 182 f"with a different previous value: {repr(old_value)}") 183 184 # pylint: disable=arguments-differ 185 def update(self, 186 initial_data: FlagsData = None, 187 override: bool = False) -> None: 188 # pylint: disable=arguments-differ 189 if initial_data is None: 190 return 191 if isinstance(initial_data, (Flags, dict)): 192 for flag_name, flag_value in initial_data.items(): 193 self.set(flag_name, flag_value, override) 194 else: 195 for flag_name_or_items in initial_data: 196 if isinstance(flag_name_or_items, str): 197 self.set(flag_name_or_items, None, override) 198 else: 199 flag_name, flag_value = flag_name_or_items 200 self.set(flag_name, flag_value, override) 201 202 def merge(self, other: FlagsData) -> None: 203 self.update(other) 204 205 def copy(self: BasicFlagsT) -> BasicFlagsT: 206 return self.__class__(self) 207 208 def merge_copy(self, other: FlagsData): 209 ret = self.copy() 210 ret.merge(other) 211 return ret 212 213 def _describe(self, flag_name: str) -> str: 214 value = self.get(flag_name) 215 if value is None: 216 return flag_name 217 return f"{flag_name}={value}" 218 219 def items(self) -> Iterable[Tuple[str, Optional[str]]]: 220 return self.data.items() 221 222 def to_dict(self) -> Dict[str, Optional[str]]: 223 return dict(self.items()) 224 225 def __iter__(self) -> Iterator[str]: 226 for k, v in self.items(): 227 if v is None: 228 yield k 229 else: 230 yield f"{k}={v}" 231 232 def __bool__(self) -> bool: 233 return bool(self.data) 234 235 def __repr__(self) -> str: 236 dict_repr = repr(self.to_dict()) 237 return f"{type(self).__name__}({dict_repr})" 238 239 def __str__(self) -> str: 240 return " ".join(self) 241 242 243class Flags(BasicFlags): 244 """ 245 Subclass with slightly stricter flag name checking. 246 Most command-line programs adhere to this. 247 """ 248 _FLAG_NAME_RE = re.compile(r"(--?)[a-zA-Z0-9][a-zA-Z0-9_-]*") 249 _PARSE_RE = re.compile(fr"(?P<name>{_FLAG_NAME_RE.pattern})" 250 fr"((?P<equal>=)({BasicFlags._VALUE_PATTERN})?)?" 251 fr"{BasicFlags._END_OR_SEPARATOR_PATTERN}") 252 253 def _validate_flag_name(self, flag_name: str) -> None: 254 super()._validate_flag_name(flag_name) 255 if not self._FLAG_NAME_RE.fullmatch(flag_name): 256 raise ValueError( 257 f"Flag name contains invalid characters: {repr(flag_name)}") 258 259 260FlagsT = TypeVar("FlagsT", bound=Flags) 261