1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Generates and prints out imports and constants for new TensorFlow python api.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import argparse 21import collections 22import importlib 23import os 24import sys 25 26from tensorflow.python.tools.api.generator import doc_srcs 27from tensorflow.python.util import tf_decorator 28from tensorflow.python.util import tf_export 29 30API_ATTRS = tf_export.API_ATTRS 31API_ATTRS_V1 = tf_export.API_ATTRS_V1 32 33_LAZY_LOADING = False 34_API_VERSIONS = [1, 2] 35_COMPAT_MODULE_TEMPLATE = 'compat.v%d' 36_SUBCOMPAT_MODULE_TEMPLATE = 'compat.v%d.compat.v%d' 37_COMPAT_MODULE_PREFIX = 'compat.v' 38_DEFAULT_PACKAGE = 'tensorflow.python' 39_GENFILES_DIR_SUFFIX = 'genfiles/' 40_SYMBOLS_TO_SKIP_EXPLICITLY = { 41 # Overrides __getattr__, so that unwrapping tf_decorator 42 # would have side effects. 43 'tensorflow.python.platform.flags.FLAGS' 44} 45_GENERATED_FILE_HEADER = """# This file is MACHINE GENERATED! Do not edit. 46# Generated by: tensorflow/python/tools/api/generator/create_python_api.py script. 47\"\"\"%s 48\"\"\" 49 50from __future__ import print_function as _print_function 51 52import sys as _sys 53 54""" 55_GENERATED_FILE_FOOTER = '\n\ndel _print_function\n' 56_DEPRECATION_FOOTER = """ 57from tensorflow.python.util import module_wrapper as _module_wrapper 58 59if not isinstance(_sys.modules[__name__], _module_wrapper.TFModuleWrapper): 60 _sys.modules[__name__] = _module_wrapper.TFModuleWrapper( 61 _sys.modules[__name__], "%s", public_apis=%s, deprecation=%s, 62 has_lite=%s) 63""" 64_LAZY_LOADING_MODULE_TEXT_TEMPLATE = """ 65# Inform pytype that this module is dynamically populated (b/111239204). 66_HAS_DYNAMIC_ATTRIBUTES = True 67_PUBLIC_APIS = { 68%s 69} 70""" 71 72 73class SymbolExposedTwiceError(Exception): 74 """Raised when different symbols are exported with the same name.""" 75 pass 76 77 78def get_canonical_import(import_set): 79 """Obtain one single import from a set of possible sources of a symbol. 80 81 One symbol might come from multiple places as it is being imported and 82 reexported. To simplify API changes, we always use the same import for the 83 same module, and give preference based on higher priority and alphabetical 84 ordering. 85 86 Args: 87 import_set: (set) Imports providing the same symbol. This is a set of tuples 88 in the form (import, priority). We want to pick an import with highest 89 priority. 90 91 Returns: 92 A module name to import 93 """ 94 # We use the fact that list sorting is stable, so first we convert the set to 95 # a sorted list of the names and then we resort this list to move elements 96 # not in core tensorflow to the end. 97 # Here we sort by priority (higher preferred) and then alphabetically by 98 # import string. 99 import_list = sorted( 100 import_set, 101 key=lambda imp_and_priority: (-imp_and_priority[1], imp_and_priority[0])) 102 return import_list[0][0] 103 104 105class _ModuleInitCodeBuilder(object): 106 """Builds a map from module name to imports included in that module.""" 107 108 def __init__(self, 109 output_package, 110 api_version, 111 lazy_loading=_LAZY_LOADING, 112 use_relative_imports=False): 113 self._output_package = output_package 114 # Maps API module to API symbol name to set of tuples of the form 115 # (module name, priority). 116 # The same symbol can be imported from multiple locations. Higher 117 # "priority" indicates that import location is preferred over others. 118 self._module_imports = collections.defaultdict( 119 lambda: collections.defaultdict(set)) 120 self._dest_import_to_id = collections.defaultdict(int) 121 # Names that start with underscore in the root module. 122 self._underscore_names_in_root = [] 123 self._api_version = api_version 124 # Controls whether or not exported symbols are lazily loaded or statically 125 # imported. 126 self._lazy_loading = lazy_loading 127 self._use_relative_imports = use_relative_imports 128 129 def _check_already_imported(self, symbol_id, api_name): 130 if (api_name in self._dest_import_to_id and 131 symbol_id != self._dest_import_to_id[api_name] and symbol_id != -1): 132 raise SymbolExposedTwiceError( 133 'Trying to export multiple symbols with same name: %s.' % api_name) 134 self._dest_import_to_id[api_name] = symbol_id 135 136 def add_import(self, symbol, source_module_name, source_name, 137 dest_module_name, dest_name): 138 """Adds this import to module_imports. 139 140 Args: 141 symbol: TensorFlow Python symbol. 142 source_module_name: (string) Module to import from. 143 source_name: (string) Name of the symbol to import. 144 dest_module_name: (string) Module name to add import to. 145 dest_name: (string) Import the symbol using this name. 146 147 Raises: 148 SymbolExposedTwiceError: Raised when an import with the same 149 dest_name has already been added to dest_module_name. 150 """ 151 # modules_with_exports.py is only used during API generation and 152 # won't be available when actually importing tensorflow. 153 if source_module_name.endswith('python.modules_with_exports'): 154 source_module_name = symbol.__module__ 155 import_str = self.format_import(source_module_name, source_name, dest_name) 156 157 # Check if we are trying to expose two different symbols with same name. 158 full_api_name = dest_name 159 if dest_module_name: 160 full_api_name = dest_module_name + '.' + full_api_name 161 symbol_id = -1 if not symbol else id(symbol) 162 self._check_already_imported(symbol_id, full_api_name) 163 164 if not dest_module_name and dest_name.startswith('_'): 165 self._underscore_names_in_root.append(dest_name) 166 167 # The same symbol can be available in multiple modules. 168 # We store all possible ways of importing this symbol and later pick just 169 # one. 170 priority = 0 171 if symbol: 172 # Give higher priority to source module if it matches 173 # symbol's original module. 174 if hasattr(symbol, '__module__'): 175 priority = int(source_module_name == symbol.__module__) 176 # Give higher priority if symbol name matches its __name__. 177 if hasattr(symbol, '__name__'): 178 priority += int(source_name == symbol.__name__) 179 self._module_imports[dest_module_name][full_api_name].add( 180 (import_str, priority)) 181 182 def _import_submodules(self): 183 """Add imports for all destination modules in self._module_imports.""" 184 # Import all required modules in their parent modules. 185 # For e.g. if we import 'foo.bar.Value'. Then, we also 186 # import 'bar' in 'foo'. 187 imported_modules = set(self._module_imports.keys()) 188 for module in imported_modules: 189 if not module: 190 continue 191 module_split = module.split('.') 192 parent_module = '' # we import submodules in their parent_module 193 194 for submodule_index in range(len(module_split)): 195 if submodule_index > 0: 196 submodule = module_split[submodule_index - 1] 197 parent_module += '.' + submodule if parent_module else submodule 198 import_from = self._output_package 199 if self._lazy_loading: 200 import_from += '.' + '.'.join(module_split[:submodule_index + 1]) 201 self.add_import( 202 symbol=None, 203 source_module_name='', 204 source_name=import_from, 205 dest_module_name=parent_module, 206 dest_name=module_split[submodule_index]) 207 else: 208 if self._use_relative_imports: 209 import_from = '.' 210 elif submodule_index > 0: 211 import_from += '.' + '.'.join(module_split[:submodule_index]) 212 self.add_import( 213 symbol=None, 214 source_module_name=import_from, 215 source_name=module_split[submodule_index], 216 dest_module_name=parent_module, 217 dest_name=module_split[submodule_index]) 218 219 def build(self): 220 """Get a map from destination module to __init__.py code for that module. 221 222 Returns: 223 A dictionary where 224 key: (string) destination module (for e.g. tf or tf.consts). 225 value: (string) text that should be in __init__.py files for 226 corresponding modules. 227 """ 228 self._import_submodules() 229 module_text_map = {} 230 footer_text_map = {} 231 for dest_module, dest_name_to_imports in self._module_imports.items(): 232 # Sort all possible imports for a symbol and pick the first one. 233 imports_list = [ 234 get_canonical_import(imports) 235 for _, imports in dest_name_to_imports.items() 236 ] 237 if self._lazy_loading: 238 module_text_map[ 239 dest_module] = _LAZY_LOADING_MODULE_TEXT_TEMPLATE % '\n'.join( 240 sorted(imports_list)) 241 else: 242 module_text_map[dest_module] = '\n'.join(sorted(imports_list)) 243 244 # Expose exported symbols with underscores in root module since we import 245 # from it using * import. Don't need this for lazy_loading because the 246 # underscore symbols are already included in __all__ when passed in and 247 # handled by TFModuleWrapper. 248 root_module_footer = '' 249 if not self._lazy_loading: 250 underscore_names_str = ', '.join( 251 '\'%s\'' % name for name in self._underscore_names_in_root) 252 253 root_module_footer = """ 254_names_with_underscore = [%s] 255__all__ = [_s for _s in dir() if not _s.startswith('_')] 256__all__.extend([_s for _s in _names_with_underscore]) 257""" % underscore_names_str 258 259 # Add module wrapper if we need to print deprecation messages 260 # or if we use lazy loading. 261 if self._api_version == 1 or self._lazy_loading: 262 for dest_module, _ in self._module_imports.items(): 263 deprecation = 'False' 264 has_lite = 'False' 265 if self._api_version == 1: # Add 1.* deprecations. 266 if not dest_module.startswith(_COMPAT_MODULE_PREFIX): 267 deprecation = 'True' 268 # Workaround to make sure not load lite from lite/__init__.py 269 if (not dest_module and 'lite' in self._module_imports and 270 self._lazy_loading): 271 has_lite = 'True' 272 if self._lazy_loading: 273 public_apis_name = '_PUBLIC_APIS' 274 else: 275 public_apis_name = 'None' 276 footer_text_map[dest_module] = _DEPRECATION_FOOTER % ( 277 dest_module, public_apis_name, deprecation, has_lite) 278 279 return module_text_map, footer_text_map, root_module_footer 280 281 def format_import(self, source_module_name, source_name, dest_name): 282 """Formats import statement. 283 284 Args: 285 source_module_name: (string) Source module to import from. 286 source_name: (string) Source symbol name to import. 287 dest_name: (string) Destination alias name. 288 289 Returns: 290 An import statement string. 291 """ 292 if self._lazy_loading: 293 return " '%s': ('%s', '%s')," % (dest_name, source_module_name, 294 source_name) 295 else: 296 if source_module_name: 297 if source_name == dest_name: 298 return 'from %s import %s' % (source_module_name, source_name) 299 else: 300 return 'from %s import %s as %s' % (source_module_name, source_name, 301 dest_name) 302 else: 303 if source_name == dest_name: 304 return 'import %s' % source_name 305 else: 306 return 'import %s as %s' % (source_name, dest_name) 307 308 def get_destination_modules(self): 309 return set(self._module_imports.keys()) 310 311 def copy_imports(self, from_dest_module, to_dest_module): 312 self._module_imports[to_dest_module] = ( 313 self._module_imports[from_dest_module].copy()) 314 315 316def add_nested_compat_imports(module_builder, compat_api_versions, 317 output_package): 318 """Adds compat.vN.compat.vK modules to module builder. 319 320 To avoid circular imports, we want to add __init__.py files under 321 compat.vN.compat.vK and under compat.vN.compat.vK.compat. For all other 322 imports, we point to corresponding modules under compat.vK. 323 324 Args: 325 module_builder: `_ModuleInitCodeBuilder` instance. 326 compat_api_versions: Supported compatibility versions. 327 output_package: Base output python package where generated API will be 328 added. 329 """ 330 imported_modules = module_builder.get_destination_modules() 331 332 # Copy over all imports in compat.vK to compat.vN.compat.vK and 333 # all imports in compat.vK.compat to compat.vN.compat.vK.compat. 334 for v in compat_api_versions: 335 for sv in compat_api_versions: 336 subcompat_module = _SUBCOMPAT_MODULE_TEMPLATE % (v, sv) 337 compat_module = _COMPAT_MODULE_TEMPLATE % sv 338 module_builder.copy_imports(compat_module, subcompat_module) 339 module_builder.copy_imports('%s.compat' % compat_module, 340 '%s.compat' % subcompat_module) 341 342 # Prefixes of modules under compatibility packages, for e.g. "compat.v1.". 343 compat_prefixes = tuple( 344 _COMPAT_MODULE_TEMPLATE % v + '.' for v in compat_api_versions) 345 346 # Above, we only copied function, class and constant imports. Here 347 # we also add imports for child modules. 348 for imported_module in imported_modules: 349 if not imported_module.startswith(compat_prefixes): 350 continue 351 module_split = imported_module.split('.') 352 353 # Handle compat.vN.compat.vK.compat.foo case. That is, 354 # import compat.vK.compat.foo in compat.vN.compat.vK.compat. 355 if len(module_split) > 3 and module_split[2] == 'compat': 356 src_module = '.'.join(module_split[:3]) 357 src_name = module_split[3] 358 assert src_name != 'v1' and src_name != 'v2', imported_module 359 else: # Handle compat.vN.compat.vK.foo case. 360 src_module = '.'.join(module_split[:2]) 361 src_name = module_split[2] 362 if src_name == 'compat': 363 continue # compat.vN.compat.vK.compat is handled separately 364 365 for compat_api_version in compat_api_versions: 366 module_builder.add_import( 367 symbol=None, 368 source_module_name='%s.%s' % (output_package, src_module), 369 source_name=src_name, 370 dest_module_name='compat.v%d.%s' % (compat_api_version, src_module), 371 dest_name=src_name) 372 373 374def _get_name_and_module(full_name): 375 """Split full_name into module and short name. 376 377 Args: 378 full_name: Full name of symbol that includes module. 379 380 Returns: 381 Full module name and short symbol name. 382 """ 383 name_segments = full_name.split('.') 384 return '.'.join(name_segments[:-1]), name_segments[-1] 385 386 387def _join_modules(module1, module2): 388 """Concatenate 2 module components. 389 390 Args: 391 module1: First module to join. 392 module2: Second module to join. 393 394 Returns: 395 Given two modules aaa.bbb and ccc.ddd, returns a joined 396 module aaa.bbb.ccc.ddd. 397 """ 398 if not module1: 399 return module2 400 if not module2: 401 return module1 402 return '%s.%s' % (module1, module2) 403 404 405def add_imports_for_symbol(module_code_builder, 406 symbol, 407 source_module_name, 408 source_name, 409 api_name, 410 api_version, 411 output_module_prefix=''): 412 """Add imports for the given symbol to `module_code_builder`. 413 414 Args: 415 module_code_builder: `_ModuleInitCodeBuilder` instance. 416 symbol: A symbol. 417 source_module_name: Module that we can import the symbol from. 418 source_name: Name we can import the symbol with. 419 api_name: API name. Currently, must be either `tensorflow` or `estimator`. 420 api_version: API version. 421 output_module_prefix: Prefix to prepend to destination module. 422 """ 423 if api_version == 1: 424 names_attr = API_ATTRS_V1[api_name].names 425 constants_attr = API_ATTRS_V1[api_name].constants 426 else: 427 names_attr = API_ATTRS[api_name].names 428 constants_attr = API_ATTRS[api_name].constants 429 430 # If symbol is _tf_api_constants attribute, then add the constants. 431 if source_name == constants_attr: 432 for exports, name in symbol: 433 for export in exports: 434 dest_module, dest_name = _get_name_and_module(export) 435 dest_module = _join_modules(output_module_prefix, dest_module) 436 module_code_builder.add_import(None, source_module_name, name, 437 dest_module, dest_name) 438 439 # If symbol has _tf_api_names attribute, then add import for it. 440 if (hasattr(symbol, '__dict__') and names_attr in symbol.__dict__): 441 442 # Generate import statements for symbols. 443 for export in getattr(symbol, names_attr): # pylint: disable=protected-access 444 dest_module, dest_name = _get_name_and_module(export) 445 dest_module = _join_modules(output_module_prefix, dest_module) 446 module_code_builder.add_import(symbol, source_module_name, source_name, 447 dest_module, dest_name) 448 449 450def get_api_init_text(packages, 451 packages_to_ignore, 452 output_package, 453 api_name, 454 api_version, 455 compat_api_versions=None, 456 lazy_loading=_LAZY_LOADING, 457 use_relative_imports=False): 458 """Get a map from destination module to __init__.py code for that module. 459 460 Args: 461 packages: Base python packages containing python with target tf_export 462 decorators. 463 packages_to_ignore: python packages to be ignored when checking for 464 tf_export decorators. 465 output_package: Base output python package where generated API will be 466 added. 467 api_name: API you want to generate (e.g. `tensorflow` or `estimator`). 468 api_version: API version you want to generate (1 or 2). 469 compat_api_versions: Additional API versions to generate under compat/ 470 directory. 471 lazy_loading: Boolean flag. If True, a lazy loading `__init__.py` file is 472 produced and if `False`, static imports are used. 473 use_relative_imports: True if we should use relative imports when importing 474 submodules. 475 476 Returns: 477 A dictionary where 478 key: (string) destination module (for e.g. tf or tf.consts). 479 value: (string) text that should be in __init__.py files for 480 corresponding modules. 481 """ 482 if compat_api_versions is None: 483 compat_api_versions = [] 484 module_code_builder = _ModuleInitCodeBuilder(output_package, api_version, 485 lazy_loading, 486 use_relative_imports) 487 488 # Traverse over everything imported above. Specifically, 489 # we want to traverse over TensorFlow Python modules. 490 491 def in_packages(m): 492 return any(package in m for package in packages) 493 494 for module in list(sys.modules.values()): 495 # Only look at tensorflow modules. 496 if (not module or not hasattr(module, '__name__') or 497 module.__name__ is None or not in_packages(module.__name__)): 498 continue 499 if packages_to_ignore and any([p for p in packages_to_ignore 500 if p in module.__name__]): 501 continue 502 503 # Do not generate __init__.py files for contrib modules for now. 504 if (('.contrib.' in module.__name__ or module.__name__.endswith('.contrib')) 505 and '.lite' not in module.__name__): 506 continue 507 508 for module_contents_name in dir(module): 509 if (module.__name__ + '.' + 510 module_contents_name in _SYMBOLS_TO_SKIP_EXPLICITLY): 511 continue 512 attr = getattr(module, module_contents_name) 513 _, attr = tf_decorator.unwrap(attr) 514 515 add_imports_for_symbol(module_code_builder, attr, module.__name__, 516 module_contents_name, api_name, api_version) 517 for compat_api_version in compat_api_versions: 518 add_imports_for_symbol(module_code_builder, attr, module.__name__, 519 module_contents_name, api_name, 520 compat_api_version, 521 _COMPAT_MODULE_TEMPLATE % compat_api_version) 522 523 if compat_api_versions: 524 add_nested_compat_imports(module_code_builder, compat_api_versions, 525 output_package) 526 return module_code_builder.build() 527 528 529def get_module(dir_path, relative_to_dir): 530 """Get module that corresponds to path relative to relative_to_dir. 531 532 Args: 533 dir_path: Path to directory. 534 relative_to_dir: Get module relative to this directory. 535 536 Returns: 537 Name of module that corresponds to the given directory. 538 """ 539 dir_path = dir_path[len(relative_to_dir):] 540 # Convert path separators to '/' for easier parsing below. 541 dir_path = dir_path.replace(os.sep, '/') 542 return dir_path.replace('/', '.').strip('.') 543 544 545def get_module_docstring(module_name, package, api_name): 546 """Get docstring for the given module. 547 548 This method looks for docstring in the following order: 549 1. Checks if module has a docstring specified in doc_srcs. 550 2. Checks if module has a docstring source module specified 551 in doc_srcs. If it does, gets docstring from that module. 552 3. Checks if module with module_name exists under base package. 553 If it does, gets docstring from that module. 554 4. Returns a default docstring. 555 556 Args: 557 module_name: module name relative to tensorflow (excluding 'tensorflow.' 558 prefix) to get a docstring for. 559 package: Base python package containing python with target tf_export 560 decorators. 561 api_name: API you want to generate (e.g. `tensorflow` or `estimator`). 562 563 Returns: 564 One-line docstring to describe the module. 565 """ 566 # Get the same module doc strings for any version. That is, for module 567 # 'compat.v1.foo' we can get docstring from module 'foo'. 568 for version in _API_VERSIONS: 569 compat_prefix = _COMPAT_MODULE_TEMPLATE % version 570 if module_name.startswith(compat_prefix): 571 module_name = module_name[len(compat_prefix):].strip('.') 572 573 # Module under base package to get a docstring from. 574 docstring_module_name = module_name 575 576 doc_sources = doc_srcs.get_doc_sources(api_name) 577 578 if module_name in doc_sources: 579 docsrc = doc_sources[module_name] 580 if docsrc.docstring: 581 return docsrc.docstring 582 if docsrc.docstring_module_name: 583 docstring_module_name = docsrc.docstring_module_name 584 585 docstring_module_name = package + '.' + docstring_module_name 586 if (docstring_module_name in sys.modules and 587 sys.modules[docstring_module_name].__doc__): 588 return sys.modules[docstring_module_name].__doc__ 589 590 return 'Public API for tf.%s namespace.' % module_name 591 592 593def create_api_files(output_files, 594 packages, 595 packages_to_ignore, 596 root_init_template, 597 output_dir, 598 output_package, 599 api_name, 600 api_version, 601 compat_api_versions, 602 compat_init_templates, 603 lazy_loading=_LAZY_LOADING, 604 use_relative_imports=False): 605 """Creates __init__.py files for the Python API. 606 607 Args: 608 output_files: List of __init__.py file paths to create. 609 packages: Base python packages containing python with target tf_export 610 decorators. 611 packages_to_ignore: python packages to be ignored when checking for 612 tf_export decorators. 613 root_init_template: Template for top-level __init__.py file. "# API IMPORTS 614 PLACEHOLDER" comment in the template file will be replaced with imports. 615 output_dir: output API root directory. 616 output_package: Base output package where generated API will be added. 617 api_name: API you want to generate (e.g. `tensorflow` or `estimator`). 618 api_version: API version to generate (`v1` or `v2`). 619 compat_api_versions: Additional API versions to generate in compat/ 620 subdirectory. 621 compat_init_templates: List of templates for top level compat init files in 622 the same order as compat_api_versions. 623 lazy_loading: Boolean flag. If True, a lazy loading `__init__.py` file is 624 produced and if `False`, static imports are used. 625 use_relative_imports: True if we should use relative imports when import 626 submodules. 627 628 Raises: 629 ValueError: if output_files list is missing a required file. 630 """ 631 module_name_to_file_path = {} 632 for output_file in output_files: 633 module_name = get_module(os.path.dirname(output_file), output_dir) 634 module_name_to_file_path[module_name] = os.path.normpath(output_file) 635 636 # Create file for each expected output in genrule. 637 for module, file_path in module_name_to_file_path.items(): 638 if not os.path.isdir(os.path.dirname(file_path)): 639 os.makedirs(os.path.dirname(file_path)) 640 open(file_path, 'a').close() 641 642 ( 643 module_text_map, 644 deprecation_footer_map, 645 root_module_footer, 646 ) = get_api_init_text(packages, packages_to_ignore, output_package, api_name, 647 api_version, compat_api_versions, lazy_loading, 648 use_relative_imports) 649 650 # Add imports to output files. 651 missing_output_files = [] 652 # Root modules are "" and "compat.v*". 653 root_module = '' 654 compat_module_to_template = { 655 _COMPAT_MODULE_TEMPLATE % v: t 656 for v, t in zip(compat_api_versions, compat_init_templates) 657 } 658 for v in compat_api_versions: 659 compat_module_to_template.update({ 660 _SUBCOMPAT_MODULE_TEMPLATE % (v, vs): t 661 for vs, t in zip(compat_api_versions, compat_init_templates) 662 }) 663 664 for module, text in module_text_map.items(): 665 # Make sure genrule output file list is in sync with API exports. 666 if module not in module_name_to_file_path: 667 module_file_path = '"%s/__init__.py"' % (module.replace('.', '/')) 668 missing_output_files.append(module_file_path) 669 continue 670 671 contents = '' 672 if module == root_module and root_init_template: 673 # Read base init file for root module 674 with open(root_init_template, 'r') as root_init_template_file: 675 contents = root_init_template_file.read() 676 contents = contents.replace('# API IMPORTS PLACEHOLDER', text) 677 contents = contents.replace('# __all__ PLACEHOLDER', root_module_footer) 678 elif module in compat_module_to_template: 679 # Read base init file for compat module 680 with open(compat_module_to_template[module], 'r') as init_template_file: 681 contents = init_template_file.read() 682 contents = contents.replace('# API IMPORTS PLACEHOLDER', text) 683 else: 684 contents = ( 685 _GENERATED_FILE_HEADER % 686 get_module_docstring(module, packages[0], api_name) + text + 687 _GENERATED_FILE_FOOTER) 688 if module in deprecation_footer_map: 689 if '# WRAPPER_PLACEHOLDER' in contents: 690 contents = contents.replace('# WRAPPER_PLACEHOLDER', 691 deprecation_footer_map[module]) 692 else: 693 contents += deprecation_footer_map[module] 694 with open(module_name_to_file_path[module], 'w') as fp: 695 fp.write(contents) 696 697 if missing_output_files: 698 raise ValueError( 699 """Missing outputs for genrule:\n%s. Be sure to add these targets to 700tensorflow/python/tools/api/generator/api_init_files_v1.bzl and 701tensorflow/python/tools/api/generator/api_init_files.bzl (tensorflow repo), or 702tensorflow_estimator/python/estimator/api/api_gen.bzl (estimator repo)""" % 703 ',\n'.join(sorted(missing_output_files))) 704 705 706def main(): 707 parser = argparse.ArgumentParser() 708 parser.add_argument( 709 'outputs', 710 metavar='O', 711 type=str, 712 nargs='+', 713 help='If a single file is passed in, then we assume it contains a ' 714 'semicolon-separated list of Python files that we expect this script to ' 715 'output. If multiple files are passed in, then we assume output files ' 716 'are listed directly as arguments.') 717 parser.add_argument( 718 '--packages', 719 default=_DEFAULT_PACKAGE, 720 type=str, 721 help='Base packages that import modules containing the target tf_export ' 722 'decorators.') 723 parser.add_argument( 724 '--packages_to_ignore', 725 default='', 726 type=str, 727 help='Packages to exclude from the api generation. This is used to hide ' 728 'certain packages from this script when multiple copy of code exists, ' 729 'eg Keras. It is useful to avoid the SymbolExposedTwiceError.' 730 ) 731 parser.add_argument( 732 '--root_init_template', 733 default='', 734 type=str, 735 help='Template for top level __init__.py file. ' 736 '"#API IMPORTS PLACEHOLDER" comment will be replaced with imports.') 737 parser.add_argument( 738 '--apidir', 739 type=str, 740 required=True, 741 help='Directory where generated output files are placed. ' 742 'gendir should be a prefix of apidir. Also, apidir ' 743 'should be a prefix of every directory in outputs.') 744 parser.add_argument( 745 '--apiname', 746 required=True, 747 type=str, 748 choices=API_ATTRS.keys(), 749 help='The API you want to generate.') 750 parser.add_argument( 751 '--apiversion', 752 default=2, 753 type=int, 754 choices=_API_VERSIONS, 755 help='The API version you want to generate.') 756 parser.add_argument( 757 '--compat_apiversions', 758 default=[], 759 type=int, 760 action='append', 761 help='Additional versions to generate in compat/ subdirectory. ' 762 'If set to 0, then no additional version would be generated.') 763 parser.add_argument( 764 '--compat_init_templates', 765 default=[], 766 type=str, 767 action='append', 768 help='Templates for top-level __init__ files under compat modules. ' 769 'The list of init file templates must be in the same order as ' 770 'list of versions passed with compat_apiversions.') 771 parser.add_argument( 772 '--output_package', 773 default='tensorflow', 774 type=str, 775 help='Root output package.') 776 parser.add_argument( 777 '--loading', 778 default='default', 779 type=str, 780 choices=['lazy', 'static', 'default'], 781 help='Controls how the generated __init__.py file loads the exported ' 782 'symbols. \'lazy\' means the symbols are loaded when first used. ' 783 '\'static\' means all exported symbols are loaded in the ' 784 '__init__.py file. \'default\' uses the value of the ' 785 '_LAZY_LOADING constant in create_python_api.py.') 786 parser.add_argument( 787 '--use_relative_imports', 788 default=False, 789 type=bool, 790 help='Whether to import submodules using relative imports or absolute ' 791 'imports') 792 args = parser.parse_args() 793 794 if len(args.outputs) == 1: 795 # If we only get a single argument, then it must be a file containing 796 # list of outputs. 797 with open(args.outputs[0]) as output_list_file: 798 outputs = [line.strip() for line in output_list_file.read().split(';')] 799 else: 800 outputs = args.outputs 801 802 # Populate `sys.modules` with modules containing tf_export(). 803 packages = args.packages.split(',') 804 for package in packages: 805 importlib.import_module(package) 806 packages_to_ignore = args.packages_to_ignore.split(',') 807 808 # Determine if the modules shall be loaded lazily or statically. 809 if args.loading == 'default': 810 lazy_loading = _LAZY_LOADING 811 elif args.loading == 'lazy': 812 lazy_loading = True 813 elif args.loading == 'static': 814 lazy_loading = False 815 else: 816 # This should never happen (tm). 817 raise ValueError('Invalid value for --loading flag: %s. Must be one of ' 818 'lazy, static, default.' % args.loading) 819 820 create_api_files(outputs, packages, packages_to_ignore, 821 args.root_init_template, args.apidir, 822 args.output_package, args.apiname, args.apiversion, 823 args.compat_apiversions, args.compat_init_templates, 824 lazy_loading, args.use_relative_imports) 825 826 827if __name__ == '__main__': 828 main() 829