1""" 2TestCases for python DB duplicate and Btree key comparison function. 3""" 4 5import sys, os, re 6import test_all 7from cStringIO import StringIO 8 9import unittest 10 11from test_all import db, dbshelve, test_support, \ 12 get_new_environment_path, get_new_database_path 13 14 15# Needed for python 3. "cmp" vanished in 3.0.1 16def cmp(a, b) : 17 if a==b : return 0 18 if a<b : return -1 19 return 1 20 21lexical_cmp = cmp 22 23def lowercase_cmp(left, right) : 24 return cmp(left.lower(), right.lower()) 25 26def make_reverse_comparator(cmp) : 27 def reverse(left, right, delegate=cmp) : 28 return - delegate(left, right) 29 return reverse 30 31_expected_lexical_test_data = ['', 'CCCP', 'a', 'aaa', 'b', 'c', 'cccce', 'ccccf'] 32_expected_lowercase_test_data = ['', 'a', 'aaa', 'b', 'c', 'CC', 'cccce', 'ccccf', 'CCCP'] 33 34class ComparatorTests(unittest.TestCase) : 35 def comparator_test_helper(self, comparator, expected_data) : 36 data = expected_data[:] 37 38 import sys 39 if sys.version_info < (2, 6) : 40 data.sort(cmp=comparator) 41 else : # Insertion Sort. Please, improve 42 data2 = [] 43 for i in data : 44 for j, k in enumerate(data2) : 45 r = comparator(k, i) 46 if r == 1 : 47 data2.insert(j, i) 48 break 49 else : 50 data2.append(i) 51 data = data2 52 53 self.assertEqual(data, expected_data, 54 "comparator `%s' is not right: %s vs. %s" 55 % (comparator, expected_data, data)) 56 def test_lexical_comparator(self) : 57 self.comparator_test_helper(lexical_cmp, _expected_lexical_test_data) 58 def test_reverse_lexical_comparator(self) : 59 rev = _expected_lexical_test_data[:] 60 rev.reverse() 61 self.comparator_test_helper(make_reverse_comparator(lexical_cmp), 62 rev) 63 def test_lowercase_comparator(self) : 64 self.comparator_test_helper(lowercase_cmp, 65 _expected_lowercase_test_data) 66 67class AbstractBtreeKeyCompareTestCase(unittest.TestCase) : 68 env = None 69 db = None 70 71 if (sys.version_info < (2, 7)) or ((sys.version_info >= (3,0)) and 72 (sys.version_info < (3, 2))) : 73 def assertLess(self, a, b, msg=None) : 74 return self.assertTrue(a<b, msg=msg) 75 76 def setUp(self) : 77 self.filename = self.__class__.__name__ + '.db' 78 self.homeDir = get_new_environment_path() 79 env = db.DBEnv() 80 env.open(self.homeDir, 81 db.DB_CREATE | db.DB_INIT_MPOOL 82 | db.DB_INIT_LOCK | db.DB_THREAD) 83 self.env = env 84 85 def tearDown(self) : 86 self.closeDB() 87 if self.env is not None: 88 self.env.close() 89 self.env = None 90 test_support.rmtree(self.homeDir) 91 92 def addDataToDB(self, data) : 93 i = 0 94 for item in data: 95 self.db.put(item, str(i)) 96 i = i + 1 97 98 def createDB(self, key_comparator) : 99 self.db = db.DB(self.env) 100 self.setupDB(key_comparator) 101 self.db.open(self.filename, "test", db.DB_BTREE, db.DB_CREATE) 102 103 def setupDB(self, key_comparator) : 104 self.db.set_bt_compare(key_comparator) 105 106 def closeDB(self) : 107 if self.db is not None: 108 self.db.close() 109 self.db = None 110 111 def startTest(self) : 112 pass 113 114 def finishTest(self, expected = None) : 115 if expected is not None: 116 self.check_results(expected) 117 self.closeDB() 118 119 def check_results(self, expected) : 120 curs = self.db.cursor() 121 try: 122 index = 0 123 rec = curs.first() 124 while rec: 125 key, ignore = rec 126 self.assertLess(index, len(expected), 127 "to many values returned from cursor") 128 self.assertEqual(expected[index], key, 129 "expected value `%s' at %d but got `%s'" 130 % (expected[index], index, key)) 131 index = index + 1 132 rec = curs.next() 133 self.assertEqual(index, len(expected), 134 "not enough values returned from cursor") 135 finally: 136 curs.close() 137 138class BtreeKeyCompareTestCase(AbstractBtreeKeyCompareTestCase) : 139 def runCompareTest(self, comparator, data) : 140 self.startTest() 141 self.createDB(comparator) 142 self.addDataToDB(data) 143 self.finishTest(data) 144 145 def test_lexical_ordering(self) : 146 self.runCompareTest(lexical_cmp, _expected_lexical_test_data) 147 148 def test_reverse_lexical_ordering(self) : 149 expected_rev_data = _expected_lexical_test_data[:] 150 expected_rev_data.reverse() 151 self.runCompareTest(make_reverse_comparator(lexical_cmp), 152 expected_rev_data) 153 154 def test_compare_function_useless(self) : 155 self.startTest() 156 def socialist_comparator(l, r) : 157 return 0 158 self.createDB(socialist_comparator) 159 self.addDataToDB(['b', 'a', 'd']) 160 # all things being equal the first key will be the only key 161 # in the database... (with the last key's value fwiw) 162 self.finishTest(['b']) 163 164 165class BtreeExceptionsTestCase(AbstractBtreeKeyCompareTestCase) : 166 def test_raises_non_callable(self) : 167 self.startTest() 168 self.assertRaises(TypeError, self.createDB, 'abc') 169 self.assertRaises(TypeError, self.createDB, None) 170 self.finishTest() 171 172 def test_set_bt_compare_with_function(self) : 173 self.startTest() 174 self.createDB(lexical_cmp) 175 self.finishTest() 176 177 def check_results(self, results) : 178 pass 179 180 def test_compare_function_incorrect(self) : 181 self.startTest() 182 def bad_comparator(l, r) : 183 return 1 184 # verify that set_bt_compare checks that comparator('', '') == 0 185 self.assertRaises(TypeError, self.createDB, bad_comparator) 186 self.finishTest() 187 188 def verifyStderr(self, method, successRe) : 189 """ 190 Call method() while capturing sys.stderr output internally and 191 call self.fail() if successRe.search() does not match the stderr 192 output. This is used to test for uncatchable exceptions. 193 """ 194 stdErr = sys.stderr 195 sys.stderr = StringIO() 196 try: 197 method() 198 finally: 199 temp = sys.stderr 200 sys.stderr = stdErr 201 errorOut = temp.getvalue() 202 if not successRe.search(errorOut) : 203 self.fail("unexpected stderr output:\n"+errorOut) 204 if sys.version_info < (3, 0) : # XXX: How to do this in Py3k ??? 205 sys.exc_traceback = sys.last_traceback = None 206 207 def _test_compare_function_exception(self) : 208 self.startTest() 209 def bad_comparator(l, r) : 210 if l == r: 211 # pass the set_bt_compare test 212 return 0 213 raise RuntimeError, "i'm a naughty comparison function" 214 self.createDB(bad_comparator) 215 #print "\n*** test should print 2 uncatchable tracebacks ***" 216 self.addDataToDB(['a', 'b', 'c']) # this should raise, but... 217 self.finishTest() 218 219 def test_compare_function_exception(self) : 220 self.verifyStderr( 221 self._test_compare_function_exception, 222 re.compile('(^RuntimeError:.* naughty.*){2}', re.M|re.S) 223 ) 224 225 def _test_compare_function_bad_return(self) : 226 self.startTest() 227 def bad_comparator(l, r) : 228 if l == r: 229 # pass the set_bt_compare test 230 return 0 231 return l 232 self.createDB(bad_comparator) 233 #print "\n*** test should print 2 errors about returning an int ***" 234 self.addDataToDB(['a', 'b', 'c']) # this should raise, but... 235 self.finishTest() 236 237 def test_compare_function_bad_return(self) : 238 self.verifyStderr( 239 self._test_compare_function_bad_return, 240 re.compile('(^TypeError:.* return an int.*){2}', re.M|re.S) 241 ) 242 243 244 def test_cannot_assign_twice(self) : 245 246 def my_compare(a, b) : 247 return 0 248 249 self.startTest() 250 self.createDB(my_compare) 251 self.assertRaises(RuntimeError, self.db.set_bt_compare, my_compare) 252 253class AbstractDuplicateCompareTestCase(unittest.TestCase) : 254 env = None 255 db = None 256 257 if (sys.version_info < (2, 7)) or ((sys.version_info >= (3,0)) and 258 (sys.version_info < (3, 2))) : 259 def assertLess(self, a, b, msg=None) : 260 return self.assertTrue(a<b, msg=msg) 261 262 def setUp(self) : 263 self.filename = self.__class__.__name__ + '.db' 264 self.homeDir = get_new_environment_path() 265 env = db.DBEnv() 266 env.open(self.homeDir, 267 db.DB_CREATE | db.DB_INIT_MPOOL 268 | db.DB_INIT_LOCK | db.DB_THREAD) 269 self.env = env 270 271 def tearDown(self) : 272 self.closeDB() 273 if self.env is not None: 274 self.env.close() 275 self.env = None 276 test_support.rmtree(self.homeDir) 277 278 def addDataToDB(self, data) : 279 for item in data: 280 self.db.put("key", item) 281 282 def createDB(self, dup_comparator) : 283 self.db = db.DB(self.env) 284 self.setupDB(dup_comparator) 285 self.db.open(self.filename, "test", db.DB_BTREE, db.DB_CREATE) 286 287 def setupDB(self, dup_comparator) : 288 self.db.set_flags(db.DB_DUPSORT) 289 self.db.set_dup_compare(dup_comparator) 290 291 def closeDB(self) : 292 if self.db is not None: 293 self.db.close() 294 self.db = None 295 296 def startTest(self) : 297 pass 298 299 def finishTest(self, expected = None) : 300 if expected is not None: 301 self.check_results(expected) 302 self.closeDB() 303 304 def check_results(self, expected) : 305 curs = self.db.cursor() 306 try: 307 index = 0 308 rec = curs.first() 309 while rec: 310 ignore, data = rec 311 self.assertLess(index, len(expected), 312 "to many values returned from cursor") 313 self.assertEqual(expected[index], data, 314 "expected value `%s' at %d but got `%s'" 315 % (expected[index], index, data)) 316 index = index + 1 317 rec = curs.next() 318 self.assertEqual(index, len(expected), 319 "not enough values returned from cursor") 320 finally: 321 curs.close() 322 323class DuplicateCompareTestCase(AbstractDuplicateCompareTestCase) : 324 def runCompareTest(self, comparator, data) : 325 self.startTest() 326 self.createDB(comparator) 327 self.addDataToDB(data) 328 self.finishTest(data) 329 330 def test_lexical_ordering(self) : 331 self.runCompareTest(lexical_cmp, _expected_lexical_test_data) 332 333 def test_reverse_lexical_ordering(self) : 334 expected_rev_data = _expected_lexical_test_data[:] 335 expected_rev_data.reverse() 336 self.runCompareTest(make_reverse_comparator(lexical_cmp), 337 expected_rev_data) 338 339class DuplicateExceptionsTestCase(AbstractDuplicateCompareTestCase) : 340 def test_raises_non_callable(self) : 341 self.startTest() 342 self.assertRaises(TypeError, self.createDB, 'abc') 343 self.assertRaises(TypeError, self.createDB, None) 344 self.finishTest() 345 346 def test_set_dup_compare_with_function(self) : 347 self.startTest() 348 self.createDB(lexical_cmp) 349 self.finishTest() 350 351 def check_results(self, results) : 352 pass 353 354 def test_compare_function_incorrect(self) : 355 self.startTest() 356 def bad_comparator(l, r) : 357 return 1 358 # verify that set_dup_compare checks that comparator('', '') == 0 359 self.assertRaises(TypeError, self.createDB, bad_comparator) 360 self.finishTest() 361 362 def test_compare_function_useless(self) : 363 self.startTest() 364 def socialist_comparator(l, r) : 365 return 0 366 self.createDB(socialist_comparator) 367 # DUPSORT does not allow "duplicate duplicates" 368 self.assertRaises(db.DBKeyExistError, self.addDataToDB, ['b', 'a', 'd']) 369 self.finishTest() 370 371 def verifyStderr(self, method, successRe) : 372 """ 373 Call method() while capturing sys.stderr output internally and 374 call self.fail() if successRe.search() does not match the stderr 375 output. This is used to test for uncatchable exceptions. 376 """ 377 stdErr = sys.stderr 378 sys.stderr = StringIO() 379 try: 380 method() 381 finally: 382 temp = sys.stderr 383 sys.stderr = stdErr 384 errorOut = temp.getvalue() 385 if not successRe.search(errorOut) : 386 self.fail("unexpected stderr output:\n"+errorOut) 387 if sys.version_info < (3, 0) : # XXX: How to do this in Py3k ??? 388 sys.exc_traceback = sys.last_traceback = None 389 390 def _test_compare_function_exception(self) : 391 self.startTest() 392 def bad_comparator(l, r) : 393 if l == r: 394 # pass the set_dup_compare test 395 return 0 396 raise RuntimeError, "i'm a naughty comparison function" 397 self.createDB(bad_comparator) 398 #print "\n*** test should print 2 uncatchable tracebacks ***" 399 self.addDataToDB(['a', 'b', 'c']) # this should raise, but... 400 self.finishTest() 401 402 def test_compare_function_exception(self) : 403 self.verifyStderr( 404 self._test_compare_function_exception, 405 re.compile('(^RuntimeError:.* naughty.*){2}', re.M|re.S) 406 ) 407 408 def _test_compare_function_bad_return(self) : 409 self.startTest() 410 def bad_comparator(l, r) : 411 if l == r: 412 # pass the set_dup_compare test 413 return 0 414 return l 415 self.createDB(bad_comparator) 416 #print "\n*** test should print 2 errors about returning an int ***" 417 self.addDataToDB(['a', 'b', 'c']) # this should raise, but... 418 self.finishTest() 419 420 def test_compare_function_bad_return(self) : 421 self.verifyStderr( 422 self._test_compare_function_bad_return, 423 re.compile('(^TypeError:.* return an int.*){2}', re.M|re.S) 424 ) 425 426 427 def test_cannot_assign_twice(self) : 428 429 def my_compare(a, b) : 430 return 0 431 432 self.startTest() 433 self.createDB(my_compare) 434 self.assertRaises(RuntimeError, self.db.set_dup_compare, my_compare) 435 436def test_suite() : 437 res = unittest.TestSuite() 438 439 res.addTest(unittest.makeSuite(ComparatorTests)) 440 res.addTest(unittest.makeSuite(BtreeExceptionsTestCase)) 441 res.addTest(unittest.makeSuite(BtreeKeyCompareTestCase)) 442 res.addTest(unittest.makeSuite(DuplicateExceptionsTestCase)) 443 res.addTest(unittest.makeSuite(DuplicateCompareTestCase)) 444 return res 445 446if __name__ == '__main__': 447 unittest.main(defaultTest = 'suite') 448