1#-*- coding: iso-8859-1 -*- 2# pysqlite2/test/userfunctions.py: tests for user-defined functions and 3# aggregates. 4# 5# Copyright (C) 2005-2007 Gerhard H�ring <gh@ghaering.de> 6# 7# This file is part of pysqlite. 8# 9# This software is provided 'as-is', without any express or implied 10# warranty. In no event will the authors be held liable for any damages 11# arising from the use of this software. 12# 13# Permission is granted to anyone to use this software for any purpose, 14# including commercial applications, and to alter it and redistribute it 15# freely, subject to the following restrictions: 16# 17# 1. The origin of this software must not be misrepresented; you must not 18# claim that you wrote the original software. If you use this software 19# in a product, an acknowledgment in the product documentation would be 20# appreciated but is not required. 21# 2. Altered source versions must be plainly marked as such, and must not be 22# misrepresented as being the original software. 23# 3. This notice may not be removed or altered from any source distribution. 24 25import unittest 26import sqlite3 as sqlite 27 28def func_returntext(): 29 return "foo" 30def func_returnunicode(): 31 return "bar" 32def func_returnint(): 33 return 42 34def func_returnfloat(): 35 return 3.14 36def func_returnnull(): 37 return None 38def func_returnblob(): 39 return b"blob" 40def func_returnlonglong(): 41 return 1<<31 42def func_raiseexception(): 43 5/0 44 45def func_isstring(v): 46 return type(v) is str 47def func_isint(v): 48 return type(v) is int 49def func_isfloat(v): 50 return type(v) is float 51def func_isnone(v): 52 return type(v) is type(None) 53def func_isblob(v): 54 return isinstance(v, (bytes, memoryview)) 55def func_islonglong(v): 56 return isinstance(v, int) and v >= 1<<31 57 58def func(*args): 59 return len(args) 60 61class AggrNoStep: 62 def __init__(self): 63 pass 64 65 def finalize(self): 66 return 1 67 68class AggrNoFinalize: 69 def __init__(self): 70 pass 71 72 def step(self, x): 73 pass 74 75class AggrExceptionInInit: 76 def __init__(self): 77 5/0 78 79 def step(self, x): 80 pass 81 82 def finalize(self): 83 pass 84 85class AggrExceptionInStep: 86 def __init__(self): 87 pass 88 89 def step(self, x): 90 5/0 91 92 def finalize(self): 93 return 42 94 95class AggrExceptionInFinalize: 96 def __init__(self): 97 pass 98 99 def step(self, x): 100 pass 101 102 def finalize(self): 103 5/0 104 105class AggrCheckType: 106 def __init__(self): 107 self.val = None 108 109 def step(self, whichType, val): 110 theType = {"str": str, "int": int, "float": float, "None": type(None), 111 "blob": bytes} 112 self.val = int(theType[whichType] is type(val)) 113 114 def finalize(self): 115 return self.val 116 117class AggrCheckTypes: 118 def __init__(self): 119 self.val = 0 120 121 def step(self, whichType, *vals): 122 theType = {"str": str, "int": int, "float": float, "None": type(None), 123 "blob": bytes} 124 for val in vals: 125 self.val += int(theType[whichType] is type(val)) 126 127 def finalize(self): 128 return self.val 129 130class AggrSum: 131 def __init__(self): 132 self.val = 0.0 133 134 def step(self, val): 135 self.val += val 136 137 def finalize(self): 138 return self.val 139 140class FunctionTests(unittest.TestCase): 141 def setUp(self): 142 self.con = sqlite.connect(":memory:") 143 144 self.con.create_function("returntext", 0, func_returntext) 145 self.con.create_function("returnunicode", 0, func_returnunicode) 146 self.con.create_function("returnint", 0, func_returnint) 147 self.con.create_function("returnfloat", 0, func_returnfloat) 148 self.con.create_function("returnnull", 0, func_returnnull) 149 self.con.create_function("returnblob", 0, func_returnblob) 150 self.con.create_function("returnlonglong", 0, func_returnlonglong) 151 self.con.create_function("raiseexception", 0, func_raiseexception) 152 153 self.con.create_function("isstring", 1, func_isstring) 154 self.con.create_function("isint", 1, func_isint) 155 self.con.create_function("isfloat", 1, func_isfloat) 156 self.con.create_function("isnone", 1, func_isnone) 157 self.con.create_function("isblob", 1, func_isblob) 158 self.con.create_function("islonglong", 1, func_islonglong) 159 self.con.create_function("spam", -1, func) 160 161 def tearDown(self): 162 self.con.close() 163 164 def CheckFuncErrorOnCreate(self): 165 with self.assertRaises(sqlite.OperationalError): 166 self.con.create_function("bla", -100, lambda x: 2*x) 167 168 def CheckFuncRefCount(self): 169 def getfunc(): 170 def f(): 171 return 1 172 return f 173 f = getfunc() 174 globals()["foo"] = f 175 # self.con.create_function("reftest", 0, getfunc()) 176 self.con.create_function("reftest", 0, f) 177 cur = self.con.cursor() 178 cur.execute("select reftest()") 179 180 def CheckFuncReturnText(self): 181 cur = self.con.cursor() 182 cur.execute("select returntext()") 183 val = cur.fetchone()[0] 184 self.assertEqual(type(val), str) 185 self.assertEqual(val, "foo") 186 187 def CheckFuncReturnUnicode(self): 188 cur = self.con.cursor() 189 cur.execute("select returnunicode()") 190 val = cur.fetchone()[0] 191 self.assertEqual(type(val), str) 192 self.assertEqual(val, "bar") 193 194 def CheckFuncReturnInt(self): 195 cur = self.con.cursor() 196 cur.execute("select returnint()") 197 val = cur.fetchone()[0] 198 self.assertEqual(type(val), int) 199 self.assertEqual(val, 42) 200 201 def CheckFuncReturnFloat(self): 202 cur = self.con.cursor() 203 cur.execute("select returnfloat()") 204 val = cur.fetchone()[0] 205 self.assertEqual(type(val), float) 206 if val < 3.139 or val > 3.141: 207 self.fail("wrong value") 208 209 def CheckFuncReturnNull(self): 210 cur = self.con.cursor() 211 cur.execute("select returnnull()") 212 val = cur.fetchone()[0] 213 self.assertEqual(type(val), type(None)) 214 self.assertEqual(val, None) 215 216 def CheckFuncReturnBlob(self): 217 cur = self.con.cursor() 218 cur.execute("select returnblob()") 219 val = cur.fetchone()[0] 220 self.assertEqual(type(val), bytes) 221 self.assertEqual(val, b"blob") 222 223 def CheckFuncReturnLongLong(self): 224 cur = self.con.cursor() 225 cur.execute("select returnlonglong()") 226 val = cur.fetchone()[0] 227 self.assertEqual(val, 1<<31) 228 229 def CheckFuncException(self): 230 cur = self.con.cursor() 231 with self.assertRaises(sqlite.OperationalError) as cm: 232 cur.execute("select raiseexception()") 233 cur.fetchone() 234 self.assertEqual(str(cm.exception), 'user-defined function raised exception') 235 236 def CheckParamString(self): 237 cur = self.con.cursor() 238 cur.execute("select isstring(?)", ("foo",)) 239 val = cur.fetchone()[0] 240 self.assertEqual(val, 1) 241 242 def CheckParamInt(self): 243 cur = self.con.cursor() 244 cur.execute("select isint(?)", (42,)) 245 val = cur.fetchone()[0] 246 self.assertEqual(val, 1) 247 248 def CheckParamFloat(self): 249 cur = self.con.cursor() 250 cur.execute("select isfloat(?)", (3.14,)) 251 val = cur.fetchone()[0] 252 self.assertEqual(val, 1) 253 254 def CheckParamNone(self): 255 cur = self.con.cursor() 256 cur.execute("select isnone(?)", (None,)) 257 val = cur.fetchone()[0] 258 self.assertEqual(val, 1) 259 260 def CheckParamBlob(self): 261 cur = self.con.cursor() 262 cur.execute("select isblob(?)", (memoryview(b"blob"),)) 263 val = cur.fetchone()[0] 264 self.assertEqual(val, 1) 265 266 def CheckParamLongLong(self): 267 cur = self.con.cursor() 268 cur.execute("select islonglong(?)", (1<<42,)) 269 val = cur.fetchone()[0] 270 self.assertEqual(val, 1) 271 272 def CheckAnyArguments(self): 273 cur = self.con.cursor() 274 cur.execute("select spam(?, ?)", (1, 2)) 275 val = cur.fetchone()[0] 276 self.assertEqual(val, 2) 277 278 279class AggregateTests(unittest.TestCase): 280 def setUp(self): 281 self.con = sqlite.connect(":memory:") 282 cur = self.con.cursor() 283 cur.execute(""" 284 create table test( 285 t text, 286 i integer, 287 f float, 288 n, 289 b blob 290 ) 291 """) 292 cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)", 293 ("foo", 5, 3.14, None, memoryview(b"blob"),)) 294 295 self.con.create_aggregate("nostep", 1, AggrNoStep) 296 self.con.create_aggregate("nofinalize", 1, AggrNoFinalize) 297 self.con.create_aggregate("excInit", 1, AggrExceptionInInit) 298 self.con.create_aggregate("excStep", 1, AggrExceptionInStep) 299 self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize) 300 self.con.create_aggregate("checkType", 2, AggrCheckType) 301 self.con.create_aggregate("checkTypes", -1, AggrCheckTypes) 302 self.con.create_aggregate("mysum", 1, AggrSum) 303 304 def tearDown(self): 305 #self.cur.close() 306 #self.con.close() 307 pass 308 309 def CheckAggrErrorOnCreate(self): 310 with self.assertRaises(sqlite.OperationalError): 311 self.con.create_function("bla", -100, AggrSum) 312 313 def CheckAggrNoStep(self): 314 cur = self.con.cursor() 315 with self.assertRaises(AttributeError) as cm: 316 cur.execute("select nostep(t) from test") 317 self.assertEqual(str(cm.exception), "'AggrNoStep' object has no attribute 'step'") 318 319 def CheckAggrNoFinalize(self): 320 cur = self.con.cursor() 321 with self.assertRaises(sqlite.OperationalError) as cm: 322 cur.execute("select nofinalize(t) from test") 323 val = cur.fetchone()[0] 324 self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error") 325 326 def CheckAggrExceptionInInit(self): 327 cur = self.con.cursor() 328 with self.assertRaises(sqlite.OperationalError) as cm: 329 cur.execute("select excInit(t) from test") 330 val = cur.fetchone()[0] 331 self.assertEqual(str(cm.exception), "user-defined aggregate's '__init__' method raised error") 332 333 def CheckAggrExceptionInStep(self): 334 cur = self.con.cursor() 335 with self.assertRaises(sqlite.OperationalError) as cm: 336 cur.execute("select excStep(t) from test") 337 val = cur.fetchone()[0] 338 self.assertEqual(str(cm.exception), "user-defined aggregate's 'step' method raised error") 339 340 def CheckAggrExceptionInFinalize(self): 341 cur = self.con.cursor() 342 with self.assertRaises(sqlite.OperationalError) as cm: 343 cur.execute("select excFinalize(t) from test") 344 val = cur.fetchone()[0] 345 self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error") 346 347 def CheckAggrCheckParamStr(self): 348 cur = self.con.cursor() 349 cur.execute("select checkType('str', ?)", ("foo",)) 350 val = cur.fetchone()[0] 351 self.assertEqual(val, 1) 352 353 def CheckAggrCheckParamInt(self): 354 cur = self.con.cursor() 355 cur.execute("select checkType('int', ?)", (42,)) 356 val = cur.fetchone()[0] 357 self.assertEqual(val, 1) 358 359 def CheckAggrCheckParamsInt(self): 360 cur = self.con.cursor() 361 cur.execute("select checkTypes('int', ?, ?)", (42, 24)) 362 val = cur.fetchone()[0] 363 self.assertEqual(val, 2) 364 365 def CheckAggrCheckParamFloat(self): 366 cur = self.con.cursor() 367 cur.execute("select checkType('float', ?)", (3.14,)) 368 val = cur.fetchone()[0] 369 self.assertEqual(val, 1) 370 371 def CheckAggrCheckParamNone(self): 372 cur = self.con.cursor() 373 cur.execute("select checkType('None', ?)", (None,)) 374 val = cur.fetchone()[0] 375 self.assertEqual(val, 1) 376 377 def CheckAggrCheckParamBlob(self): 378 cur = self.con.cursor() 379 cur.execute("select checkType('blob', ?)", (memoryview(b"blob"),)) 380 val = cur.fetchone()[0] 381 self.assertEqual(val, 1) 382 383 def CheckAggrCheckAggrSum(self): 384 cur = self.con.cursor() 385 cur.execute("delete from test") 386 cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)]) 387 cur.execute("select mysum(i) from test") 388 val = cur.fetchone()[0] 389 self.assertEqual(val, 60) 390 391class AuthorizerTests(unittest.TestCase): 392 @staticmethod 393 def authorizer_cb(action, arg1, arg2, dbname, source): 394 if action != sqlite.SQLITE_SELECT: 395 return sqlite.SQLITE_DENY 396 if arg2 == 'c2' or arg1 == 't2': 397 return sqlite.SQLITE_DENY 398 return sqlite.SQLITE_OK 399 400 def setUp(self): 401 self.con = sqlite.connect(":memory:") 402 self.con.executescript(""" 403 create table t1 (c1, c2); 404 create table t2 (c1, c2); 405 insert into t1 (c1, c2) values (1, 2); 406 insert into t2 (c1, c2) values (4, 5); 407 """) 408 409 # For our security test: 410 self.con.execute("select c2 from t2") 411 412 self.con.set_authorizer(self.authorizer_cb) 413 414 def tearDown(self): 415 pass 416 417 def test_table_access(self): 418 with self.assertRaises(sqlite.DatabaseError) as cm: 419 self.con.execute("select * from t2") 420 self.assertIn('prohibited', str(cm.exception)) 421 422 def test_column_access(self): 423 with self.assertRaises(sqlite.DatabaseError) as cm: 424 self.con.execute("select c2 from t1") 425 self.assertIn('prohibited', str(cm.exception)) 426 427class AuthorizerRaiseExceptionTests(AuthorizerTests): 428 @staticmethod 429 def authorizer_cb(action, arg1, arg2, dbname, source): 430 if action != sqlite.SQLITE_SELECT: 431 raise ValueError 432 if arg2 == 'c2' or arg1 == 't2': 433 raise ValueError 434 return sqlite.SQLITE_OK 435 436class AuthorizerIllegalTypeTests(AuthorizerTests): 437 @staticmethod 438 def authorizer_cb(action, arg1, arg2, dbname, source): 439 if action != sqlite.SQLITE_SELECT: 440 return 0.0 441 if arg2 == 'c2' or arg1 == 't2': 442 return 0.0 443 return sqlite.SQLITE_OK 444 445class AuthorizerLargeIntegerTests(AuthorizerTests): 446 @staticmethod 447 def authorizer_cb(action, arg1, arg2, dbname, source): 448 if action != sqlite.SQLITE_SELECT: 449 return 2**32 450 if arg2 == 'c2' or arg1 == 't2': 451 return 2**32 452 return sqlite.SQLITE_OK 453 454 455def suite(): 456 function_suite = unittest.makeSuite(FunctionTests, "Check") 457 aggregate_suite = unittest.makeSuite(AggregateTests, "Check") 458 authorizer_suite = unittest.makeSuite(AuthorizerTests) 459 return unittest.TestSuite(( 460 function_suite, 461 aggregate_suite, 462 authorizer_suite, 463 unittest.makeSuite(AuthorizerRaiseExceptionTests), 464 unittest.makeSuite(AuthorizerIllegalTypeTests), 465 unittest.makeSuite(AuthorizerLargeIntegerTests), 466 )) 467 468def test(): 469 runner = unittest.TextTestRunner() 470 runner.run(suite()) 471 472if __name__ == "__main__": 473 test() 474