1import copy 2import functools 3import sys 4import unittest 5from test import test_support 6from weakref import proxy 7import pickle 8 9@staticmethod 10def PythonPartial(func, *args, **keywords): 11 'Pure Python approximation of partial()' 12 def newfunc(*fargs, **fkeywords): 13 newkeywords = keywords.copy() 14 newkeywords.update(fkeywords) 15 return func(*(args + fargs), **newkeywords) 16 newfunc.func = func 17 newfunc.args = args 18 newfunc.keywords = keywords 19 return newfunc 20 21def capture(*args, **kw): 22 """capture all positional and keyword arguments""" 23 return args, kw 24 25def signature(part): 26 """ return the signature of a partial object """ 27 return (part.func, part.args, part.keywords, part.__dict__) 28 29class MyTuple(tuple): 30 pass 31 32class BadTuple(tuple): 33 def __add__(self, other): 34 return list(self) + list(other) 35 36class MyDict(dict): 37 pass 38 39class TestPartial(unittest.TestCase): 40 41 partial = functools.partial 42 43 def test_basic_examples(self): 44 p = self.partial(capture, 1, 2, a=10, b=20) 45 self.assertEqual(p(3, 4, b=30, c=40), 46 ((1, 2, 3, 4), dict(a=10, b=30, c=40))) 47 p = self.partial(map, lambda x: x*10) 48 self.assertEqual(p([1,2,3,4]), [10, 20, 30, 40]) 49 50 def test_attributes(self): 51 p = self.partial(capture, 1, 2, a=10, b=20) 52 # attributes should be readable 53 self.assertEqual(p.func, capture) 54 self.assertEqual(p.args, (1, 2)) 55 self.assertEqual(p.keywords, dict(a=10, b=20)) 56 # attributes should not be writable 57 self.assertRaises(TypeError, setattr, p, 'func', map) 58 self.assertRaises(TypeError, setattr, p, 'args', (1, 2)) 59 self.assertRaises(TypeError, setattr, p, 'keywords', dict(a=1, b=2)) 60 61 p = self.partial(hex) 62 try: 63 del p.__dict__ 64 except TypeError: 65 pass 66 else: 67 self.fail('partial object allowed __dict__ to be deleted') 68 69 def test_argument_checking(self): 70 self.assertRaises(TypeError, self.partial) # need at least a func arg 71 try: 72 self.partial(2)() 73 except TypeError: 74 pass 75 else: 76 self.fail('First arg not checked for callability') 77 78 def test_protection_of_callers_dict_argument(self): 79 # a caller's dictionary should not be altered by partial 80 def func(a=10, b=20): 81 return a 82 d = {'a':3} 83 p = self.partial(func, a=5) 84 self.assertEqual(p(**d), 3) 85 self.assertEqual(d, {'a':3}) 86 p(b=7) 87 self.assertEqual(d, {'a':3}) 88 89 def test_arg_combinations(self): 90 # exercise special code paths for zero args in either partial 91 # object or the caller 92 p = self.partial(capture) 93 self.assertEqual(p(), ((), {})) 94 self.assertEqual(p(1,2), ((1,2), {})) 95 p = self.partial(capture, 1, 2) 96 self.assertEqual(p(), ((1,2), {})) 97 self.assertEqual(p(3,4), ((1,2,3,4), {})) 98 99 def test_kw_combinations(self): 100 # exercise special code paths for no keyword args in 101 # either the partial object or the caller 102 p = self.partial(capture) 103 self.assertEqual(p.keywords, {}) 104 self.assertEqual(p(), ((), {})) 105 self.assertEqual(p(a=1), ((), {'a':1})) 106 p = self.partial(capture, a=1) 107 self.assertEqual(p.keywords, {'a':1}) 108 self.assertEqual(p(), ((), {'a':1})) 109 self.assertEqual(p(b=2), ((), {'a':1, 'b':2})) 110 # keyword args in the call override those in the partial object 111 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2})) 112 113 def test_positional(self): 114 # make sure positional arguments are captured correctly 115 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]: 116 p = self.partial(capture, *args) 117 expected = args + ('x',) 118 got, empty = p('x') 119 self.assertTrue(expected == got and empty == {}) 120 121 def test_keyword(self): 122 # make sure keyword arguments are captured correctly 123 for a in ['a', 0, None, 3.5]: 124 p = self.partial(capture, a=a) 125 expected = {'a':a,'x':None} 126 empty, got = p(x=None) 127 self.assertTrue(expected == got and empty == ()) 128 129 def test_no_side_effects(self): 130 # make sure there are no side effects that affect subsequent calls 131 p = self.partial(capture, 0, a=1) 132 args1, kw1 = p(1, b=2) 133 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2}) 134 args2, kw2 = p() 135 self.assertTrue(args2 == (0,) and kw2 == {'a':1}) 136 137 def test_error_propagation(self): 138 def f(x, y): 139 x // y 140 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0)) 141 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0) 142 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0) 143 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1) 144 145 def test_weakref(self): 146 f = self.partial(int, base=16) 147 p = proxy(f) 148 self.assertEqual(f.func, p.func) 149 f = None 150 self.assertRaises(ReferenceError, getattr, p, 'func') 151 152 def test_with_bound_and_unbound_methods(self): 153 data = map(str, range(10)) 154 join = self.partial(str.join, '') 155 self.assertEqual(join(data), '0123456789') 156 join = self.partial(''.join) 157 self.assertEqual(join(data), '0123456789') 158 159 def test_pickle(self): 160 f = self.partial(signature, ['asdf'], bar=[True]) 161 f.attr = [] 162 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 163 f_copy = pickle.loads(pickle.dumps(f, proto)) 164 self.assertEqual(signature(f_copy), signature(f)) 165 166 def test_copy(self): 167 f = self.partial(signature, ['asdf'], bar=[True]) 168 f.attr = [] 169 f_copy = copy.copy(f) 170 self.assertEqual(signature(f_copy), signature(f)) 171 self.assertIs(f_copy.attr, f.attr) 172 self.assertIs(f_copy.args, f.args) 173 self.assertIs(f_copy.keywords, f.keywords) 174 175 def test_deepcopy(self): 176 f = self.partial(signature, ['asdf'], bar=[True]) 177 f.attr = [] 178 f_copy = copy.deepcopy(f) 179 self.assertEqual(signature(f_copy), signature(f)) 180 self.assertIsNot(f_copy.attr, f.attr) 181 self.assertIsNot(f_copy.args, f.args) 182 self.assertIsNot(f_copy.args[0], f.args[0]) 183 self.assertIsNot(f_copy.keywords, f.keywords) 184 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar']) 185 186 def test_setstate(self): 187 f = self.partial(signature) 188 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[]))) 189 self.assertEqual(signature(f), 190 (capture, (1,), dict(a=10), dict(attr=[]))) 191 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20})) 192 193 f.__setstate__((capture, (1,), dict(a=10), None)) 194 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {})) 195 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20})) 196 197 f.__setstate__((capture, (1,), None, None)) 198 #self.assertEqual(signature(f), (capture, (1,), {}, {})) 199 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20})) 200 self.assertEqual(f(2), ((1, 2), {})) 201 self.assertEqual(f(), ((1,), {})) 202 203 f.__setstate__((capture, (), {}, None)) 204 self.assertEqual(signature(f), (capture, (), {}, {})) 205 self.assertEqual(f(2, b=20), ((2,), {'b': 20})) 206 self.assertEqual(f(2), ((2,), {})) 207 self.assertEqual(f(), ((), {})) 208 209 def test_setstate_errors(self): 210 f = self.partial(signature) 211 self.assertRaises(TypeError, f.__setstate__, (capture, (), {})) 212 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None)) 213 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None]) 214 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None)) 215 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None)) 216 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None)) 217 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None)) 218 219 def test_setstate_subclasses(self): 220 f = self.partial(signature) 221 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None)) 222 s = signature(f) 223 self.assertEqual(s, (capture, (1,), dict(a=10), {})) 224 self.assertIs(type(s[1]), tuple) 225 self.assertIs(type(s[2]), dict) 226 r = f() 227 self.assertEqual(r, ((1,), {'a': 10})) 228 self.assertIs(type(r[0]), tuple) 229 self.assertIs(type(r[1]), dict) 230 231 f.__setstate__((capture, BadTuple((1,)), {}, None)) 232 s = signature(f) 233 self.assertEqual(s, (capture, (1,), {}, {})) 234 self.assertIs(type(s[1]), tuple) 235 r = f(2) 236 self.assertEqual(r, ((1, 2), {})) 237 self.assertIs(type(r[0]), tuple) 238 239 def test_recursive_pickle(self): 240 f = self.partial(capture) 241 f.__setstate__((f, (), {}, {})) 242 try: 243 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 244 with self.assertRaises(RuntimeError): 245 pickle.dumps(f, proto) 246 finally: 247 f.__setstate__((capture, (), {}, {})) 248 249 f = self.partial(capture) 250 f.__setstate__((capture, (f,), {}, {})) 251 try: 252 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 253 f_copy = pickle.loads(pickle.dumps(f, proto)) 254 try: 255 self.assertIs(f_copy.args[0], f_copy) 256 finally: 257 f_copy.__setstate__((capture, (), {}, {})) 258 finally: 259 f.__setstate__((capture, (), {}, {})) 260 261 f = self.partial(capture) 262 f.__setstate__((capture, (), {'a': f}, {})) 263 try: 264 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 265 f_copy = pickle.loads(pickle.dumps(f, proto)) 266 try: 267 self.assertIs(f_copy.keywords['a'], f_copy) 268 finally: 269 f_copy.__setstate__((capture, (), {}, {})) 270 finally: 271 f.__setstate__((capture, (), {}, {})) 272 273 # Issue 6083: Reference counting bug 274 def test_setstate_refcount(self): 275 class BadSequence: 276 def __len__(self): 277 return 4 278 def __getitem__(self, key): 279 if key == 0: 280 return max 281 elif key == 1: 282 return tuple(range(1000000)) 283 elif key in (2, 3): 284 return {} 285 raise IndexError 286 287 f = self.partial(object) 288 self.assertRaises(TypeError, f.__setstate__, BadSequence()) 289 290class PartialSubclass(functools.partial): 291 pass 292 293class TestPartialSubclass(TestPartial): 294 295 partial = PartialSubclass 296 297class TestPythonPartial(TestPartial): 298 299 partial = PythonPartial 300 301 # the python version isn't picklable 302 test_pickle = None 303 test_setstate = None 304 test_setstate_errors = None 305 test_setstate_subclasses = None 306 test_setstate_refcount = None 307 test_recursive_pickle = None 308 309 # the python version isn't deepcopyable 310 test_deepcopy = None 311 312 # the python version isn't a type 313 test_attributes = None 314 315class TestUpdateWrapper(unittest.TestCase): 316 317 def check_wrapper(self, wrapper, wrapped, 318 assigned=functools.WRAPPER_ASSIGNMENTS, 319 updated=functools.WRAPPER_UPDATES): 320 # Check attributes were assigned 321 for name in assigned: 322 self.assertTrue(getattr(wrapper, name) is getattr(wrapped, name)) 323 # Check attributes were updated 324 for name in updated: 325 wrapper_attr = getattr(wrapper, name) 326 wrapped_attr = getattr(wrapped, name) 327 for key in wrapped_attr: 328 self.assertTrue(wrapped_attr[key] is wrapper_attr[key]) 329 330 def _default_update(self): 331 def f(): 332 """This is a test""" 333 pass 334 f.attr = 'This is also a test' 335 def wrapper(): 336 pass 337 functools.update_wrapper(wrapper, f) 338 return wrapper, f 339 340 def test_default_update(self): 341 wrapper, f = self._default_update() 342 self.check_wrapper(wrapper, f) 343 self.assertEqual(wrapper.__name__, 'f') 344 self.assertEqual(wrapper.attr, 'This is also a test') 345 346 @unittest.skipIf(sys.flags.optimize >= 2, 347 "Docstrings are omitted with -O2 and above") 348 def test_default_update_doc(self): 349 wrapper, f = self._default_update() 350 self.assertEqual(wrapper.__doc__, 'This is a test') 351 352 def test_no_update(self): 353 def f(): 354 """This is a test""" 355 pass 356 f.attr = 'This is also a test' 357 def wrapper(): 358 pass 359 functools.update_wrapper(wrapper, f, (), ()) 360 self.check_wrapper(wrapper, f, (), ()) 361 self.assertEqual(wrapper.__name__, 'wrapper') 362 self.assertEqual(wrapper.__doc__, None) 363 self.assertFalse(hasattr(wrapper, 'attr')) 364 365 def test_selective_update(self): 366 def f(): 367 pass 368 f.attr = 'This is a different test' 369 f.dict_attr = dict(a=1, b=2, c=3) 370 def wrapper(): 371 pass 372 wrapper.dict_attr = {} 373 assign = ('attr',) 374 update = ('dict_attr',) 375 functools.update_wrapper(wrapper, f, assign, update) 376 self.check_wrapper(wrapper, f, assign, update) 377 self.assertEqual(wrapper.__name__, 'wrapper') 378 self.assertEqual(wrapper.__doc__, None) 379 self.assertEqual(wrapper.attr, 'This is a different test') 380 self.assertEqual(wrapper.dict_attr, f.dict_attr) 381 382 @test_support.requires_docstrings 383 def test_builtin_update(self): 384 # Test for bug #1576241 385 def wrapper(): 386 pass 387 functools.update_wrapper(wrapper, max) 388 self.assertEqual(wrapper.__name__, 'max') 389 self.assertTrue(wrapper.__doc__.startswith('max(')) 390 391class TestWraps(TestUpdateWrapper): 392 393 def _default_update(self): 394 def f(): 395 """This is a test""" 396 pass 397 f.attr = 'This is also a test' 398 @functools.wraps(f) 399 def wrapper(): 400 pass 401 self.check_wrapper(wrapper, f) 402 return wrapper 403 404 def test_default_update(self): 405 wrapper = self._default_update() 406 self.assertEqual(wrapper.__name__, 'f') 407 self.assertEqual(wrapper.attr, 'This is also a test') 408 409 @unittest.skipIf(sys.flags.optimize >= 2, 410 "Docstrings are omitted with -O2 and above") 411 def test_default_update_doc(self): 412 wrapper = self._default_update() 413 self.assertEqual(wrapper.__doc__, 'This is a test') 414 415 def test_no_update(self): 416 def f(): 417 """This is a test""" 418 pass 419 f.attr = 'This is also a test' 420 @functools.wraps(f, (), ()) 421 def wrapper(): 422 pass 423 self.check_wrapper(wrapper, f, (), ()) 424 self.assertEqual(wrapper.__name__, 'wrapper') 425 self.assertEqual(wrapper.__doc__, None) 426 self.assertFalse(hasattr(wrapper, 'attr')) 427 428 def test_selective_update(self): 429 def f(): 430 pass 431 f.attr = 'This is a different test' 432 f.dict_attr = dict(a=1, b=2, c=3) 433 def add_dict_attr(f): 434 f.dict_attr = {} 435 return f 436 assign = ('attr',) 437 update = ('dict_attr',) 438 @functools.wraps(f, assign, update) 439 @add_dict_attr 440 def wrapper(): 441 pass 442 self.check_wrapper(wrapper, f, assign, update) 443 self.assertEqual(wrapper.__name__, 'wrapper') 444 self.assertEqual(wrapper.__doc__, None) 445 self.assertEqual(wrapper.attr, 'This is a different test') 446 self.assertEqual(wrapper.dict_attr, f.dict_attr) 447 448 449class TestReduce(unittest.TestCase): 450 451 def test_reduce(self): 452 class Squares: 453 454 def __init__(self, max): 455 self.max = max 456 self.sofar = [] 457 458 def __len__(self): return len(self.sofar) 459 460 def __getitem__(self, i): 461 if not 0 <= i < self.max: raise IndexError 462 n = len(self.sofar) 463 while n <= i: 464 self.sofar.append(n*n) 465 n += 1 466 return self.sofar[i] 467 468 reduce = functools.reduce 469 self.assertEqual(reduce(lambda x, y: x+y, ['a', 'b', 'c'], ''), 'abc') 470 self.assertEqual( 471 reduce(lambda x, y: x+y, [['a', 'c'], [], ['d', 'w']], []), 472 ['a','c','d','w'] 473 ) 474 self.assertEqual(reduce(lambda x, y: x*y, range(2,8), 1), 5040) 475 self.assertEqual( 476 reduce(lambda x, y: x*y, range(2,21), 1L), 477 2432902008176640000L 478 ) 479 self.assertEqual(reduce(lambda x, y: x+y, Squares(10)), 285) 480 self.assertEqual(reduce(lambda x, y: x+y, Squares(10), 0), 285) 481 self.assertEqual(reduce(lambda x, y: x+y, Squares(0), 0), 0) 482 self.assertRaises(TypeError, reduce) 483 self.assertRaises(TypeError, reduce, 42, 42) 484 self.assertRaises(TypeError, reduce, 42, 42, 42) 485 self.assertEqual(reduce(42, "1"), "1") # func is never called with one item 486 self.assertEqual(reduce(42, "", "1"), "1") # func is never called with one item 487 self.assertRaises(TypeError, reduce, 42, (42, 42)) 488 489class TestCmpToKey(unittest.TestCase): 490 def test_cmp_to_key(self): 491 def mycmp(x, y): 492 return y - x 493 self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)), 494 [4, 3, 2, 1, 0]) 495 496 def test_hash(self): 497 def mycmp(x, y): 498 return y - x 499 key = functools.cmp_to_key(mycmp) 500 k = key(10) 501 self.assertRaises(TypeError, hash(k)) 502 503class TestTotalOrdering(unittest.TestCase): 504 505 def test_total_ordering_lt(self): 506 @functools.total_ordering 507 class A: 508 def __init__(self, value): 509 self.value = value 510 def __lt__(self, other): 511 return self.value < other.value 512 def __eq__(self, other): 513 return self.value == other.value 514 self.assertTrue(A(1) < A(2)) 515 self.assertTrue(A(2) > A(1)) 516 self.assertTrue(A(1) <= A(2)) 517 self.assertTrue(A(2) >= A(1)) 518 self.assertTrue(A(2) <= A(2)) 519 self.assertTrue(A(2) >= A(2)) 520 521 def test_total_ordering_le(self): 522 @functools.total_ordering 523 class A: 524 def __init__(self, value): 525 self.value = value 526 def __le__(self, other): 527 return self.value <= other.value 528 def __eq__(self, other): 529 return self.value == other.value 530 self.assertTrue(A(1) < A(2)) 531 self.assertTrue(A(2) > A(1)) 532 self.assertTrue(A(1) <= A(2)) 533 self.assertTrue(A(2) >= A(1)) 534 self.assertTrue(A(2) <= A(2)) 535 self.assertTrue(A(2) >= A(2)) 536 537 def test_total_ordering_gt(self): 538 @functools.total_ordering 539 class A: 540 def __init__(self, value): 541 self.value = value 542 def __gt__(self, other): 543 return self.value > other.value 544 def __eq__(self, other): 545 return self.value == other.value 546 self.assertTrue(A(1) < A(2)) 547 self.assertTrue(A(2) > A(1)) 548 self.assertTrue(A(1) <= A(2)) 549 self.assertTrue(A(2) >= A(1)) 550 self.assertTrue(A(2) <= A(2)) 551 self.assertTrue(A(2) >= A(2)) 552 553 def test_total_ordering_ge(self): 554 @functools.total_ordering 555 class A: 556 def __init__(self, value): 557 self.value = value 558 def __ge__(self, other): 559 return self.value >= other.value 560 def __eq__(self, other): 561 return self.value == other.value 562 self.assertTrue(A(1) < A(2)) 563 self.assertTrue(A(2) > A(1)) 564 self.assertTrue(A(1) <= A(2)) 565 self.assertTrue(A(2) >= A(1)) 566 self.assertTrue(A(2) <= A(2)) 567 self.assertTrue(A(2) >= A(2)) 568 569 def test_total_ordering_no_overwrite(self): 570 # new methods should not overwrite existing 571 @functools.total_ordering 572 class A(str): 573 pass 574 self.assertTrue(A("a") < A("b")) 575 self.assertTrue(A("b") > A("a")) 576 self.assertTrue(A("a") <= A("b")) 577 self.assertTrue(A("b") >= A("a")) 578 self.assertTrue(A("b") <= A("b")) 579 self.assertTrue(A("b") >= A("b")) 580 581 def test_no_operations_defined(self): 582 with self.assertRaises(ValueError): 583 @functools.total_ordering 584 class A: 585 pass 586 587 def test_bug_10042(self): 588 @functools.total_ordering 589 class TestTO: 590 def __init__(self, value): 591 self.value = value 592 def __eq__(self, other): 593 if isinstance(other, TestTO): 594 return self.value == other.value 595 return False 596 def __lt__(self, other): 597 if isinstance(other, TestTO): 598 return self.value < other.value 599 raise TypeError 600 with self.assertRaises(TypeError): 601 TestTO(8) <= () 602 603def test_main(verbose=None): 604 test_classes = ( 605 TestPartial, 606 TestPartialSubclass, 607 TestPythonPartial, 608 TestUpdateWrapper, 609 TestTotalOrdering, 610 TestWraps, 611 TestReduce, 612 ) 613 test_support.run_unittest(*test_classes) 614 615 # verify reference counting 616 if verbose and hasattr(sys, "gettotalrefcount"): 617 import gc 618 counts = [None] * 5 619 for i in xrange(len(counts)): 620 test_support.run_unittest(*test_classes) 621 gc.collect() 622 counts[i] = sys.gettotalrefcount() 623 print counts 624 625if __name__ == '__main__': 626 test_main(verbose=True) 627