• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for parser module."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import re
22import textwrap
23
24import gast
25
26from tensorflow.python.autograph.pyct import ast_util
27from tensorflow.python.autograph.pyct import errors
28from tensorflow.python.autograph.pyct import parser
29from tensorflow.python.autograph.pyct import pretty_printer
30from tensorflow.python.platform import test
31
32
33class ParserTest(test.TestCase):
34
35  def assertAstMatches(self, actual_node, expected_node_src, expr=True):
36    if expr:
37      # Ensure multi-line expressions parse.
38      expected_node = gast.parse('({})'.format(expected_node_src)).body[0]
39      expected_node = expected_node.value
40    else:
41      expected_node = gast.parse(expected_node_src).body[0]
42
43    msg = 'AST did not match expected:\n{}\nActual:\n{}'.format(
44        pretty_printer.fmt(expected_node),
45        pretty_printer.fmt(actual_node))
46    self.assertTrue(ast_util.matches(actual_node, expected_node), msg)
47
48  def test_parse_entity(self):
49
50    def f(x):
51      return x + 1
52
53    node, _ = parser.parse_entity(f, future_features=())
54    self.assertEqual('f', node.name)
55
56  def test_parse_lambda(self):
57
58    l = lambda x: x + 1
59    expected_node_src = 'lambda x: (x + 1)'
60
61    node, source = parser.parse_entity(l, future_features=())
62    self.assertAstMatches(node, source)
63    self.assertAstMatches(node, expected_node_src)
64
65  def test_parse_lambda_prefix_cleanup(self):
66
67    lambda_lam = lambda x: x + 1
68    expected_node_src = 'lambda x: (x + 1)'
69
70    node, source = parser.parse_entity(lambda_lam, future_features=())
71    self.assertAstMatches(node, source)
72    self.assertAstMatches(node, expected_node_src)
73
74  def test_parse_lambda_resolution_by_location(self):
75
76    _ = lambda x: x + 1
77    l = lambda x: x + 1
78    _ = lambda x: x + 1
79    expected_node_src = 'lambda x: (x + 1)'
80
81    node, source = parser.parse_entity(l, future_features=())
82    self.assertAstMatches(node, source)
83    self.assertAstMatches(node, expected_node_src)
84    self.assertEqual(source, 'lambda x: x + 1')
85
86  def test_parse_lambda_resolution_by_signature(self):
87
88    l = lambda x: lambda x, y: x + y
89
90    node, source = parser.parse_entity(l, future_features=())
91    expected_node_src = 'lambda x: (lambda x, y: (x + y))'
92    self.assertAstMatches(node, source)
93    self.assertAstMatches(node, expected_node_src)
94    self.assertEqual(source, 'lambda x: lambda x, y: x + y')
95
96    node, source = parser.parse_entity(l(0), future_features=())
97    expected_node_src = 'lambda x, y: (x + y)'
98    self.assertAstMatches(node, source)
99    self.assertAstMatches(node, expected_node_src)
100    self.assertEqual(source, 'lambda x, y: x + y')
101
102  def test_parse_lambda_resolution_ambiguous(self):
103
104    l = lambda x: lambda x: 2 * x
105
106    expected_exception_text = re.compile(r'found multiple definitions'
107                                         r'.+'
108                                         r'\(?lambda x: \(?lambda x'
109                                         r'.+'
110                                         r'\(?lambda x: \(?2', re.DOTALL)
111
112    with self.assertRaisesRegex(
113        errors.UnsupportedLanguageElementError,
114        expected_exception_text):
115      parser.parse_entity(l, future_features=())
116
117    with self.assertRaisesRegex(
118        errors.UnsupportedLanguageElementError,
119        expected_exception_text):
120      parser.parse_entity(l(0), future_features=())
121
122  def assertMatchesWithPotentialGarbage(self, source, expected, garbage):
123    # In runtimes which don't track end_col_number, the source contains the
124    # entire line, which in turn may have garbage from the surrounding context.
125    self.assertIn(source, (expected, expected + garbage))
126
127  def test_parse_lambda_multiline(self):
128
129    l = (
130        lambda x: lambda y: x + y  # pylint:disable=g-long-lambda
131        - 1)
132
133    node, source = parser.parse_entity(l, future_features=())
134    expected_node_src = 'lambda x: (lambda y: ((x + y) - 1))'
135    self.assertAstMatches(node, expected_node_src)
136    self.assertMatchesWithPotentialGarbage(
137        source, ('lambda x: lambda y: x + y  # pylint:disable=g-long-lambda\n'
138                 '        - 1'), ')')
139
140    node, source = parser.parse_entity(l(0), future_features=())
141    expected_node_src = 'lambda y: ((x + y) - 1)'
142    self.assertAstMatches(node, expected_node_src)
143    self.assertMatchesWithPotentialGarbage(
144        source, ('lambda y: x + y  # pylint:disable=g-long-lambda\n'
145                 '        - 1'), ')')
146
147  def test_parse_lambda_in_expression(self):
148
149    l = (
150        lambda x: lambda y: x + y + 1,
151        lambda x: lambda y: x + y + 2,
152        )
153
154    node, source = parser.parse_entity(l[0], future_features=())
155    expected_node_src = 'lambda x: (lambda y: ((x + y) + 1))'
156    self.assertAstMatches(node, expected_node_src)
157    self.assertMatchesWithPotentialGarbage(
158        source, 'lambda x: lambda y: x + y + 1', ',')
159
160    node, source = parser.parse_entity(l[0](0), future_features=())
161    expected_node_src = 'lambda y: ((x + y) + 1)'
162    self.assertAstMatches(node, expected_node_src)
163    self.assertMatchesWithPotentialGarbage(
164        source, 'lambda y: x + y + 1', ',')
165
166    node, source = parser.parse_entity(l[1], future_features=())
167    expected_node_src = 'lambda x: (lambda y: ((x + y) + 2))'
168    self.assertAstMatches(node, expected_node_src)
169    self.assertMatchesWithPotentialGarbage(source,
170                                           'lambda x: lambda y: x + y + 2', ',')
171
172    node, source = parser.parse_entity(l[1](0), future_features=())
173    expected_node_src = 'lambda y: ((x + y) + 2)'
174    self.assertAstMatches(node, expected_node_src)
175    self.assertMatchesWithPotentialGarbage(source, 'lambda y: x + y + 2', ',')
176
177  def test_parse_lambda_complex_body(self):
178
179    l = lambda x: (  # pylint:disable=g-long-lambda
180        x.y(
181            [],
182            x.z,
183            (),
184            x[0:2],
185        ),
186        x.u,
187        'abc',
188        1,
189    )
190
191    node, source = parser.parse_entity(l, future_features=())
192    expected_node_src = "lambda x: (x.y([], x.z, (), x[0:2]), x.u, 'abc', 1)"
193    self.assertAstMatches(node, expected_node_src)
194
195    base_source = ('lambda x: (  # pylint:disable=g-long-lambda\n'
196                   '        x.y(\n'
197                   '            [],\n'
198                   '            x.z,\n'
199                   '            (),\n'
200                   '            x[0:2],\n'
201                   '        ),\n'
202                   '        x.u,\n'
203                   '        \'abc\',\n'
204                   '        1,')
205    # The complete source includes the trailing parenthesis. But that is only
206    # detected in runtimes which correctly track end_lineno for ASTs.
207    self.assertMatchesWithPotentialGarbage(source, base_source, '\n    )')
208
209  def test_parse_lambda_function_call_definition(self):
210
211    def do_parse_and_test(lam, **unused_kwargs):
212      node, source = parser.parse_entity(lam, future_features=())
213      expected_node_src = 'lambda x: x'
214      self.assertAstMatches(node, expected_node_src)
215      self.assertMatchesWithPotentialGarbage(
216          source, 'lambda x: x', ', named_arg=1)')
217
218    do_parse_and_test(  # Intentional line break
219        lambda x: x, named_arg=1)
220
221  def test_parse_entity_print_function(self):
222
223    def f(x):
224      print(x)
225
226    node, _ = parser.parse_entity(f, future_features=('print_function',))
227    self.assertEqual('f', node.name)
228
229  def test_parse_comments(self):
230
231    def f():
232      # unindented comment
233      pass
234
235    node, _ = parser.parse_entity(f, future_features=())
236    self.assertEqual('f', node.name)
237
238  def test_parse_multiline_strings(self):
239
240    def f():
241      print("""
242multiline
243string""")
244
245    node, _ = parser.parse_entity(f, future_features=())
246    self.assertEqual('f', node.name)
247
248  def _eval_code(self, code, name):
249    globs = {}
250    exec(code, globs)  # pylint:disable=exec-used
251    return globs[name]
252
253  def test_dedent_block_basic(self):
254
255    code = """
256    def f(x):
257      if x > 0:
258        return -x
259      return x
260    """
261
262    f = self._eval_code(parser.dedent_block(code), 'f')
263    self.assertEqual(f(1), -1)
264    self.assertEqual(f(-1), -1)
265
266  def test_dedent_block_comments_out_of_line(self):
267
268    code = """
269  ###
270    def f(x):
271###
272      if x > 0:
273  ###
274        return -x
275          ###
276  ###
277      return x
278      ###
279    """
280
281    f = self._eval_code(parser.dedent_block(code), 'f')
282    self.assertEqual(f(1), -1)
283    self.assertEqual(f(-1), -1)
284
285  def test_dedent_block_multiline_string(self):
286
287    code = """
288    def f():
289      '''
290      Docstring.
291      '''
292      return '''
293  1
294    2
295      3'''
296    """
297
298    f = self._eval_code(parser.dedent_block(code), 'f')
299    self.assertEqual(f.__doc__, '\n      Docstring.\n      ')
300    self.assertEqual(f(), '\n  1\n    2\n      3')
301
302  def test_dedent_block_multiline_expression(self):
303
304    code = """
305    def f():
306      return (1,
3072,
308        3)
309    """
310
311    f = self._eval_code(parser.dedent_block(code), 'f')
312    self.assertEqual(f(), (1, 2, 3))
313
314  def test_dedent_block_continuation(self):
315
316    code = r"""
317    def f():
318      a = \
319          1
320      return a
321    """
322
323    f = self._eval_code(parser.dedent_block(code), 'f')
324    self.assertEqual(f(), 1)
325
326  def test_dedent_block_continuation_in_string(self):
327
328    code = r"""
329    def f():
330      a = "a \
331  b"
332      return a
333    """
334
335    f = self._eval_code(parser.dedent_block(code), 'f')
336    self.assertEqual(f(), 'a   b')
337
338  def test_parse_expression(self):
339    node = parser.parse_expression('a.b')
340    self.assertEqual('a', node.value.id)
341    self.assertEqual('b', node.attr)
342
343  def test_unparse(self):
344    node = gast.If(
345        test=gast.Constant(1, kind=None),
346        body=[
347            gast.Assign(
348                targets=[
349                    gast.Name(
350                        'a',
351                        ctx=gast.Store(),
352                        annotation=None,
353                        type_comment=None)
354                ],
355                value=gast.Name(
356                    'b', ctx=gast.Load(), annotation=None, type_comment=None))
357        ],
358        orelse=[
359            gast.Assign(
360                targets=[
361                    gast.Name(
362                        'a',
363                        ctx=gast.Store(),
364                        annotation=None,
365                        type_comment=None)
366                ],
367                value=gast.Constant('c', kind=None))
368        ])
369
370    source = parser.unparse(node, indentation='  ')
371    self.assertEqual(
372        textwrap.dedent("""
373            # coding=utf-8
374            if 1:
375                a = b
376            else:
377                a = 'c'
378        """).strip(), source.strip())
379
380  def test_ext_slice_roundtrip(self):
381    def ext_slice(n):
382      return n[:, :], n[0, :], n[:, 0]
383
384    node, _ = parser.parse_entity(ext_slice, future_features=())
385    source = parser.unparse(node)
386    self.assertAstMatches(node, source, expr=False)
387
388if __name__ == '__main__':
389  test.main()
390