1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3 4# Copyright (c) 2024 Huawei Device Co., Ltd. 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# 17 18from __future__ import annotations 19 20import logging 21import json 22from dataclasses import dataclass 23from typing import TextIO, Tuple, Optional, List, Dict, Any 24import re 25from enum import Enum 26import os 27 28from runner.logger import Log 29 30_LOGGER = logging.getLogger('runner.astchecker.util_astchecker') 31 32 33class UtilASTChecker: 34 skip_options = {'SkipErrors': False, 'SkipWarnings': False} 35 36 class _TestType(Enum): 37 NODE = 'Node' 38 ERROR = 'Error' 39 WARNING = 'Warning' 40 41 @dataclass 42 class _Pattern: 43 pattern_type: UtilASTChecker._TestType 44 pattern: str 45 line: int 46 col: int 47 error_file: str = '' 48 49 class _TestCase: 50 """ 51 Class for storing test case parsed from test file 52 """ 53 54 def __init__(self, name: Optional[str], pattern: UtilASTChecker._Pattern, checks: dict) -> None: 55 self.name = name 56 self.line = pattern.line 57 self.col = pattern.col 58 self.test_type = pattern.pattern_type 59 self.checks = checks 60 self.error_file = pattern.error_file 61 62 def __repr__(self) -> str: 63 return f'TestCase({self.name}, {self.line}:{self.col}, {self.test_type}, {self.checks}, {self.error_file})' 64 65 def __eq__(self, other: Any) -> bool: 66 return bool(self.name == other.name and self.line == other.line and self.col == other.col 67 and self.test_type == other.test_type and self.checks == other.checks 68 and self.error_file == other.error_file) 69 70 def __hash__(self) -> int: 71 return hash(self.__repr__()) 72 73 class TestCasesList: 74 """ 75 Class for storing test cases parsed from one test file 76 """ 77 78 def __init__(self, tests_list: set[UtilASTChecker._TestCase]): 79 self.tests_list = tests_list 80 has_error_tests = False 81 has_warning_tests = False 82 for test in tests_list: 83 if test.test_type == UtilASTChecker._TestType.ERROR: 84 has_error_tests = True 85 if test.test_type == UtilASTChecker._TestType.WARNING: 86 has_warning_tests = True 87 self.has_error_tests = has_error_tests 88 self.has_warning_tests = has_warning_tests 89 self.skip_errors = False 90 self.skip_warnings = False 91 92 def __init__(self) -> None: 93 self.regex = re.compile(r'/\*\s*@@\s*(?P<pattern>.*?)\s*\*/', re.DOTALL) 94 self.reset_skips() 95 96 @staticmethod 97 def create_test_case(name: Optional[str], pattern: UtilASTChecker._Pattern) -> UtilASTChecker._TestCase: 98 pattern_parsed = {'error': pattern.pattern} 99 if pattern.pattern_type == UtilASTChecker._TestType.NODE: 100 try: 101 pattern_parsed = json.loads(pattern.pattern) 102 except json.JSONDecodeError as ex: 103 Log.exception_and_raise( 104 _LOGGER, 105 f'TestCase: {name}.\nThrows JSON error: {ex}.\nJSON data: {pattern.pattern}') 106 return UtilASTChecker._TestCase(name, pattern, pattern_parsed) 107 108 @staticmethod 109 def get_match_location(match: re.Match, start: bool = False) -> Tuple[int, int]: 110 """ 111 Returns match location in file: line and column (counting from 1) 112 """ 113 match_idx = match.start() if start else match.end() 114 line_start_index = match.string[:match_idx].rfind('\n') + 1 115 col = match_idx - line_start_index + 1 116 line = match.string[:match_idx].count('\n') + 1 117 return line, col 118 119 @staticmethod 120 def check_properties(node: dict, properties: dict) -> bool: 121 """ 122 Checks if `node` contains all key:value pairs specified in `properties` dict argument 123 """ 124 for key, value in properties.items(): 125 if node.get(key) != value: 126 return False 127 return True 128 129 @staticmethod 130 def run_error_test(test_file: str, test: _TestCase, actual_errors: set) -> bool: 131 file_name = os.path.basename(test_file) if test.error_file == '' else test.error_file 132 expected_error = f'{test.checks["error"]}', f'[{file_name}:{test.line}:{test.col}]' 133 if expected_error in actual_errors: 134 actual_errors.remove(expected_error) 135 return True 136 Log.all(_LOGGER, f'No Expected error {expected_error}') 137 return False 138 139 @staticmethod 140 def get_actual_errors(error: str) -> set: 141 actual_errors = set() 142 for error_str in error.splitlines(): 143 if error_str.strip(): 144 error_text, error_loc = error_str.rsplit(' ', 1) 145 actual_errors.add((error_text.strip(), error_loc)) 146 return actual_errors 147 148 def reset_skips(self) -> None: 149 self.skip_options['SkipErrors'] = False 150 self.skip_options['SkipWarnings'] = False 151 152 def check_skip_error(self) -> bool: 153 return self.skip_options["SkipErrors"] 154 155 def check_skip_warning(self) -> bool: 156 return self.skip_options["SkipWarnings"] 157 158 def parse_define_statement(self, match: re.Match[str], 159 link_defs_map: Dict[str, Tuple[UtilASTChecker._TestType, str]], 160 link_sources_map: Dict[str, re.Match[str]]) -> Optional[UtilASTChecker._TestCase]: 161 """ 162 Parses `@<id> <pattern-type> <pattern>` 163 """ 164 match_str = re.sub(r'\s+', ' ', match.group('pattern'))[1:].strip() 165 sep1 = match_str.find(' ') 166 sep2 = match_str.find(' ', sep1 + 1) 167 if sep1 == -1 or sep2 == -1: 168 Log.exception_and_raise(_LOGGER, 'Wrong definition format: expected ' 169 f'`/* @@@ <id> <pattern-type> <pattern> */`, got /* @@@ {match_str} */') 170 name = match_str[:sep1] 171 pattern_type = UtilASTChecker._TestType(match_str[sep1 + 1:sep2]) 172 pattern = match_str[sep2 + 1:] 173 174 if name in link_defs_map: 175 line, col = self.get_match_location(match) 176 Log.exception_and_raise(_LOGGER, f'Link {name} (at location {line}:{col}) is already defined') 177 178 if name in link_sources_map: 179 match = link_sources_map[name] 180 del link_sources_map[name] 181 line, col = self.get_match_location(match) 182 return self.create_test_case(name, UtilASTChecker._Pattern(pattern_type, pattern, line, col)) 183 184 link_defs_map[name] = (pattern_type, pattern) 185 return None 186 187 def parse_match_statement(self, match: re.Match[str], 188 link_defs_map: Dict[str, Tuple[UtilASTChecker._TestType, str]], 189 link_sources_map: Dict[str, re.Match[str]]) -> Optional[UtilASTChecker._TestCase]: 190 """ 191 Parses `<pattern-type> <pattern>` and `<id>` 192 """ 193 str_match = match.group('pattern') 194 sep = str_match.find(' ') 195 if sep == -1: 196 # parse `<id>` 197 name = str_match 198 if any(not char.isalnum() and char != '_' for char in name): 199 Log.exception_and_raise(_LOGGER, f'Bad `<id>` value, expected value from `[a-zA-Z0-9_]+`, got {name}') 200 201 if name in link_sources_map: 202 line, col = self.get_match_location(match) 203 Log.exception_and_raise(_LOGGER, f'Link {name} (at location {line}:{col}) is already defined') 204 205 if name in link_defs_map: 206 pattern_type, pattern = link_defs_map[name] 207 del link_defs_map[name] 208 line, col = self.get_match_location(match) 209 return self.create_test_case(name, UtilASTChecker._Pattern(pattern_type, pattern, line, col)) 210 211 link_sources_map[name] = match 212 return None 213 214 # parse `<pattern-type> <pattern>` 215 pattern_type = UtilASTChecker._TestType(str_match[:sep]) 216 pattern = str_match[sep + 1:] 217 line, col = self.get_match_location(match) 218 return self.create_test_case(None, UtilASTChecker._Pattern(pattern_type, pattern, line, col)) 219 220 def parse_skip_statement(self, match: re.Match[str]) -> None: 221 """ 222 Parses `# <skip-option>*` 223 """ 224 match_str = re.sub(r'\s+', ' ', match.group('pattern'))[1:].strip() 225 pattern = r"(\w+)\s*=\s*(\w+)" 226 matches = re.findall(pattern, match_str) 227 for opt, val in matches: 228 value = val.lower() 229 if opt not in self.skip_options or value not in ['true', 'false']: 230 Log.exception_and_raise(_LOGGER, 'Wrong match_at_location format: expected ' 231 f'`/* @@# <skip-option> */`, got /* @@? {match_str} */') 232 self.skip_options[opt] = value == 'true' 233 234 def parse_match_at_loc_statement(self, match: re.Match[str]) -> UtilASTChecker._TestCase: 235 """ 236 Parses `? <line>:<col> <pattern-type> <pattern>` 237 and `? <file-name>:<line>:<col> <pattern-type> <pattern>` 238 """ 239 match_str = re.sub(r'\s+', ' ', match.group('pattern'))[1:].strip() 240 sep1 = match_str.find(' ') 241 sep2 = match_str.find(' ', sep1 + 1) 242 if sep1 == -1 or sep2 == -1: 243 Log.exception_and_raise(_LOGGER, 'Wrong match_at_location format: expected ' 244 '`/* @@? <line>:<col> <pattern-type> <pattern> */`, ' 245 f'got /* @@? {match_str} */') 246 location = match_str[:sep1] 247 loc_sep1 = location.find(':') 248 loc_sep2 = location.find(':', loc_sep1 + 1) 249 if loc_sep1 == -1: 250 Log.exception_and_raise(_LOGGER, 'Wrong match_at_location format: expected ' 251 '`/* @@? <line>:<col> <pattern-type> <pattern> */`, ' 252 f'got /* @@? {match_str} */') 253 line_str = location[:loc_sep1] if loc_sep2 == -1 else location[loc_sep1 + 1:loc_sep2] 254 col_str = location[loc_sep1 + 1:] if loc_sep2 == -1 else location[loc_sep2 + 1:] 255 error_file = '' if loc_sep2 == -1 else location[:loc_sep1] 256 if loc_sep2 != -1: 257 pass 258 if not line_str.isdigit(): 259 Log.exception_and_raise(_LOGGER, f'Expected line number, got {line_str}') 260 if not col_str.isdigit(): 261 Log.exception_and_raise(_LOGGER, f'Expected column number, got {col_str}') 262 263 line, col = int(line_str), int(col_str) 264 pattern_type = UtilASTChecker._TestType(match_str[sep1 + 1:sep2]) 265 pattern = match_str[sep2 + 1:] 266 267 return self.create_test_case( 268 name=None, 269 pattern=UtilASTChecker._Pattern(pattern_type, pattern, line, col, error_file) 270 ) 271 272 def parse_tests(self, file: TextIO) -> UtilASTChecker.TestCasesList: 273 """ 274 Takes .ets file with tests and parses them into a list of TestCases. 275 """ 276 self.reset_skips() 277 test_text = file.read() 278 link_defs_map: Dict[str, Tuple[UtilASTChecker._TestType, str]] = {} 279 link_sources_map: Dict[str, re.Match[str]] = {} 280 test_cases = set() 281 matches = list(re.finditer(self.regex, test_text)) 282 for match in matches: 283 pattern = match.group('pattern') 284 if pattern.startswith('@'): 285 test = self.parse_define_statement(match, link_defs_map, link_sources_map) 286 elif pattern.startswith('?'): 287 test = self.parse_match_at_loc_statement(match) 288 elif pattern.startswith('#'): 289 test = None 290 self.parse_skip_statement(match) 291 else: 292 test = self.parse_match_statement(match, link_defs_map, link_sources_map) 293 if test is not None: 294 test_cases.add(test) 295 296 if len(link_defs_map) or len(link_sources_map): 297 Log.exception_and_raise(_LOGGER, 'link defined twice') 298 299 test_case_list = UtilASTChecker.TestCasesList(test_cases) 300 test_case_list.skip_errors = self.check_skip_error() 301 test_case_list.skip_warnings = self.check_skip_warning() 302 return test_case_list 303 304 def find_nodes_by_start_location(self, root: dict, line: int, col: int) -> List[dict]: 305 """ 306 Finds all descendants of `root` with location starting at `loc` 307 """ 308 nodes = [] 309 start = root.get('loc', {}).get('start', {}) 310 if start.get('line', None) == line and start.get('column', None) == col: 311 nodes.append(root) 312 313 for child in root.values(): 314 if isinstance(child, dict): 315 nodes.extend(self.find_nodes_by_start_location(child, line, col)) 316 if isinstance(child, list): 317 for item in child: 318 nodes.extend(self.find_nodes_by_start_location(item, line, col)) 319 return nodes 320 321 def run_node_test(self, test: _TestCase, ast: dict) -> bool: 322 nodes_by_loc = self.find_nodes_by_start_location(ast, test.line, test.col) 323 test_passed = False 324 for node in nodes_by_loc: 325 if self.check_properties(node, test.checks): 326 test_passed = True 327 break 328 return test_passed 329 330 def run_tests(self, test_file: str, test_cases: TestCasesList, ast: dict, error: str = '') -> bool: 331 """ 332 Takes AST and runs tests on it, returns True if all tests passed 333 """ 334 Log.all(_LOGGER, f'Running {len(test_cases.tests_list)} tests...') 335 failed_tests = 0 336 actual_errors = self.get_actual_errors(error) 337 node_test_passed = True 338 error_test_passed = True 339 warning_test_passed = True 340 tests_set = set(test_cases.tests_list) 341 342 for i, test in enumerate(tests_set): 343 if test.test_type == UtilASTChecker._TestType.NODE: 344 node_test_passed = self.run_node_test(test, ast) 345 elif test.test_type == UtilASTChecker._TestType.ERROR: 346 error_test_passed = self.run_error_test(test_file, test, actual_errors) 347 if test_cases.skip_errors: 348 error_test_passed = True 349 elif test.test_type == UtilASTChecker._TestType.WARNING: 350 warning_test_passed = self.run_error_test(test_file, test, actual_errors) 351 if test_cases.skip_warnings: 352 warning_test_passed = True 353 test_name = f'Test {i + 1}' + ('' if test.name is None else f': {test.name}') 354 if bool(node_test_passed and error_test_passed and warning_test_passed): 355 Log.all(_LOGGER, f'PASS: {test_name}') 356 else: 357 Log.all(_LOGGER, f'FAIL: {test_name} in {test_file}') 358 failed_tests += 1 359 360 for actual_error in actual_errors: 361 if actual_error[0].split()[0] == "Warning:" and not self.check_skip_warning(): 362 Log.all(_LOGGER, f'Unexpected warning {actual_error}') 363 failed_tests += 1 364 if ((actual_error[0].split()[0] == "TypeError:" or actual_error[0].split()[0] == "SyntaxError:") 365 and not self.check_skip_error()): 366 Log.all(_LOGGER, f'Unexpected error {actual_error}') 367 failed_tests += 1 368 369 if failed_tests == 0: 370 Log.all(_LOGGER, 'All tests passed') 371 return True 372 373 Log.all(_LOGGER, f'Failed {failed_tests} tests') 374 return False 375