• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright (C) 2023 The Android Open Source Project
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#      http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import concurrent.futures
17import datetime
18import difflib
19import os
20import subprocess
21import sys
22import tempfile
23from binascii import unhexlify
24from dataclasses import dataclass
25from typing import List, Tuple, Optional
26
27from google.protobuf import text_format, message_factory, descriptor_pool
28from python.generators.diff_tests.testing import TestCase, TestType, BinaryProto
29from python.generators.diff_tests.utils import (
30    ColorFormatter, create_message_factory, get_env, get_trace_descriptor_path,
31    read_all_tests, serialize_python_trace, serialize_textproto_trace,
32    modify_trace)
33
34ROOT_DIR = os.path.dirname(
35    os.path.dirname(
36        os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
37
38
39# Performance result of running the test.
40@dataclass
41class PerfResult:
42  test: TestCase
43  ingest_time_ns: int
44  real_time_ns: int
45
46  def __init__(self, test: TestCase, perf_lines: List[str]):
47    self.test = test
48
49    assert len(perf_lines) == 1
50    perf_numbers = perf_lines[0].split(',')
51
52    assert len(perf_numbers) == 2
53    self.ingest_time_ns = int(perf_numbers[0])
54    self.real_time_ns = int(perf_numbers[1])
55
56
57# Data gathered from running the test.
58@dataclass
59class TestResult:
60  test: TestCase
61  trace: str
62  cmd: List[str]
63  expected: str
64  actual: str
65  passed: bool
66  stderr: str
67  exit_code: int
68  perf_result: Optional[PerfResult]
69
70  def __init__(self, test: TestCase, gen_trace_path: str, cmd: List[str],
71               expected_text: str, actual_text: str, stderr: str,
72               exit_code: int, perf_lines: List[str]) -> None:
73    self.test = test
74    self.trace = gen_trace_path
75    self.cmd = cmd
76    self.stderr = stderr
77    self.exit_code = exit_code
78
79    # For better string formatting we often add whitespaces, which has to now
80    # be removed.
81    def strip_whitespaces(text: str):
82      no_front_new_line_text = text.lstrip('\n')
83      return '\n'.join(s.strip() for s in no_front_new_line_text.split('\n'))
84
85    self.expected = strip_whitespaces(expected_text)
86    self.actual = strip_whitespaces(actual_text)
87
88    expected_content = self.expected.replace('\r\n', '\n')
89
90    actual_content = self.actual.replace('\r\n', '\n')
91    self.passed = (expected_content == actual_content)
92
93    if self.exit_code == 0:
94      self.perf_result = PerfResult(self.test, perf_lines)
95    else:
96      self.perf_result = None
97
98  def write_diff(self):
99    expected_lines = self.expected.splitlines(True)
100    actual_lines = self.actual.splitlines(True)
101    diff = difflib.unified_diff(
102        expected_lines, actual_lines, fromfile='expected', tofile='actual')
103    return "".join(list(diff))
104
105  def rebase(self, rebase) -> str:
106    if not rebase or self.passed:
107      return ""
108    if not self.test.blueprint.is_out_file():
109      return f"Can't rebase expected results passed as strings.\n"
110    if self.exit_code != 0:
111      return f"Rebase failed for {self.test.name} as query failed\n"
112
113    with open(self.test.expected_path, 'w') as f:
114      f.write(self.actual)
115    return f"Rebasing {self.test.name}\n"
116
117
118# Results of running the test suite. Mostly used for printing aggregated
119# results.
120@dataclass
121class TestResults:
122  test_failures: List[str]
123  perf_data: List[PerfResult]
124  rebased: List[str]
125  test_time_ms: int
126
127  def str(self, no_colors: bool, tests_no: int):
128    c = ColorFormatter(no_colors)
129    res = (
130        f"[==========] {tests_no} tests ran. ({self.test_time_ms} ms total)\n"
131        f"{c.green('[  PASSED  ]')} "
132        f"{tests_no - len(self.test_failures)} tests.\n")
133    if len(self.test_failures) > 0:
134      res += (f"{c.red('[  FAILED  ]')} "
135              f"{len(self.test_failures)} tests.\n")
136      for failure in self.test_failures:
137        res += f"{c.red('[  FAILED  ]')} {failure}\n"
138    return res
139
140  def rebase_str(self):
141    res = f"\n[  REBASED  ] {len(self.rebased)} tests.\n"
142    for name in self.rebased:
143      res += f"[  REBASED  ] {name}\n"
144    return res
145
146
147# Responsible for executing singular diff test.
148@dataclass
149class TestCaseRunner:
150  test: TestCase
151  trace_processor_path: str
152  trace_descriptor_path: str
153  colors: ColorFormatter
154  override_sql_module_paths: List[str]
155
156  def __output_to_text_proto(self, actual: str, out: BinaryProto) -> str:
157    """Deserializes a binary proto and returns its text representation.
158
159  Args:
160    actual: (string) HEX encoded serialized proto message
161    message_type: (string) Message type
162
163  Returns:
164    Text proto
165  """
166    try:
167      raw_data = unhexlify(actual.splitlines()[-1][1:-1])
168      out_path = os.path.dirname(self.trace_processor_path)
169      descriptor_paths = [
170          f.path
171          for f in os.scandir(
172              os.path.join(ROOT_DIR, out_path, 'gen', 'protos', 'perfetto',
173                           'trace_processor'))
174          if f.is_file() and os.path.splitext(f.name)[1] == '.descriptor'
175      ]
176      descriptor_paths.append(
177          os.path.join(ROOT_DIR, out_path, 'gen', 'protos', 'third_party',
178                       'pprof', 'profile.descriptor'))
179      proto = create_message_factory(descriptor_paths, out.message_type)()
180      proto.ParseFromString(raw_data)
181      try:
182        return out.post_processing(proto)
183      except:
184        return '<Proto post processing failed>'
185      return text_format.MessageToString(proto)
186    except:
187      return '<Invalid input for proto deserializaiton>'
188
189  def __run_metrics_test(self, trace_path: str,
190                         metrics_message_factory) -> TestResult:
191
192    if self.test.blueprint.is_out_file():
193      with open(self.test.expected_path, 'r') as expected_file:
194        expected = expected_file.read()
195    else:
196      expected = self.test.blueprint.out.contents
197
198    tmp_perf_file = tempfile.NamedTemporaryFile(delete=False)
199    is_json_output_file = self.test.blueprint.is_out_file(
200    ) and os.path.basename(self.test.expected_path).endswith('.json.out')
201    is_json_output = is_json_output_file or self.test.blueprint.is_out_json()
202    cmd = [
203        self.trace_processor_path,
204        '--analyze-trace-proto-content',
205        '--crop-track-events',
206        '--run-metrics',
207        self.test.blueprint.query.name,
208        '--metrics-output=%s' % ('json' if is_json_output else 'binary'),
209        '--perf-file',
210        tmp_perf_file.name,
211        trace_path,
212    ]
213    for sql_module_path in self.override_sql_module_paths:
214      cmd += ['--override-sql-module', sql_module_path]
215    tp = subprocess.Popen(
216        cmd,
217        stdout=subprocess.PIPE,
218        stderr=subprocess.PIPE,
219        env=get_env(ROOT_DIR))
220    (stdout, stderr) = tp.communicate()
221
222    if is_json_output:
223      expected_text = expected
224      actual_text = stdout.decode('utf8')
225    else:
226      # Expected will be in text proto format and we'll need to parse it to
227      # a real proto.
228      expected_message = metrics_message_factory()
229      text_format.Merge(expected, expected_message)
230
231      # Actual will be the raw bytes of the proto and we'll need to parse it
232      # into a message.
233      actual_message = metrics_message_factory()
234      actual_message.ParseFromString(stdout)
235
236      # Convert both back to text format.
237      expected_text = text_format.MessageToString(expected_message)
238      actual_text = text_format.MessageToString(actual_message)
239
240    perf_lines = [line.decode('utf8') for line in tmp_perf_file.readlines()]
241    tmp_perf_file.close()
242    os.remove(tmp_perf_file.name)
243    return TestResult(self.test, trace_path, cmd, expected_text, actual_text,
244                      stderr.decode('utf8'), tp.returncode, perf_lines)
245
246  # Run a query based Diff Test.
247  def __run_query_test(self, trace_path: str, keep_query: bool) -> TestResult:
248    # Fetch expected text.
249    if self.test.expected_path:
250      with open(self.test.expected_path, 'r') as expected_file:
251        expected = expected_file.read()
252    else:
253      expected = self.test.blueprint.out.contents
254
255    # Fetch query.
256    if self.test.blueprint.is_query_file():
257      query = self.test.query_path
258    else:
259      tmp_query_file = tempfile.NamedTemporaryFile(delete=False)
260      with open(tmp_query_file.name, 'w') as query_file:
261        query_file.write(self.test.blueprint.query)
262      query = tmp_query_file.name
263
264    tmp_perf_file = tempfile.NamedTemporaryFile(delete=False)
265    cmd = [
266        self.trace_processor_path,
267        '--analyze-trace-proto-content',
268        '--crop-track-events',
269        '-q',
270        query,
271        '--perf-file',
272        tmp_perf_file.name,
273        trace_path,
274    ]
275    for sql_module_path in self.override_sql_module_paths:
276      cmd += ['--override-sql-module', sql_module_path]
277    tp = subprocess.Popen(
278        cmd,
279        stdout=subprocess.PIPE,
280        stderr=subprocess.PIPE,
281        env=get_env(ROOT_DIR))
282    (stdout, stderr) = tp.communicate()
283
284    if not self.test.blueprint.is_query_file() and not keep_query:
285      tmp_query_file.close()
286      os.remove(tmp_query_file.name)
287    perf_lines = [line.decode('utf8') for line in tmp_perf_file.readlines()]
288    tmp_perf_file.close()
289    os.remove(tmp_perf_file.name)
290
291    actual = stdout.decode('utf8')
292    if self.test.blueprint.is_out_binaryproto():
293      actual = self.__output_to_text_proto(actual, self.test.blueprint.out)
294
295    return TestResult(self.test, trace_path, cmd, expected, actual,
296                      stderr.decode('utf8'), tp.returncode, perf_lines)
297
298  def __run(self, metrics_descriptor_paths: List[str],
299            extension_descriptor_paths: List[str], keep_input,
300            rebase) -> Tuple[TestResult, str]:
301    # We can't use delete=True here. When using that on Windows, the
302    # resulting file is opened in exclusive mode (in turn that's a subtle
303    # side-effect of the underlying CreateFile(FILE_ATTRIBUTE_TEMPORARY))
304    # and TP fails to open the passed path.
305    gen_trace_file = None
306    if self.test.blueprint.is_trace_file():
307      if self.test.trace_path.endswith('.py'):
308        gen_trace_file = tempfile.NamedTemporaryFile(delete=False)
309        serialize_python_trace(ROOT_DIR, self.trace_descriptor_path,
310                               self.test.trace_path, gen_trace_file)
311
312      elif self.test.trace_path.endswith('.textproto'):
313        gen_trace_file = tempfile.NamedTemporaryFile(delete=False)
314        serialize_textproto_trace(self.trace_descriptor_path,
315                                  extension_descriptor_paths,
316                                  self.test.trace_path, gen_trace_file)
317
318    elif self.test.blueprint.is_trace_textproto():
319      gen_trace_file = tempfile.NamedTemporaryFile(delete=False)
320      proto = create_message_factory([self.trace_descriptor_path] +
321                                     extension_descriptor_paths,
322                                     'perfetto.protos.Trace')()
323      text_format.Merge(self.test.blueprint.trace.contents, proto)
324      gen_trace_file.write(proto.SerializeToString())
325      gen_trace_file.flush()
326
327    else:
328      gen_trace_file = tempfile.NamedTemporaryFile(delete=False)
329      with open(gen_trace_file.name, 'w') as trace_file:
330        trace_file.write(self.test.blueprint.trace.contents)
331
332    if self.test.blueprint.trace_modifier is not None:
333      if gen_trace_file:
334        # Overwrite |gen_trace_file|.
335        modify_trace(self.trace_descriptor_path, extension_descriptor_paths,
336                     gen_trace_file.name, gen_trace_file.name,
337                     self.test.blueprint.trace_modifier)
338      else:
339        # Create |gen_trace_file| to save the modified trace.
340        gen_trace_file = tempfile.NamedTemporaryFile(delete=False)
341        modify_trace(self.trace_descriptor_path, extension_descriptor_paths,
342                     self.test.trace_path, gen_trace_file.name,
343                     self.test.blueprint.trace_modifier)
344
345    if gen_trace_file:
346      trace_path = os.path.realpath(gen_trace_file.name)
347    else:
348      trace_path = self.test.trace_path
349
350    str = f"{self.colors.yellow('[ RUN      ]')} {self.test.name}\n"
351
352    if self.test.type == TestType.QUERY:
353      result = self.__run_query_test(trace_path, keep_input)
354    elif self.test.type == TestType.METRIC:
355      result = self.__run_metrics_test(
356          trace_path,
357          create_message_factory(metrics_descriptor_paths,
358                                 'perfetto.protos.TraceMetrics'))
359    else:
360      assert False
361
362    if gen_trace_file:
363      if keep_input:
364        str += f"Saving generated input trace: {trace_path}\n"
365      else:
366        gen_trace_file.close()
367        os.remove(trace_path)
368
369    def write_cmdlines():
370      res = ""
371      if self.test.trace_path and (self.test.trace_path.endswith('.textproto')
372                                   or self.test.trace_path.endswith('.py')):
373        res += 'Command to generate trace:\n'
374        res += 'tools/serialize_test_trace.py '
375        res += '--descriptor {} {} > {}\n'.format(
376            os.path.relpath(self.trace_descriptor_path, ROOT_DIR),
377            os.path.relpath(self.test.trace_path, ROOT_DIR),
378            os.path.relpath(trace_path, ROOT_DIR))
379      res += f"Command line:\n{' '.join(result.cmd)}\n"
380      return res
381
382    if result.exit_code != 0 or not result.passed:
383      result.passed = False
384      str += result.stderr
385
386      if result.exit_code == 0:
387        str += f"Expected did not match actual for test {self.test.name}.\n"
388        str += write_cmdlines()
389        str += result.write_diff()
390      else:
391        str += write_cmdlines()
392
393      str += (f"{self.colors.red('[  FAILED  ]')} {self.test.name}\n")
394      str += result.rebase(rebase)
395
396      return result, str
397    else:
398      str += (f"{self.colors.green('[       OK ]')} {self.test.name} "
399              f"(ingest: {result.perf_result.ingest_time_ns / 1000000:.2f} ms "
400              f"query: {result.perf_result.real_time_ns / 1000000:.2f} ms)\n")
401    return result, str
402
403  # Run a TestCase.
404  def execute(self, extension_descriptor_paths: List[str],
405              metrics_descriptor_paths: List[str], keep_input: bool,
406              rebase: bool) -> Tuple[str, str, TestResult]:
407    if not metrics_descriptor_paths:
408      out_path = os.path.dirname(self.trace_processor_path)
409      metrics_protos_path = os.path.join(out_path, 'gen', 'protos', 'perfetto',
410                                         'metrics')
411      metrics_descriptor_paths = [
412          os.path.join(metrics_protos_path, 'metrics.descriptor'),
413          os.path.join(metrics_protos_path, 'chrome',
414                       'all_chrome_metrics.descriptor'),
415          os.path.join(metrics_protos_path, 'webview',
416                       'all_webview_metrics.descriptor')
417      ]
418    result_str = ""
419
420    result, run_str = self.__run(metrics_descriptor_paths,
421                                 extension_descriptor_paths, keep_input, rebase)
422    result_str += run_str
423    if not result:
424      return self.test.name, result_str, None
425
426    return self.test.name, result_str, result
427
428
429# Fetches and executes all diff viable tests.
430@dataclass
431class DiffTestsRunner:
432  tests: List[TestCase]
433  trace_processor_path: str
434  trace_descriptor_path: str
435  test_runners: List[TestCaseRunner]
436
437  def __init__(self, name_filter: str, trace_processor_path: str,
438               trace_descriptor: str, no_colors: bool,
439               override_sql_module_paths: List[str], test_dir: str):
440    self.tests = read_all_tests(name_filter, test_dir)
441    self.trace_processor_path = trace_processor_path
442
443    out_path = os.path.dirname(self.trace_processor_path)
444    self.trace_descriptor_path = get_trace_descriptor_path(
445        out_path, trace_descriptor)
446    self.test_runners = []
447    color_formatter = ColorFormatter(no_colors)
448    for test in self.tests:
449      self.test_runners.append(
450          TestCaseRunner(test, self.trace_processor_path,
451                         self.trace_descriptor_path, color_formatter,
452                         override_sql_module_paths))
453
454  def run_all_tests(self, metrics_descriptor_paths: List[str],
455                    chrome_extensions: str, test_extensions: str,
456                    winscope_extensions: str, keep_input: bool,
457                    rebase: bool) -> TestResults:
458    perf_results = []
459    failures = []
460    rebased = []
461    test_run_start = datetime.datetime.now()
462
463    with concurrent.futures.ProcessPoolExecutor() as e:
464      fut = [
465          e.submit(test.execute,
466                   [chrome_extensions, test_extensions, winscope_extensions],
467                   metrics_descriptor_paths, keep_input, rebase)
468          for test in self.test_runners
469      ]
470      for res in concurrent.futures.as_completed(fut):
471        test_name, res_str, result = res.result()
472        sys.stderr.write(res_str)
473        if not result or not result.passed:
474          if rebase:
475            rebased.append(test_name)
476          failures.append(test_name)
477        else:
478          perf_results.append(result.perf_result)
479    test_time_ms = int(
480        (datetime.datetime.now() - test_run_start).total_seconds() * 1000)
481    return TestResults(failures, perf_results, rebased, test_time_ms)
482