• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3#
4# Copyright (c) 2024 Huawei Device Co., Ltd.
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18import re
19from dataclasses import dataclass
20from pathlib import Path
21from typing import Any, Protocol
22
23import yaml
24from pytest import fixture
25
26from arkdb.compiler import AstParser
27from arkdb.compiler_verification.ast import AstComparator
28from arkdb.compiler_verification.bytecode import BytecodeComparator
29from arkdb.disassembler import ScriptFileDisassembler
30from arkdb.logs import RichLogger
31from arkdb.runnable_module import ScriptFile
32
33
34class MetadataLoader(Protocol):
35    def __call__(
36        self,
37        path: Path,
38    ) -> dict[str, Any]:
39        pass
40
41
42@fixture
43def metadata_loader() -> MetadataLoader:
44    def load(path: Path) -> dict[str, Any]:
45        metadata_pattern = re.compile(r"(?<=\/\*---)(.*?)(?=---\*\/)", flags=re.DOTALL)
46        data = path.read_text()
47        yaml_text = "\n".join(re.findall(metadata_pattern, data))
48        metadata = yaml.safe_load(yaml_text)
49        if metadata is None:
50            metadata = {}
51        return metadata
52
53    return load
54
55
56@dataclass
57class ExpressionComparatorOptions:
58    disable_ast_comparison: bool = False
59    disable_bytecode_comparison: bool = False
60    expected_imports_base: bool = False
61
62
63class OptionsLoader(Protocol):
64    def __call__(
65        self,
66        path: Path,
67    ) -> ExpressionComparatorOptions:
68        pass
69
70
71@fixture
72def eval_options_loader(
73    metadata_loader: MetadataLoader,
74) -> OptionsLoader:
75    def load(path: Path) -> ExpressionComparatorOptions:
76        all_options = metadata_loader(path)
77        eval_options = all_options.get("evaluation", {})
78        if not isinstance(eval_options, dict):
79            raise RuntimeError()
80        return ExpressionComparatorOptions(**eval_options)
81
82    return load
83
84
85class ExpressionFileComparator(Protocol):
86    def __call__(
87        self,
88        base: ScriptFile,
89        expression: ScriptFile,
90        expected: ScriptFile,
91    ):
92        pass
93
94
95@fixture
96def expression_file_comparator(
97    script_disassembler: ScriptFileDisassembler,
98    eval_options_loader: OptionsLoader,
99    log: RichLogger,
100) -> ExpressionFileComparator:
101    def verify(
102        base: ScriptFile,
103        expression: ScriptFile,
104        expected: ScriptFile,
105    ):
106        options = eval_options_loader(base.source_file)
107        expression = script_disassembler.disassemble(expression)
108        expected = script_disassembler.disassemble(expected)
109
110        ast_comparator = AstComparator(expression, expected, options.expected_imports_base)
111        if not options.disable_ast_comparison:
112            try:
113                ast_comparator.compare()
114            except Exception as e:
115                log.warning("Expression AST mismatch: %s", e)
116        if not options.disable_bytecode_comparison:
117            bytecode_comparator = BytecodeComparator(
118                expression,
119                expected,
120                base.source_file.stem,
121                ast_comparator.get_eval_method_name(),
122            )
123            bytecode_comparator.compare()
124
125    return verify
126
127
128class ExpressionVerifier(Protocol):
129
130    def __call__(
131        self,
132        expression: ScriptFile,
133    ) -> None:
134        pass
135
136    # Read-only field
137    @property
138    def ast_parser(self) -> AstParser:
139        pass
140
141
142def get_expression_verifier(
143    verifier: ExpressionFileComparator,
144    base: ScriptFile,
145    expected: ScriptFile,
146    ast_parser: AstParser,
147) -> ExpressionVerifier:
148    class WithExpectedComparator:
149        def __init__(self, parser: AstParser):
150            self._ast_parser = parser
151
152        # CC-OFFNXT(G.CLS.07) followed required interface
153        def __call__(self, expression: ScriptFile):
154            verifier(base, expression, expected)
155
156        @property
157        def ast_parser(self):
158            return self._ast_parser
159
160    return WithExpectedComparator(ast_parser)
161