• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Tests for rich comparisons
2
3import unittest
4from test import support
5
6import operator
7
8class Number:
9
10    def __init__(self, x):
11        self.x = x
12
13    def __lt__(self, other):
14        return self.x < other
15
16    def __le__(self, other):
17        return self.x <= other
18
19    def __eq__(self, other):
20        return self.x == other
21
22    def __ne__(self, other):
23        return self.x != other
24
25    def __gt__(self, other):
26        return self.x > other
27
28    def __ge__(self, other):
29        return self.x >= other
30
31    def __cmp__(self, other):
32        raise support.TestFailed("Number.__cmp__() should not be called")
33
34    def __repr__(self):
35        return "Number(%r)" % (self.x, )
36
37class Vector:
38
39    def __init__(self, data):
40        self.data = data
41
42    def __len__(self):
43        return len(self.data)
44
45    def __getitem__(self, i):
46        return self.data[i]
47
48    def __setitem__(self, i, v):
49        self.data[i] = v
50
51    __hash__ = None # Vectors cannot be hashed
52
53    def __bool__(self):
54        raise TypeError("Vectors cannot be used in Boolean contexts")
55
56    def __cmp__(self, other):
57        raise support.TestFailed("Vector.__cmp__() should not be called")
58
59    def __repr__(self):
60        return "Vector(%r)" % (self.data, )
61
62    def __lt__(self, other):
63        return Vector([a < b for a, b in zip(self.data, self.__cast(other))])
64
65    def __le__(self, other):
66        return Vector([a <= b for a, b in zip(self.data, self.__cast(other))])
67
68    def __eq__(self, other):
69        return Vector([a == b for a, b in zip(self.data, self.__cast(other))])
70
71    def __ne__(self, other):
72        return Vector([a != b for a, b in zip(self.data, self.__cast(other))])
73
74    def __gt__(self, other):
75        return Vector([a > b for a, b in zip(self.data, self.__cast(other))])
76
77    def __ge__(self, other):
78        return Vector([a >= b for a, b in zip(self.data, self.__cast(other))])
79
80    def __cast(self, other):
81        if isinstance(other, Vector):
82            other = other.data
83        if len(self.data) != len(other):
84            raise ValueError("Cannot compare vectors of different length")
85        return other
86
87opmap = {
88    "lt": (lambda a,b: a< b, operator.lt, operator.__lt__),
89    "le": (lambda a,b: a<=b, operator.le, operator.__le__),
90    "eq": (lambda a,b: a==b, operator.eq, operator.__eq__),
91    "ne": (lambda a,b: a!=b, operator.ne, operator.__ne__),
92    "gt": (lambda a,b: a> b, operator.gt, operator.__gt__),
93    "ge": (lambda a,b: a>=b, operator.ge, operator.__ge__)
94}
95
96class VectorTest(unittest.TestCase):
97
98    def checkfail(self, error, opname, *args):
99        for op in opmap[opname]:
100            self.assertRaises(error, op, *args)
101
102    def checkequal(self, opname, a, b, expres):
103        for op in opmap[opname]:
104            realres = op(a, b)
105            # can't use assertEqual(realres, expres) here
106            self.assertEqual(len(realres), len(expres))
107            for i in range(len(realres)):
108                # results are bool, so we can use "is" here
109                self.assertTrue(realres[i] is expres[i])
110
111    def test_mixed(self):
112        # check that comparisons involving Vector objects
113        # which return rich results (i.e. Vectors with itemwise
114        # comparison results) work
115        a = Vector(range(2))
116        b = Vector(range(3))
117        # all comparisons should fail for different length
118        for opname in opmap:
119            self.checkfail(ValueError, opname, a, b)
120
121        a = list(range(5))
122        b = 5 * [2]
123        # try mixed arguments (but not (a, b) as that won't return a bool vector)
124        args = [(a, Vector(b)), (Vector(a), b), (Vector(a), Vector(b))]
125        for (a, b) in args:
126            self.checkequal("lt", a, b, [True,  True,  False, False, False])
127            self.checkequal("le", a, b, [True,  True,  True,  False, False])
128            self.checkequal("eq", a, b, [False, False, True,  False, False])
129            self.checkequal("ne", a, b, [True,  True,  False, True,  True ])
130            self.checkequal("gt", a, b, [False, False, False, True,  True ])
131            self.checkequal("ge", a, b, [False, False, True,  True,  True ])
132
133            for ops in opmap.values():
134                for op in ops:
135                    # calls __bool__, which should fail
136                    self.assertRaises(TypeError, bool, op(a, b))
137
138class NumberTest(unittest.TestCase):
139
140    def test_basic(self):
141        # Check that comparisons involving Number objects
142        # give the same results give as comparing the
143        # corresponding ints
144        for a in range(3):
145            for b in range(3):
146                for typea in (int, Number):
147                    for typeb in (int, Number):
148                        if typea==typeb==int:
149                            continue # the combination int, int is useless
150                        ta = typea(a)
151                        tb = typeb(b)
152                        for ops in opmap.values():
153                            for op in ops:
154                                realoutcome = op(a, b)
155                                testoutcome = op(ta, tb)
156                                self.assertEqual(realoutcome, testoutcome)
157
158    def checkvalue(self, opname, a, b, expres):
159        for typea in (int, Number):
160            for typeb in (int, Number):
161                ta = typea(a)
162                tb = typeb(b)
163                for op in opmap[opname]:
164                    realres = op(ta, tb)
165                    realres = getattr(realres, "x", realres)
166                    self.assertTrue(realres is expres)
167
168    def test_values(self):
169        # check all operators and all comparison results
170        self.checkvalue("lt", 0, 0, False)
171        self.checkvalue("le", 0, 0, True )
172        self.checkvalue("eq", 0, 0, True )
173        self.checkvalue("ne", 0, 0, False)
174        self.checkvalue("gt", 0, 0, False)
175        self.checkvalue("ge", 0, 0, True )
176
177        self.checkvalue("lt", 0, 1, True )
178        self.checkvalue("le", 0, 1, True )
179        self.checkvalue("eq", 0, 1, False)
180        self.checkvalue("ne", 0, 1, True )
181        self.checkvalue("gt", 0, 1, False)
182        self.checkvalue("ge", 0, 1, False)
183
184        self.checkvalue("lt", 1, 0, False)
185        self.checkvalue("le", 1, 0, False)
186        self.checkvalue("eq", 1, 0, False)
187        self.checkvalue("ne", 1, 0, True )
188        self.checkvalue("gt", 1, 0, True )
189        self.checkvalue("ge", 1, 0, True )
190
191class MiscTest(unittest.TestCase):
192
193    def test_misbehavin(self):
194        class Misb:
195            def __lt__(self_, other): return 0
196            def __gt__(self_, other): return 0
197            def __eq__(self_, other): return 0
198            def __le__(self_, other): self.fail("This shouldn't happen")
199            def __ge__(self_, other): self.fail("This shouldn't happen")
200            def __ne__(self_, other): self.fail("This shouldn't happen")
201        a = Misb()
202        b = Misb()
203        self.assertEqual(a<b, 0)
204        self.assertEqual(a==b, 0)
205        self.assertEqual(a>b, 0)
206
207    def test_not(self):
208        # Check that exceptions in __bool__ are properly
209        # propagated by the not operator
210        import operator
211        class Exc(Exception):
212            pass
213        class Bad:
214            def __bool__(self):
215                raise Exc
216
217        def do(bad):
218            not bad
219
220        for func in (do, operator.not_):
221            self.assertRaises(Exc, func, Bad())
222
223    @support.no_tracing
224    @support.infinite_recursion(25)
225    def test_recursion(self):
226        # Check that comparison for recursive objects fails gracefully
227        from collections import UserList
228        a = UserList()
229        b = UserList()
230        a.append(b)
231        b.append(a)
232        self.assertRaises(RecursionError, operator.eq, a, b)
233        self.assertRaises(RecursionError, operator.ne, a, b)
234        self.assertRaises(RecursionError, operator.lt, a, b)
235        self.assertRaises(RecursionError, operator.le, a, b)
236        self.assertRaises(RecursionError, operator.gt, a, b)
237        self.assertRaises(RecursionError, operator.ge, a, b)
238
239        b.append(17)
240        # Even recursive lists of different lengths are different,
241        # but they cannot be ordered
242        self.assertTrue(not (a == b))
243        self.assertTrue(a != b)
244        self.assertRaises(RecursionError, operator.lt, a, b)
245        self.assertRaises(RecursionError, operator.le, a, b)
246        self.assertRaises(RecursionError, operator.gt, a, b)
247        self.assertRaises(RecursionError, operator.ge, a, b)
248        a.append(17)
249        self.assertRaises(RecursionError, operator.eq, a, b)
250        self.assertRaises(RecursionError, operator.ne, a, b)
251        a.insert(0, 11)
252        b.insert(0, 12)
253        self.assertTrue(not (a == b))
254        self.assertTrue(a != b)
255        self.assertTrue(a < b)
256
257    def test_exception_message(self):
258        class Spam:
259            pass
260
261        tests = [
262            (lambda: 42 < None, r"'<' .* of 'int' and 'NoneType'"),
263            (lambda: None < 42, r"'<' .* of 'NoneType' and 'int'"),
264            (lambda: 42 > None, r"'>' .* of 'int' and 'NoneType'"),
265            (lambda: "foo" < None, r"'<' .* of 'str' and 'NoneType'"),
266            (lambda: "foo" >= 666, r"'>=' .* of 'str' and 'int'"),
267            (lambda: 42 <= None, r"'<=' .* of 'int' and 'NoneType'"),
268            (lambda: 42 >= None, r"'>=' .* of 'int' and 'NoneType'"),
269            (lambda: 42 < [], r"'<' .* of 'int' and 'list'"),
270            (lambda: () > [], r"'>' .* of 'tuple' and 'list'"),
271            (lambda: None >= None, r"'>=' .* of 'NoneType' and 'NoneType'"),
272            (lambda: Spam() < 42, r"'<' .* of 'Spam' and 'int'"),
273            (lambda: 42 < Spam(), r"'<' .* of 'int' and 'Spam'"),
274            (lambda: Spam() <= Spam(), r"'<=' .* of 'Spam' and 'Spam'"),
275        ]
276        for i, test in enumerate(tests):
277            with self.subTest(test=i):
278                with self.assertRaisesRegex(TypeError, test[1]):
279                    test[0]()
280
281
282class DictTest(unittest.TestCase):
283
284    def test_dicts(self):
285        # Verify that __eq__ and __ne__ work for dicts even if the keys and
286        # values don't support anything other than __eq__ and __ne__ (and
287        # __hash__).  Complex numbers are a fine example of that.
288        import random
289        imag1a = {}
290        for i in range(50):
291            imag1a[random.randrange(100)*1j] = random.randrange(100)*1j
292        items = list(imag1a.items())
293        random.shuffle(items)
294        imag1b = {}
295        for k, v in items:
296            imag1b[k] = v
297        imag2 = imag1b.copy()
298        imag2[k] = v + 1.0
299        self.assertEqual(imag1a, imag1a)
300        self.assertEqual(imag1a, imag1b)
301        self.assertEqual(imag2, imag2)
302        self.assertTrue(imag1a != imag2)
303        for opname in ("lt", "le", "gt", "ge"):
304            for op in opmap[opname]:
305                self.assertRaises(TypeError, op, imag1a, imag2)
306
307class ListTest(unittest.TestCase):
308
309    def test_coverage(self):
310        # exercise all comparisons for lists
311        x = [42]
312        self.assertIs(x<x, False)
313        self.assertIs(x<=x, True)
314        self.assertIs(x==x, True)
315        self.assertIs(x!=x, False)
316        self.assertIs(x>x, False)
317        self.assertIs(x>=x, True)
318        y = [42, 42]
319        self.assertIs(x<y, True)
320        self.assertIs(x<=y, True)
321        self.assertIs(x==y, False)
322        self.assertIs(x!=y, True)
323        self.assertIs(x>y, False)
324        self.assertIs(x>=y, False)
325
326    def test_badentry(self):
327        # make sure that exceptions for item comparison are properly
328        # propagated in list comparisons
329        class Exc(Exception):
330            pass
331        class Bad:
332            def __eq__(self, other):
333                raise Exc
334
335        x = [Bad()]
336        y = [Bad()]
337
338        for op in opmap["eq"]:
339            self.assertRaises(Exc, op, x, y)
340
341    def test_goodentry(self):
342        # This test exercises the final call to PyObject_RichCompare()
343        # in Objects/listobject.c::list_richcompare()
344        class Good:
345            def __lt__(self, other):
346                return True
347
348        x = [Good()]
349        y = [Good()]
350
351        for op in opmap["lt"]:
352            self.assertIs(op(x, y), True)
353
354
355if __name__ == "__main__":
356    unittest.main()
357