1# pysqlite2/test/hooks.py: tests for various SQLite-specific hooks 2# 3# Copyright (C) 2006-2007 Gerhard Häring <gh@ghaering.de> 4# 5# This file is part of pysqlite. 6# 7# This software is provided 'as-is', without any express or implied 8# warranty. In no event will the authors be held liable for any damages 9# arising from the use of this software. 10# 11# Permission is granted to anyone to use this software for any purpose, 12# including commercial applications, and to alter it and redistribute it 13# freely, subject to the following restrictions: 14# 15# 1. The origin of this software must not be misrepresented; you must not 16# claim that you wrote the original software. If you use this software 17# in a product, an acknowledgment in the product documentation would be 18# appreciated but is not required. 19# 2. Altered source versions must be plainly marked as such, and must not be 20# misrepresented as being the original software. 21# 3. This notice may not be removed or altered from any source distribution. 22 23import unittest 24import sqlite3 as sqlite 25 26from test.support.os_helper import TESTFN, unlink 27 28 29class CollationTests(unittest.TestCase): 30 def test_create_collation_not_string(self): 31 con = sqlite.connect(":memory:") 32 with self.assertRaises(TypeError): 33 con.create_collation(None, lambda x, y: (x > y) - (x < y)) 34 35 def test_create_collation_not_callable(self): 36 con = sqlite.connect(":memory:") 37 with self.assertRaises(TypeError) as cm: 38 con.create_collation("X", 42) 39 self.assertEqual(str(cm.exception), 'parameter must be callable') 40 41 def test_create_collation_not_ascii(self): 42 con = sqlite.connect(":memory:") 43 with self.assertRaises(sqlite.ProgrammingError): 44 con.create_collation("collä", lambda x, y: (x > y) - (x < y)) 45 46 def test_create_collation_bad_upper(self): 47 class BadUpperStr(str): 48 def upper(self): 49 return None 50 con = sqlite.connect(":memory:") 51 mycoll = lambda x, y: -((x > y) - (x < y)) 52 con.create_collation(BadUpperStr("mycoll"), mycoll) 53 result = con.execute(""" 54 select x from ( 55 select 'a' as x 56 union 57 select 'b' as x 58 ) order by x collate mycoll 59 """).fetchall() 60 self.assertEqual(result[0][0], 'b') 61 self.assertEqual(result[1][0], 'a') 62 63 def test_collation_is_used(self): 64 def mycoll(x, y): 65 # reverse order 66 return -((x > y) - (x < y)) 67 68 con = sqlite.connect(":memory:") 69 con.create_collation("mycoll", mycoll) 70 sql = """ 71 select x from ( 72 select 'a' as x 73 union 74 select 'b' as x 75 union 76 select 'c' as x 77 ) order by x collate mycoll 78 """ 79 result = con.execute(sql).fetchall() 80 self.assertEqual(result, [('c',), ('b',), ('a',)], 81 msg='the expected order was not returned') 82 83 con.create_collation("mycoll", None) 84 with self.assertRaises(sqlite.OperationalError) as cm: 85 result = con.execute(sql).fetchall() 86 self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll') 87 88 def test_collation_returns_large_integer(self): 89 def mycoll(x, y): 90 # reverse order 91 return -((x > y) - (x < y)) * 2**32 92 con = sqlite.connect(":memory:") 93 con.create_collation("mycoll", mycoll) 94 sql = """ 95 select x from ( 96 select 'a' as x 97 union 98 select 'b' as x 99 union 100 select 'c' as x 101 ) order by x collate mycoll 102 """ 103 result = con.execute(sql).fetchall() 104 self.assertEqual(result, [('c',), ('b',), ('a',)], 105 msg="the expected order was not returned") 106 107 def test_collation_register_twice(self): 108 """ 109 Register two different collation functions under the same name. 110 Verify that the last one is actually used. 111 """ 112 con = sqlite.connect(":memory:") 113 con.create_collation("mycoll", lambda x, y: (x > y) - (x < y)) 114 con.create_collation("mycoll", lambda x, y: -((x > y) - (x < y))) 115 result = con.execute(""" 116 select x from (select 'a' as x union select 'b' as x) order by x collate mycoll 117 """).fetchall() 118 self.assertEqual(result[0][0], 'b') 119 self.assertEqual(result[1][0], 'a') 120 121 def test_deregister_collation(self): 122 """ 123 Register a collation, then deregister it. Make sure an error is raised if we try 124 to use it. 125 """ 126 con = sqlite.connect(":memory:") 127 con.create_collation("mycoll", lambda x, y: (x > y) - (x < y)) 128 con.create_collation("mycoll", None) 129 with self.assertRaises(sqlite.OperationalError) as cm: 130 con.execute("select 'a' as x union select 'b' as x order by x collate mycoll") 131 self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll') 132 133class ProgressTests(unittest.TestCase): 134 def test_progress_handler_used(self): 135 """ 136 Test that the progress handler is invoked once it is set. 137 """ 138 con = sqlite.connect(":memory:") 139 progress_calls = [] 140 def progress(): 141 progress_calls.append(None) 142 return 0 143 con.set_progress_handler(progress, 1) 144 con.execute(""" 145 create table foo(a, b) 146 """) 147 self.assertTrue(progress_calls) 148 149 150 def test_opcode_count(self): 151 """ 152 Test that the opcode argument is respected. 153 """ 154 con = sqlite.connect(":memory:") 155 progress_calls = [] 156 def progress(): 157 progress_calls.append(None) 158 return 0 159 con.set_progress_handler(progress, 1) 160 curs = con.cursor() 161 curs.execute(""" 162 create table foo (a, b) 163 """) 164 first_count = len(progress_calls) 165 progress_calls = [] 166 con.set_progress_handler(progress, 2) 167 curs.execute(""" 168 create table bar (a, b) 169 """) 170 second_count = len(progress_calls) 171 self.assertGreaterEqual(first_count, second_count) 172 173 def test_cancel_operation(self): 174 """ 175 Test that returning a non-zero value stops the operation in progress. 176 """ 177 con = sqlite.connect(":memory:") 178 def progress(): 179 return 1 180 con.set_progress_handler(progress, 1) 181 curs = con.cursor() 182 self.assertRaises( 183 sqlite.OperationalError, 184 curs.execute, 185 "create table bar (a, b)") 186 187 def test_clear_handler(self): 188 """ 189 Test that setting the progress handler to None clears the previously set handler. 190 """ 191 con = sqlite.connect(":memory:") 192 action = 0 193 def progress(): 194 nonlocal action 195 action = 1 196 return 0 197 con.set_progress_handler(progress, 1) 198 con.set_progress_handler(None, 1) 199 con.execute("select 1 union select 2 union select 3").fetchall() 200 self.assertEqual(action, 0, "progress handler was not cleared") 201 202class TraceCallbackTests(unittest.TestCase): 203 def test_trace_callback_used(self): 204 """ 205 Test that the trace callback is invoked once it is set. 206 """ 207 con = sqlite.connect(":memory:") 208 traced_statements = [] 209 def trace(statement): 210 traced_statements.append(statement) 211 con.set_trace_callback(trace) 212 con.execute("create table foo(a, b)") 213 self.assertTrue(traced_statements) 214 self.assertTrue(any("create table foo" in stmt for stmt in traced_statements)) 215 216 def test_clear_trace_callback(self): 217 """ 218 Test that setting the trace callback to None clears the previously set callback. 219 """ 220 con = sqlite.connect(":memory:") 221 traced_statements = [] 222 def trace(statement): 223 traced_statements.append(statement) 224 con.set_trace_callback(trace) 225 con.set_trace_callback(None) 226 con.execute("create table foo(a, b)") 227 self.assertFalse(traced_statements, "trace callback was not cleared") 228 229 def test_unicode_content(self): 230 """ 231 Test that the statement can contain unicode literals. 232 """ 233 unicode_value = '\xf6\xe4\xfc\xd6\xc4\xdc\xdf\u20ac' 234 con = sqlite.connect(":memory:") 235 traced_statements = [] 236 def trace(statement): 237 traced_statements.append(statement) 238 con.set_trace_callback(trace) 239 con.execute("create table foo(x)") 240 con.execute("insert into foo(x) values ('%s')" % unicode_value) 241 con.commit() 242 self.assertTrue(any(unicode_value in stmt for stmt in traced_statements), 243 "Unicode data %s garbled in trace callback: %s" 244 % (ascii(unicode_value), ', '.join(map(ascii, traced_statements)))) 245 246 def test_trace_callback_content(self): 247 # set_trace_callback() shouldn't produce duplicate content (bpo-26187) 248 traced_statements = [] 249 def trace(statement): 250 traced_statements.append(statement) 251 252 queries = ["create table foo(x)", 253 "insert into foo(x) values(1)"] 254 self.addCleanup(unlink, TESTFN) 255 con1 = sqlite.connect(TESTFN, isolation_level=None) 256 con2 = sqlite.connect(TESTFN) 257 con1.set_trace_callback(trace) 258 cur = con1.cursor() 259 cur.execute(queries[0]) 260 con2.execute("create table bar(x)") 261 cur.execute(queries[1]) 262 self.assertEqual(traced_statements, queries) 263 264 265def suite(): 266 tests = [ 267 CollationTests, 268 ProgressTests, 269 TraceCallbackTests, 270 ] 271 return unittest.TestSuite( 272 [unittest.TestLoader().loadTestsFromTestCase(t) for t in tests] 273 ) 274 275def test(): 276 runner = unittest.TextTestRunner() 277 runner.run(suite()) 278 279if __name__ == "__main__": 280 test() 281