1import contextlib 2import functools 3import io 4import re 5import sqlite3 6import test.support 7 8 9# Helper for temporary memory databases 10def memory_database(*args, **kwargs): 11 cx = sqlite3.connect(":memory:", *args, **kwargs) 12 return contextlib.closing(cx) 13 14 15# Temporarily limit a database connection parameter 16@contextlib.contextmanager 17def cx_limit(cx, category=sqlite3.SQLITE_LIMIT_SQL_LENGTH, limit=128): 18 try: 19 _prev = cx.setlimit(category, limit) 20 yield limit 21 finally: 22 cx.setlimit(category, _prev) 23 24 25def with_tracebacks(exc, regex="", name=""): 26 """Convenience decorator for testing callback tracebacks.""" 27 def decorator(func): 28 _regex = re.compile(regex) if regex else None 29 @functools.wraps(func) 30 def wrapper(self, *args, **kwargs): 31 with test.support.catch_unraisable_exception() as cm: 32 # First, run the test with traceback enabled. 33 with check_tracebacks(self, cm, exc, _regex, name): 34 func(self, *args, **kwargs) 35 36 # Then run the test with traceback disabled. 37 func(self, *args, **kwargs) 38 return wrapper 39 return decorator 40 41 42@contextlib.contextmanager 43def check_tracebacks(self, cm, exc, regex, obj_name): 44 """Convenience context manager for testing callback tracebacks.""" 45 sqlite3.enable_callback_tracebacks(True) 46 try: 47 buf = io.StringIO() 48 with contextlib.redirect_stderr(buf): 49 yield 50 51 self.assertEqual(cm.unraisable.exc_type, exc) 52 if regex: 53 msg = str(cm.unraisable.exc_value) 54 self.assertIsNotNone(regex.search(msg)) 55 if obj_name: 56 self.assertEqual(cm.unraisable.object.__name__, obj_name) 57 finally: 58 sqlite3.enable_callback_tracebacks(False) 59 60 61class MemoryDatabaseMixin: 62 63 def setUp(self): 64 self.con = sqlite3.connect(":memory:") 65 self.cur = self.con.cursor() 66 67 def tearDown(self): 68 self.cur.close() 69 self.con.close() 70 71 @property 72 def cx(self): 73 return self.con 74 75 @property 76 def cu(self): 77 return self.cur 78