1# Copyright 2023 The Chromium Authors 2# Use of this source code is governed by a BSD-style license that can be 3# found in the LICENSE file. 4 5from __future__ import annotations 6 7import abc 8import datetime as dt 9import logging 10import threading 11import time 12from typing import TYPE_CHECKING, Iterable 13 14from crossbench.parse import DurationParser, ObjectParser 15from crossbench.probes.probe import Probe, ProbeConfigParser, ProbeKeyT 16from crossbench.probes.probe_context import ProbeContext 17from crossbench.probes.results import LocalProbeResult, ProbeResult 18 19if TYPE_CHECKING: 20 from crossbench import plt 21 from crossbench.env import HostEnvironment 22 from crossbench.path import LocalPath 23 from crossbench.plt.base import CmdArg, TupleCmdArgs 24 from crossbench.runner.run import Run 25 26class PollingProbe(Probe, metaclass=abc.ABCMeta): 27 """ 28 Abstract probe to periodically collect the results of any bash cmd. 29 """ 30 NAME = "polling" 31 IS_GENERAL_PURPOSE = False 32 33 @classmethod 34 def config_parser(cls) -> ProbeConfigParser: 35 parser = super().config_parser() 36 parser.add_argument( 37 "interval", 38 type=DurationParser.positive_duration, 39 default=dt.timedelta(seconds=1), 40 help="Run the cmd at this interval and produce separate results.") 41 return parser 42 43 def __init__( 44 self, 45 cmd: Iterable[CmdArg], 46 interval: dt.timedelta = dt.timedelta(seconds=1) 47 ) -> None: 48 super().__init__() 49 self._cmd: TupleCmdArgs = tuple(cmd) 50 self._interval = interval 51 if interval.total_seconds() < 0.1: 52 raise ValueError(f"Polling interval must be >= 0.1s, but got: {interval}") 53 54 @property 55 def key(self) -> ProbeKeyT: 56 return super().key + (("cmd", tuple(self.cmd)), 57 ("interval", self.interval.total_seconds())) 58 59 @property 60 def interval(self) -> dt.timedelta: 61 return self._interval 62 63 @property 64 def cmd(self) -> TupleCmdArgs: 65 return self._cmd 66 67 def validate_env(self, env: HostEnvironment) -> None: 68 super().validate_env(env) 69 if env.repetitions != 1: 70 env.handle_warning(f"Probe={self.NAME} cannot merge data over multiple " 71 f"repetitions={env.repetitions}.") 72 73 def get_context(self, run: Run) -> PollingProbeContext: 74 return PollingProbeContext(self, run) 75 76 77class ShellPollingProbe(PollingProbe): 78 """ 79 General-purpose probe to periodically collect the stdout of a given bash cmd. 80 """ 81 82 IS_GENERAL_PURPOSE = True 83 NAME = "poll" 84 85 @classmethod 86 def config_parser(cls) -> ProbeConfigParser: 87 parser = super().config_parser() 88 parser.add_argument( 89 "cmd", 90 type=ObjectParser.sh_cmd, 91 required=True, 92 help="Write stdout of this CMD as a result.") 93 return parser 94 95 96class PollingProbeContext(ProbeContext[PollingProbe]): 97 _poller: CMDPoller 98 99 def __init__(self, probe: PollingProbe, run: Run) -> None: 100 super().__init__(probe, run) 101 self._poller = CMDPoller(self.browser_platform, self.probe.cmd, 102 self.probe.interval, self.local_result_path) 103 104 def setup(self) -> None: 105 self.local_result_path.mkdir() 106 107 def start(self) -> None: 108 self._poller.start() 109 110 def stop(self) -> None: 111 self._poller.stop() 112 113 def teardown(self) -> ProbeResult: 114 return LocalProbeResult(file=(self.local_result_path,)) 115 116 117class CMDPoller(threading.Thread): 118 119 def __init__(self, platform: plt.Platform, cmd: Iterable[CmdArg], 120 interval: dt.timedelta, path: LocalPath): 121 super().__init__() 122 self._platform = platform 123 self._cmd: TupleCmdArgs = tuple(cmd) 124 self._path: LocalPath = path 125 if interval < dt.timedelta(seconds=0.1): 126 raise ValueError("Poller interval should be >= 0.1s for accuracy, " 127 f"but got {interval}s") 128 self._interval_seconds = interval.total_seconds() 129 self._event = threading.Event() 130 131 def stop(self) -> None: 132 self._event.set() 133 self.join() 134 135 def run(self) -> None: 136 start_time = time.monotonic_ns() 137 while not self._event.is_set(): 138 poll_start = dt.datetime.now() 139 140 data = self._platform.sh_stdout(*self._cmd) 141 datetime_str = poll_start.strftime("%Y-%m-%d_%H%M%S_%f") 142 out_file = self._path / f"{datetime_str}.txt" 143 with out_file.open("w", encoding="utf-8") as f: 144 f.write(data) 145 146 poll_end = dt.datetime.now() 147 diff = (poll_end - poll_start).total_seconds() 148 if diff > self._interval_seconds: 149 logging.warning("Poller command took longer than expected %fs: %s", 150 self._interval_seconds, self._cmd) 151 152 # Calculate wait_time against fixed start time to avoid drifting. 153 total_time = (time.monotonic_ns() - start_time) / 10.0**9 154 wait_time = self._interval_seconds - (total_time % self._interval_seconds) 155 self._event.wait(wait_time) 156