1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""An analysis that determines the reach of a function definition. 16 17A function definition is said to reach a statement if that function may exist 18(and therefore may be called) when that statement executes. 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import gast 26 27from tensorflow.python.autograph.pyct import anno 28from tensorflow.python.autograph.pyct import cfg 29from tensorflow.python.autograph.pyct import transformer 30 31 32class Definition(object): 33 """Definition objects describe a unique definition of a function.""" 34 35 def __init__(self, def_node): 36 self.def_node = def_node 37 38 39class _NodeState(object): 40 """Abstraction for the state of the CFG walk for reaching definition analysis. 41 42 This is a value type. Only implements the strictly necessary operators. 43 44 Attributes: 45 value: Dict[qual_names.QN, Set[Definition, ...]], the defined symbols and 46 their possible definitions 47 """ 48 49 def __init__(self, init_from=None): 50 if init_from: 51 self.value = set(init_from) 52 else: 53 self.value = set() 54 55 def __eq__(self, other): 56 return self.value == other.value 57 58 def __ne__(self, other): 59 return self.value != other.value 60 61 def __or__(self, other): 62 assert isinstance(other, _NodeState) 63 result = _NodeState(self.value) 64 result.value.update(other.value) 65 return result 66 67 def __add__(self, value): 68 result = _NodeState(self.value) 69 result.value.add(value) 70 return result 71 72 def __repr__(self): 73 return 'NodeState[%s]=%s' % (id(self), repr(self.value)) 74 75 76class Analyzer(cfg.GraphVisitor): 77 """CFG visitor that determines reaching definitions at statement level.""" 78 79 def __init__(self, graph, external_defs): 80 super(Analyzer, self).__init__(graph) 81 # This allows communicating that nodes have extra reaching definitions, 82 # e.g. those that a function closes over. 83 self.external_defs = external_defs 84 85 def init_state(self, _): 86 return _NodeState() 87 88 def visit_node(self, node): 89 prev_defs_out = self.out[node] 90 91 if node is self.graph.entry: 92 defs_in = _NodeState(self.external_defs) 93 else: 94 defs_in = prev_defs_out 95 96 for n in node.prev: 97 defs_in |= self.out[n] 98 99 defs_out = defs_in 100 if isinstance(node.ast_node, (gast.Lambda, gast.FunctionDef)): 101 defs_out += node.ast_node 102 103 self.in_[node] = defs_in 104 self.out[node] = defs_out 105 106 return prev_defs_out != defs_out 107 108 109class TreeAnnotator(transformer.Base): 110 """AST visitor that annotates each symbol name with its reaching definitions. 111 112 Simultaneously, the visitor runs the dataflow analysis on each function node, 113 accounting for the effect of closures. For example: 114 115 def foo(): 116 def f(): 117 pass 118 def g(): 119 # `def f` reaches here 120 """ 121 122 def __init__(self, source_info, graphs): 123 super(TreeAnnotator, self).__init__(source_info) 124 self.graphs = graphs 125 self.allow_skips = False 126 self.current_analyzer = None 127 128 def _proces_function(self, node): 129 parent_analyzer = self.current_analyzer 130 subgraph = self.graphs[node] 131 132 if (self.current_analyzer is not None 133 and node in self.current_analyzer.graph.index): 134 cfg_node = self.current_analyzer.graph.index[node] 135 defined_in = self.current_analyzer.in_[cfg_node].value 136 else: 137 defined_in = () 138 139 analyzer = Analyzer(subgraph, defined_in) 140 analyzer.visit_forward() 141 142 self.current_analyzer = analyzer 143 node = self.generic_visit(node) 144 self.current_analyzer = parent_analyzer 145 return node 146 147 def visit_FunctionDef(self, node): 148 return self._proces_function(node) 149 150 def visit_Lambda(self, node): 151 return self._proces_function(node) 152 153 def visit(self, node): 154 # This can happen before entering the top level function 155 if (self.current_analyzer is not None 156 and node in self.current_analyzer.graph.index): 157 cfg_node = self.current_analyzer.graph.index[node] 158 anno.setanno(node, anno.Static.DEFINED_FNS_IN, 159 self.current_analyzer.in_[cfg_node].value) 160 161 extra_node = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST, default=None) 162 if extra_node is not None: 163 cfg_node = self.current_analyzer.graph.index[extra_node] 164 anno.setanno(extra_node, anno.Static.DEFINED_FNS_IN, 165 self.current_analyzer.in_[cfg_node].value) 166 167 return super(TreeAnnotator, self).visit(node) 168 169 170def resolve(node, source_info, graphs): 171 """Resolves reaching definitions for each symbol. 172 173 Args: 174 node: ast.AST 175 source_info: transformer.SourceInfo 176 graphs: Dict[ast.FunctionDef, cfg.Graph] 177 Returns: 178 ast.AST 179 """ 180 visitor = TreeAnnotator(source_info, graphs) 181 node = visitor.visit(node) 182 return node 183