• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import sys
2import test.support
3import unittest
4from contextlib import closing
5from functools import partial
6from pathlib import Path
7from test.support import cpython_only, import_helper, os_helper
8
9dbm_sqlite3 = import_helper.import_module("dbm.sqlite3")
10# N.B. The test will fail on some platforms without sqlite3
11# if the sqlite3 import is above the import of dbm.sqlite3.
12# This is deliberate: if the import helper managed to import dbm.sqlite3,
13# we must inevitably be able to import sqlite3. Else, we have a problem.
14import sqlite3
15from dbm.sqlite3 import _normalize_uri
16
17
18class _SQLiteDbmTests(unittest.TestCase):
19
20    def setUp(self):
21        self.filename = os_helper.TESTFN
22        db = dbm_sqlite3.open(self.filename, "c")
23        db.close()
24
25    def tearDown(self):
26        for suffix in "", "-wal", "-shm":
27            os_helper.unlink(self.filename + suffix)
28
29
30class URI(unittest.TestCase):
31
32    def test_uri_substitutions(self):
33        dataset = (
34            ("/absolute/////b/c", "/absolute/b/c"),
35            ("PRE#MID##END", "PRE%23MID%23%23END"),
36            ("%#?%%#", "%25%23%3F%25%25%23"),
37        )
38        for path, normalized in dataset:
39            with self.subTest(path=path, normalized=normalized):
40                self.assertTrue(_normalize_uri(path).endswith(normalized))
41
42    @unittest.skipUnless(sys.platform == "win32", "requires Windows")
43    def test_uri_windows(self):
44        dataset = (
45            # Relative subdir.
46            (r"2018\January.xlsx",
47             "2018/January.xlsx"),
48            # Absolute with drive letter.
49            (r"C:\Projects\apilibrary\apilibrary.sln",
50             "/C:/Projects/apilibrary/apilibrary.sln"),
51            # Relative with drive letter.
52            (r"C:Projects\apilibrary\apilibrary.sln",
53             "/C:Projects/apilibrary/apilibrary.sln"),
54        )
55        for path, normalized in dataset:
56            with self.subTest(path=path, normalized=normalized):
57                if not Path(path).is_absolute():
58                    self.skipTest(f"skipping relative path: {path!r}")
59                self.assertTrue(_normalize_uri(path).endswith(normalized))
60
61
62class ReadOnly(_SQLiteDbmTests):
63
64    def setUp(self):
65        super().setUp()
66        with dbm_sqlite3.open(self.filename, "w") as db:
67            db[b"key1"] = "value1"
68            db[b"key2"] = "value2"
69        self.db = dbm_sqlite3.open(self.filename, "r")
70
71    def tearDown(self):
72        self.db.close()
73        super().tearDown()
74
75    def test_readonly_read(self):
76        self.assertEqual(self.db[b"key1"], b"value1")
77        self.assertEqual(self.db[b"key2"], b"value2")
78
79    def test_readonly_write(self):
80        with self.assertRaises(dbm_sqlite3.error):
81            self.db[b"new"] = "value"
82
83    def test_readonly_delete(self):
84        with self.assertRaises(dbm_sqlite3.error):
85            del self.db[b"key1"]
86
87    def test_readonly_keys(self):
88        self.assertEqual(self.db.keys(), [b"key1", b"key2"])
89
90    def test_readonly_iter(self):
91        self.assertEqual([k for k in self.db], [b"key1", b"key2"])
92
93
94class ReadWrite(_SQLiteDbmTests):
95
96    def setUp(self):
97        super().setUp()
98        self.db = dbm_sqlite3.open(self.filename, "w")
99
100    def tearDown(self):
101        self.db.close()
102        super().tearDown()
103
104    def db_content(self):
105        with closing(sqlite3.connect(self.filename)) as cx:
106            keys = [r[0] for r in cx.execute("SELECT key FROM Dict")]
107            vals = [r[0] for r in cx.execute("SELECT value FROM Dict")]
108        return keys, vals
109
110    def test_readwrite_unique_key(self):
111        self.db["key"] = "value"
112        self.db["key"] = "other"
113        keys, vals = self.db_content()
114        self.assertEqual(keys, [b"key"])
115        self.assertEqual(vals, [b"other"])
116
117    def test_readwrite_delete(self):
118        self.db["key"] = "value"
119        self.db["new"] = "other"
120
121        del self.db[b"new"]
122        keys, vals = self.db_content()
123        self.assertEqual(keys, [b"key"])
124        self.assertEqual(vals, [b"value"])
125
126        del self.db[b"key"]
127        keys, vals = self.db_content()
128        self.assertEqual(keys, [])
129        self.assertEqual(vals, [])
130
131    def test_readwrite_null_key(self):
132        with self.assertRaises(dbm_sqlite3.error):
133            self.db[None] = "value"
134
135    def test_readwrite_null_value(self):
136        with self.assertRaises(dbm_sqlite3.error):
137            self.db[b"key"] = None
138
139
140class Misuse(_SQLiteDbmTests):
141
142    def setUp(self):
143        super().setUp()
144        self.db = dbm_sqlite3.open(self.filename, "w")
145
146    def tearDown(self):
147        self.db.close()
148        super().tearDown()
149
150    def test_misuse_double_create(self):
151        self.db["key"] = "value"
152        with dbm_sqlite3.open(self.filename, "c") as db:
153            self.assertEqual(db[b"key"], b"value")
154
155    def test_misuse_double_close(self):
156        self.db.close()
157
158    def test_misuse_invalid_flag(self):
159        regex = "must be.*'r'.*'w'.*'c'.*'n', not 'invalid'"
160        with self.assertRaisesRegex(ValueError, regex):
161            dbm_sqlite3.open(self.filename, flag="invalid")
162
163    def test_misuse_double_delete(self):
164        self.db["key"] = "value"
165        del self.db[b"key"]
166        with self.assertRaises(KeyError):
167            del self.db[b"key"]
168
169    def test_misuse_invalid_key(self):
170        with self.assertRaises(KeyError):
171            self.db[b"key"]
172
173    def test_misuse_iter_close1(self):
174        self.db["1"] = 1
175        it = iter(self.db)
176        self.db.close()
177        with self.assertRaises(dbm_sqlite3.error):
178            next(it)
179
180    def test_misuse_iter_close2(self):
181        self.db["1"] = 1
182        self.db["2"] = 2
183        it = iter(self.db)
184        next(it)
185        self.db.close()
186        with self.assertRaises(dbm_sqlite3.error):
187            next(it)
188
189    def test_misuse_use_after_close(self):
190        self.db.close()
191        with self.assertRaises(dbm_sqlite3.error):
192            self.db[b"read"]
193        with self.assertRaises(dbm_sqlite3.error):
194            self.db[b"write"] = "value"
195        with self.assertRaises(dbm_sqlite3.error):
196            del self.db[b"del"]
197        with self.assertRaises(dbm_sqlite3.error):
198            len(self.db)
199        with self.assertRaises(dbm_sqlite3.error):
200            self.db.keys()
201
202    def test_misuse_reinit(self):
203        with self.assertRaises(dbm_sqlite3.error):
204            self.db.__init__("new.db", flag="n", mode=0o666)
205
206    def test_misuse_empty_filename(self):
207        for flag in "r", "w", "c", "n":
208            with self.assertRaises(dbm_sqlite3.error):
209                db = dbm_sqlite3.open("", flag="c")
210
211
212class DataTypes(_SQLiteDbmTests):
213
214    dataset = (
215        # (raw, coerced)
216        (42, b"42"),
217        (3.14, b"3.14"),
218        ("string", b"string"),
219        (b"bytes", b"bytes"),
220    )
221
222    def setUp(self):
223        super().setUp()
224        self.db = dbm_sqlite3.open(self.filename, "w")
225
226    def tearDown(self):
227        self.db.close()
228        super().tearDown()
229
230    def test_datatypes_values(self):
231        for raw, coerced in self.dataset:
232            with self.subTest(raw=raw, coerced=coerced):
233                self.db["key"] = raw
234                self.assertEqual(self.db[b"key"], coerced)
235
236    def test_datatypes_keys(self):
237        for raw, coerced in self.dataset:
238            with self.subTest(raw=raw, coerced=coerced):
239                self.db[raw] = "value"
240                self.assertEqual(self.db[coerced], b"value")
241                # Raw keys are silently coerced to bytes.
242                self.assertEqual(self.db[raw], b"value")
243                del self.db[raw]
244
245    def test_datatypes_replace_coerced(self):
246        self.db["10"] = "value"
247        self.db[b"10"] = "value"
248        self.db[10] = "value"
249        self.assertEqual(self.db.keys(), [b"10"])
250
251
252class CorruptDatabase(_SQLiteDbmTests):
253    """Verify that database exceptions are raised as dbm.sqlite3.error."""
254
255    def setUp(self):
256        super().setUp()
257        with closing(sqlite3.connect(self.filename)) as cx:
258            with cx:
259                cx.execute("DROP TABLE IF EXISTS Dict")
260                cx.execute("CREATE TABLE Dict (invalid_schema)")
261
262    def check(self, flag, fn, should_succeed=False):
263        with closing(dbm_sqlite3.open(self.filename, flag)) as db:
264            with self.assertRaises(dbm_sqlite3.error):
265                fn(db)
266
267    @staticmethod
268    def read(db):
269        return db["key"]
270
271    @staticmethod
272    def write(db):
273        db["key"] = "value"
274
275    @staticmethod
276    def iter(db):
277        next(iter(db))
278
279    @staticmethod
280    def keys(db):
281        db.keys()
282
283    @staticmethod
284    def del_(db):
285        del db["key"]
286
287    @staticmethod
288    def len_(db):
289        len(db)
290
291    def test_corrupt_readwrite(self):
292        for flag in "r", "w", "c":
293            with self.subTest(flag=flag):
294                check = partial(self.check, flag=flag)
295                check(fn=self.read)
296                check(fn=self.write)
297                check(fn=self.iter)
298                check(fn=self.keys)
299                check(fn=self.del_)
300                check(fn=self.len_)
301
302    def test_corrupt_force_new(self):
303        with closing(dbm_sqlite3.open(self.filename, "n")) as db:
304            db["foo"] = "write"
305            _ = db[b"foo"]
306            next(iter(db))
307            del db[b"foo"]
308
309
310if __name__ == "__main__":
311    unittest.main()
312