1#-*- coding: iso-8859-1 -*- 2# pysqlite2/test/userfunctions.py: tests for user-defined functions and 3# aggregates. 4# 5# Copyright (C) 2005-2007 Gerhard H�ring <gh@ghaering.de> 6# 7# This file is part of pysqlite. 8# 9# This software is provided 'as-is', without any express or implied 10# warranty. In no event will the authors be held liable for any damages 11# arising from the use of this software. 12# 13# Permission is granted to anyone to use this software for any purpose, 14# including commercial applications, and to alter it and redistribute it 15# freely, subject to the following restrictions: 16# 17# 1. The origin of this software must not be misrepresented; you must not 18# claim that you wrote the original software. If you use this software 19# in a product, an acknowledgment in the product documentation would be 20# appreciated but is not required. 21# 2. Altered source versions must be plainly marked as such, and must not be 22# misrepresented as being the original software. 23# 3. This notice may not be removed or altered from any source distribution. 24 25import unittest 26import unittest.mock 27import sqlite3 as sqlite 28 29def func_returntext(): 30 return "foo" 31def func_returnunicode(): 32 return "bar" 33def func_returnint(): 34 return 42 35def func_returnfloat(): 36 return 3.14 37def func_returnnull(): 38 return None 39def func_returnblob(): 40 return b"blob" 41def func_returnlonglong(): 42 return 1<<31 43def func_raiseexception(): 44 5/0 45 46def func_isstring(v): 47 return type(v) is str 48def func_isint(v): 49 return type(v) is int 50def func_isfloat(v): 51 return type(v) is float 52def func_isnone(v): 53 return type(v) is type(None) 54def func_isblob(v): 55 return isinstance(v, (bytes, memoryview)) 56def func_islonglong(v): 57 return isinstance(v, int) and v >= 1<<31 58 59def func(*args): 60 return len(args) 61 62class AggrNoStep: 63 def __init__(self): 64 pass 65 66 def finalize(self): 67 return 1 68 69class AggrNoFinalize: 70 def __init__(self): 71 pass 72 73 def step(self, x): 74 pass 75 76class AggrExceptionInInit: 77 def __init__(self): 78 5/0 79 80 def step(self, x): 81 pass 82 83 def finalize(self): 84 pass 85 86class AggrExceptionInStep: 87 def __init__(self): 88 pass 89 90 def step(self, x): 91 5/0 92 93 def finalize(self): 94 return 42 95 96class AggrExceptionInFinalize: 97 def __init__(self): 98 pass 99 100 def step(self, x): 101 pass 102 103 def finalize(self): 104 5/0 105 106class AggrCheckType: 107 def __init__(self): 108 self.val = None 109 110 def step(self, whichType, val): 111 theType = {"str": str, "int": int, "float": float, "None": type(None), 112 "blob": bytes} 113 self.val = int(theType[whichType] is type(val)) 114 115 def finalize(self): 116 return self.val 117 118class AggrCheckTypes: 119 def __init__(self): 120 self.val = 0 121 122 def step(self, whichType, *vals): 123 theType = {"str": str, "int": int, "float": float, "None": type(None), 124 "blob": bytes} 125 for val in vals: 126 self.val += int(theType[whichType] is type(val)) 127 128 def finalize(self): 129 return self.val 130 131class AggrSum: 132 def __init__(self): 133 self.val = 0.0 134 135 def step(self, val): 136 self.val += val 137 138 def finalize(self): 139 return self.val 140 141class FunctionTests(unittest.TestCase): 142 def setUp(self): 143 self.con = sqlite.connect(":memory:") 144 145 self.con.create_function("returntext", 0, func_returntext) 146 self.con.create_function("returnunicode", 0, func_returnunicode) 147 self.con.create_function("returnint", 0, func_returnint) 148 self.con.create_function("returnfloat", 0, func_returnfloat) 149 self.con.create_function("returnnull", 0, func_returnnull) 150 self.con.create_function("returnblob", 0, func_returnblob) 151 self.con.create_function("returnlonglong", 0, func_returnlonglong) 152 self.con.create_function("raiseexception", 0, func_raiseexception) 153 154 self.con.create_function("isstring", 1, func_isstring) 155 self.con.create_function("isint", 1, func_isint) 156 self.con.create_function("isfloat", 1, func_isfloat) 157 self.con.create_function("isnone", 1, func_isnone) 158 self.con.create_function("isblob", 1, func_isblob) 159 self.con.create_function("islonglong", 1, func_islonglong) 160 self.con.create_function("spam", -1, func) 161 162 def tearDown(self): 163 self.con.close() 164 165 def CheckFuncErrorOnCreate(self): 166 with self.assertRaises(sqlite.OperationalError): 167 self.con.create_function("bla", -100, lambda x: 2*x) 168 169 def CheckFuncRefCount(self): 170 def getfunc(): 171 def f(): 172 return 1 173 return f 174 f = getfunc() 175 globals()["foo"] = f 176 # self.con.create_function("reftest", 0, getfunc()) 177 self.con.create_function("reftest", 0, f) 178 cur = self.con.cursor() 179 cur.execute("select reftest()") 180 181 def CheckFuncReturnText(self): 182 cur = self.con.cursor() 183 cur.execute("select returntext()") 184 val = cur.fetchone()[0] 185 self.assertEqual(type(val), str) 186 self.assertEqual(val, "foo") 187 188 def CheckFuncReturnUnicode(self): 189 cur = self.con.cursor() 190 cur.execute("select returnunicode()") 191 val = cur.fetchone()[0] 192 self.assertEqual(type(val), str) 193 self.assertEqual(val, "bar") 194 195 def CheckFuncReturnInt(self): 196 cur = self.con.cursor() 197 cur.execute("select returnint()") 198 val = cur.fetchone()[0] 199 self.assertEqual(type(val), int) 200 self.assertEqual(val, 42) 201 202 def CheckFuncReturnFloat(self): 203 cur = self.con.cursor() 204 cur.execute("select returnfloat()") 205 val = cur.fetchone()[0] 206 self.assertEqual(type(val), float) 207 if val < 3.139 or val > 3.141: 208 self.fail("wrong value") 209 210 def CheckFuncReturnNull(self): 211 cur = self.con.cursor() 212 cur.execute("select returnnull()") 213 val = cur.fetchone()[0] 214 self.assertEqual(type(val), type(None)) 215 self.assertEqual(val, None) 216 217 def CheckFuncReturnBlob(self): 218 cur = self.con.cursor() 219 cur.execute("select returnblob()") 220 val = cur.fetchone()[0] 221 self.assertEqual(type(val), bytes) 222 self.assertEqual(val, b"blob") 223 224 def CheckFuncReturnLongLong(self): 225 cur = self.con.cursor() 226 cur.execute("select returnlonglong()") 227 val = cur.fetchone()[0] 228 self.assertEqual(val, 1<<31) 229 230 def CheckFuncException(self): 231 cur = self.con.cursor() 232 with self.assertRaises(sqlite.OperationalError) as cm: 233 cur.execute("select raiseexception()") 234 cur.fetchone() 235 self.assertEqual(str(cm.exception), 'user-defined function raised exception') 236 237 def CheckParamString(self): 238 cur = self.con.cursor() 239 cur.execute("select isstring(?)", ("foo",)) 240 val = cur.fetchone()[0] 241 self.assertEqual(val, 1) 242 243 def CheckParamInt(self): 244 cur = self.con.cursor() 245 cur.execute("select isint(?)", (42,)) 246 val = cur.fetchone()[0] 247 self.assertEqual(val, 1) 248 249 def CheckParamFloat(self): 250 cur = self.con.cursor() 251 cur.execute("select isfloat(?)", (3.14,)) 252 val = cur.fetchone()[0] 253 self.assertEqual(val, 1) 254 255 def CheckParamNone(self): 256 cur = self.con.cursor() 257 cur.execute("select isnone(?)", (None,)) 258 val = cur.fetchone()[0] 259 self.assertEqual(val, 1) 260 261 def CheckParamBlob(self): 262 cur = self.con.cursor() 263 cur.execute("select isblob(?)", (memoryview(b"blob"),)) 264 val = cur.fetchone()[0] 265 self.assertEqual(val, 1) 266 267 def CheckParamLongLong(self): 268 cur = self.con.cursor() 269 cur.execute("select islonglong(?)", (1<<42,)) 270 val = cur.fetchone()[0] 271 self.assertEqual(val, 1) 272 273 def CheckAnyArguments(self): 274 cur = self.con.cursor() 275 cur.execute("select spam(?, ?)", (1, 2)) 276 val = cur.fetchone()[0] 277 self.assertEqual(val, 2) 278 279 def CheckFuncNonDeterministic(self): 280 mock = unittest.mock.Mock(return_value=None) 281 self.con.create_function("deterministic", 0, mock, deterministic=False) 282 self.con.execute("select deterministic() = deterministic()") 283 self.assertEqual(mock.call_count, 2) 284 285 @unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "deterministic parameter not supported") 286 def CheckFuncDeterministic(self): 287 mock = unittest.mock.Mock(return_value=None) 288 self.con.create_function("deterministic", 0, mock, deterministic=True) 289 self.con.execute("select deterministic() = deterministic()") 290 self.assertEqual(mock.call_count, 1) 291 292 @unittest.skipIf(sqlite.sqlite_version_info >= (3, 8, 3), "SQLite < 3.8.3 needed") 293 def CheckFuncDeterministicNotSupported(self): 294 with self.assertRaises(sqlite.NotSupportedError): 295 self.con.create_function("deterministic", 0, int, deterministic=True) 296 297 def CheckFuncDeterministicKeywordOnly(self): 298 with self.assertRaises(TypeError): 299 self.con.create_function("deterministic", 0, int, True) 300 301 302class AggregateTests(unittest.TestCase): 303 def setUp(self): 304 self.con = sqlite.connect(":memory:") 305 cur = self.con.cursor() 306 cur.execute(""" 307 create table test( 308 t text, 309 i integer, 310 f float, 311 n, 312 b blob 313 ) 314 """) 315 cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)", 316 ("foo", 5, 3.14, None, memoryview(b"blob"),)) 317 318 self.con.create_aggregate("nostep", 1, AggrNoStep) 319 self.con.create_aggregate("nofinalize", 1, AggrNoFinalize) 320 self.con.create_aggregate("excInit", 1, AggrExceptionInInit) 321 self.con.create_aggregate("excStep", 1, AggrExceptionInStep) 322 self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize) 323 self.con.create_aggregate("checkType", 2, AggrCheckType) 324 self.con.create_aggregate("checkTypes", -1, AggrCheckTypes) 325 self.con.create_aggregate("mysum", 1, AggrSum) 326 327 def tearDown(self): 328 #self.cur.close() 329 #self.con.close() 330 pass 331 332 def CheckAggrErrorOnCreate(self): 333 with self.assertRaises(sqlite.OperationalError): 334 self.con.create_function("bla", -100, AggrSum) 335 336 def CheckAggrNoStep(self): 337 cur = self.con.cursor() 338 with self.assertRaises(AttributeError) as cm: 339 cur.execute("select nostep(t) from test") 340 self.assertEqual(str(cm.exception), "'AggrNoStep' object has no attribute 'step'") 341 342 def CheckAggrNoFinalize(self): 343 cur = self.con.cursor() 344 with self.assertRaises(sqlite.OperationalError) as cm: 345 cur.execute("select nofinalize(t) from test") 346 val = cur.fetchone()[0] 347 self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error") 348 349 def CheckAggrExceptionInInit(self): 350 cur = self.con.cursor() 351 with self.assertRaises(sqlite.OperationalError) as cm: 352 cur.execute("select excInit(t) from test") 353 val = cur.fetchone()[0] 354 self.assertEqual(str(cm.exception), "user-defined aggregate's '__init__' method raised error") 355 356 def CheckAggrExceptionInStep(self): 357 cur = self.con.cursor() 358 with self.assertRaises(sqlite.OperationalError) as cm: 359 cur.execute("select excStep(t) from test") 360 val = cur.fetchone()[0] 361 self.assertEqual(str(cm.exception), "user-defined aggregate's 'step' method raised error") 362 363 def CheckAggrExceptionInFinalize(self): 364 cur = self.con.cursor() 365 with self.assertRaises(sqlite.OperationalError) as cm: 366 cur.execute("select excFinalize(t) from test") 367 val = cur.fetchone()[0] 368 self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error") 369 370 def CheckAggrCheckParamStr(self): 371 cur = self.con.cursor() 372 cur.execute("select checkType('str', ?)", ("foo",)) 373 val = cur.fetchone()[0] 374 self.assertEqual(val, 1) 375 376 def CheckAggrCheckParamInt(self): 377 cur = self.con.cursor() 378 cur.execute("select checkType('int', ?)", (42,)) 379 val = cur.fetchone()[0] 380 self.assertEqual(val, 1) 381 382 def CheckAggrCheckParamsInt(self): 383 cur = self.con.cursor() 384 cur.execute("select checkTypes('int', ?, ?)", (42, 24)) 385 val = cur.fetchone()[0] 386 self.assertEqual(val, 2) 387 388 def CheckAggrCheckParamFloat(self): 389 cur = self.con.cursor() 390 cur.execute("select checkType('float', ?)", (3.14,)) 391 val = cur.fetchone()[0] 392 self.assertEqual(val, 1) 393 394 def CheckAggrCheckParamNone(self): 395 cur = self.con.cursor() 396 cur.execute("select checkType('None', ?)", (None,)) 397 val = cur.fetchone()[0] 398 self.assertEqual(val, 1) 399 400 def CheckAggrCheckParamBlob(self): 401 cur = self.con.cursor() 402 cur.execute("select checkType('blob', ?)", (memoryview(b"blob"),)) 403 val = cur.fetchone()[0] 404 self.assertEqual(val, 1) 405 406 def CheckAggrCheckAggrSum(self): 407 cur = self.con.cursor() 408 cur.execute("delete from test") 409 cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)]) 410 cur.execute("select mysum(i) from test") 411 val = cur.fetchone()[0] 412 self.assertEqual(val, 60) 413 414class AuthorizerTests(unittest.TestCase): 415 @staticmethod 416 def authorizer_cb(action, arg1, arg2, dbname, source): 417 if action != sqlite.SQLITE_SELECT: 418 return sqlite.SQLITE_DENY 419 if arg2 == 'c2' or arg1 == 't2': 420 return sqlite.SQLITE_DENY 421 return sqlite.SQLITE_OK 422 423 def setUp(self): 424 self.con = sqlite.connect(":memory:") 425 self.con.executescript(""" 426 create table t1 (c1, c2); 427 create table t2 (c1, c2); 428 insert into t1 (c1, c2) values (1, 2); 429 insert into t2 (c1, c2) values (4, 5); 430 """) 431 432 # For our security test: 433 self.con.execute("select c2 from t2") 434 435 self.con.set_authorizer(self.authorizer_cb) 436 437 def tearDown(self): 438 pass 439 440 def test_table_access(self): 441 with self.assertRaises(sqlite.DatabaseError) as cm: 442 self.con.execute("select * from t2") 443 self.assertIn('prohibited', str(cm.exception)) 444 445 def test_column_access(self): 446 with self.assertRaises(sqlite.DatabaseError) as cm: 447 self.con.execute("select c2 from t1") 448 self.assertIn('prohibited', str(cm.exception)) 449 450class AuthorizerRaiseExceptionTests(AuthorizerTests): 451 @staticmethod 452 def authorizer_cb(action, arg1, arg2, dbname, source): 453 if action != sqlite.SQLITE_SELECT: 454 raise ValueError 455 if arg2 == 'c2' or arg1 == 't2': 456 raise ValueError 457 return sqlite.SQLITE_OK 458 459class AuthorizerIllegalTypeTests(AuthorizerTests): 460 @staticmethod 461 def authorizer_cb(action, arg1, arg2, dbname, source): 462 if action != sqlite.SQLITE_SELECT: 463 return 0.0 464 if arg2 == 'c2' or arg1 == 't2': 465 return 0.0 466 return sqlite.SQLITE_OK 467 468class AuthorizerLargeIntegerTests(AuthorizerTests): 469 @staticmethod 470 def authorizer_cb(action, arg1, arg2, dbname, source): 471 if action != sqlite.SQLITE_SELECT: 472 return 2**32 473 if arg2 == 'c2' or arg1 == 't2': 474 return 2**32 475 return sqlite.SQLITE_OK 476 477 478def suite(): 479 function_suite = unittest.makeSuite(FunctionTests, "Check") 480 aggregate_suite = unittest.makeSuite(AggregateTests, "Check") 481 authorizer_suite = unittest.makeSuite(AuthorizerTests) 482 return unittest.TestSuite(( 483 function_suite, 484 aggregate_suite, 485 authorizer_suite, 486 unittest.makeSuite(AuthorizerRaiseExceptionTests), 487 unittest.makeSuite(AuthorizerIllegalTypeTests), 488 unittest.makeSuite(AuthorizerLargeIntegerTests), 489 )) 490 491def test(): 492 runner = unittest.TextTestRunner() 493 runner.run(suite()) 494 495if __name__ == "__main__": 496 test() 497