• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3
4# Copyright (c) 2024 Huawei Device Co., Ltd.
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18from __future__ import annotations
19
20import logging
21import json
22from dataclasses import dataclass
23from typing import TextIO, Tuple, Optional, List, Dict, Any
24import re
25from enum import Enum
26import os
27
28from runner.logger import Log
29
30_LOGGER = logging.getLogger('runner.astchecker.util_astchecker')
31
32
33class UtilASTChecker:
34    skip_options = {'SkipErrors': False, 'SkipWarnings': False}
35
36    class _TestType(Enum):
37        NODE = 'Node'
38        ERROR = 'Error'
39        WARNING = 'Warning'
40
41    @dataclass
42    class _Pattern:
43        pattern_type: UtilASTChecker._TestType
44        pattern: str
45        line: int
46        col: int
47        error_file: str = ''
48
49    class _TestCase:
50        """
51        Class for storing test case parsed from test file
52        """
53
54        def __init__(self, name: Optional[str], pattern: UtilASTChecker._Pattern, checks: dict) -> None:
55            self.name = name
56            self.line = pattern.line
57            self.col = pattern.col
58            self.test_type = pattern.pattern_type
59            self.checks = checks
60            self.error_file = pattern.error_file
61
62        def __repr__(self) -> str:
63            return f'TestCase({self.name}, {self.line}:{self.col}, {self.test_type}, {self.checks}, {self.error_file})'
64
65        def __eq__(self, other: Any) -> bool:
66            return bool(self.name == other.name and self.line == other.line and self.col == other.col
67                        and self.test_type == other.test_type and self.checks == other.checks
68                        and self.error_file == other.error_file)
69
70        def __hash__(self) -> int:
71            return hash(self.__repr__())
72
73    class TestCasesList:
74        """
75        Class for storing test cases parsed from one test file
76        """
77
78        def __init__(self, tests_list: set[UtilASTChecker._TestCase]):
79            self.tests_list = tests_list
80            has_error_tests = False
81            has_warning_tests = False
82            for test in tests_list:
83                if test.test_type == UtilASTChecker._TestType.ERROR:
84                    has_error_tests = True
85                if test.test_type == UtilASTChecker._TestType.WARNING:
86                    has_warning_tests = True
87            self.has_error_tests = has_error_tests
88            self.has_warning_tests = has_warning_tests
89            self.skip_errors = False
90            self.skip_warnings = False
91
92    def __init__(self) -> None:
93        self.regex = re.compile(r'/\*\s*@@\s*(?P<pattern>.*?)\s*\*/', re.DOTALL)
94        self.reset_skips()
95
96    @staticmethod
97    def create_test_case(name: Optional[str], pattern: UtilASTChecker._Pattern) -> UtilASTChecker._TestCase:
98        pattern_parsed = {'error': pattern.pattern}
99        if pattern.pattern_type == UtilASTChecker._TestType.NODE:
100            try:
101                pattern_parsed = json.loads(pattern.pattern)
102            except json.JSONDecodeError as ex:
103                Log.exception_and_raise(
104                    _LOGGER,
105                    f'TestCase: {name}.\nThrows JSON error: {ex}.\nJSON data: {pattern.pattern}')
106        return UtilASTChecker._TestCase(name, pattern, pattern_parsed)
107
108    @staticmethod
109    def get_match_location(match: re.Match, start: bool = False) -> Tuple[int, int]:
110        """
111        Returns match location in file: line and column (counting from 1)
112        """
113        match_idx = match.start() if start else match.end()
114        line_start_index = match.string[:match_idx].rfind('\n') + 1
115        col = match_idx - line_start_index + 1
116        line = match.string[:match_idx].count('\n') + 1
117        return line, col
118
119    @staticmethod
120    def check_properties(node: dict, properties: dict) -> bool:
121        """
122        Checks if `node` contains all key:value pairs specified in `properties` dict argument
123        """
124        for key, value in properties.items():
125            if node.get(key) != value:
126                return False
127        return True
128
129    @staticmethod
130    def run_error_test(test_file: str, test: _TestCase, actual_errors: set) -> bool:
131        file_name = os.path.basename(test_file) if test.error_file == '' else test.error_file
132        expected_error = f'{test.checks["error"]}', f'[{file_name}:{test.line}:{test.col}]'
133        if expected_error in actual_errors:
134            actual_errors.remove(expected_error)
135            return True
136        Log.all(_LOGGER, f'No Expected error {expected_error}')
137        return False
138
139    @staticmethod
140    def get_actual_errors(error: str) -> set:
141        actual_errors = set()
142        for error_str in error.splitlines():
143            if error_str.strip():
144                error_text, error_loc = error_str.rsplit(' ', 1)
145                actual_errors.add((error_text.strip(), error_loc))
146        return actual_errors
147
148    def reset_skips(self) -> None:
149        self.skip_options['SkipErrors'] = False
150        self.skip_options['SkipWarnings'] = False
151
152    def check_skip_error(self) -> bool:
153        return self.skip_options["SkipErrors"]
154
155    def check_skip_warning(self) -> bool:
156        return self.skip_options["SkipWarnings"]
157
158    def parse_define_statement(self, match: re.Match[str],
159                               link_defs_map: Dict[str, Tuple[UtilASTChecker._TestType, str]],
160                               link_sources_map: Dict[str, re.Match[str]]) -> Optional[UtilASTChecker._TestCase]:
161        """
162        Parses `@<id> <pattern-type> <pattern>`
163        """
164        match_str = re.sub(r'\s+', ' ', match.group('pattern'))[1:].strip()
165        sep1 = match_str.find(' ')
166        sep2 = match_str.find(' ', sep1 + 1)
167        if sep1 == -1 or sep2 == -1:
168            Log.exception_and_raise(_LOGGER, 'Wrong definition format: expected '
169                                             f'`/* @@@ <id> <pattern-type> <pattern> */`, got /* @@@ {match_str} */')
170        name = match_str[:sep1]
171        pattern_type = UtilASTChecker._TestType(match_str[sep1 + 1:sep2])
172        pattern = match_str[sep2 + 1:]
173
174        if name in link_defs_map:
175            line, col = self.get_match_location(match)
176            Log.exception_and_raise(_LOGGER, f'Link {name} (at location {line}:{col}) is already defined')
177
178        if name in link_sources_map:
179            match = link_sources_map[name]
180            del link_sources_map[name]
181            line, col = self.get_match_location(match)
182            return self.create_test_case(name, UtilASTChecker._Pattern(pattern_type, pattern, line, col))
183
184        link_defs_map[name] = (pattern_type, pattern)
185        return None
186
187    def parse_match_statement(self, match: re.Match[str],
188                              link_defs_map: Dict[str, Tuple[UtilASTChecker._TestType, str]],
189                              link_sources_map: Dict[str, re.Match[str]]) -> Optional[UtilASTChecker._TestCase]:
190        """
191        Parses `<pattern-type> <pattern>` and `<id>`
192        """
193        str_match = match.group('pattern')
194        sep = str_match.find(' ')
195        if sep == -1:
196            # parse `<id>`
197            name = str_match
198            if any(not char.isalnum() and char != '_' for char in name):
199                Log.exception_and_raise(_LOGGER, f'Bad `<id>` value, expected value from `[a-zA-Z0-9_]+`, got {name}')
200
201            if name in link_sources_map:
202                line, col = self.get_match_location(match)
203                Log.exception_and_raise(_LOGGER, f'Link {name} (at location {line}:{col}) is already defined')
204
205            if name in link_defs_map:
206                pattern_type, pattern = link_defs_map[name]
207                del link_defs_map[name]
208                line, col = self.get_match_location(match)
209                return self.create_test_case(name, UtilASTChecker._Pattern(pattern_type, pattern, line, col))
210
211            link_sources_map[name] = match
212            return None
213
214        # parse `<pattern-type> <pattern>`
215        pattern_type = UtilASTChecker._TestType(str_match[:sep])
216        pattern = str_match[sep + 1:]
217        line, col = self.get_match_location(match)
218        return self.create_test_case(None, UtilASTChecker._Pattern(pattern_type, pattern, line, col))
219
220    def parse_skip_statement(self, match: re.Match[str]) -> None:
221        """
222        Parses `# <skip-option>*`
223        """
224        match_str = re.sub(r'\s+', ' ', match.group('pattern'))[1:].strip()
225        pattern = r"(\w+)\s*=\s*(\w+)"
226        matches = re.findall(pattern, match_str)
227        for opt, val in matches:
228            value = val.lower()
229            if opt not in self.skip_options or value not in ['true', 'false']:
230                Log.exception_and_raise(_LOGGER, 'Wrong match_at_location format: expected '
231                                                 f'`/* @@# <skip-option> */`, got /* @@? {match_str} */')
232            self.skip_options[opt] = value == 'true'
233
234    def parse_match_at_loc_statement(self, match: re.Match[str]) -> UtilASTChecker._TestCase:
235        """
236        Parses `? <line>:<col> <pattern-type> <pattern>`
237        and  `? <file-name>:<line>:<col> <pattern-type> <pattern>`
238        """
239        match_str = re.sub(r'\s+', ' ', match.group('pattern'))[1:].strip()
240        sep1 = match_str.find(' ')
241        sep2 = match_str.find(' ', sep1 + 1)
242        if sep1 == -1 or sep2 == -1:
243            Log.exception_and_raise(_LOGGER, 'Wrong match_at_location format: expected '
244                                             '`/* @@? <line>:<col> <pattern-type> <pattern> */`, '
245                                             f'got /* @@? {match_str} */')
246        location = match_str[:sep1]
247        loc_sep1 = location.find(':')
248        loc_sep2 = location.find(':', loc_sep1 + 1)
249        if loc_sep1 == -1:
250            Log.exception_and_raise(_LOGGER, 'Wrong match_at_location format: expected '
251                                             '`/* @@? <line>:<col> <pattern-type> <pattern> */`, '
252                                             f'got /* @@? {match_str} */')
253        line_str = location[:loc_sep1] if loc_sep2 == -1 else location[loc_sep1 + 1:loc_sep2]
254        col_str = location[loc_sep1 + 1:] if loc_sep2 == -1 else location[loc_sep2 + 1:]
255        error_file = '' if loc_sep2 == -1 else location[:loc_sep1]
256        if loc_sep2 != -1:
257            pass
258        if not line_str.isdigit():
259            Log.exception_and_raise(_LOGGER, f'Expected line number, got {line_str}')
260        if not col_str.isdigit():
261            Log.exception_and_raise(_LOGGER, f'Expected column number, got {col_str}')
262
263        line, col = int(line_str), int(col_str)
264        pattern_type = UtilASTChecker._TestType(match_str[sep1 + 1:sep2])
265        pattern = match_str[sep2 + 1:]
266
267        return self.create_test_case(
268            name=None,
269            pattern=UtilASTChecker._Pattern(pattern_type, pattern, line, col, error_file)
270        )
271
272    def parse_tests(self, file: TextIO) -> UtilASTChecker.TestCasesList:
273        """
274        Takes .ets file with tests and parses them into a list of TestCases.
275        """
276        self.reset_skips()
277        test_text = file.read()
278        link_defs_map: Dict[str, Tuple[UtilASTChecker._TestType, str]] = {}
279        link_sources_map: Dict[str, re.Match[str]] = {}
280        test_cases = set()
281        matches = list(re.finditer(self.regex, test_text))
282        for match in matches:
283            pattern = match.group('pattern')
284            if pattern.startswith('@'):
285                test = self.parse_define_statement(match, link_defs_map, link_sources_map)
286            elif pattern.startswith('?'):
287                test = self.parse_match_at_loc_statement(match)
288            elif pattern.startswith('#'):
289                test = None
290                self.parse_skip_statement(match)
291            else:
292                test = self.parse_match_statement(match, link_defs_map, link_sources_map)
293            if test is not None:
294                test_cases.add(test)
295
296        if len(link_defs_map) or len(link_sources_map):
297            Log.exception_and_raise(_LOGGER, 'link defined twice')
298
299        test_case_list = UtilASTChecker.TestCasesList(test_cases)
300        test_case_list.skip_errors = self.check_skip_error()
301        test_case_list.skip_warnings = self.check_skip_warning()
302        return test_case_list
303
304    def find_nodes_by_start_location(self, root: dict, line: int, col: int) -> List[dict]:
305        """
306        Finds all descendants of `root` with location starting at `loc`
307        """
308        nodes = []
309        start = root.get('loc', {}).get('start', {})
310        if start.get('line', None) == line and start.get('column', None) == col:
311            nodes.append(root)
312
313        for child in root.values():
314            if isinstance(child, dict):
315                nodes.extend(self.find_nodes_by_start_location(child, line, col))
316            if isinstance(child, list):
317                for item in child:
318                    nodes.extend(self.find_nodes_by_start_location(item, line, col))
319        return nodes
320
321    def run_node_test(self, test: _TestCase, ast: dict) -> bool:
322        nodes_by_loc = self.find_nodes_by_start_location(ast, test.line, test.col)
323        test_passed = False
324        for node in nodes_by_loc:
325            if self.check_properties(node, test.checks):
326                test_passed = True
327                break
328        return test_passed
329
330    def run_tests(self, test_file: str, test_cases: TestCasesList, ast: dict, error: str = '') -> bool:
331        """
332        Takes AST and runs tests on it, returns True if all tests passed
333        """
334        Log.all(_LOGGER, f'Running {len(test_cases.tests_list)} tests...')
335        failed_tests = 0
336        actual_errors = self.get_actual_errors(error)
337        node_test_passed = True
338        error_test_passed = True
339        warning_test_passed = True
340        tests_set = set(test_cases.tests_list)
341
342        for i, test in enumerate(tests_set):
343            if test.test_type == UtilASTChecker._TestType.NODE:
344                node_test_passed = self.run_node_test(test, ast)
345            elif test.test_type == UtilASTChecker._TestType.ERROR:
346                error_test_passed = self.run_error_test(test_file, test, actual_errors)
347                if test_cases.skip_errors:
348                    error_test_passed = True
349            elif test.test_type == UtilASTChecker._TestType.WARNING:
350                warning_test_passed = self.run_error_test(test_file, test, actual_errors)
351                if test_cases.skip_warnings:
352                    warning_test_passed = True
353            test_name = f'Test {i + 1}' + ('' if test.name is None else f': {test.name}')
354            if bool(node_test_passed and error_test_passed and warning_test_passed):
355                Log.all(_LOGGER, f'PASS: {test_name}')
356            else:
357                Log.all(_LOGGER, f'FAIL: {test_name} in {test_file}')
358                failed_tests += 1
359
360        for actual_error in actual_errors:
361            if actual_error[0].split()[0] == "Warning:" and not self.check_skip_warning():
362                Log.all(_LOGGER, f'Unexpected warning {actual_error}')
363                failed_tests += 1
364            if ((actual_error[0].split()[0] == "TypeError:" or actual_error[0].split()[0] == "SyntaxError:")
365                    and not self.check_skip_error()):
366                Log.all(_LOGGER, f'Unexpected error {actual_error}')
367                failed_tests += 1
368
369        if failed_tests == 0:
370            Log.all(_LOGGER, 'All tests passed')
371            return True
372
373        Log.all(_LOGGER, f'Failed {failed_tests} tests')
374        return False
375