• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3#
4# Copyright (c) 2024 Huawei Device Co., Ltd.
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18from collections.abc import Sequence
19from contextlib import contextmanager
20from typing import Callable
21
22import trio
23from pytest import Config, Parser
24from typing_extensions import override
25
26from .ark_config import get_config, str2bool
27from .logs import TRIO, logger
28
29LOG = logger(__name__)
30
31
32@contextmanager
33def task_instrumentation(**kwargs):
34    try:
35        t = TrioTracer(**kwargs)
36        trio.lowlevel.add_instrument(t)
37        yield
38    finally:
39        trio.lowlevel.remove_instrument(t)
40
41
42class TrioTracer(trio.abc.Instrument):
43
44    def __init__(self, show_io: bool = True, show_run: bool = True) -> None:
45        super().__init__()
46        self.show_io = show_io
47        self.show_run = show_run
48        self._sleep_time = trio.current_time()
49
50    @override
51    def before_run(self):
52        self._log("!!! STARTED")
53
54    @override
55    def after_run(self):
56        self._log("!!! FINISHED")
57
58    @override
59    def task_spawned(self, task: trio.lowlevel.Task):
60        self._task_log("+++ SPAWNED     ", task)
61
62    @override
63    def task_scheduled(self, task: trio.lowlevel.Task):
64        self._task_log("    SCHEDULED   ", task)
65
66    @override
67    def before_task_step(self, task: trio.lowlevel.Task):
68        self._task_log("* { BEFORE STEP ", task)
69
70    @override
71    def after_task_step(self, task: trio.lowlevel.Task):
72        self._task_log("  } AFTER STEP  ", task)
73
74    @override
75    def task_exited(self, task: trio.lowlevel.Task):
76        self._task_log("--- EXITED      ", task)
77
78    @override
79    def before_io_wait(self, timeout):
80        if not self.show_io:
81            return
82        if timeout:
83            self._log(f". waiting for I/O for up to {timeout} seconds")
84        else:
85            self._log(". doing a quick check for I/O")
86        self._sleep_time = trio.current_time()
87
88    @override
89    def after_io_wait(self, timeout):
90        if not self.show_io:
91            return
92        duration = trio.current_time() - self._sleep_time
93        self._log(f". finished I/O check (took {duration} seconds)")
94
95    def _log(self, msg: str, *args, **kwargs):
96        LOG.log(TRIO, msg, *args, **kwargs)
97
98    def _task_log(self, msg: str, task: trio.lowlevel.Task):
99        self._log(f"{msg}: %r", task)
100
101
102def pytest_addoption(parser: Parser) -> None:
103    """Add options to control log capturing."""
104    group = parser.getgroup("arkdb")
105
106    parser.addini(
107        name="task_trace",
108        default=False,
109        type="bool",
110        help="Default value for --task-trace",
111    )
112    group.addoption(
113        "--task-trace",
114        "-T",
115        dest="task_trace",
116        help="Enable trio task traces.",
117        type=str2bool,
118    )
119
120
121def choose_run(config: Config) -> Callable[..., None]:
122
123    def run(
124        *args,
125        instruments: Sequence[trio.abc.Instrument] = (),
126        strict_exception_groups: bool = False,
127        **kwargs,
128    ) -> None:
129        if get_config(config, "task_trace"):
130            if not any([isinstance(i, TrioTracer) for i in instruments]):
131                instruments = [*instruments, TrioTracer()]
132        trio.run(
133            *args,
134            instruments=instruments,
135            strict_exception_groups=strict_exception_groups,
136            **kwargs,
137        )
138
139    return run
140