1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Tests for api_init_files.bzl and api_init_files_v1.bzl.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import sys 21 22# The unused imports are needed so that the python and lite modules are 23# available in sys.modules 24# pylint: disable=unused-import 25from tensorflow import python as _tf_for_api_traversal 26from tensorflow.lite.python import lite as _tflite_for_api_traversal 27from tensorflow.python import modules_with_exports 28from tensorflow.python.distribute import multi_process_runner 29from tensorflow.python.distribute import multi_worker_test_base 30from tensorflow.python.distribute import parameter_server_strategy_v2 31from tensorflow.python.distribute.coordinator import cluster_coordinator 32from tensorflow.python.framework import combinations 33from tensorflow.python.framework import test_combinations 34# pylint: enable=unused-import 35from tensorflow.python.platform import resource_loader 36from tensorflow.python.platform import test 37from tensorflow.python.util import tf_decorator 38 39 40def _get_module_from_symbol(symbol): 41 if '.' not in symbol: 42 return '' 43 return '.'.join(symbol.split('.')[:-1]) 44 45 46def _get_modules(package, attr_name, constants_attr_name): 47 """Get list of TF API modules. 48 49 Args: 50 package: We only look at modules that contain package in the name. 51 attr_name: Attribute set on TF symbols that contains API names. 52 constants_attr_name: Attribute set on TF modules that contains 53 API constant names. 54 55 Returns: 56 Set of TensorFlow API modules. 57 """ 58 modules = set() 59 # TODO(annarev): split up the logic in create_python_api.py so that 60 # it can be reused in this test. 61 for module in list(sys.modules.values()): 62 if (not module or not hasattr(module, '__name__') or 63 package not in module.__name__): 64 continue 65 66 for module_contents_name in dir(module): 67 attr = getattr(module, module_contents_name) 68 _, attr = tf_decorator.unwrap(attr) 69 70 # Add modules to _tf_api_constants attribute. 71 if module_contents_name == constants_attr_name: 72 for exports, _ in attr: 73 modules.update( 74 [_get_module_from_symbol(export) for export in exports]) 75 continue 76 77 # Add modules for _tf_api_names attribute. 78 if (hasattr(attr, '__dict__') and attr_name in attr.__dict__): 79 modules.update([ 80 _get_module_from_symbol(export) 81 for export in getattr(attr, attr_name)]) 82 return modules 83 84 85def _get_files_set(path, start_tag, end_tag): 86 """Get set of file paths from the given file. 87 88 Args: 89 path: Path to file. File at `path` is expected to contain a list of paths 90 where entire list starts with `start_tag` and ends with `end_tag`. List 91 must be comma-separated and each path entry must be surrounded by double 92 quotes. 93 start_tag: String that indicates start of path list. 94 end_tag: String that indicates end of path list. 95 96 Returns: 97 List of string paths. 98 """ 99 with open(path, 'r') as f: 100 contents = f.read() 101 start = contents.find(start_tag) + len(start_tag) + 1 102 end = contents.find(end_tag) 103 contents = contents[start:end] 104 file_paths = [ 105 file_path.strip().strip('"') for file_path in contents.split(',')] 106 return set(file_path for file_path in file_paths if file_path) 107 108 109def _module_to_paths(module): 110 """Get all API __init__.py file paths for the given module. 111 112 Args: 113 module: Module to get file paths for. 114 115 Returns: 116 List of paths for the given module. For e.g. module foo.bar 117 requires 'foo/__init__.py' and 'foo/bar/__init__.py'. 118 """ 119 submodules = [] 120 module_segments = module.split('.') 121 for i in range(len(module_segments)): 122 submodules.append('.'.join(module_segments[:i+1])) 123 paths = [] 124 for submodule in submodules: 125 if not submodule: 126 paths.append('__init__.py') 127 continue 128 paths.append('%s/__init__.py' % (submodule.replace('.', '/'))) 129 return paths 130 131 132class OutputInitFilesTest(test.TestCase): 133 """Test that verifies files that list paths for TensorFlow API.""" 134 135 def _validate_paths_for_modules( 136 self, actual_paths, expected_paths, file_to_update_on_error): 137 """Validates that actual_paths match expected_paths. 138 139 Args: 140 actual_paths: */__init__.py file paths listed in file_to_update_on_error. 141 expected_paths: */__init__.py file paths that we need to create for 142 TensorFlow API. 143 file_to_update_on_error: File that contains list of */__init__.py files. 144 We include it in error message printed if the file list needs to be 145 updated. 146 """ 147 self.assertTrue(actual_paths) 148 self.assertTrue(expected_paths) 149 missing_paths = expected_paths - actual_paths 150 extra_paths = actual_paths - expected_paths 151 152 # Surround paths with quotes so that they can be copy-pasted 153 # from error messages as strings. 154 missing_paths = ['\'%s\'' % path for path in missing_paths] 155 extra_paths = ['\'%s\'' % path for path in extra_paths] 156 157 self.assertFalse( 158 missing_paths, 159 'Please add %s to %s.' % ( 160 ',\n'.join(sorted(missing_paths)), file_to_update_on_error)) 161 self.assertFalse( 162 extra_paths, 163 'Redundant paths, please remove %s in %s.' % ( 164 ',\n'.join(sorted(extra_paths)), file_to_update_on_error)) 165 166 def test_V2_init_files(self): 167 modules = _get_modules( 168 'tensorflow', '_tf_api_names', '_tf_api_constants') 169 file_path = resource_loader.get_path_to_datafile( 170 'api_init_files.bzl') 171 paths = _get_files_set( 172 file_path, '# BEGIN GENERATED FILES', '# END GENERATED FILES') 173 module_paths = set( 174 f for module in modules for f in _module_to_paths(module)) 175 self._validate_paths_for_modules( 176 paths, module_paths, file_to_update_on_error=file_path) 177 178 def test_V1_init_files(self): 179 modules = _get_modules( 180 'tensorflow', '_tf_api_names_v1', '_tf_api_constants_v1') 181 file_path = resource_loader.get_path_to_datafile( 182 'api_init_files_v1.bzl') 183 paths = _get_files_set( 184 file_path, '# BEGIN GENERATED FILES', '# END GENERATED FILES') 185 module_paths = set( 186 f for module in modules for f in _module_to_paths(module)) 187 self._validate_paths_for_modules( 188 paths, module_paths, file_to_update_on_error=file_path) 189 190 191if __name__ == '__main__': 192 test.main() 193