1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Tests for api_init_files.bzl and api_init_files_v1.bzl.""" 16import sys 17 18# The unused imports are needed so that the python and lite modules are 19# available in sys.modules 20# pylint: disable=unused-import 21from tensorflow import python as _tf_for_api_traversal 22from tensorflow.dtensor import python as _dtensor_for_api_traversal 23from tensorflow.lite.python import lite as _tflite_for_api_traversal 24from tensorflow.lite.python.authoring import authoring 25from tensorflow.python import modules_with_exports 26from tensorflow.python.distribute import merge_call_interim 27from tensorflow.python.distribute import multi_process_runner 28from tensorflow.python.distribute import multi_worker_test_base 29from tensorflow.python.distribute import parameter_server_strategy_v2 30from tensorflow.python.distribute import sharded_variable 31from tensorflow.python.distribute.coordinator import cluster_coordinator 32from tensorflow.python.distribute.failure_handling import failure_handling 33from tensorflow.python.framework import combinations 34from tensorflow.python.framework import test_combinations 35# pylint: enable=unused-import 36from tensorflow.python.platform import resource_loader 37from tensorflow.python.platform import test 38from tensorflow.python.util import tf_decorator 39 40 41def _get_module_from_symbol(symbol): 42 if '.' not in symbol: 43 return '' 44 return '.'.join(symbol.split('.')[:-1]) 45 46 47def _get_modules(package, attr_name, constants_attr_name): 48 """Get list of TF API modules. 49 50 Args: 51 package: We only look at modules that contain package in the name. 52 attr_name: Attribute set on TF symbols that contains API names. 53 constants_attr_name: Attribute set on TF modules that contains 54 API constant names. 55 56 Returns: 57 Set of TensorFlow API modules. 58 """ 59 modules = set() 60 # TODO(annarev): split up the logic in create_python_api.py so that 61 # it can be reused in this test. 62 for module in list(sys.modules.values()): 63 if (not module or not hasattr(module, '__name__') or 64 package not in module.__name__): 65 continue 66 67 for module_contents_name in dir(module): 68 attr = getattr(module, module_contents_name) 69 _, attr = tf_decorator.unwrap(attr) 70 71 # Add modules to _tf_api_constants attribute. 72 if module_contents_name == constants_attr_name: 73 for exports, _ in attr: 74 modules.update( 75 [_get_module_from_symbol(export) for export in exports]) 76 continue 77 78 # Add modules for _tf_api_names attribute. 79 if (hasattr(attr, '__dict__') and attr_name in attr.__dict__): 80 modules.update([ 81 _get_module_from_symbol(export) 82 for export in getattr(attr, attr_name)]) 83 return modules 84 85 86def _get_files_set(path, start_tag, end_tag): 87 """Get set of file paths from the given file. 88 89 Args: 90 path: Path to file. File at `path` is expected to contain a list of paths 91 where entire list starts with `start_tag` and ends with `end_tag`. List 92 must be comma-separated and each path entry must be surrounded by double 93 quotes. 94 start_tag: String that indicates start of path list. 95 end_tag: String that indicates end of path list. 96 97 Returns: 98 List of string paths. 99 """ 100 with open(path, 'r') as f: 101 contents = f.read() 102 start = contents.find(start_tag) + len(start_tag) + 1 103 end = contents.find(end_tag) 104 contents = contents[start:end] 105 file_paths = [ 106 file_path.strip().strip('"') for file_path in contents.split(',')] 107 return set(file_path for file_path in file_paths if file_path) 108 109 110def _module_to_paths(module): 111 """Get all API __init__.py file paths for the given module. 112 113 Args: 114 module: Module to get file paths for. 115 116 Returns: 117 List of paths for the given module. For e.g. module foo.bar 118 requires 'foo/__init__.py' and 'foo/bar/__init__.py'. 119 """ 120 submodules = [] 121 module_segments = module.split('.') 122 for i in range(len(module_segments)): 123 submodules.append('.'.join(module_segments[:i+1])) 124 paths = [] 125 for submodule in submodules: 126 if not submodule: 127 paths.append('__init__.py') 128 continue 129 paths.append('%s/__init__.py' % (submodule.replace('.', '/'))) 130 return paths 131 132 133class OutputInitFilesTest(test.TestCase): 134 """Test that verifies files that list paths for TensorFlow API.""" 135 136 def _validate_paths_for_modules( 137 self, actual_paths, expected_paths, file_to_update_on_error): 138 """Validates that actual_paths match expected_paths. 139 140 Args: 141 actual_paths: */__init__.py file paths listed in file_to_update_on_error. 142 expected_paths: */__init__.py file paths that we need to create for 143 TensorFlow API. 144 file_to_update_on_error: File that contains list of */__init__.py files. 145 We include it in error message printed if the file list needs to be 146 updated. 147 """ 148 self.assertTrue(actual_paths) 149 self.assertTrue(expected_paths) 150 missing_paths = expected_paths - actual_paths 151 extra_paths = actual_paths - expected_paths 152 153 # Surround paths with quotes so that they can be copy-pasted 154 # from error messages as strings. 155 missing_paths = ['\'%s\'' % path for path in missing_paths] 156 extra_paths = ['\'%s\'' % path for path in extra_paths] 157 158 self.assertFalse( 159 missing_paths, 160 'Please add %s to %s.' % ( 161 ',\n'.join(sorted(missing_paths)), file_to_update_on_error)) 162 self.assertFalse( 163 extra_paths, 164 'Redundant paths, please remove %s in %s.' % ( 165 ',\n'.join(sorted(extra_paths)), file_to_update_on_error)) 166 167 def test_V2_init_files(self): 168 modules = _get_modules( 169 'tensorflow', '_tf_api_names', '_tf_api_constants') 170 file_path = resource_loader.get_path_to_datafile( 171 'api_init_files.bzl') 172 paths = _get_files_set( 173 file_path, '# BEGIN GENERATED FILES', '# END GENERATED FILES') 174 module_paths = set( 175 f for module in modules for f in _module_to_paths(module)) 176 self._validate_paths_for_modules( 177 paths, module_paths, file_to_update_on_error=file_path) 178 179 def test_V1_init_files(self): 180 modules = _get_modules( 181 'tensorflow', '_tf_api_names_v1', '_tf_api_constants_v1') 182 file_path = resource_loader.get_path_to_datafile( 183 'api_init_files_v1.bzl') 184 paths = _get_files_set( 185 file_path, '# BEGIN GENERATED FILES', '# END GENERATED FILES') 186 module_paths = set( 187 f for module in modules for f in _module_to_paths(module)) 188 self._validate_paths_for_modules( 189 paths, module_paths, file_to_update_on_error=file_path) 190 191 192if __name__ == '__main__': 193 test.main() 194