• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tools.docs.doc_generator_visitor."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import types
22
23from tensorflow.python.platform import googletest
24from tensorflow.tools.docs import doc_generator_visitor
25from tensorflow.tools.docs import generate_lib
26
27
28class NoDunderVisitor(doc_generator_visitor.DocGeneratorVisitor):
29
30  def __call__(self, parent_name, parent, children):
31    """Drop all the dunder methods to make testing easier."""
32    children = [
33        (name, obj) for (name, obj) in children if not name.startswith('_')
34    ]
35    super(NoDunderVisitor, self).__call__(parent_name, parent, children)
36
37
38class DocGeneratorVisitorTest(googletest.TestCase):
39
40  def test_call_module(self):
41    visitor = doc_generator_visitor.DocGeneratorVisitor()
42    visitor(
43        'doc_generator_visitor', doc_generator_visitor,
44        [('DocGeneratorVisitor', doc_generator_visitor.DocGeneratorVisitor)])
45
46    self.assertEqual({'doc_generator_visitor': ['DocGeneratorVisitor']},
47                     visitor.tree)
48    self.assertEqual({
49        'doc_generator_visitor': doc_generator_visitor,
50        'doc_generator_visitor.DocGeneratorVisitor':
51        doc_generator_visitor.DocGeneratorVisitor,
52    }, visitor.index)
53
54  def test_call_class(self):
55    visitor = doc_generator_visitor.DocGeneratorVisitor()
56    visitor(
57        'DocGeneratorVisitor', doc_generator_visitor.DocGeneratorVisitor,
58        [('index', doc_generator_visitor.DocGeneratorVisitor.index)])
59
60    self.assertEqual({'DocGeneratorVisitor': ['index']},
61                     visitor.tree)
62    self.assertEqual({
63        'DocGeneratorVisitor': doc_generator_visitor.DocGeneratorVisitor,
64        'DocGeneratorVisitor.index':
65        doc_generator_visitor.DocGeneratorVisitor.index
66    }, visitor.index)
67
68  def test_call_raises(self):
69    visitor = doc_generator_visitor.DocGeneratorVisitor()
70    with self.assertRaises(RuntimeError):
71      visitor('non_class_or_module', 'non_class_or_module_object', [])
72
73  def test_duplicates_module_class_depth(self):
74
75    class Parent(object):
76
77      class Nested(object):
78        pass
79
80    tf = types.ModuleType('tf')
81    tf.Parent = Parent
82    tf.submodule = types.ModuleType('submodule')
83    tf.submodule.Parent = Parent
84
85    visitor = generate_lib.extract(
86        [('tf', tf)],
87        private_map={},
88        do_not_descend_map={},
89        visitor_cls=NoDunderVisitor)
90
91    self.assertEqual({
92        'tf.submodule.Parent':
93            sorted([
94                'tf.Parent',
95                'tf.submodule.Parent',
96            ]),
97        'tf.submodule.Parent.Nested':
98            sorted([
99                'tf.Parent.Nested',
100                'tf.submodule.Parent.Nested',
101            ]),
102    }, visitor.duplicates)
103
104    self.assertEqual({
105        'tf.Parent.Nested': 'tf.submodule.Parent.Nested',
106        'tf.Parent': 'tf.submodule.Parent',
107    }, visitor.duplicate_of)
108
109    self.assertEqual({
110        id(Parent): 'tf.submodule.Parent',
111        id(Parent.Nested): 'tf.submodule.Parent.Nested',
112        id(tf): 'tf',
113        id(tf.submodule): 'tf.submodule',
114    }, visitor.reverse_index)
115
116  def test_duplicates_contrib(self):
117
118    class Parent(object):
119      pass
120
121    tf = types.ModuleType('tf')
122    tf.contrib = types.ModuleType('contrib')
123    tf.submodule = types.ModuleType('submodule')
124    tf.contrib.Parent = Parent
125    tf.submodule.Parent = Parent
126
127    visitor = generate_lib.extract(
128        [('tf', tf)],
129        private_map={},
130        do_not_descend_map={},
131        visitor_cls=NoDunderVisitor)
132
133    self.assertEqual({
134        'tf.submodule.Parent':
135            sorted(['tf.contrib.Parent', 'tf.submodule.Parent']),
136    }, visitor.duplicates)
137
138    self.assertEqual({
139        'tf.contrib.Parent': 'tf.submodule.Parent',
140    }, visitor.duplicate_of)
141
142    self.assertEqual({
143        id(tf): 'tf',
144        id(tf.submodule): 'tf.submodule',
145        id(Parent): 'tf.submodule.Parent',
146        id(tf.contrib): 'tf.contrib',
147    }, visitor.reverse_index)
148
149  def test_duplicates_defining_class(self):
150
151    class Parent(object):
152      obj1 = object()
153
154    class Child(Parent):
155      pass
156
157    tf = types.ModuleType('tf')
158    tf.Parent = Parent
159    tf.Child = Child
160
161    visitor = generate_lib.extract(
162        [('tf', tf)],
163        private_map={},
164        do_not_descend_map={},
165        visitor_cls=NoDunderVisitor)
166
167    self.assertEqual({
168        'tf.Parent.obj1': sorted([
169            'tf.Parent.obj1',
170            'tf.Child.obj1',
171        ]),
172    }, visitor.duplicates)
173
174    self.assertEqual({
175        'tf.Child.obj1': 'tf.Parent.obj1',
176    }, visitor.duplicate_of)
177
178    self.assertEqual({
179        id(tf): 'tf',
180        id(Parent): 'tf.Parent',
181        id(Child): 'tf.Child',
182        id(Parent.obj1): 'tf.Parent.obj1',
183    }, visitor.reverse_index)
184
185  def test_duplicates_module_depth(self):
186
187    class Parent(object):
188      pass
189
190    tf = types.ModuleType('tf')
191    tf.submodule = types.ModuleType('submodule')
192    tf.submodule.submodule2 = types.ModuleType('submodule2')
193    tf.Parent = Parent
194    tf.submodule.submodule2.Parent = Parent
195
196    visitor = generate_lib.extract(
197        [('tf', tf)],
198        private_map={},
199        do_not_descend_map={},
200        visitor_cls=NoDunderVisitor)
201
202    self.assertEqual({
203        'tf.Parent': sorted(['tf.Parent', 'tf.submodule.submodule2.Parent']),
204    }, visitor.duplicates)
205
206    self.assertEqual({
207        'tf.submodule.submodule2.Parent': 'tf.Parent'
208    }, visitor.duplicate_of)
209
210    self.assertEqual({
211        id(tf): 'tf',
212        id(tf.submodule): 'tf.submodule',
213        id(tf.submodule.submodule2): 'tf.submodule.submodule2',
214        id(Parent): 'tf.Parent',
215    }, visitor.reverse_index)
216
217  def test_duplicates_name(self):
218
219    class Parent(object):
220      obj1 = object()
221
222    Parent.obj2 = Parent.obj1
223
224    tf = types.ModuleType('tf')
225    tf.submodule = types.ModuleType('submodule')
226    tf.submodule.Parent = Parent
227
228    visitor = generate_lib.extract(
229        [('tf', tf)],
230        private_map={},
231        do_not_descend_map={},
232        visitor_cls=NoDunderVisitor)
233
234    self.assertEqual({
235        'tf.submodule.Parent.obj1':
236            sorted([
237                'tf.submodule.Parent.obj1',
238                'tf.submodule.Parent.obj2',
239            ]),
240    }, visitor.duplicates)
241
242    self.assertEqual({
243        'tf.submodule.Parent.obj2': 'tf.submodule.Parent.obj1',
244    }, visitor.duplicate_of)
245
246    self.assertEqual({
247        id(tf): 'tf',
248        id(tf.submodule): 'tf.submodule',
249        id(Parent): 'tf.submodule.Parent',
250        id(Parent.obj1): 'tf.submodule.Parent.obj1',
251    }, visitor.reverse_index)
252
253if __name__ == '__main__':
254  googletest.main()
255