1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Converting code to AST. 16 17Adapted from Tangent. 18""" 19 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24import inspect 25import linecache 26import re 27import sys 28import textwrap 29import tokenize 30 31import astunparse 32import gast 33import six 34 35from tensorflow.python.autograph.pyct import errors 36from tensorflow.python.autograph.pyct import inspect_utils 37from tensorflow.python.util import tf_inspect 38 39 40PY2_PREAMBLE = textwrap.dedent(""" 41from __future__ import division 42from __future__ import print_function 43""") 44PY3_PREAMBLE = '' 45MAX_SIZE = 0 46 47if sys.version_info >= (3,): 48 STANDARD_PREAMBLE = PY3_PREAMBLE 49 MAX_SIZE = sys.maxsize 50else: 51 STANDARD_PREAMBLE = PY2_PREAMBLE 52 MAX_SIZE = sys.maxint 53 54STANDARD_PREAMBLE_LEN = STANDARD_PREAMBLE.count('__future__') 55 56 57_LEADING_WHITESPACE = re.compile(r'\s*') 58 59 60def _unfold_continuations(code_string): 61 """Removes any backslash line continuations from the code.""" 62 return code_string.replace('\\\n', '') 63 64 65def dedent_block(code_string): 66 """Dedents a code so that its first line starts at row zero.""" 67 68 code_string = _unfold_continuations(code_string) 69 70 token_gen = tokenize.generate_tokens(six.StringIO(code_string).readline) 71 72 block_indentation = None 73 tokens = [] 74 try: 75 for tok in token_gen: 76 tokens.append(tok) 77 except tokenize.TokenError: 78 # Resolution of lambda functions may yield incomplete code, which can 79 # in turn generate this error. We silently ignore this error because the 80 # parser may still be able to deal with it. 81 pass 82 83 for tok in tokens: 84 tok_type, tok_string, _, _, _ = tok 85 if tok_type == tokenize.INDENT: 86 block_indentation = tok_string 87 block_level = len(block_indentation) 88 break 89 elif tok_type not in ( 90 tokenize.NL, tokenize.NEWLINE, tokenize.STRING, tokenize.COMMENT): 91 block_indentation = '' 92 break 93 94 if not block_indentation: 95 return code_string 96 97 block_level = len(block_indentation) 98 first_indent_uses_tabs = '\t' in block_indentation 99 for i, tok in enumerate(tokens): 100 tok_type, tok_string, _, _, _ = tok 101 if tok_type == tokenize.INDENT: 102 if ((' ' in tok_string and first_indent_uses_tabs) 103 or ('\t' in tok_string and not first_indent_uses_tabs)): 104 # TODO(mdan): We could attempt to convert tabs to spaces by unix rule. 105 # See: 106 # https://docs.python.org/3/reference/lexical_analysis.html#indentation 107 raise errors.UnsupportedLanguageElementError( 108 'code mixing tabs and spaces for indentation is not allowed') 109 if len(tok_string) >= block_level: 110 tok_string = tok_string[block_level:] 111 tokens[i] = (tok_type, tok_string) 112 113 new_code = tokenize.untokenize(tokens) 114 115 # Note: untokenize respects the line structure, but not the whitespace within 116 # lines. For example, `def foo()` may be untokenized as `def foo ()` 117 # So instead of using the output of dedent, we match the leading whitespace 118 # on each line. 119 dedented_code = [] 120 for line, new_line in zip(code_string.split('\n'), new_code.split('\n')): 121 original_indent = re.match(_LEADING_WHITESPACE, line).group() 122 new_indent = re.match(_LEADING_WHITESPACE, new_line).group() 123 if len(original_indent) > len(new_indent): 124 dedented_line = line[len(original_indent) - len(new_indent):] 125 else: 126 dedented_line = line 127 dedented_code.append(dedented_line) 128 new_code = '\n'.join(dedented_code) 129 130 return new_code 131 132 133def parse_entity(entity, future_features): 134 """Returns the AST and source code of given entity. 135 136 Args: 137 entity: Any, Python function/method/class 138 future_features: Iterable[Text], future features to use (e.g. 139 'print_statement'). See 140 https://docs.python.org/2/reference/simple_stmts.html#future 141 142 Returns: 143 gast.AST, Text: the parsed AST node; the source code that was parsed to 144 generate the AST (including any prefixes that this function may have added). 145 """ 146 if inspect_utils.islambda(entity): 147 return _parse_lambda(entity) 148 149 try: 150 original_source = inspect_utils.getimmediatesource(entity) 151 except (IOError, OSError) as e: 152 raise ValueError( 153 'Unable to locate the source code of {}. Note that functions defined' 154 ' in certain environments, like the interactive Python shell do not' 155 ' expose their source code. If that is the case, you should to define' 156 ' them in a .py source file. If you are certain the code is' 157 ' graph-compatible, wrap the call using' 158 ' @tf.autograph.do_not_convert. Original error: {}'.format(entity, e)) 159 160 source = dedent_block(original_source) 161 162 future_statements = tuple( 163 'from __future__ import {}'.format(name) for name in future_features) 164 source = '\n'.join(future_statements + (source,)) 165 166 return parse(source, preamble_len=len(future_features)), source 167 168 169def _without_context(node, lines, minl, maxl): 170 """Returns a clean node and source code without indenting and context.""" 171 for n in gast.walk(node): 172 lineno = getattr(n, 'lineno', None) 173 if lineno is not None: 174 n.lineno = lineno - minl 175 end_lineno = getattr(n, 'end_lineno', None) 176 if end_lineno is not None: 177 n.end_lineno = end_lineno - minl 178 179 code_lines = lines[minl - 1:maxl] 180 181 # Attempt to clean up surrounding context code. 182 183 end_col_offset = getattr(node, 'end_col_offset', None) 184 if end_col_offset is not None: 185 # This is only available in 3.8. 186 code_lines[-1] = code_lines[-1][:end_col_offset] 187 188 col_offset = getattr(node, 'col_offset', None) 189 if col_offset is None: 190 # Older Python: try to find the "lambda" token. This is brittle. 191 match = re.search(r'(?<!\w)lambda(?!\w)', code_lines[0]) 192 if match is not None: 193 col_offset = match.start(0) 194 195 if col_offset is not None: 196 code_lines[0] = code_lines[0][col_offset:] 197 198 code_block = '\n'.join([c.rstrip() for c in code_lines]) 199 200 return node, code_block 201 202 203def _arg_name(node): 204 if node is None: 205 return None 206 if isinstance(node, gast.Name): 207 return node.id 208 assert isinstance(node, str) 209 return node 210 211 212def _node_matches_argspec(node, func): 213 """Returns True is node fits the argspec of func.""" 214 # TODO(mdan): Use just inspect once support for Python 2 is dropped. 215 arg_spec = tf_inspect.getfullargspec(func) 216 217 node_args = tuple(_arg_name(arg) for arg in node.args.args) 218 if node_args != tuple(arg_spec.args): 219 return False 220 221 if arg_spec.varargs != _arg_name(node.args.vararg): 222 return False 223 224 if arg_spec.varkw != _arg_name(node.args.kwarg): 225 return False 226 227 node_kwonlyargs = tuple(_arg_name(arg) for arg in node.args.kwonlyargs) 228 if node_kwonlyargs != tuple(arg_spec.kwonlyargs): 229 return False 230 231 return True 232 233 234def _parse_lambda(lam): 235 """Returns the AST and source code of given lambda function. 236 237 Args: 238 lam: types.LambdaType, Python function/method/class 239 240 Returns: 241 gast.AST, Text: the parsed AST node; the source code that was parsed to 242 generate the AST (including any prefixes that this function may have added). 243 """ 244 # TODO(mdan): Use a fast path if the definition is not multi-line. 245 # We could detect that the lambda is in a multi-line expression by looking 246 # at the surrounding code - an surrounding set of parentheses indicates a 247 # potential multi-line definition. 248 249 mod = inspect.getmodule(lam) 250 f = inspect.getsourcefile(lam) 251 def_line = lam.__code__.co_firstlineno 252 253 # This method is more robust that just calling inspect.getsource(mod), as it 254 # works in interactive shells, where getsource would fail. This is the 255 # same procedure followed by inspect for non-modules: 256 # https://github.com/python/cpython/blob/3.8/Lib/inspect.py#L772 257 lines = linecache.getlines(f, mod.__dict__) 258 source = ''.join(lines) 259 260 # Narrow down to the last node starting before our definition node. 261 all_nodes = parse(source, preamble_len=0, single_node=False) 262 search_nodes = [] 263 for node in all_nodes: 264 # Also include nodes without a line number, for safety. This is defensive - 265 # we don't know whether such nodes might exist, and if they do, whether 266 # they are not safe to skip. 267 # TODO(mdan): Replace this check with an assertion or skip such nodes. 268 if getattr(node, 'lineno', def_line) <= def_line: 269 search_nodes.append(node) 270 else: 271 # Found a node starting past our lambda - can stop the search. 272 break 273 274 # Extract all lambda nodes from the shortlist. 275 lambda_nodes = [] 276 for node in search_nodes: 277 lambda_nodes.extend( 278 n for n in gast.walk(node) if isinstance(n, gast.Lambda)) 279 280 # Filter down to lambda nodes which span our actual lambda. 281 candidates = [] 282 for ln in lambda_nodes: 283 minl, maxl = MAX_SIZE, 0 284 for n in gast.walk(ln): 285 minl = min(minl, getattr(n, 'lineno', minl)) 286 lineno = getattr(n, 'lineno', maxl) 287 end_lineno = getattr(n, 'end_lineno', None) 288 if end_lineno is not None: 289 # end_lineno is more precise, but lineno should almost always work too. 290 lineno = end_lineno 291 maxl = max(maxl, lineno) 292 if minl <= def_line <= maxl: 293 candidates.append((ln, minl, maxl)) 294 295 # Happy path: exactly one node found. 296 if len(candidates) == 1: 297 (node, minl, maxl), = candidates # pylint:disable=unbalanced-tuple-unpacking 298 return _without_context(node, lines, minl, maxl) 299 300 elif not candidates: 301 raise errors.UnsupportedLanguageElementError( 302 'could not parse the source code of {}:' 303 ' no matching AST found'.format(lam)) 304 305 # Attempt to narrow down selection by signature is multiple nodes are found. 306 matches = [v for v in candidates if _node_matches_argspec(v[0], lam)] 307 if len(matches) == 1: 308 (node, minl, maxl), = matches 309 return _without_context(node, lines, minl, maxl) 310 311 # Give up if could not narrow down to a single node. 312 matches = '\n'.join( 313 'Match {}:\n{}\n'.format(i, unparse(node, include_encoding_marker=False)) 314 for i, (node, _, _) in enumerate(matches)) 315 raise errors.UnsupportedLanguageElementError( 316 'could not parse the source code of {}: found multiple definitions with' 317 ' identical signatures at the location. This error' 318 ' may be avoided by defining each lambda on a single line and with' 319 ' unique argument names.\n{}'.format(lam, matches)) 320 321 322# TODO(mdan): This should take futures as input instead. 323def parse(src, preamble_len=0, single_node=True): 324 """Returns the AST of given piece of code. 325 326 Args: 327 src: Text 328 preamble_len: Int, indicates leading nodes in the parsed AST which should be 329 dropped. 330 single_node: Bool, whether `src` is assumed to be represented by exactly one 331 AST node. 332 333 Returns: 334 ast.AST 335 """ 336 module_node = gast.parse(src) 337 nodes = module_node.body 338 if preamble_len: 339 nodes = nodes[preamble_len:] 340 if single_node: 341 if len(nodes) != 1: 342 raise ValueError('expected exactly one node, found {}'.format(nodes)) 343 return nodes[0] 344 return nodes 345 346 347def parse_expression(src): 348 """Returns the AST of given identifier. 349 350 Args: 351 src: A piece of code that represents a single Python expression 352 Returns: 353 A gast.AST object. 354 Raises: 355 ValueError: if src does not consist of a single Expression. 356 """ 357 src = STANDARD_PREAMBLE + src.strip() 358 node = parse(src, preamble_len=STANDARD_PREAMBLE_LEN, single_node=True) 359 if __debug__: 360 if not isinstance(node, gast.Expr): 361 raise ValueError( 362 'expected a single expression, found instead {}'.format(node)) 363 return node.value 364 365 366def unparse(node, indentation=None, include_encoding_marker=True): 367 """Returns the source code of given AST. 368 369 Args: 370 node: The code to compile, as an AST object. 371 indentation: Unused, deprecated. The returning code will always be indented 372 at 4 spaces. 373 include_encoding_marker: Bool, whether to include a comment on the first 374 line to explicitly specify UTF-8 encoding. 375 376 Returns: 377 code: The source code generated from the AST object 378 source_mapping: A mapping between the user and AutoGraph generated code. 379 """ 380 del indentation # astunparse doesn't allow configuring it. 381 if not isinstance(node, (list, tuple)): 382 node = (node,) 383 384 codes = [] 385 if include_encoding_marker: 386 codes.append('# coding=utf-8') 387 for n in node: 388 if isinstance(n, gast.AST): 389 n = gast.gast_to_ast(n) 390 codes.append(astunparse.unparse(n).strip()) 391 392 return '\n'.join(codes) 393