1# lint as: python3 2# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""A tool to generate api_docs for TensorFlow2. 17 18``` 19python generate2.py --output_dir=/tmp/out 20``` 21 22Requires a local installation of `tensorflow_docs`: 23 24``` 25pip install git+https://github.com/tensorflow/docs 26``` 27""" 28 29from __future__ import absolute_import 30from __future__ import division 31from __future__ import print_function 32 33import pathlib 34import textwrap 35 36from absl import app 37from absl import flags 38 39import tensorflow as tf 40 41from tensorflow_docs.api_generator import doc_controls 42from tensorflow_docs.api_generator import doc_generator_visitor 43from tensorflow_docs.api_generator import generate_lib 44 45from tensorflow.python.framework import ops 46from tensorflow.python.util import tf_export 47from tensorflow.python.util import tf_inspect 48 49# Caution: the google and oss versions of this import are different. 50import base_dir 51 52# pylint: disable=g-import-not-at-top 53try: 54 from tensorflow.python.types import doc_typealias 55 _EXTRA_DOCS = getattr(doc_typealias, "_EXTRA_DOCS", {}) 56 del doc_typealias 57except ImportError: 58 _EXTRA_DOCS = {} 59# pylint: enable=g-import-not-at-top 60 61# `tf` has an `__all__` that doesn't list important things like `keras`. 62# The doc generator recognizes `__all__` as the list of public symbols. 63# So patch `tf.__all__` to list everything. 64tf.__all__ = [item_name for item_name, value in tf_inspect.getmembers(tf)] 65 66# tf_export generated two copies of the module objects. 67# This will just list compat.v2 as an alias for tf. Close enough, let's not 68# duplicate all the module skeleton files. 69tf.compat.v2 = tf 70 71FLAGS = flags.FLAGS 72 73flags.DEFINE_string( 74 "code_url_prefix", 75 "/code/stable/tensorflow", 76 "A url to prepend to code paths when creating links to defining code") 77 78flags.DEFINE_string("output_dir", "/tmp/out", 79 "A directory, where the docs will be output to.") 80 81flags.DEFINE_bool("search_hints", True, 82 "Include meta-data search hints at the top of each file.") 83 84flags.DEFINE_string( 85 "site_path", "", 86 "The path prefix (up to `.../api_docs/python`) used in the " 87 "`_toc.yaml` and `_redirects.yaml` files") 88 89flags.DEFINE_bool("gen_report", False, 90 ("Generate an API report containing the health of the" 91 "docstrings of the public API.")) 92 93_PRIVATE_MAP = { 94 "tf": ["python", "core", "compiler", "examples", "tools", "contrib"], 95 # There's some aliasing between the compats and v1/2s, so it's easier to 96 # block by name and location than by deleting, or hiding objects. 97 "tf.compat.v1.compat": ["v1", "v2"], 98 "tf.compat.v2.compat": ["v1", "v2"] 99} 100 101tf.__doc__ = """ 102 ## TensorFlow 103 104 ``` 105 pip install tensorflow 106 ``` 107 """ 108 109 110def generate_raw_ops_doc(): 111 """Generates docs for `tf.raw_ops`.""" 112 113 warning = textwrap.dedent("""\n 114 Note: `tf.raw_ops` provides direct/low level access to all TensorFlow ops. 115 See [the RFC](https://github.com/tensorflow/community/blob/master/rfcs/20181225-tf-raw-ops.md) 116 for details. Unless you are library writer, you likely do not need to use 117 these ops directly.""") 118 119 table_header = textwrap.dedent(""" 120 121 | Op Name | Has Gradient | 122 |---------|:------------:|""") 123 124 parts = [warning, table_header] 125 126 for op_name in sorted(dir(tf.raw_ops)): 127 try: 128 ops._gradient_registry.lookup(op_name) # pylint: disable=protected-access 129 has_gradient = "\N{HEAVY CHECK MARK}\N{VARIATION SELECTOR-16}" 130 except LookupError: 131 has_gradient = "\N{CROSS MARK}" 132 133 if not op_name.startswith("_"): 134 path = pathlib.Path("/") / FLAGS.site_path / "tf/raw_ops" / op_name 135 path = path.with_suffix(".md") 136 link = ('<a id={op_name} href="{path}">{op_name}</a>').format( 137 op_name=op_name, path=str(path)) 138 parts.append("| {link} | {has_gradient} |".format( 139 link=link, has_gradient=has_gradient)) 140 141 return "\n".join(parts) 142 143 144# The doc generator isn't aware of tf_export. 145# So prefix the score tuples with -1 when this is the canonical name, +1 146# otherwise. The generator chooses the name with the lowest score. 147class TfExportAwareVisitor(doc_generator_visitor.DocGeneratorVisitor): 148 """A `tf_export`, `keras_export` and `estimator_export` aware doc_visitor.""" 149 150 def _score_name(self, name): 151 all_exports = [tf_export.TENSORFLOW_API_NAME, 152 tf_export.KERAS_API_NAME, 153 tf_export.ESTIMATOR_API_NAME] 154 155 for api_name in all_exports: 156 canonical = tf_export.get_canonical_name_for_symbol( 157 self._index[name], api_name=api_name) 158 if canonical is not None: 159 break 160 161 canonical_score = 1 162 if canonical is not None and name == "tf." + canonical: 163 canonical_score = -1 164 165 scores = super()._score_name(name) 166 return (canonical_score,) + scores 167 168 169def build_docs(output_dir, code_url_prefix, search_hints, gen_report): 170 """Build api docs for tensorflow v2. 171 172 Args: 173 output_dir: A string path, where to put the files. 174 code_url_prefix: prefix for "Defined in" links. 175 search_hints: Bool. Include meta-data search hints at the top of each file. 176 gen_report: Bool. Generates an API report containing the health of the 177 docstrings of the public API. 178 """ 179 # The custom page will be used for raw_ops.md not the one generated above. 180 doc_controls.set_custom_page_content(tf.raw_ops, generate_raw_ops_doc()) 181 182 # Hide raw_ops from search. 183 for name, obj in tf_inspect.getmembers(tf.raw_ops): 184 if not name.startswith("_"): 185 doc_controls.hide_from_search(obj) 186 187 for cls in [tf.Module, tf.keras.layers.Layer, tf.keras.optimizers.Optimizer]: 188 doc_controls.decorate_all_class_attributes( 189 decorator=doc_controls.do_not_doc_in_subclasses, 190 cls=cls, 191 skip=["__init__"]) 192 193 try: 194 doc_controls.do_not_generate_docs(tf.__internal__) 195 except AttributeError: 196 pass 197 198 try: 199 doc_controls.do_not_generate_docs(tf.keras.__internal__) 200 except AttributeError: 201 pass 202 203 try: 204 doc_controls.do_not_generate_docs(tf.__operators__) 205 except AttributeError: 206 pass 207 208 try: 209 doc_controls.do_not_generate_docs(tf.tools) 210 except AttributeError: 211 pass 212 213 try: 214 doc_controls.do_not_generate_docs(tf.compat.v1.pywrap_tensorflow) 215 except AttributeError: 216 pass 217 218 try: 219 doc_controls.do_not_generate_docs(tf.pywrap_tensorflow) 220 except AttributeError: 221 pass 222 223 try: 224 doc_controls.do_not_generate_docs(tf.flags) 225 except AttributeError: 226 pass 227 228 base_dirs, code_url_prefixes = base_dir.get_base_dirs_and_prefixes( 229 code_url_prefix) 230 doc_generator = generate_lib.DocGenerator( 231 root_title="TensorFlow 2", 232 py_modules=[("tf", tf)], 233 base_dir=base_dirs, 234 search_hints=search_hints, 235 code_url_prefix=code_url_prefixes, 236 site_path=FLAGS.site_path, 237 visitor_cls=TfExportAwareVisitor, 238 private_map=_PRIVATE_MAP, 239 gen_report=gen_report, 240 extra_docs=_EXTRA_DOCS 241 ) 242 243 doc_generator.build(output_dir) 244 245 if gen_report: 246 return 247 248 out_path = pathlib.Path(output_dir) 249 250 expected_path_contents = { 251 "tf/summary/audio.md": 252 "tensorboard/plugins/audio/summary_v2.py", 253 "tf/estimator/DNNClassifier.md": 254 "tensorflow_estimator/python/estimator/canned/dnn.py", 255 "tf/nn/sigmoid_cross_entropy_with_logits.md": 256 "python/ops/nn_impl.py", 257 "tf/keras/Model.md": 258 "keras/engine/training.py", 259 "tf/keras/preprocessing/image/random_brightness.md": 260 "keras_preprocessing/image/affine_transformations.py" 261 } 262 263 all_passed = True 264 error_msg_parts = [ 265 'Some "view source" links seem to be broken, please check:' 266 ] 267 268 for (rel_path, contents) in expected_path_contents.items(): 269 path = out_path / rel_path 270 if contents not in path.read_text(): 271 all_passed = False 272 error_msg_parts.append(" " + str(path)) 273 274 if not all_passed: 275 raise ValueError("\n".join(error_msg_parts)) 276 277 rejected_path_contents = { 278 "tf/keras/optimizers.md": "keras/optimizers/__init__.py", 279 } 280 281 all_passed = True 282 error_msg_parts = [ 283 'Bad "view source" links in generated files, please check:' 284 ] 285 for rel_path, content in rejected_path_contents.items(): 286 path = out_path / rel_path 287 if content in path.read_text(): 288 all_passed = False 289 error_msg_parts.append(" " + str(path)) 290 291 if not all_passed: 292 raise ValueError("\n".join(error_msg_parts)) 293 294 num_files = len(list(out_path.rglob("*"))) 295 if num_files < 2000: 296 raise ValueError("The TensorFlow api should be more than 2000 files" 297 "(found {}).".format(num_files)) 298 299 300def main(argv): 301 del argv 302 build_docs( 303 output_dir=FLAGS.output_dir, 304 code_url_prefix=FLAGS.code_url_prefix, 305 search_hints=FLAGS.search_hints, 306 gen_report=FLAGS.gen_report,) 307 308 309if __name__ == "__main__": 310 app.run(main) 311