• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Tests for create_python_api."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import imp
22import sys
23
24from tensorflow.python.platform import test
25from tensorflow.python.tools.api.generator import create_python_api
26from tensorflow.python.util.tf_export import tf_export
27
28
29@tf_export('test_op', 'test_op1', 'test.test_op2')
30def test_op():
31  pass
32
33
34@tf_export('TestClass', 'NewTestClass')
35class TestClass(object):
36  pass
37
38
39_TEST_CONSTANT = 5
40_MODULE_NAME = 'tensorflow.python.test_module'
41
42
43class CreatePythonApiTest(test.TestCase):
44
45  def setUp(self):
46    # Add fake op to a module that has 'tensorflow' in the name.
47    sys.modules[_MODULE_NAME] = imp.new_module(_MODULE_NAME)
48    setattr(sys.modules[_MODULE_NAME], 'test_op', test_op)
49    setattr(sys.modules[_MODULE_NAME], 'TestClass', TestClass)
50    test_op.__module__ = _MODULE_NAME
51    TestClass.__module__ = _MODULE_NAME
52    tf_export('consts._TEST_CONSTANT').export_constant(
53        _MODULE_NAME, '_TEST_CONSTANT')
54
55  def tearDown(self):
56    del sys.modules[_MODULE_NAME]
57
58  def testFunctionImportIsAdded(self):
59    imports = create_python_api.get_api_init_text(
60        packages=[create_python_api._DEFAULT_PACKAGE],
61        output_package='tensorflow',
62        api_name='tensorflow',
63        api_version=1)
64    expected_import = (
65        'from tensorflow.python.test_module '
66        'import test_op as test_op1')
67    self.assertTrue(
68        expected_import in str(imports),
69        msg='%s not in %s' % (expected_import, str(imports)))
70
71    expected_import = ('from tensorflow.python.test_module '
72                       'import test_op')
73    self.assertTrue(
74        expected_import in str(imports),
75        msg='%s not in %s' % (expected_import, str(imports)))
76    # Also check that compat.v1 is not added to imports.
77    self.assertFalse('compat.v1' in imports,
78                     msg='compat.v1 in %s' % str(imports.keys()))
79
80  def testClassImportIsAdded(self):
81    imports = create_python_api.get_api_init_text(
82        packages=[create_python_api._DEFAULT_PACKAGE],
83        output_package='tensorflow',
84        api_name='tensorflow',
85        api_version=2)
86    expected_import = ('from tensorflow.python.test_module '
87                       'import TestClass')
88    self.assertTrue(
89        'TestClass' in str(imports),
90        msg='%s not in %s' % (expected_import, str(imports)))
91
92  def testConstantIsAdded(self):
93    imports = create_python_api.get_api_init_text(
94        packages=[create_python_api._DEFAULT_PACKAGE],
95        output_package='tensorflow',
96        api_name='tensorflow',
97        api_version=1)
98    expected = ('from tensorflow.python.test_module '
99                'import _TEST_CONSTANT')
100    self.assertTrue(expected in str(imports),
101                    msg='%s not in %s' % (expected, str(imports)))
102
103  def testCompatModuleIsAdded(self):
104    imports = create_python_api.get_api_init_text(
105        packages=[create_python_api._DEFAULT_PACKAGE],
106        output_package='tensorflow',
107        api_name='tensorflow',
108        api_version=2,
109        compat_api_versions=[1])
110    self.assertTrue('compat.v1' in imports,
111                    msg='compat.v1 not in %s' % str(imports.keys()))
112    self.assertTrue('compat.v1.test' in imports,
113                    msg='compat.v1.test not in %s' % str(imports.keys()))
114
115
116if __name__ == '__main__':
117  test.main()
118