• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Activity analysis.
16
17Requires qualified name annotations (see qual_names.py).
18"""
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import copy
25import weakref
26
27import gast
28import six
29
30from tensorflow.python.autograph.pyct import anno
31from tensorflow.python.autograph.pyct import qual_names
32from tensorflow.python.autograph.pyct import transformer
33from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
34
35# TODO(mdan): Add support for PY3 (e.g. Param vs arg).
36# TODO(alexbw): Ignore named literals (e.g. None)
37
38
39class Scope(object):
40  """Encloses local symbol definition and usage information.
41
42  This can track for instance whether a symbol is modified in the current scope.
43  Note that scopes do not necessarily align with Python's scopes. For example,
44  the body of an if statement may be considered a separate scope.
45
46  Caution - the AST references held by this object are weak.
47
48  Attributes:
49    modified: Set[qual_names.QN], identifiers modified in this scope
50    read: Set[qual_names.QN], identifiers read in this scope
51    deleted: Set[qual_names.QN], identifiers deleted in this scope
52    params: WeakValueDictionary[qual_names.QN, ast.Node], function arguments
53      visible in this scope, mapped to the function node that defines them
54
55  Note - simple statements may never delete and modify a symbol at the same
56  time. However, compound ones like if statements can. In that latter case, it's
57  undefined whether the symbol is actually modified or deleted upon statement
58  exit. Certain analyses like reaching definitions need to be careful about
59  this.
60  """
61
62  def __init__(self, parent, isolated=True, add_unknown_symbols=False):
63    """Create a new scope.
64
65    Args:
66      parent: A Scope or None.
67      isolated: Whether the scope is isolated, that is, whether variables
68          modified in this scope should be considered modified in the parent
69          scope.
70      add_unknown_symbols: Whether to handle attributed and subscripts
71          without having first seen the base name.
72          E.g., analyzing the statement 'x.y = z' without first having seen 'x'.
73    """
74    self.isolated = isolated
75    self.parent = parent
76    self.add_unknown_symbols = add_unknown_symbols
77    self.modified = set()
78    self.read = set()
79    self.deleted = set()
80    self.params = weakref.WeakValueDictionary()
81
82  @property
83  def affects_parent(self):
84    return not self.isolated and self.parent is not None
85
86  @property
87  def referenced(self):
88    if self.affects_parent:
89      return self.read | self.parent.referenced
90    return self.read
91
92  def __repr__(self):
93    return 'Scope{r=%s, w=%s}' % (tuple(self.read), tuple(self.modified))
94
95  def copy_from(self, other):
96    """Recursively copies the contents of this scope from another scope."""
97    if (self.parent is None) != (other.parent is None):
98      raise ValueError('cannot copy scopes of different structures')
99    if other.parent is not None:
100      self.parent.copy_from(other.parent)
101    self.isolated = other.isolated
102    self.modified = copy.copy(other.modified)
103    self.read = copy.copy(other.read)
104    self.params = copy.copy(other.params)
105
106  @classmethod
107  def copy_of(cls, other):
108    if other.parent is not None:
109      parent = cls.copy_of(other.parent)
110    else:
111      parent = None
112    new_copy = cls(parent)
113    new_copy.copy_from(other)
114    return new_copy
115
116  def merge_from(self, other):
117    if (self.parent is None) != (other.parent is None):
118      raise ValueError('cannot merge scopes of different structures')
119    if other.parent is not None:
120      self.parent.merge_from(other.parent)
121    self.modified |= other.modified
122    self.read |= other.read
123    self.params.update(other.params)
124
125  def mark_read(self, name):
126    self.read.add(name)
127    if self.parent is not None and name not in self.params:
128      self.parent.mark_read(name)
129
130  def mark_modified(self, name):
131    self.modified.add(name)
132    if self.affects_parent:
133      self.parent.mark_modified(name)
134
135  def mark_deleted(self, name):
136    self.deleted.add(name)
137
138  def mark_param(self, name, owner):
139    # Assumption: all AST nodes have the same life span. This lets us use
140    # a weak reference to mark the connection between a symbol node and the
141    # function node whose argument that symbol is.
142    self.params[name] = owner
143
144
145class _Lambda(object):
146
147  no_root = True
148
149  def __init__(self):
150    self.args = set()
151
152
153class _Comprehension(object):
154
155  no_root = True
156
157  def __init__(self):
158    self.targets = set()
159
160
161class ActivityAnalyzer(transformer.Base):
162  """Annotates nodes with local scope information.
163
164  See Scope.
165
166  The use of this class requires that qual_names.resolve() has been called on
167  the node. This class will ignore nodes have not been
168  annotated with their qualified names.
169  """
170
171  def __init__(self, context, parent_scope=None, add_unknown_symbols=False):
172    super(ActivityAnalyzer, self).__init__(context)
173    self.scope = Scope(parent_scope, None, add_unknown_symbols)
174
175    # Note: all these flags crucially rely on the respective nodes are
176    # leaves in the AST, that is, they cannot contain other statements.
177    self._in_aug_assign = False
178    self._in_function_def_args = False
179
180  @property
181  def _in_constructor(self):
182    if len(self.enclosing_entities) > 1:
183      innermost = self.enclosing_entities[-1]
184      parent = self.enclosing_entities[-2]
185      return isinstance(parent, gast.ClassDef) and innermost.name == '__init__'
186    return False
187
188  def _node_sets_self_attribute(self, node):
189    if anno.hasanno(node, anno.Basic.QN):
190      qn = anno.getanno(node, anno.Basic.QN)
191      # TODO(mdan): The 'self' argument is not guaranteed to be called 'self'.
192      if qn.has_attr and qn.parent.qn == ('self',):
193        return True
194    return False
195
196  def _track_symbol(self, node, composite_writes_alter_parent=False):
197    # A QN may be missing when we have an attribute (or subscript) on a function
198    # call. Example: a().b
199    if not anno.hasanno(node, anno.Basic.QN):
200      return
201    qn = anno.getanno(node, anno.Basic.QN)
202
203    # When inside a lambda, ignore any of the lambda's arguments.
204    # This includes attributes or slices of those arguments.
205    for l in self.state[_Lambda]:
206      if qn in l.args:
207        return
208      if qn.owner_set & set(l.args):
209        return
210
211    # When inside a comprehension, ignore any of the comprehensions's targets.
212    # This includes attributes or slices of those arguments.
213    # This is not true in Python2, which leaks symbols.
214    if six.PY3:
215      for l in self.state[_Comprehension]:
216        if qn in l.targets:
217          return
218        if qn.owner_set & set(l.targets):
219          return
220
221    if isinstance(node.ctx, gast.Store):
222      # In comprehensions, modified symbols are the comprehension targets.
223      if six.PY3 and self.state[_Comprehension].level > 0:
224        # Like a lambda's args, they are tracked separately in Python3.
225        self.state[_Comprehension].targets.add(qn)
226      else:
227        self.scope.mark_modified(qn)
228        if qn.is_composite and composite_writes_alter_parent:
229          self.scope.mark_modified(qn.parent)
230        if self._in_aug_assign:
231          self.scope.mark_read(qn)
232    elif isinstance(node.ctx, gast.Load):
233      self.scope.mark_read(qn)
234    elif isinstance(node.ctx, gast.Param):
235      if self._in_function_def_args:
236        # In function defs have the meaning of defining a variable.
237        self.scope.mark_modified(qn)
238        self.scope.mark_param(qn, self.enclosing_entities[-1])
239      elif self.state[_Lambda].level:
240        # In lambdas, they are tracked separately.
241        self.state[_Lambda].args.add(qn)
242      else:
243        # TODO(mdan): Is this case possible at all?
244        raise NotImplementedError(
245            'Param "{}" outside a function arguments or lambda.'.format(qn))
246    elif isinstance(node.ctx, gast.Del):
247      # The read matches the Python semantics - attempting to delete an
248      # undefined symbol is illegal.
249      self.scope.mark_read(qn)
250      self.scope.mark_deleted(qn)
251    else:
252      raise ValueError('Unknown context {} for node "{}".'.format(
253          type(node.ctx), qn))
254
255  def _enter_scope(self, isolated):
256    self.scope = Scope(self.scope, isolated=isolated)
257
258  def _exit_scope(self):
259    self.scope = self.scope.parent
260
261  def _process_statement(self, node):
262    self._enter_scope(False)
263    node = self.generic_visit(node)
264    anno.setanno(node, anno.Static.SCOPE, self.scope)
265    self._exit_scope()
266    return node
267
268  def visit_Nonlocal(self, node):
269    raise NotImplementedError()
270
271  def visit_Global(self, node):
272    raise NotImplementedError()
273
274  def visit_Expr(self, node):
275    return self._process_statement(node)
276
277  def visit_Return(self, node):
278    return self._process_statement(node)
279
280  def visit_Assign(self, node):
281    return self._process_statement(node)
282
283  def visit_AugAssign(self, node):
284    # Special rules for AugAssign. In Assign, the target is only written,
285    # but in AugAssig (e.g. a += b), the target is both read and written.
286    self._in_aug_assign = True
287    node = self._process_statement(node)
288    self._in_aug_assign = False
289    return node
290
291  def visit_Delete(self, node):
292    return self._process_statement(node)
293
294  def visit_Name(self, node):
295    node = self.generic_visit(node)
296    self._track_symbol(node)
297    return node
298
299  def visit_Attribute(self, node):
300    node = self.generic_visit(node)
301    if self._in_constructor and self._node_sets_self_attribute(node):
302      self._track_symbol(node, composite_writes_alter_parent=True)
303    else:
304      self._track_symbol(node)
305    return node
306
307  def visit_Subscript(self, node):
308    node = self.generic_visit(node)
309    # Subscript writes (e.g. a[b] = "value") are considered to modify
310    # both the element itself (a[b]) and its parent (a).
311    self._track_symbol(node)
312    return node
313
314  def visit_Print(self, node):
315    self._enter_scope(False)
316    node.values = self.visit_block(node.values)
317    anno.setanno(node, anno.Static.SCOPE, self.scope)
318    anno.setanno(node, NodeAnno.ARGS_SCOPE, self.scope)
319    self._exit_scope()
320    return node
321
322  def visit_Assert(self, node):
323    return self._process_statement(node)
324
325  def visit_Call(self, node):
326    self._enter_scope(False)
327    node.args = self.visit_block(node.args)
328    node.keywords = self.visit_block(node.keywords)
329    # TODO(mdan): Account starargs, kwargs
330    anno.setanno(node, NodeAnno.ARGS_SCOPE, self.scope)
331    self._exit_scope()
332    node.func = self.visit(node.func)
333    return node
334
335  def _process_block_node(self, node, block, scope_name):
336    self._enter_scope(False)
337    block = self.visit_block(block)
338    anno.setanno(node, scope_name, self.scope)
339    self._exit_scope()
340    return node
341
342  def _process_parallel_blocks(self, parent, children):
343    # Because the scopes are not isolated, processing any child block
344    # modifies the parent state causing the other child blocks to be
345    # processed incorrectly. So we need to checkpoint the parent scope so that
346    # each child sees the same context.
347    before_parent = Scope.copy_of(self.scope)
348    after_children = []
349    for child, scope_name in children:
350      self.scope.copy_from(before_parent)
351      parent = self._process_block_node(parent, child, scope_name)
352      after_child = Scope.copy_of(self.scope)
353      after_children.append(after_child)
354    for after_child in after_children:
355      self.scope.merge_from(after_child)
356    return parent
357
358  def visit_Lambda(self, node):
359    assert not self._in_function_def_args
360    self.state[_Lambda].enter()
361    node = self.generic_visit(node)
362    self.state[_Lambda].exit()
363    return node
364
365  def _process_iterable_comprehension(self, node):
366    # This handles ListComp, SetComp, GeneratorExp.
367    self.state[_Comprehension].enter()
368    # Note: it's important to visit the generators first to properly account
369    # for the variables local to these generators. Example: `x` is local to the
370    # expression `x for x in y`.
371    node.generators = self.visit_block(node.generators)
372    node.elt = self.visit(node.elt)
373    self.state[_Comprehension].exit()
374    return node
375
376  def visit_DictComp(self, node):
377    # Identical to _process_iterable_comprehension, different node names.
378    self.state[_Comprehension].enter()
379    node.generators = self.visit_block(node.generators)
380    node.key = self.visit(node.key)
381    node.value = self.visit(node.value)
382    self.state[_Comprehension].exit()
383    return node
384
385  def visit_ListComp(self, node):
386    return self._process_iterable_comprehension(node)
387
388  def visit_SetComp(self, node):
389    return self._process_iterable_comprehension(node)
390
391  def visit_GeneratorExp(self, node):
392    return self._process_iterable_comprehension(node)
393
394  def visit_arguments(self, node):
395    return self._process_statement(node)
396
397  def visit_FunctionDef(self, node):
398    # The FunctionDef node itself has a Scope object that tracks the creation
399    # of its name, along with the usage of any decorator accompanying it.
400    self._enter_scope(False)
401    node.decorator_list = self.visit_block(node.decorator_list)
402    self.scope.mark_modified(qual_names.QN(node.name))
403    anno.setanno(node, anno.Static.SCOPE, self.scope)
404    self._exit_scope()
405
406    # A separate Scope tracks the actual function definition.
407    self._enter_scope(True)
408    assert not (self._in_function_def_args or self.state[_Lambda].level)
409    self._in_function_def_args = True
410    node.args = self.visit(node.args)
411    self._in_function_def_args = False
412
413    # Track the body separately. This is for compatibility reasons, it may not
414    # be strictly needed.
415    self._enter_scope(False)
416    node.body = self.visit_block(node.body)
417    anno.setanno(node, NodeAnno.BODY_SCOPE, self.scope)
418    self._exit_scope()
419
420    self._exit_scope()
421    return node
422
423  def visit_With(self, node):
424    self._enter_scope(False)
425    node = self.generic_visit(node)
426    anno.setanno(node, NodeAnno.BODY_SCOPE, self.scope)
427    self._exit_scope()
428    return node
429
430  def visit_withitem(self, node):
431    return self._process_statement(node)
432
433  def visit_If(self, node):
434    self._enter_scope(False)
435    node.test = self.visit(node.test)
436    anno.setanno(node, NodeAnno.COND_SCOPE, self.scope)
437    anno.setanno(node.test, anno.Static.SCOPE, self.scope)
438    self._exit_scope()
439    node = self._process_parallel_blocks(node,
440                                         ((node.body, NodeAnno.BODY_SCOPE),
441                                          (node.orelse, NodeAnno.ORELSE_SCOPE)))
442    return node
443
444  def visit_For(self, node):
445    self._enter_scope(False)
446    node.target = self.visit(node.target)
447    node.iter = self.visit(node.iter)
448    anno.setanno(node.iter, anno.Static.SCOPE, self.scope)
449    self._exit_scope()
450    node = self._process_parallel_blocks(node,
451                                         ((node.body, NodeAnno.BODY_SCOPE),
452                                          (node.orelse, NodeAnno.ORELSE_SCOPE)))
453    return node
454
455  def visit_While(self, node):
456    self._enter_scope(False)
457    node.test = self.visit(node.test)
458    anno.setanno(node, NodeAnno.COND_SCOPE, self.scope)
459    anno.setanno(node.test, anno.Static.SCOPE, self.scope)
460    self._exit_scope()
461    node = self._process_parallel_blocks(node,
462                                         ((node.body, NodeAnno.BODY_SCOPE),
463                                          (node.orelse, NodeAnno.ORELSE_SCOPE)))
464    return node
465
466
467def resolve(node, context, parent_scope=None):
468  return ActivityAnalyzer(context, parent_scope).visit(node)
469