• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright (C) 2023 The Android Open Source Project
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#      http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import inspect
17import os
18from dataclasses import dataclass
19from typing import List, Union, Callable
20from enum import Enum
21import re
22
23from google.protobuf import text_format
24
25TestName = str
26
27
28@dataclass
29class Path:
30  filename: str
31
32
33@dataclass
34class DataPath(Path):
35  filename: str
36
37
38@dataclass
39class Metric:
40  name: str
41
42
43@dataclass
44class Json:
45  contents: str
46
47
48@dataclass
49class Csv:
50  contents: str
51
52
53@dataclass
54class TextProto:
55  contents: str
56
57
58@dataclass
59class BinaryProto:
60  message_type: str
61  contents: str
62  # Comparing protos is tricky. For example, repeated fields might be written in
63  # any order. To help with that you can specify a `post_processing` function
64  # that will be called with the actual proto message object before converting
65  # it to text representation and doing the comparison with `contents`. This
66  # gives us a chance to e.g. sort messages in a repeated field.
67  post_processing: Callable = text_format.MessageToString
68
69
70@dataclass
71class Systrace:
72  contents: str
73
74
75class TestType(Enum):
76  QUERY = 1
77  METRIC = 2
78
79
80# Blueprint for running the diff test. 'query' is being run over data from the
81# 'trace 'and result will be compared to the 'out. Each test (function in class
82# inheriting from TestSuite) returns a DiffTestBlueprint.
83@dataclass
84class DiffTestBlueprint:
85
86  trace: Union[Path, DataPath, Json, Systrace, TextProto]
87  query: Union[str, Path, DataPath, Metric]
88  out: Union[Path, DataPath, Json, Csv, TextProto, BinaryProto]
89
90  def is_trace_file(self):
91    return isinstance(self.trace, Path)
92
93  def is_trace_textproto(self):
94    return isinstance(self.trace, TextProto)
95
96  def is_trace_json(self):
97    return isinstance(self.trace, Json)
98
99  def is_trace_systrace(self):
100    return isinstance(self.trace, Systrace)
101
102  def is_query_file(self):
103    return isinstance(self.query, Path)
104
105  def is_metric(self):
106    return isinstance(self.query, Metric)
107
108  def is_out_file(self):
109    return isinstance(self.out, Path)
110
111  def is_out_json(self):
112    return isinstance(self.out, Json)
113
114  def is_out_texproto(self):
115    return isinstance(self.out, TextProto)
116
117  def is_out_binaryproto(self):
118    return isinstance(self.out, BinaryProto)
119
120  def is_out_csv(self):
121    return isinstance(self.out, Csv)
122
123
124# Description of a diff test. Created in `fetch_diff_tests()` in
125# TestSuite: each test (function starting with `test_`) returns
126# DiffTestBlueprint and function name is a TestCase name. Used by diff test
127# script.
128class TestCase:
129
130  def __get_query_path(self) -> str:
131    if not self.blueprint.is_query_file():
132      return None
133
134    if isinstance(self.blueprint.query, DataPath):
135      path = os.path.join(self.test_dir, 'data', self.blueprint.query.filename)
136    else:
137      path = os.path.abspath(
138          os.path.join(self.index_dir, self.blueprint.query.filename))
139
140    if not os.path.exists(path):
141      raise AssertionError(
142          f"Query file ({path}) for test '{self.name}' does not exist.")
143    return path
144
145  def __get_trace_path(self) -> str:
146    if not self.blueprint.is_trace_file():
147      return None
148
149    if isinstance(self.blueprint.trace, DataPath):
150      path = os.path.join(self.test_dir, 'data', self.blueprint.trace.filename)
151    else:
152      path = os.path.abspath(
153          os.path.join(self.index_dir, self.blueprint.trace.filename))
154
155    if not os.path.exists(path):
156      raise AssertionError(
157          f"Trace file ({path}) for test '{self.name}' does not exist.")
158    return path
159
160  def __get_out_path(self) -> str:
161    if not self.blueprint.is_out_file():
162      return None
163
164    if isinstance(self.blueprint.out, DataPath):
165      path = os.path.join(self.test_dir, 'data', self.blueprint.out.filename)
166    else:
167      path = os.path.abspath(
168          os.path.join(self.index_dir, self.blueprint.out.filename))
169
170    if not os.path.exists(path):
171      raise AssertionError(
172          f"Out file ({path}) for test '{self.name}' does not exist.")
173    return path
174
175  def __init__(self, name: str, blueprint: DiffTestBlueprint,
176               index_dir: str) -> None:
177    self.name = name
178    self.blueprint = blueprint
179    self.index_dir = index_dir
180    self.test_dir = os.path.dirname(os.path.dirname(os.path.dirname(index_dir)))
181
182    if blueprint.is_metric():
183      self.type = TestType.METRIC
184    else:
185      self.type = TestType.QUERY
186
187    self.query_path = self.__get_query_path()
188    self.trace_path = self.__get_trace_path()
189    self.expected_path = self.__get_out_path()
190
191  # Verifies that the test should be in test suite. If False, test will not be
192  # executed.
193  def validate(self, name_filter: str):
194    query_metric_pattern = re.compile(name_filter)
195    return bool(query_metric_pattern.match(os.path.basename(self.name)))
196
197
198# Virtual class responsible for fetching diff tests.
199# All functions with name starting with `test_` have to return
200# DiffTestBlueprint and function name is a test name. All DiffTestModules have
201# to be included in `test/diff_tests/trace_processor/include_index.py`.
202# `fetch_diff_test` function should not be overwritten.
203class TestSuite:
204
205  def __init__(self, include_index_dir: str, dir_name: str,
206               class_name: str) -> None:
207    self.dir_name = dir_name
208    self.index_dir = os.path.join(include_index_dir, dir_name)
209    self.class_name = class_name
210
211  def __test_name(self, method_name):
212    return f"{self.class_name}:{method_name.split('test_',1)[1]}"
213
214  def fetch(self) -> List['TestCase']:
215    attrs = (getattr(self, name) for name in dir(self))
216    methods = [attr for attr in attrs if inspect.ismethod(attr)]
217    return [
218        TestCase(self.__test_name(method.__name__), method(), self.index_dir)
219        for method in methods
220        if method.__name__.startswith('test_')
221    ]
222