1from test import support 2import unittest 3from types import MethodType 4 5def funcattrs(**kwds): 6 def decorate(func): 7 func.__dict__.update(kwds) 8 return func 9 return decorate 10 11class MiscDecorators (object): 12 @staticmethod 13 def author(name): 14 def decorate(func): 15 func.__dict__['author'] = name 16 return func 17 return decorate 18 19# ----------------------------------------------- 20 21class DbcheckError (Exception): 22 def __init__(self, exprstr, func, args, kwds): 23 # A real version of this would set attributes here 24 Exception.__init__(self, "dbcheck %r failed (func=%s args=%s kwds=%s)" % 25 (exprstr, func, args, kwds)) 26 27 28def dbcheck(exprstr, globals=None, locals=None): 29 "Decorator to implement debugging assertions" 30 def decorate(func): 31 expr = compile(exprstr, "dbcheck-%s" % func.__name__, "eval") 32 def check(*args, **kwds): 33 if not eval(expr, globals, locals): 34 raise DbcheckError(exprstr, func, args, kwds) 35 return func(*args, **kwds) 36 return check 37 return decorate 38 39# ----------------------------------------------- 40 41def countcalls(counts): 42 "Decorator to count calls to a function" 43 def decorate(func): 44 func_name = func.__name__ 45 counts[func_name] = 0 46 def call(*args, **kwds): 47 counts[func_name] += 1 48 return func(*args, **kwds) 49 call.__name__ = func_name 50 return call 51 return decorate 52 53# ----------------------------------------------- 54 55def memoize(func): 56 saved = {} 57 def call(*args): 58 try: 59 return saved[args] 60 except KeyError: 61 res = func(*args) 62 saved[args] = res 63 return res 64 except TypeError: 65 # Unhashable argument 66 return func(*args) 67 call.__name__ = func.__name__ 68 return call 69 70# ----------------------------------------------- 71 72class TestDecorators(unittest.TestCase): 73 74 def test_single(self): 75 class C(object): 76 @staticmethod 77 def foo(): return 42 78 self.assertEqual(C.foo(), 42) 79 self.assertEqual(C().foo(), 42) 80 81 def check_wrapper_attrs(self, method_wrapper, format_str): 82 def func(x): 83 return x 84 wrapper = method_wrapper(func) 85 86 self.assertIs(wrapper.__func__, func) 87 self.assertIs(wrapper.__wrapped__, func) 88 89 for attr in ('__module__', '__qualname__', '__name__', 90 '__doc__', '__annotations__'): 91 self.assertIs(getattr(wrapper, attr), 92 getattr(func, attr)) 93 94 self.assertEqual(repr(wrapper), format_str.format(func)) 95 return wrapper 96 97 def test_staticmethod(self): 98 wrapper = self.check_wrapper_attrs(staticmethod, '<staticmethod({!r})>') 99 100 # bpo-43682: Static methods are callable since Python 3.10 101 self.assertEqual(wrapper(1), 1) 102 103 def test_classmethod(self): 104 wrapper = self.check_wrapper_attrs(classmethod, '<classmethod({!r})>') 105 106 self.assertRaises(TypeError, wrapper, 1) 107 108 def test_dotted(self): 109 decorators = MiscDecorators() 110 @decorators.author('Cleese') 111 def foo(): return 42 112 self.assertEqual(foo(), 42) 113 self.assertEqual(foo.author, 'Cleese') 114 115 def test_argforms(self): 116 # A few tests of argument passing, as we use restricted form 117 # of expressions for decorators. 118 119 def noteargs(*args, **kwds): 120 def decorate(func): 121 setattr(func, 'dbval', (args, kwds)) 122 return func 123 return decorate 124 125 args = ( 'Now', 'is', 'the', 'time' ) 126 kwds = dict(one=1, two=2) 127 @noteargs(*args, **kwds) 128 def f1(): return 42 129 self.assertEqual(f1(), 42) 130 self.assertEqual(f1.dbval, (args, kwds)) 131 132 @noteargs('terry', 'gilliam', eric='idle', john='cleese') 133 def f2(): return 84 134 self.assertEqual(f2(), 84) 135 self.assertEqual(f2.dbval, (('terry', 'gilliam'), 136 dict(eric='idle', john='cleese'))) 137 138 @noteargs(1, 2,) 139 def f3(): pass 140 self.assertEqual(f3.dbval, ((1, 2), {})) 141 142 def test_dbcheck(self): 143 @dbcheck('args[1] is not None') 144 def f(a, b): 145 return a + b 146 self.assertEqual(f(1, 2), 3) 147 self.assertRaises(DbcheckError, f, 1, None) 148 149 def test_memoize(self): 150 counts = {} 151 152 @memoize 153 @countcalls(counts) 154 def double(x): 155 return x * 2 156 self.assertEqual(double.__name__, 'double') 157 158 self.assertEqual(counts, dict(double=0)) 159 160 # Only the first call with a given argument bumps the call count: 161 # 162 self.assertEqual(double(2), 4) 163 self.assertEqual(counts['double'], 1) 164 self.assertEqual(double(2), 4) 165 self.assertEqual(counts['double'], 1) 166 self.assertEqual(double(3), 6) 167 self.assertEqual(counts['double'], 2) 168 169 # Unhashable arguments do not get memoized: 170 # 171 self.assertEqual(double([10]), [10, 10]) 172 self.assertEqual(counts['double'], 3) 173 self.assertEqual(double([10]), [10, 10]) 174 self.assertEqual(counts['double'], 4) 175 176 def test_errors(self): 177 178 # Test SyntaxErrors: 179 for stmt in ("x,", "x, y", "x = y", "pass", "import sys"): 180 compile(stmt, "test", "exec") # Sanity check. 181 with self.assertRaises(SyntaxError): 182 compile(f"@{stmt}\ndef f(): pass", "test", "exec") 183 184 # Test TypeErrors that used to be SyntaxErrors: 185 for expr in ("1.+2j", "[1, 2][-1]", "(1, 2)", "True", "...", "None"): 186 compile(expr, "test", "eval") # Sanity check. 187 with self.assertRaises(TypeError): 188 exec(f"@{expr}\ndef f(): pass") 189 190 def unimp(func): 191 raise NotImplementedError 192 context = dict(nullval=None, unimp=unimp) 193 194 for expr, exc in [ ("undef", NameError), 195 ("nullval", TypeError), 196 ("nullval.attr", AttributeError), 197 ("unimp", NotImplementedError)]: 198 codestr = "@%s\ndef f(): pass\nassert f() is None" % expr 199 code = compile(codestr, "test", "exec") 200 self.assertRaises(exc, eval, code, context) 201 202 def test_expressions(self): 203 for expr in ( 204 "(x,)", "(x, y)", "x := y", "(x := y)", "x @y", "(x @ y)", "x[0]", 205 "w[x].y.z", "w + x - (y + z)", "x(y)()(z)", "[w, x, y][z]", "x.y", 206 ): 207 compile(f"@{expr}\ndef f(): pass", "test", "exec") 208 209 def test_double(self): 210 class C(object): 211 @funcattrs(abc=1, xyz="haha") 212 @funcattrs(booh=42) 213 def foo(self): return 42 214 self.assertEqual(C().foo(), 42) 215 self.assertEqual(C.foo.abc, 1) 216 self.assertEqual(C.foo.xyz, "haha") 217 self.assertEqual(C.foo.booh, 42) 218 219 def test_order(self): 220 # Test that decorators are applied in the proper order to the function 221 # they are decorating. 222 def callnum(num): 223 """Decorator factory that returns a decorator that replaces the 224 passed-in function with one that returns the value of 'num'""" 225 def deco(func): 226 return lambda: num 227 return deco 228 @callnum(2) 229 @callnum(1) 230 def foo(): return 42 231 self.assertEqual(foo(), 2, 232 "Application order of decorators is incorrect") 233 234 def test_eval_order(self): 235 # Evaluating a decorated function involves four steps for each 236 # decorator-maker (the function that returns a decorator): 237 # 238 # 1: Evaluate the decorator-maker name 239 # 2: Evaluate the decorator-maker arguments (if any) 240 # 3: Call the decorator-maker to make a decorator 241 # 4: Call the decorator 242 # 243 # When there are multiple decorators, these steps should be 244 # performed in the above order for each decorator, but we should 245 # iterate through the decorators in the reverse of the order they 246 # appear in the source. 247 248 actions = [] 249 250 def make_decorator(tag): 251 actions.append('makedec' + tag) 252 def decorate(func): 253 actions.append('calldec' + tag) 254 return func 255 return decorate 256 257 class NameLookupTracer (object): 258 def __init__(self, index): 259 self.index = index 260 261 def __getattr__(self, fname): 262 if fname == 'make_decorator': 263 opname, res = ('evalname', make_decorator) 264 elif fname == 'arg': 265 opname, res = ('evalargs', str(self.index)) 266 else: 267 assert False, "Unknown attrname %s" % fname 268 actions.append('%s%d' % (opname, self.index)) 269 return res 270 271 c1, c2, c3 = map(NameLookupTracer, [ 1, 2, 3 ]) 272 273 expected_actions = [ 'evalname1', 'evalargs1', 'makedec1', 274 'evalname2', 'evalargs2', 'makedec2', 275 'evalname3', 'evalargs3', 'makedec3', 276 'calldec3', 'calldec2', 'calldec1' ] 277 278 actions = [] 279 @c1.make_decorator(c1.arg) 280 @c2.make_decorator(c2.arg) 281 @c3.make_decorator(c3.arg) 282 def foo(): return 42 283 self.assertEqual(foo(), 42) 284 285 self.assertEqual(actions, expected_actions) 286 287 # Test the equivalence claim in chapter 7 of the reference manual. 288 # 289 actions = [] 290 def bar(): return 42 291 bar = c1.make_decorator(c1.arg)(c2.make_decorator(c2.arg)(c3.make_decorator(c3.arg)(bar))) 292 self.assertEqual(bar(), 42) 293 self.assertEqual(actions, expected_actions) 294 295 def test_wrapped_descriptor_inside_classmethod(self): 296 class BoundWrapper: 297 def __init__(self, wrapped): 298 self.__wrapped__ = wrapped 299 300 def __call__(self, *args, **kwargs): 301 return self.__wrapped__(*args, **kwargs) 302 303 class Wrapper: 304 def __init__(self, wrapped): 305 self.__wrapped__ = wrapped 306 307 def __get__(self, instance, owner): 308 bound_function = self.__wrapped__.__get__(instance, owner) 309 return BoundWrapper(bound_function) 310 311 def decorator(wrapped): 312 return Wrapper(wrapped) 313 314 class Class: 315 @decorator 316 @classmethod 317 def inner(cls): 318 # This should already work. 319 return 'spam' 320 321 @classmethod 322 @decorator 323 def outer(cls): 324 # Raised TypeError with a message saying that the 'Wrapper' 325 # object is not callable. 326 return 'eggs' 327 328 self.assertEqual(Class.inner(), 'spam') 329 self.assertEqual(Class.outer(), 'eggs') 330 self.assertEqual(Class().inner(), 'spam') 331 self.assertEqual(Class().outer(), 'eggs') 332 333 def test_wrapped_classmethod_inside_classmethod(self): 334 class MyClassMethod1: 335 def __init__(self, func): 336 self.func = func 337 338 def __call__(self, cls): 339 if hasattr(self.func, '__get__'): 340 return self.func.__get__(cls, cls)() 341 return self.func(cls) 342 343 def __get__(self, instance, owner=None): 344 if owner is None: 345 owner = type(instance) 346 return MethodType(self, owner) 347 348 class MyClassMethod2: 349 def __init__(self, func): 350 if isinstance(func, classmethod): 351 func = func.__func__ 352 self.func = func 353 354 def __call__(self, cls): 355 return self.func(cls) 356 357 def __get__(self, instance, owner=None): 358 if owner is None: 359 owner = type(instance) 360 return MethodType(self, owner) 361 362 for myclassmethod in [MyClassMethod1, MyClassMethod2]: 363 class A: 364 @myclassmethod 365 def f1(cls): 366 return cls 367 368 @classmethod 369 @myclassmethod 370 def f2(cls): 371 return cls 372 373 @myclassmethod 374 @classmethod 375 def f3(cls): 376 return cls 377 378 @classmethod 379 @classmethod 380 def f4(cls): 381 return cls 382 383 @myclassmethod 384 @MyClassMethod1 385 def f5(cls): 386 return cls 387 388 @myclassmethod 389 @MyClassMethod2 390 def f6(cls): 391 return cls 392 393 self.assertIs(A.f1(), A) 394 self.assertIs(A.f2(), A) 395 self.assertIs(A.f3(), A) 396 self.assertIs(A.f4(), A) 397 self.assertIs(A.f5(), A) 398 self.assertIs(A.f6(), A) 399 a = A() 400 self.assertIs(a.f1(), A) 401 self.assertIs(a.f2(), A) 402 self.assertIs(a.f3(), A) 403 self.assertIs(a.f4(), A) 404 self.assertIs(a.f5(), A) 405 self.assertIs(a.f6(), A) 406 407 def f(cls): 408 return cls 409 410 self.assertIs(myclassmethod(f).__get__(a)(), A) 411 self.assertIs(myclassmethod(f).__get__(a, A)(), A) 412 self.assertIs(myclassmethod(f).__get__(A, A)(), A) 413 self.assertIs(myclassmethod(f).__get__(A)(), type(A)) 414 self.assertIs(classmethod(f).__get__(a)(), A) 415 self.assertIs(classmethod(f).__get__(a, A)(), A) 416 self.assertIs(classmethod(f).__get__(A, A)(), A) 417 self.assertIs(classmethod(f).__get__(A)(), type(A)) 418 419class TestClassDecorators(unittest.TestCase): 420 421 def test_simple(self): 422 def plain(x): 423 x.extra = 'Hello' 424 return x 425 @plain 426 class C(object): pass 427 self.assertEqual(C.extra, 'Hello') 428 429 def test_double(self): 430 def ten(x): 431 x.extra = 10 432 return x 433 def add_five(x): 434 x.extra += 5 435 return x 436 437 @add_five 438 @ten 439 class C(object): pass 440 self.assertEqual(C.extra, 15) 441 442 def test_order(self): 443 def applied_first(x): 444 x.extra = 'first' 445 return x 446 def applied_second(x): 447 x.extra = 'second' 448 return x 449 @applied_second 450 @applied_first 451 class C(object): pass 452 self.assertEqual(C.extra, 'second') 453 454if __name__ == "__main__": 455 unittest.main() 456