1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Activity analysis. 16 17Requires qualified name annotations (see qual_names.py). 18""" 19 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24import copy 25import weakref 26 27import gast 28import six 29 30from tensorflow.python.autograph.pyct import anno 31from tensorflow.python.autograph.pyct import qual_names 32from tensorflow.python.autograph.pyct import transformer 33from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno 34 35# TODO(mdan): Add support for PY3 (e.g. Param vs arg). 36# TODO(alexbw): Ignore named literals (e.g. None) 37 38 39class Scope(object): 40 """Encloses local symbol definition and usage information. 41 42 This can track for instance whether a symbol is modified in the current scope. 43 Note that scopes do not necessarily align with Python's scopes. For example, 44 the body of an if statement may be considered a separate scope. 45 46 Caution - the AST references held by this object are weak. 47 48 Attributes: 49 modified: Set[qual_names.QN], identifiers modified in this scope 50 read: Set[qual_names.QN], identifiers read in this scope 51 deleted: Set[qual_names.QN], identifiers deleted in this scope 52 params: WeakValueDictionary[qual_names.QN, ast.Node], function arguments 53 visible in this scope, mapped to the function node that defines them 54 55 Note - simple statements may never delete and modify a symbol at the same 56 time. However, compound ones like if statements can. In that latter case, it's 57 undefined whether the symbol is actually modified or deleted upon statement 58 exit. Certain analyses like reaching definitions need to be careful about 59 this. 60 """ 61 62 def __init__(self, parent, isolated=True, add_unknown_symbols=False): 63 """Create a new scope. 64 65 Args: 66 parent: A Scope or None. 67 isolated: Whether the scope is isolated, that is, whether variables 68 modified in this scope should be considered modified in the parent 69 scope. 70 add_unknown_symbols: Whether to handle attributed and subscripts 71 without having first seen the base name. 72 E.g., analyzing the statement 'x.y = z' without first having seen 'x'. 73 """ 74 self.isolated = isolated 75 self.parent = parent 76 self.add_unknown_symbols = add_unknown_symbols 77 self.modified = set() 78 self.read = set() 79 self.deleted = set() 80 self.params = weakref.WeakValueDictionary() 81 82 @property 83 def affects_parent(self): 84 return not self.isolated and self.parent is not None 85 86 @property 87 def referenced(self): 88 if self.affects_parent: 89 return self.read | self.parent.referenced 90 return self.read 91 92 def __repr__(self): 93 return 'Scope{r=%s, w=%s}' % (tuple(self.read), tuple(self.modified)) 94 95 def copy_from(self, other): 96 """Recursively copies the contents of this scope from another scope.""" 97 if (self.parent is None) != (other.parent is None): 98 raise ValueError('cannot copy scopes of different structures') 99 if other.parent is not None: 100 self.parent.copy_from(other.parent) 101 self.isolated = other.isolated 102 self.modified = copy.copy(other.modified) 103 self.read = copy.copy(other.read) 104 self.params = copy.copy(other.params) 105 106 @classmethod 107 def copy_of(cls, other): 108 if other.parent is not None: 109 parent = cls.copy_of(other.parent) 110 else: 111 parent = None 112 new_copy = cls(parent) 113 new_copy.copy_from(other) 114 return new_copy 115 116 def merge_from(self, other): 117 if (self.parent is None) != (other.parent is None): 118 raise ValueError('cannot merge scopes of different structures') 119 if other.parent is not None: 120 self.parent.merge_from(other.parent) 121 self.modified |= other.modified 122 self.read |= other.read 123 self.params.update(other.params) 124 125 def mark_read(self, name): 126 self.read.add(name) 127 if self.parent is not None and name not in self.params: 128 self.parent.mark_read(name) 129 130 def mark_modified(self, name): 131 self.modified.add(name) 132 if self.affects_parent: 133 self.parent.mark_modified(name) 134 135 def mark_deleted(self, name): 136 self.deleted.add(name) 137 138 def mark_param(self, name, owner): 139 # Assumption: all AST nodes have the same life span. This lets us use 140 # a weak reference to mark the connection between a symbol node and the 141 # function node whose argument that symbol is. 142 self.params[name] = owner 143 144 145class _Lambda(object): 146 147 no_root = True 148 149 def __init__(self): 150 self.args = set() 151 152 153class _Comprehension(object): 154 155 no_root = True 156 157 def __init__(self): 158 self.targets = set() 159 160 161class ActivityAnalyzer(transformer.Base): 162 """Annotates nodes with local scope information. 163 164 See Scope. 165 166 The use of this class requires that qual_names.resolve() has been called on 167 the node. This class will ignore nodes have not been 168 annotated with their qualified names. 169 """ 170 171 def __init__(self, context, parent_scope=None, add_unknown_symbols=False): 172 super(ActivityAnalyzer, self).__init__(context) 173 self.scope = Scope(parent_scope, None, add_unknown_symbols) 174 175 # Note: all these flags crucially rely on the respective nodes are 176 # leaves in the AST, that is, they cannot contain other statements. 177 self._in_aug_assign = False 178 self._in_function_def_args = False 179 180 @property 181 def _in_constructor(self): 182 if len(self.enclosing_entities) > 1: 183 innermost = self.enclosing_entities[-1] 184 parent = self.enclosing_entities[-2] 185 return isinstance(parent, gast.ClassDef) and innermost.name == '__init__' 186 return False 187 188 def _node_sets_self_attribute(self, node): 189 if anno.hasanno(node, anno.Basic.QN): 190 qn = anno.getanno(node, anno.Basic.QN) 191 # TODO(mdan): The 'self' argument is not guaranteed to be called 'self'. 192 if qn.has_attr and qn.parent.qn == ('self',): 193 return True 194 return False 195 196 def _track_symbol(self, node, composite_writes_alter_parent=False): 197 # A QN may be missing when we have an attribute (or subscript) on a function 198 # call. Example: a().b 199 if not anno.hasanno(node, anno.Basic.QN): 200 return 201 qn = anno.getanno(node, anno.Basic.QN) 202 203 # When inside a lambda, ignore any of the lambda's arguments. 204 # This includes attributes or slices of those arguments. 205 for l in self.state[_Lambda]: 206 if qn in l.args: 207 return 208 if qn.owner_set & set(l.args): 209 return 210 211 # When inside a comprehension, ignore any of the comprehensions's targets. 212 # This includes attributes or slices of those arguments. 213 # This is not true in Python2, which leaks symbols. 214 if six.PY3: 215 for l in self.state[_Comprehension]: 216 if qn in l.targets: 217 return 218 if qn.owner_set & set(l.targets): 219 return 220 221 if isinstance(node.ctx, gast.Store): 222 # In comprehensions, modified symbols are the comprehension targets. 223 if six.PY3 and self.state[_Comprehension].level > 0: 224 # Like a lambda's args, they are tracked separately in Python3. 225 self.state[_Comprehension].targets.add(qn) 226 else: 227 self.scope.mark_modified(qn) 228 if qn.is_composite and composite_writes_alter_parent: 229 self.scope.mark_modified(qn.parent) 230 if self._in_aug_assign: 231 self.scope.mark_read(qn) 232 elif isinstance(node.ctx, gast.Load): 233 self.scope.mark_read(qn) 234 elif isinstance(node.ctx, gast.Param): 235 if self._in_function_def_args: 236 # In function defs have the meaning of defining a variable. 237 self.scope.mark_modified(qn) 238 self.scope.mark_param(qn, self.enclosing_entities[-1]) 239 elif self.state[_Lambda].level: 240 # In lambdas, they are tracked separately. 241 self.state[_Lambda].args.add(qn) 242 else: 243 # TODO(mdan): Is this case possible at all? 244 raise NotImplementedError( 245 'Param "{}" outside a function arguments or lambda.'.format(qn)) 246 elif isinstance(node.ctx, gast.Del): 247 # The read matches the Python semantics - attempting to delete an 248 # undefined symbol is illegal. 249 self.scope.mark_read(qn) 250 self.scope.mark_deleted(qn) 251 else: 252 raise ValueError('Unknown context {} for node "{}".'.format( 253 type(node.ctx), qn)) 254 255 def _enter_scope(self, isolated): 256 self.scope = Scope(self.scope, isolated=isolated) 257 258 def _exit_scope(self): 259 self.scope = self.scope.parent 260 261 def _process_statement(self, node): 262 self._enter_scope(False) 263 node = self.generic_visit(node) 264 anno.setanno(node, anno.Static.SCOPE, self.scope) 265 self._exit_scope() 266 return node 267 268 def visit_Nonlocal(self, node): 269 raise NotImplementedError() 270 271 def visit_Global(self, node): 272 raise NotImplementedError() 273 274 def visit_Expr(self, node): 275 return self._process_statement(node) 276 277 def visit_Return(self, node): 278 return self._process_statement(node) 279 280 def visit_Assign(self, node): 281 return self._process_statement(node) 282 283 def visit_AugAssign(self, node): 284 # Special rules for AugAssign. In Assign, the target is only written, 285 # but in AugAssig (e.g. a += b), the target is both read and written. 286 self._in_aug_assign = True 287 node = self._process_statement(node) 288 self._in_aug_assign = False 289 return node 290 291 def visit_Delete(self, node): 292 return self._process_statement(node) 293 294 def visit_Name(self, node): 295 node = self.generic_visit(node) 296 self._track_symbol(node) 297 return node 298 299 def visit_Attribute(self, node): 300 node = self.generic_visit(node) 301 if self._in_constructor and self._node_sets_self_attribute(node): 302 self._track_symbol(node, composite_writes_alter_parent=True) 303 else: 304 self._track_symbol(node) 305 return node 306 307 def visit_Subscript(self, node): 308 node = self.generic_visit(node) 309 # Subscript writes (e.g. a[b] = "value") are considered to modify 310 # both the element itself (a[b]) and its parent (a). 311 self._track_symbol(node) 312 return node 313 314 def visit_Print(self, node): 315 self._enter_scope(False) 316 node.values = self.visit_block(node.values) 317 anno.setanno(node, anno.Static.SCOPE, self.scope) 318 anno.setanno(node, NodeAnno.ARGS_SCOPE, self.scope) 319 self._exit_scope() 320 return node 321 322 def visit_Assert(self, node): 323 return self._process_statement(node) 324 325 def visit_Call(self, node): 326 self._enter_scope(False) 327 node.args = self.visit_block(node.args) 328 node.keywords = self.visit_block(node.keywords) 329 # TODO(mdan): Account starargs, kwargs 330 anno.setanno(node, NodeAnno.ARGS_SCOPE, self.scope) 331 self._exit_scope() 332 node.func = self.visit(node.func) 333 return node 334 335 def _process_block_node(self, node, block, scope_name): 336 self._enter_scope(False) 337 block = self.visit_block(block) 338 anno.setanno(node, scope_name, self.scope) 339 self._exit_scope() 340 return node 341 342 def _process_parallel_blocks(self, parent, children): 343 # Because the scopes are not isolated, processing any child block 344 # modifies the parent state causing the other child blocks to be 345 # processed incorrectly. So we need to checkpoint the parent scope so that 346 # each child sees the same context. 347 before_parent = Scope.copy_of(self.scope) 348 after_children = [] 349 for child, scope_name in children: 350 self.scope.copy_from(before_parent) 351 parent = self._process_block_node(parent, child, scope_name) 352 after_child = Scope.copy_of(self.scope) 353 after_children.append(after_child) 354 for after_child in after_children: 355 self.scope.merge_from(after_child) 356 return parent 357 358 def visit_Lambda(self, node): 359 assert not self._in_function_def_args 360 self.state[_Lambda].enter() 361 node = self.generic_visit(node) 362 self.state[_Lambda].exit() 363 return node 364 365 def _process_iterable_comprehension(self, node): 366 # This handles ListComp, SetComp, GeneratorExp. 367 self.state[_Comprehension].enter() 368 # Note: it's important to visit the generators first to properly account 369 # for the variables local to these generators. Example: `x` is local to the 370 # expression `x for x in y`. 371 node.generators = self.visit_block(node.generators) 372 node.elt = self.visit(node.elt) 373 self.state[_Comprehension].exit() 374 return node 375 376 def visit_DictComp(self, node): 377 # Identical to _process_iterable_comprehension, different node names. 378 self.state[_Comprehension].enter() 379 node.generators = self.visit_block(node.generators) 380 node.key = self.visit(node.key) 381 node.value = self.visit(node.value) 382 self.state[_Comprehension].exit() 383 return node 384 385 def visit_ListComp(self, node): 386 return self._process_iterable_comprehension(node) 387 388 def visit_SetComp(self, node): 389 return self._process_iterable_comprehension(node) 390 391 def visit_GeneratorExp(self, node): 392 return self._process_iterable_comprehension(node) 393 394 def visit_arguments(self, node): 395 return self._process_statement(node) 396 397 def visit_FunctionDef(self, node): 398 # The FunctionDef node itself has a Scope object that tracks the creation 399 # of its name, along with the usage of any decorator accompanying it. 400 self._enter_scope(False) 401 node.decorator_list = self.visit_block(node.decorator_list) 402 self.scope.mark_modified(qual_names.QN(node.name)) 403 anno.setanno(node, anno.Static.SCOPE, self.scope) 404 self._exit_scope() 405 406 # A separate Scope tracks the actual function definition. 407 self._enter_scope(True) 408 assert not (self._in_function_def_args or self.state[_Lambda].level) 409 self._in_function_def_args = True 410 node.args = self.visit(node.args) 411 self._in_function_def_args = False 412 413 # Track the body separately. This is for compatibility reasons, it may not 414 # be strictly needed. 415 self._enter_scope(False) 416 node.body = self.visit_block(node.body) 417 anno.setanno(node, NodeAnno.BODY_SCOPE, self.scope) 418 self._exit_scope() 419 420 self._exit_scope() 421 return node 422 423 def visit_With(self, node): 424 self._enter_scope(False) 425 node = self.generic_visit(node) 426 anno.setanno(node, NodeAnno.BODY_SCOPE, self.scope) 427 self._exit_scope() 428 return node 429 430 def visit_withitem(self, node): 431 return self._process_statement(node) 432 433 def visit_If(self, node): 434 self._enter_scope(False) 435 node.test = self.visit(node.test) 436 anno.setanno(node, NodeAnno.COND_SCOPE, self.scope) 437 anno.setanno(node.test, anno.Static.SCOPE, self.scope) 438 self._exit_scope() 439 node = self._process_parallel_blocks(node, 440 ((node.body, NodeAnno.BODY_SCOPE), 441 (node.orelse, NodeAnno.ORELSE_SCOPE))) 442 return node 443 444 def visit_For(self, node): 445 self._enter_scope(False) 446 node.target = self.visit(node.target) 447 node.iter = self.visit(node.iter) 448 anno.setanno(node.iter, anno.Static.SCOPE, self.scope) 449 self._exit_scope() 450 node = self._process_parallel_blocks(node, 451 ((node.body, NodeAnno.BODY_SCOPE), 452 (node.orelse, NodeAnno.ORELSE_SCOPE))) 453 return node 454 455 def visit_While(self, node): 456 self._enter_scope(False) 457 node.test = self.visit(node.test) 458 anno.setanno(node, NodeAnno.COND_SCOPE, self.scope) 459 anno.setanno(node.test, anno.Static.SCOPE, self.scope) 460 self._exit_scope() 461 node = self._process_parallel_blocks(node, 462 ((node.body, NodeAnno.BODY_SCOPE), 463 (node.orelse, NodeAnno.ORELSE_SCOPE))) 464 return node 465 466 467def resolve(node, context, parent_scope=None): 468 return ActivityAnalyzer(context, parent_scope).visit(node) 469