• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for unspect_utils module."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import functools
23import imp
24import types
25import weakref
26
27import six
28
29from tensorflow.python import lib
30from tensorflow.python.autograph.pyct import inspect_utils
31from tensorflow.python.autograph.pyct.testing import future_import_module
32from tensorflow.python.eager import function
33from tensorflow.python.framework import constant_op
34from tensorflow.python.platform import test
35
36
37def decorator(f):
38  return f
39
40
41def function_decorator():
42  def dec(f):
43    return f
44  return dec
45
46
47def wrapping_decorator():
48  def dec(f):
49    def replacement(*_):
50      return None
51
52    @functools.wraps(f)
53    def wrapper(*args, **kwargs):
54      return replacement(*args, **kwargs)
55    return wrapper
56  return dec
57
58
59class TestClass(object):
60
61  def member_function(self):
62    pass
63
64  @decorator
65  def decorated_member(self):
66    pass
67
68  @function_decorator()
69  def fn_decorated_member(self):
70    pass
71
72  @wrapping_decorator()
73  def wrap_decorated_member(self):
74    pass
75
76  @staticmethod
77  def static_method():
78    pass
79
80  @classmethod
81  def class_method(cls):
82    pass
83
84
85def free_function():
86  pass
87
88
89def factory():
90  return free_function
91
92
93def free_factory():
94  def local_function():
95    pass
96  return local_function
97
98
99class InspectUtilsTest(test.TestCase):
100
101  def test_islambda(self):
102    def test_fn():
103      pass
104
105    self.assertTrue(inspect_utils.islambda(lambda x: x))
106    self.assertFalse(inspect_utils.islambda(test_fn))
107
108  def test_isnamedtuple(self):
109    nt = collections.namedtuple('TestNamedTuple', ['a', 'b'])
110
111    class NotANamedTuple(tuple):
112      pass
113
114    self.assertTrue(inspect_utils.isnamedtuple(nt))
115    self.assertFalse(inspect_utils.isnamedtuple(NotANamedTuple))
116
117  def test_isnamedtuple_confounder(self):
118    """This test highlights false positives when detecting named tuples."""
119
120    class NamedTupleLike(tuple):
121      _fields = ('a', 'b')
122
123    self.assertTrue(inspect_utils.isnamedtuple(NamedTupleLike))
124
125  def test_isnamedtuple_subclass(self):
126    """This test highlights false positives when detecting named tuples."""
127
128    class NamedTupleSubclass(collections.namedtuple('Test', ['a', 'b'])):
129      pass
130
131    self.assertTrue(inspect_utils.isnamedtuple(NamedTupleSubclass))
132
133  def test_getnamespace_globals(self):
134    ns = inspect_utils.getnamespace(factory)
135    self.assertEqual(ns['free_function'], free_function)
136
137  def test_getnamespace_hermetic(self):
138
139    # Intentionally hiding the global function to make sure we don't overwrite
140    # it in the global namespace.
141    free_function = object()  # pylint:disable=redefined-outer-name
142
143    def test_fn():
144      return free_function
145
146    ns = inspect_utils.getnamespace(test_fn)
147    globs = six.get_function_globals(test_fn)
148    self.assertTrue(ns['free_function'] is free_function)
149    self.assertFalse(globs['free_function'] is free_function)
150
151  def test_getnamespace_locals(self):
152
153    def called_fn():
154      return 0
155
156    closed_over_list = []
157    closed_over_primitive = 1
158
159    def local_fn():
160      closed_over_list.append(1)
161      local_var = 1
162      return called_fn() + local_var + closed_over_primitive
163
164    ns = inspect_utils.getnamespace(local_fn)
165    self.assertEqual(ns['called_fn'], called_fn)
166    self.assertEqual(ns['closed_over_list'], closed_over_list)
167    self.assertEqual(ns['closed_over_primitive'], closed_over_primitive)
168    self.assertTrue('local_var' not in ns)
169
170  def test_getqualifiedname(self):
171    foo = object()
172    qux = imp.new_module('quxmodule')
173    bar = imp.new_module('barmodule')
174    baz = object()
175    bar.baz = baz
176
177    ns = {
178        'foo': foo,
179        'bar': bar,
180        'qux': qux,
181    }
182
183    self.assertIsNone(inspect_utils.getqualifiedname(ns, inspect_utils))
184    self.assertEqual(inspect_utils.getqualifiedname(ns, foo), 'foo')
185    self.assertEqual(inspect_utils.getqualifiedname(ns, bar), 'bar')
186    self.assertEqual(inspect_utils.getqualifiedname(ns, baz), 'bar.baz')
187
188  def test_getqualifiedname_efficiency(self):
189    foo = object()
190
191    # We create a densely connected graph consisting of a relatively small
192    # number of modules and hide our symbol in one of them. The path to the
193    # symbol is at least 10, and each node has about 10 neighbors. However,
194    # by skipping visited modules, the search should take much less.
195    ns = {}
196    prev_level = []
197    for i in range(10):
198      current_level = []
199      for j in range(10):
200        mod_name = 'mod_{}_{}'.format(i, j)
201        mod = imp.new_module(mod_name)
202        current_level.append(mod)
203        if i == 9 and j == 9:
204          mod.foo = foo
205      if prev_level:
206        # All modules at level i refer to all modules at level i+1
207        for prev in prev_level:
208          for mod in current_level:
209            prev.__dict__[mod.__name__] = mod
210      else:
211        for mod in current_level:
212          ns[mod.__name__] = mod
213      prev_level = current_level
214
215    self.assertIsNone(inspect_utils.getqualifiedname(ns, inspect_utils))
216    self.assertIsNotNone(
217        inspect_utils.getqualifiedname(ns, foo, max_depth=10000000000))
218
219  def test_getqualifiedname_cycles(self):
220    foo = object()
221
222    # We create a graph of modules that contains circular references. The
223    # search process should avoid them. The searched object is hidden at the
224    # bottom of a path of length roughly 10.
225    ns = {}
226    mods = []
227    for i in range(10):
228      mod = imp.new_module('mod_{}'.format(i))
229      if i == 9:
230        mod.foo = foo
231      # Module i refers to module i+1
232      if mods:
233        mods[-1].__dict__[mod.__name__] = mod
234      else:
235        ns[mod.__name__] = mod
236      # Module i refers to all modules j < i.
237      for prev in mods:
238        mod.__dict__[prev.__name__] = prev
239      mods.append(mod)
240
241    self.assertIsNone(inspect_utils.getqualifiedname(ns, inspect_utils))
242    self.assertIsNotNone(
243        inspect_utils.getqualifiedname(ns, foo, max_depth=10000000000))
244
245  def test_getqualifiedname_finds_via_parent_module(self):
246    # TODO(mdan): This test is vulnerable to change in the lib module.
247    # A better way to forge modules should be found.
248    self.assertEqual(
249        inspect_utils.getqualifiedname(
250            lib.__dict__, lib.io.file_io.FileIO, max_depth=1),
251        'io.file_io.FileIO')
252
253  def test_getmethodclass(self):
254
255    self.assertEqual(
256        inspect_utils.getmethodclass(free_function), None)
257    self.assertEqual(
258        inspect_utils.getmethodclass(free_factory()), None)
259
260    self.assertEqual(
261        inspect_utils.getmethodclass(TestClass.member_function),
262        TestClass)
263    self.assertEqual(
264        inspect_utils.getmethodclass(TestClass.decorated_member),
265        TestClass)
266    self.assertEqual(
267        inspect_utils.getmethodclass(TestClass.fn_decorated_member),
268        TestClass)
269    self.assertEqual(
270        inspect_utils.getmethodclass(TestClass.wrap_decorated_member),
271        TestClass)
272    self.assertEqual(
273        inspect_utils.getmethodclass(TestClass.static_method),
274        TestClass)
275    self.assertEqual(
276        inspect_utils.getmethodclass(TestClass.class_method),
277        TestClass)
278
279    test_obj = TestClass()
280    self.assertEqual(
281        inspect_utils.getmethodclass(test_obj.member_function),
282        TestClass)
283    self.assertEqual(
284        inspect_utils.getmethodclass(test_obj.decorated_member),
285        TestClass)
286    self.assertEqual(
287        inspect_utils.getmethodclass(test_obj.fn_decorated_member),
288        TestClass)
289    self.assertEqual(
290        inspect_utils.getmethodclass(test_obj.wrap_decorated_member),
291        TestClass)
292    self.assertEqual(
293        inspect_utils.getmethodclass(test_obj.static_method),
294        TestClass)
295    self.assertEqual(
296        inspect_utils.getmethodclass(test_obj.class_method),
297        TestClass)
298
299  def test_getmethodclass_locals(self):
300
301    def local_function():
302      pass
303
304    class LocalClass(object):
305
306      def member_function(self):
307        pass
308
309      @decorator
310      def decorated_member(self):
311        pass
312
313      @function_decorator()
314      def fn_decorated_member(self):
315        pass
316
317      @wrapping_decorator()
318      def wrap_decorated_member(self):
319        pass
320
321    self.assertEqual(
322        inspect_utils.getmethodclass(local_function), None)
323
324    self.assertEqual(
325        inspect_utils.getmethodclass(LocalClass.member_function),
326        LocalClass)
327    self.assertEqual(
328        inspect_utils.getmethodclass(LocalClass.decorated_member),
329        LocalClass)
330    self.assertEqual(
331        inspect_utils.getmethodclass(LocalClass.fn_decorated_member),
332        LocalClass)
333    self.assertEqual(
334        inspect_utils.getmethodclass(LocalClass.wrap_decorated_member),
335        LocalClass)
336
337    test_obj = LocalClass()
338    self.assertEqual(
339        inspect_utils.getmethodclass(test_obj.member_function),
340        LocalClass)
341    self.assertEqual(
342        inspect_utils.getmethodclass(test_obj.decorated_member),
343        LocalClass)
344    self.assertEqual(
345        inspect_utils.getmethodclass(test_obj.fn_decorated_member),
346        LocalClass)
347    self.assertEqual(
348        inspect_utils.getmethodclass(test_obj.wrap_decorated_member),
349        LocalClass)
350
351  def test_getmethodclass_callables(self):
352    class TestCallable(object):
353
354      def __call__(self):
355        pass
356
357    c = TestCallable()
358    self.assertEqual(inspect_utils.getmethodclass(c), TestCallable)
359
360  def test_getmethodclass_weakref_mechanism(self):
361    test_obj = TestClass()
362
363    def test_fn(self):
364      return self
365
366    bound_method = types.MethodType(
367        test_fn,
368        function.TfMethodTarget(
369            weakref.ref(test_obj), test_obj.member_function))
370    self.assertEqual(inspect_utils.getmethodclass(bound_method), TestClass)
371
372  def test_getmethodclass_no_bool_conversion(self):
373
374    tensor = constant_op.constant([1])
375    self.assertEqual(
376        inspect_utils.getmethodclass(tensor.get_shape), type(tensor))
377
378  def test_getdefiningclass(self):
379    class Superclass(object):
380
381      def foo(self):
382        pass
383
384      def bar(self):
385        pass
386
387      @classmethod
388      def class_method(cls):
389        pass
390
391    class Subclass(Superclass):
392
393      def foo(self):
394        pass
395
396      def baz(self):
397        pass
398
399    self.assertTrue(
400        inspect_utils.getdefiningclass(Subclass.foo, Subclass) is Subclass)
401    self.assertTrue(
402        inspect_utils.getdefiningclass(Subclass.bar, Subclass) is Superclass)
403    self.assertTrue(
404        inspect_utils.getdefiningclass(Subclass.baz, Subclass) is Subclass)
405    self.assertTrue(
406        inspect_utils.getdefiningclass(Subclass.class_method, Subclass) is
407        Superclass)
408
409  def test_isbuiltin(self):
410    self.assertTrue(inspect_utils.isbuiltin(enumerate))
411    self.assertTrue(inspect_utils.isbuiltin(float))
412    self.assertTrue(inspect_utils.isbuiltin(int))
413    self.assertTrue(inspect_utils.isbuiltin(len))
414    self.assertTrue(inspect_utils.isbuiltin(range))
415    self.assertTrue(inspect_utils.isbuiltin(zip))
416    self.assertFalse(inspect_utils.isbuiltin(function_decorator))
417
418  def test_getfutureimports_simple_case(self):
419    expected_imports = ('absolute_import', 'division', 'print_function',
420                        'with_statement')
421    self.assertEqual(inspect_utils.getfutureimports(future_import_module.f),
422                     expected_imports)
423
424  def test_super_wrapper_for_dynamic_attrs(self):
425
426    a = object()
427    b = object()
428
429    class Base(object):
430
431      def __init__(self):
432        self.a = a
433
434    class Subclass(Base):
435
436      def __init__(self):
437        super(Subclass, self).__init__()
438        self.b = b
439
440    base = Base()
441    sub = Subclass()
442
443    sub_super = super(Subclass, sub)
444    sub_super_wrapped = inspect_utils.SuperWrapperForDynamicAttrs(sub_super)
445
446    self.assertIs(base.a, a)
447    self.assertIs(sub.a, a)
448
449    self.assertFalse(hasattr(sub_super, 'a'))
450    self.assertIs(sub_super_wrapped.a, a)
451
452    # TODO(mdan): Is this side effect harmful? Can it be avoided?
453    # Note that `b` was set in `Subclass.__init__`.
454    self.assertIs(sub_super_wrapped.b, b)
455
456
457if __name__ == '__main__':
458  test.main()
459