• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Generates and prints out imports and constants for new TensorFlow python api.
16"""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import argparse
22import collections
23import importlib
24import os
25import sys
26
27from tensorflow.python.tools.api.generator import doc_srcs
28from tensorflow.python.util import tf_decorator
29from tensorflow.python.util import tf_export
30
31API_ATTRS = tf_export.API_ATTRS
32API_ATTRS_V1 = tf_export.API_ATTRS_V1
33
34_API_VERSIONS = [1, 2]
35_COMPAT_MODULE_TEMPLATE = 'compat.v%d'
36_DEFAULT_PACKAGE = 'tensorflow.python'
37_GENFILES_DIR_SUFFIX = 'genfiles/'
38_SYMBOLS_TO_SKIP_EXPLICITLY = {
39    # Overrides __getattr__, so that unwrapping tf_decorator
40    # would have side effects.
41    'tensorflow.python.platform.flags.FLAGS'
42}
43_GENERATED_FILE_HEADER = """# This file is MACHINE GENERATED! Do not edit.
44# Generated by: tensorflow/python/tools/api/generator/create_python_api.py script.
45\"\"\"%s
46\"\"\"
47
48from __future__ import print_function as _print_function
49
50"""
51_GENERATED_FILE_FOOTER = '\n\ndel _print_function\n'
52
53
54class SymbolExposedTwiceError(Exception):
55  """Raised when different symbols are exported with the same name."""
56  pass
57
58
59def format_import(source_module_name, source_name, dest_name):
60  """Formats import statement.
61
62  Args:
63    source_module_name: (string) Source module to import from.
64    source_name: (string) Source symbol name to import.
65    dest_name: (string) Destination alias name.
66
67  Returns:
68    An import statement string.
69  """
70  if source_module_name:
71    if source_name == dest_name:
72      return 'from %s import %s' % (source_module_name, source_name)
73    else:
74      return 'from %s import %s as %s' % (
75          source_module_name, source_name, dest_name)
76  else:
77    if source_name == dest_name:
78      return 'import %s' % source_name
79    else:
80      return 'import %s as %s' % (source_name, dest_name)
81
82
83class _ModuleInitCodeBuilder(object):
84  """Builds a map from module name to imports included in that module."""
85
86  def __init__(self, output_package):
87    self._output_package = output_package
88    self._module_imports = collections.defaultdict(
89        lambda: collections.defaultdict(set))
90    self._dest_import_to_id = collections.defaultdict(int)
91    # Names that start with underscore in the root module.
92    self._underscore_names_in_root = []
93
94  def add_import(
95      self, symbol_id, dest_module_name, source_module_name, source_name,
96      dest_name):
97    """Adds this import to module_imports.
98
99    Args:
100      symbol_id: (number) Unique identifier of the symbol to import.
101      dest_module_name: (string) Module name to add import to.
102      source_module_name: (string) Module to import from.
103      source_name: (string) Name of the symbol to import.
104      dest_name: (string) Import the symbol using this name.
105
106    Raises:
107      SymbolExposedTwiceError: Raised when an import with the same
108        dest_name has already been added to dest_module_name.
109    """
110    import_str = format_import(source_module_name, source_name, dest_name)
111
112    # Check if we are trying to expose two different symbols with same name.
113    full_api_name = dest_name
114    if dest_module_name:
115      full_api_name = dest_module_name + '.' + full_api_name
116    if (full_api_name in self._dest_import_to_id and
117        symbol_id != self._dest_import_to_id[full_api_name] and
118        symbol_id != -1):
119      raise SymbolExposedTwiceError(
120          'Trying to export multiple symbols with same name: %s.' %
121          full_api_name)
122    self._dest_import_to_id[full_api_name] = symbol_id
123
124    if not dest_module_name and dest_name.startswith('_'):
125      self._underscore_names_in_root.append(dest_name)
126
127    # The same symbol can be available in multiple modules.
128    # We store all possible ways of importing this symbol and later pick just
129    # one.
130    self._module_imports[dest_module_name][full_api_name].add(import_str)
131
132  def _import_submodules(self):
133    """Add imports for all destination modules in self._module_imports."""
134    # Import all required modules in their parent modules.
135    # For e.g. if we import 'foo.bar.Value'. Then, we also
136    # import 'bar' in 'foo'.
137    imported_modules = set(self._module_imports.keys())
138    for module in imported_modules:
139      if not module:
140        continue
141      module_split = module.split('.')
142      parent_module = ''  # we import submodules in their parent_module
143
144      for submodule_index in range(len(module_split)):
145        if submodule_index > 0:
146          submodule = module_split[submodule_index-1]
147          parent_module += '.' + submodule if parent_module else submodule
148        import_from = self._output_package
149        if submodule_index > 0:
150          import_from += '.' + '.'.join(module_split[:submodule_index])
151        self.add_import(
152            -1, parent_module, import_from,
153            module_split[submodule_index], module_split[submodule_index])
154
155  def build(self):
156    """Get a map from destination module to __init__.py code for that module.
157
158    Returns:
159      A dictionary where
160        key: (string) destination module (for e.g. tf or tf.consts).
161        value: (string) text that should be in __init__.py files for
162          corresponding modules.
163    """
164    self._import_submodules()
165    module_text_map = {}
166    for dest_module, dest_name_to_imports in self._module_imports.items():
167      # Sort all possible imports for a symbol and pick the first one.
168      imports_list = [
169          sorted(imports)[0]
170          for _, imports in dest_name_to_imports.items()]
171      module_text_map[dest_module] = '\n'.join(sorted(imports_list))
172
173    # Expose exported symbols with underscores in root module
174    # since we import from it using * import.
175    underscore_names_str = ', '.join(
176        '\'%s\'' % name for name in self._underscore_names_in_root)
177    # We will always generate a root __init__.py file to let us handle *
178    # imports consistently. Be sure to have a root __init__.py file listed in
179    # the script outputs.
180    module_text_map[''] = module_text_map.get('', '') + '''
181_names_with_underscore = [%s]
182__all__ = [_s for _s in dir() if not _s.startswith('_')]
183__all__.extend([_s for _s in _names_with_underscore])
184''' % underscore_names_str
185
186    return module_text_map
187
188
189def _get_name_and_module(full_name):
190  """Split full_name into module and short name.
191
192  Args:
193    full_name: Full name of symbol that includes module.
194
195  Returns:
196    Full module name and short symbol name.
197  """
198  name_segments = full_name.split('.')
199  return '.'.join(name_segments[:-1]), name_segments[-1]
200
201
202def _join_modules(module1, module2):
203  """Concatenate 2 module components.
204
205  Args:
206    module1: First module to join.
207    module2: Second module to join.
208
209  Returns:
210    Given two modules aaa.bbb and ccc.ddd, returns a joined
211    module aaa.bbb.ccc.ddd.
212  """
213  if not module1:
214    return module2
215  if not module2:
216    return module1
217  return '%s.%s' % (module1, module2)
218
219
220def add_imports_for_symbol(
221    module_code_builder,
222    symbol,
223    source_module_name,
224    source_name,
225    api_name,
226    api_version,
227    output_module_prefix=''):
228  """Add imports for the given symbol to `module_code_builder`.
229
230  Args:
231    module_code_builder: `_ModuleInitCodeBuilder` instance.
232    symbol: A symbol.
233    source_module_name: Module that we can import the symbol from.
234    source_name: Name we can import the symbol with.
235    api_name: API name. Currently, must be either `tensorflow` or `estimator`.
236    api_version: API version.
237    output_module_prefix: Prefix to prepend to destination module.
238  """
239  if api_version == 1:
240    names_attr = API_ATTRS_V1[api_name].names
241    constants_attr = API_ATTRS_V1[api_name].constants
242  else:
243    names_attr = API_ATTRS[api_name].names
244    constants_attr = API_ATTRS[api_name].constants
245
246  # If symbol is _tf_api_constants attribute, then add the constants.
247  if source_name == constants_attr:
248    for exports, name in symbol:
249      for export in exports:
250        dest_module, dest_name = _get_name_and_module(export)
251        dest_module = _join_modules(output_module_prefix, dest_module)
252        module_code_builder.add_import(
253            -1, dest_module, source_module_name, name, dest_name)
254
255  # If symbol has _tf_api_names attribute, then add import for it.
256  if (hasattr(symbol, '__dict__') and names_attr in symbol.__dict__):
257    for export in getattr(symbol, names_attr):  # pylint: disable=protected-access
258      dest_module, dest_name = _get_name_and_module(export)
259      dest_module = _join_modules(output_module_prefix, dest_module)
260      module_code_builder.add_import(
261          id(symbol), dest_module, source_module_name, source_name, dest_name)
262
263
264def get_api_init_text(packages,
265                      output_package,
266                      api_name,
267                      api_version,
268                      compat_api_versions=None):
269  """Get a map from destination module to __init__.py code for that module.
270
271  Args:
272    packages: Base python packages containing python with target tf_export
273      decorators.
274    output_package: Base output python package where generated API will be
275      added.
276    api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
277    api_version: API version you want to generate (1 or 2).
278    compat_api_versions: Additional API versions to generate under compat/
279      directory.
280
281  Returns:
282    A dictionary where
283      key: (string) destination module (for e.g. tf or tf.consts).
284      value: (string) text that should be in __init__.py files for
285        corresponding modules.
286  """
287  if compat_api_versions is None:
288    compat_api_versions = []
289  module_code_builder = _ModuleInitCodeBuilder(output_package)
290  # Traverse over everything imported above. Specifically,
291  # we want to traverse over TensorFlow Python modules.
292
293  def in_packages(m):
294    return any(package in m for package in packages)
295
296  for module in list(sys.modules.values()):
297    # Only look at tensorflow modules.
298    if (not module or not hasattr(module, '__name__') or
299        module.__name__ is None or not in_packages(module.__name__)):
300      continue
301    # Do not generate __init__.py files for contrib modules for now.
302    if (('.contrib.' in module.__name__ or module.__name__.endswith('.contrib'))
303        and '.lite' not in module.__name__):
304      continue
305
306    for module_contents_name in dir(module):
307      if (module.__name__ + '.' + module_contents_name
308          in _SYMBOLS_TO_SKIP_EXPLICITLY):
309        continue
310      attr = getattr(module, module_contents_name)
311      _, attr = tf_decorator.unwrap(attr)
312
313      add_imports_for_symbol(
314          module_code_builder, attr, module.__name__, module_contents_name,
315          api_name, api_version)
316      for compat_api_version in compat_api_versions:
317        add_imports_for_symbol(
318            module_code_builder, attr, module.__name__, module_contents_name,
319            api_name, compat_api_version,
320            _COMPAT_MODULE_TEMPLATE % compat_api_version)
321
322  return module_code_builder.build()
323
324
325def get_module(dir_path, relative_to_dir):
326  """Get module that corresponds to path relative to relative_to_dir.
327
328  Args:
329    dir_path: Path to directory.
330    relative_to_dir: Get module relative to this directory.
331
332  Returns:
333    Name of module that corresponds to the given directory.
334  """
335  dir_path = dir_path[len(relative_to_dir):]
336  # Convert path separators to '/' for easier parsing below.
337  dir_path = dir_path.replace(os.sep, '/')
338  return dir_path.replace('/', '.').strip('.')
339
340
341def get_module_docstring(module_name, package, api_name):
342  """Get docstring for the given module.
343
344  This method looks for docstring in the following order:
345  1. Checks if module has a docstring specified in doc_srcs.
346  2. Checks if module has a docstring source module specified
347     in doc_srcs. If it does, gets docstring from that module.
348  3. Checks if module with module_name exists under base package.
349     If it does, gets docstring from that module.
350  4. Returns a default docstring.
351
352  Args:
353    module_name: module name relative to tensorflow
354      (excluding 'tensorflow.' prefix) to get a docstring for.
355    package: Base python package containing python with target tf_export
356      decorators.
357    api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
358
359  Returns:
360    One-line docstring to describe the module.
361  """
362  # Get the same module doc strings for any version. That is, for module
363  # 'compat.v1.foo' we can get docstring from module 'foo'.
364  for version in _API_VERSIONS:
365    compat_prefix = _COMPAT_MODULE_TEMPLATE % version
366    if module_name.startswith(compat_prefix):
367      module_name = module_name[len(compat_prefix):].strip('.')
368
369  # Module under base package to get a docstring from.
370  docstring_module_name = module_name
371
372  doc_sources = doc_srcs.get_doc_sources(api_name)
373
374  if module_name in doc_sources:
375    docsrc = doc_sources[module_name]
376    if docsrc.docstring:
377      return docsrc.docstring
378    if docsrc.docstring_module_name:
379      docstring_module_name = docsrc.docstring_module_name
380
381  docstring_module_name = package + '.' + docstring_module_name
382  if (docstring_module_name in sys.modules and
383      sys.modules[docstring_module_name].__doc__):
384    return sys.modules[docstring_module_name].__doc__
385
386  return 'Public API for tf.%s namespace.' % module_name
387
388
389def create_api_files(output_files, packages, root_init_template, output_dir,
390                     output_package, api_name, api_version,
391                     compat_api_versions, compat_init_templates):
392  """Creates __init__.py files for the Python API.
393
394  Args:
395    output_files: List of __init__.py file paths to create.
396    packages: Base python packages containing python with target tf_export
397      decorators.
398    root_init_template: Template for top-level __init__.py file.
399      "# API IMPORTS PLACEHOLDER" comment in the template file will be replaced
400      with imports.
401    output_dir: output API root directory.
402    output_package: Base output package where generated API will be added.
403    api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
404    api_version: API version to generate (`v1` or `v2`).
405    compat_api_versions: Additional API versions to generate in compat/
406      subdirectory.
407    compat_init_templates: List of templates for top level compat init files
408      in the same order as compat_api_versions.
409
410  Raises:
411    ValueError: if output_files list is missing a required file.
412  """
413  module_name_to_file_path = {}
414  for output_file in output_files:
415    module_name = get_module(os.path.dirname(output_file), output_dir)
416    module_name_to_file_path[module_name] = os.path.normpath(output_file)
417
418  # Create file for each expected output in genrule.
419  for module, file_path in module_name_to_file_path.items():
420    if not os.path.isdir(os.path.dirname(file_path)):
421      os.makedirs(os.path.dirname(file_path))
422    open(file_path, 'a').close()
423
424  module_text_map = get_api_init_text(packages, output_package, api_name,
425                                      api_version, compat_api_versions)
426
427  # Add imports to output files.
428  missing_output_files = []
429  # Root modules are "" and "compat.v*".
430  root_module = ''
431  compat_module_to_template = {
432      _COMPAT_MODULE_TEMPLATE % v: t
433      for v, t in zip(compat_api_versions, compat_init_templates)
434  }
435
436  for module, text in module_text_map.items():
437    # Make sure genrule output file list is in sync with API exports.
438    if module not in module_name_to_file_path:
439      module_file_path = '"%s/__init__.py"' %  (
440          module.replace('.', '/'))
441      missing_output_files.append(module_file_path)
442      continue
443
444    contents = ''
445    if module == root_module and root_init_template:
446      # Read base init file for root module
447      with open(root_init_template, 'r') as root_init_template_file:
448        contents = root_init_template_file.read()
449        contents = contents.replace('# API IMPORTS PLACEHOLDER', text)
450    elif module in compat_module_to_template:
451      # Read base init file for compat module
452      with open(compat_module_to_template[module], 'r') as init_template_file:
453        contents = init_template_file.read()
454        contents = contents.replace('# API IMPORTS PLACEHOLDER', text)
455    else:
456      contents = (
457          _GENERATED_FILE_HEADER % get_module_docstring(
458              module, packages[0], api_name) + text + _GENERATED_FILE_FOOTER)
459    with open(module_name_to_file_path[module], 'w') as fp:
460      fp.write(contents)
461
462  if missing_output_files:
463    raise ValueError(
464        """Missing outputs for genrule:\n%s. Be sure to add these targets to
465tensorflow/python/tools/api/generator/api_init_files_v1.bzl and
466tensorflow/python/tools/api/generator/api_init_files.bzl (tensorflow repo), or
467tensorflow_estimator/python/estimator/api/api_gen.bzl (estimator repo)"""
468        % ',\n'.join(sorted(missing_output_files)))
469
470
471def main():
472  parser = argparse.ArgumentParser()
473  parser.add_argument(
474      'outputs', metavar='O', type=str, nargs='+',
475      help='If a single file is passed in, then we we assume it contains a '
476      'semicolon-separated list of Python files that we expect this script to '
477      'output. If multiple files are passed in, then we assume output files '
478      'are listed directly as arguments.')
479  parser.add_argument(
480      '--packages',
481      default=_DEFAULT_PACKAGE,
482      type=str,
483      help='Base packages that import modules containing the target tf_export '
484      'decorators.')
485  parser.add_argument(
486      '--root_init_template', default='', type=str,
487      help='Template for top level __init__.py file. '
488           '"#API IMPORTS PLACEHOLDER" comment will be replaced with imports.')
489  parser.add_argument(
490      '--apidir', type=str, required=True,
491      help='Directory where generated output files are placed. '
492           'gendir should be a prefix of apidir. Also, apidir '
493           'should be a prefix of every directory in outputs.')
494  parser.add_argument(
495      '--apiname', required=True, type=str,
496      choices=API_ATTRS.keys(),
497      help='The API you want to generate.')
498  parser.add_argument(
499      '--apiversion', default=2, type=int,
500      choices=_API_VERSIONS,
501      help='The API version you want to generate.')
502  parser.add_argument(
503      '--compat_apiversions', default=[], type=int, action='append',
504      help='Additional versions to generate in compat/ subdirectory. '
505           'If set to 0, then no additional version would be generated.')
506  parser.add_argument(
507      '--compat_init_templates', default=[], type=str, action='append',
508      help='Templates for top-level __init__ files under compat modules. '
509           'The list of init file templates must be in the same order as '
510           'list of versions passed with compat_apiversions.')
511  parser.add_argument(
512      '--output_package', default='tensorflow', type=str,
513      help='Root output package.')
514  args = parser.parse_args()
515
516  if len(args.outputs) == 1:
517    # If we only get a single argument, then it must be a file containing
518    # list of outputs.
519    with open(args.outputs[0]) as output_list_file:
520      outputs = [line.strip() for line in output_list_file.read().split(';')]
521  else:
522    outputs = args.outputs
523
524  # Populate `sys.modules` with modules containing tf_export().
525  packages = args.packages.split(',')
526  for package in packages:
527    importlib.import_module(package)
528  create_api_files(outputs, packages, args.root_init_template, args.apidir,
529                   args.output_package, args.apiname, args.apiversion,
530                   args.compat_apiversions, args.compat_init_templates)
531
532
533if __name__ == '__main__':
534  main()
535