1# pysqlite2/test/userfunctions.py: tests for user-defined functions and 2# aggregates. 3# 4# Copyright (C) 2005-2007 Gerhard Häring <gh@ghaering.de> 5# 6# This file is part of pysqlite. 7# 8# This software is provided 'as-is', without any express or implied 9# warranty. In no event will the authors be held liable for any damages 10# arising from the use of this software. 11# 12# Permission is granted to anyone to use this software for any purpose, 13# including commercial applications, and to alter it and redistribute it 14# freely, subject to the following restrictions: 15# 16# 1. The origin of this software must not be misrepresented; you must not 17# claim that you wrote the original software. If you use this software 18# in a product, an acknowledgment in the product documentation would be 19# appreciated but is not required. 20# 2. Altered source versions must be plainly marked as such, and must not be 21# misrepresented as being the original software. 22# 3. This notice may not be removed or altered from any source distribution. 23 24import sys 25import unittest 26import sqlite3 as sqlite 27 28from unittest.mock import Mock, patch 29from test.support import bigmemtest, gc_collect 30 31from .util import cx_limit, memory_database 32from .util import with_tracebacks 33 34 35def func_returntext(): 36 return "foo" 37def func_returntextwithnull(): 38 return "1\x002" 39def func_returnunicode(): 40 return "bar" 41def func_returnint(): 42 return 42 43def func_returnfloat(): 44 return 3.14 45def func_returnnull(): 46 return None 47def func_returnblob(): 48 return b"blob" 49def func_returnlonglong(): 50 return 1<<31 51def func_raiseexception(): 52 5/0 53def func_memoryerror(): 54 raise MemoryError 55def func_overflowerror(): 56 raise OverflowError 57 58class AggrNoStep: 59 def __init__(self): 60 pass 61 62 def finalize(self): 63 return 1 64 65class AggrNoFinalize: 66 def __init__(self): 67 pass 68 69 def step(self, x): 70 pass 71 72class AggrExceptionInInit: 73 def __init__(self): 74 5/0 75 76 def step(self, x): 77 pass 78 79 def finalize(self): 80 pass 81 82class AggrExceptionInStep: 83 def __init__(self): 84 pass 85 86 def step(self, x): 87 5/0 88 89 def finalize(self): 90 return 42 91 92class AggrExceptionInFinalize: 93 def __init__(self): 94 pass 95 96 def step(self, x): 97 pass 98 99 def finalize(self): 100 5/0 101 102class AggrCheckType: 103 def __init__(self): 104 self.val = None 105 106 def step(self, whichType, val): 107 theType = {"str": str, "int": int, "float": float, "None": type(None), 108 "blob": bytes} 109 self.val = int(theType[whichType] is type(val)) 110 111 def finalize(self): 112 return self.val 113 114class AggrCheckTypes: 115 def __init__(self): 116 self.val = 0 117 118 def step(self, whichType, *vals): 119 theType = {"str": str, "int": int, "float": float, "None": type(None), 120 "blob": bytes} 121 for val in vals: 122 self.val += int(theType[whichType] is type(val)) 123 124 def finalize(self): 125 return self.val 126 127class AggrSum: 128 def __init__(self): 129 self.val = 0.0 130 131 def step(self, val): 132 self.val += val 133 134 def finalize(self): 135 return self.val 136 137class AggrText: 138 def __init__(self): 139 self.txt = "" 140 def step(self, txt): 141 self.txt = self.txt + txt 142 def finalize(self): 143 return self.txt 144 145 146class FunctionTests(unittest.TestCase): 147 def setUp(self): 148 self.con = sqlite.connect(":memory:") 149 150 self.con.create_function("returntext", 0, func_returntext) 151 self.con.create_function("returntextwithnull", 0, func_returntextwithnull) 152 self.con.create_function("returnunicode", 0, func_returnunicode) 153 self.con.create_function("returnint", 0, func_returnint) 154 self.con.create_function("returnfloat", 0, func_returnfloat) 155 self.con.create_function("returnnull", 0, func_returnnull) 156 self.con.create_function("returnblob", 0, func_returnblob) 157 self.con.create_function("returnlonglong", 0, func_returnlonglong) 158 self.con.create_function("returnnan", 0, lambda: float("nan")) 159 self.con.create_function("return_noncont_blob", 0, 160 lambda: memoryview(b"blob")[::2]) 161 self.con.create_function("raiseexception", 0, func_raiseexception) 162 self.con.create_function("memoryerror", 0, func_memoryerror) 163 self.con.create_function("overflowerror", 0, func_overflowerror) 164 165 self.con.create_function("isblob", 1, lambda x: isinstance(x, bytes)) 166 self.con.create_function("isnone", 1, lambda x: x is None) 167 self.con.create_function("spam", -1, lambda *x: len(x)) 168 self.con.execute("create table test(t text)") 169 170 def tearDown(self): 171 self.con.close() 172 173 def test_func_error_on_create(self): 174 with self.assertRaises(sqlite.OperationalError): 175 self.con.create_function("bla", -100, lambda x: 2*x) 176 177 def test_func_too_many_args(self): 178 category = sqlite.SQLITE_LIMIT_FUNCTION_ARG 179 msg = "too many arguments on function" 180 with cx_limit(self.con, category=category, limit=1): 181 self.con.execute("select abs(-1)"); 182 with self.assertRaisesRegex(sqlite.OperationalError, msg): 183 self.con.execute("select max(1, 2)"); 184 185 def test_func_ref_count(self): 186 def getfunc(): 187 def f(): 188 return 1 189 return f 190 f = getfunc() 191 globals()["foo"] = f 192 # self.con.create_function("reftest", 0, getfunc()) 193 self.con.create_function("reftest", 0, f) 194 cur = self.con.cursor() 195 cur.execute("select reftest()") 196 197 def test_func_return_text(self): 198 cur = self.con.cursor() 199 cur.execute("select returntext()") 200 val = cur.fetchone()[0] 201 self.assertEqual(type(val), str) 202 self.assertEqual(val, "foo") 203 204 def test_func_return_text_with_null_char(self): 205 cur = self.con.cursor() 206 res = cur.execute("select returntextwithnull()").fetchone()[0] 207 self.assertEqual(type(res), str) 208 self.assertEqual(res, "1\x002") 209 210 def test_func_return_unicode(self): 211 cur = self.con.cursor() 212 cur.execute("select returnunicode()") 213 val = cur.fetchone()[0] 214 self.assertEqual(type(val), str) 215 self.assertEqual(val, "bar") 216 217 def test_func_return_int(self): 218 cur = self.con.cursor() 219 cur.execute("select returnint()") 220 val = cur.fetchone()[0] 221 self.assertEqual(type(val), int) 222 self.assertEqual(val, 42) 223 224 def test_func_return_float(self): 225 cur = self.con.cursor() 226 cur.execute("select returnfloat()") 227 val = cur.fetchone()[0] 228 self.assertEqual(type(val), float) 229 if val < 3.139 or val > 3.141: 230 self.fail("wrong value") 231 232 def test_func_return_null(self): 233 cur = self.con.cursor() 234 cur.execute("select returnnull()") 235 val = cur.fetchone()[0] 236 self.assertEqual(type(val), type(None)) 237 self.assertEqual(val, None) 238 239 def test_func_return_blob(self): 240 cur = self.con.cursor() 241 cur.execute("select returnblob()") 242 val = cur.fetchone()[0] 243 self.assertEqual(type(val), bytes) 244 self.assertEqual(val, b"blob") 245 246 def test_func_return_long_long(self): 247 cur = self.con.cursor() 248 cur.execute("select returnlonglong()") 249 val = cur.fetchone()[0] 250 self.assertEqual(val, 1<<31) 251 252 def test_func_return_nan(self): 253 cur = self.con.cursor() 254 cur.execute("select returnnan()") 255 self.assertIsNone(cur.fetchone()[0]) 256 257 @with_tracebacks(ZeroDivisionError, name="func_raiseexception") 258 def test_func_exception(self): 259 cur = self.con.cursor() 260 with self.assertRaises(sqlite.OperationalError) as cm: 261 cur.execute("select raiseexception()") 262 cur.fetchone() 263 self.assertEqual(str(cm.exception), 'user-defined function raised exception') 264 265 @with_tracebacks(MemoryError, name="func_memoryerror") 266 def test_func_memory_error(self): 267 cur = self.con.cursor() 268 with self.assertRaises(MemoryError): 269 cur.execute("select memoryerror()") 270 cur.fetchone() 271 272 @with_tracebacks(OverflowError, name="func_overflowerror") 273 def test_func_overflow_error(self): 274 cur = self.con.cursor() 275 with self.assertRaises(sqlite.DataError): 276 cur.execute("select overflowerror()") 277 cur.fetchone() 278 279 def test_any_arguments(self): 280 cur = self.con.cursor() 281 cur.execute("select spam(?, ?)", (1, 2)) 282 val = cur.fetchone()[0] 283 self.assertEqual(val, 2) 284 285 def test_empty_blob(self): 286 cur = self.con.execute("select isblob(x'')") 287 self.assertTrue(cur.fetchone()[0]) 288 289 def test_nan_float(self): 290 cur = self.con.execute("select isnone(?)", (float("nan"),)) 291 # SQLite has no concept of nan; it is converted to NULL 292 self.assertTrue(cur.fetchone()[0]) 293 294 def test_too_large_int(self): 295 err = "Python int too large to convert to SQLite INTEGER" 296 self.assertRaisesRegex(OverflowError, err, self.con.execute, 297 "select spam(?)", (1 << 65,)) 298 299 def test_non_contiguous_blob(self): 300 self.assertRaisesRegex(BufferError, 301 "underlying buffer is not C-contiguous", 302 self.con.execute, "select spam(?)", 303 (memoryview(b"blob")[::2],)) 304 305 @with_tracebacks(BufferError, regex="buffer.*contiguous") 306 def test_return_non_contiguous_blob(self): 307 with self.assertRaises(sqlite.OperationalError): 308 cur = self.con.execute("select return_noncont_blob()") 309 cur.fetchone() 310 311 def test_param_surrogates(self): 312 self.assertRaisesRegex(UnicodeEncodeError, "surrogates not allowed", 313 self.con.execute, "select spam(?)", 314 ("\ud803\ude6d",)) 315 316 def test_func_params(self): 317 results = [] 318 def append_result(arg): 319 results.append((arg, type(arg))) 320 self.con.create_function("test_params", 1, append_result) 321 322 dataset = [ 323 (42, int), 324 (-1, int), 325 (1234567890123456789, int), 326 (4611686018427387905, int), # 63-bit int with non-zero low bits 327 (3.14, float), 328 (float('inf'), float), 329 ("text", str), 330 ("1\x002", str), 331 ("\u02e2q\u02e1\u2071\u1d57\u1d49", str), 332 (b"blob", bytes), 333 (bytearray(range(2)), bytes), 334 (memoryview(b"blob"), bytes), 335 (None, type(None)), 336 ] 337 for val, _ in dataset: 338 cur = self.con.execute("select test_params(?)", (val,)) 339 cur.fetchone() 340 self.assertEqual(dataset, results) 341 342 # Regarding deterministic functions: 343 # 344 # Between 3.8.3 and 3.15.0, deterministic functions were only used to 345 # optimize inner loops. From 3.15.0 and onward, deterministic functions 346 # were permitted in WHERE clauses of partial indices, which allows testing 347 # based on syntax, iso. the query optimizer. 348 def test_func_non_deterministic(self): 349 mock = Mock(return_value=None) 350 self.con.create_function("nondeterministic", 0, mock, deterministic=False) 351 with self.assertRaises(sqlite.OperationalError): 352 self.con.execute("create index t on test(t) where nondeterministic() is not null") 353 354 def test_func_deterministic(self): 355 mock = Mock(return_value=None) 356 self.con.create_function("deterministic", 0, mock, deterministic=True) 357 try: 358 self.con.execute("create index t on test(t) where deterministic() is not null") 359 except sqlite.OperationalError: 360 self.fail("Unexpected failure while creating partial index") 361 362 def test_func_deterministic_keyword_only(self): 363 with self.assertRaises(TypeError): 364 self.con.create_function("deterministic", 0, int, True) 365 366 def test_function_destructor_via_gc(self): 367 # See bpo-44304: The destructor of the user function can 368 # crash if is called without the GIL from the gc functions 369 def md5sum(t): 370 return 371 372 with memory_database() as dest: 373 dest.create_function("md5", 1, md5sum) 374 x = dest("create table lang (name, first_appeared)") 375 del md5sum, dest 376 377 y = [x] 378 y.append(y) 379 380 del x,y 381 gc_collect() 382 383 @with_tracebacks(OverflowError) 384 def test_func_return_too_large_int(self): 385 cur = self.con.cursor() 386 msg = "string or blob too big" 387 for value in 2**63, -2**63-1, 2**64: 388 self.con.create_function("largeint", 0, lambda value=value: value) 389 with self.assertRaisesRegex(sqlite.DataError, msg): 390 cur.execute("select largeint()") 391 392 @with_tracebacks(UnicodeEncodeError, "surrogates not allowed", "chr") 393 def test_func_return_text_with_surrogates(self): 394 cur = self.con.cursor() 395 self.con.create_function("pychr", 1, chr) 396 for value in 0xd8ff, 0xdcff: 397 with self.assertRaises(sqlite.OperationalError): 398 cur.execute("select pychr(?)", (value,)) 399 400 @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform') 401 @bigmemtest(size=2**31, memuse=3, dry_run=False) 402 def test_func_return_too_large_text(self, size): 403 cur = self.con.cursor() 404 for size in 2**31-1, 2**31: 405 self.con.create_function("largetext", 0, lambda size=size: "b" * size) 406 with self.assertRaises(sqlite.DataError): 407 cur.execute("select largetext()") 408 409 @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform') 410 @bigmemtest(size=2**31, memuse=2, dry_run=False) 411 def test_func_return_too_large_blob(self, size): 412 cur = self.con.cursor() 413 for size in 2**31-1, 2**31: 414 self.con.create_function("largeblob", 0, lambda size=size: b"b" * size) 415 with self.assertRaises(sqlite.DataError): 416 cur.execute("select largeblob()") 417 418 def test_func_return_illegal_value(self): 419 self.con.create_function("badreturn", 0, lambda: self) 420 msg = "user-defined function raised exception" 421 self.assertRaisesRegex(sqlite.OperationalError, msg, 422 self.con.execute, "select badreturn()") 423 424 def test_func_keyword_args(self): 425 regex = ( 426 r"Passing keyword arguments 'name', 'narg' and 'func' to " 427 r"_sqlite3.Connection.create_function\(\) is deprecated. " 428 r"Parameters 'name', 'narg' and 'func' will become " 429 r"positional-only in Python 3.15." 430 ) 431 432 def noop(): 433 return None 434 435 with self.assertWarnsRegex(DeprecationWarning, regex) as cm: 436 self.con.create_function("noop", 0, func=noop) 437 self.assertEqual(cm.filename, __file__) 438 439 with self.assertWarnsRegex(DeprecationWarning, regex) as cm: 440 self.con.create_function("noop", narg=0, func=noop) 441 self.assertEqual(cm.filename, __file__) 442 443 with self.assertWarnsRegex(DeprecationWarning, regex) as cm: 444 self.con.create_function(name="noop", narg=0, func=noop) 445 self.assertEqual(cm.filename, __file__) 446 447 448class WindowSumInt: 449 def __init__(self): 450 self.count = 0 451 452 def step(self, value): 453 self.count += value 454 455 def value(self): 456 return self.count 457 458 def inverse(self, value): 459 self.count -= value 460 461 def finalize(self): 462 return self.count 463 464class BadWindow(Exception): 465 pass 466 467 468@unittest.skipIf(sqlite.sqlite_version_info < (3, 25, 0), 469 "Requires SQLite 3.25.0 or newer") 470class WindowFunctionTests(unittest.TestCase): 471 def setUp(self): 472 self.con = sqlite.connect(":memory:") 473 self.cur = self.con.cursor() 474 475 # Test case taken from https://www.sqlite.org/windowfunctions.html#udfwinfunc 476 values = [ 477 ("a", 4), 478 ("b", 5), 479 ("c", 3), 480 ("d", 8), 481 ("e", 1), 482 ] 483 with self.con: 484 self.con.execute("create table test(x, y)") 485 self.con.executemany("insert into test values(?, ?)", values) 486 self.expected = [ 487 ("a", 9), 488 ("b", 12), 489 ("c", 16), 490 ("d", 12), 491 ("e", 9), 492 ] 493 self.query = """ 494 select x, %s(y) over ( 495 order by x rows between 1 preceding and 1 following 496 ) as sum_y 497 from test order by x 498 """ 499 self.con.create_window_function("sumint", 1, WindowSumInt) 500 501 def tearDown(self): 502 self.cur.close() 503 self.con.close() 504 505 def test_win_sum_int(self): 506 self.cur.execute(self.query % "sumint") 507 self.assertEqual(self.cur.fetchall(), self.expected) 508 509 def test_win_error_on_create(self): 510 self.assertRaises(sqlite.ProgrammingError, 511 self.con.create_window_function, 512 "shouldfail", -100, WindowSumInt) 513 514 @with_tracebacks(BadWindow) 515 def test_win_exception_in_method(self): 516 for meth in "__init__", "step", "value", "inverse": 517 with self.subTest(meth=meth): 518 with patch.object(WindowSumInt, meth, side_effect=BadWindow): 519 name = f"exc_{meth}" 520 self.con.create_window_function(name, 1, WindowSumInt) 521 msg = f"'{meth}' method raised error" 522 with self.assertRaisesRegex(sqlite.OperationalError, msg): 523 self.cur.execute(self.query % name) 524 self.cur.fetchall() 525 526 @with_tracebacks(BadWindow) 527 def test_win_exception_in_finalize(self): 528 # Note: SQLite does not (as of version 3.38.0) propagate finalize 529 # callback errors to sqlite3_step(); this implies that OperationalError 530 # is _not_ raised. 531 with patch.object(WindowSumInt, "finalize", side_effect=BadWindow): 532 name = "exception_in_finalize" 533 self.con.create_window_function(name, 1, WindowSumInt) 534 self.cur.execute(self.query % name) 535 self.cur.fetchall() 536 537 @with_tracebacks(AttributeError) 538 def test_win_missing_method(self): 539 class MissingValue: 540 def step(self, x): pass 541 def inverse(self, x): pass 542 def finalize(self): return 42 543 544 class MissingInverse: 545 def step(self, x): pass 546 def value(self): return 42 547 def finalize(self): return 42 548 549 class MissingStep: 550 def value(self): return 42 551 def inverse(self, x): pass 552 def finalize(self): return 42 553 554 dataset = ( 555 ("step", MissingStep), 556 ("value", MissingValue), 557 ("inverse", MissingInverse), 558 ) 559 for meth, cls in dataset: 560 with self.subTest(meth=meth, cls=cls): 561 name = f"exc_{meth}" 562 self.con.create_window_function(name, 1, cls) 563 with self.assertRaisesRegex(sqlite.OperationalError, 564 f"'{meth}' method not defined"): 565 self.cur.execute(self.query % name) 566 self.cur.fetchall() 567 568 @with_tracebacks(AttributeError) 569 def test_win_missing_finalize(self): 570 # Note: SQLite does not (as of version 3.38.0) propagate finalize 571 # callback errors to sqlite3_step(); this implies that OperationalError 572 # is _not_ raised. 573 class MissingFinalize: 574 def step(self, x): pass 575 def value(self): return 42 576 def inverse(self, x): pass 577 578 name = "missing_finalize" 579 self.con.create_window_function(name, 1, MissingFinalize) 580 self.cur.execute(self.query % name) 581 self.cur.fetchall() 582 583 def test_win_clear_function(self): 584 self.con.create_window_function("sumint", 1, None) 585 self.assertRaises(sqlite.OperationalError, self.cur.execute, 586 self.query % "sumint") 587 588 def test_win_redefine_function(self): 589 # Redefine WindowSumInt; adjust the expected results accordingly. 590 class Redefined(WindowSumInt): 591 def step(self, value): self.count += value * 2 592 def inverse(self, value): self.count -= value * 2 593 expected = [(v[0], v[1]*2) for v in self.expected] 594 595 self.con.create_window_function("sumint", 1, Redefined) 596 self.cur.execute(self.query % "sumint") 597 self.assertEqual(self.cur.fetchall(), expected) 598 599 def test_win_error_value_return(self): 600 class ErrorValueReturn: 601 def __init__(self): pass 602 def step(self, x): pass 603 def value(self): return 1 << 65 604 605 self.con.create_window_function("err_val_ret", 1, ErrorValueReturn) 606 self.assertRaisesRegex(sqlite.DataError, "string or blob too big", 607 self.cur.execute, self.query % "err_val_ret") 608 609 610class AggregateTests(unittest.TestCase): 611 def setUp(self): 612 self.con = sqlite.connect(":memory:") 613 cur = self.con.cursor() 614 cur.execute(""" 615 create table test( 616 t text, 617 i integer, 618 f float, 619 n, 620 b blob 621 ) 622 """) 623 cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)", 624 ("foo", 5, 3.14, None, memoryview(b"blob"),)) 625 cur.close() 626 627 self.con.create_aggregate("nostep", 1, AggrNoStep) 628 self.con.create_aggregate("nofinalize", 1, AggrNoFinalize) 629 self.con.create_aggregate("excInit", 1, AggrExceptionInInit) 630 self.con.create_aggregate("excStep", 1, AggrExceptionInStep) 631 self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize) 632 self.con.create_aggregate("checkType", 2, AggrCheckType) 633 self.con.create_aggregate("checkTypes", -1, AggrCheckTypes) 634 self.con.create_aggregate("mysum", 1, AggrSum) 635 self.con.create_aggregate("aggtxt", 1, AggrText) 636 637 def tearDown(self): 638 self.con.close() 639 640 def test_aggr_error_on_create(self): 641 with self.assertRaises(sqlite.OperationalError): 642 self.con.create_function("bla", -100, AggrSum) 643 644 @with_tracebacks(AttributeError, name="AggrNoStep") 645 def test_aggr_no_step(self): 646 cur = self.con.cursor() 647 with self.assertRaises(sqlite.OperationalError) as cm: 648 cur.execute("select nostep(t) from test") 649 self.assertEqual(str(cm.exception), 650 "user-defined aggregate's 'step' method not defined") 651 652 def test_aggr_no_finalize(self): 653 cur = self.con.cursor() 654 msg = "user-defined aggregate's 'finalize' method not defined" 655 with self.assertRaisesRegex(sqlite.OperationalError, msg): 656 cur.execute("select nofinalize(t) from test") 657 val = cur.fetchone()[0] 658 659 @with_tracebacks(ZeroDivisionError, name="AggrExceptionInInit") 660 def test_aggr_exception_in_init(self): 661 cur = self.con.cursor() 662 with self.assertRaises(sqlite.OperationalError) as cm: 663 cur.execute("select excInit(t) from test") 664 val = cur.fetchone()[0] 665 self.assertEqual(str(cm.exception), "user-defined aggregate's '__init__' method raised error") 666 667 @with_tracebacks(ZeroDivisionError, name="AggrExceptionInStep") 668 def test_aggr_exception_in_step(self): 669 cur = self.con.cursor() 670 with self.assertRaises(sqlite.OperationalError) as cm: 671 cur.execute("select excStep(t) from test") 672 val = cur.fetchone()[0] 673 self.assertEqual(str(cm.exception), "user-defined aggregate's 'step' method raised error") 674 675 @with_tracebacks(ZeroDivisionError, name="AggrExceptionInFinalize") 676 def test_aggr_exception_in_finalize(self): 677 cur = self.con.cursor() 678 with self.assertRaises(sqlite.OperationalError) as cm: 679 cur.execute("select excFinalize(t) from test") 680 val = cur.fetchone()[0] 681 self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error") 682 683 def test_aggr_check_param_str(self): 684 cur = self.con.cursor() 685 cur.execute("select checkTypes('str', ?, ?)", ("foo", str())) 686 val = cur.fetchone()[0] 687 self.assertEqual(val, 2) 688 689 def test_aggr_check_param_int(self): 690 cur = self.con.cursor() 691 cur.execute("select checkType('int', ?)", (42,)) 692 val = cur.fetchone()[0] 693 self.assertEqual(val, 1) 694 695 def test_aggr_check_params_int(self): 696 cur = self.con.cursor() 697 cur.execute("select checkTypes('int', ?, ?)", (42, 24)) 698 val = cur.fetchone()[0] 699 self.assertEqual(val, 2) 700 701 def test_aggr_check_param_float(self): 702 cur = self.con.cursor() 703 cur.execute("select checkType('float', ?)", (3.14,)) 704 val = cur.fetchone()[0] 705 self.assertEqual(val, 1) 706 707 def test_aggr_check_param_none(self): 708 cur = self.con.cursor() 709 cur.execute("select checkType('None', ?)", (None,)) 710 val = cur.fetchone()[0] 711 self.assertEqual(val, 1) 712 713 def test_aggr_check_param_blob(self): 714 cur = self.con.cursor() 715 cur.execute("select checkType('blob', ?)", (memoryview(b"blob"),)) 716 val = cur.fetchone()[0] 717 self.assertEqual(val, 1) 718 719 def test_aggr_check_aggr_sum(self): 720 cur = self.con.cursor() 721 cur.execute("delete from test") 722 cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)]) 723 cur.execute("select mysum(i) from test") 724 val = cur.fetchone()[0] 725 self.assertEqual(val, 60) 726 727 def test_aggr_no_match(self): 728 cur = self.con.execute("select mysum(i) from (select 1 as i) where i == 0") 729 val = cur.fetchone()[0] 730 self.assertIsNone(val) 731 732 def test_aggr_text(self): 733 cur = self.con.cursor() 734 for txt in ["foo", "1\x002"]: 735 with self.subTest(txt=txt): 736 cur.execute("select aggtxt(?) from test", (txt,)) 737 val = cur.fetchone()[0] 738 self.assertEqual(val, txt) 739 740 def test_agg_keyword_args(self): 741 regex = ( 742 r"Passing keyword arguments 'name', 'n_arg' and 'aggregate_class' to " 743 r"_sqlite3.Connection.create_aggregate\(\) is deprecated. " 744 r"Parameters 'name', 'n_arg' and 'aggregate_class' will become " 745 r"positional-only in Python 3.15." 746 ) 747 748 with self.assertWarnsRegex(DeprecationWarning, regex) as cm: 749 self.con.create_aggregate("test", 1, aggregate_class=AggrText) 750 self.assertEqual(cm.filename, __file__) 751 752 with self.assertWarnsRegex(DeprecationWarning, regex) as cm: 753 self.con.create_aggregate("test", n_arg=1, aggregate_class=AggrText) 754 self.assertEqual(cm.filename, __file__) 755 756 with self.assertWarnsRegex(DeprecationWarning, regex) as cm: 757 self.con.create_aggregate(name="test", n_arg=0, 758 aggregate_class=AggrText) 759 self.assertEqual(cm.filename, __file__) 760 761 762class AuthorizerTests(unittest.TestCase): 763 @staticmethod 764 def authorizer_cb(action, arg1, arg2, dbname, source): 765 if action != sqlite.SQLITE_SELECT: 766 return sqlite.SQLITE_DENY 767 if arg2 == 'c2' or arg1 == 't2': 768 return sqlite.SQLITE_DENY 769 return sqlite.SQLITE_OK 770 771 def setUp(self): 772 self.con = sqlite.connect(":memory:") 773 self.con.executescript(""" 774 create table t1 (c1, c2); 775 create table t2 (c1, c2); 776 insert into t1 (c1, c2) values (1, 2); 777 insert into t2 (c1, c2) values (4, 5); 778 """) 779 780 # For our security test: 781 self.con.execute("select c2 from t2") 782 783 self.con.set_authorizer(self.authorizer_cb) 784 785 def tearDown(self): 786 self.con.close() 787 788 def test_table_access(self): 789 with self.assertRaises(sqlite.DatabaseError) as cm: 790 self.con.execute("select * from t2") 791 self.assertIn('prohibited', str(cm.exception)) 792 793 def test_column_access(self): 794 with self.assertRaises(sqlite.DatabaseError) as cm: 795 self.con.execute("select c2 from t1") 796 self.assertIn('prohibited', str(cm.exception)) 797 798 def test_clear_authorizer(self): 799 self.con.set_authorizer(None) 800 self.con.execute("select * from t2") 801 self.con.execute("select c2 from t1") 802 803 def test_authorizer_keyword_args(self): 804 regex = ( 805 r"Passing keyword argument 'authorizer_callback' to " 806 r"_sqlite3.Connection.set_authorizer\(\) is deprecated. " 807 r"Parameter 'authorizer_callback' will become positional-only in " 808 r"Python 3.15." 809 ) 810 811 with self.assertWarnsRegex(DeprecationWarning, regex) as cm: 812 self.con.set_authorizer(authorizer_callback=lambda: None) 813 self.assertEqual(cm.filename, __file__) 814 815 816class AuthorizerRaiseExceptionTests(AuthorizerTests): 817 @staticmethod 818 def authorizer_cb(action, arg1, arg2, dbname, source): 819 if action != sqlite.SQLITE_SELECT: 820 raise ValueError 821 if arg2 == 'c2' or arg1 == 't2': 822 raise ValueError 823 return sqlite.SQLITE_OK 824 825 @with_tracebacks(ValueError, name="authorizer_cb") 826 def test_table_access(self): 827 super().test_table_access() 828 829 @with_tracebacks(ValueError, name="authorizer_cb") 830 def test_column_access(self): 831 super().test_table_access() 832 833class AuthorizerIllegalTypeTests(AuthorizerTests): 834 @staticmethod 835 def authorizer_cb(action, arg1, arg2, dbname, source): 836 if action != sqlite.SQLITE_SELECT: 837 return 0.0 838 if arg2 == 'c2' or arg1 == 't2': 839 return 0.0 840 return sqlite.SQLITE_OK 841 842class AuthorizerLargeIntegerTests(AuthorizerTests): 843 @staticmethod 844 def authorizer_cb(action, arg1, arg2, dbname, source): 845 if action != sqlite.SQLITE_SELECT: 846 return 2**32 847 if arg2 == 'c2' or arg1 == 't2': 848 return 2**32 849 return sqlite.SQLITE_OK 850 851 852if __name__ == "__main__": 853 unittest.main() 854