1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Unit tests for tf_decorator.""" 16 17# pylint: disable=unused-import 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import functools 23 24from tensorflow.python.platform import test 25from tensorflow.python.platform import tf_logging as logging 26from tensorflow.python.util import tf_decorator 27from tensorflow.python.util import tf_inspect 28 29 30def test_tfdecorator(decorator_name, decorator_doc=None): 31 32 def make_tf_decorator(target): 33 return tf_decorator.TFDecorator(decorator_name, target, decorator_doc) 34 35 return make_tf_decorator 36 37 38def test_decorator_increment_first_int_arg(target): 39 """This test decorator skips past `self` as args[0] in the bound case.""" 40 41 def wrapper(*args, **kwargs): 42 new_args = [] 43 found = False 44 for arg in args: 45 if not found and isinstance(arg, int): 46 new_args.append(arg + 1) 47 found = True 48 else: 49 new_args.append(arg) 50 return target(*new_args, **kwargs) 51 52 return tf_decorator.make_decorator(target, wrapper) 53 54 55def test_injectable_decorator_square(target): 56 57 def wrapper(x): 58 return wrapper.__wrapped__(x)**2 59 60 return tf_decorator.make_decorator(target, wrapper) 61 62 63def test_injectable_decorator_increment(target): 64 65 def wrapper(x): 66 return wrapper.__wrapped__(x) + 1 67 68 return tf_decorator.make_decorator(target, wrapper) 69 70 71def test_function(x): 72 """Test Function Docstring.""" 73 return x + 1 74 75 76@test_tfdecorator('decorator 1') 77@test_decorator_increment_first_int_arg 78@test_tfdecorator('decorator 3', 'decorator 3 documentation') 79def test_decorated_function(x): 80 """Test Decorated Function Docstring.""" 81 return x * 2 82 83 84@test_injectable_decorator_square 85@test_injectable_decorator_increment 86def test_rewrappable_decorated(x): 87 return x * 2 88 89 90@test_tfdecorator('decorator') 91class TestDecoratedClass(object): 92 """Test Decorated Class.""" 93 94 def __init__(self, two_attr=2): 95 self.two_attr = two_attr 96 97 @property 98 def two_prop(self): 99 return 2 100 101 def two_func(self): 102 return 2 103 104 @test_decorator_increment_first_int_arg 105 def return_params(self, a, b, c): 106 """Return parameters.""" 107 return [a, b, c] 108 109 110class TfDecoratorTest(test.TestCase): 111 112 def testInitCapturesTarget(self): 113 self.assertIs(test_function, 114 tf_decorator.TFDecorator('', test_function).decorated_target) 115 116 def testInitCapturesDecoratorName(self): 117 self.assertEqual('decorator name', 118 tf_decorator.TFDecorator('decorator name', 119 test_function).decorator_name) 120 121 def testInitCapturesDecoratorDoc(self): 122 self.assertEqual('decorator doc', 123 tf_decorator.TFDecorator('', test_function, 124 'decorator doc').decorator_doc) 125 126 def testInitCapturesNonNoneArgspec(self): 127 argspec = tf_inspect.ArgSpec( 128 args=['a', 'b', 'c'], 129 varargs=None, 130 keywords=None, 131 defaults=(1, 'hello')) 132 self.assertIs(argspec, 133 tf_decorator.TFDecorator('', test_function, '', 134 argspec).decorator_argspec) 135 136 def testInitSetsDecoratorNameToTargetName(self): 137 self.assertEqual('test_function', 138 tf_decorator.TFDecorator('', test_function).__name__) 139 140 def testInitSetsDecoratorQualNameToTargetQualName(self): 141 if hasattr(tf_decorator.TFDecorator('', test_function), '__qualname__'): 142 self.assertEqual('test_function', 143 tf_decorator.TFDecorator('', test_function).__qualname__) 144 145 def testInitSetsDecoratorDocToTargetDoc(self): 146 self.assertEqual('Test Function Docstring.', 147 tf_decorator.TFDecorator('', test_function).__doc__) 148 149 def testCallingATFDecoratorCallsTheTarget(self): 150 self.assertEqual(124, tf_decorator.TFDecorator('', test_function)(123)) 151 152 def testCallingADecoratedFunctionCallsTheTarget(self): 153 self.assertEqual((2 + 1) * 2, test_decorated_function(2)) 154 155 def testInitializingDecoratedClassWithInitParamsDoesntRaise(self): 156 try: 157 TestDecoratedClass(2) 158 except TypeError: 159 self.assertFail() 160 161 def testReadingClassAttributeOnDecoratedClass(self): 162 self.assertEqual(2, TestDecoratedClass().two_attr) 163 164 def testCallingClassMethodOnDecoratedClass(self): 165 self.assertEqual(2, TestDecoratedClass().two_func()) 166 167 def testReadingClassPropertyOnDecoratedClass(self): 168 self.assertEqual(2, TestDecoratedClass().two_prop) 169 170 def testNameOnBoundProperty(self): 171 self.assertEqual('return_params', 172 TestDecoratedClass().return_params.__name__) 173 174 def testQualNameOnBoundProperty(self): 175 if hasattr(TestDecoratedClass().return_params, '__qualname__'): 176 self.assertEqual('TestDecoratedClass.return_params', 177 TestDecoratedClass().return_params.__qualname__) 178 179 def testDocstringOnBoundProperty(self): 180 self.assertEqual('Return parameters.', 181 TestDecoratedClass().return_params.__doc__) 182 183 def testTarget__get__IsProxied(self): 184 class Descr(object): 185 186 def __get__(self, instance, owner): 187 return self 188 189 class Foo(object): 190 foo = tf_decorator.TFDecorator('Descr', Descr()) 191 192 self.assertIsInstance(Foo.foo, Descr) 193 194 195def test_wrapper(*args, **kwargs): 196 return test_function(*args, **kwargs) 197 198 199class TfMakeDecoratorTest(test.TestCase): 200 201 def testAttachesATFDecoratorAttr(self): 202 decorated = tf_decorator.make_decorator(test_function, test_wrapper) 203 decorator = getattr(decorated, '_tf_decorator') 204 self.assertIsInstance(decorator, tf_decorator.TFDecorator) 205 206 def testAttachesWrappedAttr(self): 207 decorated = tf_decorator.make_decorator(test_function, test_wrapper) 208 wrapped_attr = getattr(decorated, '__wrapped__') 209 self.assertIs(test_function, wrapped_attr) 210 211 def testSetsTFDecoratorNameToDecoratorNameArg(self): 212 decorated = tf_decorator.make_decorator(test_function, test_wrapper, 213 'test decorator name') 214 decorator = getattr(decorated, '_tf_decorator') 215 self.assertEqual('test decorator name', decorator.decorator_name) 216 217 def testSetsTFDecoratorDocToDecoratorDocArg(self): 218 decorated = tf_decorator.make_decorator( 219 test_function, test_wrapper, decorator_doc='test decorator doc') 220 decorator = getattr(decorated, '_tf_decorator') 221 self.assertEqual('test decorator doc', decorator.decorator_doc) 222 223 def testUpdatesDictWithMissingEntries(self): 224 test_function.foobar = True 225 decorated = tf_decorator.make_decorator(test_function, test_wrapper) 226 self.assertTrue(decorated.foobar) 227 del test_function.foobar 228 229 def testUpdatesDict_doesNotOverridePresentEntries(self): 230 test_function.foobar = True 231 test_wrapper.foobar = False 232 decorated = tf_decorator.make_decorator(test_function, test_wrapper) 233 self.assertFalse(decorated.foobar) 234 del test_function.foobar 235 del test_wrapper.foobar 236 237 def testSetsTFDecoratorArgSpec(self): 238 argspec = tf_inspect.ArgSpec( 239 args=['a', 'b', 'c'], 240 varargs=None, 241 keywords=None, 242 defaults=(1, 'hello')) 243 decorated = tf_decorator.make_decorator(test_function, test_wrapper, '', '', 244 argspec) 245 decorator = getattr(decorated, '_tf_decorator') 246 self.assertEqual(argspec, decorator.decorator_argspec) 247 248 def testSetsDecoratorNameToFunctionThatCallsMakeDecoratorIfAbsent(self): 249 250 def test_decorator_name(wrapper): 251 return tf_decorator.make_decorator(test_function, wrapper) 252 253 decorated = test_decorator_name(test_wrapper) 254 decorator = getattr(decorated, '_tf_decorator') 255 self.assertEqual('test_decorator_name', decorator.decorator_name) 256 257 def testCompatibleWithNamelessCallables(self): 258 259 class Callable(object): 260 261 def __call__(self): 262 pass 263 264 callable_object = Callable() 265 # Smoke test: This should not raise an exception, even though 266 # `callable_object` does not have a `__name__` attribute. 267 _ = tf_decorator.make_decorator(callable_object, test_wrapper) 268 269 partial = functools.partial(test_function, x=1) 270 # Smoke test: This should not raise an exception, even though `partial` does 271 # not have `__name__`, `__module__`, and `__doc__` attributes. 272 _ = tf_decorator.make_decorator(partial, test_wrapper) 273 274 275class TfDecoratorRewrapTest(test.TestCase): 276 277 def testRewrapMutatesAffectedFunction(self): 278 279 def new_target(x): 280 return x * 3 281 282 self.assertEqual((1 * 2 + 1) ** 2, test_rewrappable_decorated(1)) 283 prev_target, _ = tf_decorator.unwrap(test_rewrappable_decorated) 284 tf_decorator.rewrap(test_rewrappable_decorated, prev_target, new_target) 285 self.assertEqual((1 * 3 + 1) ** 2, test_rewrappable_decorated(1)) 286 287 def testRewrapOfDecoratorFunction(self): 288 289 def new_target(x): 290 return x * 3 291 292 prev_target = test_rewrappable_decorated._tf_decorator._decorated_target 293 # In this case, only the outer decorator (test_injectable_decorator_square) 294 # should be preserved. 295 tf_decorator.rewrap(test_rewrappable_decorated, prev_target, new_target) 296 self.assertEqual((1 * 3) ** 2, test_rewrappable_decorated(1)) 297 298 299class TfDecoratorUnwrapTest(test.TestCase): 300 301 def testUnwrapReturnsEmptyArrayForUndecoratedFunction(self): 302 decorators, _ = tf_decorator.unwrap(test_function) 303 self.assertEqual(0, len(decorators)) 304 305 def testUnwrapReturnsUndecoratedFunctionAsTarget(self): 306 _, target = tf_decorator.unwrap(test_function) 307 self.assertIs(test_function, target) 308 309 def testUnwrapReturnsFinalFunctionAsTarget(self): 310 self.assertEqual((4 + 1) * 2, test_decorated_function(4)) 311 _, target = tf_decorator.unwrap(test_decorated_function) 312 self.assertTrue(tf_inspect.isfunction(target)) 313 self.assertEqual(4 * 2, target(4)) 314 315 def testUnwrapReturnsListOfUniqueTFDecorators(self): 316 decorators, _ = tf_decorator.unwrap(test_decorated_function) 317 self.assertEqual(3, len(decorators)) 318 self.assertTrue(isinstance(decorators[0], tf_decorator.TFDecorator)) 319 self.assertTrue(isinstance(decorators[1], tf_decorator.TFDecorator)) 320 self.assertTrue(isinstance(decorators[2], tf_decorator.TFDecorator)) 321 self.assertIsNot(decorators[0], decorators[1]) 322 self.assertIsNot(decorators[1], decorators[2]) 323 self.assertIsNot(decorators[2], decorators[0]) 324 325 def testUnwrapReturnsDecoratorListFromOutermostToInnermost(self): 326 decorators, _ = tf_decorator.unwrap(test_decorated_function) 327 self.assertEqual('decorator 1', decorators[0].decorator_name) 328 self.assertEqual('test_decorator_increment_first_int_arg', 329 decorators[1].decorator_name) 330 self.assertEqual('decorator 3', decorators[2].decorator_name) 331 self.assertEqual('decorator 3 documentation', decorators[2].decorator_doc) 332 333 def testUnwrapBoundMethods(self): 334 test_decorated_class = TestDecoratedClass() 335 self.assertEqual([2, 2, 3], test_decorated_class.return_params(1, 2, 3)) 336 decorators, target = tf_decorator.unwrap(test_decorated_class.return_params) 337 self.assertEqual('test_decorator_increment_first_int_arg', 338 decorators[0].decorator_name) 339 self.assertEqual([1, 2, 3], target(test_decorated_class, 1, 2, 3)) 340 341 342if __name__ == '__main__': 343 test.main() 344