• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import contextlib
2import functools
3import io
4import re
5import sqlite3
6import test.support
7
8
9# Helper for temporary memory databases
10def memory_database(*args, **kwargs):
11    cx = sqlite3.connect(":memory:", *args, **kwargs)
12    return contextlib.closing(cx)
13
14
15# Temporarily limit a database connection parameter
16@contextlib.contextmanager
17def cx_limit(cx, category=sqlite3.SQLITE_LIMIT_SQL_LENGTH, limit=128):
18    try:
19        _prev = cx.setlimit(category, limit)
20        yield limit
21    finally:
22        cx.setlimit(category, _prev)
23
24
25def with_tracebacks(exc, regex="", name=""):
26    """Convenience decorator for testing callback tracebacks."""
27    def decorator(func):
28        _regex = re.compile(regex) if regex else None
29        @functools.wraps(func)
30        def wrapper(self, *args, **kwargs):
31            with test.support.catch_unraisable_exception() as cm:
32                # First, run the test with traceback enabled.
33                with check_tracebacks(self, cm, exc, _regex, name):
34                    func(self, *args, **kwargs)
35
36            # Then run the test with traceback disabled.
37            func(self, *args, **kwargs)
38        return wrapper
39    return decorator
40
41
42@contextlib.contextmanager
43def check_tracebacks(self, cm, exc, regex, obj_name):
44    """Convenience context manager for testing callback tracebacks."""
45    sqlite3.enable_callback_tracebacks(True)
46    try:
47        buf = io.StringIO()
48        with contextlib.redirect_stderr(buf):
49            yield
50
51        self.assertEqual(cm.unraisable.exc_type, exc)
52        if regex:
53            msg = str(cm.unraisable.exc_value)
54            self.assertIsNotNone(regex.search(msg))
55        if obj_name:
56            self.assertEqual(cm.unraisable.object.__name__, obj_name)
57    finally:
58        sqlite3.enable_callback_tracebacks(False)
59
60
61class MemoryDatabaseMixin:
62
63    def setUp(self):
64        self.con = sqlite3.connect(":memory:")
65        self.cur = self.con.cursor()
66
67    def tearDown(self):
68        self.cur.close()
69        self.con.close()
70
71    @property
72    def cx(self):
73        return self.con
74
75    @property
76    def cu(self):
77        return self.cur
78