• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Test case implementation"""
2
3import collections
4import sys
5import functools
6import difflib
7import pprint
8import re
9import warnings
10
11from . import result
12from .util import (
13    strclass, safe_repr, unorderable_list_difference,
14    _count_diff_all_purpose, _count_diff_hashable
15)
16
17
18__unittest = True
19
20
21DIFF_OMITTED = ('\nDiff is %s characters long. '
22                 'Set self.maxDiff to None to see it.')
23
24class SkipTest(Exception):
25    """
26    Raise this exception in a test to skip it.
27
28    Usually you can use TestResult.skip() or one of the skipping decorators
29    instead of raising this directly.
30    """
31    pass
32
33class _ExpectedFailure(Exception):
34    """
35    Raise this when a test is expected to fail.
36
37    This is an implementation detail.
38    """
39
40    def __init__(self, exc_info):
41        super(_ExpectedFailure, self).__init__()
42        self.exc_info = exc_info
43
44class _UnexpectedSuccess(Exception):
45    """
46    The test was supposed to fail, but it didn't!
47    """
48    pass
49
50def _id(obj):
51    return obj
52
53def skip(reason):
54    """
55    Unconditionally skip a test.
56    """
57    def decorator(test_item):
58        if not (isinstance(test_item, type) and issubclass(test_item, TestCase)):
59            @functools.wraps(test_item)
60            def skip_wrapper(*args, **kwargs):
61                raise SkipTest(reason)
62            test_item = skip_wrapper
63
64        test_item.__unittest_skip__ = True
65        test_item.__unittest_skip_why__ = reason
66        return test_item
67    return decorator
68
69def skipIf(condition, reason):
70    """
71    Skip a test if the condition is true.
72    """
73    if condition:
74        return skip(reason)
75    return _id
76
77def skipUnless(condition, reason):
78    """
79    Skip a test unless the condition is true.
80    """
81    if not condition:
82        return skip(reason)
83    return _id
84
85
86def expectedFailure(func):
87    @functools.wraps(func)
88    def wrapper(*args, **kwargs):
89        try:
90            func(*args, **kwargs)
91        except Exception:
92            raise _ExpectedFailure(sys.exc_info())
93        raise _UnexpectedSuccess
94    return wrapper
95
96
97class _AssertRaisesContext(object):
98    """A context manager used to implement TestCase.assertRaises* methods."""
99
100    def __init__(self, expected, test_case, expected_regexp=None):
101        self.expected = expected
102        self.failureException = test_case.failureException
103        self.expected_regexp = expected_regexp
104
105    def __enter__(self):
106        return self
107
108    def __exit__(self, exc_type, exc_value, tb):
109        if exc_type is None:
110            try:
111                exc_name = self.expected.__name__
112            except AttributeError:
113                exc_name = str(self.expected)
114            raise self.failureException(
115                "{0} not raised".format(exc_name))
116        if not issubclass(exc_type, self.expected):
117            # let unexpected exceptions pass through
118            return False
119        self.exception = exc_value # store for later retrieval
120        if self.expected_regexp is None:
121            return True
122
123        expected_regexp = self.expected_regexp
124        if isinstance(expected_regexp, basestring):
125            expected_regexp = re.compile(expected_regexp)
126        if not expected_regexp.search(str(exc_value)):
127            raise self.failureException('"%s" does not match "%s"' %
128                     (expected_regexp.pattern, str(exc_value)))
129        return True
130
131
132class TestCase(object):
133    """A class whose instances are single test cases.
134
135    By default, the test code itself should be placed in a method named
136    'runTest'.
137
138    If the fixture may be used for many test cases, create as
139    many test methods as are needed. When instantiating such a TestCase
140    subclass, specify in the constructor arguments the name of the test method
141    that the instance is to execute.
142
143    Test authors should subclass TestCase for their own tests. Construction
144    and deconstruction of the test's environment ('fixture') can be
145    implemented by overriding the 'setUp' and 'tearDown' methods respectively.
146
147    If it is necessary to override the __init__ method, the base class
148    __init__ method must always be called. It is important that subclasses
149    should not change the signature of their __init__ method, since instances
150    of the classes are instantiated automatically by parts of the framework
151    in order to be run.
152    """
153
154    # This attribute determines which exception will be raised when
155    # the instance's assertion methods fail; test methods raising this
156    # exception will be deemed to have 'failed' rather than 'errored'
157
158    failureException = AssertionError
159
160    # This attribute determines whether long messages (including repr of
161    # objects used in assert methods) will be printed on failure in *addition*
162    # to any explicit message passed.
163
164    longMessage = False
165
166    # This attribute sets the maximum length of a diff in failure messages
167    # by assert methods using difflib. It is looked up as an instance attribute
168    # so can be configured by individual tests if required.
169
170    maxDiff = 80*8
171
172    # If a string is longer than _diffThreshold, use normal comparison instead
173    # of difflib.  See #11763.
174    _diffThreshold = 2**16
175
176    # Attribute used by TestSuite for classSetUp
177
178    _classSetupFailed = False
179
180    def __init__(self, methodName='runTest'):
181        """Create an instance of the class that will use the named test
182           method when executed. Raises a ValueError if the instance does
183           not have a method with the specified name.
184        """
185        self._testMethodName = methodName
186        self._resultForDoCleanups = None
187        try:
188            testMethod = getattr(self, methodName)
189        except AttributeError:
190            raise ValueError("no such test method in %s: %s" %
191                  (self.__class__, methodName))
192        self._testMethodDoc = testMethod.__doc__
193        self._cleanups = []
194
195        # Map types to custom assertEqual functions that will compare
196        # instances of said type in more detail to generate a more useful
197        # error message.
198        self._type_equality_funcs = {}
199        self.addTypeEqualityFunc(dict, self.assertDictEqual)
200        self.addTypeEqualityFunc(list, self.assertListEqual)
201        self.addTypeEqualityFunc(tuple, self.assertTupleEqual)
202        self.addTypeEqualityFunc(set, self.assertSetEqual)
203        self.addTypeEqualityFunc(frozenset, self.assertSetEqual)
204        self.addTypeEqualityFunc(unicode, self.assertMultiLineEqual)
205
206    def addTypeEqualityFunc(self, typeobj, function):
207        """Add a type specific assertEqual style function to compare a type.
208
209        This method is for use by TestCase subclasses that need to register
210        their own type equality functions to provide nicer error messages.
211
212        Args:
213            typeobj: The data type to call this function on when both values
214                    are of the same type in assertEqual().
215            function: The callable taking two arguments and an optional
216                    msg= argument that raises self.failureException with a
217                    useful error message when the two arguments are not equal.
218        """
219        self._type_equality_funcs[typeobj] = function
220
221    def addCleanup(self, function, *args, **kwargs):
222        """Add a function, with arguments, to be called when the test is
223        completed. Functions added are called on a LIFO basis and are
224        called after tearDown on test failure or success.
225
226        Cleanup items are called even if setUp fails (unlike tearDown)."""
227        self._cleanups.append((function, args, kwargs))
228
229    def setUp(self):
230        "Hook method for setting up the test fixture before exercising it."
231        pass
232
233    def tearDown(self):
234        "Hook method for deconstructing the test fixture after testing it."
235        pass
236
237    @classmethod
238    def setUpClass(cls):
239        "Hook method for setting up class fixture before running tests in the class."
240
241    @classmethod
242    def tearDownClass(cls):
243        "Hook method for deconstructing the class fixture after running all tests in the class."
244
245    def countTestCases(self):
246        return 1
247
248    def defaultTestResult(self):
249        return result.TestResult()
250
251    def shortDescription(self):
252        """Returns a one-line description of the test, or None if no
253        description has been provided.
254
255        The default implementation of this method returns the first line of
256        the specified test method's docstring.
257        """
258        doc = self._testMethodDoc
259        return doc and doc.split("\n")[0].strip() or None
260
261
262    def id(self):
263        return "%s.%s" % (strclass(self.__class__), self._testMethodName)
264
265    def __eq__(self, other):
266        if type(self) is not type(other):
267            return NotImplemented
268
269        return self._testMethodName == other._testMethodName
270
271    def __ne__(self, other):
272        return not self == other
273
274    def __hash__(self):
275        return hash((type(self), self._testMethodName))
276
277    def __str__(self):
278        return "%s (%s)" % (self._testMethodName, strclass(self.__class__))
279
280    def __repr__(self):
281        return "<%s testMethod=%s>" % \
282               (strclass(self.__class__), self._testMethodName)
283
284    def _addSkip(self, result, reason):
285        addSkip = getattr(result, 'addSkip', None)
286        if addSkip is not None:
287            addSkip(self, reason)
288        else:
289            warnings.warn("TestResult has no addSkip method, skips not reported",
290                          RuntimeWarning, 2)
291            result.addSuccess(self)
292
293    def run(self, result=None):
294        orig_result = result
295        if result is None:
296            result = self.defaultTestResult()
297            startTestRun = getattr(result, 'startTestRun', None)
298            if startTestRun is not None:
299                startTestRun()
300
301        self._resultForDoCleanups = result
302        result.startTest(self)
303
304        testMethod = getattr(self, self._testMethodName)
305        if (getattr(self.__class__, "__unittest_skip__", False) or
306            getattr(testMethod, "__unittest_skip__", False)):
307            # If the class or method was skipped.
308            try:
309                skip_why = (getattr(self.__class__, '__unittest_skip_why__', '')
310                            or getattr(testMethod, '__unittest_skip_why__', ''))
311                self._addSkip(result, skip_why)
312            finally:
313                result.stopTest(self)
314            return
315        try:
316            success = False
317            try:
318                self.setUp()
319            except SkipTest as e:
320                self._addSkip(result, str(e))
321            except KeyboardInterrupt:
322                raise
323            except:
324                result.addError(self, sys.exc_info())
325            else:
326                try:
327                    testMethod()
328                except KeyboardInterrupt:
329                    raise
330                except self.failureException:
331                    result.addFailure(self, sys.exc_info())
332                except _ExpectedFailure as e:
333                    addExpectedFailure = getattr(result, 'addExpectedFailure', None)
334                    if addExpectedFailure is not None:
335                        addExpectedFailure(self, e.exc_info)
336                    else:
337                        warnings.warn("TestResult has no addExpectedFailure method, reporting as passes",
338                                      RuntimeWarning)
339                        result.addSuccess(self)
340                except _UnexpectedSuccess:
341                    addUnexpectedSuccess = getattr(result, 'addUnexpectedSuccess', None)
342                    if addUnexpectedSuccess is not None:
343                        addUnexpectedSuccess(self)
344                    else:
345                        warnings.warn("TestResult has no addUnexpectedSuccess method, reporting as failures",
346                                      RuntimeWarning)
347                        result.addFailure(self, sys.exc_info())
348                except SkipTest as e:
349                    self._addSkip(result, str(e))
350                except:
351                    result.addError(self, sys.exc_info())
352                else:
353                    success = True
354
355                try:
356                    self.tearDown()
357                except KeyboardInterrupt:
358                    raise
359                except:
360                    result.addError(self, sys.exc_info())
361                    success = False
362
363            cleanUpSuccess = self.doCleanups()
364            success = success and cleanUpSuccess
365            if success:
366                result.addSuccess(self)
367        finally:
368            result.stopTest(self)
369            if orig_result is None:
370                stopTestRun = getattr(result, 'stopTestRun', None)
371                if stopTestRun is not None:
372                    stopTestRun()
373
374    def doCleanups(self):
375        """Execute all cleanup functions. Normally called for you after
376        tearDown."""
377        result = self._resultForDoCleanups
378        ok = True
379        while self._cleanups:
380            function, args, kwargs = self._cleanups.pop(-1)
381            try:
382                function(*args, **kwargs)
383            except KeyboardInterrupt:
384                raise
385            except:
386                ok = False
387                result.addError(self, sys.exc_info())
388        return ok
389
390    def __call__(self, *args, **kwds):
391        return self.run(*args, **kwds)
392
393    def debug(self):
394        """Run the test without collecting errors in a TestResult"""
395        self.setUp()
396        getattr(self, self._testMethodName)()
397        self.tearDown()
398        while self._cleanups:
399            function, args, kwargs = self._cleanups.pop(-1)
400            function(*args, **kwargs)
401
402    def skipTest(self, reason):
403        """Skip this test."""
404        raise SkipTest(reason)
405
406    def fail(self, msg=None):
407        """Fail immediately, with the given message."""
408        raise self.failureException(msg)
409
410    def assertFalse(self, expr, msg=None):
411        """Check that the expression is false."""
412        if expr:
413            msg = self._formatMessage(msg, "%s is not false" % safe_repr(expr))
414            raise self.failureException(msg)
415
416    def assertTrue(self, expr, msg=None):
417        """Check that the expression is true."""
418        if not expr:
419            msg = self._formatMessage(msg, "%s is not true" % safe_repr(expr))
420            raise self.failureException(msg)
421
422    def _formatMessage(self, msg, standardMsg):
423        """Honour the longMessage attribute when generating failure messages.
424        If longMessage is False this means:
425        * Use only an explicit message if it is provided
426        * Otherwise use the standard message for the assert
427
428        If longMessage is True:
429        * Use the standard message
430        * If an explicit message is provided, plus ' : ' and the explicit message
431        """
432        if not self.longMessage:
433            return msg or standardMsg
434        if msg is None:
435            return standardMsg
436        try:
437            # don't switch to '{}' formatting in Python 2.X
438            # it changes the way unicode input is handled
439            return '%s : %s' % (standardMsg, msg)
440        except UnicodeDecodeError:
441            return  '%s : %s' % (safe_repr(standardMsg), safe_repr(msg))
442
443
444    def assertRaises(self, excClass, callableObj=None, *args, **kwargs):
445        """Fail unless an exception of class excClass is thrown
446           by callableObj when invoked with arguments args and keyword
447           arguments kwargs. If a different type of exception is
448           thrown, it will not be caught, and the test case will be
449           deemed to have suffered an error, exactly as for an
450           unexpected exception.
451
452           If called with callableObj omitted or None, will return a
453           context object used like this::
454
455                with self.assertRaises(SomeException):
456                    do_something()
457
458           The context manager keeps a reference to the exception as
459           the 'exception' attribute. This allows you to inspect the
460           exception after the assertion::
461
462               with self.assertRaises(SomeException) as cm:
463                   do_something()
464               the_exception = cm.exception
465               self.assertEqual(the_exception.error_code, 3)
466        """
467        context = _AssertRaisesContext(excClass, self)
468        if callableObj is None:
469            return context
470        with context:
471            callableObj(*args, **kwargs)
472
473    def _getAssertEqualityFunc(self, first, second):
474        """Get a detailed comparison function for the types of the two args.
475
476        Returns: A callable accepting (first, second, msg=None) that will
477        raise a failure exception if first != second with a useful human
478        readable error message for those types.
479        """
480        #
481        # NOTE(gregory.p.smith): I considered isinstance(first, type(second))
482        # and vice versa.  I opted for the conservative approach in case
483        # subclasses are not intended to be compared in detail to their super
484        # class instances using a type equality func.  This means testing
485        # subtypes won't automagically use the detailed comparison.  Callers
486        # should use their type specific assertSpamEqual method to compare
487        # subclasses if the detailed comparison is desired and appropriate.
488        # See the discussion in http://bugs.python.org/issue2578.
489        #
490        if type(first) is type(second):
491            asserter = self._type_equality_funcs.get(type(first))
492            if asserter is not None:
493                return asserter
494
495        return self._baseAssertEqual
496
497    def _baseAssertEqual(self, first, second, msg=None):
498        """The default assertEqual implementation, not type specific."""
499        if not first == second:
500            standardMsg = '%s != %s' % (safe_repr(first), safe_repr(second))
501            msg = self._formatMessage(msg, standardMsg)
502            raise self.failureException(msg)
503
504    def assertEqual(self, first, second, msg=None):
505        """Fail if the two objects are unequal as determined by the '=='
506           operator.
507        """
508        assertion_func = self._getAssertEqualityFunc(first, second)
509        assertion_func(first, second, msg=msg)
510
511    def assertNotEqual(self, first, second, msg=None):
512        """Fail if the two objects are equal as determined by the '=='
513           operator.
514        """
515        if not first != second:
516            msg = self._formatMessage(msg, '%s == %s' % (safe_repr(first),
517                                                          safe_repr(second)))
518            raise self.failureException(msg)
519
520
521    def assertAlmostEqual(self, first, second, places=None, msg=None, delta=None):
522        """Fail if the two objects are unequal as determined by their
523           difference rounded to the given number of decimal places
524           (default 7) and comparing to zero, or by comparing that the
525           between the two objects is more than the given delta.
526
527           Note that decimal places (from zero) are usually not the same
528           as significant digits (measured from the most signficant digit).
529
530           If the two objects compare equal then they will automatically
531           compare almost equal.
532        """
533        if first == second:
534            # shortcut
535            return
536        if delta is not None and places is not None:
537            raise TypeError("specify delta or places not both")
538
539        if delta is not None:
540            if abs(first - second) <= delta:
541                return
542
543            standardMsg = '%s != %s within %s delta' % (safe_repr(first),
544                                                        safe_repr(second),
545                                                        safe_repr(delta))
546        else:
547            if places is None:
548                places = 7
549
550            if round(abs(second-first), places) == 0:
551                return
552
553            standardMsg = '%s != %s within %r places' % (safe_repr(first),
554                                                          safe_repr(second),
555                                                          places)
556        msg = self._formatMessage(msg, standardMsg)
557        raise self.failureException(msg)
558
559    def assertNotAlmostEqual(self, first, second, places=None, msg=None, delta=None):
560        """Fail if the two objects are equal as determined by their
561           difference rounded to the given number of decimal places
562           (default 7) and comparing to zero, or by comparing that the
563           between the two objects is less than the given delta.
564
565           Note that decimal places (from zero) are usually not the same
566           as significant digits (measured from the most signficant digit).
567
568           Objects that are equal automatically fail.
569        """
570        if delta is not None and places is not None:
571            raise TypeError("specify delta or places not both")
572        if delta is not None:
573            if not (first == second) and abs(first - second) > delta:
574                return
575            standardMsg = '%s == %s within %s delta' % (safe_repr(first),
576                                                        safe_repr(second),
577                                                        safe_repr(delta))
578        else:
579            if places is None:
580                places = 7
581            if not (first == second) and round(abs(second-first), places) != 0:
582                return
583            standardMsg = '%s == %s within %r places' % (safe_repr(first),
584                                                         safe_repr(second),
585                                                         places)
586
587        msg = self._formatMessage(msg, standardMsg)
588        raise self.failureException(msg)
589
590    # Synonyms for assertion methods
591
592    # The plurals are undocumented.  Keep them that way to discourage use.
593    # Do not add more.  Do not remove.
594    # Going through a deprecation cycle on these would annoy many people.
595    assertEquals = assertEqual
596    assertNotEquals = assertNotEqual
597    assertAlmostEquals = assertAlmostEqual
598    assertNotAlmostEquals = assertNotAlmostEqual
599    assert_ = assertTrue
600
601    # These fail* assertion method names are pending deprecation and will
602    # be a DeprecationWarning in 3.2; http://bugs.python.org/issue2578
603    def _deprecate(original_func):
604        def deprecated_func(*args, **kwargs):
605            warnings.warn(
606                'Please use {0} instead.'.format(original_func.__name__),
607                PendingDeprecationWarning, 2)
608            return original_func(*args, **kwargs)
609        return deprecated_func
610
611    failUnlessEqual = _deprecate(assertEqual)
612    failIfEqual = _deprecate(assertNotEqual)
613    failUnlessAlmostEqual = _deprecate(assertAlmostEqual)
614    failIfAlmostEqual = _deprecate(assertNotAlmostEqual)
615    failUnless = _deprecate(assertTrue)
616    failUnlessRaises = _deprecate(assertRaises)
617    failIf = _deprecate(assertFalse)
618
619    def assertSequenceEqual(self, seq1, seq2, msg=None, seq_type=None):
620        """An equality assertion for ordered sequences (like lists and tuples).
621
622        For the purposes of this function, a valid ordered sequence type is one
623        which can be indexed, has a length, and has an equality operator.
624
625        Args:
626            seq1: The first sequence to compare.
627            seq2: The second sequence to compare.
628            seq_type: The expected datatype of the sequences, or None if no
629                    datatype should be enforced.
630            msg: Optional message to use on failure instead of a list of
631                    differences.
632        """
633        if seq_type is not None:
634            seq_type_name = seq_type.__name__
635            if not isinstance(seq1, seq_type):
636                raise self.failureException('First sequence is not a %s: %s'
637                                        % (seq_type_name, safe_repr(seq1)))
638            if not isinstance(seq2, seq_type):
639                raise self.failureException('Second sequence is not a %s: %s'
640                                        % (seq_type_name, safe_repr(seq2)))
641        else:
642            seq_type_name = "sequence"
643
644        differing = None
645        try:
646            len1 = len(seq1)
647        except (TypeError, NotImplementedError):
648            differing = 'First %s has no length.    Non-sequence?' % (
649                    seq_type_name)
650
651        if differing is None:
652            try:
653                len2 = len(seq2)
654            except (TypeError, NotImplementedError):
655                differing = 'Second %s has no length.    Non-sequence?' % (
656                        seq_type_name)
657
658        if differing is None:
659            if seq1 == seq2:
660                return
661
662            seq1_repr = safe_repr(seq1)
663            seq2_repr = safe_repr(seq2)
664            if len(seq1_repr) > 30:
665                seq1_repr = seq1_repr[:30] + '...'
666            if len(seq2_repr) > 30:
667                seq2_repr = seq2_repr[:30] + '...'
668            elements = (seq_type_name.capitalize(), seq1_repr, seq2_repr)
669            differing = '%ss differ: %s != %s\n' % elements
670
671            for i in xrange(min(len1, len2)):
672                try:
673                    item1 = seq1[i]
674                except (TypeError, IndexError, NotImplementedError):
675                    differing += ('\nUnable to index element %d of first %s\n' %
676                                 (i, seq_type_name))
677                    break
678
679                try:
680                    item2 = seq2[i]
681                except (TypeError, IndexError, NotImplementedError):
682                    differing += ('\nUnable to index element %d of second %s\n' %
683                                 (i, seq_type_name))
684                    break
685
686                if item1 != item2:
687                    differing += ('\nFirst differing element %d:\n%s\n%s\n' %
688                                 (i, item1, item2))
689                    break
690            else:
691                if (len1 == len2 and seq_type is None and
692                    type(seq1) != type(seq2)):
693                    # The sequences are the same, but have differing types.
694                    return
695
696            if len1 > len2:
697                differing += ('\nFirst %s contains %d additional '
698                             'elements.\n' % (seq_type_name, len1 - len2))
699                try:
700                    differing += ('First extra element %d:\n%s\n' %
701                                  (len2, seq1[len2]))
702                except (TypeError, IndexError, NotImplementedError):
703                    differing += ('Unable to index element %d '
704                                  'of first %s\n' % (len2, seq_type_name))
705            elif len1 < len2:
706                differing += ('\nSecond %s contains %d additional '
707                             'elements.\n' % (seq_type_name, len2 - len1))
708                try:
709                    differing += ('First extra element %d:\n%s\n' %
710                                  (len1, seq2[len1]))
711                except (TypeError, IndexError, NotImplementedError):
712                    differing += ('Unable to index element %d '
713                                  'of second %s\n' % (len1, seq_type_name))
714        standardMsg = differing
715        diffMsg = '\n' + '\n'.join(
716            difflib.ndiff(pprint.pformat(seq1).splitlines(),
717                          pprint.pformat(seq2).splitlines()))
718        standardMsg = self._truncateMessage(standardMsg, diffMsg)
719        msg = self._formatMessage(msg, standardMsg)
720        self.fail(msg)
721
722    def _truncateMessage(self, message, diff):
723        max_diff = self.maxDiff
724        if max_diff is None or len(diff) <= max_diff:
725            return message + diff
726        return message + (DIFF_OMITTED % len(diff))
727
728    def assertListEqual(self, list1, list2, msg=None):
729        """A list-specific equality assertion.
730
731        Args:
732            list1: The first list to compare.
733            list2: The second list to compare.
734            msg: Optional message to use on failure instead of a list of
735                    differences.
736
737        """
738        self.assertSequenceEqual(list1, list2, msg, seq_type=list)
739
740    def assertTupleEqual(self, tuple1, tuple2, msg=None):
741        """A tuple-specific equality assertion.
742
743        Args:
744            tuple1: The first tuple to compare.
745            tuple2: The second tuple to compare.
746            msg: Optional message to use on failure instead of a list of
747                    differences.
748        """
749        self.assertSequenceEqual(tuple1, tuple2, msg, seq_type=tuple)
750
751    def assertSetEqual(self, set1, set2, msg=None):
752        """A set-specific equality assertion.
753
754        Args:
755            set1: The first set to compare.
756            set2: The second set to compare.
757            msg: Optional message to use on failure instead of a list of
758                    differences.
759
760        assertSetEqual uses ducktyping to support different types of sets, and
761        is optimized for sets specifically (parameters must support a
762        difference method).
763        """
764        try:
765            difference1 = set1.difference(set2)
766        except TypeError, e:
767            self.fail('invalid type when attempting set difference: %s' % e)
768        except AttributeError, e:
769            self.fail('first argument does not support set difference: %s' % e)
770
771        try:
772            difference2 = set2.difference(set1)
773        except TypeError, e:
774            self.fail('invalid type when attempting set difference: %s' % e)
775        except AttributeError, e:
776            self.fail('second argument does not support set difference: %s' % e)
777
778        if not (difference1 or difference2):
779            return
780
781        lines = []
782        if difference1:
783            lines.append('Items in the first set but not the second:')
784            for item in difference1:
785                lines.append(repr(item))
786        if difference2:
787            lines.append('Items in the second set but not the first:')
788            for item in difference2:
789                lines.append(repr(item))
790
791        standardMsg = '\n'.join(lines)
792        self.fail(self._formatMessage(msg, standardMsg))
793
794    def assertIn(self, member, container, msg=None):
795        """Just like self.assertTrue(a in b), but with a nicer default message."""
796        if member not in container:
797            standardMsg = '%s not found in %s' % (safe_repr(member),
798                                                  safe_repr(container))
799            self.fail(self._formatMessage(msg, standardMsg))
800
801    def assertNotIn(self, member, container, msg=None):
802        """Just like self.assertTrue(a not in b), but with a nicer default message."""
803        if member in container:
804            standardMsg = '%s unexpectedly found in %s' % (safe_repr(member),
805                                                        safe_repr(container))
806            self.fail(self._formatMessage(msg, standardMsg))
807
808    def assertIs(self, expr1, expr2, msg=None):
809        """Just like self.assertTrue(a is b), but with a nicer default message."""
810        if expr1 is not expr2:
811            standardMsg = '%s is not %s' % (safe_repr(expr1),
812                                             safe_repr(expr2))
813            self.fail(self._formatMessage(msg, standardMsg))
814
815    def assertIsNot(self, expr1, expr2, msg=None):
816        """Just like self.assertTrue(a is not b), but with a nicer default message."""
817        if expr1 is expr2:
818            standardMsg = 'unexpectedly identical: %s' % (safe_repr(expr1),)
819            self.fail(self._formatMessage(msg, standardMsg))
820
821    def assertDictEqual(self, d1, d2, msg=None):
822        self.assertIsInstance(d1, dict, 'First argument is not a dictionary')
823        self.assertIsInstance(d2, dict, 'Second argument is not a dictionary')
824
825        if d1 != d2:
826            standardMsg = '%s != %s' % (safe_repr(d1, True), safe_repr(d2, True))
827            diff = ('\n' + '\n'.join(difflib.ndiff(
828                           pprint.pformat(d1).splitlines(),
829                           pprint.pformat(d2).splitlines())))
830            standardMsg = self._truncateMessage(standardMsg, diff)
831            self.fail(self._formatMessage(msg, standardMsg))
832
833    def assertDictContainsSubset(self, expected, actual, msg=None):
834        """Checks whether actual is a superset of expected."""
835        missing = []
836        mismatched = []
837        for key, value in expected.iteritems():
838            if key not in actual:
839                missing.append(key)
840            elif value != actual[key]:
841                mismatched.append('%s, expected: %s, actual: %s' %
842                                  (safe_repr(key), safe_repr(value),
843                                   safe_repr(actual[key])))
844
845        if not (missing or mismatched):
846            return
847
848        standardMsg = ''
849        if missing:
850            standardMsg = 'Missing: %s' % ','.join(safe_repr(m) for m in
851                                                    missing)
852        if mismatched:
853            if standardMsg:
854                standardMsg += '; '
855            standardMsg += 'Mismatched values: %s' % ','.join(mismatched)
856
857        self.fail(self._formatMessage(msg, standardMsg))
858
859    def assertItemsEqual(self, expected_seq, actual_seq, msg=None):
860        """An unordered sequence specific comparison. It asserts that
861        actual_seq and expected_seq have the same element counts.
862        Equivalent to::
863
864            self.assertEqual(Counter(iter(actual_seq)),
865                             Counter(iter(expected_seq)))
866
867        Asserts that each element has the same count in both sequences.
868        Example:
869            - [0, 1, 1] and [1, 0, 1] compare equal.
870            - [0, 0, 1] and [0, 1] compare unequal.
871        """
872        first_seq, second_seq = list(actual_seq), list(expected_seq)
873        with warnings.catch_warnings():
874            if sys.py3kwarning:
875                # Silence Py3k warning raised during the sorting
876                for _msg in ["(code|dict|type) inequality comparisons",
877                             "builtin_function_or_method order comparisons",
878                             "comparing unequal types"]:
879                    warnings.filterwarnings("ignore", _msg, DeprecationWarning)
880            try:
881                first = collections.Counter(first_seq)
882                second = collections.Counter(second_seq)
883            except TypeError:
884                # Handle case with unhashable elements
885                differences = _count_diff_all_purpose(first_seq, second_seq)
886            else:
887                if first == second:
888                    return
889                differences = _count_diff_hashable(first_seq, second_seq)
890
891        if differences:
892            standardMsg = 'Element counts were not equal:\n'
893            lines = ['First has %d, Second has %d:  %r' % diff for diff in differences]
894            diffMsg = '\n'.join(lines)
895            standardMsg = self._truncateMessage(standardMsg, diffMsg)
896            msg = self._formatMessage(msg, standardMsg)
897            self.fail(msg)
898
899    def assertMultiLineEqual(self, first, second, msg=None):
900        """Assert that two multi-line strings are equal."""
901        self.assertIsInstance(first, basestring,
902                'First argument is not a string')
903        self.assertIsInstance(second, basestring,
904                'Second argument is not a string')
905
906        if first != second:
907            # don't use difflib if the strings are too long
908            if (len(first) > self._diffThreshold or
909                len(second) > self._diffThreshold):
910                self._baseAssertEqual(first, second, msg)
911            firstlines = first.splitlines(True)
912            secondlines = second.splitlines(True)
913            if len(firstlines) == 1 and first.strip('\r\n') == first:
914                firstlines = [first + '\n']
915                secondlines = [second + '\n']
916            standardMsg = '%s != %s' % (safe_repr(first, True),
917                                        safe_repr(second, True))
918            diff = '\n' + ''.join(difflib.ndiff(firstlines, secondlines))
919            standardMsg = self._truncateMessage(standardMsg, diff)
920            self.fail(self._formatMessage(msg, standardMsg))
921
922    def assertLess(self, a, b, msg=None):
923        """Just like self.assertTrue(a < b), but with a nicer default message."""
924        if not a < b:
925            standardMsg = '%s not less than %s' % (safe_repr(a), safe_repr(b))
926            self.fail(self._formatMessage(msg, standardMsg))
927
928    def assertLessEqual(self, a, b, msg=None):
929        """Just like self.assertTrue(a <= b), but with a nicer default message."""
930        if not a <= b:
931            standardMsg = '%s not less than or equal to %s' % (safe_repr(a), safe_repr(b))
932            self.fail(self._formatMessage(msg, standardMsg))
933
934    def assertGreater(self, a, b, msg=None):
935        """Just like self.assertTrue(a > b), but with a nicer default message."""
936        if not a > b:
937            standardMsg = '%s not greater than %s' % (safe_repr(a), safe_repr(b))
938            self.fail(self._formatMessage(msg, standardMsg))
939
940    def assertGreaterEqual(self, a, b, msg=None):
941        """Just like self.assertTrue(a >= b), but with a nicer default message."""
942        if not a >= b:
943            standardMsg = '%s not greater than or equal to %s' % (safe_repr(a), safe_repr(b))
944            self.fail(self._formatMessage(msg, standardMsg))
945
946    def assertIsNone(self, obj, msg=None):
947        """Same as self.assertTrue(obj is None), with a nicer default message."""
948        if obj is not None:
949            standardMsg = '%s is not None' % (safe_repr(obj),)
950            self.fail(self._formatMessage(msg, standardMsg))
951
952    def assertIsNotNone(self, obj, msg=None):
953        """Included for symmetry with assertIsNone."""
954        if obj is None:
955            standardMsg = 'unexpectedly None'
956            self.fail(self._formatMessage(msg, standardMsg))
957
958    def assertIsInstance(self, obj, cls, msg=None):
959        """Same as self.assertTrue(isinstance(obj, cls)), with a nicer
960        default message."""
961        if not isinstance(obj, cls):
962            standardMsg = '%s is not an instance of %r' % (safe_repr(obj), cls)
963            self.fail(self._formatMessage(msg, standardMsg))
964
965    def assertNotIsInstance(self, obj, cls, msg=None):
966        """Included for symmetry with assertIsInstance."""
967        if isinstance(obj, cls):
968            standardMsg = '%s is an instance of %r' % (safe_repr(obj), cls)
969            self.fail(self._formatMessage(msg, standardMsg))
970
971    def assertRaisesRegexp(self, expected_exception, expected_regexp,
972                           callable_obj=None, *args, **kwargs):
973        """Asserts that the message in a raised exception matches a regexp.
974
975        Args:
976            expected_exception: Exception class expected to be raised.
977            expected_regexp: Regexp (re pattern object or string) expected
978                    to be found in error message.
979            callable_obj: Function to be called.
980            args: Extra args.
981            kwargs: Extra kwargs.
982        """
983        context = _AssertRaisesContext(expected_exception, self, expected_regexp)
984        if callable_obj is None:
985            return context
986        with context:
987            callable_obj(*args, **kwargs)
988
989    def assertRegexpMatches(self, text, expected_regexp, msg=None):
990        """Fail the test unless the text matches the regular expression."""
991        if isinstance(expected_regexp, basestring):
992            expected_regexp = re.compile(expected_regexp)
993        if not expected_regexp.search(text):
994            msg = msg or "Regexp didn't match"
995            msg = '%s: %r not found in %r' % (msg, expected_regexp.pattern, text)
996            raise self.failureException(msg)
997
998    def assertNotRegexpMatches(self, text, unexpected_regexp, msg=None):
999        """Fail the test if the text matches the regular expression."""
1000        if isinstance(unexpected_regexp, basestring):
1001            unexpected_regexp = re.compile(unexpected_regexp)
1002        match = unexpected_regexp.search(text)
1003        if match:
1004            msg = msg or "Regexp matched"
1005            msg = '%s: %r matches %r in %r' % (msg,
1006                                               text[match.start():match.end()],
1007                                               unexpected_regexp.pattern,
1008                                               text)
1009            raise self.failureException(msg)
1010
1011
1012class FunctionTestCase(TestCase):
1013    """A test case that wraps a test function.
1014
1015    This is useful for slipping pre-existing test functions into the
1016    unittest framework. Optionally, set-up and tidy-up functions can be
1017    supplied. As with TestCase, the tidy-up ('tearDown') function will
1018    always be called if the set-up ('setUp') function ran successfully.
1019    """
1020
1021    def __init__(self, testFunc, setUp=None, tearDown=None, description=None):
1022        super(FunctionTestCase, self).__init__()
1023        self._setUpFunc = setUp
1024        self._tearDownFunc = tearDown
1025        self._testFunc = testFunc
1026        self._description = description
1027
1028    def setUp(self):
1029        if self._setUpFunc is not None:
1030            self._setUpFunc()
1031
1032    def tearDown(self):
1033        if self._tearDownFunc is not None:
1034            self._tearDownFunc()
1035
1036    def runTest(self):
1037        self._testFunc()
1038
1039    def id(self):
1040        return self._testFunc.__name__
1041
1042    def __eq__(self, other):
1043        if not isinstance(other, self.__class__):
1044            return NotImplemented
1045
1046        return self._setUpFunc == other._setUpFunc and \
1047               self._tearDownFunc == other._tearDownFunc and \
1048               self._testFunc == other._testFunc and \
1049               self._description == other._description
1050
1051    def __ne__(self, other):
1052        return not self == other
1053
1054    def __hash__(self):
1055        return hash((type(self), self._setUpFunc, self._tearDownFunc,
1056                     self._testFunc, self._description))
1057
1058    def __str__(self):
1059        return "%s (%s)" % (strclass(self.__class__),
1060                            self._testFunc.__name__)
1061
1062    def __repr__(self):
1063        return "<%s tec=%s>" % (strclass(self.__class__),
1064                                     self._testFunc)
1065
1066    def shortDescription(self):
1067        if self._description is not None:
1068            return self._description
1069        doc = self._testFunc.__doc__
1070        return doc and doc.split("\n")[0].strip() or None
1071