1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A node transformer that includes utilities for SCT.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21 22import gast 23 24from tensorflow.python.autograph.pyct import anno 25from tensorflow.python.autograph.pyct import compiler 26from tensorflow.python.autograph.pyct import pretty_printer 27from tensorflow.python.autograph.pyct import templates 28 29 30# TODO(znado): Use namedtuple. 31class Context(object): 32 """Contains information about a source code transformation. 33 34 This object is mutable, and is updated during conversion. Not thread safe. 35 36 Attributes: 37 info: EntityInfo, immutable. 38 current_origin: origin_info.OriginInfo, holds the OriginInfo of the last 39 AST node to be processed successfully. Useful for error handling. 40 """ 41 42 def __init__(self, info): 43 self.info = info 44 self.current_origin = None 45 46 47# TODO(mdan): Use namedtuple. 48class EntityInfo(object): 49 """Contains information about a Python entity. 50 51 Immutable. 52 53 Examples of entities include functions and classes. 54 55 Attributes: 56 source_code: The entity's source code. 57 source_file: The entity's source file. 58 namespace: Dict[str, ], containing symbols visible to the entity (excluding 59 parameters). 60 arg_values: dict[str->*], containing parameter values, if known. 61 arg_types: dict[str->*], containing parameter types, if known. 62 """ 63 64 # TODO(mdan): Remove the default and update tests. 65 def __init__(self, source_code, source_file, namespace, arg_values, 66 arg_types): 67 self.source_code = source_code 68 self.source_file = source_file 69 self.namespace = namespace 70 self.arg_values = {} if arg_values is None else arg_values 71 self.arg_types = {} if arg_types is None else arg_types 72 73 74class _StateStack(object): 75 """Typed stack abstraction. 76 77 This class provides syntactic sugar for a stack of objects of known 78 type. It allows accessing attributes of the object at the top of the stack 79 directly against this object, which allows for very terse syntax. 80 81 For example, this code: 82 83 stack = _StateStack(Foo) 84 stack.enter() 85 stack.bar 86 87 Is equivalent to: 88 89 stack = [] 90 stack.append(Foo()) 91 foo = stack[-1] 92 foo.bar 93 94 See _State for more on how this is used. 95 96 Attributes: 97 type: Any, the type of objects that this stack holds 98 level: int, the current stack depth 99 value: Any, the instance of the object at the top of the stack 100 """ 101 102 def __init__(self, type_): 103 # Because we override __setattr__, we need to attach these attributes using 104 # the superclass' setattr. 105 object.__setattr__(self, 'type', type_) 106 object.__setattr__(self, '_stack', []) 107 if not hasattr(type_, 'no_root'): 108 self.enter() 109 110 def enter(self): 111 self._stack.append(self.type()) 112 113 def exit(self): 114 return self._stack.pop() 115 116 @property 117 def level(self): 118 return len(self._stack) 119 120 @property 121 def value(self): 122 return self._stack[-1] 123 124 def __iter__(self): 125 return iter(self._stack) 126 127 def __getattr__(self, key): 128 return getattr(self._stack[-1], key) 129 130 def __setattr__(self, key, value): 131 setattr(self._stack[-1], key, value) 132 133 134class _State(object): 135 """Supporting class for nested scope variable space for converter.Base. 136 137 This structure offers syntactic sugar over a dict of stacks of objects 138 of known type. These structures are useful to keep state during AST walks. 139 Multiple different scopes can be tracked in parallel. For example: 140 141 s = _State() 142 143 s[foo].enter() 144 s[bar].enter() # this will not affect s[foo] 145 146 Element access has special semantics: 147 * keys are a data type 148 * element values are _StateStack(type=key) objects 149 * missing elements are automatically added, similarly to defaultdict 150 151 For example, the following block : 152 153 _State s 154 s[Foo] 155 156 Is equivalent to: 157 158 s = {} 159 if Foo not in s: 160 s[Foo] = Foo() 161 s[Foo] 162 163 See Base for how it's used. 164 """ 165 166 def __init__(self): 167 self._value = {} 168 169 def __getitem__(self, key): 170 if key not in self._value: 171 self._value[key] = _StateStack(key) 172 return self._value[key] 173 174 175class Base(gast.NodeTransformer): 176 """Base class for general-purpose code transformers transformers. 177 178 This is an extension of ast.NodeTransformer that provides a few additional 179 functions, like state tracking within the scope of arbitrary node, helpers 180 for processing code blocks, debugging, mapping of transformed code to 181 original code, and others. 182 183 Scope-local state tracking: to keep state across nodes, at the level of 184 (possibly nested) scopes, use enter/exit_local_scope and set/get_local. 185 You must call enter/exit_local_scope manually, but the transformer detects 186 when they are not properly paired. 187 188 The transformer allows keeping state across calls to visit_* that is local to 189 arbitrary nodes and their descendants, using the self.state attribute. 190 Multiple independent scopes are allowed and automatically constructed. 191 192 For example, to keep track of the If node that encloses any Name node, one can 193 write: 194 195 class FooType(object): 196 197 def __init__(self): 198 self.foo_property = None 199 200 class DummyTransformer(Base): 201 202 def visit_If(self, node): 203 self.state[FooType].enter() 204 self.state[FooType].foo_property = node 205 206 def visit_Name(self, node): 207 self.state[FooType].foo_property # will hold the innermost enclosing if 208 """ 209 210 # TODO(mdan): Document all extra features. 211 212 def __init__(self, ctx): 213 """Initialize the transformer. 214 215 Subclasses should call this. 216 217 Args: 218 ctx: A Context object. 219 """ 220 self._lineno = 0 221 self._col_offset = 0 222 self.ctx = ctx 223 self._enclosing_entities = [] 224 225 # A stack that allows keeping mutable, scope-local state where scopes may be 226 # nested. For example, it can be used to track the usage of break 227 # statements in each loop, where loops may be nested. 228 self._local_scope_state = [] 229 self.enter_local_scope() 230 231 # Allows scoping of local variables to keep state across calls to visit_* 232 # methods. Multiple scope hierchies may exist and are keyed by tag. A scope 233 # is valid at one or more nodes and all its children. Scopes created in 234 # child nodes supersede their parent. Scopes are isolated from one another. 235 self.state = _State() 236 237 @property 238 def enclosing_entities(self): 239 return tuple(self._enclosing_entities) 240 241 @property 242 def local_scope_level(self): 243 return len(self._local_scope_state) 244 245 def enter_local_scope(self, inherit=None): 246 """Deprecated. 247 248 Use self.state instead. 249 250 Marks entry into a new local scope. 251 252 Args: 253 inherit: Optional enumerable of variable names to copy from the parent 254 scope. 255 """ 256 scope_entered = {} 257 if inherit: 258 this_scope = self._local_scope_state[-1] 259 for name in inherit: 260 if name in this_scope: 261 scope_entered[name] = this_scope[name] 262 self._local_scope_state.append(scope_entered) 263 264 def exit_local_scope(self, keep=None): 265 """Deprecated. 266 267 Use self.state instead. 268 269 Marks exit from the current local scope. 270 271 Args: 272 keep: Optional enumerable of variable names to copy into the parent scope. 273 274 Returns: 275 A dict containing the scope that has just been exited. 276 """ 277 scope_left = self._local_scope_state.pop() 278 if keep: 279 this_scope = self._local_scope_state[-1] 280 for name in keep: 281 if name in scope_left: 282 this_scope[name] = scope_left[name] 283 return scope_left 284 285 def set_local(self, name, value): 286 """Deprecated. Use self.state instead.""" 287 self._local_scope_state[-1][name] = value 288 289 def get_local(self, name, default=None): 290 """Deprecated. Use self.state instead.""" 291 return self._local_scope_state[-1].get(name, default) 292 293 def debug_print(self, node): 294 """Helper method useful for debugging. Prints the AST.""" 295 if __debug__: 296 print(pretty_printer.fmt(node)) 297 return node 298 299 def debug_print_src(self, node): 300 """Helper method useful for debugging. Prints the AST as code.""" 301 if __debug__: 302 print(compiler.ast_to_source(node)) 303 return node 304 305 def create_assignment(self, target, expression): 306 template = """ 307 target = expression 308 """ 309 return templates.replace(template, target=target, expression=expression) 310 311 def visit_block(self, nodes, before_visit=None, after_visit=None): 312 """A more powerful version of generic_visit for statement blocks. 313 314 An example of a block is the body of an if statement. 315 316 This function allows specifying a postprocessing callback (the 317 after_visit argument) argument which can be used to move nodes to a new 318 destination. This is done by after_visit by returning a non-null 319 second return value, e.g. return new_node, new_destination. 320 321 For example, a transformer could perform the following move: 322 323 foo() 324 bar() 325 baz() 326 327 foo() 328 if cond: 329 bar() 330 baz() 331 332 The above could be done with a postprocessor of this kind: 333 334 def after_visit(node): 335 if node_is_function_call(bar): 336 new_container_node = build_cond() 337 new_container_node.body.append(node) 338 return new_container_node, new_container_node.body 339 else: 340 # Once we set a new destination, all subsequent items will be 341 # moved to it, so we don't need to explicitly handle baz. 342 return node, None 343 344 Args: 345 nodes: enumerable of AST node objects. If None, the function returns None. 346 before_visit: optional callable that is called before visiting each item 347 in nodes 348 after_visit: optional callable that takes in an AST node and returns a 349 tuple (new_node, new_destination). It is called after visiting each item 350 in nodes. Is used in the same was as the 351 visit_* methods: new_node will replace the node; if not None, 352 new_destination must be a list, and subsequent nodes will be placed 353 in this list instead of the list returned by visit_block. 354 355 Returns: 356 A list of AST node objects containing the transformed items fron nodes, 357 except those nodes that have been relocated using after_visit. 358 """ 359 if nodes is None: 360 return None 361 362 results = [] 363 node_destination = results 364 for node in nodes: 365 if before_visit: 366 # TODO(mdan): We can modify node here too, if ever needed. 367 before_visit() 368 369 replacement = self.visit(node) 370 371 if after_visit and replacement: 372 replacement, new_destination = after_visit(replacement) 373 else: 374 new_destination = None 375 376 if replacement: 377 if isinstance(replacement, (list, tuple)): 378 node_destination.extend(replacement) 379 else: 380 node_destination.append(replacement) 381 382 # Allow the postprocessor to reroute the remaining nodes to a new list. 383 if new_destination is not None: 384 node_destination = new_destination 385 return results 386 387 # TODO(mdan): Remove. 388 def apply_to_single_assignments(self, targets, values, apply_fn): 389 """Applies a function to each individual assignment. 390 391 This function can process a possibly-unpacked (e.g. a, b = c, d) assignment. 392 It tries to break down the unpacking if possible. In effect, it has the same 393 effect as passing the assigned values in SSA form to apply_fn. 394 395 Examples: 396 397 The following will result in apply_fn(a, c), apply_fn(b, d): 398 399 a, b = c, d 400 401 The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]): 402 403 a, b = c 404 405 The following will result in apply_fn(a, (b, c)): 406 407 a = b, c 408 409 It uses the visitor pattern to allow subclasses to process single 410 assignments individually. 411 412 Args: 413 targets: list, tuple of or individual AST node. Should be used with the 414 targets field of an ast.Assign node. 415 values: an AST node. 416 apply_fn: a function of a single argument, which will be called with the 417 respective nodes of each single assignment. The signature is 418 apply_fn(target, value), no return value. 419 """ 420 if not isinstance(targets, (list, tuple)): 421 targets = (targets,) 422 for target in targets: 423 if isinstance(target, (gast.Tuple, gast.List)): 424 for i in range(len(target.elts)): 425 target_el = target.elts[i] 426 if isinstance(values, (gast.Tuple, gast.List)): 427 value_el = values.elts[i] 428 else: 429 value_el = gast.Subscript(values, gast.Index(i), ctx=gast.Store()) 430 self.apply_to_single_assignments(target_el, value_el, apply_fn) 431 else: 432 # TODO(mdan): Look into allowing to rewrite the AST here. 433 apply_fn(target, values) 434 435 def _get_source(self, node): 436 try: 437 source, _ = compiler.ast_to_source(node) 438 return source 439 # pylint: disable=broad-except 440 # This function is used for error reporting. If an exception occurs here, 441 # it should be suppressed, in favor of emitting as informative a message 442 # about the original error as possible. 443 except Exception: 444 return '<could not convert AST to source>' 445 446 def visit(self, node): 447 if not isinstance(node, gast.AST): 448 # This is not that uncommon a mistake: various node bodies are lists, for 449 # example, posing a land mine for transformers that need to recursively 450 # call `visit`. The error needs to be raised before the exception handler 451 # below is installed, because said handler will mess up if `node` is not, 452 # in fact, a node. 453 msg = ('invalid value for "node": expected "ast.AST", got "{}"; to' 454 ' visit lists of nodes, use "visit_block" instead').format( 455 type(node)) 456 raise ValueError(msg) 457 458 did_enter_function = False 459 local_scope_size_at_entry = len(self._local_scope_state) 460 processing_expr_node = False 461 462 parent_origin = self.ctx.current_origin 463 if isinstance(node, (gast.FunctionDef, gast.ClassDef, gast.Lambda)): 464 did_enter_function = True 465 elif isinstance(node, gast.Expr): 466 processing_expr_node = True 467 468 if did_enter_function: 469 self._enclosing_entities.append(node) 470 471 if anno.hasanno(node, anno.Basic.ORIGIN): 472 self.ctx.current_origin = anno.getanno(node, anno.Basic.ORIGIN) 473 474 if processing_expr_node: 475 entry_expr_value = node.value 476 477 if not anno.hasanno(node, anno.Basic.SKIP_PROCESSING): 478 result = super(Base, self).visit(node) 479 self.ctx.current_origin = parent_origin 480 481 # Adjust for consistency: replacing the value of an Expr with 482 # an Assign node removes the need for the Expr node. 483 if processing_expr_node: 484 if isinstance(result, gast.Expr) and result.value != entry_expr_value: 485 # When the replacement is a list, it is assumed that the list came 486 # from a template that contained a number of statements, which 487 # themselves are standalone and don't require an enclosing Expr. 488 if isinstance(result.value, 489 (list, tuple, gast.Assign, gast.AugAssign)): 490 result = result.value 491 492 # On exception, the local scope integrity is not guaranteed. 493 if did_enter_function: 494 self._enclosing_entities.pop() 495 496 if local_scope_size_at_entry != len(self._local_scope_state): 497 raise AssertionError( 498 'Inconsistent local scope stack. Before entering node %s, the' 499 ' stack had length %d, after exit it has length %d. This' 500 ' indicates enter_local_scope and exit_local_scope are not' 501 ' well paired.' % (node, local_scope_size_at_entry, 502 len(self._local_scope_state))) 503 return result 504