• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3#
4# Copyright (c) 2024 Huawei Device Co., Ltd.
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18import difflib
19import json
20from copy import deepcopy
21from typing import Any, Callable, Final, TypeVar
22
23from arkdb.compiler_verification import ExpressionEvaluationNames
24from arkdb.runnable_module import ScriptFile
25
26T = TypeVar("T")
27Transformer = Callable[[T], T]
28
29AstNode = dict[str, Any]
30AstNodeOrList = AstNode | list[AstNode]
31
32# NOTE(dslynko): remove this utility when es2panda-generated cctor calls are removed
33CCTOR_PREFIX: Final[str] = "_$trigger_cctor$_for_$"
34
35
36class AstComparisonError(AssertionError):
37    pass
38
39
40class AstComparator:
41    def __init__(self, expression: ScriptFile, expected: ScriptFile, expected_imports_base: bool = False):
42        self.expression = expression
43        self.expected = expected
44        self.expected_imports_base = expected_imports_base
45        self._eval_method_name: str | None = None
46
47    def get_eval_method_name(self) -> str:
48        if self._eval_method_name is None:
49            # Initialize `_eval_method_name` by traversing expression AST.
50            self._prepare_expression_statements()
51        if self._eval_method_name is None:
52            raise RuntimeError()
53        return self._eval_method_name
54
55    def compare(self):
56        expected_stmts, expected_imports = self._prepare_expected_statements()
57        expression_stmts, expression_imports = self._prepare_expression_statements()
58        _compare_ast_statements(expression_stmts, expected_stmts)
59        _compare_ast_import_decls(expression_imports, expected_imports)
60
61    def _prepare_expected_statements(self):
62        statements_list = self.expected.ast.get("statements")
63        if not isinstance(statements_list, list):
64            raise RuntimeError()
65        statements_list = deepcopy(statements_list)
66
67        if self.expected_imports_base:
68            base_test_file_name = self.expected.source_file.name
69            if (idx := base_test_file_name.find(".")) > 0 and idx < len(base_test_file_name) - 1:
70                base_test_file_name = base_test_file_name[:idx]
71
72            statements_filter = _get_import_statements_sources_filter({base_test_file_name: ""})
73            statements_list = list(map(statements_filter, statements_list))
74
75        def _find_prefix_recursively(ast_node: Any, prefix: str):
76            if isinstance(ast_node, dict):
77                if (name := ast_node.get("name")) and isinstance(name, str) and name.startswith(prefix):
78                    return True
79                return any(_find_prefix_recursively(x, prefix) for x in ast_node.values())
80            if isinstance(ast_node, list):
81                return any(_find_prefix_recursively(x, prefix) for x in ast_node)
82            return False
83
84        def _imports_trigger_cctor(x: dict) -> bool:
85            return x.get("type") == "ImportSpecifier" and _find_prefix_recursively(x.get("local"), CCTOR_PREFIX)
86
87        def _remove_cctor_call(ast_node: Any) -> Any:
88            if isinstance(ast_node, dict):
89                if (stmts := ast_node.get("statements")) and isinstance(stmts, list):
90                    ast_node["statements"] = [x for x in stmts if not _find_prefix_recursively(x, CCTOR_PREFIX)]
91                    return ast_node
92                for key, value in ast_node.items():
93                    ast_node[key] = _remove_cctor_call(value)
94            elif isinstance(ast_node, list):
95                return [
96                    _remove_cctor_call(x) for x in ast_node if not (isinstance(x, dict) and _imports_trigger_cctor(x))
97                ]
98            return ast_node
99
100        return _split_statements(statements_list, additional_filters=[_remove_cctor_call])
101
102    def _prepare_expression_statements(self):
103        statements_list = self.expression.ast.get("statements")
104        if not isinstance(statements_list, list):
105            raise RuntimeError()
106        statements_list = deepcopy(statements_list)
107
108        method_names_candidates: set[str] = set()
109
110        def _replace_in_ast_node(ast_node: AstNode):
111            for key, value in ast_node.items():
112                if key == "name" and isinstance(value, str):
113                    if ExpressionEvaluationNames.EVAL_METHOD_GENERATED_NAME_RE.match(value):
114                        method_names_candidates.add(value)
115                        ast_node[key] = ExpressionEvaluationNames.EVAL_PATCH_FUNCTION_NAME
116                    elif ExpressionEvaluationNames.EVAL_METHOD_RETURN_VALUE_RE.match(value):
117                        ast_node[key] = ExpressionEvaluationNames.EVAL_PATCH_RETURN_VALUE
118                else:
119                    ast_node[key] = replace_generated_names(value)
120
121        def replace_generated_names(ast_node: Any) -> Any:
122            if isinstance(ast_node, dict):
123                _replace_in_ast_node(ast_node)
124            elif isinstance(ast_node, list):
125                return list(map(replace_generated_names, ast_node))
126            return ast_node
127
128        stmts, imports = _split_statements(
129            statements_list,
130            additional_filters=[replace_generated_names],
131        )
132
133        method_names_list = list(method_names_candidates)
134        if len(method_names_list) != 1 or method_names_list[0] == "":
135            error_message = f"Failed to find expected evaluation method name; candidates: {method_names_list}"
136            raise AstComparisonError(error_message)
137        self._eval_method_name = method_names_list[0]
138
139        return stmts, imports
140
141
142def _del_locations(ast: Any) -> Any:
143    if isinstance(ast, dict):
144        for key, value in ast.copy().items():
145            if key == "loc":
146                del ast[key]
147            else:
148                ast[key] = _del_locations(value)
149    elif isinstance(ast, list):
150        return list(map(_del_locations, ast))
151    return ast
152
153
154def _split_statements(statements_list: list[AstNode], additional_filters: list[Transformer] | None = None):
155    all_filters = additional_filters if additional_filters else []
156    all_filters.append(_del_locations)
157    for f in all_filters:
158        statements_list = f(statements_list)
159
160    imports_list: list[AstNode] = []
161
162    def filter_lambda(stmt_node: AstNode) -> bool:
163        if stmt_node.get("type") == "ImportDeclaration":
164            imports_list.append(stmt_node)
165            return False
166        return True
167
168    # Collect import declarations in imports_list and remove them from AST.
169    filtered_list = filter(filter_lambda, statements_list)
170    return list(filtered_list), imports_list
171
172
173def _get_import_statements_sources_filter(imports_replacement_map: dict[str, str]) -> Callable[[AstNode], AstNode]:
174    def imports_filter(stmt_node: AstNode) -> AstNode:
175        if stmt_node.get("type") != "ImportDeclaration":
176            return stmt_node
177
178        source_node = stmt_node.get("source")
179        if not isinstance(source_node, dict):
180            raise RuntimeError()
181        import_path = source_node.get("value")
182        if not isinstance(import_path, str):
183            raise RuntimeError()
184
185        # ArkTS paths always have "/" as delimiter.
186        if (pos := import_path.rfind("/")) != -1:
187            pos += 1
188            import_path_name = import_path[pos:]
189            if (new_import_path := imports_replacement_map.get(import_path_name)) is not None:
190                stmt_node["source"]["value"] = new_import_path
191        return stmt_node
192
193    return imports_filter
194
195
196def dump_ast(ast: AstNodeOrList):
197    return json.dumps(ast, indent=4).splitlines()
198
199
200def _compare_ast_statements(patched_output: AstNodeOrList, expected_output: AstNodeOrList):
201    diff = difflib.ndiff(dump_ast(patched_output), dump_ast(expected_output))
202    error_list = [x for x in diff if x[0] not in ExpressionEvaluationNames.NON_DIFF_MARKS]
203    if error_list:
204        raise AstComparisonError("AST comparison failed:\n" + "\n".join(error_list))
205
206
207def _compare_ast_import_decls(expression_imports: list[AstNode], expected_imports: list[AstNode]):
208    expression_specifiers_list: list[AstNode] = []
209    expected_local_name_to_specifier: dict[str, AstNode] = {}
210
211    for import_decl in expression_imports:
212        specifiers_list = import_decl.get("specifiers")
213        if not isinstance(specifiers_list, list):
214            raise RuntimeError()
215        for specifier in specifiers_list:
216            expression_specifiers_list.append(specifier)
217
218    for import_decl in expected_imports:
219        specifiers_list = import_decl.get("specifiers")
220        if not isinstance(specifiers_list, list):
221            raise RuntimeError()
222        for specifier in specifiers_list:
223            local_import_name = specifier.get("local").get("name")
224            expected_local_name_to_specifier[local_import_name] = specifier
225    if len(expected_local_name_to_specifier) != len(expression_specifiers_list):
226        error_report = (
227            f"Imports expected size {len(expected_local_name_to_specifier)}"
228            f" do not match with patch imports size {len(expression_specifiers_list)}"
229        )
230        raise AstComparisonError(error_report)
231
232    for specifier in expression_specifiers_list:
233        local_identifier_node = specifier.get("local")
234        if not isinstance(local_identifier_node, dict):
235            raise RuntimeError()
236        local_import_name = local_identifier_node.get("name")
237        expected_specifier = expected_local_name_to_specifier.get(local_import_name, None)
238        if expected_specifier is None:
239            error_report = f"Patch import specifier {local_import_name} do not contained in expected specifiers"
240            raise AstComparisonError(error_report)
241
242        _compare_ast_statements(specifier, expected_specifier)
243