• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.tools.docs.generate2."""
16
17import os
18import pathlib
19import shutil
20import types
21from unittest import mock
22
23import tensorflow as tf
24from tensorflow import estimator as tf_estimator
25
26import yaml
27
28from tensorflow.python.platform import googletest
29from tensorflow.tools.docs import generate2
30
31
32class AutoModule(types.ModuleType):
33
34  def __getattr__(self, name):
35    if name.startswith('_'):
36      raise AttributeError()
37    mod = AutoModule(name)
38    setattr(self, name, mod)
39    return mod
40
41# Make a mock tensorflow package that won't take too long to test.
42fake_tf = AutoModule('FakeTensorFlow')
43fake_tf.Module = tf.Module  # pylint: disable=invalid-name
44fake_tf.estimator.DNNClassifier = tf_estimator.DNNClassifier
45fake_tf.feature_column.nummeric_column = tf.feature_column.numeric_column
46fake_tf.keras.Model = tf.keras.Model
47fake_tf.keras.preprocessing = tf.keras.preprocessing
48fake_tf.keras.layers.Layer = tf.keras.layers.Layer
49fake_tf.keras.optimizers.Optimizer = tf.keras.optimizers.Optimizer
50fake_tf.nn.sigmoid_cross_entropy_with_logits = (
51    tf.nn.sigmoid_cross_entropy_with_logits)
52fake_tf.raw_ops.Add = tf.raw_ops.Add
53fake_tf.summary.audio = tf.summary.audio
54fake_tf.summary.audio2 = tf.summary.audio
55fake_tf.__version__ = tf.__version__
56
57
58class Generate2Test(googletest.TestCase):
59
60  @mock.patch.object(generate2, 'tf', fake_tf)
61  def test_end_to_end(self):
62    generate2.MIN_NUM_FILES_EXPECTED = 1
63    output_dir = pathlib.Path(googletest.GetTempDir())/'output'
64    if os.path.exists(output_dir):
65      shutil.rmtree(output_dir)
66    os.makedirs(output_dir)
67    generate2.build_docs(
68        output_dir=output_dir,
69        code_url_prefix='',
70        search_hints=True,
71    )
72
73    raw_ops_page = (output_dir/'tf/raw_ops.md').read_text()
74    self.assertIn('/tf/raw_ops/Add.md', raw_ops_page)
75
76    toc = yaml.safe_load((output_dir / 'tf/_toc.yaml').read_text())
77    self.assertEqual({
78        'title': 'Overview',
79        'path': '/tf_overview'
80    }, toc['toc'][0]['section'][0])
81    redirects = yaml.safe_load((output_dir / 'tf/_redirects.yaml').read_text())
82    self.assertIn({'from': '/tf_overview', 'to': '/tf'}, redirects['redirects'])
83
84if __name__ == '__main__':
85  googletest.main()
86