• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The Abseil Authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Base functionality for Abseil Python tests.
16
17This module contains base classes and high-level functions for Abseil-style
18tests.
19"""
20
21from collections import abc
22import contextlib
23import difflib
24import enum
25import errno
26import getpass
27import inspect
28import io
29import itertools
30import json
31import os
32import random
33import re
34import shlex
35import shutil
36import signal
37import stat
38import subprocess
39import sys
40import tempfile
41import textwrap
42import unittest
43from unittest import mock  # pylint: disable=unused-import Allow absltest.mock.
44from urllib import parse
45
46try:
47  # The faulthandler module isn't always available, and pytype doesn't
48  # understand that we're catching ImportError, so suppress the error.
49  # pytype: disable=import-error
50  import faulthandler
51  # pytype: enable=import-error
52except ImportError:
53  # We use faulthandler if it is available.
54  faulthandler = None
55
56from absl import app
57from absl import flags
58from absl import logging
59from absl.testing import _pretty_print_reporter
60from absl.testing import xml_reporter
61
62# Make typing an optional import to avoid it being a required dependency
63# in Python 2. Type checkers will still understand the imports.
64try:
65  # pylint: disable=unused-import
66  import typing
67  from typing import Any, AnyStr, BinaryIO, Callable, ContextManager, IO, Iterator, List, Mapping, MutableMapping, MutableSequence, Optional, Sequence, Text, TextIO, Tuple, Type, Union
68  # pylint: enable=unused-import
69except ImportError:
70  pass
71else:
72  # Use an if-type-checking block to prevent leakage of type-checking only
73  # symbols. We don't want people relying on these at runtime.
74  if typing.TYPE_CHECKING:
75    # Unbounded TypeVar for general usage
76    _T = typing.TypeVar('_T')
77
78    import unittest.case
79    _OutcomeType = unittest.case._Outcome  # pytype: disable=module-attr
80
81
82
83# Re-export a bunch of unittest functions we support so that people don't
84# have to import unittest to get them
85# pylint: disable=invalid-name
86skip = unittest.skip
87skipIf = unittest.skipIf
88skipUnless = unittest.skipUnless
89SkipTest = unittest.SkipTest
90expectedFailure = unittest.expectedFailure
91# pylint: enable=invalid-name
92
93# End unittest re-exports
94
95FLAGS = flags.FLAGS
96
97_TEXT_OR_BINARY_TYPES = (str, bytes)
98
99# Suppress surplus entries in AssertionError stack traces.
100__unittest = True  # pylint: disable=invalid-name
101
102
103def expectedFailureIf(condition, reason):  # pylint: disable=invalid-name
104  """Expects the test to fail if the run condition is True.
105
106  Example usage::
107
108      @expectedFailureIf(sys.version.major == 2, "Not yet working in py2")
109      def test_foo(self):
110        ...
111
112  Args:
113    condition: bool, whether to expect failure or not.
114    reason: Text, the reason to expect failure.
115  Returns:
116    Decorator function
117  """
118  del reason  # Unused
119  if condition:
120    return unittest.expectedFailure
121  else:
122    return lambda f: f
123
124
125class TempFileCleanup(enum.Enum):
126  # Always cleanup temp files when the test completes.
127  ALWAYS = 'always'
128  # Only cleanup temp file if the test passes. This allows easier inspection
129  # of tempfile contents on test failure. absltest.TEST_TMPDIR.value determines
130  # where tempfiles are created.
131  SUCCESS = 'success'
132  # Never cleanup temp files.
133  OFF = 'never'
134
135
136# Many of the methods in this module have names like assertSameElements.
137# This kind of name does not comply with PEP8 style,
138# but it is consistent with the naming of methods in unittest.py.
139# pylint: disable=invalid-name
140
141
142def _get_default_test_random_seed():
143  # type: () -> int
144  random_seed = 301
145  value = os.environ.get('TEST_RANDOM_SEED', '')
146  try:
147    random_seed = int(value)
148  except ValueError:
149    pass
150  return random_seed
151
152
153def get_default_test_srcdir():
154  # type: () -> Text
155  """Returns default test source dir."""
156  return os.environ.get('TEST_SRCDIR', '')
157
158
159def get_default_test_tmpdir():
160  # type: () -> Text
161  """Returns default test temp dir."""
162  tmpdir = os.environ.get('TEST_TMPDIR', '')
163  if not tmpdir:
164    tmpdir = os.path.join(tempfile.gettempdir(), 'absl_testing')
165
166  return tmpdir
167
168
169def _get_default_randomize_ordering_seed():
170  # type: () -> int
171  """Returns default seed to use for randomizing test order.
172
173  This function first checks the --test_randomize_ordering_seed flag, and then
174  the TEST_RANDOMIZE_ORDERING_SEED environment variable. If the first value
175  we find is:
176    * (not set): disable test randomization
177    * 0: disable test randomization
178    * 'random': choose a random seed in [1, 4294967295] for test order
179      randomization
180    * positive integer: use this seed for test order randomization
181
182  (The values used are patterned after
183  https://docs.python.org/3/using/cmdline.html#envvar-PYTHONHASHSEED).
184
185  In principle, it would be simpler to return None if no override is provided;
186  however, the python random module has no `get_seed()`, only `getstate()`,
187  which returns far more data than we want to pass via an environment variable
188  or flag.
189
190  Returns:
191    A default value for test case randomization (int). 0 means do not randomize.
192
193  Raises:
194    ValueError: Raised when the flag or env value is not one of the options
195        above.
196  """
197  if FLAGS['test_randomize_ordering_seed'].present:
198    randomize = FLAGS.test_randomize_ordering_seed
199  elif 'TEST_RANDOMIZE_ORDERING_SEED' in os.environ:
200    randomize = os.environ['TEST_RANDOMIZE_ORDERING_SEED']
201  else:
202    randomize = ''
203  if not randomize:
204    return 0
205  if randomize == 'random':
206    return random.Random().randint(1, 4294967295)
207  if randomize == '0':
208    return 0
209  try:
210    seed = int(randomize)
211    if seed > 0:
212      return seed
213  except ValueError:
214    pass
215  raise ValueError(
216      'Unknown test randomization seed value: {}'.format(randomize))
217
218
219TEST_SRCDIR = flags.DEFINE_string(
220    'test_srcdir',
221    get_default_test_srcdir(),
222    'Root of directory tree where source files live',
223    allow_override_cpp=True)
224TEST_TMPDIR = flags.DEFINE_string(
225    'test_tmpdir',
226    get_default_test_tmpdir(),
227    'Directory for temporary testing files',
228    allow_override_cpp=True)
229
230flags.DEFINE_integer(
231    'test_random_seed',
232    _get_default_test_random_seed(),
233    'Random seed for testing. Some test frameworks may '
234    'change the default value of this flag between runs, so '
235    'it is not appropriate for seeding probabilistic tests.',
236    allow_override_cpp=True)
237flags.DEFINE_string(
238    'test_randomize_ordering_seed',
239    '',
240    'If positive, use this as a seed to randomize the '
241    'execution order for test cases. If "random", pick a '
242    'random seed to use. If 0 or not set, do not randomize '
243    'test case execution order. This flag also overrides '
244    'the TEST_RANDOMIZE_ORDERING_SEED environment variable.',
245    allow_override_cpp=True)
246flags.DEFINE_string('xml_output_file', '', 'File to store XML test results')
247
248
249# We might need to monkey-patch TestResult so that it stops considering an
250# unexpected pass as a as a "successful result".  For details, see
251# http://bugs.python.org/issue20165
252def _monkey_patch_test_result_for_unexpected_passes():
253  # type: () -> None
254  """Workaround for <http://bugs.python.org/issue20165>."""
255
256  def wasSuccessful(self):
257    # type: () -> bool
258    """Tells whether or not this result was a success.
259
260    Any unexpected pass is to be counted as a non-success.
261
262    Args:
263      self: The TestResult instance.
264
265    Returns:
266      Whether or not this result was a success.
267    """
268    return (len(self.failures) == len(self.errors) ==
269            len(self.unexpectedSuccesses) == 0)
270
271  test_result = unittest.TestResult()
272  test_result.addUnexpectedSuccess(unittest.FunctionTestCase(lambda: None))
273  if test_result.wasSuccessful():  # The bug is present.
274    unittest.TestResult.wasSuccessful = wasSuccessful
275    if test_result.wasSuccessful():  # Warn the user if our hot-fix failed.
276      sys.stderr.write('unittest.result.TestResult monkey patch to report'
277                       ' unexpected passes as failures did not work.\n')
278
279
280_monkey_patch_test_result_for_unexpected_passes()
281
282
283def _open(filepath, mode, _open_func=open):
284  # type: (Text, Text, Callable[..., IO]) -> IO
285  """Opens a file.
286
287  Like open(), but ensure that we can open real files even if tests stub out
288  open().
289
290  Args:
291    filepath: A filepath.
292    mode: A mode.
293    _open_func: A built-in open() function.
294
295  Returns:
296    The opened file object.
297  """
298  return _open_func(filepath, mode, encoding='utf-8')
299
300
301class _TempDir(object):
302  """Represents a temporary directory for tests.
303
304  Creation of this class is internal. Using its public methods is OK.
305
306  This class implements the `os.PathLike` interface (specifically,
307  `os.PathLike[str]`). This means, in Python 3, it can be directly passed
308  to e.g. `os.path.join()`.
309  """
310
311  def __init__(self, path):
312    # type: (Text) -> None
313    """Module-private: do not instantiate outside module."""
314    self._path = path
315
316  @property
317  def full_path(self):
318    # type: () -> Text
319    """Returns the path, as a string, for the directory.
320
321    TIP: Instead of e.g. `os.path.join(temp_dir.full_path)`, you can simply
322    do `os.path.join(temp_dir)` because `__fspath__()` is implemented.
323    """
324    return self._path
325
326  def __fspath__(self):
327    # type: () -> Text
328    """See os.PathLike."""
329    return self.full_path
330
331  def create_file(self, file_path=None, content=None, mode='w', encoding='utf8',
332                  errors='strict'):
333    # type: (Optional[Text], Optional[AnyStr], Text, Text, Text) -> _TempFile
334    """Create a file in the directory.
335
336    NOTE: If the file already exists, it will be made writable and overwritten.
337
338    Args:
339      file_path: Optional file path for the temp file. If not given, a unique
340        file name will be generated and used. Slashes are allowed in the name;
341        any missing intermediate directories will be created. NOTE: This path
342        is the path that will be cleaned up, including any directories in the
343        path, e.g., 'foo/bar/baz.txt' will `rm -r foo`
344      content: Optional string or bytes to initially write to the file. If not
345        specified, then an empty file is created.
346      mode: Mode string to use when writing content. Only used if `content` is
347        non-empty.
348      encoding: Encoding to use when writing string content. Only used if
349        `content` is text.
350      errors: How to handle text to bytes encoding errors. Only used if
351        `content` is text.
352
353    Returns:
354      A _TempFile representing the created file.
355    """
356    tf, _ = _TempFile._create(self._path, file_path, content, mode, encoding,
357                              errors)
358    return tf
359
360  def mkdir(self, dir_path=None):
361    # type: (Optional[Text]) -> _TempDir
362    """Create a directory in the directory.
363
364    Args:
365      dir_path: Optional path to the directory to create. If not given,
366        a unique name will be generated and used.
367
368    Returns:
369      A _TempDir representing the created directory.
370    """
371    if dir_path:
372      path = os.path.join(self._path, dir_path)
373    else:
374      path = tempfile.mkdtemp(dir=self._path)
375
376    # Note: there's no need to clear the directory since the containing
377    # dir was cleared by the tempdir() function.
378    os.makedirs(path, exist_ok=True)
379    return _TempDir(path)
380
381
382class _TempFile(object):
383  """Represents a tempfile for tests.
384
385  Creation of this class is internal. Using its public methods is OK.
386
387  This class implements the `os.PathLike` interface (specifically,
388  `os.PathLike[str]`). This means, in Python 3, it can be directly passed
389  to e.g. `os.path.join()`.
390  """
391
392  def __init__(self, path):
393    # type: (Text) -> None
394    """Private: use _create instead."""
395    self._path = path
396
397  # pylint: disable=line-too-long
398  @classmethod
399  def _create(cls, base_path, file_path, content, mode, encoding, errors):
400    # type: (Text, Optional[Text], AnyStr, Text, Text, Text) -> Tuple[_TempFile, Text]
401    # pylint: enable=line-too-long
402    """Module-private: create a tempfile instance."""
403    if file_path:
404      cleanup_path = os.path.join(base_path, _get_first_part(file_path))
405      path = os.path.join(base_path, file_path)
406      os.makedirs(os.path.dirname(path), exist_ok=True)
407      # The file may already exist, in which case, ensure it's writable so that
408      # it can be truncated.
409      if os.path.exists(path) and not os.access(path, os.W_OK):
410        stat_info = os.stat(path)
411        os.chmod(path, stat_info.st_mode | stat.S_IWUSR)
412    else:
413      os.makedirs(base_path, exist_ok=True)
414      fd, path = tempfile.mkstemp(dir=str(base_path))
415      os.close(fd)
416      cleanup_path = path
417
418    tf = cls(path)
419
420    if content:
421      if isinstance(content, str):
422        tf.write_text(content, mode=mode, encoding=encoding, errors=errors)
423      else:
424        tf.write_bytes(content, mode)
425
426    else:
427      tf.write_bytes(b'')
428
429    return tf, cleanup_path
430
431  @property
432  def full_path(self):
433    # type: () -> Text
434    """Returns the path, as a string, for the file.
435
436    TIP: Instead of e.g. `os.path.join(temp_file.full_path)`, you can simply
437    do `os.path.join(temp_file)` because `__fspath__()` is implemented.
438    """
439    return self._path
440
441  def __fspath__(self):
442    # type: () -> Text
443    """See os.PathLike."""
444    return self.full_path
445
446  def read_text(self, encoding='utf8', errors='strict'):
447    # type: (Text, Text) -> Text
448    """Return the contents of the file as text."""
449    with self.open_text(encoding=encoding, errors=errors) as fp:
450      return fp.read()
451
452  def read_bytes(self):
453    # type: () -> bytes
454    """Return the content of the file as bytes."""
455    with self.open_bytes() as fp:
456      return fp.read()
457
458  def write_text(self, text, mode='w', encoding='utf8', errors='strict'):
459    # type: (Text, Text, Text, Text) -> None
460    """Write text to the file.
461
462    Args:
463      text: Text to write. In Python 2, it can be bytes, which will be
464        decoded using the `encoding` arg (this is as an aid for code that
465        is 2 and 3 compatible).
466      mode: The mode to open the file for writing.
467      encoding: The encoding to use when writing the text to the file.
468      errors: The error handling strategy to use when converting text to bytes.
469    """
470    with self.open_text(mode, encoding=encoding, errors=errors) as fp:
471      fp.write(text)
472
473  def write_bytes(self, data, mode='wb'):
474    # type: (bytes, Text) -> None
475    """Write bytes to the file.
476
477    Args:
478      data: bytes to write.
479      mode: Mode to open the file for writing. The "b" flag is implicit if
480        not already present. It must not have the "t" flag.
481    """
482    with self.open_bytes(mode) as fp:
483      fp.write(data)
484
485  def open_text(self, mode='rt', encoding='utf8', errors='strict'):
486    # type: (Text, Text, Text) -> ContextManager[TextIO]
487    """Return a context manager for opening the file in text mode.
488
489    Args:
490      mode: The mode to open the file in. The "t" flag is implicit if not
491        already present. It must not have the "b" flag.
492      encoding: The encoding to use when opening the file.
493      errors: How to handle decoding errors.
494
495    Returns:
496      Context manager that yields an open file.
497
498    Raises:
499      ValueError: if invalid inputs are provided.
500    """
501    if 'b' in mode:
502      raise ValueError('Invalid mode {!r}: "b" flag not allowed when opening '
503                       'file in text mode'.format(mode))
504    if 't' not in mode:
505      mode += 't'
506    cm = self._open(mode, encoding, errors)
507    return cm
508
509  def open_bytes(self, mode='rb'):
510    # type: (Text) -> ContextManager[BinaryIO]
511    """Return a context manager for opening the file in binary mode.
512
513    Args:
514      mode: The mode to open the file in. The "b" mode is implicit if not
515        already present. It must not have the "t" flag.
516
517    Returns:
518      Context manager that yields an open file.
519
520    Raises:
521      ValueError: if invalid inputs are provided.
522    """
523    if 't' in mode:
524      raise ValueError('Invalid mode {!r}: "t" flag not allowed when opening '
525                       'file in binary mode'.format(mode))
526    if 'b' not in mode:
527      mode += 'b'
528    cm = self._open(mode, encoding=None, errors=None)
529    return cm
530
531  # TODO(b/123775699): Once pytype supports typing.Literal, use overload and
532  # Literal to express more precise return types. The contained type is
533  # currently `Any` to avoid [bad-return-type] errors in the open_* methods.
534  @contextlib.contextmanager
535  def _open(
536      self,
537      mode: str,
538      encoding: Optional[str] = 'utf8',
539      errors: Optional[str] = 'strict',
540  ) -> Iterator[Any]:
541    with io.open(
542        self.full_path, mode=mode, encoding=encoding, errors=errors) as fp:
543      yield fp
544
545
546class _method(object):
547  """A decorator that supports both instance and classmethod invocations.
548
549  Using similar semantics to the @property builtin, this decorator can augment
550  an instance method to support conditional logic when invoked on a class
551  object. This breaks support for invoking an instance method via the class
552  (e.g. Cls.method(self, ...)) but is still situationally useful.
553  """
554
555  def __init__(self, finstancemethod):
556    # type: (Callable[..., Any]) -> None
557    self._finstancemethod = finstancemethod
558    self._fclassmethod = None
559
560  def classmethod(self, fclassmethod):
561    # type: (Callable[..., Any]) -> _method
562    self._fclassmethod = classmethod(fclassmethod)
563    return self
564
565  def __doc__(self):
566    # type: () -> str
567    if getattr(self._finstancemethod, '__doc__'):
568      return self._finstancemethod.__doc__
569    elif getattr(self._fclassmethod, '__doc__'):
570      return self._fclassmethod.__doc__
571    return ''
572
573  def __get__(self, obj, type_):
574    # type: (Optional[Any], Optional[Type[Any]]) -> Callable[..., Any]
575    func = self._fclassmethod if obj is None else self._finstancemethod
576    return func.__get__(obj, type_)  # pytype: disable=attribute-error
577
578
579class TestCase(unittest.TestCase):
580  """Extension of unittest.TestCase providing more power."""
581
582  # When to cleanup files/directories created by our `create_tempfile()` and
583  # `create_tempdir()` methods after each test case completes. This does *not*
584  # affect e.g., files created outside of those methods, e.g., using the stdlib
585  # tempfile module. This can be overridden at the class level, instance level,
586  # or with the `cleanup` arg of `create_tempfile()` and `create_tempdir()`. See
587  # `TempFileCleanup` for details on the different values.
588  # TODO(b/70517332): Remove the type comment and the disable once pytype has
589  # better support for enums.
590  tempfile_cleanup = TempFileCleanup.ALWAYS  # type: TempFileCleanup  # pytype: disable=annotation-type-mismatch
591
592  maxDiff = 80 * 20
593  longMessage = True
594
595  # Exit stacks for per-test and per-class scopes.
596  _exit_stack = None
597  _cls_exit_stack = None
598
599  def __init__(self, *args, **kwargs):
600    super(TestCase, self).__init__(*args, **kwargs)
601    # This is to work around missing type stubs in unittest.pyi
602    self._outcome = getattr(self, '_outcome')  # type: Optional[_OutcomeType]
603
604  def setUp(self):
605    super(TestCase, self).setUp()
606    # NOTE: Only Python 3 contextlib has ExitStack
607    if hasattr(contextlib, 'ExitStack'):
608      self._exit_stack = contextlib.ExitStack()
609      self.addCleanup(self._exit_stack.close)
610
611  @classmethod
612  def setUpClass(cls):
613    super(TestCase, cls).setUpClass()
614    # NOTE: Only Python 3 contextlib has ExitStack and only Python 3.8+ has
615    # addClassCleanup.
616    if hasattr(contextlib, 'ExitStack') and hasattr(cls, 'addClassCleanup'):
617      cls._cls_exit_stack = contextlib.ExitStack()
618      cls.addClassCleanup(cls._cls_exit_stack.close)
619
620  def create_tempdir(self, name=None, cleanup=None):
621    # type: (Optional[Text], Optional[TempFileCleanup]) -> _TempDir
622    """Create a temporary directory specific to the test.
623
624    NOTE: The directory and its contents will be recursively cleared before
625    creation. This ensures that there is no pre-existing state.
626
627    This creates a named directory on disk that is isolated to this test, and
628    will be properly cleaned up by the test. This avoids several pitfalls of
629    creating temporary directories for test purposes, as well as makes it easier
630    to setup directories and verify their contents. For example::
631
632        def test_foo(self):
633          out_dir = self.create_tempdir()
634          out_log = out_dir.create_file('output.log')
635          expected_outputs = [
636              os.path.join(out_dir, 'data-0.txt'),
637              os.path.join(out_dir, 'data-1.txt'),
638          ]
639          code_under_test(out_dir)
640          self.assertTrue(os.path.exists(expected_paths[0]))
641          self.assertTrue(os.path.exists(expected_paths[1]))
642          self.assertEqual('foo', out_log.read_text())
643
644    See also: :meth:`create_tempfile` for creating temporary files.
645
646    Args:
647      name: Optional name of the directory. If not given, a unique
648        name will be generated and used.
649      cleanup: Optional cleanup policy on when/if to remove the directory (and
650        all its contents) at the end of the test. If None, then uses
651        :attr:`tempfile_cleanup`.
652
653    Returns:
654      A _TempDir representing the created directory; see _TempDir class docs
655      for usage.
656    """
657    test_path = self._get_tempdir_path_test()
658
659    if name:
660      path = os.path.join(test_path, name)
661      cleanup_path = os.path.join(test_path, _get_first_part(name))
662    else:
663      os.makedirs(test_path, exist_ok=True)
664      path = tempfile.mkdtemp(dir=test_path)
665      cleanup_path = path
666
667    _rmtree_ignore_errors(cleanup_path)
668    os.makedirs(path, exist_ok=True)
669
670    self._maybe_add_temp_path_cleanup(cleanup_path, cleanup)
671
672    return _TempDir(path)
673
674  # pylint: disable=line-too-long
675  def create_tempfile(self, file_path=None, content=None, mode='w',
676                      encoding='utf8', errors='strict', cleanup=None):
677    # type: (Optional[Text], Optional[AnyStr], Text, Text, Text, Optional[TempFileCleanup]) -> _TempFile
678    # pylint: enable=line-too-long
679    """Create a temporary file specific to the test.
680
681    This creates a named file on disk that is isolated to this test, and will
682    be properly cleaned up by the test. This avoids several pitfalls of
683    creating temporary files for test purposes, as well as makes it easier
684    to setup files, their data, read them back, and inspect them when
685    a test fails. For example::
686
687        def test_foo(self):
688          output = self.create_tempfile()
689          code_under_test(output)
690          self.assertGreater(os.path.getsize(output), 0)
691          self.assertEqual('foo', output.read_text())
692
693    NOTE: This will zero-out the file. This ensures there is no pre-existing
694    state.
695    NOTE: If the file already exists, it will be made writable and overwritten.
696
697    See also: :meth:`create_tempdir` for creating temporary directories, and
698    ``_TempDir.create_file`` for creating files within a temporary directory.
699
700    Args:
701      file_path: Optional file path for the temp file. If not given, a unique
702        file name will be generated and used. Slashes are allowed in the name;
703        any missing intermediate directories will be created. NOTE: This path is
704        the path that will be cleaned up, including any directories in the path,
705        e.g., ``'foo/bar/baz.txt'`` will ``rm -r foo``.
706      content: Optional string or
707        bytes to initially write to the file. If not
708        specified, then an empty file is created.
709      mode: Mode string to use when writing content. Only used if `content` is
710        non-empty.
711      encoding: Encoding to use when writing string content. Only used if
712        `content` is text.
713      errors: How to handle text to bytes encoding errors. Only used if
714        `content` is text.
715      cleanup: Optional cleanup policy on when/if to remove the directory (and
716        all its contents) at the end of the test. If None, then uses
717        :attr:`tempfile_cleanup`.
718
719    Returns:
720      A _TempFile representing the created file; see _TempFile class docs for
721      usage.
722    """
723    test_path = self._get_tempdir_path_test()
724    tf, cleanup_path = _TempFile._create(test_path, file_path, content=content,
725                                         mode=mode, encoding=encoding,
726                                         errors=errors)
727    self._maybe_add_temp_path_cleanup(cleanup_path, cleanup)
728    return tf
729
730  @_method
731  def enter_context(self, manager):
732    # type: (ContextManager[_T]) -> _T
733    """Returns the CM's value after registering it with the exit stack.
734
735    Entering a context pushes it onto a stack of contexts. When `enter_context`
736    is called on the test instance (e.g. `self.enter_context`), the context is
737    exited after the test case's tearDown call. When called on the test class
738    (e.g. `TestCase.enter_context`), the context is exited after the test
739    class's tearDownClass call.
740
741    Contexts are exited in the reverse order of entering. They will always
742    be exited, regardless of test failure/success.
743
744    This is useful to eliminate per-test boilerplate when context managers
745    are used. For example, instead of decorating every test with `@mock.patch`,
746    simply do `self.foo = self.enter_context(mock.patch(...))' in `setUp()`.
747
748    NOTE: The context managers will always be exited without any error
749    information. This is an unfortunate implementation detail due to some
750    internals of how unittest runs tests.
751
752    Args:
753      manager: The context manager to enter.
754    """
755    if not self._exit_stack:
756      raise AssertionError(
757          'self._exit_stack is not set: enter_context is Py3-only; also make '
758          'sure that AbslTest.setUp() is called.')
759    return self._exit_stack.enter_context(manager)
760
761  @enter_context.classmethod
762  def enter_context(cls, manager):  # pylint: disable=no-self-argument
763    # type: (ContextManager[_T]) -> _T
764    if not cls._cls_exit_stack:
765      raise AssertionError(
766          'cls._cls_exit_stack is not set: cls.enter_context requires '
767          'Python 3.8+; also make sure that AbslTest.setUpClass() is called.')
768    return cls._cls_exit_stack.enter_context(manager)
769
770  @classmethod
771  def _get_tempdir_path_cls(cls):
772    # type: () -> Text
773    return os.path.join(TEST_TMPDIR.value,
774                        cls.__qualname__.replace('__main__.', ''))
775
776  def _get_tempdir_path_test(self):
777    # type: () -> Text
778    return os.path.join(self._get_tempdir_path_cls(), self._testMethodName)
779
780  def _get_tempfile_cleanup(self, override):
781    # type: (Optional[TempFileCleanup]) -> TempFileCleanup
782    if override is not None:
783      return override
784    return self.tempfile_cleanup
785
786  def _maybe_add_temp_path_cleanup(self, path, cleanup):
787    # type: (Text, Optional[TempFileCleanup]) -> None
788    cleanup = self._get_tempfile_cleanup(cleanup)
789    if cleanup == TempFileCleanup.OFF:
790      return
791    elif cleanup == TempFileCleanup.ALWAYS:
792      self.addCleanup(_rmtree_ignore_errors, path)
793    elif cleanup == TempFileCleanup.SUCCESS:
794      self._internal_add_cleanup_on_success(_rmtree_ignore_errors, path)
795    else:
796      raise AssertionError('Unexpected cleanup value: {}'.format(cleanup))
797
798  def _internal_add_cleanup_on_success(
799      self,
800      function: Callable[..., Any],
801      *args: Any,
802      **kwargs: Any,
803  ) -> None:
804    """Adds `function` as cleanup when the test case succeeds."""
805    outcome = self._outcome
806    previous_failure_count = (
807        len(outcome.result.failures)
808        + len(outcome.result.errors)
809        + len(outcome.result.unexpectedSuccesses)
810    )
811    def _call_cleaner_on_success(*args, **kwargs):
812      if not self._internal_ran_and_passed_when_called_during_cleanup(
813          previous_failure_count):
814        return
815      function(*args, **kwargs)
816    self.addCleanup(_call_cleaner_on_success, *args, **kwargs)
817
818  def _internal_ran_and_passed_when_called_during_cleanup(
819      self,
820      previous_failure_count: int,
821  ) -> bool:
822    """Returns whether test is passed. Expected to be called during cleanup."""
823    outcome = self._outcome
824    if sys.version_info[:2] >= (3, 11):
825      current_failure_count = (
826          len(outcome.result.failures)
827          + len(outcome.result.errors)
828          + len(outcome.result.unexpectedSuccesses)
829      )
830      return current_failure_count == previous_failure_count
831    else:
832      # Before Python 3.11 https://github.com/python/cpython/pull/28180, errors
833      # were bufferred in _Outcome before calling cleanup.
834      result = self.defaultTestResult()
835      self._feedErrorsToResult(result, outcome.errors)  # pytype: disable=attribute-error
836      return result.wasSuccessful()
837
838  def shortDescription(self):
839    # type: () -> Text
840    """Formats both the test method name and the first line of its docstring.
841
842    If no docstring is given, only returns the method name.
843
844    This method overrides unittest.TestCase.shortDescription(), which
845    only returns the first line of the docstring, obscuring the name
846    of the test upon failure.
847
848    Returns:
849      desc: A short description of a test method.
850    """
851    desc = self.id()
852
853    # Omit the main name so that test name can be directly copy/pasted to
854    # the command line.
855    if desc.startswith('__main__.'):
856      desc = desc[len('__main__.'):]
857
858    # NOTE: super() is used here instead of directly invoking
859    # unittest.TestCase.shortDescription(self), because of the
860    # following line that occurs later on:
861    #       unittest.TestCase = TestCase
862    # Because of this, direct invocation of what we think is the
863    # superclass will actually cause infinite recursion.
864    doc_first_line = super(TestCase, self).shortDescription()
865    if doc_first_line is not None:
866      desc = '\n'.join((desc, doc_first_line))
867    return desc
868
869  def assertStartsWith(self, actual, expected_start, msg=None):
870    """Asserts that actual.startswith(expected_start) is True.
871
872    Args:
873      actual: str
874      expected_start: str
875      msg: Optional message to report on failure.
876    """
877    if not actual.startswith(expected_start):
878      self.fail('%r does not start with %r' % (actual, expected_start), msg)
879
880  def assertNotStartsWith(self, actual, unexpected_start, msg=None):
881    """Asserts that actual.startswith(unexpected_start) is False.
882
883    Args:
884      actual: str
885      unexpected_start: str
886      msg: Optional message to report on failure.
887    """
888    if actual.startswith(unexpected_start):
889      self.fail('%r does start with %r' % (actual, unexpected_start), msg)
890
891  def assertEndsWith(self, actual, expected_end, msg=None):
892    """Asserts that actual.endswith(expected_end) is True.
893
894    Args:
895      actual: str
896      expected_end: str
897      msg: Optional message to report on failure.
898    """
899    if not actual.endswith(expected_end):
900      self.fail('%r does not end with %r' % (actual, expected_end), msg)
901
902  def assertNotEndsWith(self, actual, unexpected_end, msg=None):
903    """Asserts that actual.endswith(unexpected_end) is False.
904
905    Args:
906      actual: str
907      unexpected_end: str
908      msg: Optional message to report on failure.
909    """
910    if actual.endswith(unexpected_end):
911      self.fail('%r does end with %r' % (actual, unexpected_end), msg)
912
913  def assertSequenceStartsWith(self, prefix, whole, msg=None):
914    """An equality assertion for the beginning of ordered sequences.
915
916    If prefix is an empty sequence, it will raise an error unless whole is also
917    an empty sequence.
918
919    If prefix is not a sequence, it will raise an error if the first element of
920    whole does not match.
921
922    Args:
923      prefix: A sequence expected at the beginning of the whole parameter.
924      whole: The sequence in which to look for prefix.
925      msg: Optional message to report on failure.
926    """
927    try:
928      prefix_len = len(prefix)
929    except (TypeError, NotImplementedError):
930      prefix = [prefix]
931      prefix_len = 1
932
933    try:
934      whole_len = len(whole)
935    except (TypeError, NotImplementedError):
936      self.fail('For whole: len(%s) is not supported, it appears to be type: '
937                '%s' % (whole, type(whole)), msg)
938
939    assert prefix_len <= whole_len, self._formatMessage(
940        msg,
941        'Prefix length (%d) is longer than whole length (%d).' %
942        (prefix_len, whole_len)
943    )
944
945    if not prefix_len and whole_len:
946      self.fail('Prefix length is 0 but whole length is %d: %s' %
947                (len(whole), whole), msg)
948
949    try:
950      self.assertSequenceEqual(prefix, whole[:prefix_len], msg)
951    except AssertionError:
952      self.fail('prefix: %s not found at start of whole: %s.' %
953                (prefix, whole), msg)
954
955  def assertEmpty(self, container, msg=None):
956    """Asserts that an object has zero length.
957
958    Args:
959      container: Anything that implements the collections.abc.Sized interface.
960      msg: Optional message to report on failure.
961    """
962    if not isinstance(container, abc.Sized):
963      self.fail('Expected a Sized object, got: '
964                '{!r}'.format(type(container).__name__), msg)
965
966    # explicitly check the length since some Sized objects (e.g. numpy.ndarray)
967    # have strange __nonzero__/__bool__ behavior.
968    if len(container):  # pylint: disable=g-explicit-length-test
969      self.fail('{!r} has length of {}.'.format(container, len(container)), msg)
970
971  def assertNotEmpty(self, container, msg=None):
972    """Asserts that an object has non-zero length.
973
974    Args:
975      container: Anything that implements the collections.abc.Sized interface.
976      msg: Optional message to report on failure.
977    """
978    if not isinstance(container, abc.Sized):
979      self.fail('Expected a Sized object, got: '
980                '{!r}'.format(type(container).__name__), msg)
981
982    # explicitly check the length since some Sized objects (e.g. numpy.ndarray)
983    # have strange __nonzero__/__bool__ behavior.
984    if not len(container):  # pylint: disable=g-explicit-length-test
985      self.fail('{!r} has length of 0.'.format(container), msg)
986
987  def assertLen(self, container, expected_len, msg=None):
988    """Asserts that an object has the expected length.
989
990    Args:
991      container: Anything that implements the collections.abc.Sized interface.
992      expected_len: The expected length of the container.
993      msg: Optional message to report on failure.
994    """
995    if not isinstance(container, abc.Sized):
996      self.fail('Expected a Sized object, got: '
997                '{!r}'.format(type(container).__name__), msg)
998    if len(container) != expected_len:
999      container_repr = unittest.util.safe_repr(container)  # pytype: disable=module-attr
1000      self.fail('{} has length of {}, expected {}.'.format(
1001          container_repr, len(container), expected_len), msg)
1002
1003  def assertSequenceAlmostEqual(self, expected_seq, actual_seq, places=None,
1004                                msg=None, delta=None):
1005    """An approximate equality assertion for ordered sequences.
1006
1007    Fail if the two sequences are unequal as determined by their value
1008    differences rounded to the given number of decimal places (default 7) and
1009    comparing to zero, or by comparing that the difference between each value
1010    in the two sequences is more than the given delta.
1011
1012    Note that decimal places (from zero) are usually not the same as significant
1013    digits (measured from the most significant digit).
1014
1015    If the two sequences compare equal then they will automatically compare
1016    almost equal.
1017
1018    Args:
1019      expected_seq: A sequence containing elements we are expecting.
1020      actual_seq: The sequence that we are testing.
1021      places: The number of decimal places to compare.
1022      msg: The message to be printed if the test fails.
1023      delta: The OK difference between compared values.
1024    """
1025    if len(expected_seq) != len(actual_seq):
1026      self.fail('Sequence size mismatch: {} vs {}'.format(
1027          len(expected_seq), len(actual_seq)), msg)
1028
1029    err_list = []
1030    for idx, (exp_elem, act_elem) in enumerate(zip(expected_seq, actual_seq)):
1031      try:
1032        # assertAlmostEqual should be called with at most one of `places` and
1033        # `delta`. However, it's okay for assertSequenceAlmostEqual to pass
1034        # both because we want the latter to fail if the former does.
1035        # pytype: disable=wrong-keyword-args
1036        self.assertAlmostEqual(exp_elem, act_elem, places=places, msg=msg,
1037                               delta=delta)
1038        # pytype: enable=wrong-keyword-args
1039      except self.failureException as err:
1040        err_list.append('At index {}: {}'.format(idx, err))
1041
1042    if err_list:
1043      if len(err_list) > 30:
1044        err_list = err_list[:30] + ['...']
1045      msg = self._formatMessage(msg, '\n'.join(err_list))
1046      self.fail(msg)
1047
1048  def assertContainsSubset(self, expected_subset, actual_set, msg=None):
1049    """Checks whether actual iterable is a superset of expected iterable."""
1050    missing = set(expected_subset) - set(actual_set)
1051    if not missing:
1052      return
1053
1054    self.fail('Missing elements %s\nExpected: %s\nActual: %s' % (
1055        missing, expected_subset, actual_set), msg)
1056
1057  def assertNoCommonElements(self, expected_seq, actual_seq, msg=None):
1058    """Checks whether actual iterable and expected iterable are disjoint."""
1059    common = set(expected_seq) & set(actual_seq)
1060    if not common:
1061      return
1062
1063    self.fail('Common elements %s\nExpected: %s\nActual: %s' % (
1064        common, expected_seq, actual_seq), msg)
1065
1066  def assertItemsEqual(self, expected_seq, actual_seq, msg=None):
1067    """Deprecated, please use assertCountEqual instead.
1068
1069    This is equivalent to assertCountEqual.
1070
1071    Args:
1072      expected_seq: A sequence containing elements we are expecting.
1073      actual_seq: The sequence that we are testing.
1074      msg: The message to be printed if the test fails.
1075    """
1076    super().assertCountEqual(expected_seq, actual_seq, msg)
1077
1078  def assertSameElements(self, expected_seq, actual_seq, msg=None):
1079    """Asserts that two sequences have the same elements (in any order).
1080
1081    This method, unlike assertCountEqual, doesn't care about any
1082    duplicates in the expected and actual sequences::
1083
1084        # Doesn't raise an AssertionError
1085        assertSameElements([1, 1, 1, 0, 0, 0], [0, 1])
1086
1087    If possible, you should use assertCountEqual instead of
1088    assertSameElements.
1089
1090    Args:
1091      expected_seq: A sequence containing elements we are expecting.
1092      actual_seq: The sequence that we are testing.
1093      msg: The message to be printed if the test fails.
1094    """
1095    # `unittest2.TestCase` used to have assertSameElements, but it was
1096    # removed in favor of assertItemsEqual. As there's a unit test
1097    # that explicitly checks this behavior, I am leaving this method
1098    # alone.
1099    # Fail on strings: empirically, passing strings to this test method
1100    # is almost always a bug. If comparing the character sets of two strings
1101    # is desired, cast the inputs to sets or lists explicitly.
1102    if (isinstance(expected_seq, _TEXT_OR_BINARY_TYPES) or
1103        isinstance(actual_seq, _TEXT_OR_BINARY_TYPES)):
1104      self.fail('Passing string/bytes to assertSameElements is usually a bug. '
1105                'Did you mean to use assertEqual?\n'
1106                'Expected: %s\nActual: %s' % (expected_seq, actual_seq))
1107    try:
1108      expected = dict([(element, None) for element in expected_seq])
1109      actual = dict([(element, None) for element in actual_seq])
1110      missing = [element for element in expected if element not in actual]
1111      unexpected = [element for element in actual if element not in expected]
1112      missing.sort()
1113      unexpected.sort()
1114    except TypeError:
1115      # Fall back to slower list-compare if any of the objects are
1116      # not hashable.
1117      expected = list(expected_seq)
1118      actual = list(actual_seq)
1119      expected.sort()
1120      actual.sort()
1121      missing, unexpected = _sorted_list_difference(expected, actual)
1122    errors = []
1123    if msg:
1124      errors.extend((msg, ':\n'))
1125    if missing:
1126      errors.append('Expected, but missing:\n  %r\n' % missing)
1127    if unexpected:
1128      errors.append('Unexpected, but present:\n  %r\n' % unexpected)
1129    if missing or unexpected:
1130      self.fail(''.join(errors))
1131
1132  # unittest.TestCase.assertMultiLineEqual works very similarly, but it
1133  # has a different error format. However, I find this slightly more readable.
1134  def assertMultiLineEqual(self, first, second, msg=None, **kwargs):
1135    """Asserts that two multi-line strings are equal."""
1136    assert isinstance(first,
1137                      str), ('First argument is not a string: %r' % (first,))
1138    assert isinstance(second,
1139                      str), ('Second argument is not a string: %r' % (second,))
1140    line_limit = kwargs.pop('line_limit', 0)
1141    if kwargs:
1142      raise TypeError('Unexpected keyword args {}'.format(tuple(kwargs)))
1143
1144    if first == second:
1145      return
1146    if msg:
1147      failure_message = [msg + ':\n']
1148    else:
1149      failure_message = ['\n']
1150    if line_limit:
1151      line_limit += len(failure_message)
1152    for line in difflib.ndiff(first.splitlines(True), second.splitlines(True)):
1153      failure_message.append(line)
1154      if not line.endswith('\n'):
1155        failure_message.append('\n')
1156    if line_limit and len(failure_message) > line_limit:
1157      n_omitted = len(failure_message) - line_limit
1158      failure_message = failure_message[:line_limit]
1159      failure_message.append(
1160          '(... and {} more delta lines omitted for brevity.)\n'.format(
1161              n_omitted))
1162
1163    raise self.failureException(''.join(failure_message))
1164
1165  def assertBetween(self, value, minv, maxv, msg=None):
1166    """Asserts that value is between minv and maxv (inclusive)."""
1167    msg = self._formatMessage(msg,
1168                              '"%r" unexpectedly not between "%r" and "%r"' %
1169                              (value, minv, maxv))
1170    self.assertTrue(minv <= value, msg)
1171    self.assertTrue(maxv >= value, msg)
1172
1173  def assertRegexMatch(self, actual_str, regexes, message=None):
1174    r"""Asserts that at least one regex in regexes matches str.
1175
1176    If possible you should use `assertRegex`, which is a simpler
1177    version of this method. `assertRegex` takes a single regular
1178    expression (a string or re compiled object) instead of a list.
1179
1180    Notes:
1181
1182    1. This function uses substring matching, i.e. the matching
1183       succeeds if *any* substring of the error message matches *any*
1184       regex in the list.  This is more convenient for the user than
1185       full-string matching.
1186
1187    2. If regexes is the empty list, the matching will always fail.
1188
1189    3. Use regexes=[''] for a regex that will always pass.
1190
1191    4. '.' matches any single character *except* the newline.  To
1192       match any character, use '(.|\n)'.
1193
1194    5. '^' matches the beginning of each line, not just the beginning
1195       of the string.  Similarly, '$' matches the end of each line.
1196
1197    6. An exception will be thrown if regexes contains an invalid
1198       regex.
1199
1200    Args:
1201      actual_str:  The string we try to match with the items in regexes.
1202      regexes:  The regular expressions we want to match against str.
1203          See "Notes" above for detailed notes on how this is interpreted.
1204      message:  The message to be printed if the test fails.
1205    """
1206    if isinstance(regexes, _TEXT_OR_BINARY_TYPES):
1207      self.fail('regexes is string or bytes; use assertRegex instead.',
1208                message)
1209    if not regexes:
1210      self.fail('No regexes specified.', message)
1211
1212    regex_type = type(regexes[0])
1213    for regex in regexes[1:]:
1214      if type(regex) is not regex_type:  # pylint: disable=unidiomatic-typecheck
1215        self.fail('regexes list must all be the same type.', message)
1216
1217    if regex_type is bytes and isinstance(actual_str, str):
1218      regexes = [regex.decode('utf-8') for regex in regexes]
1219      regex_type = str
1220    elif regex_type is str and isinstance(actual_str, bytes):
1221      regexes = [regex.encode('utf-8') for regex in regexes]
1222      regex_type = bytes
1223
1224    if regex_type is str:
1225      regex = u'(?:%s)' % u')|(?:'.join(regexes)
1226    elif regex_type is bytes:
1227      regex = b'(?:' + (b')|(?:'.join(regexes)) + b')'
1228    else:
1229      self.fail('Only know how to deal with unicode str or bytes regexes.',
1230                message)
1231
1232    if not re.search(regex, actual_str, re.MULTILINE):
1233      self.fail('"%s" does not contain any of these regexes: %s.' %
1234                (actual_str, regexes), message)
1235
1236  def assertCommandSucceeds(self, command, regexes=(b'',), env=None,
1237                            close_fds=True, msg=None):
1238    """Asserts that a shell command succeeds (i.e. exits with code 0).
1239
1240    Args:
1241      command: List or string representing the command to run.
1242      regexes: List of regular expression byte strings that match success.
1243      env: Dictionary of environment variable settings. If None, no environment
1244          variables will be set for the child process. This is to make tests
1245          more hermetic. NOTE: this behavior is different than the standard
1246          subprocess module.
1247      close_fds: Whether or not to close all open fd's in the child after
1248          forking.
1249      msg: Optional message to report on failure.
1250    """
1251    (ret_code, err) = get_command_stderr(command, env, close_fds)
1252
1253    # We need bytes regexes here because `err` is bytes.
1254    # Accommodate code which listed their output regexes w/o the b'' prefix by
1255    # converting them to bytes for the user.
1256    if isinstance(regexes[0], str):
1257      regexes = [regex.encode('utf-8') for regex in regexes]
1258
1259    command_string = get_command_string(command)
1260    self.assertEqual(
1261        ret_code, 0,
1262        self._formatMessage(msg,
1263                            'Running command\n'
1264                            '%s failed with error code %s and message\n'
1265                            '%s' % (_quote_long_string(command_string),
1266                                    ret_code,
1267                                    _quote_long_string(err)))
1268    )
1269    self.assertRegexMatch(
1270        err,
1271        regexes,
1272        message=self._formatMessage(
1273            msg,
1274            'Running command\n'
1275            '%s failed with error code %s and message\n'
1276            '%s which matches no regex in %s' % (
1277                _quote_long_string(command_string),
1278                ret_code,
1279                _quote_long_string(err),
1280                regexes)))
1281
1282  def assertCommandFails(self, command, regexes, env=None, close_fds=True,
1283                         msg=None):
1284    """Asserts a shell command fails and the error matches a regex in a list.
1285
1286    Args:
1287      command: List or string representing the command to run.
1288      regexes: the list of regular expression strings.
1289      env: Dictionary of environment variable settings. If None, no environment
1290          variables will be set for the child process. This is to make tests
1291          more hermetic. NOTE: this behavior is different than the standard
1292          subprocess module.
1293      close_fds: Whether or not to close all open fd's in the child after
1294          forking.
1295      msg: Optional message to report on failure.
1296    """
1297    (ret_code, err) = get_command_stderr(command, env, close_fds)
1298
1299    # We need bytes regexes here because `err` is bytes.
1300    # Accommodate code which listed their output regexes w/o the b'' prefix by
1301    # converting them to bytes for the user.
1302    if isinstance(regexes[0], str):
1303      regexes = [regex.encode('utf-8') for regex in regexes]
1304
1305    command_string = get_command_string(command)
1306    self.assertNotEqual(
1307        ret_code, 0,
1308        self._formatMessage(msg, 'The following command succeeded '
1309                            'while expected to fail:\n%s' %
1310                            _quote_long_string(command_string)))
1311    self.assertRegexMatch(
1312        err,
1313        regexes,
1314        message=self._formatMessage(
1315            msg,
1316            'Running command\n'
1317            '%s failed with error code %s and message\n'
1318            '%s which matches no regex in %s' % (
1319                _quote_long_string(command_string),
1320                ret_code,
1321                _quote_long_string(err),
1322                regexes)))
1323
1324  class _AssertRaisesContext(object):
1325
1326    def __init__(self, expected_exception, test_case, test_func, msg=None):
1327      self.expected_exception = expected_exception
1328      self.test_case = test_case
1329      self.test_func = test_func
1330      self.msg = msg
1331
1332    def __enter__(self):
1333      return self
1334
1335    def __exit__(self, exc_type, exc_value, tb):
1336      if exc_type is None:
1337        self.test_case.fail(self.expected_exception.__name__ + ' not raised',
1338                            self.msg)
1339      if not issubclass(exc_type, self.expected_exception):
1340        return False
1341      self.test_func(exc_value)
1342      if exc_value:
1343        self.exception = exc_value.with_traceback(None)
1344      return True
1345
1346  @typing.overload
1347  def assertRaisesWithPredicateMatch(
1348      self, expected_exception, predicate) -> _AssertRaisesContext:
1349    # The purpose of this return statement is to work around
1350    # https://github.com/PyCQA/pylint/issues/5273; it is otherwise ignored.
1351    return self._AssertRaisesContext(None, None, None)
1352
1353  @typing.overload
1354  def assertRaisesWithPredicateMatch(
1355      self, expected_exception, predicate, callable_obj: Callable[..., Any],
1356      *args, **kwargs) -> None:
1357    # The purpose of this return statement is to work around
1358    # https://github.com/PyCQA/pylint/issues/5273; it is otherwise ignored.
1359    return self._AssertRaisesContext(None, None, None)
1360
1361  def assertRaisesWithPredicateMatch(self, expected_exception, predicate,
1362                                     callable_obj=None, *args, **kwargs):
1363    """Asserts that exception is thrown and predicate(exception) is true.
1364
1365    Args:
1366      expected_exception: Exception class expected to be raised.
1367      predicate: Function of one argument that inspects the passed-in exception
1368          and returns True (success) or False (please fail the test).
1369      callable_obj: Function to be called.
1370      *args: Extra args.
1371      **kwargs: Extra keyword args.
1372
1373    Returns:
1374      A context manager if callable_obj is None. Otherwise, None.
1375
1376    Raises:
1377      self.failureException if callable_obj does not raise a matching exception.
1378    """
1379    def Check(err):
1380      self.assertTrue(predicate(err),
1381                      '%r does not match predicate %r' % (err, predicate))
1382
1383    context = self._AssertRaisesContext(expected_exception, self, Check)
1384    if callable_obj is None:
1385      return context
1386    with context:
1387      callable_obj(*args, **kwargs)
1388
1389  @typing.overload
1390  def assertRaisesWithLiteralMatch(
1391      self, expected_exception, expected_exception_message
1392  ) -> _AssertRaisesContext:
1393    # The purpose of this return statement is to work around
1394    # https://github.com/PyCQA/pylint/issues/5273; it is otherwise ignored.
1395    return self._AssertRaisesContext(None, None, None)
1396
1397  @typing.overload
1398  def assertRaisesWithLiteralMatch(
1399      self, expected_exception, expected_exception_message,
1400      callable_obj: Callable[..., Any], *args, **kwargs) -> None:
1401    # The purpose of this return statement is to work around
1402    # https://github.com/PyCQA/pylint/issues/5273; it is otherwise ignored.
1403    return self._AssertRaisesContext(None, None, None)
1404
1405  def assertRaisesWithLiteralMatch(self, expected_exception,
1406                                   expected_exception_message,
1407                                   callable_obj=None, *args, **kwargs):
1408    """Asserts that the message in a raised exception equals the given string.
1409
1410    Unlike assertRaisesRegex, this method takes a literal string, not
1411    a regular expression.
1412
1413    with self.assertRaisesWithLiteralMatch(ExType, 'message'):
1414      DoSomething()
1415
1416    Args:
1417      expected_exception: Exception class expected to be raised.
1418      expected_exception_message: String message expected in the raised
1419          exception.  For a raise exception e, expected_exception_message must
1420          equal str(e).
1421      callable_obj: Function to be called, or None to return a context.
1422      *args: Extra args.
1423      **kwargs: Extra kwargs.
1424
1425    Returns:
1426      A context manager if callable_obj is None. Otherwise, None.
1427
1428    Raises:
1429      self.failureException if callable_obj does not raise a matching exception.
1430    """
1431    def Check(err):
1432      actual_exception_message = str(err)
1433      self.assertTrue(expected_exception_message == actual_exception_message,
1434                      'Exception message does not match.\n'
1435                      'Expected: %r\n'
1436                      'Actual: %r' % (expected_exception_message,
1437                                      actual_exception_message))
1438
1439    context = self._AssertRaisesContext(expected_exception, self, Check)
1440    if callable_obj is None:
1441      return context
1442    with context:
1443      callable_obj(*args, **kwargs)
1444
1445  def assertContainsInOrder(self, strings, target, msg=None):
1446    """Asserts that the strings provided are found in the target in order.
1447
1448    This may be useful for checking HTML output.
1449
1450    Args:
1451      strings: A list of strings, such as [ 'fox', 'dog' ]
1452      target: A target string in which to look for the strings, such as
1453          'The quick brown fox jumped over the lazy dog'.
1454      msg: Optional message to report on failure.
1455    """
1456    if isinstance(strings, (bytes, unicode if str is bytes else str)):
1457      strings = (strings,)
1458
1459    current_index = 0
1460    last_string = None
1461    for string in strings:
1462      index = target.find(str(string), current_index)
1463      if index == -1 and current_index == 0:
1464        self.fail("Did not find '%s' in '%s'" %
1465                  (string, target), msg)
1466      elif index == -1:
1467        self.fail("Did not find '%s' after '%s' in '%s'" %
1468                  (string, last_string, target), msg)
1469      last_string = string
1470      current_index = index
1471
1472  def assertContainsSubsequence(self, container, subsequence, msg=None):
1473    """Asserts that "container" contains "subsequence" as a subsequence.
1474
1475    Asserts that "container" contains all the elements of "subsequence", in
1476    order, but possibly with other elements interspersed. For example, [1, 2, 3]
1477    is a subsequence of [0, 0, 1, 2, 0, 3, 0] but not of [0, 0, 1, 3, 0, 2, 0].
1478
1479    Args:
1480      container: the list we're testing for subsequence inclusion.
1481      subsequence: the list we hope will be a subsequence of container.
1482      msg: Optional message to report on failure.
1483    """
1484    first_nonmatching = None
1485    reversed_container = list(reversed(container))
1486    subsequence = list(subsequence)
1487
1488    for e in subsequence:
1489      if e not in reversed_container:
1490        first_nonmatching = e
1491        break
1492      while e != reversed_container.pop():
1493        pass
1494
1495    if first_nonmatching is not None:
1496      self.fail('%s not a subsequence of %s. First non-matching element: %s' %
1497                (subsequence, container, first_nonmatching), msg)
1498
1499  def assertContainsExactSubsequence(self, container, subsequence, msg=None):
1500    """Asserts that "container" contains "subsequence" as an exact subsequence.
1501
1502    Asserts that "container" contains all the elements of "subsequence", in
1503    order, and without other elements interspersed. For example, [1, 2, 3] is an
1504    exact subsequence of [0, 0, 1, 2, 3, 0] but not of [0, 0, 1, 2, 0, 3, 0].
1505
1506    Args:
1507      container: the list we're testing for subsequence inclusion.
1508      subsequence: the list we hope will be an exact subsequence of container.
1509      msg: Optional message to report on failure.
1510    """
1511    container = list(container)
1512    subsequence = list(subsequence)
1513    longest_match = 0
1514
1515    for start in range(1 + len(container) - len(subsequence)):
1516      if longest_match == len(subsequence):
1517        break
1518      index = 0
1519      while (index < len(subsequence) and
1520             subsequence[index] == container[start + index]):
1521        index += 1
1522      longest_match = max(longest_match, index)
1523
1524    if longest_match < len(subsequence):
1525      self.fail('%s not an exact subsequence of %s. '
1526                'Longest matching prefix: %s' %
1527                (subsequence, container, subsequence[:longest_match]), msg)
1528
1529  def assertTotallyOrdered(self, *groups, **kwargs):
1530    """Asserts that total ordering has been implemented correctly.
1531
1532    For example, say you have a class A that compares only on its attribute x.
1533    Comparators other than ``__lt__`` are omitted for brevity::
1534
1535        class A(object):
1536          def __init__(self, x, y):
1537            self.x = x
1538            self.y = y
1539
1540          def __hash__(self):
1541            return hash(self.x)
1542
1543          def __lt__(self, other):
1544            try:
1545              return self.x < other.x
1546            except AttributeError:
1547              return NotImplemented
1548
1549    assertTotallyOrdered will check that instances can be ordered correctly.
1550    For example::
1551
1552        self.assertTotallyOrdered(
1553            [None],  # None should come before everything else.
1554            [1],  # Integers sort earlier.
1555            [A(1, 'a')],
1556            [A(2, 'b')],  # 2 is after 1.
1557            [A(3, 'c'), A(3, 'd')],  # The second argument is irrelevant.
1558            [A(4, 'z')],
1559            ['foo'])  # Strings sort last.
1560
1561    Args:
1562      *groups: A list of groups of elements.  Each group of elements is a list
1563        of objects that are equal.  The elements in each group must be less
1564        than the elements in the group after it.  For example, these groups are
1565        totally ordered: ``[None]``, ``[1]``, ``[2, 2]``, ``[3]``.
1566      **kwargs: optional msg keyword argument can be passed.
1567    """
1568
1569    def CheckOrder(small, big):
1570      """Ensures small is ordered before big."""
1571      self.assertFalse(small == big,
1572                       self._formatMessage(msg, '%r unexpectedly equals %r' %
1573                                           (small, big)))
1574      self.assertTrue(small != big,
1575                      self._formatMessage(msg, '%r unexpectedly equals %r' %
1576                                          (small, big)))
1577      self.assertLess(small, big, msg)
1578      self.assertFalse(big < small,
1579                       self._formatMessage(msg,
1580                                           '%r unexpectedly less than %r' %
1581                                           (big, small)))
1582      self.assertLessEqual(small, big, msg)
1583      self.assertFalse(big <= small, self._formatMessage(
1584          '%r unexpectedly less than or equal to %r' % (big, small), msg
1585      ))
1586      self.assertGreater(big, small, msg)
1587      self.assertFalse(small > big,
1588                       self._formatMessage(msg,
1589                                           '%r unexpectedly greater than %r' %
1590                                           (small, big)))
1591      self.assertGreaterEqual(big, small)
1592      self.assertFalse(small >= big, self._formatMessage(
1593          msg,
1594          '%r unexpectedly greater than or equal to %r' % (small, big)))
1595
1596    def CheckEqual(a, b):
1597      """Ensures that a and b are equal."""
1598      self.assertEqual(a, b, msg)
1599      self.assertFalse(a != b,
1600                       self._formatMessage(msg, '%r unexpectedly unequals %r' %
1601                                           (a, b)))
1602
1603      # Objects that compare equal must hash to the same value, but this only
1604      # applies if both objects are hashable.
1605      if (isinstance(a, abc.Hashable) and
1606          isinstance(b, abc.Hashable)):
1607        self.assertEqual(
1608            hash(a), hash(b),
1609            self._formatMessage(
1610                msg, 'hash %d of %r unexpectedly not equal to hash %d of %r' %
1611                (hash(a), a, hash(b), b)))
1612
1613      self.assertFalse(a < b,
1614                       self._formatMessage(msg,
1615                                           '%r unexpectedly less than %r' %
1616                                           (a, b)))
1617      self.assertFalse(b < a,
1618                       self._formatMessage(msg,
1619                                           '%r unexpectedly less than %r' %
1620                                           (b, a)))
1621      self.assertLessEqual(a, b, msg)
1622      self.assertLessEqual(b, a, msg)  # pylint: disable=arguments-out-of-order
1623      self.assertFalse(a > b,
1624                       self._formatMessage(msg,
1625                                           '%r unexpectedly greater than %r' %
1626                                           (a, b)))
1627      self.assertFalse(b > a,
1628                       self._formatMessage(msg,
1629                                           '%r unexpectedly greater than %r' %
1630                                           (b, a)))
1631      self.assertGreaterEqual(a, b, msg)
1632      self.assertGreaterEqual(b, a, msg)  # pylint: disable=arguments-out-of-order
1633
1634    msg = kwargs.get('msg')
1635
1636    # For every combination of elements, check the order of every pair of
1637    # elements.
1638    for elements in itertools.product(*groups):
1639      elements = list(elements)
1640      for index, small in enumerate(elements[:-1]):
1641        for big in elements[index + 1:]:
1642          CheckOrder(small, big)
1643
1644    # Check that every element in each group is equal.
1645    for group in groups:
1646      for a in group:
1647        CheckEqual(a, a)
1648      for a, b in itertools.product(group, group):
1649        CheckEqual(a, b)
1650
1651  def assertDictEqual(self, a, b, msg=None):
1652    """Raises AssertionError if a and b are not equal dictionaries.
1653
1654    Args:
1655      a: A dict, the expected value.
1656      b: A dict, the actual value.
1657      msg: An optional str, the associated message.
1658
1659    Raises:
1660      AssertionError: if the dictionaries are not equal.
1661    """
1662    self.assertIsInstance(a, dict, self._formatMessage(
1663        msg,
1664        'First argument is not a dictionary'
1665    ))
1666    self.assertIsInstance(b, dict, self._formatMessage(
1667        msg,
1668        'Second argument is not a dictionary'
1669    ))
1670
1671    def Sorted(list_of_items):
1672      try:
1673        return sorted(list_of_items)  # In 3.3, unordered are possible.
1674      except TypeError:
1675        return list_of_items
1676
1677    if a == b:
1678      return
1679    a_items = Sorted(list(a.items()))
1680    b_items = Sorted(list(b.items()))
1681
1682    unexpected = []
1683    missing = []
1684    different = []
1685
1686    safe_repr = unittest.util.safe_repr  # pytype: disable=module-attr
1687
1688    def Repr(dikt):
1689      """Deterministic repr for dict."""
1690      # Sort the entries based on their repr, not based on their sort order,
1691      # which will be non-deterministic across executions, for many types.
1692      entries = sorted((safe_repr(k), safe_repr(v)) for k, v in dikt.items())
1693      return '{%s}' % (', '.join('%s: %s' % pair for pair in entries))
1694
1695    message = ['%s != %s%s' % (Repr(a), Repr(b), ' (%s)' % msg if msg else '')]
1696
1697    # The standard library default output confounds lexical difference with
1698    # value difference; treat them separately.
1699    for a_key, a_value in a_items:
1700      if a_key not in b:
1701        missing.append((a_key, a_value))
1702      elif a_value != b[a_key]:
1703        different.append((a_key, a_value, b[a_key]))
1704
1705    for b_key, b_value in b_items:
1706      if b_key not in a:
1707        unexpected.append((b_key, b_value))
1708
1709    if unexpected:
1710      message.append(
1711          'Unexpected, but present entries:\n%s' % ''.join(
1712              '%s: %s\n' % (safe_repr(k), safe_repr(v)) for k, v in unexpected))
1713
1714    if different:
1715      message.append(
1716          'repr() of differing entries:\n%s' % ''.join(
1717              '%s: %s != %s\n' % (safe_repr(k), safe_repr(a_value),
1718                                  safe_repr(b_value))
1719              for k, a_value, b_value in different))
1720
1721    if missing:
1722      message.append(
1723          'Missing entries:\n%s' % ''.join(
1724              ('%s: %s\n' % (safe_repr(k), safe_repr(v)) for k, v in missing)))
1725
1726    raise self.failureException('\n'.join(message))
1727
1728  def assertUrlEqual(self, a, b, msg=None):
1729    """Asserts that urls are equal, ignoring ordering of query params."""
1730    parsed_a = parse.urlparse(a)
1731    parsed_b = parse.urlparse(b)
1732    self.assertEqual(parsed_a.scheme, parsed_b.scheme, msg)
1733    self.assertEqual(parsed_a.netloc, parsed_b.netloc, msg)
1734    self.assertEqual(parsed_a.path, parsed_b.path, msg)
1735    self.assertEqual(parsed_a.fragment, parsed_b.fragment, msg)
1736    self.assertEqual(sorted(parsed_a.params.split(';')),
1737                     sorted(parsed_b.params.split(';')), msg)
1738    self.assertDictEqual(
1739        parse.parse_qs(parsed_a.query, keep_blank_values=True),
1740        parse.parse_qs(parsed_b.query, keep_blank_values=True), msg)
1741
1742  def assertSameStructure(self, a, b, aname='a', bname='b', msg=None):
1743    """Asserts that two values contain the same structural content.
1744
1745    The two arguments should be data trees consisting of trees of dicts and
1746    lists. They will be deeply compared by walking into the contents of dicts
1747    and lists; other items will be compared using the == operator.
1748    If the two structures differ in content, the failure message will indicate
1749    the location within the structures where the first difference is found.
1750    This may be helpful when comparing large structures.
1751
1752    Mixed Sequence and Set types are supported. Mixed Mapping types are
1753    supported, but the order of the keys will not be considered in the
1754    comparison.
1755
1756    Args:
1757      a: The first structure to compare.
1758      b: The second structure to compare.
1759      aname: Variable name to use for the first structure in assertion messages.
1760      bname: Variable name to use for the second structure.
1761      msg: Additional text to include in the failure message.
1762    """
1763
1764    # Accumulate all the problems found so we can report all of them at once
1765    # rather than just stopping at the first
1766    problems = []
1767
1768    _walk_structure_for_problems(a, b, aname, bname, problems)
1769
1770    # Avoid spamming the user toooo much
1771    if self.maxDiff is not None:
1772      max_problems_to_show = self.maxDiff // 80
1773      if len(problems) > max_problems_to_show:
1774        problems = problems[0:max_problems_to_show-1] + ['...']
1775
1776    if problems:
1777      self.fail('; '.join(problems), msg)
1778
1779  def assertJsonEqual(self, first, second, msg=None):
1780    """Asserts that the JSON objects defined in two strings are equal.
1781
1782    A summary of the differences will be included in the failure message
1783    using assertSameStructure.
1784
1785    Args:
1786      first: A string containing JSON to decode and compare to second.
1787      second: A string containing JSON to decode and compare to first.
1788      msg: Additional text to include in the failure message.
1789    """
1790    try:
1791      first_structured = json.loads(first)
1792    except ValueError as e:
1793      raise ValueError(self._formatMessage(
1794          msg,
1795          'could not decode first JSON value %s: %s' % (first, e)))
1796
1797    try:
1798      second_structured = json.loads(second)
1799    except ValueError as e:
1800      raise ValueError(self._formatMessage(
1801          msg,
1802          'could not decode second JSON value %s: %s' % (second, e)))
1803
1804    self.assertSameStructure(first_structured, second_structured,
1805                             aname='first', bname='second', msg=msg)
1806
1807  def _getAssertEqualityFunc(self, first, second):
1808    # type: (Any, Any) -> Callable[..., None]
1809    try:
1810      return super(TestCase, self)._getAssertEqualityFunc(first, second)
1811    except AttributeError:
1812      # This is a workaround if unittest.TestCase.__init__ was never run.
1813      # It usually means that somebody created a subclass just for the
1814      # assertions and has overridden __init__. "assertTrue" is a safe
1815      # value that will not make __init__ raise a ValueError.
1816      test_method = getattr(self, '_testMethodName', 'assertTrue')
1817      super(TestCase, self).__init__(test_method)
1818
1819    return super(TestCase, self)._getAssertEqualityFunc(first, second)
1820
1821  def fail(self, msg=None, prefix=None):
1822    """Fail immediately with the given message, optionally prefixed."""
1823    return super(TestCase, self).fail(self._formatMessage(prefix, msg))
1824
1825
1826def _sorted_list_difference(expected, actual):
1827  # type: (List[_T], List[_T]) -> Tuple[List[_T], List[_T]]
1828  """Finds elements in only one or the other of two, sorted input lists.
1829
1830  Returns a two-element tuple of lists.  The first list contains those
1831  elements in the "expected" list but not in the "actual" list, and the
1832  second contains those elements in the "actual" list but not in the
1833  "expected" list.  Duplicate elements in either input list are ignored.
1834
1835  Args:
1836    expected:  The list we expected.
1837    actual:  The list we actually got.
1838  Returns:
1839    (missing, unexpected)
1840    missing: items in expected that are not in actual.
1841    unexpected: items in actual that are not in expected.
1842  """
1843  i = j = 0
1844  missing = []
1845  unexpected = []
1846  while True:
1847    try:
1848      e = expected[i]
1849      a = actual[j]
1850      if e < a:
1851        missing.append(e)
1852        i += 1
1853        while expected[i] == e:
1854          i += 1
1855      elif e > a:
1856        unexpected.append(a)
1857        j += 1
1858        while actual[j] == a:
1859          j += 1
1860      else:
1861        i += 1
1862        try:
1863          while expected[i] == e:
1864            i += 1
1865        finally:
1866          j += 1
1867          while actual[j] == a:
1868            j += 1
1869    except IndexError:
1870      missing.extend(expected[i:])
1871      unexpected.extend(actual[j:])
1872      break
1873  return missing, unexpected
1874
1875
1876def _are_both_of_integer_type(a, b):
1877  # type: (object, object) -> bool
1878  return isinstance(a, int) and isinstance(b, int)
1879
1880
1881def _are_both_of_sequence_type(a, b):
1882  # type: (object, object) -> bool
1883  return isinstance(a, abc.Sequence) and isinstance(
1884      b, abc.Sequence) and not isinstance(
1885          a, _TEXT_OR_BINARY_TYPES) and not isinstance(b, _TEXT_OR_BINARY_TYPES)
1886
1887
1888def _are_both_of_set_type(a, b):
1889  # type: (object, object) -> bool
1890  return isinstance(a, abc.Set) and isinstance(b, abc.Set)
1891
1892
1893def _are_both_of_mapping_type(a, b):
1894  # type: (object, object) -> bool
1895  return isinstance(a, abc.Mapping) and isinstance(
1896      b, abc.Mapping)
1897
1898
1899def _walk_structure_for_problems(a, b, aname, bname, problem_list):
1900  """The recursive comparison behind assertSameStructure."""
1901  if type(a) != type(b) and not (  # pylint: disable=unidiomatic-typecheck
1902      _are_both_of_integer_type(a, b) or _are_both_of_sequence_type(a, b) or
1903      _are_both_of_set_type(a, b) or _are_both_of_mapping_type(a, b)):
1904    # We do not distinguish between int and long types as 99.99% of Python 2
1905    # code should never care.  They collapse into a single type in Python 3.
1906    problem_list.append('%s is a %r but %s is a %r' %
1907                        (aname, type(a), bname, type(b)))
1908    # If they have different types there's no point continuing
1909    return
1910
1911  if isinstance(a, abc.Set):
1912    for k in a:
1913      if k not in b:
1914        problem_list.append(
1915            '%s has %r but %s does not' % (aname, k, bname))
1916    for k in b:
1917      if k not in a:
1918        problem_list.append('%s lacks %r but %s has it' % (aname, k, bname))
1919
1920  # NOTE: a or b could be a defaultdict, so we must take care that the traversal
1921  # doesn't modify the data.
1922  elif isinstance(a, abc.Mapping):
1923    for k in a:
1924      if k in b:
1925        _walk_structure_for_problems(
1926            a[k], b[k], '%s[%r]' % (aname, k), '%s[%r]' % (bname, k),
1927            problem_list)
1928      else:
1929        problem_list.append(
1930            "%s has [%r] with value %r but it's missing in %s" %
1931            (aname, k, a[k], bname))
1932    for k in b:
1933      if k not in a:
1934        problem_list.append(
1935            '%s lacks [%r] but %s has it with value %r' %
1936            (aname, k, bname, b[k]))
1937
1938  # Strings/bytes are Sequences but we'll just do those with regular !=
1939  elif (isinstance(a, abc.Sequence) and
1940        not isinstance(a, _TEXT_OR_BINARY_TYPES)):
1941    minlen = min(len(a), len(b))
1942    for i in range(minlen):
1943      _walk_structure_for_problems(
1944          a[i], b[i], '%s[%d]' % (aname, i), '%s[%d]' % (bname, i),
1945          problem_list)
1946    for i in range(minlen, len(a)):
1947      problem_list.append('%s has [%i] with value %r but %s does not' %
1948                          (aname, i, a[i], bname))
1949    for i in range(minlen, len(b)):
1950      problem_list.append('%s lacks [%i] but %s has it with value %r' %
1951                          (aname, i, bname, b[i]))
1952
1953  else:
1954    if a != b:
1955      problem_list.append('%s is %r but %s is %r' % (aname, a, bname, b))
1956
1957
1958def get_command_string(command):
1959  """Returns an escaped string that can be used as a shell command.
1960
1961  Args:
1962    command: List or string representing the command to run.
1963  Returns:
1964    A string suitable for use as a shell command.
1965  """
1966  if isinstance(command, str):
1967    return command
1968  else:
1969    if os.name == 'nt':
1970      return ' '.join(command)
1971    else:
1972      # The following is identical to Python 3's shlex.quote function.
1973      command_string = ''
1974      for word in command:
1975        # Single quote word, and replace each ' in word with '"'"'
1976        command_string += "'" + word.replace("'", "'\"'\"'") + "' "
1977      return command_string[:-1]
1978
1979
1980def get_command_stderr(command, env=None, close_fds=True):
1981  """Runs the given shell command and returns a tuple.
1982
1983  Args:
1984    command: List or string representing the command to run.
1985    env: Dictionary of environment variable settings. If None, no environment
1986        variables will be set for the child process. This is to make tests
1987        more hermetic. NOTE: this behavior is different than the standard
1988        subprocess module.
1989    close_fds: Whether or not to close all open fd's in the child after forking.
1990        On Windows, this is ignored and close_fds is always False.
1991
1992  Returns:
1993    Tuple of (exit status, text printed to stdout and stderr by the command).
1994  """
1995  if env is None: env = {}
1996  if os.name == 'nt':
1997    # Windows does not support setting close_fds to True while also redirecting
1998    # standard handles.
1999    close_fds = False
2000
2001  use_shell = isinstance(command, str)
2002  process = subprocess.Popen(
2003      command,
2004      close_fds=close_fds,
2005      env=env,
2006      shell=use_shell,
2007      stderr=subprocess.STDOUT,
2008      stdout=subprocess.PIPE)
2009  output = process.communicate()[0]
2010  exit_status = process.wait()
2011  return (exit_status, output)
2012
2013
2014def _quote_long_string(s):
2015  # type: (Union[Text, bytes, bytearray]) -> Text
2016  """Quotes a potentially multi-line string to make the start and end obvious.
2017
2018  Args:
2019    s: A string.
2020
2021  Returns:
2022    The quoted string.
2023  """
2024  if isinstance(s, (bytes, bytearray)):
2025    try:
2026      s = s.decode('utf-8')
2027    except UnicodeDecodeError:
2028      s = str(s)
2029  return ('8<-----------\n' +
2030          s + '\n' +
2031          '----------->8\n')
2032
2033
2034def print_python_version():
2035  # type: () -> None
2036  # Having this in the test output logs by default helps debugging when all
2037  # you've got is the log and no other idea of which Python was used.
2038  sys.stderr.write('Running tests under Python {0[0]}.{0[1]}.{0[2]}: '
2039                   '{1}\n'.format(
2040                       sys.version_info,
2041                       sys.executable if sys.executable else 'embedded.'))
2042
2043
2044def main(*args, **kwargs):
2045  # type: (Text, Any) -> None
2046  """Executes a set of Python unit tests.
2047
2048  Usually this function is called without arguments, so the
2049  unittest.TestProgram instance will get created with the default settings,
2050  so it will run all test methods of all TestCase classes in the ``__main__``
2051  module.
2052
2053  Args:
2054    *args: Positional arguments passed through to
2055        ``unittest.TestProgram.__init__``.
2056    **kwargs: Keyword arguments passed through to
2057        ``unittest.TestProgram.__init__``.
2058  """
2059  print_python_version()
2060  _run_in_app(run_tests, args, kwargs)
2061
2062
2063def _is_in_app_main():
2064  # type: () -> bool
2065  """Returns True iff app.run is active."""
2066  f = sys._getframe().f_back  # pylint: disable=protected-access
2067  while f:
2068    if f.f_code == app.run.__code__:
2069      return True
2070    f = f.f_back
2071  return False
2072
2073
2074def _register_sigterm_with_faulthandler():
2075  # type: () -> None
2076  """Have faulthandler dump stacks on SIGTERM.  Useful to diagnose timeouts."""
2077  if faulthandler and getattr(faulthandler, 'register', None):
2078    # faulthandler.register is not available on Windows.
2079    # faulthandler.enable() is already called by app.run.
2080    try:
2081      faulthandler.register(signal.SIGTERM, chain=True)  # pytype: disable=module-attr
2082    except Exception as e:  # pylint: disable=broad-except
2083      sys.stderr.write('faulthandler.register(SIGTERM) failed '
2084                       '%r; ignoring.\n' % e)
2085
2086
2087def _run_in_app(function, args, kwargs):
2088  # type: (Callable[..., None], Sequence[Text], Mapping[Text, Any]) -> None
2089  """Executes a set of Python unit tests, ensuring app.run.
2090
2091  This is a private function, users should call absltest.main().
2092
2093  _run_in_app calculates argv to be the command-line arguments of this program
2094  (without the flags), sets the default of FLAGS.alsologtostderr to True,
2095  then it calls function(argv, args, kwargs), making sure that `function'
2096  will get called within app.run(). _run_in_app does this by checking whether
2097  it is called by app.run(), or by calling app.run() explicitly.
2098
2099  The reason why app.run has to be ensured is to make sure that
2100  flags are parsed and stripped properly, and other initializations done by
2101  the app module are also carried out, no matter if absltest.run() is called
2102  from within or outside app.run().
2103
2104  If _run_in_app is called from within app.run(), then it will reparse
2105  sys.argv and pass the result without command-line flags into the argv
2106  argument of `function'. The reason why this parsing is needed is that
2107  __main__.main() calls absltest.main() without passing its argv. So the
2108  only way _run_in_app could get to know the argv without the flags is that
2109  it reparses sys.argv.
2110
2111  _run_in_app changes the default of FLAGS.alsologtostderr to True so that the
2112  test program's stderr will contain all the log messages unless otherwise
2113  specified on the command-line. This overrides any explicit assignment to
2114  FLAGS.alsologtostderr by the test program prior to the call to _run_in_app()
2115  (e.g. in __main__.main).
2116
2117  Please note that _run_in_app (and the function it calls) is allowed to make
2118  changes to kwargs.
2119
2120  Args:
2121    function: absltest.run_tests or a similar function. It will be called as
2122        function(argv, args, kwargs) where argv is a list containing the
2123        elements of sys.argv without the command-line flags.
2124    args: Positional arguments passed through to unittest.TestProgram.__init__.
2125    kwargs: Keyword arguments passed through to unittest.TestProgram.__init__.
2126  """
2127  if _is_in_app_main():
2128    _register_sigterm_with_faulthandler()
2129
2130    # Change the default of alsologtostderr from False to True, so the test
2131    # programs's stderr will contain all the log messages.
2132    # If --alsologtostderr=false is specified in the command-line, or user
2133    # has called FLAGS.alsologtostderr = False before, then the value is kept
2134    # False.
2135    FLAGS.set_default('alsologtostderr', True)
2136
2137    # Here we only want to get the `argv` without the flags. To avoid any
2138    # side effects of parsing flags, we temporarily stub out the `parse` method
2139    stored_parse_methods = {}
2140    noop_parse = lambda _: None
2141    for name in FLAGS:
2142      # Avoid any side effects of parsing flags.
2143      stored_parse_methods[name] = FLAGS[name].parse
2144    # This must be a separate loop since multiple flag names (short_name=) can
2145    # point to the same flag object.
2146    for name in FLAGS:
2147      FLAGS[name].parse = noop_parse
2148    try:
2149      argv = FLAGS(sys.argv)
2150    finally:
2151      for name in FLAGS:
2152        FLAGS[name].parse = stored_parse_methods[name]
2153      sys.stdout.flush()
2154
2155    function(argv, args, kwargs)
2156  else:
2157    # Send logging to stderr. Use --alsologtostderr instead of --logtostderr
2158    # in case tests are reading their own logs.
2159    FLAGS.set_default('alsologtostderr', True)
2160
2161    def main_function(argv):
2162      _register_sigterm_with_faulthandler()
2163      function(argv, args, kwargs)
2164
2165    app.run(main=main_function)
2166
2167
2168def _is_suspicious_attribute(testCaseClass, name):
2169  # type: (Type, Text) -> bool
2170  """Returns True if an attribute is a method named like a test method."""
2171  if name.startswith('Test') and len(name) > 4 and name[4].isupper():
2172    attr = getattr(testCaseClass, name)
2173    if inspect.isfunction(attr) or inspect.ismethod(attr):
2174      args = inspect.getfullargspec(attr)
2175      return (len(args.args) == 1 and args.args[0] == 'self' and
2176              args.varargs is None and args.varkw is None and
2177              not args.kwonlyargs)
2178  return False
2179
2180
2181def skipThisClass(reason):
2182  # type: (Text) -> Callable[[_T], _T]
2183  """Skip tests in the decorated TestCase, but not any of its subclasses.
2184
2185  This decorator indicates that this class should skip all its tests, but not
2186  any of its subclasses. Useful for if you want to share testMethod or setUp
2187  implementations between a number of concrete testcase classes.
2188
2189  Example usage, showing how you can share some common test methods between
2190  subclasses. In this example, only ``BaseTest`` will be marked as skipped, and
2191  not RealTest or SecondRealTest::
2192
2193      @absltest.skipThisClass("Shared functionality")
2194      class BaseTest(absltest.TestCase):
2195        def test_simple_functionality(self):
2196          self.assertEqual(self.system_under_test.method(), 1)
2197
2198      class RealTest(BaseTest):
2199        def setUp(self):
2200          super().setUp()
2201          self.system_under_test = MakeSystem(argument)
2202
2203        def test_specific_behavior(self):
2204          ...
2205
2206      class SecondRealTest(BaseTest):
2207        def setUp(self):
2208          super().setUp()
2209          self.system_under_test = MakeSystem(other_arguments)
2210
2211        def test_other_behavior(self):
2212          ...
2213
2214  Args:
2215    reason: The reason we have a skip in place. For instance: 'shared test
2216      methods' or 'shared assertion methods'.
2217
2218  Returns:
2219    Decorator function that will cause a class to be skipped.
2220  """
2221  if isinstance(reason, type):
2222    raise TypeError('Got {!r}, expected reason as string'.format(reason))
2223
2224  def _skip_class(test_case_class):
2225    if not issubclass(test_case_class, unittest.TestCase):
2226      raise TypeError(
2227          'Decorating {!r}, expected TestCase subclass'.format(test_case_class))
2228
2229    # Only shadow the setUpClass method if it is directly defined. If it is
2230    # in the parent class we invoke it via a super() call instead of holding
2231    # a reference to it.
2232    shadowed_setupclass = test_case_class.__dict__.get('setUpClass', None)
2233
2234    @classmethod
2235    def replacement_setupclass(cls, *args, **kwargs):
2236      # Skip this class if it is the one that was decorated with @skipThisClass
2237      if cls is test_case_class:
2238        raise SkipTest(reason)
2239      if shadowed_setupclass:
2240        # Pass along `cls` so the MRO chain doesn't break.
2241        # The original method is a `classmethod` descriptor, which can't
2242        # be directly called, but `__func__` has the underlying function.
2243        return shadowed_setupclass.__func__(cls, *args, **kwargs)
2244      else:
2245        # Because there's no setUpClass() defined directly on test_case_class,
2246        # we call super() ourselves to continue execution of the inheritance
2247        # chain.
2248        return super(test_case_class, cls).setUpClass(*args, **kwargs)
2249
2250    test_case_class.setUpClass = replacement_setupclass
2251    return test_case_class
2252
2253  return _skip_class
2254
2255
2256class TestLoader(unittest.TestLoader):
2257  """A test loader which supports common test features.
2258
2259  Supported features include:
2260   * Banning untested methods with test-like names: methods attached to this
2261     testCase with names starting with `Test` are ignored by the test runner,
2262     and often represent mistakenly-omitted test cases. This loader will raise
2263     a TypeError when attempting to load a TestCase with such methods.
2264   * Randomization of test case execution order (optional).
2265  """
2266
2267  _ERROR_MSG = textwrap.dedent("""Method '%s' is named like a test case but
2268  is not one. This is often a bug. If you want it to be a test method,
2269  name it with 'test' in lowercase. If not, rename the method to not begin
2270  with 'Test'.""")
2271
2272  def __init__(self, *args, **kwds):
2273    super(TestLoader, self).__init__(*args, **kwds)
2274    seed = _get_default_randomize_ordering_seed()
2275    if seed:
2276      self._randomize_ordering_seed = seed
2277      self._random = random.Random(self._randomize_ordering_seed)
2278    else:
2279      self._randomize_ordering_seed = None
2280      self._random = None
2281
2282  def getTestCaseNames(self, testCaseClass):  # pylint:disable=invalid-name
2283    """Validates and returns a (possibly randomized) list of test case names."""
2284    for name in dir(testCaseClass):
2285      if _is_suspicious_attribute(testCaseClass, name):
2286        raise TypeError(TestLoader._ERROR_MSG % name)
2287    names = super(TestLoader, self).getTestCaseNames(testCaseClass)
2288    if self._randomize_ordering_seed is not None:
2289      logging.info(
2290          'Randomizing test order with seed: %d', self._randomize_ordering_seed)
2291      logging.info(
2292          'To reproduce this order, re-run with '
2293          '--test_randomize_ordering_seed=%d', self._randomize_ordering_seed)
2294      self._random.shuffle(names)
2295    return names
2296
2297
2298def get_default_xml_output_filename():
2299  # type: () -> Optional[Text]
2300  if os.environ.get('XML_OUTPUT_FILE'):
2301    return os.environ['XML_OUTPUT_FILE']
2302  elif os.environ.get('RUNNING_UNDER_TEST_DAEMON'):
2303    return os.path.join(os.path.dirname(TEST_TMPDIR.value), 'test_detail.xml')
2304  elif os.environ.get('TEST_XMLOUTPUTDIR'):
2305    return os.path.join(
2306        os.environ['TEST_XMLOUTPUTDIR'],
2307        os.path.splitext(os.path.basename(sys.argv[0]))[0] + '.xml')
2308
2309
2310def _setup_filtering(argv):
2311  # type: (MutableSequence[Text]) -> None
2312  """Implements the bazel test filtering protocol.
2313
2314  The following environment variable is used in this method:
2315
2316    TESTBRIDGE_TEST_ONLY: string, if set, is forwarded to the unittest
2317      framework to use as a test filter. Its value is split with shlex, then:
2318      1. On Python 3.6 and before, split values are passed as positional
2319         arguments on argv.
2320      2. On Python 3.7+, split values are passed to unittest's `-k` flag. Tests
2321         are matched by glob patterns or substring. See
2322         https://docs.python.org/3/library/unittest.html#cmdoption-unittest-k
2323
2324  Args:
2325    argv: the argv to mutate in-place.
2326  """
2327  test_filter = os.environ.get('TESTBRIDGE_TEST_ONLY')
2328  if argv is None or not test_filter:
2329    return
2330
2331  filters = shlex.split(test_filter)
2332  if sys.version_info[:2] >= (3, 7):
2333    filters = ['-k=' + test_filter for test_filter in filters]
2334
2335  argv[1:1] = filters
2336
2337
2338def _setup_test_runner_fail_fast(argv):
2339  # type: (MutableSequence[Text]) -> None
2340  """Implements the bazel test fail fast protocol.
2341
2342  The following environment variable is used in this method:
2343
2344    TESTBRIDGE_TEST_RUNNER_FAIL_FAST=<1|0>
2345
2346  If set to 1, --failfast is passed to the unittest framework to return upon
2347  first failure.
2348
2349  Args:
2350    argv: the argv to mutate in-place.
2351  """
2352
2353  if argv is None:
2354    return
2355
2356  if os.environ.get('TESTBRIDGE_TEST_RUNNER_FAIL_FAST') != '1':
2357    return
2358
2359  argv[1:1] = ['--failfast']
2360
2361
2362def _setup_sharding(custom_loader=None):
2363  # type: (Optional[unittest.TestLoader]) -> unittest.TestLoader
2364  """Implements the bazel sharding protocol.
2365
2366  The following environment variables are used in this method:
2367
2368    TEST_SHARD_STATUS_FILE: string, if set, points to a file. We write a blank
2369      file to tell the test runner that this test implements the test sharding
2370      protocol.
2371
2372    TEST_TOTAL_SHARDS: int, if set, sharding is requested.
2373
2374    TEST_SHARD_INDEX: int, must be set if TEST_TOTAL_SHARDS is set. Specifies
2375      the shard index for this instance of the test process. Must satisfy:
2376      0 <= TEST_SHARD_INDEX < TEST_TOTAL_SHARDS.
2377
2378  Args:
2379    custom_loader: A TestLoader to be made sharded.
2380
2381  Returns:
2382    The test loader for shard-filtering or the standard test loader, depending
2383    on the sharding environment variables.
2384  """
2385
2386  # It may be useful to write the shard file even if the other sharding
2387  # environment variables are not set. Test runners may use this functionality
2388  # to query whether a test binary implements the test sharding protocol.
2389  if 'TEST_SHARD_STATUS_FILE' in os.environ:
2390    try:
2391      with open(os.environ['TEST_SHARD_STATUS_FILE'], 'w') as f:
2392        f.write('')
2393    except IOError:
2394      sys.stderr.write('Error opening TEST_SHARD_STATUS_FILE (%s). Exiting.'
2395                       % os.environ['TEST_SHARD_STATUS_FILE'])
2396      sys.exit(1)
2397
2398  base_loader = custom_loader or TestLoader()
2399  if 'TEST_TOTAL_SHARDS' not in os.environ:
2400    # Not using sharding, use the expected test loader.
2401    return base_loader
2402
2403  total_shards = int(os.environ['TEST_TOTAL_SHARDS'])
2404  shard_index = int(os.environ['TEST_SHARD_INDEX'])
2405
2406  if shard_index < 0 or shard_index >= total_shards:
2407    sys.stderr.write('ERROR: Bad sharding values. index=%d, total=%d\n' %
2408                     (shard_index, total_shards))
2409    sys.exit(1)
2410
2411  # Replace the original getTestCaseNames with one that returns
2412  # the test case names for this shard.
2413  delegate_get_names = base_loader.getTestCaseNames
2414
2415  bucket_iterator = itertools.cycle(range(total_shards))
2416
2417  def getShardedTestCaseNames(testCaseClass):
2418    filtered_names = []
2419    # We need to sort the list of tests in order to determine which tests this
2420    # shard is responsible for; however, it's important to preserve the order
2421    # returned by the base loader, e.g. in the case of randomized test ordering.
2422    ordered_names = delegate_get_names(testCaseClass)
2423    for testcase in sorted(ordered_names):
2424      bucket = next(bucket_iterator)
2425      if bucket == shard_index:
2426        filtered_names.append(testcase)
2427    return [x for x in ordered_names if x in filtered_names]
2428
2429  base_loader.getTestCaseNames = getShardedTestCaseNames
2430  return base_loader
2431
2432
2433# pylint: disable=line-too-long
2434def _run_and_get_tests_result(argv, args, kwargs, xml_test_runner_class):
2435  # type: (MutableSequence[Text], Sequence[Any], MutableMapping[Text, Any], Type) -> unittest.TestResult
2436  # pylint: enable=line-too-long
2437  """Same as run_tests, except it returns the result instead of exiting."""
2438
2439  # The entry from kwargs overrides argv.
2440  argv = kwargs.pop('argv', argv)
2441
2442  # Set up test filtering if requested in environment.
2443  _setup_filtering(argv)
2444  # Set up --failfast as requested in environment
2445  _setup_test_runner_fail_fast(argv)
2446
2447  # Shard the (default or custom) loader if sharding is turned on.
2448  kwargs['testLoader'] = _setup_sharding(kwargs.get('testLoader', None))
2449
2450  # XML file name is based upon (sorted by priority):
2451  # --xml_output_file flag, XML_OUTPUT_FILE variable,
2452  # TEST_XMLOUTPUTDIR variable or RUNNING_UNDER_TEST_DAEMON variable.
2453  if not FLAGS.xml_output_file:
2454    FLAGS.xml_output_file = get_default_xml_output_filename()
2455  xml_output_file = FLAGS.xml_output_file
2456
2457  xml_buffer = None
2458  if xml_output_file:
2459    xml_output_dir = os.path.dirname(xml_output_file)
2460    if xml_output_dir and not os.path.isdir(xml_output_dir):
2461      try:
2462        os.makedirs(xml_output_dir)
2463      except OSError as e:
2464        # File exists error can occur with concurrent tests
2465        if e.errno != errno.EEXIST:
2466          raise
2467    # Fail early if we can't write to the XML output file. This is so that we
2468    # don't waste people's time running tests that will just fail anyways.
2469    with _open(xml_output_file, 'w'):
2470      pass
2471
2472    # We can reuse testRunner if it supports XML output (e. g. by inheriting
2473    # from xml_reporter.TextAndXMLTestRunner). Otherwise we need to use
2474    # xml_reporter.TextAndXMLTestRunner.
2475    if (kwargs.get('testRunner') is not None
2476        and not hasattr(kwargs['testRunner'], 'set_default_xml_stream')):
2477      sys.stderr.write('WARNING: XML_OUTPUT_FILE or --xml_output_file setting '
2478                       'overrides testRunner=%r setting (possibly from --pdb)'
2479                       % (kwargs['testRunner']))
2480      # Passing a class object here allows TestProgram to initialize
2481      # instances based on its kwargs and/or parsed command-line args.
2482      kwargs['testRunner'] = xml_test_runner_class
2483    if kwargs.get('testRunner') is None:
2484      kwargs['testRunner'] = xml_test_runner_class
2485    # Use an in-memory buffer (not backed by the actual file) to store the XML
2486    # report, because some tools modify the file (e.g., create a placeholder
2487    # with partial information, in case the test process crashes).
2488    xml_buffer = io.StringIO()
2489    kwargs['testRunner'].set_default_xml_stream(xml_buffer)  # pytype: disable=attribute-error
2490
2491    # If we've used a seed to randomize test case ordering, we want to record it
2492    # as a top-level attribute in the `testsuites` section of the XML output.
2493    randomize_ordering_seed = getattr(
2494        kwargs['testLoader'], '_randomize_ordering_seed', None)
2495    setter = getattr(kwargs['testRunner'], 'set_testsuites_property', None)
2496    if randomize_ordering_seed and setter:
2497      setter('test_randomize_ordering_seed', randomize_ordering_seed)
2498  elif kwargs.get('testRunner') is None:
2499    kwargs['testRunner'] = _pretty_print_reporter.TextTestRunner
2500
2501  if FLAGS.pdb_post_mortem:
2502    runner = kwargs['testRunner']
2503    # testRunner can be a class or an instance, which must be tested for
2504    # differently.
2505    # Overriding testRunner isn't uncommon, so only enable the debugging
2506    # integration if the runner claims it does; we don't want to accidentally
2507    # clobber something on the runner.
2508    if ((isinstance(runner, type) and
2509         issubclass(runner, _pretty_print_reporter.TextTestRunner)) or
2510        isinstance(runner, _pretty_print_reporter.TextTestRunner)):
2511      runner.run_for_debugging = True
2512
2513  # Make sure tmpdir exists.
2514  if not os.path.isdir(TEST_TMPDIR.value):
2515    try:
2516      os.makedirs(TEST_TMPDIR.value)
2517    except OSError as e:
2518      # Concurrent test might have created the directory.
2519      if e.errno != errno.EEXIST:
2520        raise
2521
2522  # Let unittest.TestProgram.__init__ do its own argv parsing, e.g. for '-v',
2523  # on argv, which is sys.argv without the command-line flags.
2524  kwargs['argv'] = argv
2525
2526  try:
2527    test_program = unittest.TestProgram(*args, **kwargs)
2528    return test_program.result
2529  finally:
2530    if xml_buffer:
2531      try:
2532        with _open(xml_output_file, 'w') as f:
2533          f.write(xml_buffer.getvalue())
2534      finally:
2535        xml_buffer.close()
2536
2537
2538def run_tests(argv, args, kwargs):  # pylint: disable=line-too-long
2539  # type: (MutableSequence[Text], Sequence[Any], MutableMapping[Text, Any]) -> None
2540  # pylint: enable=line-too-long
2541  """Executes a set of Python unit tests.
2542
2543  Most users should call absltest.main() instead of run_tests.
2544
2545  Please note that run_tests should be called from app.run.
2546  Calling absltest.main() would ensure that.
2547
2548  Please note that run_tests is allowed to make changes to kwargs.
2549
2550  Args:
2551    argv: sys.argv with the command-line flags removed from the front, i.e. the
2552      argv with which :func:`app.run()<absl.app.run>` has called
2553      ``__main__.main``. It is passed to
2554      ``unittest.TestProgram.__init__(argv=)``, which does its own flag parsing.
2555      It is ignored if kwargs contains an argv entry.
2556    args: Positional arguments passed through to
2557      ``unittest.TestProgram.__init__``.
2558    kwargs: Keyword arguments passed through to
2559      ``unittest.TestProgram.__init__``.
2560  """
2561  result = _run_and_get_tests_result(
2562      argv, args, kwargs, xml_reporter.TextAndXMLTestRunner)
2563  sys.exit(not result.wasSuccessful())
2564
2565
2566def _rmtree_ignore_errors(path):
2567  # type: (Text) -> None
2568  if os.path.isfile(path):
2569    try:
2570      os.unlink(path)
2571    except OSError:
2572      pass
2573  else:
2574    shutil.rmtree(path, ignore_errors=True)
2575
2576
2577def _get_first_part(path):
2578  # type: (Text) -> Text
2579  parts = path.split(os.sep, 1)
2580  return parts[0]
2581