• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright (C) 2023 The Android Open Source Project
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#      http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import concurrent.futures
17import datetime
18import difflib
19import os
20import subprocess
21import sys
22import tempfile
23from binascii import unhexlify
24from dataclasses import dataclass
25from typing import List, Tuple, Optional
26
27from google.protobuf import text_format
28from python.generators.diff_tests.testing import Metric, MetricV2SpecTextproto, Path, TestCase, TestType, BinaryProto, TextProto
29from python.generators.diff_tests.utils import (
30    ColorFormatter, create_message_factory, get_env, get_trace_descriptor_path,
31    read_all_tests, serialize_python_trace, serialize_textproto_trace,
32    modify_trace)
33
34ROOT_DIR = os.path.dirname(
35    os.path.dirname(
36        os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
37
38
39# Performance result of running the test.
40@dataclass
41class PerfResult:
42  test: TestCase
43  ingest_time_ns: int
44  real_time_ns: int
45
46  def __init__(self, test: TestCase, perf_lines: List[str]):
47    self.test = test
48
49    assert len(perf_lines) == 1
50    perf_numbers = perf_lines[0].split(',')
51
52    assert len(perf_numbers) == 2
53    self.ingest_time_ns = int(perf_numbers[0])
54    self.real_time_ns = int(perf_numbers[1])
55
56
57# Data gathered from running the test.
58@dataclass
59class TestResult:
60  test: TestCase
61  trace: str
62  cmd: List[str]
63  expected: str
64  actual: str
65  passed: bool
66  stderr: str
67  exit_code: int
68  perf_result: Optional[PerfResult]
69
70  def __init__(
71      self,
72      test: TestCase,
73      gen_trace_path: str,
74      cmd: List[str],
75      expected_text: str,
76      actual_text: str,
77      stderr: str,
78      exit_code: int,
79      perf_lines: List[str],
80  ) -> None:
81    self.test = test
82    self.trace = gen_trace_path
83    self.cmd = cmd
84    self.stderr = stderr
85    self.exit_code = exit_code
86
87    # For better string formatting we often add whitespaces, which has to now
88    # be removed.
89    def strip_whitespaces(text: str):
90      no_front_new_line_text = text.lstrip('\n')
91      return '\n'.join(s.strip() for s in no_front_new_line_text.split('\n'))
92
93    self.expected = strip_whitespaces(expected_text)
94    self.actual = strip_whitespaces(actual_text)
95
96    expected_content = self.expected.replace('\r\n', '\n')
97
98    actual_content = self.actual.replace('\r\n', '\n')
99    self.passed = (expected_content == actual_content)
100
101    if self.exit_code == 0:
102      self.perf_result = PerfResult(self.test, perf_lines)
103    else:
104      self.perf_result = None
105
106  def write_diff(self):
107    expected_lines = self.expected.splitlines(True)
108    actual_lines = self.actual.splitlines(True)
109    diff = difflib.unified_diff(
110        expected_lines, actual_lines, fromfile='expected', tofile='actual')
111    return "".join(list(diff))
112
113
114# Results of running the test suite. Mostly used for printing aggregated
115# results.
116@dataclass
117class TestResults:
118  test_failures: List[str]
119  perf_data: List[PerfResult]
120  test_time_ms: int
121
122  def str(self, no_colors: bool, tests_no: int):
123    c = ColorFormatter(no_colors)
124    res = (
125        f"[==========] {tests_no} tests ran. ({self.test_time_ms} ms total)\n"
126        f"{c.green('[  PASSED  ]')} "
127        f"{tests_no - len(self.test_failures)} tests.\n")
128    if len(self.test_failures) > 0:
129      res += (f"{c.red('[  FAILED  ]')} "
130              f"{len(self.test_failures)} tests.\n")
131      for failure in self.test_failures:
132        res += f"{c.red('[  FAILED  ]')} {failure}\n"
133    return res
134
135
136# Responsible for executing singular diff test.
137@dataclass
138class TestCaseRunner:
139  test: TestCase
140  trace_processor_path: str
141  trace_descriptor_path: str
142  colors: ColorFormatter
143  override_sql_module_paths: List[str]
144
145  def __output_to_text_proto(self, actual: str, out: BinaryProto) -> str:
146    """Deserializes a binary proto and returns its text representation.
147
148    Args:
149      actual: (string) HEX encoded serialized proto message
150      message_type: (string) Message type
151
152    Returns:
153      Text proto
154    """
155    try:
156      protos_dir = os.path.join(
157          ROOT_DIR,
158          os.path.dirname(self.trace_processor_path),
159          'gen',
160          'protos',
161      )
162      raw_data = unhexlify(actual.splitlines()[-1][1:-1])
163      descriptor_paths = [
164          f.path
165          for f in os.scandir(
166              os.path.join(protos_dir, 'perfetto', 'trace_processor'))
167          if f.is_file() and os.path.splitext(f.name)[1] == '.descriptor'
168      ]
169      descriptor_paths.append(
170          os.path.join(protos_dir, 'third_party', 'pprof',
171                       'profile.descriptor'))
172      proto = create_message_factory(descriptor_paths, out.message_type)()
173      proto.ParseFromString(raw_data)
174      try:
175        return out.post_processing(proto)
176      except:
177        return '<Proto post processing failed>'
178    except:
179      return '<Invalid input for proto deserializaiton>'
180
181  def __run_metrics_test(self, trace_path: str,
182                         metrics_message_factory) -> TestResult:
183    with tempfile.NamedTemporaryFile(delete=False) as tmp_perf_file:
184      assert isinstance(self.test.blueprint.query, Metric)
185
186      is_json_output_file = self.test.blueprint.is_out_file(
187      ) and os.path.basename(self.test.expected_path).endswith('.json.out')
188      is_json_output = is_json_output_file or self.test.blueprint.is_out_json()
189      cmd = [
190          self.trace_processor_path,
191          '--analyze-trace-proto-content',
192          '--crop-track-events',
193          '--extra-checks',
194          '--run-metrics',
195          self.test.blueprint.query.name,
196          '--metrics-output=%s' % ('json' if is_json_output else 'binary'),
197          '--perf-file',
198          tmp_perf_file.name,
199          trace_path,
200      ]
201      if self.test.register_files_dir:
202        cmd += ['--register-files-dir', self.test.register_files_dir]
203      for sql_module_path in self.override_sql_module_paths:
204        cmd += ['--override-sql-module', sql_module_path]
205      tp = subprocess.Popen(
206          cmd,
207          stdout=subprocess.PIPE,
208          stderr=subprocess.PIPE,
209          env=get_env(ROOT_DIR))
210      (stdout, stderr) = tp.communicate()
211
212      if is_json_output:
213        expected_text = self.test.expected_str
214        actual_text = stdout.decode('utf8')
215      else:
216        # Expected will be in text proto format and we'll need to parse it to
217        # a real proto.
218        expected_message = metrics_message_factory()
219        text_format.Merge(self.test.expected_str, expected_message)
220
221        # Actual will be the raw bytes of the proto and we'll need to parse it
222        # into a message.
223        actual_message = metrics_message_factory()
224        actual_message.ParseFromString(stdout)
225
226        # Convert both back to text format.
227        expected_text = text_format.MessageToString(expected_message)
228        actual_text = text_format.MessageToString(actual_message)
229
230      os.remove(tmp_perf_file.name)
231
232      return TestResult(
233          self.test,
234          trace_path,
235          cmd,
236          expected_text,
237          actual_text,
238          stderr.decode('utf8'),
239          tp.returncode,
240          [line.decode('utf8') for line in tmp_perf_file.readlines()],
241      )
242
243  def __run_metrics_v2_test(
244      self,
245      trace_path: str,
246      keep_input: bool,
247      summary_spec_message_factory,
248      summary_message_factory,
249  ) -> TestResult:
250    with tempfile.NamedTemporaryFile(delete=False) as tmp_perf_file, \
251         tempfile.NamedTemporaryFile(delete=False) as tmp_spec_file:
252      assert isinstance(self.test.blueprint.query, MetricV2SpecTextproto)
253
254      spec_message = summary_spec_message_factory()
255      text_format.Merge(self.test.blueprint.query.contents,
256                        spec_message.metric_spec.add())
257
258      tmp_spec_file.write(spec_message.SerializeToString())
259      tmp_spec_file.flush()
260
261      cmd = [
262          self.trace_processor_path,
263          '--analyze-trace-proto-content',
264          '--crop-track-events',
265          '--extra-checks',
266          '--perf-file',
267          tmp_perf_file.name,
268          '--summary',
269          '--summary-spec',
270          tmp_spec_file.name,
271          '--summary-metrics-v2',
272          spec_message.metric_spec[0].id,
273          '--summary-format',
274          'binary',
275          trace_path,
276      ]
277      for sql_module_path in self.override_sql_module_paths:
278        cmd += ['--override-sql-module', sql_module_path]
279      tp = subprocess.Popen(
280          cmd,
281          stdout=subprocess.PIPE,
282          stderr=subprocess.PIPE,
283          env=get_env(ROOT_DIR),
284      )
285      (stdout, stderr) = tp.communicate()
286
287      # Expected will be in text proto format and we'll need to parse it to
288      # a real proto.
289      expected_summary = summary_message_factory()
290      text_format.Merge(self.test.expected_str, expected_summary.metric.add())
291
292      # Actual will be the raw bytes of the proto and we'll need to parse it
293      # into a message.
294      actual_summary = summary_message_factory()
295      actual_summary.ParseFromString(stdout)
296
297      os.remove(tmp_perf_file.name)
298      if not keep_input:
299        os.remove(tmp_spec_file.name)
300
301
302      return TestResult(
303          self.test,
304          trace_path,
305          cmd,
306          text_format.MessageToString(expected_summary.metric[0]),
307          text_format.MessageToString(actual_summary.metric[0]),
308          stderr.decode('utf8'),
309          tp.returncode,
310          [line.decode('utf8') for line in tmp_perf_file.readlines()],
311      )
312
313  # Run a query based Diff Test.
314  def __run_query_test(self, trace_path: str) -> TestResult:
315    with tempfile.NamedTemporaryFile(delete=False) as tmp_perf_file:
316      cmd = [
317          self.trace_processor_path,
318          '--analyze-trace-proto-content',
319          '--crop-track-events',
320          '--extra-checks',
321          '--perf-file',
322          tmp_perf_file.name,
323          trace_path,
324      ]
325      if self.test.blueprint.is_query_file():
326        cmd += ['-q', self.test.query_path]
327      else:
328        assert isinstance(self.test.blueprint.query, str)
329        cmd += ['-Q', self.test.blueprint.query]
330      if self.test.register_files_dir:
331        cmd += ['--register-files-dir', self.test.register_files_dir]
332      for sql_module_path in self.override_sql_module_paths:
333        cmd += ['--override-sql-module', sql_module_path]
334      tp = subprocess.Popen(
335          cmd,
336          stdout=subprocess.PIPE,
337          stderr=subprocess.PIPE,
338          env=get_env(ROOT_DIR))
339      (stdout, stderr) = tp.communicate()
340
341      actual = stdout.decode('utf8')
342      if self.test.blueprint.is_out_binaryproto():
343        assert isinstance(self.test.blueprint.out, BinaryProto)
344        actual = self.__output_to_text_proto(actual, self.test.blueprint.out)
345
346      os.remove(tmp_perf_file.name)
347
348      return TestResult(
349          self.test,
350          trace_path,
351          cmd,
352          self.test.expected_str,
353          actual,
354          stderr.decode('utf8'),
355          tp.returncode,
356          [line.decode('utf8') for line in tmp_perf_file.readlines()],
357      )
358
359  def __run(
360      self,
361      summary_descriptor_path: str,
362      metrics_descriptor_paths: List[str],
363      extension_descriptor_paths: List[str],
364      keep_input,
365  ) -> Tuple[TestResult, str]:
366    # We can't use delete=True here. When using that on Windows, the
367    # resulting file is opened in exclusive mode (in turn that's a subtle
368    # side-effect of the underlying CreateFile(FILE_ATTRIBUTE_TEMPORARY))
369    # and TP fails to open the passed path.
370    gen_trace_file = None
371    if self.test.blueprint.is_trace_file():
372      assert self.test.trace_path
373      if self.test.trace_path.endswith('.py'):
374        gen_trace_file = tempfile.NamedTemporaryFile(delete=False)
375        serialize_python_trace(ROOT_DIR, self.trace_descriptor_path,
376                               self.test.trace_path, gen_trace_file)
377
378      elif self.test.trace_path.endswith('.textproto'):
379        gen_trace_file = tempfile.NamedTemporaryFile(delete=False)
380        serialize_textproto_trace(self.trace_descriptor_path,
381                                  extension_descriptor_paths,
382                                  self.test.trace_path, gen_trace_file)
383
384    elif self.test.blueprint.is_trace_textproto():
385      gen_trace_file = tempfile.NamedTemporaryFile(delete=False)
386      proto = create_message_factory([self.trace_descriptor_path] +
387                                     extension_descriptor_paths,
388                                     'perfetto.protos.Trace')()
389      assert isinstance(self.test.blueprint.trace, TextProto)
390      text_format.Merge(self.test.blueprint.trace.contents, proto)
391      gen_trace_file.write(proto.SerializeToString())
392      gen_trace_file.flush()
393
394    else:
395      gen_trace_file = tempfile.NamedTemporaryFile(delete=False)
396      with open(gen_trace_file.name, 'w') as trace_file:
397        trace_file.write(self.test.blueprint.trace.contents)
398
399    if self.test.blueprint.trace_modifier is not None:
400      if gen_trace_file:
401        # Overwrite |gen_trace_file|.
402        modify_trace(self.trace_descriptor_path, extension_descriptor_paths,
403                     gen_trace_file.name, gen_trace_file.name,
404                     self.test.blueprint.trace_modifier)
405      else:
406        # Create |gen_trace_file| to save the modified trace.
407        gen_trace_file = tempfile.NamedTemporaryFile(delete=False)
408        modify_trace(self.trace_descriptor_path, extension_descriptor_paths,
409                     self.test.trace_path, gen_trace_file.name,
410                     self.test.blueprint.trace_modifier)
411
412    if gen_trace_file:
413      trace_path = os.path.realpath(gen_trace_file.name)
414    else:
415      trace_path = self.test.trace_path
416    assert trace_path
417
418    str = f"{self.colors.yellow('[ RUN      ]')} {self.test.name}\n"
419
420    if self.test.type == TestType.QUERY:
421      result = self.__run_query_test(trace_path)
422    elif self.test.type == TestType.METRIC:
423      result = self.__run_metrics_test(
424          trace_path,
425          create_message_factory(metrics_descriptor_paths,
426                                 'perfetto.protos.TraceMetrics'),
427      )
428    elif self.test.type == TestType.METRIC_V2:
429      result = self.__run_metrics_v2_test(
430          trace_path,
431          keep_input,
432          create_message_factory([summary_descriptor_path],
433                                 'perfetto.protos.TraceSummarySpec'),
434          create_message_factory([summary_descriptor_path],
435                                 'perfetto.protos.TraceSummary'),
436      )
437    else:
438      assert False
439
440    if gen_trace_file:
441      if not keep_input:
442        gen_trace_file.close()
443        os.remove(trace_path)
444
445    def write_cmdlines():
446      res = ""
447      if self.test.trace_path and (self.test.trace_path.endswith('.textproto')
448                                   or self.test.trace_path.endswith('.py')):
449        res += 'Command to generate trace:\n'
450        res += 'tools/serialize_test_trace.py '
451        res += '--descriptor {} {} > {}\n'.format(
452            os.path.relpath(self.trace_descriptor_path, ROOT_DIR),
453            os.path.relpath(self.test.trace_path, ROOT_DIR),
454            os.path.relpath(trace_path, ROOT_DIR))
455      res += f"Command line:\n{' '.join(result.cmd)}\n"
456      return res
457
458    if result.exit_code != 0 or not result.passed:
459      result.passed = False
460      str += result.stderr
461
462      if result.exit_code == 0:
463        str += f"Expected did not match actual for test {self.test.name}.\n"
464        str += write_cmdlines()
465        str += result.write_diff()
466      else:
467        str += write_cmdlines()
468
469      str += (f"{self.colors.red('[  FAILED  ]')} {self.test.name}\n")
470
471      return result, str
472    else:
473      assert result.perf_result
474      str += (f"{self.colors.green('[       OK ]')} {self.test.name} "
475              f"(ingest: {result.perf_result.ingest_time_ns / 1000000:.2f} ms "
476              f"query: {result.perf_result.real_time_ns / 1000000:.2f} ms)\n")
477    return result, str
478
479  # Run a TestCase.
480  def execute(
481      self,
482      summary_descriptor_path: str,
483      metrics_descriptor_paths: List[str],
484      extension_descriptor_paths: List[str],
485      keep_input: bool,
486  ) -> Tuple[str, str, TestResult]:
487    if not metrics_descriptor_paths:
488      out_path = os.path.dirname(self.trace_processor_path)
489      metrics_protos_path = os.path.join(
490          out_path,
491          'gen',
492          'protos',
493          'perfetto',
494          'metrics',
495      )
496      metrics_descriptor_paths = [
497          os.path.join(metrics_protos_path, 'metrics.descriptor'),
498          os.path.join(metrics_protos_path, 'chrome',
499                       'all_chrome_metrics.descriptor'),
500          os.path.join(metrics_protos_path, 'webview',
501                       'all_webview_metrics.descriptor')
502      ]
503    result, run_str = self.__run(
504        summary_descriptor_path,
505        metrics_descriptor_paths,
506        extension_descriptor_paths,
507        keep_input,
508    )
509    return self.test.name, run_str, result
510
511
512# Fetches and executes all diff viable tests.
513@dataclass
514class DiffTestsRunner:
515  tests: List[TestCase]
516  trace_processor_path: str
517  trace_descriptor_path: str
518  test_runners: List[TestCaseRunner]
519  quiet: bool
520
521  def __init__(
522      self,
523      name_filter: str,
524      trace_processor_path: str,
525      trace_descriptor: str,
526      no_colors: bool,
527      override_sql_module_paths: List[str],
528      test_dir: str,
529      quiet: bool,
530  ):
531    self.tests = read_all_tests(name_filter, test_dir)
532    self.trace_processor_path = trace_processor_path
533    self.quiet = quiet
534
535    out_path = os.path.dirname(self.trace_processor_path)
536    self.trace_descriptor_path = get_trace_descriptor_path(
537        out_path,
538        trace_descriptor,
539    )
540    self.test_runners = []
541    color_formatter = ColorFormatter(no_colors)
542    for test in self.tests:
543      self.test_runners.append(
544          TestCaseRunner(
545              test,
546              self.trace_processor_path,
547              self.trace_descriptor_path,
548              color_formatter,
549              override_sql_module_paths,
550          ))
551
552  def run_all_tests(
553      self,
554      summary_descriptor: str,
555      metrics_descriptor_paths: List[str],
556      chrome_extensions: str,
557      test_extensions: str,
558      winscope_extensions: str,
559      keep_input: bool,
560  ) -> TestResults:
561    perf_results = []
562    failures = []
563    test_run_start = datetime.datetime.now()
564    completed_tests = 0
565
566    with concurrent.futures.ProcessPoolExecutor() as e:
567      fut = [
568          e.submit(
569              test.execute,
570              summary_descriptor,
571              metrics_descriptor_paths,
572              [chrome_extensions, test_extensions, winscope_extensions],
573              keep_input,
574          ) for test in self.test_runners
575      ]
576      for res in concurrent.futures.as_completed(fut):
577        test_name, res_str, result = res.result()
578
579        if self.quiet:
580          completed_tests += 1
581          sys.stderr.write(f"\rRan {completed_tests} tests")
582          if not result.passed:
583            sys.stderr.write(f"\r")
584            sys.stderr.write(res_str)
585        else:
586          sys.stderr.write(res_str)
587
588        if not result or not result.passed:
589          failures.append(test_name)
590        else:
591          perf_results.append(result.perf_result)
592    test_time_ms = int(
593        (datetime.datetime.now() - test_run_start).total_seconds() * 1000)
594    if self.quiet:
595      sys.stderr.write(f"\r")
596    return TestResults(failures, perf_results, test_time_ms)
597