1#!/usr/bin/env python3 2# Copyright (C) 2023 The Android Open Source Project 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15 16import inspect 17import os 18from dataclasses import dataclass 19from typing import List, Union, Callable 20from enum import Enum 21import re 22 23from google.protobuf import text_format 24 25TestName = str 26 27 28@dataclass 29class Path: 30 filename: str 31 32 33@dataclass 34class DataPath(Path): 35 filename: str 36 37 38@dataclass 39class Metric: 40 name: str 41 42 43@dataclass 44class Json: 45 contents: str 46 47 48@dataclass 49class Csv: 50 contents: str 51 52 53@dataclass 54class TextProto: 55 contents: str 56 57 58@dataclass 59class BinaryProto: 60 message_type: str 61 contents: str 62 # Comparing protos is tricky. For example, repeated fields might be written in 63 # any order. To help with that you can specify a `post_processing` function 64 # that will be called with the actual proto message object before converting 65 # it to text representation and doing the comparison with `contents`. This 66 # gives us a chance to e.g. sort messages in a repeated field. 67 post_processing: Callable = text_format.MessageToString 68 69 70@dataclass 71class Systrace: 72 contents: str 73 74 75class TestType(Enum): 76 QUERY = 1 77 METRIC = 2 78 79 80# Blueprint for running the diff test. 'query' is being run over data from the 81# 'trace 'and result will be compared to the 'out. Each test (function in class 82# inheriting from TestSuite) returns a DiffTestBlueprint. 83@dataclass 84class DiffTestBlueprint: 85 86 trace: Union[Path, DataPath, Json, Systrace, TextProto] 87 query: Union[str, Path, DataPath, Metric] 88 out: Union[Path, DataPath, Json, Csv, TextProto, BinaryProto] 89 90 def is_trace_file(self): 91 return isinstance(self.trace, Path) 92 93 def is_trace_textproto(self): 94 return isinstance(self.trace, TextProto) 95 96 def is_trace_json(self): 97 return isinstance(self.trace, Json) 98 99 def is_trace_systrace(self): 100 return isinstance(self.trace, Systrace) 101 102 def is_query_file(self): 103 return isinstance(self.query, Path) 104 105 def is_metric(self): 106 return isinstance(self.query, Metric) 107 108 def is_out_file(self): 109 return isinstance(self.out, Path) 110 111 def is_out_json(self): 112 return isinstance(self.out, Json) 113 114 def is_out_texproto(self): 115 return isinstance(self.out, TextProto) 116 117 def is_out_binaryproto(self): 118 return isinstance(self.out, BinaryProto) 119 120 def is_out_csv(self): 121 return isinstance(self.out, Csv) 122 123 124# Description of a diff test. Created in `fetch_diff_tests()` in 125# TestSuite: each test (function starting with `test_`) returns 126# DiffTestBlueprint and function name is a TestCase name. Used by diff test 127# script. 128class TestCase: 129 130 def __get_query_path(self) -> str: 131 if not self.blueprint.is_query_file(): 132 return None 133 134 if isinstance(self.blueprint.query, DataPath): 135 path = os.path.join(self.test_dir, 'data', self.blueprint.query.filename) 136 else: 137 path = os.path.abspath( 138 os.path.join(self.index_dir, self.blueprint.query.filename)) 139 140 if not os.path.exists(path): 141 raise AssertionError( 142 f"Query file ({path}) for test '{self.name}' does not exist.") 143 return path 144 145 def __get_trace_path(self) -> str: 146 if not self.blueprint.is_trace_file(): 147 return None 148 149 if isinstance(self.blueprint.trace, DataPath): 150 path = os.path.join(self.test_dir, 'data', self.blueprint.trace.filename) 151 else: 152 path = os.path.abspath( 153 os.path.join(self.index_dir, self.blueprint.trace.filename)) 154 155 if not os.path.exists(path): 156 raise AssertionError( 157 f"Trace file ({path}) for test '{self.name}' does not exist.") 158 return path 159 160 def __get_out_path(self) -> str: 161 if not self.blueprint.is_out_file(): 162 return None 163 164 if isinstance(self.blueprint.out, DataPath): 165 path = os.path.join(self.test_dir, 'data', self.blueprint.out.filename) 166 else: 167 path = os.path.abspath( 168 os.path.join(self.index_dir, self.blueprint.out.filename)) 169 170 if not os.path.exists(path): 171 raise AssertionError( 172 f"Out file ({path}) for test '{self.name}' does not exist.") 173 return path 174 175 def __init__(self, name: str, blueprint: DiffTestBlueprint, 176 index_dir: str) -> None: 177 self.name = name 178 self.blueprint = blueprint 179 self.index_dir = index_dir 180 self.test_dir = os.path.dirname(os.path.dirname(os.path.dirname(index_dir))) 181 182 if blueprint.is_metric(): 183 self.type = TestType.METRIC 184 else: 185 self.type = TestType.QUERY 186 187 self.query_path = self.__get_query_path() 188 self.trace_path = self.__get_trace_path() 189 self.expected_path = self.__get_out_path() 190 191 # Verifies that the test should be in test suite. If False, test will not be 192 # executed. 193 def validate(self, name_filter: str): 194 query_metric_pattern = re.compile(name_filter) 195 return bool(query_metric_pattern.match(os.path.basename(self.name))) 196 197 198# Virtual class responsible for fetching diff tests. 199# All functions with name starting with `test_` have to return 200# DiffTestBlueprint and function name is a test name. All DiffTestModules have 201# to be included in `test/diff_tests/trace_processor/include_index.py`. 202# `fetch_diff_test` function should not be overwritten. 203class TestSuite: 204 205 def __init__(self, include_index_dir: str, dir_name: str, 206 class_name: str) -> None: 207 self.dir_name = dir_name 208 self.index_dir = os.path.join(include_index_dir, dir_name) 209 self.class_name = class_name 210 211 def __test_name(self, method_name): 212 return f"{self.class_name}:{method_name.split('test_',1)[1]}" 213 214 def fetch(self) -> List['TestCase']: 215 attrs = (getattr(self, name) for name in dir(self)) 216 methods = [attr for attr in attrs if inspect.ismethod(attr)] 217 return [ 218 TestCase(self.__test_name(method.__name__), method(), self.index_dir) 219 for method in methods 220 if method.__name__.startswith('test_') 221 ] 222