• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 The Chromium Authors
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5from __future__ import annotations
6
7import argparse
8import contextlib
9import dataclasses
10import functools
11import logging
12from typing import (TYPE_CHECKING, Any, Dict, Final, Iterator, List, Optional,
13                    Sequence, Set, TextIO, Tuple, Type, Union, cast)
14
15import hjson
16from immutabledict import immutabledict
17from ordered_set import OrderedSet
18
19import crossbench.browsers.all as browsers
20from crossbench import exception
21from crossbench import path as pth
22from crossbench import plt
23from crossbench.browsers.browser_helper import convert_flags_to_label
24from crossbench.browsers.chrome.downloader import ChromeDownloader
25from crossbench.browsers.firefox.downloader import FirefoxDownloader
26from crossbench.browsers.settings import Settings
27from crossbench.cli.config.browser import BrowserConfig
28from crossbench.cli.config.driver import BrowserDriverType
29from crossbench.cli.config.network import NetworkConfig
30from crossbench.config import ConfigError, ConfigObject
31from crossbench.flags.base import Flags
32from crossbench.flags.chrome import ChromeFlags
33from crossbench.network.base import Network
34from crossbench.parse import LateArgumentError, ObjectParser
35
36if TYPE_CHECKING:
37  from crossbench.browsers.browser import Browser
38  FlagGroupItemT = Optional[Tuple[str, Optional[str]]]
39  BrowserLookupTableT = Dict[str, Tuple[Type[Browser], "BrowserConfig"]]
40
41
42@contextlib.contextmanager
43def late_argument_type_error_wrapper(flag: str) -> Iterator[None]:
44  """Converts raised ValueError and ArgumentTypeError to LateArgumentError
45  that are associated with the given flag.
46  """
47  try:
48    yield
49  except Exception as e:
50    raise LateArgumentError(flag, str(e)) from e
51
52
53def _flags_to_label(flags: Flags) -> str:
54  return convert_flags_to_label(*flags)
55
56
57FlagItemT = Tuple[str, Optional[str]]
58FlagVariantsDictT = Dict[str, List[str]]
59
60DEFAULT_LABEL: Final[str] = "default"
61
62
63@dataclasses.dataclass(frozen=True)
64class FlagsVariantConfig:
65  label: str
66  index: int = 0
67  flags: Flags = dataclasses.field(default_factory=lambda: Flags().freeze())
68
69  @classmethod
70  def parse(cls, name: str, index: int, data: Any) -> FlagsVariantConfig:
71    return cls(name, index, Flags.parse(data).freeze())
72
73  def merge_copy(self,
74                 other: FlagsVariantConfig,
75                 label: Optional[str] = None,
76                 index: int = -1) -> FlagsVariantConfig:
77    index = self.index if index < 0 else index
78    new_label = label or f"{self.label}_{other.label}"
79    return FlagsVariantConfig(new_label, index,
80                              self.flags.merge_copy(other.flags).freeze())
81
82  def __hash__(self) -> int:
83    return hash(self.flags)
84
85  def __eq__(self, other: Any) -> bool:
86    if not isinstance(other, FlagsVariantConfig):
87      return False
88    return self.flags == other.flags
89
90
91try:
92  FlagsGroupConfigTuple = tuple[FlagsVariantConfig, ...]
93except:  # pylint: disable=bare-except
94  # Python 3.8 fallback
95  FlagsGroupConfigTuple = tuple
96
97
98class FlagsGroupConfig(FlagsGroupConfigTuple):
99  """
100  Config container for a list of FlagsVariantConfig:
101  FlagsGroupConfig(
102    FlagsVariantConfig("default"),
103    FlagsVariantConfig("max_opt_1", "--js-flags='--max-opt=1'),
104    FlagsVariantConfig("max_opt_2", "--js-flags='--max-opt=2'),
105    ...
106  )
107  """
108
109  @classmethod
110  def parse(cls, data: Any) -> FlagsGroupConfig:
111    if data is None:
112      return FlagsGroupConfig()
113    if isinstance(data, str):
114      return cls.parse_str(data)
115    if isinstance(data, dict):
116      return cls.parse_dict(data)
117    if isinstance(data, (list, tuple)):
118      return cls.parse_sequence(data)
119    raise ConfigError(f"Invalid type {type(data)}: {repr(data)}")
120
121  @classmethod
122  def parse_dict(cls, config: Dict) -> FlagsGroupConfig:
123    if not config:
124      return FlagsGroupConfig()
125    all_flag_keys = all(key.startswith("-") for key in config.keys())
126    all_str_values = all(isinstance(value, str) for value in config.values())
127    if not all_flag_keys:
128      return cls.parse_dict_with_labels(config)
129    if all_str_values:
130      return cls.parse_dict_simple(config)
131    return cls._parse_variants_dict(config)
132
133  @classmethod
134  def parse_dict_with_labels(cls, config: Dict) -> FlagsGroupConfig:
135    variants: OrderedSet[FlagsVariantConfig] = OrderedSet()
136    logging.debug("Using custom flag group labels")
137    for label, value in config.items():
138      with exception.annotate_argparsing(
139          f"Parsing flag variant ...[{repr(label)}]:"):
140        variant = FlagsVariantConfig.parse(label, len(variants), value)
141        if variant in variants:
142          raise ConfigError(f"Duplicate flag variant: {value}")
143        variants.add(variant)
144    return FlagsGroupConfig(tuple(variants))
145
146  @classmethod
147  def parse_dict_simple(cls, config: Dict) -> FlagsGroupConfig:
148    logging.debug("Using single flag group dict")
149    variants = (FlagsVariantConfig.parse(DEFAULT_LABEL, 0, config),)
150    return FlagsGroupConfig(variants)
151
152  @classmethod
153  def _parse_variants_dict(cls, data: Dict[str, Any]) -> FlagsGroupConfig:
154    # data == {
155    #  "--flag": None,
156    #  "--flag-b": "custom flag value",
157    #  "--flag-c": (None, "value 2", "value 3"),
158    # }
159    cls._validate_variants_dict(data)
160    per_flag_groups: List[FlagsGroupConfig] = []
161    for flag_name, flag_data in data.items():
162      per_flag_groups.append(cls._dict_variant_to_group(flag_name, flag_data))
163
164    variants = per_flag_groups[0]
165    for next_variant in per_flag_groups[1:]:
166      variants = variants.product(next_variant)
167    return variants
168
169  @classmethod
170  def _validate_variants_dict(cls, data: Dict[str, Any]) -> None:
171    flags = Flags()
172    for flag_name, flag_value in data.items():
173      with exception.annotate_argparsing(
174          f"Parsing flag variant ...[{flag_name}]:"):
175        flags.set(flag_name)
176        if flag_value is None:
177          continue
178        if not isinstance(flag_value, (str, list, tuple)):
179          raise ConfigError(
180              f"Invalid flag variant value (None, str or sequence): "
181              f"{flag_name}={repr(flag_value)}")
182        if isinstance(flag_value, (list, tuple)):
183          ObjectParser.unique_sequence(
184              flag_value, f"flag {repr(flag_name)} variant values", ConfigError)
185
186  @classmethod
187  def _dict_variant_to_group(cls, flag_name: str,
188                             data: Any) -> FlagsGroupConfig:
189    if data is None:
190      return cls.parse_str(flag_name)
191    if isinstance(data, str):
192      data_str: str = data.strip()
193      if not data_str:
194        return cls.parse_str(flag_name)
195      data = (data_str,)
196    assert isinstance(data, (list, tuple)), "Invalid flag variant type"
197    flags: OrderedSet[Optional[Flags]] = OrderedSet()
198    for variant in data:
199      if variant is None:
200        flag = None
201      elif not variant.strip():
202        flag = Flags((flag_name,))
203      else:
204        cls._validate_variant_flag(flag_name, variant)
205        flag = Flags({flag_name: variant})
206      if flag in flags:
207        raise ConfigError("Same flag variant was specified more than once: "
208                          f"{repr(flag)} for entry {repr(flag_name)}")
209      flags.add(flag)
210    return cls.parse_sequence(flags)
211
212  @classmethod
213  def _validate_variant_flag(cls, flag_name: str, flag_value: Any) -> None:
214    if flag_value == "None,":
215      raise ConfigError("Please use null (from json) instead of "
216                        f"None (from python) for flag {repr(flag_name)}")
217
218  @classmethod
219  def parse_sequence(cls, data: Sequence) -> FlagsGroupConfig:
220    variants: List[FlagsVariantConfig] = []
221    duplicates: Set[str] = set()
222    for flag_data in data:
223      if not flag_data:
224        flags = Flags()
225      else:
226        flags = Flags.parse(flag_data)
227      if flag_data in duplicates:
228        raise ConfigError(f"Duplicate variant: {flags}")
229      duplicates.add(flag_data)
230      variants.append(
231          FlagsVariantConfig(_flags_to_label(flags), len(variants), flags))
232    return FlagsGroupConfig(tuple(variants))
233
234  @classmethod
235  def parse_str(cls, value: str) -> FlagsGroupConfig:
236    if not value.strip():
237      return FlagsGroupConfig()
238    variants = (FlagsVariantConfig.parse(DEFAULT_LABEL, 0, value),)
239    return FlagsGroupConfig(variants)
240
241  def product(self, *args: FlagsGroupConfig) -> FlagsGroupConfig:
242    return functools.reduce(lambda a, b: a.inner_product(b), args, self)
243
244  def inner_product(self, other: FlagsGroupConfig) -> FlagsGroupConfig:
245    """Create a new FlagsGroupConfig as the combination of
246    self.variants x other.variants"""
247    new_variants: List[FlagsVariantConfig] = []
248    new_labels: Set[str] = set()
249    if not other:
250      return self
251    if not self:
252      return other
253    for variant in self:
254      for variant_other in other:
255        new_label = self._unique_product_label(new_labels, variant,
256                                               variant_other)
257        new_labels.add(new_label)
258        new_variant: FlagsVariantConfig = variant.merge_copy(
259            variant_other, index=len(new_variants), label=new_label)
260        new_variants.append(new_variant)
261
262    return FlagsGroupConfig(tuple(new_variants))
263
264  def _unique_product_label(self, label_set: Set[str],
265                            variant_a: FlagsVariantConfig,
266                            variant_b: FlagsVariantConfig) -> str:
267    default = f"{variant_a.label}_{variant_b.label}"
268    if variant_a.label == DEFAULT_LABEL:
269      default = variant_b.label
270    if variant_b.label == DEFAULT_LABEL:
271      default = variant_a.label
272    label = default
273    if not variant_a.flags:
274      label = variant_b.label
275    if not variant_b.flags:
276      label = variant_a.label
277    if label not in label_set:
278      return label
279    if default not in label_set:
280      return default
281    return f"{default}_{len(label_set)}"
282
283
284class FlagsConfig(ConfigObject, immutabledict[str, FlagsGroupConfig]):
285
286  @classmethod
287  def parse_str(cls, value: str) -> FlagsConfig:
288    if not value:
289      raise ConfigError("Cannot parse empty string")
290    return cls({"default": FlagsGroupConfig.parse_str(value)})
291
292  @classmethod
293  def parse_dict(cls, config: Dict[str, Any]) -> FlagsConfig:
294    groups: Dict[str, FlagsGroupConfig] = {}
295    for group_name, group_data in config.items():
296      with exception.annotate(f"Parsing flag-group: flags[{repr(group_name)}]"):
297        groups[group_name] = FlagsGroupConfig.parse(group_data)
298    return cls(groups)
299
300
301class BrowserVariantsConfig:
302
303  @classmethod
304  def from_cli_args(cls, args: argparse.Namespace) -> BrowserVariantsConfig:
305    browser_config = BrowserVariantsConfig()
306    if args.browser_config:
307      with late_argument_type_error_wrapper("--browser-config"):
308        path = args.browser_config.expanduser()
309        with path.open(encoding="utf-8") as f:
310          browser_config.parse_text_io(f, args)
311    else:
312      with late_argument_type_error_wrapper("--browser"):
313        browser_config.parse_args(args)
314    return browser_config
315
316  def __init__(self,
317               raw_config_data: Optional[Dict[str, Any]] = None,
318               browser_lookup_override: Optional[BrowserLookupTableT] = None,
319               args: Optional[argparse.Namespace] = None):
320    self.flags_config: FlagsConfig = FlagsConfig()
321    self._variants: List[Browser] = []
322    self._unique_names: Set[str] = set()
323    self._browser_lookup_override = browser_lookup_override or {}
324    if raw_config_data:
325      assert args, "args object needed when loading from dict."
326      self.parse_dict(raw_config_data, args)
327
328  @property
329  def variants(self) -> List[Browser]:
330    assert self._variants
331    return self._variants
332
333  def parse_text_io(self, f: TextIO, args: argparse.Namespace) -> None:
334    with exception.annotate(f"Loading browser config file: {f.name}"):
335      config = {}
336      with exception.annotate("Parsing hjson"):
337        config = hjson.load(f)
338      with exception.annotate(f"Parsing config file: {f.name}"):
339        self.parse_dict(config, args)
340
341  def parse_dict(self, config: Dict[str, Any],
342                 args: argparse.Namespace) -> None:
343    with exception.annotate(
344        f"Parsing {type(self).__name__} dict", throw_cls=ConfigError):
345      if "flags" in config:
346        with exception.annotate("Parsing config['flags']"):
347          self.flags_config = FlagsConfig.parse(config["flags"])
348      if "browsers" not in config:
349        raise ConfigError("Config does not provide a 'browsers' dict.")
350      if not config["browsers"]:
351        raise ConfigError("Config contains empty 'browsers' dict.")
352      with exception.annotate("Parsing config['browsers']"):
353        self._parse_browsers(config["browsers"], args)
354
355  def parse_args(self, args: argparse.Namespace) -> None:
356    browser_list: List[BrowserConfig] = args.browser or [
357        BrowserConfig.default()
358    ]
359    assert isinstance(browser_list, list)
360    browser_list = ObjectParser.unique_sequence(browser_list,
361                                                "--browser arguments")
362    for i, browser in enumerate(browser_list):
363      with exception.annotate(f"Append browser {i}"):
364        self._append_browser(args, browser)
365    self._verify_browser_flags(args)
366    self._ensure_unique_browser_names()
367
368  def _parse_browsers(self, data: Dict[str, Any],
369                      args: argparse.Namespace) -> None:
370    for name, browser_config in data.items():
371      with exception.annotate(f"Parsing browsers[{repr(name)}]"):
372        self._parse_browser(name, browser_config, args)
373    self._ensure_unique_browser_names()
374
375  def _parse_browser(self, name: str, raw_browser_data: Any,
376                     args: argparse.Namespace) -> None:
377    if isinstance(raw_browser_data, (dict, str)):
378      return self._parse_browser_dict(name, raw_browser_data, args)
379    raise argparse.ArgumentTypeError(
380        f"Expected str or dict, got {type(raw_browser_data).__name__}: "
381        f"{repr(raw_browser_data)}")
382
383  def _parse_browser_dict(self, name: str,
384                          raw_browser_data: Union[str, Dict[str, Any]],
385                          args: argparse.Namespace) -> None:
386    path_or_identifier: Optional[str] = None
387    if isinstance(raw_browser_data, dict):
388      path_or_identifier = raw_browser_data.get("path")
389    else:
390      path_or_identifier = raw_browser_data
391    browser_cls: Type[Browser]
392    if path_or_identifier and (path_or_identifier
393                               in self._browser_lookup_override):
394      browser_cls, browser_config = self._browser_lookup_override[
395          path_or_identifier]
396    else:
397      browser_config = self._maybe_downloaded_binary(
398          cast(BrowserConfig, BrowserConfig.parse(raw_browser_data)))
399      browser_cls = self.get_browser_cls(browser_config)
400    if not browser_config.driver.type.is_remote and (not pth.LocalPath(
401        browser_config.path).exists()):
402      raise ConfigError(
403          f"browsers[{repr(name)}].path='{browser_config.path}' does not exist."
404      )
405    flag_variants: FlagsGroupConfig = self._get_browser_variants(
406        name, raw_browser_data)
407    self._log_browser_variants(name, flag_variants)
408    browser_platform = self._get_browser_platform(browser_config)
409    labels_lookup = self._create_unique_variant_labels(name, raw_browser_data,
410                                                       flag_variants)
411    for variant in flag_variants:
412      label = labels_lookup[variant]
413      browser_flags = browser_cls.default_flags(variant.flags)
414      with exception.annotate_argparsing("Creating network config"):
415        network_config = browser_config.network or args.network
416        network = self._get_browser_network(network_config, browser_platform)
417      # TODO: move the browser instantiation to a separate step and only
418      # create BrowserConfig objects first.
419      # pytype: disable=not-instantiable
420      settings = Settings(
421          flags=browser_flags,
422          network=network,
423          driver_path=args.driver_path or browser_config.driver.path,
424          # TODO: support all args in the browser.config file
425          viewport=args.viewport,
426          splash_screen=args.splash_screen,
427          platform=browser_platform,
428          secrets=args.secrets.as_dict(),
429          driver_logging=args.driver_logging,
430          wipe_system_user_data=args.wipe_system_user_data,
431          http_request_timeout=args.http_request_timeout)
432      browser_instance = browser_cls(
433          label=label, path=browser_config.path, settings=settings)
434      # pytype: enable=not-instantiable
435      self._variants.append(browser_instance)
436
437  def _flags_to_label(self, name: str, flags: Flags) -> str:
438    return f"{name}_{_flags_to_label(flags)}"
439
440  def _create_unique_variant_labels(self, name: str,
441                                    raw_browser_data: Union[str, Dict[str,
442                                                                      Any]],
443                                    flag_variants: FlagsGroupConfig) -> Dict:
444    labels_lookup: Dict[FlagsVariantConfig, str] = {}
445    group_labels = set(variant.label for variant in flag_variants)
446    use_unique_variant_label = len(group_labels) == len(flag_variants)
447
448    for variant in flag_variants:
449      label = name
450      if isinstance(raw_browser_data, dict):
451        label = raw_browser_data.get("label", name)
452      if len(flag_variants) > 1:
453        if use_unique_variant_label:
454          label = f"{name}_{variant.label}"
455        else:
456          # TODO: This case might not happen anymore
457          label = self._flags_to_label(name, variant.flags)
458      if not self._check_unique_label(label):
459        raise ConfigError(f"browsers[{repr(name)}] has non-unique label: "
460                          f"{repr(label)}")
461      labels_lookup[variant] = label
462    return labels_lookup
463
464  def _check_unique_label(self, label: str) -> bool:
465    if label in self._unique_names:
466      return False
467    self._unique_names.add(label)
468    return True
469
470  def _get_browser_variants(
471      self, browser_name: str,
472      raw_browser_data: Union[str, Dict[str, Any]]) -> FlagsGroupConfig:
473    default_variant = FlagsVariantConfig(DEFAULT_LABEL)
474    flag_variants = FlagsGroupConfig((default_variant,))
475    if not isinstance(raw_browser_data, dict):
476      return flag_variants
477    flag_groups: List[FlagsGroupConfig] = []
478    with exception.annotate(f"Parsing browsers[{repr(browser_name)}].flags"):
479      flag_groups = self._parse_browser_flags(browser_name, raw_browser_data)
480    with exception.annotate(
481        f"Expand browsers[{repr(browser_name)}].flags into full variants"):
482      flag_variants = flag_variants.product(*flag_groups)
483    return flag_variants
484
485  def _parse_browser_flags(self, browser_name: str,
486                           data: Dict[str, Any]) -> List[FlagsGroupConfig]:
487    flag_group_names = data.get("flags", [])
488    if isinstance(flag_group_names, str):
489      flag_group_names = [flag_group_names]
490    self._validate_flags(browser_name, flag_group_names)
491    inline_flags = Flags()
492    flag_groups: List[FlagsGroupConfig] = []
493    for flag_group_name in flag_group_names:
494      if flag_group_name.startswith("--"):
495        inline_flags.update(Flags.parse(flag_group_name))
496      else:
497        maybe_flag_group = self.flags_config.get(flag_group_name, None)
498        if maybe_flag_group is None:
499          raise ConfigError(
500              f"group={repr(flag_group_name)} "
501              f"for browser={repr(browser_name)} does not exist.\n"
502              f"Choices are: {list(self.flags_config.keys())}")
503        flag_groups.append(maybe_flag_group)
504    if inline_flags:
505      flag_data = {"inline": inline_flags}
506      flag_groups.append(FlagsGroupConfig.parse_dict(flag_data))
507    return flag_groups
508
509  def _validate_flags(self, browser_name: str, flag_group_names: List[str]):
510    if isinstance(flag_group_names, str):
511      flag_group_names = [flag_group_names]
512    if not isinstance(flag_group_names, list):
513      raise ConfigError(
514          f"'flags' is not a list for browser={repr(browser_name)}")
515    seen_flag_group_names: Set[str] = set()
516    for flag_group_name in flag_group_names:
517      if flag_group_name in seen_flag_group_names:
518        raise ConfigError(f"Duplicate group name {repr(flag_group_name)} "
519                          f"for browser={repr(browser_name)}")
520
521  def _log_browser_variants(self, name: str,
522                            flag_variants: FlagsGroupConfig) -> None:
523    logging.info("SELECTED BROWSER: '%s' with %s flag variants:", name,
524                 len(flag_variants))
525    for i, variant in enumerate(flag_variants):
526      logging.info("   %s: %s", i, variant.flags)
527
528  def get_browser_cls(self, browser_config: BrowserConfig) -> Type[Browser]:
529    driver = browser_config.driver.type
530    path: pth.AnyPath = browser_config.path
531    assert not isinstance(path, str), "Invalid path"
532    if not BrowserConfig.is_supported_browser_path(path):
533      raise argparse.ArgumentTypeError(f"Unsupported browser path='{path}'")
534    path_str = str(browser_config.path).lower()
535    if "safari" in path_str:
536      return self._get_safari_browser_cls(browser_config)
537    if "chrome" in path_str:
538      return self._get_chrome_browser_cls(browser_config)
539    if "chromium" in path_str:
540      return self._get_chromium_browser_cls(browser_config)
541    if "firefox" in path_str:
542      if driver == BrowserDriverType.WEB_DRIVER:
543        return browsers.FirefoxWebDriver
544    if "edge" in path_str:
545      return browsers.EdgeWebDriver
546    raise argparse.ArgumentTypeError(f"Unsupported browser path='{path}'")
547
548  def _get_safari_browser_cls(self,
549                              browser_config: BrowserConfig) -> Type[Browser]:
550    driver = browser_config.driver.type
551    if driver == BrowserDriverType.IOS:
552      return browsers.SafariWebdriverIOS
553    if driver == BrowserDriverType.WEB_DRIVER:
554      return browsers.SafariWebDriver
555    if driver == BrowserDriverType.APPLE_SCRIPT:
556      return browsers.SafariAppleScript
557    raise argparse.ArgumentTypeError(f"Unsupported Safari driver: {driver}")
558
559  def _get_chrome_browser_cls(self,
560                              browser_config: BrowserConfig) -> Type[Browser]:
561    driver = browser_config.driver.type
562    if driver == BrowserDriverType.WEB_DRIVER:
563      return browsers.ChromeWebDriver
564    if driver == BrowserDriverType.APPLE_SCRIPT:
565      return browsers.ChromeAppleScript
566    if driver == BrowserDriverType.ANDROID:
567      if browsers.LocalChromeWebDriverAndroid.is_apk_helper(
568          browser_config.path):
569        return browsers.LocalChromeWebDriverAndroid
570      return browsers.ChromeWebDriverAndroid
571    if driver == BrowserDriverType.LINUX_SSH:
572      return browsers.ChromeWebDriverSsh
573    if driver == BrowserDriverType.CHROMEOS_SSH:
574      return browsers.ChromeWebDriverChromeOsSsh
575    raise argparse.ArgumentTypeError(f"Unsupported Chrome driver: {driver}")
576
577  def _get_chromium_browser_cls(self,
578                                browser_config: BrowserConfig) -> Type[Browser]:
579    driver = browser_config.driver.type
580    # TODO: technically this should be ChromiumWebDriver
581    if driver == BrowserDriverType.WEB_DRIVER:
582      return browsers.ChromiumWebDriver
583    if driver == BrowserDriverType.APPLE_SCRIPT:
584      return browsers.ChromiumAppleScript
585    if driver == BrowserDriverType.ANDROID:
586      if browsers.LocalChromiumWebDriverAndroid.is_apk_helper(
587          browser_config.path):
588        return browsers.LocalChromiumWebDriverAndroid
589      return browsers.ChromiumWebDriverAndroid
590    if driver == BrowserDriverType.LINUX_SSH:
591      return browsers.ChromiumWebDriverSsh
592    if driver == BrowserDriverType.CHROMEOS_SSH:
593      return browsers.ChromiumWebDriverChromeOsSsh
594    raise argparse.ArgumentTypeError(f"Unsupported chromium driver: {driver}")
595
596  def _get_browser_platform(self,
597                            browser_config: BrowserConfig) -> plt.Platform:
598    return browser_config.get_platform()
599
600  def _ensure_unique_browser_names(self) -> None:
601    if self._has_unique_variant_names():
602      return
603    # Expand to full version names
604    for browser in self._variants:
605      browser.unique_name = (
606          f"{browser.type_name}_{browser.version}_{browser.label}")
607    if self._has_unique_variant_names():
608      return
609    logging.info("Got unique browser names and versions, "
610                 "please use --browser-config for more meaningful names")
611    # Last resort, add index
612    for index, browser in enumerate(self._variants):
613      browser.unique_name += f"_{index}"
614    assert self._has_unique_variant_names()
615
616  def _has_unique_variant_names(self) -> bool:
617    names = [browser.unique_name for browser in self._variants]
618    unique_names = set(names)
619    return len(unique_names) == len(names)
620
621  def _extract_chrome_flags(self,
622                            args: argparse.Namespace) -> List[ChromeFlags]:
623    initial_flags = ChromeFlags()
624
625    if args.enable_features:
626      initial_flags["--enable-features"] = args.enable_features
627    if args.disable_features:
628      initial_flags["--disable-features"] = args.disable_features
629    if args.enable_field_trial_config is True:
630      initial_flags.set("--enable-field-trial-config")
631    if args.enable_field_trial_config is False:
632      initial_flags.set("--disable-field-trial-config")
633
634    flags_sets = [initial_flags]
635    if not args.js_flags:
636      return flags_sets
637
638    def copy_and_set_js_flags(flags: ChromeFlags,
639                              js_flags_str: str) -> ChromeFlags:
640      flags = flags.copy()
641      for js_flag in js_flags_str.split(","):
642        js_flag_name, js_flag_value = Flags.split(js_flag.lstrip())
643        flags.js_flags.set(js_flag_name, js_flag_value)
644      return flags
645
646    flags_sets = [
647        copy_and_set_js_flags(flags, js_flags_str)
648        for flags in flags_sets
649        for js_flags_str in args.js_flags
650    ]
651    return flags_sets
652
653  def _verify_browser_flags(self, args: argparse.Namespace) -> None:
654    for chrome_flags in self._extract_chrome_flags(args):
655      for flag_name, value in chrome_flags.items():
656        if not value:
657          continue
658        for browser in self._variants:
659          if not browser.attributes.is_chromium_based:
660            raise argparse.ArgumentTypeError(
661                f"Used chrome/chromium-specific flags {flag_name} "
662                f"for non-chrome {browser.unique_name}.\n"
663                "Use --browser-config for complex variants.")
664    browser_types = set(browser.type_name for browser in self._variants)
665    if len(browser_types) == 1:
666      return
667    if args.driver_path:
668      raise argparse.ArgumentTypeError(
669          f"Cannot use custom --driver-path='{args.driver_path}' "
670          f"for multiple browser {browser_types}.")
671    if args.other_browser_args:
672      raise argparse.ArgumentTypeError(
673          f"Multiple browser types {browser_types} "
674          "cannot be used with common extra browser flags: "
675          f"{args.other_browser_args}.\n"
676          "Use --browser-config for complex variants.")
677
678  def _maybe_downloaded_binary(self,
679                               browser_config: BrowserConfig) -> BrowserConfig:
680    path_or_identifier = browser_config.browser
681    if isinstance(path_or_identifier, pth.AnyPath):
682      return browser_config
683    browser_platform = self._get_browser_platform(browser_config)
684    if ChromeDownloader.is_valid(path_or_identifier, browser_platform):
685      downloaded = ChromeDownloader.load(path_or_identifier, browser_platform)
686    elif FirefoxDownloader.is_valid(path_or_identifier, browser_platform):
687      downloaded = FirefoxDownloader.load(path_or_identifier, browser_platform)
688    else:
689      raise ValueError(
690          f"No version-download support for browser: {path_or_identifier}")
691    return BrowserConfig(downloaded, browser_config.driver)
692
693  def _append_browser(self, args: argparse.Namespace,
694                      browser_config: BrowserConfig) -> None:
695    assert browser_config, "Expected non-empty BrowserConfig."
696    browser_config = self._maybe_downloaded_binary(browser_config)
697    browser_cls: Type[Browser] = self.get_browser_cls(browser_config)
698    path: pth.AnyPath = browser_config.path
699    flags_sets = [browser_cls.default_flags()]
700
701    if browser_config.driver.is_local and not pth.LocalPath(path).exists():
702      raise argparse.ArgumentTypeError(f"Browser binary does not exist: {path}")
703
704    if issubclass(browser_cls, browsers.Chromium):
705      assert all(isinstance(flags, ChromeFlags) for flags in flags_sets)
706
707      extra_flag_sets = self._extract_chrome_flags(args)
708      flags_sets = [
709          flags.merge_copy(extra_flags)
710          for flags in flags_sets
711          for extra_flags in extra_flag_sets
712      ]
713
714    for flag_str in args.other_browser_args:
715      flag_name, flag_value = Flags.split(flag_str)
716      for flags in flags_sets:
717        flags.set(flag_name, flag_value)
718
719    browser_platform = self._get_browser_platform(browser_config)
720    with exception.annotate_argparsing("Creating network config"):
721      network_config = browser_config.network or args.network
722      network = self._get_browser_network(network_config, browser_platform)
723
724    name = f"{browser_platform}_{len(self._unique_names)}"
725    for flags in flags_sets:
726      label = name
727      if len(flags_sets) > 1:
728        label = self._flags_to_label(label, flags)
729      assert self._check_unique_label(label), f"Non-unique label: {label}"
730      settings = Settings(
731          flags=flags,
732          network=network,
733          driver_path=args.driver_path or browser_config.driver.path,
734          viewport=args.viewport,
735          splash_screen=args.splash_screen,
736          platform=browser_platform,
737          secrets=args.secrets.as_dict(),
738          driver_logging=args.driver_logging,
739          wipe_system_user_data=args.wipe_system_user_data,
740          http_request_timeout=args.http_request_timeout)
741
742      browser_instance = browser_cls(  # pytype: disable=not-instantiable # pylint: disable=abstract-class-instantiated
743          label=label,
744          path=path,
745          settings=settings)
746      logging.info("SELECTED BROWSER: name=%s path='%s' ",
747                   browser_instance.unique_name, path)
748      self._variants.append(browser_instance)
749
750  def _get_browser_network(self, network_config: Union[pth.LocalPath,
751                                                       NetworkConfig],
752                           browser_platform: plt.Platform) -> Network:
753    if not isinstance(network_config, NetworkConfig):
754      network_config = NetworkConfig.parse(network_config)
755    return network_config.create(browser_platform)
756