1import typing as t 2 3from . import nodes 4from .visitor import NodeVisitor 5 6VAR_LOAD_PARAMETER = "param" 7VAR_LOAD_RESOLVE = "resolve" 8VAR_LOAD_ALIAS = "alias" 9VAR_LOAD_UNDEFINED = "undefined" 10 11 12def find_symbols( 13 nodes: t.Iterable[nodes.Node], parent_symbols: t.Optional["Symbols"] = None 14) -> "Symbols": 15 sym = Symbols(parent=parent_symbols) 16 visitor = FrameSymbolVisitor(sym) 17 for node in nodes: 18 visitor.visit(node) 19 return sym 20 21 22def symbols_for_node( 23 node: nodes.Node, parent_symbols: t.Optional["Symbols"] = None 24) -> "Symbols": 25 sym = Symbols(parent=parent_symbols) 26 sym.analyze_node(node) 27 return sym 28 29 30class Symbols: 31 def __init__( 32 self, parent: t.Optional["Symbols"] = None, level: t.Optional[int] = None 33 ) -> None: 34 if level is None: 35 if parent is None: 36 level = 0 37 else: 38 level = parent.level + 1 39 40 self.level: int = level 41 self.parent = parent 42 self.refs: t.Dict[str, str] = {} 43 self.loads: t.Dict[str, t.Any] = {} 44 self.stores: t.Set[str] = set() 45 46 def analyze_node(self, node: nodes.Node, **kwargs: t.Any) -> None: 47 visitor = RootVisitor(self) 48 visitor.visit(node, **kwargs) 49 50 def _define_ref( 51 self, name: str, load: t.Optional[t.Tuple[str, t.Optional[str]]] = None 52 ) -> str: 53 ident = f"l_{self.level}_{name}" 54 self.refs[name] = ident 55 if load is not None: 56 self.loads[ident] = load 57 return ident 58 59 def find_load(self, target: str) -> t.Optional[t.Any]: 60 if target in self.loads: 61 return self.loads[target] 62 63 if self.parent is not None: 64 return self.parent.find_load(target) 65 66 return None 67 68 def find_ref(self, name: str) -> t.Optional[str]: 69 if name in self.refs: 70 return self.refs[name] 71 72 if self.parent is not None: 73 return self.parent.find_ref(name) 74 75 return None 76 77 def ref(self, name: str) -> str: 78 rv = self.find_ref(name) 79 if rv is None: 80 raise AssertionError( 81 "Tried to resolve a name to a reference that was" 82 f" unknown to the frame ({name!r})" 83 ) 84 return rv 85 86 def copy(self) -> "Symbols": 87 rv = object.__new__(self.__class__) 88 rv.__dict__.update(self.__dict__) 89 rv.refs = self.refs.copy() 90 rv.loads = self.loads.copy() 91 rv.stores = self.stores.copy() 92 return rv 93 94 def store(self, name: str) -> None: 95 self.stores.add(name) 96 97 # If we have not see the name referenced yet, we need to figure 98 # out what to set it to. 99 if name not in self.refs: 100 # If there is a parent scope we check if the name has a 101 # reference there. If it does it means we might have to alias 102 # to a variable there. 103 if self.parent is not None: 104 outer_ref = self.parent.find_ref(name) 105 if outer_ref is not None: 106 self._define_ref(name, load=(VAR_LOAD_ALIAS, outer_ref)) 107 return 108 109 # Otherwise we can just set it to undefined. 110 self._define_ref(name, load=(VAR_LOAD_UNDEFINED, None)) 111 112 def declare_parameter(self, name: str) -> str: 113 self.stores.add(name) 114 return self._define_ref(name, load=(VAR_LOAD_PARAMETER, None)) 115 116 def load(self, name: str) -> None: 117 if self.find_ref(name) is None: 118 self._define_ref(name, load=(VAR_LOAD_RESOLVE, name)) 119 120 def branch_update(self, branch_symbols: t.Sequence["Symbols"]) -> None: 121 stores: t.Dict[str, int] = {} 122 for branch in branch_symbols: 123 for target in branch.stores: 124 if target in self.stores: 125 continue 126 stores[target] = stores.get(target, 0) + 1 127 128 for sym in branch_symbols: 129 self.refs.update(sym.refs) 130 self.loads.update(sym.loads) 131 self.stores.update(sym.stores) 132 133 for name, branch_count in stores.items(): 134 if branch_count == len(branch_symbols): 135 continue 136 137 target = self.find_ref(name) # type: ignore 138 assert target is not None, "should not happen" 139 140 if self.parent is not None: 141 outer_target = self.parent.find_ref(name) 142 if outer_target is not None: 143 self.loads[target] = (VAR_LOAD_ALIAS, outer_target) 144 continue 145 self.loads[target] = (VAR_LOAD_RESOLVE, name) 146 147 def dump_stores(self) -> t.Dict[str, str]: 148 rv: t.Dict[str, str] = {} 149 node: t.Optional["Symbols"] = self 150 151 while node is not None: 152 for name in sorted(node.stores): 153 if name not in rv: 154 rv[name] = self.find_ref(name) # type: ignore 155 156 node = node.parent 157 158 return rv 159 160 def dump_param_targets(self) -> t.Set[str]: 161 rv = set() 162 node: t.Optional["Symbols"] = self 163 164 while node is not None: 165 for target, (instr, _) in self.loads.items(): 166 if instr == VAR_LOAD_PARAMETER: 167 rv.add(target) 168 169 node = node.parent 170 171 return rv 172 173 174class RootVisitor(NodeVisitor): 175 def __init__(self, symbols: "Symbols") -> None: 176 self.sym_visitor = FrameSymbolVisitor(symbols) 177 178 def _simple_visit(self, node: nodes.Node, **kwargs: t.Any) -> None: 179 for child in node.iter_child_nodes(): 180 self.sym_visitor.visit(child) 181 182 visit_Template = _simple_visit 183 visit_Block = _simple_visit 184 visit_Macro = _simple_visit 185 visit_FilterBlock = _simple_visit 186 visit_Scope = _simple_visit 187 visit_If = _simple_visit 188 visit_ScopedEvalContextModifier = _simple_visit 189 190 def visit_AssignBlock(self, node: nodes.AssignBlock, **kwargs: t.Any) -> None: 191 for child in node.body: 192 self.sym_visitor.visit(child) 193 194 def visit_CallBlock(self, node: nodes.CallBlock, **kwargs: t.Any) -> None: 195 for child in node.iter_child_nodes(exclude=("call",)): 196 self.sym_visitor.visit(child) 197 198 def visit_OverlayScope(self, node: nodes.OverlayScope, **kwargs: t.Any) -> None: 199 for child in node.body: 200 self.sym_visitor.visit(child) 201 202 def visit_For( 203 self, node: nodes.For, for_branch: str = "body", **kwargs: t.Any 204 ) -> None: 205 if for_branch == "body": 206 self.sym_visitor.visit(node.target, store_as_param=True) 207 branch = node.body 208 elif for_branch == "else": 209 branch = node.else_ 210 elif for_branch == "test": 211 self.sym_visitor.visit(node.target, store_as_param=True) 212 if node.test is not None: 213 self.sym_visitor.visit(node.test) 214 return 215 else: 216 raise RuntimeError("Unknown for branch") 217 218 if branch: 219 for item in branch: 220 self.sym_visitor.visit(item) 221 222 def visit_With(self, node: nodes.With, **kwargs: t.Any) -> None: 223 for target in node.targets: 224 self.sym_visitor.visit(target) 225 for child in node.body: 226 self.sym_visitor.visit(child) 227 228 def generic_visit(self, node: nodes.Node, *args: t.Any, **kwargs: t.Any) -> None: 229 raise NotImplementedError(f"Cannot find symbols for {type(node).__name__!r}") 230 231 232class FrameSymbolVisitor(NodeVisitor): 233 """A visitor for `Frame.inspect`.""" 234 235 def __init__(self, symbols: "Symbols") -> None: 236 self.symbols = symbols 237 238 def visit_Name( 239 self, node: nodes.Name, store_as_param: bool = False, **kwargs: t.Any 240 ) -> None: 241 """All assignments to names go through this function.""" 242 if store_as_param or node.ctx == "param": 243 self.symbols.declare_parameter(node.name) 244 elif node.ctx == "store": 245 self.symbols.store(node.name) 246 elif node.ctx == "load": 247 self.symbols.load(node.name) 248 249 def visit_NSRef(self, node: nodes.NSRef, **kwargs: t.Any) -> None: 250 self.symbols.load(node.name) 251 252 def visit_If(self, node: nodes.If, **kwargs: t.Any) -> None: 253 self.visit(node.test, **kwargs) 254 original_symbols = self.symbols 255 256 def inner_visit(nodes: t.Iterable[nodes.Node]) -> "Symbols": 257 self.symbols = rv = original_symbols.copy() 258 259 for subnode in nodes: 260 self.visit(subnode, **kwargs) 261 262 self.symbols = original_symbols 263 return rv 264 265 body_symbols = inner_visit(node.body) 266 elif_symbols = inner_visit(node.elif_) 267 else_symbols = inner_visit(node.else_ or ()) 268 self.symbols.branch_update([body_symbols, elif_symbols, else_symbols]) 269 270 def visit_Macro(self, node: nodes.Macro, **kwargs: t.Any) -> None: 271 self.symbols.store(node.name) 272 273 def visit_Import(self, node: nodes.Import, **kwargs: t.Any) -> None: 274 self.generic_visit(node, **kwargs) 275 self.symbols.store(node.target) 276 277 def visit_FromImport(self, node: nodes.FromImport, **kwargs: t.Any) -> None: 278 self.generic_visit(node, **kwargs) 279 280 for name in node.names: 281 if isinstance(name, tuple): 282 self.symbols.store(name[1]) 283 else: 284 self.symbols.store(name) 285 286 def visit_Assign(self, node: nodes.Assign, **kwargs: t.Any) -> None: 287 """Visit assignments in the correct order.""" 288 self.visit(node.node, **kwargs) 289 self.visit(node.target, **kwargs) 290 291 def visit_For(self, node: nodes.For, **kwargs: t.Any) -> None: 292 """Visiting stops at for blocks. However the block sequence 293 is visited as part of the outer scope. 294 """ 295 self.visit(node.iter, **kwargs) 296 297 def visit_CallBlock(self, node: nodes.CallBlock, **kwargs: t.Any) -> None: 298 self.visit(node.call, **kwargs) 299 300 def visit_FilterBlock(self, node: nodes.FilterBlock, **kwargs: t.Any) -> None: 301 self.visit(node.filter, **kwargs) 302 303 def visit_With(self, node: nodes.With, **kwargs: t.Any) -> None: 304 for target in node.values: 305 self.visit(target) 306 307 def visit_AssignBlock(self, node: nodes.AssignBlock, **kwargs: t.Any) -> None: 308 """Stop visiting at block assigns.""" 309 self.visit(node.target, **kwargs) 310 311 def visit_Scope(self, node: nodes.Scope, **kwargs: t.Any) -> None: 312 """Stop visiting at scopes.""" 313 314 def visit_Block(self, node: nodes.Block, **kwargs: t.Any) -> None: 315 """Stop visiting at blocks.""" 316 317 def visit_OverlayScope(self, node: nodes.OverlayScope, **kwargs: t.Any) -> None: 318 """Do not visit into overlay scopes.""" 319