1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3# 4# Copyright (c) 2024 Huawei Device Co., Ltd. 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# 17 18import re 19from dataclasses import dataclass 20from pathlib import Path 21from typing import Any, Protocol 22 23import yaml 24from pytest import fixture 25 26from arkdb.compiler import AstParser 27from arkdb.compiler_verification.ast import AstComparator 28from arkdb.compiler_verification.bytecode import BytecodeComparator 29from arkdb.disassembler import ScriptFileDisassembler 30from arkdb.logs import RichLogger 31from arkdb.runnable_module import ScriptFile 32 33 34class MetadataLoader(Protocol): 35 def __call__( 36 self, 37 path: Path, 38 ) -> dict[str, Any]: 39 pass 40 41 42@fixture 43def metadata_loader() -> MetadataLoader: 44 def load(path: Path) -> dict[str, Any]: 45 metadata_pattern = re.compile(r"(?<=\/\*---)(.*?)(?=---\*\/)", flags=re.DOTALL) 46 data = path.read_text() 47 yaml_text = "\n".join(re.findall(metadata_pattern, data)) 48 metadata = yaml.safe_load(yaml_text) 49 if metadata is None: 50 metadata = {} 51 return metadata 52 53 return load 54 55 56@dataclass 57class ExpressionComparatorOptions: 58 disable_ast_comparison: bool = False 59 disable_bytecode_comparison: bool = False 60 expected_imports_base: bool = False 61 62 63class OptionsLoader(Protocol): 64 def __call__( 65 self, 66 path: Path, 67 ) -> ExpressionComparatorOptions: 68 pass 69 70 71@fixture 72def eval_options_loader( 73 metadata_loader: MetadataLoader, 74) -> OptionsLoader: 75 def load(path: Path) -> ExpressionComparatorOptions: 76 all_options = metadata_loader(path) 77 eval_options = all_options.get("evaluation", {}) 78 if not isinstance(eval_options, dict): 79 raise RuntimeError() 80 return ExpressionComparatorOptions(**eval_options) 81 82 return load 83 84 85class ExpressionFileComparator(Protocol): 86 def __call__( 87 self, 88 base: ScriptFile, 89 expression: ScriptFile, 90 expected: ScriptFile, 91 ): 92 pass 93 94 95@fixture 96def expression_file_comparator( 97 script_disassembler: ScriptFileDisassembler, 98 eval_options_loader: OptionsLoader, 99 log: RichLogger, 100) -> ExpressionFileComparator: 101 def verify( 102 base: ScriptFile, 103 expression: ScriptFile, 104 expected: ScriptFile, 105 ): 106 options = eval_options_loader(base.source_file) 107 expression = script_disassembler.disassemble(expression) 108 expected = script_disassembler.disassemble(expected) 109 110 ast_comparator = AstComparator(expression, expected, options.expected_imports_base) 111 if not options.disable_ast_comparison: 112 try: 113 ast_comparator.compare() 114 except Exception as e: 115 log.warning("Expression AST mismatch: %s", e) 116 if not options.disable_bytecode_comparison: 117 bytecode_comparator = BytecodeComparator( 118 expression, 119 expected, 120 base.source_file.stem, 121 ast_comparator.get_eval_method_name(), 122 ) 123 bytecode_comparator.compare() 124 125 return verify 126 127 128class ExpressionVerifier(Protocol): 129 130 def __call__( 131 self, 132 expression: ScriptFile, 133 ) -> None: 134 pass 135 136 # Read-only field 137 @property 138 def ast_parser(self) -> AstParser: 139 pass 140 141 142def get_expression_verifier( 143 verifier: ExpressionFileComparator, 144 base: ScriptFile, 145 expected: ScriptFile, 146 ast_parser: AstParser, 147) -> ExpressionVerifier: 148 class WithExpectedComparator: 149 def __init__(self, parser: AstParser): 150 self._ast_parser = parser 151 152 # CC-OFFNXT(G.CLS.07) followed required interface 153 def __call__(self, expression: ScriptFile): 154 verifier(base, expression, expected) 155 156 @property 157 def ast_parser(self): 158 return self._ast_parser 159 160 return WithExpectedComparator(ast_parser) 161