1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Generates and prints out imports and constants for new TensorFlow python api.""" 16import argparse 17import collections 18import importlib 19import os 20import sys 21 22from tensorflow.python.tools.api.generator import doc_srcs 23from tensorflow.python.util import tf_decorator 24from tensorflow.python.util import tf_export 25 26API_ATTRS = tf_export.API_ATTRS 27API_ATTRS_V1 = tf_export.API_ATTRS_V1 28 29_LAZY_LOADING = False 30_API_VERSIONS = [1, 2] 31_COMPAT_MODULE_TEMPLATE = 'compat.v%d' 32_SUBCOMPAT_MODULE_TEMPLATE = 'compat.v%d.compat.v%d' 33_COMPAT_MODULE_PREFIX = 'compat.v' 34_DEFAULT_PACKAGE = 'tensorflow.python' 35_GENFILES_DIR_SUFFIX = 'genfiles/' 36_SYMBOLS_TO_SKIP_EXPLICITLY = { 37 # Overrides __getattr__, so that unwrapping tf_decorator 38 # would have side effects. 39 'tensorflow.python.platform.flags.FLAGS' 40} 41_GENERATED_FILE_HEADER = """# This file is MACHINE GENERATED! Do not edit. 42# Generated by: tensorflow/python/tools/api/generator/create_python_api.py script. 43\"\"\"%s 44\"\"\" 45 46import sys as _sys 47 48""" 49_GENERATED_FILE_FOOTER = '' 50_DEPRECATION_FOOTER = """ 51from tensorflow.python.util import module_wrapper as _module_wrapper 52 53if not isinstance(_sys.modules[__name__], _module_wrapper.TFModuleWrapper): 54 _sys.modules[__name__] = _module_wrapper.TFModuleWrapper( 55 _sys.modules[__name__], "%s", public_apis=%s, deprecation=%s, 56 has_lite=%s) 57""" 58_LAZY_LOADING_MODULE_TEXT_TEMPLATE = """ 59# Inform pytype that this module is dynamically populated (b/111239204). 60_HAS_DYNAMIC_ATTRIBUTES = True 61_PUBLIC_APIS = { 62%s 63} 64""" 65 66 67class SymbolExposedTwiceError(Exception): 68 """Raised when different symbols are exported with the same name.""" 69 pass 70 71 72def get_canonical_import(import_set): 73 """Obtain one single import from a set of possible sources of a symbol. 74 75 One symbol might come from multiple places as it is being imported and 76 reexported. To simplify API changes, we always use the same import for the 77 same module, and give preference based on higher priority and alphabetical 78 ordering. 79 80 Args: 81 import_set: (set) Imports providing the same symbol. This is a set of tuples 82 in the form (import, priority). We want to pick an import with highest 83 priority. 84 85 Returns: 86 A module name to import 87 """ 88 # We use the fact that list sorting is stable, so first we convert the set to 89 # a sorted list of the names and then we resort this list to move elements 90 # not in core tensorflow to the end. 91 # Here we sort by priority (higher preferred) and then alphabetically by 92 # import string. 93 import_list = sorted( 94 import_set, 95 key=lambda imp_and_priority: (-imp_and_priority[1], imp_and_priority[0])) 96 return import_list[0][0] 97 98 99class _ModuleInitCodeBuilder(object): 100 """Builds a map from module name to imports included in that module.""" 101 102 def __init__(self, 103 output_package, 104 api_version, 105 lazy_loading=_LAZY_LOADING, 106 use_relative_imports=False): 107 self._output_package = output_package 108 # Maps API module to API symbol name to set of tuples of the form 109 # (module name, priority). 110 # The same symbol can be imported from multiple locations. Higher 111 # "priority" indicates that import location is preferred over others. 112 self._module_imports = collections.defaultdict( 113 lambda: collections.defaultdict(set)) 114 self._dest_import_to_id = collections.defaultdict(int) 115 # Names that start with underscore in the root module. 116 self._underscore_names_in_root = set() 117 self._api_version = api_version 118 # Controls whether or not exported symbols are lazily loaded or statically 119 # imported. 120 self._lazy_loading = lazy_loading 121 self._use_relative_imports = use_relative_imports 122 123 def _check_already_imported(self, symbol_id, api_name): 124 if (api_name in self._dest_import_to_id and 125 symbol_id != self._dest_import_to_id[api_name] and symbol_id != -1): 126 raise SymbolExposedTwiceError( 127 f'Trying to export multiple symbols with same name: {api_name}') 128 self._dest_import_to_id[api_name] = symbol_id 129 130 def add_import(self, symbol, source_module_name, source_name, 131 dest_module_name, dest_name): 132 """Adds this import to module_imports. 133 134 Args: 135 symbol: TensorFlow Python symbol. 136 source_module_name: (string) Module to import from. 137 source_name: (string) Name of the symbol to import. 138 dest_module_name: (string) Module name to add import to. 139 dest_name: (string) Import the symbol using this name. 140 141 Raises: 142 SymbolExposedTwiceError: Raised when an import with the same 143 dest_name has already been added to dest_module_name. 144 """ 145 # modules_with_exports.py is only used during API generation and 146 # won't be available when actually importing tensorflow. 147 if source_module_name.endswith('python.modules_with_exports'): 148 source_module_name = symbol.__module__ 149 import_str = self.format_import(source_module_name, source_name, dest_name) 150 151 # Check if we are trying to expose two different symbols with same name. 152 full_api_name = dest_name 153 if dest_module_name: 154 full_api_name = dest_module_name + '.' + full_api_name 155 symbol_id = -1 if not symbol else id(symbol) 156 self._check_already_imported(symbol_id, full_api_name) 157 158 if not dest_module_name and dest_name.startswith('_'): 159 self._underscore_names_in_root.add(dest_name) 160 161 # The same symbol can be available in multiple modules. 162 # We store all possible ways of importing this symbol and later pick just 163 # one. 164 priority = 0 165 if symbol: 166 # Give higher priority to source module if it matches 167 # symbol's original module. 168 if hasattr(symbol, '__module__'): 169 priority = int(source_module_name == symbol.__module__) 170 # Give higher priority if symbol name matches its __name__. 171 if hasattr(symbol, '__name__'): 172 priority += int(source_name == symbol.__name__) 173 self._module_imports[dest_module_name][full_api_name].add( 174 (import_str, priority)) 175 176 def _import_submodules(self): 177 """Add imports for all destination modules in self._module_imports.""" 178 # Import all required modules in their parent modules. 179 # For e.g. if we import 'foo.bar.Value'. Then, we also 180 # import 'bar' in 'foo'. 181 imported_modules = set(self._module_imports.keys()) 182 for module in imported_modules: 183 if not module: 184 continue 185 module_split = module.split('.') 186 parent_module = '' # we import submodules in their parent_module 187 188 for submodule_index in range(len(module_split)): 189 if submodule_index > 0: 190 submodule = module_split[submodule_index - 1] 191 parent_module += '.' + submodule if parent_module else submodule 192 import_from = self._output_package 193 if self._lazy_loading: 194 import_from += '.' + '.'.join(module_split[:submodule_index + 1]) 195 self.add_import( 196 symbol=None, 197 source_module_name='', 198 source_name=import_from, 199 dest_module_name=parent_module, 200 dest_name=module_split[submodule_index]) 201 else: 202 if self._use_relative_imports: 203 import_from = '.' 204 elif submodule_index > 0: 205 import_from += '.' + '.'.join(module_split[:submodule_index]) 206 self.add_import( 207 symbol=None, 208 source_module_name=import_from, 209 source_name=module_split[submodule_index], 210 dest_module_name=parent_module, 211 dest_name=module_split[submodule_index]) 212 213 def build(self): 214 """Get a map from destination module to __init__.py code for that module. 215 216 Returns: 217 A dictionary where 218 key: (string) destination module (for e.g. tf or tf.consts). 219 value: (string) text that should be in __init__.py files for 220 corresponding modules. 221 """ 222 self._import_submodules() 223 module_text_map = {} 224 footer_text_map = {} 225 for dest_module, dest_name_to_imports in self._module_imports.items(): 226 # Sort all possible imports for a symbol and pick the first one. 227 imports_list = [ 228 get_canonical_import(imports) 229 for _, imports in dest_name_to_imports.items() 230 ] 231 if self._lazy_loading: 232 module_text_map[ 233 dest_module] = _LAZY_LOADING_MODULE_TEXT_TEMPLATE % '\n'.join( 234 sorted(imports_list)) 235 else: 236 module_text_map[dest_module] = '\n'.join(sorted(imports_list)) 237 238 # Expose exported symbols with underscores in root module since we import 239 # from it using * import. Don't need this for lazy_loading because the 240 # underscore symbols are already included in __all__ when passed in and 241 # handled by TFModuleWrapper. 242 root_module_footer = '' 243 if not self._lazy_loading: 244 underscore_names_str = ', '.join( 245 '\'%s\'' % name for name in sorted(self._underscore_names_in_root)) 246 247 root_module_footer = """ 248_names_with_underscore = [%s] 249__all__ = [_s for _s in dir() if not _s.startswith('_')] 250__all__.extend([_s for _s in _names_with_underscore]) 251""" % underscore_names_str 252 253 # Add module wrapper if we need to print deprecation messages 254 # or if we use lazy loading. 255 if self._api_version == 1 or self._lazy_loading: 256 for dest_module, _ in self._module_imports.items(): 257 deprecation = 'False' 258 has_lite = 'False' 259 if self._api_version == 1: # Add 1.* deprecations. 260 if not dest_module.startswith(_COMPAT_MODULE_PREFIX): 261 deprecation = 'True' 262 # Workaround to make sure not load lite from lite/__init__.py 263 if (not dest_module and 'lite' in self._module_imports and 264 self._lazy_loading): 265 has_lite = 'True' 266 if self._lazy_loading: 267 public_apis_name = '_PUBLIC_APIS' 268 else: 269 public_apis_name = 'None' 270 footer_text_map[dest_module] = _DEPRECATION_FOOTER % ( 271 dest_module, public_apis_name, deprecation, has_lite) 272 273 return module_text_map, footer_text_map, root_module_footer 274 275 def format_import(self, source_module_name, source_name, dest_name): 276 """Formats import statement. 277 278 Args: 279 source_module_name: (string) Source module to import from. 280 source_name: (string) Source symbol name to import. 281 dest_name: (string) Destination alias name. 282 283 Returns: 284 An import statement string. 285 """ 286 if self._lazy_loading: 287 return " '%s': ('%s', '%s')," % (dest_name, source_module_name, 288 source_name) 289 else: 290 if source_module_name: 291 if source_name == dest_name: 292 return 'from %s import %s' % (source_module_name, source_name) 293 else: 294 return 'from %s import %s as %s' % (source_module_name, source_name, 295 dest_name) 296 else: 297 if source_name == dest_name: 298 return 'import %s' % source_name 299 else: 300 return 'import %s as %s' % (source_name, dest_name) 301 302 def get_destination_modules(self): 303 return set(self._module_imports.keys()) 304 305 def copy_imports(self, from_dest_module, to_dest_module): 306 self._module_imports[to_dest_module] = ( 307 self._module_imports[from_dest_module].copy()) 308 309 310def add_nested_compat_imports(module_builder, compat_api_versions, 311 output_package): 312 """Adds compat.vN.compat.vK modules to module builder. 313 314 To avoid circular imports, we want to add __init__.py files under 315 compat.vN.compat.vK and under compat.vN.compat.vK.compat. For all other 316 imports, we point to corresponding modules under compat.vK. 317 318 Args: 319 module_builder: `_ModuleInitCodeBuilder` instance. 320 compat_api_versions: Supported compatibility versions. 321 output_package: Base output python package where generated API will be 322 added. 323 """ 324 imported_modules = module_builder.get_destination_modules() 325 326 # Copy over all imports in compat.vK to compat.vN.compat.vK and 327 # all imports in compat.vK.compat to compat.vN.compat.vK.compat. 328 for v in compat_api_versions: 329 for sv in compat_api_versions: 330 subcompat_module = _SUBCOMPAT_MODULE_TEMPLATE % (v, sv) 331 compat_module = _COMPAT_MODULE_TEMPLATE % sv 332 module_builder.copy_imports(compat_module, subcompat_module) 333 module_builder.copy_imports('%s.compat' % compat_module, 334 '%s.compat' % subcompat_module) 335 336 # Prefixes of modules under compatibility packages, for e.g. "compat.v1.". 337 compat_prefixes = tuple( 338 _COMPAT_MODULE_TEMPLATE % v + '.' for v in compat_api_versions) 339 340 # Above, we only copied function, class and constant imports. Here 341 # we also add imports for child modules. 342 for imported_module in imported_modules: 343 if not imported_module.startswith(compat_prefixes): 344 continue 345 module_split = imported_module.split('.') 346 347 # Handle compat.vN.compat.vK.compat.foo case. That is, 348 # import compat.vK.compat.foo in compat.vN.compat.vK.compat. 349 if len(module_split) > 3 and module_split[2] == 'compat': 350 src_module = '.'.join(module_split[:3]) 351 src_name = module_split[3] 352 assert src_name != 'v1' and src_name != 'v2', imported_module 353 else: # Handle compat.vN.compat.vK.foo case. 354 src_module = '.'.join(module_split[:2]) 355 src_name = module_split[2] 356 if src_name == 'compat': 357 continue # compat.vN.compat.vK.compat is handled separately 358 359 for compat_api_version in compat_api_versions: 360 module_builder.add_import( 361 symbol=None, 362 source_module_name='%s.%s' % (output_package, src_module), 363 source_name=src_name, 364 dest_module_name='compat.v%d.%s' % (compat_api_version, src_module), 365 dest_name=src_name) 366 367 368def _get_name_and_module(full_name): 369 """Split full_name into module and short name. 370 371 Args: 372 full_name: Full name of symbol that includes module. 373 374 Returns: 375 Full module name and short symbol name. 376 """ 377 name_segments = full_name.split('.') 378 return '.'.join(name_segments[:-1]), name_segments[-1] 379 380 381def _join_modules(module1, module2): 382 """Concatenate 2 module components. 383 384 Args: 385 module1: First module to join. 386 module2: Second module to join. 387 388 Returns: 389 Given two modules aaa.bbb and ccc.ddd, returns a joined 390 module aaa.bbb.ccc.ddd. 391 """ 392 if not module1: 393 return module2 394 if not module2: 395 return module1 396 return '%s.%s' % (module1, module2) 397 398 399def add_imports_for_symbol(module_code_builder, 400 symbol, 401 source_module_name, 402 source_name, 403 api_name, 404 api_version, 405 output_module_prefix=''): 406 """Add imports for the given symbol to `module_code_builder`. 407 408 Args: 409 module_code_builder: `_ModuleInitCodeBuilder` instance. 410 symbol: A symbol. 411 source_module_name: Module that we can import the symbol from. 412 source_name: Name we can import the symbol with. 413 api_name: API name. Currently, must be either `tensorflow` or `estimator`. 414 api_version: API version. 415 output_module_prefix: Prefix to prepend to destination module. 416 """ 417 if api_version == 1: 418 names_attr = API_ATTRS_V1[api_name].names 419 constants_attr = API_ATTRS_V1[api_name].constants 420 else: 421 names_attr = API_ATTRS[api_name].names 422 constants_attr = API_ATTRS[api_name].constants 423 424 # If symbol is _tf_api_constants attribute, then add the constants. 425 if source_name == constants_attr: 426 for exports, name in symbol: 427 for export in exports: 428 dest_module, dest_name = _get_name_and_module(export) 429 dest_module = _join_modules(output_module_prefix, dest_module) 430 module_code_builder.add_import(None, source_module_name, name, 431 dest_module, dest_name) 432 433 # If symbol has _tf_api_names attribute, then add import for it. 434 if (hasattr(symbol, '__dict__') and names_attr in symbol.__dict__): 435 436 # Generate import statements for symbols. 437 for export in getattr(symbol, names_attr): # pylint: disable=protected-access 438 dest_module, dest_name = _get_name_and_module(export) 439 dest_module = _join_modules(output_module_prefix, dest_module) 440 module_code_builder.add_import(symbol, source_module_name, source_name, 441 dest_module, dest_name) 442 443 444def get_api_init_text(packages, 445 packages_to_ignore, 446 output_package, 447 api_name, 448 api_version, 449 compat_api_versions=None, 450 lazy_loading=_LAZY_LOADING, 451 use_relative_imports=False): 452 """Get a map from destination module to __init__.py code for that module. 453 454 Args: 455 packages: Base python packages containing python with target tf_export 456 decorators. 457 packages_to_ignore: python packages to be ignored when checking for 458 tf_export decorators. 459 output_package: Base output python package where generated API will be 460 added. 461 api_name: API you want to generate (e.g. `tensorflow` or `estimator`). 462 api_version: API version you want to generate (1 or 2). 463 compat_api_versions: Additional API versions to generate under compat/ 464 directory. 465 lazy_loading: Boolean flag. If True, a lazy loading `__init__.py` file is 466 produced and if `False`, static imports are used. 467 use_relative_imports: True if we should use relative imports when importing 468 submodules. 469 470 Returns: 471 A dictionary where 472 key: (string) destination module (for e.g. tf or tf.consts). 473 value: (string) text that should be in __init__.py files for 474 corresponding modules. 475 """ 476 if compat_api_versions is None: 477 compat_api_versions = [] 478 module_code_builder = _ModuleInitCodeBuilder(output_package, api_version, 479 lazy_loading, 480 use_relative_imports) 481 482 # Traverse over everything imported above. Specifically, 483 # we want to traverse over TensorFlow Python modules. 484 485 def in_packages(m): 486 return any(package in m for package in packages) 487 488 for module in list(sys.modules.values()): 489 # Only look at tensorflow modules. 490 if (not module or not hasattr(module, '__name__') or 491 module.__name__ is None or not in_packages(module.__name__)): 492 continue 493 if packages_to_ignore and any([p for p in packages_to_ignore 494 if p in module.__name__]): 495 continue 496 497 # Do not generate __init__.py files for contrib modules for now. 498 if (('.contrib.' in module.__name__ or module.__name__.endswith('.contrib')) 499 and '.lite' not in module.__name__): 500 continue 501 502 for module_contents_name in dir(module): 503 if (module.__name__ + '.' + 504 module_contents_name in _SYMBOLS_TO_SKIP_EXPLICITLY): 505 continue 506 attr = getattr(module, module_contents_name) 507 _, attr = tf_decorator.unwrap(attr) 508 509 add_imports_for_symbol(module_code_builder, attr, module.__name__, 510 module_contents_name, api_name, api_version) 511 for compat_api_version in compat_api_versions: 512 add_imports_for_symbol(module_code_builder, attr, module.__name__, 513 module_contents_name, api_name, 514 compat_api_version, 515 _COMPAT_MODULE_TEMPLATE % compat_api_version) 516 517 if compat_api_versions: 518 add_nested_compat_imports(module_code_builder, compat_api_versions, 519 output_package) 520 return module_code_builder.build() 521 522 523def get_module(dir_path, relative_to_dir): 524 """Get module that corresponds to path relative to relative_to_dir. 525 526 Args: 527 dir_path: Path to directory. 528 relative_to_dir: Get module relative to this directory. 529 530 Returns: 531 Name of module that corresponds to the given directory. 532 """ 533 dir_path = dir_path[len(relative_to_dir):] 534 # Convert path separators to '/' for easier parsing below. 535 dir_path = dir_path.replace(os.sep, '/') 536 return dir_path.replace('/', '.').strip('.') 537 538 539def get_module_docstring(module_name, package, api_name): 540 """Get docstring for the given module. 541 542 This method looks for docstring in the following order: 543 1. Checks if module has a docstring specified in doc_srcs. 544 2. Checks if module has a docstring source module specified 545 in doc_srcs. If it does, gets docstring from that module. 546 3. Checks if module with module_name exists under base package. 547 If it does, gets docstring from that module. 548 4. Returns a default docstring. 549 550 Args: 551 module_name: module name relative to tensorflow (excluding 'tensorflow.' 552 prefix) to get a docstring for. 553 package: Base python package containing python with target tf_export 554 decorators. 555 api_name: API you want to generate (e.g. `tensorflow` or `estimator`). 556 557 Returns: 558 One-line docstring to describe the module. 559 """ 560 # Get the same module doc strings for any version. That is, for module 561 # 'compat.v1.foo' we can get docstring from module 'foo'. 562 for version in _API_VERSIONS: 563 compat_prefix = _COMPAT_MODULE_TEMPLATE % version 564 if module_name.startswith(compat_prefix): 565 module_name = module_name[len(compat_prefix):].strip('.') 566 567 # Module under base package to get a docstring from. 568 docstring_module_name = module_name 569 570 doc_sources = doc_srcs.get_doc_sources(api_name) 571 572 if module_name in doc_sources: 573 docsrc = doc_sources[module_name] 574 if docsrc.docstring: 575 return docsrc.docstring 576 if docsrc.docstring_module_name: 577 docstring_module_name = docsrc.docstring_module_name 578 579 if package != 'keras': 580 docstring_module_name = package + '.' + docstring_module_name 581 if (docstring_module_name in sys.modules and 582 sys.modules[docstring_module_name].__doc__): 583 return sys.modules[docstring_module_name].__doc__ 584 585 return 'Public API for tf.%s namespace.' % module_name 586 587 588def create_api_files(output_files, 589 packages, 590 packages_to_ignore, 591 root_init_template, 592 output_dir, 593 output_package, 594 api_name, 595 api_version, 596 compat_api_versions, 597 compat_init_templates, 598 lazy_loading=_LAZY_LOADING, 599 use_relative_imports=False): 600 """Creates __init__.py files for the Python API. 601 602 Args: 603 output_files: List of __init__.py file paths to create. 604 packages: Base python packages containing python with target tf_export 605 decorators. 606 packages_to_ignore: python packages to be ignored when checking for 607 tf_export decorators. 608 root_init_template: Template for top-level __init__.py file. "# API IMPORTS 609 PLACEHOLDER" comment in the template file will be replaced with imports. 610 output_dir: output API root directory. 611 output_package: Base output package where generated API will be added. 612 api_name: API you want to generate (e.g. `tensorflow` or `estimator`). 613 api_version: API version to generate (`v1` or `v2`). 614 compat_api_versions: Additional API versions to generate in compat/ 615 subdirectory. 616 compat_init_templates: List of templates for top level compat init files in 617 the same order as compat_api_versions. 618 lazy_loading: Boolean flag. If True, a lazy loading `__init__.py` file is 619 produced and if `False`, static imports are used. 620 use_relative_imports: True if we should use relative imports when import 621 submodules. 622 623 Raises: 624 ValueError: if output_files list is missing a required file. 625 """ 626 module_name_to_file_path = {} 627 for output_file in output_files: 628 module_name = get_module(os.path.dirname(output_file), output_dir) 629 module_name_to_file_path[module_name] = os.path.normpath(output_file) 630 631 # Create file for each expected output in genrule. 632 for module, file_path in module_name_to_file_path.items(): 633 if not os.path.isdir(os.path.dirname(file_path)): 634 os.makedirs(os.path.dirname(file_path)) 635 open(file_path, 'a').close() 636 637 ( 638 module_text_map, 639 deprecation_footer_map, 640 root_module_footer, 641 ) = get_api_init_text(packages, packages_to_ignore, output_package, api_name, 642 api_version, compat_api_versions, lazy_loading, 643 use_relative_imports) 644 645 # Add imports to output files. 646 missing_output_files = [] 647 # Root modules are "" and "compat.v*". 648 root_module = '' 649 compat_module_to_template = { 650 _COMPAT_MODULE_TEMPLATE % v: t 651 for v, t in zip(compat_api_versions, compat_init_templates) 652 } 653 for v in compat_api_versions: 654 compat_module_to_template.update({ 655 _SUBCOMPAT_MODULE_TEMPLATE % (v, vs): t 656 for vs, t in zip(compat_api_versions, compat_init_templates) 657 }) 658 659 for module, text in module_text_map.items(): 660 # Make sure genrule output file list is in sync with API exports. 661 if module not in module_name_to_file_path: 662 module_file_path = '"%s/__init__.py"' % (module.replace('.', '/')) 663 missing_output_files.append(module_file_path) 664 continue 665 666 contents = '' 667 if module == root_module and root_init_template: 668 # Read base init file for root module 669 with open(root_init_template, 'r') as root_init_template_file: 670 contents = root_init_template_file.read() 671 contents = contents.replace('# API IMPORTS PLACEHOLDER', text) 672 contents = contents.replace('# __all__ PLACEHOLDER', root_module_footer) 673 elif module in compat_module_to_template: 674 # Read base init file for compat module 675 with open(compat_module_to_template[module], 'r') as init_template_file: 676 contents = init_template_file.read() 677 contents = contents.replace('# API IMPORTS PLACEHOLDER', text) 678 else: 679 contents = ( 680 _GENERATED_FILE_HEADER % 681 get_module_docstring(module, packages[0], api_name) + text + 682 _GENERATED_FILE_FOOTER) 683 if module in deprecation_footer_map: 684 if '# WRAPPER_PLACEHOLDER' in contents: 685 contents = contents.replace('# WRAPPER_PLACEHOLDER', 686 deprecation_footer_map[module]) 687 else: 688 contents += deprecation_footer_map[module] 689 with open(module_name_to_file_path[module], 'w') as fp: 690 fp.write(contents) 691 692 if missing_output_files: 693 missing_files = ',\n'.join(sorted(missing_output_files)) 694 raise ValueError( 695 f'Missing outputs for genrule:\n{missing_files}. Be sure to add these ' 696 'targets to tensorflow/python/tools/api/generator/api_init_files_v1.bzl' 697 ' and tensorflow/python/tools/api/generator/api_init_files.bzl ' 698 '(tensorflow repo), keras/api/api_init_files.bzl (keras repo), or ' 699 'tensorflow_estimator/python/estimator/api/api_gen.bzl (estimator ' 700 'repo)') 701 702 703def main(): 704 parser = argparse.ArgumentParser() 705 parser.add_argument( 706 'outputs', 707 metavar='O', 708 type=str, 709 nargs='+', 710 help='If a single file is passed in, then we assume it contains a ' 711 'semicolon-separated list of Python files that we expect this script to ' 712 'output. If multiple files are passed in, then we assume output files ' 713 'are listed directly as arguments.') 714 parser.add_argument( 715 '--packages', 716 default=_DEFAULT_PACKAGE, 717 type=str, 718 help='Base packages that import modules containing the target tf_export ' 719 'decorators.') 720 parser.add_argument( 721 '--packages_to_ignore', 722 default='', 723 type=str, 724 help='Packages to exclude from the api generation. This is used to hide ' 725 'certain packages from this script when multiple copy of code exists, ' 726 'eg Keras. It is useful to avoid the SymbolExposedTwiceError.' 727 ) 728 parser.add_argument( 729 '--root_init_template', 730 default='', 731 type=str, 732 help='Template for top level __init__.py file. ' 733 '"#API IMPORTS PLACEHOLDER" comment will be replaced with imports.') 734 parser.add_argument( 735 '--apidir', 736 type=str, 737 required=True, 738 help='Directory where generated output files are placed. ' 739 'gendir should be a prefix of apidir. Also, apidir ' 740 'should be a prefix of every directory in outputs.') 741 parser.add_argument( 742 '--apiname', 743 required=True, 744 type=str, 745 choices=API_ATTRS.keys(), 746 help='The API you want to generate.') 747 parser.add_argument( 748 '--apiversion', 749 default=2, 750 type=int, 751 choices=_API_VERSIONS, 752 help='The API version you want to generate.') 753 parser.add_argument( 754 '--compat_apiversions', 755 default=[], 756 type=int, 757 action='append', 758 help='Additional versions to generate in compat/ subdirectory. ' 759 'If set to 0, then no additional version would be generated.') 760 parser.add_argument( 761 '--compat_init_templates', 762 default=[], 763 type=str, 764 action='append', 765 help='Templates for top-level __init__ files under compat modules. ' 766 'The list of init file templates must be in the same order as ' 767 'list of versions passed with compat_apiversions.') 768 parser.add_argument( 769 '--output_package', 770 default='tensorflow', 771 type=str, 772 help='Root output package.') 773 parser.add_argument( 774 '--loading', 775 default='default', 776 type=str, 777 choices=['lazy', 'static', 'default'], 778 help='Controls how the generated __init__.py file loads the exported ' 779 'symbols. \'lazy\' means the symbols are loaded when first used. ' 780 '\'static\' means all exported symbols are loaded in the ' 781 '__init__.py file. \'default\' uses the value of the ' 782 '_LAZY_LOADING constant in create_python_api.py.') 783 parser.add_argument( 784 '--use_relative_imports', 785 default=False, 786 type=bool, 787 help='Whether to import submodules using relative imports or absolute ' 788 'imports') 789 args = parser.parse_args() 790 791 if len(args.outputs) == 1: 792 # If we only get a single argument, then it must be a file containing 793 # list of outputs. 794 with open(args.outputs[0]) as output_list_file: 795 outputs = [line.strip() for line in output_list_file.read().split(';')] 796 else: 797 outputs = args.outputs 798 799 # Populate `sys.modules` with modules containing tf_export(). 800 packages = args.packages.split(',') 801 for package in packages: 802 importlib.import_module(package) 803 packages_to_ignore = args.packages_to_ignore.split(',') 804 805 # Determine if the modules shall be loaded lazily or statically. 806 if args.loading == 'default': 807 lazy_loading = _LAZY_LOADING 808 elif args.loading == 'lazy': 809 lazy_loading = True 810 elif args.loading == 'static': 811 lazy_loading = False 812 else: 813 # This should never happen (tm). 814 raise ValueError(f'Invalid value for --loading flag: {args.loading}. Must ' 815 'be one of lazy, static, default.') 816 817 create_api_files(outputs, packages, packages_to_ignore, 818 args.root_init_template, args.apidir, 819 args.output_package, args.apiname, args.apiversion, 820 args.compat_apiversions, args.compat_init_templates, 821 lazy_loading, args.use_relative_imports) 822 823 824if __name__ == '__main__': 825 main() 826