1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3# 4# Copyright (c) 2024 Huawei Device Co., Ltd. 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# 17 18import difflib 19import re 20from typing import Iterable 21 22from arkdb.compiler_verification import ExpressionEvaluationNames 23from arkdb.runnable_module import ScriptFile 24 25 26class BytecodeComparisonError(AssertionError): 27 pass 28 29 30class BytecodeComparator: 31 EXPECTED_FUNC_DECL_PATTERN = ( 32 r"^\.function [a-zA-Z_\d\$\-\.]+ " 33 f"{ExpressionEvaluationNames.EVAL_PATCH_FUNCTION_NAME}" 34 r"\." 35 f"{ExpressionEvaluationNames.EVAL_PATCH_FUNCTION_NAME}" 36 r"\." 37 f"{ExpressionEvaluationNames.EVAL_PATCH_FUNCTION_NAME}" 38 r"\(\)" 39 ) 40 41 def __init__(self, expression: ScriptFile, expected: ScriptFile, base_module_name: str, eval_method_name: str): 42 self.expression = expression 43 self.expected = expected 44 self.base_module_name = base_module_name 45 self.eval_method_name = eval_method_name 46 self.patch_func_pattern = ( 47 r"^\.function [a-zA-Z_\d\$\-\.]+ " 48 f"{self.eval_method_name}" 49 r"\." 50 f"{self.eval_method_name}" 51 r"\." 52 f"{self.eval_method_name}" 53 r"\(\)" 54 ) 55 56 def compare(self): 57 """ 58 Compares two bytecode files according to disasm output. 59 """ 60 with self.expression.disasm_file.open() as expression_file: 61 with self.expected.disasm_file.open() as expected_file: 62 error_report = "" 63 64 expected_func_body = _fetch_bytecode_function( 65 expected_file.readlines(), 66 BytecodeComparator.EXPECTED_FUNC_DECL_PATTERN, 67 [f"{ExpressionEvaluationNames.EVAL_PATCH_FUNCTION_NAME}.", f"{self.base_module_name}."], 68 ) 69 if not expected_func_body: 70 error_report = "Expected bytecode function was not found or empty." 71 else: 72 # Restore fully qualified function name after prefixes removal. 73 method_name = ExpressionEvaluationNames.EVAL_PATCH_FUNCTION_NAME 74 expected_func_body[0] = expected_func_body[0].replace( 75 f" {method_name}()", 76 f" {method_name}.{method_name}.{method_name}()", 77 ) 78 79 patch_func_body = _fetch_bytecode_function( 80 expression_file.readlines(), 81 self.patch_func_pattern, 82 [f"{self.base_module_name}."], 83 ) 84 if not patch_func_body: 85 error_report += "\tEvaluation patch bytecode function was not found" 86 87 if error_report == "" and len(expected_func_body) != len(patch_func_body): 88 error_report = "Expected and patch bytecode differ in count" 89 90 if error_report != "": 91 raise BytecodeComparisonError(error_report) 92 93 patch_func_body[0] = patch_func_body[0].replace( 94 self.eval_method_name, 95 ExpressionEvaluationNames.EVAL_PATCH_FUNCTION_NAME, 96 ) 97 diff = difflib.ndiff(expected_func_body, patch_func_body) 98 error_list = [x for x in diff if x[0] not in ExpressionEvaluationNames.NON_DIFF_MARKS] 99 if error_list: 100 raise BytecodeComparisonError("Bytecode comparison failed:\n" + "\n".join(error_list)) 101 102 103def _fetch_bytecode_function( 104 bytecode: list[str], 105 function_decl_pattern: str, 106 prefixes: Iterable[str], 107) -> list[str]: 108 func_body: list[str] = [] 109 start_idx: int | None = None 110 111 for idx, line in enumerate(bytecode): 112 if re.match(function_decl_pattern, line): 113 start_idx = idx 114 break 115 116 if start_idx is not None: 117 for line in bytecode[start_idx:]: 118 func_body.append(_remove_prefix(line, prefixes)) 119 if line == "}\n": 120 return func_body 121 122 return [] 123 124 125def _remove_prefix(line: str, prefixes: Iterable[str]) -> str: 126 for p in prefixes: 127 line = line.replace(p, "") 128 return line 129