1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for parser module.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import re 22import textwrap 23 24import gast 25 26from tensorflow.python.autograph.pyct import ast_util 27from tensorflow.python.autograph.pyct import errors 28from tensorflow.python.autograph.pyct import parser 29from tensorflow.python.autograph.pyct import pretty_printer 30from tensorflow.python.platform import test 31 32 33class ParserTest(test.TestCase): 34 35 def assertAstMatches(self, actual_node, expected_node_src, expr=True): 36 if expr: 37 # Ensure multi-line expressions parse. 38 expected_node = gast.parse('({})'.format(expected_node_src)).body[0] 39 expected_node = expected_node.value 40 else: 41 expected_node = gast.parse(expected_node_src).body[0] 42 43 msg = 'AST did not match expected:\n{}\nActual:\n{}'.format( 44 pretty_printer.fmt(expected_node), 45 pretty_printer.fmt(actual_node)) 46 self.assertTrue(ast_util.matches(actual_node, expected_node), msg) 47 48 def test_parse_entity(self): 49 50 def f(x): 51 return x + 1 52 53 node, _ = parser.parse_entity(f, future_features=()) 54 self.assertEqual('f', node.name) 55 56 def test_parse_lambda(self): 57 58 l = lambda x: x + 1 59 expected_node_src = 'lambda x: (x + 1)' 60 61 node, source = parser.parse_entity(l, future_features=()) 62 self.assertAstMatches(node, source) 63 self.assertAstMatches(node, expected_node_src) 64 65 def test_parse_lambda_prefix_cleanup(self): 66 67 lambda_lam = lambda x: x + 1 68 expected_node_src = 'lambda x: (x + 1)' 69 70 node, source = parser.parse_entity(lambda_lam, future_features=()) 71 self.assertAstMatches(node, source) 72 self.assertAstMatches(node, expected_node_src) 73 74 def test_parse_lambda_resolution_by_location(self): 75 76 _ = lambda x: x + 1 77 l = lambda x: x + 1 78 _ = lambda x: x + 1 79 expected_node_src = 'lambda x: (x + 1)' 80 81 node, source = parser.parse_entity(l, future_features=()) 82 self.assertAstMatches(node, source) 83 self.assertAstMatches(node, expected_node_src) 84 self.assertEqual(source, 'lambda x: x + 1') 85 86 def test_parse_lambda_resolution_by_signature(self): 87 88 l = lambda x: lambda x, y: x + y 89 90 node, source = parser.parse_entity(l, future_features=()) 91 expected_node_src = 'lambda x: (lambda x, y: (x + y))' 92 self.assertAstMatches(node, source) 93 self.assertAstMatches(node, expected_node_src) 94 self.assertEqual(source, 'lambda x: lambda x, y: x + y') 95 96 node, source = parser.parse_entity(l(0), future_features=()) 97 expected_node_src = 'lambda x, y: (x + y)' 98 self.assertAstMatches(node, source) 99 self.assertAstMatches(node, expected_node_src) 100 self.assertEqual(source, 'lambda x, y: x + y') 101 102 def test_parse_lambda_resolution_ambiguous(self): 103 104 l = lambda x: lambda x: 2 * x 105 106 expected_exception_text = re.compile(r'found multiple definitions' 107 r'.+' 108 r'\(?lambda x: \(?lambda x' 109 r'.+' 110 r'\(?lambda x: \(?2', re.DOTALL) 111 112 with self.assertRaisesRegex( 113 errors.UnsupportedLanguageElementError, 114 expected_exception_text): 115 parser.parse_entity(l, future_features=()) 116 117 with self.assertRaisesRegex( 118 errors.UnsupportedLanguageElementError, 119 expected_exception_text): 120 parser.parse_entity(l(0), future_features=()) 121 122 def assertMatchesWithPotentialGarbage(self, source, expected, garbage): 123 # In runtimes which don't track end_col_number, the source contains the 124 # entire line, which in turn may have garbage from the surrounding context. 125 self.assertIn(source, (expected, expected + garbage)) 126 127 def test_parse_lambda_multiline(self): 128 129 l = ( 130 lambda x: lambda y: x + y # pylint:disable=g-long-lambda 131 - 1) 132 133 node, source = parser.parse_entity(l, future_features=()) 134 expected_node_src = 'lambda x: (lambda y: ((x + y) - 1))' 135 self.assertAstMatches(node, expected_node_src) 136 self.assertMatchesWithPotentialGarbage( 137 source, ('lambda x: lambda y: x + y # pylint:disable=g-long-lambda\n' 138 ' - 1'), ')') 139 140 node, source = parser.parse_entity(l(0), future_features=()) 141 expected_node_src = 'lambda y: ((x + y) - 1)' 142 self.assertAstMatches(node, expected_node_src) 143 self.assertMatchesWithPotentialGarbage( 144 source, ('lambda y: x + y # pylint:disable=g-long-lambda\n' 145 ' - 1'), ')') 146 147 def test_parse_lambda_in_expression(self): 148 149 l = ( 150 lambda x: lambda y: x + y + 1, 151 lambda x: lambda y: x + y + 2, 152 ) 153 154 node, source = parser.parse_entity(l[0], future_features=()) 155 expected_node_src = 'lambda x: (lambda y: ((x + y) + 1))' 156 self.assertAstMatches(node, expected_node_src) 157 self.assertMatchesWithPotentialGarbage( 158 source, 'lambda x: lambda y: x + y + 1', ',') 159 160 node, source = parser.parse_entity(l[0](0), future_features=()) 161 expected_node_src = 'lambda y: ((x + y) + 1)' 162 self.assertAstMatches(node, expected_node_src) 163 self.assertMatchesWithPotentialGarbage( 164 source, 'lambda y: x + y + 1', ',') 165 166 node, source = parser.parse_entity(l[1], future_features=()) 167 expected_node_src = 'lambda x: (lambda y: ((x + y) + 2))' 168 self.assertAstMatches(node, expected_node_src) 169 self.assertMatchesWithPotentialGarbage(source, 170 'lambda x: lambda y: x + y + 2', ',') 171 172 node, source = parser.parse_entity(l[1](0), future_features=()) 173 expected_node_src = 'lambda y: ((x + y) + 2)' 174 self.assertAstMatches(node, expected_node_src) 175 self.assertMatchesWithPotentialGarbage(source, 'lambda y: x + y + 2', ',') 176 177 def test_parse_lambda_complex_body(self): 178 179 l = lambda x: ( # pylint:disable=g-long-lambda 180 x.y( 181 [], 182 x.z, 183 (), 184 x[0:2], 185 ), 186 x.u, 187 'abc', 188 1, 189 ) 190 191 node, source = parser.parse_entity(l, future_features=()) 192 expected_node_src = "lambda x: (x.y([], x.z, (), x[0:2]), x.u, 'abc', 1)" 193 self.assertAstMatches(node, expected_node_src) 194 195 base_source = ('lambda x: ( # pylint:disable=g-long-lambda\n' 196 ' x.y(\n' 197 ' [],\n' 198 ' x.z,\n' 199 ' (),\n' 200 ' x[0:2],\n' 201 ' ),\n' 202 ' x.u,\n' 203 ' \'abc\',\n' 204 ' 1,') 205 # The complete source includes the trailing parenthesis. But that is only 206 # detected in runtimes which correctly track end_lineno for ASTs. 207 self.assertMatchesWithPotentialGarbage(source, base_source, '\n )') 208 209 def test_parse_lambda_function_call_definition(self): 210 211 def do_parse_and_test(lam, **unused_kwargs): 212 node, source = parser.parse_entity(lam, future_features=()) 213 expected_node_src = 'lambda x: x' 214 self.assertAstMatches(node, expected_node_src) 215 self.assertMatchesWithPotentialGarbage( 216 source, 'lambda x: x', ', named_arg=1)') 217 218 do_parse_and_test( # Intentional line break 219 lambda x: x, named_arg=1) 220 221 def test_parse_entity_print_function(self): 222 223 def f(x): 224 print(x) 225 226 node, _ = parser.parse_entity(f, future_features=('print_function',)) 227 self.assertEqual('f', node.name) 228 229 def test_parse_comments(self): 230 231 def f(): 232 # unindented comment 233 pass 234 235 node, _ = parser.parse_entity(f, future_features=()) 236 self.assertEqual('f', node.name) 237 238 def test_parse_multiline_strings(self): 239 240 def f(): 241 print(""" 242multiline 243string""") 244 245 node, _ = parser.parse_entity(f, future_features=()) 246 self.assertEqual('f', node.name) 247 248 def _eval_code(self, code, name): 249 globs = {} 250 exec(code, globs) # pylint:disable=exec-used 251 return globs[name] 252 253 def test_dedent_block_basic(self): 254 255 code = """ 256 def f(x): 257 if x > 0: 258 return -x 259 return x 260 """ 261 262 f = self._eval_code(parser.dedent_block(code), 'f') 263 self.assertEqual(f(1), -1) 264 self.assertEqual(f(-1), -1) 265 266 def test_dedent_block_comments_out_of_line(self): 267 268 code = """ 269 ### 270 def f(x): 271### 272 if x > 0: 273 ### 274 return -x 275 ### 276 ### 277 return x 278 ### 279 """ 280 281 f = self._eval_code(parser.dedent_block(code), 'f') 282 self.assertEqual(f(1), -1) 283 self.assertEqual(f(-1), -1) 284 285 def test_dedent_block_multiline_string(self): 286 287 code = """ 288 def f(): 289 ''' 290 Docstring. 291 ''' 292 return ''' 293 1 294 2 295 3''' 296 """ 297 298 f = self._eval_code(parser.dedent_block(code), 'f') 299 self.assertEqual(f.__doc__, '\n Docstring.\n ') 300 self.assertEqual(f(), '\n 1\n 2\n 3') 301 302 def test_dedent_block_multiline_expression(self): 303 304 code = """ 305 def f(): 306 return (1, 3072, 308 3) 309 """ 310 311 f = self._eval_code(parser.dedent_block(code), 'f') 312 self.assertEqual(f(), (1, 2, 3)) 313 314 def test_dedent_block_continuation(self): 315 316 code = r""" 317 def f(): 318 a = \ 319 1 320 return a 321 """ 322 323 f = self._eval_code(parser.dedent_block(code), 'f') 324 self.assertEqual(f(), 1) 325 326 def test_dedent_block_continuation_in_string(self): 327 328 code = r""" 329 def f(): 330 a = "a \ 331 b" 332 return a 333 """ 334 335 f = self._eval_code(parser.dedent_block(code), 'f') 336 self.assertEqual(f(), 'a b') 337 338 def test_parse_expression(self): 339 node = parser.parse_expression('a.b') 340 self.assertEqual('a', node.value.id) 341 self.assertEqual('b', node.attr) 342 343 def test_unparse(self): 344 node = gast.If( 345 test=gast.Constant(1, kind=None), 346 body=[ 347 gast.Assign( 348 targets=[ 349 gast.Name( 350 'a', 351 ctx=gast.Store(), 352 annotation=None, 353 type_comment=None) 354 ], 355 value=gast.Name( 356 'b', ctx=gast.Load(), annotation=None, type_comment=None)) 357 ], 358 orelse=[ 359 gast.Assign( 360 targets=[ 361 gast.Name( 362 'a', 363 ctx=gast.Store(), 364 annotation=None, 365 type_comment=None) 366 ], 367 value=gast.Constant('c', kind=None)) 368 ]) 369 370 source = parser.unparse(node, indentation=' ') 371 self.assertEqual( 372 textwrap.dedent(""" 373 # coding=utf-8 374 if 1: 375 a = b 376 else: 377 a = 'c' 378 """).strip(), source.strip()) 379 380 def test_ext_slice_roundtrip(self): 381 def ext_slice(n): 382 return n[:, :], n[0, :], n[:, 0] 383 384 node, _ = parser.parse_entity(ext_slice, future_features=()) 385 source = parser.unparse(node) 386 self.assertAstMatches(node, source, expr=False) 387 388if __name__ == '__main__': 389 test.main() 390