1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Converting code to AST. 16 17Adapted from Tangent. 18""" 19 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24import ast 25import inspect 26import linecache 27import re 28import sys 29import textwrap 30import tokenize 31 32import astunparse 33import gast 34import six 35 36from tensorflow.python.autograph.pyct import errors 37from tensorflow.python.autograph.pyct import inspect_utils 38from tensorflow.python.util import tf_inspect 39 40 41PY2_PREAMBLE = textwrap.dedent(""" 42from __future__ import division 43from __future__ import print_function 44""") 45PY3_PREAMBLE = '' 46MAX_SIZE = 0 47 48if sys.version_info >= (3, 9): 49 astunparse = ast 50 51if sys.version_info >= (3,): 52 STANDARD_PREAMBLE = PY3_PREAMBLE 53 MAX_SIZE = sys.maxsize 54else: 55 STANDARD_PREAMBLE = PY2_PREAMBLE 56 MAX_SIZE = sys.maxint 57 58STANDARD_PREAMBLE_LEN = STANDARD_PREAMBLE.count('__future__') 59 60 61_LEADING_WHITESPACE = re.compile(r'\s*') 62 63 64def _unfold_continuations(code_string): 65 """Removes any backslash line continuations from the code.""" 66 return code_string.replace('\\\n', '') 67 68 69def dedent_block(code_string): 70 """Dedents a code so that its first line starts at row zero.""" 71 72 code_string = _unfold_continuations(code_string) 73 74 token_gen = tokenize.generate_tokens(six.StringIO(code_string).readline) 75 76 block_indentation = None 77 tokens = [] 78 try: 79 for tok in token_gen: 80 tokens.append(tok) 81 except tokenize.TokenError: 82 # Resolution of lambda functions may yield incomplete code, which can 83 # in turn generate this error. We silently ignore this error because the 84 # parser may still be able to deal with it. 85 pass 86 87 for tok in tokens: 88 tok_type, tok_string, _, _, _ = tok 89 if tok_type == tokenize.INDENT: 90 block_indentation = tok_string 91 block_level = len(block_indentation) 92 break 93 elif tok_type not in ( 94 tokenize.NL, tokenize.NEWLINE, tokenize.STRING, tokenize.COMMENT): 95 block_indentation = '' 96 break 97 98 if not block_indentation: 99 return code_string 100 101 block_level = len(block_indentation) 102 first_indent_uses_tabs = '\t' in block_indentation 103 for i, tok in enumerate(tokens): 104 tok_type, tok_string, _, _, _ = tok 105 if tok_type == tokenize.INDENT: 106 if ((' ' in tok_string and first_indent_uses_tabs) 107 or ('\t' in tok_string and not first_indent_uses_tabs)): 108 # TODO(mdan): We could attempt to convert tabs to spaces by unix rule. 109 # See: 110 # https://docs.python.org/3/reference/lexical_analysis.html#indentation 111 raise errors.UnsupportedLanguageElementError( 112 'code mixing tabs and spaces for indentation is not allowed') 113 if len(tok_string) >= block_level: 114 tok_string = tok_string[block_level:] 115 tokens[i] = (tok_type, tok_string) 116 117 new_code = tokenize.untokenize(tokens) 118 119 # Note: untokenize respects the line structure, but not the whitespace within 120 # lines. For example, `def foo()` may be untokenized as `def foo ()` 121 # So instead of using the output of dedent, we match the leading whitespace 122 # on each line. 123 dedented_code = [] 124 for line, new_line in zip(code_string.split('\n'), new_code.split('\n')): 125 original_indent = re.match(_LEADING_WHITESPACE, line).group() 126 new_indent = re.match(_LEADING_WHITESPACE, new_line).group() 127 if len(original_indent) > len(new_indent): 128 dedented_line = line[len(original_indent) - len(new_indent):] 129 else: 130 dedented_line = line 131 dedented_code.append(dedented_line) 132 new_code = '\n'.join(dedented_code) 133 134 return new_code 135 136 137def parse_entity(entity, future_features): 138 """Returns the AST and source code of given entity. 139 140 Args: 141 entity: Any, Python function/method/class 142 future_features: Iterable[Text], future features to use (e.g. 143 'print_statement'). See 144 https://docs.python.org/2/reference/simple_stmts.html#future 145 146 Returns: 147 gast.AST, Text: the parsed AST node; the source code that was parsed to 148 generate the AST (including any prefixes that this function may have added). 149 """ 150 if inspect_utils.islambda(entity): 151 return _parse_lambda(entity) 152 153 try: 154 original_source = inspect_utils.getimmediatesource(entity) 155 except (IOError, OSError) as e: 156 raise ValueError( 157 'Unable to locate the source code of {}. Note that functions defined' 158 ' in certain environments, like the interactive Python shell, do not' 159 ' expose their source code. If that is the case, you should define' 160 ' them in a .py source file. If you are certain the code is' 161 ' graph-compatible, wrap the call using' 162 ' @tf.autograph.experimental.do_not_convert. Original error: {}'.format( 163 entity, e)) 164 165 source = dedent_block(original_source) 166 167 future_statements = tuple( 168 'from __future__ import {}'.format(name) for name in future_features) 169 source = '\n'.join(future_statements + (source,)) 170 171 return parse(source, preamble_len=len(future_features)), source 172 173 174def _without_context(node, lines, minl, maxl): 175 """Returns a clean node and source code without indenting and context.""" 176 for n in gast.walk(node): 177 lineno = getattr(n, 'lineno', None) 178 if lineno is not None: 179 n.lineno = lineno - minl 180 end_lineno = getattr(n, 'end_lineno', None) 181 if end_lineno is not None: 182 n.end_lineno = end_lineno - minl 183 184 code_lines = lines[minl - 1:maxl] 185 186 # Attempt to clean up surrounding context code. 187 188 end_col_offset = getattr(node, 'end_col_offset', None) 189 if end_col_offset is not None: 190 # This is only available in 3.8. 191 code_lines[-1] = code_lines[-1][:end_col_offset] 192 193 col_offset = getattr(node, 'col_offset', None) 194 if col_offset is None: 195 # Older Python: try to find the "lambda" token. This is brittle. 196 match = re.search(r'(?<!\w)lambda(?!\w)', code_lines[0]) 197 if match is not None: 198 col_offset = match.start(0) 199 200 if col_offset is not None: 201 code_lines[0] = code_lines[0][col_offset:] 202 203 code_block = '\n'.join([c.rstrip() for c in code_lines]) 204 205 return node, code_block 206 207 208def _arg_name(node): 209 if node is None: 210 return None 211 if isinstance(node, gast.Name): 212 return node.id 213 assert isinstance(node, str) 214 return node 215 216 217def _node_matches_argspec(node, func): 218 """Returns True is node fits the argspec of func.""" 219 # TODO(mdan): Use just inspect once support for Python 2 is dropped. 220 arg_spec = tf_inspect.getfullargspec(func) 221 222 node_args = tuple(_arg_name(arg) for arg in node.args.args) 223 if node_args != tuple(arg_spec.args): 224 return False 225 226 if arg_spec.varargs != _arg_name(node.args.vararg): 227 return False 228 229 if arg_spec.varkw != _arg_name(node.args.kwarg): 230 return False 231 232 node_kwonlyargs = tuple(_arg_name(arg) for arg in node.args.kwonlyargs) 233 if node_kwonlyargs != tuple(arg_spec.kwonlyargs): 234 return False 235 236 return True 237 238 239def _parse_lambda(lam): 240 """Returns the AST and source code of given lambda function. 241 242 Args: 243 lam: types.LambdaType, Python function/method/class 244 245 Returns: 246 gast.AST, Text: the parsed AST node; the source code that was parsed to 247 generate the AST (including any prefixes that this function may have added). 248 """ 249 # TODO(mdan): Use a fast path if the definition is not multi-line. 250 # We could detect that the lambda is in a multi-line expression by looking 251 # at the surrounding code - an surrounding set of parentheses indicates a 252 # potential multi-line definition. 253 254 mod = inspect.getmodule(lam) 255 f = inspect.getsourcefile(lam) 256 def_line = lam.__code__.co_firstlineno 257 258 # This method is more robust that just calling inspect.getsource(mod), as it 259 # works in interactive shells, where getsource would fail. This is the 260 # same procedure followed by inspect for non-modules: 261 # https://github.com/python/cpython/blob/3.8/Lib/inspect.py#L772 262 lines = linecache.getlines(f, mod.__dict__) 263 source = ''.join(lines) 264 265 # Narrow down to the last node starting before our definition node. 266 all_nodes = parse(source, preamble_len=0, single_node=False) 267 search_nodes = [] 268 for node in all_nodes: 269 # Also include nodes without a line number, for safety. This is defensive - 270 # we don't know whether such nodes might exist, and if they do, whether 271 # they are not safe to skip. 272 # TODO(mdan): Replace this check with an assertion or skip such nodes. 273 if getattr(node, 'lineno', def_line) <= def_line: 274 search_nodes.append(node) 275 else: 276 # Found a node starting past our lambda - can stop the search. 277 break 278 279 # Extract all lambda nodes from the shortlist. 280 lambda_nodes = [] 281 for node in search_nodes: 282 lambda_nodes.extend( 283 n for n in gast.walk(node) if isinstance(n, gast.Lambda)) 284 285 # Filter down to lambda nodes which span our actual lambda. 286 candidates = [] 287 for ln in lambda_nodes: 288 minl, maxl = MAX_SIZE, 0 289 for n in gast.walk(ln): 290 minl = min(minl, getattr(n, 'lineno', minl)) 291 lineno = getattr(n, 'lineno', maxl) 292 end_lineno = getattr(n, 'end_lineno', None) 293 if end_lineno is not None: 294 # end_lineno is more precise, but lineno should almost always work too. 295 lineno = end_lineno 296 maxl = max(maxl, lineno) 297 if minl <= def_line <= maxl: 298 candidates.append((ln, minl, maxl)) 299 300 # Happy path: exactly one node found. 301 if len(candidates) == 1: 302 (node, minl, maxl), = candidates # pylint:disable=unbalanced-tuple-unpacking 303 return _without_context(node, lines, minl, maxl) 304 305 elif not candidates: 306 raise errors.UnsupportedLanguageElementError( 307 'could not parse the source code of {}:' 308 ' no matching AST found'.format(lam)) 309 310 # Attempt to narrow down selection by signature is multiple nodes are found. 311 matches = [v for v in candidates if _node_matches_argspec(v[0], lam)] 312 if len(matches) == 1: 313 (node, minl, maxl), = matches 314 return _without_context(node, lines, minl, maxl) 315 316 # Give up if could not narrow down to a single node. 317 matches = '\n'.join( 318 'Match {}:\n{}\n'.format(i, unparse(node, include_encoding_marker=False)) 319 for i, (node, _, _) in enumerate(matches)) 320 raise errors.UnsupportedLanguageElementError( 321 'could not parse the source code of {}: found multiple definitions with' 322 ' identical signatures at the location. This error' 323 ' may be avoided by defining each lambda on a single line and with' 324 ' unique argument names.\n{}'.format(lam, matches)) 325 326 327# TODO(mdan): This should take futures as input instead. 328def parse(src, preamble_len=0, single_node=True): 329 """Returns the AST of given piece of code. 330 331 Args: 332 src: Text 333 preamble_len: Int, indicates leading nodes in the parsed AST which should be 334 dropped. 335 single_node: Bool, whether `src` is assumed to be represented by exactly one 336 AST node. 337 338 Returns: 339 ast.AST 340 """ 341 module_node = gast.parse(src) 342 nodes = module_node.body 343 if preamble_len: 344 nodes = nodes[preamble_len:] 345 if single_node: 346 if len(nodes) != 1: 347 raise ValueError('expected exactly one node, found {}'.format(nodes)) 348 return nodes[0] 349 return nodes 350 351 352def parse_expression(src): 353 """Returns the AST of given identifier. 354 355 Args: 356 src: A piece of code that represents a single Python expression 357 Returns: 358 A gast.AST object. 359 Raises: 360 ValueError: if src does not consist of a single Expression. 361 """ 362 src = STANDARD_PREAMBLE + src.strip() 363 node = parse(src, preamble_len=STANDARD_PREAMBLE_LEN, single_node=True) 364 if __debug__: 365 if not isinstance(node, gast.Expr): 366 raise ValueError( 367 'expected a single expression, found instead {}'.format(node)) 368 return node.value 369 370 371def unparse(node, indentation=None, include_encoding_marker=True): 372 """Returns the source code of given AST. 373 374 Args: 375 node: The code to compile, as an AST object. 376 indentation: Unused, deprecated. The returning code will always be indented 377 at 4 spaces. 378 include_encoding_marker: Bool, whether to include a comment on the first 379 line to explicitly specify UTF-8 encoding. 380 381 Returns: 382 code: The source code generated from the AST object 383 source_mapping: A mapping between the user and AutoGraph generated code. 384 """ 385 del indentation # astunparse doesn't allow configuring it. 386 if not isinstance(node, (list, tuple)): 387 node = (node,) 388 389 codes = [] 390 if include_encoding_marker: 391 codes.append('# coding=utf-8') 392 for n in node: 393 if isinstance(n, gast.AST): 394 ast_n = gast.gast_to_ast(n) 395 else: 396 ast_n = n 397 398 if astunparse is ast: 399 ast.fix_missing_locations(ast_n) # Only ast needs to call this. 400 codes.append(astunparse.unparse(ast_n).strip()) 401 402 return '\n'.join(codes) 403