• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.python.util.module_wrapper."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import pickle
22import types
23
24from tensorflow.python.platform import test
25from tensorflow.python.platform import tf_logging as logging
26from tensorflow.python.util import module_wrapper
27from tensorflow.python.util import tf_inspect
28from tensorflow.tools.compatibility import all_renames_v2
29
30module_wrapper._PER_MODULE_WARNING_LIMIT = 5
31
32
33class MockModule(types.ModuleType):
34  pass
35
36
37class DeprecationWrapperTest(test.TestCase):
38
39  def testWrapperIsAModule(self):
40    module = MockModule('test')
41    wrapped_module = module_wrapper.TFModuleWrapper(module, 'test')
42    self.assertTrue(tf_inspect.ismodule(wrapped_module))
43
44  @test.mock.patch.object(logging, 'warning', autospec=True)
45  def testDeprecationWarnings(self, mock_warning):
46    module = MockModule('test')
47    module.foo = 1
48    module.bar = 2
49    module.baz = 3
50    all_renames_v2.symbol_renames['tf.test.bar'] = 'tf.bar2'
51    all_renames_v2.symbol_renames['tf.test.baz'] = 'tf.compat.v1.baz'
52
53    wrapped_module = module_wrapper.TFModuleWrapper(module, 'test')
54    self.assertTrue(tf_inspect.ismodule(wrapped_module))
55
56    self.assertEqual(0, mock_warning.call_count)
57    bar = wrapped_module.bar
58    self.assertEqual(1, mock_warning.call_count)
59    foo = wrapped_module.foo
60    self.assertEqual(1, mock_warning.call_count)
61    baz = wrapped_module.baz  # pylint: disable=unused-variable
62    self.assertEqual(2, mock_warning.call_count)
63    baz = wrapped_module.baz
64    self.assertEqual(2, mock_warning.call_count)
65
66    # Check that values stayed the same
67    self.assertEqual(module.foo, foo)
68    self.assertEqual(module.bar, bar)
69
70
71class LazyLoadingWrapperTest(test.TestCase):
72
73  def testLazyLoad(self):
74    module = MockModule('test')
75    apis = {'cmd': ('', 'cmd'), 'ABCMeta': ('abc', 'ABCMeta')}
76    wrapped_module = module_wrapper.TFModuleWrapper(
77        module, 'test', public_apis=apis, deprecation=False)
78    import cmd as _cmd  # pylint: disable=g-import-not-at-top
79    from abc import ABCMeta as _ABCMeta  # pylint: disable=g-import-not-at-top, g-importing-member
80    self.assertFalse(wrapped_module._fastdict_key_in('cmd'))
81    self.assertEqual(wrapped_module.cmd, _cmd)
82    # Verify that the APIs are added to the cache of FastModuleType object
83    self.assertTrue(wrapped_module._fastdict_key_in('cmd'))
84    self.assertFalse(wrapped_module._fastdict_key_in('ABCMeta'))
85    self.assertEqual(wrapped_module.ABCMeta, _ABCMeta)
86    self.assertTrue(wrapped_module._fastdict_key_in('ABCMeta'))
87
88  def testLazyLoadLocalOverride(self):
89    # Test that we can override and add fields to the wrapped module.
90    module = MockModule('test')
91    apis = {'cmd': ('', 'cmd')}
92    wrapped_module = module_wrapper.TFModuleWrapper(
93        module, 'test', public_apis=apis, deprecation=False)
94    import cmd as _cmd  # pylint: disable=g-import-not-at-top
95    self.assertEqual(wrapped_module.cmd, _cmd)
96    setattr(wrapped_module, 'cmd', 1)
97    setattr(wrapped_module, 'cgi', 2)
98    self.assertEqual(wrapped_module.cmd, 1)  # override
99    # Verify that the values are also updated in the cache
100    # of the FastModuleType object
101    self.assertEqual(wrapped_module._fastdict_get('cmd'), 1)
102    self.assertEqual(wrapped_module.cgi, 2)  # add
103    self.assertEqual(wrapped_module._fastdict_get('cgi'), 2)
104
105  def testLazyLoadDict(self):
106    # Test that we can override and add fields to the wrapped module.
107    module = MockModule('test')
108    apis = {'cmd': ('', 'cmd')}
109    wrapped_module = module_wrapper.TFModuleWrapper(
110        module, 'test', public_apis=apis, deprecation=False)
111    import cmd as _cmd  # pylint: disable=g-import-not-at-top
112    # At first cmd key does not exist in __dict__
113    self.assertNotIn('cmd', wrapped_module.__dict__)
114    # After it is referred (lazyloaded), it gets added to __dict__
115    wrapped_module.cmd  # pylint: disable=pointless-statement
116    self.assertEqual(wrapped_module.__dict__['cmd'], _cmd)
117    # When we call setattr, it also gets added to __dict__
118    setattr(wrapped_module, 'cmd2', _cmd)
119    self.assertEqual(wrapped_module.__dict__['cmd2'], _cmd)
120
121  def testLazyLoadWildcardImport(self):
122    # Test that public APIs are in __all__.
123    module = MockModule('test')
124    module._should_not_be_public = 5
125    apis = {'cmd': ('', 'cmd')}
126    wrapped_module = module_wrapper.TFModuleWrapper(
127        module, 'test', public_apis=apis, deprecation=False)
128    setattr(wrapped_module, 'hello', 1)
129    self.assertIn('hello', wrapped_module.__all__)
130    self.assertIn('cmd', wrapped_module.__all__)
131    self.assertNotIn('_should_not_be_public', wrapped_module.__all__)
132
133  def testLazyLoadCorrectLiteModule(self):
134    # If set, always load lite module from public API list.
135    module = MockModule('test')
136    apis = {'lite': ('', 'cmd')}
137    module.lite = 5
138    import cmd as _cmd  # pylint: disable=g-import-not-at-top
139    wrapped_module = module_wrapper.TFModuleWrapper(
140        module, 'test', public_apis=apis, deprecation=False, has_lite=True)
141    self.assertEqual(wrapped_module.lite, _cmd)
142
143  def testInitCachesAttributes(self):
144    module = MockModule('test')
145    wrapped_module = module_wrapper.TFModuleWrapper(module, 'test')
146    self.assertTrue(wrapped_module._fastdict_key_in('_fastdict_key_in'))
147    self.assertTrue(wrapped_module._fastdict_key_in('_tfmw_module_name'))
148    self.assertTrue(wrapped_module._fastdict_key_in('__all__'))
149
150
151class PickleTest(test.TestCase):
152
153  def testPickleSubmodule(self):
154    name = PickleTest.__module__  # The current module is a submodule.
155    module = module_wrapper.TFModuleWrapper(MockModule(name), name)
156    restored = pickle.loads(pickle.dumps(module))
157    self.assertEqual(restored.__name__, name)
158    self.assertIsNotNone(restored.PickleTest)
159
160
161if __name__ == '__main__':
162  test.main()
163