• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A tool to generate api_docs for TensorFlow2.
16
17```
18python generate2.py --output_dir=/tmp/out
19```
20
21Requires a local installation of `tensorflow_docs`:
22
23```
24pip install git+https://github.com/tensorflow/docs
25```
26"""
27import contextlib
28import distutils
29import pathlib
30import textwrap
31
32from typing import NamedTuple
33
34from absl import app
35from absl import flags
36
37import tensorflow as tf
38
39from tensorflow_docs.api_generator import doc_controls
40from tensorflow_docs.api_generator import doc_generator_visitor
41from tensorflow_docs.api_generator import generate_lib
42from tensorflow_docs.api_generator.pretty_docs import base_page
43from tensorflow_docs.api_generator.pretty_docs import module_page
44
45import yaml
46
47from tensorflow.python.framework import ops
48from tensorflow.python.util import tf_export
49from tensorflow.python.util import tf_inspect
50
51# Caution: the google and oss versions of this import are different.
52import base_dir
53
54# pylint: disable=g-import-not-at-top
55try:
56  from tensorflow.python.types import doc_typealias
57  _EXTRA_DOCS = getattr(doc_typealias, "_EXTRA_DOCS", {})
58  del doc_typealias
59except ImportError:
60  _EXTRA_DOCS = {}
61# pylint: enable=g-import-not-at-top
62
63# `tf` has an `__all__` that doesn't list important things like `keras`.
64# The doc generator recognizes `__all__` as the list of public symbols.
65# So patch `tf.__all__` to list everything.
66tf.__all__ = [item_name for item_name, value in tf_inspect.getmembers(tf)]
67
68# tf_export generated two copies of the module objects.
69# This will just list compat.v2 as an alias for tf. Close enough, let's not
70# duplicate all the module skeleton files.
71tf.compat.v2 = tf
72
73tf.losses = tf.keras.losses
74tf.metrics = tf.keras.metrics
75tf.optimizers = tf.keras.optimizers
76tf.initializers = tf.keras.initializers
77
78MIN_NUM_FILES_EXPECTED = 2000
79FLAGS = flags.FLAGS
80
81flags.DEFINE_string(
82    "code_url_prefix",
83    "/code/stable/tensorflow",
84    "A url to prepend to code paths when creating links to defining code")
85
86flags.DEFINE_string("output_dir", "/tmp/out",
87                    "A directory, where the docs will be output to.")
88
89flags.DEFINE_bool("search_hints", True,
90                  "Include meta-data search hints at the top of each file.")
91
92flags.DEFINE_string(
93    "site_path", "",
94    "The path prefix (up to `.../api_docs/python`) used in the "
95    "`_toc.yaml` and `_redirects.yaml` files")
96
97_PRIVATE_MAP = {
98    "tf": ["python", "core", "compiler", "examples", "tools", "contrib"],
99    # There's some aliasing between the compats and v1/2s, so it's easier to
100    # block by name and location than by deleting, or hiding objects.
101    "tf.compat.v1.compat": ["v1", "v2"],
102    "tf.compat.v2.compat": ["v1", "v2"]
103}
104
105tf.__doc__ = """
106  ## TensorFlow
107
108  ```
109  pip install tensorflow
110  ```
111  """
112
113
114class RawOpsPageInfo(module_page.ModulePageInfo):
115  """Generates a custom page for `tf.raw_ops`."""
116  DEFAULT_BUILDER_CLASS = base_page.TemplatePageBuilder
117
118  def build(self):
119    # Skip the ModulePage implementation, which doesn't use a template.
120    content = base_page.PageInfo.build(self)
121
122    raw_ops_doc = self.generate_raw_ops_doc()
123
124    return "\n".join([content, raw_ops_doc])
125
126  def generate_raw_ops_doc(self):
127    """Generates docs for `tf.raw_ops`."""
128    del self
129
130    warning = textwrap.dedent("""\n
131      Note: `tf.raw_ops` provides direct/low level access to all TensorFlow ops.
132      See [the RFC](https://github.com/tensorflow/community/blob/master/rfcs/20181225-tf-raw-ops.md)
133      for details. Unless you are library writer, you likely do not need to use
134      these ops directly.""")
135
136    table_header = textwrap.dedent("""
137
138        | Op Name | Has Gradient |
139        |---------|:------------:|""")
140
141    parts = [warning, table_header]
142
143    for op_name in sorted(dir(tf.raw_ops)):
144      try:
145        ops._gradient_registry.lookup(op_name)  # pylint: disable=protected-access
146        has_gradient = "\N{HEAVY CHECK MARK}\N{VARIATION SELECTOR-16}"
147      except LookupError:
148        has_gradient = "\N{CROSS MARK}"
149
150      if not op_name.startswith("_"):
151        path = pathlib.Path("/") / FLAGS.site_path / "tf/raw_ops" / op_name
152        path = path.with_suffix(".md")
153        link = ('<a id={op_name} href="{path}">{op_name}</a>').format(
154            op_name=op_name, path=str(path))
155        parts.append("| {link} | {has_gradient} |".format(
156            link=link, has_gradient=has_gradient))
157
158    return "\n".join(parts)
159
160
161# The doc generator isn't aware of tf_export.
162# So prefix the score tuples with -1 when this is the canonical name, +1
163# otherwise. The generator chooses the name with the lowest score.
164class TfExportAwareVisitor(doc_generator_visitor.DocGeneratorVisitor):
165  """A `tf_export`, `keras_export` and `estimator_export` aware doc_visitor."""
166
167  class TfNameScore(NamedTuple):
168    cannonical_score: int
169    name_score: doc_generator_visitor.DocGeneratorVisitor.NameScore
170
171  def _score_name(self, path: doc_generator_visitor.ApiPath) -> TfNameScore:
172    name = ".".join(path)
173    all_exports = [tf_export.TENSORFLOW_API_NAME,
174                   tf_export.KERAS_API_NAME,
175                   tf_export.ESTIMATOR_API_NAME]
176
177    for api_name in all_exports:
178      canonical = tf_export.get_canonical_name_for_symbol(
179          self._index[name], api_name=api_name)
180      if canonical is not None:
181        break
182
183    canonical_score = 1
184    if canonical is not None and name == "tf." + canonical:
185      canonical_score = -1
186
187    return self.TfNameScore(canonical_score, super()._score_name(path))
188
189
190def build_docs(output_dir, code_url_prefix, search_hints):
191  """Build api docs for tensorflow v2.
192
193  Args:
194    output_dir: A string path, where to put the files.
195    code_url_prefix: prefix for "Defined in" links.
196    search_hints: Bool. Include meta-data search hints at the top of each file.
197  """
198  output_dir = pathlib.Path(output_dir)
199  site_path = pathlib.Path("/", FLAGS.site_path)
200
201  if distutils.version.LooseVersion(tf.__version__) >= "2.9":
202    doc_controls.set_deprecated(tf.compat.v1)
203    doc_controls.set_deprecated(tf.estimator)
204    doc_controls.set_deprecated(tf.feature_column)
205    doc_controls.set_deprecated(tf.keras.preprocessing)
206
207  # The custom page will be used for raw_ops.md not the one generated above.
208  doc_controls.set_custom_page_builder_cls(tf.raw_ops, RawOpsPageInfo)
209
210  # Hide raw_ops from search.
211  for name, obj in tf_inspect.getmembers(tf.raw_ops):
212    if not name.startswith("_"):
213      doc_controls.hide_from_search(obj)
214
215  for cls in [tf.Module, tf.keras.layers.Layer, tf.keras.optimizers.Optimizer]:
216    doc_controls.decorate_all_class_attributes(
217        decorator=doc_controls.do_not_doc_in_subclasses,
218        cls=cls,
219        skip=["__init__"])
220
221  do_not_document = ["tf.__internal__",
222                     "tf.keras.__internal__",
223                     "tf.__operators__",
224                     "tf.tools",
225                     "tf.compat.v1.pywrap_tensorflow",
226                     "tf.pywrap_tensorflow",
227                     "tf.flags",
228                     "tf.batch_mat_mul_v3",
229                     "tf.sparse_segment_sum_grad"]
230  for path in do_not_document:
231    item = tf
232    for part in path.split(".")[1:]:
233      item = getattr(item, part, None)
234    if item is None:
235      continue
236    doc_controls.do_not_generate_docs(item)
237
238  base_dirs, code_url_prefixes = base_dir.get_base_dirs_and_prefixes(
239      code_url_prefix)
240  doc_generator = generate_lib.DocGenerator(
241      root_title="TensorFlow 2",
242      py_modules=[("tf", tf)],
243      base_dir=base_dirs,
244      search_hints=search_hints,
245      code_url_prefix=code_url_prefixes,
246      site_path=site_path,
247      visitor_cls=TfExportAwareVisitor,
248      private_map=_PRIVATE_MAP,
249      extra_docs=_EXTRA_DOCS,
250      callbacks=base_dir.get_callbacks())
251
252  doc_generator.build(output_dir)
253
254  @contextlib.contextmanager
255  def edit_yaml_file(path):
256    content = yaml.safe_load(path.read_text())
257    yield content
258
259    with path.open("w") as f:
260      yaml.dump(content, f, default_flow_style=False)
261
262  toc_path = output_dir / "tf/_toc.yaml"
263  with edit_yaml_file(toc_path) as toc:
264    # Replace the overview path for 'TensorFlow' to
265    # `/api_docs/python/tf_overview`. This will be redirected to
266    # `/api_docs/python/tf`.
267    toc["toc"][0]["section"][0]["path"] = str(site_path / "tf_overview")
268
269  redirects_path = output_dir / "tf/_redirects.yaml"
270  with edit_yaml_file(redirects_path) as redirects:
271    redirects["redirects"].append({
272        "from": str(site_path / "tf_overview"),
273        "to": str(site_path / "tf"),
274    })
275
276  expected_path_contents = {
277      "tf/summary/audio.md":
278          "tensorboard/plugins/audio/summary_v2.py",
279      "tf/estimator/DNNClassifier.md":
280          "tensorflow_estimator/python/estimator/canned/dnn.py",
281      "tf/nn/sigmoid_cross_entropy_with_logits.md":
282          "python/ops/nn_impl.py",
283      "tf/keras/Model.md":
284          "keras/engine/training.py",
285  }
286
287  all_passed = True
288  error_msg_parts = [
289      'Some "view source" links seem to be broken, please check:'
290  ]
291
292  for (rel_path, contents) in expected_path_contents.items():
293    path = output_dir / rel_path
294    if contents not in path.read_text():
295      all_passed = False
296      error_msg_parts.append("  " + str(path))
297
298  if not all_passed:
299    raise ValueError("\n".join(error_msg_parts))
300
301  rejected_path_contents = {
302      "tf/keras/optimizers.md": "keras/optimizers/__init__.py",
303  }
304
305  all_passed = True
306  error_msg_parts = [
307      'Bad "view source" links in generated files, please check:'
308  ]
309  for rel_path, content in rejected_path_contents.items():
310    path = output_dir / rel_path
311    if content in path.read_text():
312      all_passed = False
313      error_msg_parts.append("  " + str(path))
314
315  if not all_passed:
316    raise ValueError("\n".join(error_msg_parts))
317
318  num_files = len(list(output_dir.rglob("*")))
319  if num_files < MIN_NUM_FILES_EXPECTED:
320    raise ValueError(
321        f"The TensorFlow api should be more than {MIN_NUM_FILES_EXPECTED} files"
322        f"(found {num_files}).")
323
324
325def main(argv):
326  del argv
327  build_docs(
328      output_dir=FLAGS.output_dir,
329      code_url_prefix=FLAGS.code_url_prefix,
330      search_hints=FLAGS.search_hints)
331
332
333if __name__ == "__main__":
334  app.run(main)
335