• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright (C) 2023 The Android Open Source Project
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#      http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import concurrent.futures
17import datetime
18import difflib
19import os
20import subprocess
21import sys
22import tempfile
23from binascii import unhexlify
24from dataclasses import dataclass
25from typing import List, Tuple, Optional
26
27from google.protobuf import text_format, message_factory, descriptor_pool
28from python.generators.diff_tests.testing import TestCase, TestType, BinaryProto
29from python.generators.diff_tests.utils import (
30    ColorFormatter, create_message_factory, get_env, get_trace_descriptor_path,
31    read_all_tests, serialize_python_trace, serialize_textproto_trace)
32
33ROOT_DIR = os.path.dirname(
34    os.path.dirname(
35        os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
36
37
38# Performance result of running the test.
39@dataclass
40class PerfResult:
41  test: TestCase
42  ingest_time_ns: int
43  real_time_ns: int
44
45  def __init__(self, test: TestCase, perf_lines: List[str]):
46    self.test = test
47
48    assert len(perf_lines) == 1
49    perf_numbers = perf_lines[0].split(',')
50
51    assert len(perf_numbers) == 2
52    self.ingest_time_ns = int(perf_numbers[0])
53    self.real_time_ns = int(perf_numbers[1])
54
55
56# Data gathered from running the test.
57@dataclass
58class TestResult:
59  test: TestCase
60  trace: str
61  cmd: List[str]
62  expected: str
63  actual: str
64  passed: bool
65  stderr: str
66  exit_code: int
67  perf_result: Optional[PerfResult]
68
69  def __init__(self, test: TestCase, gen_trace_path: str, cmd: List[str],
70               expected_text: str, actual_text: str, stderr: str,
71               exit_code: int, perf_lines: List[str]) -> None:
72    self.test = test
73    self.trace = gen_trace_path
74    self.cmd = cmd
75    self.stderr = stderr
76    self.exit_code = exit_code
77
78    # For better string formatting we often add whitespaces, which has to now
79    # be removed.
80    def strip_whitespaces(text: str):
81      no_front_new_line_text = text.lstrip('\n')
82      return '\n'.join(s.strip() for s in no_front_new_line_text.split('\n'))
83
84    self.expected = strip_whitespaces(expected_text)
85    self.actual = strip_whitespaces(actual_text)
86
87    expected_content = self.expected.replace('\r\n', '\n')
88
89    actual_content = self.actual.replace('\r\n', '\n')
90    self.passed = (expected_content == actual_content)
91
92    if self.exit_code == 0:
93      self.perf_result = PerfResult(self.test, perf_lines)
94    else:
95      self.perf_result = None
96
97  def write_diff(self):
98    expected_lines = self.expected.splitlines(True)
99    actual_lines = self.actual.splitlines(True)
100    diff = difflib.unified_diff(
101        expected_lines, actual_lines, fromfile='expected', tofile='actual')
102    return "".join(list(diff))
103
104  def rebase(self, rebase) -> str:
105    if not rebase or self.passed:
106      return ""
107    if not self.test.blueprint.is_out_file():
108      return f"Can't rebase expected results passed as strings.\n"
109    if self.exit_code != 0:
110      return f"Rebase failed for {self.test.name} as query failed\n"
111
112    with open(self.test.expected_path, 'w') as f:
113      f.write(self.actual)
114    return f"Rebasing {self.test.name}\n"
115
116
117# Results of running the test suite. Mostly used for printing aggregated
118# results.
119@dataclass
120class TestResults:
121  test_failures: List[str]
122  perf_data: List[PerfResult]
123  rebased: List[str]
124  test_time_ms: int
125
126  def str(self, no_colors: bool, tests_no: int):
127    c = ColorFormatter(no_colors)
128    res = (
129        f"[==========] {tests_no} tests ran. ({self.test_time_ms} ms total)\n"
130        f"{c.green('[  PASSED  ]')} "
131        f"{tests_no - len(self.test_failures)} tests.\n")
132    if len(self.test_failures) > 0:
133      res += (f"{c.red('[  FAILED  ]')} " f"{len(self.test_failures)} tests.\n")
134      for failure in self.test_failures:
135        res += f"{c.red('[  FAILED  ]')} {failure}\n"
136    return res
137
138  def rebase_str(self):
139    res = f"\n[  REBASED  ] {len(self.rebased)} tests.\n"
140    for name in self.rebased:
141      res += f"[  REBASED  ] {name}\n"
142    return res
143
144
145# Responsible for executing singular diff test.
146@dataclass
147class TestCaseRunner:
148  test: TestCase
149  trace_processor_path: str
150  trace_descriptor_path: str
151  colors: ColorFormatter
152
153  def __output_to_text_proto(self, actual: str, out: BinaryProto) -> str:
154    """Deserializes a binary proto and returns its text representation.
155
156  Args:
157    actual: (string) HEX encoded serialized proto message
158    message_type: (string) Message type
159
160  Returns:
161    Text proto
162  """
163    try:
164      raw_data = unhexlify(actual.splitlines()[-1][1:-1])
165      out_path = os.path.dirname(self.trace_processor_path)
166      descriptor_paths = [
167          f.path
168          for f in os.scandir(
169              os.path.join(ROOT_DIR, out_path, 'gen', 'protos', 'perfetto',
170                           'trace_processor'))
171          if f.is_file() and os.path.splitext(f.name)[1] == '.descriptor'
172      ]
173      descriptor_paths.append(
174          os.path.join(ROOT_DIR, out_path, 'gen', 'protos', 'third_party',
175                       'pprof', 'profile.descriptor'))
176      proto = create_message_factory(descriptor_paths, out.message_type)()
177      proto.ParseFromString(raw_data)
178      try:
179        return out.post_processing(proto)
180      except:
181        return '<Proto post processing failed>'
182      return text_format.MessageToString(proto)
183    except:
184      return '<Invalid input for proto deserializaiton>'
185
186  def __run_metrics_test(self, trace_path: str,
187                         metrics_message_factory) -> TestResult:
188
189    if self.test.blueprint.is_out_file():
190      with open(self.test.expected_path, 'r') as expected_file:
191        expected = expected_file.read()
192    else:
193      expected = self.test.blueprint.out.contents
194
195    tmp_perf_file = tempfile.NamedTemporaryFile(delete=False)
196    is_json_output_file = self.test.blueprint.is_out_file(
197    ) and os.path.basename(self.test.expected_path).endswith('.json.out')
198    is_json_output = is_json_output_file or self.test.blueprint.is_out_json()
199    cmd = [
200        self.trace_processor_path,
201        '--analyze-trace-proto-content',
202        '--crop-track-events',
203        '--run-metrics',
204        self.test.blueprint.query.name,
205        '--metrics-output=%s' % ('json' if is_json_output else 'binary'),
206        '--perf-file',
207        tmp_perf_file.name,
208        trace_path,
209    ]
210    tp = subprocess.Popen(
211        cmd,
212        stdout=subprocess.PIPE,
213        stderr=subprocess.PIPE,
214        env=get_env(ROOT_DIR))
215    (stdout, stderr) = tp.communicate()
216
217    if is_json_output:
218      expected_text = expected
219      actual_text = stdout.decode('utf8')
220    else:
221      # Expected will be in text proto format and we'll need to parse it to
222      # a real proto.
223      expected_message = metrics_message_factory()
224      text_format.Merge(expected, expected_message)
225
226      # Actual will be the raw bytes of the proto and we'll need to parse it
227      # into a message.
228      actual_message = metrics_message_factory()
229      actual_message.ParseFromString(stdout)
230
231      # Convert both back to text format.
232      expected_text = text_format.MessageToString(expected_message)
233      actual_text = text_format.MessageToString(actual_message)
234
235    perf_lines = [line.decode('utf8') for line in tmp_perf_file.readlines()]
236    tmp_perf_file.close()
237    os.remove(tmp_perf_file.name)
238    return TestResult(self.test, trace_path, cmd, expected_text, actual_text,
239                      stderr.decode('utf8'), tp.returncode, perf_lines)
240
241  # Run a query based Diff Test.
242  def __run_query_test(self, trace_path: str) -> TestResult:
243    # Fetch expected text.
244    if self.test.expected_path:
245      with open(self.test.expected_path, 'r') as expected_file:
246        expected = expected_file.read()
247    else:
248      expected = self.test.blueprint.out.contents
249
250    # Fetch query.
251    if self.test.blueprint.is_query_file():
252      query = self.test.query_path
253    else:
254      tmp_query_file = tempfile.NamedTemporaryFile(delete=False)
255      with open(tmp_query_file.name, 'w') as query_file:
256        query_file.write(self.test.blueprint.query)
257      query = tmp_query_file.name
258
259    tmp_perf_file = tempfile.NamedTemporaryFile(delete=False)
260    cmd = [
261        self.trace_processor_path,
262        '--analyze-trace-proto-content',
263        '--crop-track-events',
264        '-q',
265        query,
266        '--perf-file',
267        tmp_perf_file.name,
268        trace_path,
269    ]
270    tp = subprocess.Popen(
271        cmd,
272        stdout=subprocess.PIPE,
273        stderr=subprocess.PIPE,
274        env=get_env(ROOT_DIR))
275    (stdout, stderr) = tp.communicate()
276
277    if not self.test.blueprint.is_query_file():
278      tmp_query_file.close()
279      os.remove(tmp_query_file.name)
280    perf_lines = [line.decode('utf8') for line in tmp_perf_file.readlines()]
281    tmp_perf_file.close()
282    os.remove(tmp_perf_file.name)
283
284    actual = stdout.decode('utf8')
285    if self.test.blueprint.is_out_binaryproto():
286      actual = self.__output_to_text_proto(actual, self.test.blueprint.out)
287
288    return TestResult(self.test, trace_path, cmd, expected, actual,
289                      stderr.decode('utf8'), tp.returncode, perf_lines)
290
291  def __run(self, metrics_descriptor_paths: List[str],
292            extension_descriptor_paths: List[str], keep_input,
293            rebase) -> Tuple[TestResult, str]:
294    # We can't use delete=True here. When using that on Windows, the
295    # resulting file is opened in exclusive mode (in turn that's a subtle
296    # side-effect of the underlying CreateFile(FILE_ATTRIBUTE_TEMPORARY))
297    # and TP fails to open the passed path.
298    gen_trace_file = None
299    if self.test.blueprint.is_trace_file():
300      if self.test.trace_path.endswith('.py'):
301        gen_trace_file = tempfile.NamedTemporaryFile(delete=False)
302        serialize_python_trace(ROOT_DIR, self.trace_descriptor_path,
303                               self.test.trace_path, gen_trace_file)
304
305      elif self.test.trace_path.endswith('.textproto'):
306        gen_trace_file = tempfile.NamedTemporaryFile(delete=False)
307        serialize_textproto_trace(self.trace_descriptor_path,
308                                  extension_descriptor_paths,
309                                  self.test.trace_path, gen_trace_file)
310
311    elif self.test.blueprint.is_trace_textproto():
312      gen_trace_file = tempfile.NamedTemporaryFile(delete=False)
313      proto = create_message_factory([self.trace_descriptor_path] +
314                                     extension_descriptor_paths,
315                                     'perfetto.protos.Trace')()
316      text_format.Merge(self.test.blueprint.trace.contents, proto)
317      gen_trace_file.write(proto.SerializeToString())
318      gen_trace_file.flush()
319
320    else:
321      gen_trace_file = tempfile.NamedTemporaryFile(delete=False)
322      with open(gen_trace_file.name, 'w') as trace_file:
323        trace_file.write(self.test.blueprint.trace.contents)
324
325    if gen_trace_file:
326      trace_path = os.path.realpath(gen_trace_file.name)
327    else:
328      trace_path = self.test.trace_path
329
330    str = f"{self.colors.yellow('[ RUN      ]')} {self.test.name}\n"
331
332    if self.test.type == TestType.QUERY:
333      result = self.__run_query_test(trace_path)
334    elif self.test.type == TestType.METRIC:
335      result = self.__run_metrics_test(
336          trace_path,
337          create_message_factory(metrics_descriptor_paths,
338                                 'perfetto.protos.TraceMetrics'))
339    else:
340      assert False
341
342    if gen_trace_file:
343      if keep_input:
344        str += f"Saving generated input trace: {trace_path}\n"
345      else:
346        gen_trace_file.close()
347        os.remove(trace_path)
348
349    def write_cmdlines():
350      res = ""
351      if not gen_trace_file:
352        res += 'Command to generate trace:\n'
353        res += 'tools/serialize_test_trace.py '
354        res += '--descriptor {} {} > {}\n'.format(
355            os.path.relpath(self.trace_descriptor_path, ROOT_DIR),
356            os.path.relpath(self.test.trace_path, ROOT_DIR),
357            os.path.relpath(trace_path, ROOT_DIR))
358      res += f"Command line:\n{' '.join(result.cmd)}\n"
359      return res
360
361    if result.exit_code != 0 or not result.passed:
362      str += result.stderr
363
364      if result.exit_code == 0:
365        str += f"Expected did not match actual for test {self.test.name}.\n"
366        str += write_cmdlines()
367        str += result.write_diff()
368      else:
369        str += write_cmdlines()
370
371      str += (f"{self.colors.red('[  FAILED  ]')} {self.test.name}\n")
372      str += result.rebase(rebase)
373
374      return result, str
375    else:
376      str += (f"{self.colors.green('[       OK ]')} {self.test.name} "
377              f"(ingest: {result.perf_result.ingest_time_ns / 1000000:.2f} ms "
378              f"query: {result.perf_result.real_time_ns / 1000000:.2f} ms)\n")
379    return result, str
380
381  # Run a TestCase.
382  def execute(self, extension_descriptor_paths: List[str],
383              metrics_descriptor: str, keep_input: bool,
384              rebase: bool) -> Tuple[str, str, TestResult]:
385    if metrics_descriptor:
386      metrics_descriptor_paths = [metrics_descriptor]
387    else:
388      out_path = os.path.dirname(self.trace_processor_path)
389      metrics_protos_path = os.path.join(out_path, 'gen', 'protos', 'perfetto',
390                                         'metrics')
391      metrics_descriptor_paths = [
392          os.path.join(metrics_protos_path, 'metrics.descriptor'),
393          os.path.join(metrics_protos_path, 'chrome',
394                       'all_chrome_metrics.descriptor'),
395          os.path.join(metrics_protos_path, 'webview',
396                       'all_webview_metrics.descriptor')
397      ]
398    result_str = ""
399
400    result, run_str = self.__run(metrics_descriptor_paths,
401                                 extension_descriptor_paths, keep_input, rebase)
402    result_str += run_str
403    if not result:
404      return self.test.name, result_str, None
405
406    return self.test.name, result_str, result
407
408
409# Fetches and executes all diff viable tests.
410@dataclass
411class DiffTestsRunner:
412  tests: List[TestCase]
413  trace_processor_path: str
414  trace_descriptor_path: str
415  test_runners: List[TestCaseRunner]
416
417  def __init__(self, name_filter: str, trace_processor_path: str,
418               trace_descriptor: str, no_colors: bool):
419    self.tests = read_all_tests(name_filter, ROOT_DIR)
420    self.trace_processor_path = trace_processor_path
421
422    out_path = os.path.dirname(self.trace_processor_path)
423    self.trace_descriptor_path = get_trace_descriptor_path(
424        out_path, trace_descriptor)
425    self.test_runners = []
426    color_formatter = ColorFormatter(no_colors)
427    for test in self.tests:
428      self.test_runners.append(
429          TestCaseRunner(test, self.trace_processor_path,
430                         self.trace_descriptor_path, color_formatter))
431
432  def run_all_tests(self, metrics_descriptor: str, keep_input: bool,
433                    rebase: bool) -> TestResults:
434    perf_results = []
435    failures = []
436    rebased = []
437    test_run_start = datetime.datetime.now()
438
439    out_path = os.path.dirname(self.trace_processor_path)
440    chrome_extensions = os.path.join(out_path, 'gen', 'protos', 'third_party',
441                                     'chromium',
442                                     'chrome_track_event.descriptor')
443    test_extensions = os.path.join(out_path, 'gen', 'protos', 'perfetto',
444                                   'trace', 'test_extensions.descriptor')
445
446    with concurrent.futures.ProcessPoolExecutor() as e:
447      fut = [
448          e.submit(test.execute, [chrome_extensions, test_extensions],
449                   metrics_descriptor, keep_input, rebase)
450          for test in self.test_runners
451      ]
452      for res in concurrent.futures.as_completed(fut):
453        test_name, res_str, result = res.result()
454        sys.stderr.write(res_str)
455        if not result or not result.passed:
456          if rebase:
457            rebased.append(test_name)
458          failures.append(test_name)
459        else:
460          perf_results.append(result.perf_result)
461    test_time_ms = int(
462        (datetime.datetime.now() - test_run_start).total_seconds() * 1000)
463    return TestResults(failures, perf_results, rebased, test_time_ms)
464