1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for unspect_utils module.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import functools 23import imp 24import types 25import weakref 26 27import six 28 29from tensorflow.python import lib 30from tensorflow.python.autograph.pyct import inspect_utils 31from tensorflow.python.autograph.pyct.testing import future_import_module 32from tensorflow.python.eager import function 33from tensorflow.python.framework import constant_op 34from tensorflow.python.platform import test 35 36 37def decorator(f): 38 return f 39 40 41def function_decorator(): 42 def dec(f): 43 return f 44 return dec 45 46 47def wrapping_decorator(): 48 def dec(f): 49 def replacement(*_): 50 return None 51 52 @functools.wraps(f) 53 def wrapper(*args, **kwargs): 54 return replacement(*args, **kwargs) 55 return wrapper 56 return dec 57 58 59class TestClass(object): 60 61 def member_function(self): 62 pass 63 64 @decorator 65 def decorated_member(self): 66 pass 67 68 @function_decorator() 69 def fn_decorated_member(self): 70 pass 71 72 @wrapping_decorator() 73 def wrap_decorated_member(self): 74 pass 75 76 @staticmethod 77 def static_method(): 78 pass 79 80 @classmethod 81 def class_method(cls): 82 pass 83 84 85def free_function(): 86 pass 87 88 89def factory(): 90 return free_function 91 92 93def free_factory(): 94 def local_function(): 95 pass 96 return local_function 97 98 99class InspectUtilsTest(test.TestCase): 100 101 def test_islambda(self): 102 def test_fn(): 103 pass 104 105 self.assertTrue(inspect_utils.islambda(lambda x: x)) 106 self.assertFalse(inspect_utils.islambda(test_fn)) 107 108 def test_isnamedtuple(self): 109 nt = collections.namedtuple('TestNamedTuple', ['a', 'b']) 110 111 class NotANamedTuple(tuple): 112 pass 113 114 self.assertTrue(inspect_utils.isnamedtuple(nt)) 115 self.assertFalse(inspect_utils.isnamedtuple(NotANamedTuple)) 116 117 def test_isnamedtuple_confounder(self): 118 """This test highlights false positives when detecting named tuples.""" 119 120 class NamedTupleLike(tuple): 121 _fields = ('a', 'b') 122 123 self.assertTrue(inspect_utils.isnamedtuple(NamedTupleLike)) 124 125 def test_isnamedtuple_subclass(self): 126 """This test highlights false positives when detecting named tuples.""" 127 128 class NamedTupleSubclass(collections.namedtuple('Test', ['a', 'b'])): 129 pass 130 131 self.assertTrue(inspect_utils.isnamedtuple(NamedTupleSubclass)) 132 133 def test_getnamespace_globals(self): 134 ns = inspect_utils.getnamespace(factory) 135 self.assertEqual(ns['free_function'], free_function) 136 137 def test_getnamespace_hermetic(self): 138 139 # Intentionally hiding the global function to make sure we don't overwrite 140 # it in the global namespace. 141 free_function = object() # pylint:disable=redefined-outer-name 142 143 def test_fn(): 144 return free_function 145 146 ns = inspect_utils.getnamespace(test_fn) 147 globs = six.get_function_globals(test_fn) 148 self.assertTrue(ns['free_function'] is free_function) 149 self.assertFalse(globs['free_function'] is free_function) 150 151 def test_getnamespace_locals(self): 152 153 def called_fn(): 154 return 0 155 156 closed_over_list = [] 157 closed_over_primitive = 1 158 159 def local_fn(): 160 closed_over_list.append(1) 161 local_var = 1 162 return called_fn() + local_var + closed_over_primitive 163 164 ns = inspect_utils.getnamespace(local_fn) 165 self.assertEqual(ns['called_fn'], called_fn) 166 self.assertEqual(ns['closed_over_list'], closed_over_list) 167 self.assertEqual(ns['closed_over_primitive'], closed_over_primitive) 168 self.assertTrue('local_var' not in ns) 169 170 def test_getqualifiedname(self): 171 foo = object() 172 qux = imp.new_module('quxmodule') 173 bar = imp.new_module('barmodule') 174 baz = object() 175 bar.baz = baz 176 177 ns = { 178 'foo': foo, 179 'bar': bar, 180 'qux': qux, 181 } 182 183 self.assertIsNone(inspect_utils.getqualifiedname(ns, inspect_utils)) 184 self.assertEqual(inspect_utils.getqualifiedname(ns, foo), 'foo') 185 self.assertEqual(inspect_utils.getqualifiedname(ns, bar), 'bar') 186 self.assertEqual(inspect_utils.getqualifiedname(ns, baz), 'bar.baz') 187 188 def test_getqualifiedname_efficiency(self): 189 foo = object() 190 191 # We create a densely connected graph consisting of a relatively small 192 # number of modules and hide our symbol in one of them. The path to the 193 # symbol is at least 10, and each node has about 10 neighbors. However, 194 # by skipping visited modules, the search should take much less. 195 ns = {} 196 prev_level = [] 197 for i in range(10): 198 current_level = [] 199 for j in range(10): 200 mod_name = 'mod_{}_{}'.format(i, j) 201 mod = imp.new_module(mod_name) 202 current_level.append(mod) 203 if i == 9 and j == 9: 204 mod.foo = foo 205 if prev_level: 206 # All modules at level i refer to all modules at level i+1 207 for prev in prev_level: 208 for mod in current_level: 209 prev.__dict__[mod.__name__] = mod 210 else: 211 for mod in current_level: 212 ns[mod.__name__] = mod 213 prev_level = current_level 214 215 self.assertIsNone(inspect_utils.getqualifiedname(ns, inspect_utils)) 216 self.assertIsNotNone( 217 inspect_utils.getqualifiedname(ns, foo, max_depth=10000000000)) 218 219 def test_getqualifiedname_cycles(self): 220 foo = object() 221 222 # We create a graph of modules that contains circular references. The 223 # search process should avoid them. The searched object is hidden at the 224 # bottom of a path of length roughly 10. 225 ns = {} 226 mods = [] 227 for i in range(10): 228 mod = imp.new_module('mod_{}'.format(i)) 229 if i == 9: 230 mod.foo = foo 231 # Module i refers to module i+1 232 if mods: 233 mods[-1].__dict__[mod.__name__] = mod 234 else: 235 ns[mod.__name__] = mod 236 # Module i refers to all modules j < i. 237 for prev in mods: 238 mod.__dict__[prev.__name__] = prev 239 mods.append(mod) 240 241 self.assertIsNone(inspect_utils.getqualifiedname(ns, inspect_utils)) 242 self.assertIsNotNone( 243 inspect_utils.getqualifiedname(ns, foo, max_depth=10000000000)) 244 245 def test_getqualifiedname_finds_via_parent_module(self): 246 # TODO(mdan): This test is vulnerable to change in the lib module. 247 # A better way to forge modules should be found. 248 self.assertEqual( 249 inspect_utils.getqualifiedname( 250 lib.__dict__, lib.io.file_io.FileIO, max_depth=1), 251 'io.file_io.FileIO') 252 253 def test_getmethodclass(self): 254 255 self.assertEqual( 256 inspect_utils.getmethodclass(free_function), None) 257 self.assertEqual( 258 inspect_utils.getmethodclass(free_factory()), None) 259 260 self.assertEqual( 261 inspect_utils.getmethodclass(TestClass.member_function), 262 TestClass) 263 self.assertEqual( 264 inspect_utils.getmethodclass(TestClass.decorated_member), 265 TestClass) 266 self.assertEqual( 267 inspect_utils.getmethodclass(TestClass.fn_decorated_member), 268 TestClass) 269 self.assertEqual( 270 inspect_utils.getmethodclass(TestClass.wrap_decorated_member), 271 TestClass) 272 self.assertEqual( 273 inspect_utils.getmethodclass(TestClass.static_method), 274 TestClass) 275 self.assertEqual( 276 inspect_utils.getmethodclass(TestClass.class_method), 277 TestClass) 278 279 test_obj = TestClass() 280 self.assertEqual( 281 inspect_utils.getmethodclass(test_obj.member_function), 282 TestClass) 283 self.assertEqual( 284 inspect_utils.getmethodclass(test_obj.decorated_member), 285 TestClass) 286 self.assertEqual( 287 inspect_utils.getmethodclass(test_obj.fn_decorated_member), 288 TestClass) 289 self.assertEqual( 290 inspect_utils.getmethodclass(test_obj.wrap_decorated_member), 291 TestClass) 292 self.assertEqual( 293 inspect_utils.getmethodclass(test_obj.static_method), 294 TestClass) 295 self.assertEqual( 296 inspect_utils.getmethodclass(test_obj.class_method), 297 TestClass) 298 299 def test_getmethodclass_locals(self): 300 301 def local_function(): 302 pass 303 304 class LocalClass(object): 305 306 def member_function(self): 307 pass 308 309 @decorator 310 def decorated_member(self): 311 pass 312 313 @function_decorator() 314 def fn_decorated_member(self): 315 pass 316 317 @wrapping_decorator() 318 def wrap_decorated_member(self): 319 pass 320 321 self.assertEqual( 322 inspect_utils.getmethodclass(local_function), None) 323 324 self.assertEqual( 325 inspect_utils.getmethodclass(LocalClass.member_function), 326 LocalClass) 327 self.assertEqual( 328 inspect_utils.getmethodclass(LocalClass.decorated_member), 329 LocalClass) 330 self.assertEqual( 331 inspect_utils.getmethodclass(LocalClass.fn_decorated_member), 332 LocalClass) 333 self.assertEqual( 334 inspect_utils.getmethodclass(LocalClass.wrap_decorated_member), 335 LocalClass) 336 337 test_obj = LocalClass() 338 self.assertEqual( 339 inspect_utils.getmethodclass(test_obj.member_function), 340 LocalClass) 341 self.assertEqual( 342 inspect_utils.getmethodclass(test_obj.decorated_member), 343 LocalClass) 344 self.assertEqual( 345 inspect_utils.getmethodclass(test_obj.fn_decorated_member), 346 LocalClass) 347 self.assertEqual( 348 inspect_utils.getmethodclass(test_obj.wrap_decorated_member), 349 LocalClass) 350 351 def test_getmethodclass_callables(self): 352 class TestCallable(object): 353 354 def __call__(self): 355 pass 356 357 c = TestCallable() 358 self.assertEqual(inspect_utils.getmethodclass(c), TestCallable) 359 360 def test_getmethodclass_weakref_mechanism(self): 361 test_obj = TestClass() 362 363 def test_fn(self): 364 return self 365 366 bound_method = types.MethodType( 367 test_fn, 368 function.TfMethodTarget( 369 weakref.ref(test_obj), test_obj.member_function)) 370 self.assertEqual(inspect_utils.getmethodclass(bound_method), TestClass) 371 372 def test_getmethodclass_no_bool_conversion(self): 373 374 tensor = constant_op.constant([1]) 375 self.assertEqual( 376 inspect_utils.getmethodclass(tensor.get_shape), type(tensor)) 377 378 def test_getdefiningclass(self): 379 class Superclass(object): 380 381 def foo(self): 382 pass 383 384 def bar(self): 385 pass 386 387 @classmethod 388 def class_method(cls): 389 pass 390 391 class Subclass(Superclass): 392 393 def foo(self): 394 pass 395 396 def baz(self): 397 pass 398 399 self.assertTrue( 400 inspect_utils.getdefiningclass(Subclass.foo, Subclass) is Subclass) 401 self.assertTrue( 402 inspect_utils.getdefiningclass(Subclass.bar, Subclass) is Superclass) 403 self.assertTrue( 404 inspect_utils.getdefiningclass(Subclass.baz, Subclass) is Subclass) 405 self.assertTrue( 406 inspect_utils.getdefiningclass(Subclass.class_method, Subclass) is 407 Superclass) 408 409 def test_isbuiltin(self): 410 self.assertTrue(inspect_utils.isbuiltin(enumerate)) 411 self.assertTrue(inspect_utils.isbuiltin(float)) 412 self.assertTrue(inspect_utils.isbuiltin(int)) 413 self.assertTrue(inspect_utils.isbuiltin(len)) 414 self.assertTrue(inspect_utils.isbuiltin(range)) 415 self.assertTrue(inspect_utils.isbuiltin(zip)) 416 self.assertFalse(inspect_utils.isbuiltin(function_decorator)) 417 418 def test_getfutureimports_simple_case(self): 419 expected_imports = ('absolute_import', 'division', 'print_function', 420 'with_statement') 421 self.assertEqual(inspect_utils.getfutureimports(future_import_module.f), 422 expected_imports) 423 424 def test_super_wrapper_for_dynamic_attrs(self): 425 426 a = object() 427 b = object() 428 429 class Base(object): 430 431 def __init__(self): 432 self.a = a 433 434 class Subclass(Base): 435 436 def __init__(self): 437 super(Subclass, self).__init__() 438 self.b = b 439 440 base = Base() 441 sub = Subclass() 442 443 sub_super = super(Subclass, sub) 444 sub_super_wrapped = inspect_utils.SuperWrapperForDynamicAttrs(sub_super) 445 446 self.assertIs(base.a, a) 447 self.assertIs(sub.a, a) 448 449 self.assertFalse(hasattr(sub_super, 'a')) 450 self.assertIs(sub_super_wrapped.a, a) 451 452 # TODO(mdan): Is this side effect harmful? Can it be avoided? 453 # Note that `b` was set in `Subclass.__init__`. 454 self.assertIs(sub_super_wrapped.b, b) 455 456 457if __name__ == '__main__': 458 test.main() 459