• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Reaching definition analysis.
16
17This analysis attaches a set of a Definition objects to each symbol, one
18for each distinct definition that may reach it. The Definition objects are
19mutable and may be used by subsequent analyses to further annotate data like
20static type and value information.
21The analysis also attaches the set of the symbols defined at the entry of
22control flow statements.
23
24Requires activity analysis.
25"""
26
27from __future__ import absolute_import
28from __future__ import division
29from __future__ import print_function
30
31import weakref
32
33import gast
34
35from tensorflow.python.autograph.pyct import anno
36from tensorflow.python.autograph.pyct import cfg
37from tensorflow.python.autograph.pyct import transformer
38from tensorflow.python.autograph.pyct.static_analysis import annos
39
40
41class Definition(object):
42  """Definition objects describe a unique definition of a variable.
43
44  Subclasses of this may be used by passing an appropriate factory function to
45  resolve.
46
47  Attributes:
48    param_of: Optional[ast.AST]
49  """
50
51  def __init__(self):
52    self.param_of = None
53
54  def __repr__(self):
55    return '%s[%d]' % (self.__class__.__name__, id(self))
56
57
58class _NodeState(object):
59  """Abstraction for the state of the CFG walk for reaching definition analysis.
60
61  This is a value type. Only implements the strictly necessary operators.
62
63  Attributes:
64    value: Dict[qual_names.QN, Set[Definition, ...]], the defined symbols and
65        their possible definitions
66  """
67
68  def __init__(self, init_from=None):
69    if init_from:
70      if isinstance(init_from, _NodeState):
71        self.value = {
72            s: set(other_infos) for s, other_infos in init_from.value.items()
73        }
74      elif isinstance(init_from, dict):
75        self.value = {s: set((init_from[s],)) for s in init_from}
76      else:
77        assert False, init_from
78    else:
79      self.value = {}
80
81  def __eq__(self, other):
82    if frozenset(self.value.keys()) != frozenset(other.value.keys()):
83      return False
84    ret = all(self.value[s] == other.value[s] for s in self.value)
85    return ret
86
87  def __ne__(self, other):
88    return not self.__eq__(other)
89
90  def __or__(self, other):
91    assert isinstance(other, _NodeState)
92    result = _NodeState(self)
93    for s, other_infos in other.value.items():
94      if s in result.value:
95        result.value[s].update(other_infos)
96      else:
97        result.value[s] = set(other_infos)
98    return result
99
100  def __sub__(self, other):
101    assert isinstance(other, set)
102    result = _NodeState(self)
103    for s in other:
104      result.value.pop(s, None)
105    return result
106
107  def __repr__(self):
108    return 'NodeState[%s]=%s' % (id(self), repr(self.value))
109
110
111class Analyzer(cfg.GraphVisitor):
112  """CFG visitor that determines reaching definitions at statement level."""
113
114  def __init__(self, graph, definition_factory):
115    self._definition_factory = definition_factory
116    super(Analyzer, self).__init__(graph)
117    # This allows communicating that nodes have extra reaching definitions,
118    # e.g. those that a function closes over.
119    self.extra_in = {}
120
121    self.gen_map = {}
122
123  def init_state(self, _):
124    return _NodeState()
125
126  def visit_node(self, node):
127    prev_defs_out = self.out[node]
128
129    defs_in = _NodeState(self.extra_in.get(node.ast_node, None))
130    for n in node.prev:
131      defs_in |= self.out[n]
132
133    if anno.hasanno(node.ast_node, anno.Static.SCOPE):
134      node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE)
135      # The definition objects created by each node must be singletons because
136      # their ids are used in equality checks.
137      if node not in self.gen_map:
138        node_symbols = {}
139        for s in node_scope.modified:
140          def_ = self._definition_factory()
141          if s in node_scope.params:
142            def_.param_of = weakref.ref(node_scope.params[s])
143          node_symbols[s] = def_
144        self.gen_map[node] = _NodeState(node_symbols)
145
146      gen = self.gen_map[node]
147      kill = node_scope.modified | node_scope.deleted
148      defs_out = gen | (defs_in - kill)
149
150    else:
151      # Nodes that don't have a scope annotation are assumed not to touch any
152      # symbols.
153      # This Name node below is a literal name, e.g. False
154      # This can also happen if activity.py forgot to annotate the node with a
155      # scope object.
156      assert isinstance(
157          node.ast_node,
158          (gast.Name, gast.Break, gast.Continue, gast.Raise)), (node.ast_node,
159                                                                node)
160      defs_out = defs_in
161
162    self.in_[node] = defs_in
163    self.out[node] = defs_out
164
165    # TODO(mdan): Move this to the superclass?
166    return prev_defs_out != defs_out
167
168
169class TreeAnnotator(transformer.Base):
170  """AST visitor that annotates each symbol name with its reaching definitions.
171
172  Simultaneously, the visitor runs the dataflow analysis on each function node,
173  accounting for the effect of closures. For example:
174
175    def foo():
176      bar = 1
177      def baz():
178        # bar = 1 reaches here
179  """
180
181  def __init__(self, source_info, graphs, definition_factory):
182    super(TreeAnnotator, self).__init__(source_info)
183    self.definition_factory = definition_factory
184    self.graphs = graphs
185    self.current_analyzer = None
186    self.current_cfg_node = None
187
188  def visit_FunctionDef(self, node):
189    parent_analyzer = self.current_analyzer
190    subgraph = self.graphs[node]
191
192    # Preorder tree processing:
193    #  1. if this is a child function, the parent was already analyzed and it
194    #     has the proper state value for the subgraph's entry
195    #  2. analyze the current function body
196    #  2. recursively walk the subtree; child functions will be processed
197    analyzer = Analyzer(subgraph, self.definition_factory)
198    if parent_analyzer is not None:
199      # Wire the state between the two subgraphs' analyzers.
200      parent_out_state = parent_analyzer.out[parent_analyzer.graph.index[node]]
201      # Exception: symbols modified in the child function are local to it
202      body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
203      parent_out_state -= body_scope.modified
204      analyzer.extra_in[node.args] = parent_out_state
205
206    # Complete the analysis for the local function and annotate its body.
207    analyzer.visit_forward()
208
209    # Recursively process any remaining subfunctions.
210    self.current_analyzer = analyzer
211    # Note: not visiting name, decorator_list and returns because they don't
212    # apply to this anlysis.
213    # TODO(mdan): Should we still process the function name?
214    node.args = self.visit(node.args)
215    node.body = self.visit_block(node.body)
216    self.current_analyzer = parent_analyzer
217
218    return node
219
220  def visit_Nonlocal(self, node):
221    raise NotImplementedError()
222
223  def visit_Global(self, node):
224    raise NotImplementedError()
225
226  def visit_ExceptHandler(self, node):
227    # TODO(b/123995141) Add Exception Handlers to the CFG
228    return node
229
230  def visit_Name(self, node):
231    if self.current_analyzer is None:
232      # Names may appear outside function defs - for example in class
233      # definitions.
234      return node
235
236    analyzer = self.current_analyzer
237    cfg_node = self.current_cfg_node
238
239    assert cfg_node is not None, ('name node, %s, outside of any statement?'
240                                  % node.id)
241
242    qn = anno.getanno(node, anno.Basic.QN)
243    if isinstance(node.ctx, gast.Load):
244      anno.setanno(node, anno.Static.DEFINITIONS,
245                   tuple(analyzer.in_[cfg_node].value.get(qn, ())))
246    else:
247      anno.setanno(node, anno.Static.DEFINITIONS,
248                   tuple(analyzer.out[cfg_node].value.get(qn, ())))
249
250    return node
251
252  def _aggregate_predecessors_defined_in(self, node):
253    preds = self.current_analyzer.graph.stmt_prev[node]
254    node_defined_in = set()
255    for p in preds:
256      node_defined_in |= set(self.current_analyzer.out[p].value.keys())
257    anno.setanno(node, anno.Static.DEFINED_VARS_IN, frozenset(node_defined_in))
258
259  def visit_If(self, node):
260    self._aggregate_predecessors_defined_in(node)
261    return self.generic_visit(node)
262
263  def visit_For(self, node):
264    self._aggregate_predecessors_defined_in(node)
265
266    # Manually accounting for the shortcoming described in
267    # cfg.AstToCfg.visit_For.
268    parent = self.current_cfg_node
269    self.current_cfg_node = self.current_analyzer.graph.index[node.iter]
270    node.target = self.visit(node.target)
271    self.current_cfg_node = parent
272
273    node.iter = self.visit(node.iter)
274    node.body = self.visit_block(node.body)
275    node.orelse = self.visit_block(node.orelse)
276
277    return node
278
279  def visit_While(self, node):
280    self._aggregate_predecessors_defined_in(node)
281    return self.generic_visit(node)
282
283  def visit(self, node):
284    parent = self.current_cfg_node
285
286    if (self.current_analyzer is not None and
287        node in self.current_analyzer.graph.index):
288      self.current_cfg_node = self.current_analyzer.graph.index[node]
289    node = super(TreeAnnotator, self).visit(node)
290
291    self.current_cfg_node = parent
292    return node
293
294
295def resolve(node, source_info, graphs, definition_factory):
296  """Resolves reaching definitions for each symbol.
297
298  Args:
299    node: ast.AST
300    source_info: transformer.SourceInfo
301    graphs: Dict[ast.FunctionDef, cfg.Graph]
302    definition_factory: Callable[[], Definition]
303  Returns:
304    ast.AST
305  """
306  visitor = TreeAnnotator(source_info, graphs, definition_factory)
307  node = visitor.visit(node)
308  return node
309