• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# lint as: python3
2# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""A tool to generate api_docs for TensorFlow2.
17
18```
19python generate2.py --output_dir=/tmp/out
20```
21
22Requires a local installation of `tensorflow_docs`:
23
24```
25pip install git+https://github.com/tensorflow/docs
26```
27"""
28
29from __future__ import absolute_import
30from __future__ import division
31from __future__ import print_function
32
33import pathlib
34import textwrap
35
36from absl import app
37from absl import flags
38
39import tensorflow as tf
40
41from tensorflow_docs.api_generator import doc_controls
42from tensorflow_docs.api_generator import doc_generator_visitor
43from tensorflow_docs.api_generator import generate_lib
44
45from tensorflow.python.framework import ops
46from tensorflow.python.util import tf_export
47from tensorflow.python.util import tf_inspect
48
49# Caution: the google and oss versions of this import are different.
50import base_dir
51
52# pylint: disable=g-import-not-at-top
53try:
54  from tensorflow.python.types import doc_typealias
55  _EXTRA_DOCS = getattr(doc_typealias, "_EXTRA_DOCS", {})
56  del doc_typealias
57except ImportError:
58  _EXTRA_DOCS = {}
59# pylint: enable=g-import-not-at-top
60
61# `tf` has an `__all__` that doesn't list important things like `keras`.
62# The doc generator recognizes `__all__` as the list of public symbols.
63# So patch `tf.__all__` to list everything.
64tf.__all__ = [item_name for item_name, value in tf_inspect.getmembers(tf)]
65
66# tf_export generated two copies of the module objects.
67# This will just list compat.v2 as an alias for tf. Close enough, let's not
68# duplicate all the module skeleton files.
69tf.compat.v2 = tf
70
71FLAGS = flags.FLAGS
72
73flags.DEFINE_string(
74    "code_url_prefix",
75    "/code/stable/tensorflow",
76    "A url to prepend to code paths when creating links to defining code")
77
78flags.DEFINE_string("output_dir", "/tmp/out",
79                    "A directory, where the docs will be output to.")
80
81flags.DEFINE_bool("search_hints", True,
82                  "Include meta-data search hints at the top of each file.")
83
84flags.DEFINE_string(
85    "site_path", "",
86    "The path prefix (up to `.../api_docs/python`) used in the "
87    "`_toc.yaml` and `_redirects.yaml` files")
88
89flags.DEFINE_bool("gen_report", False,
90                  ("Generate an API report containing the health of the"
91                   "docstrings of the public API."))
92
93_PRIVATE_MAP = {
94    "tf": ["python", "core", "compiler", "examples", "tools", "contrib"],
95    # There's some aliasing between the compats and v1/2s, so it's easier to
96    # block by name and location than by deleting, or hiding objects.
97    "tf.compat.v1.compat": ["v1", "v2"],
98    "tf.compat.v2.compat": ["v1", "v2"]
99}
100
101tf.__doc__ = """
102  ## TensorFlow
103
104  ```
105  pip install tensorflow
106  ```
107  """
108
109
110def generate_raw_ops_doc():
111  """Generates docs for `tf.raw_ops`."""
112
113  warning = textwrap.dedent("""\n
114    Note: `tf.raw_ops` provides direct/low level access to all TensorFlow ops.
115    See [the RFC](https://github.com/tensorflow/community/blob/master/rfcs/20181225-tf-raw-ops.md)
116    for details. Unless you are library writer, you likely do not need to use
117    these ops directly.""")
118
119  table_header = textwrap.dedent("""
120
121      | Op Name | Has Gradient |
122      |---------|:------------:|""")
123
124  parts = [warning, table_header]
125
126  for op_name in sorted(dir(tf.raw_ops)):
127    try:
128      ops._gradient_registry.lookup(op_name)  # pylint: disable=protected-access
129      has_gradient = "\N{HEAVY CHECK MARK}\N{VARIATION SELECTOR-16}"
130    except LookupError:
131      has_gradient = "\N{CROSS MARK}"
132
133    if not op_name.startswith("_"):
134      path = pathlib.Path("/") / FLAGS.site_path / "tf/raw_ops" / op_name
135      path = path.with_suffix(".md")
136      link = ('<a id={op_name} href="{path}">{op_name}</a>').format(
137          op_name=op_name, path=str(path))
138      parts.append("| {link} | {has_gradient} |".format(
139          link=link, has_gradient=has_gradient))
140
141  return "\n".join(parts)
142
143
144# The doc generator isn't aware of tf_export.
145# So prefix the score tuples with -1 when this is the canonical name, +1
146# otherwise. The generator chooses the name with the lowest score.
147class TfExportAwareVisitor(doc_generator_visitor.DocGeneratorVisitor):
148  """A `tf_export`, `keras_export` and `estimator_export` aware doc_visitor."""
149
150  def _score_name(self, name):
151    all_exports = [tf_export.TENSORFLOW_API_NAME,
152                   tf_export.KERAS_API_NAME,
153                   tf_export.ESTIMATOR_API_NAME]
154
155    for api_name in all_exports:
156      canonical = tf_export.get_canonical_name_for_symbol(
157          self._index[name], api_name=api_name)
158      if canonical is not None:
159        break
160
161    canonical_score = 1
162    if canonical is not None and name == "tf." + canonical:
163      canonical_score = -1
164
165    scores = super()._score_name(name)
166    return (canonical_score,) + scores
167
168
169def build_docs(output_dir, code_url_prefix, search_hints, gen_report):
170  """Build api docs for tensorflow v2.
171
172  Args:
173    output_dir: A string path, where to put the files.
174    code_url_prefix: prefix for "Defined in" links.
175    search_hints: Bool. Include meta-data search hints at the top of each file.
176    gen_report: Bool. Generates an API report containing the health of the
177      docstrings of the public API.
178  """
179  # The custom page will be used for raw_ops.md not the one generated above.
180  doc_controls.set_custom_page_content(tf.raw_ops, generate_raw_ops_doc())
181
182  # Hide raw_ops from search.
183  for name, obj in tf_inspect.getmembers(tf.raw_ops):
184    if not name.startswith("_"):
185      doc_controls.hide_from_search(obj)
186
187  for cls in [tf.Module, tf.keras.layers.Layer, tf.keras.optimizers.Optimizer]:
188    doc_controls.decorate_all_class_attributes(
189        decorator=doc_controls.do_not_doc_in_subclasses,
190        cls=cls,
191        skip=["__init__"])
192
193  try:
194    doc_controls.do_not_generate_docs(tf.__internal__)
195  except AttributeError:
196    pass
197
198  try:
199    doc_controls.do_not_generate_docs(tf.keras.__internal__)
200  except AttributeError:
201    pass
202
203  try:
204    doc_controls.do_not_generate_docs(tf.__operators__)
205  except AttributeError:
206    pass
207
208  try:
209    doc_controls.do_not_generate_docs(tf.tools)
210  except AttributeError:
211    pass
212
213  try:
214    doc_controls.do_not_generate_docs(tf.compat.v1.pywrap_tensorflow)
215  except AttributeError:
216    pass
217
218  try:
219    doc_controls.do_not_generate_docs(tf.pywrap_tensorflow)
220  except AttributeError:
221    pass
222
223  try:
224    doc_controls.do_not_generate_docs(tf.flags)
225  except AttributeError:
226    pass
227
228  base_dirs, code_url_prefixes = base_dir.get_base_dirs_and_prefixes(
229      code_url_prefix)
230  doc_generator = generate_lib.DocGenerator(
231      root_title="TensorFlow 2",
232      py_modules=[("tf", tf)],
233      base_dir=base_dirs,
234      search_hints=search_hints,
235      code_url_prefix=code_url_prefixes,
236      site_path=FLAGS.site_path,
237      visitor_cls=TfExportAwareVisitor,
238      private_map=_PRIVATE_MAP,
239      gen_report=gen_report,
240      extra_docs=_EXTRA_DOCS
241  )
242
243  doc_generator.build(output_dir)
244
245  if gen_report:
246    return
247
248  out_path = pathlib.Path(output_dir)
249
250  expected_path_contents = {
251      "tf/summary/audio.md":
252          "tensorboard/plugins/audio/summary_v2.py",
253      "tf/estimator/DNNClassifier.md":
254          "tensorflow_estimator/python/estimator/canned/dnn.py",
255      "tf/nn/sigmoid_cross_entropy_with_logits.md":
256          "python/ops/nn_impl.py",
257      "tf/keras/Model.md":
258          "keras/engine/training.py",
259      "tf/keras/preprocessing/image/random_brightness.md":
260          "keras_preprocessing/image/affine_transformations.py"
261  }
262
263  all_passed = True
264  error_msg_parts = [
265      'Some "view source" links seem to be broken, please check:'
266  ]
267
268  for (rel_path, contents) in expected_path_contents.items():
269    path = out_path / rel_path
270    if contents not in path.read_text():
271      all_passed = False
272      error_msg_parts.append("  " + str(path))
273
274  if not all_passed:
275    raise ValueError("\n".join(error_msg_parts))
276
277  rejected_path_contents = {
278      "tf/keras/optimizers.md": "keras/optimizers/__init__.py",
279  }
280
281  all_passed = True
282  error_msg_parts = [
283      'Bad "view source" links in generated files, please check:'
284  ]
285  for rel_path, content in rejected_path_contents.items():
286    path = out_path / rel_path
287    if content in path.read_text():
288      all_passed = False
289      error_msg_parts.append("  " + str(path))
290
291  if not all_passed:
292    raise ValueError("\n".join(error_msg_parts))
293
294  num_files = len(list(out_path.rglob("*")))
295  if num_files < 2000:
296    raise ValueError("The TensorFlow api should be more than 2000 files"
297                     "(found {}).".format(num_files))
298
299
300def main(argv):
301  del argv
302  build_docs(
303      output_dir=FLAGS.output_dir,
304      code_url_prefix=FLAGS.code_url_prefix,
305      search_hints=FLAGS.search_hints,
306      gen_report=FLAGS.gen_report,)
307
308
309if __name__ == "__main__":
310  app.run(main)
311