• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Tests for api_init_files.bzl and api_init_files_v1.bzl."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import sys
21
22# The unused imports are needed so that the python and lite modules are
23# available in sys.modules
24# pylint: disable=unused-import
25from tensorflow import python as _tf_for_api_traversal
26from tensorflow.lite.python import lite as _tflite_for_api_traversal
27from tensorflow.python import modules_with_exports
28from tensorflow.python.distribute import multi_process_runner
29from tensorflow.python.distribute import multi_worker_test_base
30from tensorflow.python.distribute import parameter_server_strategy_v2
31from tensorflow.python.distribute.coordinator import cluster_coordinator
32from tensorflow.python.framework import combinations
33from tensorflow.python.framework import test_combinations
34# pylint: enable=unused-import
35from tensorflow.python.platform import resource_loader
36from tensorflow.python.platform import test
37from tensorflow.python.util import tf_decorator
38
39
40def _get_module_from_symbol(symbol):
41  if '.' not in symbol:
42    return ''
43  return '.'.join(symbol.split('.')[:-1])
44
45
46def _get_modules(package, attr_name, constants_attr_name):
47  """Get list of TF API modules.
48
49  Args:
50    package: We only look at modules that contain package in the name.
51    attr_name: Attribute set on TF symbols that contains API names.
52    constants_attr_name: Attribute set on TF modules that contains
53      API constant names.
54
55  Returns:
56    Set of TensorFlow API modules.
57  """
58  modules = set()
59  # TODO(annarev): split up the logic in create_python_api.py so that
60  #   it can be reused in this test.
61  for module in list(sys.modules.values()):
62    if (not module or not hasattr(module, '__name__') or
63        package not in module.__name__):
64      continue
65
66    for module_contents_name in dir(module):
67      attr = getattr(module, module_contents_name)
68      _, attr = tf_decorator.unwrap(attr)
69
70      # Add modules to _tf_api_constants attribute.
71      if module_contents_name == constants_attr_name:
72        for exports, _ in attr:
73          modules.update(
74              [_get_module_from_symbol(export) for export in exports])
75        continue
76
77      # Add modules for _tf_api_names attribute.
78      if (hasattr(attr, '__dict__') and attr_name in attr.__dict__):
79        modules.update([
80            _get_module_from_symbol(export)
81            for export in getattr(attr, attr_name)])
82  return modules
83
84
85def _get_files_set(path, start_tag, end_tag):
86  """Get set of file paths from the given file.
87
88  Args:
89    path: Path to file. File at `path` is expected to contain a list of paths
90      where entire list starts with `start_tag` and ends with `end_tag`. List
91      must be comma-separated and each path entry must be surrounded by double
92      quotes.
93    start_tag: String that indicates start of path list.
94    end_tag: String that indicates end of path list.
95
96  Returns:
97    List of string paths.
98  """
99  with open(path, 'r') as f:
100    contents = f.read()
101    start = contents.find(start_tag) + len(start_tag) + 1
102    end = contents.find(end_tag)
103    contents = contents[start:end]
104    file_paths = [
105        file_path.strip().strip('"') for file_path in contents.split(',')]
106    return set(file_path for file_path in file_paths if file_path)
107
108
109def _module_to_paths(module):
110  """Get all API __init__.py file paths for the given module.
111
112  Args:
113    module: Module to get file paths for.
114
115  Returns:
116    List of paths for the given module. For e.g. module foo.bar
117    requires 'foo/__init__.py' and 'foo/bar/__init__.py'.
118  """
119  submodules = []
120  module_segments = module.split('.')
121  for i in range(len(module_segments)):
122    submodules.append('.'.join(module_segments[:i+1]))
123  paths = []
124  for submodule in submodules:
125    if not submodule:
126      paths.append('__init__.py')
127      continue
128    paths.append('%s/__init__.py' % (submodule.replace('.', '/')))
129  return paths
130
131
132class OutputInitFilesTest(test.TestCase):
133  """Test that verifies files that list paths for TensorFlow API."""
134
135  def _validate_paths_for_modules(
136      self, actual_paths, expected_paths, file_to_update_on_error):
137    """Validates that actual_paths match expected_paths.
138
139    Args:
140      actual_paths: */__init__.py file paths listed in file_to_update_on_error.
141      expected_paths: */__init__.py file paths that we need to create for
142        TensorFlow API.
143      file_to_update_on_error: File that contains list of */__init__.py files.
144        We include it in error message printed if the file list needs to be
145        updated.
146    """
147    self.assertTrue(actual_paths)
148    self.assertTrue(expected_paths)
149    missing_paths = expected_paths - actual_paths
150    extra_paths = actual_paths - expected_paths
151
152    # Surround paths with quotes so that they can be copy-pasted
153    # from error messages as strings.
154    missing_paths = ['\'%s\'' % path for path in missing_paths]
155    extra_paths = ['\'%s\'' % path for path in extra_paths]
156
157    self.assertFalse(
158        missing_paths,
159        'Please add %s to %s.' % (
160            ',\n'.join(sorted(missing_paths)), file_to_update_on_error))
161    self.assertFalse(
162        extra_paths,
163        'Redundant paths, please remove %s in %s.' % (
164            ',\n'.join(sorted(extra_paths)), file_to_update_on_error))
165
166  def test_V2_init_files(self):
167    modules = _get_modules(
168        'tensorflow', '_tf_api_names', '_tf_api_constants')
169    file_path = resource_loader.get_path_to_datafile(
170        'api_init_files.bzl')
171    paths = _get_files_set(
172        file_path, '# BEGIN GENERATED FILES', '# END GENERATED FILES')
173    module_paths = set(
174        f for module in modules for f in _module_to_paths(module))
175    self._validate_paths_for_modules(
176        paths, module_paths, file_to_update_on_error=file_path)
177
178  def test_V1_init_files(self):
179    modules = _get_modules(
180        'tensorflow', '_tf_api_names_v1', '_tf_api_constants_v1')
181    file_path = resource_loader.get_path_to_datafile(
182        'api_init_files_v1.bzl')
183    paths = _get_files_set(
184        file_path, '# BEGIN GENERATED FILES', '# END GENERATED FILES')
185    module_paths = set(
186        f for module in modules for f in _module_to_paths(module))
187    self._validate_paths_for_modules(
188        paths, module_paths, file_to_update_on_error=file_path)
189
190
191if __name__ == '__main__':
192  test.main()
193