1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3# 4# Copyright (c) 2024 Huawei Device Co., Ltd. 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# 17 18from collections.abc import Sequence 19from contextlib import contextmanager 20from typing import Callable 21 22import trio 23from pytest import Config, Parser 24from typing_extensions import override 25 26from .ark_config import get_config, str2bool 27from .logs import TRIO, logger 28 29LOG = logger(__name__) 30 31 32@contextmanager 33def task_instrumentation(**kwargs): 34 try: 35 t = TrioTracer(**kwargs) 36 trio.lowlevel.add_instrument(t) 37 yield 38 finally: 39 trio.lowlevel.remove_instrument(t) 40 41 42class TrioTracer(trio.abc.Instrument): 43 44 def __init__(self, show_io: bool = True, show_run: bool = True) -> None: 45 super().__init__() 46 self.show_io = show_io 47 self.show_run = show_run 48 self._sleep_time = trio.current_time() 49 50 @override 51 def before_run(self): 52 self._log("!!! STARTED") 53 54 @override 55 def after_run(self): 56 self._log("!!! FINISHED") 57 58 @override 59 def task_spawned(self, task: trio.lowlevel.Task): 60 self._task_log("+++ SPAWNED ", task) 61 62 @override 63 def task_scheduled(self, task: trio.lowlevel.Task): 64 self._task_log(" SCHEDULED ", task) 65 66 @override 67 def before_task_step(self, task: trio.lowlevel.Task): 68 self._task_log("* { BEFORE STEP ", task) 69 70 @override 71 def after_task_step(self, task: trio.lowlevel.Task): 72 self._task_log(" } AFTER STEP ", task) 73 74 @override 75 def task_exited(self, task: trio.lowlevel.Task): 76 self._task_log("--- EXITED ", task) 77 78 @override 79 def before_io_wait(self, timeout): 80 if not self.show_io: 81 return 82 if timeout: 83 self._log(f". waiting for I/O for up to {timeout} seconds") 84 else: 85 self._log(". doing a quick check for I/O") 86 self._sleep_time = trio.current_time() 87 88 @override 89 def after_io_wait(self, timeout): 90 if not self.show_io: 91 return 92 duration = trio.current_time() - self._sleep_time 93 self._log(f". finished I/O check (took {duration} seconds)") 94 95 def _log(self, msg: str, *args, **kwargs): 96 LOG.log(TRIO, msg, *args, **kwargs) 97 98 def _task_log(self, msg: str, task: trio.lowlevel.Task): 99 self._log(f"{msg}: %r", task) 100 101 102def pytest_addoption(parser: Parser) -> None: 103 """Add options to control log capturing.""" 104 group = parser.getgroup("arkdb") 105 106 parser.addini( 107 name="task_trace", 108 default=False, 109 type="bool", 110 help="Default value for --task-trace", 111 ) 112 group.addoption( 113 "--task-trace", 114 "-T", 115 dest="task_trace", 116 help="Enable trio task traces.", 117 type=str2bool, 118 ) 119 120 121def choose_run(config: Config) -> Callable[..., None]: 122 123 def run( 124 *args, 125 instruments: Sequence[trio.abc.Instrument] = (), 126 strict_exception_groups: bool = False, 127 **kwargs, 128 ) -> None: 129 if get_config(config, "task_trace"): 130 if not any([isinstance(i, TrioTracer) for i in instruments]): 131 instruments = [*instruments, TrioTracer()] 132 trio.run( 133 *args, 134 instruments=instruments, 135 strict_exception_groups=strict_exception_groups, 136 **kwargs, 137 ) 138 139 return run 140