• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Generates and prints out imports and constants for new TensorFlow python api."""
16import argparse
17import collections
18import importlib
19import os
20import sys
21
22from tensorflow.python.tools.api.generator import doc_srcs
23from tensorflow.python.util import tf_decorator
24from tensorflow.python.util import tf_export
25
26API_ATTRS = tf_export.API_ATTRS
27API_ATTRS_V1 = tf_export.API_ATTRS_V1
28
29_LAZY_LOADING = False
30_API_VERSIONS = [1, 2]
31_COMPAT_MODULE_TEMPLATE = 'compat.v%d'
32_SUBCOMPAT_MODULE_TEMPLATE = 'compat.v%d.compat.v%d'
33_COMPAT_MODULE_PREFIX = 'compat.v'
34_DEFAULT_PACKAGE = 'tensorflow.python'
35_GENFILES_DIR_SUFFIX = 'genfiles/'
36_SYMBOLS_TO_SKIP_EXPLICITLY = {
37    # Overrides __getattr__, so that unwrapping tf_decorator
38    # would have side effects.
39    'tensorflow.python.platform.flags.FLAGS'
40}
41_GENERATED_FILE_HEADER = """# This file is MACHINE GENERATED! Do not edit.
42# Generated by: tensorflow/python/tools/api/generator/create_python_api.py script.
43\"\"\"%s
44\"\"\"
45
46import sys as _sys
47
48"""
49_GENERATED_FILE_FOOTER = ''
50_DEPRECATION_FOOTER = """
51from tensorflow.python.util import module_wrapper as _module_wrapper
52
53if not isinstance(_sys.modules[__name__], _module_wrapper.TFModuleWrapper):
54  _sys.modules[__name__] = _module_wrapper.TFModuleWrapper(
55      _sys.modules[__name__], "%s", public_apis=%s, deprecation=%s,
56      has_lite=%s)
57"""
58_LAZY_LOADING_MODULE_TEXT_TEMPLATE = """
59# Inform pytype that this module is dynamically populated (b/111239204).
60_HAS_DYNAMIC_ATTRIBUTES = True
61_PUBLIC_APIS = {
62%s
63}
64"""
65
66
67class SymbolExposedTwiceError(Exception):
68  """Raised when different symbols are exported with the same name."""
69  pass
70
71
72def get_canonical_import(import_set):
73  """Obtain one single import from a set of possible sources of a symbol.
74
75  One symbol might come from multiple places as it is being imported and
76  reexported. To simplify API changes, we always use the same import for the
77  same module, and give preference based on higher priority and alphabetical
78  ordering.
79
80  Args:
81    import_set: (set) Imports providing the same symbol. This is a set of tuples
82      in the form (import, priority). We want to pick an import with highest
83      priority.
84
85  Returns:
86    A module name to import
87  """
88  # We use the fact that list sorting is stable, so first we convert the set to
89  # a sorted list of the names and then we resort this list to move elements
90  # not in core tensorflow to the end.
91  # Here we sort by priority (higher preferred) and then  alphabetically by
92  # import string.
93  import_list = sorted(
94      import_set,
95      key=lambda imp_and_priority: (-imp_and_priority[1], imp_and_priority[0]))
96  return import_list[0][0]
97
98
99class _ModuleInitCodeBuilder(object):
100  """Builds a map from module name to imports included in that module."""
101
102  def __init__(self,
103               output_package,
104               api_version,
105               lazy_loading=_LAZY_LOADING,
106               use_relative_imports=False):
107    self._output_package = output_package
108    # Maps API module to API symbol name to set of tuples of the form
109    # (module name, priority).
110    # The same symbol can be imported from multiple locations. Higher
111    # "priority" indicates that import location is preferred over others.
112    self._module_imports = collections.defaultdict(
113        lambda: collections.defaultdict(set))
114    self._dest_import_to_id = collections.defaultdict(int)
115    # Names that start with underscore in the root module.
116    self._underscore_names_in_root = set()
117    self._api_version = api_version
118    # Controls whether or not exported symbols are lazily loaded or statically
119    # imported.
120    self._lazy_loading = lazy_loading
121    self._use_relative_imports = use_relative_imports
122
123  def _check_already_imported(self, symbol_id, api_name):
124    if (api_name in self._dest_import_to_id and
125        symbol_id != self._dest_import_to_id[api_name] and symbol_id != -1):
126      raise SymbolExposedTwiceError(
127          f'Trying to export multiple symbols with same name: {api_name}')
128    self._dest_import_to_id[api_name] = symbol_id
129
130  def add_import(self, symbol, source_module_name, source_name,
131                 dest_module_name, dest_name):
132    """Adds this import to module_imports.
133
134    Args:
135      symbol: TensorFlow Python symbol.
136      source_module_name: (string) Module to import from.
137      source_name: (string) Name of the symbol to import.
138      dest_module_name: (string) Module name to add import to.
139      dest_name: (string) Import the symbol using this name.
140
141    Raises:
142      SymbolExposedTwiceError: Raised when an import with the same
143        dest_name has already been added to dest_module_name.
144    """
145    # modules_with_exports.py is only used during API generation and
146    # won't be available when actually importing tensorflow.
147    if source_module_name.endswith('python.modules_with_exports'):
148      source_module_name = symbol.__module__
149    import_str = self.format_import(source_module_name, source_name, dest_name)
150
151    # Check if we are trying to expose two different symbols with same name.
152    full_api_name = dest_name
153    if dest_module_name:
154      full_api_name = dest_module_name + '.' + full_api_name
155    symbol_id = -1 if not symbol else id(symbol)
156    self._check_already_imported(symbol_id, full_api_name)
157
158    if not dest_module_name and dest_name.startswith('_'):
159      self._underscore_names_in_root.add(dest_name)
160
161    # The same symbol can be available in multiple modules.
162    # We store all possible ways of importing this symbol and later pick just
163    # one.
164    priority = 0
165    if symbol:
166      # Give higher priority to source module if it matches
167      # symbol's original module.
168      if hasattr(symbol, '__module__'):
169        priority = int(source_module_name == symbol.__module__)
170      # Give higher priority if symbol name matches its __name__.
171      if hasattr(symbol, '__name__'):
172        priority += int(source_name == symbol.__name__)
173    self._module_imports[dest_module_name][full_api_name].add(
174        (import_str, priority))
175
176  def _import_submodules(self):
177    """Add imports for all destination modules in self._module_imports."""
178    # Import all required modules in their parent modules.
179    # For e.g. if we import 'foo.bar.Value'. Then, we also
180    # import 'bar' in 'foo'.
181    imported_modules = set(self._module_imports.keys())
182    for module in imported_modules:
183      if not module:
184        continue
185      module_split = module.split('.')
186      parent_module = ''  # we import submodules in their parent_module
187
188      for submodule_index in range(len(module_split)):
189        if submodule_index > 0:
190          submodule = module_split[submodule_index - 1]
191          parent_module += '.' + submodule if parent_module else submodule
192        import_from = self._output_package
193        if self._lazy_loading:
194          import_from += '.' + '.'.join(module_split[:submodule_index + 1])
195          self.add_import(
196              symbol=None,
197              source_module_name='',
198              source_name=import_from,
199              dest_module_name=parent_module,
200              dest_name=module_split[submodule_index])
201        else:
202          if self._use_relative_imports:
203            import_from = '.'
204          elif submodule_index > 0:
205            import_from += '.' + '.'.join(module_split[:submodule_index])
206          self.add_import(
207              symbol=None,
208              source_module_name=import_from,
209              source_name=module_split[submodule_index],
210              dest_module_name=parent_module,
211              dest_name=module_split[submodule_index])
212
213  def build(self):
214    """Get a map from destination module to __init__.py code for that module.
215
216    Returns:
217      A dictionary where
218        key: (string) destination module (for e.g. tf or tf.consts).
219        value: (string) text that should be in __init__.py files for
220          corresponding modules.
221    """
222    self._import_submodules()
223    module_text_map = {}
224    footer_text_map = {}
225    for dest_module, dest_name_to_imports in self._module_imports.items():
226      # Sort all possible imports for a symbol and pick the first one.
227      imports_list = [
228          get_canonical_import(imports)
229          for _, imports in dest_name_to_imports.items()
230      ]
231      if self._lazy_loading:
232        module_text_map[
233            dest_module] = _LAZY_LOADING_MODULE_TEXT_TEMPLATE % '\n'.join(
234                sorted(imports_list))
235      else:
236        module_text_map[dest_module] = '\n'.join(sorted(imports_list))
237
238    # Expose exported symbols with underscores in root module since we import
239    # from it using * import. Don't need this for lazy_loading because the
240    # underscore symbols are already included in __all__ when passed in and
241    # handled by TFModuleWrapper.
242    root_module_footer = ''
243    if not self._lazy_loading:
244      underscore_names_str = ', '.join(
245          '\'%s\'' % name for name in sorted(self._underscore_names_in_root))
246
247      root_module_footer = """
248_names_with_underscore = [%s]
249__all__ = [_s for _s in dir() if not _s.startswith('_')]
250__all__.extend([_s for _s in _names_with_underscore])
251""" % underscore_names_str
252
253    # Add module wrapper if we need to print deprecation messages
254    # or if we use lazy loading.
255    if self._api_version == 1 or self._lazy_loading:
256      for dest_module, _ in self._module_imports.items():
257        deprecation = 'False'
258        has_lite = 'False'
259        if self._api_version == 1:  # Add 1.* deprecations.
260          if not dest_module.startswith(_COMPAT_MODULE_PREFIX):
261            deprecation = 'True'
262        # Workaround to make sure not load lite from lite/__init__.py
263        if (not dest_module and 'lite' in self._module_imports and
264            self._lazy_loading):
265          has_lite = 'True'
266        if self._lazy_loading:
267          public_apis_name = '_PUBLIC_APIS'
268        else:
269          public_apis_name = 'None'
270        footer_text_map[dest_module] = _DEPRECATION_FOOTER % (
271            dest_module, public_apis_name, deprecation, has_lite)
272
273    return module_text_map, footer_text_map, root_module_footer
274
275  def format_import(self, source_module_name, source_name, dest_name):
276    """Formats import statement.
277
278    Args:
279      source_module_name: (string) Source module to import from.
280      source_name: (string) Source symbol name to import.
281      dest_name: (string) Destination alias name.
282
283    Returns:
284      An import statement string.
285    """
286    if self._lazy_loading:
287      return "  '%s': ('%s', '%s')," % (dest_name, source_module_name,
288                                        source_name)
289    else:
290      if source_module_name:
291        if source_name == dest_name:
292          return 'from %s import %s' % (source_module_name, source_name)
293        else:
294          return 'from %s import %s as %s' % (source_module_name, source_name,
295                                              dest_name)
296      else:
297        if source_name == dest_name:
298          return 'import %s' % source_name
299        else:
300          return 'import %s as %s' % (source_name, dest_name)
301
302  def get_destination_modules(self):
303    return set(self._module_imports.keys())
304
305  def copy_imports(self, from_dest_module, to_dest_module):
306    self._module_imports[to_dest_module] = (
307        self._module_imports[from_dest_module].copy())
308
309
310def add_nested_compat_imports(module_builder, compat_api_versions,
311                              output_package):
312  """Adds compat.vN.compat.vK modules to module builder.
313
314  To avoid circular imports, we want to add __init__.py files under
315  compat.vN.compat.vK and under compat.vN.compat.vK.compat. For all other
316  imports, we point to corresponding modules under compat.vK.
317
318  Args:
319    module_builder: `_ModuleInitCodeBuilder` instance.
320    compat_api_versions: Supported compatibility versions.
321    output_package: Base output python package where generated API will be
322      added.
323  """
324  imported_modules = module_builder.get_destination_modules()
325
326  # Copy over all imports in compat.vK to compat.vN.compat.vK and
327  # all imports in compat.vK.compat to compat.vN.compat.vK.compat.
328  for v in compat_api_versions:
329    for sv in compat_api_versions:
330      subcompat_module = _SUBCOMPAT_MODULE_TEMPLATE % (v, sv)
331      compat_module = _COMPAT_MODULE_TEMPLATE % sv
332      module_builder.copy_imports(compat_module, subcompat_module)
333      module_builder.copy_imports('%s.compat' % compat_module,
334                                  '%s.compat' % subcompat_module)
335
336  # Prefixes of modules under compatibility packages, for e.g. "compat.v1.".
337  compat_prefixes = tuple(
338      _COMPAT_MODULE_TEMPLATE % v + '.' for v in compat_api_versions)
339
340  # Above, we only copied function, class and constant imports. Here
341  # we also add imports for child modules.
342  for imported_module in imported_modules:
343    if not imported_module.startswith(compat_prefixes):
344      continue
345    module_split = imported_module.split('.')
346
347    # Handle compat.vN.compat.vK.compat.foo case. That is,
348    # import compat.vK.compat.foo in compat.vN.compat.vK.compat.
349    if len(module_split) > 3 and module_split[2] == 'compat':
350      src_module = '.'.join(module_split[:3])
351      src_name = module_split[3]
352      assert src_name != 'v1' and src_name != 'v2', imported_module
353    else:  # Handle compat.vN.compat.vK.foo case.
354      src_module = '.'.join(module_split[:2])
355      src_name = module_split[2]
356      if src_name == 'compat':
357        continue  # compat.vN.compat.vK.compat is handled separately
358
359    for compat_api_version in compat_api_versions:
360      module_builder.add_import(
361          symbol=None,
362          source_module_name='%s.%s' % (output_package, src_module),
363          source_name=src_name,
364          dest_module_name='compat.v%d.%s' % (compat_api_version, src_module),
365          dest_name=src_name)
366
367
368def _get_name_and_module(full_name):
369  """Split full_name into module and short name.
370
371  Args:
372    full_name: Full name of symbol that includes module.
373
374  Returns:
375    Full module name and short symbol name.
376  """
377  name_segments = full_name.split('.')
378  return '.'.join(name_segments[:-1]), name_segments[-1]
379
380
381def _join_modules(module1, module2):
382  """Concatenate 2 module components.
383
384  Args:
385    module1: First module to join.
386    module2: Second module to join.
387
388  Returns:
389    Given two modules aaa.bbb and ccc.ddd, returns a joined
390    module aaa.bbb.ccc.ddd.
391  """
392  if not module1:
393    return module2
394  if not module2:
395    return module1
396  return '%s.%s' % (module1, module2)
397
398
399def add_imports_for_symbol(module_code_builder,
400                           symbol,
401                           source_module_name,
402                           source_name,
403                           api_name,
404                           api_version,
405                           output_module_prefix=''):
406  """Add imports for the given symbol to `module_code_builder`.
407
408  Args:
409    module_code_builder: `_ModuleInitCodeBuilder` instance.
410    symbol: A symbol.
411    source_module_name: Module that we can import the symbol from.
412    source_name: Name we can import the symbol with.
413    api_name: API name. Currently, must be either `tensorflow` or `estimator`.
414    api_version: API version.
415    output_module_prefix: Prefix to prepend to destination module.
416  """
417  if api_version == 1:
418    names_attr = API_ATTRS_V1[api_name].names
419    constants_attr = API_ATTRS_V1[api_name].constants
420  else:
421    names_attr = API_ATTRS[api_name].names
422    constants_attr = API_ATTRS[api_name].constants
423
424  # If symbol is _tf_api_constants attribute, then add the constants.
425  if source_name == constants_attr:
426    for exports, name in symbol:
427      for export in exports:
428        dest_module, dest_name = _get_name_and_module(export)
429        dest_module = _join_modules(output_module_prefix, dest_module)
430        module_code_builder.add_import(None, source_module_name, name,
431                                       dest_module, dest_name)
432
433  # If symbol has _tf_api_names attribute, then add import for it.
434  if (hasattr(symbol, '__dict__') and names_attr in symbol.__dict__):
435
436    # Generate import statements for symbols.
437    for export in getattr(symbol, names_attr):  # pylint: disable=protected-access
438      dest_module, dest_name = _get_name_and_module(export)
439      dest_module = _join_modules(output_module_prefix, dest_module)
440      module_code_builder.add_import(symbol, source_module_name, source_name,
441                                     dest_module, dest_name)
442
443
444def get_api_init_text(packages,
445                      packages_to_ignore,
446                      output_package,
447                      api_name,
448                      api_version,
449                      compat_api_versions=None,
450                      lazy_loading=_LAZY_LOADING,
451                      use_relative_imports=False):
452  """Get a map from destination module to __init__.py code for that module.
453
454  Args:
455    packages: Base python packages containing python with target tf_export
456      decorators.
457    packages_to_ignore: python packages to be ignored when checking for
458      tf_export decorators.
459    output_package: Base output python package where generated API will be
460      added.
461    api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
462    api_version: API version you want to generate (1 or 2).
463    compat_api_versions: Additional API versions to generate under compat/
464      directory.
465    lazy_loading: Boolean flag. If True, a lazy loading `__init__.py` file is
466      produced and if `False`, static imports are used.
467    use_relative_imports: True if we should use relative imports when importing
468      submodules.
469
470  Returns:
471    A dictionary where
472      key: (string) destination module (for e.g. tf or tf.consts).
473      value: (string) text that should be in __init__.py files for
474        corresponding modules.
475  """
476  if compat_api_versions is None:
477    compat_api_versions = []
478  module_code_builder = _ModuleInitCodeBuilder(output_package, api_version,
479                                               lazy_loading,
480                                               use_relative_imports)
481
482  # Traverse over everything imported above. Specifically,
483  # we want to traverse over TensorFlow Python modules.
484
485  def in_packages(m):
486    return any(package in m for package in packages)
487
488  for module in list(sys.modules.values()):
489    # Only look at tensorflow modules.
490    if (not module or not hasattr(module, '__name__') or
491        module.__name__ is None or not in_packages(module.__name__)):
492      continue
493    if packages_to_ignore and any([p for p in packages_to_ignore
494                                   if p in module.__name__]):
495      continue
496
497    # Do not generate __init__.py files for contrib modules for now.
498    if (('.contrib.' in module.__name__ or module.__name__.endswith('.contrib'))
499        and '.lite' not in module.__name__):
500      continue
501
502    for module_contents_name in dir(module):
503      if (module.__name__ + '.' +
504          module_contents_name in _SYMBOLS_TO_SKIP_EXPLICITLY):
505        continue
506      attr = getattr(module, module_contents_name)
507      _, attr = tf_decorator.unwrap(attr)
508
509      add_imports_for_symbol(module_code_builder, attr, module.__name__,
510                             module_contents_name, api_name, api_version)
511      for compat_api_version in compat_api_versions:
512        add_imports_for_symbol(module_code_builder, attr, module.__name__,
513                               module_contents_name, api_name,
514                               compat_api_version,
515                               _COMPAT_MODULE_TEMPLATE % compat_api_version)
516
517  if compat_api_versions:
518    add_nested_compat_imports(module_code_builder, compat_api_versions,
519                              output_package)
520  return module_code_builder.build()
521
522
523def get_module(dir_path, relative_to_dir):
524  """Get module that corresponds to path relative to relative_to_dir.
525
526  Args:
527    dir_path: Path to directory.
528    relative_to_dir: Get module relative to this directory.
529
530  Returns:
531    Name of module that corresponds to the given directory.
532  """
533  dir_path = dir_path[len(relative_to_dir):]
534  # Convert path separators to '/' for easier parsing below.
535  dir_path = dir_path.replace(os.sep, '/')
536  return dir_path.replace('/', '.').strip('.')
537
538
539def get_module_docstring(module_name, package, api_name):
540  """Get docstring for the given module.
541
542  This method looks for docstring in the following order:
543  1. Checks if module has a docstring specified in doc_srcs.
544  2. Checks if module has a docstring source module specified
545     in doc_srcs. If it does, gets docstring from that module.
546  3. Checks if module with module_name exists under base package.
547     If it does, gets docstring from that module.
548  4. Returns a default docstring.
549
550  Args:
551    module_name: module name relative to tensorflow (excluding 'tensorflow.'
552      prefix) to get a docstring for.
553    package: Base python package containing python with target tf_export
554      decorators.
555    api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
556
557  Returns:
558    One-line docstring to describe the module.
559  """
560  # Get the same module doc strings for any version. That is, for module
561  # 'compat.v1.foo' we can get docstring from module 'foo'.
562  for version in _API_VERSIONS:
563    compat_prefix = _COMPAT_MODULE_TEMPLATE % version
564    if module_name.startswith(compat_prefix):
565      module_name = module_name[len(compat_prefix):].strip('.')
566
567  # Module under base package to get a docstring from.
568  docstring_module_name = module_name
569
570  doc_sources = doc_srcs.get_doc_sources(api_name)
571
572  if module_name in doc_sources:
573    docsrc = doc_sources[module_name]
574    if docsrc.docstring:
575      return docsrc.docstring
576    if docsrc.docstring_module_name:
577      docstring_module_name = docsrc.docstring_module_name
578
579  if package != 'keras':
580    docstring_module_name = package + '.' + docstring_module_name
581  if (docstring_module_name in sys.modules and
582      sys.modules[docstring_module_name].__doc__):
583    return sys.modules[docstring_module_name].__doc__
584
585  return 'Public API for tf.%s namespace.' % module_name
586
587
588def create_api_files(output_files,
589                     packages,
590                     packages_to_ignore,
591                     root_init_template,
592                     output_dir,
593                     output_package,
594                     api_name,
595                     api_version,
596                     compat_api_versions,
597                     compat_init_templates,
598                     lazy_loading=_LAZY_LOADING,
599                     use_relative_imports=False):
600  """Creates __init__.py files for the Python API.
601
602  Args:
603    output_files: List of __init__.py file paths to create.
604    packages: Base python packages containing python with target tf_export
605      decorators.
606    packages_to_ignore: python packages to be ignored when checking for
607      tf_export decorators.
608    root_init_template: Template for top-level __init__.py file. "# API IMPORTS
609      PLACEHOLDER" comment in the template file will be replaced with imports.
610    output_dir: output API root directory.
611    output_package: Base output package where generated API will be added.
612    api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
613    api_version: API version to generate (`v1` or `v2`).
614    compat_api_versions: Additional API versions to generate in compat/
615      subdirectory.
616    compat_init_templates: List of templates for top level compat init files in
617      the same order as compat_api_versions.
618    lazy_loading: Boolean flag. If True, a lazy loading `__init__.py` file is
619      produced and if `False`, static imports are used.
620    use_relative_imports: True if we should use relative imports when import
621      submodules.
622
623  Raises:
624    ValueError: if output_files list is missing a required file.
625  """
626  module_name_to_file_path = {}
627  for output_file in output_files:
628    module_name = get_module(os.path.dirname(output_file), output_dir)
629    module_name_to_file_path[module_name] = os.path.normpath(output_file)
630
631  # Create file for each expected output in genrule.
632  for module, file_path in module_name_to_file_path.items():
633    if not os.path.isdir(os.path.dirname(file_path)):
634      os.makedirs(os.path.dirname(file_path))
635    open(file_path, 'a').close()
636
637  (
638      module_text_map,
639      deprecation_footer_map,
640      root_module_footer,
641  ) = get_api_init_text(packages, packages_to_ignore, output_package, api_name,
642                        api_version, compat_api_versions, lazy_loading,
643                        use_relative_imports)
644
645  # Add imports to output files.
646  missing_output_files = []
647  # Root modules are "" and "compat.v*".
648  root_module = ''
649  compat_module_to_template = {
650      _COMPAT_MODULE_TEMPLATE % v: t
651      for v, t in zip(compat_api_versions, compat_init_templates)
652  }
653  for v in compat_api_versions:
654    compat_module_to_template.update({
655        _SUBCOMPAT_MODULE_TEMPLATE % (v, vs): t
656        for vs, t in zip(compat_api_versions, compat_init_templates)
657    })
658
659  for module, text in module_text_map.items():
660    # Make sure genrule output file list is in sync with API exports.
661    if module not in module_name_to_file_path:
662      module_file_path = '"%s/__init__.py"' % (module.replace('.', '/'))
663      missing_output_files.append(module_file_path)
664      continue
665
666    contents = ''
667    if module == root_module and root_init_template:
668      # Read base init file for root module
669      with open(root_init_template, 'r') as root_init_template_file:
670        contents = root_init_template_file.read()
671        contents = contents.replace('# API IMPORTS PLACEHOLDER', text)
672        contents = contents.replace('# __all__ PLACEHOLDER', root_module_footer)
673    elif module in compat_module_to_template:
674      # Read base init file for compat module
675      with open(compat_module_to_template[module], 'r') as init_template_file:
676        contents = init_template_file.read()
677        contents = contents.replace('# API IMPORTS PLACEHOLDER', text)
678    else:
679      contents = (
680          _GENERATED_FILE_HEADER %
681          get_module_docstring(module, packages[0], api_name) + text +
682          _GENERATED_FILE_FOOTER)
683    if module in deprecation_footer_map:
684      if '# WRAPPER_PLACEHOLDER' in contents:
685        contents = contents.replace('# WRAPPER_PLACEHOLDER',
686                                    deprecation_footer_map[module])
687      else:
688        contents += deprecation_footer_map[module]
689    with open(module_name_to_file_path[module], 'w') as fp:
690      fp.write(contents)
691
692  if missing_output_files:
693    missing_files = ',\n'.join(sorted(missing_output_files))
694    raise ValueError(
695        f'Missing outputs for genrule:\n{missing_files}. Be sure to add these '
696        'targets to tensorflow/python/tools/api/generator/api_init_files_v1.bzl'
697        ' and tensorflow/python/tools/api/generator/api_init_files.bzl '
698        '(tensorflow repo), keras/api/api_init_files.bzl (keras repo), or '
699        'tensorflow_estimator/python/estimator/api/api_gen.bzl (estimator '
700        'repo)')
701
702
703def main():
704  parser = argparse.ArgumentParser()
705  parser.add_argument(
706      'outputs',
707      metavar='O',
708      type=str,
709      nargs='+',
710      help='If a single file is passed in, then we assume it contains a '
711      'semicolon-separated list of Python files that we expect this script to '
712      'output. If multiple files are passed in, then we assume output files '
713      'are listed directly as arguments.')
714  parser.add_argument(
715      '--packages',
716      default=_DEFAULT_PACKAGE,
717      type=str,
718      help='Base packages that import modules containing the target tf_export '
719      'decorators.')
720  parser.add_argument(
721      '--packages_to_ignore',
722      default='',
723      type=str,
724      help='Packages to exclude from the api generation. This is used to hide '
725      'certain packages from this script when multiple copy of code exists, '
726      'eg Keras. It is useful to avoid the SymbolExposedTwiceError.'
727      )
728  parser.add_argument(
729      '--root_init_template',
730      default='',
731      type=str,
732      help='Template for top level __init__.py file. '
733      '"#API IMPORTS PLACEHOLDER" comment will be replaced with imports.')
734  parser.add_argument(
735      '--apidir',
736      type=str,
737      required=True,
738      help='Directory where generated output files are placed. '
739      'gendir should be a prefix of apidir. Also, apidir '
740      'should be a prefix of every directory in outputs.')
741  parser.add_argument(
742      '--apiname',
743      required=True,
744      type=str,
745      choices=API_ATTRS.keys(),
746      help='The API you want to generate.')
747  parser.add_argument(
748      '--apiversion',
749      default=2,
750      type=int,
751      choices=_API_VERSIONS,
752      help='The API version you want to generate.')
753  parser.add_argument(
754      '--compat_apiversions',
755      default=[],
756      type=int,
757      action='append',
758      help='Additional versions to generate in compat/ subdirectory. '
759      'If set to 0, then no additional version would be generated.')
760  parser.add_argument(
761      '--compat_init_templates',
762      default=[],
763      type=str,
764      action='append',
765      help='Templates for top-level __init__ files under compat modules. '
766      'The list of init file templates must be in the same order as '
767      'list of versions passed with compat_apiversions.')
768  parser.add_argument(
769      '--output_package',
770      default='tensorflow',
771      type=str,
772      help='Root output package.')
773  parser.add_argument(
774      '--loading',
775      default='default',
776      type=str,
777      choices=['lazy', 'static', 'default'],
778      help='Controls how the generated __init__.py file loads the exported '
779      'symbols. \'lazy\' means the symbols are loaded when first used. '
780      '\'static\' means all exported symbols are loaded in the '
781      '__init__.py file. \'default\' uses the value of the '
782      '_LAZY_LOADING constant in create_python_api.py.')
783  parser.add_argument(
784      '--use_relative_imports',
785      default=False,
786      type=bool,
787      help='Whether to import submodules using relative imports or absolute '
788      'imports')
789  args = parser.parse_args()
790
791  if len(args.outputs) == 1:
792    # If we only get a single argument, then it must be a file containing
793    # list of outputs.
794    with open(args.outputs[0]) as output_list_file:
795      outputs = [line.strip() for line in output_list_file.read().split(';')]
796  else:
797    outputs = args.outputs
798
799  # Populate `sys.modules` with modules containing tf_export().
800  packages = args.packages.split(',')
801  for package in packages:
802    importlib.import_module(package)
803  packages_to_ignore = args.packages_to_ignore.split(',')
804
805  # Determine if the modules shall be loaded lazily or statically.
806  if args.loading == 'default':
807    lazy_loading = _LAZY_LOADING
808  elif args.loading == 'lazy':
809    lazy_loading = True
810  elif args.loading == 'static':
811    lazy_loading = False
812  else:
813    # This should never happen (tm).
814    raise ValueError(f'Invalid value for --loading flag: {args.loading}. Must '
815                     'be one of lazy, static, default.')
816
817  create_api_files(outputs, packages, packages_to_ignore,
818                   args.root_init_template, args.apidir,
819                   args.output_package, args.apiname, args.apiversion,
820                   args.compat_apiversions, args.compat_init_templates,
821                   lazy_loading, args.use_relative_imports)
822
823
824if __name__ == '__main__':
825  main()
826