• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Converting code to AST.
16
17Adapted from Tangent.
18"""
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import inspect
25import linecache
26import re
27import sys
28import textwrap
29import tokenize
30
31import astunparse
32import gast
33import six
34
35from tensorflow.python.autograph.pyct import errors
36from tensorflow.python.autograph.pyct import inspect_utils
37from tensorflow.python.util import tf_inspect
38
39
40PY2_PREAMBLE = textwrap.dedent("""
41from __future__ import division
42from __future__ import print_function
43""")
44PY3_PREAMBLE = ''
45MAX_SIZE = 0
46
47if sys.version_info >= (3,):
48  STANDARD_PREAMBLE = PY3_PREAMBLE
49  MAX_SIZE = sys.maxsize
50else:
51  STANDARD_PREAMBLE = PY2_PREAMBLE
52  MAX_SIZE = sys.maxint
53
54STANDARD_PREAMBLE_LEN = STANDARD_PREAMBLE.count('__future__')
55
56
57_LEADING_WHITESPACE = re.compile(r'\s*')
58
59
60def _unfold_continuations(code_string):
61  """Removes any backslash line continuations from the code."""
62  return code_string.replace('\\\n', '')
63
64
65def dedent_block(code_string):
66  """Dedents a code so that its first line starts at row zero."""
67
68  code_string = _unfold_continuations(code_string)
69
70  token_gen = tokenize.generate_tokens(six.StringIO(code_string).readline)
71
72  block_indentation = None
73  tokens = []
74  try:
75    for tok in token_gen:
76      tokens.append(tok)
77  except tokenize.TokenError:
78    # Resolution of lambda functions may yield incomplete code, which can
79    # in turn generate this error. We silently ignore this error because the
80    # parser may still be able to deal with it.
81    pass
82
83  for tok in tokens:
84    tok_type, tok_string, _, _, _ = tok
85    if tok_type == tokenize.INDENT:
86      block_indentation = tok_string
87      block_level = len(block_indentation)
88      break
89    elif tok_type not in (
90        tokenize.NL, tokenize.NEWLINE, tokenize.STRING, tokenize.COMMENT):
91      block_indentation = ''
92      break
93
94  if not block_indentation:
95    return code_string
96
97  block_level = len(block_indentation)
98  first_indent_uses_tabs = '\t' in block_indentation
99  for i, tok in enumerate(tokens):
100    tok_type, tok_string, _, _, _ = tok
101    if tok_type == tokenize.INDENT:
102      if ((' ' in tok_string and first_indent_uses_tabs)
103          or ('\t' in tok_string and not first_indent_uses_tabs)):
104        # TODO(mdan): We could attempt to convert tabs to spaces by unix rule.
105        # See:
106        # https://docs.python.org/3/reference/lexical_analysis.html#indentation
107        raise errors.UnsupportedLanguageElementError(
108            'code mixing tabs and spaces for indentation is not allowed')
109      if len(tok_string) >= block_level:
110        tok_string = tok_string[block_level:]
111      tokens[i] = (tok_type, tok_string)
112
113  new_code = tokenize.untokenize(tokens)
114
115  # Note: untokenize respects the line structure, but not the whitespace within
116  # lines. For example, `def foo()` may be untokenized as `def foo ()`
117  # So instead of using the output of dedent, we match the leading whitespace
118  # on each line.
119  dedented_code = []
120  for line, new_line in zip(code_string.split('\n'), new_code.split('\n')):
121    original_indent = re.match(_LEADING_WHITESPACE, line).group()
122    new_indent = re.match(_LEADING_WHITESPACE, new_line).group()
123    if len(original_indent) > len(new_indent):
124      dedented_line = line[len(original_indent) - len(new_indent):]
125    else:
126      dedented_line = line
127    dedented_code.append(dedented_line)
128  new_code = '\n'.join(dedented_code)
129
130  return new_code
131
132
133def parse_entity(entity, future_features):
134  """Returns the AST and source code of given entity.
135
136  Args:
137    entity: Any, Python function/method/class
138    future_features: Iterable[Text], future features to use (e.g.
139      'print_statement'). See
140      https://docs.python.org/2/reference/simple_stmts.html#future
141
142  Returns:
143    gast.AST, Text: the parsed AST node; the source code that was parsed to
144    generate the AST (including any prefixes that this function may have added).
145  """
146  if inspect_utils.islambda(entity):
147    return _parse_lambda(entity)
148
149  try:
150    original_source = inspect_utils.getimmediatesource(entity)
151  except (IOError, OSError) as e:
152    raise ValueError(
153        'Unable to locate the source code of {}. Note that functions defined'
154        ' in certain environments, like the interactive Python shell do not'
155        ' expose their source code. If that is the case, you should to define'
156        ' them in a .py source file. If you are certain the code is'
157        ' graph-compatible, wrap the call using'
158        ' @tf.autograph.do_not_convert. Original error: {}'.format(entity, e))
159
160  source = dedent_block(original_source)
161
162  future_statements = tuple(
163      'from __future__ import {}'.format(name) for name in future_features)
164  source = '\n'.join(future_statements + (source,))
165
166  return parse(source, preamble_len=len(future_features)), source
167
168
169def _without_context(node, lines, minl, maxl):
170  """Returns a clean node and source code without indenting and context."""
171  for n in gast.walk(node):
172    lineno = getattr(n, 'lineno', None)
173    if lineno is not None:
174      n.lineno = lineno - minl
175    end_lineno = getattr(n, 'end_lineno', None)
176    if end_lineno is not None:
177      n.end_lineno = end_lineno - minl
178
179  code_lines = lines[minl - 1:maxl]
180
181  # Attempt to clean up surrounding context code.
182
183  end_col_offset = getattr(node, 'end_col_offset', None)
184  if end_col_offset is not None:
185    # This is only available in 3.8.
186    code_lines[-1] = code_lines[-1][:end_col_offset]
187
188  col_offset = getattr(node, 'col_offset', None)
189  if col_offset is None:
190    # Older Python: try to find the "lambda" token. This is brittle.
191    match = re.search(r'(?<!\w)lambda(?!\w)', code_lines[0])
192    if match is not None:
193      col_offset = match.start(0)
194
195  if col_offset is not None:
196    code_lines[0] = code_lines[0][col_offset:]
197
198  code_block = '\n'.join([c.rstrip() for c in code_lines])
199
200  return node, code_block
201
202
203def _arg_name(node):
204  if node is None:
205    return None
206  if isinstance(node, gast.Name):
207    return node.id
208  assert isinstance(node, str)
209  return node
210
211
212def _node_matches_argspec(node, func):
213  """Returns True is node fits the argspec of func."""
214  # TODO(mdan): Use just inspect once support for Python 2 is dropped.
215  arg_spec = tf_inspect.getfullargspec(func)
216
217  node_args = tuple(_arg_name(arg) for arg in node.args.args)
218  if node_args != tuple(arg_spec.args):
219    return False
220
221  if arg_spec.varargs != _arg_name(node.args.vararg):
222    return False
223
224  if arg_spec.varkw != _arg_name(node.args.kwarg):
225    return False
226
227  node_kwonlyargs = tuple(_arg_name(arg) for arg in node.args.kwonlyargs)
228  if node_kwonlyargs != tuple(arg_spec.kwonlyargs):
229    return False
230
231  return True
232
233
234def _parse_lambda(lam):
235  """Returns the AST and source code of given lambda function.
236
237  Args:
238    lam: types.LambdaType, Python function/method/class
239
240  Returns:
241    gast.AST, Text: the parsed AST node; the source code that was parsed to
242    generate the AST (including any prefixes that this function may have added).
243  """
244  # TODO(mdan): Use a fast path if the definition is not multi-line.
245  # We could detect that the lambda is in a multi-line expression by looking
246  # at the surrounding code - an surrounding set of parentheses indicates a
247  # potential multi-line definition.
248
249  mod = inspect.getmodule(lam)
250  f = inspect.getsourcefile(lam)
251  def_line = lam.__code__.co_firstlineno
252
253  # This method is more robust that just calling inspect.getsource(mod), as it
254  # works in interactive shells, where getsource would fail. This is the
255  # same procedure followed by inspect for non-modules:
256  # https://github.com/python/cpython/blob/3.8/Lib/inspect.py#L772
257  lines = linecache.getlines(f, mod.__dict__)
258  source = ''.join(lines)
259
260  # Narrow down to the last node starting before our definition node.
261  all_nodes = parse(source, preamble_len=0, single_node=False)
262  search_nodes = []
263  for node in all_nodes:
264    # Also include nodes without a line number, for safety. This is defensive -
265    # we don't know whether such nodes might exist, and if they do, whether
266    # they are not safe to skip.
267    # TODO(mdan): Replace this check with an assertion or skip such nodes.
268    if getattr(node, 'lineno', def_line) <= def_line:
269      search_nodes.append(node)
270    else:
271      # Found a node starting past our lambda - can stop the search.
272      break
273
274  # Extract all lambda nodes from the shortlist.
275  lambda_nodes = []
276  for node in search_nodes:
277    lambda_nodes.extend(
278        n for n in gast.walk(node) if isinstance(n, gast.Lambda))
279
280  # Filter down to lambda nodes which span our actual lambda.
281  candidates = []
282  for ln in lambda_nodes:
283    minl, maxl = MAX_SIZE, 0
284    for n in gast.walk(ln):
285      minl = min(minl, getattr(n, 'lineno', minl))
286      lineno = getattr(n, 'lineno', maxl)
287      end_lineno = getattr(n, 'end_lineno', None)
288      if end_lineno is not None:
289        # end_lineno is more precise, but lineno should almost always work too.
290        lineno = end_lineno
291      maxl = max(maxl, lineno)
292    if minl <= def_line <= maxl:
293      candidates.append((ln, minl, maxl))
294
295  # Happy path: exactly one node found.
296  if len(candidates) == 1:
297    (node, minl, maxl), = candidates  # pylint:disable=unbalanced-tuple-unpacking
298    return _without_context(node, lines, minl, maxl)
299
300  elif not candidates:
301    raise errors.UnsupportedLanguageElementError(
302        'could not parse the source code of {}:'
303        ' no matching AST found'.format(lam))
304
305  # Attempt to narrow down selection by signature is multiple nodes are found.
306  matches = [v for v in candidates if _node_matches_argspec(v[0], lam)]
307  if len(matches) == 1:
308    (node, minl, maxl), = matches
309    return _without_context(node, lines, minl, maxl)
310
311  # Give up if could not narrow down to a single node.
312  matches = '\n'.join(
313      'Match {}:\n{}\n'.format(i, unparse(node, include_encoding_marker=False))
314      for i, (node, _, _) in enumerate(matches))
315  raise errors.UnsupportedLanguageElementError(
316      'could not parse the source code of {}: found multiple definitions with'
317      ' identical signatures at the location. This error'
318      ' may be avoided by defining each lambda on a single line and with'
319      ' unique argument names.\n{}'.format(lam, matches))
320
321
322# TODO(mdan): This should take futures as input instead.
323def parse(src, preamble_len=0, single_node=True):
324  """Returns the AST of given piece of code.
325
326  Args:
327    src: Text
328    preamble_len: Int, indicates leading nodes in the parsed AST which should be
329      dropped.
330    single_node: Bool, whether `src` is assumed to be represented by exactly one
331      AST node.
332
333  Returns:
334    ast.AST
335  """
336  module_node = gast.parse(src)
337  nodes = module_node.body
338  if preamble_len:
339    nodes = nodes[preamble_len:]
340  if single_node:
341    if len(nodes) != 1:
342      raise ValueError('expected exactly one node, found {}'.format(nodes))
343    return nodes[0]
344  return nodes
345
346
347def parse_expression(src):
348  """Returns the AST of given identifier.
349
350  Args:
351    src: A piece of code that represents a single Python expression
352  Returns:
353    A gast.AST object.
354  Raises:
355    ValueError: if src does not consist of a single Expression.
356  """
357  src = STANDARD_PREAMBLE + src.strip()
358  node = parse(src, preamble_len=STANDARD_PREAMBLE_LEN, single_node=True)
359  if __debug__:
360    if not isinstance(node, gast.Expr):
361      raise ValueError(
362          'expected a single expression, found instead {}'.format(node))
363  return node.value
364
365
366def unparse(node, indentation=None, include_encoding_marker=True):
367  """Returns the source code of given AST.
368
369  Args:
370    node: The code to compile, as an AST object.
371    indentation: Unused, deprecated. The returning code will always be indented
372      at 4 spaces.
373    include_encoding_marker: Bool, whether to include a comment on the first
374      line to explicitly specify UTF-8 encoding.
375
376  Returns:
377    code: The source code generated from the AST object
378    source_mapping: A mapping between the user and AutoGraph generated code.
379  """
380  del indentation  # astunparse doesn't allow configuring it.
381  if not isinstance(node, (list, tuple)):
382    node = (node,)
383
384  codes = []
385  if include_encoding_marker:
386    codes.append('# coding=utf-8')
387  for n in node:
388    if isinstance(n, gast.AST):
389      n = gast.gast_to_ast(n)
390    codes.append(astunparse.unparse(n).strip())
391
392  return '\n'.join(codes)
393