• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2#
3# Copyright 2023, The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Module to facilitate integration test within the build and test environment.
18
19This module provides utilities for running tests in both build and test
20environments, managing environment variables, and snapshotting the workspace for
21restoration later.
22"""
23
24import argparse
25import atexit
26import concurrent.futures
27import copy
28import datetime
29import functools
30import itertools
31import logging
32import multiprocessing
33import os
34import pathlib
35import shutil
36import subprocess
37import sys
38import tarfile
39import tempfile
40import time
41import traceback
42from typing import Any, Callable, Iterator
43import unittest
44import zipfile
45
46from snapshot import Snapshot
47
48# Env key for the storage tar path.
49SNAPSHOT_STORAGE_TAR_KEY = 'SNAPSHOT_STORAGE_TAR_PATH'
50
51# Env key for the repo root
52ANDROID_BUILD_TOP_KEY = 'ANDROID_BUILD_TOP'
53
54
55class IntegrationTestConfiguration:
56  """Internal class to store integration test configuration."""
57
58  device_serial: str = None
59  is_build_env: bool = False
60  is_test_env: bool = False
61  snapshot_storage_path: pathlib.Path = None
62  snapshot_storage_tar_path: pathlib.Path = None
63  workspace_path: pathlib.Path = None
64  is_tar_snapshot: bool = False
65
66
67class StepInput:
68  """Input information for a build/test step."""
69
70  def __init__(self, env, repo_root, config, objs):
71    self._env = env
72    self._repo_root = repo_root
73    self._config = config
74    self._objs = objs
75
76  def get_device_serial_args_or_empty(self) -> str:
77    """Gets command arguments for device serial. May return empty string."""
78    # TODO: b/336839543 - Remove this method when we deprecate the support to
79    # run the integration test directly through 'python **.py' command.
80    if self._config.device_serial:
81      return ' -s ' + self._config.device_serial
82    if ANDROID_BUILD_TOP_KEY not in os.environ and self._config.is_test_env:
83      # Likely in test lab environment, where connected devices can are
84      # allocated to other tests. In this case we must explicitly set device
85      # serials in any atest calls .
86      raise RuntimeError('Device serial is required but not set')
87    # Empty is allowed because it allows tradefed to decide which device to
88    # select in local run.
89    return ''
90
91  def get_device_serial(self) -> str:
92    """Returns the serial of the connected device. Throws if not set."""
93    if not self._config.device_serial:
94      raise RuntimeError('Device serial is not set')
95    return self._config.device_serial
96
97  def get_env(self):
98    """Get environment variables."""
99    return self._env
100
101  def get_repo_root(self) -> str:
102    """Get repo root directory."""
103    return self._repo_root
104
105  def get_obj(self, name: str) -> Any:
106    """Get an object saved in previous snapshot."""
107    return self._objs.get(name, None)
108
109  def get_config(self) -> IntegrationTestConfiguration:
110    """Get the integration test configuration."""
111    return self._config
112
113
114class StepOutput:
115  """Output information generated from a build step."""
116
117  def __init__(self):
118    self._snapshot_include_paths: list[str] = []
119    self._snapshot_exclude_paths: list[str] = []
120    self._snapshot_env_keys: list[str] = []
121    self._snapshot_objs: dict[str, Any] = {}
122
123  def add_snapshot_include_paths(self, paths: list[str]) -> None:
124    """Add paths to include in snapshot artifacts."""
125    self._snapshot_include_paths.extend(paths)
126
127  def set_snapshot_include_paths(self, paths: list[str]) -> None:
128    """Set the snapshot include paths.
129
130    Note that the default include paths will be removed.
131    Use add_snapshot_include_paths if that's not intended.
132
133    Args:
134        paths: The new list of paths to include for snapshot.
135    """
136    self._snapshot_include_paths.clear()
137    self._snapshot_include_paths.extend(paths)
138
139  def add_snapshot_exclude_paths(self, paths: list[str]) -> None:
140    """Add paths to exclude from snapshot artifacts."""
141    self._snapshot_exclude_paths.extend(paths)
142
143  def add_snapshot_env_keys(self, keys: list[str]) -> None:
144    """Add environment variable keys for snapshot."""
145    self._snapshot_env_keys.extend(keys)
146
147  def add_snapshot_obj(self, name: str, obj: Any):
148    """Add objects to save in snapshot."""
149    self._snapshot_objs[name] = obj
150
151  def get_snapshot_include_paths(self):
152    """Returns the stored snapshot include path list."""
153    return self._snapshot_include_paths
154
155  def get_snapshot_exclude_paths(self):
156    """Returns the stored snapshot exclude path list."""
157    return self._snapshot_exclude_paths
158
159  def get_snapshot_env_keys(self):
160    """Returns the stored snapshot env key list."""
161    return self._snapshot_env_keys
162
163  def get_snapshot_objs(self):
164    """Returns the stored snapshot object dictionary."""
165    return self._snapshot_objs
166
167
168class SplitBuildTestScript:
169  """Utility for running integration test in build and test environment."""
170
171  def __init__(self, name: str, config: IntegrationTestConfiguration) -> None:
172    self._config = config
173    self._id: str = name
174    self._snapshot: Snapshot = Snapshot(self._config.snapshot_storage_path)
175    self._has_already_run: bool = False
176    self._steps: list[self._Step] = []
177    self._snapshot_restore_exclude_paths: list[str] = []
178
179  def get_config(self) -> IntegrationTestConfiguration:
180    return self._config
181
182  def add_build_step(self, step_func: Callable[StepInput, StepOutput]):
183    """Add a build step.
184
185    Args:
186        step_func: A function that takes a StepInput object and returns a
187          StepOutput object.
188
189    Raises:
190        RuntimeError: Unexpected step orders detected.
191    """
192    if self._steps and isinstance(self._steps[-1], self._BuildStep):
193      raise RuntimeError(
194          'Two adjacent build steps are unnecessary. Combine them.'
195      )
196    self._steps.append(self._BuildStep(step_func))
197
198  def add_test_step(self, step_func: Callable[StepInput, None]):
199    """Add a test step.
200
201    Args:
202        step_func: A function that takes a StepInput object.
203
204    Raises:
205        RuntimeError: Unexpected step orders detected.
206    """
207    if not self._steps or isinstance(self._steps[-1], self._TestStep):
208      raise RuntimeError('A build step is required before a test step.')
209    self._steps.append(self._TestStep(step_func))
210
211  def _exception_to_dict(self, exception: Exception):
212    """Converts an exception object to a dictionary to be saved by json."""
213    return {
214        'type': exception.__class__.__name__,
215        'message': str(exception),
216        'traceback': ''.join(traceback.format_tb(exception.__traceback__)),
217    }
218
219  def _dict_to_exception(self, exception_dict: dict[str, str]):
220    """Converts a dictionary to an exception object."""
221    return RuntimeError(
222        'The last build step raised an exception:\n'
223        f'{exception_dict["type"]}: {exception_dict["message"]}\n'
224        'Traceback (from saved snapshot):\n'
225        f'{exception_dict["traceback"]}'
226    )
227
228  def run(self):
229    """Run the steps added previously.
230
231    This function cannot be executed more than once.
232    Raises:
233        RuntimeError: When attempted to run the script multiple times.
234    """
235    if self._has_already_run:
236      raise RuntimeError(f'Script {self.name} has already run.')
237    self._has_already_run = True
238
239    build_step_exception_key = '_internal_build_step_exception'
240
241    for index, step in enumerate(self._steps):
242      if isinstance(step, self._BuildStep) and self.get_config().is_build_env:
243        env = os.environ
244        step_in = StepInput(
245            env,
246            self._get_repo_root(os.environ),
247            self.get_config(),
248            {},
249        )
250        last_exception = None
251        try:
252          step_out = step.get_step_func()(step_in)
253        # pylint: disable=broad-exception-caught
254        except Exception as e:
255          last_exception = e
256          step_out = StepOutput()
257          step_out.add_snapshot_obj(
258              build_step_exception_key, self._exception_to_dict(e)
259          )
260
261        self._take_snapshot(
262            self._get_repo_root(os.environ),
263            self._id + '_' + str(index // 2),
264            step_out,
265            env,
266        )
267
268        if last_exception:
269          raise last_exception
270
271      if isinstance(step, self._TestStep) and self.get_config().is_test_env:
272        env, objs = self._restore_snapshot(self._id + '_' + str(index // 2))
273
274        if build_step_exception_key in objs:
275          raise self._dict_to_exception(objs[build_step_exception_key])
276
277        step_in = StepInput(
278            env,
279            self._get_repo_root(env),
280            self.get_config(),
281            objs,
282        )
283        step.get_step_func()(step_in)
284
285  def add_snapshot_restore_exclude_paths(self, paths: list[str]) -> None:
286    """Add paths to ignore during snapshot directory restore."""
287    self._snapshot_restore_exclude_paths.extend(paths)
288
289  def _take_snapshot(
290      self,
291      repo_root: str,
292      name: str,
293      step_out: StepOutput,
294      env: dict[str, str],
295  ) -> None:
296    """Take a snapshot of the repository and environment."""
297    self._snapshot.take_snapshot(
298        name,
299        repo_root,
300        include_paths=step_out.get_snapshot_include_paths(),
301        exclude_paths=step_out.get_snapshot_exclude_paths(),
302        env_keys=step_out.get_snapshot_env_keys(),
303        env=env,
304        objs=step_out.get_snapshot_objs(),
305    )
306
307  def _restore_snapshot(self, name: str) -> None:
308    """Restore the repository and environment from a snapshot."""
309    return self._snapshot.restore_snapshot(
310        name,
311        self.get_config().workspace_path.as_posix(),
312        exclude_paths=self._snapshot_restore_exclude_paths,
313    )
314
315  def _get_repo_root(self, env) -> str:
316    """Get repo root directory."""
317    if self.get_config().is_build_env:
318      return os.environ[ANDROID_BUILD_TOP_KEY]
319    return env[ANDROID_BUILD_TOP_KEY]
320
321  class _Step:
322    """Parent class to build step and test step for typing declaration."""
323
324  class _BuildStep(_Step):
325
326    def __init__(self, step_func: Callable[StepInput, StepOutput]):
327      self._step_func = step_func
328
329    def get_step_func(self) -> Callable[StepInput, StepOutput]:
330      """Returns the stored step function for build."""
331      return self._step_func
332
333  class _TestStep(_Step):
334
335    def __init__(self, step_func: Callable[StepInput, None]):
336      self._step_func = step_func
337
338    def get_step_func(self) -> Callable[StepInput, None]:
339      """Returns the stored step function for test."""
340      return self._step_func
341
342
343class SplitBuildTestTestCase(unittest.TestCase):
344  """Base test case class for split build-test scripting tests."""
345
346  # Internal config to be injected to the test case from main.
347  _config: IntegrationTestConfiguration = None
348
349  @classmethod
350  def set_config(cls, config: IntegrationTestConfiguration) -> None:
351    cls._config = config
352
353  @classmethod
354  def get_config(cls) -> IntegrationTestConfiguration:
355    return cls._config
356
357  def create_split_build_test_script(
358      self, name: str = None
359  ) -> SplitBuildTestScript:
360    """Return an instance of SplitBuildTestScript with the given name.
361
362    Args:
363        name: The name of the script. The name will be used to store snapshots
364          and it's recommended to set the name to test id such as self.id().
365          Defaults to the test id if not set.
366    """
367    if not name:
368      name = self.id()
369      main_module_name = '__main__'
370      if name.startswith(main_module_name):
371        script_name = pathlib.Path(sys.modules[main_module_name].__file__).stem
372        name = name.replace(main_module_name, script_name)
373    return SplitBuildTestScript(name, self.get_config())
374
375
376class _FileCompressor:
377  """Class for compressing and decompressing files."""
378
379  def compress_all_sub_files(self, root_path: pathlib.Path) -> None:
380    """Compresses all files in the given directory and subdirectories.
381
382    Args:
383        root_path: The path to the root directory.
384    """
385    cpu_count = multiprocessing.cpu_count()
386    with concurrent.futures.ThreadPoolExecutor(
387        max_workers=cpu_count
388    ) as executor:
389      for file_path in root_path.rglob('*'):
390        if file_path.is_file():
391          executor.submit(self.compress_file, file_path)
392
393  def compress_file(self, file_path: pathlib.Path) -> None:
394    """Compresses a single file to zip.
395
396    Args:
397        file_path: The path to the file to compress.
398    """
399    with zipfile.ZipFile(
400        file_path.with_suffix('.zip'), 'w', zipfile.ZIP_DEFLATED
401    ) as zip_file:
402      zip_file.write(file_path, arcname=file_path.name)
403    file_path.unlink()
404
405  def decompress_all_sub_files(self, root_path: pathlib.Path) -> None:
406    """Decompresses all compressed sub files in the given directory.
407
408    Args:
409        root_path: The path to the root directory.
410    """
411    cpu_count = multiprocessing.cpu_count()
412    with concurrent.futures.ThreadPoolExecutor(
413        max_workers=cpu_count
414    ) as executor:
415      for file_path in root_path.rglob('*.zip'):
416        executor.submit(self.decompress_file, file_path)
417
418  def decompress_file(self, file_path: pathlib.Path) -> None:
419    """Decompresses a single zip file.
420
421    Args:
422        file_path: The path to the compressed file.
423    """
424    with zipfile.ZipFile(file_path, 'r') as zip_file:
425      zip_file.extractall(file_path.parent)
426    file_path.unlink()
427
428
429class ParallelTestRunner(unittest.TextTestRunner):
430  """A class that holds the logic of parallel test execution.
431
432  Test methods wrapped by decorators defined in this class will be pre-executed
433  at the beginning of the test run in parallel and have the results cached when
434  the test runner is also this class. Available decorators: `run_in_parallel`
435  for runnint test method in parallel during both build and test env,
436  `run_in_parallel_in_build_env` for parallel run in build env only, and
437  `run_in_parallel_in_test_env` for parallel run in test env only.
438  """
439
440  _RUN_IN_PARALLEL = 'run_in_parallel'
441  _RUN_IN_PARALLEL_IN_BUILD_ENV = 'run_in_parallel_in_build_env'
442  _RUN_IN_PARALLEL_IN_TEST_ENV = 'run_in_parallel_in_test_env'
443  _DECORATOR_NAME = 'decorator_name'
444
445  @classmethod
446  def _cache_first(
447      cls, func: Callable[[Any], Any], decorator_name: str
448  ) -> Callable[[Any], Any]:
449    """Cache a function's first call result and consumes it in the next call.
450
451    This decorator is similar to the built-in `functools.cache` decorator except
452    that this decorator caches the first call's run result and emit it in the
453    next run of the function, regardless of the function's input argument value
454    changes. Caching only the first call of the test ensures test retries emit
455    fresh results.
456
457    Args:
458        func: The function to cache.
459        decorator_name: The name of the decorator.
460
461    Returns:
462        The wrapped function with queue caching ability.
463    """
464    setattr(func, cls._DECORATOR_NAME, decorator_name)
465
466    class _ResultCache:
467      result = None
468      is_to_be_cached = False
469
470    result_cache = _ResultCache()
471
472    @functools.wraps(func)
473    def _wrapped(*args, only_set_next_run_caching=False, **kwargs):
474      if only_set_next_run_caching:
475        result_cache.is_to_be_cached = True
476        return
477
478      def _get_fresh_call_result():
479        try:
480          return (func(*args, **kwargs), None)
481        # pylint: disable-next=broad-exception-caught
482        except Exception as e:
483          return (None, e)
484
485      if result_cache.is_to_be_cached:
486        result = _get_fresh_call_result()
487        result_cache.result = result
488        result_cache.is_to_be_cached = False
489      elif result_cache.result:
490        result = result_cache.result
491        result_cache.result = None
492      else:
493        result = _get_fresh_call_result()
494      if result[1]:
495        raise result[1]
496      return result[0]
497
498    return _wrapped
499
500  @classmethod
501  def run_in_parallel(cls, func: Callable[[Any], Any]) -> Callable[[Any], Any]:
502    """Hint that a test method can run in parallel."""
503    return cls._cache_first(func, cls.run_in_parallel.__name__)
504
505  @classmethod
506  def run_in_parallel_in_build_env(
507      cls, func: Callable[[Any], Any]
508  ) -> Callable[[Any], Any]:
509    """Hint that a test method can run in parallel in build env only."""
510    return cls._cache_first(func, cls.run_in_parallel_in_build_env.__name__)
511
512  @classmethod
513  def run_in_parallel_in_test_env(
514      cls, func: Callable[[Any], Any]
515  ) -> Callable[[Any], Any]:
516    """Hint that a test method can run in parallel in test env only."""
517    return cls._cache_first(func, cls.run_in_parallel_in_test_env.__name__)
518
519  @classmethod
520  def setup_parallel(cls, func: Callable[[Any], Any]) -> Callable[[Any], Any]:
521    """Hint that a method is for setting up a parallel run."""
522    return cls._cache_first(func, cls.setup_parallel.__name__)
523
524  @classmethod
525  def setup_parallel_in_build_env(
526      cls, func: Callable[[Any], Any]
527  ) -> Callable[[Any], Any]:
528    """Hint that a method is for setting up a parallel run in build env only."""
529    return cls._cache_first(func, cls.setup_parallel_in_build_env.__name__)
530
531  @classmethod
532  def setup_parallel_in_test_env(
533      cls, func: Callable[[Any], Any]
534  ) -> Callable[[Any], Any]:
535    """Hint that a method is for setting up a parallel run in test env only."""
536    return cls._cache_first(func, cls.setup_parallel_in_test_env.__name__)
537
538  def run(self, test):
539    """Executes parallel tests first and then non-parallel tests."""
540    for test_suite in test:
541      self._pre_execute_parallel_tests(test_suite)
542    return super().run(test)
543
544  @staticmethod
545  def _get_test_function(test: unittest.TestCase) -> Callable[Any, Any]:
546    """Gets the test function from a TestCase class wrapped by unittest."""
547    return getattr(test, test.id().split('.')[-1])
548
549  @classmethod
550  def _get_parallel_setups(
551      cls, test_suite: unittest.TestSuite
552  ) -> set[Callable[None, Any]]:
553    """Returns a set of functions to be executed as setup for parallel run."""
554    test_cls = None
555    for test_case in test_suite:
556      test_cls = test_case.__class__
557      break
558    if not test_cls:
559      return set()
560
561    result = set()
562    update_result = lambda decorator: result.update(
563        filter(
564            lambda func: callable(func)
565            and decorator.__name__ == getattr(func, cls._DECORATOR_NAME, None),
566            map(functools.partial(getattr, test_cls), dir(test_cls)),
567        )
568    )
569    update_result(cls.setup_parallel)
570    if test_cls.get_config().is_build_env:
571      update_result(cls.setup_parallel_in_build_env)
572    if test_cls.get_config().is_test_env:
573      update_result(cls.setup_parallel_in_test_env)
574    return result
575
576  @classmethod
577  def _get_parallel_tests(
578      cls, test_suite: unittest.TestSuite
579  ) -> Iterator[unittest.TestCase]:
580    """Returns a list of test cases to be run in parallel from a test suite."""
581    and_combine = lambda *funcs: functools.reduce(
582        lambda accu, func: lambda item: accu(item) and func(item), funcs
583    )
584    or_combine = lambda *funcs: functools.reduce(
585        lambda accu, func: lambda item: accu(item) or func(item), funcs
586    )
587    is_decorated = lambda decorator, test: decorator.__name__ == getattr(
588        cls._get_test_function(test),
589        cls._DECORATOR_NAME,
590        None,
591    )
592    is_parallel = functools.partial(is_decorated, cls.run_in_parallel)
593    is_parallel_in_build = functools.partial(
594        is_decorated, cls.run_in_parallel_in_build_env
595    )
596    is_parallel_in_test = functools.partial(
597        is_decorated, cls.run_in_parallel_in_test_env
598    )
599    is_in_build_env = lambda test: test.get_config().is_build_env
600    is_in_test_env = lambda test: test.get_config().is_test_env
601    combined_filter = or_combine(
602        and_combine(is_parallel_in_build, is_in_build_env),
603        and_combine(is_parallel_in_test, is_in_test_env),
604        is_parallel,
605    )
606    return filter(combined_filter, test_suite)
607
608  @classmethod
609  def _pre_execute_parallel_tests(cls, test_suite: unittest.TestSuite) -> None:
610    """Pre-execute parallel tests in the test suite."""
611    for setup_func in cls._get_parallel_setups(test_suite):
612      logging.info('Setting up parallel tests with function %s', setup_func)
613      setup_func()
614    with concurrent.futures.ThreadPoolExecutor(
615        max_workers=multiprocessing.cpu_count()
616    ) as executor:
617
618      def _execute_test(test):
619        # We can't directly call test.run because the function would either not
620        # know that it's being pre-executed or not know whether it's being
621        # executed by this test runner. We can't call the test function directly
622        # because setup and teardown would be missed. We can't set properties
623        # of the test function here because the test function has already been
624        # wrapped by unittest. The only way we can let the test function know
625        # that it needs to cache the next run is to call the function with a
626        # parameter first before calling the run method.
627        cls._get_test_function(test).__func__(only_set_next_run_caching=True)
628        return executor.submit(test.run)
629
630      for class_name, class_group in itertools.groupby(
631          cls._get_parallel_tests(test_suite),
632          lambda obj: f'{obj.__class__.__module__}.{obj.__class__}',
633      ):
634        test_group = list(class_group)
635        logging.info(
636            'Pre-executing %s of %s tests in parallel...',
637            len(test_group),
638            class_name,
639        )
640
641        list(concurrent.futures.as_completed(map(_execute_test, test_group)))
642
643
644def _configure_logging(verbose: bool, log_file_dir_path: pathlib.Path):
645  """Configure the logger.
646
647  Args:
648      verbose: If true display DEBUG level logs on console.
649      log_file_dir_path: A directory which stores the log file.
650  """
651  log_file = log_file_dir_path.joinpath('asuite_integration_tests.log')
652  if log_file.exists():
653    timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
654    log_file = log_file_dir_path.joinpath(
655        f'asuite_integration_tests_{timestamp}.log'
656    )
657  log_file.parent.mkdir(parents=True, exist_ok=True)
658
659  atexit.register(lambda: print('Logs are saved to %s' % log_file))
660
661  log_format = '%(asctime)s %(filename)s:%(lineno)s:%(levelname)s: %(message)s'
662  date_format = '%Y-%m-%d %H:%M:%S'
663  logging.basicConfig(
664      filename=log_file.as_posix(),
665      level=logging.DEBUG,
666      format=log_format,
667      datefmt=date_format,
668  )
669  console = logging.StreamHandler()
670  console.name = 'console'
671  console.setLevel(logging.INFO)
672  if verbose:
673    console.setLevel(logging.DEBUG)
674  console.setFormatter(logging.Formatter(log_format))
675  logging.getLogger('').addHandler(console)
676
677
678def _parse_known_args(
679    argv: list[str],
680    argparser_update_func: Callable[argparse.ArgumentParser, None] = None,
681) -> tuple[argparse.Namespace, list[str]]:
682  """Parse command line args and check required args being provided."""
683
684  description = """A script to build and/or run the Asuite integration tests.
685Usage examples:
686   python <script_path>: Runs both the build and test steps.
687   python <script_path> -b -t: Runs both the build and test steps.
688   python <script_path> -b: Runs only the build steps.
689   python <script_path> -t: Runs only the test steps.
690"""
691
692  parser = argparse.ArgumentParser(
693      add_help=True,
694      description=description,
695      formatter_class=argparse.RawDescriptionHelpFormatter,
696  )
697
698  parser.add_argument(
699      '-b',
700      '--build',
701      action='store_true',
702      default=False,
703      help=(
704          'Run build steps. Can be set to true together with the test option.'
705          ' If both build and test are unset, will run both steps.'
706      ),
707  )
708  parser.add_argument(
709      '-t',
710      '--test',
711      action='store_true',
712      default=False,
713      help=(
714          'Run test steps. Can be set to true together with the build option.'
715          ' If both build and test are unset, will run both steps.'
716      ),
717  )
718  parser.add_argument(
719      '--tar_snapshot',
720      action='store_true',
721      default=False,
722      help=(
723          'Whether to tar and untar the snapshot storage into/from a single'
724          ' file.'
725      ),
726  )
727  parser.add_argument(
728      '-v',
729      '--verbose',
730      action='store_true',
731      default=False,
732      help='Whether to set log level to verbose.',
733  )
734
735  # The below flags are passed in by the TF Python test runner.
736  parser.add_argument(
737      '-s',
738      '--serial',
739      help=(
740          'The device serial. Required in test mode when ANDROID_BUILD_TOP is'
741          ' not set.'
742      ),
743  )
744  parser.add_argument(
745      '--test-output-file',
746      help=(
747          'The file in which to store the unit test results. This option is'
748          ' usually set by TradeFed when running the script with python and'
749          ' is optional during manual script execution.'
750      ),
751  )
752
753  if argparser_update_func:
754    argparser_update_func(parser)
755
756  return parser.parse_known_args(argv)
757
758
759def _run_test(
760    config: IntegrationTestConfiguration,
761    argv: list[str],
762    test_output_file_path: str = None,
763) -> None:
764  """Execute integration tests with given test configuration."""
765
766  compressor = _FileCompressor()
767
768  def cleanup() -> None:
769    if config.workspace_path.exists():
770      shutil.rmtree(config.workspace_path)
771    if config.snapshot_storage_path.exists():
772      shutil.rmtree(config.snapshot_storage_path)
773
774  if config.is_test_env and config.is_tar_snapshot:
775    if not config.snapshot_storage_tar_path.exists():
776      raise EnvironmentError(
777          f'Snapshot tar {config.snapshot_storage_tar_path} does not'
778          ' exist. Have you run the build mode with --tar_snapshot'
779          ' option enabled?'
780      )
781    with tarfile.open(config.snapshot_storage_tar_path, 'r') as tar:
782      tar.extractall(config.snapshot_storage_path.parent.as_posix())
783
784    logging.info(
785        'Decompressing the snapshot storage with %s threads...',
786        multiprocessing.cpu_count(),
787    )
788    start_time = time.time()
789    compressor.decompress_all_sub_files(config.snapshot_storage_path)
790    logging.info(
791        'Decompression finished in {:.2f} seconds'.format(
792            time.time() - start_time
793        )
794    )
795
796    atexit.register(cleanup)
797
798  def unittest_main(stream=None):
799    # Note that we use a type and not an instance for 'testRunner'
800    # since TestProgram forwards its constructor arguments when creating
801    # an instance of the runner type. Not doing so would require us to
802    # make sure that the parameters passed to TestProgram are aligned
803    # with those for creating a runner instance.
804    class TestRunner(ParallelTestRunner):
805      """Writes test results to the TF-provided file."""
806
807      def __init__(self, *args: Any, **kwargs: Any) -> None:
808        super().__init__(stream=stream, *args, **kwargs)
809
810    class TestLoader(unittest.TestLoader):
811      """Injects the test configuration to the test classes."""
812
813      def loadTestsFromTestCase(self, *args, **kwargs):
814        test_suite = super().loadTestsFromTestCase(*args, **kwargs)
815        for test in test_suite:
816          test.__class__.set_config(config)
817          break
818        return test_suite
819
820    # Setting verbosity is required to generate output that the TradeFed
821    # test runner can parse.
822    unittest.main(
823        testRunner=TestRunner,
824        verbosity=3,
825        argv=argv,
826        testLoader=TestLoader(),
827        exit=config.is_test_env,
828    )
829
830  if test_output_file_path:
831    pathlib.Path(test_output_file_path).parent.mkdir(exist_ok=True)
832
833    with open(test_output_file_path, 'w', encoding='utf-8') as test_output_file:
834      unittest_main(stream=test_output_file)
835  else:
836    unittest_main(stream=None)
837
838  if config.is_build_env and config.is_tar_snapshot:
839    logging.info(
840        'Compressing the snapshot storage with %s threads...',
841        multiprocessing.cpu_count(),
842    )
843    start_time = time.time()
844    compressor.compress_all_sub_files(config.snapshot_storage_path)
845    logging.info(
846        'Compression finished in {:.2f} seconds'.format(
847            time.time() - start_time
848        )
849    )
850
851    with tarfile.open(config.snapshot_storage_tar_path, 'w') as tar:
852      tar.add(
853          config.snapshot_storage_path,
854          arcname=config.snapshot_storage_path.name,
855      )
856    cleanup()
857
858
859def main(
860    argv: list[str] = None,
861    make_before_build: list[str] = None,
862    argparser_update_func: Callable[argparse.ArgumentParser, None] = None,
863    config_update_function: Callable[
864        [IntegrationTestConfiguration, argparse.Namespace], None
865    ] = None,
866) -> None:
867  """Main method to start the integration tests.
868
869  Args:
870      argv: A list of arguments to parse.
871      make_before_build: A list of targets to make before running build steps.
872      argparser_update_func: A function that takes an ArgumentParser object and
873        updates it.
874      config_update_function: A function that takes a
875        IntegrationTestConfiguration config and the parsed args to updates the
876        config.
877
878  Raises:
879      EnvironmentError: When some environment variables are missing.
880  """
881  if not argv:
882    argv = sys.argv
883  if make_before_build is None:
884    make_before_build = []
885
886  args, unittest_argv = _parse_known_args(argv, argparser_update_func)
887
888  snapshot_storage_dir_name = 'snapshot_storage'
889  snapshot_storage_tar_name = 'snapshot.tar'
890
891  integration_test_out_path = pathlib.Path(
892      tempfile.gettempdir(),
893      'asuite_integration_tests_%s'
894      % pathlib.Path('~').expanduser().name.replace(' ', '_'),
895  )
896
897  if SNAPSHOT_STORAGE_TAR_KEY in os.environ:
898    snapshot_storage_tar_path = pathlib.Path(
899        os.environ[SNAPSHOT_STORAGE_TAR_KEY]
900    )
901    snapshot_storage_tar_path.parent.mkdir(parents=True, exist_ok=True)
902  else:
903    snapshot_storage_tar_path = integration_test_out_path.joinpath(
904        snapshot_storage_tar_name
905    )
906
907  _configure_logging(args.verbose, snapshot_storage_tar_path.parent)
908
909  logging.debug('The os environ is: %s', os.environ)
910
911  # When the build or test is unset, assume it's a local run for both build
912  # and test steps.
913  is_build_test_unset = not args.build and not args.test
914  config = IntegrationTestConfiguration()
915  config.is_build_env = args.build or is_build_test_unset
916  config.is_test_env = args.test or is_build_test_unset
917  config.device_serial = args.serial
918  config.snapshot_storage_path = integration_test_out_path.joinpath(
919      snapshot_storage_dir_name
920  )
921  config.snapshot_storage_tar_path = snapshot_storage_tar_path
922  config.workspace_path = integration_test_out_path.joinpath('workspace')
923  config.is_tar_snapshot = args.tar_snapshot
924
925  if config_update_function:
926    config_update_function(config, args)
927
928  if config.is_build_env:
929    if ANDROID_BUILD_TOP_KEY not in os.environ:
930      raise EnvironmentError(
931          f'Environment variable {ANDROID_BUILD_TOP_KEY} is required to'
932          ' build the integration test.'
933      )
934
935    repo_root = os.environ[ANDROID_BUILD_TOP_KEY]
936
937    total, used, free = shutil.disk_usage(repo_root)
938    logging.debug(
939        'Disk usage: Total: {:.2f} GB, Used: {:.2f} GB, Free: {:.2f} GB'.format(
940            total / (1024**3), used / (1024**3), free / (1024**3)
941        )
942    )
943
944    if 'OUT_DIR' in os.environ:
945      out_dir = os.environ['OUT_DIR']
946      if os.path.isabs(out_dir) and not pathlib.Path(out_dir).is_relative_to(
947          repo_root
948      ):
949        raise EnvironmentError(
950            f'$OUT_DIR {out_dir} not relative to the repo root'
951            f' {repo_root} is not supported yet.'
952        )
953    elif 'HOST_OUT' in os.environ:
954      out_dir = (
955          pathlib.Path(os.environ['HOST_OUT']).relative_to(repo_root).parts[0]
956      )
957    else:
958      out_dir = 'out'
959    os.environ['OUT_DIR'] = out_dir
960
961    for target in make_before_build:
962      logging.info(
963          'Building the %s target before integration test run.', target
964      )
965      subprocess.check_call(
966          f'build/soong/soong_ui.bash --make-mode {target}'.split(),
967          cwd=repo_root,
968      )
969
970  if config.is_build_env ^ config.is_test_env:
971    _run_test(config, unittest_argv, args.test_output_file)
972    return
973
974  build_config = copy.deepcopy(config)
975  build_config.is_test_env = False
976
977  test_config = copy.deepcopy(config)
978  test_config.is_build_env = False
979
980  _run_test(build_config, unittest_argv, args.test_output_file)
981  _run_test(test_config, unittest_argv, args.test_output_file)
982