• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""An analysis that determines the reach of a function definition.
16
17A function definition is said to reach a statement if that function may exist
18(and therefore may be called) when that statement executes.
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import gast
26
27from tensorflow.python.autograph.pyct import anno
28from tensorflow.python.autograph.pyct import cfg
29from tensorflow.python.autograph.pyct import transformer
30
31
32class Definition(object):
33  """Definition objects describe a unique definition of a function."""
34
35  def __init__(self, def_node):
36    self.def_node = def_node
37
38
39class _NodeState(object):
40  """Abstraction for the state of the CFG walk for reaching definition analysis.
41
42  This is a value type. Only implements the strictly necessary operators.
43
44  Attributes:
45    value: Dict[qual_names.QN, Set[Definition, ...]], the defined symbols and
46        their possible definitions
47  """
48
49  def __init__(self, init_from=None):
50    if init_from:
51      self.value = set(init_from)
52    else:
53      self.value = set()
54
55  def __eq__(self, other):
56    return self.value == other.value
57
58  def __ne__(self, other):
59    return self.value != other.value
60
61  def __or__(self, other):
62    assert isinstance(other, _NodeState)
63    result = _NodeState(self.value)
64    result.value.update(other.value)
65    return result
66
67  def __add__(self, value):
68    result = _NodeState(self.value)
69    result.value.add(value)
70    return result
71
72  def __repr__(self):
73    return 'NodeState[%s]=%s' % (id(self), repr(self.value))
74
75
76class Analyzer(cfg.GraphVisitor):
77  """CFG visitor that determines reaching definitions at statement level."""
78
79  def __init__(self, graph, external_defs):
80    super(Analyzer, self).__init__(graph)
81    # This allows communicating that nodes have extra reaching definitions,
82    # e.g. those that a function closes over.
83    self.external_defs = external_defs
84
85  def init_state(self, _):
86    return _NodeState()
87
88  def visit_node(self, node):
89    prev_defs_out = self.out[node]
90
91    if node is self.graph.entry:
92      defs_in = _NodeState(self.external_defs)
93    else:
94      defs_in = prev_defs_out
95
96    for n in node.prev:
97      defs_in |= self.out[n]
98
99    defs_out = defs_in
100    if isinstance(node.ast_node, (gast.Lambda, gast.FunctionDef)):
101      defs_out += node.ast_node
102
103    self.in_[node] = defs_in
104    self.out[node] = defs_out
105
106    return prev_defs_out != defs_out
107
108
109class TreeAnnotator(transformer.Base):
110  """AST visitor that annotates each symbol name with its reaching definitions.
111
112  Simultaneously, the visitor runs the dataflow analysis on each function node,
113  accounting for the effect of closures. For example:
114
115    def foo():
116      def f():
117        pass
118      def g():
119        # `def f` reaches here
120  """
121
122  def __init__(self, source_info, graphs):
123    super(TreeAnnotator, self).__init__(source_info)
124    self.graphs = graphs
125    self.allow_skips = False
126    self.current_analyzer = None
127
128  def _proces_function(self, node):
129    parent_analyzer = self.current_analyzer
130    subgraph = self.graphs[node]
131
132    if (self.current_analyzer is not None
133        and node in self.current_analyzer.graph.index):
134      cfg_node = self.current_analyzer.graph.index[node]
135      defined_in = self.current_analyzer.in_[cfg_node].value
136    else:
137      defined_in = ()
138
139    analyzer = Analyzer(subgraph, defined_in)
140    analyzer.visit_forward()
141
142    self.current_analyzer = analyzer
143    node = self.generic_visit(node)
144    self.current_analyzer = parent_analyzer
145    return node
146
147  def visit_FunctionDef(self, node):
148    return self._proces_function(node)
149
150  def visit_Lambda(self, node):
151    return self._proces_function(node)
152
153  def visit(self, node):
154    # This can happen before entering the top level function
155    if (self.current_analyzer is not None
156        and node in self.current_analyzer.graph.index):
157      cfg_node = self.current_analyzer.graph.index[node]
158      anno.setanno(node, anno.Static.DEFINED_FNS_IN,
159                   self.current_analyzer.in_[cfg_node].value)
160
161    extra_node = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST, default=None)
162    if extra_node is not None:
163      cfg_node = self.current_analyzer.graph.index[extra_node]
164      anno.setanno(extra_node, anno.Static.DEFINED_FNS_IN,
165                   self.current_analyzer.in_[cfg_node].value)
166
167    return super(TreeAnnotator, self).visit(node)
168
169
170def resolve(node, source_info, graphs):
171  """Resolves reaching definitions for each symbol.
172
173  Args:
174    node: ast.AST
175    source_info: transformer.SourceInfo
176    graphs: Dict[ast.FunctionDef, cfg.Graph]
177  Returns:
178    ast.AST
179  """
180  visitor = TreeAnnotator(source_info, graphs)
181  node = visitor.visit(node)
182  return node
183