• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Test case implementation"""
2
3import collections
4import sys
5import functools
6import difflib
7import pprint
8import re
9import types
10import warnings
11
12from . import result
13from .util import (
14    strclass, safe_repr, unorderable_list_difference,
15    _count_diff_all_purpose, _count_diff_hashable
16)
17
18
19__unittest = True
20
21
22DIFF_OMITTED = ('\nDiff is %s characters long. '
23                 'Set self.maxDiff to None to see it.')
24
25class SkipTest(Exception):
26    """
27    Raise this exception in a test to skip it.
28
29    Usually you can use TestCase.skipTest() or one of the skipping decorators
30    instead of raising this directly.
31    """
32    pass
33
34class _ExpectedFailure(Exception):
35    """
36    Raise this when a test is expected to fail.
37
38    This is an implementation detail.
39    """
40
41    def __init__(self, exc_info):
42        super(_ExpectedFailure, self).__init__()
43        self.exc_info = exc_info
44
45class _UnexpectedSuccess(Exception):
46    """
47    The test was supposed to fail, but it didn't!
48    """
49    pass
50
51def _id(obj):
52    return obj
53
54def skip(reason):
55    """
56    Unconditionally skip a test.
57    """
58    def decorator(test_item):
59        if not isinstance(test_item, (type, types.ClassType)):
60            @functools.wraps(test_item)
61            def skip_wrapper(*args, **kwargs):
62                raise SkipTest(reason)
63            test_item = skip_wrapper
64
65        test_item.__unittest_skip__ = True
66        test_item.__unittest_skip_why__ = reason
67        return test_item
68    return decorator
69
70def skipIf(condition, reason):
71    """
72    Skip a test if the condition is true.
73    """
74    if condition:
75        return skip(reason)
76    return _id
77
78def skipUnless(condition, reason):
79    """
80    Skip a test unless the condition is true.
81    """
82    if not condition:
83        return skip(reason)
84    return _id
85
86
87def expectedFailure(func):
88    @functools.wraps(func)
89    def wrapper(*args, **kwargs):
90        try:
91            func(*args, **kwargs)
92        except Exception:
93            raise _ExpectedFailure(sys.exc_info())
94        raise _UnexpectedSuccess
95    return wrapper
96
97
98class _AssertRaisesContext(object):
99    """A context manager used to implement TestCase.assertRaises* methods."""
100
101    def __init__(self, expected, test_case, expected_regexp=None):
102        self.expected = expected
103        self.failureException = test_case.failureException
104        self.expected_regexp = expected_regexp
105
106    def __enter__(self):
107        return self
108
109    def __exit__(self, exc_type, exc_value, tb):
110        if exc_type is None:
111            try:
112                exc_name = self.expected.__name__
113            except AttributeError:
114                exc_name = str(self.expected)
115            raise self.failureException(
116                "{0} not raised".format(exc_name))
117        if not issubclass(exc_type, self.expected):
118            # let unexpected exceptions pass through
119            return False
120        self.exception = exc_value # store for later retrieval
121        if self.expected_regexp is None:
122            return True
123
124        expected_regexp = self.expected_regexp
125        if not expected_regexp.search(str(exc_value)):
126            raise self.failureException('"%s" does not match "%s"' %
127                     (expected_regexp.pattern, str(exc_value)))
128        return True
129
130
131class TestCase(object):
132    """A class whose instances are single test cases.
133
134    By default, the test code itself should be placed in a method named
135    'runTest'.
136
137    If the fixture may be used for many test cases, create as
138    many test methods as are needed. When instantiating such a TestCase
139    subclass, specify in the constructor arguments the name of the test method
140    that the instance is to execute.
141
142    Test authors should subclass TestCase for their own tests. Construction
143    and deconstruction of the test's environment ('fixture') can be
144    implemented by overriding the 'setUp' and 'tearDown' methods respectively.
145
146    If it is necessary to override the __init__ method, the base class
147    __init__ method must always be called. It is important that subclasses
148    should not change the signature of their __init__ method, since instances
149    of the classes are instantiated automatically by parts of the framework
150    in order to be run.
151
152    When subclassing TestCase, you can set these attributes:
153    * failureException: determines which exception will be raised when
154        the instance's assertion methods fail; test methods raising this
155        exception will be deemed to have 'failed' rather than 'errored'.
156    * longMessage: determines whether long messages (including repr of
157        objects used in assert methods) will be printed on failure in *addition*
158        to any explicit message passed.
159    * maxDiff: sets the maximum length of a diff in failure messages
160        by assert methods using difflib. It is looked up as an instance
161        attribute so can be configured by individual tests if required.
162    """
163
164    failureException = AssertionError
165
166    longMessage = False
167
168    maxDiff = 80*8
169
170    # If a string is longer than _diffThreshold, use normal comparison instead
171    # of difflib.  See #11763.
172    _diffThreshold = 2**16
173
174    # Attribute used by TestSuite for classSetUp
175
176    _classSetupFailed = False
177
178    def __init__(self, methodName='runTest'):
179        """Create an instance of the class that will use the named test
180           method when executed. Raises a ValueError if the instance does
181           not have a method with the specified name.
182        """
183        self._testMethodName = methodName
184        self._resultForDoCleanups = None
185        try:
186            testMethod = getattr(self, methodName)
187        except AttributeError:
188            raise ValueError("no such test method in %s: %s" %
189                  (self.__class__, methodName))
190        self._testMethodDoc = testMethod.__doc__
191        self._cleanups = []
192
193        # Map types to custom assertEqual functions that will compare
194        # instances of said type in more detail to generate a more useful
195        # error message.
196        self._type_equality_funcs = {}
197        self.addTypeEqualityFunc(dict, 'assertDictEqual')
198        self.addTypeEqualityFunc(list, 'assertListEqual')
199        self.addTypeEqualityFunc(tuple, 'assertTupleEqual')
200        self.addTypeEqualityFunc(set, 'assertSetEqual')
201        self.addTypeEqualityFunc(frozenset, 'assertSetEqual')
202        try:
203            self.addTypeEqualityFunc(unicode, 'assertMultiLineEqual')
204        except NameError:
205            # No unicode support in this build
206            pass
207
208    def addTypeEqualityFunc(self, typeobj, function):
209        """Add a type specific assertEqual style function to compare a type.
210
211        This method is for use by TestCase subclasses that need to register
212        their own type equality functions to provide nicer error messages.
213
214        Args:
215            typeobj: The data type to call this function on when both values
216                    are of the same type in assertEqual().
217            function: The callable taking two arguments and an optional
218                    msg= argument that raises self.failureException with a
219                    useful error message when the two arguments are not equal.
220        """
221        self._type_equality_funcs[typeobj] = function
222
223    def addCleanup(self, function, *args, **kwargs):
224        """Add a function, with arguments, to be called when the test is
225        completed. Functions added are called on a LIFO basis and are
226        called after tearDown on test failure or success.
227
228        Cleanup items are called even if setUp fails (unlike tearDown)."""
229        self._cleanups.append((function, args, kwargs))
230
231    def setUp(self):
232        "Hook method for setting up the test fixture before exercising it."
233        pass
234
235    def tearDown(self):
236        "Hook method for deconstructing the test fixture after testing it."
237        pass
238
239    @classmethod
240    def setUpClass(cls):
241        "Hook method for setting up class fixture before running tests in the class."
242
243    @classmethod
244    def tearDownClass(cls):
245        "Hook method for deconstructing the class fixture after running all tests in the class."
246
247    def countTestCases(self):
248        return 1
249
250    def defaultTestResult(self):
251        return result.TestResult()
252
253    def shortDescription(self):
254        """Returns a one-line description of the test, or None if no
255        description has been provided.
256
257        The default implementation of this method returns the first line of
258        the specified test method's docstring.
259        """
260        doc = self._testMethodDoc
261        return doc and doc.split("\n")[0].strip() or None
262
263
264    def id(self):
265        return "%s.%s" % (strclass(self.__class__), self._testMethodName)
266
267    def __eq__(self, other):
268        if type(self) is not type(other):
269            return NotImplemented
270
271        return self._testMethodName == other._testMethodName
272
273    def __ne__(self, other):
274        return not self == other
275
276    def __hash__(self):
277        return hash((type(self), self._testMethodName))
278
279    def __str__(self):
280        return "%s (%s)" % (self._testMethodName, strclass(self.__class__))
281
282    def __repr__(self):
283        return "<%s testMethod=%s>" % \
284               (strclass(self.__class__), self._testMethodName)
285
286    def _addSkip(self, result, reason):
287        addSkip = getattr(result, 'addSkip', None)
288        if addSkip is not None:
289            addSkip(self, reason)
290        else:
291            warnings.warn("TestResult has no addSkip method, skips not reported",
292                          RuntimeWarning, 2)
293            result.addSuccess(self)
294
295    def run(self, result=None):
296        orig_result = result
297        if result is None:
298            result = self.defaultTestResult()
299            startTestRun = getattr(result, 'startTestRun', None)
300            if startTestRun is not None:
301                startTestRun()
302
303        self._resultForDoCleanups = result
304        result.startTest(self)
305
306        testMethod = getattr(self, self._testMethodName)
307        if (getattr(self.__class__, "__unittest_skip__", False) or
308            getattr(testMethod, "__unittest_skip__", False)):
309            # If the class or method was skipped.
310            try:
311                skip_why = (getattr(self.__class__, '__unittest_skip_why__', '')
312                            or getattr(testMethod, '__unittest_skip_why__', ''))
313                self._addSkip(result, skip_why)
314            finally:
315                result.stopTest(self)
316            return
317        try:
318            success = False
319            try:
320                self.setUp()
321            except SkipTest as e:
322                self._addSkip(result, str(e))
323            except KeyboardInterrupt:
324                raise
325            except:
326                result.addError(self, sys.exc_info())
327            else:
328                try:
329                    testMethod()
330                except KeyboardInterrupt:
331                    raise
332                except self.failureException:
333                    result.addFailure(self, sys.exc_info())
334                except _ExpectedFailure as e:
335                    addExpectedFailure = getattr(result, 'addExpectedFailure', None)
336                    if addExpectedFailure is not None:
337                        addExpectedFailure(self, e.exc_info)
338                    else:
339                        warnings.warn("TestResult has no addExpectedFailure method, reporting as passes",
340                                      RuntimeWarning)
341                        result.addSuccess(self)
342                except _UnexpectedSuccess:
343                    addUnexpectedSuccess = getattr(result, 'addUnexpectedSuccess', None)
344                    if addUnexpectedSuccess is not None:
345                        addUnexpectedSuccess(self)
346                    else:
347                        warnings.warn("TestResult has no addUnexpectedSuccess method, reporting as failures",
348                                      RuntimeWarning)
349                        result.addFailure(self, sys.exc_info())
350                except SkipTest as e:
351                    self._addSkip(result, str(e))
352                except:
353                    result.addError(self, sys.exc_info())
354                else:
355                    success = True
356
357                try:
358                    self.tearDown()
359                except KeyboardInterrupt:
360                    raise
361                except:
362                    result.addError(self, sys.exc_info())
363                    success = False
364
365            cleanUpSuccess = self.doCleanups()
366            success = success and cleanUpSuccess
367            if success:
368                result.addSuccess(self)
369        finally:
370            result.stopTest(self)
371            if orig_result is None:
372                stopTestRun = getattr(result, 'stopTestRun', None)
373                if stopTestRun is not None:
374                    stopTestRun()
375
376    def doCleanups(self):
377        """Execute all cleanup functions. Normally called for you after
378        tearDown."""
379        result = self._resultForDoCleanups
380        ok = True
381        while self._cleanups:
382            function, args, kwargs = self._cleanups.pop(-1)
383            try:
384                function(*args, **kwargs)
385            except KeyboardInterrupt:
386                raise
387            except:
388                ok = False
389                result.addError(self, sys.exc_info())
390        return ok
391
392    def __call__(self, *args, **kwds):
393        return self.run(*args, **kwds)
394
395    def debug(self):
396        """Run the test without collecting errors in a TestResult"""
397        self.setUp()
398        getattr(self, self._testMethodName)()
399        self.tearDown()
400        while self._cleanups:
401            function, args, kwargs = self._cleanups.pop(-1)
402            function(*args, **kwargs)
403
404    def skipTest(self, reason):
405        """Skip this test."""
406        raise SkipTest(reason)
407
408    def fail(self, msg=None):
409        """Fail immediately, with the given message."""
410        raise self.failureException(msg)
411
412    def assertFalse(self, expr, msg=None):
413        """Check that the expression is false."""
414        if expr:
415            msg = self._formatMessage(msg, "%s is not false" % safe_repr(expr))
416            raise self.failureException(msg)
417
418    def assertTrue(self, expr, msg=None):
419        """Check that the expression is true."""
420        if not expr:
421            msg = self._formatMessage(msg, "%s is not true" % safe_repr(expr))
422            raise self.failureException(msg)
423
424    def _formatMessage(self, msg, standardMsg):
425        """Honour the longMessage attribute when generating failure messages.
426        If longMessage is False this means:
427        * Use only an explicit message if it is provided
428        * Otherwise use the standard message for the assert
429
430        If longMessage is True:
431        * Use the standard message
432        * If an explicit message is provided, plus ' : ' and the explicit message
433        """
434        if not self.longMessage:
435            return msg or standardMsg
436        if msg is None:
437            return standardMsg
438        try:
439            # don't switch to '{}' formatting in Python 2.X
440            # it changes the way unicode input is handled
441            return '%s : %s' % (standardMsg, msg)
442        except UnicodeDecodeError:
443            return  '%s : %s' % (safe_repr(standardMsg), safe_repr(msg))
444
445
446    def assertRaises(self, excClass, callableObj=None, *args, **kwargs):
447        """Fail unless an exception of class excClass is raised
448           by callableObj when invoked with arguments args and keyword
449           arguments kwargs. If a different type of exception is
450           raised, it will not be caught, and the test case will be
451           deemed to have suffered an error, exactly as for an
452           unexpected exception.
453
454           If called with callableObj omitted or None, will return a
455           context object used like this::
456
457                with self.assertRaises(SomeException):
458                    do_something()
459
460           The context manager keeps a reference to the exception as
461           the 'exception' attribute. This allows you to inspect the
462           exception after the assertion::
463
464               with self.assertRaises(SomeException) as cm:
465                   do_something()
466               the_exception = cm.exception
467               self.assertEqual(the_exception.error_code, 3)
468        """
469        context = _AssertRaisesContext(excClass, self)
470        if callableObj is None:
471            return context
472        with context:
473            callableObj(*args, **kwargs)
474
475    def _getAssertEqualityFunc(self, first, second):
476        """Get a detailed comparison function for the types of the two args.
477
478        Returns: A callable accepting (first, second, msg=None) that will
479        raise a failure exception if first != second with a useful human
480        readable error message for those types.
481        """
482        #
483        # NOTE(gregory.p.smith): I considered isinstance(first, type(second))
484        # and vice versa.  I opted for the conservative approach in case
485        # subclasses are not intended to be compared in detail to their super
486        # class instances using a type equality func.  This means testing
487        # subtypes won't automagically use the detailed comparison.  Callers
488        # should use their type specific assertSpamEqual method to compare
489        # subclasses if the detailed comparison is desired and appropriate.
490        # See the discussion in http://bugs.python.org/issue2578.
491        #
492        if type(first) is type(second):
493            asserter = self._type_equality_funcs.get(type(first))
494            if asserter is not None:
495                if isinstance(asserter, basestring):
496                    asserter = getattr(self, asserter)
497                return asserter
498
499        return self._baseAssertEqual
500
501    def _baseAssertEqual(self, first, second, msg=None):
502        """The default assertEqual implementation, not type specific."""
503        if not first == second:
504            standardMsg = '%s != %s' % (safe_repr(first), safe_repr(second))
505            msg = self._formatMessage(msg, standardMsg)
506            raise self.failureException(msg)
507
508    def assertEqual(self, first, second, msg=None):
509        """Fail if the two objects are unequal as determined by the '=='
510           operator.
511        """
512        assertion_func = self._getAssertEqualityFunc(first, second)
513        assertion_func(first, second, msg=msg)
514
515    def assertNotEqual(self, first, second, msg=None):
516        """Fail if the two objects are equal as determined by the '!='
517           operator.
518        """
519        if not first != second:
520            msg = self._formatMessage(msg, '%s == %s' % (safe_repr(first),
521                                                          safe_repr(second)))
522            raise self.failureException(msg)
523
524
525    def assertAlmostEqual(self, first, second, places=None, msg=None, delta=None):
526        """Fail if the two objects are unequal as determined by their
527           difference rounded to the given number of decimal places
528           (default 7) and comparing to zero, or by comparing that the
529           between the two objects is more than the given delta.
530
531           Note that decimal places (from zero) are usually not the same
532           as significant digits (measured from the most significant digit).
533
534           If the two objects compare equal then they will automatically
535           compare almost equal.
536        """
537        if first == second:
538            # shortcut
539            return
540        if delta is not None and places is not None:
541            raise TypeError("specify delta or places not both")
542
543        if delta is not None:
544            if abs(first - second) <= delta:
545                return
546
547            standardMsg = '%s != %s within %s delta' % (safe_repr(first),
548                                                        safe_repr(second),
549                                                        safe_repr(delta))
550        else:
551            if places is None:
552                places = 7
553
554            if round(abs(second-first), places) == 0:
555                return
556
557            standardMsg = '%s != %s within %r places' % (safe_repr(first),
558                                                          safe_repr(second),
559                                                          places)
560        msg = self._formatMessage(msg, standardMsg)
561        raise self.failureException(msg)
562
563    def assertNotAlmostEqual(self, first, second, places=None, msg=None, delta=None):
564        """Fail if the two objects are equal as determined by their
565           difference rounded to the given number of decimal places
566           (default 7) and comparing to zero, or by comparing that the
567           between the two objects is less than the given delta.
568
569           Note that decimal places (from zero) are usually not the same
570           as significant digits (measured from the most significant digit).
571
572           Objects that are equal automatically fail.
573        """
574        if delta is not None and places is not None:
575            raise TypeError("specify delta or places not both")
576        if delta is not None:
577            if not (first == second) and abs(first - second) > delta:
578                return
579            standardMsg = '%s == %s within %s delta' % (safe_repr(first),
580                                                        safe_repr(second),
581                                                        safe_repr(delta))
582        else:
583            if places is None:
584                places = 7
585            if not (first == second) and round(abs(second-first), places) != 0:
586                return
587            standardMsg = '%s == %s within %r places' % (safe_repr(first),
588                                                         safe_repr(second),
589                                                         places)
590
591        msg = self._formatMessage(msg, standardMsg)
592        raise self.failureException(msg)
593
594    # Synonyms for assertion methods
595
596    # The plurals are undocumented.  Keep them that way to discourage use.
597    # Do not add more.  Do not remove.
598    # Going through a deprecation cycle on these would annoy many people.
599    assertEquals = assertEqual
600    assertNotEquals = assertNotEqual
601    assertAlmostEquals = assertAlmostEqual
602    assertNotAlmostEquals = assertNotAlmostEqual
603    assert_ = assertTrue
604
605    # These fail* assertion method names are pending deprecation and will
606    # be a DeprecationWarning in 3.2; http://bugs.python.org/issue2578
607    def _deprecate(original_func):
608        def deprecated_func(*args, **kwargs):
609            warnings.warn(
610                'Please use {0} instead.'.format(original_func.__name__),
611                PendingDeprecationWarning, 2)
612            return original_func(*args, **kwargs)
613        return deprecated_func
614
615    failUnlessEqual = _deprecate(assertEqual)
616    failIfEqual = _deprecate(assertNotEqual)
617    failUnlessAlmostEqual = _deprecate(assertAlmostEqual)
618    failIfAlmostEqual = _deprecate(assertNotAlmostEqual)
619    failUnless = _deprecate(assertTrue)
620    failUnlessRaises = _deprecate(assertRaises)
621    failIf = _deprecate(assertFalse)
622
623    def assertSequenceEqual(self, seq1, seq2, msg=None, seq_type=None):
624        """An equality assertion for ordered sequences (like lists and tuples).
625
626        For the purposes of this function, a valid ordered sequence type is one
627        which can be indexed, has a length, and has an equality operator.
628
629        Args:
630            seq1: The first sequence to compare.
631            seq2: The second sequence to compare.
632            seq_type: The expected datatype of the sequences, or None if no
633                    datatype should be enforced.
634            msg: Optional message to use on failure instead of a list of
635                    differences.
636        """
637        if seq_type is not None:
638            seq_type_name = seq_type.__name__
639            if not isinstance(seq1, seq_type):
640                raise self.failureException('First sequence is not a %s: %s'
641                                        % (seq_type_name, safe_repr(seq1)))
642            if not isinstance(seq2, seq_type):
643                raise self.failureException('Second sequence is not a %s: %s'
644                                        % (seq_type_name, safe_repr(seq2)))
645        else:
646            seq_type_name = "sequence"
647
648        differing = None
649        try:
650            len1 = len(seq1)
651        except (TypeError, NotImplementedError):
652            differing = 'First %s has no length.    Non-sequence?' % (
653                    seq_type_name)
654
655        if differing is None:
656            try:
657                len2 = len(seq2)
658            except (TypeError, NotImplementedError):
659                differing = 'Second %s has no length.    Non-sequence?' % (
660                        seq_type_name)
661
662        if differing is None:
663            if seq1 == seq2:
664                return
665
666            seq1_repr = safe_repr(seq1)
667            seq2_repr = safe_repr(seq2)
668            if len(seq1_repr) > 30:
669                seq1_repr = seq1_repr[:30] + '...'
670            if len(seq2_repr) > 30:
671                seq2_repr = seq2_repr[:30] + '...'
672            elements = (seq_type_name.capitalize(), seq1_repr, seq2_repr)
673            differing = '%ss differ: %s != %s\n' % elements
674
675            for i in xrange(min(len1, len2)):
676                try:
677                    item1 = seq1[i]
678                except (TypeError, IndexError, NotImplementedError):
679                    differing += ('\nUnable to index element %d of first %s\n' %
680                                 (i, seq_type_name))
681                    break
682
683                try:
684                    item2 = seq2[i]
685                except (TypeError, IndexError, NotImplementedError):
686                    differing += ('\nUnable to index element %d of second %s\n' %
687                                 (i, seq_type_name))
688                    break
689
690                if item1 != item2:
691                    differing += ('\nFirst differing element %d:\n%s\n%s\n' %
692                                 (i, safe_repr(item1), safe_repr(item2)))
693                    break
694            else:
695                if (len1 == len2 and seq_type is None and
696                    type(seq1) != type(seq2)):
697                    # The sequences are the same, but have differing types.
698                    return
699
700            if len1 > len2:
701                differing += ('\nFirst %s contains %d additional '
702                             'elements.\n' % (seq_type_name, len1 - len2))
703                try:
704                    differing += ('First extra element %d:\n%s\n' %
705                                  (len2, safe_repr(seq1[len2])))
706                except (TypeError, IndexError, NotImplementedError):
707                    differing += ('Unable to index element %d '
708                                  'of first %s\n' % (len2, seq_type_name))
709            elif len1 < len2:
710                differing += ('\nSecond %s contains %d additional '
711                             'elements.\n' % (seq_type_name, len2 - len1))
712                try:
713                    differing += ('First extra element %d:\n%s\n' %
714                                  (len1, safe_repr(seq2[len1])))
715                except (TypeError, IndexError, NotImplementedError):
716                    differing += ('Unable to index element %d '
717                                  'of second %s\n' % (len1, seq_type_name))
718        standardMsg = differing
719        diffMsg = '\n' + '\n'.join(
720            difflib.ndiff(pprint.pformat(seq1).splitlines(),
721                          pprint.pformat(seq2).splitlines()))
722        standardMsg = self._truncateMessage(standardMsg, diffMsg)
723        msg = self._formatMessage(msg, standardMsg)
724        self.fail(msg)
725
726    def _truncateMessage(self, message, diff):
727        max_diff = self.maxDiff
728        if max_diff is None or len(diff) <= max_diff:
729            return message + diff
730        return message + (DIFF_OMITTED % len(diff))
731
732    def assertListEqual(self, list1, list2, msg=None):
733        """A list-specific equality assertion.
734
735        Args:
736            list1: The first list to compare.
737            list2: The second list to compare.
738            msg: Optional message to use on failure instead of a list of
739                    differences.
740
741        """
742        self.assertSequenceEqual(list1, list2, msg, seq_type=list)
743
744    def assertTupleEqual(self, tuple1, tuple2, msg=None):
745        """A tuple-specific equality assertion.
746
747        Args:
748            tuple1: The first tuple to compare.
749            tuple2: The second tuple to compare.
750            msg: Optional message to use on failure instead of a list of
751                    differences.
752        """
753        self.assertSequenceEqual(tuple1, tuple2, msg, seq_type=tuple)
754
755    def assertSetEqual(self, set1, set2, msg=None):
756        """A set-specific equality assertion.
757
758        Args:
759            set1: The first set to compare.
760            set2: The second set to compare.
761            msg: Optional message to use on failure instead of a list of
762                    differences.
763
764        assertSetEqual uses ducktyping to support different types of sets, and
765        is optimized for sets specifically (parameters must support a
766        difference method).
767        """
768        try:
769            difference1 = set1.difference(set2)
770        except TypeError, e:
771            self.fail('invalid type when attempting set difference: %s' % e)
772        except AttributeError, e:
773            self.fail('first argument does not support set difference: %s' % e)
774
775        try:
776            difference2 = set2.difference(set1)
777        except TypeError, e:
778            self.fail('invalid type when attempting set difference: %s' % e)
779        except AttributeError, e:
780            self.fail('second argument does not support set difference: %s' % e)
781
782        if not (difference1 or difference2):
783            return
784
785        lines = []
786        if difference1:
787            lines.append('Items in the first set but not the second:')
788            for item in difference1:
789                lines.append(repr(item))
790        if difference2:
791            lines.append('Items in the second set but not the first:')
792            for item in difference2:
793                lines.append(repr(item))
794
795        standardMsg = '\n'.join(lines)
796        self.fail(self._formatMessage(msg, standardMsg))
797
798    def assertIn(self, member, container, msg=None):
799        """Just like self.assertTrue(a in b), but with a nicer default message."""
800        if member not in container:
801            standardMsg = '%s not found in %s' % (safe_repr(member),
802                                                  safe_repr(container))
803            self.fail(self._formatMessage(msg, standardMsg))
804
805    def assertNotIn(self, member, container, msg=None):
806        """Just like self.assertTrue(a not in b), but with a nicer default message."""
807        if member in container:
808            standardMsg = '%s unexpectedly found in %s' % (safe_repr(member),
809                                                        safe_repr(container))
810            self.fail(self._formatMessage(msg, standardMsg))
811
812    def assertIs(self, expr1, expr2, msg=None):
813        """Just like self.assertTrue(a is b), but with a nicer default message."""
814        if expr1 is not expr2:
815            standardMsg = '%s is not %s' % (safe_repr(expr1),
816                                             safe_repr(expr2))
817            self.fail(self._formatMessage(msg, standardMsg))
818
819    def assertIsNot(self, expr1, expr2, msg=None):
820        """Just like self.assertTrue(a is not b), but with a nicer default message."""
821        if expr1 is expr2:
822            standardMsg = 'unexpectedly identical: %s' % (safe_repr(expr1),)
823            self.fail(self._formatMessage(msg, standardMsg))
824
825    def assertDictEqual(self, d1, d2, msg=None):
826        self.assertIsInstance(d1, dict, 'First argument is not a dictionary')
827        self.assertIsInstance(d2, dict, 'Second argument is not a dictionary')
828
829        if d1 != d2:
830            standardMsg = '%s != %s' % (safe_repr(d1, True), safe_repr(d2, True))
831            diff = ('\n' + '\n'.join(difflib.ndiff(
832                           pprint.pformat(d1).splitlines(),
833                           pprint.pformat(d2).splitlines())))
834            standardMsg = self._truncateMessage(standardMsg, diff)
835            self.fail(self._formatMessage(msg, standardMsg))
836
837    def assertDictContainsSubset(self, expected, actual, msg=None):
838        """Checks whether actual is a superset of expected."""
839        missing = []
840        mismatched = []
841        for key, value in expected.iteritems():
842            if key not in actual:
843                missing.append(key)
844            elif value != actual[key]:
845                mismatched.append('%s, expected: %s, actual: %s' %
846                                  (safe_repr(key), safe_repr(value),
847                                   safe_repr(actual[key])))
848
849        if not (missing or mismatched):
850            return
851
852        standardMsg = ''
853        if missing:
854            standardMsg = 'Missing: %s' % ','.join(safe_repr(m) for m in
855                                                    missing)
856        if mismatched:
857            if standardMsg:
858                standardMsg += '; '
859            standardMsg += 'Mismatched values: %s' % ','.join(mismatched)
860
861        self.fail(self._formatMessage(msg, standardMsg))
862
863    def assertItemsEqual(self, expected_seq, actual_seq, msg=None):
864        """An unordered sequence specific comparison. It asserts that
865        actual_seq and expected_seq have the same element counts.
866        Equivalent to::
867
868            self.assertEqual(Counter(iter(actual_seq)),
869                             Counter(iter(expected_seq)))
870
871        Asserts that each element has the same count in both sequences.
872        Example:
873            - [0, 1, 1] and [1, 0, 1] compare equal.
874            - [0, 0, 1] and [0, 1] compare unequal.
875        """
876        first_seq, second_seq = list(expected_seq), list(actual_seq)
877        with warnings.catch_warnings():
878            if sys.py3kwarning:
879                # Silence Py3k warning raised during the sorting
880                for _msg in ["(code|dict|type) inequality comparisons",
881                             "builtin_function_or_method order comparisons",
882                             "comparing unequal types"]:
883                    warnings.filterwarnings("ignore", _msg, DeprecationWarning)
884            try:
885                first = collections.Counter(first_seq)
886                second = collections.Counter(second_seq)
887            except TypeError:
888                # Handle case with unhashable elements
889                differences = _count_diff_all_purpose(first_seq, second_seq)
890            else:
891                if first == second:
892                    return
893                differences = _count_diff_hashable(first_seq, second_seq)
894
895        if differences:
896            standardMsg = 'Element counts were not equal:\n'
897            lines = ['First has %d, Second has %d:  %r' % diff for diff in differences]
898            diffMsg = '\n'.join(lines)
899            standardMsg = self._truncateMessage(standardMsg, diffMsg)
900            msg = self._formatMessage(msg, standardMsg)
901            self.fail(msg)
902
903    def assertMultiLineEqual(self, first, second, msg=None):
904        """Assert that two multi-line strings are equal."""
905        self.assertIsInstance(first, basestring,
906                'First argument is not a string')
907        self.assertIsInstance(second, basestring,
908                'Second argument is not a string')
909
910        if first != second:
911            # don't use difflib if the strings are too long
912            if (len(first) > self._diffThreshold or
913                len(second) > self._diffThreshold):
914                self._baseAssertEqual(first, second, msg)
915            firstlines = first.splitlines(True)
916            secondlines = second.splitlines(True)
917            if len(firstlines) == 1 and first.strip('\r\n') == first:
918                firstlines = [first + '\n']
919                secondlines = [second + '\n']
920            standardMsg = '%s != %s' % (safe_repr(first, True),
921                                        safe_repr(second, True))
922            diff = '\n' + ''.join(difflib.ndiff(firstlines, secondlines))
923            standardMsg = self._truncateMessage(standardMsg, diff)
924            self.fail(self._formatMessage(msg, standardMsg))
925
926    def assertLess(self, a, b, msg=None):
927        """Just like self.assertTrue(a < b), but with a nicer default message."""
928        if not a < b:
929            standardMsg = '%s not less than %s' % (safe_repr(a), safe_repr(b))
930            self.fail(self._formatMessage(msg, standardMsg))
931
932    def assertLessEqual(self, a, b, msg=None):
933        """Just like self.assertTrue(a <= b), but with a nicer default message."""
934        if not a <= b:
935            standardMsg = '%s not less than or equal to %s' % (safe_repr(a), safe_repr(b))
936            self.fail(self._formatMessage(msg, standardMsg))
937
938    def assertGreater(self, a, b, msg=None):
939        """Just like self.assertTrue(a > b), but with a nicer default message."""
940        if not a > b:
941            standardMsg = '%s not greater than %s' % (safe_repr(a), safe_repr(b))
942            self.fail(self._formatMessage(msg, standardMsg))
943
944    def assertGreaterEqual(self, a, b, msg=None):
945        """Just like self.assertTrue(a >= b), but with a nicer default message."""
946        if not a >= b:
947            standardMsg = '%s not greater than or equal to %s' % (safe_repr(a), safe_repr(b))
948            self.fail(self._formatMessage(msg, standardMsg))
949
950    def assertIsNone(self, obj, msg=None):
951        """Same as self.assertTrue(obj is None), with a nicer default message."""
952        if obj is not None:
953            standardMsg = '%s is not None' % (safe_repr(obj),)
954            self.fail(self._formatMessage(msg, standardMsg))
955
956    def assertIsNotNone(self, obj, msg=None):
957        """Included for symmetry with assertIsNone."""
958        if obj is None:
959            standardMsg = 'unexpectedly None'
960            self.fail(self._formatMessage(msg, standardMsg))
961
962    def assertIsInstance(self, obj, cls, msg=None):
963        """Same as self.assertTrue(isinstance(obj, cls)), with a nicer
964        default message."""
965        if not isinstance(obj, cls):
966            standardMsg = '%s is not an instance of %r' % (safe_repr(obj), cls)
967            self.fail(self._formatMessage(msg, standardMsg))
968
969    def assertNotIsInstance(self, obj, cls, msg=None):
970        """Included for symmetry with assertIsInstance."""
971        if isinstance(obj, cls):
972            standardMsg = '%s is an instance of %r' % (safe_repr(obj), cls)
973            self.fail(self._formatMessage(msg, standardMsg))
974
975    def assertRaisesRegexp(self, expected_exception, expected_regexp,
976                           callable_obj=None, *args, **kwargs):
977        """Asserts that the message in a raised exception matches a regexp.
978
979        Args:
980            expected_exception: Exception class expected to be raised.
981            expected_regexp: Regexp (re pattern object or string) expected
982                    to be found in error message.
983            callable_obj: Function to be called.
984            args: Extra args.
985            kwargs: Extra kwargs.
986        """
987        if expected_regexp is not None:
988            expected_regexp = re.compile(expected_regexp)
989        context = _AssertRaisesContext(expected_exception, self, expected_regexp)
990        if callable_obj is None:
991            return context
992        with context:
993            callable_obj(*args, **kwargs)
994
995    def assertRegexpMatches(self, text, expected_regexp, msg=None):
996        """Fail the test unless the text matches the regular expression."""
997        if isinstance(expected_regexp, basestring):
998            expected_regexp = re.compile(expected_regexp)
999        if not expected_regexp.search(text):
1000            msg = msg or "Regexp didn't match"
1001            msg = '%s: %r not found in %r' % (msg, expected_regexp.pattern, text)
1002            raise self.failureException(msg)
1003
1004    def assertNotRegexpMatches(self, text, unexpected_regexp, msg=None):
1005        """Fail the test if the text matches the regular expression."""
1006        if isinstance(unexpected_regexp, basestring):
1007            unexpected_regexp = re.compile(unexpected_regexp)
1008        match = unexpected_regexp.search(text)
1009        if match:
1010            msg = msg or "Regexp matched"
1011            msg = '%s: %r matches %r in %r' % (msg,
1012                                               text[match.start():match.end()],
1013                                               unexpected_regexp.pattern,
1014                                               text)
1015            raise self.failureException(msg)
1016
1017
1018class FunctionTestCase(TestCase):
1019    """A test case that wraps a test function.
1020
1021    This is useful for slipping pre-existing test functions into the
1022    unittest framework. Optionally, set-up and tidy-up functions can be
1023    supplied. As with TestCase, the tidy-up ('tearDown') function will
1024    always be called if the set-up ('setUp') function ran successfully.
1025    """
1026
1027    def __init__(self, testFunc, setUp=None, tearDown=None, description=None):
1028        super(FunctionTestCase, self).__init__()
1029        self._setUpFunc = setUp
1030        self._tearDownFunc = tearDown
1031        self._testFunc = testFunc
1032        self._description = description
1033
1034    def setUp(self):
1035        if self._setUpFunc is not None:
1036            self._setUpFunc()
1037
1038    def tearDown(self):
1039        if self._tearDownFunc is not None:
1040            self._tearDownFunc()
1041
1042    def runTest(self):
1043        self._testFunc()
1044
1045    def id(self):
1046        return self._testFunc.__name__
1047
1048    def __eq__(self, other):
1049        if not isinstance(other, self.__class__):
1050            return NotImplemented
1051
1052        return self._setUpFunc == other._setUpFunc and \
1053               self._tearDownFunc == other._tearDownFunc and \
1054               self._testFunc == other._testFunc and \
1055               self._description == other._description
1056
1057    def __ne__(self, other):
1058        return not self == other
1059
1060    def __hash__(self):
1061        return hash((type(self), self._setUpFunc, self._tearDownFunc,
1062                     self._testFunc, self._description))
1063
1064    def __str__(self):
1065        return "%s (%s)" % (strclass(self.__class__),
1066                            self._testFunc.__name__)
1067
1068    def __repr__(self):
1069        return "<%s tec=%s>" % (strclass(self.__class__),
1070                                     self._testFunc)
1071
1072    def shortDescription(self):
1073        if self._description is not None:
1074            return self._description
1075        doc = self._testFunc.__doc__
1076        return doc and doc.split("\n")[0].strip() or None
1077