• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2'''
3Python unit testing framework, based on Erich Gamma's JUnit and Kent Beck's
4Smalltalk testing framework.
5
6This module contains the core framework classes that form the basis of
7specific test cases and suites (TestCase, TestSuite etc.), and also a
8text-based utility class for running the tests and reporting the results
9 (TextTestRunner).
10
11Simple usage:
12
13    import unittest
14
15    class IntegerArithmenticTestCase(unittest.TestCase):
16        def testAdd(self):  ## test method names begin 'test*'
17            self.assertEqual((1 + 2), 3)
18            self.assertEqual(0 + 1, 1)
19        def testMultiply(self):
20            self.assertEqual((0 * 10), 0)
21            self.assertEqual((5 * 8), 40)
22
23    if __name__ == '__main__':
24        unittest.main()
25
26Further information is available in the bundled documentation, and from
27
28  http://docs.python.org/library/unittest.html
29
30Copyright (c) 1999-2003 Steve Purcell
31Copyright (c) 2003-2009 Python Software Foundation
32Copyright (c) 2009      Garrett Cooper
33This module is free software, and you may redistribute it and/or modify
34it under the same terms as Python itself, so long as this copyright message
35and disclaimer are retained in their original form.
36
37IN NO EVENT SHALL THE AUTHOR BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT,
38SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OF
39THIS CODE, EVEN IF THE AUTHOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH
40DAMAGE.
41
42THE AUTHOR SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT
43LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
44PARTICULAR PURPOSE.  THE CODE PROVIDED HEREUNDER IS ON AN "AS IS" BASIS,
45AND THERE IS NO OBLIGATION WHATSOEVER TO PROVIDE MAINTENANCE,
46SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.
47
48Garrett: This module was backported using source from r71263 with fixes noted
49in Issue 5771.
50'''
51
52import difflib
53try:
54    import functools
55except ImportError:
56    # we put a local copy of this in our repository
57    #  http://pypi.python.org/simple/functools/
58    import functools_24
59    functools = functools_24
60
61import os
62import pprint
63import re
64import sys
65import time
66import traceback
67import types
68import warnings
69
70##############################################################################
71# Exported classes and functions
72##############################################################################
73__all__ = ['TestResult', 'TestCase', 'TestSuite', 'ClassTestSuite',
74           'TextTestRunner', 'TestLoader', 'FunctionTestCase', 'main',
75           'defaultTestLoader', 'SkipTest', 'skip', 'skipIf', 'skipUnless',
76           'expectedFailure']
77
78# Expose obsolete functions for backwards compatibility
79__all__.extend(['getTestCaseNames', 'makeSuite', 'findTestCases'])
80
81
82##############################################################################
83# Backward compatibility
84##############################################################################
85
86def _CmpToKey(mycmp):
87    'Convert a cmp= function into a key= function'
88    class K(object):
89        def __init__(self, obj):
90            self.obj = obj
91        def __lt__(self, other):
92            return mycmp(self.obj, other.obj) == -1
93    return K
94
95def _EmulateWith(context, func):
96    context.__enter__()
97    try:
98        func()
99    except:
100        if not context.__exit__(sys.exc_type, sys.exc_value, sys.exc_traceback):
101            raise
102    else:
103        context.__exit__(None, None, None)
104
105##############################################################################
106# Test framework core
107##############################################################################
108
109def _strclass(cls):
110    return "%s.%s" % (cls.__module__, cls.__name__)
111
112
113class SkipTest(Exception):
114    """
115    Raise this exception in a test to skip it.
116
117    Usually you can use TestResult.skip() or one of the skipping decorators
118    instead of raising this directly.
119    """
120    pass
121
122class _ExpectedFailure(Exception):
123    """
124    Raise this when a test is expected to fail.
125
126    This is an implementation detail.
127    """
128
129    def __init__(self, exc_info):
130        super(_ExpectedFailure, self).__init__()
131        self.exc_info = exc_info
132
133class _UnexpectedSuccess(Exception):
134    """
135    The test was supposed to fail, but it didn't!
136    """
137    pass
138
139def _id(obj):
140    return obj
141
142def skip(reason):
143    """
144    Unconditionally skip a test.
145    """
146    def decorator(test_item):
147        if isinstance(test_item, type) and issubclass(test_item, TestCase):
148            test_item.__unittest_skip__ = True
149            test_item.__unittest_skip_why__ = reason
150            return test_item
151        @functools.wraps(test_item)
152        def skip_wrapper(*args, **kwargs):
153            raise SkipTest(reason)
154        return skip_wrapper
155    return decorator
156
157def skipIf(condition, reason):
158    """
159    Skip a test if the condition is true.
160    """
161    if condition:
162        return skip(reason)
163    return _id
164
165def skipUnless(condition, reason):
166    """
167    Skip a test unless the condition is true.
168    """
169    if not condition:
170        return skip(reason)
171    return _id
172
173
174def expectedFailure(func):
175    @functools.wraps(func)
176    def wrapper(*args, **kwargs):
177        try:
178            func(*args, **kwargs)
179        except Exception:
180            raise _ExpectedFailure(sys.exc_info())
181        raise _UnexpectedSuccess
182    return wrapper
183
184__unittest = 1
185
186class TestResult(object):
187    """Holder for test result information.
188
189    Test results are automatically managed by the TestCase and TestSuite
190    classes, and do not need to be explicitly manipulated by writers of tests.
191
192    Each instance holds the total number of tests run, and collections of
193    failures and errors that occurred among those test runs. The collections
194    contain tuples of (testcase, exceptioninfo), where exceptioninfo is the
195    formatted traceback of the error that occurred.
196    """
197    def __init__(self):
198        self.failures = []
199        self.errors = []
200        self.testsRun = 0
201        self.skipped = []
202        self.expectedFailures = []
203        self.unexpectedSuccesses = []
204        self.shouldStop = False
205
206    def startTest(self, test):
207        "Called when the given test is about to be run"
208        self.testsRun = self.testsRun + 1
209
210    def stopTest(self, test):
211        "Called when the given test has been run"
212        pass
213
214    def addError(self, test, err):
215        """Called when an error has occurred. 'err' is a tuple of values as
216        returned by sys.exc_info().
217        """
218        self.errors.append((test, self._exc_info_to_string(err, test)))
219
220    def addFailure(self, test, err):
221        """Called when an error has occurred. 'err' is a tuple of values as
222        returned by sys.exc_info()."""
223        self.failures.append((test, self._exc_info_to_string(err, test)))
224
225    def addSuccess(self, test):
226        "Called when a test has completed successfully"
227        pass
228
229    def addSkip(self, test, reason):
230        """Called when a test is skipped."""
231        self.skipped.append((test, reason))
232
233    def addExpectedFailure(self, test, err):
234        """Called when an expected failure/error occured."""
235        self.expectedFailures.append(
236            (test, self._exc_info_to_string(err, test)))
237
238    def addUnexpectedSuccess(self, test):
239        """Called when a test was expected to fail, but succeed."""
240        self.unexpectedSuccesses.append(test)
241
242    def wasSuccessful(self):
243        "Tells whether or not this result was a success"
244        return len(self.failures) == len(self.errors) == 0
245
246    def stop(self):
247        "Indicates that the tests should be aborted"
248        self.shouldStop = True
249
250    def _exc_info_to_string(self, err, test):
251        """Converts a sys.exc_info()-style tuple of values into a string."""
252        exctype, value, tb = err
253        # Skip test runner traceback levels
254        while tb and self._is_relevant_tb_level(tb):
255            tb = tb.tb_next
256        if exctype is test.failureException:
257            # Skip assert*() traceback levels
258            length = self._count_relevant_tb_levels(tb)
259            return ''.join(traceback.format_exception(exctype, value, tb, length))
260        return ''.join(traceback.format_exception(exctype, value, tb))
261
262    def _is_relevant_tb_level(self, tb):
263        return '__unittest' in tb.tb_frame.f_globals
264
265    def _count_relevant_tb_levels(self, tb):
266        length = 0
267        while tb and not self._is_relevant_tb_level(tb):
268            length += 1
269            tb = tb.tb_next
270        return length
271
272    def __repr__(self):
273        return "<%s run=%i errors=%i failures=%i>" % \
274               (_strclass(self.__class__), self.testsRun, len(self.errors),
275                len(self.failures))
276
277
278class _AssertRaisesContext(object):
279    """A context manager used to implement TestCase.assertRaises* methods."""
280
281    def __init__(self, expected, test_case, expected_regexp=None):
282        self.expected = expected
283        self.failureException = test_case.failureException
284        self.expected_regex = expected_regexp
285
286    def __enter__(self):
287        pass
288
289    def __exit__(self, exc_type, exc_value, tb):
290        if exc_type is None:
291            try:
292                exc_name = self.expected.__name__
293            except AttributeError:
294                exc_name = str(self.expected)
295            raise self.failureException(
296                "%s not raised" % exc_name)
297        if not issubclass(exc_type, self.expected):
298            # let unexpected exceptions pass through
299            return False
300        if self.expected_regex is None:
301            return True
302
303        expected_regexp = self.expected_regex
304        if isinstance(expected_regexp, basestring):
305            expected_regexp = re.compile(expected_regexp)
306        if not expected_regexp.search(str(exc_value)):
307            raise self.failureException('"%s" does not match "%s"' %
308                     (expected_regexp.pattern, str(exc_value)))
309        return True
310
311
312class _AssertWrapper(object):
313    """Wrap entries in the _type_equality_funcs registry to make them deep
314    copyable."""
315
316    def __init__(self, function):
317        self.function = function
318
319    def __deepcopy__(self, memo):
320        memo[id(self)] = self
321
322
323class TestCase(object):
324    """A class whose instances are single test cases.
325
326    By default, the test code itself should be placed in a method named
327    'runTest'.
328
329    If the fixture may be used for many test cases, create as
330    many test methods as are needed. When instantiating such a TestCase
331    subclass, specify in the constructor arguments the name of the test method
332    that the instance is to execute.
333
334    Test authors should subclass TestCase for their own tests. Construction
335    and deconstruction of the test's environment ('fixture') can be
336    implemented by overriding the 'setUp' and 'tearDown' methods respectively.
337
338    If it is necessary to override the __init__ method, the base class
339    __init__ method must always be called. It is important that subclasses
340    should not change the signature of their __init__ method, since instances
341    of the classes are instantiated automatically by parts of the framework
342    in order to be run.
343    """
344
345    # This attribute determines which exception will be raised when
346    # the instance's assertion methods fail; test methods raising this
347    # exception will be deemed to have 'failed' rather than 'errored'
348
349    failureException = AssertionError
350
351    # This attribute determines whether long messages (including repr of
352    # objects used in assert methods) will be printed on failure in *addition*
353    # to any explicit message passed.
354
355    longMessage = False
356
357
358    def __init__(self, methodName='runTest'):
359        """Create an instance of the class that will use the named test
360           method when executed. Raises a ValueError if the instance does
361           not have a method with the specified name.
362        """
363        self._testMethodName = methodName
364        try:
365            testMethod = getattr(self, methodName)
366        except AttributeError:
367            raise ValueError("no such test method in %s: %s" % \
368                  (self.__class__, methodName))
369        self._testMethodDoc = testMethod.__doc__
370
371        # Map types to custom assertEqual functions that will compare
372        # instances of said type in more detail to generate a more useful
373        # error message.
374        self._type_equality_funcs = {}
375        self.addTypeEqualityFunc(dict, self.assertDictEqual)
376        self.addTypeEqualityFunc(list, self.assertListEqual)
377        self.addTypeEqualityFunc(tuple, self.assertTupleEqual)
378        self.addTypeEqualityFunc(set, self.assertSetEqual)
379        self.addTypeEqualityFunc(frozenset, self.assertSetEqual)
380
381    def addTypeEqualityFunc(self, typeobj, function):
382        """Add a type specific assertEqual style function to compare a type.
383
384        This method is for use by TestCase subclasses that need to register
385        their own type equality functions to provide nicer error messages.
386
387        Args:
388            typeobj: The data type to call this function on when both values
389                    are of the same type in assertEqual().
390            function: The callable taking two arguments and an optional
391                    msg= argument that raises self.failureException with a
392                    useful error message when the two arguments are not equal.
393        """
394        self._type_equality_funcs[typeobj] = _AssertWrapper(function)
395
396    def setUp(self):
397        "Hook method for setting up the test fixture before exercising it."
398        pass
399
400    def tearDown(self):
401        "Hook method for deconstructing the test fixture after testing it."
402        pass
403
404    def countTestCases(self):
405        return 1
406
407    def defaultTestResult(self):
408        return TestResult()
409
410    def shortDescription(self):
411        """Returns both the test method name and first line of its docstring.
412
413        If no docstring is given, only returns the method name.
414
415        This method overrides unittest.TestCase.shortDescription(), which
416        only returns the first line of the docstring, obscuring the name
417        of the test upon failure.
418        """
419        desc = str(self)
420        doc_first_line = None
421
422        if self._testMethodDoc:
423            doc_first_line = self._testMethodDoc.split("\n")[0].strip()
424        if doc_first_line:
425            desc = '\n'.join((desc, doc_first_line))
426        return desc
427
428    def id(self):
429        return "%s.%s" % (_strclass(self.__class__), self._testMethodName)
430
431    def __eq__(self, other):
432        if type(self) is not type(other):
433            return NotImplemented
434
435        return self._testMethodName == other._testMethodName
436
437    def __ne__(self, other):
438        return not self == other
439
440    def __hash__(self):
441        return hash((type(self), self._testMethodName))
442
443    def __str__(self):
444        return "%s (%s)" % (self._testMethodName, _strclass(self.__class__))
445
446    def __repr__(self):
447        return "<%s testMethod=%s>" % \
448               (_strclass(self.__class__), self._testMethodName)
449
450    def run(self, result=None):
451        if result is None:
452            result = self.defaultTestResult()
453        result.startTest(self)
454        testMethod = getattr(self, self._testMethodName)
455        try:
456            try:
457                self.setUp()
458            except SkipTest, e:
459                result.addSkip(self, str(e))
460                return
461            except Exception:
462                result.addError(self, sys.exc_info())
463                return
464
465            success = False
466            try:
467                testMethod()
468            except self.failureException:
469                result.addFailure(self, sys.exc_info())
470            except _ExpectedFailure, e:
471                result.addExpectedFailure(self, e.exc_info)
472            except _UnexpectedSuccess:
473                result.addUnexpectedSuccess(self)
474            except SkipTest, e:
475                result.addSkip(self, str(e))
476            except Exception:
477                result.addError(self, sys.exc_info())
478            else:
479                success = True
480
481            try:
482                self.tearDown()
483            except Exception:
484                result.addError(self, sys.exc_info())
485                success = False
486            if success:
487                result.addSuccess(self)
488        finally:
489            result.stopTest(self)
490
491    def __call__(self, *args, **kwds):
492        return self.run(*args, **kwds)
493
494    def debug(self):
495        """Run the test without collecting errors in a TestResult"""
496        self.setUp()
497        getattr(self, self._testMethodName)()
498        self.tearDown()
499
500    def skipTest(self, reason):
501        """Skip this test."""
502        raise SkipTest(reason)
503
504    def fail(self, msg=None):
505        """Fail immediately, with the given message."""
506        raise self.failureException(msg)
507
508    def assertFalse(self, expr, msg=None):
509        "Fail the test if the expression is true."
510        if expr:
511            msg = self._formatMessage(msg, "%r is not False" % expr)
512            raise self.failureException(msg)
513
514    def assertTrue(self, expr, msg=None):
515        """Fail the test unless the expression is true."""
516        if not expr:
517            msg = self._formatMessage(msg, "%r is not True" % expr)
518            raise self.failureException(msg)
519
520    def _formatMessage(self, msg, standardMsg):
521        """Honour the longMessage attribute when generating failure messages.
522        If longMessage is False this means:
523        * Use only an explicit message if it is provided
524        * Otherwise use the standard message for the assert
525
526        If longMessage is True:
527        * Use the standard message
528        * If an explicit message is provided, plus ' : ' and the explicit message
529        """
530        if not self.longMessage:
531            return msg or standardMsg
532        if msg is None:
533            return standardMsg
534        return standardMsg + ' : ' + msg
535
536
537    def assertRaises(self, excClass, callableObj=None, *args, **kwargs):
538        """Fail unless an exception of class excClass is thrown
539           by callableObj when invoked with arguments args and keyword
540           arguments kwargs. If a different type of exception is
541           thrown, it will not be caught, and the test case will be
542           deemed to have suffered an error, exactly as for an
543           unexpected exception.
544
545           If called with callableObj omitted or None, will return a
546           context object used like this::
547
548                with self.assertRaises(some_error_class):
549                    do_something()
550        """
551        context = _AssertRaisesContext(excClass, self)
552        if callableObj is None:
553            return context
554
555        # XXX (garrcoop): `with context' isn't supported lexigram with 2.4/2.5,
556        # even though PEP 343 (sort of) implies that based on the publishing
557        # date. There may be another PEP which changed the syntax...
558        _EmulateWith(context, lambda: callableObj(*args, **kwargs))
559
560    def _getAssertEqualityFunc(self, first, second):
561        """Get a detailed comparison function for the types of the two args.
562
563        Returns: A callable accepting (first, second, msg=None) that will
564        raise a failure exception if first != second with a useful human
565        readable error message for those types.
566        """
567        #
568        # NOTE(gregory.p.smith): I considered isinstance(first, type(second))
569        # and vice versa.  I opted for the conservative approach in case
570        # subclasses are not intended to be compared in detail to their super
571        # class instances using a type equality func.  This means testing
572        # subtypes won't automagically use the detailed comparison.  Callers
573        # should use their type specific assertSpamEqual method to compare
574        # subclasses if the detailed comparison is desired and appropriate.
575        # See the discussion in http://bugs.python.org/issue2578.
576        #
577        if type(first) is type(second):
578            asserter = self._type_equality_funcs.get(type(first))
579            if asserter is not None:
580                return asserter.function
581
582        return self._baseAssertEqual
583
584    def _baseAssertEqual(self, first, second, msg=None):
585        """The default assertEqual implementation, not type specific."""
586        if not first == second:
587            standardMsg = '%r != %r' % (first, second)
588            msg = self._formatMessage(msg, standardMsg)
589            raise self.failureException(msg)
590
591    def assertEqual(self, first, second, msg=None):
592        """Fail if the two objects are unequal as determined by the '=='
593           operator.
594        """
595        assertion_func = self._getAssertEqualityFunc(first, second)
596        assertion_func(first, second, msg=msg)
597
598    def assertNotEqual(self, first, second, msg=None):
599        """Fail if the two objects are equal as determined by the '=='
600           operator.
601        """
602        if not first != second:
603            msg = self._formatMessage(msg, '%r == %r' % (first, second))
604            raise self.failureException(msg)
605
606    def assertAlmostEqual(self, first, second, places=7, msg=None):
607        """Fail if the two objects are unequal as determined by their
608           difference rounded to the given number of decimal places
609           (default 7) and comparing to zero.
610
611           Note that decimal places (from zero) are usually not the same
612           as significant digits (measured from the most signficant digit).
613        """
614        if round(abs(second-first), places) != 0:
615            standardMsg = '%r != %r within %r places' % (first, second, places)
616            msg = self._formatMessage(msg, standardMsg)
617            raise self.failureException(msg)
618
619    def assertNotAlmostEqual(self, first, second, places=7, msg=None):
620        """Fail if the two objects are equal as determined by their
621           difference rounded to the given number of decimal places
622           (default 7) and comparing to zero.
623
624           Note that decimal places (from zero) are usually not the same
625           as significant digits (measured from the most signficant digit).
626        """
627        if round(abs(second-first), places) == 0:
628            standardMsg = '%r == %r within %r places' % (first, second, places)
629            msg = self._formatMessage(msg, standardMsg)
630            raise self.failureException(msg)
631
632    # Synonyms for assertion methods
633
634    # The plurals are undocumented.  Keep them that way to discourage use.
635    # Do not add more.  Do not remove.
636    # Going through a deprecation cycle on these would annoy many people.
637    assertEquals = assertEqual
638    assertNotEquals = assertNotEqual
639    assertAlmostEquals = assertAlmostEqual
640    assertNotAlmostEquals = assertNotAlmostEqual
641    assert_ = assertTrue
642
643    # These fail* assertion method names are pending deprecation and will
644    # be a DeprecationWarning in 3.2; http://bugs.python.org/issue2578
645    def _deprecate(original_func):
646        def deprecated_func(*args, **kwargs):
647            warnings.warn(
648                'Please use %s instead.' % original_func.__name__,
649                PendingDeprecationWarning, 2)
650            return original_func(*args, **kwargs)
651        return deprecated_func
652
653    failUnlessEqual = _deprecate(assertEqual)
654    failIfEqual = _deprecate(assertNotEqual)
655    failUnlessAlmostEqual = _deprecate(assertAlmostEqual)
656    failIfAlmostEqual = _deprecate(assertNotAlmostEqual)
657    failUnless = _deprecate(assertTrue)
658    failUnlessRaises = _deprecate(assertRaises)
659    failIf = _deprecate(assertFalse)
660
661    def assertSequenceEqual(self, seq1, seq2, msg=None, seq_type=None):
662        """An equality assertion for ordered sequences (like lists and tuples).
663
664        For the purposes of this function, a valid orderd sequence type is one
665        which can be indexed, has a length, and has an equality operator.
666
667        Args:
668            seq1: The first sequence to compare.
669            seq2: The second sequence to compare.
670            seq_type: The expected datatype of the sequences, or None if no
671                    datatype should be enforced.
672            msg: Optional message to use on failure instead of a list of
673                    differences.
674        """
675        if seq_type != None:
676            seq_type_name = seq_type.__name__
677            if not isinstance(seq1, seq_type):
678                raise self.failureException('First sequence is not a %s: %r'
679                                            % (seq_type_name, seq1))
680            if not isinstance(seq2, seq_type):
681                raise self.failureException('Second sequence is not a %s: %r'
682                                            % (seq_type_name, seq2))
683        else:
684            seq_type_name = "sequence"
685
686        differing = None
687        try:
688            len1 = len(seq1)
689        except (TypeError, NotImplementedError):
690            differing = 'First %s has no length.    Non-sequence?' % (
691                    seq_type_name)
692
693        if differing is None:
694            try:
695                len2 = len(seq2)
696            except (TypeError, NotImplementedError):
697                differing = 'Second %s has no length.    Non-sequence?' % (
698                        seq_type_name)
699
700        if differing is None:
701            if seq1 == seq2:
702                return
703
704            for i in xrange(min(len1, len2)):
705                try:
706                    item1 = seq1[i]
707                except (TypeError, IndexError, NotImplementedError):
708                    differing = ('Unable to index element %d of first %s\n' %
709                                 (i, seq_type_name))
710                    break
711
712                try:
713                    item2 = seq2[i]
714                except (TypeError, IndexError, NotImplementedError):
715                    differing = ('Unable to index element %d of second %s\n' %
716                                 (i, seq_type_name))
717                    break
718
719                if item1 != item2:
720                    differing = ('First differing element %d:\n%s\n%s\n' %
721                                 (i, item1, item2))
722                    break
723            else:
724                if (len1 == len2 and seq_type is None and
725                    type(seq1) != type(seq2)):
726                    # The sequences are the same, but have differing types.
727                    return
728                # A catch-all message for handling arbitrary user-defined
729                # sequences.
730                differing = '%ss differ:\n' % seq_type_name.capitalize()
731                if len1 > len2:
732                    differing = ('First %s contains %d additional '
733                                 'elements.\n' % (seq_type_name, len1 - len2))
734                    try:
735                        differing += ('First extra element %d:\n%s\n' %
736                                      (len2, seq1[len2]))
737                    except (TypeError, IndexError, NotImplementedError):
738                        differing += ('Unable to index element %d '
739                                      'of first %s\n' % (len2, seq_type_name))
740                elif len1 < len2:
741                    differing = ('Second %s contains %d additional '
742                                 'elements.\n' % (seq_type_name, len2 - len1))
743                    try:
744                        differing += ('First extra element %d:\n%s\n' %
745                                      (len1, seq2[len1]))
746                    except (TypeError, IndexError, NotImplementedError):
747                        differing += ('Unable to index element %d '
748                                      'of second %s\n' % (len1, seq_type_name))
749        standardMsg = differing + '\n'.join(difflib.ndiff(pprint.pformat(seq1).splitlines(),
750                                            pprint.pformat(seq2).splitlines()))
751        msg = self._formatMessage(msg, standardMsg)
752        self.fail(msg)
753
754    def assertListEqual(self, list1, list2, msg=None):
755        """A list-specific equality assertion.
756
757        Args:
758            list1: The first list to compare.
759            list2: The second list to compare.
760            msg: Optional message to use on failure instead of a list of
761                    differences.
762
763        """
764        self.assertSequenceEqual(list1, list2, msg, seq_type=list)
765
766    def assertTupleEqual(self, tuple1, tuple2, msg=None):
767        """A tuple-specific equality assertion.
768
769        Args:
770            tuple1: The first tuple to compare.
771            tuple2: The second tuple to compare.
772            msg: Optional message to use on failure instead of a list of
773                    differences.
774        """
775        self.assertSequenceEqual(tuple1, tuple2, msg, seq_type=tuple)
776
777    def assertSetEqual(self, set1, set2, msg=None):
778        """A set-specific equality assertion.
779
780        Args:
781            set1: The first set to compare.
782            set2: The second set to compare.
783            msg: Optional message to use on failure instead of a list of
784                    differences.
785
786        For more general containership equality, assertSameElements will work
787        with things other than sets.    This uses ducktyping to support
788        different types of sets, and is optimized for sets specifically
789        (parameters must support a difference method).
790        """
791        try:
792            difference1 = set1.difference(set2)
793        except TypeError, e:
794            self.fail('invalid type when attempting set difference: %s' % e)
795        except AttributeError, e:
796            self.fail('first argument does not support set difference: %s' % e)
797
798        try:
799            difference2 = set2.difference(set1)
800        except TypeError, e:
801            self.fail('invalid type when attempting set difference: %s' % e)
802        except AttributeError, e:
803            self.fail('second argument does not support set difference: %s' % e)
804
805        if not (difference1 or difference2):
806            return
807
808        lines = []
809        if difference1:
810            lines.append('Items in the first set but not the second:')
811            for item in difference1:
812                lines.append(repr(item))
813        if difference2:
814            lines.append('Items in the second set but not the first:')
815            for item in difference2:
816                lines.append(repr(item))
817
818        standardMsg = '\n'.join(lines)
819        self.fail(self._formatMessage(msg, standardMsg))
820
821    def assertIn(self, member, container, msg=None):
822        """Just like self.assertTrue(a in b), but with a nicer default message."""
823        if member not in container:
824            standardMsg = '%r not found in %r' % (member, container)
825            self.fail(self._formatMessage(msg, standardMsg))
826
827    def assertNotIn(self, member, container, msg=None):
828        """Just like self.assertTrue(a not in b), but with a nicer default message."""
829        if member in container:
830            standardMsg = '%r unexpectedly found in %r' % (member, container)
831            self.fail(self._formatMessage(msg, standardMsg))
832
833    def assertIs(self, expr1, expr2, msg=None):
834        """Just like self.assertTrue(a is b), but with a nicer default message."""
835        if expr1 is not expr2:
836            standardMsg = '%r is not %r' % (expr1, expr2)
837            self.fail(self._formatMessage(msg, standardMsg))
838
839    def assertIsNot(self, expr1, expr2, msg=None):
840        """Just like self.assertTrue(a is not b), but with a nicer default message."""
841        if expr1 is expr2:
842            standardMsg = 'unexpectedly identical: %r' % (expr1,)
843            self.fail(self._formatMessage(msg, standardMsg))
844
845    def assertDictEqual(self, d1, d2, msg=None):
846        self.assert_(isinstance(d1, dict), 'First argument is not a dictionary')
847        self.assert_(isinstance(d2, dict), 'Second argument is not a dictionary')
848
849        if d1 != d2:
850            standardMsg = ('\n' + '\n'.join(difflib.ndiff(
851                           pprint.pformat(d1).splitlines(),
852                           pprint.pformat(d2).splitlines())))
853            self.fail(self._formatMessage(msg, standardMsg))
854
855    def assertDictContainsSubset(self, expected, actual, msg=None):
856        """Checks whether actual is a superset of expected."""
857        missing = []
858        mismatched = []
859        for key, value in expected.iteritems():
860            if key not in actual:
861                missing.append(key)
862            elif value != actual[key]:
863                mismatched.append('%s, expected: %s, actual: %s' % (key, value,                                                                                                       actual[key]))
864
865        if not (missing or mismatched):
866            return
867
868        standardMsg = ''
869        if missing:
870            standardMsg = 'Missing: %r' % ','.join(missing)
871        if mismatched:
872            if standardMsg:
873                standardMsg += '; '
874            standardMsg += 'Mismatched values: %s' % ','.join(mismatched)
875
876        self.fail(self._formatMessage(msg, standardMsg))
877
878    def assertSameElements(self, expected_seq, actual_seq, msg=None):
879        """An unordered sequence specific comparison.
880
881        Raises with an error message listing which elements of expected_seq
882        are missing from actual_seq and vice versa if any.
883        """
884        try:
885            expected = set(expected_seq)
886            actual = set(actual_seq)
887            missing = list(expected.difference(actual))
888            unexpected = list(actual.difference(expected))
889            missing.sort()
890            unexpected.sort()
891        except TypeError:
892            # Fall back to slower list-compare if any of the objects are
893            # not hashable.
894            expected = list(expected_seq)
895            actual = list(actual_seq)
896            expected.sort()
897            actual.sort()
898            missing, unexpected = _SortedListDifference(expected, actual)
899        errors = []
900        if missing:
901            errors.append('Expected, but missing:\n    %r' % missing)
902        if unexpected:
903            errors.append('Unexpected, but present:\n    %r' % unexpected)
904        if errors:
905            standardMsg = '\n'.join(errors)
906            self.fail(self._formatMessage(msg, standardMsg))
907
908    def assertMultiLineEqual(self, first, second, msg=None):
909        """Assert that two multi-line strings are equal."""
910        self.assert_(isinstance(first, basestring), (
911                'First argument is not a string'))
912        self.assert_(isinstance(second, basestring), (
913                'Second argument is not a string'))
914
915        if first != second:
916            standardMsg = '\n' + ''.join(difflib.ndiff(first.splitlines(True), second.splitlines(True)))
917            self.fail(self._formatMessage(msg, standardMsg))
918
919    def assertLess(self, a, b, msg=None):
920        """Just like self.assertTrue(a < b), but with a nicer default message."""
921        if not a < b:
922            standardMsg = '%r not less than %r' % (a, b)
923            self.fail(self._formatMessage(msg, standardMsg))
924
925    def assertLessEqual(self, a, b, msg=None):
926        """Just like self.assertTrue(a <= b), but with a nicer default message."""
927        if not a <= b:
928            standardMsg = '%r not less than or equal to %r' % (a, b)
929            self.fail(self._formatMessage(msg, standardMsg))
930
931    def assertGreater(self, a, b, msg=None):
932        """Just like self.assertTrue(a > b), but with a nicer default message."""
933        if not a > b:
934            standardMsg = '%r not greater than %r' % (a, b)
935            self.fail(self._formatMessage(msg, standardMsg))
936
937    def assertGreaterEqual(self, a, b, msg=None):
938        """Just like self.assertTrue(a >= b), but with a nicer default message."""
939        if not a >= b:
940            standardMsg = '%r not greater than or equal to %r' % (a, b)
941            self.fail(self._formatMessage(msg, standardMsg))
942
943    def assertIsNone(self, obj, msg=None):
944        """Same as self.assertTrue(obj is None), with a nicer default message."""
945        if obj is not None:
946            standardMsg = '%r is not None' % obj
947            self.fail(self._formatMessage(msg, standardMsg))
948
949    def assertIsNotNone(self, obj, msg=None):
950        """Included for symmetry with assertIsNone."""
951        if obj is None:
952            standardMsg = 'unexpectedly None'
953            self.fail(self._formatMessage(msg, standardMsg))
954
955    def assertRaisesRegexp(self, expected_exception, expected_regexp,
956                           callable_obj=None, *args, **kwargs):
957        """Asserts that the message in a raised exception matches a regexp.
958
959        Args:
960            expected_exception: Exception class expected to be raised.
961            expected_regexp: Regexp (re pattern object or string) expected
962                    to be found in error message.
963            callable_obj: Function to be called.
964            args: Extra args.
965            kwargs: Extra kwargs.
966        """
967        context = _AssertRaisesContext(expected_exception, self, expected_regexp)
968        if callable_obj is None:
969            return context
970        # XXX (garrcoop): See comment above about `with context'.
971        _EmulateWith(context, lambda: callable_obj(*args, **kwargs))
972
973    def assertRegexpMatches(self, text, expected_regex, msg=None):
974        if isinstance(expected_regex, basestring):
975            expected_regex = re.compile(expected_regex)
976        if not expected_regex.search(text):
977            msg = msg or "Regexp didn't match"
978            msg = '%s: %r not found in %r' % (msg, expected_regex.pattern, text)
979            raise self.failureException(msg)
980
981
982def _SortedListDifference(expected, actual):
983    """Finds elements in only one or the other of two, sorted input lists.
984
985    Returns a two-element tuple of lists.    The first list contains those
986    elements in the "expected" list but not in the "actual" list, and the
987    second contains those elements in the "actual" list but not in the
988    "expected" list.    Duplicate elements in either input list are ignored.
989    """
990    i = j = 0
991    missing = []
992    unexpected = []
993    while True:
994        try:
995            e = expected[i]
996            a = actual[j]
997            if e < a:
998                missing.append(e)
999                i += 1
1000                while expected[i] == e:
1001                    i += 1
1002            elif e > a:
1003                unexpected.append(a)
1004                j += 1
1005                while actual[j] == a:
1006                    j += 1
1007            else:
1008                i += 1
1009                try:
1010                    while expected[i] == e:
1011                        i += 1
1012                finally:
1013                    j += 1
1014                    while actual[j] == a:
1015                        j += 1
1016        except IndexError:
1017            missing.extend(expected[i:])
1018            unexpected.extend(actual[j:])
1019            break
1020    return missing, unexpected
1021
1022
1023class TestSuite(object):
1024    """A test suite is a composite test consisting of a number of TestCases.
1025
1026    For use, create an instance of TestSuite, then add test case instances.
1027    When all tests have been added, the suite can be passed to a test
1028    runner, such as TextTestRunner. It will run the individual test cases
1029    in the order in which they were added, aggregating the results. When
1030    subclassing, do not forget to call the base class constructor.
1031    """
1032    def __init__(self, tests=()):
1033        self._tests = []
1034        self.addTests(tests)
1035
1036    def __repr__(self):
1037        return "<%s tests=%s>" % (_strclass(self.__class__), list(self))
1038
1039    def __eq__(self, other):
1040        if not isinstance(other, self.__class__):
1041            return NotImplemented
1042        return self._tests == other._tests
1043
1044    def __ne__(self, other):
1045        return not self == other
1046
1047    # Can't guarantee hash invariant, so flag as unhashable
1048    __hash__ = None
1049
1050    def __iter__(self):
1051        return iter(self._tests)
1052
1053    def countTestCases(self):
1054        cases = 0
1055        for test in self:
1056            cases += test.countTestCases()
1057        return cases
1058
1059    def addTest(self, test):
1060        # sanity checks
1061        if not hasattr(test, '__call__'):
1062            raise TypeError("the test to add must be callable")
1063        if isinstance(test, type) and issubclass(test, (TestCase, TestSuite)):
1064            raise TypeError("TestCases and TestSuites must be instantiated "
1065                            "before passing them to addTest()")
1066        self._tests.append(test)
1067
1068    def addTests(self, tests):
1069        if isinstance(tests, basestring):
1070            raise TypeError("tests must be an iterable of tests, not a string")
1071        for test in tests:
1072            self.addTest(test)
1073
1074    def run(self, result):
1075        for test in self:
1076            if result.shouldStop:
1077                break
1078            test(result)
1079        return result
1080
1081    def __call__(self, *args, **kwds):
1082        return self.run(*args, **kwds)
1083
1084    def debug(self):
1085        """Run the tests without collecting errors in a TestResult"""
1086        for test in self:
1087            test.debug()
1088
1089
1090class ClassTestSuite(TestSuite):
1091    """
1092    Suite of tests derived from a single TestCase class.
1093    """
1094
1095    def __init__(self, tests, class_collected_from):
1096        super(ClassTestSuite, self).__init__(tests)
1097        self.collected_from = class_collected_from
1098
1099    def id(self):
1100        module = getattr(self.collected_from, "__module__", None)
1101        if module is not None:
1102            return "%s.%s" % (str(module), str(self.collected_from.__name__))
1103        return self.collected_from.__name__
1104
1105    def run(self, result):
1106        if getattr(self.collected_from, "__unittest_skip__", False):
1107            # ClassTestSuite result pretends to be a TestCase enough to be
1108            # reported.
1109            result.startTest(self)
1110            try:
1111                result.addSkip(self, self.collected_from.__unittest_skip_why__)
1112            finally:
1113                result.stopTest(self)
1114        else:
1115            result = super(ClassTestSuite, self).run(result)
1116        return result
1117
1118    shortDescription = id
1119
1120
1121class FunctionTestCase(TestCase):
1122    """A test case that wraps a test function.
1123
1124    This is useful for slipping pre-existing test functions into the
1125    unittest framework. Optionally, set-up and tidy-up functions can be
1126    supplied. As with TestCase, the tidy-up ('tearDown') function will
1127    always be called if the set-up ('setUp') function ran successfully.
1128    """
1129
1130    def __init__(self, testFunc, setUp=None, tearDown=None, description=None):
1131        super(FunctionTestCase, self).__init__()
1132        self._setUpFunc = setUp
1133        self._tearDownFunc = tearDown
1134        self._testFunc = testFunc
1135        self._description = description
1136
1137    def setUp(self):
1138        if self._setUpFunc is not None:
1139            self._setUpFunc()
1140
1141    def tearDown(self):
1142        if self._tearDownFunc is not None:
1143            self._tearDownFunc()
1144
1145    def runTest(self):
1146        self._testFunc()
1147
1148    def id(self):
1149        return self._testFunc.__name__
1150
1151    def __eq__(self, other):
1152        if not isinstance(other, self.__class__):
1153            return NotImplemented
1154
1155        return self._setUpFunc == other._setUpFunc and \
1156               self._tearDownFunc == other._tearDownFunc and \
1157               self._testFunc == other._testFunc and \
1158               self._description == other._description
1159
1160    def __ne__(self, other):
1161        return not self == other
1162
1163    def __hash__(self):
1164        return hash((type(self), self._setUpFunc, self._tearDownFunc,
1165                     self._testFunc, self._description))
1166
1167    def __str__(self):
1168        return "%s (%s)" % (_strclass(self.__class__), self._testFunc.__name__)
1169
1170    def __repr__(self):
1171        return "<%s testFunc=%s>" % (_strclass(self.__class__), self._testFunc)
1172
1173    def shortDescription(self):
1174        if self._description is not None:
1175            return self._description
1176        doc = self._testFunc.__doc__
1177        return doc and doc.split("\n")[0].strip() or None
1178
1179
1180
1181##############################################################################
1182# Locating and loading tests
1183##############################################################################
1184
1185class TestLoader(object):
1186    """
1187    This class is responsible for loading tests according to various criteria
1188    and returning them wrapped in a TestSuite
1189    """
1190    testMethodPrefix = 'test'
1191    sortTestMethodsUsing = cmp
1192    suiteClass = TestSuite
1193    classSuiteClass = ClassTestSuite
1194
1195    def loadTestsFromTestCase(self, testCaseClass):
1196        """Return a suite of all tests cases contained in testCaseClass"""
1197        if issubclass(testCaseClass, TestSuite):
1198            raise TypeError("Test cases should not be derived from TestSuite." \
1199                                " Maybe you meant to derive from TestCase?")
1200        testCaseNames = self.getTestCaseNames(testCaseClass)
1201        if not testCaseNames and hasattr(testCaseClass, 'runTest'):
1202            testCaseNames = ['runTest']
1203        suite = self.classSuiteClass(map(testCaseClass, testCaseNames),
1204                                     testCaseClass)
1205        return suite
1206
1207    def loadTestsFromModule(self, module):
1208        """Return a suite of all tests cases contained in the given module"""
1209        tests = []
1210        for name in dir(module):
1211            obj = getattr(module, name)
1212            if isinstance(obj, type) and issubclass(obj, TestCase):
1213                tests.append(self.loadTestsFromTestCase(obj))
1214        return self.suiteClass(tests)
1215
1216    def loadTestsFromName(self, name, module=None):
1217        """Return a suite of all tests cases given a string specifier.
1218
1219        The name may resolve either to a module, a test case class, a
1220        test method within a test case class, or a callable object which
1221        returns a TestCase or TestSuite instance.
1222
1223        The method optionally resolves the names relative to a given module.
1224        """
1225        parts = name.split('.')
1226        if module is None:
1227            parts_copy = parts[:]
1228            while parts_copy:
1229                try:
1230                    module = __import__('.'.join(parts_copy))
1231                    break
1232                except ImportError:
1233                    del parts_copy[-1]
1234                    if not parts_copy:
1235                        raise
1236            parts = parts[1:]
1237        obj = module
1238        for part in parts:
1239            parent, obj = obj, getattr(obj, part)
1240
1241        if isinstance(obj, types.ModuleType):
1242            return self.loadTestsFromModule(obj)
1243        elif isinstance(obj, type) and issubclass(obj, TestCase):
1244            return self.loadTestsFromTestCase(obj)
1245        elif (isinstance(obj, types.UnboundMethodType) and
1246              isinstance(parent, type) and
1247              issubclass(parent, TestCase)):
1248            return TestSuite([parent(obj.__name__)])
1249        elif isinstance(obj, TestSuite):
1250            return obj
1251        elif hasattr(obj, '__call__'):
1252            test = obj()
1253            if isinstance(test, TestSuite):
1254                return test
1255            elif isinstance(test, TestCase):
1256                return TestSuite([test])
1257            else:
1258                raise TypeError("calling %s returned %s, not a test" %
1259                                (obj, test))
1260        else:
1261            raise TypeError("don't know how to make test from: %s" % obj)
1262
1263    def loadTestsFromNames(self, names, module=None):
1264        """Return a suite of all tests cases found using the given sequence
1265        of string specifiers. See 'loadTestsFromName()'.
1266        """
1267        suites = [self.loadTestsFromName(name, module) for name in names]
1268        return self.suiteClass(suites)
1269
1270    def getTestCaseNames(self, testCaseClass):
1271        """Return a sorted sequence of method names found within testCaseClass
1272        """
1273        def isTestMethod(attrname, testCaseClass=testCaseClass,
1274                         prefix=self.testMethodPrefix):
1275            return attrname.startswith(prefix) and \
1276                hasattr(getattr(testCaseClass, attrname), '__call__')
1277        testFnNames = filter(isTestMethod, dir(testCaseClass))
1278        if self.sortTestMethodsUsing:
1279            testFnNames.sort(key=_CmpToKey(self.sortTestMethodsUsing))
1280        return testFnNames
1281
1282
1283
1284defaultTestLoader = TestLoader()
1285
1286
1287##############################################################################
1288# Patches for old functions: these functions should be considered obsolete
1289##############################################################################
1290
1291def _makeLoader(prefix, sortUsing, suiteClass=None):
1292    loader = TestLoader()
1293    loader.sortTestMethodsUsing = sortUsing
1294    loader.testMethodPrefix = prefix
1295    if suiteClass: loader.suiteClass = suiteClass
1296    return loader
1297
1298def getTestCaseNames(testCaseClass, prefix, sortUsing=cmp):
1299    return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)
1300
1301def makeSuite(testCaseClass, prefix='test', sortUsing=cmp, suiteClass=TestSuite):
1302    return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(testCaseClass)
1303
1304def findTestCases(module, prefix='test', sortUsing=cmp, suiteClass=TestSuite):
1305    return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(module)
1306
1307
1308##############################################################################
1309# Text UI
1310##############################################################################
1311
1312class _WritelnDecorator(object):
1313    """Used to decorate file-like objects with a handy 'writeln' method"""
1314    def __init__(self,stream):
1315        self.stream = stream
1316
1317    def __getattr__(self, attr):
1318        return getattr(self.stream,attr)
1319
1320    def writeln(self, arg=None):
1321        if arg:
1322            self.write(arg)
1323        self.write('\n') # text-mode streams translate to \r\n if needed
1324
1325
1326class _TextTestResult(TestResult):
1327    """A test result class that can print formatted text results to a stream.
1328
1329    Used by TextTestRunner.
1330    """
1331    separator1 = '=' * 70
1332    separator2 = '-' * 70
1333
1334    def __init__(self, stream, descriptions, verbosity):
1335        super(_TextTestResult, self).__init__()
1336        self.stream = stream
1337        self.showAll = verbosity > 1
1338        self.dots = verbosity == 1
1339        self.descriptions = descriptions
1340
1341    def getDescription(self, test):
1342        if self.descriptions:
1343            return test.shortDescription() or str(test)
1344        else:
1345            return str(test)
1346
1347    def startTest(self, test):
1348        super(_TextTestResult, self).startTest(test)
1349        if self.showAll:
1350            self.stream.write(self.getDescription(test))
1351            self.stream.write(" ... ")
1352            self.stream.flush()
1353
1354    def addSuccess(self, test):
1355        super(_TextTestResult, self).addSuccess(test)
1356        if self.showAll:
1357            self.stream.writeln("ok")
1358        elif self.dots:
1359            self.stream.write('.')
1360            self.stream.flush()
1361
1362    def addError(self, test, err):
1363        super(_TextTestResult, self).addError(test, err)
1364        if self.showAll:
1365            self.stream.writeln("ERROR")
1366        elif self.dots:
1367            self.stream.write('E')
1368            self.stream.flush()
1369
1370    def addFailure(self, test, err):
1371        super(_TextTestResult, self).addFailure(test, err)
1372        if self.showAll:
1373            self.stream.writeln("FAIL")
1374        elif self.dots:
1375            self.stream.write('F')
1376            self.stream.flush()
1377
1378    def addSkip(self, test, reason):
1379        super(_TextTestResult, self).addSkip(test, reason)
1380        if self.showAll:
1381            self.stream.writeln("skipped '%s'" % str(reason))
1382        elif self.dots:
1383            self.stream.write("s")
1384            self.stream.flush()
1385
1386    def addExpectedFailure(self, test, err):
1387        super(_TextTestResult, self).addExpectedFailure(test, err)
1388        if self.showAll:
1389            self.stream.writeln("expected failure")
1390        elif self.dots:
1391            self.stream.write("x")
1392            self.stream.flush()
1393
1394    def addUnexpectedSuccess(self, test):
1395        super(_TextTestResult, self).addUnexpectedSuccess(test)
1396        if self.showAll:
1397            self.stream.writeln("unexpected success")
1398        elif self.dots:
1399            self.stream.write("u")
1400            self.stream.flush()
1401
1402    def printErrors(self):
1403        if self.dots or self.showAll:
1404            self.stream.writeln()
1405        self.printErrorList('ERROR', self.errors)
1406        self.printErrorList('FAIL', self.failures)
1407
1408    def printErrorList(self, flavour, errors):
1409        for test, err in errors:
1410            self.stream.writeln(self.separator1)
1411            self.stream.writeln("%s: %s" % (flavour,self.getDescription(test)))
1412            self.stream.writeln(self.separator2)
1413            self.stream.writeln("%s" % err)
1414
1415
1416class TextTestRunner(object):
1417    """A test runner class that displays results in textual form.
1418
1419    It prints out the names of tests as they are run, errors as they
1420    occur, and a summary of the results at the end of the test run.
1421    """
1422    def __init__(self, stream=sys.stderr, descriptions=1, verbosity=1):
1423        self.stream = _WritelnDecorator(stream)
1424        self.descriptions = descriptions
1425        self.verbosity = verbosity
1426
1427    def _makeResult(self):
1428        return _TextTestResult(self.stream, self.descriptions, self.verbosity)
1429
1430    def run(self, test):
1431        "Run the given test case or test suite."
1432        result = self._makeResult()
1433        startTime = time.time()
1434        test(result)
1435        stopTime = time.time()
1436        timeTaken = stopTime - startTime
1437        result.printErrors()
1438        self.stream.writeln(result.separator2)
1439        run = result.testsRun
1440        self.stream.writeln("Ran %d test%s in %.3fs" %
1441                            (run, run != 1 and "s" or "", timeTaken))
1442        self.stream.writeln()
1443        results = map(len, (result.expectedFailures,
1444                            result.unexpectedSuccesses,
1445                            result.skipped))
1446        expectedFails, unexpectedSuccesses, skipped = results
1447        infos = []
1448        if not result.wasSuccessful():
1449            self.stream.write("FAILED")
1450            failed, errored = map(len, (result.failures, result.errors))
1451            if failed:
1452                infos.append("failures=%d" % failed)
1453            if errored:
1454                infos.append("errors=%d" % errored)
1455        else:
1456            self.stream.write("OK")
1457        if skipped:
1458            infos.append("skipped=%d" % skipped)
1459        if expectedFails:
1460            infos.append("expected failures=%d" % expectedFails)
1461        if unexpectedSuccesses:
1462            infos.append("unexpected successes=%d" % unexpectedSuccesses)
1463        if infos:
1464            self.stream.writeln(" (%s)" % (", ".join(infos),))
1465        else:
1466            self.stream.write("\n")
1467        return result
1468
1469
1470
1471##############################################################################
1472# Facilities for running tests from the command line
1473##############################################################################
1474
1475class TestProgram(object):
1476    """A command-line program that runs a set of tests; this is primarily
1477       for making test modules conveniently executable.
1478    """
1479    USAGE = """\
1480Usage: %(progName)s [options] [test] [...]
1481
1482Options:
1483  -h, --help       Show this message
1484  -v, --verbose    Verbose output
1485  -q, --quiet      Minimal output
1486
1487Examples:
1488  %(progName)s                               - run default set of tests
1489  %(progName)s MyTestSuite                   - run suite 'MyTestSuite'
1490  %(progName)s MyTestCase.testSomething      - run MyTestCase.testSomething
1491  %(progName)s MyTestCase                    - run all 'test*' test methods
1492                                               in MyTestCase
1493"""
1494    def __init__(self, module='__main__', defaultTest=None,
1495                 argv=None, testRunner=TextTestRunner,
1496                 testLoader=defaultTestLoader):
1497        if isinstance(module, basestring):
1498            self.module = __import__(module)
1499            for part in module.split('.')[1:]:
1500                self.module = getattr(self.module, part)
1501        else:
1502            self.module = module
1503        if argv is None:
1504            argv = sys.argv
1505        self.verbosity = 1
1506        self.defaultTest = defaultTest
1507        self.testRunner = testRunner
1508        self.testLoader = testLoader
1509        self.progName = os.path.basename(argv[0])
1510        self.parseArgs(argv)
1511        self.runTests()
1512
1513    def usageExit(self, msg=None):
1514        if msg:
1515            print msg
1516        print self.USAGE % self.__dict__
1517        sys.exit(2)
1518
1519    def parseArgs(self, argv):
1520        import getopt
1521        long_opts = ['help','verbose','quiet']
1522        try:
1523            options, args = getopt.getopt(argv[1:], 'hHvq', long_opts)
1524            for opt, value in options:
1525                if opt in ('-h','-H','--help'):
1526                    self.usageExit()
1527                if opt in ('-q','--quiet'):
1528                    self.verbosity = 0
1529                if opt in ('-v','--verbose'):
1530                    self.verbosity = 2
1531            if len(args) == 0 and self.defaultTest is None:
1532                self.test = self.testLoader.loadTestsFromModule(self.module)
1533                return
1534            if len(args) > 0:
1535                self.testNames = args
1536            else:
1537                self.testNames = (self.defaultTest,)
1538            self.createTests()
1539        except getopt.error, msg:
1540            self.usageExit(msg)
1541
1542    def createTests(self):
1543        self.test = self.testLoader.loadTestsFromNames(self.testNames,
1544                                                       self.module)
1545
1546    def runTests(self):
1547        if isinstance(self.testRunner, (type, types.ClassType)):
1548            try:
1549                testRunner = self.testRunner(verbosity=self.verbosity)
1550            except TypeError:
1551                # didn't accept the verbosity argument
1552                testRunner = self.testRunner()
1553        else:
1554            # it is assumed to be a TestRunner instance
1555            testRunner = self.testRunner
1556        result = testRunner.run(self.test)
1557        sys.exit(not result.wasSuccessful())
1558
1559main = TestProgram
1560
1561
1562##############################################################################
1563# Executing this module from the command line
1564##############################################################################
1565
1566if __name__ == "__main__":
1567    main(module=None)
1568