• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Generates and prints out imports and constants for new TensorFlow python api."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import argparse
21import collections
22import importlib
23import os
24import sys
25
26from tensorflow.python.tools.api.generator import doc_srcs
27from tensorflow.python.util import tf_decorator
28from tensorflow.python.util import tf_export
29
30API_ATTRS = tf_export.API_ATTRS
31API_ATTRS_V1 = tf_export.API_ATTRS_V1
32
33_LAZY_LOADING = False
34_API_VERSIONS = [1, 2]
35_COMPAT_MODULE_TEMPLATE = 'compat.v%d'
36_SUBCOMPAT_MODULE_TEMPLATE = 'compat.v%d.compat.v%d'
37_COMPAT_MODULE_PREFIX = 'compat.v'
38_DEFAULT_PACKAGE = 'tensorflow.python'
39_GENFILES_DIR_SUFFIX = 'genfiles/'
40_SYMBOLS_TO_SKIP_EXPLICITLY = {
41    # Overrides __getattr__, so that unwrapping tf_decorator
42    # would have side effects.
43    'tensorflow.python.platform.flags.FLAGS'
44}
45_GENERATED_FILE_HEADER = """# This file is MACHINE GENERATED! Do not edit.
46# Generated by: tensorflow/python/tools/api/generator/create_python_api.py script.
47\"\"\"%s
48\"\"\"
49
50from __future__ import print_function as _print_function
51
52import sys as _sys
53
54"""
55_GENERATED_FILE_FOOTER = '\n\ndel _print_function\n'
56_DEPRECATION_FOOTER = """
57from tensorflow.python.util import module_wrapper as _module_wrapper
58
59if not isinstance(_sys.modules[__name__], _module_wrapper.TFModuleWrapper):
60  _sys.modules[__name__] = _module_wrapper.TFModuleWrapper(
61      _sys.modules[__name__], "%s", public_apis=%s, deprecation=%s,
62      has_lite=%s)
63"""
64_LAZY_LOADING_MODULE_TEXT_TEMPLATE = """
65# Inform pytype that this module is dynamically populated (b/111239204).
66_HAS_DYNAMIC_ATTRIBUTES = True
67_PUBLIC_APIS = {
68%s
69}
70"""
71
72
73class SymbolExposedTwiceError(Exception):
74  """Raised when different symbols are exported with the same name."""
75  pass
76
77
78def get_canonical_import(import_set):
79  """Obtain one single import from a set of possible sources of a symbol.
80
81  One symbol might come from multiple places as it is being imported and
82  reexported. To simplify API changes, we always use the same import for the
83  same module, and give preference based on higher priority and alphabetical
84  ordering.
85
86  Args:
87    import_set: (set) Imports providing the same symbol. This is a set of tuples
88      in the form (import, priority). We want to pick an import with highest
89      priority.
90
91  Returns:
92    A module name to import
93  """
94  # We use the fact that list sorting is stable, so first we convert the set to
95  # a sorted list of the names and then we resort this list to move elements
96  # not in core tensorflow to the end.
97  # Here we sort by priority (higher preferred) and then  alphabetically by
98  # import string.
99  import_list = sorted(
100      import_set,
101      key=lambda imp_and_priority: (-imp_and_priority[1], imp_and_priority[0]))
102  return import_list[0][0]
103
104
105class _ModuleInitCodeBuilder(object):
106  """Builds a map from module name to imports included in that module."""
107
108  def __init__(self,
109               output_package,
110               api_version,
111               lazy_loading=_LAZY_LOADING,
112               use_relative_imports=False):
113    self._output_package = output_package
114    # Maps API module to API symbol name to set of tuples of the form
115    # (module name, priority).
116    # The same symbol can be imported from multiple locations. Higher
117    # "priority" indicates that import location is preferred over others.
118    self._module_imports = collections.defaultdict(
119        lambda: collections.defaultdict(set))
120    self._dest_import_to_id = collections.defaultdict(int)
121    # Names that start with underscore in the root module.
122    self._underscore_names_in_root = []
123    self._api_version = api_version
124    # Controls whether or not exported symbols are lazily loaded or statically
125    # imported.
126    self._lazy_loading = lazy_loading
127    self._use_relative_imports = use_relative_imports
128
129  def _check_already_imported(self, symbol_id, api_name):
130    if (api_name in self._dest_import_to_id and
131        symbol_id != self._dest_import_to_id[api_name] and symbol_id != -1):
132      raise SymbolExposedTwiceError(
133          'Trying to export multiple symbols with same name: %s.' % api_name)
134    self._dest_import_to_id[api_name] = symbol_id
135
136  def add_import(self, symbol, source_module_name, source_name,
137                 dest_module_name, dest_name):
138    """Adds this import to module_imports.
139
140    Args:
141      symbol: TensorFlow Python symbol.
142      source_module_name: (string) Module to import from.
143      source_name: (string) Name of the symbol to import.
144      dest_module_name: (string) Module name to add import to.
145      dest_name: (string) Import the symbol using this name.
146
147    Raises:
148      SymbolExposedTwiceError: Raised when an import with the same
149        dest_name has already been added to dest_module_name.
150    """
151    # modules_with_exports.py is only used during API generation and
152    # won't be available when actually importing tensorflow.
153    if source_module_name.endswith('python.modules_with_exports'):
154      source_module_name = symbol.__module__
155    import_str = self.format_import(source_module_name, source_name, dest_name)
156
157    # Check if we are trying to expose two different symbols with same name.
158    full_api_name = dest_name
159    if dest_module_name:
160      full_api_name = dest_module_name + '.' + full_api_name
161    symbol_id = -1 if not symbol else id(symbol)
162    self._check_already_imported(symbol_id, full_api_name)
163
164    if not dest_module_name and dest_name.startswith('_'):
165      self._underscore_names_in_root.append(dest_name)
166
167    # The same symbol can be available in multiple modules.
168    # We store all possible ways of importing this symbol and later pick just
169    # one.
170    priority = 0
171    if symbol:
172      # Give higher priority to source module if it matches
173      # symbol's original module.
174      if hasattr(symbol, '__module__'):
175        priority = int(source_module_name == symbol.__module__)
176      # Give higher priority if symbol name matches its __name__.
177      if hasattr(symbol, '__name__'):
178        priority += int(source_name == symbol.__name__)
179    self._module_imports[dest_module_name][full_api_name].add(
180        (import_str, priority))
181
182  def _import_submodules(self):
183    """Add imports for all destination modules in self._module_imports."""
184    # Import all required modules in their parent modules.
185    # For e.g. if we import 'foo.bar.Value'. Then, we also
186    # import 'bar' in 'foo'.
187    imported_modules = set(self._module_imports.keys())
188    for module in imported_modules:
189      if not module:
190        continue
191      module_split = module.split('.')
192      parent_module = ''  # we import submodules in their parent_module
193
194      for submodule_index in range(len(module_split)):
195        if submodule_index > 0:
196          submodule = module_split[submodule_index - 1]
197          parent_module += '.' + submodule if parent_module else submodule
198        import_from = self._output_package
199        if self._lazy_loading:
200          import_from += '.' + '.'.join(module_split[:submodule_index + 1])
201          self.add_import(
202              symbol=None,
203              source_module_name='',
204              source_name=import_from,
205              dest_module_name=parent_module,
206              dest_name=module_split[submodule_index])
207        else:
208          if self._use_relative_imports:
209            import_from = '.'
210          elif submodule_index > 0:
211            import_from += '.' + '.'.join(module_split[:submodule_index])
212          self.add_import(
213              symbol=None,
214              source_module_name=import_from,
215              source_name=module_split[submodule_index],
216              dest_module_name=parent_module,
217              dest_name=module_split[submodule_index])
218
219  def build(self):
220    """Get a map from destination module to __init__.py code for that module.
221
222    Returns:
223      A dictionary where
224        key: (string) destination module (for e.g. tf or tf.consts).
225        value: (string) text that should be in __init__.py files for
226          corresponding modules.
227    """
228    self._import_submodules()
229    module_text_map = {}
230    footer_text_map = {}
231    for dest_module, dest_name_to_imports in self._module_imports.items():
232      # Sort all possible imports for a symbol and pick the first one.
233      imports_list = [
234          get_canonical_import(imports)
235          for _, imports in dest_name_to_imports.items()
236      ]
237      if self._lazy_loading:
238        module_text_map[
239            dest_module] = _LAZY_LOADING_MODULE_TEXT_TEMPLATE % '\n'.join(
240                sorted(imports_list))
241      else:
242        module_text_map[dest_module] = '\n'.join(sorted(imports_list))
243
244    # Expose exported symbols with underscores in root module since we import
245    # from it using * import. Don't need this for lazy_loading because the
246    # underscore symbols are already included in __all__ when passed in and
247    # handled by TFModuleWrapper.
248    root_module_footer = ''
249    if not self._lazy_loading:
250      underscore_names_str = ', '.join(
251          '\'%s\'' % name for name in self._underscore_names_in_root)
252
253      root_module_footer = """
254_names_with_underscore = [%s]
255__all__ = [_s for _s in dir() if not _s.startswith('_')]
256__all__.extend([_s for _s in _names_with_underscore])
257""" % underscore_names_str
258
259    # Add module wrapper if we need to print deprecation messages
260    # or if we use lazy loading.
261    if self._api_version == 1 or self._lazy_loading:
262      for dest_module, _ in self._module_imports.items():
263        deprecation = 'False'
264        has_lite = 'False'
265        if self._api_version == 1:  # Add 1.* deprecations.
266          if not dest_module.startswith(_COMPAT_MODULE_PREFIX):
267            deprecation = 'True'
268        # Workaround to make sure not load lite from lite/__init__.py
269        if (not dest_module and 'lite' in self._module_imports and
270            self._lazy_loading):
271          has_lite = 'True'
272        if self._lazy_loading:
273          public_apis_name = '_PUBLIC_APIS'
274        else:
275          public_apis_name = 'None'
276        footer_text_map[dest_module] = _DEPRECATION_FOOTER % (
277            dest_module, public_apis_name, deprecation, has_lite)
278
279    return module_text_map, footer_text_map, root_module_footer
280
281  def format_import(self, source_module_name, source_name, dest_name):
282    """Formats import statement.
283
284    Args:
285      source_module_name: (string) Source module to import from.
286      source_name: (string) Source symbol name to import.
287      dest_name: (string) Destination alias name.
288
289    Returns:
290      An import statement string.
291    """
292    if self._lazy_loading:
293      return "  '%s': ('%s', '%s')," % (dest_name, source_module_name,
294                                        source_name)
295    else:
296      if source_module_name:
297        if source_name == dest_name:
298          return 'from %s import %s' % (source_module_name, source_name)
299        else:
300          return 'from %s import %s as %s' % (source_module_name, source_name,
301                                              dest_name)
302      else:
303        if source_name == dest_name:
304          return 'import %s' % source_name
305        else:
306          return 'import %s as %s' % (source_name, dest_name)
307
308  def get_destination_modules(self):
309    return set(self._module_imports.keys())
310
311  def copy_imports(self, from_dest_module, to_dest_module):
312    self._module_imports[to_dest_module] = (
313        self._module_imports[from_dest_module].copy())
314
315
316def add_nested_compat_imports(module_builder, compat_api_versions,
317                              output_package):
318  """Adds compat.vN.compat.vK modules to module builder.
319
320  To avoid circular imports, we want to add __init__.py files under
321  compat.vN.compat.vK and under compat.vN.compat.vK.compat. For all other
322  imports, we point to corresponding modules under compat.vK.
323
324  Args:
325    module_builder: `_ModuleInitCodeBuilder` instance.
326    compat_api_versions: Supported compatibility versions.
327    output_package: Base output python package where generated API will be
328      added.
329  """
330  imported_modules = module_builder.get_destination_modules()
331
332  # Copy over all imports in compat.vK to compat.vN.compat.vK and
333  # all imports in compat.vK.compat to compat.vN.compat.vK.compat.
334  for v in compat_api_versions:
335    for sv in compat_api_versions:
336      subcompat_module = _SUBCOMPAT_MODULE_TEMPLATE % (v, sv)
337      compat_module = _COMPAT_MODULE_TEMPLATE % sv
338      module_builder.copy_imports(compat_module, subcompat_module)
339      module_builder.copy_imports('%s.compat' % compat_module,
340                                  '%s.compat' % subcompat_module)
341
342  # Prefixes of modules under compatibility packages, for e.g. "compat.v1.".
343  compat_prefixes = tuple(
344      _COMPAT_MODULE_TEMPLATE % v + '.' for v in compat_api_versions)
345
346  # Above, we only copied function, class and constant imports. Here
347  # we also add imports for child modules.
348  for imported_module in imported_modules:
349    if not imported_module.startswith(compat_prefixes):
350      continue
351    module_split = imported_module.split('.')
352
353    # Handle compat.vN.compat.vK.compat.foo case. That is,
354    # import compat.vK.compat.foo in compat.vN.compat.vK.compat.
355    if len(module_split) > 3 and module_split[2] == 'compat':
356      src_module = '.'.join(module_split[:3])
357      src_name = module_split[3]
358      assert src_name != 'v1' and src_name != 'v2', imported_module
359    else:  # Handle compat.vN.compat.vK.foo case.
360      src_module = '.'.join(module_split[:2])
361      src_name = module_split[2]
362      if src_name == 'compat':
363        continue  # compat.vN.compat.vK.compat is handled separately
364
365    for compat_api_version in compat_api_versions:
366      module_builder.add_import(
367          symbol=None,
368          source_module_name='%s.%s' % (output_package, src_module),
369          source_name=src_name,
370          dest_module_name='compat.v%d.%s' % (compat_api_version, src_module),
371          dest_name=src_name)
372
373
374def _get_name_and_module(full_name):
375  """Split full_name into module and short name.
376
377  Args:
378    full_name: Full name of symbol that includes module.
379
380  Returns:
381    Full module name and short symbol name.
382  """
383  name_segments = full_name.split('.')
384  return '.'.join(name_segments[:-1]), name_segments[-1]
385
386
387def _join_modules(module1, module2):
388  """Concatenate 2 module components.
389
390  Args:
391    module1: First module to join.
392    module2: Second module to join.
393
394  Returns:
395    Given two modules aaa.bbb and ccc.ddd, returns a joined
396    module aaa.bbb.ccc.ddd.
397  """
398  if not module1:
399    return module2
400  if not module2:
401    return module1
402  return '%s.%s' % (module1, module2)
403
404
405def add_imports_for_symbol(module_code_builder,
406                           symbol,
407                           source_module_name,
408                           source_name,
409                           api_name,
410                           api_version,
411                           output_module_prefix=''):
412  """Add imports for the given symbol to `module_code_builder`.
413
414  Args:
415    module_code_builder: `_ModuleInitCodeBuilder` instance.
416    symbol: A symbol.
417    source_module_name: Module that we can import the symbol from.
418    source_name: Name we can import the symbol with.
419    api_name: API name. Currently, must be either `tensorflow` or `estimator`.
420    api_version: API version.
421    output_module_prefix: Prefix to prepend to destination module.
422  """
423  if api_version == 1:
424    names_attr = API_ATTRS_V1[api_name].names
425    constants_attr = API_ATTRS_V1[api_name].constants
426  else:
427    names_attr = API_ATTRS[api_name].names
428    constants_attr = API_ATTRS[api_name].constants
429
430  # If symbol is _tf_api_constants attribute, then add the constants.
431  if source_name == constants_attr:
432    for exports, name in symbol:
433      for export in exports:
434        dest_module, dest_name = _get_name_and_module(export)
435        dest_module = _join_modules(output_module_prefix, dest_module)
436        module_code_builder.add_import(None, source_module_name, name,
437                                       dest_module, dest_name)
438
439  # If symbol has _tf_api_names attribute, then add import for it.
440  if (hasattr(symbol, '__dict__') and names_attr in symbol.__dict__):
441
442    # Generate import statements for symbols.
443    for export in getattr(symbol, names_attr):  # pylint: disable=protected-access
444      dest_module, dest_name = _get_name_and_module(export)
445      dest_module = _join_modules(output_module_prefix, dest_module)
446      module_code_builder.add_import(symbol, source_module_name, source_name,
447                                     dest_module, dest_name)
448
449
450def get_api_init_text(packages,
451                      packages_to_ignore,
452                      output_package,
453                      api_name,
454                      api_version,
455                      compat_api_versions=None,
456                      lazy_loading=_LAZY_LOADING,
457                      use_relative_imports=False):
458  """Get a map from destination module to __init__.py code for that module.
459
460  Args:
461    packages: Base python packages containing python with target tf_export
462      decorators.
463    packages_to_ignore: python packages to be ignored when checking for
464      tf_export decorators.
465    output_package: Base output python package where generated API will be
466      added.
467    api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
468    api_version: API version you want to generate (1 or 2).
469    compat_api_versions: Additional API versions to generate under compat/
470      directory.
471    lazy_loading: Boolean flag. If True, a lazy loading `__init__.py` file is
472      produced and if `False`, static imports are used.
473    use_relative_imports: True if we should use relative imports when importing
474      submodules.
475
476  Returns:
477    A dictionary where
478      key: (string) destination module (for e.g. tf or tf.consts).
479      value: (string) text that should be in __init__.py files for
480        corresponding modules.
481  """
482  if compat_api_versions is None:
483    compat_api_versions = []
484  module_code_builder = _ModuleInitCodeBuilder(output_package, api_version,
485                                               lazy_loading,
486                                               use_relative_imports)
487
488  # Traverse over everything imported above. Specifically,
489  # we want to traverse over TensorFlow Python modules.
490
491  def in_packages(m):
492    return any(package in m for package in packages)
493
494  for module in list(sys.modules.values()):
495    # Only look at tensorflow modules.
496    if (not module or not hasattr(module, '__name__') or
497        module.__name__ is None or not in_packages(module.__name__)):
498      continue
499    if packages_to_ignore and any([p for p in packages_to_ignore
500                                   if p in module.__name__]):
501      continue
502
503    # Do not generate __init__.py files for contrib modules for now.
504    if (('.contrib.' in module.__name__ or module.__name__.endswith('.contrib'))
505        and '.lite' not in module.__name__):
506      continue
507
508    for module_contents_name in dir(module):
509      if (module.__name__ + '.' +
510          module_contents_name in _SYMBOLS_TO_SKIP_EXPLICITLY):
511        continue
512      attr = getattr(module, module_contents_name)
513      _, attr = tf_decorator.unwrap(attr)
514
515      add_imports_for_symbol(module_code_builder, attr, module.__name__,
516                             module_contents_name, api_name, api_version)
517      for compat_api_version in compat_api_versions:
518        add_imports_for_symbol(module_code_builder, attr, module.__name__,
519                               module_contents_name, api_name,
520                               compat_api_version,
521                               _COMPAT_MODULE_TEMPLATE % compat_api_version)
522
523  if compat_api_versions:
524    add_nested_compat_imports(module_code_builder, compat_api_versions,
525                              output_package)
526  return module_code_builder.build()
527
528
529def get_module(dir_path, relative_to_dir):
530  """Get module that corresponds to path relative to relative_to_dir.
531
532  Args:
533    dir_path: Path to directory.
534    relative_to_dir: Get module relative to this directory.
535
536  Returns:
537    Name of module that corresponds to the given directory.
538  """
539  dir_path = dir_path[len(relative_to_dir):]
540  # Convert path separators to '/' for easier parsing below.
541  dir_path = dir_path.replace(os.sep, '/')
542  return dir_path.replace('/', '.').strip('.')
543
544
545def get_module_docstring(module_name, package, api_name):
546  """Get docstring for the given module.
547
548  This method looks for docstring in the following order:
549  1. Checks if module has a docstring specified in doc_srcs.
550  2. Checks if module has a docstring source module specified
551     in doc_srcs. If it does, gets docstring from that module.
552  3. Checks if module with module_name exists under base package.
553     If it does, gets docstring from that module.
554  4. Returns a default docstring.
555
556  Args:
557    module_name: module name relative to tensorflow (excluding 'tensorflow.'
558      prefix) to get a docstring for.
559    package: Base python package containing python with target tf_export
560      decorators.
561    api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
562
563  Returns:
564    One-line docstring to describe the module.
565  """
566  # Get the same module doc strings for any version. That is, for module
567  # 'compat.v1.foo' we can get docstring from module 'foo'.
568  for version in _API_VERSIONS:
569    compat_prefix = _COMPAT_MODULE_TEMPLATE % version
570    if module_name.startswith(compat_prefix):
571      module_name = module_name[len(compat_prefix):].strip('.')
572
573  # Module under base package to get a docstring from.
574  docstring_module_name = module_name
575
576  doc_sources = doc_srcs.get_doc_sources(api_name)
577
578  if module_name in doc_sources:
579    docsrc = doc_sources[module_name]
580    if docsrc.docstring:
581      return docsrc.docstring
582    if docsrc.docstring_module_name:
583      docstring_module_name = docsrc.docstring_module_name
584
585  docstring_module_name = package + '.' + docstring_module_name
586  if (docstring_module_name in sys.modules and
587      sys.modules[docstring_module_name].__doc__):
588    return sys.modules[docstring_module_name].__doc__
589
590  return 'Public API for tf.%s namespace.' % module_name
591
592
593def create_api_files(output_files,
594                     packages,
595                     packages_to_ignore,
596                     root_init_template,
597                     output_dir,
598                     output_package,
599                     api_name,
600                     api_version,
601                     compat_api_versions,
602                     compat_init_templates,
603                     lazy_loading=_LAZY_LOADING,
604                     use_relative_imports=False):
605  """Creates __init__.py files for the Python API.
606
607  Args:
608    output_files: List of __init__.py file paths to create.
609    packages: Base python packages containing python with target tf_export
610      decorators.
611    packages_to_ignore: python packages to be ignored when checking for
612      tf_export decorators.
613    root_init_template: Template for top-level __init__.py file. "# API IMPORTS
614      PLACEHOLDER" comment in the template file will be replaced with imports.
615    output_dir: output API root directory.
616    output_package: Base output package where generated API will be added.
617    api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
618    api_version: API version to generate (`v1` or `v2`).
619    compat_api_versions: Additional API versions to generate in compat/
620      subdirectory.
621    compat_init_templates: List of templates for top level compat init files in
622      the same order as compat_api_versions.
623    lazy_loading: Boolean flag. If True, a lazy loading `__init__.py` file is
624      produced and if `False`, static imports are used.
625    use_relative_imports: True if we should use relative imports when import
626      submodules.
627
628  Raises:
629    ValueError: if output_files list is missing a required file.
630  """
631  module_name_to_file_path = {}
632  for output_file in output_files:
633    module_name = get_module(os.path.dirname(output_file), output_dir)
634    module_name_to_file_path[module_name] = os.path.normpath(output_file)
635
636  # Create file for each expected output in genrule.
637  for module, file_path in module_name_to_file_path.items():
638    if not os.path.isdir(os.path.dirname(file_path)):
639      os.makedirs(os.path.dirname(file_path))
640    open(file_path, 'a').close()
641
642  (
643      module_text_map,
644      deprecation_footer_map,
645      root_module_footer,
646  ) = get_api_init_text(packages, packages_to_ignore, output_package, api_name,
647                        api_version, compat_api_versions, lazy_loading,
648                        use_relative_imports)
649
650  # Add imports to output files.
651  missing_output_files = []
652  # Root modules are "" and "compat.v*".
653  root_module = ''
654  compat_module_to_template = {
655      _COMPAT_MODULE_TEMPLATE % v: t
656      for v, t in zip(compat_api_versions, compat_init_templates)
657  }
658  for v in compat_api_versions:
659    compat_module_to_template.update({
660        _SUBCOMPAT_MODULE_TEMPLATE % (v, vs): t
661        for vs, t in zip(compat_api_versions, compat_init_templates)
662    })
663
664  for module, text in module_text_map.items():
665    # Make sure genrule output file list is in sync with API exports.
666    if module not in module_name_to_file_path:
667      module_file_path = '"%s/__init__.py"' % (module.replace('.', '/'))
668      missing_output_files.append(module_file_path)
669      continue
670
671    contents = ''
672    if module == root_module and root_init_template:
673      # Read base init file for root module
674      with open(root_init_template, 'r') as root_init_template_file:
675        contents = root_init_template_file.read()
676        contents = contents.replace('# API IMPORTS PLACEHOLDER', text)
677        contents = contents.replace('# __all__ PLACEHOLDER', root_module_footer)
678    elif module in compat_module_to_template:
679      # Read base init file for compat module
680      with open(compat_module_to_template[module], 'r') as init_template_file:
681        contents = init_template_file.read()
682        contents = contents.replace('# API IMPORTS PLACEHOLDER', text)
683    else:
684      contents = (
685          _GENERATED_FILE_HEADER %
686          get_module_docstring(module, packages[0], api_name) + text +
687          _GENERATED_FILE_FOOTER)
688    if module in deprecation_footer_map:
689      if '# WRAPPER_PLACEHOLDER' in contents:
690        contents = contents.replace('# WRAPPER_PLACEHOLDER',
691                                    deprecation_footer_map[module])
692      else:
693        contents += deprecation_footer_map[module]
694    with open(module_name_to_file_path[module], 'w') as fp:
695      fp.write(contents)
696
697  if missing_output_files:
698    raise ValueError(
699        """Missing outputs for genrule:\n%s. Be sure to add these targets to
700tensorflow/python/tools/api/generator/api_init_files_v1.bzl and
701tensorflow/python/tools/api/generator/api_init_files.bzl (tensorflow repo), or
702tensorflow_estimator/python/estimator/api/api_gen.bzl (estimator repo)""" %
703        ',\n'.join(sorted(missing_output_files)))
704
705
706def main():
707  parser = argparse.ArgumentParser()
708  parser.add_argument(
709      'outputs',
710      metavar='O',
711      type=str,
712      nargs='+',
713      help='If a single file is passed in, then we assume it contains a '
714      'semicolon-separated list of Python files that we expect this script to '
715      'output. If multiple files are passed in, then we assume output files '
716      'are listed directly as arguments.')
717  parser.add_argument(
718      '--packages',
719      default=_DEFAULT_PACKAGE,
720      type=str,
721      help='Base packages that import modules containing the target tf_export '
722      'decorators.')
723  parser.add_argument(
724      '--packages_to_ignore',
725      default='',
726      type=str,
727      help='Packages to exclude from the api generation. This is used to hide '
728      'certain packages from this script when multiple copy of code exists, '
729      'eg Keras. It is useful to avoid the SymbolExposedTwiceError.'
730      )
731  parser.add_argument(
732      '--root_init_template',
733      default='',
734      type=str,
735      help='Template for top level __init__.py file. '
736      '"#API IMPORTS PLACEHOLDER" comment will be replaced with imports.')
737  parser.add_argument(
738      '--apidir',
739      type=str,
740      required=True,
741      help='Directory where generated output files are placed. '
742      'gendir should be a prefix of apidir. Also, apidir '
743      'should be a prefix of every directory in outputs.')
744  parser.add_argument(
745      '--apiname',
746      required=True,
747      type=str,
748      choices=API_ATTRS.keys(),
749      help='The API you want to generate.')
750  parser.add_argument(
751      '--apiversion',
752      default=2,
753      type=int,
754      choices=_API_VERSIONS,
755      help='The API version you want to generate.')
756  parser.add_argument(
757      '--compat_apiversions',
758      default=[],
759      type=int,
760      action='append',
761      help='Additional versions to generate in compat/ subdirectory. '
762      'If set to 0, then no additional version would be generated.')
763  parser.add_argument(
764      '--compat_init_templates',
765      default=[],
766      type=str,
767      action='append',
768      help='Templates for top-level __init__ files under compat modules. '
769      'The list of init file templates must be in the same order as '
770      'list of versions passed with compat_apiversions.')
771  parser.add_argument(
772      '--output_package',
773      default='tensorflow',
774      type=str,
775      help='Root output package.')
776  parser.add_argument(
777      '--loading',
778      default='default',
779      type=str,
780      choices=['lazy', 'static', 'default'],
781      help='Controls how the generated __init__.py file loads the exported '
782      'symbols. \'lazy\' means the symbols are loaded when first used. '
783      '\'static\' means all exported symbols are loaded in the '
784      '__init__.py file. \'default\' uses the value of the '
785      '_LAZY_LOADING constant in create_python_api.py.')
786  parser.add_argument(
787      '--use_relative_imports',
788      default=False,
789      type=bool,
790      help='Whether to import submodules using relative imports or absolute '
791      'imports')
792  args = parser.parse_args()
793
794  if len(args.outputs) == 1:
795    # If we only get a single argument, then it must be a file containing
796    # list of outputs.
797    with open(args.outputs[0]) as output_list_file:
798      outputs = [line.strip() for line in output_list_file.read().split(';')]
799  else:
800    outputs = args.outputs
801
802  # Populate `sys.modules` with modules containing tf_export().
803  packages = args.packages.split(',')
804  for package in packages:
805    importlib.import_module(package)
806  packages_to_ignore = args.packages_to_ignore.split(',')
807
808  # Determine if the modules shall be loaded lazily or statically.
809  if args.loading == 'default':
810    lazy_loading = _LAZY_LOADING
811  elif args.loading == 'lazy':
812    lazy_loading = True
813  elif args.loading == 'static':
814    lazy_loading = False
815  else:
816    # This should never happen (tm).
817    raise ValueError('Invalid value for --loading flag: %s. Must be one of '
818                     'lazy, static, default.' % args.loading)
819
820  create_api_files(outputs, packages, packages_to_ignore,
821                   args.root_init_template, args.apidir,
822                   args.output_package, args.apiname, args.apiversion,
823                   args.compat_apiversions, args.compat_init_templates,
824                   lazy_loading, args.use_relative_imports)
825
826
827if __name__ == '__main__':
828  main()
829