1# Tests for rich comparisons 2 3import unittest 4from test import support 5 6import operator 7 8class Number: 9 10 def __init__(self, x): 11 self.x = x 12 13 def __lt__(self, other): 14 return self.x < other 15 16 def __le__(self, other): 17 return self.x <= other 18 19 def __eq__(self, other): 20 return self.x == other 21 22 def __ne__(self, other): 23 return self.x != other 24 25 def __gt__(self, other): 26 return self.x > other 27 28 def __ge__(self, other): 29 return self.x >= other 30 31 def __cmp__(self, other): 32 raise support.TestFailed("Number.__cmp__() should not be called") 33 34 def __repr__(self): 35 return "Number(%r)" % (self.x, ) 36 37class Vector: 38 39 def __init__(self, data): 40 self.data = data 41 42 def __len__(self): 43 return len(self.data) 44 45 def __getitem__(self, i): 46 return self.data[i] 47 48 def __setitem__(self, i, v): 49 self.data[i] = v 50 51 __hash__ = None # Vectors cannot be hashed 52 53 def __bool__(self): 54 raise TypeError("Vectors cannot be used in Boolean contexts") 55 56 def __cmp__(self, other): 57 raise support.TestFailed("Vector.__cmp__() should not be called") 58 59 def __repr__(self): 60 return "Vector(%r)" % (self.data, ) 61 62 def __lt__(self, other): 63 return Vector([a < b for a, b in zip(self.data, self.__cast(other))]) 64 65 def __le__(self, other): 66 return Vector([a <= b for a, b in zip(self.data, self.__cast(other))]) 67 68 def __eq__(self, other): 69 return Vector([a == b for a, b in zip(self.data, self.__cast(other))]) 70 71 def __ne__(self, other): 72 return Vector([a != b for a, b in zip(self.data, self.__cast(other))]) 73 74 def __gt__(self, other): 75 return Vector([a > b for a, b in zip(self.data, self.__cast(other))]) 76 77 def __ge__(self, other): 78 return Vector([a >= b for a, b in zip(self.data, self.__cast(other))]) 79 80 def __cast(self, other): 81 if isinstance(other, Vector): 82 other = other.data 83 if len(self.data) != len(other): 84 raise ValueError("Cannot compare vectors of different length") 85 return other 86 87opmap = { 88 "lt": (lambda a,b: a< b, operator.lt, operator.__lt__), 89 "le": (lambda a,b: a<=b, operator.le, operator.__le__), 90 "eq": (lambda a,b: a==b, operator.eq, operator.__eq__), 91 "ne": (lambda a,b: a!=b, operator.ne, operator.__ne__), 92 "gt": (lambda a,b: a> b, operator.gt, operator.__gt__), 93 "ge": (lambda a,b: a>=b, operator.ge, operator.__ge__) 94} 95 96class VectorTest(unittest.TestCase): 97 98 def checkfail(self, error, opname, *args): 99 for op in opmap[opname]: 100 self.assertRaises(error, op, *args) 101 102 def checkequal(self, opname, a, b, expres): 103 for op in opmap[opname]: 104 realres = op(a, b) 105 # can't use assertEqual(realres, expres) here 106 self.assertEqual(len(realres), len(expres)) 107 for i in range(len(realres)): 108 # results are bool, so we can use "is" here 109 self.assertTrue(realres[i] is expres[i]) 110 111 def test_mixed(self): 112 # check that comparisons involving Vector objects 113 # which return rich results (i.e. Vectors with itemwise 114 # comparison results) work 115 a = Vector(range(2)) 116 b = Vector(range(3)) 117 # all comparisons should fail for different length 118 for opname in opmap: 119 self.checkfail(ValueError, opname, a, b) 120 121 a = list(range(5)) 122 b = 5 * [2] 123 # try mixed arguments (but not (a, b) as that won't return a bool vector) 124 args = [(a, Vector(b)), (Vector(a), b), (Vector(a), Vector(b))] 125 for (a, b) in args: 126 self.checkequal("lt", a, b, [True, True, False, False, False]) 127 self.checkequal("le", a, b, [True, True, True, False, False]) 128 self.checkequal("eq", a, b, [False, False, True, False, False]) 129 self.checkequal("ne", a, b, [True, True, False, True, True ]) 130 self.checkequal("gt", a, b, [False, False, False, True, True ]) 131 self.checkequal("ge", a, b, [False, False, True, True, True ]) 132 133 for ops in opmap.values(): 134 for op in ops: 135 # calls __bool__, which should fail 136 self.assertRaises(TypeError, bool, op(a, b)) 137 138class NumberTest(unittest.TestCase): 139 140 def test_basic(self): 141 # Check that comparisons involving Number objects 142 # give the same results give as comparing the 143 # corresponding ints 144 for a in range(3): 145 for b in range(3): 146 for typea in (int, Number): 147 for typeb in (int, Number): 148 if typea==typeb==int: 149 continue # the combination int, int is useless 150 ta = typea(a) 151 tb = typeb(b) 152 for ops in opmap.values(): 153 for op in ops: 154 realoutcome = op(a, b) 155 testoutcome = op(ta, tb) 156 self.assertEqual(realoutcome, testoutcome) 157 158 def checkvalue(self, opname, a, b, expres): 159 for typea in (int, Number): 160 for typeb in (int, Number): 161 ta = typea(a) 162 tb = typeb(b) 163 for op in opmap[opname]: 164 realres = op(ta, tb) 165 realres = getattr(realres, "x", realres) 166 self.assertTrue(realres is expres) 167 168 def test_values(self): 169 # check all operators and all comparison results 170 self.checkvalue("lt", 0, 0, False) 171 self.checkvalue("le", 0, 0, True ) 172 self.checkvalue("eq", 0, 0, True ) 173 self.checkvalue("ne", 0, 0, False) 174 self.checkvalue("gt", 0, 0, False) 175 self.checkvalue("ge", 0, 0, True ) 176 177 self.checkvalue("lt", 0, 1, True ) 178 self.checkvalue("le", 0, 1, True ) 179 self.checkvalue("eq", 0, 1, False) 180 self.checkvalue("ne", 0, 1, True ) 181 self.checkvalue("gt", 0, 1, False) 182 self.checkvalue("ge", 0, 1, False) 183 184 self.checkvalue("lt", 1, 0, False) 185 self.checkvalue("le", 1, 0, False) 186 self.checkvalue("eq", 1, 0, False) 187 self.checkvalue("ne", 1, 0, True ) 188 self.checkvalue("gt", 1, 0, True ) 189 self.checkvalue("ge", 1, 0, True ) 190 191class MiscTest(unittest.TestCase): 192 193 def test_misbehavin(self): 194 class Misb: 195 def __lt__(self_, other): return 0 196 def __gt__(self_, other): return 0 197 def __eq__(self_, other): return 0 198 def __le__(self_, other): self.fail("This shouldn't happen") 199 def __ge__(self_, other): self.fail("This shouldn't happen") 200 def __ne__(self_, other): self.fail("This shouldn't happen") 201 a = Misb() 202 b = Misb() 203 self.assertEqual(a<b, 0) 204 self.assertEqual(a==b, 0) 205 self.assertEqual(a>b, 0) 206 207 def test_not(self): 208 # Check that exceptions in __bool__ are properly 209 # propagated by the not operator 210 import operator 211 class Exc(Exception): 212 pass 213 class Bad: 214 def __bool__(self): 215 raise Exc 216 217 def do(bad): 218 not bad 219 220 for func in (do, operator.not_): 221 self.assertRaises(Exc, func, Bad()) 222 223 @support.no_tracing 224 @support.infinite_recursion(25) 225 def test_recursion(self): 226 # Check that comparison for recursive objects fails gracefully 227 from collections import UserList 228 a = UserList() 229 b = UserList() 230 a.append(b) 231 b.append(a) 232 self.assertRaises(RecursionError, operator.eq, a, b) 233 self.assertRaises(RecursionError, operator.ne, a, b) 234 self.assertRaises(RecursionError, operator.lt, a, b) 235 self.assertRaises(RecursionError, operator.le, a, b) 236 self.assertRaises(RecursionError, operator.gt, a, b) 237 self.assertRaises(RecursionError, operator.ge, a, b) 238 239 b.append(17) 240 # Even recursive lists of different lengths are different, 241 # but they cannot be ordered 242 self.assertTrue(not (a == b)) 243 self.assertTrue(a != b) 244 self.assertRaises(RecursionError, operator.lt, a, b) 245 self.assertRaises(RecursionError, operator.le, a, b) 246 self.assertRaises(RecursionError, operator.gt, a, b) 247 self.assertRaises(RecursionError, operator.ge, a, b) 248 a.append(17) 249 self.assertRaises(RecursionError, operator.eq, a, b) 250 self.assertRaises(RecursionError, operator.ne, a, b) 251 a.insert(0, 11) 252 b.insert(0, 12) 253 self.assertTrue(not (a == b)) 254 self.assertTrue(a != b) 255 self.assertTrue(a < b) 256 257 def test_exception_message(self): 258 class Spam: 259 pass 260 261 tests = [ 262 (lambda: 42 < None, r"'<' .* of 'int' and 'NoneType'"), 263 (lambda: None < 42, r"'<' .* of 'NoneType' and 'int'"), 264 (lambda: 42 > None, r"'>' .* of 'int' and 'NoneType'"), 265 (lambda: "foo" < None, r"'<' .* of 'str' and 'NoneType'"), 266 (lambda: "foo" >= 666, r"'>=' .* of 'str' and 'int'"), 267 (lambda: 42 <= None, r"'<=' .* of 'int' and 'NoneType'"), 268 (lambda: 42 >= None, r"'>=' .* of 'int' and 'NoneType'"), 269 (lambda: 42 < [], r"'<' .* of 'int' and 'list'"), 270 (lambda: () > [], r"'>' .* of 'tuple' and 'list'"), 271 (lambda: None >= None, r"'>=' .* of 'NoneType' and 'NoneType'"), 272 (lambda: Spam() < 42, r"'<' .* of 'Spam' and 'int'"), 273 (lambda: 42 < Spam(), r"'<' .* of 'int' and 'Spam'"), 274 (lambda: Spam() <= Spam(), r"'<=' .* of 'Spam' and 'Spam'"), 275 ] 276 for i, test in enumerate(tests): 277 with self.subTest(test=i): 278 with self.assertRaisesRegex(TypeError, test[1]): 279 test[0]() 280 281 282class DictTest(unittest.TestCase): 283 284 def test_dicts(self): 285 # Verify that __eq__ and __ne__ work for dicts even if the keys and 286 # values don't support anything other than __eq__ and __ne__ (and 287 # __hash__). Complex numbers are a fine example of that. 288 import random 289 imag1a = {} 290 for i in range(50): 291 imag1a[random.randrange(100)*1j] = random.randrange(100)*1j 292 items = list(imag1a.items()) 293 random.shuffle(items) 294 imag1b = {} 295 for k, v in items: 296 imag1b[k] = v 297 imag2 = imag1b.copy() 298 imag2[k] = v + 1.0 299 self.assertEqual(imag1a, imag1a) 300 self.assertEqual(imag1a, imag1b) 301 self.assertEqual(imag2, imag2) 302 self.assertTrue(imag1a != imag2) 303 for opname in ("lt", "le", "gt", "ge"): 304 for op in opmap[opname]: 305 self.assertRaises(TypeError, op, imag1a, imag2) 306 307class ListTest(unittest.TestCase): 308 309 def test_coverage(self): 310 # exercise all comparisons for lists 311 x = [42] 312 self.assertIs(x<x, False) 313 self.assertIs(x<=x, True) 314 self.assertIs(x==x, True) 315 self.assertIs(x!=x, False) 316 self.assertIs(x>x, False) 317 self.assertIs(x>=x, True) 318 y = [42, 42] 319 self.assertIs(x<y, True) 320 self.assertIs(x<=y, True) 321 self.assertIs(x==y, False) 322 self.assertIs(x!=y, True) 323 self.assertIs(x>y, False) 324 self.assertIs(x>=y, False) 325 326 def test_badentry(self): 327 # make sure that exceptions for item comparison are properly 328 # propagated in list comparisons 329 class Exc(Exception): 330 pass 331 class Bad: 332 def __eq__(self, other): 333 raise Exc 334 335 x = [Bad()] 336 y = [Bad()] 337 338 for op in opmap["eq"]: 339 self.assertRaises(Exc, op, x, y) 340 341 def test_goodentry(self): 342 # This test exercises the final call to PyObject_RichCompare() 343 # in Objects/listobject.c::list_richcompare() 344 class Good: 345 def __lt__(self, other): 346 return True 347 348 x = [Good()] 349 y = [Good()] 350 351 for op in opmap["lt"]: 352 self.assertIs(op(x, y), True) 353 354 355if __name__ == "__main__": 356 unittest.main() 357