• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Unit tests for tf_decorator."""
16
17# pylint: disable=unused-import
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import functools
23
24from tensorflow.python.platform import test
25from tensorflow.python.platform import tf_logging as logging
26from tensorflow.python.util import tf_decorator
27from tensorflow.python.util import tf_inspect
28
29
30def test_tfdecorator(decorator_name, decorator_doc=None):
31
32  def make_tf_decorator(target):
33    return tf_decorator.TFDecorator(decorator_name, target, decorator_doc)
34
35  return make_tf_decorator
36
37
38def test_decorator_increment_first_int_arg(target):
39  """This test decorator skips past `self` as args[0] in the bound case."""
40
41  def wrapper(*args, **kwargs):
42    new_args = []
43    found = False
44    for arg in args:
45      if not found and isinstance(arg, int):
46        new_args.append(arg + 1)
47        found = True
48      else:
49        new_args.append(arg)
50    return target(*new_args, **kwargs)
51
52  return tf_decorator.make_decorator(target, wrapper)
53
54
55def test_injectable_decorator_square(target):
56
57  def wrapper(x):
58    return wrapper.__wrapped__(x)**2
59
60  return tf_decorator.make_decorator(target, wrapper)
61
62
63def test_injectable_decorator_increment(target):
64
65  def wrapper(x):
66    return wrapper.__wrapped__(x) + 1
67
68  return tf_decorator.make_decorator(target, wrapper)
69
70
71def test_function(x):
72  """Test Function Docstring."""
73  return x + 1
74
75
76@test_tfdecorator('decorator 1')
77@test_decorator_increment_first_int_arg
78@test_tfdecorator('decorator 3', 'decorator 3 documentation')
79def test_decorated_function(x):
80  """Test Decorated Function Docstring."""
81  return x * 2
82
83
84@test_injectable_decorator_square
85@test_injectable_decorator_increment
86def test_rewrappable_decorated(x):
87  return x * 2
88
89
90@test_tfdecorator('decorator')
91class TestDecoratedClass(object):
92  """Test Decorated Class."""
93
94  def __init__(self, two_attr=2):
95    self.two_attr = two_attr
96
97  @property
98  def two_prop(self):
99    return 2
100
101  def two_func(self):
102    return 2
103
104  @test_decorator_increment_first_int_arg
105  def return_params(self, a, b, c):
106    """Return parameters."""
107    return [a, b, c]
108
109
110class TfDecoratorTest(test.TestCase):
111
112  def testInitCapturesTarget(self):
113    self.assertIs(test_function,
114                  tf_decorator.TFDecorator('', test_function).decorated_target)
115
116  def testInitCapturesDecoratorName(self):
117    self.assertEqual('decorator name',
118                     tf_decorator.TFDecorator('decorator name',
119                                              test_function).decorator_name)
120
121  def testInitCapturesDecoratorDoc(self):
122    self.assertEqual('decorator doc',
123                     tf_decorator.TFDecorator('', test_function,
124                                              'decorator doc').decorator_doc)
125
126  def testInitCapturesNonNoneArgspec(self):
127    argspec = tf_inspect.ArgSpec(
128        args=['a', 'b', 'c'],
129        varargs=None,
130        keywords=None,
131        defaults=(1, 'hello'))
132    self.assertIs(argspec,
133                  tf_decorator.TFDecorator('', test_function, '',
134                                           argspec).decorator_argspec)
135
136  def testInitSetsDecoratorNameToTargetName(self):
137    self.assertEqual('test_function',
138                     tf_decorator.TFDecorator('', test_function).__name__)
139
140  def testInitSetsDecoratorQualNameToTargetQualName(self):
141    if hasattr(tf_decorator.TFDecorator('', test_function), '__qualname__'):
142      self.assertEqual('test_function',
143                       tf_decorator.TFDecorator('', test_function).__qualname__)
144
145  def testInitSetsDecoratorDocToTargetDoc(self):
146    self.assertEqual('Test Function Docstring.',
147                     tf_decorator.TFDecorator('', test_function).__doc__)
148
149  def testCallingATFDecoratorCallsTheTarget(self):
150    self.assertEqual(124, tf_decorator.TFDecorator('', test_function)(123))
151
152  def testCallingADecoratedFunctionCallsTheTarget(self):
153    self.assertEqual((2 + 1) * 2, test_decorated_function(2))
154
155  def testInitializingDecoratedClassWithInitParamsDoesntRaise(self):
156    try:
157      TestDecoratedClass(2)
158    except TypeError:
159      self.assertFail()
160
161  def testReadingClassAttributeOnDecoratedClass(self):
162    self.assertEqual(2, TestDecoratedClass().two_attr)
163
164  def testCallingClassMethodOnDecoratedClass(self):
165    self.assertEqual(2, TestDecoratedClass().two_func())
166
167  def testReadingClassPropertyOnDecoratedClass(self):
168    self.assertEqual(2, TestDecoratedClass().two_prop)
169
170  def testNameOnBoundProperty(self):
171    self.assertEqual('return_params',
172                     TestDecoratedClass().return_params.__name__)
173
174  def testQualNameOnBoundProperty(self):
175    if hasattr(TestDecoratedClass().return_params, '__qualname__'):
176      self.assertEqual('TestDecoratedClass.return_params',
177                       TestDecoratedClass().return_params.__qualname__)
178
179  def testDocstringOnBoundProperty(self):
180    self.assertEqual('Return parameters.',
181                     TestDecoratedClass().return_params.__doc__)
182
183  def testTarget__get__IsProxied(self):
184    class Descr(object):
185
186      def __get__(self, instance, owner):
187        return self
188
189    class Foo(object):
190      foo = tf_decorator.TFDecorator('Descr', Descr())
191
192    self.assertIsInstance(Foo.foo, Descr)
193
194
195def test_wrapper(*args, **kwargs):
196  return test_function(*args, **kwargs)
197
198
199class TfMakeDecoratorTest(test.TestCase):
200
201  def testAttachesATFDecoratorAttr(self):
202    decorated = tf_decorator.make_decorator(test_function, test_wrapper)
203    decorator = getattr(decorated, '_tf_decorator')
204    self.assertIsInstance(decorator, tf_decorator.TFDecorator)
205
206  def testAttachesWrappedAttr(self):
207    decorated = tf_decorator.make_decorator(test_function, test_wrapper)
208    wrapped_attr = getattr(decorated, '__wrapped__')
209    self.assertIs(test_function, wrapped_attr)
210
211  def testSetsTFDecoratorNameToDecoratorNameArg(self):
212    decorated = tf_decorator.make_decorator(test_function, test_wrapper,
213                                            'test decorator name')
214    decorator = getattr(decorated, '_tf_decorator')
215    self.assertEqual('test decorator name', decorator.decorator_name)
216
217  def testSetsTFDecoratorDocToDecoratorDocArg(self):
218    decorated = tf_decorator.make_decorator(
219        test_function, test_wrapper, decorator_doc='test decorator doc')
220    decorator = getattr(decorated, '_tf_decorator')
221    self.assertEqual('test decorator doc', decorator.decorator_doc)
222
223  def testUpdatesDictWithMissingEntries(self):
224    test_function.foobar = True
225    decorated = tf_decorator.make_decorator(test_function, test_wrapper)
226    self.assertTrue(decorated.foobar)
227    del test_function.foobar
228
229  def testUpdatesDict_doesNotOverridePresentEntries(self):
230    test_function.foobar = True
231    test_wrapper.foobar = False
232    decorated = tf_decorator.make_decorator(test_function, test_wrapper)
233    self.assertFalse(decorated.foobar)
234    del test_function.foobar
235    del test_wrapper.foobar
236
237  def testSetsTFDecoratorArgSpec(self):
238    argspec = tf_inspect.ArgSpec(
239        args=['a', 'b', 'c'],
240        varargs=None,
241        keywords=None,
242        defaults=(1, 'hello'))
243    decorated = tf_decorator.make_decorator(test_function, test_wrapper, '', '',
244                                            argspec)
245    decorator = getattr(decorated, '_tf_decorator')
246    self.assertEqual(argspec, decorator.decorator_argspec)
247
248  def testSetsDecoratorNameToFunctionThatCallsMakeDecoratorIfAbsent(self):
249
250    def test_decorator_name(wrapper):
251      return tf_decorator.make_decorator(test_function, wrapper)
252
253    decorated = test_decorator_name(test_wrapper)
254    decorator = getattr(decorated, '_tf_decorator')
255    self.assertEqual('test_decorator_name', decorator.decorator_name)
256
257  def testCompatibleWithNamelessCallables(self):
258
259    class Callable(object):
260
261      def __call__(self):
262        pass
263
264    callable_object = Callable()
265    # Smoke test: This should not raise an exception, even though
266    # `callable_object` does not have a `__name__` attribute.
267    _ = tf_decorator.make_decorator(callable_object, test_wrapper)
268
269    partial = functools.partial(test_function, x=1)
270    # Smoke test: This should not raise an exception, even though `partial` does
271    # not have `__name__`, `__module__`, and `__doc__` attributes.
272    _ = tf_decorator.make_decorator(partial, test_wrapper)
273
274
275class TfDecoratorRewrapTest(test.TestCase):
276
277  def testRewrapMutatesAffectedFunction(self):
278
279    def new_target(x):
280      return x * 3
281
282    self.assertEqual((1 * 2 + 1) ** 2, test_rewrappable_decorated(1))
283    prev_target, _ = tf_decorator.unwrap(test_rewrappable_decorated)
284    tf_decorator.rewrap(test_rewrappable_decorated, prev_target, new_target)
285    self.assertEqual((1 * 3 + 1) ** 2, test_rewrappable_decorated(1))
286
287  def testRewrapOfDecoratorFunction(self):
288
289    def new_target(x):
290      return x * 3
291
292    prev_target = test_rewrappable_decorated._tf_decorator._decorated_target
293    # In this case, only the outer decorator (test_injectable_decorator_square)
294    # should be preserved.
295    tf_decorator.rewrap(test_rewrappable_decorated, prev_target, new_target)
296    self.assertEqual((1 * 3) ** 2, test_rewrappable_decorated(1))
297
298
299class TfDecoratorUnwrapTest(test.TestCase):
300
301  def testUnwrapReturnsEmptyArrayForUndecoratedFunction(self):
302    decorators, _ = tf_decorator.unwrap(test_function)
303    self.assertEqual(0, len(decorators))
304
305  def testUnwrapReturnsUndecoratedFunctionAsTarget(self):
306    _, target = tf_decorator.unwrap(test_function)
307    self.assertIs(test_function, target)
308
309  def testUnwrapReturnsFinalFunctionAsTarget(self):
310    self.assertEqual((4 + 1) * 2, test_decorated_function(4))
311    _, target = tf_decorator.unwrap(test_decorated_function)
312    self.assertTrue(tf_inspect.isfunction(target))
313    self.assertEqual(4 * 2, target(4))
314
315  def testUnwrapReturnsListOfUniqueTFDecorators(self):
316    decorators, _ = tf_decorator.unwrap(test_decorated_function)
317    self.assertEqual(3, len(decorators))
318    self.assertTrue(isinstance(decorators[0], tf_decorator.TFDecorator))
319    self.assertTrue(isinstance(decorators[1], tf_decorator.TFDecorator))
320    self.assertTrue(isinstance(decorators[2], tf_decorator.TFDecorator))
321    self.assertIsNot(decorators[0], decorators[1])
322    self.assertIsNot(decorators[1], decorators[2])
323    self.assertIsNot(decorators[2], decorators[0])
324
325  def testUnwrapReturnsDecoratorListFromOutermostToInnermost(self):
326    decorators, _ = tf_decorator.unwrap(test_decorated_function)
327    self.assertEqual('decorator 1', decorators[0].decorator_name)
328    self.assertEqual('test_decorator_increment_first_int_arg',
329                     decorators[1].decorator_name)
330    self.assertEqual('decorator 3', decorators[2].decorator_name)
331    self.assertEqual('decorator 3 documentation', decorators[2].decorator_doc)
332
333  def testUnwrapBoundMethods(self):
334    test_decorated_class = TestDecoratedClass()
335    self.assertEqual([2, 2, 3], test_decorated_class.return_params(1, 2, 3))
336    decorators, target = tf_decorator.unwrap(test_decorated_class.return_params)
337    self.assertEqual('test_decorator_increment_first_int_arg',
338                     decorators[0].decorator_name)
339    self.assertEqual([1, 2, 3], target(test_decorated_class, 1, 2, 3))
340
341
342if __name__ == '__main__':
343  test.main()
344