• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# pysqlite2/test/userfunctions.py: tests for user-defined functions and
2#                                  aggregates.
3#
4# Copyright (C) 2005-2007 Gerhard Häring <gh@ghaering.de>
5#
6# This file is part of pysqlite.
7#
8# This software is provided 'as-is', without any express or implied
9# warranty.  In no event will the authors be held liable for any damages
10# arising from the use of this software.
11#
12# Permission is granted to anyone to use this software for any purpose,
13# including commercial applications, and to alter it and redistribute it
14# freely, subject to the following restrictions:
15#
16# 1. The origin of this software must not be misrepresented; you must not
17#    claim that you wrote the original software. If you use this software
18#    in a product, an acknowledgment in the product documentation would be
19#    appreciated but is not required.
20# 2. Altered source versions must be plainly marked as such, and must not be
21#    misrepresented as being the original software.
22# 3. This notice may not be removed or altered from any source distribution.
23
24import sys
25import unittest
26import sqlite3 as sqlite
27
28from unittest.mock import Mock, patch
29from test.support import bigmemtest, gc_collect
30
31from .util import cx_limit, memory_database
32from .util import with_tracebacks
33
34
35def func_returntext():
36    return "foo"
37def func_returntextwithnull():
38    return "1\x002"
39def func_returnunicode():
40    return "bar"
41def func_returnint():
42    return 42
43def func_returnfloat():
44    return 3.14
45def func_returnnull():
46    return None
47def func_returnblob():
48    return b"blob"
49def func_returnlonglong():
50    return 1<<31
51def func_raiseexception():
52    5/0
53def func_memoryerror():
54    raise MemoryError
55def func_overflowerror():
56    raise OverflowError
57
58class AggrNoStep:
59    def __init__(self):
60        pass
61
62    def finalize(self):
63        return 1
64
65class AggrNoFinalize:
66    def __init__(self):
67        pass
68
69    def step(self, x):
70        pass
71
72class AggrExceptionInInit:
73    def __init__(self):
74        5/0
75
76    def step(self, x):
77        pass
78
79    def finalize(self):
80        pass
81
82class AggrExceptionInStep:
83    def __init__(self):
84        pass
85
86    def step(self, x):
87        5/0
88
89    def finalize(self):
90        return 42
91
92class AggrExceptionInFinalize:
93    def __init__(self):
94        pass
95
96    def step(self, x):
97        pass
98
99    def finalize(self):
100        5/0
101
102class AggrCheckType:
103    def __init__(self):
104        self.val = None
105
106    def step(self, whichType, val):
107        theType = {"str": str, "int": int, "float": float, "None": type(None),
108                   "blob": bytes}
109        self.val = int(theType[whichType] is type(val))
110
111    def finalize(self):
112        return self.val
113
114class AggrCheckTypes:
115    def __init__(self):
116        self.val = 0
117
118    def step(self, whichType, *vals):
119        theType = {"str": str, "int": int, "float": float, "None": type(None),
120                   "blob": bytes}
121        for val in vals:
122            self.val += int(theType[whichType] is type(val))
123
124    def finalize(self):
125        return self.val
126
127class AggrSum:
128    def __init__(self):
129        self.val = 0.0
130
131    def step(self, val):
132        self.val += val
133
134    def finalize(self):
135        return self.val
136
137class AggrText:
138    def __init__(self):
139        self.txt = ""
140    def step(self, txt):
141        self.txt = self.txt + txt
142    def finalize(self):
143        return self.txt
144
145
146class FunctionTests(unittest.TestCase):
147    def setUp(self):
148        self.con = sqlite.connect(":memory:")
149
150        self.con.create_function("returntext", 0, func_returntext)
151        self.con.create_function("returntextwithnull", 0, func_returntextwithnull)
152        self.con.create_function("returnunicode", 0, func_returnunicode)
153        self.con.create_function("returnint", 0, func_returnint)
154        self.con.create_function("returnfloat", 0, func_returnfloat)
155        self.con.create_function("returnnull", 0, func_returnnull)
156        self.con.create_function("returnblob", 0, func_returnblob)
157        self.con.create_function("returnlonglong", 0, func_returnlonglong)
158        self.con.create_function("returnnan", 0, lambda: float("nan"))
159        self.con.create_function("return_noncont_blob", 0,
160                                 lambda: memoryview(b"blob")[::2])
161        self.con.create_function("raiseexception", 0, func_raiseexception)
162        self.con.create_function("memoryerror", 0, func_memoryerror)
163        self.con.create_function("overflowerror", 0, func_overflowerror)
164
165        self.con.create_function("isblob", 1, lambda x: isinstance(x, bytes))
166        self.con.create_function("isnone", 1, lambda x: x is None)
167        self.con.create_function("spam", -1, lambda *x: len(x))
168        self.con.execute("create table test(t text)")
169
170    def tearDown(self):
171        self.con.close()
172
173    def test_func_error_on_create(self):
174        with self.assertRaises(sqlite.OperationalError):
175            self.con.create_function("bla", -100, lambda x: 2*x)
176
177    def test_func_too_many_args(self):
178        category = sqlite.SQLITE_LIMIT_FUNCTION_ARG
179        msg = "too many arguments on function"
180        with cx_limit(self.con, category=category, limit=1):
181            self.con.execute("select abs(-1)");
182            with self.assertRaisesRegex(sqlite.OperationalError, msg):
183                self.con.execute("select max(1, 2)");
184
185    def test_func_ref_count(self):
186        def getfunc():
187            def f():
188                return 1
189            return f
190        f = getfunc()
191        globals()["foo"] = f
192        # self.con.create_function("reftest", 0, getfunc())
193        self.con.create_function("reftest", 0, f)
194        cur = self.con.cursor()
195        cur.execute("select reftest()")
196
197    def test_func_return_text(self):
198        cur = self.con.cursor()
199        cur.execute("select returntext()")
200        val = cur.fetchone()[0]
201        self.assertEqual(type(val), str)
202        self.assertEqual(val, "foo")
203
204    def test_func_return_text_with_null_char(self):
205        cur = self.con.cursor()
206        res = cur.execute("select returntextwithnull()").fetchone()[0]
207        self.assertEqual(type(res), str)
208        self.assertEqual(res, "1\x002")
209
210    def test_func_return_unicode(self):
211        cur = self.con.cursor()
212        cur.execute("select returnunicode()")
213        val = cur.fetchone()[0]
214        self.assertEqual(type(val), str)
215        self.assertEqual(val, "bar")
216
217    def test_func_return_int(self):
218        cur = self.con.cursor()
219        cur.execute("select returnint()")
220        val = cur.fetchone()[0]
221        self.assertEqual(type(val), int)
222        self.assertEqual(val, 42)
223
224    def test_func_return_float(self):
225        cur = self.con.cursor()
226        cur.execute("select returnfloat()")
227        val = cur.fetchone()[0]
228        self.assertEqual(type(val), float)
229        if val < 3.139 or val > 3.141:
230            self.fail("wrong value")
231
232    def test_func_return_null(self):
233        cur = self.con.cursor()
234        cur.execute("select returnnull()")
235        val = cur.fetchone()[0]
236        self.assertEqual(type(val), type(None))
237        self.assertEqual(val, None)
238
239    def test_func_return_blob(self):
240        cur = self.con.cursor()
241        cur.execute("select returnblob()")
242        val = cur.fetchone()[0]
243        self.assertEqual(type(val), bytes)
244        self.assertEqual(val, b"blob")
245
246    def test_func_return_long_long(self):
247        cur = self.con.cursor()
248        cur.execute("select returnlonglong()")
249        val = cur.fetchone()[0]
250        self.assertEqual(val, 1<<31)
251
252    def test_func_return_nan(self):
253        cur = self.con.cursor()
254        cur.execute("select returnnan()")
255        self.assertIsNone(cur.fetchone()[0])
256
257    @with_tracebacks(ZeroDivisionError, name="func_raiseexception")
258    def test_func_exception(self):
259        cur = self.con.cursor()
260        with self.assertRaises(sqlite.OperationalError) as cm:
261            cur.execute("select raiseexception()")
262            cur.fetchone()
263        self.assertEqual(str(cm.exception), 'user-defined function raised exception')
264
265    @with_tracebacks(MemoryError, name="func_memoryerror")
266    def test_func_memory_error(self):
267        cur = self.con.cursor()
268        with self.assertRaises(MemoryError):
269            cur.execute("select memoryerror()")
270            cur.fetchone()
271
272    @with_tracebacks(OverflowError, name="func_overflowerror")
273    def test_func_overflow_error(self):
274        cur = self.con.cursor()
275        with self.assertRaises(sqlite.DataError):
276            cur.execute("select overflowerror()")
277            cur.fetchone()
278
279    def test_any_arguments(self):
280        cur = self.con.cursor()
281        cur.execute("select spam(?, ?)", (1, 2))
282        val = cur.fetchone()[0]
283        self.assertEqual(val, 2)
284
285    def test_empty_blob(self):
286        cur = self.con.execute("select isblob(x'')")
287        self.assertTrue(cur.fetchone()[0])
288
289    def test_nan_float(self):
290        cur = self.con.execute("select isnone(?)", (float("nan"),))
291        # SQLite has no concept of nan; it is converted to NULL
292        self.assertTrue(cur.fetchone()[0])
293
294    def test_too_large_int(self):
295        err = "Python int too large to convert to SQLite INTEGER"
296        self.assertRaisesRegex(OverflowError, err, self.con.execute,
297                               "select spam(?)", (1 << 65,))
298
299    def test_non_contiguous_blob(self):
300        self.assertRaisesRegex(BufferError,
301                               "underlying buffer is not C-contiguous",
302                               self.con.execute, "select spam(?)",
303                               (memoryview(b"blob")[::2],))
304
305    @with_tracebacks(BufferError, regex="buffer.*contiguous")
306    def test_return_non_contiguous_blob(self):
307        with self.assertRaises(sqlite.OperationalError):
308            cur = self.con.execute("select return_noncont_blob()")
309            cur.fetchone()
310
311    def test_param_surrogates(self):
312        self.assertRaisesRegex(UnicodeEncodeError, "surrogates not allowed",
313                               self.con.execute, "select spam(?)",
314                               ("\ud803\ude6d",))
315
316    def test_func_params(self):
317        results = []
318        def append_result(arg):
319            results.append((arg, type(arg)))
320        self.con.create_function("test_params", 1, append_result)
321
322        dataset = [
323            (42, int),
324            (-1, int),
325            (1234567890123456789, int),
326            (4611686018427387905, int),  # 63-bit int with non-zero low bits
327            (3.14, float),
328            (float('inf'), float),
329            ("text", str),
330            ("1\x002", str),
331            ("\u02e2q\u02e1\u2071\u1d57\u1d49", str),
332            (b"blob", bytes),
333            (bytearray(range(2)), bytes),
334            (memoryview(b"blob"), bytes),
335            (None, type(None)),
336        ]
337        for val, _ in dataset:
338            cur = self.con.execute("select test_params(?)", (val,))
339            cur.fetchone()
340        self.assertEqual(dataset, results)
341
342    # Regarding deterministic functions:
343    #
344    # Between 3.8.3 and 3.15.0, deterministic functions were only used to
345    # optimize inner loops. From 3.15.0 and onward, deterministic functions
346    # were permitted in WHERE clauses of partial indices, which allows testing
347    # based on syntax, iso. the query optimizer.
348    def test_func_non_deterministic(self):
349        mock = Mock(return_value=None)
350        self.con.create_function("nondeterministic", 0, mock, deterministic=False)
351        with self.assertRaises(sqlite.OperationalError):
352            self.con.execute("create index t on test(t) where nondeterministic() is not null")
353
354    def test_func_deterministic(self):
355        mock = Mock(return_value=None)
356        self.con.create_function("deterministic", 0, mock, deterministic=True)
357        try:
358            self.con.execute("create index t on test(t) where deterministic() is not null")
359        except sqlite.OperationalError:
360            self.fail("Unexpected failure while creating partial index")
361
362    def test_func_deterministic_keyword_only(self):
363        with self.assertRaises(TypeError):
364            self.con.create_function("deterministic", 0, int, True)
365
366    def test_function_destructor_via_gc(self):
367        # See bpo-44304: The destructor of the user function can
368        # crash if is called without the GIL from the gc functions
369        def md5sum(t):
370            return
371
372        with memory_database() as dest:
373            dest.create_function("md5", 1, md5sum)
374            x = dest("create table lang (name, first_appeared)")
375            del md5sum, dest
376
377            y = [x]
378            y.append(y)
379
380            del x,y
381            gc_collect()
382
383    @with_tracebacks(OverflowError)
384    def test_func_return_too_large_int(self):
385        cur = self.con.cursor()
386        msg = "string or blob too big"
387        for value in 2**63, -2**63-1, 2**64:
388            self.con.create_function("largeint", 0, lambda value=value: value)
389            with self.assertRaisesRegex(sqlite.DataError, msg):
390                cur.execute("select largeint()")
391
392    @with_tracebacks(UnicodeEncodeError, "surrogates not allowed", "chr")
393    def test_func_return_text_with_surrogates(self):
394        cur = self.con.cursor()
395        self.con.create_function("pychr", 1, chr)
396        for value in 0xd8ff, 0xdcff:
397            with self.assertRaises(sqlite.OperationalError):
398                cur.execute("select pychr(?)", (value,))
399
400    @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
401    @bigmemtest(size=2**31, memuse=3, dry_run=False)
402    def test_func_return_too_large_text(self, size):
403        cur = self.con.cursor()
404        for size in 2**31-1, 2**31:
405            self.con.create_function("largetext", 0, lambda size=size: "b" * size)
406            with self.assertRaises(sqlite.DataError):
407                cur.execute("select largetext()")
408
409    @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
410    @bigmemtest(size=2**31, memuse=2, dry_run=False)
411    def test_func_return_too_large_blob(self, size):
412        cur = self.con.cursor()
413        for size in 2**31-1, 2**31:
414            self.con.create_function("largeblob", 0, lambda size=size: b"b" * size)
415            with self.assertRaises(sqlite.DataError):
416                cur.execute("select largeblob()")
417
418    def test_func_return_illegal_value(self):
419        self.con.create_function("badreturn", 0, lambda: self)
420        msg = "user-defined function raised exception"
421        self.assertRaisesRegex(sqlite.OperationalError, msg,
422                               self.con.execute, "select badreturn()")
423
424    def test_func_keyword_args(self):
425        regex = (
426            r"Passing keyword arguments 'name', 'narg' and 'func' to "
427            r"_sqlite3.Connection.create_function\(\) is deprecated. "
428            r"Parameters 'name', 'narg' and 'func' will become "
429            r"positional-only in Python 3.15."
430        )
431
432        def noop():
433            return None
434
435        with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
436            self.con.create_function("noop", 0, func=noop)
437        self.assertEqual(cm.filename, __file__)
438
439        with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
440            self.con.create_function("noop", narg=0, func=noop)
441        self.assertEqual(cm.filename, __file__)
442
443        with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
444            self.con.create_function(name="noop", narg=0, func=noop)
445        self.assertEqual(cm.filename, __file__)
446
447
448class WindowSumInt:
449    def __init__(self):
450        self.count = 0
451
452    def step(self, value):
453        self.count += value
454
455    def value(self):
456        return self.count
457
458    def inverse(self, value):
459        self.count -= value
460
461    def finalize(self):
462        return self.count
463
464class BadWindow(Exception):
465    pass
466
467
468@unittest.skipIf(sqlite.sqlite_version_info < (3, 25, 0),
469                 "Requires SQLite 3.25.0 or newer")
470class WindowFunctionTests(unittest.TestCase):
471    def setUp(self):
472        self.con = sqlite.connect(":memory:")
473        self.cur = self.con.cursor()
474
475        # Test case taken from https://www.sqlite.org/windowfunctions.html#udfwinfunc
476        values = [
477            ("a", 4),
478            ("b", 5),
479            ("c", 3),
480            ("d", 8),
481            ("e", 1),
482        ]
483        with self.con:
484            self.con.execute("create table test(x, y)")
485            self.con.executemany("insert into test values(?, ?)", values)
486        self.expected = [
487            ("a", 9),
488            ("b", 12),
489            ("c", 16),
490            ("d", 12),
491            ("e", 9),
492        ]
493        self.query = """
494            select x, %s(y) over (
495                order by x rows between 1 preceding and 1 following
496            ) as sum_y
497            from test order by x
498        """
499        self.con.create_window_function("sumint", 1, WindowSumInt)
500
501    def tearDown(self):
502        self.cur.close()
503        self.con.close()
504
505    def test_win_sum_int(self):
506        self.cur.execute(self.query % "sumint")
507        self.assertEqual(self.cur.fetchall(), self.expected)
508
509    def test_win_error_on_create(self):
510        self.assertRaises(sqlite.ProgrammingError,
511                          self.con.create_window_function,
512                          "shouldfail", -100, WindowSumInt)
513
514    @with_tracebacks(BadWindow)
515    def test_win_exception_in_method(self):
516        for meth in "__init__", "step", "value", "inverse":
517            with self.subTest(meth=meth):
518                with patch.object(WindowSumInt, meth, side_effect=BadWindow):
519                    name = f"exc_{meth}"
520                    self.con.create_window_function(name, 1, WindowSumInt)
521                    msg = f"'{meth}' method raised error"
522                    with self.assertRaisesRegex(sqlite.OperationalError, msg):
523                        self.cur.execute(self.query % name)
524                        self.cur.fetchall()
525
526    @with_tracebacks(BadWindow)
527    def test_win_exception_in_finalize(self):
528        # Note: SQLite does not (as of version 3.38.0) propagate finalize
529        # callback errors to sqlite3_step(); this implies that OperationalError
530        # is _not_ raised.
531        with patch.object(WindowSumInt, "finalize", side_effect=BadWindow):
532            name = "exception_in_finalize"
533            self.con.create_window_function(name, 1, WindowSumInt)
534            self.cur.execute(self.query % name)
535            self.cur.fetchall()
536
537    @with_tracebacks(AttributeError)
538    def test_win_missing_method(self):
539        class MissingValue:
540            def step(self, x): pass
541            def inverse(self, x): pass
542            def finalize(self): return 42
543
544        class MissingInverse:
545            def step(self, x): pass
546            def value(self): return 42
547            def finalize(self): return 42
548
549        class MissingStep:
550            def value(self): return 42
551            def inverse(self, x): pass
552            def finalize(self): return 42
553
554        dataset = (
555            ("step", MissingStep),
556            ("value", MissingValue),
557            ("inverse", MissingInverse),
558        )
559        for meth, cls in dataset:
560            with self.subTest(meth=meth, cls=cls):
561                name = f"exc_{meth}"
562                self.con.create_window_function(name, 1, cls)
563                with self.assertRaisesRegex(sqlite.OperationalError,
564                                            f"'{meth}' method not defined"):
565                    self.cur.execute(self.query % name)
566                    self.cur.fetchall()
567
568    @with_tracebacks(AttributeError)
569    def test_win_missing_finalize(self):
570        # Note: SQLite does not (as of version 3.38.0) propagate finalize
571        # callback errors to sqlite3_step(); this implies that OperationalError
572        # is _not_ raised.
573        class MissingFinalize:
574            def step(self, x): pass
575            def value(self): return 42
576            def inverse(self, x): pass
577
578        name = "missing_finalize"
579        self.con.create_window_function(name, 1, MissingFinalize)
580        self.cur.execute(self.query % name)
581        self.cur.fetchall()
582
583    def test_win_clear_function(self):
584        self.con.create_window_function("sumint", 1, None)
585        self.assertRaises(sqlite.OperationalError, self.cur.execute,
586                          self.query % "sumint")
587
588    def test_win_redefine_function(self):
589        # Redefine WindowSumInt; adjust the expected results accordingly.
590        class Redefined(WindowSumInt):
591            def step(self, value): self.count += value * 2
592            def inverse(self, value): self.count -= value * 2
593        expected = [(v[0], v[1]*2) for v in self.expected]
594
595        self.con.create_window_function("sumint", 1, Redefined)
596        self.cur.execute(self.query % "sumint")
597        self.assertEqual(self.cur.fetchall(), expected)
598
599    def test_win_error_value_return(self):
600        class ErrorValueReturn:
601            def __init__(self): pass
602            def step(self, x): pass
603            def value(self): return 1 << 65
604
605        self.con.create_window_function("err_val_ret", 1, ErrorValueReturn)
606        self.assertRaisesRegex(sqlite.DataError, "string or blob too big",
607                               self.cur.execute, self.query % "err_val_ret")
608
609
610class AggregateTests(unittest.TestCase):
611    def setUp(self):
612        self.con = sqlite.connect(":memory:")
613        cur = self.con.cursor()
614        cur.execute("""
615            create table test(
616                t text,
617                i integer,
618                f float,
619                n,
620                b blob
621                )
622            """)
623        cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)",
624            ("foo", 5, 3.14, None, memoryview(b"blob"),))
625        cur.close()
626
627        self.con.create_aggregate("nostep", 1, AggrNoStep)
628        self.con.create_aggregate("nofinalize", 1, AggrNoFinalize)
629        self.con.create_aggregate("excInit", 1, AggrExceptionInInit)
630        self.con.create_aggregate("excStep", 1, AggrExceptionInStep)
631        self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize)
632        self.con.create_aggregate("checkType", 2, AggrCheckType)
633        self.con.create_aggregate("checkTypes", -1, AggrCheckTypes)
634        self.con.create_aggregate("mysum", 1, AggrSum)
635        self.con.create_aggregate("aggtxt", 1, AggrText)
636
637    def tearDown(self):
638        self.con.close()
639
640    def test_aggr_error_on_create(self):
641        with self.assertRaises(sqlite.OperationalError):
642            self.con.create_function("bla", -100, AggrSum)
643
644    @with_tracebacks(AttributeError, name="AggrNoStep")
645    def test_aggr_no_step(self):
646        cur = self.con.cursor()
647        with self.assertRaises(sqlite.OperationalError) as cm:
648            cur.execute("select nostep(t) from test")
649        self.assertEqual(str(cm.exception),
650                         "user-defined aggregate's 'step' method not defined")
651
652    def test_aggr_no_finalize(self):
653        cur = self.con.cursor()
654        msg = "user-defined aggregate's 'finalize' method not defined"
655        with self.assertRaisesRegex(sqlite.OperationalError, msg):
656            cur.execute("select nofinalize(t) from test")
657            val = cur.fetchone()[0]
658
659    @with_tracebacks(ZeroDivisionError, name="AggrExceptionInInit")
660    def test_aggr_exception_in_init(self):
661        cur = self.con.cursor()
662        with self.assertRaises(sqlite.OperationalError) as cm:
663            cur.execute("select excInit(t) from test")
664            val = cur.fetchone()[0]
665        self.assertEqual(str(cm.exception), "user-defined aggregate's '__init__' method raised error")
666
667    @with_tracebacks(ZeroDivisionError, name="AggrExceptionInStep")
668    def test_aggr_exception_in_step(self):
669        cur = self.con.cursor()
670        with self.assertRaises(sqlite.OperationalError) as cm:
671            cur.execute("select excStep(t) from test")
672            val = cur.fetchone()[0]
673        self.assertEqual(str(cm.exception), "user-defined aggregate's 'step' method raised error")
674
675    @with_tracebacks(ZeroDivisionError, name="AggrExceptionInFinalize")
676    def test_aggr_exception_in_finalize(self):
677        cur = self.con.cursor()
678        with self.assertRaises(sqlite.OperationalError) as cm:
679            cur.execute("select excFinalize(t) from test")
680            val = cur.fetchone()[0]
681        self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error")
682
683    def test_aggr_check_param_str(self):
684        cur = self.con.cursor()
685        cur.execute("select checkTypes('str', ?, ?)", ("foo", str()))
686        val = cur.fetchone()[0]
687        self.assertEqual(val, 2)
688
689    def test_aggr_check_param_int(self):
690        cur = self.con.cursor()
691        cur.execute("select checkType('int', ?)", (42,))
692        val = cur.fetchone()[0]
693        self.assertEqual(val, 1)
694
695    def test_aggr_check_params_int(self):
696        cur = self.con.cursor()
697        cur.execute("select checkTypes('int', ?, ?)", (42, 24))
698        val = cur.fetchone()[0]
699        self.assertEqual(val, 2)
700
701    def test_aggr_check_param_float(self):
702        cur = self.con.cursor()
703        cur.execute("select checkType('float', ?)", (3.14,))
704        val = cur.fetchone()[0]
705        self.assertEqual(val, 1)
706
707    def test_aggr_check_param_none(self):
708        cur = self.con.cursor()
709        cur.execute("select checkType('None', ?)", (None,))
710        val = cur.fetchone()[0]
711        self.assertEqual(val, 1)
712
713    def test_aggr_check_param_blob(self):
714        cur = self.con.cursor()
715        cur.execute("select checkType('blob', ?)", (memoryview(b"blob"),))
716        val = cur.fetchone()[0]
717        self.assertEqual(val, 1)
718
719    def test_aggr_check_aggr_sum(self):
720        cur = self.con.cursor()
721        cur.execute("delete from test")
722        cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)])
723        cur.execute("select mysum(i) from test")
724        val = cur.fetchone()[0]
725        self.assertEqual(val, 60)
726
727    def test_aggr_no_match(self):
728        cur = self.con.execute("select mysum(i) from (select 1 as i) where i == 0")
729        val = cur.fetchone()[0]
730        self.assertIsNone(val)
731
732    def test_aggr_text(self):
733        cur = self.con.cursor()
734        for txt in ["foo", "1\x002"]:
735            with self.subTest(txt=txt):
736                cur.execute("select aggtxt(?) from test", (txt,))
737                val = cur.fetchone()[0]
738                self.assertEqual(val, txt)
739
740    def test_agg_keyword_args(self):
741        regex = (
742            r"Passing keyword arguments 'name', 'n_arg' and 'aggregate_class' to "
743            r"_sqlite3.Connection.create_aggregate\(\) is deprecated. "
744            r"Parameters 'name', 'n_arg' and 'aggregate_class' will become "
745            r"positional-only in Python 3.15."
746        )
747
748        with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
749            self.con.create_aggregate("test", 1, aggregate_class=AggrText)
750        self.assertEqual(cm.filename, __file__)
751
752        with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
753            self.con.create_aggregate("test", n_arg=1, aggregate_class=AggrText)
754        self.assertEqual(cm.filename, __file__)
755
756        with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
757            self.con.create_aggregate(name="test", n_arg=0,
758                                      aggregate_class=AggrText)
759        self.assertEqual(cm.filename, __file__)
760
761
762class AuthorizerTests(unittest.TestCase):
763    @staticmethod
764    def authorizer_cb(action, arg1, arg2, dbname, source):
765        if action != sqlite.SQLITE_SELECT:
766            return sqlite.SQLITE_DENY
767        if arg2 == 'c2' or arg1 == 't2':
768            return sqlite.SQLITE_DENY
769        return sqlite.SQLITE_OK
770
771    def setUp(self):
772        self.con = sqlite.connect(":memory:")
773        self.con.executescript("""
774            create table t1 (c1, c2);
775            create table t2 (c1, c2);
776            insert into t1 (c1, c2) values (1, 2);
777            insert into t2 (c1, c2) values (4, 5);
778            """)
779
780        # For our security test:
781        self.con.execute("select c2 from t2")
782
783        self.con.set_authorizer(self.authorizer_cb)
784
785    def tearDown(self):
786        self.con.close()
787
788    def test_table_access(self):
789        with self.assertRaises(sqlite.DatabaseError) as cm:
790            self.con.execute("select * from t2")
791        self.assertIn('prohibited', str(cm.exception))
792
793    def test_column_access(self):
794        with self.assertRaises(sqlite.DatabaseError) as cm:
795            self.con.execute("select c2 from t1")
796        self.assertIn('prohibited', str(cm.exception))
797
798    def test_clear_authorizer(self):
799        self.con.set_authorizer(None)
800        self.con.execute("select * from t2")
801        self.con.execute("select c2 from t1")
802
803    def test_authorizer_keyword_args(self):
804        regex = (
805            r"Passing keyword argument 'authorizer_callback' to "
806            r"_sqlite3.Connection.set_authorizer\(\) is deprecated. "
807            r"Parameter 'authorizer_callback' will become positional-only in "
808            r"Python 3.15."
809        )
810
811        with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
812            self.con.set_authorizer(authorizer_callback=lambda: None)
813        self.assertEqual(cm.filename, __file__)
814
815
816class AuthorizerRaiseExceptionTests(AuthorizerTests):
817    @staticmethod
818    def authorizer_cb(action, arg1, arg2, dbname, source):
819        if action != sqlite.SQLITE_SELECT:
820            raise ValueError
821        if arg2 == 'c2' or arg1 == 't2':
822            raise ValueError
823        return sqlite.SQLITE_OK
824
825    @with_tracebacks(ValueError, name="authorizer_cb")
826    def test_table_access(self):
827        super().test_table_access()
828
829    @with_tracebacks(ValueError, name="authorizer_cb")
830    def test_column_access(self):
831        super().test_table_access()
832
833class AuthorizerIllegalTypeTests(AuthorizerTests):
834    @staticmethod
835    def authorizer_cb(action, arg1, arg2, dbname, source):
836        if action != sqlite.SQLITE_SELECT:
837            return 0.0
838        if arg2 == 'c2' or arg1 == 't2':
839            return 0.0
840        return sqlite.SQLITE_OK
841
842class AuthorizerLargeIntegerTests(AuthorizerTests):
843    @staticmethod
844    def authorizer_cb(action, arg1, arg2, dbname, source):
845        if action != sqlite.SQLITE_SELECT:
846            return 2**32
847        if arg2 == 'c2' or arg1 == 't2':
848            return 2**32
849        return sqlite.SQLITE_OK
850
851
852if __name__ == "__main__":
853    unittest.main()
854