1#-*- coding: iso-8859-1 -*- 2# pysqlite2/test/hooks.py: tests for various SQLite-specific hooks 3# 4# Copyright (C) 2006-2007 Gerhard H�ring <gh@ghaering.de> 5# 6# This file is part of pysqlite. 7# 8# This software is provided 'as-is', without any express or implied 9# warranty. In no event will the authors be held liable for any damages 10# arising from the use of this software. 11# 12# Permission is granted to anyone to use this software for any purpose, 13# including commercial applications, and to alter it and redistribute it 14# freely, subject to the following restrictions: 15# 16# 1. The origin of this software must not be misrepresented; you must not 17# claim that you wrote the original software. If you use this software 18# in a product, an acknowledgment in the product documentation would be 19# appreciated but is not required. 20# 2. Altered source versions must be plainly marked as such, and must not be 21# misrepresented as being the original software. 22# 3. This notice may not be removed or altered from any source distribution. 23 24import unittest 25import sqlite3 as sqlite 26 27from test.support import TESTFN, unlink 28 29class CollationTests(unittest.TestCase): 30 def CheckCreateCollationNotString(self): 31 con = sqlite.connect(":memory:") 32 with self.assertRaises(TypeError): 33 con.create_collation(None, lambda x, y: (x > y) - (x < y)) 34 35 def CheckCreateCollationNotCallable(self): 36 con = sqlite.connect(":memory:") 37 with self.assertRaises(TypeError) as cm: 38 con.create_collation("X", 42) 39 self.assertEqual(str(cm.exception), 'parameter must be callable') 40 41 def CheckCreateCollationNotAscii(self): 42 con = sqlite.connect(":memory:") 43 with self.assertRaises(sqlite.ProgrammingError): 44 con.create_collation("coll�", lambda x, y: (x > y) - (x < y)) 45 46 def CheckCreateCollationBadUpper(self): 47 class BadUpperStr(str): 48 def upper(self): 49 return None 50 con = sqlite.connect(":memory:") 51 mycoll = lambda x, y: -((x > y) - (x < y)) 52 con.create_collation(BadUpperStr("mycoll"), mycoll) 53 result = con.execute(""" 54 select x from ( 55 select 'a' as x 56 union 57 select 'b' as x 58 ) order by x collate mycoll 59 """).fetchall() 60 self.assertEqual(result[0][0], 'b') 61 self.assertEqual(result[1][0], 'a') 62 63 @unittest.skipIf(sqlite.sqlite_version_info < (3, 2, 1), 64 'old SQLite versions crash on this test') 65 def CheckCollationIsUsed(self): 66 def mycoll(x, y): 67 # reverse order 68 return -((x > y) - (x < y)) 69 70 con = sqlite.connect(":memory:") 71 con.create_collation("mycoll", mycoll) 72 sql = """ 73 select x from ( 74 select 'a' as x 75 union 76 select 'b' as x 77 union 78 select 'c' as x 79 ) order by x collate mycoll 80 """ 81 result = con.execute(sql).fetchall() 82 self.assertEqual(result, [('c',), ('b',), ('a',)], 83 msg='the expected order was not returned') 84 85 con.create_collation("mycoll", None) 86 with self.assertRaises(sqlite.OperationalError) as cm: 87 result = con.execute(sql).fetchall() 88 self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll') 89 90 def CheckCollationReturnsLargeInteger(self): 91 def mycoll(x, y): 92 # reverse order 93 return -((x > y) - (x < y)) * 2**32 94 con = sqlite.connect(":memory:") 95 con.create_collation("mycoll", mycoll) 96 sql = """ 97 select x from ( 98 select 'a' as x 99 union 100 select 'b' as x 101 union 102 select 'c' as x 103 ) order by x collate mycoll 104 """ 105 result = con.execute(sql).fetchall() 106 self.assertEqual(result, [('c',), ('b',), ('a',)], 107 msg="the expected order was not returned") 108 109 def CheckCollationRegisterTwice(self): 110 """ 111 Register two different collation functions under the same name. 112 Verify that the last one is actually used. 113 """ 114 con = sqlite.connect(":memory:") 115 con.create_collation("mycoll", lambda x, y: (x > y) - (x < y)) 116 con.create_collation("mycoll", lambda x, y: -((x > y) - (x < y))) 117 result = con.execute(""" 118 select x from (select 'a' as x union select 'b' as x) order by x collate mycoll 119 """).fetchall() 120 self.assertEqual(result[0][0], 'b') 121 self.assertEqual(result[1][0], 'a') 122 123 def CheckDeregisterCollation(self): 124 """ 125 Register a collation, then deregister it. Make sure an error is raised if we try 126 to use it. 127 """ 128 con = sqlite.connect(":memory:") 129 con.create_collation("mycoll", lambda x, y: (x > y) - (x < y)) 130 con.create_collation("mycoll", None) 131 with self.assertRaises(sqlite.OperationalError) as cm: 132 con.execute("select 'a' as x union select 'b' as x order by x collate mycoll") 133 self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll') 134 135class ProgressTests(unittest.TestCase): 136 def CheckProgressHandlerUsed(self): 137 """ 138 Test that the progress handler is invoked once it is set. 139 """ 140 con = sqlite.connect(":memory:") 141 progress_calls = [] 142 def progress(): 143 progress_calls.append(None) 144 return 0 145 con.set_progress_handler(progress, 1) 146 con.execute(""" 147 create table foo(a, b) 148 """) 149 self.assertTrue(progress_calls) 150 151 152 def CheckOpcodeCount(self): 153 """ 154 Test that the opcode argument is respected. 155 """ 156 con = sqlite.connect(":memory:") 157 progress_calls = [] 158 def progress(): 159 progress_calls.append(None) 160 return 0 161 con.set_progress_handler(progress, 1) 162 curs = con.cursor() 163 curs.execute(""" 164 create table foo (a, b) 165 """) 166 first_count = len(progress_calls) 167 progress_calls = [] 168 con.set_progress_handler(progress, 2) 169 curs.execute(""" 170 create table bar (a, b) 171 """) 172 second_count = len(progress_calls) 173 self.assertGreaterEqual(first_count, second_count) 174 175 def CheckCancelOperation(self): 176 """ 177 Test that returning a non-zero value stops the operation in progress. 178 """ 179 con = sqlite.connect(":memory:") 180 def progress(): 181 return 1 182 con.set_progress_handler(progress, 1) 183 curs = con.cursor() 184 self.assertRaises( 185 sqlite.OperationalError, 186 curs.execute, 187 "create table bar (a, b)") 188 189 def CheckClearHandler(self): 190 """ 191 Test that setting the progress handler to None clears the previously set handler. 192 """ 193 con = sqlite.connect(":memory:") 194 action = 0 195 def progress(): 196 nonlocal action 197 action = 1 198 return 0 199 con.set_progress_handler(progress, 1) 200 con.set_progress_handler(None, 1) 201 con.execute("select 1 union select 2 union select 3").fetchall() 202 self.assertEqual(action, 0, "progress handler was not cleared") 203 204class TraceCallbackTests(unittest.TestCase): 205 def CheckTraceCallbackUsed(self): 206 """ 207 Test that the trace callback is invoked once it is set. 208 """ 209 con = sqlite.connect(":memory:") 210 traced_statements = [] 211 def trace(statement): 212 traced_statements.append(statement) 213 con.set_trace_callback(trace) 214 con.execute("create table foo(a, b)") 215 self.assertTrue(traced_statements) 216 self.assertTrue(any("create table foo" in stmt for stmt in traced_statements)) 217 218 def CheckClearTraceCallback(self): 219 """ 220 Test that setting the trace callback to None clears the previously set callback. 221 """ 222 con = sqlite.connect(":memory:") 223 traced_statements = [] 224 def trace(statement): 225 traced_statements.append(statement) 226 con.set_trace_callback(trace) 227 con.set_trace_callback(None) 228 con.execute("create table foo(a, b)") 229 self.assertFalse(traced_statements, "trace callback was not cleared") 230 231 def CheckUnicodeContent(self): 232 """ 233 Test that the statement can contain unicode literals. 234 """ 235 unicode_value = '\xf6\xe4\xfc\xd6\xc4\xdc\xdf\u20ac' 236 con = sqlite.connect(":memory:") 237 traced_statements = [] 238 def trace(statement): 239 traced_statements.append(statement) 240 con.set_trace_callback(trace) 241 con.execute("create table foo(x)") 242 # Can't execute bound parameters as their values don't appear 243 # in traced statements before SQLite 3.6.21 244 # (cf. http://www.sqlite.org/draft/releaselog/3_6_21.html) 245 con.execute('insert into foo(x) values ("%s")' % unicode_value) 246 con.commit() 247 self.assertTrue(any(unicode_value in stmt for stmt in traced_statements), 248 "Unicode data %s garbled in trace callback: %s" 249 % (ascii(unicode_value), ', '.join(map(ascii, traced_statements)))) 250 251 @unittest.skipIf(sqlite.sqlite_version_info < (3, 3, 9), "sqlite3_prepare_v2 is not available") 252 def CheckTraceCallbackContent(self): 253 # set_trace_callback() shouldn't produce duplicate content (bpo-26187) 254 traced_statements = [] 255 def trace(statement): 256 traced_statements.append(statement) 257 258 queries = ["create table foo(x)", 259 "insert into foo(x) values(1)"] 260 self.addCleanup(unlink, TESTFN) 261 con1 = sqlite.connect(TESTFN, isolation_level=None) 262 con2 = sqlite.connect(TESTFN) 263 con1.set_trace_callback(trace) 264 cur = con1.cursor() 265 cur.execute(queries[0]) 266 con2.execute("create table bar(x)") 267 cur.execute(queries[1]) 268 self.assertEqual(traced_statements, queries) 269 270 271def suite(): 272 collation_suite = unittest.makeSuite(CollationTests, "Check") 273 progress_suite = unittest.makeSuite(ProgressTests, "Check") 274 trace_suite = unittest.makeSuite(TraceCallbackTests, "Check") 275 return unittest.TestSuite((collation_suite, progress_suite, trace_suite)) 276 277def test(): 278 runner = unittest.TextTestRunner() 279 runner.run(suite()) 280 281if __name__ == "__main__": 282 test() 283