1#-*- coding: ISO-8859-1 -*- 2# pysqlite2/test/userfunctions.py: tests for user-defined functions and 3# aggregates. 4# 5# Copyright (C) 2005-2007 Gerhard H�ring <gh@ghaering.de> 6# 7# This file is part of pysqlite. 8# 9# This software is provided 'as-is', without any express or implied 10# warranty. In no event will the authors be held liable for any damages 11# arising from the use of this software. 12# 13# Permission is granted to anyone to use this software for any purpose, 14# including commercial applications, and to alter it and redistribute it 15# freely, subject to the following restrictions: 16# 17# 1. The origin of this software must not be misrepresented; you must not 18# claim that you wrote the original software. If you use this software 19# in a product, an acknowledgment in the product documentation would be 20# appreciated but is not required. 21# 2. Altered source versions must be plainly marked as such, and must not be 22# misrepresented as being the original software. 23# 3. This notice may not be removed or altered from any source distribution. 24 25import unittest 26import sqlite3 as sqlite 27from test import test_support 28 29def func_returntext(): 30 return "foo" 31def func_returnunicode(): 32 return u"bar" 33def func_returnint(): 34 return 42 35def func_returnfloat(): 36 return 3.14 37def func_returnnull(): 38 return None 39def func_returnblob(): 40 with test_support.check_py3k_warnings(): 41 return buffer("blob") 42def func_returnlonglong(): 43 return 1<<31 44def func_raiseexception(): 45 5 // 0 46 47def func_isstring(v): 48 return type(v) is unicode 49def func_isint(v): 50 return type(v) is int 51def func_isfloat(v): 52 return type(v) is float 53def func_isnone(v): 54 return type(v) is type(None) 55def func_isblob(v): 56 return type(v) is buffer 57def func_islonglong(v): 58 return isinstance(v, (int, long)) and v >= 1<<31 59 60class AggrNoStep: 61 def __init__(self): 62 pass 63 64 def finalize(self): 65 return 1 66 67class AggrNoFinalize: 68 def __init__(self): 69 pass 70 71 def step(self, x): 72 pass 73 74class AggrExceptionInInit: 75 def __init__(self): 76 5 // 0 77 78 def step(self, x): 79 pass 80 81 def finalize(self): 82 pass 83 84class AggrExceptionInStep: 85 def __init__(self): 86 pass 87 88 def step(self, x): 89 5 // 0 90 91 def finalize(self): 92 return 42 93 94class AggrExceptionInFinalize: 95 def __init__(self): 96 pass 97 98 def step(self, x): 99 pass 100 101 def finalize(self): 102 5 // 0 103 104class AggrCheckType: 105 def __init__(self): 106 self.val = None 107 108 def step(self, whichType, val): 109 theType = {"str": unicode, "int": int, "float": float, "None": type(None), "blob": buffer} 110 self.val = int(theType[whichType] is type(val)) 111 112 def finalize(self): 113 return self.val 114 115class AggrSum: 116 def __init__(self): 117 self.val = 0.0 118 119 def step(self, val): 120 self.val += val 121 122 def finalize(self): 123 return self.val 124 125class FunctionTests(unittest.TestCase): 126 def setUp(self): 127 self.con = sqlite.connect(":memory:") 128 129 self.con.create_function("returntext", 0, func_returntext) 130 self.con.create_function("returnunicode", 0, func_returnunicode) 131 self.con.create_function("returnint", 0, func_returnint) 132 self.con.create_function("returnfloat", 0, func_returnfloat) 133 self.con.create_function("returnnull", 0, func_returnnull) 134 self.con.create_function("returnblob", 0, func_returnblob) 135 self.con.create_function("returnlonglong", 0, func_returnlonglong) 136 self.con.create_function("raiseexception", 0, func_raiseexception) 137 138 self.con.create_function("isstring", 1, func_isstring) 139 self.con.create_function("isint", 1, func_isint) 140 self.con.create_function("isfloat", 1, func_isfloat) 141 self.con.create_function("isnone", 1, func_isnone) 142 self.con.create_function("isblob", 1, func_isblob) 143 self.con.create_function("islonglong", 1, func_islonglong) 144 145 def tearDown(self): 146 self.con.close() 147 148 def CheckFuncErrorOnCreate(self): 149 try: 150 self.con.create_function("bla", -100, lambda x: 2*x) 151 self.fail("should have raised an OperationalError") 152 except sqlite.OperationalError: 153 pass 154 155 def CheckFuncRefCount(self): 156 def getfunc(): 157 def f(): 158 return 1 159 return f 160 f = getfunc() 161 globals()["foo"] = f 162 # self.con.create_function("reftest", 0, getfunc()) 163 self.con.create_function("reftest", 0, f) 164 cur = self.con.cursor() 165 cur.execute("select reftest()") 166 167 def CheckFuncReturnText(self): 168 cur = self.con.cursor() 169 cur.execute("select returntext()") 170 val = cur.fetchone()[0] 171 self.assertEqual(type(val), unicode) 172 self.assertEqual(val, "foo") 173 174 def CheckFuncReturnUnicode(self): 175 cur = self.con.cursor() 176 cur.execute("select returnunicode()") 177 val = cur.fetchone()[0] 178 self.assertEqual(type(val), unicode) 179 self.assertEqual(val, u"bar") 180 181 def CheckFuncReturnInt(self): 182 cur = self.con.cursor() 183 cur.execute("select returnint()") 184 val = cur.fetchone()[0] 185 self.assertEqual(type(val), int) 186 self.assertEqual(val, 42) 187 188 def CheckFuncReturnFloat(self): 189 cur = self.con.cursor() 190 cur.execute("select returnfloat()") 191 val = cur.fetchone()[0] 192 self.assertEqual(type(val), float) 193 if val < 3.139 or val > 3.141: 194 self.fail("wrong value") 195 196 def CheckFuncReturnNull(self): 197 cur = self.con.cursor() 198 cur.execute("select returnnull()") 199 val = cur.fetchone()[0] 200 self.assertEqual(type(val), type(None)) 201 self.assertEqual(val, None) 202 203 def CheckFuncReturnBlob(self): 204 cur = self.con.cursor() 205 cur.execute("select returnblob()") 206 val = cur.fetchone()[0] 207 with test_support.check_py3k_warnings(): 208 self.assertEqual(type(val), buffer) 209 self.assertEqual(val, buffer("blob")) 210 211 def CheckFuncReturnLongLong(self): 212 cur = self.con.cursor() 213 cur.execute("select returnlonglong()") 214 val = cur.fetchone()[0] 215 self.assertEqual(val, 1<<31) 216 217 def CheckFuncException(self): 218 cur = self.con.cursor() 219 try: 220 cur.execute("select raiseexception()") 221 cur.fetchone() 222 self.fail("should have raised OperationalError") 223 except sqlite.OperationalError, e: 224 self.assertEqual(e.args[0], 'user-defined function raised exception') 225 226 def CheckParamString(self): 227 cur = self.con.cursor() 228 cur.execute("select isstring(?)", ("foo",)) 229 val = cur.fetchone()[0] 230 self.assertEqual(val, 1) 231 232 def CheckParamInt(self): 233 cur = self.con.cursor() 234 cur.execute("select isint(?)", (42,)) 235 val = cur.fetchone()[0] 236 self.assertEqual(val, 1) 237 238 def CheckParamFloat(self): 239 cur = self.con.cursor() 240 cur.execute("select isfloat(?)", (3.14,)) 241 val = cur.fetchone()[0] 242 self.assertEqual(val, 1) 243 244 def CheckParamNone(self): 245 cur = self.con.cursor() 246 cur.execute("select isnone(?)", (None,)) 247 val = cur.fetchone()[0] 248 self.assertEqual(val, 1) 249 250 def CheckParamBlob(self): 251 cur = self.con.cursor() 252 with test_support.check_py3k_warnings(): 253 cur.execute("select isblob(?)", (buffer("blob"),)) 254 val = cur.fetchone()[0] 255 self.assertEqual(val, 1) 256 257 def CheckParamLongLong(self): 258 cur = self.con.cursor() 259 cur.execute("select islonglong(?)", (1<<42,)) 260 val = cur.fetchone()[0] 261 self.assertEqual(val, 1) 262 263class AggregateTests(unittest.TestCase): 264 def setUp(self): 265 self.con = sqlite.connect(":memory:") 266 cur = self.con.cursor() 267 cur.execute(""" 268 create table test( 269 t text, 270 i integer, 271 f float, 272 n, 273 b blob 274 ) 275 """) 276 with test_support.check_py3k_warnings(): 277 cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)", 278 ("foo", 5, 3.14, None, buffer("blob"),)) 279 280 self.con.create_aggregate("nostep", 1, AggrNoStep) 281 self.con.create_aggregate("nofinalize", 1, AggrNoFinalize) 282 self.con.create_aggregate("excInit", 1, AggrExceptionInInit) 283 self.con.create_aggregate("excStep", 1, AggrExceptionInStep) 284 self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize) 285 self.con.create_aggregate("checkType", 2, AggrCheckType) 286 self.con.create_aggregate("mysum", 1, AggrSum) 287 288 def tearDown(self): 289 #self.cur.close() 290 #self.con.close() 291 pass 292 293 def CheckAggrErrorOnCreate(self): 294 try: 295 self.con.create_function("bla", -100, AggrSum) 296 self.fail("should have raised an OperationalError") 297 except sqlite.OperationalError: 298 pass 299 300 def CheckAggrNoStep(self): 301 cur = self.con.cursor() 302 try: 303 cur.execute("select nostep(t) from test") 304 self.fail("should have raised an AttributeError") 305 except AttributeError, e: 306 self.assertEqual(e.args[0], "AggrNoStep instance has no attribute 'step'") 307 308 def CheckAggrNoFinalize(self): 309 cur = self.con.cursor() 310 try: 311 cur.execute("select nofinalize(t) from test") 312 val = cur.fetchone()[0] 313 self.fail("should have raised an OperationalError") 314 except sqlite.OperationalError, e: 315 self.assertEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error") 316 317 def CheckAggrExceptionInInit(self): 318 cur = self.con.cursor() 319 try: 320 cur.execute("select excInit(t) from test") 321 val = cur.fetchone()[0] 322 self.fail("should have raised an OperationalError") 323 except sqlite.OperationalError, e: 324 self.assertEqual(e.args[0], "user-defined aggregate's '__init__' method raised error") 325 326 def CheckAggrExceptionInStep(self): 327 cur = self.con.cursor() 328 try: 329 cur.execute("select excStep(t) from test") 330 val = cur.fetchone()[0] 331 self.fail("should have raised an OperationalError") 332 except sqlite.OperationalError, e: 333 self.assertEqual(e.args[0], "user-defined aggregate's 'step' method raised error") 334 335 def CheckAggrExceptionInFinalize(self): 336 cur = self.con.cursor() 337 try: 338 cur.execute("select excFinalize(t) from test") 339 val = cur.fetchone()[0] 340 self.fail("should have raised an OperationalError") 341 except sqlite.OperationalError, e: 342 self.assertEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error") 343 344 def CheckAggrCheckParamStr(self): 345 cur = self.con.cursor() 346 cur.execute("select checkType('str', ?)", ("foo",)) 347 val = cur.fetchone()[0] 348 self.assertEqual(val, 1) 349 350 def CheckAggrCheckParamInt(self): 351 cur = self.con.cursor() 352 cur.execute("select checkType('int', ?)", (42,)) 353 val = cur.fetchone()[0] 354 self.assertEqual(val, 1) 355 356 def CheckAggrCheckParamFloat(self): 357 cur = self.con.cursor() 358 cur.execute("select checkType('float', ?)", (3.14,)) 359 val = cur.fetchone()[0] 360 self.assertEqual(val, 1) 361 362 def CheckAggrCheckParamNone(self): 363 cur = self.con.cursor() 364 cur.execute("select checkType('None', ?)", (None,)) 365 val = cur.fetchone()[0] 366 self.assertEqual(val, 1) 367 368 def CheckAggrCheckParamBlob(self): 369 cur = self.con.cursor() 370 with test_support.check_py3k_warnings(): 371 cur.execute("select checkType('blob', ?)", (buffer("blob"),)) 372 val = cur.fetchone()[0] 373 self.assertEqual(val, 1) 374 375 def CheckAggrCheckAggrSum(self): 376 cur = self.con.cursor() 377 cur.execute("delete from test") 378 cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)]) 379 cur.execute("select mysum(i) from test") 380 val = cur.fetchone()[0] 381 self.assertEqual(val, 60) 382 383class AuthorizerTests(unittest.TestCase): 384 @staticmethod 385 def authorizer_cb(action, arg1, arg2, dbname, source): 386 if action != sqlite.SQLITE_SELECT: 387 return sqlite.SQLITE_DENY 388 if arg2 == 'c2' or arg1 == 't2': 389 return sqlite.SQLITE_DENY 390 return sqlite.SQLITE_OK 391 392 def setUp(self): 393 self.con = sqlite.connect(":memory:") 394 self.con.executescript(""" 395 create table t1 (c1, c2); 396 create table t2 (c1, c2); 397 insert into t1 (c1, c2) values (1, 2); 398 insert into t2 (c1, c2) values (4, 5); 399 """) 400 401 # For our security test: 402 self.con.execute("select c2 from t2") 403 404 self.con.set_authorizer(self.authorizer_cb) 405 406 def tearDown(self): 407 pass 408 409 def test_table_access(self): 410 try: 411 self.con.execute("select * from t2") 412 except sqlite.DatabaseError, e: 413 if not e.args[0].endswith("prohibited"): 414 self.fail("wrong exception text: %s" % e.args[0]) 415 return 416 self.fail("should have raised an exception due to missing privileges") 417 418 def test_column_access(self): 419 try: 420 self.con.execute("select c2 from t1") 421 except sqlite.DatabaseError, e: 422 if not e.args[0].endswith("prohibited"): 423 self.fail("wrong exception text: %s" % e.args[0]) 424 return 425 self.fail("should have raised an exception due to missing privileges") 426 427class AuthorizerRaiseExceptionTests(AuthorizerTests): 428 @staticmethod 429 def authorizer_cb(action, arg1, arg2, dbname, source): 430 if action != sqlite.SQLITE_SELECT: 431 raise ValueError 432 if arg2 == 'c2' or arg1 == 't2': 433 raise ValueError 434 return sqlite.SQLITE_OK 435 436class AuthorizerIllegalTypeTests(AuthorizerTests): 437 @staticmethod 438 def authorizer_cb(action, arg1, arg2, dbname, source): 439 if action != sqlite.SQLITE_SELECT: 440 return 0.0 441 if arg2 == 'c2' or arg1 == 't2': 442 return 0.0 443 return sqlite.SQLITE_OK 444 445class AuthorizerLargeIntegerTests(AuthorizerTests): 446 @staticmethod 447 def authorizer_cb(action, arg1, arg2, dbname, source): 448 if action != sqlite.SQLITE_SELECT: 449 return 2**32 450 if arg2 == 'c2' or arg1 == 't2': 451 return 2**32 452 return sqlite.SQLITE_OK 453 454 455def suite(): 456 function_suite = unittest.makeSuite(FunctionTests, "Check") 457 aggregate_suite = unittest.makeSuite(AggregateTests, "Check") 458 authorizer_suite = unittest.makeSuite(AuthorizerTests) 459 return unittest.TestSuite(( 460 function_suite, 461 aggregate_suite, 462 authorizer_suite, 463 unittest.makeSuite(AuthorizerRaiseExceptionTests), 464 unittest.makeSuite(AuthorizerIllegalTypeTests), 465 unittest.makeSuite(AuthorizerLargeIntegerTests), 466 )) 467 468def test(): 469 runner = unittest.TextTestRunner() 470 runner.run(suite()) 471 472if __name__ == "__main__": 473 test() 474