1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Tests for create_python_api.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import imp 22import sys 23 24from tensorflow.python.platform import test 25from tensorflow.python.tools.api.generator import create_python_api 26from tensorflow.python.util.tf_export import tf_export 27 28 29@tf_export('test_op', 'test_op1', 'test.test_op2') 30def test_op(): 31 pass 32 33 34@tf_export('TestClass', 'NewTestClass') 35class TestClass(object): 36 pass 37 38 39_TEST_CONSTANT = 5 40_MODULE_NAME = 'tensorflow.python.test_module' 41 42 43class CreatePythonApiTest(test.TestCase): 44 45 def setUp(self): 46 # Add fake op to a module that has 'tensorflow' in the name. 47 sys.modules[_MODULE_NAME] = imp.new_module(_MODULE_NAME) 48 setattr(sys.modules[_MODULE_NAME], 'test_op', test_op) 49 setattr(sys.modules[_MODULE_NAME], 'TestClass', TestClass) 50 test_op.__module__ = _MODULE_NAME 51 TestClass.__module__ = _MODULE_NAME 52 tf_export('consts._TEST_CONSTANT').export_constant( 53 _MODULE_NAME, '_TEST_CONSTANT') 54 55 def tearDown(self): 56 del sys.modules[_MODULE_NAME] 57 58 def testFunctionImportIsAdded(self): 59 imports = create_python_api.get_api_init_text( 60 packages=[create_python_api._DEFAULT_PACKAGE], 61 output_package='tensorflow', 62 api_name='tensorflow', 63 api_version=1) 64 expected_import = ( 65 'from tensorflow.python.test_module ' 66 'import test_op as test_op1') 67 self.assertTrue( 68 expected_import in str(imports), 69 msg='%s not in %s' % (expected_import, str(imports))) 70 71 expected_import = ('from tensorflow.python.test_module ' 72 'import test_op') 73 self.assertTrue( 74 expected_import in str(imports), 75 msg='%s not in %s' % (expected_import, str(imports))) 76 # Also check that compat.v1 is not added to imports. 77 self.assertFalse('compat.v1' in imports, 78 msg='compat.v1 in %s' % str(imports.keys())) 79 80 def testClassImportIsAdded(self): 81 imports = create_python_api.get_api_init_text( 82 packages=[create_python_api._DEFAULT_PACKAGE], 83 output_package='tensorflow', 84 api_name='tensorflow', 85 api_version=2) 86 expected_import = ('from tensorflow.python.test_module ' 87 'import TestClass') 88 self.assertTrue( 89 'TestClass' in str(imports), 90 msg='%s not in %s' % (expected_import, str(imports))) 91 92 def testConstantIsAdded(self): 93 imports = create_python_api.get_api_init_text( 94 packages=[create_python_api._DEFAULT_PACKAGE], 95 output_package='tensorflow', 96 api_name='tensorflow', 97 api_version=1) 98 expected = ('from tensorflow.python.test_module ' 99 'import _TEST_CONSTANT') 100 self.assertTrue(expected in str(imports), 101 msg='%s not in %s' % (expected, str(imports))) 102 103 def testCompatModuleIsAdded(self): 104 imports = create_python_api.get_api_init_text( 105 packages=[create_python_api._DEFAULT_PACKAGE], 106 output_package='tensorflow', 107 api_name='tensorflow', 108 api_version=2, 109 compat_api_versions=[1]) 110 self.assertTrue('compat.v1' in imports, 111 msg='compat.v1 not in %s' % str(imports.keys())) 112 self.assertTrue('compat.v1.test' in imports, 113 msg='compat.v1.test not in %s' % str(imports.keys())) 114 115 116if __name__ == '__main__': 117 test.main() 118