1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3# 4# Copyright (c) 2024 Huawei Device Co., Ltd. 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# 17 18import difflib 19import json 20from copy import deepcopy 21from typing import Any, Callable, Final, TypeVar 22 23from arkdb.compiler_verification import ExpressionEvaluationNames 24from arkdb.runnable_module import ScriptFile 25 26T = TypeVar("T") 27Transformer = Callable[[T], T] 28 29AstNode = dict[str, Any] 30AstNodeOrList = AstNode | list[AstNode] 31 32# NOTE(dslynko): remove this utility when es2panda-generated cctor calls are removed 33CCTOR_PREFIX: Final[str] = "_$trigger_cctor$_for_$" 34 35 36class AstComparisonError(AssertionError): 37 pass 38 39 40class AstComparator: 41 def __init__(self, expression: ScriptFile, expected: ScriptFile, expected_imports_base: bool = False): 42 self.expression = expression 43 self.expected = expected 44 self.expected_imports_base = expected_imports_base 45 self._eval_method_name: str | None = None 46 47 def get_eval_method_name(self) -> str: 48 if self._eval_method_name is None: 49 # Initialize `_eval_method_name` by traversing expression AST. 50 self._prepare_expression_statements() 51 if self._eval_method_name is None: 52 raise RuntimeError() 53 return self._eval_method_name 54 55 def compare(self): 56 expected_stmts, expected_imports = self._prepare_expected_statements() 57 expression_stmts, expression_imports = self._prepare_expression_statements() 58 _compare_ast_statements(expression_stmts, expected_stmts) 59 _compare_ast_import_decls(expression_imports, expected_imports) 60 61 def _prepare_expected_statements(self): 62 statements_list = self.expected.ast.get("statements") 63 if not isinstance(statements_list, list): 64 raise RuntimeError() 65 statements_list = deepcopy(statements_list) 66 67 if self.expected_imports_base: 68 base_test_file_name = self.expected.source_file.name 69 if (idx := base_test_file_name.find(".")) > 0 and idx < len(base_test_file_name) - 1: 70 base_test_file_name = base_test_file_name[:idx] 71 72 statements_filter = _get_import_statements_sources_filter({base_test_file_name: ""}) 73 statements_list = list(map(statements_filter, statements_list)) 74 75 def _find_prefix_recursively(ast_node: Any, prefix: str): 76 if isinstance(ast_node, dict): 77 if (name := ast_node.get("name")) and isinstance(name, str) and name.startswith(prefix): 78 return True 79 return any(_find_prefix_recursively(x, prefix) for x in ast_node.values()) 80 if isinstance(ast_node, list): 81 return any(_find_prefix_recursively(x, prefix) for x in ast_node) 82 return False 83 84 def _imports_trigger_cctor(x: dict) -> bool: 85 return x.get("type") == "ImportSpecifier" and _find_prefix_recursively(x.get("local"), CCTOR_PREFIX) 86 87 def _remove_cctor_call(ast_node: Any) -> Any: 88 if isinstance(ast_node, dict): 89 if (stmts := ast_node.get("statements")) and isinstance(stmts, list): 90 ast_node["statements"] = [x for x in stmts if not _find_prefix_recursively(x, CCTOR_PREFIX)] 91 return ast_node 92 for key, value in ast_node.items(): 93 ast_node[key] = _remove_cctor_call(value) 94 elif isinstance(ast_node, list): 95 return [ 96 _remove_cctor_call(x) for x in ast_node if not (isinstance(x, dict) and _imports_trigger_cctor(x)) 97 ] 98 return ast_node 99 100 return _split_statements(statements_list, additional_filters=[_remove_cctor_call]) 101 102 def _prepare_expression_statements(self): 103 statements_list = self.expression.ast.get("statements") 104 if not isinstance(statements_list, list): 105 raise RuntimeError() 106 statements_list = deepcopy(statements_list) 107 108 method_names_candidates: set[str] = set() 109 110 def _replace_in_ast_node(ast_node: AstNode): 111 for key, value in ast_node.items(): 112 if key == "name" and isinstance(value, str): 113 if ExpressionEvaluationNames.EVAL_METHOD_GENERATED_NAME_RE.match(value): 114 method_names_candidates.add(value) 115 ast_node[key] = ExpressionEvaluationNames.EVAL_PATCH_FUNCTION_NAME 116 elif ExpressionEvaluationNames.EVAL_METHOD_RETURN_VALUE_RE.match(value): 117 ast_node[key] = ExpressionEvaluationNames.EVAL_PATCH_RETURN_VALUE 118 else: 119 ast_node[key] = replace_generated_names(value) 120 121 def replace_generated_names(ast_node: Any) -> Any: 122 if isinstance(ast_node, dict): 123 _replace_in_ast_node(ast_node) 124 elif isinstance(ast_node, list): 125 return list(map(replace_generated_names, ast_node)) 126 return ast_node 127 128 stmts, imports = _split_statements( 129 statements_list, 130 additional_filters=[replace_generated_names], 131 ) 132 133 method_names_list = list(method_names_candidates) 134 if len(method_names_list) != 1 or method_names_list[0] == "": 135 error_message = f"Failed to find expected evaluation method name; candidates: {method_names_list}" 136 raise AstComparisonError(error_message) 137 self._eval_method_name = method_names_list[0] 138 139 return stmts, imports 140 141 142def _del_locations(ast: Any) -> Any: 143 if isinstance(ast, dict): 144 for key, value in ast.copy().items(): 145 if key == "loc": 146 del ast[key] 147 else: 148 ast[key] = _del_locations(value) 149 elif isinstance(ast, list): 150 return list(map(_del_locations, ast)) 151 return ast 152 153 154def _split_statements(statements_list: list[AstNode], additional_filters: list[Transformer] | None = None): 155 all_filters = additional_filters if additional_filters else [] 156 all_filters.append(_del_locations) 157 for f in all_filters: 158 statements_list = f(statements_list) 159 160 imports_list: list[AstNode] = [] 161 162 def filter_lambda(stmt_node: AstNode) -> bool: 163 if stmt_node.get("type") == "ImportDeclaration": 164 imports_list.append(stmt_node) 165 return False 166 return True 167 168 # Collect import declarations in imports_list and remove them from AST. 169 filtered_list = filter(filter_lambda, statements_list) 170 return list(filtered_list), imports_list 171 172 173def _get_import_statements_sources_filter(imports_replacement_map: dict[str, str]) -> Callable[[AstNode], AstNode]: 174 def imports_filter(stmt_node: AstNode) -> AstNode: 175 if stmt_node.get("type") != "ImportDeclaration": 176 return stmt_node 177 178 source_node = stmt_node.get("source") 179 if not isinstance(source_node, dict): 180 raise RuntimeError() 181 import_path = source_node.get("value") 182 if not isinstance(import_path, str): 183 raise RuntimeError() 184 185 # ArkTS paths always have "/" as delimiter. 186 if (pos := import_path.rfind("/")) != -1: 187 pos += 1 188 import_path_name = import_path[pos:] 189 if (new_import_path := imports_replacement_map.get(import_path_name)) is not None: 190 stmt_node["source"]["value"] = new_import_path 191 return stmt_node 192 193 return imports_filter 194 195 196def dump_ast(ast: AstNodeOrList): 197 return json.dumps(ast, indent=4).splitlines() 198 199 200def _compare_ast_statements(patched_output: AstNodeOrList, expected_output: AstNodeOrList): 201 diff = difflib.ndiff(dump_ast(patched_output), dump_ast(expected_output)) 202 error_list = [x for x in diff if x[0] not in ExpressionEvaluationNames.NON_DIFF_MARKS] 203 if error_list: 204 raise AstComparisonError("AST comparison failed:\n" + "\n".join(error_list)) 205 206 207def _compare_ast_import_decls(expression_imports: list[AstNode], expected_imports: list[AstNode]): 208 expression_specifiers_list: list[AstNode] = [] 209 expected_local_name_to_specifier: dict[str, AstNode] = {} 210 211 for import_decl in expression_imports: 212 specifiers_list = import_decl.get("specifiers") 213 if not isinstance(specifiers_list, list): 214 raise RuntimeError() 215 for specifier in specifiers_list: 216 expression_specifiers_list.append(specifier) 217 218 for import_decl in expected_imports: 219 specifiers_list = import_decl.get("specifiers") 220 if not isinstance(specifiers_list, list): 221 raise RuntimeError() 222 for specifier in specifiers_list: 223 local_import_name = specifier.get("local").get("name") 224 expected_local_name_to_specifier[local_import_name] = specifier 225 if len(expected_local_name_to_specifier) != len(expression_specifiers_list): 226 error_report = ( 227 f"Imports expected size {len(expected_local_name_to_specifier)}" 228 f" do not match with patch imports size {len(expression_specifiers_list)}" 229 ) 230 raise AstComparisonError(error_report) 231 232 for specifier in expression_specifiers_list: 233 local_identifier_node = specifier.get("local") 234 if not isinstance(local_identifier_node, dict): 235 raise RuntimeError() 236 local_import_name = local_identifier_node.get("name") 237 expected_specifier = expected_local_name_to_specifier.get(local_import_name, None) 238 if expected_specifier is None: 239 error_report = f"Patch import specifier {local_import_name} do not contained in expected specifiers" 240 raise AstComparisonError(error_report) 241 242 _compare_ast_statements(specifier, expected_specifier) 243