• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Live variable analysis.
16
17See https://en.wikipedia.org/wiki/Live_variable_analysis for a definition of
18the following idioms: live variable, live in, live out, which are used
19throughout this file.
20
21This analysis attaches the following:
22 * symbols that are live at the exit of control flow statements
23 * symbols that are live at the entry of control flow statements
24
25Requires activity analysis.
26"""
27
28from __future__ import absolute_import
29from __future__ import division
30from __future__ import print_function
31
32import gast
33
34from tensorflow.python.autograph.pyct import anno
35from tensorflow.python.autograph.pyct import cfg
36from tensorflow.python.autograph.pyct import transformer
37from tensorflow.python.autograph.pyct.static_analysis import annos
38
39
40class Analyzer(cfg.GraphVisitor):
41  """CFG visitor that performs liveness analysis at statement level."""
42
43  def __init__(self, graph, include_annotations):
44    super(Analyzer, self).__init__(graph)
45    self.include_annotations = include_annotations
46
47  def init_state(self, _):
48    return set()
49
50  def visit_node(self, node):
51    prev_live_in = self.in_[node]
52
53    if anno.hasanno(node.ast_node, anno.Static.SCOPE):
54      node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE)
55
56      gen = node_scope.read
57      if not self.include_annotations:
58        gen -= node_scope.annotations
59      # TODO(mdan): verify whether composites' parents need to be added.
60      # E.g. whether x needs to be added if x.y is live. Theoretically the
61      # activity analysis should have both so that wouldn't be needed.
62      kill = node_scope.modified | node_scope.deleted
63
64      live_out = set()
65      for n in node.next:
66        live_out |= self.in_[n]
67      live_in = gen | (live_out - kill)
68
69      reaching_functions = anno.getanno(
70          node.ast_node, anno.Static.DEFINED_FNS_IN)
71      for fn_ast_node in reaching_functions:
72        if isinstance(fn_ast_node, gast.Lambda):
73          # Exception: lambda functions are assumed to be used only in the
74          # place where they are defined, and not later.
75          continue
76        fn_scope = anno.getanno(fn_ast_node, annos.NodeAnno.ARGS_AND_BODY_SCOPE)
77        # Any closure of a reaching function definition is conservatively
78        # considered live.
79        live_in |= (fn_scope.read - fn_scope.bound)
80
81    else:
82      assert self.can_ignore(node), (node.ast_node, node)
83
84      live_out = set()
85      for n in node.next:
86        live_out |= self.in_[n]
87      live_in = live_out
88
89    self.in_[node] = live_in
90    self.out[node] = live_out
91
92    # TODO(mdan): Move this to the superclass?
93    return prev_live_in != live_in
94
95
96class TreeAnnotator(transformer.Base):
97  """Runs liveness analysis on each of the functions defined in the AST.
98
99  If a function defined other local functions, those will have separate CFGs.
100  However, dataflow analysis needs to tie up these CFGs to properly emulate the
101  effect of closures. In the case of liveness, the parent function's live
102  variables must account for the variables that are live at the entry of each
103  subfunction. For example:
104
105    def foo():
106      # baz is live from here on
107      def bar():
108        print(baz)
109
110  This analyzer runs liveness analysis on each individual function, accounting
111  for the effect above.
112  """
113
114  def __init__(self, source_info, graphs, include_annotations):
115    super(TreeAnnotator, self).__init__(source_info)
116    self.include_annotations = include_annotations
117    self.allow_skips = False
118    self.graphs = graphs
119    self.current_analyzer = None
120
121  def visit(self, node):
122    node = super(TreeAnnotator, self).visit(node)
123    if (self.current_analyzer is not None and
124        isinstance(node, gast.stmt) and
125        node in self.current_analyzer.graph.index):
126      cfg_node = self.current_analyzer.graph.index[node]
127      anno.setanno(node, anno.Static.LIVE_VARS_IN,
128                   frozenset(self.current_analyzer.in_[cfg_node]))
129    return node
130
131  def _analyze_function(self, node, is_lambda):
132    parent_analyzer = self.current_analyzer
133
134    analyzer = Analyzer(self.graphs[node], self.include_annotations)
135    analyzer.visit_reverse()
136    self.current_analyzer = analyzer
137    node = self.generic_visit(node)
138
139    self.current_analyzer = parent_analyzer
140    return node
141
142  def visit_Lambda(self, node):
143    return self._analyze_function(node, is_lambda=True)
144
145  def visit_FunctionDef(self, node):
146    return self._analyze_function(node, is_lambda=False)
147
148  def _block_statement_live_out(self, node):
149    successors = self.current_analyzer.graph.stmt_next[node]
150    stmt_live_out = set()
151    for s in successors:
152      stmt_live_out.update(self.current_analyzer.in_[s])
153    anno.setanno(node, anno.Static.LIVE_VARS_OUT, frozenset(stmt_live_out))
154    return node
155
156  def _block_statement_live_in(self, node, entry_node):
157    if entry_node in self.current_analyzer.graph.index:
158      cfg_node = self.current_analyzer.graph.index[entry_node]
159      stmt_live_in = frozenset(self.current_analyzer.in_[cfg_node])
160    else:
161      assert anno.hasanno(entry_node, anno.Static.LIVE_VARS_IN), (
162          'If not matching a CFG node, must be a block statement:'
163          ' {}'.format(entry_node))
164      stmt_live_in = anno.getanno(entry_node, anno.Static.LIVE_VARS_IN)
165    anno.setanno(node, anno.Static.LIVE_VARS_IN, stmt_live_in)
166    return node
167
168  def visit_If(self, node):
169    node = self.generic_visit(node)
170    node = self._block_statement_live_out(node)
171    return self._block_statement_live_in(node, node.test)
172
173  def visit_For(self, node):
174    node = self.generic_visit(node)
175    node = self._block_statement_live_out(node)
176    return self._block_statement_live_in(node, node.iter)
177
178  def visit_While(self, node):
179    node = self.generic_visit(node)
180    node = self._block_statement_live_out(node)
181    return self._block_statement_live_in(node, node.test)
182
183  def visit_Try(self, node):
184    node = self.generic_visit(node)
185    node = self._block_statement_live_out(node)
186    return self._block_statement_live_in(node, node.body[0])
187
188  def visit_ExceptHandler(self, node):
189    node = self.generic_visit(node)
190    node = self._block_statement_live_out(node)
191    return self._block_statement_live_in(node, node.body[0])
192
193  def visit_With(self, node):
194    node = self.generic_visit(node)
195    return self._block_statement_live_in(node, node.items[0])
196
197  def visit_Expr(self, node):
198    node = self.generic_visit(node)
199    cfg_node = self.current_analyzer.graph.index[node]
200    anno.setanno(node, anno.Static.LIVE_VARS_OUT,
201                 frozenset(self.current_analyzer.out[cfg_node]))
202    return node
203
204
205# TODO(mdan): Investigate the possibility of removing include_annotations.
206def resolve(node, source_info, graphs, include_annotations=True):
207  """Resolves the live symbols at the exit of control flow statements.
208
209  Args:
210    node: ast.AST
211    source_info: transformer.SourceInfo
212    graphs: Dict[ast.FunctionDef, cfg.Graph]
213    include_annotations: Bool, whether type annotations should be included in
214      the analysis.
215  Returns:
216    ast.AST
217  """
218  node = TreeAnnotator(source_info, graphs, include_annotations).visit(node)
219  return node
220