1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for templates module.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import gast 22 23from tensorflow.python.autograph.pyct import anno 24from tensorflow.python.autograph.pyct import parser 25from tensorflow.python.autograph.pyct import transformer 26from tensorflow.python.platform import test 27 28 29class TransformerTest(test.TestCase): 30 31 def _simple_context(self): 32 entity_info = transformer.EntityInfo( 33 source_code=None, 34 source_file=None, 35 namespace=None, 36 arg_values=None, 37 arg_types=None) 38 return transformer.Context(entity_info) 39 40 def test_entity_scope_tracking(self): 41 42 class TestTransformer(transformer.Base): 43 44 # The choice of note to assign to is arbitrary. Using Assign because it's 45 # easy to find in the tree. 46 def visit_Assign(self, node): 47 anno.setanno(node, 'enclosing_entities', self.enclosing_entities) 48 return self.generic_visit(node) 49 50 # This will show up in the lambda function. 51 def visit_BinOp(self, node): 52 anno.setanno(node, 'enclosing_entities', self.enclosing_entities) 53 return self.generic_visit(node) 54 55 tr = TestTransformer(self._simple_context()) 56 57 def test_function(): 58 a = 0 59 60 class TestClass(object): 61 62 def test_method(self): 63 b = 0 64 def inner_function(x): 65 c = 0 66 d = lambda y: (x + y) 67 return c, d 68 return b, inner_function 69 return a, TestClass 70 71 node, _, _ = parser.parse_entity(test_function) 72 node = tr.visit(node) 73 74 test_function_node = node 75 test_class = test_function_node.body[1] 76 test_method = test_class.body[0] 77 inner_function = test_method.body[1] 78 lambda_node = inner_function.body[1].value 79 80 a = test_function_node.body[0] 81 b = test_method.body[0] 82 c = inner_function.body[0] 83 lambda_expr = lambda_node.body 84 85 self.assertEqual( 86 (test_function_node,), anno.getanno(a, 'enclosing_entities')) 87 self.assertEqual((test_function_node, test_class, test_method), 88 anno.getanno(b, 'enclosing_entities')) 89 self.assertEqual( 90 (test_function_node, test_class, test_method, inner_function), 91 anno.getanno(c, 'enclosing_entities')) 92 self.assertEqual((test_function_node, test_class, test_method, 93 inner_function, lambda_node), 94 anno.getanno(lambda_expr, 'enclosing_entities')) 95 96 def assertSameAnno(self, first, second, key): 97 self.assertIs(anno.getanno(first, key), anno.getanno(second, key)) 98 99 def assertDifferentAnno(self, first, second, key): 100 self.assertIsNot(anno.getanno(first, key), anno.getanno(second, key)) 101 102 def test_state_tracking(self): 103 104 class LoopState(object): 105 pass 106 107 class CondState(object): 108 pass 109 110 class TestTransformer(transformer.Base): 111 112 def visit(self, node): 113 anno.setanno(node, 'loop_state', self.state[LoopState].value) 114 anno.setanno(node, 'cond_state', self.state[CondState].value) 115 return super(TestTransformer, self).visit(node) 116 117 def visit_While(self, node): 118 self.state[LoopState].enter() 119 node = self.generic_visit(node) 120 self.state[LoopState].exit() 121 return node 122 123 def visit_If(self, node): 124 self.state[CondState].enter() 125 node = self.generic_visit(node) 126 self.state[CondState].exit() 127 return node 128 129 tr = TestTransformer(self._simple_context()) 130 131 def test_function(a): 132 a = 1 133 while a: 134 _ = 'a' 135 if a > 2: 136 _ = 'b' 137 while True: 138 raise '1' 139 if a > 3: 140 _ = 'c' 141 while True: 142 raise '1' 143 144 node, _, _ = parser.parse_entity(test_function) 145 node = tr.visit(node) 146 147 fn_body = node.body 148 outer_while_body = fn_body[1].body 149 self.assertSameAnno(fn_body[0], outer_while_body[0], 'cond_state') 150 self.assertDifferentAnno(fn_body[0], outer_while_body[0], 'loop_state') 151 152 first_if_body = outer_while_body[1].body 153 self.assertDifferentAnno(outer_while_body[0], first_if_body[0], 154 'cond_state') 155 self.assertSameAnno(outer_while_body[0], first_if_body[0], 'loop_state') 156 157 first_inner_while_body = first_if_body[1].body 158 self.assertSameAnno(first_if_body[0], first_inner_while_body[0], 159 'cond_state') 160 self.assertDifferentAnno(first_if_body[0], first_inner_while_body[0], 161 'loop_state') 162 163 second_if_body = outer_while_body[2].body 164 self.assertDifferentAnno(first_if_body[0], second_if_body[0], 'cond_state') 165 self.assertSameAnno(first_if_body[0], second_if_body[0], 'loop_state') 166 167 second_inner_while_body = second_if_body[1].body 168 self.assertDifferentAnno(first_inner_while_body[0], 169 second_inner_while_body[0], 'cond_state') 170 self.assertDifferentAnno(first_inner_while_body[0], 171 second_inner_while_body[0], 'loop_state') 172 173 def test_local_scope_info_stack(self): 174 175 class TestTransformer(transformer.Base): 176 177 # Extract all string constants from the block. 178 def visit_Str(self, node): 179 self.set_local('string', self.get_local('string', default='') + node.s) 180 return self.generic_visit(node) 181 182 def _annotate_result(self, node): 183 self.enter_local_scope() 184 node = self.generic_visit(node) 185 anno.setanno(node, 'test', self.get_local('string')) 186 self.exit_local_scope() 187 return node 188 189 def visit_While(self, node): 190 return self._annotate_result(node) 191 192 def visit_For(self, node): 193 return self._annotate_result(node) 194 195 tr = TestTransformer(self._simple_context()) 196 197 def test_function(a): 198 """Docstring.""" 199 assert a == 'This should not be counted' 200 for i in range(3): 201 _ = 'a' 202 if i > 2: 203 return 'b' 204 else: 205 _ = 'c' 206 while True: 207 raise '1' 208 return 'nor this' 209 210 node, _, _ = parser.parse_entity(test_function) 211 node = tr.visit(node) 212 213 for_node = node.body[2] 214 while_node = for_node.body[1].orelse[1] 215 216 self.assertFalse(anno.hasanno(for_node, 'string')) 217 self.assertEqual('abc', anno.getanno(for_node, 'test')) 218 self.assertFalse(anno.hasanno(while_node, 'string')) 219 self.assertEqual('1', anno.getanno(while_node, 'test')) 220 221 def test_local_scope_info_stack_checks_integrity(self): 222 223 class TestTransformer(transformer.Base): 224 225 def visit_If(self, node): 226 self.enter_local_scope() 227 return self.generic_visit(node) 228 229 def visit_For(self, node): 230 node = self.generic_visit(node) 231 self.exit_local_scope() 232 return node 233 234 tr = TestTransformer(self._simple_context()) 235 236 def no_exit(a): 237 if a > 0: 238 print(a) 239 return None 240 241 node, _, _ = parser.parse_entity(no_exit) 242 with self.assertRaises(AssertionError): 243 tr.visit(node) 244 245 def no_entry(a): 246 for _ in a: 247 print(a) 248 249 node, _, _ = parser.parse_entity(no_entry) 250 with self.assertRaises(AssertionError): 251 tr.visit(node) 252 253 def test_visit_block_postprocessing(self): 254 255 class TestTransformer(transformer.Base): 256 257 def _process_body_item(self, node): 258 if isinstance(node, gast.Assign) and (node.value.id == 'y'): 259 if_node = gast.If(gast.Name('x', gast.Load(), None), [node], []) 260 return if_node, if_node.body 261 return node, None 262 263 def visit_FunctionDef(self, node): 264 node.body = self.visit_block( 265 node.body, after_visit=self._process_body_item) 266 return node 267 268 def test_function(x, y): 269 z = x 270 z = y 271 return z 272 273 tr = TestTransformer(self._simple_context()) 274 275 node, _, _ = parser.parse_entity(test_function) 276 node = tr.visit(node) 277 278 self.assertEqual(len(node.body), 2) 279 self.assertTrue(isinstance(node.body[0], gast.Assign)) 280 self.assertTrue(isinstance(node.body[1], gast.If)) 281 self.assertTrue(isinstance(node.body[1].body[0], gast.Assign)) 282 self.assertTrue(isinstance(node.body[1].body[1], gast.Return)) 283 284 def test_robust_error_on_list_visit(self): 285 286 class BrokenTransformer(transformer.Base): 287 288 def visit_If(self, node): 289 # This is broken because visit expects a single node, not a list, and 290 # the body of an if is a list. 291 # Importantly, the default error handling in visit also expects a single 292 # node. Therefore, mistakes like this need to trigger a type error 293 # before the visit called here installs its error handler. 294 # That type error can then be caught by the enclosing call to visit, 295 # and correctly blame the If node. 296 self.visit(node.body) 297 return node 298 299 def test_function(x): 300 if x > 0: 301 return x 302 303 tr = BrokenTransformer(self._simple_context()) 304 305 _, _, all_nodes = parser.parse_entity(test_function) 306 with self.assertRaises(ValueError) as cm: 307 all_nodes = tr.visit(all_nodes) 308 obtained_message = str(cm.exception) 309 expected_message = r'expected "ast.AST", got "\<(type|class) \'list\'\>"' 310 self.assertRegexpMatches(obtained_message, expected_message) 311 312 def test_robust_error_on_ast_corruption(self): 313 # A child class should not be able to be so broken that it causes the error 314 # handling in `transformer.Base` to raise an exception. Why not? Because 315 # then the original error location is dropped, and an error handler higher 316 # up in the call stack gives misleading information. 317 318 # Here we test that the error handling in `visit` completes, and blames the 319 # correct original exception, even if the AST gets corrupted. 320 321 class NotANode(object): 322 pass 323 324 class BrokenTransformer(transformer.Base): 325 326 def visit_If(self, node): 327 node.body = NotANode() 328 raise ValueError('I blew up') 329 330 def test_function(x): 331 if x > 0: 332 return x 333 334 tr = BrokenTransformer(self._simple_context()) 335 336 _, _, all_nodes = parser.parse_entity(test_function) 337 with self.assertRaises(ValueError) as cm: 338 all_nodes = tr.visit(all_nodes) 339 obtained_message = str(cm.exception) 340 # The message should reference the exception actually raised, not anything 341 # from the exception handler. 342 expected_substring = 'I blew up' 343 self.assertTrue(expected_substring in obtained_message, obtained_message) 344 345if __name__ == '__main__': 346 test.main() 347