1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.tools.docs.generate2.""" 16 17import os 18import pathlib 19import shutil 20import types 21from unittest import mock 22 23import tensorflow as tf 24from tensorflow import estimator as tf_estimator 25 26import yaml 27 28from tensorflow.python.platform import googletest 29from tensorflow.tools.docs import generate2 30 31 32class AutoModule(types.ModuleType): 33 34 def __getattr__(self, name): 35 if name.startswith('_'): 36 raise AttributeError() 37 mod = AutoModule(name) 38 setattr(self, name, mod) 39 return mod 40 41# Make a mock tensorflow package that won't take too long to test. 42fake_tf = AutoModule('FakeTensorFlow') 43fake_tf.Module = tf.Module # pylint: disable=invalid-name 44fake_tf.estimator.DNNClassifier = tf_estimator.DNNClassifier 45fake_tf.feature_column.nummeric_column = tf.feature_column.numeric_column 46fake_tf.keras.Model = tf.keras.Model 47fake_tf.keras.preprocessing = tf.keras.preprocessing 48fake_tf.keras.layers.Layer = tf.keras.layers.Layer 49fake_tf.keras.optimizers.Optimizer = tf.keras.optimizers.Optimizer 50fake_tf.nn.sigmoid_cross_entropy_with_logits = ( 51 tf.nn.sigmoid_cross_entropy_with_logits) 52fake_tf.raw_ops.Add = tf.raw_ops.Add 53fake_tf.summary.audio = tf.summary.audio 54fake_tf.summary.audio2 = tf.summary.audio 55fake_tf.__version__ = tf.__version__ 56 57 58class Generate2Test(googletest.TestCase): 59 60 @mock.patch.object(generate2, 'tf', fake_tf) 61 def test_end_to_end(self): 62 generate2.MIN_NUM_FILES_EXPECTED = 1 63 output_dir = pathlib.Path(googletest.GetTempDir())/'output' 64 if os.path.exists(output_dir): 65 shutil.rmtree(output_dir) 66 os.makedirs(output_dir) 67 generate2.build_docs( 68 output_dir=output_dir, 69 code_url_prefix='', 70 search_hints=True, 71 ) 72 73 raw_ops_page = (output_dir/'tf/raw_ops.md').read_text() 74 self.assertIn('/tf/raw_ops/Add.md', raw_ops_page) 75 76 toc = yaml.safe_load((output_dir / 'tf/_toc.yaml').read_text()) 77 self.assertEqual({ 78 'title': 'Overview', 79 'path': '/tf_overview' 80 }, toc['toc'][0]['section'][0]) 81 redirects = yaml.safe_load((output_dir / 'tf/_redirects.yaml').read_text()) 82 self.assertIn({'from': '/tf_overview', 'to': '/tf'}, redirects['redirects']) 83 84if __name__ == '__main__': 85 googletest.main() 86