1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Generates and prints out imports and constants for new TensorFlow python api. 16""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import argparse 22import collections 23import importlib 24import os 25import sys 26 27from tensorflow.python.tools.api.generator import doc_srcs 28from tensorflow.python.util import tf_decorator 29from tensorflow.python.util import tf_export 30 31API_ATTRS = tf_export.API_ATTRS 32API_ATTRS_V1 = tf_export.API_ATTRS_V1 33 34_API_VERSIONS = [1, 2] 35_COMPAT_MODULE_TEMPLATE = 'compat.v%d' 36_DEFAULT_PACKAGE = 'tensorflow.python' 37_GENFILES_DIR_SUFFIX = 'genfiles/' 38_SYMBOLS_TO_SKIP_EXPLICITLY = { 39 # Overrides __getattr__, so that unwrapping tf_decorator 40 # would have side effects. 41 'tensorflow.python.platform.flags.FLAGS' 42} 43_GENERATED_FILE_HEADER = """# This file is MACHINE GENERATED! Do not edit. 44# Generated by: tensorflow/python/tools/api/generator/create_python_api.py script. 45\"\"\"%s 46\"\"\" 47 48from __future__ import print_function as _print_function 49 50""" 51_GENERATED_FILE_FOOTER = '\n\ndel _print_function\n' 52 53 54class SymbolExposedTwiceError(Exception): 55 """Raised when different symbols are exported with the same name.""" 56 pass 57 58 59def format_import(source_module_name, source_name, dest_name): 60 """Formats import statement. 61 62 Args: 63 source_module_name: (string) Source module to import from. 64 source_name: (string) Source symbol name to import. 65 dest_name: (string) Destination alias name. 66 67 Returns: 68 An import statement string. 69 """ 70 if source_module_name: 71 if source_name == dest_name: 72 return 'from %s import %s' % (source_module_name, source_name) 73 else: 74 return 'from %s import %s as %s' % ( 75 source_module_name, source_name, dest_name) 76 else: 77 if source_name == dest_name: 78 return 'import %s' % source_name 79 else: 80 return 'import %s as %s' % (source_name, dest_name) 81 82 83class _ModuleInitCodeBuilder(object): 84 """Builds a map from module name to imports included in that module.""" 85 86 def __init__(self, output_package): 87 self._output_package = output_package 88 self._module_imports = collections.defaultdict( 89 lambda: collections.defaultdict(set)) 90 self._dest_import_to_id = collections.defaultdict(int) 91 # Names that start with underscore in the root module. 92 self._underscore_names_in_root = [] 93 94 def add_import( 95 self, symbol_id, dest_module_name, source_module_name, source_name, 96 dest_name): 97 """Adds this import to module_imports. 98 99 Args: 100 symbol_id: (number) Unique identifier of the symbol to import. 101 dest_module_name: (string) Module name to add import to. 102 source_module_name: (string) Module to import from. 103 source_name: (string) Name of the symbol to import. 104 dest_name: (string) Import the symbol using this name. 105 106 Raises: 107 SymbolExposedTwiceError: Raised when an import with the same 108 dest_name has already been added to dest_module_name. 109 """ 110 import_str = format_import(source_module_name, source_name, dest_name) 111 112 # Check if we are trying to expose two different symbols with same name. 113 full_api_name = dest_name 114 if dest_module_name: 115 full_api_name = dest_module_name + '.' + full_api_name 116 if (full_api_name in self._dest_import_to_id and 117 symbol_id != self._dest_import_to_id[full_api_name] and 118 symbol_id != -1): 119 raise SymbolExposedTwiceError( 120 'Trying to export multiple symbols with same name: %s.' % 121 full_api_name) 122 self._dest_import_to_id[full_api_name] = symbol_id 123 124 if not dest_module_name and dest_name.startswith('_'): 125 self._underscore_names_in_root.append(dest_name) 126 127 # The same symbol can be available in multiple modules. 128 # We store all possible ways of importing this symbol and later pick just 129 # one. 130 self._module_imports[dest_module_name][full_api_name].add(import_str) 131 132 def _import_submodules(self): 133 """Add imports for all destination modules in self._module_imports.""" 134 # Import all required modules in their parent modules. 135 # For e.g. if we import 'foo.bar.Value'. Then, we also 136 # import 'bar' in 'foo'. 137 imported_modules = set(self._module_imports.keys()) 138 for module in imported_modules: 139 if not module: 140 continue 141 module_split = module.split('.') 142 parent_module = '' # we import submodules in their parent_module 143 144 for submodule_index in range(len(module_split)): 145 if submodule_index > 0: 146 submodule = module_split[submodule_index-1] 147 parent_module += '.' + submodule if parent_module else submodule 148 import_from = self._output_package 149 if submodule_index > 0: 150 import_from += '.' + '.'.join(module_split[:submodule_index]) 151 self.add_import( 152 -1, parent_module, import_from, 153 module_split[submodule_index], module_split[submodule_index]) 154 155 def build(self): 156 """Get a map from destination module to __init__.py code for that module. 157 158 Returns: 159 A dictionary where 160 key: (string) destination module (for e.g. tf or tf.consts). 161 value: (string) text that should be in __init__.py files for 162 corresponding modules. 163 """ 164 self._import_submodules() 165 module_text_map = {} 166 for dest_module, dest_name_to_imports in self._module_imports.items(): 167 # Sort all possible imports for a symbol and pick the first one. 168 imports_list = [ 169 sorted(imports)[0] 170 for _, imports in dest_name_to_imports.items()] 171 module_text_map[dest_module] = '\n'.join(sorted(imports_list)) 172 173 # Expose exported symbols with underscores in root module 174 # since we import from it using * import. 175 underscore_names_str = ', '.join( 176 '\'%s\'' % name for name in self._underscore_names_in_root) 177 # We will always generate a root __init__.py file to let us handle * 178 # imports consistently. Be sure to have a root __init__.py file listed in 179 # the script outputs. 180 module_text_map[''] = module_text_map.get('', '') + ''' 181_names_with_underscore = [%s] 182__all__ = [_s for _s in dir() if not _s.startswith('_')] 183__all__.extend([_s for _s in _names_with_underscore]) 184''' % underscore_names_str 185 186 return module_text_map 187 188 189def _get_name_and_module(full_name): 190 """Split full_name into module and short name. 191 192 Args: 193 full_name: Full name of symbol that includes module. 194 195 Returns: 196 Full module name and short symbol name. 197 """ 198 name_segments = full_name.split('.') 199 return '.'.join(name_segments[:-1]), name_segments[-1] 200 201 202def _join_modules(module1, module2): 203 """Concatenate 2 module components. 204 205 Args: 206 module1: First module to join. 207 module2: Second module to join. 208 209 Returns: 210 Given two modules aaa.bbb and ccc.ddd, returns a joined 211 module aaa.bbb.ccc.ddd. 212 """ 213 if not module1: 214 return module2 215 if not module2: 216 return module1 217 return '%s.%s' % (module1, module2) 218 219 220def add_imports_for_symbol( 221 module_code_builder, 222 symbol, 223 source_module_name, 224 source_name, 225 api_name, 226 api_version, 227 output_module_prefix=''): 228 """Add imports for the given symbol to `module_code_builder`. 229 230 Args: 231 module_code_builder: `_ModuleInitCodeBuilder` instance. 232 symbol: A symbol. 233 source_module_name: Module that we can import the symbol from. 234 source_name: Name we can import the symbol with. 235 api_name: API name. Currently, must be either `tensorflow` or `estimator`. 236 api_version: API version. 237 output_module_prefix: Prefix to prepend to destination module. 238 """ 239 if api_version == 1: 240 names_attr = API_ATTRS_V1[api_name].names 241 constants_attr = API_ATTRS_V1[api_name].constants 242 else: 243 names_attr = API_ATTRS[api_name].names 244 constants_attr = API_ATTRS[api_name].constants 245 246 # If symbol is _tf_api_constants attribute, then add the constants. 247 if source_name == constants_attr: 248 for exports, name in symbol: 249 for export in exports: 250 dest_module, dest_name = _get_name_and_module(export) 251 dest_module = _join_modules(output_module_prefix, dest_module) 252 module_code_builder.add_import( 253 -1, dest_module, source_module_name, name, dest_name) 254 255 # If symbol has _tf_api_names attribute, then add import for it. 256 if (hasattr(symbol, '__dict__') and names_attr in symbol.__dict__): 257 for export in getattr(symbol, names_attr): # pylint: disable=protected-access 258 dest_module, dest_name = _get_name_and_module(export) 259 dest_module = _join_modules(output_module_prefix, dest_module) 260 module_code_builder.add_import( 261 id(symbol), dest_module, source_module_name, source_name, dest_name) 262 263 264def get_api_init_text(packages, 265 output_package, 266 api_name, 267 api_version, 268 compat_api_versions=None): 269 """Get a map from destination module to __init__.py code for that module. 270 271 Args: 272 packages: Base python packages containing python with target tf_export 273 decorators. 274 output_package: Base output python package where generated API will be 275 added. 276 api_name: API you want to generate (e.g. `tensorflow` or `estimator`). 277 api_version: API version you want to generate (1 or 2). 278 compat_api_versions: Additional API versions to generate under compat/ 279 directory. 280 281 Returns: 282 A dictionary where 283 key: (string) destination module (for e.g. tf or tf.consts). 284 value: (string) text that should be in __init__.py files for 285 corresponding modules. 286 """ 287 if compat_api_versions is None: 288 compat_api_versions = [] 289 module_code_builder = _ModuleInitCodeBuilder(output_package) 290 # Traverse over everything imported above. Specifically, 291 # we want to traverse over TensorFlow Python modules. 292 293 def in_packages(m): 294 return any(package in m for package in packages) 295 296 for module in list(sys.modules.values()): 297 # Only look at tensorflow modules. 298 if (not module or not hasattr(module, '__name__') or 299 module.__name__ is None or not in_packages(module.__name__)): 300 continue 301 # Do not generate __init__.py files for contrib modules for now. 302 if (('.contrib.' in module.__name__ or module.__name__.endswith('.contrib')) 303 and '.lite' not in module.__name__): 304 continue 305 306 for module_contents_name in dir(module): 307 if (module.__name__ + '.' + module_contents_name 308 in _SYMBOLS_TO_SKIP_EXPLICITLY): 309 continue 310 attr = getattr(module, module_contents_name) 311 _, attr = tf_decorator.unwrap(attr) 312 313 add_imports_for_symbol( 314 module_code_builder, attr, module.__name__, module_contents_name, 315 api_name, api_version) 316 for compat_api_version in compat_api_versions: 317 add_imports_for_symbol( 318 module_code_builder, attr, module.__name__, module_contents_name, 319 api_name, compat_api_version, 320 _COMPAT_MODULE_TEMPLATE % compat_api_version) 321 322 return module_code_builder.build() 323 324 325def get_module(dir_path, relative_to_dir): 326 """Get module that corresponds to path relative to relative_to_dir. 327 328 Args: 329 dir_path: Path to directory. 330 relative_to_dir: Get module relative to this directory. 331 332 Returns: 333 Name of module that corresponds to the given directory. 334 """ 335 dir_path = dir_path[len(relative_to_dir):] 336 # Convert path separators to '/' for easier parsing below. 337 dir_path = dir_path.replace(os.sep, '/') 338 return dir_path.replace('/', '.').strip('.') 339 340 341def get_module_docstring(module_name, package, api_name): 342 """Get docstring for the given module. 343 344 This method looks for docstring in the following order: 345 1. Checks if module has a docstring specified in doc_srcs. 346 2. Checks if module has a docstring source module specified 347 in doc_srcs. If it does, gets docstring from that module. 348 3. Checks if module with module_name exists under base package. 349 If it does, gets docstring from that module. 350 4. Returns a default docstring. 351 352 Args: 353 module_name: module name relative to tensorflow 354 (excluding 'tensorflow.' prefix) to get a docstring for. 355 package: Base python package containing python with target tf_export 356 decorators. 357 api_name: API you want to generate (e.g. `tensorflow` or `estimator`). 358 359 Returns: 360 One-line docstring to describe the module. 361 """ 362 # Get the same module doc strings for any version. That is, for module 363 # 'compat.v1.foo' we can get docstring from module 'foo'. 364 for version in _API_VERSIONS: 365 compat_prefix = _COMPAT_MODULE_TEMPLATE % version 366 if module_name.startswith(compat_prefix): 367 module_name = module_name[len(compat_prefix):].strip('.') 368 369 # Module under base package to get a docstring from. 370 docstring_module_name = module_name 371 372 doc_sources = doc_srcs.get_doc_sources(api_name) 373 374 if module_name in doc_sources: 375 docsrc = doc_sources[module_name] 376 if docsrc.docstring: 377 return docsrc.docstring 378 if docsrc.docstring_module_name: 379 docstring_module_name = docsrc.docstring_module_name 380 381 docstring_module_name = package + '.' + docstring_module_name 382 if (docstring_module_name in sys.modules and 383 sys.modules[docstring_module_name].__doc__): 384 return sys.modules[docstring_module_name].__doc__ 385 386 return 'Public API for tf.%s namespace.' % module_name 387 388 389def create_api_files(output_files, packages, root_init_template, output_dir, 390 output_package, api_name, api_version, 391 compat_api_versions, compat_init_templates): 392 """Creates __init__.py files for the Python API. 393 394 Args: 395 output_files: List of __init__.py file paths to create. 396 packages: Base python packages containing python with target tf_export 397 decorators. 398 root_init_template: Template for top-level __init__.py file. 399 "# API IMPORTS PLACEHOLDER" comment in the template file will be replaced 400 with imports. 401 output_dir: output API root directory. 402 output_package: Base output package where generated API will be added. 403 api_name: API you want to generate (e.g. `tensorflow` or `estimator`). 404 api_version: API version to generate (`v1` or `v2`). 405 compat_api_versions: Additional API versions to generate in compat/ 406 subdirectory. 407 compat_init_templates: List of templates for top level compat init files 408 in the same order as compat_api_versions. 409 410 Raises: 411 ValueError: if output_files list is missing a required file. 412 """ 413 module_name_to_file_path = {} 414 for output_file in output_files: 415 module_name = get_module(os.path.dirname(output_file), output_dir) 416 module_name_to_file_path[module_name] = os.path.normpath(output_file) 417 418 # Create file for each expected output in genrule. 419 for module, file_path in module_name_to_file_path.items(): 420 if not os.path.isdir(os.path.dirname(file_path)): 421 os.makedirs(os.path.dirname(file_path)) 422 open(file_path, 'a').close() 423 424 module_text_map = get_api_init_text(packages, output_package, api_name, 425 api_version, compat_api_versions) 426 427 # Add imports to output files. 428 missing_output_files = [] 429 # Root modules are "" and "compat.v*". 430 root_module = '' 431 compat_module_to_template = { 432 _COMPAT_MODULE_TEMPLATE % v: t 433 for v, t in zip(compat_api_versions, compat_init_templates) 434 } 435 436 for module, text in module_text_map.items(): 437 # Make sure genrule output file list is in sync with API exports. 438 if module not in module_name_to_file_path: 439 module_file_path = '"%s/__init__.py"' % ( 440 module.replace('.', '/')) 441 missing_output_files.append(module_file_path) 442 continue 443 444 contents = '' 445 if module == root_module and root_init_template: 446 # Read base init file for root module 447 with open(root_init_template, 'r') as root_init_template_file: 448 contents = root_init_template_file.read() 449 contents = contents.replace('# API IMPORTS PLACEHOLDER', text) 450 elif module in compat_module_to_template: 451 # Read base init file for compat module 452 with open(compat_module_to_template[module], 'r') as init_template_file: 453 contents = init_template_file.read() 454 contents = contents.replace('# API IMPORTS PLACEHOLDER', text) 455 else: 456 contents = ( 457 _GENERATED_FILE_HEADER % get_module_docstring( 458 module, packages[0], api_name) + text + _GENERATED_FILE_FOOTER) 459 with open(module_name_to_file_path[module], 'w') as fp: 460 fp.write(contents) 461 462 if missing_output_files: 463 raise ValueError( 464 """Missing outputs for genrule:\n%s. Be sure to add these targets to 465tensorflow/python/tools/api/generator/api_init_files_v1.bzl and 466tensorflow/python/tools/api/generator/api_init_files.bzl (tensorflow repo), or 467tensorflow_estimator/python/estimator/api/api_gen.bzl (estimator repo)""" 468 % ',\n'.join(sorted(missing_output_files))) 469 470 471def main(): 472 parser = argparse.ArgumentParser() 473 parser.add_argument( 474 'outputs', metavar='O', type=str, nargs='+', 475 help='If a single file is passed in, then we we assume it contains a ' 476 'semicolon-separated list of Python files that we expect this script to ' 477 'output. If multiple files are passed in, then we assume output files ' 478 'are listed directly as arguments.') 479 parser.add_argument( 480 '--packages', 481 default=_DEFAULT_PACKAGE, 482 type=str, 483 help='Base packages that import modules containing the target tf_export ' 484 'decorators.') 485 parser.add_argument( 486 '--root_init_template', default='', type=str, 487 help='Template for top level __init__.py file. ' 488 '"#API IMPORTS PLACEHOLDER" comment will be replaced with imports.') 489 parser.add_argument( 490 '--apidir', type=str, required=True, 491 help='Directory where generated output files are placed. ' 492 'gendir should be a prefix of apidir. Also, apidir ' 493 'should be a prefix of every directory in outputs.') 494 parser.add_argument( 495 '--apiname', required=True, type=str, 496 choices=API_ATTRS.keys(), 497 help='The API you want to generate.') 498 parser.add_argument( 499 '--apiversion', default=2, type=int, 500 choices=_API_VERSIONS, 501 help='The API version you want to generate.') 502 parser.add_argument( 503 '--compat_apiversions', default=[], type=int, action='append', 504 help='Additional versions to generate in compat/ subdirectory. ' 505 'If set to 0, then no additional version would be generated.') 506 parser.add_argument( 507 '--compat_init_templates', default=[], type=str, action='append', 508 help='Templates for top-level __init__ files under compat modules. ' 509 'The list of init file templates must be in the same order as ' 510 'list of versions passed with compat_apiversions.') 511 parser.add_argument( 512 '--output_package', default='tensorflow', type=str, 513 help='Root output package.') 514 args = parser.parse_args() 515 516 if len(args.outputs) == 1: 517 # If we only get a single argument, then it must be a file containing 518 # list of outputs. 519 with open(args.outputs[0]) as output_list_file: 520 outputs = [line.strip() for line in output_list_file.read().split(';')] 521 else: 522 outputs = args.outputs 523 524 # Populate `sys.modules` with modules containing tf_export(). 525 packages = args.packages.split(',') 526 for package in packages: 527 importlib.import_module(package) 528 create_api_files(outputs, packages, args.root_init_template, args.apidir, 529 args.output_package, args.apiname, args.apiversion, 530 args.compat_apiversions, args.compat_init_templates) 531 532 533if __name__ == '__main__': 534 main() 535