• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Tests for tensorflow.python.tools.api.generator.doc_srcs."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import argparse
22import importlib
23import sys
24
25from tensorflow.python.platform import test
26from tensorflow.python.tools.api.generator import doc_srcs
27
28
29FLAGS = None
30
31
32class DocSrcsTest(test.TestCase):
33
34  def testModulesAreValidAPIModules(self):
35    for module_name in doc_srcs.get_doc_sources(FLAGS.api_name):
36      # Convert module_name to corresponding __init__.py file path.
37      file_path = module_name.replace('.', '/')
38      if file_path:
39        file_path += '/'
40      file_path += '__init__.py'
41
42      self.assertIn(
43          file_path, FLAGS.outputs,
44          msg='%s is not a valid API module' % module_name)
45
46  def testHaveDocstringOrDocstringModule(self):
47    for module_name, docsrc in doc_srcs.get_doc_sources(FLAGS.api_name).items():
48      self.assertFalse(
49          docsrc.docstring and docsrc.docstring_module_name,
50          msg=('%s contains DocSource has both a docstring and a '
51               'docstring_module_name. Only one of "docstring" or '
52               '"docstring_module_name" should be set.') % (module_name))
53
54  def testDocstringModulesAreValidModules(self):
55    for _, docsrc in doc_srcs.get_doc_sources(FLAGS.api_name).items():
56      if docsrc.docstring_module_name:
57        doc_module_name = '.'.join([
58            FLAGS.package, docsrc.docstring_module_name])
59        self.assertIn(
60            doc_module_name, sys.modules,
61            msg=('docsources_module %s is not a valid module under %s.' %
62                 (docsrc.docstring_module_name, FLAGS.package)))
63
64
65if __name__ == '__main__':
66  parser = argparse.ArgumentParser()
67  parser.add_argument(
68      'outputs', metavar='O', type=str, nargs='+',
69      help='create_python_api output files.')
70  parser.add_argument(
71      '--package', type=str,
72      help='Base package that imports modules containing the target tf_export '
73           'decorators.')
74  parser.add_argument(
75      '--api_name', type=str,
76      help='API name: tensorflow or estimator')
77  FLAGS, unparsed = parser.parse_known_args()
78
79  importlib.import_module(FLAGS.package)
80
81  # Now update argv, so that unittest library does not get confused.
82  sys.argv = [sys.argv[0]] + unparsed
83  test.main()
84