• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#-*- coding: iso-8859-1 -*-
2# pysqlite2/test/userfunctions.py: tests for user-defined functions and
3#                                  aggregates.
4#
5# Copyright (C) 2005-2007 Gerhard H�ring <gh@ghaering.de>
6#
7# This file is part of pysqlite.
8#
9# This software is provided 'as-is', without any express or implied
10# warranty.  In no event will the authors be held liable for any damages
11# arising from the use of this software.
12#
13# Permission is granted to anyone to use this software for any purpose,
14# including commercial applications, and to alter it and redistribute it
15# freely, subject to the following restrictions:
16#
17# 1. The origin of this software must not be misrepresented; you must not
18#    claim that you wrote the original software. If you use this software
19#    in a product, an acknowledgment in the product documentation would be
20#    appreciated but is not required.
21# 2. Altered source versions must be plainly marked as such, and must not be
22#    misrepresented as being the original software.
23# 3. This notice may not be removed or altered from any source distribution.
24
25import unittest
26import sqlite3 as sqlite
27
28def func_returntext():
29    return "foo"
30def func_returnunicode():
31    return "bar"
32def func_returnint():
33    return 42
34def func_returnfloat():
35    return 3.14
36def func_returnnull():
37    return None
38def func_returnblob():
39    return b"blob"
40def func_returnlonglong():
41    return 1<<31
42def func_raiseexception():
43    5/0
44
45def func_isstring(v):
46    return type(v) is str
47def func_isint(v):
48    return type(v) is int
49def func_isfloat(v):
50    return type(v) is float
51def func_isnone(v):
52    return type(v) is type(None)
53def func_isblob(v):
54    return isinstance(v, (bytes, memoryview))
55def func_islonglong(v):
56    return isinstance(v, int) and v >= 1<<31
57
58def func(*args):
59    return len(args)
60
61class AggrNoStep:
62    def __init__(self):
63        pass
64
65    def finalize(self):
66        return 1
67
68class AggrNoFinalize:
69    def __init__(self):
70        pass
71
72    def step(self, x):
73        pass
74
75class AggrExceptionInInit:
76    def __init__(self):
77        5/0
78
79    def step(self, x):
80        pass
81
82    def finalize(self):
83        pass
84
85class AggrExceptionInStep:
86    def __init__(self):
87        pass
88
89    def step(self, x):
90        5/0
91
92    def finalize(self):
93        return 42
94
95class AggrExceptionInFinalize:
96    def __init__(self):
97        pass
98
99    def step(self, x):
100        pass
101
102    def finalize(self):
103        5/0
104
105class AggrCheckType:
106    def __init__(self):
107        self.val = None
108
109    def step(self, whichType, val):
110        theType = {"str": str, "int": int, "float": float, "None": type(None),
111                   "blob": bytes}
112        self.val = int(theType[whichType] is type(val))
113
114    def finalize(self):
115        return self.val
116
117class AggrCheckTypes:
118    def __init__(self):
119        self.val = 0
120
121    def step(self, whichType, *vals):
122        theType = {"str": str, "int": int, "float": float, "None": type(None),
123                   "blob": bytes}
124        for val in vals:
125            self.val += int(theType[whichType] is type(val))
126
127    def finalize(self):
128        return self.val
129
130class AggrSum:
131    def __init__(self):
132        self.val = 0.0
133
134    def step(self, val):
135        self.val += val
136
137    def finalize(self):
138        return self.val
139
140class FunctionTests(unittest.TestCase):
141    def setUp(self):
142        self.con = sqlite.connect(":memory:")
143
144        self.con.create_function("returntext", 0, func_returntext)
145        self.con.create_function("returnunicode", 0, func_returnunicode)
146        self.con.create_function("returnint", 0, func_returnint)
147        self.con.create_function("returnfloat", 0, func_returnfloat)
148        self.con.create_function("returnnull", 0, func_returnnull)
149        self.con.create_function("returnblob", 0, func_returnblob)
150        self.con.create_function("returnlonglong", 0, func_returnlonglong)
151        self.con.create_function("raiseexception", 0, func_raiseexception)
152
153        self.con.create_function("isstring", 1, func_isstring)
154        self.con.create_function("isint", 1, func_isint)
155        self.con.create_function("isfloat", 1, func_isfloat)
156        self.con.create_function("isnone", 1, func_isnone)
157        self.con.create_function("isblob", 1, func_isblob)
158        self.con.create_function("islonglong", 1, func_islonglong)
159        self.con.create_function("spam", -1, func)
160
161    def tearDown(self):
162        self.con.close()
163
164    def CheckFuncErrorOnCreate(self):
165        with self.assertRaises(sqlite.OperationalError):
166            self.con.create_function("bla", -100, lambda x: 2*x)
167
168    def CheckFuncRefCount(self):
169        def getfunc():
170            def f():
171                return 1
172            return f
173        f = getfunc()
174        globals()["foo"] = f
175        # self.con.create_function("reftest", 0, getfunc())
176        self.con.create_function("reftest", 0, f)
177        cur = self.con.cursor()
178        cur.execute("select reftest()")
179
180    def CheckFuncReturnText(self):
181        cur = self.con.cursor()
182        cur.execute("select returntext()")
183        val = cur.fetchone()[0]
184        self.assertEqual(type(val), str)
185        self.assertEqual(val, "foo")
186
187    def CheckFuncReturnUnicode(self):
188        cur = self.con.cursor()
189        cur.execute("select returnunicode()")
190        val = cur.fetchone()[0]
191        self.assertEqual(type(val), str)
192        self.assertEqual(val, "bar")
193
194    def CheckFuncReturnInt(self):
195        cur = self.con.cursor()
196        cur.execute("select returnint()")
197        val = cur.fetchone()[0]
198        self.assertEqual(type(val), int)
199        self.assertEqual(val, 42)
200
201    def CheckFuncReturnFloat(self):
202        cur = self.con.cursor()
203        cur.execute("select returnfloat()")
204        val = cur.fetchone()[0]
205        self.assertEqual(type(val), float)
206        if val < 3.139 or val > 3.141:
207            self.fail("wrong value")
208
209    def CheckFuncReturnNull(self):
210        cur = self.con.cursor()
211        cur.execute("select returnnull()")
212        val = cur.fetchone()[0]
213        self.assertEqual(type(val), type(None))
214        self.assertEqual(val, None)
215
216    def CheckFuncReturnBlob(self):
217        cur = self.con.cursor()
218        cur.execute("select returnblob()")
219        val = cur.fetchone()[0]
220        self.assertEqual(type(val), bytes)
221        self.assertEqual(val, b"blob")
222
223    def CheckFuncReturnLongLong(self):
224        cur = self.con.cursor()
225        cur.execute("select returnlonglong()")
226        val = cur.fetchone()[0]
227        self.assertEqual(val, 1<<31)
228
229    def CheckFuncException(self):
230        cur = self.con.cursor()
231        with self.assertRaises(sqlite.OperationalError) as cm:
232            cur.execute("select raiseexception()")
233            cur.fetchone()
234        self.assertEqual(str(cm.exception), 'user-defined function raised exception')
235
236    def CheckParamString(self):
237        cur = self.con.cursor()
238        cur.execute("select isstring(?)", ("foo",))
239        val = cur.fetchone()[0]
240        self.assertEqual(val, 1)
241
242    def CheckParamInt(self):
243        cur = self.con.cursor()
244        cur.execute("select isint(?)", (42,))
245        val = cur.fetchone()[0]
246        self.assertEqual(val, 1)
247
248    def CheckParamFloat(self):
249        cur = self.con.cursor()
250        cur.execute("select isfloat(?)", (3.14,))
251        val = cur.fetchone()[0]
252        self.assertEqual(val, 1)
253
254    def CheckParamNone(self):
255        cur = self.con.cursor()
256        cur.execute("select isnone(?)", (None,))
257        val = cur.fetchone()[0]
258        self.assertEqual(val, 1)
259
260    def CheckParamBlob(self):
261        cur = self.con.cursor()
262        cur.execute("select isblob(?)", (memoryview(b"blob"),))
263        val = cur.fetchone()[0]
264        self.assertEqual(val, 1)
265
266    def CheckParamLongLong(self):
267        cur = self.con.cursor()
268        cur.execute("select islonglong(?)", (1<<42,))
269        val = cur.fetchone()[0]
270        self.assertEqual(val, 1)
271
272    def CheckAnyArguments(self):
273        cur = self.con.cursor()
274        cur.execute("select spam(?, ?)", (1, 2))
275        val = cur.fetchone()[0]
276        self.assertEqual(val, 2)
277
278
279class AggregateTests(unittest.TestCase):
280    def setUp(self):
281        self.con = sqlite.connect(":memory:")
282        cur = self.con.cursor()
283        cur.execute("""
284            create table test(
285                t text,
286                i integer,
287                f float,
288                n,
289                b blob
290                )
291            """)
292        cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)",
293            ("foo", 5, 3.14, None, memoryview(b"blob"),))
294
295        self.con.create_aggregate("nostep", 1, AggrNoStep)
296        self.con.create_aggregate("nofinalize", 1, AggrNoFinalize)
297        self.con.create_aggregate("excInit", 1, AggrExceptionInInit)
298        self.con.create_aggregate("excStep", 1, AggrExceptionInStep)
299        self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize)
300        self.con.create_aggregate("checkType", 2, AggrCheckType)
301        self.con.create_aggregate("checkTypes", -1, AggrCheckTypes)
302        self.con.create_aggregate("mysum", 1, AggrSum)
303
304    def tearDown(self):
305        #self.cur.close()
306        #self.con.close()
307        pass
308
309    def CheckAggrErrorOnCreate(self):
310        with self.assertRaises(sqlite.OperationalError):
311            self.con.create_function("bla", -100, AggrSum)
312
313    def CheckAggrNoStep(self):
314        cur = self.con.cursor()
315        with self.assertRaises(AttributeError) as cm:
316            cur.execute("select nostep(t) from test")
317        self.assertEqual(str(cm.exception), "'AggrNoStep' object has no attribute 'step'")
318
319    def CheckAggrNoFinalize(self):
320        cur = self.con.cursor()
321        with self.assertRaises(sqlite.OperationalError) as cm:
322            cur.execute("select nofinalize(t) from test")
323            val = cur.fetchone()[0]
324        self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error")
325
326    def CheckAggrExceptionInInit(self):
327        cur = self.con.cursor()
328        with self.assertRaises(sqlite.OperationalError) as cm:
329            cur.execute("select excInit(t) from test")
330            val = cur.fetchone()[0]
331        self.assertEqual(str(cm.exception), "user-defined aggregate's '__init__' method raised error")
332
333    def CheckAggrExceptionInStep(self):
334        cur = self.con.cursor()
335        with self.assertRaises(sqlite.OperationalError) as cm:
336            cur.execute("select excStep(t) from test")
337            val = cur.fetchone()[0]
338        self.assertEqual(str(cm.exception), "user-defined aggregate's 'step' method raised error")
339
340    def CheckAggrExceptionInFinalize(self):
341        cur = self.con.cursor()
342        with self.assertRaises(sqlite.OperationalError) as cm:
343            cur.execute("select excFinalize(t) from test")
344            val = cur.fetchone()[0]
345        self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error")
346
347    def CheckAggrCheckParamStr(self):
348        cur = self.con.cursor()
349        cur.execute("select checkType('str', ?)", ("foo",))
350        val = cur.fetchone()[0]
351        self.assertEqual(val, 1)
352
353    def CheckAggrCheckParamInt(self):
354        cur = self.con.cursor()
355        cur.execute("select checkType('int', ?)", (42,))
356        val = cur.fetchone()[0]
357        self.assertEqual(val, 1)
358
359    def CheckAggrCheckParamsInt(self):
360        cur = self.con.cursor()
361        cur.execute("select checkTypes('int', ?, ?)", (42, 24))
362        val = cur.fetchone()[0]
363        self.assertEqual(val, 2)
364
365    def CheckAggrCheckParamFloat(self):
366        cur = self.con.cursor()
367        cur.execute("select checkType('float', ?)", (3.14,))
368        val = cur.fetchone()[0]
369        self.assertEqual(val, 1)
370
371    def CheckAggrCheckParamNone(self):
372        cur = self.con.cursor()
373        cur.execute("select checkType('None', ?)", (None,))
374        val = cur.fetchone()[0]
375        self.assertEqual(val, 1)
376
377    def CheckAggrCheckParamBlob(self):
378        cur = self.con.cursor()
379        cur.execute("select checkType('blob', ?)", (memoryview(b"blob"),))
380        val = cur.fetchone()[0]
381        self.assertEqual(val, 1)
382
383    def CheckAggrCheckAggrSum(self):
384        cur = self.con.cursor()
385        cur.execute("delete from test")
386        cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)])
387        cur.execute("select mysum(i) from test")
388        val = cur.fetchone()[0]
389        self.assertEqual(val, 60)
390
391class AuthorizerTests(unittest.TestCase):
392    @staticmethod
393    def authorizer_cb(action, arg1, arg2, dbname, source):
394        if action != sqlite.SQLITE_SELECT:
395            return sqlite.SQLITE_DENY
396        if arg2 == 'c2' or arg1 == 't2':
397            return sqlite.SQLITE_DENY
398        return sqlite.SQLITE_OK
399
400    def setUp(self):
401        self.con = sqlite.connect(":memory:")
402        self.con.executescript("""
403            create table t1 (c1, c2);
404            create table t2 (c1, c2);
405            insert into t1 (c1, c2) values (1, 2);
406            insert into t2 (c1, c2) values (4, 5);
407            """)
408
409        # For our security test:
410        self.con.execute("select c2 from t2")
411
412        self.con.set_authorizer(self.authorizer_cb)
413
414    def tearDown(self):
415        pass
416
417    def test_table_access(self):
418        with self.assertRaises(sqlite.DatabaseError) as cm:
419            self.con.execute("select * from t2")
420        self.assertIn('prohibited', str(cm.exception))
421
422    def test_column_access(self):
423        with self.assertRaises(sqlite.DatabaseError) as cm:
424            self.con.execute("select c2 from t1")
425        self.assertIn('prohibited', str(cm.exception))
426
427class AuthorizerRaiseExceptionTests(AuthorizerTests):
428    @staticmethod
429    def authorizer_cb(action, arg1, arg2, dbname, source):
430        if action != sqlite.SQLITE_SELECT:
431            raise ValueError
432        if arg2 == 'c2' or arg1 == 't2':
433            raise ValueError
434        return sqlite.SQLITE_OK
435
436class AuthorizerIllegalTypeTests(AuthorizerTests):
437    @staticmethod
438    def authorizer_cb(action, arg1, arg2, dbname, source):
439        if action != sqlite.SQLITE_SELECT:
440            return 0.0
441        if arg2 == 'c2' or arg1 == 't2':
442            return 0.0
443        return sqlite.SQLITE_OK
444
445class AuthorizerLargeIntegerTests(AuthorizerTests):
446    @staticmethod
447    def authorizer_cb(action, arg1, arg2, dbname, source):
448        if action != sqlite.SQLITE_SELECT:
449            return 2**32
450        if arg2 == 'c2' or arg1 == 't2':
451            return 2**32
452        return sqlite.SQLITE_OK
453
454
455def suite():
456    function_suite = unittest.makeSuite(FunctionTests, "Check")
457    aggregate_suite = unittest.makeSuite(AggregateTests, "Check")
458    authorizer_suite = unittest.makeSuite(AuthorizerTests)
459    return unittest.TestSuite((
460            function_suite,
461            aggregate_suite,
462            authorizer_suite,
463            unittest.makeSuite(AuthorizerRaiseExceptionTests),
464            unittest.makeSuite(AuthorizerIllegalTypeTests),
465            unittest.makeSuite(AuthorizerLargeIntegerTests),
466        ))
467
468def test():
469    runner = unittest.TextTestRunner()
470    runner.run(suite())
471
472if __name__ == "__main__":
473    test()
474