1"""Unittests for heapq.""" 2 3import random 4import unittest 5import doctest 6 7from test import support 8from test.support import import_helper 9from unittest import TestCase, skipUnless 10from operator import itemgetter 11 12py_heapq = import_helper.import_fresh_module('heapq', blocked=['_heapq']) 13c_heapq = import_helper.import_fresh_module('heapq', fresh=['_heapq']) 14 15# _heapq.nlargest/nsmallest are saved in heapq._nlargest/_smallest when 16# _heapq is imported, so check them there 17func_names = ['heapify', 'heappop', 'heappush', 'heappushpop', 'heapreplace', 18 '_heappop_max', '_heapreplace_max', '_heapify_max'] 19 20class TestModules(TestCase): 21 def test_py_functions(self): 22 for fname in func_names: 23 self.assertEqual(getattr(py_heapq, fname).__module__, 'heapq') 24 25 @skipUnless(c_heapq, 'requires _heapq') 26 def test_c_functions(self): 27 for fname in func_names: 28 self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq') 29 30 31def load_tests(loader, tests, ignore): 32 # The 'merge' function has examples in its docstring which we should test 33 # with 'doctest'. 34 # 35 # However, doctest can't easily find all docstrings in the module (loading 36 # it through import_fresh_module seems to confuse it), so we specifically 37 # create a finder which returns the doctests from the merge method. 38 39 class HeapqMergeDocTestFinder: 40 def find(self, *args, **kwargs): 41 dtf = doctest.DocTestFinder() 42 return dtf.find(py_heapq.merge) 43 44 tests.addTests(doctest.DocTestSuite(py_heapq, 45 test_finder=HeapqMergeDocTestFinder())) 46 return tests 47 48class TestHeap: 49 50 def test_push_pop(self): 51 # 1) Push 256 random numbers and pop them off, verifying all's OK. 52 heap = [] 53 data = [] 54 self.check_invariant(heap) 55 for i in range(256): 56 item = random.random() 57 data.append(item) 58 self.module.heappush(heap, item) 59 self.check_invariant(heap) 60 results = [] 61 while heap: 62 item = self.module.heappop(heap) 63 self.check_invariant(heap) 64 results.append(item) 65 data_sorted = data[:] 66 data_sorted.sort() 67 self.assertEqual(data_sorted, results) 68 # 2) Check that the invariant holds for a sorted array 69 self.check_invariant(results) 70 71 self.assertRaises(TypeError, self.module.heappush, []) 72 try: 73 self.assertRaises(TypeError, self.module.heappush, None, None) 74 self.assertRaises(TypeError, self.module.heappop, None) 75 except AttributeError: 76 pass 77 78 def check_invariant(self, heap): 79 # Check the heap invariant. 80 for pos, item in enumerate(heap): 81 if pos: # pos 0 has no parent 82 parentpos = (pos-1) >> 1 83 self.assertTrue(heap[parentpos] <= item) 84 85 def test_heapify(self): 86 for size in list(range(30)) + [20000]: 87 heap = [random.random() for dummy in range(size)] 88 self.module.heapify(heap) 89 self.check_invariant(heap) 90 91 self.assertRaises(TypeError, self.module.heapify, None) 92 93 def test_naive_nbest(self): 94 data = [random.randrange(2000) for i in range(1000)] 95 heap = [] 96 for item in data: 97 self.module.heappush(heap, item) 98 if len(heap) > 10: 99 self.module.heappop(heap) 100 heap.sort() 101 self.assertEqual(heap, sorted(data)[-10:]) 102 103 def heapiter(self, heap): 104 # An iterator returning a heap's elements, smallest-first. 105 try: 106 while 1: 107 yield self.module.heappop(heap) 108 except IndexError: 109 pass 110 111 def test_nbest(self): 112 # Less-naive "N-best" algorithm, much faster (if len(data) is big 113 # enough <wink>) than sorting all of data. However, if we had a max 114 # heap instead of a min heap, it could go faster still via 115 # heapify'ing all of data (linear time), then doing 10 heappops 116 # (10 log-time steps). 117 data = [random.randrange(2000) for i in range(1000)] 118 heap = data[:10] 119 self.module.heapify(heap) 120 for item in data[10:]: 121 if item > heap[0]: # this gets rarer the longer we run 122 self.module.heapreplace(heap, item) 123 self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:]) 124 125 self.assertRaises(TypeError, self.module.heapreplace, None) 126 self.assertRaises(TypeError, self.module.heapreplace, None, None) 127 self.assertRaises(IndexError, self.module.heapreplace, [], None) 128 129 def test_nbest_with_pushpop(self): 130 data = [random.randrange(2000) for i in range(1000)] 131 heap = data[:10] 132 self.module.heapify(heap) 133 for item in data[10:]: 134 self.module.heappushpop(heap, item) 135 self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:]) 136 self.assertEqual(self.module.heappushpop([], 'x'), 'x') 137 138 def test_heappushpop(self): 139 h = [] 140 x = self.module.heappushpop(h, 10) 141 self.assertEqual((h, x), ([], 10)) 142 143 h = [10] 144 x = self.module.heappushpop(h, 10.0) 145 self.assertEqual((h, x), ([10], 10.0)) 146 self.assertEqual(type(h[0]), int) 147 self.assertEqual(type(x), float) 148 149 h = [10] 150 x = self.module.heappushpop(h, 9) 151 self.assertEqual((h, x), ([10], 9)) 152 153 h = [10] 154 x = self.module.heappushpop(h, 11) 155 self.assertEqual((h, x), ([11], 10)) 156 157 def test_heappop_max(self): 158 # _heapop_max has an optimization for one-item lists which isn't 159 # covered in other tests, so test that case explicitly here 160 h = [3, 2] 161 self.assertEqual(self.module._heappop_max(h), 3) 162 self.assertEqual(self.module._heappop_max(h), 2) 163 164 def test_heapsort(self): 165 # Exercise everything with repeated heapsort checks 166 for trial in range(100): 167 size = random.randrange(50) 168 data = [random.randrange(25) for i in range(size)] 169 if trial & 1: # Half of the time, use heapify 170 heap = data[:] 171 self.module.heapify(heap) 172 else: # The rest of the time, use heappush 173 heap = [] 174 for item in data: 175 self.module.heappush(heap, item) 176 heap_sorted = [self.module.heappop(heap) for i in range(size)] 177 self.assertEqual(heap_sorted, sorted(data)) 178 179 def test_merge(self): 180 inputs = [] 181 for i in range(random.randrange(25)): 182 row = [] 183 for j in range(random.randrange(100)): 184 tup = random.choice('ABC'), random.randrange(-500, 500) 185 row.append(tup) 186 inputs.append(row) 187 188 for key in [None, itemgetter(0), itemgetter(1), itemgetter(1, 0)]: 189 for reverse in [False, True]: 190 seqs = [] 191 for seq in inputs: 192 seqs.append(sorted(seq, key=key, reverse=reverse)) 193 self.assertEqual(sorted(chain(*inputs), key=key, reverse=reverse), 194 list(self.module.merge(*seqs, key=key, reverse=reverse))) 195 self.assertEqual(list(self.module.merge()), []) 196 197 def test_empty_merges(self): 198 # Merging two empty lists (with or without a key) should produce 199 # another empty list. 200 self.assertEqual(list(self.module.merge([], [])), []) 201 self.assertEqual(list(self.module.merge([], [], key=lambda: 6)), []) 202 203 def test_merge_does_not_suppress_index_error(self): 204 # Issue 19018: Heapq.merge suppresses IndexError from user generator 205 def iterable(): 206 s = list(range(10)) 207 for i in range(20): 208 yield s[i] # IndexError when i > 10 209 with self.assertRaises(IndexError): 210 list(self.module.merge(iterable(), iterable())) 211 212 def test_merge_stability(self): 213 class Int(int): 214 pass 215 inputs = [[], [], [], []] 216 for i in range(20000): 217 stream = random.randrange(4) 218 x = random.randrange(500) 219 obj = Int(x) 220 obj.pair = (x, stream) 221 inputs[stream].append(obj) 222 for stream in inputs: 223 stream.sort() 224 result = [i.pair for i in self.module.merge(*inputs)] 225 self.assertEqual(result, sorted(result)) 226 227 def test_nsmallest(self): 228 data = [(random.randrange(2000), i) for i in range(1000)] 229 for f in (None, lambda x: x[0] * 547 % 2000): 230 for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100): 231 self.assertEqual(list(self.module.nsmallest(n, data)), 232 sorted(data)[:n]) 233 self.assertEqual(list(self.module.nsmallest(n, data, key=f)), 234 sorted(data, key=f)[:n]) 235 236 def test_nlargest(self): 237 data = [(random.randrange(2000), i) for i in range(1000)] 238 for f in (None, lambda x: x[0] * 547 % 2000): 239 for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100): 240 self.assertEqual(list(self.module.nlargest(n, data)), 241 sorted(data, reverse=True)[:n]) 242 self.assertEqual(list(self.module.nlargest(n, data, key=f)), 243 sorted(data, key=f, reverse=True)[:n]) 244 245 def test_comparison_operator(self): 246 # Issue 3051: Make sure heapq works with both __lt__ 247 # For python 3.0, __le__ alone is not enough 248 def hsort(data, comp): 249 data = [comp(x) for x in data] 250 self.module.heapify(data) 251 return [self.module.heappop(data).x for i in range(len(data))] 252 class LT: 253 def __init__(self, x): 254 self.x = x 255 def __lt__(self, other): 256 return self.x > other.x 257 class LE: 258 def __init__(self, x): 259 self.x = x 260 def __le__(self, other): 261 return self.x >= other.x 262 data = [random.random() for i in range(100)] 263 target = sorted(data, reverse=True) 264 self.assertEqual(hsort(data, LT), target) 265 self.assertRaises(TypeError, data, LE) 266 267 268class TestHeapPython(TestHeap, TestCase): 269 module = py_heapq 270 271 272@skipUnless(c_heapq, 'requires _heapq') 273class TestHeapC(TestHeap, TestCase): 274 module = c_heapq 275 276 277#============================================================================== 278 279class LenOnly: 280 "Dummy sequence class defining __len__ but not __getitem__." 281 def __len__(self): 282 return 10 283 284class CmpErr: 285 "Dummy element that always raises an error during comparison" 286 def __eq__(self, other): 287 raise ZeroDivisionError 288 __ne__ = __lt__ = __le__ = __gt__ = __ge__ = __eq__ 289 290def R(seqn): 291 'Regular generator' 292 for i in seqn: 293 yield i 294 295class G: 296 'Sequence using __getitem__' 297 def __init__(self, seqn): 298 self.seqn = seqn 299 def __getitem__(self, i): 300 return self.seqn[i] 301 302class I: 303 'Sequence using iterator protocol' 304 def __init__(self, seqn): 305 self.seqn = seqn 306 self.i = 0 307 def __iter__(self): 308 return self 309 def __next__(self): 310 if self.i >= len(self.seqn): raise StopIteration 311 v = self.seqn[self.i] 312 self.i += 1 313 return v 314 315class Ig: 316 'Sequence using iterator protocol defined with a generator' 317 def __init__(self, seqn): 318 self.seqn = seqn 319 self.i = 0 320 def __iter__(self): 321 for val in self.seqn: 322 yield val 323 324class X: 325 'Missing __getitem__ and __iter__' 326 def __init__(self, seqn): 327 self.seqn = seqn 328 self.i = 0 329 def __next__(self): 330 if self.i >= len(self.seqn): raise StopIteration 331 v = self.seqn[self.i] 332 self.i += 1 333 return v 334 335class N: 336 'Iterator missing __next__()' 337 def __init__(self, seqn): 338 self.seqn = seqn 339 self.i = 0 340 def __iter__(self): 341 return self 342 343class E: 344 'Test propagation of exceptions' 345 def __init__(self, seqn): 346 self.seqn = seqn 347 self.i = 0 348 def __iter__(self): 349 return self 350 def __next__(self): 351 3 // 0 352 353class S: 354 'Test immediate stop' 355 def __init__(self, seqn): 356 pass 357 def __iter__(self): 358 return self 359 def __next__(self): 360 raise StopIteration 361 362from itertools import chain 363def L(seqn): 364 'Test multiple tiers of iterators' 365 return chain(map(lambda x:x, R(Ig(G(seqn))))) 366 367 368class SideEffectLT: 369 def __init__(self, value, heap): 370 self.value = value 371 self.heap = heap 372 373 def __lt__(self, other): 374 self.heap[:] = [] 375 return self.value < other.value 376 377 378class TestErrorHandling: 379 380 def test_non_sequence(self): 381 for f in (self.module.heapify, self.module.heappop): 382 self.assertRaises((TypeError, AttributeError), f, 10) 383 for f in (self.module.heappush, self.module.heapreplace, 384 self.module.nlargest, self.module.nsmallest): 385 self.assertRaises((TypeError, AttributeError), f, 10, 10) 386 387 def test_len_only(self): 388 for f in (self.module.heapify, self.module.heappop): 389 self.assertRaises((TypeError, AttributeError), f, LenOnly()) 390 for f in (self.module.heappush, self.module.heapreplace): 391 self.assertRaises((TypeError, AttributeError), f, LenOnly(), 10) 392 for f in (self.module.nlargest, self.module.nsmallest): 393 self.assertRaises(TypeError, f, 2, LenOnly()) 394 395 def test_cmp_err(self): 396 seq = [CmpErr(), CmpErr(), CmpErr()] 397 for f in (self.module.heapify, self.module.heappop): 398 self.assertRaises(ZeroDivisionError, f, seq) 399 for f in (self.module.heappush, self.module.heapreplace): 400 self.assertRaises(ZeroDivisionError, f, seq, 10) 401 for f in (self.module.nlargest, self.module.nsmallest): 402 self.assertRaises(ZeroDivisionError, f, 2, seq) 403 404 def test_arg_parsing(self): 405 for f in (self.module.heapify, self.module.heappop, 406 self.module.heappush, self.module.heapreplace, 407 self.module.nlargest, self.module.nsmallest): 408 self.assertRaises((TypeError, AttributeError), f, 10) 409 410 def test_iterable_args(self): 411 for f in (self.module.nlargest, self.module.nsmallest): 412 for s in ("123", "", range(1000), (1, 1.2), range(2000,2200,5)): 413 for g in (G, I, Ig, L, R): 414 self.assertEqual(list(f(2, g(s))), list(f(2,s))) 415 self.assertEqual(list(f(2, S(s))), []) 416 self.assertRaises(TypeError, f, 2, X(s)) 417 self.assertRaises(TypeError, f, 2, N(s)) 418 self.assertRaises(ZeroDivisionError, f, 2, E(s)) 419 420 # Issue #17278: the heap may change size while it's being walked. 421 422 def test_heappush_mutating_heap(self): 423 heap = [] 424 heap.extend(SideEffectLT(i, heap) for i in range(200)) 425 # Python version raises IndexError, C version RuntimeError 426 with self.assertRaises((IndexError, RuntimeError)): 427 self.module.heappush(heap, SideEffectLT(5, heap)) 428 429 def test_heappop_mutating_heap(self): 430 heap = [] 431 heap.extend(SideEffectLT(i, heap) for i in range(200)) 432 # Python version raises IndexError, C version RuntimeError 433 with self.assertRaises((IndexError, RuntimeError)): 434 self.module.heappop(heap) 435 436 def test_comparison_operator_modifiying_heap(self): 437 # See bpo-39421: Strong references need to be taken 438 # when comparing objects as they can alter the heap 439 class EvilClass(int): 440 def __lt__(self, o): 441 heap.clear() 442 return NotImplemented 443 444 heap = [] 445 self.module.heappush(heap, EvilClass(0)) 446 self.assertRaises(IndexError, self.module.heappushpop, heap, 1) 447 448 def test_comparison_operator_modifiying_heap_two_heaps(self): 449 450 class h(int): 451 def __lt__(self, o): 452 list2.clear() 453 return NotImplemented 454 455 class g(int): 456 def __lt__(self, o): 457 list1.clear() 458 return NotImplemented 459 460 list1, list2 = [], [] 461 462 self.module.heappush(list1, h(0)) 463 self.module.heappush(list2, g(0)) 464 465 self.assertRaises((IndexError, RuntimeError), self.module.heappush, list1, g(1)) 466 self.assertRaises((IndexError, RuntimeError), self.module.heappush, list2, h(1)) 467 468class TestErrorHandlingPython(TestErrorHandling, TestCase): 469 module = py_heapq 470 471@skipUnless(c_heapq, 'requires _heapq') 472class TestErrorHandlingC(TestErrorHandling, TestCase): 473 module = c_heapq 474 475 476if __name__ == "__main__": 477 unittest.main() 478