• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# coding: utf-8
2from __future__ import unicode_literals, division, absolute_import, print_function
3
4import sys
5import unittest
6import re
7
8
9if sys.version_info < (3,):
10    str_cls = unicode  # noqa
11else:
12    str_cls = str
13
14
15_non_local = {'patched': False}
16
17
18def patch():
19    if sys.version_info >= (3, 0):
20        return
21
22    if _non_local['patched']:
23        return
24
25    if sys.version_info < (2, 7):
26        unittest.TestCase.assertIsInstance = _assert_is_instance
27        unittest.TestCase.assertRegex = _assert_regex
28        unittest.TestCase.assertRaises = _assert_raises
29        unittest.TestCase.assertRaisesRegex = _assert_raises_regex
30        unittest.TestCase.assertGreaterEqual = _assert_greater_equal
31        unittest.TestCase.assertLess = _assert_less
32        unittest.TestCase.assertLessEqual = _assert_less_equal
33        unittest.TestCase.assertIn = _assert_in
34        unittest.TestCase.assertNotIn = _assert_not_in
35    else:
36        unittest.TestCase.assertRegex = unittest.TestCase.assertRegexpMatches
37        unittest.TestCase.assertRaisesRegex = unittest.TestCase.assertRaisesRegexp
38    _non_local['patched'] = True
39
40
41def _safe_repr(obj):
42    try:
43        return repr(obj)
44    except Exception:
45        return object.__repr__(obj)
46
47
48def _format_message(msg, standard_msg):
49    return msg or standard_msg
50
51
52def _assert_greater_equal(self, a, b, msg=None):
53    if not a >= b:
54        standard_msg = '%s not greater than or equal to %s' % (_safe_repr(a), _safe_repr(b))
55        self.fail(_format_message(msg, standard_msg))
56
57
58def _assert_less(self, a, b, msg=None):
59    if not a < b:
60        standard_msg = '%s not less than %s' % (_safe_repr(a), _safe_repr(b))
61        self.fail(_format_message(msg, standard_msg))
62
63
64def _assert_less_equal(self, a, b, msg=None):
65    if not a <= b:
66        standard_msg = '%s not less than or equal to %s' % (_safe_repr(a), _safe_repr(b))
67        self.fail(_format_message(msg, standard_msg))
68
69
70def _assert_is_instance(self, obj, cls, msg=None):
71    if not isinstance(obj, cls):
72        if not msg:
73            msg = '%s is not an instance of %r' % (obj, cls)
74        self.fail(msg)
75
76
77def _assert_in(self, member, container, msg=None):
78    if member not in container:
79        standard_msg = '%s not found in %s' % (_safe_repr(member), _safe_repr(container))
80        self.fail(_format_message(msg, standard_msg))
81
82
83def _assert_not_in(self, member, container, msg=None):
84    if member in container:
85        standard_msg = '%s found in %s' % (_safe_repr(member), _safe_repr(container))
86        self.fail(_format_message(msg, standard_msg))
87
88
89def _assert_regex(self, text, expected_regexp, msg=None):
90    """Fail the test unless the text matches the regular expression."""
91    if isinstance(expected_regexp, str_cls):
92        expected_regexp = re.compile(expected_regexp)
93    if not expected_regexp.search(text):
94        msg = msg or "Regexp didn't match"
95        msg = '%s: %r not found in %r' % (msg, expected_regexp.pattern, text)
96        self.fail(msg)
97
98
99def _assert_raises(self, excClass, callableObj=None, *args, **kwargs):  # noqa
100    context = _AssertRaisesContext(excClass, self)
101    if callableObj is None:
102        return context
103    with context:
104        callableObj(*args, **kwargs)
105
106
107def _assert_raises_regex(self, expected_exception, expected_regexp, callable_obj=None, *args, **kwargs):
108    if expected_regexp is not None:
109        expected_regexp = re.compile(expected_regexp)
110    context = _AssertRaisesContext(expected_exception, self, expected_regexp)
111    if callable_obj is None:
112        return context
113    with context:
114        callable_obj(*args, **kwargs)
115
116
117class _AssertRaisesContext(object):
118    def __init__(self, expected, test_case, expected_regexp=None):
119        self.expected = expected
120        self.failureException = test_case.failureException
121        self.expected_regexp = expected_regexp
122
123    def __enter__(self):
124        return self
125
126    def __exit__(self, exc_type, exc_value, tb):
127        if exc_type is None:
128            try:
129                exc_name = self.expected.__name__
130            except AttributeError:
131                exc_name = str(self.expected)
132            raise self.failureException(
133                "{0} not raised".format(exc_name))
134        if not issubclass(exc_type, self.expected):
135            # let unexpected exceptions pass through
136            return False
137        self.exception = exc_value  # store for later retrieval
138        if self.expected_regexp is None:
139            return True
140
141        expected_regexp = self.expected_regexp
142        if not expected_regexp.search(str(exc_value)):
143            raise self.failureException(
144                '"%s" does not match "%s"' %
145                (expected_regexp.pattern, str(exc_value))
146            )
147        return True
148