• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""tf_export tests."""
16
17# pylint: disable=unused-import
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import sys
23
24from tensorflow.python.platform import test
25from tensorflow.python.util import tf_decorator
26from tensorflow.python.util import tf_export
27
28
29def _test_function(unused_arg=0):
30  pass
31
32
33def _test_function2(unused_arg=0):
34  pass
35
36
37class TestClassA(object):
38  pass
39
40
41class TestClassB(TestClassA):
42  pass
43
44
45class ValidateExportTest(test.TestCase):
46  """Tests for tf_export class."""
47
48  class MockModule(object):
49
50    def __init__(self, name):
51      self.__name__ = name
52
53  def setUp(self):
54    self._modules = []
55
56  def tearDown(self):
57    for name in self._modules:
58      del sys.modules[name]
59    self._modules = []
60    for symbol in [_test_function, _test_function, TestClassA, TestClassB]:
61      if hasattr(symbol, '_tf_api_names'):
62        del symbol._tf_api_names
63      if hasattr(symbol, '_tf_api_names_v1'):
64        del symbol._tf_api_names_v1
65      if hasattr(symbol, '_estimator_api_names'):
66        del symbol._estimator_api_names
67      if hasattr(symbol, '_estimator_api_names_v1'):
68        del symbol._estimator_api_names_v1
69
70  def _CreateMockModule(self, name):
71    mock_module = self.MockModule(name)
72    sys.modules[name] = mock_module
73    self._modules.append(name)
74    return mock_module
75
76  def testExportSingleFunction(self):
77    export_decorator = tf_export.tf_export('nameA', 'nameB')
78    decorated_function = export_decorator(_test_function)
79    self.assertEqual(decorated_function, _test_function)
80    self.assertEqual(('nameA', 'nameB'), decorated_function._tf_api_names)
81    self.assertEqual(['nameA', 'nameB'],
82                     tf_export.get_v1_names(decorated_function))
83    self.assertEqual(['nameA', 'nameB'],
84                     tf_export.get_v2_names(decorated_function))
85    self.assertEqual(tf_export.get_symbol_from_name('nameA'),
86                     decorated_function)
87    self.assertEqual(tf_export.get_symbol_from_name('nameB'),
88                     decorated_function)
89    self.assertEqual(
90        tf_export.get_symbol_from_name(
91            tf_export.get_canonical_name_for_symbol(decorated_function)),
92        decorated_function)
93
94  def testExportSingleFunctionV1Only(self):
95    export_decorator = tf_export.tf_export(v1=['nameA', 'nameB'])
96    decorated_function = export_decorator(_test_function)
97    self.assertEqual(decorated_function, _test_function)
98    self.assertAllEqual(('nameA', 'nameB'), decorated_function._tf_api_names_v1)
99    self.assertAllEqual(['nameA', 'nameB'],
100                        tf_export.get_v1_names(decorated_function))
101    self.assertEqual([],
102                     tf_export.get_v2_names(decorated_function))
103    self.assertEqual(tf_export.get_symbol_from_name('compat.v1.nameA'),
104                     decorated_function)
105    self.assertEqual(tf_export.get_symbol_from_name('compat.v1.nameB'),
106                     decorated_function)
107    self.assertEqual(
108        tf_export.get_symbol_from_name(
109            tf_export.get_canonical_name_for_symbol(
110                decorated_function, add_prefix_to_v1_names=True)),
111        decorated_function)
112
113  def testExportMultipleFunctions(self):
114    export_decorator1 = tf_export.tf_export('nameA', 'nameB')
115    export_decorator2 = tf_export.tf_export('nameC', 'nameD')
116    decorated_function1 = export_decorator1(_test_function)
117    decorated_function2 = export_decorator2(_test_function2)
118    self.assertEqual(decorated_function1, _test_function)
119    self.assertEqual(decorated_function2, _test_function2)
120    self.assertEqual(('nameA', 'nameB'), decorated_function1._tf_api_names)
121    self.assertEqual(('nameC', 'nameD'), decorated_function2._tf_api_names)
122    self.assertEqual(tf_export.get_symbol_from_name('nameB'),
123                     decorated_function1)
124    self.assertEqual(tf_export.get_symbol_from_name('nameD'),
125                     decorated_function2)
126    self.assertEqual(
127        tf_export.get_symbol_from_name(
128            tf_export.get_canonical_name_for_symbol(
129                decorated_function1)),
130        decorated_function1)
131    self.assertEqual(
132        tf_export.get_symbol_from_name(
133            tf_export.get_canonical_name_for_symbol(
134                decorated_function2)),
135        decorated_function2)
136
137  def testExportClasses(self):
138    export_decorator_a = tf_export.tf_export('TestClassA1')
139    export_decorator_a(TestClassA)
140    self.assertEqual(('TestClassA1',), TestClassA._tf_api_names)
141    self.assertTrue('_tf_api_names' not in TestClassB.__dict__)
142
143    export_decorator_b = tf_export.tf_export('TestClassB1')
144    export_decorator_b(TestClassB)
145    self.assertEqual(('TestClassA1',), TestClassA._tf_api_names)
146    self.assertEqual(('TestClassB1',), TestClassB._tf_api_names)
147    self.assertEqual(['TestClassA1'], tf_export.get_v1_names(TestClassA))
148    self.assertEqual(['TestClassB1'], tf_export.get_v1_names(TestClassB))
149
150  def testExportClassInEstimator(self):
151    export_decorator_a = tf_export.tf_export('TestClassA1')
152    export_decorator_a(TestClassA)
153    self.assertEqual(('TestClassA1',), TestClassA._tf_api_names)
154
155    export_decorator_b = tf_export.estimator_export(
156        'estimator.TestClassB1')
157    export_decorator_b(TestClassB)
158    self.assertTrue('_tf_api_names' not in TestClassB.__dict__)
159    self.assertEqual(('TestClassA1',), TestClassA._tf_api_names)
160    self.assertEqual(['TestClassA1'], tf_export.get_v1_names(TestClassA))
161    self.assertEqual(['estimator.TestClassB1'],
162                     tf_export.get_v1_names(TestClassB))
163
164  def testExportSingleConstant(self):
165    module1 = self._CreateMockModule('module1')
166
167    export_decorator = tf_export.tf_export('NAME_A', 'NAME_B')
168    export_decorator.export_constant('module1', 'test_constant')
169    self.assertEqual([(('NAME_A', 'NAME_B'), 'test_constant')],
170                     module1._tf_api_constants)
171    self.assertEqual([(('NAME_A', 'NAME_B'), 'test_constant')],
172                     tf_export.get_v1_constants(module1))
173    self.assertEqual([(('NAME_A', 'NAME_B'), 'test_constant')],
174                     tf_export.get_v2_constants(module1))
175
176  def testExportMultipleConstants(self):
177    module1 = self._CreateMockModule('module1')
178    module2 = self._CreateMockModule('module2')
179
180    test_constant1 = 123
181    test_constant2 = 'abc'
182    test_constant3 = 0.5
183
184    export_decorator1 = tf_export.tf_export('NAME_A', 'NAME_B')
185    export_decorator2 = tf_export.tf_export('NAME_C', 'NAME_D')
186    export_decorator3 = tf_export.tf_export('NAME_E', 'NAME_F')
187    export_decorator1.export_constant('module1', test_constant1)
188    export_decorator2.export_constant('module2', test_constant2)
189    export_decorator3.export_constant('module2', test_constant3)
190    self.assertEqual([(('NAME_A', 'NAME_B'), 123)], module1._tf_api_constants)
191    self.assertEqual([(('NAME_C', 'NAME_D'), 'abc'),
192                      (('NAME_E', 'NAME_F'), 0.5)], module2._tf_api_constants)
193
194  def testRaisesExceptionIfAlreadyHasAPINames(self):
195    _test_function._tf_api_names = ['abc']
196    export_decorator = tf_export.tf_export('nameA', 'nameB')
197    with self.assertRaises(tf_export.SymbolAlreadyExposedError):
198      export_decorator(_test_function)
199
200  def testRaisesExceptionIfInvalidSymbolName(self):
201    # TensorFlow code is not allowed to export symbols under package
202    # tf.estimator
203    with self.assertRaises(tf_export.InvalidSymbolNameError):
204      tf_export.tf_export('estimator.invalid')
205
206    # All symbols exported by Estimator must be under tf.estimator package.
207    with self.assertRaises(tf_export.InvalidSymbolNameError):
208      tf_export.estimator_export('invalid')
209    with self.assertRaises(tf_export.InvalidSymbolNameError):
210      tf_export.estimator_export('Estimator.invalid')
211    with self.assertRaises(tf_export.InvalidSymbolNameError):
212      tf_export.estimator_export('invalid.estimator')
213
214  def testRaisesExceptionIfInvalidV1SymbolName(self):
215    with self.assertRaises(tf_export.InvalidSymbolNameError):
216      tf_export.tf_export('valid', v1=['estimator.invalid'])
217    with self.assertRaises(tf_export.InvalidSymbolNameError):
218      tf_export.estimator_export('estimator.valid', v1=['invalid'])
219
220  def testOverridesFunction(self):
221    _test_function2._tf_api_names = ['abc']
222
223    export_decorator = tf_export.tf_export(
224        'nameA', 'nameB', overrides=[_test_function2])
225    export_decorator(_test_function)
226
227    # _test_function overrides _test_function2. So, _tf_api_names
228    # should be removed from _test_function2.
229    self.assertFalse(hasattr(_test_function2, '_tf_api_names'))
230
231  def testMultipleDecorators(self):
232    def get_wrapper(func):
233      def wrapper(*unused_args, **unused_kwargs):
234        pass
235      return tf_decorator.make_decorator(func, wrapper)
236    decorated_function = get_wrapper(_test_function)
237
238    export_decorator = tf_export.tf_export('nameA', 'nameB')
239    exported_function = export_decorator(decorated_function)
240    self.assertEqual(decorated_function, exported_function)
241    self.assertEqual(('nameA', 'nameB'), _test_function._tf_api_names)
242
243
244if __name__ == '__main__':
245  test.main()
246