1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""tf_export tests.""" 16 17# pylint: disable=unused-import 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import sys 23 24from tensorflow.python.platform import test 25from tensorflow.python.util import tf_decorator 26from tensorflow.python.util import tf_export 27 28 29def _test_function(unused_arg=0): 30 pass 31 32 33def _test_function2(unused_arg=0): 34 pass 35 36 37class TestClassA(object): 38 pass 39 40 41class TestClassB(TestClassA): 42 pass 43 44 45class ValidateExportTest(test.TestCase): 46 """Tests for tf_export class.""" 47 48 class MockModule(object): 49 50 def __init__(self, name): 51 self.__name__ = name 52 53 def setUp(self): 54 self._modules = [] 55 56 def tearDown(self): 57 for name in self._modules: 58 del sys.modules[name] 59 self._modules = [] 60 for symbol in [_test_function, _test_function, TestClassA, TestClassB]: 61 if hasattr(symbol, '_tf_api_names'): 62 del symbol._tf_api_names 63 if hasattr(symbol, '_tf_api_names_v1'): 64 del symbol._tf_api_names_v1 65 if hasattr(symbol, '_estimator_api_names'): 66 del symbol._estimator_api_names 67 if hasattr(symbol, '_estimator_api_names_v1'): 68 del symbol._estimator_api_names_v1 69 70 def _CreateMockModule(self, name): 71 mock_module = self.MockModule(name) 72 sys.modules[name] = mock_module 73 self._modules.append(name) 74 return mock_module 75 76 def testExportSingleFunction(self): 77 export_decorator = tf_export.tf_export('nameA', 'nameB') 78 decorated_function = export_decorator(_test_function) 79 self.assertEqual(decorated_function, _test_function) 80 self.assertEqual(('nameA', 'nameB'), decorated_function._tf_api_names) 81 self.assertEqual(['nameA', 'nameB'], 82 tf_export.get_v1_names(decorated_function)) 83 self.assertEqual(['nameA', 'nameB'], 84 tf_export.get_v2_names(decorated_function)) 85 self.assertEqual(tf_export.get_symbol_from_name('nameA'), 86 decorated_function) 87 self.assertEqual(tf_export.get_symbol_from_name('nameB'), 88 decorated_function) 89 self.assertEqual( 90 tf_export.get_symbol_from_name( 91 tf_export.get_canonical_name_for_symbol(decorated_function)), 92 decorated_function) 93 94 def testExportSingleFunctionV1Only(self): 95 export_decorator = tf_export.tf_export(v1=['nameA', 'nameB']) 96 decorated_function = export_decorator(_test_function) 97 self.assertEqual(decorated_function, _test_function) 98 self.assertAllEqual(('nameA', 'nameB'), decorated_function._tf_api_names_v1) 99 self.assertAllEqual(['nameA', 'nameB'], 100 tf_export.get_v1_names(decorated_function)) 101 self.assertEqual([], 102 tf_export.get_v2_names(decorated_function)) 103 self.assertEqual(tf_export.get_symbol_from_name('compat.v1.nameA'), 104 decorated_function) 105 self.assertEqual(tf_export.get_symbol_from_name('compat.v1.nameB'), 106 decorated_function) 107 self.assertEqual( 108 tf_export.get_symbol_from_name( 109 tf_export.get_canonical_name_for_symbol( 110 decorated_function, add_prefix_to_v1_names=True)), 111 decorated_function) 112 113 def testExportMultipleFunctions(self): 114 export_decorator1 = tf_export.tf_export('nameA', 'nameB') 115 export_decorator2 = tf_export.tf_export('nameC', 'nameD') 116 decorated_function1 = export_decorator1(_test_function) 117 decorated_function2 = export_decorator2(_test_function2) 118 self.assertEqual(decorated_function1, _test_function) 119 self.assertEqual(decorated_function2, _test_function2) 120 self.assertEqual(('nameA', 'nameB'), decorated_function1._tf_api_names) 121 self.assertEqual(('nameC', 'nameD'), decorated_function2._tf_api_names) 122 self.assertEqual(tf_export.get_symbol_from_name('nameB'), 123 decorated_function1) 124 self.assertEqual(tf_export.get_symbol_from_name('nameD'), 125 decorated_function2) 126 self.assertEqual( 127 tf_export.get_symbol_from_name( 128 tf_export.get_canonical_name_for_symbol( 129 decorated_function1)), 130 decorated_function1) 131 self.assertEqual( 132 tf_export.get_symbol_from_name( 133 tf_export.get_canonical_name_for_symbol( 134 decorated_function2)), 135 decorated_function2) 136 137 def testExportClasses(self): 138 export_decorator_a = tf_export.tf_export('TestClassA1') 139 export_decorator_a(TestClassA) 140 self.assertEqual(('TestClassA1',), TestClassA._tf_api_names) 141 self.assertTrue('_tf_api_names' not in TestClassB.__dict__) 142 143 export_decorator_b = tf_export.tf_export('TestClassB1') 144 export_decorator_b(TestClassB) 145 self.assertEqual(('TestClassA1',), TestClassA._tf_api_names) 146 self.assertEqual(('TestClassB1',), TestClassB._tf_api_names) 147 self.assertEqual(['TestClassA1'], tf_export.get_v1_names(TestClassA)) 148 self.assertEqual(['TestClassB1'], tf_export.get_v1_names(TestClassB)) 149 150 def testExportClassInEstimator(self): 151 export_decorator_a = tf_export.tf_export('TestClassA1') 152 export_decorator_a(TestClassA) 153 self.assertEqual(('TestClassA1',), TestClassA._tf_api_names) 154 155 export_decorator_b = tf_export.estimator_export( 156 'estimator.TestClassB1') 157 export_decorator_b(TestClassB) 158 self.assertTrue('_tf_api_names' not in TestClassB.__dict__) 159 self.assertEqual(('TestClassA1',), TestClassA._tf_api_names) 160 self.assertEqual(['TestClassA1'], tf_export.get_v1_names(TestClassA)) 161 self.assertEqual(['estimator.TestClassB1'], 162 tf_export.get_v1_names(TestClassB)) 163 164 def testExportSingleConstant(self): 165 module1 = self._CreateMockModule('module1') 166 167 export_decorator = tf_export.tf_export('NAME_A', 'NAME_B') 168 export_decorator.export_constant('module1', 'test_constant') 169 self.assertEqual([(('NAME_A', 'NAME_B'), 'test_constant')], 170 module1._tf_api_constants) 171 self.assertEqual([(('NAME_A', 'NAME_B'), 'test_constant')], 172 tf_export.get_v1_constants(module1)) 173 self.assertEqual([(('NAME_A', 'NAME_B'), 'test_constant')], 174 tf_export.get_v2_constants(module1)) 175 176 def testExportMultipleConstants(self): 177 module1 = self._CreateMockModule('module1') 178 module2 = self._CreateMockModule('module2') 179 180 test_constant1 = 123 181 test_constant2 = 'abc' 182 test_constant3 = 0.5 183 184 export_decorator1 = tf_export.tf_export('NAME_A', 'NAME_B') 185 export_decorator2 = tf_export.tf_export('NAME_C', 'NAME_D') 186 export_decorator3 = tf_export.tf_export('NAME_E', 'NAME_F') 187 export_decorator1.export_constant('module1', test_constant1) 188 export_decorator2.export_constant('module2', test_constant2) 189 export_decorator3.export_constant('module2', test_constant3) 190 self.assertEqual([(('NAME_A', 'NAME_B'), 123)], module1._tf_api_constants) 191 self.assertEqual([(('NAME_C', 'NAME_D'), 'abc'), 192 (('NAME_E', 'NAME_F'), 0.5)], module2._tf_api_constants) 193 194 def testRaisesExceptionIfAlreadyHasAPINames(self): 195 _test_function._tf_api_names = ['abc'] 196 export_decorator = tf_export.tf_export('nameA', 'nameB') 197 with self.assertRaises(tf_export.SymbolAlreadyExposedError): 198 export_decorator(_test_function) 199 200 def testRaisesExceptionIfInvalidSymbolName(self): 201 # TensorFlow code is not allowed to export symbols under package 202 # tf.estimator 203 with self.assertRaises(tf_export.InvalidSymbolNameError): 204 tf_export.tf_export('estimator.invalid') 205 206 # All symbols exported by Estimator must be under tf.estimator package. 207 with self.assertRaises(tf_export.InvalidSymbolNameError): 208 tf_export.estimator_export('invalid') 209 with self.assertRaises(tf_export.InvalidSymbolNameError): 210 tf_export.estimator_export('Estimator.invalid') 211 with self.assertRaises(tf_export.InvalidSymbolNameError): 212 tf_export.estimator_export('invalid.estimator') 213 214 def testRaisesExceptionIfInvalidV1SymbolName(self): 215 with self.assertRaises(tf_export.InvalidSymbolNameError): 216 tf_export.tf_export('valid', v1=['estimator.invalid']) 217 with self.assertRaises(tf_export.InvalidSymbolNameError): 218 tf_export.estimator_export('estimator.valid', v1=['invalid']) 219 220 def testOverridesFunction(self): 221 _test_function2._tf_api_names = ['abc'] 222 223 export_decorator = tf_export.tf_export( 224 'nameA', 'nameB', overrides=[_test_function2]) 225 export_decorator(_test_function) 226 227 # _test_function overrides _test_function2. So, _tf_api_names 228 # should be removed from _test_function2. 229 self.assertFalse(hasattr(_test_function2, '_tf_api_names')) 230 231 def testMultipleDecorators(self): 232 def get_wrapper(func): 233 def wrapper(*unused_args, **unused_kwargs): 234 pass 235 return tf_decorator.make_decorator(func, wrapper) 236 decorated_function = get_wrapper(_test_function) 237 238 export_decorator = tf_export.tf_export('nameA', 'nameB') 239 exported_function = export_decorator(decorated_function) 240 self.assertEqual(decorated_function, exported_function) 241 self.assertEqual(('nameA', 'nameB'), _test_function._tf_api_names) 242 243 244if __name__ == '__main__': 245 test.main() 246