1# pysqlite2/test/userfunctions.py: tests for user-defined functions and 2# aggregates. 3# 4# Copyright (C) 2005-2007 Gerhard Häring <gh@ghaering.de> 5# 6# This file is part of pysqlite. 7# 8# This software is provided 'as-is', without any express or implied 9# warranty. In no event will the authors be held liable for any damages 10# arising from the use of this software. 11# 12# Permission is granted to anyone to use this software for any purpose, 13# including commercial applications, and to alter it and redistribute it 14# freely, subject to the following restrictions: 15# 16# 1. The origin of this software must not be misrepresented; you must not 17# claim that you wrote the original software. If you use this software 18# in a product, an acknowledgment in the product documentation would be 19# appreciated but is not required. 20# 2. Altered source versions must be plainly marked as such, and must not be 21# misrepresented as being the original software. 22# 3. This notice may not be removed or altered from any source distribution. 23 24import unittest 25import unittest.mock 26import gc 27import sqlite3 as sqlite 28 29def func_returntext(): 30 return "foo" 31def func_returntextwithnull(): 32 return "1\x002" 33def func_returnunicode(): 34 return "bar" 35def func_returnint(): 36 return 42 37def func_returnfloat(): 38 return 3.14 39def func_returnnull(): 40 return None 41def func_returnblob(): 42 return b"blob" 43def func_returnlonglong(): 44 return 1<<31 45def func_raiseexception(): 46 5/0 47 48def func_isstring(v): 49 return type(v) is str 50def func_isint(v): 51 return type(v) is int 52def func_isfloat(v): 53 return type(v) is float 54def func_isnone(v): 55 return type(v) is type(None) 56def func_isblob(v): 57 return isinstance(v, (bytes, memoryview)) 58def func_islonglong(v): 59 return isinstance(v, int) and v >= 1<<31 60 61def func(*args): 62 return len(args) 63 64class AggrNoStep: 65 def __init__(self): 66 pass 67 68 def finalize(self): 69 return 1 70 71class AggrNoFinalize: 72 def __init__(self): 73 pass 74 75 def step(self, x): 76 pass 77 78class AggrExceptionInInit: 79 def __init__(self): 80 5/0 81 82 def step(self, x): 83 pass 84 85 def finalize(self): 86 pass 87 88class AggrExceptionInStep: 89 def __init__(self): 90 pass 91 92 def step(self, x): 93 5/0 94 95 def finalize(self): 96 return 42 97 98class AggrExceptionInFinalize: 99 def __init__(self): 100 pass 101 102 def step(self, x): 103 pass 104 105 def finalize(self): 106 5/0 107 108class AggrCheckType: 109 def __init__(self): 110 self.val = None 111 112 def step(self, whichType, val): 113 theType = {"str": str, "int": int, "float": float, "None": type(None), 114 "blob": bytes} 115 self.val = int(theType[whichType] is type(val)) 116 117 def finalize(self): 118 return self.val 119 120class AggrCheckTypes: 121 def __init__(self): 122 self.val = 0 123 124 def step(self, whichType, *vals): 125 theType = {"str": str, "int": int, "float": float, "None": type(None), 126 "blob": bytes} 127 for val in vals: 128 self.val += int(theType[whichType] is type(val)) 129 130 def finalize(self): 131 return self.val 132 133class AggrSum: 134 def __init__(self): 135 self.val = 0.0 136 137 def step(self, val): 138 self.val += val 139 140 def finalize(self): 141 return self.val 142 143class AggrText: 144 def __init__(self): 145 self.txt = "" 146 def step(self, txt): 147 self.txt = self.txt + txt 148 def finalize(self): 149 return self.txt 150 151 152class FunctionTests(unittest.TestCase): 153 def setUp(self): 154 self.con = sqlite.connect(":memory:") 155 156 self.con.create_function("returntext", 0, func_returntext) 157 self.con.create_function("returntextwithnull", 0, func_returntextwithnull) 158 self.con.create_function("returnunicode", 0, func_returnunicode) 159 self.con.create_function("returnint", 0, func_returnint) 160 self.con.create_function("returnfloat", 0, func_returnfloat) 161 self.con.create_function("returnnull", 0, func_returnnull) 162 self.con.create_function("returnblob", 0, func_returnblob) 163 self.con.create_function("returnlonglong", 0, func_returnlonglong) 164 self.con.create_function("raiseexception", 0, func_raiseexception) 165 166 self.con.create_function("isstring", 1, func_isstring) 167 self.con.create_function("isint", 1, func_isint) 168 self.con.create_function("isfloat", 1, func_isfloat) 169 self.con.create_function("isnone", 1, func_isnone) 170 self.con.create_function("isblob", 1, func_isblob) 171 self.con.create_function("islonglong", 1, func_islonglong) 172 self.con.create_function("spam", -1, func) 173 self.con.execute("create table test(t text)") 174 175 def tearDown(self): 176 self.con.close() 177 178 def test_func_error_on_create(self): 179 with self.assertRaises(sqlite.OperationalError): 180 self.con.create_function("bla", -100, lambda x: 2*x) 181 182 def test_func_ref_count(self): 183 def getfunc(): 184 def f(): 185 return 1 186 return f 187 f = getfunc() 188 globals()["foo"] = f 189 # self.con.create_function("reftest", 0, getfunc()) 190 self.con.create_function("reftest", 0, f) 191 cur = self.con.cursor() 192 cur.execute("select reftest()") 193 194 def test_func_return_text(self): 195 cur = self.con.cursor() 196 cur.execute("select returntext()") 197 val = cur.fetchone()[0] 198 self.assertEqual(type(val), str) 199 self.assertEqual(val, "foo") 200 201 def test_func_return_text_with_null_char(self): 202 cur = self.con.cursor() 203 res = cur.execute("select returntextwithnull()").fetchone()[0] 204 self.assertEqual(type(res), str) 205 self.assertEqual(res, "1\x002") 206 207 def test_func_return_unicode(self): 208 cur = self.con.cursor() 209 cur.execute("select returnunicode()") 210 val = cur.fetchone()[0] 211 self.assertEqual(type(val), str) 212 self.assertEqual(val, "bar") 213 214 def test_func_return_int(self): 215 cur = self.con.cursor() 216 cur.execute("select returnint()") 217 val = cur.fetchone()[0] 218 self.assertEqual(type(val), int) 219 self.assertEqual(val, 42) 220 221 def test_func_return_float(self): 222 cur = self.con.cursor() 223 cur.execute("select returnfloat()") 224 val = cur.fetchone()[0] 225 self.assertEqual(type(val), float) 226 if val < 3.139 or val > 3.141: 227 self.fail("wrong value") 228 229 def test_func_return_null(self): 230 cur = self.con.cursor() 231 cur.execute("select returnnull()") 232 val = cur.fetchone()[0] 233 self.assertEqual(type(val), type(None)) 234 self.assertEqual(val, None) 235 236 def test_func_return_blob(self): 237 cur = self.con.cursor() 238 cur.execute("select returnblob()") 239 val = cur.fetchone()[0] 240 self.assertEqual(type(val), bytes) 241 self.assertEqual(val, b"blob") 242 243 def test_func_return_long_long(self): 244 cur = self.con.cursor() 245 cur.execute("select returnlonglong()") 246 val = cur.fetchone()[0] 247 self.assertEqual(val, 1<<31) 248 249 def test_func_exception(self): 250 cur = self.con.cursor() 251 with self.assertRaises(sqlite.OperationalError) as cm: 252 cur.execute("select raiseexception()") 253 cur.fetchone() 254 self.assertEqual(str(cm.exception), 'user-defined function raised exception') 255 256 def test_param_string(self): 257 cur = self.con.cursor() 258 for text in ["foo", str()]: 259 with self.subTest(text=text): 260 cur.execute("select isstring(?)", (text,)) 261 val = cur.fetchone()[0] 262 self.assertEqual(val, 1) 263 264 def test_param_int(self): 265 cur = self.con.cursor() 266 cur.execute("select isint(?)", (42,)) 267 val = cur.fetchone()[0] 268 self.assertEqual(val, 1) 269 270 def test_param_float(self): 271 cur = self.con.cursor() 272 cur.execute("select isfloat(?)", (3.14,)) 273 val = cur.fetchone()[0] 274 self.assertEqual(val, 1) 275 276 def test_param_none(self): 277 cur = self.con.cursor() 278 cur.execute("select isnone(?)", (None,)) 279 val = cur.fetchone()[0] 280 self.assertEqual(val, 1) 281 282 def test_param_blob(self): 283 cur = self.con.cursor() 284 cur.execute("select isblob(?)", (memoryview(b"blob"),)) 285 val = cur.fetchone()[0] 286 self.assertEqual(val, 1) 287 288 def test_param_long_long(self): 289 cur = self.con.cursor() 290 cur.execute("select islonglong(?)", (1<<42,)) 291 val = cur.fetchone()[0] 292 self.assertEqual(val, 1) 293 294 def test_any_arguments(self): 295 cur = self.con.cursor() 296 cur.execute("select spam(?, ?)", (1, 2)) 297 val = cur.fetchone()[0] 298 self.assertEqual(val, 2) 299 300 def test_empty_blob(self): 301 cur = self.con.execute("select isblob(x'')") 302 self.assertTrue(cur.fetchone()[0]) 303 304 # Regarding deterministic functions: 305 # 306 # Between 3.8.3 and 3.15.0, deterministic functions were only used to 307 # optimize inner loops, so for those versions we can only test if the 308 # sqlite machinery has factored out a call or not. From 3.15.0 and onward, 309 # deterministic functions were permitted in WHERE clauses of partial 310 # indices, which allows testing based on syntax, iso. the query optimizer. 311 @unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "Requires SQLite 3.8.3 or higher") 312 def test_func_non_deterministic(self): 313 mock = unittest.mock.Mock(return_value=None) 314 self.con.create_function("nondeterministic", 0, mock, deterministic=False) 315 if sqlite.sqlite_version_info < (3, 15, 0): 316 self.con.execute("select nondeterministic() = nondeterministic()") 317 self.assertEqual(mock.call_count, 2) 318 else: 319 with self.assertRaises(sqlite.OperationalError): 320 self.con.execute("create index t on test(t) where nondeterministic() is not null") 321 322 @unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "Requires SQLite 3.8.3 or higher") 323 def test_func_deterministic(self): 324 mock = unittest.mock.Mock(return_value=None) 325 self.con.create_function("deterministic", 0, mock, deterministic=True) 326 if sqlite.sqlite_version_info < (3, 15, 0): 327 self.con.execute("select deterministic() = deterministic()") 328 self.assertEqual(mock.call_count, 1) 329 else: 330 try: 331 self.con.execute("create index t on test(t) where deterministic() is not null") 332 except sqlite.OperationalError: 333 self.fail("Unexpected failure while creating partial index") 334 335 @unittest.skipIf(sqlite.sqlite_version_info >= (3, 8, 3), "SQLite < 3.8.3 needed") 336 def test_func_deterministic_not_supported(self): 337 with self.assertRaises(sqlite.NotSupportedError): 338 self.con.create_function("deterministic", 0, int, deterministic=True) 339 340 def test_func_deterministic_keyword_only(self): 341 with self.assertRaises(TypeError): 342 self.con.create_function("deterministic", 0, int, True) 343 344 def test_function_destructor_via_gc(self): 345 # See bpo-44304: The destructor of the user function can 346 # crash if is called without the GIL from the gc functions 347 dest = sqlite.connect(':memory:') 348 def md5sum(t): 349 return 350 351 dest.create_function("md5", 1, md5sum) 352 x = dest("create table lang (name, first_appeared)") 353 del md5sum, dest 354 355 y = [x] 356 y.append(y) 357 358 del x,y 359 gc.collect() 360 361class AggregateTests(unittest.TestCase): 362 def setUp(self): 363 self.con = sqlite.connect(":memory:") 364 cur = self.con.cursor() 365 cur.execute(""" 366 create table test( 367 t text, 368 i integer, 369 f float, 370 n, 371 b blob 372 ) 373 """) 374 cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)", 375 ("foo", 5, 3.14, None, memoryview(b"blob"),)) 376 377 self.con.create_aggregate("nostep", 1, AggrNoStep) 378 self.con.create_aggregate("nofinalize", 1, AggrNoFinalize) 379 self.con.create_aggregate("excInit", 1, AggrExceptionInInit) 380 self.con.create_aggregate("excStep", 1, AggrExceptionInStep) 381 self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize) 382 self.con.create_aggregate("checkType", 2, AggrCheckType) 383 self.con.create_aggregate("checkTypes", -1, AggrCheckTypes) 384 self.con.create_aggregate("mysum", 1, AggrSum) 385 self.con.create_aggregate("aggtxt", 1, AggrText) 386 387 def tearDown(self): 388 #self.cur.close() 389 #self.con.close() 390 pass 391 392 def test_aggr_error_on_create(self): 393 with self.assertRaises(sqlite.OperationalError): 394 self.con.create_function("bla", -100, AggrSum) 395 396 def test_aggr_no_step(self): 397 cur = self.con.cursor() 398 with self.assertRaises(AttributeError) as cm: 399 cur.execute("select nostep(t) from test") 400 self.assertEqual(str(cm.exception), "'AggrNoStep' object has no attribute 'step'") 401 402 def test_aggr_no_finalize(self): 403 cur = self.con.cursor() 404 with self.assertRaises(sqlite.OperationalError) as cm: 405 cur.execute("select nofinalize(t) from test") 406 val = cur.fetchone()[0] 407 self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error") 408 409 def test_aggr_exception_in_init(self): 410 cur = self.con.cursor() 411 with self.assertRaises(sqlite.OperationalError) as cm: 412 cur.execute("select excInit(t) from test") 413 val = cur.fetchone()[0] 414 self.assertEqual(str(cm.exception), "user-defined aggregate's '__init__' method raised error") 415 416 def test_aggr_exception_in_step(self): 417 cur = self.con.cursor() 418 with self.assertRaises(sqlite.OperationalError) as cm: 419 cur.execute("select excStep(t) from test") 420 val = cur.fetchone()[0] 421 self.assertEqual(str(cm.exception), "user-defined aggregate's 'step' method raised error") 422 423 def test_aggr_exception_in_finalize(self): 424 cur = self.con.cursor() 425 with self.assertRaises(sqlite.OperationalError) as cm: 426 cur.execute("select excFinalize(t) from test") 427 val = cur.fetchone()[0] 428 self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error") 429 430 def test_aggr_check_param_str(self): 431 cur = self.con.cursor() 432 cur.execute("select checkTypes('str', ?, ?)", ("foo", str())) 433 val = cur.fetchone()[0] 434 self.assertEqual(val, 2) 435 436 def test_aggr_check_param_int(self): 437 cur = self.con.cursor() 438 cur.execute("select checkType('int', ?)", (42,)) 439 val = cur.fetchone()[0] 440 self.assertEqual(val, 1) 441 442 def test_aggr_check_params_int(self): 443 cur = self.con.cursor() 444 cur.execute("select checkTypes('int', ?, ?)", (42, 24)) 445 val = cur.fetchone()[0] 446 self.assertEqual(val, 2) 447 448 def test_aggr_check_param_float(self): 449 cur = self.con.cursor() 450 cur.execute("select checkType('float', ?)", (3.14,)) 451 val = cur.fetchone()[0] 452 self.assertEqual(val, 1) 453 454 def test_aggr_check_param_none(self): 455 cur = self.con.cursor() 456 cur.execute("select checkType('None', ?)", (None,)) 457 val = cur.fetchone()[0] 458 self.assertEqual(val, 1) 459 460 def test_aggr_check_param_blob(self): 461 cur = self.con.cursor() 462 cur.execute("select checkType('blob', ?)", (memoryview(b"blob"),)) 463 val = cur.fetchone()[0] 464 self.assertEqual(val, 1) 465 466 def test_aggr_check_aggr_sum(self): 467 cur = self.con.cursor() 468 cur.execute("delete from test") 469 cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)]) 470 cur.execute("select mysum(i) from test") 471 val = cur.fetchone()[0] 472 self.assertEqual(val, 60) 473 474 def test_aggr_no_match(self): 475 cur = self.con.execute("select mysum(i) from (select 1 as i) where i == 0") 476 val = cur.fetchone()[0] 477 self.assertIsNone(val) 478 479 def test_aggr_text(self): 480 cur = self.con.cursor() 481 for txt in ["foo", "1\x002"]: 482 with self.subTest(txt=txt): 483 cur.execute("select aggtxt(?) from test", (txt,)) 484 val = cur.fetchone()[0] 485 self.assertEqual(val, txt) 486 487 488class AuthorizerTests(unittest.TestCase): 489 @staticmethod 490 def authorizer_cb(action, arg1, arg2, dbname, source): 491 if action != sqlite.SQLITE_SELECT: 492 return sqlite.SQLITE_DENY 493 if arg2 == 'c2' or arg1 == 't2': 494 return sqlite.SQLITE_DENY 495 return sqlite.SQLITE_OK 496 497 def setUp(self): 498 self.con = sqlite.connect(":memory:") 499 self.con.executescript(""" 500 create table t1 (c1, c2); 501 create table t2 (c1, c2); 502 insert into t1 (c1, c2) values (1, 2); 503 insert into t2 (c1, c2) values (4, 5); 504 """) 505 506 # For our security test: 507 self.con.execute("select c2 from t2") 508 509 self.con.set_authorizer(self.authorizer_cb) 510 511 def tearDown(self): 512 pass 513 514 def test_table_access(self): 515 with self.assertRaises(sqlite.DatabaseError) as cm: 516 self.con.execute("select * from t2") 517 self.assertIn('prohibited', str(cm.exception)) 518 519 def test_column_access(self): 520 with self.assertRaises(sqlite.DatabaseError) as cm: 521 self.con.execute("select c2 from t1") 522 self.assertIn('prohibited', str(cm.exception)) 523 524class AuthorizerRaiseExceptionTests(AuthorizerTests): 525 @staticmethod 526 def authorizer_cb(action, arg1, arg2, dbname, source): 527 if action != sqlite.SQLITE_SELECT: 528 raise ValueError 529 if arg2 == 'c2' or arg1 == 't2': 530 raise ValueError 531 return sqlite.SQLITE_OK 532 533class AuthorizerIllegalTypeTests(AuthorizerTests): 534 @staticmethod 535 def authorizer_cb(action, arg1, arg2, dbname, source): 536 if action != sqlite.SQLITE_SELECT: 537 return 0.0 538 if arg2 == 'c2' or arg1 == 't2': 539 return 0.0 540 return sqlite.SQLITE_OK 541 542class AuthorizerLargeIntegerTests(AuthorizerTests): 543 @staticmethod 544 def authorizer_cb(action, arg1, arg2, dbname, source): 545 if action != sqlite.SQLITE_SELECT: 546 return 2**32 547 if arg2 == 'c2' or arg1 == 't2': 548 return 2**32 549 return sqlite.SQLITE_OK 550 551 552def suite(): 553 tests = [ 554 AggregateTests, 555 AuthorizerIllegalTypeTests, 556 AuthorizerLargeIntegerTests, 557 AuthorizerRaiseExceptionTests, 558 AuthorizerTests, 559 FunctionTests, 560 ] 561 return unittest.TestSuite( 562 [unittest.TestLoader().loadTestsFromTestCase(t) for t in tests] 563 ) 564 565def test(): 566 runner = unittest.TextTestRunner() 567 runner.run(suite()) 568 569if __name__ == "__main__": 570 test() 571