1"""Unit tests for contextlib.py, and other context managers.""" 2 3import sys 4import tempfile 5import unittest 6from contextlib import * # Tests __all__ 7from test import test_support 8try: 9 import threading 10except ImportError: 11 threading = None 12 13 14class ContextManagerTestCase(unittest.TestCase): 15 16 def test_contextmanager_plain(self): 17 state = [] 18 @contextmanager 19 def woohoo(): 20 state.append(1) 21 yield 42 22 state.append(999) 23 with woohoo() as x: 24 self.assertEqual(state, [1]) 25 self.assertEqual(x, 42) 26 state.append(x) 27 self.assertEqual(state, [1, 42, 999]) 28 29 def test_contextmanager_finally(self): 30 state = [] 31 @contextmanager 32 def woohoo(): 33 state.append(1) 34 try: 35 yield 42 36 finally: 37 state.append(999) 38 with self.assertRaises(ZeroDivisionError): 39 with woohoo() as x: 40 self.assertEqual(state, [1]) 41 self.assertEqual(x, 42) 42 state.append(x) 43 raise ZeroDivisionError() 44 self.assertEqual(state, [1, 42, 999]) 45 46 def test_contextmanager_no_reraise(self): 47 @contextmanager 48 def whee(): 49 yield 50 ctx = whee() 51 ctx.__enter__() 52 # Calling __exit__ should not result in an exception 53 self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None)) 54 55 def test_contextmanager_trap_yield_after_throw(self): 56 @contextmanager 57 def whoo(): 58 try: 59 yield 60 except: 61 yield 62 ctx = whoo() 63 ctx.__enter__() 64 self.assertRaises( 65 RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None 66 ) 67 68 def test_contextmanager_except(self): 69 state = [] 70 @contextmanager 71 def woohoo(): 72 state.append(1) 73 try: 74 yield 42 75 except ZeroDivisionError, e: 76 state.append(e.args[0]) 77 self.assertEqual(state, [1, 42, 999]) 78 with woohoo() as x: 79 self.assertEqual(state, [1]) 80 self.assertEqual(x, 42) 81 state.append(x) 82 raise ZeroDivisionError(999) 83 self.assertEqual(state, [1, 42, 999]) 84 85 def _create_contextmanager_attribs(self): 86 def attribs(**kw): 87 def decorate(func): 88 for k,v in kw.items(): 89 setattr(func,k,v) 90 return func 91 return decorate 92 @contextmanager 93 @attribs(foo='bar') 94 def baz(spam): 95 """Whee!""" 96 return baz 97 98 def test_contextmanager_attribs(self): 99 baz = self._create_contextmanager_attribs() 100 self.assertEqual(baz.__name__,'baz') 101 self.assertEqual(baz.foo, 'bar') 102 103 @unittest.skipIf(sys.flags.optimize >= 2, 104 "Docstrings are omitted with -O2 and above") 105 def test_contextmanager_doc_attrib(self): 106 baz = self._create_contextmanager_attribs() 107 self.assertEqual(baz.__doc__, "Whee!") 108 109 def test_keywords(self): 110 # Ensure no keyword arguments are inhibited 111 @contextmanager 112 def woohoo(self, func, args, kwds): 113 yield (self, func, args, kwds) 114 with woohoo(self=11, func=22, args=33, kwds=44) as target: 115 self.assertEqual(target, (11, 22, 33, 44)) 116 117class NestedTestCase(unittest.TestCase): 118 119 # XXX This needs more work 120 121 def test_nested(self): 122 @contextmanager 123 def a(): 124 yield 1 125 @contextmanager 126 def b(): 127 yield 2 128 @contextmanager 129 def c(): 130 yield 3 131 with nested(a(), b(), c()) as (x, y, z): 132 self.assertEqual(x, 1) 133 self.assertEqual(y, 2) 134 self.assertEqual(z, 3) 135 136 def test_nested_cleanup(self): 137 state = [] 138 @contextmanager 139 def a(): 140 state.append(1) 141 try: 142 yield 2 143 finally: 144 state.append(3) 145 @contextmanager 146 def b(): 147 state.append(4) 148 try: 149 yield 5 150 finally: 151 state.append(6) 152 with self.assertRaises(ZeroDivisionError): 153 with nested(a(), b()) as (x, y): 154 state.append(x) 155 state.append(y) 156 1 // 0 157 self.assertEqual(state, [1, 4, 2, 5, 6, 3]) 158 159 def test_nested_right_exception(self): 160 @contextmanager 161 def a(): 162 yield 1 163 class b(object): 164 def __enter__(self): 165 return 2 166 def __exit__(self, *exc_info): 167 try: 168 raise Exception() 169 except: 170 pass 171 with self.assertRaises(ZeroDivisionError): 172 with nested(a(), b()) as (x, y): 173 1 // 0 174 self.assertEqual((x, y), (1, 2)) 175 176 def test_nested_b_swallows(self): 177 @contextmanager 178 def a(): 179 yield 180 @contextmanager 181 def b(): 182 try: 183 yield 184 except: 185 # Swallow the exception 186 pass 187 try: 188 with nested(a(), b()): 189 1 // 0 190 except ZeroDivisionError: 191 self.fail("Didn't swallow ZeroDivisionError") 192 193 def test_nested_break(self): 194 @contextmanager 195 def a(): 196 yield 197 state = 0 198 while True: 199 state += 1 200 with nested(a(), a()): 201 break 202 state += 10 203 self.assertEqual(state, 1) 204 205 def test_nested_continue(self): 206 @contextmanager 207 def a(): 208 yield 209 state = 0 210 while state < 3: 211 state += 1 212 with nested(a(), a()): 213 continue 214 state += 10 215 self.assertEqual(state, 3) 216 217 def test_nested_return(self): 218 @contextmanager 219 def a(): 220 try: 221 yield 222 except: 223 pass 224 def foo(): 225 with nested(a(), a()): 226 return 1 227 return 10 228 self.assertEqual(foo(), 1) 229 230class ClosingTestCase(unittest.TestCase): 231 232 # XXX This needs more work 233 234 def test_closing(self): 235 state = [] 236 class C: 237 def close(self): 238 state.append(1) 239 x = C() 240 self.assertEqual(state, []) 241 with closing(x) as y: 242 self.assertEqual(x, y) 243 self.assertEqual(state, [1]) 244 245 def test_closing_error(self): 246 state = [] 247 class C: 248 def close(self): 249 state.append(1) 250 x = C() 251 self.assertEqual(state, []) 252 with self.assertRaises(ZeroDivisionError): 253 with closing(x) as y: 254 self.assertEqual(x, y) 255 1 // 0 256 self.assertEqual(state, [1]) 257 258class FileContextTestCase(unittest.TestCase): 259 260 def testWithOpen(self): 261 tfn = tempfile.mktemp() 262 try: 263 f = None 264 with open(tfn, "w") as f: 265 self.assertFalse(f.closed) 266 f.write("Booh\n") 267 self.assertTrue(f.closed) 268 f = None 269 with self.assertRaises(ZeroDivisionError): 270 with open(tfn, "r") as f: 271 self.assertFalse(f.closed) 272 self.assertEqual(f.read(), "Booh\n") 273 1 // 0 274 self.assertTrue(f.closed) 275 finally: 276 test_support.unlink(tfn) 277 278@unittest.skipUnless(threading, 'Threading required for this test.') 279class LockContextTestCase(unittest.TestCase): 280 281 def boilerPlate(self, lock, locked): 282 self.assertFalse(locked()) 283 with lock: 284 self.assertTrue(locked()) 285 self.assertFalse(locked()) 286 with self.assertRaises(ZeroDivisionError): 287 with lock: 288 self.assertTrue(locked()) 289 1 // 0 290 self.assertFalse(locked()) 291 292 def testWithLock(self): 293 lock = threading.Lock() 294 self.boilerPlate(lock, lock.locked) 295 296 def testWithRLock(self): 297 lock = threading.RLock() 298 self.boilerPlate(lock, lock._is_owned) 299 300 def testWithCondition(self): 301 lock = threading.Condition() 302 def locked(): 303 return lock._is_owned() 304 self.boilerPlate(lock, locked) 305 306 def testWithSemaphore(self): 307 lock = threading.Semaphore() 308 def locked(): 309 if lock.acquire(False): 310 lock.release() 311 return False 312 else: 313 return True 314 self.boilerPlate(lock, locked) 315 316 def testWithBoundedSemaphore(self): 317 lock = threading.BoundedSemaphore() 318 def locked(): 319 if lock.acquire(False): 320 lock.release() 321 return False 322 else: 323 return True 324 self.boilerPlate(lock, locked) 325 326# This is needed to make the test actually run under regrtest.py! 327def test_main(): 328 with test_support.check_warnings(("With-statements now directly support " 329 "multiple context managers", 330 DeprecationWarning)): 331 test_support.run_unittest(__name__) 332 333if __name__ == "__main__": 334 test_main() 335