1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.python.util.module_wrapper.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import pickle 22import types 23 24from tensorflow.python.platform import test 25from tensorflow.python.platform import tf_logging as logging 26from tensorflow.python.util import module_wrapper 27from tensorflow.python.util import tf_inspect 28from tensorflow.tools.compatibility import all_renames_v2 29 30module_wrapper._PER_MODULE_WARNING_LIMIT = 5 31 32 33class MockModule(types.ModuleType): 34 pass 35 36 37class DeprecationWrapperTest(test.TestCase): 38 39 def testWrapperIsAModule(self): 40 module = MockModule('test') 41 wrapped_module = module_wrapper.TFModuleWrapper(module, 'test') 42 self.assertTrue(tf_inspect.ismodule(wrapped_module)) 43 44 @test.mock.patch.object(logging, 'warning', autospec=True) 45 def testDeprecationWarnings(self, mock_warning): 46 module = MockModule('test') 47 module.foo = 1 48 module.bar = 2 49 module.baz = 3 50 all_renames_v2.symbol_renames['tf.test.bar'] = 'tf.bar2' 51 all_renames_v2.symbol_renames['tf.test.baz'] = 'tf.compat.v1.baz' 52 53 wrapped_module = module_wrapper.TFModuleWrapper(module, 'test') 54 self.assertTrue(tf_inspect.ismodule(wrapped_module)) 55 56 self.assertEqual(0, mock_warning.call_count) 57 bar = wrapped_module.bar 58 self.assertEqual(1, mock_warning.call_count) 59 foo = wrapped_module.foo 60 self.assertEqual(1, mock_warning.call_count) 61 baz = wrapped_module.baz # pylint: disable=unused-variable 62 self.assertEqual(2, mock_warning.call_count) 63 baz = wrapped_module.baz 64 self.assertEqual(2, mock_warning.call_count) 65 66 # Check that values stayed the same 67 self.assertEqual(module.foo, foo) 68 self.assertEqual(module.bar, bar) 69 70 71class LazyLoadingWrapperTest(test.TestCase): 72 73 def testLazyLoad(self): 74 module = MockModule('test') 75 apis = {'cmd': ('', 'cmd'), 'ABCMeta': ('abc', 'ABCMeta')} 76 wrapped_module = module_wrapper.TFModuleWrapper( 77 module, 'test', public_apis=apis, deprecation=False) 78 import cmd as _cmd # pylint: disable=g-import-not-at-top 79 from abc import ABCMeta as _ABCMeta # pylint: disable=g-import-not-at-top, g-importing-member 80 self.assertFalse(wrapped_module._fastdict_key_in('cmd')) 81 self.assertEqual(wrapped_module.cmd, _cmd) 82 # Verify that the APIs are added to the cache of FastModuleType object 83 self.assertTrue(wrapped_module._fastdict_key_in('cmd')) 84 self.assertFalse(wrapped_module._fastdict_key_in('ABCMeta')) 85 self.assertEqual(wrapped_module.ABCMeta, _ABCMeta) 86 self.assertTrue(wrapped_module._fastdict_key_in('ABCMeta')) 87 88 def testLazyLoadLocalOverride(self): 89 # Test that we can override and add fields to the wrapped module. 90 module = MockModule('test') 91 apis = {'cmd': ('', 'cmd')} 92 wrapped_module = module_wrapper.TFModuleWrapper( 93 module, 'test', public_apis=apis, deprecation=False) 94 import cmd as _cmd # pylint: disable=g-import-not-at-top 95 self.assertEqual(wrapped_module.cmd, _cmd) 96 setattr(wrapped_module, 'cmd', 1) 97 setattr(wrapped_module, 'cgi', 2) 98 self.assertEqual(wrapped_module.cmd, 1) # override 99 # Verify that the values are also updated in the cache 100 # of the FastModuleType object 101 self.assertEqual(wrapped_module._fastdict_get('cmd'), 1) 102 self.assertEqual(wrapped_module.cgi, 2) # add 103 self.assertEqual(wrapped_module._fastdict_get('cgi'), 2) 104 105 def testLazyLoadDict(self): 106 # Test that we can override and add fields to the wrapped module. 107 module = MockModule('test') 108 apis = {'cmd': ('', 'cmd')} 109 wrapped_module = module_wrapper.TFModuleWrapper( 110 module, 'test', public_apis=apis, deprecation=False) 111 import cmd as _cmd # pylint: disable=g-import-not-at-top 112 # At first cmd key does not exist in __dict__ 113 self.assertNotIn('cmd', wrapped_module.__dict__) 114 # After it is referred (lazyloaded), it gets added to __dict__ 115 wrapped_module.cmd # pylint: disable=pointless-statement 116 self.assertEqual(wrapped_module.__dict__['cmd'], _cmd) 117 # When we call setattr, it also gets added to __dict__ 118 setattr(wrapped_module, 'cmd2', _cmd) 119 self.assertEqual(wrapped_module.__dict__['cmd2'], _cmd) 120 121 def testLazyLoadWildcardImport(self): 122 # Test that public APIs are in __all__. 123 module = MockModule('test') 124 module._should_not_be_public = 5 125 apis = {'cmd': ('', 'cmd')} 126 wrapped_module = module_wrapper.TFModuleWrapper( 127 module, 'test', public_apis=apis, deprecation=False) 128 setattr(wrapped_module, 'hello', 1) 129 self.assertIn('hello', wrapped_module.__all__) 130 self.assertIn('cmd', wrapped_module.__all__) 131 self.assertNotIn('_should_not_be_public', wrapped_module.__all__) 132 133 def testLazyLoadCorrectLiteModule(self): 134 # If set, always load lite module from public API list. 135 module = MockModule('test') 136 apis = {'lite': ('', 'cmd')} 137 module.lite = 5 138 import cmd as _cmd # pylint: disable=g-import-not-at-top 139 wrapped_module = module_wrapper.TFModuleWrapper( 140 module, 'test', public_apis=apis, deprecation=False, has_lite=True) 141 self.assertEqual(wrapped_module.lite, _cmd) 142 143 def testInitCachesAttributes(self): 144 module = MockModule('test') 145 wrapped_module = module_wrapper.TFModuleWrapper(module, 'test') 146 self.assertTrue(wrapped_module._fastdict_key_in('_fastdict_key_in')) 147 self.assertTrue(wrapped_module._fastdict_key_in('_tfmw_module_name')) 148 self.assertTrue(wrapped_module._fastdict_key_in('__all__')) 149 150 151class PickleTest(test.TestCase): 152 153 def testPickleSubmodule(self): 154 name = PickleTest.__module__ # The current module is a submodule. 155 module = module_wrapper.TFModuleWrapper(MockModule(name), name) 156 restored = pickle.loads(pickle.dumps(module)) 157 self.assertEqual(restored.__name__, name) 158 self.assertIsNotNone(restored.PickleTest) 159 160 161if __name__ == '__main__': 162 test.main() 163