• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3#
4# Copyright (c) 2024 Huawei Device Co., Ltd.
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18import difflib
19import re
20from typing import Iterable
21
22from arkdb.compiler_verification import ExpressionEvaluationNames
23from arkdb.runnable_module import ScriptFile
24
25
26class BytecodeComparisonError(AssertionError):
27    pass
28
29
30class BytecodeComparator:
31    EXPECTED_FUNC_DECL_PATTERN = (
32        r"^\.function [a-zA-Z_\d\$\-\.]+ "
33        f"{ExpressionEvaluationNames.EVAL_PATCH_FUNCTION_NAME}"
34        r"\."
35        f"{ExpressionEvaluationNames.EVAL_PATCH_FUNCTION_NAME}"
36        r"\."
37        f"{ExpressionEvaluationNames.EVAL_PATCH_FUNCTION_NAME}"
38        r"\(\)"
39    )
40
41    def __init__(self, expression: ScriptFile, expected: ScriptFile, base_module_name: str, eval_method_name: str):
42        self.expression = expression
43        self.expected = expected
44        self.base_module_name = base_module_name
45        self.eval_method_name = eval_method_name
46        self.patch_func_pattern = (
47            r"^\.function [a-zA-Z_\d\$\-\.]+ "
48            f"{self.eval_method_name}"
49            r"\."
50            f"{self.eval_method_name}"
51            r"\."
52            f"{self.eval_method_name}"
53            r"\(\)"
54        )
55
56    def compare(self):
57        """
58        Compares two bytecode files according to disasm output.
59        """
60        with self.expression.disasm_file.open() as expression_file:
61            with self.expected.disasm_file.open() as expected_file:
62                error_report = ""
63
64                expected_func_body = _fetch_bytecode_function(
65                    expected_file.readlines(),
66                    BytecodeComparator.EXPECTED_FUNC_DECL_PATTERN,
67                    [f"{ExpressionEvaluationNames.EVAL_PATCH_FUNCTION_NAME}.", f"{self.base_module_name}."],
68                )
69                if not expected_func_body:
70                    error_report = "Expected bytecode function was not found or empty."
71                else:
72                    # Restore fully qualified function name after prefixes removal.
73                    method_name = ExpressionEvaluationNames.EVAL_PATCH_FUNCTION_NAME
74                    expected_func_body[0] = expected_func_body[0].replace(
75                        f" {method_name}()",
76                        f" {method_name}.{method_name}.{method_name}()",
77                    )
78
79                patch_func_body = _fetch_bytecode_function(
80                    expression_file.readlines(),
81                    self.patch_func_pattern,
82                    [f"{self.base_module_name}."],
83                )
84                if not patch_func_body:
85                    error_report += "\tEvaluation patch bytecode function was not found"
86
87                if error_report == "" and len(expected_func_body) != len(patch_func_body):
88                    error_report = "Expected and patch bytecode differ in count"
89
90                if error_report != "":
91                    raise BytecodeComparisonError(error_report)
92
93                patch_func_body[0] = patch_func_body[0].replace(
94                    self.eval_method_name,
95                    ExpressionEvaluationNames.EVAL_PATCH_FUNCTION_NAME,
96                )
97                diff = difflib.ndiff(expected_func_body, patch_func_body)
98                error_list = [x for x in diff if x[0] not in ExpressionEvaluationNames.NON_DIFF_MARKS]
99                if error_list:
100                    raise BytecodeComparisonError("Bytecode comparison failed:\n" + "\n".join(error_list))
101
102
103def _fetch_bytecode_function(
104    bytecode: list[str],
105    function_decl_pattern: str,
106    prefixes: Iterable[str],
107) -> list[str]:
108    func_body: list[str] = []
109    start_idx: int | None = None
110
111    for idx, line in enumerate(bytecode):
112        if re.match(function_decl_pattern, line):
113            start_idx = idx
114            break
115
116    if start_idx is not None:
117        for line in bytecode[start_idx:]:
118            func_body.append(_remove_prefix(line, prefixes))
119            if line == "}\n":
120                return func_body
121
122    return []
123
124
125def _remove_prefix(line: str, prefixes: Iterable[str]) -> str:
126    for p in prefixes:
127        line = line.replace(p, "")
128    return line
129