1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Reaching definition analysis. 16 17This analysis attaches a set of a Definition objects to each symbol, one 18for each distinct definition that may reach it. The Definition objects are 19mutable and may be used by subsequent analyses to further annotate data like 20static type and value information. 21The analysis also attaches the set of the symbols defined at the entry of 22control flow statements. 23 24Requires activity analysis. 25""" 26 27from __future__ import absolute_import 28from __future__ import division 29from __future__ import print_function 30 31import weakref 32 33import gast 34 35from tensorflow.python.autograph.pyct import anno 36from tensorflow.python.autograph.pyct import cfg 37from tensorflow.python.autograph.pyct import transformer 38from tensorflow.python.autograph.pyct.static_analysis import annos 39 40 41class Definition(object): 42 """Definition objects describe a unique definition of a variable. 43 44 Subclasses of this may be used by passing an appropriate factory function to 45 resolve. 46 47 Attributes: 48 param_of: Optional[ast.AST] 49 """ 50 51 def __init__(self): 52 self.param_of = None 53 54 def __repr__(self): 55 return '%s[%d]' % (self.__class__.__name__, id(self)) 56 57 58class _NodeState(object): 59 """Abstraction for the state of the CFG walk for reaching definition analysis. 60 61 This is a value type. Only implements the strictly necessary operators. 62 63 Attributes: 64 value: Dict[qual_names.QN, Set[Definition, ...]], the defined symbols and 65 their possible definitions 66 """ 67 68 def __init__(self, init_from=None): 69 if init_from: 70 if isinstance(init_from, _NodeState): 71 self.value = { 72 s: set(other_infos) for s, other_infos in init_from.value.items() 73 } 74 elif isinstance(init_from, dict): 75 self.value = {s: set((init_from[s],)) for s in init_from} 76 else: 77 assert False, init_from 78 else: 79 self.value = {} 80 81 def __eq__(self, other): 82 if frozenset(self.value.keys()) != frozenset(other.value.keys()): 83 return False 84 ret = all(self.value[s] == other.value[s] for s in self.value) 85 return ret 86 87 def __ne__(self, other): 88 return not self.__eq__(other) 89 90 def __or__(self, other): 91 assert isinstance(other, _NodeState) 92 result = _NodeState(self) 93 for s, other_infos in other.value.items(): 94 if s in result.value: 95 result.value[s].update(other_infos) 96 else: 97 result.value[s] = set(other_infos) 98 return result 99 100 def __sub__(self, other): 101 assert isinstance(other, set) 102 result = _NodeState(self) 103 for s in other: 104 result.value.pop(s, None) 105 return result 106 107 def __repr__(self): 108 return 'NodeState[%s]=%s' % (id(self), repr(self.value)) 109 110 111class Analyzer(cfg.GraphVisitor): 112 """CFG visitor that determines reaching definitions at statement level.""" 113 114 def __init__(self, graph, definition_factory): 115 self._definition_factory = definition_factory 116 super(Analyzer, self).__init__(graph) 117 # This allows communicating that nodes have extra reaching definitions, 118 # e.g. those that a function closes over. 119 self.extra_in = {} 120 121 self.gen_map = {} 122 123 def init_state(self, _): 124 return _NodeState() 125 126 def visit_node(self, node): 127 prev_defs_out = self.out[node] 128 129 defs_in = _NodeState(self.extra_in.get(node.ast_node, None)) 130 for n in node.prev: 131 defs_in |= self.out[n] 132 133 if anno.hasanno(node.ast_node, anno.Static.SCOPE): 134 node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE) 135 # The definition objects created by each node must be singletons because 136 # their ids are used in equality checks. 137 if node not in self.gen_map: 138 node_symbols = {} 139 for s in node_scope.modified: 140 def_ = self._definition_factory() 141 if s in node_scope.params: 142 def_.param_of = weakref.ref(node_scope.params[s]) 143 node_symbols[s] = def_ 144 self.gen_map[node] = _NodeState(node_symbols) 145 146 gen = self.gen_map[node] 147 kill = node_scope.modified | node_scope.deleted 148 defs_out = gen | (defs_in - kill) 149 150 else: 151 # Nodes that don't have a scope annotation are assumed not to touch any 152 # symbols. 153 # This Name node below is a literal name, e.g. False 154 # This can also happen if activity.py forgot to annotate the node with a 155 # scope object. 156 assert isinstance( 157 node.ast_node, 158 (gast.Name, gast.Break, gast.Continue, gast.Raise)), (node.ast_node, 159 node) 160 defs_out = defs_in 161 162 self.in_[node] = defs_in 163 self.out[node] = defs_out 164 165 # TODO(mdan): Move this to the superclass? 166 return prev_defs_out != defs_out 167 168 169class TreeAnnotator(transformer.Base): 170 """AST visitor that annotates each symbol name with its reaching definitions. 171 172 Simultaneously, the visitor runs the dataflow analysis on each function node, 173 accounting for the effect of closures. For example: 174 175 def foo(): 176 bar = 1 177 def baz(): 178 # bar = 1 reaches here 179 """ 180 181 def __init__(self, source_info, graphs, definition_factory): 182 super(TreeAnnotator, self).__init__(source_info) 183 self.definition_factory = definition_factory 184 self.graphs = graphs 185 self.current_analyzer = None 186 self.current_cfg_node = None 187 188 def visit_FunctionDef(self, node): 189 parent_analyzer = self.current_analyzer 190 subgraph = self.graphs[node] 191 192 # Preorder tree processing: 193 # 1. if this is a child function, the parent was already analyzed and it 194 # has the proper state value for the subgraph's entry 195 # 2. analyze the current function body 196 # 2. recursively walk the subtree; child functions will be processed 197 analyzer = Analyzer(subgraph, self.definition_factory) 198 if parent_analyzer is not None: 199 # Wire the state between the two subgraphs' analyzers. 200 parent_out_state = parent_analyzer.out[parent_analyzer.graph.index[node]] 201 # Exception: symbols modified in the child function are local to it 202 body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) 203 parent_out_state -= body_scope.modified 204 analyzer.extra_in[node.args] = parent_out_state 205 206 # Complete the analysis for the local function and annotate its body. 207 analyzer.visit_forward() 208 209 # Recursively process any remaining subfunctions. 210 self.current_analyzer = analyzer 211 # Note: not visiting name, decorator_list and returns because they don't 212 # apply to this anlysis. 213 # TODO(mdan): Should we still process the function name? 214 node.args = self.visit(node.args) 215 node.body = self.visit_block(node.body) 216 self.current_analyzer = parent_analyzer 217 218 return node 219 220 def visit_Nonlocal(self, node): 221 raise NotImplementedError() 222 223 def visit_Global(self, node): 224 raise NotImplementedError() 225 226 def visit_ExceptHandler(self, node): 227 # TODO(b/123995141) Add Exception Handlers to the CFG 228 return node 229 230 def visit_Name(self, node): 231 if self.current_analyzer is None: 232 # Names may appear outside function defs - for example in class 233 # definitions. 234 return node 235 236 analyzer = self.current_analyzer 237 cfg_node = self.current_cfg_node 238 239 assert cfg_node is not None, ('name node, %s, outside of any statement?' 240 % node.id) 241 242 qn = anno.getanno(node, anno.Basic.QN) 243 if isinstance(node.ctx, gast.Load): 244 anno.setanno(node, anno.Static.DEFINITIONS, 245 tuple(analyzer.in_[cfg_node].value.get(qn, ()))) 246 else: 247 anno.setanno(node, anno.Static.DEFINITIONS, 248 tuple(analyzer.out[cfg_node].value.get(qn, ()))) 249 250 return node 251 252 def _aggregate_predecessors_defined_in(self, node): 253 preds = self.current_analyzer.graph.stmt_prev[node] 254 node_defined_in = set() 255 for p in preds: 256 node_defined_in |= set(self.current_analyzer.out[p].value.keys()) 257 anno.setanno(node, anno.Static.DEFINED_VARS_IN, frozenset(node_defined_in)) 258 259 def visit_If(self, node): 260 self._aggregate_predecessors_defined_in(node) 261 return self.generic_visit(node) 262 263 def visit_For(self, node): 264 self._aggregate_predecessors_defined_in(node) 265 266 # Manually accounting for the shortcoming described in 267 # cfg.AstToCfg.visit_For. 268 parent = self.current_cfg_node 269 self.current_cfg_node = self.current_analyzer.graph.index[node.iter] 270 node.target = self.visit(node.target) 271 self.current_cfg_node = parent 272 273 node.iter = self.visit(node.iter) 274 node.body = self.visit_block(node.body) 275 node.orelse = self.visit_block(node.orelse) 276 277 return node 278 279 def visit_While(self, node): 280 self._aggregate_predecessors_defined_in(node) 281 return self.generic_visit(node) 282 283 def visit(self, node): 284 parent = self.current_cfg_node 285 286 if (self.current_analyzer is not None and 287 node in self.current_analyzer.graph.index): 288 self.current_cfg_node = self.current_analyzer.graph.index[node] 289 node = super(TreeAnnotator, self).visit(node) 290 291 self.current_cfg_node = parent 292 return node 293 294 295def resolve(node, source_info, graphs, definition_factory): 296 """Resolves reaching definitions for each symbol. 297 298 Args: 299 node: ast.AST 300 source_info: transformer.SourceInfo 301 graphs: Dict[ast.FunctionDef, cfg.Graph] 302 definition_factory: Callable[[], Definition] 303 Returns: 304 ast.AST 305 """ 306 visitor = TreeAnnotator(source_info, graphs, definition_factory) 307 node = visitor.visit(node) 308 return node 309