1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tools.docs.doc_generator_visitor.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import types 22 23from tensorflow.python.platform import googletest 24from tensorflow.tools.docs import doc_generator_visitor 25from tensorflow.tools.docs import generate_lib 26 27 28class NoDunderVisitor(doc_generator_visitor.DocGeneratorVisitor): 29 30 def __call__(self, parent_name, parent, children): 31 """Drop all the dunder methods to make testing easier.""" 32 children = [ 33 (name, obj) for (name, obj) in children if not name.startswith('_') 34 ] 35 super(NoDunderVisitor, self).__call__(parent_name, parent, children) 36 37 38class DocGeneratorVisitorTest(googletest.TestCase): 39 40 def test_call_module(self): 41 visitor = doc_generator_visitor.DocGeneratorVisitor() 42 visitor( 43 'doc_generator_visitor', doc_generator_visitor, 44 [('DocGeneratorVisitor', doc_generator_visitor.DocGeneratorVisitor)]) 45 46 self.assertEqual({'doc_generator_visitor': ['DocGeneratorVisitor']}, 47 visitor.tree) 48 self.assertEqual({ 49 'doc_generator_visitor': doc_generator_visitor, 50 'doc_generator_visitor.DocGeneratorVisitor': 51 doc_generator_visitor.DocGeneratorVisitor, 52 }, visitor.index) 53 54 def test_call_class(self): 55 visitor = doc_generator_visitor.DocGeneratorVisitor() 56 visitor( 57 'DocGeneratorVisitor', doc_generator_visitor.DocGeneratorVisitor, 58 [('index', doc_generator_visitor.DocGeneratorVisitor.index)]) 59 60 self.assertEqual({'DocGeneratorVisitor': ['index']}, 61 visitor.tree) 62 self.assertEqual({ 63 'DocGeneratorVisitor': doc_generator_visitor.DocGeneratorVisitor, 64 'DocGeneratorVisitor.index': 65 doc_generator_visitor.DocGeneratorVisitor.index 66 }, visitor.index) 67 68 def test_call_raises(self): 69 visitor = doc_generator_visitor.DocGeneratorVisitor() 70 with self.assertRaises(RuntimeError): 71 visitor('non_class_or_module', 'non_class_or_module_object', []) 72 73 def test_duplicates_module_class_depth(self): 74 75 class Parent(object): 76 77 class Nested(object): 78 pass 79 80 tf = types.ModuleType('tf') 81 tf.Parent = Parent 82 tf.submodule = types.ModuleType('submodule') 83 tf.submodule.Parent = Parent 84 85 visitor = generate_lib.extract( 86 [('tf', tf)], 87 private_map={}, 88 do_not_descend_map={}, 89 visitor_cls=NoDunderVisitor) 90 91 self.assertEqual({ 92 'tf.submodule.Parent': 93 sorted([ 94 'tf.Parent', 95 'tf.submodule.Parent', 96 ]), 97 'tf.submodule.Parent.Nested': 98 sorted([ 99 'tf.Parent.Nested', 100 'tf.submodule.Parent.Nested', 101 ]), 102 }, visitor.duplicates) 103 104 self.assertEqual({ 105 'tf.Parent.Nested': 'tf.submodule.Parent.Nested', 106 'tf.Parent': 'tf.submodule.Parent', 107 }, visitor.duplicate_of) 108 109 self.assertEqual({ 110 id(Parent): 'tf.submodule.Parent', 111 id(Parent.Nested): 'tf.submodule.Parent.Nested', 112 id(tf): 'tf', 113 id(tf.submodule): 'tf.submodule', 114 }, visitor.reverse_index) 115 116 def test_duplicates_contrib(self): 117 118 class Parent(object): 119 pass 120 121 tf = types.ModuleType('tf') 122 tf.contrib = types.ModuleType('contrib') 123 tf.submodule = types.ModuleType('submodule') 124 tf.contrib.Parent = Parent 125 tf.submodule.Parent = Parent 126 127 visitor = generate_lib.extract( 128 [('tf', tf)], 129 private_map={}, 130 do_not_descend_map={}, 131 visitor_cls=NoDunderVisitor) 132 133 self.assertEqual({ 134 'tf.submodule.Parent': 135 sorted(['tf.contrib.Parent', 'tf.submodule.Parent']), 136 }, visitor.duplicates) 137 138 self.assertEqual({ 139 'tf.contrib.Parent': 'tf.submodule.Parent', 140 }, visitor.duplicate_of) 141 142 self.assertEqual({ 143 id(tf): 'tf', 144 id(tf.submodule): 'tf.submodule', 145 id(Parent): 'tf.submodule.Parent', 146 id(tf.contrib): 'tf.contrib', 147 }, visitor.reverse_index) 148 149 def test_duplicates_defining_class(self): 150 151 class Parent(object): 152 obj1 = object() 153 154 class Child(Parent): 155 pass 156 157 tf = types.ModuleType('tf') 158 tf.Parent = Parent 159 tf.Child = Child 160 161 visitor = generate_lib.extract( 162 [('tf', tf)], 163 private_map={}, 164 do_not_descend_map={}, 165 visitor_cls=NoDunderVisitor) 166 167 self.assertEqual({ 168 'tf.Parent.obj1': sorted([ 169 'tf.Parent.obj1', 170 'tf.Child.obj1', 171 ]), 172 }, visitor.duplicates) 173 174 self.assertEqual({ 175 'tf.Child.obj1': 'tf.Parent.obj1', 176 }, visitor.duplicate_of) 177 178 self.assertEqual({ 179 id(tf): 'tf', 180 id(Parent): 'tf.Parent', 181 id(Child): 'tf.Child', 182 id(Parent.obj1): 'tf.Parent.obj1', 183 }, visitor.reverse_index) 184 185 def test_duplicates_module_depth(self): 186 187 class Parent(object): 188 pass 189 190 tf = types.ModuleType('tf') 191 tf.submodule = types.ModuleType('submodule') 192 tf.submodule.submodule2 = types.ModuleType('submodule2') 193 tf.Parent = Parent 194 tf.submodule.submodule2.Parent = Parent 195 196 visitor = generate_lib.extract( 197 [('tf', tf)], 198 private_map={}, 199 do_not_descend_map={}, 200 visitor_cls=NoDunderVisitor) 201 202 self.assertEqual({ 203 'tf.Parent': sorted(['tf.Parent', 'tf.submodule.submodule2.Parent']), 204 }, visitor.duplicates) 205 206 self.assertEqual({ 207 'tf.submodule.submodule2.Parent': 'tf.Parent' 208 }, visitor.duplicate_of) 209 210 self.assertEqual({ 211 id(tf): 'tf', 212 id(tf.submodule): 'tf.submodule', 213 id(tf.submodule.submodule2): 'tf.submodule.submodule2', 214 id(Parent): 'tf.Parent', 215 }, visitor.reverse_index) 216 217 def test_duplicates_name(self): 218 219 class Parent(object): 220 obj1 = object() 221 222 Parent.obj2 = Parent.obj1 223 224 tf = types.ModuleType('tf') 225 tf.submodule = types.ModuleType('submodule') 226 tf.submodule.Parent = Parent 227 228 visitor = generate_lib.extract( 229 [('tf', tf)], 230 private_map={}, 231 do_not_descend_map={}, 232 visitor_cls=NoDunderVisitor) 233 234 self.assertEqual({ 235 'tf.submodule.Parent.obj1': 236 sorted([ 237 'tf.submodule.Parent.obj1', 238 'tf.submodule.Parent.obj2', 239 ]), 240 }, visitor.duplicates) 241 242 self.assertEqual({ 243 'tf.submodule.Parent.obj2': 'tf.submodule.Parent.obj1', 244 }, visitor.duplicate_of) 245 246 self.assertEqual({ 247 id(tf): 'tf', 248 id(tf.submodule): 'tf.submodule', 249 id(Parent): 'tf.submodule.Parent', 250 id(Parent.obj1): 'tf.submodule.Parent.obj1', 251 }, visitor.reverse_index) 252 253if __name__ == '__main__': 254 googletest.main() 255