1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A tool to generate api_docs for TensorFlow2. 16 17``` 18python generate2.py --output_dir=/tmp/out 19``` 20 21Requires a local installation of `tensorflow_docs`: 22 23``` 24pip install git+https://github.com/tensorflow/docs 25``` 26""" 27import contextlib 28import distutils 29import pathlib 30import textwrap 31 32from typing import NamedTuple 33 34from absl import app 35from absl import flags 36 37import tensorflow as tf 38 39from tensorflow_docs.api_generator import doc_controls 40from tensorflow_docs.api_generator import doc_generator_visitor 41from tensorflow_docs.api_generator import generate_lib 42from tensorflow_docs.api_generator.pretty_docs import base_page 43from tensorflow_docs.api_generator.pretty_docs import module_page 44 45import yaml 46 47from tensorflow.python.framework import ops 48from tensorflow.python.util import tf_export 49from tensorflow.python.util import tf_inspect 50 51# Caution: the google and oss versions of this import are different. 52import base_dir 53 54# pylint: disable=g-import-not-at-top 55try: 56 from tensorflow.python.types import doc_typealias 57 _EXTRA_DOCS = getattr(doc_typealias, "_EXTRA_DOCS", {}) 58 del doc_typealias 59except ImportError: 60 _EXTRA_DOCS = {} 61# pylint: enable=g-import-not-at-top 62 63# `tf` has an `__all__` that doesn't list important things like `keras`. 64# The doc generator recognizes `__all__` as the list of public symbols. 65# So patch `tf.__all__` to list everything. 66tf.__all__ = [item_name for item_name, value in tf_inspect.getmembers(tf)] 67 68# tf_export generated two copies of the module objects. 69# This will just list compat.v2 as an alias for tf. Close enough, let's not 70# duplicate all the module skeleton files. 71tf.compat.v2 = tf 72 73tf.losses = tf.keras.losses 74tf.metrics = tf.keras.metrics 75tf.optimizers = tf.keras.optimizers 76tf.initializers = tf.keras.initializers 77 78MIN_NUM_FILES_EXPECTED = 2000 79FLAGS = flags.FLAGS 80 81flags.DEFINE_string( 82 "code_url_prefix", 83 "/code/stable/tensorflow", 84 "A url to prepend to code paths when creating links to defining code") 85 86flags.DEFINE_string("output_dir", "/tmp/out", 87 "A directory, where the docs will be output to.") 88 89flags.DEFINE_bool("search_hints", True, 90 "Include meta-data search hints at the top of each file.") 91 92flags.DEFINE_string( 93 "site_path", "", 94 "The path prefix (up to `.../api_docs/python`) used in the " 95 "`_toc.yaml` and `_redirects.yaml` files") 96 97_PRIVATE_MAP = { 98 "tf": ["python", "core", "compiler", "examples", "tools", "contrib"], 99 # There's some aliasing between the compats and v1/2s, so it's easier to 100 # block by name and location than by deleting, or hiding objects. 101 "tf.compat.v1.compat": ["v1", "v2"], 102 "tf.compat.v2.compat": ["v1", "v2"] 103} 104 105tf.__doc__ = """ 106 ## TensorFlow 107 108 ``` 109 pip install tensorflow 110 ``` 111 """ 112 113 114class RawOpsPageInfo(module_page.ModulePageInfo): 115 """Generates a custom page for `tf.raw_ops`.""" 116 DEFAULT_BUILDER_CLASS = base_page.TemplatePageBuilder 117 118 def build(self): 119 # Skip the ModulePage implementation, which doesn't use a template. 120 content = base_page.PageInfo.build(self) 121 122 raw_ops_doc = self.generate_raw_ops_doc() 123 124 return "\n".join([content, raw_ops_doc]) 125 126 def generate_raw_ops_doc(self): 127 """Generates docs for `tf.raw_ops`.""" 128 del self 129 130 warning = textwrap.dedent("""\n 131 Note: `tf.raw_ops` provides direct/low level access to all TensorFlow ops. 132 See [the RFC](https://github.com/tensorflow/community/blob/master/rfcs/20181225-tf-raw-ops.md) 133 for details. Unless you are library writer, you likely do not need to use 134 these ops directly.""") 135 136 table_header = textwrap.dedent(""" 137 138 | Op Name | Has Gradient | 139 |---------|:------------:|""") 140 141 parts = [warning, table_header] 142 143 for op_name in sorted(dir(tf.raw_ops)): 144 try: 145 ops._gradient_registry.lookup(op_name) # pylint: disable=protected-access 146 has_gradient = "\N{HEAVY CHECK MARK}\N{VARIATION SELECTOR-16}" 147 except LookupError: 148 has_gradient = "\N{CROSS MARK}" 149 150 if not op_name.startswith("_"): 151 path = pathlib.Path("/") / FLAGS.site_path / "tf/raw_ops" / op_name 152 path = path.with_suffix(".md") 153 link = ('<a id={op_name} href="{path}">{op_name}</a>').format( 154 op_name=op_name, path=str(path)) 155 parts.append("| {link} | {has_gradient} |".format( 156 link=link, has_gradient=has_gradient)) 157 158 return "\n".join(parts) 159 160 161# The doc generator isn't aware of tf_export. 162# So prefix the score tuples with -1 when this is the canonical name, +1 163# otherwise. The generator chooses the name with the lowest score. 164class TfExportAwareVisitor(doc_generator_visitor.DocGeneratorVisitor): 165 """A `tf_export`, `keras_export` and `estimator_export` aware doc_visitor.""" 166 167 class TfNameScore(NamedTuple): 168 cannonical_score: int 169 name_score: doc_generator_visitor.DocGeneratorVisitor.NameScore 170 171 def _score_name(self, path: doc_generator_visitor.ApiPath) -> TfNameScore: 172 name = ".".join(path) 173 all_exports = [tf_export.TENSORFLOW_API_NAME, 174 tf_export.KERAS_API_NAME, 175 tf_export.ESTIMATOR_API_NAME] 176 177 for api_name in all_exports: 178 canonical = tf_export.get_canonical_name_for_symbol( 179 self._index[name], api_name=api_name) 180 if canonical is not None: 181 break 182 183 canonical_score = 1 184 if canonical is not None and name == "tf." + canonical: 185 canonical_score = -1 186 187 return self.TfNameScore(canonical_score, super()._score_name(path)) 188 189 190def build_docs(output_dir, code_url_prefix, search_hints): 191 """Build api docs for tensorflow v2. 192 193 Args: 194 output_dir: A string path, where to put the files. 195 code_url_prefix: prefix for "Defined in" links. 196 search_hints: Bool. Include meta-data search hints at the top of each file. 197 """ 198 output_dir = pathlib.Path(output_dir) 199 site_path = pathlib.Path("/", FLAGS.site_path) 200 201 if distutils.version.LooseVersion(tf.__version__) >= "2.9": 202 doc_controls.set_deprecated(tf.compat.v1) 203 doc_controls.set_deprecated(tf.estimator) 204 doc_controls.set_deprecated(tf.feature_column) 205 doc_controls.set_deprecated(tf.keras.preprocessing) 206 207 # The custom page will be used for raw_ops.md not the one generated above. 208 doc_controls.set_custom_page_builder_cls(tf.raw_ops, RawOpsPageInfo) 209 210 # Hide raw_ops from search. 211 for name, obj in tf_inspect.getmembers(tf.raw_ops): 212 if not name.startswith("_"): 213 doc_controls.hide_from_search(obj) 214 215 for cls in [tf.Module, tf.keras.layers.Layer, tf.keras.optimizers.Optimizer]: 216 doc_controls.decorate_all_class_attributes( 217 decorator=doc_controls.do_not_doc_in_subclasses, 218 cls=cls, 219 skip=["__init__"]) 220 221 do_not_document = ["tf.__internal__", 222 "tf.keras.__internal__", 223 "tf.__operators__", 224 "tf.tools", 225 "tf.compat.v1.pywrap_tensorflow", 226 "tf.pywrap_tensorflow", 227 "tf.flags", 228 "tf.batch_mat_mul_v3", 229 "tf.sparse_segment_sum_grad"] 230 for path in do_not_document: 231 item = tf 232 for part in path.split(".")[1:]: 233 item = getattr(item, part, None) 234 if item is None: 235 continue 236 doc_controls.do_not_generate_docs(item) 237 238 base_dirs, code_url_prefixes = base_dir.get_base_dirs_and_prefixes( 239 code_url_prefix) 240 doc_generator = generate_lib.DocGenerator( 241 root_title="TensorFlow 2", 242 py_modules=[("tf", tf)], 243 base_dir=base_dirs, 244 search_hints=search_hints, 245 code_url_prefix=code_url_prefixes, 246 site_path=site_path, 247 visitor_cls=TfExportAwareVisitor, 248 private_map=_PRIVATE_MAP, 249 extra_docs=_EXTRA_DOCS, 250 callbacks=base_dir.get_callbacks()) 251 252 doc_generator.build(output_dir) 253 254 @contextlib.contextmanager 255 def edit_yaml_file(path): 256 content = yaml.safe_load(path.read_text()) 257 yield content 258 259 with path.open("w") as f: 260 yaml.dump(content, f, default_flow_style=False) 261 262 toc_path = output_dir / "tf/_toc.yaml" 263 with edit_yaml_file(toc_path) as toc: 264 # Replace the overview path for 'TensorFlow' to 265 # `/api_docs/python/tf_overview`. This will be redirected to 266 # `/api_docs/python/tf`. 267 toc["toc"][0]["section"][0]["path"] = str(site_path / "tf_overview") 268 269 redirects_path = output_dir / "tf/_redirects.yaml" 270 with edit_yaml_file(redirects_path) as redirects: 271 redirects["redirects"].append({ 272 "from": str(site_path / "tf_overview"), 273 "to": str(site_path / "tf"), 274 }) 275 276 expected_path_contents = { 277 "tf/summary/audio.md": 278 "tensorboard/plugins/audio/summary_v2.py", 279 "tf/estimator/DNNClassifier.md": 280 "tensorflow_estimator/python/estimator/canned/dnn.py", 281 "tf/nn/sigmoid_cross_entropy_with_logits.md": 282 "python/ops/nn_impl.py", 283 "tf/keras/Model.md": 284 "keras/engine/training.py", 285 } 286 287 all_passed = True 288 error_msg_parts = [ 289 'Some "view source" links seem to be broken, please check:' 290 ] 291 292 for (rel_path, contents) in expected_path_contents.items(): 293 path = output_dir / rel_path 294 if contents not in path.read_text(): 295 all_passed = False 296 error_msg_parts.append(" " + str(path)) 297 298 if not all_passed: 299 raise ValueError("\n".join(error_msg_parts)) 300 301 rejected_path_contents = { 302 "tf/keras/optimizers.md": "keras/optimizers/__init__.py", 303 } 304 305 all_passed = True 306 error_msg_parts = [ 307 'Bad "view source" links in generated files, please check:' 308 ] 309 for rel_path, content in rejected_path_contents.items(): 310 path = output_dir / rel_path 311 if content in path.read_text(): 312 all_passed = False 313 error_msg_parts.append(" " + str(path)) 314 315 if not all_passed: 316 raise ValueError("\n".join(error_msg_parts)) 317 318 num_files = len(list(output_dir.rglob("*"))) 319 if num_files < MIN_NUM_FILES_EXPECTED: 320 raise ValueError( 321 f"The TensorFlow api should be more than {MIN_NUM_FILES_EXPECTED} files" 322 f"(found {num_files}).") 323 324 325def main(argv): 326 del argv 327 build_docs( 328 output_dir=FLAGS.output_dir, 329 code_url_prefix=FLAGS.code_url_prefix, 330 search_hints=FLAGS.search_hints) 331 332 333if __name__ == "__main__": 334 app.run(main) 335