• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Tests for api_init_files.bzl and api_init_files_v1.bzl."""
16import sys
17
18# The unused imports are needed so that the python and lite modules are
19# available in sys.modules
20# pylint: disable=unused-import
21from tensorflow import python as _tf_for_api_traversal
22from tensorflow.dtensor import python as _dtensor_for_api_traversal
23from tensorflow.lite.python import lite as _tflite_for_api_traversal
24from tensorflow.lite.python.authoring import authoring
25from tensorflow.python import modules_with_exports
26from tensorflow.python.distribute import merge_call_interim
27from tensorflow.python.distribute import multi_process_runner
28from tensorflow.python.distribute import multi_worker_test_base
29from tensorflow.python.distribute import parameter_server_strategy_v2
30from tensorflow.python.distribute import sharded_variable
31from tensorflow.python.distribute.coordinator import cluster_coordinator
32from tensorflow.python.distribute.failure_handling import failure_handling
33from tensorflow.python.framework import combinations
34from tensorflow.python.framework import test_combinations
35# pylint: enable=unused-import
36from tensorflow.python.platform import resource_loader
37from tensorflow.python.platform import test
38from tensorflow.python.util import tf_decorator
39
40
41def _get_module_from_symbol(symbol):
42  if '.' not in symbol:
43    return ''
44  return '.'.join(symbol.split('.')[:-1])
45
46
47def _get_modules(package, attr_name, constants_attr_name):
48  """Get list of TF API modules.
49
50  Args:
51    package: We only look at modules that contain package in the name.
52    attr_name: Attribute set on TF symbols that contains API names.
53    constants_attr_name: Attribute set on TF modules that contains
54      API constant names.
55
56  Returns:
57    Set of TensorFlow API modules.
58  """
59  modules = set()
60  # TODO(annarev): split up the logic in create_python_api.py so that
61  #   it can be reused in this test.
62  for module in list(sys.modules.values()):
63    if (not module or not hasattr(module, '__name__') or
64        package not in module.__name__):
65      continue
66
67    for module_contents_name in dir(module):
68      attr = getattr(module, module_contents_name)
69      _, attr = tf_decorator.unwrap(attr)
70
71      # Add modules to _tf_api_constants attribute.
72      if module_contents_name == constants_attr_name:
73        for exports, _ in attr:
74          modules.update(
75              [_get_module_from_symbol(export) for export in exports])
76        continue
77
78      # Add modules for _tf_api_names attribute.
79      if (hasattr(attr, '__dict__') and attr_name in attr.__dict__):
80        modules.update([
81            _get_module_from_symbol(export)
82            for export in getattr(attr, attr_name)])
83  return modules
84
85
86def _get_files_set(path, start_tag, end_tag):
87  """Get set of file paths from the given file.
88
89  Args:
90    path: Path to file. File at `path` is expected to contain a list of paths
91      where entire list starts with `start_tag` and ends with `end_tag`. List
92      must be comma-separated and each path entry must be surrounded by double
93      quotes.
94    start_tag: String that indicates start of path list.
95    end_tag: String that indicates end of path list.
96
97  Returns:
98    List of string paths.
99  """
100  with open(path, 'r') as f:
101    contents = f.read()
102    start = contents.find(start_tag) + len(start_tag) + 1
103    end = contents.find(end_tag)
104    contents = contents[start:end]
105    file_paths = [
106        file_path.strip().strip('"') for file_path in contents.split(',')]
107    return set(file_path for file_path in file_paths if file_path)
108
109
110def _module_to_paths(module):
111  """Get all API __init__.py file paths for the given module.
112
113  Args:
114    module: Module to get file paths for.
115
116  Returns:
117    List of paths for the given module. For e.g. module foo.bar
118    requires 'foo/__init__.py' and 'foo/bar/__init__.py'.
119  """
120  submodules = []
121  module_segments = module.split('.')
122  for i in range(len(module_segments)):
123    submodules.append('.'.join(module_segments[:i+1]))
124  paths = []
125  for submodule in submodules:
126    if not submodule:
127      paths.append('__init__.py')
128      continue
129    paths.append('%s/__init__.py' % (submodule.replace('.', '/')))
130  return paths
131
132
133class OutputInitFilesTest(test.TestCase):
134  """Test that verifies files that list paths for TensorFlow API."""
135
136  def _validate_paths_for_modules(
137      self, actual_paths, expected_paths, file_to_update_on_error):
138    """Validates that actual_paths match expected_paths.
139
140    Args:
141      actual_paths: */__init__.py file paths listed in file_to_update_on_error.
142      expected_paths: */__init__.py file paths that we need to create for
143        TensorFlow API.
144      file_to_update_on_error: File that contains list of */__init__.py files.
145        We include it in error message printed if the file list needs to be
146        updated.
147    """
148    self.assertTrue(actual_paths)
149    self.assertTrue(expected_paths)
150    missing_paths = expected_paths - actual_paths
151    extra_paths = actual_paths - expected_paths
152
153    # Surround paths with quotes so that they can be copy-pasted
154    # from error messages as strings.
155    missing_paths = ['\'%s\'' % path for path in missing_paths]
156    extra_paths = ['\'%s\'' % path for path in extra_paths]
157
158    self.assertFalse(
159        missing_paths,
160        'Please add %s to %s.' % (
161            ',\n'.join(sorted(missing_paths)), file_to_update_on_error))
162    self.assertFalse(
163        extra_paths,
164        'Redundant paths, please remove %s in %s.' % (
165            ',\n'.join(sorted(extra_paths)), file_to_update_on_error))
166
167  def test_V2_init_files(self):
168    modules = _get_modules(
169        'tensorflow', '_tf_api_names', '_tf_api_constants')
170    file_path = resource_loader.get_path_to_datafile(
171        'api_init_files.bzl')
172    paths = _get_files_set(
173        file_path, '# BEGIN GENERATED FILES', '# END GENERATED FILES')
174    module_paths = set(
175        f for module in modules for f in _module_to_paths(module))
176    self._validate_paths_for_modules(
177        paths, module_paths, file_to_update_on_error=file_path)
178
179  def test_V1_init_files(self):
180    modules = _get_modules(
181        'tensorflow', '_tf_api_names_v1', '_tf_api_constants_v1')
182    file_path = resource_loader.get_path_to_datafile(
183        'api_init_files_v1.bzl')
184    paths = _get_files_set(
185        file_path, '# BEGIN GENERATED FILES', '# END GENERATED FILES')
186    module_paths = set(
187        f for module in modules for f in _module_to_paths(module))
188    self._validate_paths_for_modules(
189        paths, module_paths, file_to_update_on_error=file_path)
190
191
192if __name__ == '__main__':
193  test.main()
194