• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1from test import support
2import random
3import unittest
4from functools import cmp_to_key
5
6verbose = support.verbose
7nerrors = 0
8
9
10def check(tag, expected, raw, compare=None):
11    global nerrors
12
13    if verbose:
14        print("    checking", tag)
15
16    orig = raw[:]   # save input in case of error
17    if compare:
18        raw.sort(key=cmp_to_key(compare))
19    else:
20        raw.sort()
21
22    if len(expected) != len(raw):
23        print("error in", tag)
24        print("length mismatch;", len(expected), len(raw))
25        print(expected)
26        print(orig)
27        print(raw)
28        nerrors += 1
29        return
30
31    for i, good in enumerate(expected):
32        maybe = raw[i]
33        if good is not maybe:
34            print("error in", tag)
35            print("out of order at index", i, good, maybe)
36            print(expected)
37            print(orig)
38            print(raw)
39            nerrors += 1
40            return
41
42class TestBase(unittest.TestCase):
43    def testStressfully(self):
44        # Try a variety of sizes at and around powers of 2, and at powers of 10.
45        sizes = [0]
46        for power in range(1, 10):
47            n = 2 ** power
48            sizes.extend(range(n-1, n+2))
49        sizes.extend([10, 100, 1000])
50
51        class Complains(object):
52            maybe_complain = True
53
54            def __init__(self, i):
55                self.i = i
56
57            def __lt__(self, other):
58                if Complains.maybe_complain and random.random() < 0.001:
59                    if verbose:
60                        print("        complaining at", self, other)
61                    raise RuntimeError
62                return self.i < other.i
63
64            def __repr__(self):
65                return "Complains(%d)" % self.i
66
67        class Stable(object):
68            def __init__(self, key, i):
69                self.key = key
70                self.index = i
71
72            def __lt__(self, other):
73                return self.key < other.key
74
75            def __repr__(self):
76                return "Stable(%d, %d)" % (self.key, self.index)
77
78        for n in sizes:
79            x = list(range(n))
80            if verbose:
81                print("Testing size", n)
82
83            s = x[:]
84            check("identity", x, s)
85
86            s = x[:]
87            s.reverse()
88            check("reversed", x, s)
89
90            s = x[:]
91            random.shuffle(s)
92            check("random permutation", x, s)
93
94            y = x[:]
95            y.reverse()
96            s = x[:]
97            check("reversed via function", y, s, lambda a, b: (b>a)-(b<a))
98
99            if verbose:
100                print("    Checking against an insane comparison function.")
101                print("        If the implementation isn't careful, this may segfault.")
102            s = x[:]
103            s.sort(key=cmp_to_key(lambda a, b:  int(random.random() * 3) - 1))
104            check("an insane function left some permutation", x, s)
105
106            if len(x) >= 2:
107                def bad_key(x):
108                    raise RuntimeError
109                s = x[:]
110                self.assertRaises(RuntimeError, s.sort, key=bad_key)
111
112            x = [Complains(i) for i in x]
113            s = x[:]
114            random.shuffle(s)
115            Complains.maybe_complain = True
116            it_complained = False
117            try:
118                s.sort()
119            except RuntimeError:
120                it_complained = True
121            if it_complained:
122                Complains.maybe_complain = False
123                check("exception during sort left some permutation", x, s)
124
125            s = [Stable(random.randrange(10), i) for i in range(n)]
126            augmented = [(e, e.index) for e in s]
127            augmented.sort()    # forced stable because ties broken by index
128            x = [e for e, i in augmented] # a stable sort of s
129            check("stability", x, s)
130
131    def test_small_stability(self):
132        from itertools import product
133        from operator import itemgetter
134
135        # Exhaustively test stability across all lists of small lengths
136        # and only a few distinct elements.
137        # This can provoke edge cases that randomization is unlikely to find.
138        # But it can grow very expensive quickly, so don't overdo it.
139        NELTS = 3
140        MAXSIZE = 9
141
142        pick0 = itemgetter(0)
143        for length in range(MAXSIZE + 1):
144            # There are NELTS ** length distinct lists.
145            for t in product(range(NELTS), repeat=length):
146                xs = list(zip(t, range(length)))
147                # Stability forced by index in each element.
148                forced = sorted(xs)
149                # Use key= to hide the index from compares.
150                native = sorted(xs, key=pick0)
151                self.assertEqual(forced, native)
152#==============================================================================
153
154class TestBugs(unittest.TestCase):
155
156    def test_bug453523(self):
157        # bug 453523 -- list.sort() crasher.
158        # If this fails, the most likely outcome is a core dump.
159        # Mutations during a list sort should raise a ValueError.
160
161        class C:
162            def __lt__(self, other):
163                if L and random.random() < 0.75:
164                    L.pop()
165                else:
166                    L.append(3)
167                return random.random() < 0.5
168
169        L = [C() for i in range(50)]
170        self.assertRaises(ValueError, L.sort)
171
172    def test_undetected_mutation(self):
173        # Python 2.4a1 did not always detect mutation
174        memorywaster = []
175        for i in range(20):
176            def mutating_cmp(x, y):
177                L.append(3)
178                L.pop()
179                return (x > y) - (x < y)
180            L = [1,2]
181            self.assertRaises(ValueError, L.sort, key=cmp_to_key(mutating_cmp))
182            def mutating_cmp(x, y):
183                L.append(3)
184                del L[:]
185                return (x > y) - (x < y)
186            self.assertRaises(ValueError, L.sort, key=cmp_to_key(mutating_cmp))
187            memorywaster = [memorywaster]
188
189#==============================================================================
190
191class TestDecorateSortUndecorate(unittest.TestCase):
192
193    def test_decorated(self):
194        data = 'The quick Brown fox Jumped over The lazy Dog'.split()
195        copy = data[:]
196        random.shuffle(data)
197        data.sort(key=str.lower)
198        def my_cmp(x, y):
199            xlower, ylower = x.lower(), y.lower()
200            return (xlower > ylower) - (xlower < ylower)
201        copy.sort(key=cmp_to_key(my_cmp))
202
203    def test_baddecorator(self):
204        data = 'The quick Brown fox Jumped over The lazy Dog'.split()
205        self.assertRaises(TypeError, data.sort, key=lambda x,y: 0)
206
207    def test_stability(self):
208        data = [(random.randrange(100), i) for i in range(200)]
209        copy = data[:]
210        data.sort(key=lambda t: t[0])   # sort on the random first field
211        copy.sort()                     # sort using both fields
212        self.assertEqual(data, copy)    # should get the same result
213
214    def test_key_with_exception(self):
215        # Verify that the wrapper has been removed
216        data = list(range(-2, 2))
217        dup = data[:]
218        self.assertRaises(ZeroDivisionError, data.sort, key=lambda x: 1/x)
219        self.assertEqual(data, dup)
220
221    def test_key_with_mutation(self):
222        data = list(range(10))
223        def k(x):
224            del data[:]
225            data[:] = range(20)
226            return x
227        self.assertRaises(ValueError, data.sort, key=k)
228
229    def test_key_with_mutating_del(self):
230        data = list(range(10))
231        class SortKiller(object):
232            def __init__(self, x):
233                pass
234            def __del__(self):
235                del data[:]
236                data[:] = range(20)
237            def __lt__(self, other):
238                return id(self) < id(other)
239        self.assertRaises(ValueError, data.sort, key=SortKiller)
240
241    def test_key_with_mutating_del_and_exception(self):
242        data = list(range(10))
243        ## dup = data[:]
244        class SortKiller(object):
245            def __init__(self, x):
246                if x > 2:
247                    raise RuntimeError
248            def __del__(self):
249                del data[:]
250                data[:] = list(range(20))
251        self.assertRaises(RuntimeError, data.sort, key=SortKiller)
252        ## major honking subtlety: we *can't* do:
253        ##
254        ## self.assertEqual(data, dup)
255        ##
256        ## because there is a reference to a SortKiller in the
257        ## traceback and by the time it dies we're outside the call to
258        ## .sort() and so the list protection gimmicks are out of
259        ## date (this cost some brain cells to figure out...).
260
261    def test_reverse(self):
262        data = list(range(100))
263        random.shuffle(data)
264        data.sort(reverse=True)
265        self.assertEqual(data, list(range(99,-1,-1)))
266
267    def test_reverse_stability(self):
268        data = [(random.randrange(100), i) for i in range(200)]
269        copy1 = data[:]
270        copy2 = data[:]
271        def my_cmp(x, y):
272            x0, y0 = x[0], y[0]
273            return (x0 > y0) - (x0 < y0)
274        def my_cmp_reversed(x, y):
275            x0, y0 = x[0], y[0]
276            return (y0 > x0) - (y0 < x0)
277        data.sort(key=cmp_to_key(my_cmp), reverse=True)
278        copy1.sort(key=cmp_to_key(my_cmp_reversed))
279        self.assertEqual(data, copy1)
280        copy2.sort(key=lambda x: x[0], reverse=True)
281        self.assertEqual(data, copy2)
282
283#==============================================================================
284def check_against_PyObject_RichCompareBool(self, L):
285    ## The idea here is to exploit the fact that unsafe_tuple_compare uses
286    ## PyObject_RichCompareBool for the second elements of tuples. So we have,
287    ## for (most) L, sorted(L) == [y[1] for y in sorted([(0,x) for x in L])]
288    ## This will work as long as __eq__ => not __lt__ for all the objects in L,
289    ## which holds for all the types used below.
290    ##
291    ## Testing this way ensures that the optimized implementation remains consistent
292    ## with the naive implementation, even if changes are made to any of the
293    ## richcompares.
294    ##
295    ## This function tests sorting for three lists (it randomly shuffles each one):
296    ##                        1. L
297    ##                        2. [(x,) for x in L]
298    ##                        3. [((x,),) for x in L]
299
300    random.seed(0)
301    random.shuffle(L)
302    L_1 = L[:]
303    L_2 = [(x,) for x in L]
304    L_3 = [((x,),) for x in L]
305    for L in [L_1, L_2, L_3]:
306        optimized = sorted(L)
307        reference = [y[1] for y in sorted([(0,x) for x in L])]
308        for (opt, ref) in zip(optimized, reference):
309            self.assertIs(opt, ref)
310            #note: not assertEqual! We want to ensure *identical* behavior.
311
312class TestOptimizedCompares(unittest.TestCase):
313    def test_safe_object_compare(self):
314        heterogeneous_lists = [[0, 'foo'],
315                               [0.0, 'foo'],
316                               [('foo',), 'foo']]
317        for L in heterogeneous_lists:
318            self.assertRaises(TypeError, L.sort)
319            self.assertRaises(TypeError, [(x,) for x in L].sort)
320            self.assertRaises(TypeError, [((x,),) for x in L].sort)
321
322        float_int_lists = [[1,1.1],
323                           [1<<70,1.1],
324                           [1.1,1],
325                           [1.1,1<<70]]
326        for L in float_int_lists:
327            check_against_PyObject_RichCompareBool(self, L)
328
329    def test_unsafe_object_compare(self):
330
331        # This test is by ppperry. It ensures that unsafe_object_compare is
332        # verifying ms->key_richcompare == tp->richcompare before comparing.
333
334        class WackyComparator(int):
335            def __lt__(self, other):
336                elem.__class__ = WackyList2
337                return int.__lt__(self, other)
338
339        class WackyList1(list):
340            pass
341
342        class WackyList2(list):
343            def __lt__(self, other):
344                raise ValueError
345
346        L = [WackyList1([WackyComparator(i), i]) for i in range(10)]
347        elem = L[-1]
348        with self.assertRaises(ValueError):
349            L.sort()
350
351        L = [WackyList1([WackyComparator(i), i]) for i in range(10)]
352        elem = L[-1]
353        with self.assertRaises(ValueError):
354            [(x,) for x in L].sort()
355
356        # The following test is also by ppperry. It ensures that
357        # unsafe_object_compare handles Py_NotImplemented appropriately.
358        class PointlessComparator:
359            def __lt__(self, other):
360                return NotImplemented
361        L = [PointlessComparator(), PointlessComparator()]
362        self.assertRaises(TypeError, L.sort)
363        self.assertRaises(TypeError, [(x,) for x in L].sort)
364
365        # The following tests go through various types that would trigger
366        # ms->key_compare = unsafe_object_compare
367        lists = [list(range(100)) + [(1<<70)],
368                 [str(x) for x in range(100)] + ['\uffff'],
369                 [bytes(x) for x in range(100)],
370                 [cmp_to_key(lambda x,y: x<y)(x) for x in range(100)]]
371        for L in lists:
372            check_against_PyObject_RichCompareBool(self, L)
373
374    def test_unsafe_latin_compare(self):
375        check_against_PyObject_RichCompareBool(self, [str(x) for
376                                                      x in range(100)])
377
378    def test_unsafe_long_compare(self):
379        check_against_PyObject_RichCompareBool(self, [x for
380                                                      x in range(100)])
381
382    def test_unsafe_float_compare(self):
383        check_against_PyObject_RichCompareBool(self, [float(x) for
384                                                      x in range(100)])
385
386    def test_unsafe_tuple_compare(self):
387        # This test was suggested by Tim Peters. It verifies that the tuple
388        # comparison respects the current tuple compare semantics, which do not
389        # guarantee that x < x <=> (x,) < (x,)
390        #
391        # Note that we don't have to put anything in tuples here, because
392        # the check function does a tuple test automatically.
393
394        check_against_PyObject_RichCompareBool(self, [float('nan')]*100)
395        check_against_PyObject_RichCompareBool(self, [float('nan') for
396                                                      _ in range(100)])
397
398    def test_not_all_tuples(self):
399        self.assertRaises(TypeError, [(1.0, 1.0), (False, "A"), 6].sort)
400        self.assertRaises(TypeError, [('a', 1), (1, 'a')].sort)
401        self.assertRaises(TypeError, [(1, 'a'), ('a', 1)].sort)
402
403    def test_none_in_tuples(self):
404        expected = [(None, 1), (None, 2)]
405        actual = sorted([(None, 2), (None, 1)])
406        self.assertEqual(actual, expected)
407
408#==============================================================================
409
410if __name__ == "__main__":
411    unittest.main()
412