1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Reaching definition analysis. 16 17This analysis attaches a set of a Definition objects to each symbol, one 18for each distinct definition that may reach it. The Definition objects are 19mutable and may be used by subsequent analyses to further annotate data like 20static type and value information. 21The analysis also attaches the set of the symbols defined at the entry of 22control flow statements. 23 24Requires activity analysis. 25""" 26 27from __future__ import absolute_import 28from __future__ import division 29from __future__ import print_function 30 31import weakref 32 33import gast 34 35from tensorflow.python.autograph.pyct import anno 36from tensorflow.python.autograph.pyct import cfg 37from tensorflow.python.autograph.pyct import transformer 38 39 40class Definition(object): 41 """Definition objects describe a unique definition of a variable. 42 43 Subclasses of this may be used by passing an appropriate factory function to 44 resolve. 45 46 Attributes: 47 param_of: Optional[ast.AST] 48 directives: Dict, optional definition annotations 49 """ 50 51 def __init__(self): 52 self.param_of = None 53 self.directives = {} 54 55 def __repr__(self): 56 return '%s[%d]' % (self.__class__.__name__, id(self)) 57 58 59class _NodeState(object): 60 """Abstraction for the state of the CFG walk for reaching definition analysis. 61 62 This is a value type. Only implements the strictly necessary operators. 63 64 Attributes: 65 value: Dict[qual_names.QN, Set[Definition, ...]], the defined symbols and 66 their possible definitions 67 """ 68 69 def __init__(self, init_from=None): 70 if init_from: 71 if isinstance(init_from, _NodeState): 72 self.value = { 73 s: set(other_infos) for s, other_infos in init_from.value.items() 74 } 75 elif isinstance(init_from, dict): 76 self.value = {s: set((init_from[s],)) for s in init_from} 77 else: 78 assert False, init_from 79 else: 80 self.value = {} 81 82 def __eq__(self, other): 83 if frozenset(self.value.keys()) != frozenset(other.value.keys()): 84 return False 85 ret = all(self.value[s] == other.value[s] for s in self.value) 86 return ret 87 88 def __ne__(self, other): 89 return not self.__eq__(other) 90 91 def __or__(self, other): 92 assert isinstance(other, _NodeState) 93 result = _NodeState(self) 94 for s, other_infos in other.value.items(): 95 if s in result.value: 96 result.value[s].update(other_infos) 97 else: 98 result.value[s] = set(other_infos) 99 return result 100 101 def __sub__(self, other): 102 assert isinstance(other, set) 103 result = _NodeState(self) 104 for s in other: 105 result.value.pop(s, None) 106 return result 107 108 def __repr__(self): 109 return 'NodeState[%s]=%s' % (id(self), repr(self.value)) 110 111 112class Analyzer(cfg.GraphVisitor): 113 """CFG visitor that determines reaching definitions at statement level.""" 114 115 def __init__(self, graph, definition_factory): 116 self._definition_factory = definition_factory 117 super(Analyzer, self).__init__(graph) 118 self.gen_map = {} 119 120 def init_state(self, _): 121 return _NodeState() 122 123 def visit_node(self, node): 124 prev_defs_out = self.out[node] 125 126 defs_in = _NodeState() 127 for n in node.prev: 128 defs_in |= self.out[n] 129 130 if anno.hasanno(node.ast_node, anno.Static.SCOPE): 131 node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE) 132 # The definition objects created by each node must be singletons because 133 # their ids are used in equality checks. 134 if node not in self.gen_map: 135 node_symbols = {} 136 # Every binding operation (assign, nonlocal, global, etc.) counts as a 137 # definition, with the exception of del, which only deletes without 138 # creating a new variable. 139 newly_defined = ((node_scope.bound | node_scope.globals) - 140 node_scope.deleted) 141 for s in newly_defined: 142 def_ = self._definition_factory() 143 node_symbols[s] = def_ 144 # Every param receives a definition. Params are not necessarily 145 # considered as "modified". 146 for s, p in node_scope.params.items(): 147 def_ = self._definition_factory() 148 def_.param_of = weakref.ref(p) 149 node_symbols[s] = def_ 150 self.gen_map[node] = _NodeState(node_symbols) 151 152 gen = self.gen_map[node] 153 kill = node_scope.modified | node_scope.deleted 154 defs_out = gen | (defs_in - kill) 155 156 gen = self.gen_map[node] 157 defs_out = gen | (defs_in - kill) 158 159 else: 160 assert self.can_ignore(node), (node.ast_node, node) 161 defs_out = defs_in 162 163 self.in_[node] = defs_in 164 self.out[node] = defs_out 165 166 return prev_defs_out != defs_out 167 168 169class TreeAnnotator(transformer.Base): 170 """AST visitor that annotates each symbol name with its reaching definitions. 171 172 Simultaneously, the visitor runs the dataflow analysis on each function node, 173 accounting for the effect of closures. For example: 174 175 def foo(): 176 bar = 1 177 def baz(): 178 # bar = 1 reaches here 179 """ 180 181 def __init__(self, source_info, graphs, definition_factory): 182 super(TreeAnnotator, self).__init__(source_info) 183 self.allow_skips = False 184 self.definition_factory = definition_factory 185 self.graphs = graphs 186 self.current_analyzer = None 187 self.current_cfg_node = None 188 189 def visit_FunctionDef(self, node): 190 parent_analyzer = self.current_analyzer 191 subgraph = self.graphs[node] 192 193 analyzer = Analyzer(subgraph, self.definition_factory) 194 analyzer.visit_forward() 195 196 # Recursively process any remaining subfunctions. 197 self.current_analyzer = analyzer 198 node.args = self.visit(node.args) 199 node.body = self.visit_block(node.body) 200 self.current_analyzer = parent_analyzer 201 202 return node 203 204 def visit_Name(self, node): 205 if self.current_analyzer is None: 206 # Names may appear outside function defs - for example in class 207 # definitions. 208 return node 209 210 analyzer = self.current_analyzer 211 cfg_node = self.current_cfg_node 212 213 assert cfg_node is not None, ('name node, %s, outside of any statement?' 214 % node.id) 215 216 qn = anno.getanno(node, anno.Basic.QN) 217 if isinstance(node.ctx, gast.Load): 218 anno.setanno(node, anno.Static.DEFINITIONS, 219 tuple(analyzer.in_[cfg_node].value.get(qn, ()))) 220 else: 221 anno.setanno(node, anno.Static.DEFINITIONS, 222 tuple(analyzer.out[cfg_node].value.get(qn, ()))) 223 224 return node 225 226 def _aggregate_predecessors_defined_in(self, node): 227 preds = self.current_analyzer.graph.stmt_prev[node] 228 node_defined_in = set() 229 for p in preds: 230 node_defined_in |= set(self.current_analyzer.out[p].value.keys()) 231 anno.setanno(node, anno.Static.DEFINED_VARS_IN, frozenset(node_defined_in)) 232 233 def visit_If(self, node): 234 self._aggregate_predecessors_defined_in(node) 235 return self.generic_visit(node) 236 237 def visit_For(self, node): 238 self._aggregate_predecessors_defined_in(node) 239 240 # Manually accounting for the shortcoming described in 241 # cfg.AstToCfg.visit_For. 242 parent = self.current_cfg_node 243 self.current_cfg_node = self.current_analyzer.graph.index[node.iter] 244 node.target = self.visit(node.target) 245 self.current_cfg_node = parent 246 247 node.iter = self.visit(node.iter) 248 node.body = self.visit_block(node.body) 249 node.orelse = self.visit_block(node.orelse) 250 251 return node 252 253 def visit_While(self, node): 254 self._aggregate_predecessors_defined_in(node) 255 return self.generic_visit(node) 256 257 def visit_Try(self, node): 258 self._aggregate_predecessors_defined_in(node) 259 return self.generic_visit(node) 260 261 def visit_ExceptHandler(self, node): 262 self._aggregate_predecessors_defined_in(node) 263 # TODO(mdan): Also track the exception type / name symbols. 264 node.body = self.visit_block(node.body) 265 return node 266 267 def visit(self, node): 268 parent = self.current_cfg_node 269 270 if (self.current_analyzer is not None and 271 node in self.current_analyzer.graph.index): 272 self.current_cfg_node = self.current_analyzer.graph.index[node] 273 node = super(TreeAnnotator, self).visit(node) 274 275 self.current_cfg_node = parent 276 return node 277 278 279def resolve(node, source_info, graphs, definition_factory=Definition): 280 """Resolves reaching definitions for each symbol. 281 282 Args: 283 node: ast.AST 284 source_info: transformer.SourceInfo 285 graphs: Dict[ast.FunctionDef, cfg.Graph] 286 definition_factory: Callable[[], Definition] 287 Returns: 288 ast.AST 289 """ 290 visitor = TreeAnnotator(source_info, graphs, definition_factory) 291 node = visitor.visit(node) 292 return node 293