1# coding: utf-8 2from __future__ import unicode_literals, division, absolute_import, print_function 3 4import sys 5import unittest 6import re 7 8 9if sys.version_info < (3,): 10 str_cls = unicode # noqa 11else: 12 str_cls = str 13 14 15_non_local = {'patched': False} 16 17 18def patch(): 19 if sys.version_info >= (3, 0): 20 return 21 22 if _non_local['patched']: 23 return 24 25 if sys.version_info < (2, 7): 26 unittest.TestCase.assertIsInstance = _assert_is_instance 27 unittest.TestCase.assertRegex = _assert_regex 28 unittest.TestCase.assertRaises = _assert_raises 29 unittest.TestCase.assertRaisesRegex = _assert_raises_regex 30 unittest.TestCase.assertGreaterEqual = _assert_greater_equal 31 unittest.TestCase.assertLess = _assert_less 32 unittest.TestCase.assertLessEqual = _assert_less_equal 33 unittest.TestCase.assertIn = _assert_in 34 unittest.TestCase.assertNotIn = _assert_not_in 35 else: 36 unittest.TestCase.assertRegex = unittest.TestCase.assertRegexpMatches 37 unittest.TestCase.assertRaisesRegex = unittest.TestCase.assertRaisesRegexp 38 _non_local['patched'] = True 39 40 41def _safe_repr(obj): 42 try: 43 return repr(obj) 44 except Exception: 45 return object.__repr__(obj) 46 47 48def _format_message(msg, standard_msg): 49 return msg or standard_msg 50 51 52def _assert_greater_equal(self, a, b, msg=None): 53 if not a >= b: 54 standard_msg = '%s not greater than or equal to %s' % (_safe_repr(a), _safe_repr(b)) 55 self.fail(_format_message(msg, standard_msg)) 56 57 58def _assert_less(self, a, b, msg=None): 59 if not a < b: 60 standard_msg = '%s not less than %s' % (_safe_repr(a), _safe_repr(b)) 61 self.fail(_format_message(msg, standard_msg)) 62 63 64def _assert_less_equal(self, a, b, msg=None): 65 if not a <= b: 66 standard_msg = '%s not less than or equal to %s' % (_safe_repr(a), _safe_repr(b)) 67 self.fail(_format_message(msg, standard_msg)) 68 69 70def _assert_is_instance(self, obj, cls, msg=None): 71 if not isinstance(obj, cls): 72 if not msg: 73 msg = '%s is not an instance of %r' % (obj, cls) 74 self.fail(msg) 75 76 77def _assert_in(self, member, container, msg=None): 78 if member not in container: 79 standard_msg = '%s not found in %s' % (_safe_repr(member), _safe_repr(container)) 80 self.fail(_format_message(msg, standard_msg)) 81 82 83def _assert_not_in(self, member, container, msg=None): 84 if member in container: 85 standard_msg = '%s found in %s' % (_safe_repr(member), _safe_repr(container)) 86 self.fail(_format_message(msg, standard_msg)) 87 88 89def _assert_regex(self, text, expected_regexp, msg=None): 90 """Fail the test unless the text matches the regular expression.""" 91 if isinstance(expected_regexp, str_cls): 92 expected_regexp = re.compile(expected_regexp) 93 if not expected_regexp.search(text): 94 msg = msg or "Regexp didn't match" 95 msg = '%s: %r not found in %r' % (msg, expected_regexp.pattern, text) 96 self.fail(msg) 97 98 99def _assert_raises(self, excClass, callableObj=None, *args, **kwargs): # noqa 100 context = _AssertRaisesContext(excClass, self) 101 if callableObj is None: 102 return context 103 with context: 104 callableObj(*args, **kwargs) 105 106 107def _assert_raises_regex(self, expected_exception, expected_regexp, callable_obj=None, *args, **kwargs): 108 if expected_regexp is not None: 109 expected_regexp = re.compile(expected_regexp) 110 context = _AssertRaisesContext(expected_exception, self, expected_regexp) 111 if callable_obj is None: 112 return context 113 with context: 114 callable_obj(*args, **kwargs) 115 116 117class _AssertRaisesContext(object): 118 def __init__(self, expected, test_case, expected_regexp=None): 119 self.expected = expected 120 self.failureException = test_case.failureException 121 self.expected_regexp = expected_regexp 122 123 def __enter__(self): 124 return self 125 126 def __exit__(self, exc_type, exc_value, tb): 127 if exc_type is None: 128 try: 129 exc_name = self.expected.__name__ 130 except AttributeError: 131 exc_name = str(self.expected) 132 raise self.failureException( 133 "{0} not raised".format(exc_name)) 134 if not issubclass(exc_type, self.expected): 135 # let unexpected exceptions pass through 136 return False 137 self.exception = exc_value # store for later retrieval 138 if self.expected_regexp is None: 139 return True 140 141 expected_regexp = self.expected_regexp 142 if not expected_regexp.search(str(exc_value)): 143 raise self.failureException( 144 '"%s" does not match "%s"' % 145 (expected_regexp.pattern, str(exc_value)) 146 ) 147 return True 148