• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 The Chromium Authors
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5from __future__ import annotations
6
7import abc
8import datetime as dt
9import logging
10import threading
11import time
12from typing import TYPE_CHECKING, Iterable
13
14from crossbench.parse import DurationParser, ObjectParser
15from crossbench.probes.probe import Probe, ProbeConfigParser, ProbeKeyT
16from crossbench.probes.probe_context import ProbeContext
17from crossbench.probes.results import LocalProbeResult, ProbeResult
18
19if TYPE_CHECKING:
20  from crossbench import plt
21  from crossbench.env import HostEnvironment
22  from crossbench.path import LocalPath
23  from crossbench.plt.base import CmdArg, TupleCmdArgs
24  from crossbench.runner.run import Run
25
26class PollingProbe(Probe, metaclass=abc.ABCMeta):
27  """
28  Abstract probe to periodically collect the results of any bash cmd.
29  """
30  NAME = "polling"
31  IS_GENERAL_PURPOSE = False
32
33  @classmethod
34  def config_parser(cls) -> ProbeConfigParser:
35    parser = super().config_parser()
36    parser.add_argument(
37        "interval",
38        type=DurationParser.positive_duration,
39        default=dt.timedelta(seconds=1),
40        help="Run the cmd at this interval and produce separate results.")
41    return parser
42
43  def __init__(
44      self,
45      cmd: Iterable[CmdArg],
46      interval: dt.timedelta = dt.timedelta(seconds=1)
47  ) -> None:
48    super().__init__()
49    self._cmd: TupleCmdArgs = tuple(cmd)
50    self._interval = interval
51    if interval.total_seconds() < 0.1:
52      raise ValueError(f"Polling interval must be >= 0.1s, but got: {interval}")
53
54  @property
55  def key(self) -> ProbeKeyT:
56    return super().key + (("cmd", tuple(self.cmd)),
57                          ("interval", self.interval.total_seconds()))
58
59  @property
60  def interval(self) -> dt.timedelta:
61    return self._interval
62
63  @property
64  def cmd(self) -> TupleCmdArgs:
65    return self._cmd
66
67  def validate_env(self, env: HostEnvironment) -> None:
68    super().validate_env(env)
69    if env.repetitions != 1:
70      env.handle_warning(f"Probe={self.NAME} cannot merge data over multiple "
71                         f"repetitions={env.repetitions}.")
72
73  def get_context(self, run: Run) -> PollingProbeContext:
74    return PollingProbeContext(self, run)
75
76
77class ShellPollingProbe(PollingProbe):
78  """
79  General-purpose probe to periodically collect the stdout of a given bash cmd.
80  """
81
82  IS_GENERAL_PURPOSE = True
83  NAME = "poll"
84
85  @classmethod
86  def config_parser(cls) -> ProbeConfigParser:
87    parser = super().config_parser()
88    parser.add_argument(
89        "cmd",
90        type=ObjectParser.sh_cmd,
91        required=True,
92        help="Write stdout of this CMD as a result.")
93    return parser
94
95
96class PollingProbeContext(ProbeContext[PollingProbe]):
97  _poller: CMDPoller
98
99  def __init__(self, probe: PollingProbe, run: Run) -> None:
100    super().__init__(probe, run)
101    self._poller = CMDPoller(self.browser_platform, self.probe.cmd,
102                             self.probe.interval, self.local_result_path)
103
104  def setup(self) -> None:
105    self.local_result_path.mkdir()
106
107  def start(self) -> None:
108    self._poller.start()
109
110  def stop(self) -> None:
111    self._poller.stop()
112
113  def teardown(self) -> ProbeResult:
114    return LocalProbeResult(file=(self.local_result_path,))
115
116
117class CMDPoller(threading.Thread):
118
119  def __init__(self, platform: plt.Platform, cmd: Iterable[CmdArg],
120               interval: dt.timedelta, path: LocalPath):
121    super().__init__()
122    self._platform = platform
123    self._cmd: TupleCmdArgs = tuple(cmd)
124    self._path: LocalPath = path
125    if interval < dt.timedelta(seconds=0.1):
126      raise ValueError("Poller interval should be >= 0.1s for accuracy, "
127                       f"but got {interval}s")
128    self._interval_seconds = interval.total_seconds()
129    self._event = threading.Event()
130
131  def stop(self) -> None:
132    self._event.set()
133    self.join()
134
135  def run(self) -> None:
136    start_time = time.monotonic_ns()
137    while not self._event.is_set():
138      poll_start = dt.datetime.now()
139
140      data = self._platform.sh_stdout(*self._cmd)
141      datetime_str = poll_start.strftime("%Y-%m-%d_%H%M%S_%f")
142      out_file = self._path / f"{datetime_str}.txt"
143      with out_file.open("w", encoding="utf-8") as f:
144        f.write(data)
145
146      poll_end = dt.datetime.now()
147      diff = (poll_end - poll_start).total_seconds()
148      if diff > self._interval_seconds:
149        logging.warning("Poller command took longer than expected %fs: %s",
150                        self._interval_seconds, self._cmd)
151
152      # Calculate wait_time against fixed start time to avoid drifting.
153      total_time = (time.monotonic_ns() - start_time) / 10.0**9
154      wait_time = self._interval_seconds - (total_time % self._interval_seconds)
155      self._event.wait(wait_time)
156