• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for templates module."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import gast
22
23from tensorflow.python.autograph.pyct import anno
24from tensorflow.python.autograph.pyct import parser
25from tensorflow.python.autograph.pyct import transformer
26from tensorflow.python.platform import test
27
28
29class TransformerTest(test.TestCase):
30
31  def _simple_context(self):
32    entity_info = transformer.EntityInfo(
33        source_code=None,
34        source_file=None,
35        namespace=None,
36        arg_values=None,
37        arg_types=None)
38    return transformer.Context(entity_info)
39
40  def test_entity_scope_tracking(self):
41
42    class TestTransformer(transformer.Base):
43
44      # The choice of note to assign to is arbitrary. Using Assign because it's
45      # easy to find in the tree.
46      def visit_Assign(self, node):
47        anno.setanno(node, 'enclosing_entities', self.enclosing_entities)
48        return self.generic_visit(node)
49
50      # This will show up in the lambda function.
51      def visit_BinOp(self, node):
52        anno.setanno(node, 'enclosing_entities', self.enclosing_entities)
53        return self.generic_visit(node)
54
55    tr = TestTransformer(self._simple_context())
56
57    def test_function():
58      a = 0
59
60      class TestClass(object):
61
62        def test_method(self):
63          b = 0
64          def inner_function(x):
65            c = 0
66            d = lambda y: (x + y)
67            return c, d
68          return b, inner_function
69      return a, TestClass
70
71    node, _, _ = parser.parse_entity(test_function)
72    node = tr.visit(node)
73
74    test_function_node = node
75    test_class = test_function_node.body[1]
76    test_method = test_class.body[0]
77    inner_function = test_method.body[1]
78    lambda_node = inner_function.body[1].value
79
80    a = test_function_node.body[0]
81    b = test_method.body[0]
82    c = inner_function.body[0]
83    lambda_expr = lambda_node.body
84
85    self.assertEqual(
86        (test_function_node,), anno.getanno(a, 'enclosing_entities'))
87    self.assertEqual((test_function_node, test_class, test_method),
88                     anno.getanno(b, 'enclosing_entities'))
89    self.assertEqual(
90        (test_function_node, test_class, test_method, inner_function),
91        anno.getanno(c, 'enclosing_entities'))
92    self.assertEqual((test_function_node, test_class, test_method,
93                      inner_function, lambda_node),
94                     anno.getanno(lambda_expr, 'enclosing_entities'))
95
96  def assertSameAnno(self, first, second, key):
97    self.assertIs(anno.getanno(first, key), anno.getanno(second, key))
98
99  def assertDifferentAnno(self, first, second, key):
100    self.assertIsNot(anno.getanno(first, key), anno.getanno(second, key))
101
102  def test_state_tracking(self):
103
104    class LoopState(object):
105      pass
106
107    class CondState(object):
108      pass
109
110    class TestTransformer(transformer.Base):
111
112      def visit(self, node):
113        anno.setanno(node, 'loop_state', self.state[LoopState].value)
114        anno.setanno(node, 'cond_state', self.state[CondState].value)
115        return super(TestTransformer, self).visit(node)
116
117      def visit_While(self, node):
118        self.state[LoopState].enter()
119        node = self.generic_visit(node)
120        self.state[LoopState].exit()
121        return node
122
123      def visit_If(self, node):
124        self.state[CondState].enter()
125        node = self.generic_visit(node)
126        self.state[CondState].exit()
127        return node
128
129    tr = TestTransformer(self._simple_context())
130
131    def test_function(a):
132      a = 1
133      while a:
134        _ = 'a'
135        if a > 2:
136          _ = 'b'
137          while True:
138            raise '1'
139        if a > 3:
140          _ = 'c'
141          while True:
142            raise '1'
143
144    node, _, _ = parser.parse_entity(test_function)
145    node = tr.visit(node)
146
147    fn_body = node.body
148    outer_while_body = fn_body[1].body
149    self.assertSameAnno(fn_body[0], outer_while_body[0], 'cond_state')
150    self.assertDifferentAnno(fn_body[0], outer_while_body[0], 'loop_state')
151
152    first_if_body = outer_while_body[1].body
153    self.assertDifferentAnno(outer_while_body[0], first_if_body[0],
154                             'cond_state')
155    self.assertSameAnno(outer_while_body[0], first_if_body[0], 'loop_state')
156
157    first_inner_while_body = first_if_body[1].body
158    self.assertSameAnno(first_if_body[0], first_inner_while_body[0],
159                        'cond_state')
160    self.assertDifferentAnno(first_if_body[0], first_inner_while_body[0],
161                             'loop_state')
162
163    second_if_body = outer_while_body[2].body
164    self.assertDifferentAnno(first_if_body[0], second_if_body[0], 'cond_state')
165    self.assertSameAnno(first_if_body[0], second_if_body[0], 'loop_state')
166
167    second_inner_while_body = second_if_body[1].body
168    self.assertDifferentAnno(first_inner_while_body[0],
169                             second_inner_while_body[0], 'cond_state')
170    self.assertDifferentAnno(first_inner_while_body[0],
171                             second_inner_while_body[0], 'loop_state')
172
173  def test_local_scope_info_stack(self):
174
175    class TestTransformer(transformer.Base):
176
177      # Extract all string constants from the block.
178      def visit_Str(self, node):
179        self.set_local('string', self.get_local('string', default='') + node.s)
180        return self.generic_visit(node)
181
182      def _annotate_result(self, node):
183        self.enter_local_scope()
184        node = self.generic_visit(node)
185        anno.setanno(node, 'test', self.get_local('string'))
186        self.exit_local_scope()
187        return node
188
189      def visit_While(self, node):
190        return self._annotate_result(node)
191
192      def visit_For(self, node):
193        return self._annotate_result(node)
194
195    tr = TestTransformer(self._simple_context())
196
197    def test_function(a):
198      """Docstring."""
199      assert a == 'This should not be counted'
200      for i in range(3):
201        _ = 'a'
202        if i > 2:
203          return 'b'
204        else:
205          _ = 'c'
206          while True:
207            raise '1'
208      return 'nor this'
209
210    node, _, _ = parser.parse_entity(test_function)
211    node = tr.visit(node)
212
213    for_node = node.body[2]
214    while_node = for_node.body[1].orelse[1]
215
216    self.assertFalse(anno.hasanno(for_node, 'string'))
217    self.assertEqual('abc', anno.getanno(for_node, 'test'))
218    self.assertFalse(anno.hasanno(while_node, 'string'))
219    self.assertEqual('1', anno.getanno(while_node, 'test'))
220
221  def test_local_scope_info_stack_checks_integrity(self):
222
223    class TestTransformer(transformer.Base):
224
225      def visit_If(self, node):
226        self.enter_local_scope()
227        return self.generic_visit(node)
228
229      def visit_For(self, node):
230        node = self.generic_visit(node)
231        self.exit_local_scope()
232        return node
233
234    tr = TestTransformer(self._simple_context())
235
236    def no_exit(a):
237      if a > 0:
238        print(a)
239      return None
240
241    node, _, _ = parser.parse_entity(no_exit)
242    with self.assertRaises(AssertionError):
243      tr.visit(node)
244
245    def no_entry(a):
246      for _ in a:
247        print(a)
248
249    node, _, _ = parser.parse_entity(no_entry)
250    with self.assertRaises(AssertionError):
251      tr.visit(node)
252
253  def test_visit_block_postprocessing(self):
254
255    class TestTransformer(transformer.Base):
256
257      def _process_body_item(self, node):
258        if isinstance(node, gast.Assign) and (node.value.id == 'y'):
259          if_node = gast.If(gast.Name('x', gast.Load(), None), [node], [])
260          return if_node, if_node.body
261        return node, None
262
263      def visit_FunctionDef(self, node):
264        node.body = self.visit_block(
265            node.body, after_visit=self._process_body_item)
266        return node
267
268    def test_function(x, y):
269      z = x
270      z = y
271      return z
272
273    tr = TestTransformer(self._simple_context())
274
275    node, _, _ = parser.parse_entity(test_function)
276    node = tr.visit(node)
277
278    self.assertEqual(len(node.body), 2)
279    self.assertTrue(isinstance(node.body[0], gast.Assign))
280    self.assertTrue(isinstance(node.body[1], gast.If))
281    self.assertTrue(isinstance(node.body[1].body[0], gast.Assign))
282    self.assertTrue(isinstance(node.body[1].body[1], gast.Return))
283
284  def test_robust_error_on_list_visit(self):
285
286    class BrokenTransformer(transformer.Base):
287
288      def visit_If(self, node):
289        # This is broken because visit expects a single node, not a list, and
290        # the body of an if is a list.
291        # Importantly, the default error handling in visit also expects a single
292        # node.  Therefore, mistakes like this need to trigger a type error
293        # before the visit called here installs its error handler.
294        # That type error can then be caught by the enclosing call to visit,
295        # and correctly blame the If node.
296        self.visit(node.body)
297        return node
298
299    def test_function(x):
300      if x > 0:
301        return x
302
303    tr = BrokenTransformer(self._simple_context())
304
305    _, _, all_nodes = parser.parse_entity(test_function)
306    with self.assertRaises(ValueError) as cm:
307      all_nodes = tr.visit(all_nodes)
308    obtained_message = str(cm.exception)
309    expected_message = r'expected "ast.AST", got "\<(type|class) \'list\'\>"'
310    self.assertRegexpMatches(obtained_message, expected_message)
311
312  def test_robust_error_on_ast_corruption(self):
313    # A child class should not be able to be so broken that it causes the error
314    # handling in `transformer.Base` to raise an exception.  Why not?  Because
315    # then the original error location is dropped, and an error handler higher
316    # up in the call stack gives misleading information.
317
318    # Here we test that the error handling in `visit` completes, and blames the
319    # correct original exception, even if the AST gets corrupted.
320
321    class NotANode(object):
322      pass
323
324    class BrokenTransformer(transformer.Base):
325
326      def visit_If(self, node):
327        node.body = NotANode()
328        raise ValueError('I blew up')
329
330    def test_function(x):
331      if x > 0:
332        return x
333
334    tr = BrokenTransformer(self._simple_context())
335
336    _, _, all_nodes = parser.parse_entity(test_function)
337    with self.assertRaises(ValueError) as cm:
338      all_nodes = tr.visit(all_nodes)
339    obtained_message = str(cm.exception)
340    # The message should reference the exception actually raised, not anything
341    # from the exception handler.
342    expected_substring = 'I blew up'
343    self.assertTrue(expected_substring in obtained_message, obtained_message)
344
345if __name__ == '__main__':
346  test.main()
347