1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A node transformer that includes utilities for SCT.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import enum 23 24import gast 25 26from tensorflow.python.autograph.pyct import anno 27from tensorflow.python.autograph.pyct import parser 28from tensorflow.python.autograph.pyct import pretty_printer 29from tensorflow.python.autograph.pyct import templates 30 31 32class AnalysisLevel(enum.IntEnum): 33 34 NONE = 0 35 ACTIVITY = 1 36 DEFINEDNESS = 2 37 LIVENESS = 3 38 39 40# TODO(znado): Use namedtuple. 41class Context(object): 42 """Contains information about a source code transformation. 43 44 This object is mutable, and is updated during conversion. Not thread safe. 45 46 Attributes: 47 info: EntityInfo, immutable. 48 namer: naming.Namer. 49 current_origin: origin_info.OriginInfo, holds the OriginInfo of the last 50 AST node to be processed successfully. Useful for error handling. 51 user: An user-supplied context object. The object is opaque to the 52 infrastructure, but will pe passed through to all custom transformations. 53 """ 54 55 def __init__(self, info, namer, user_context): 56 self.info = info 57 self.namer = namer 58 self.current_origin = None 59 self.user = user_context 60 61 62# TODO(mdan): Move to a standalone file. 63class EntityInfo( 64 collections.namedtuple( 65 'EntityInfo', 66 ('name', 'source_code', 'source_file', 'future_features', 'namespace')) 67): 68 """Contains information about a Python entity. 69 70 Immutable. 71 72 Examples of entities include functions and classes. 73 74 Attributes: 75 name: The name that identifies this entity. 76 source_code: The entity's source code. 77 source_file: The entity's source file. 78 future_features: Tuple[Text], the future features that this entity was 79 compiled with. See 80 https://docs.python.org/2/reference/simple_stmts.html#future. 81 namespace: Dict[str, ], containing symbols visible to the entity (excluding 82 parameters). 83 """ 84 pass 85 86 87class _StateStack(object): 88 """Templated context manager. 89 90 This class provides syntactic sugar for a stack of objects of known 91 type. It allows accessing attributes of the object at the top of the stack 92 directly against this object, which allows for very terse syntax. 93 94 For example, this code: 95 96 stack = _StateStack(Foo) 97 stack.enter() 98 stack.bar 99 100 Is equivalent to: 101 102 stack = [] 103 stack.append(Foo()) 104 foo = stack[-1] 105 foo.bar 106 107 See _State for more on how this is used. 108 109 Attributes: 110 type: Any, the type of objects that this stack holds 111 level: int, the current stack depth 112 stack: List[Any], the actual stack 113 value: Any, the instance of the object at the top of the stack 114 """ 115 116 def __init__(self, type_): 117 # Because we override __setattr__, we need to attach these attributes using 118 # the superclass' setattr. 119 object.__setattr__(self, 'type', type_) 120 object.__setattr__(self, '_stack', []) 121 if not hasattr(type_, 'no_root'): 122 self.enter() 123 124 def __enter__(self): 125 self.enter() 126 return self 127 128 def __exit__(self, exc_type, exc_value, traceback): 129 self.exit() 130 131 def enter(self): 132 self._stack.append(self.type()) 133 134 def exit(self): 135 self._stack.pop() 136 137 @property 138 def stack(self): 139 return self._stack 140 141 @property 142 def level(self): 143 return len(self._stack) 144 145 @property 146 def value(self): 147 return self._stack[-1] 148 149 def __iter__(self): 150 return iter(self._stack) 151 152 def __getattr__(self, key): 153 return getattr(self._stack[-1], key) 154 155 def __setattr__(self, key, value): 156 setattr(self._stack[-1], key, value) 157 158 159class _State(object): 160 """Syntactic sugar for accessing an instance of a StateStack context manager. 161 162 This structure offers syntactic sugar over a dict of stacks of objects 163 of known type. These structures are useful to keep state during AST walks. 164 Multiple different scopes can be tracked in parallel. For example: 165 166 s = _State() 167 168 s[foo].enter() 169 s[bar].enter() # this will not affect s[foo] 170 171 Element access has special semantics: 172 * keys are a data type 173 * element values are _StateStack(type=key) objects 174 * missing elements are automatically added, similarly to defaultdict 175 176 For example, the following block : 177 178 _State s 179 s[Foo] 180 181 Is equivalent to: 182 183 s = {} 184 if Foo not in s: 185 s[Foo] = Foo() 186 s[Foo] 187 188 See Base for how it's used. 189 """ 190 191 def __init__(self): 192 self._value = {} 193 194 def __getitem__(self, key): 195 if key not in self._value: 196 self._value[key] = _StateStack(key) 197 return self._value[key] 198 199 200class NodeStateTracker(object): 201 """Base class for general-purpose Python code transformation. 202 203 This abstract class provides helpful functions, like state tracking within 204 the scope of arbitrary node, helpers for processing code blocks, debugging, 205 mapping of transformed code to original code, and others. 206 207 Scope-local state tracking: to keep state across nodes, at the level of 208 (possibly nested) scopes, use enter/exit_local_scope and set/get_local. 209 You must call enter/exit_local_scope manually, but the transformer detects 210 when they are not properly paired. 211 212 The transformer allows keeping state across calls that is local 213 to arbitrary nodes and their descendants, using the self.state attribute. 214 Multiple independent scopes are allowed and automatically constructed. 215 216 For example, to keep track of the `If` node that encloses any `Name` node, 217 one can write: 218 219 ``` 220 class FooType(object): 221 222 def __init__(self): 223 self.foo_property = None 224 225 class DummyTransformer(NodeStateTracker, ast.NodeTransformer): 226 227 def visit_If(self, node): 228 self.state[FooType].enter() 229 self.state[FooType].foo_property = node 230 node = self.veneric_visit(node) 231 self.state[FooType].exit() 232 return node 233 234 def visit_Name(self, node): 235 self.state[FooType].foo_property # will hold the innermost enclosing if 236 ``` 237 238 Alternatively, the `enter()`/`exit()` calls can be managed by a `with` 239 statement: 240 241 ``` 242 def visit_If(self, node): 243 with self.state[FooType] as foo: 244 foo.foo_property = node 245 return self.generic_visit(node) 246 ``` 247 """ 248 249 # TODO(mdan): Document all extra features. 250 251 def __init__(self, ctx): 252 """Initialize the transformer. 253 254 Subclasses should call this. 255 256 Args: 257 ctx: A Context object. 258 """ 259 self._lineno = 0 260 self._col_offset = 0 261 self.ctx = ctx 262 263 # Allows scoping of local variables to keep state across calls to visit_* 264 # methods. Multiple scope hierarchies may exist and are keyed by tag. A 265 # scope is valid at one or more nodes and all its children. Scopes created 266 # in child nodes supersede their parent. Scopes are isolated from one 267 # another. 268 self.state = _State() 269 270 def debug_print(self, node): 271 """Helper method useful for debugging. Prints the AST.""" 272 if __debug__: 273 print(pretty_printer.fmt(node)) 274 return node 275 276 def debug_print_src(self, node): 277 """Helper method useful for debugging. Prints the AST as code.""" 278 if __debug__: 279 print(parser.unparse(node)) 280 return node 281 282 def visit_block(self, nodes, before_visit=None, after_visit=None): 283 """A more powerful version of generic_visit for statement blocks. 284 285 An example of a block is the body of an if statement. 286 287 This function allows specifying a postprocessing callback (the 288 after_visit argument) argument which can be used to move nodes to a new 289 destination. This is done by after_visit by returning a non-null 290 second return value, e.g. return new_node, new_destination. 291 292 For example, a transformer could perform the following move: 293 294 foo() 295 bar() 296 baz() 297 298 foo() 299 if cond: 300 bar() 301 baz() 302 303 The above could be done with a postprocessor of this kind: 304 305 def after_visit(node): 306 if node_is_function_call(bar): 307 new_container_node = build_cond() 308 new_container_node.body.append(node) 309 return new_container_node, new_container_node.body 310 else: 311 # Once we set a new destination, all subsequent items will be 312 # moved to it, so we don't need to explicitly handle baz. 313 return node, None 314 315 Args: 316 nodes: enumerable of AST node objects. If None, the function returns None. 317 before_visit: optional callable that is called before visiting each item 318 in nodes 319 after_visit: optional callable that takes in an AST node and returns a 320 tuple (new_node, new_destination). It is called after visiting each item 321 in nodes. Is used in the same was as the 322 visit_* methods: new_node will replace the node; if not None, 323 new_destination must be a list, and subsequent nodes will be placed 324 in this list instead of the list returned by visit_block. 325 326 Returns: 327 A list of AST node objects containing the transformed items fron nodes, 328 except those nodes that have been relocated using after_visit. 329 """ 330 if nodes is None: 331 return None 332 333 results = [] 334 node_destination = results 335 for node in nodes: 336 if before_visit: 337 # TODO(mdan): We can modify node here too, if ever needed. 338 before_visit() 339 340 replacement = self.visit(node) 341 342 if after_visit and replacement: 343 replacement, new_destination = after_visit(replacement) 344 else: 345 new_destination = None 346 347 if replacement: 348 if isinstance(replacement, (list, tuple)): 349 node_destination.extend(replacement) 350 else: 351 node_destination.append(replacement) 352 353 # Allow the postprocessor to reroute the remaining nodes to a new list. 354 if new_destination is not None: 355 node_destination = new_destination 356 return results 357 358 359# TODO(mdan): Rename to PythonCodeTransformer. 360class Base(NodeStateTracker, gast.NodeTransformer): 361 """Base class for general-purpose Python-to-Python code transformation. 362 363 This is an extension of ast.NodeTransformer that provides the additional 364 functions offered by NodeStateTracker. 365 """ 366 367 def create_assignment(self, target, expression): 368 template = """ 369 target = expression 370 """ 371 return templates.replace(template, target=target, expression=expression) 372 373 # TODO(mdan): Remove. 374 def apply_to_single_assignments(self, targets, values, apply_fn): 375 """Applies a function to each individual assignment. 376 377 This function can process a possibly-unpacked (e.g. a, b = c, d) assignment. 378 It tries to break down the unpacking if possible. In effect, it has the same 379 effect as passing the assigned values in SSA form to apply_fn. 380 381 Examples: 382 383 The following will result in apply_fn(a, c), apply_fn(b, d): 384 385 a, b = c, d 386 387 The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]): 388 389 a, b = c 390 391 The following will result in apply_fn(a, (b, c)): 392 393 a = b, c 394 395 It uses the visitor pattern to allow subclasses to process single 396 assignments individually. 397 398 Args: 399 targets: list, tuple of or individual AST node. Should be used with the 400 targets field of an ast.Assign node. 401 values: an AST node. 402 apply_fn: a function of a single argument, which will be called with the 403 respective nodes of each single assignment. The signature is 404 apply_fn(target, value), no return value. 405 """ 406 if not isinstance(targets, (list, tuple)): 407 targets = (targets,) 408 for target in targets: 409 if isinstance(target, (gast.Tuple, gast.List)): 410 for i in range(len(target.elts)): 411 target_el = target.elts[i] 412 if isinstance(values, (gast.Tuple, gast.List)): 413 value_el = values.elts[i] 414 else: 415 value_el = gast.Subscript(values, i, ctx=gast.Store()) 416 self.apply_to_single_assignments(target_el, value_el, apply_fn) 417 else: 418 # TODO(mdan): Look into allowing to rewrite the AST here. 419 apply_fn(target, values) 420 421 def visit(self, node): 422 if not isinstance(node, gast.AST): 423 # This is not that uncommon a mistake: various node bodies are lists, for 424 # example, posing a land mine for transformers that need to recursively 425 # call `visit`. The error needs to be raised before the exception handler 426 # below is installed, because said handler will mess up if `node` is not, 427 # in fact, a node. 428 msg = ('invalid value for "node": expected "ast.AST", got "{}"; to' 429 ' visit lists of nodes, use "visit_block" instead').format( 430 type(node)) 431 raise ValueError(msg) 432 433 if anno.hasanno(node, anno.Basic.SKIP_PROCESSING): 434 return node 435 436 parent_origin = self.ctx.current_origin 437 if anno.hasanno(node, anno.Basic.ORIGIN): 438 self.ctx.current_origin = anno.getanno(node, anno.Basic.ORIGIN) 439 440 try: 441 processing_expr_node = isinstance(node, gast.Expr) 442 if processing_expr_node: 443 entry_expr_value = node.value 444 445 result = super(Base, self).visit(node) 446 447 # Adjust for consistency: replacing the value of an Expr with 448 # an Assign node removes the need for the Expr node. 449 if (processing_expr_node and isinstance(result, gast.Expr) and 450 (result.value is not entry_expr_value)): 451 # When the replacement is a list, it is assumed that the list came 452 # from a template that contained a number of statements, which 453 # themselves are standalone and don't require an enclosing Expr. 454 if isinstance(result.value, 455 (list, tuple, gast.Assign, gast.AugAssign)): 456 result = result.value 457 458 # By default, all replacements receive the origin info of the replaced 459 # node. 460 if result is not node and result is not None: 461 inherited_origin = anno.getanno( 462 node, anno.Basic.ORIGIN, default=parent_origin) 463 if inherited_origin is not None: 464 nodes_to_adjust = result 465 if isinstance(result, (list, tuple)): 466 nodes_to_adjust = result 467 else: 468 nodes_to_adjust = (result,) 469 for n in nodes_to_adjust: 470 if not anno.hasanno(n, anno.Basic.ORIGIN): 471 anno.setanno(n, anno.Basic.ORIGIN, inherited_origin) 472 finally: 473 self.ctx.current_origin = parent_origin 474 475 return result 476 477 478class CodeGenerator(NodeStateTracker, gast.NodeVisitor): 479 """Base class for general-purpose Python-to-string code transformation. 480 481 Similar to Base, but outputs arbitrary strings instead of a Python AST. 482 483 This uses the same visitor mechanism that the standard NodeVisitor uses, 484 meaning that subclasses write handlers for the different kinds of nodes. 485 New code is generated using the emit method, which appends to a code buffer 486 that can be afterwards obtained from code_buffer. 487 488 Example: 489 490 class SimpleCodeGen(CodeGenerator): 491 492 def visitIf(self, node): 493 self.emit('if ') 494 self.visit(node.test) 495 self.emit(' { ') 496 self.visit(node.body) 497 self.emit(' } else { ') 498 self.visit(node.orelse) 499 self.emit(' } ') 500 501 node = ast.parse(...) 502 gen = SimpleCodeGen() 503 gen.visit(node) 504 # gen.code_buffer contains the resulting code 505 """ 506 507 def __init__(self, ctx): 508 super(CodeGenerator, self).__init__(ctx) 509 510 self._output_code = '' 511 self.source_map = {} 512 513 def emit(self, code): 514 self._output_code += code 515 516 @property 517 def code_buffer(self): 518 return self._output_code 519 520 def visit(self, node): 521 if anno.hasanno(node, anno.Basic.SKIP_PROCESSING): 522 return 523 524 parent_origin = self.ctx.current_origin 525 eof_before = len(self._output_code) 526 if anno.hasanno(node, anno.Basic.ORIGIN): 527 self.ctx.current_origin = anno.getanno(node, anno.Basic.ORIGIN) 528 529 try: 530 ret = super(CodeGenerator, self).visit(node) 531 532 # By default, all replacements receive the origin info of the replaced 533 # node. 534 eof_after = len(self._output_code) 535 if eof_before - eof_after: 536 inherited_origin = anno.getanno( 537 node, anno.Basic.ORIGIN, default=parent_origin) 538 if inherited_origin is not None: 539 self.source_map[(eof_before, eof_after)] = inherited_origin 540 return ret 541 finally: 542 self.ctx.current_origin = parent_origin 543