• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3#
4# Copyright (c) 2024 Huawei Device Co., Ltd.
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18import copy
19import logging
20import re
21import sys
22import warnings
23from collections.abc import Iterator
24from contextlib import contextmanager
25from traceback import StackSummary, extract_tb
26from types import TracebackType
27from typing import Optional, Tuple, Type
28
29import pluggy
30from _pytest.nodes import Item
31from pytest import FixtureRequest, Function, StashKey, fixture, hookimpl, mark
32from rich.traceback import Traceback
33from typing_extensions import TypeVar
34
35from .logs import logger
36from .rich_logging import remove_ansi_escape_sequences
37
38LOG = logger(__name__)
39
40
41class ExpectWarning(UserWarning):
42    pass
43
44
45class ExpectErrorWarning(UserWarning):
46    pass
47
48
49class ExpectError(Exception):
50    pass
51
52
53def tb_next(tb, step: int = 1) -> Optional[TracebackType]:
54    for _ in range(step):
55        if tb is None:
56            return None
57        tb = tb.tb_next
58    return tb
59
60
61def _extract_traceback(level: int = 4) -> Optional[Tuple[StackSummary, Traceback]]:
62    match sys.exc_info():
63        case (None, _, _) | (_, None, _) | (_, _, None):
64            return None
65        case (exc_type, exc_value, traceback):
66            assert isinstance(exc_value, BaseException)
67            new_exc_value = copy.copy(exc_value)
68            new_exc_value.args = tuple(remove_ansi_escape_sequences(a) for a in new_exc_value.args)
69
70            tb = tb_next(traceback, 2)
71            return (
72                extract_tb(tb),
73                Traceback.from_exception(
74                    exc_type=exc_type,  # type: ignore[arg-type]
75                    exc_value=new_exc_value,
76                    traceback=tb,
77                    suppress=(__file__,),
78                ),
79            )
80    return None
81
82
83E = TypeVar("E", bound=BaseException)
84
85
86class Expect:
87    fail = False
88
89    @contextmanager
90    def _check(self, level: int, waring_type: Type[Warning]) -> Iterator[None]:
91        try:
92            yield
93        except AssertionError as e:
94            tb = _extract_traceback()
95            if tb is None:
96                raise e
97            stacks, rich_tb = tb
98            s = stacks[0]
99            m = re.match(".*", str(e))
100            msg = m.group() if m else str(e)
101            LOG.log(level, "%s", msg, rich=rich_tb)
102            warnings.warn_explicit(
103                waring_type(e),
104                category=waring_type,
105                filename=s.filename,
106                lineno=(s.lineno or 0),
107                module=s.name,
108            )
109
110    @contextmanager
111    def warning(self) -> Iterator[None]:
112        """
113        Return a context that registers AssertError as a warning and continues execution.
114        """
115        with self._check(logging.WARNING, ExpectWarning) as c:
116            yield c
117
118    @contextmanager
119    def error(self) -> Iterator[None]:
120        """
121        Return a context that registers AssertError as a error and continues execution.
122        """
123        with self._check(logging.ERROR, ExpectErrorWarning) as c:
124            try:
125                yield c
126            except BaseException:
127                self.fail = True
128                raise
129
130
131def expect_xfail(**kwargs):
132    """
133    Pytest ``xfail`` mark for ``expect`` fixture.
134    """
135
136    raises: None | type[BaseException] | tuple[type[BaseException], ...] = kwargs.get("raises", tuple())
137    if isinstance(raises, BaseException):
138        raises = (raises,)
139    elif isinstance(raises, tuple):
140        pass
141    elif raises is None:
142        raises = tuple()
143    else:
144        raise ValueError(raises)
145    kwargs = {
146        "raises": (ExpectError, *raises),
147        **kwargs,
148    }
149    return mark.xfail(**kwargs)
150
151
152context_key = StashKey[Expect]()
153
154
155@fixture(scope="function")
156def expect(request: FixtureRequest) -> Expect:
157    """
158    Return a :class:`Expect` instance that implements the expect feature.
159    """
160    node: Function = request.node
161    return node.stash.setdefault(context_key, Expect())
162
163
164@hookimpl(hookwrapper=True)
165def pytest_runtest_call(item: Item):
166    __tracebackhide__ = True
167    context: Expect = item.stash.setdefault(context_key, Expect())
168    outcome: pluggy.Result = yield
169    if context.fail:
170        outcome.force_exception(ExpectError())
171