1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Tests for api_init_files.bzl and api_init_files_v1.bzl.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import sys 21 22# The unused imports are needed so that the python and lite modules are 23# available in sys.modules 24# pylint: disable=unused-import 25from tensorflow import python as _tf_for_api_traversal 26from tensorflow.lite.python import lite as _tflite_for_api_traversal 27# pylint: enable=unused-import 28from tensorflow.python.platform import test 29from tensorflow.python.util import tf_decorator 30 31 32def _get_module_from_symbol(symbol): 33 if '.' not in symbol: 34 return '' 35 return '.'.join(symbol.split('.')[:-1]) 36 37 38def _get_modules(package, attr_name, constants_attr_name): 39 """Get list of TF API modules. 40 41 Args: 42 package: We only look at modules that contain package in the name. 43 attr_name: Attribute set on TF symbols that contains API names. 44 constants_attr_name: Attribute set on TF modules that contains 45 API constant names. 46 47 Returns: 48 Set of TensorFlow API modules. 49 """ 50 modules = set() 51 # TODO(annarev): split up the logic in create_python_api.py so that 52 # it can be reused in this test. 53 for module in list(sys.modules.values()): 54 if (not module or not hasattr(module, '__name__') or 55 package not in module.__name__): 56 continue 57 58 for module_contents_name in dir(module): 59 attr = getattr(module, module_contents_name) 60 _, attr = tf_decorator.unwrap(attr) 61 62 # Add modules to _tf_api_constants attribute. 63 if module_contents_name == constants_attr_name: 64 for exports, _ in attr: 65 modules.update( 66 [_get_module_from_symbol(export) for export in exports]) 67 continue 68 69 # Add modules for _tf_api_names attribute. 70 if (hasattr(attr, '__dict__') and attr_name in attr.__dict__): 71 modules.update([ 72 _get_module_from_symbol(export) 73 for export in getattr(attr, attr_name)]) 74 return modules 75 76 77def _get_files_set(path, start_tag, end_tag): 78 """Get set of file paths from the given file. 79 80 Args: 81 path: Path to file. File at `path` is expected to contain a list of paths 82 where entire list starts with `start_tag` and ends with `end_tag`. List 83 must be comma-separated and each path entry must be surrounded by double 84 quotes. 85 start_tag: String that indicates start of path list. 86 end_tag: String that indicates end of path list. 87 88 Returns: 89 List of string paths. 90 """ 91 with open(path, 'r') as f: 92 contents = f.read() 93 start = contents.find(start_tag) + len(start_tag) + 1 94 end = contents.find(end_tag) 95 contents = contents[start:end] 96 file_paths = [ 97 file_path.strip().strip('"') for file_path in contents.split(',')] 98 return set(file_path for file_path in file_paths if file_path) 99 100 101def _module_to_paths(module): 102 """Get all API __init__.py file paths for the given module. 103 104 Args: 105 module: Module to get file paths for. 106 107 Returns: 108 List of paths for the given module. For e.g. module foo.bar 109 requires 'foo/__init__.py' and 'foo/bar/__init__.py'. 110 """ 111 submodules = [] 112 module_segments = module.split('.') 113 for i in range(len(module_segments)): 114 submodules.append('.'.join(module_segments[:i+1])) 115 paths = [] 116 for submodule in submodules: 117 if not submodule: 118 paths.append('__init__.py') 119 continue 120 paths.append('%s/__init__.py' % (submodule.replace('.', '/'))) 121 return paths 122 123 124class OutputInitFilesTest(test.TestCase): 125 """Test that verifies files that list paths for TensorFlow API.""" 126 127 def _validate_paths_for_modules( 128 self, actual_paths, expected_paths, file_to_update_on_error): 129 """Validates that actual_paths match expected_paths. 130 131 Args: 132 actual_paths: */__init__.py file paths listed in file_to_update_on_error. 133 expected_paths: */__init__.py file paths that we need to create for 134 TensorFlow API. 135 file_to_update_on_error: File that contains list of */__init__.py files. 136 We include it in error message printed if the file list needs to be 137 updated. 138 """ 139 self.assertTrue(actual_paths) 140 self.assertTrue(expected_paths) 141 missing_paths = expected_paths - actual_paths 142 extra_paths = actual_paths - expected_paths 143 144 # Surround paths with quotes so that they can be copy-pasted 145 # from error messages as strings. 146 missing_paths = ['\'%s\'' % path for path in missing_paths] 147 extra_paths = ['\'%s\'' % path for path in extra_paths] 148 149 self.assertFalse( 150 missing_paths, 151 'Please add %s to %s.' % ( 152 ',\n'.join(sorted(missing_paths)), file_to_update_on_error)) 153 self.assertFalse( 154 extra_paths, 155 'Redundant paths, please remove %s in %s.' % ( 156 ',\n'.join(sorted(extra_paths)), file_to_update_on_error)) 157 158 def test_V2_init_files(self): 159 modules = _get_modules( 160 'tensorflow', '_tf_api_names', '_tf_api_constants') 161 file_path = ( 162 'tensorflow/python/tools/api/generator/api_init_files.bzl') 163 paths = _get_files_set( 164 file_path, '# BEGIN GENERATED FILES', '# END GENERATED FILES') 165 module_paths = set( 166 f for module in modules for f in _module_to_paths(module)) 167 self._validate_paths_for_modules( 168 paths, module_paths, file_to_update_on_error=file_path) 169 170 def test_V1_init_files(self): 171 modules = _get_modules( 172 'tensorflow', '_tf_api_names_v1', '_tf_api_constants_v1') 173 file_path = ( 174 'tensorflow/python/tools/api/generator/' 175 'api_init_files_v1.bzl') 176 paths = _get_files_set( 177 file_path, '# BEGIN GENERATED FILES', '# END GENERATED FILES') 178 module_paths = set( 179 f for module in modules for f in _module_to_paths(module)) 180 self._validate_paths_for_modules( 181 paths, module_paths, file_to_update_on_error=file_path) 182 183 184if __name__ == '__main__': 185 test.main() 186