• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Converting code to AST.
16
17Adapted from Tangent.
18"""
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import ast
25import inspect
26import linecache
27import re
28import sys
29import textwrap
30import tokenize
31
32import astunparse
33import gast
34import six
35
36from tensorflow.python.autograph.pyct import errors
37from tensorflow.python.autograph.pyct import inspect_utils
38from tensorflow.python.util import tf_inspect
39
40
41PY2_PREAMBLE = textwrap.dedent("""
42from __future__ import division
43from __future__ import print_function
44""")
45PY3_PREAMBLE = ''
46MAX_SIZE = 0
47
48if sys.version_info >= (3, 9):
49  astunparse = ast
50
51if sys.version_info >= (3,):
52  STANDARD_PREAMBLE = PY3_PREAMBLE
53  MAX_SIZE = sys.maxsize
54else:
55  STANDARD_PREAMBLE = PY2_PREAMBLE
56  MAX_SIZE = sys.maxint
57
58STANDARD_PREAMBLE_LEN = STANDARD_PREAMBLE.count('__future__')
59
60
61_LEADING_WHITESPACE = re.compile(r'\s*')
62
63
64def _unfold_continuations(code_string):
65  """Removes any backslash line continuations from the code."""
66  return code_string.replace('\\\n', '')
67
68
69def dedent_block(code_string):
70  """Dedents a code so that its first line starts at row zero."""
71
72  code_string = _unfold_continuations(code_string)
73
74  token_gen = tokenize.generate_tokens(six.StringIO(code_string).readline)
75
76  block_indentation = None
77  tokens = []
78  try:
79    for tok in token_gen:
80      tokens.append(tok)
81  except tokenize.TokenError:
82    # Resolution of lambda functions may yield incomplete code, which can
83    # in turn generate this error. We silently ignore this error because the
84    # parser may still be able to deal with it.
85    pass
86
87  for tok in tokens:
88    tok_type, tok_string, _, _, _ = tok
89    if tok_type == tokenize.INDENT:
90      block_indentation = tok_string
91      block_level = len(block_indentation)
92      break
93    elif tok_type not in (
94        tokenize.NL, tokenize.NEWLINE, tokenize.STRING, tokenize.COMMENT):
95      block_indentation = ''
96      break
97
98  if not block_indentation:
99    return code_string
100
101  block_level = len(block_indentation)
102  first_indent_uses_tabs = '\t' in block_indentation
103  for i, tok in enumerate(tokens):
104    tok_type, tok_string, _, _, _ = tok
105    if tok_type == tokenize.INDENT:
106      if ((' ' in tok_string and first_indent_uses_tabs)
107          or ('\t' in tok_string and not first_indent_uses_tabs)):
108        # TODO(mdan): We could attempt to convert tabs to spaces by unix rule.
109        # See:
110        # https://docs.python.org/3/reference/lexical_analysis.html#indentation
111        raise errors.UnsupportedLanguageElementError(
112            'code mixing tabs and spaces for indentation is not allowed')
113      if len(tok_string) >= block_level:
114        tok_string = tok_string[block_level:]
115      tokens[i] = (tok_type, tok_string)
116
117  new_code = tokenize.untokenize(tokens)
118
119  # Note: untokenize respects the line structure, but not the whitespace within
120  # lines. For example, `def foo()` may be untokenized as `def foo ()`
121  # So instead of using the output of dedent, we match the leading whitespace
122  # on each line.
123  dedented_code = []
124  for line, new_line in zip(code_string.split('\n'), new_code.split('\n')):
125    original_indent = re.match(_LEADING_WHITESPACE, line).group()
126    new_indent = re.match(_LEADING_WHITESPACE, new_line).group()
127    if len(original_indent) > len(new_indent):
128      dedented_line = line[len(original_indent) - len(new_indent):]
129    else:
130      dedented_line = line
131    dedented_code.append(dedented_line)
132  new_code = '\n'.join(dedented_code)
133
134  return new_code
135
136
137def parse_entity(entity, future_features):
138  """Returns the AST and source code of given entity.
139
140  Args:
141    entity: Any, Python function/method/class
142    future_features: Iterable[Text], future features to use (e.g.
143      'print_statement'). See
144      https://docs.python.org/2/reference/simple_stmts.html#future
145
146  Returns:
147    gast.AST, Text: the parsed AST node; the source code that was parsed to
148    generate the AST (including any prefixes that this function may have added).
149  """
150  if inspect_utils.islambda(entity):
151    return _parse_lambda(entity)
152
153  try:
154    original_source = inspect_utils.getimmediatesource(entity)
155  except (IOError, OSError) as e:
156    raise ValueError(
157        'Unable to locate the source code of {}. Note that functions defined'
158        ' in certain environments, like the interactive Python shell, do not'
159        ' expose their source code. If that is the case, you should define'
160        ' them in a .py source file. If you are certain the code is'
161        ' graph-compatible, wrap the call using'
162        ' @tf.autograph.experimental.do_not_convert. Original error: {}'.format(
163            entity, e))
164
165  source = dedent_block(original_source)
166
167  future_statements = tuple(
168      'from __future__ import {}'.format(name) for name in future_features)
169  source = '\n'.join(future_statements + (source,))
170
171  return parse(source, preamble_len=len(future_features)), source
172
173
174def _without_context(node, lines, minl, maxl):
175  """Returns a clean node and source code without indenting and context."""
176  for n in gast.walk(node):
177    lineno = getattr(n, 'lineno', None)
178    if lineno is not None:
179      n.lineno = lineno - minl
180    end_lineno = getattr(n, 'end_lineno', None)
181    if end_lineno is not None:
182      n.end_lineno = end_lineno - minl
183
184  code_lines = lines[minl - 1:maxl]
185
186  # Attempt to clean up surrounding context code.
187
188  end_col_offset = getattr(node, 'end_col_offset', None)
189  if end_col_offset is not None:
190    # This is only available in 3.8.
191    code_lines[-1] = code_lines[-1][:end_col_offset]
192
193  col_offset = getattr(node, 'col_offset', None)
194  if col_offset is None:
195    # Older Python: try to find the "lambda" token. This is brittle.
196    match = re.search(r'(?<!\w)lambda(?!\w)', code_lines[0])
197    if match is not None:
198      col_offset = match.start(0)
199
200  if col_offset is not None:
201    code_lines[0] = code_lines[0][col_offset:]
202
203  code_block = '\n'.join([c.rstrip() for c in code_lines])
204
205  return node, code_block
206
207
208def _arg_name(node):
209  if node is None:
210    return None
211  if isinstance(node, gast.Name):
212    return node.id
213  assert isinstance(node, str)
214  return node
215
216
217def _node_matches_argspec(node, func):
218  """Returns True is node fits the argspec of func."""
219  # TODO(mdan): Use just inspect once support for Python 2 is dropped.
220  arg_spec = tf_inspect.getfullargspec(func)
221
222  node_args = tuple(_arg_name(arg) for arg in node.args.args)
223  if node_args != tuple(arg_spec.args):
224    return False
225
226  if arg_spec.varargs != _arg_name(node.args.vararg):
227    return False
228
229  if arg_spec.varkw != _arg_name(node.args.kwarg):
230    return False
231
232  node_kwonlyargs = tuple(_arg_name(arg) for arg in node.args.kwonlyargs)
233  if node_kwonlyargs != tuple(arg_spec.kwonlyargs):
234    return False
235
236  return True
237
238
239def _parse_lambda(lam):
240  """Returns the AST and source code of given lambda function.
241
242  Args:
243    lam: types.LambdaType, Python function/method/class
244
245  Returns:
246    gast.AST, Text: the parsed AST node; the source code that was parsed to
247    generate the AST (including any prefixes that this function may have added).
248  """
249  # TODO(mdan): Use a fast path if the definition is not multi-line.
250  # We could detect that the lambda is in a multi-line expression by looking
251  # at the surrounding code - an surrounding set of parentheses indicates a
252  # potential multi-line definition.
253
254  mod = inspect.getmodule(lam)
255  f = inspect.getsourcefile(lam)
256  def_line = lam.__code__.co_firstlineno
257
258  # This method is more robust that just calling inspect.getsource(mod), as it
259  # works in interactive shells, where getsource would fail. This is the
260  # same procedure followed by inspect for non-modules:
261  # https://github.com/python/cpython/blob/3.8/Lib/inspect.py#L772
262  lines = linecache.getlines(f, mod.__dict__)
263  source = ''.join(lines)
264
265  # Narrow down to the last node starting before our definition node.
266  all_nodes = parse(source, preamble_len=0, single_node=False)
267  search_nodes = []
268  for node in all_nodes:
269    # Also include nodes without a line number, for safety. This is defensive -
270    # we don't know whether such nodes might exist, and if they do, whether
271    # they are not safe to skip.
272    # TODO(mdan): Replace this check with an assertion or skip such nodes.
273    if getattr(node, 'lineno', def_line) <= def_line:
274      search_nodes.append(node)
275    else:
276      # Found a node starting past our lambda - can stop the search.
277      break
278
279  # Extract all lambda nodes from the shortlist.
280  lambda_nodes = []
281  for node in search_nodes:
282    lambda_nodes.extend(
283        n for n in gast.walk(node) if isinstance(n, gast.Lambda))
284
285  # Filter down to lambda nodes which span our actual lambda.
286  candidates = []
287  for ln in lambda_nodes:
288    minl, maxl = MAX_SIZE, 0
289    for n in gast.walk(ln):
290      minl = min(minl, getattr(n, 'lineno', minl))
291      lineno = getattr(n, 'lineno', maxl)
292      end_lineno = getattr(n, 'end_lineno', None)
293      if end_lineno is not None:
294        # end_lineno is more precise, but lineno should almost always work too.
295        lineno = end_lineno
296      maxl = max(maxl, lineno)
297    if minl <= def_line <= maxl:
298      candidates.append((ln, minl, maxl))
299
300  # Happy path: exactly one node found.
301  if len(candidates) == 1:
302    (node, minl, maxl), = candidates  # pylint:disable=unbalanced-tuple-unpacking
303    return _without_context(node, lines, minl, maxl)
304
305  elif not candidates:
306    raise errors.UnsupportedLanguageElementError(
307        'could not parse the source code of {}:'
308        ' no matching AST found'.format(lam))
309
310  # Attempt to narrow down selection by signature is multiple nodes are found.
311  matches = [v for v in candidates if _node_matches_argspec(v[0], lam)]
312  if len(matches) == 1:
313    (node, minl, maxl), = matches
314    return _without_context(node, lines, minl, maxl)
315
316  # Give up if could not narrow down to a single node.
317  matches = '\n'.join(
318      'Match {}:\n{}\n'.format(i, unparse(node, include_encoding_marker=False))
319      for i, (node, _, _) in enumerate(matches))
320  raise errors.UnsupportedLanguageElementError(
321      'could not parse the source code of {}: found multiple definitions with'
322      ' identical signatures at the location. This error'
323      ' may be avoided by defining each lambda on a single line and with'
324      ' unique argument names.\n{}'.format(lam, matches))
325
326
327# TODO(mdan): This should take futures as input instead.
328def parse(src, preamble_len=0, single_node=True):
329  """Returns the AST of given piece of code.
330
331  Args:
332    src: Text
333    preamble_len: Int, indicates leading nodes in the parsed AST which should be
334      dropped.
335    single_node: Bool, whether `src` is assumed to be represented by exactly one
336      AST node.
337
338  Returns:
339    ast.AST
340  """
341  module_node = gast.parse(src)
342  nodes = module_node.body
343  if preamble_len:
344    nodes = nodes[preamble_len:]
345  if single_node:
346    if len(nodes) != 1:
347      raise ValueError('expected exactly one node, found {}'.format(nodes))
348    return nodes[0]
349  return nodes
350
351
352def parse_expression(src):
353  """Returns the AST of given identifier.
354
355  Args:
356    src: A piece of code that represents a single Python expression
357  Returns:
358    A gast.AST object.
359  Raises:
360    ValueError: if src does not consist of a single Expression.
361  """
362  src = STANDARD_PREAMBLE + src.strip()
363  node = parse(src, preamble_len=STANDARD_PREAMBLE_LEN, single_node=True)
364  if __debug__:
365    if not isinstance(node, gast.Expr):
366      raise ValueError(
367          'expected a single expression, found instead {}'.format(node))
368  return node.value
369
370
371def unparse(node, indentation=None, include_encoding_marker=True):
372  """Returns the source code of given AST.
373
374  Args:
375    node: The code to compile, as an AST object.
376    indentation: Unused, deprecated. The returning code will always be indented
377      at 4 spaces.
378    include_encoding_marker: Bool, whether to include a comment on the first
379      line to explicitly specify UTF-8 encoding.
380
381  Returns:
382    code: The source code generated from the AST object
383    source_mapping: A mapping between the user and AutoGraph generated code.
384  """
385  del indentation  # astunparse doesn't allow configuring it.
386  if not isinstance(node, (list, tuple)):
387    node = (node,)
388
389  codes = []
390  if include_encoding_marker:
391    codes.append('# coding=utf-8')
392  for n in node:
393    if isinstance(n, gast.AST):
394      ast_n = gast.gast_to_ast(n)
395    else:
396      ast_n = n
397
398    if astunparse is ast:
399      ast.fix_missing_locations(ast_n)  # Only ast needs to call this.
400    codes.append(astunparse.unparse(ast_n).strip())
401
402  return '\n'.join(codes)
403