• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Test case implementation"""
2
3import sys
4import functools
5import difflib
6import pprint
7import re
8import warnings
9import collections
10import contextlib
11import traceback
12import types
13
14from . import result
15from .util import (strclass, safe_repr, _count_diff_all_purpose,
16                   _count_diff_hashable, _common_shorten_repr)
17
18__unittest = True
19
20_subtest_msg_sentinel = object()
21
22DIFF_OMITTED = ('\nDiff is %s characters long. '
23                 'Set self.maxDiff to None to see it.')
24
25class SkipTest(Exception):
26    """
27    Raise this exception in a test to skip it.
28
29    Usually you can use TestCase.skipTest() or one of the skipping decorators
30    instead of raising this directly.
31    """
32
33class _ShouldStop(Exception):
34    """
35    The test should stop.
36    """
37
38class _UnexpectedSuccess(Exception):
39    """
40    The test was supposed to fail, but it didn't!
41    """
42
43
44class _Outcome(object):
45    def __init__(self, result=None):
46        self.expecting_failure = False
47        self.result = result
48        self.result_supports_subtests = hasattr(result, "addSubTest")
49        self.success = True
50        self.skipped = []
51        self.expectedFailure = None
52        self.errors = []
53
54    @contextlib.contextmanager
55    def testPartExecutor(self, test_case, isTest=False):
56        old_success = self.success
57        self.success = True
58        try:
59            yield
60        except KeyboardInterrupt:
61            raise
62        except SkipTest as e:
63            self.success = False
64            self.skipped.append((test_case, str(e)))
65        except _ShouldStop:
66            pass
67        except:
68            exc_info = sys.exc_info()
69            if self.expecting_failure:
70                self.expectedFailure = exc_info
71            else:
72                self.success = False
73                self.errors.append((test_case, exc_info))
74            # explicitly break a reference cycle:
75            # exc_info -> frame -> exc_info
76            exc_info = None
77        else:
78            if self.result_supports_subtests and self.success:
79                self.errors.append((test_case, None))
80        finally:
81            self.success = self.success and old_success
82
83
84def _id(obj):
85    return obj
86
87
88_module_cleanups = []
89def addModuleCleanup(function, /, *args, **kwargs):
90    """Same as addCleanup, except the cleanup items are called even if
91    setUpModule fails (unlike tearDownModule)."""
92    _module_cleanups.append((function, args, kwargs))
93
94
95def doModuleCleanups():
96    """Execute all module cleanup functions. Normally called for you after
97    tearDownModule."""
98    exceptions = []
99    while _module_cleanups:
100        function, args, kwargs = _module_cleanups.pop()
101        try:
102            function(*args, **kwargs)
103        except Exception as exc:
104            exceptions.append(exc)
105    if exceptions:
106        # Swallows all but first exception. If a multi-exception handler
107        # gets written we should use that here instead.
108        raise exceptions[0]
109
110
111def skip(reason):
112    """
113    Unconditionally skip a test.
114    """
115    def decorator(test_item):
116        if not isinstance(test_item, type):
117            @functools.wraps(test_item)
118            def skip_wrapper(*args, **kwargs):
119                raise SkipTest(reason)
120            test_item = skip_wrapper
121
122        test_item.__unittest_skip__ = True
123        test_item.__unittest_skip_why__ = reason
124        return test_item
125    if isinstance(reason, types.FunctionType):
126        test_item = reason
127        reason = ''
128        return decorator(test_item)
129    return decorator
130
131def skipIf(condition, reason):
132    """
133    Skip a test if the condition is true.
134    """
135    if condition:
136        return skip(reason)
137    return _id
138
139def skipUnless(condition, reason):
140    """
141    Skip a test unless the condition is true.
142    """
143    if not condition:
144        return skip(reason)
145    return _id
146
147def expectedFailure(test_item):
148    test_item.__unittest_expecting_failure__ = True
149    return test_item
150
151def _is_subtype(expected, basetype):
152    if isinstance(expected, tuple):
153        return all(_is_subtype(e, basetype) for e in expected)
154    return isinstance(expected, type) and issubclass(expected, basetype)
155
156class _BaseTestCaseContext:
157
158    def __init__(self, test_case):
159        self.test_case = test_case
160
161    def _raiseFailure(self, standardMsg):
162        msg = self.test_case._formatMessage(self.msg, standardMsg)
163        raise self.test_case.failureException(msg)
164
165class _AssertRaisesBaseContext(_BaseTestCaseContext):
166
167    def __init__(self, expected, test_case, expected_regex=None):
168        _BaseTestCaseContext.__init__(self, test_case)
169        self.expected = expected
170        self.test_case = test_case
171        if expected_regex is not None:
172            expected_regex = re.compile(expected_regex)
173        self.expected_regex = expected_regex
174        self.obj_name = None
175        self.msg = None
176
177    def handle(self, name, args, kwargs):
178        """
179        If args is empty, assertRaises/Warns is being used as a
180        context manager, so check for a 'msg' kwarg and return self.
181        If args is not empty, call a callable passing positional and keyword
182        arguments.
183        """
184        try:
185            if not _is_subtype(self.expected, self._base_type):
186                raise TypeError('%s() arg 1 must be %s' %
187                                (name, self._base_type_str))
188            if not args:
189                self.msg = kwargs.pop('msg', None)
190                if kwargs:
191                    raise TypeError('%r is an invalid keyword argument for '
192                                    'this function' % (next(iter(kwargs)),))
193                return self
194
195            callable_obj, *args = args
196            try:
197                self.obj_name = callable_obj.__name__
198            except AttributeError:
199                self.obj_name = str(callable_obj)
200            with self:
201                callable_obj(*args, **kwargs)
202        finally:
203            # bpo-23890: manually break a reference cycle
204            self = None
205
206
207class _AssertRaisesContext(_AssertRaisesBaseContext):
208    """A context manager used to implement TestCase.assertRaises* methods."""
209
210    _base_type = BaseException
211    _base_type_str = 'an exception type or tuple of exception types'
212
213    def __enter__(self):
214        return self
215
216    def __exit__(self, exc_type, exc_value, tb):
217        if exc_type is None:
218            try:
219                exc_name = self.expected.__name__
220            except AttributeError:
221                exc_name = str(self.expected)
222            if self.obj_name:
223                self._raiseFailure("{} not raised by {}".format(exc_name,
224                                                                self.obj_name))
225            else:
226                self._raiseFailure("{} not raised".format(exc_name))
227        else:
228            traceback.clear_frames(tb)
229        if not issubclass(exc_type, self.expected):
230            # let unexpected exceptions pass through
231            return False
232        # store exception, without traceback, for later retrieval
233        self.exception = exc_value.with_traceback(None)
234        if self.expected_regex is None:
235            return True
236
237        expected_regex = self.expected_regex
238        if not expected_regex.search(str(exc_value)):
239            self._raiseFailure('"{}" does not match "{}"'.format(
240                     expected_regex.pattern, str(exc_value)))
241        return True
242
243    __class_getitem__ = classmethod(types.GenericAlias)
244
245
246class _AssertWarnsContext(_AssertRaisesBaseContext):
247    """A context manager used to implement TestCase.assertWarns* methods."""
248
249    _base_type = Warning
250    _base_type_str = 'a warning type or tuple of warning types'
251
252    def __enter__(self):
253        # The __warningregistry__'s need to be in a pristine state for tests
254        # to work properly.
255        for v in sys.modules.values():
256            if getattr(v, '__warningregistry__', None):
257                v.__warningregistry__ = {}
258        self.warnings_manager = warnings.catch_warnings(record=True)
259        self.warnings = self.warnings_manager.__enter__()
260        warnings.simplefilter("always", self.expected)
261        return self
262
263    def __exit__(self, exc_type, exc_value, tb):
264        self.warnings_manager.__exit__(exc_type, exc_value, tb)
265        if exc_type is not None:
266            # let unexpected exceptions pass through
267            return
268        try:
269            exc_name = self.expected.__name__
270        except AttributeError:
271            exc_name = str(self.expected)
272        first_matching = None
273        for m in self.warnings:
274            w = m.message
275            if not isinstance(w, self.expected):
276                continue
277            if first_matching is None:
278                first_matching = w
279            if (self.expected_regex is not None and
280                not self.expected_regex.search(str(w))):
281                continue
282            # store warning for later retrieval
283            self.warning = w
284            self.filename = m.filename
285            self.lineno = m.lineno
286            return
287        # Now we simply try to choose a helpful failure message
288        if first_matching is not None:
289            self._raiseFailure('"{}" does not match "{}"'.format(
290                     self.expected_regex.pattern, str(first_matching)))
291        if self.obj_name:
292            self._raiseFailure("{} not triggered by {}".format(exc_name,
293                                                               self.obj_name))
294        else:
295            self._raiseFailure("{} not triggered".format(exc_name))
296
297
298
299class _OrderedChainMap(collections.ChainMap):
300    def __iter__(self):
301        seen = set()
302        for mapping in self.maps:
303            for k in mapping:
304                if k not in seen:
305                    seen.add(k)
306                    yield k
307
308
309class TestCase(object):
310    """A class whose instances are single test cases.
311
312    By default, the test code itself should be placed in a method named
313    'runTest'.
314
315    If the fixture may be used for many test cases, create as
316    many test methods as are needed. When instantiating such a TestCase
317    subclass, specify in the constructor arguments the name of the test method
318    that the instance is to execute.
319
320    Test authors should subclass TestCase for their own tests. Construction
321    and deconstruction of the test's environment ('fixture') can be
322    implemented by overriding the 'setUp' and 'tearDown' methods respectively.
323
324    If it is necessary to override the __init__ method, the base class
325    __init__ method must always be called. It is important that subclasses
326    should not change the signature of their __init__ method, since instances
327    of the classes are instantiated automatically by parts of the framework
328    in order to be run.
329
330    When subclassing TestCase, you can set these attributes:
331    * failureException: determines which exception will be raised when
332        the instance's assertion methods fail; test methods raising this
333        exception will be deemed to have 'failed' rather than 'errored'.
334    * longMessage: determines whether long messages (including repr of
335        objects used in assert methods) will be printed on failure in *addition*
336        to any explicit message passed.
337    * maxDiff: sets the maximum length of a diff in failure messages
338        by assert methods using difflib. It is looked up as an instance
339        attribute so can be configured by individual tests if required.
340    """
341
342    failureException = AssertionError
343
344    longMessage = True
345
346    maxDiff = 80*8
347
348    # If a string is longer than _diffThreshold, use normal comparison instead
349    # of difflib.  See #11763.
350    _diffThreshold = 2**16
351
352    # Attribute used by TestSuite for classSetUp
353
354    _classSetupFailed = False
355
356    _class_cleanups = []
357
358    def __init__(self, methodName='runTest'):
359        """Create an instance of the class that will use the named test
360           method when executed. Raises a ValueError if the instance does
361           not have a method with the specified name.
362        """
363        self._testMethodName = methodName
364        self._outcome = None
365        self._testMethodDoc = 'No test'
366        try:
367            testMethod = getattr(self, methodName)
368        except AttributeError:
369            if methodName != 'runTest':
370                # we allow instantiation with no explicit method name
371                # but not an *incorrect* or missing method name
372                raise ValueError("no such test method in %s: %s" %
373                      (self.__class__, methodName))
374        else:
375            self._testMethodDoc = testMethod.__doc__
376        self._cleanups = []
377        self._subtest = None
378
379        # Map types to custom assertEqual functions that will compare
380        # instances of said type in more detail to generate a more useful
381        # error message.
382        self._type_equality_funcs = {}
383        self.addTypeEqualityFunc(dict, 'assertDictEqual')
384        self.addTypeEqualityFunc(list, 'assertListEqual')
385        self.addTypeEqualityFunc(tuple, 'assertTupleEqual')
386        self.addTypeEqualityFunc(set, 'assertSetEqual')
387        self.addTypeEqualityFunc(frozenset, 'assertSetEqual')
388        self.addTypeEqualityFunc(str, 'assertMultiLineEqual')
389
390    def addTypeEqualityFunc(self, typeobj, function):
391        """Add a type specific assertEqual style function to compare a type.
392
393        This method is for use by TestCase subclasses that need to register
394        their own type equality functions to provide nicer error messages.
395
396        Args:
397            typeobj: The data type to call this function on when both values
398                    are of the same type in assertEqual().
399            function: The callable taking two arguments and an optional
400                    msg= argument that raises self.failureException with a
401                    useful error message when the two arguments are not equal.
402        """
403        self._type_equality_funcs[typeobj] = function
404
405    def addCleanup(self, function, /, *args, **kwargs):
406        """Add a function, with arguments, to be called when the test is
407        completed. Functions added are called on a LIFO basis and are
408        called after tearDown on test failure or success.
409
410        Cleanup items are called even if setUp fails (unlike tearDown)."""
411        self._cleanups.append((function, args, kwargs))
412
413    @classmethod
414    def addClassCleanup(cls, function, /, *args, **kwargs):
415        """Same as addCleanup, except the cleanup items are called even if
416        setUpClass fails (unlike tearDownClass)."""
417        cls._class_cleanups.append((function, args, kwargs))
418
419    def setUp(self):
420        "Hook method for setting up the test fixture before exercising it."
421        pass
422
423    def tearDown(self):
424        "Hook method for deconstructing the test fixture after testing it."
425        pass
426
427    @classmethod
428    def setUpClass(cls):
429        "Hook method for setting up class fixture before running tests in the class."
430
431    @classmethod
432    def tearDownClass(cls):
433        "Hook method for deconstructing the class fixture after running all tests in the class."
434
435    def countTestCases(self):
436        return 1
437
438    def defaultTestResult(self):
439        return result.TestResult()
440
441    def shortDescription(self):
442        """Returns a one-line description of the test, or None if no
443        description has been provided.
444
445        The default implementation of this method returns the first line of
446        the specified test method's docstring.
447        """
448        doc = self._testMethodDoc
449        return doc.strip().split("\n")[0].strip() if doc else None
450
451
452    def id(self):
453        return "%s.%s" % (strclass(self.__class__), self._testMethodName)
454
455    def __eq__(self, other):
456        if type(self) is not type(other):
457            return NotImplemented
458
459        return self._testMethodName == other._testMethodName
460
461    def __hash__(self):
462        return hash((type(self), self._testMethodName))
463
464    def __str__(self):
465        return "%s (%s)" % (self._testMethodName, strclass(self.__class__))
466
467    def __repr__(self):
468        return "<%s testMethod=%s>" % \
469               (strclass(self.__class__), self._testMethodName)
470
471    def _addSkip(self, result, test_case, reason):
472        addSkip = getattr(result, 'addSkip', None)
473        if addSkip is not None:
474            addSkip(test_case, reason)
475        else:
476            warnings.warn("TestResult has no addSkip method, skips not reported",
477                          RuntimeWarning, 2)
478            result.addSuccess(test_case)
479
480    @contextlib.contextmanager
481    def subTest(self, msg=_subtest_msg_sentinel, **params):
482        """Return a context manager that will return the enclosed block
483        of code in a subtest identified by the optional message and
484        keyword parameters.  A failure in the subtest marks the test
485        case as failed but resumes execution at the end of the enclosed
486        block, allowing further test code to be executed.
487        """
488        if self._outcome is None or not self._outcome.result_supports_subtests:
489            yield
490            return
491        parent = self._subtest
492        if parent is None:
493            params_map = _OrderedChainMap(params)
494        else:
495            params_map = parent.params.new_child(params)
496        self._subtest = _SubTest(self, msg, params_map)
497        try:
498            with self._outcome.testPartExecutor(self._subtest, isTest=True):
499                yield
500            if not self._outcome.success:
501                result = self._outcome.result
502                if result is not None and result.failfast:
503                    raise _ShouldStop
504            elif self._outcome.expectedFailure:
505                # If the test is expecting a failure, we really want to
506                # stop now and register the expected failure.
507                raise _ShouldStop
508        finally:
509            self._subtest = parent
510
511    def _feedErrorsToResult(self, result, errors):
512        for test, exc_info in errors:
513            if isinstance(test, _SubTest):
514                result.addSubTest(test.test_case, test, exc_info)
515            elif exc_info is not None:
516                if issubclass(exc_info[0], self.failureException):
517                    result.addFailure(test, exc_info)
518                else:
519                    result.addError(test, exc_info)
520
521    def _addExpectedFailure(self, result, exc_info):
522        try:
523            addExpectedFailure = result.addExpectedFailure
524        except AttributeError:
525            warnings.warn("TestResult has no addExpectedFailure method, reporting as passes",
526                          RuntimeWarning)
527            result.addSuccess(self)
528        else:
529            addExpectedFailure(self, exc_info)
530
531    def _addUnexpectedSuccess(self, result):
532        try:
533            addUnexpectedSuccess = result.addUnexpectedSuccess
534        except AttributeError:
535            warnings.warn("TestResult has no addUnexpectedSuccess method, reporting as failure",
536                          RuntimeWarning)
537            # We need to pass an actual exception and traceback to addFailure,
538            # otherwise the legacy result can choke.
539            try:
540                raise _UnexpectedSuccess from None
541            except _UnexpectedSuccess:
542                result.addFailure(self, sys.exc_info())
543        else:
544            addUnexpectedSuccess(self)
545
546    def _callSetUp(self):
547        self.setUp()
548
549    def _callTestMethod(self, method):
550        method()
551
552    def _callTearDown(self):
553        self.tearDown()
554
555    def _callCleanup(self, function, /, *args, **kwargs):
556        function(*args, **kwargs)
557
558    def run(self, result=None):
559        orig_result = result
560        if result is None:
561            result = self.defaultTestResult()
562            startTestRun = getattr(result, 'startTestRun', None)
563            if startTestRun is not None:
564                startTestRun()
565
566        result.startTest(self)
567
568        testMethod = getattr(self, self._testMethodName)
569        if (getattr(self.__class__, "__unittest_skip__", False) or
570            getattr(testMethod, "__unittest_skip__", False)):
571            # If the class or method was skipped.
572            try:
573                skip_why = (getattr(self.__class__, '__unittest_skip_why__', '')
574                            or getattr(testMethod, '__unittest_skip_why__', ''))
575                self._addSkip(result, self, skip_why)
576            finally:
577                result.stopTest(self)
578            return
579        expecting_failure_method = getattr(testMethod,
580                                           "__unittest_expecting_failure__", False)
581        expecting_failure_class = getattr(self,
582                                          "__unittest_expecting_failure__", False)
583        expecting_failure = expecting_failure_class or expecting_failure_method
584        outcome = _Outcome(result)
585        try:
586            self._outcome = outcome
587
588            with outcome.testPartExecutor(self):
589                self._callSetUp()
590            if outcome.success:
591                outcome.expecting_failure = expecting_failure
592                with outcome.testPartExecutor(self, isTest=True):
593                    self._callTestMethod(testMethod)
594                outcome.expecting_failure = False
595                with outcome.testPartExecutor(self):
596                    self._callTearDown()
597
598            self.doCleanups()
599            for test, reason in outcome.skipped:
600                self._addSkip(result, test, reason)
601            self._feedErrorsToResult(result, outcome.errors)
602            if outcome.success:
603                if expecting_failure:
604                    if outcome.expectedFailure:
605                        self._addExpectedFailure(result, outcome.expectedFailure)
606                    else:
607                        self._addUnexpectedSuccess(result)
608                else:
609                    result.addSuccess(self)
610            return result
611        finally:
612            result.stopTest(self)
613            if orig_result is None:
614                stopTestRun = getattr(result, 'stopTestRun', None)
615                if stopTestRun is not None:
616                    stopTestRun()
617
618            # explicitly break reference cycles:
619            # outcome.errors -> frame -> outcome -> outcome.errors
620            # outcome.expectedFailure -> frame -> outcome -> outcome.expectedFailure
621            outcome.errors.clear()
622            outcome.expectedFailure = None
623
624            # clear the outcome, no more needed
625            self._outcome = None
626
627    def doCleanups(self):
628        """Execute all cleanup functions. Normally called for you after
629        tearDown."""
630        outcome = self._outcome or _Outcome()
631        while self._cleanups:
632            function, args, kwargs = self._cleanups.pop()
633            with outcome.testPartExecutor(self):
634                self._callCleanup(function, *args, **kwargs)
635
636        # return this for backwards compatibility
637        # even though we no longer use it internally
638        return outcome.success
639
640    @classmethod
641    def doClassCleanups(cls):
642        """Execute all class cleanup functions. Normally called for you after
643        tearDownClass."""
644        cls.tearDown_exceptions = []
645        while cls._class_cleanups:
646            function, args, kwargs = cls._class_cleanups.pop()
647            try:
648                function(*args, **kwargs)
649            except Exception:
650                cls.tearDown_exceptions.append(sys.exc_info())
651
652    def __call__(self, *args, **kwds):
653        return self.run(*args, **kwds)
654
655    def debug(self):
656        """Run the test without collecting errors in a TestResult"""
657        self.setUp()
658        getattr(self, self._testMethodName)()
659        self.tearDown()
660        while self._cleanups:
661            function, args, kwargs = self._cleanups.pop(-1)
662            function(*args, **kwargs)
663
664    def skipTest(self, reason):
665        """Skip this test."""
666        raise SkipTest(reason)
667
668    def fail(self, msg=None):
669        """Fail immediately, with the given message."""
670        raise self.failureException(msg)
671
672    def assertFalse(self, expr, msg=None):
673        """Check that the expression is false."""
674        if expr:
675            msg = self._formatMessage(msg, "%s is not false" % safe_repr(expr))
676            raise self.failureException(msg)
677
678    def assertTrue(self, expr, msg=None):
679        """Check that the expression is true."""
680        if not expr:
681            msg = self._formatMessage(msg, "%s is not true" % safe_repr(expr))
682            raise self.failureException(msg)
683
684    def _formatMessage(self, msg, standardMsg):
685        """Honour the longMessage attribute when generating failure messages.
686        If longMessage is False this means:
687        * Use only an explicit message if it is provided
688        * Otherwise use the standard message for the assert
689
690        If longMessage is True:
691        * Use the standard message
692        * If an explicit message is provided, plus ' : ' and the explicit message
693        """
694        if not self.longMessage:
695            return msg or standardMsg
696        if msg is None:
697            return standardMsg
698        try:
699            # don't switch to '{}' formatting in Python 2.X
700            # it changes the way unicode input is handled
701            return '%s : %s' % (standardMsg, msg)
702        except UnicodeDecodeError:
703            return  '%s : %s' % (safe_repr(standardMsg), safe_repr(msg))
704
705    def assertRaises(self, expected_exception, *args, **kwargs):
706        """Fail unless an exception of class expected_exception is raised
707           by the callable when invoked with specified positional and
708           keyword arguments. If a different type of exception is
709           raised, it will not be caught, and the test case will be
710           deemed to have suffered an error, exactly as for an
711           unexpected exception.
712
713           If called with the callable and arguments omitted, will return a
714           context object used like this::
715
716                with self.assertRaises(SomeException):
717                    do_something()
718
719           An optional keyword argument 'msg' can be provided when assertRaises
720           is used as a context object.
721
722           The context manager keeps a reference to the exception as
723           the 'exception' attribute. This allows you to inspect the
724           exception after the assertion::
725
726               with self.assertRaises(SomeException) as cm:
727                   do_something()
728               the_exception = cm.exception
729               self.assertEqual(the_exception.error_code, 3)
730        """
731        context = _AssertRaisesContext(expected_exception, self)
732        try:
733            return context.handle('assertRaises', args, kwargs)
734        finally:
735            # bpo-23890: manually break a reference cycle
736            context = None
737
738    def assertWarns(self, expected_warning, *args, **kwargs):
739        """Fail unless a warning of class warnClass is triggered
740           by the callable when invoked with specified positional and
741           keyword arguments.  If a different type of warning is
742           triggered, it will not be handled: depending on the other
743           warning filtering rules in effect, it might be silenced, printed
744           out, or raised as an exception.
745
746           If called with the callable and arguments omitted, will return a
747           context object used like this::
748
749                with self.assertWarns(SomeWarning):
750                    do_something()
751
752           An optional keyword argument 'msg' can be provided when assertWarns
753           is used as a context object.
754
755           The context manager keeps a reference to the first matching
756           warning as the 'warning' attribute; similarly, the 'filename'
757           and 'lineno' attributes give you information about the line
758           of Python code from which the warning was triggered.
759           This allows you to inspect the warning after the assertion::
760
761               with self.assertWarns(SomeWarning) as cm:
762                   do_something()
763               the_warning = cm.warning
764               self.assertEqual(the_warning.some_attribute, 147)
765        """
766        context = _AssertWarnsContext(expected_warning, self)
767        return context.handle('assertWarns', args, kwargs)
768
769    def assertLogs(self, logger=None, level=None):
770        """Fail unless a log message of level *level* or higher is emitted
771        on *logger_name* or its children.  If omitted, *level* defaults to
772        INFO and *logger* defaults to the root logger.
773
774        This method must be used as a context manager, and will yield
775        a recording object with two attributes: `output` and `records`.
776        At the end of the context manager, the `output` attribute will
777        be a list of the matching formatted log messages and the
778        `records` attribute will be a list of the corresponding LogRecord
779        objects.
780
781        Example::
782
783            with self.assertLogs('foo', level='INFO') as cm:
784                logging.getLogger('foo').info('first message')
785                logging.getLogger('foo.bar').error('second message')
786            self.assertEqual(cm.output, ['INFO:foo:first message',
787                                         'ERROR:foo.bar:second message'])
788        """
789        # Lazy import to avoid importing logging if it is not needed.
790        from ._log import _AssertLogsContext
791        return _AssertLogsContext(self, logger, level)
792
793    def _getAssertEqualityFunc(self, first, second):
794        """Get a detailed comparison function for the types of the two args.
795
796        Returns: A callable accepting (first, second, msg=None) that will
797        raise a failure exception if first != second with a useful human
798        readable error message for those types.
799        """
800        #
801        # NOTE(gregory.p.smith): I considered isinstance(first, type(second))
802        # and vice versa.  I opted for the conservative approach in case
803        # subclasses are not intended to be compared in detail to their super
804        # class instances using a type equality func.  This means testing
805        # subtypes won't automagically use the detailed comparison.  Callers
806        # should use their type specific assertSpamEqual method to compare
807        # subclasses if the detailed comparison is desired and appropriate.
808        # See the discussion in http://bugs.python.org/issue2578.
809        #
810        if type(first) is type(second):
811            asserter = self._type_equality_funcs.get(type(first))
812            if asserter is not None:
813                if isinstance(asserter, str):
814                    asserter = getattr(self, asserter)
815                return asserter
816
817        return self._baseAssertEqual
818
819    def _baseAssertEqual(self, first, second, msg=None):
820        """The default assertEqual implementation, not type specific."""
821        if not first == second:
822            standardMsg = '%s != %s' % _common_shorten_repr(first, second)
823            msg = self._formatMessage(msg, standardMsg)
824            raise self.failureException(msg)
825
826    def assertEqual(self, first, second, msg=None):
827        """Fail if the two objects are unequal as determined by the '=='
828           operator.
829        """
830        assertion_func = self._getAssertEqualityFunc(first, second)
831        assertion_func(first, second, msg=msg)
832
833    def assertNotEqual(self, first, second, msg=None):
834        """Fail if the two objects are equal as determined by the '!='
835           operator.
836        """
837        if not first != second:
838            msg = self._formatMessage(msg, '%s == %s' % (safe_repr(first),
839                                                          safe_repr(second)))
840            raise self.failureException(msg)
841
842    def assertAlmostEqual(self, first, second, places=None, msg=None,
843                          delta=None):
844        """Fail if the two objects are unequal as determined by their
845           difference rounded to the given number of decimal places
846           (default 7) and comparing to zero, or by comparing that the
847           difference between the two objects is more than the given
848           delta.
849
850           Note that decimal places (from zero) are usually not the same
851           as significant digits (measured from the most significant digit).
852
853           If the two objects compare equal then they will automatically
854           compare almost equal.
855        """
856        if first == second:
857            # shortcut
858            return
859        if delta is not None and places is not None:
860            raise TypeError("specify delta or places not both")
861
862        diff = abs(first - second)
863        if delta is not None:
864            if diff <= delta:
865                return
866
867            standardMsg = '%s != %s within %s delta (%s difference)' % (
868                safe_repr(first),
869                safe_repr(second),
870                safe_repr(delta),
871                safe_repr(diff))
872        else:
873            if places is None:
874                places = 7
875
876            if round(diff, places) == 0:
877                return
878
879            standardMsg = '%s != %s within %r places (%s difference)' % (
880                safe_repr(first),
881                safe_repr(second),
882                places,
883                safe_repr(diff))
884        msg = self._formatMessage(msg, standardMsg)
885        raise self.failureException(msg)
886
887    def assertNotAlmostEqual(self, first, second, places=None, msg=None,
888                             delta=None):
889        """Fail if the two objects are equal as determined by their
890           difference rounded to the given number of decimal places
891           (default 7) and comparing to zero, or by comparing that the
892           difference between the two objects is less than the given delta.
893
894           Note that decimal places (from zero) are usually not the same
895           as significant digits (measured from the most significant digit).
896
897           Objects that are equal automatically fail.
898        """
899        if delta is not None and places is not None:
900            raise TypeError("specify delta or places not both")
901        diff = abs(first - second)
902        if delta is not None:
903            if not (first == second) and diff > delta:
904                return
905            standardMsg = '%s == %s within %s delta (%s difference)' % (
906                safe_repr(first),
907                safe_repr(second),
908                safe_repr(delta),
909                safe_repr(diff))
910        else:
911            if places is None:
912                places = 7
913            if not (first == second) and round(diff, places) != 0:
914                return
915            standardMsg = '%s == %s within %r places' % (safe_repr(first),
916                                                         safe_repr(second),
917                                                         places)
918
919        msg = self._formatMessage(msg, standardMsg)
920        raise self.failureException(msg)
921
922    def assertSequenceEqual(self, seq1, seq2, msg=None, seq_type=None):
923        """An equality assertion for ordered sequences (like lists and tuples).
924
925        For the purposes of this function, a valid ordered sequence type is one
926        which can be indexed, has a length, and has an equality operator.
927
928        Args:
929            seq1: The first sequence to compare.
930            seq2: The second sequence to compare.
931            seq_type: The expected datatype of the sequences, or None if no
932                    datatype should be enforced.
933            msg: Optional message to use on failure instead of a list of
934                    differences.
935        """
936        if seq_type is not None:
937            seq_type_name = seq_type.__name__
938            if not isinstance(seq1, seq_type):
939                raise self.failureException('First sequence is not a %s: %s'
940                                        % (seq_type_name, safe_repr(seq1)))
941            if not isinstance(seq2, seq_type):
942                raise self.failureException('Second sequence is not a %s: %s'
943                                        % (seq_type_name, safe_repr(seq2)))
944        else:
945            seq_type_name = "sequence"
946
947        differing = None
948        try:
949            len1 = len(seq1)
950        except (TypeError, NotImplementedError):
951            differing = 'First %s has no length.    Non-sequence?' % (
952                    seq_type_name)
953
954        if differing is None:
955            try:
956                len2 = len(seq2)
957            except (TypeError, NotImplementedError):
958                differing = 'Second %s has no length.    Non-sequence?' % (
959                        seq_type_name)
960
961        if differing is None:
962            if seq1 == seq2:
963                return
964
965            differing = '%ss differ: %s != %s\n' % (
966                    (seq_type_name.capitalize(),) +
967                    _common_shorten_repr(seq1, seq2))
968
969            for i in range(min(len1, len2)):
970                try:
971                    item1 = seq1[i]
972                except (TypeError, IndexError, NotImplementedError):
973                    differing += ('\nUnable to index element %d of first %s\n' %
974                                 (i, seq_type_name))
975                    break
976
977                try:
978                    item2 = seq2[i]
979                except (TypeError, IndexError, NotImplementedError):
980                    differing += ('\nUnable to index element %d of second %s\n' %
981                                 (i, seq_type_name))
982                    break
983
984                if item1 != item2:
985                    differing += ('\nFirst differing element %d:\n%s\n%s\n' %
986                                 ((i,) + _common_shorten_repr(item1, item2)))
987                    break
988            else:
989                if (len1 == len2 and seq_type is None and
990                    type(seq1) != type(seq2)):
991                    # The sequences are the same, but have differing types.
992                    return
993
994            if len1 > len2:
995                differing += ('\nFirst %s contains %d additional '
996                             'elements.\n' % (seq_type_name, len1 - len2))
997                try:
998                    differing += ('First extra element %d:\n%s\n' %
999                                  (len2, safe_repr(seq1[len2])))
1000                except (TypeError, IndexError, NotImplementedError):
1001                    differing += ('Unable to index element %d '
1002                                  'of first %s\n' % (len2, seq_type_name))
1003            elif len1 < len2:
1004                differing += ('\nSecond %s contains %d additional '
1005                             'elements.\n' % (seq_type_name, len2 - len1))
1006                try:
1007                    differing += ('First extra element %d:\n%s\n' %
1008                                  (len1, safe_repr(seq2[len1])))
1009                except (TypeError, IndexError, NotImplementedError):
1010                    differing += ('Unable to index element %d '
1011                                  'of second %s\n' % (len1, seq_type_name))
1012        standardMsg = differing
1013        diffMsg = '\n' + '\n'.join(
1014            difflib.ndiff(pprint.pformat(seq1).splitlines(),
1015                          pprint.pformat(seq2).splitlines()))
1016
1017        standardMsg = self._truncateMessage(standardMsg, diffMsg)
1018        msg = self._formatMessage(msg, standardMsg)
1019        self.fail(msg)
1020
1021    def _truncateMessage(self, message, diff):
1022        max_diff = self.maxDiff
1023        if max_diff is None or len(diff) <= max_diff:
1024            return message + diff
1025        return message + (DIFF_OMITTED % len(diff))
1026
1027    def assertListEqual(self, list1, list2, msg=None):
1028        """A list-specific equality assertion.
1029
1030        Args:
1031            list1: The first list to compare.
1032            list2: The second list to compare.
1033            msg: Optional message to use on failure instead of a list of
1034                    differences.
1035
1036        """
1037        self.assertSequenceEqual(list1, list2, msg, seq_type=list)
1038
1039    def assertTupleEqual(self, tuple1, tuple2, msg=None):
1040        """A tuple-specific equality assertion.
1041
1042        Args:
1043            tuple1: The first tuple to compare.
1044            tuple2: The second tuple to compare.
1045            msg: Optional message to use on failure instead of a list of
1046                    differences.
1047        """
1048        self.assertSequenceEqual(tuple1, tuple2, msg, seq_type=tuple)
1049
1050    def assertSetEqual(self, set1, set2, msg=None):
1051        """A set-specific equality assertion.
1052
1053        Args:
1054            set1: The first set to compare.
1055            set2: The second set to compare.
1056            msg: Optional message to use on failure instead of a list of
1057                    differences.
1058
1059        assertSetEqual uses ducktyping to support different types of sets, and
1060        is optimized for sets specifically (parameters must support a
1061        difference method).
1062        """
1063        try:
1064            difference1 = set1.difference(set2)
1065        except TypeError as e:
1066            self.fail('invalid type when attempting set difference: %s' % e)
1067        except AttributeError as e:
1068            self.fail('first argument does not support set difference: %s' % e)
1069
1070        try:
1071            difference2 = set2.difference(set1)
1072        except TypeError as e:
1073            self.fail('invalid type when attempting set difference: %s' % e)
1074        except AttributeError as e:
1075            self.fail('second argument does not support set difference: %s' % e)
1076
1077        if not (difference1 or difference2):
1078            return
1079
1080        lines = []
1081        if difference1:
1082            lines.append('Items in the first set but not the second:')
1083            for item in difference1:
1084                lines.append(repr(item))
1085        if difference2:
1086            lines.append('Items in the second set but not the first:')
1087            for item in difference2:
1088                lines.append(repr(item))
1089
1090        standardMsg = '\n'.join(lines)
1091        self.fail(self._formatMessage(msg, standardMsg))
1092
1093    def assertIn(self, member, container, msg=None):
1094        """Just like self.assertTrue(a in b), but with a nicer default message."""
1095        if member not in container:
1096            standardMsg = '%s not found in %s' % (safe_repr(member),
1097                                                  safe_repr(container))
1098            self.fail(self._formatMessage(msg, standardMsg))
1099
1100    def assertNotIn(self, member, container, msg=None):
1101        """Just like self.assertTrue(a not in b), but with a nicer default message."""
1102        if member in container:
1103            standardMsg = '%s unexpectedly found in %s' % (safe_repr(member),
1104                                                        safe_repr(container))
1105            self.fail(self._formatMessage(msg, standardMsg))
1106
1107    def assertIs(self, expr1, expr2, msg=None):
1108        """Just like self.assertTrue(a is b), but with a nicer default message."""
1109        if expr1 is not expr2:
1110            standardMsg = '%s is not %s' % (safe_repr(expr1),
1111                                             safe_repr(expr2))
1112            self.fail(self._formatMessage(msg, standardMsg))
1113
1114    def assertIsNot(self, expr1, expr2, msg=None):
1115        """Just like self.assertTrue(a is not b), but with a nicer default message."""
1116        if expr1 is expr2:
1117            standardMsg = 'unexpectedly identical: %s' % (safe_repr(expr1),)
1118            self.fail(self._formatMessage(msg, standardMsg))
1119
1120    def assertDictEqual(self, d1, d2, msg=None):
1121        self.assertIsInstance(d1, dict, 'First argument is not a dictionary')
1122        self.assertIsInstance(d2, dict, 'Second argument is not a dictionary')
1123
1124        if d1 != d2:
1125            standardMsg = '%s != %s' % _common_shorten_repr(d1, d2)
1126            diff = ('\n' + '\n'.join(difflib.ndiff(
1127                           pprint.pformat(d1).splitlines(),
1128                           pprint.pformat(d2).splitlines())))
1129            standardMsg = self._truncateMessage(standardMsg, diff)
1130            self.fail(self._formatMessage(msg, standardMsg))
1131
1132    def assertDictContainsSubset(self, subset, dictionary, msg=None):
1133        """Checks whether dictionary is a superset of subset."""
1134        warnings.warn('assertDictContainsSubset is deprecated',
1135                      DeprecationWarning)
1136        missing = []
1137        mismatched = []
1138        for key, value in subset.items():
1139            if key not in dictionary:
1140                missing.append(key)
1141            elif value != dictionary[key]:
1142                mismatched.append('%s, expected: %s, actual: %s' %
1143                                  (safe_repr(key), safe_repr(value),
1144                                   safe_repr(dictionary[key])))
1145
1146        if not (missing or mismatched):
1147            return
1148
1149        standardMsg = ''
1150        if missing:
1151            standardMsg = 'Missing: %s' % ','.join(safe_repr(m) for m in
1152                                                    missing)
1153        if mismatched:
1154            if standardMsg:
1155                standardMsg += '; '
1156            standardMsg += 'Mismatched values: %s' % ','.join(mismatched)
1157
1158        self.fail(self._formatMessage(msg, standardMsg))
1159
1160
1161    def assertCountEqual(self, first, second, msg=None):
1162        """Asserts that two iterables have the same elements, the same number of
1163        times, without regard to order.
1164
1165            self.assertEqual(Counter(list(first)),
1166                             Counter(list(second)))
1167
1168         Example:
1169            - [0, 1, 1] and [1, 0, 1] compare equal.
1170            - [0, 0, 1] and [0, 1] compare unequal.
1171
1172        """
1173        first_seq, second_seq = list(first), list(second)
1174        try:
1175            first = collections.Counter(first_seq)
1176            second = collections.Counter(second_seq)
1177        except TypeError:
1178            # Handle case with unhashable elements
1179            differences = _count_diff_all_purpose(first_seq, second_seq)
1180        else:
1181            if first == second:
1182                return
1183            differences = _count_diff_hashable(first_seq, second_seq)
1184
1185        if differences:
1186            standardMsg = 'Element counts were not equal:\n'
1187            lines = ['First has %d, Second has %d:  %r' % diff for diff in differences]
1188            diffMsg = '\n'.join(lines)
1189            standardMsg = self._truncateMessage(standardMsg, diffMsg)
1190            msg = self._formatMessage(msg, standardMsg)
1191            self.fail(msg)
1192
1193    def assertMultiLineEqual(self, first, second, msg=None):
1194        """Assert that two multi-line strings are equal."""
1195        self.assertIsInstance(first, str, 'First argument is not a string')
1196        self.assertIsInstance(second, str, 'Second argument is not a string')
1197
1198        if first != second:
1199            # don't use difflib if the strings are too long
1200            if (len(first) > self._diffThreshold or
1201                len(second) > self._diffThreshold):
1202                self._baseAssertEqual(first, second, msg)
1203            firstlines = first.splitlines(keepends=True)
1204            secondlines = second.splitlines(keepends=True)
1205            if len(firstlines) == 1 and first.strip('\r\n') == first:
1206                firstlines = [first + '\n']
1207                secondlines = [second + '\n']
1208            standardMsg = '%s != %s' % _common_shorten_repr(first, second)
1209            diff = '\n' + ''.join(difflib.ndiff(firstlines, secondlines))
1210            standardMsg = self._truncateMessage(standardMsg, diff)
1211            self.fail(self._formatMessage(msg, standardMsg))
1212
1213    def assertLess(self, a, b, msg=None):
1214        """Just like self.assertTrue(a < b), but with a nicer default message."""
1215        if not a < b:
1216            standardMsg = '%s not less than %s' % (safe_repr(a), safe_repr(b))
1217            self.fail(self._formatMessage(msg, standardMsg))
1218
1219    def assertLessEqual(self, a, b, msg=None):
1220        """Just like self.assertTrue(a <= b), but with a nicer default message."""
1221        if not a <= b:
1222            standardMsg = '%s not less than or equal to %s' % (safe_repr(a), safe_repr(b))
1223            self.fail(self._formatMessage(msg, standardMsg))
1224
1225    def assertGreater(self, a, b, msg=None):
1226        """Just like self.assertTrue(a > b), but with a nicer default message."""
1227        if not a > b:
1228            standardMsg = '%s not greater than %s' % (safe_repr(a), safe_repr(b))
1229            self.fail(self._formatMessage(msg, standardMsg))
1230
1231    def assertGreaterEqual(self, a, b, msg=None):
1232        """Just like self.assertTrue(a >= b), but with a nicer default message."""
1233        if not a >= b:
1234            standardMsg = '%s not greater than or equal to %s' % (safe_repr(a), safe_repr(b))
1235            self.fail(self._formatMessage(msg, standardMsg))
1236
1237    def assertIsNone(self, obj, msg=None):
1238        """Same as self.assertTrue(obj is None), with a nicer default message."""
1239        if obj is not None:
1240            standardMsg = '%s is not None' % (safe_repr(obj),)
1241            self.fail(self._formatMessage(msg, standardMsg))
1242
1243    def assertIsNotNone(self, obj, msg=None):
1244        """Included for symmetry with assertIsNone."""
1245        if obj is None:
1246            standardMsg = 'unexpectedly None'
1247            self.fail(self._formatMessage(msg, standardMsg))
1248
1249    def assertIsInstance(self, obj, cls, msg=None):
1250        """Same as self.assertTrue(isinstance(obj, cls)), with a nicer
1251        default message."""
1252        if not isinstance(obj, cls):
1253            standardMsg = '%s is not an instance of %r' % (safe_repr(obj), cls)
1254            self.fail(self._formatMessage(msg, standardMsg))
1255
1256    def assertNotIsInstance(self, obj, cls, msg=None):
1257        """Included for symmetry with assertIsInstance."""
1258        if isinstance(obj, cls):
1259            standardMsg = '%s is an instance of %r' % (safe_repr(obj), cls)
1260            self.fail(self._formatMessage(msg, standardMsg))
1261
1262    def assertRaisesRegex(self, expected_exception, expected_regex,
1263                          *args, **kwargs):
1264        """Asserts that the message in a raised exception matches a regex.
1265
1266        Args:
1267            expected_exception: Exception class expected to be raised.
1268            expected_regex: Regex (re.Pattern object or string) expected
1269                    to be found in error message.
1270            args: Function to be called and extra positional args.
1271            kwargs: Extra kwargs.
1272            msg: Optional message used in case of failure. Can only be used
1273                    when assertRaisesRegex is used as a context manager.
1274        """
1275        context = _AssertRaisesContext(expected_exception, self, expected_regex)
1276        return context.handle('assertRaisesRegex', args, kwargs)
1277
1278    def assertWarnsRegex(self, expected_warning, expected_regex,
1279                         *args, **kwargs):
1280        """Asserts that the message in a triggered warning matches a regexp.
1281        Basic functioning is similar to assertWarns() with the addition
1282        that only warnings whose messages also match the regular expression
1283        are considered successful matches.
1284
1285        Args:
1286            expected_warning: Warning class expected to be triggered.
1287            expected_regex: Regex (re.Pattern object or string) expected
1288                    to be found in error message.
1289            args: Function to be called and extra positional args.
1290            kwargs: Extra kwargs.
1291            msg: Optional message used in case of failure. Can only be used
1292                    when assertWarnsRegex is used as a context manager.
1293        """
1294        context = _AssertWarnsContext(expected_warning, self, expected_regex)
1295        return context.handle('assertWarnsRegex', args, kwargs)
1296
1297    def assertRegex(self, text, expected_regex, msg=None):
1298        """Fail the test unless the text matches the regular expression."""
1299        if isinstance(expected_regex, (str, bytes)):
1300            assert expected_regex, "expected_regex must not be empty."
1301            expected_regex = re.compile(expected_regex)
1302        if not expected_regex.search(text):
1303            standardMsg = "Regex didn't match: %r not found in %r" % (
1304                expected_regex.pattern, text)
1305            # _formatMessage ensures the longMessage option is respected
1306            msg = self._formatMessage(msg, standardMsg)
1307            raise self.failureException(msg)
1308
1309    def assertNotRegex(self, text, unexpected_regex, msg=None):
1310        """Fail the test if the text matches the regular expression."""
1311        if isinstance(unexpected_regex, (str, bytes)):
1312            unexpected_regex = re.compile(unexpected_regex)
1313        match = unexpected_regex.search(text)
1314        if match:
1315            standardMsg = 'Regex matched: %r matches %r in %r' % (
1316                text[match.start() : match.end()],
1317                unexpected_regex.pattern,
1318                text)
1319            # _formatMessage ensures the longMessage option is respected
1320            msg = self._formatMessage(msg, standardMsg)
1321            raise self.failureException(msg)
1322
1323
1324    def _deprecate(original_func):
1325        def deprecated_func(*args, **kwargs):
1326            warnings.warn(
1327                'Please use {0} instead.'.format(original_func.__name__),
1328                DeprecationWarning, 2)
1329            return original_func(*args, **kwargs)
1330        return deprecated_func
1331
1332    # see #9424
1333    failUnlessEqual = assertEquals = _deprecate(assertEqual)
1334    failIfEqual = assertNotEquals = _deprecate(assertNotEqual)
1335    failUnlessAlmostEqual = assertAlmostEquals = _deprecate(assertAlmostEqual)
1336    failIfAlmostEqual = assertNotAlmostEquals = _deprecate(assertNotAlmostEqual)
1337    failUnless = assert_ = _deprecate(assertTrue)
1338    failUnlessRaises = _deprecate(assertRaises)
1339    failIf = _deprecate(assertFalse)
1340    assertRaisesRegexp = _deprecate(assertRaisesRegex)
1341    assertRegexpMatches = _deprecate(assertRegex)
1342    assertNotRegexpMatches = _deprecate(assertNotRegex)
1343
1344
1345
1346class FunctionTestCase(TestCase):
1347    """A test case that wraps a test function.
1348
1349    This is useful for slipping pre-existing test functions into the
1350    unittest framework. Optionally, set-up and tidy-up functions can be
1351    supplied. As with TestCase, the tidy-up ('tearDown') function will
1352    always be called if the set-up ('setUp') function ran successfully.
1353    """
1354
1355    def __init__(self, testFunc, setUp=None, tearDown=None, description=None):
1356        super(FunctionTestCase, self).__init__()
1357        self._setUpFunc = setUp
1358        self._tearDownFunc = tearDown
1359        self._testFunc = testFunc
1360        self._description = description
1361
1362    def setUp(self):
1363        if self._setUpFunc is not None:
1364            self._setUpFunc()
1365
1366    def tearDown(self):
1367        if self._tearDownFunc is not None:
1368            self._tearDownFunc()
1369
1370    def runTest(self):
1371        self._testFunc()
1372
1373    def id(self):
1374        return self._testFunc.__name__
1375
1376    def __eq__(self, other):
1377        if not isinstance(other, self.__class__):
1378            return NotImplemented
1379
1380        return self._setUpFunc == other._setUpFunc and \
1381               self._tearDownFunc == other._tearDownFunc and \
1382               self._testFunc == other._testFunc and \
1383               self._description == other._description
1384
1385    def __hash__(self):
1386        return hash((type(self), self._setUpFunc, self._tearDownFunc,
1387                     self._testFunc, self._description))
1388
1389    def __str__(self):
1390        return "%s (%s)" % (strclass(self.__class__),
1391                            self._testFunc.__name__)
1392
1393    def __repr__(self):
1394        return "<%s tec=%s>" % (strclass(self.__class__),
1395                                     self._testFunc)
1396
1397    def shortDescription(self):
1398        if self._description is not None:
1399            return self._description
1400        doc = self._testFunc.__doc__
1401        return doc and doc.split("\n")[0].strip() or None
1402
1403
1404class _SubTest(TestCase):
1405
1406    def __init__(self, test_case, message, params):
1407        super().__init__()
1408        self._message = message
1409        self.test_case = test_case
1410        self.params = params
1411        self.failureException = test_case.failureException
1412
1413    def runTest(self):
1414        raise NotImplementedError("subtests cannot be run directly")
1415
1416    def _subDescription(self):
1417        parts = []
1418        if self._message is not _subtest_msg_sentinel:
1419            parts.append("[{}]".format(self._message))
1420        if self.params:
1421            params_desc = ', '.join(
1422                "{}={!r}".format(k, v)
1423                for (k, v) in self.params.items())
1424            parts.append("({})".format(params_desc))
1425        return " ".join(parts) or '(<subtest>)'
1426
1427    def id(self):
1428        return "{} {}".format(self.test_case.id(), self._subDescription())
1429
1430    def shortDescription(self):
1431        """Returns a one-line description of the subtest, or None if no
1432        description has been provided.
1433        """
1434        return self.test_case.shortDescription()
1435
1436    def __str__(self):
1437        return "{} {}".format(self.test_case, self._subDescription())
1438