• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Tests for api_init_files.bzl and api_init_files_v1.bzl."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import sys
21
22# The unused imports are needed so that the python and lite modules are
23# available in sys.modules
24# pylint: disable=unused-import
25from tensorflow import python as _tf_for_api_traversal
26from tensorflow.lite.python import lite as _tflite_for_api_traversal
27# pylint: enable=unused-import
28from tensorflow.python.platform import test
29from tensorflow.python.util import tf_decorator
30
31
32def _get_module_from_symbol(symbol):
33  if '.' not in symbol:
34    return ''
35  return '.'.join(symbol.split('.')[:-1])
36
37
38def _get_modules(package, attr_name, constants_attr_name):
39  """Get list of TF API modules.
40
41  Args:
42    package: We only look at modules that contain package in the name.
43    attr_name: Attribute set on TF symbols that contains API names.
44    constants_attr_name: Attribute set on TF modules that contains
45      API constant names.
46
47  Returns:
48    Set of TensorFlow API modules.
49  """
50  modules = set()
51  # TODO(annarev): split up the logic in create_python_api.py so that
52  #   it can be reused in this test.
53  for module in list(sys.modules.values()):
54    if (not module or not hasattr(module, '__name__') or
55        package not in module.__name__):
56      continue
57
58    for module_contents_name in dir(module):
59      attr = getattr(module, module_contents_name)
60      _, attr = tf_decorator.unwrap(attr)
61
62      # Add modules to _tf_api_constants attribute.
63      if module_contents_name == constants_attr_name:
64        for exports, _ in attr:
65          modules.update(
66              [_get_module_from_symbol(export) for export in exports])
67        continue
68
69      # Add modules for _tf_api_names attribute.
70      if (hasattr(attr, '__dict__') and attr_name in attr.__dict__):
71        modules.update([
72            _get_module_from_symbol(export)
73            for export in getattr(attr, attr_name)])
74  return modules
75
76
77def _get_files_set(path, start_tag, end_tag):
78  """Get set of file paths from the given file.
79
80  Args:
81    path: Path to file. File at `path` is expected to contain a list of paths
82      where entire list starts with `start_tag` and ends with `end_tag`. List
83      must be comma-separated and each path entry must be surrounded by double
84      quotes.
85    start_tag: String that indicates start of path list.
86    end_tag: String that indicates end of path list.
87
88  Returns:
89    List of string paths.
90  """
91  with open(path, 'r') as f:
92    contents = f.read()
93    start = contents.find(start_tag) + len(start_tag) + 1
94    end = contents.find(end_tag)
95    contents = contents[start:end]
96    file_paths = [
97        file_path.strip().strip('"') for file_path in contents.split(',')]
98    return set(file_path for file_path in file_paths if file_path)
99
100
101def _module_to_paths(module):
102  """Get all API __init__.py file paths for the given module.
103
104  Args:
105    module: Module to get file paths for.
106
107  Returns:
108    List of paths for the given module. For e.g. module foo.bar
109    requires 'foo/__init__.py' and 'foo/bar/__init__.py'.
110  """
111  submodules = []
112  module_segments = module.split('.')
113  for i in range(len(module_segments)):
114    submodules.append('.'.join(module_segments[:i+1]))
115  paths = []
116  for submodule in submodules:
117    if not submodule:
118      paths.append('__init__.py')
119      continue
120    paths.append('%s/__init__.py' % (submodule.replace('.', '/')))
121  return paths
122
123
124class OutputInitFilesTest(test.TestCase):
125  """Test that verifies files that list paths for TensorFlow API."""
126
127  def _validate_paths_for_modules(
128      self, actual_paths, expected_paths, file_to_update_on_error):
129    """Validates that actual_paths match expected_paths.
130
131    Args:
132      actual_paths: */__init__.py file paths listed in file_to_update_on_error.
133      expected_paths: */__init__.py file paths that we need to create for
134        TensorFlow API.
135      file_to_update_on_error: File that contains list of */__init__.py files.
136        We include it in error message printed if the file list needs to be
137        updated.
138    """
139    self.assertTrue(actual_paths)
140    self.assertTrue(expected_paths)
141    missing_paths = expected_paths - actual_paths
142    extra_paths = actual_paths - expected_paths
143
144    # Surround paths with quotes so that they can be copy-pasted
145    # from error messages as strings.
146    missing_paths = ['\'%s\'' % path for path in missing_paths]
147    extra_paths = ['\'%s\'' % path for path in extra_paths]
148
149    self.assertFalse(
150        missing_paths,
151        'Please add %s to %s.' % (
152            ',\n'.join(sorted(missing_paths)), file_to_update_on_error))
153    self.assertFalse(
154        extra_paths,
155        'Redundant paths, please remove %s in %s.' % (
156            ',\n'.join(sorted(extra_paths)), file_to_update_on_error))
157
158  def test_V2_init_files(self):
159    modules = _get_modules(
160        'tensorflow', '_tf_api_names', '_tf_api_constants')
161    file_path = (
162        'tensorflow/python/tools/api/generator/api_init_files.bzl')
163    paths = _get_files_set(
164        file_path, '# BEGIN GENERATED FILES', '# END GENERATED FILES')
165    module_paths = set(
166        f for module in modules for f in _module_to_paths(module))
167    self._validate_paths_for_modules(
168        paths, module_paths, file_to_update_on_error=file_path)
169
170  def test_V1_init_files(self):
171    modules = _get_modules(
172        'tensorflow', '_tf_api_names_v1', '_tf_api_constants_v1')
173    file_path = (
174        'tensorflow/python/tools/api/generator/'
175        'api_init_files_v1.bzl')
176    paths = _get_files_set(
177        file_path, '# BEGIN GENERATED FILES', '# END GENERATED FILES')
178    module_paths = set(
179        f for module in modules for f in _module_to_paths(module))
180    self._validate_paths_for_modules(
181        paths, module_paths, file_to_update_on_error=file_path)
182
183
184if __name__ == '__main__':
185  test.main()
186