1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3# 4# Copyright (c) 2024 Huawei Device Co., Ltd. 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# 17 18import copy 19import logging 20import re 21import sys 22import warnings 23from collections.abc import Iterator 24from contextlib import contextmanager 25from traceback import StackSummary, extract_tb 26from types import TracebackType 27from typing import Optional, Tuple, Type 28 29import pluggy 30from _pytest.nodes import Item 31from pytest import FixtureRequest, Function, StashKey, fixture, hookimpl, mark 32from rich.traceback import Traceback 33from typing_extensions import TypeVar 34 35from .logs import logger 36from .rich_logging import remove_ansi_escape_sequences 37 38LOG = logger(__name__) 39 40 41class ExpectWarning(UserWarning): 42 pass 43 44 45class ExpectErrorWarning(UserWarning): 46 pass 47 48 49class ExpectError(Exception): 50 pass 51 52 53def tb_next(tb, step: int = 1) -> Optional[TracebackType]: 54 for _ in range(step): 55 if tb is None: 56 return None 57 tb = tb.tb_next 58 return tb 59 60 61def _extract_traceback(level: int = 4) -> Optional[Tuple[StackSummary, Traceback]]: 62 match sys.exc_info(): 63 case (None, _, _) | (_, None, _) | (_, _, None): 64 return None 65 case (exc_type, exc_value, traceback): 66 assert isinstance(exc_value, BaseException) 67 new_exc_value = copy.copy(exc_value) 68 new_exc_value.args = tuple(remove_ansi_escape_sequences(a) for a in new_exc_value.args) 69 70 tb = tb_next(traceback, 2) 71 return ( 72 extract_tb(tb), 73 Traceback.from_exception( 74 exc_type=exc_type, # type: ignore[arg-type] 75 exc_value=new_exc_value, 76 traceback=tb, 77 suppress=(__file__,), 78 ), 79 ) 80 return None 81 82 83E = TypeVar("E", bound=BaseException) 84 85 86class Expect: 87 fail = False 88 89 @contextmanager 90 def _check(self, level: int, waring_type: Type[Warning]) -> Iterator[None]: 91 try: 92 yield 93 except AssertionError as e: 94 tb = _extract_traceback() 95 if tb is None: 96 raise e 97 stacks, rich_tb = tb 98 s = stacks[0] 99 m = re.match(".*", str(e)) 100 msg = m.group() if m else str(e) 101 LOG.log(level, "%s", msg, rich=rich_tb) 102 warnings.warn_explicit( 103 waring_type(e), 104 category=waring_type, 105 filename=s.filename, 106 lineno=(s.lineno or 0), 107 module=s.name, 108 ) 109 110 @contextmanager 111 def warning(self) -> Iterator[None]: 112 """ 113 Return a context that registers AssertError as a warning and continues execution. 114 """ 115 with self._check(logging.WARNING, ExpectWarning) as c: 116 yield c 117 118 @contextmanager 119 def error(self) -> Iterator[None]: 120 """ 121 Return a context that registers AssertError as a error and continues execution. 122 """ 123 with self._check(logging.ERROR, ExpectErrorWarning) as c: 124 try: 125 yield c 126 except BaseException: 127 self.fail = True 128 raise 129 130 131def expect_xfail(**kwargs): 132 """ 133 Pytest ``xfail`` mark for ``expect`` fixture. 134 """ 135 136 raises: None | type[BaseException] | tuple[type[BaseException], ...] = kwargs.get("raises", tuple()) 137 if isinstance(raises, BaseException): 138 raises = (raises,) 139 elif isinstance(raises, tuple): 140 pass 141 elif raises is None: 142 raises = tuple() 143 else: 144 raise ValueError(raises) 145 kwargs = { 146 "raises": (ExpectError, *raises), 147 **kwargs, 148 } 149 return mark.xfail(**kwargs) 150 151 152context_key = StashKey[Expect]() 153 154 155@fixture(scope="function") 156def expect(request: FixtureRequest) -> Expect: 157 """ 158 Return a :class:`Expect` instance that implements the expect feature. 159 """ 160 node: Function = request.node 161 return node.stash.setdefault(context_key, Expect()) 162 163 164@hookimpl(hookwrapper=True) 165def pytest_runtest_call(item: Item): 166 __tracebackhide__ = True 167 context: Expect = item.stash.setdefault(context_key, Expect()) 168 outcome: pluggy.Result = yield 169 if context.fail: 170 outcome.force_exception(ExpectError()) 171