1# pysqlite2/test/factory.py: tests for the various factories in pysqlite 2# 3# Copyright (C) 2005-2007 Gerhard Häring <gh@ghaering.de> 4# 5# This file is part of pysqlite. 6# 7# This software is provided 'as-is', without any express or implied 8# warranty. In no event will the authors be held liable for any damages 9# arising from the use of this software. 10# 11# Permission is granted to anyone to use this software for any purpose, 12# including commercial applications, and to alter it and redistribute it 13# freely, subject to the following restrictions: 14# 15# 1. The origin of this software must not be misrepresented; you must not 16# claim that you wrote the original software. If you use this software 17# in a product, an acknowledgment in the product documentation would be 18# appreciated but is not required. 19# 2. Altered source versions must be plainly marked as such, and must not be 20# misrepresented as being the original software. 21# 3. This notice may not be removed or altered from any source distribution. 22 23import unittest 24import sqlite3 as sqlite 25from collections.abc import Sequence 26 27class MyConnection(sqlite.Connection): 28 def __init__(self, *args, **kwargs): 29 sqlite.Connection.__init__(self, *args, **kwargs) 30 31def dict_factory(cursor, row): 32 d = {} 33 for idx, col in enumerate(cursor.description): 34 d[col[0]] = row[idx] 35 return d 36 37class MyCursor(sqlite.Cursor): 38 def __init__(self, *args, **kwargs): 39 sqlite.Cursor.__init__(self, *args, **kwargs) 40 self.row_factory = dict_factory 41 42class ConnectionFactoryTests(unittest.TestCase): 43 def setUp(self): 44 self.con = sqlite.connect(":memory:", factory=MyConnection) 45 46 def tearDown(self): 47 self.con.close() 48 49 def test_is_instance(self): 50 self.assertIsInstance(self.con, MyConnection) 51 52class CursorFactoryTests(unittest.TestCase): 53 def setUp(self): 54 self.con = sqlite.connect(":memory:") 55 56 def tearDown(self): 57 self.con.close() 58 59 def test_is_instance(self): 60 cur = self.con.cursor() 61 self.assertIsInstance(cur, sqlite.Cursor) 62 cur = self.con.cursor(MyCursor) 63 self.assertIsInstance(cur, MyCursor) 64 cur = self.con.cursor(factory=lambda con: MyCursor(con)) 65 self.assertIsInstance(cur, MyCursor) 66 67 def test_invalid_factory(self): 68 # not a callable at all 69 self.assertRaises(TypeError, self.con.cursor, None) 70 # invalid callable with not exact one argument 71 self.assertRaises(TypeError, self.con.cursor, lambda: None) 72 # invalid callable returning non-cursor 73 self.assertRaises(TypeError, self.con.cursor, lambda con: None) 74 75class RowFactoryTestsBackwardsCompat(unittest.TestCase): 76 def setUp(self): 77 self.con = sqlite.connect(":memory:") 78 79 def test_is_produced_by_factory(self): 80 cur = self.con.cursor(factory=MyCursor) 81 cur.execute("select 4+5 as foo") 82 row = cur.fetchone() 83 self.assertIsInstance(row, dict) 84 cur.close() 85 86 def tearDown(self): 87 self.con.close() 88 89class RowFactoryTests(unittest.TestCase): 90 def setUp(self): 91 self.con = sqlite.connect(":memory:") 92 93 def test_custom_factory(self): 94 self.con.row_factory = lambda cur, row: list(row) 95 row = self.con.execute("select 1, 2").fetchone() 96 self.assertIsInstance(row, list) 97 98 def test_sqlite_row_index(self): 99 self.con.row_factory = sqlite.Row 100 row = self.con.execute("select 1 as a_1, 2 as b").fetchone() 101 self.assertIsInstance(row, sqlite.Row) 102 103 self.assertEqual(row["a_1"], 1, "by name: wrong result for column 'a_1'") 104 self.assertEqual(row["b"], 2, "by name: wrong result for column 'b'") 105 106 self.assertEqual(row["A_1"], 1, "by name: wrong result for column 'A_1'") 107 self.assertEqual(row["B"], 2, "by name: wrong result for column 'B'") 108 109 self.assertEqual(row[0], 1, "by index: wrong result for column 0") 110 self.assertEqual(row[1], 2, "by index: wrong result for column 1") 111 self.assertEqual(row[-1], 2, "by index: wrong result for column -1") 112 self.assertEqual(row[-2], 1, "by index: wrong result for column -2") 113 114 with self.assertRaises(IndexError): 115 row['c'] 116 with self.assertRaises(IndexError): 117 row['a_\x11'] 118 with self.assertRaises(IndexError): 119 row['a\x7f1'] 120 with self.assertRaises(IndexError): 121 row[2] 122 with self.assertRaises(IndexError): 123 row[-3] 124 with self.assertRaises(IndexError): 125 row[2**1000] 126 127 def test_sqlite_row_index_unicode(self): 128 self.con.row_factory = sqlite.Row 129 row = self.con.execute("select 1 as \xff").fetchone() 130 self.assertEqual(row["\xff"], 1) 131 with self.assertRaises(IndexError): 132 row['\u0178'] 133 with self.assertRaises(IndexError): 134 row['\xdf'] 135 136 def test_sqlite_row_slice(self): 137 # A sqlite.Row can be sliced like a list. 138 self.con.row_factory = sqlite.Row 139 row = self.con.execute("select 1, 2, 3, 4").fetchone() 140 self.assertEqual(row[0:0], ()) 141 self.assertEqual(row[0:1], (1,)) 142 self.assertEqual(row[1:3], (2, 3)) 143 self.assertEqual(row[3:1], ()) 144 # Explicit bounds are optional. 145 self.assertEqual(row[1:], (2, 3, 4)) 146 self.assertEqual(row[:3], (1, 2, 3)) 147 # Slices can use negative indices. 148 self.assertEqual(row[-2:-1], (3,)) 149 self.assertEqual(row[-2:], (3, 4)) 150 # Slicing supports steps. 151 self.assertEqual(row[0:4:2], (1, 3)) 152 self.assertEqual(row[3:0:-2], (4, 2)) 153 154 def test_sqlite_row_iter(self): 155 """Checks if the row object is iterable""" 156 self.con.row_factory = sqlite.Row 157 row = self.con.execute("select 1 as a, 2 as b").fetchone() 158 for col in row: 159 pass 160 161 def test_sqlite_row_as_tuple(self): 162 """Checks if the row object can be converted to a tuple""" 163 self.con.row_factory = sqlite.Row 164 row = self.con.execute("select 1 as a, 2 as b").fetchone() 165 t = tuple(row) 166 self.assertEqual(t, (row['a'], row['b'])) 167 168 def test_sqlite_row_as_dict(self): 169 """Checks if the row object can be correctly converted to a dictionary""" 170 self.con.row_factory = sqlite.Row 171 row = self.con.execute("select 1 as a, 2 as b").fetchone() 172 d = dict(row) 173 self.assertEqual(d["a"], row["a"]) 174 self.assertEqual(d["b"], row["b"]) 175 176 def test_sqlite_row_hash_cmp(self): 177 """Checks if the row object compares and hashes correctly""" 178 self.con.row_factory = sqlite.Row 179 row_1 = self.con.execute("select 1 as a, 2 as b").fetchone() 180 row_2 = self.con.execute("select 1 as a, 2 as b").fetchone() 181 row_3 = self.con.execute("select 1 as a, 3 as b").fetchone() 182 row_4 = self.con.execute("select 1 as b, 2 as a").fetchone() 183 row_5 = self.con.execute("select 2 as b, 1 as a").fetchone() 184 185 self.assertTrue(row_1 == row_1) 186 self.assertTrue(row_1 == row_2) 187 self.assertFalse(row_1 == row_3) 188 self.assertFalse(row_1 == row_4) 189 self.assertFalse(row_1 == row_5) 190 self.assertFalse(row_1 == object()) 191 192 self.assertFalse(row_1 != row_1) 193 self.assertFalse(row_1 != row_2) 194 self.assertTrue(row_1 != row_3) 195 self.assertTrue(row_1 != row_4) 196 self.assertTrue(row_1 != row_5) 197 self.assertTrue(row_1 != object()) 198 199 with self.assertRaises(TypeError): 200 row_1 > row_2 201 with self.assertRaises(TypeError): 202 row_1 < row_2 203 with self.assertRaises(TypeError): 204 row_1 >= row_2 205 with self.assertRaises(TypeError): 206 row_1 <= row_2 207 208 self.assertEqual(hash(row_1), hash(row_2)) 209 210 def test_sqlite_row_as_sequence(self): 211 """ Checks if the row object can act like a sequence """ 212 self.con.row_factory = sqlite.Row 213 row = self.con.execute("select 1 as a, 2 as b").fetchone() 214 215 as_tuple = tuple(row) 216 self.assertEqual(list(reversed(row)), list(reversed(as_tuple))) 217 self.assertIsInstance(row, Sequence) 218 219 def test_fake_cursor_class(self): 220 # Issue #24257: Incorrect use of PyObject_IsInstance() caused 221 # segmentation fault. 222 # Issue #27861: Also applies for cursor factory. 223 class FakeCursor(str): 224 __class__ = sqlite.Cursor 225 self.con.row_factory = sqlite.Row 226 self.assertRaises(TypeError, self.con.cursor, FakeCursor) 227 self.assertRaises(TypeError, sqlite.Row, FakeCursor(), ()) 228 229 def tearDown(self): 230 self.con.close() 231 232class TextFactoryTests(unittest.TestCase): 233 def setUp(self): 234 self.con = sqlite.connect(":memory:") 235 236 def test_unicode(self): 237 austria = "Österreich" 238 row = self.con.execute("select ?", (austria,)).fetchone() 239 self.assertEqual(type(row[0]), str, "type of row[0] must be unicode") 240 241 def test_string(self): 242 self.con.text_factory = bytes 243 austria = "Österreich" 244 row = self.con.execute("select ?", (austria,)).fetchone() 245 self.assertEqual(type(row[0]), bytes, "type of row[0] must be bytes") 246 self.assertEqual(row[0], austria.encode("utf-8"), "column must equal original data in UTF-8") 247 248 def test_custom(self): 249 self.con.text_factory = lambda x: str(x, "utf-8", "ignore") 250 austria = "Österreich" 251 row = self.con.execute("select ?", (austria,)).fetchone() 252 self.assertEqual(type(row[0]), str, "type of row[0] must be unicode") 253 self.assertTrue(row[0].endswith("reich"), "column must contain original data") 254 255 def test_optimized_unicode(self): 256 # OptimizedUnicode is deprecated as of Python 3.10 257 with self.assertWarns(DeprecationWarning) as cm: 258 self.con.text_factory = sqlite.OptimizedUnicode 259 self.assertIn("factory.py", cm.filename) 260 austria = "Österreich" 261 germany = "Deutchland" 262 a_row = self.con.execute("select ?", (austria,)).fetchone() 263 d_row = self.con.execute("select ?", (germany,)).fetchone() 264 self.assertEqual(type(a_row[0]), str, "type of non-ASCII row must be str") 265 self.assertEqual(type(d_row[0]), str, "type of ASCII-only row must be str") 266 267 def tearDown(self): 268 self.con.close() 269 270class TextFactoryTestsWithEmbeddedZeroBytes(unittest.TestCase): 271 def setUp(self): 272 self.con = sqlite.connect(":memory:") 273 self.con.execute("create table test (value text)") 274 self.con.execute("insert into test (value) values (?)", ("a\x00b",)) 275 276 def test_string(self): 277 # text_factory defaults to str 278 row = self.con.execute("select value from test").fetchone() 279 self.assertIs(type(row[0]), str) 280 self.assertEqual(row[0], "a\x00b") 281 282 def test_bytes(self): 283 self.con.text_factory = bytes 284 row = self.con.execute("select value from test").fetchone() 285 self.assertIs(type(row[0]), bytes) 286 self.assertEqual(row[0], b"a\x00b") 287 288 def test_bytearray(self): 289 self.con.text_factory = bytearray 290 row = self.con.execute("select value from test").fetchone() 291 self.assertIs(type(row[0]), bytearray) 292 self.assertEqual(row[0], b"a\x00b") 293 294 def test_custom(self): 295 # A custom factory should receive a bytes argument 296 self.con.text_factory = lambda x: x 297 row = self.con.execute("select value from test").fetchone() 298 self.assertIs(type(row[0]), bytes) 299 self.assertEqual(row[0], b"a\x00b") 300 301 def tearDown(self): 302 self.con.close() 303 304def suite(): 305 tests = [ 306 ConnectionFactoryTests, 307 CursorFactoryTests, 308 RowFactoryTests, 309 RowFactoryTestsBackwardsCompat, 310 TextFactoryTests, 311 TextFactoryTestsWithEmbeddedZeroBytes, 312 ] 313 return unittest.TestSuite( 314 [unittest.TestLoader().loadTestsFromTestCase(t) for t in tests] 315 ) 316 317def test(): 318 runner = unittest.TextTestRunner() 319 runner.run(suite()) 320 321if __name__ == "__main__": 322 test() 323