1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""SymbolTree nodes manager.""" 16from typing import Optional, Union 17import ast 18from .node import Node 19from .node_topological_manager import TopoManager 20from ..api.node_type import NodeType 21from ..api.scoped_value import ScopedValue 22 23 24class NodeManager: 25 """ 26 NodeManager saves nodes and manager nodes' topological relationship. 27 """ 28 def __init__(self): 29 """Initializer of NodeManager""" 30 self._topo_mgr = TopoManager() 31 self._nodes: {str, Node} = {} 32 self._manager_node_namer = None 33 # record all tree nodes, which is used when generating codes 34 self._tree_nodes: [Node] = [] 35 # head node is always point to the first node of nodes 36 self._head = None 37 # tail node is always point to the last node of nodes 38 self._tail = None 39 # nodes of Input type 40 self._inputs: [Node] = [] 41 # nodes of Output type 42 self._returns: [Node] = [] 43 # ast of node manager 44 # SymbolTree -> ast.FunctionDef 45 # CallFunction -> ast.FunctionDef 46 # ControlFlow -> list 47 # CellContainer -> ast.Assign 48 self._node_manager_ast: Union[ast.AST, list] = None 49 # name of manager 50 self._manager_name = "OriginNodeManager" 51 52 @property 53 def node_list(self): 54 """ Get node list. """ 55 nodes = [] 56 node = self._head 57 while node is not None: 58 nodes.append(node) 59 node = node.get_next() 60 return nodes 61 62 @property 63 def node_count(self): 64 """Number of nodes.""" 65 node_num = 0 66 node = self._head 67 while node is not None: 68 node_num = node_num + 1 69 node = node.get_next() 70 return node_num 71 72 def insert_node(self, new_node: Node, base_node: Node, before_node: bool): 73 """ 74 Insert a node before or after base_node. 75 76 Args: 77 new_node (Node): Node to be inserted. 78 base_node (Node): New node will be inserted before or after base_node. 79 before_node (bool): Indicate whether new node is inserted before base_node. 80 """ 81 # update node name 82 new_node_name = self._manager_node_namer.get_name(new_node) 83 new_node.set_name(new_node_name) 84 if isinstance(new_node, NodeManager): 85 new_node.set_manager_name(new_node_name) 86 # insert node to list table 87 if base_node is None: 88 if self._nodes: 89 raise ValueError("base_node cannot be None when node inserted is not the first node.") 90 self._head = new_node 91 self._tail = new_node 92 elif before_node: 93 base_node.insert_before(new_node) 94 if self._head == base_node: 95 self._head = new_node 96 else: 97 base_node.insert_after(new_node) 98 if self._tail == base_node: 99 self._tail = new_node 100 self._add_node_to_nodes(new_node) 101 self._topo_mgr.on_insert_node(new_node) 102 new_node.set_node_manager(self) 103 # record Input nodes, Output nodes and tree nodes 104 if new_node.get_node_type() == NodeType.Output: 105 self._returns.append(new_node) 106 elif new_node.get_node_type() == NodeType.Input: 107 self._inputs.append(new_node) 108 elif new_node.get_node_type() == NodeType.Tree: 109 self._tree_nodes.append(new_node) 110 111 def erase_node(self, node: Node): 112 """ 113 Erase a node from nodes. 114 115 Args: 116 node (Node): Node to be erased. 117 """ 118 self._topo_mgr.on_erase_node(node) 119 for key, value in self._nodes.items(): 120 if id(value) == id(node): 121 # update self._head and self._tail 122 if self._head == node: 123 self._head = node.get_next() 124 if self._tail == node: 125 self._tail = node.get_prev() 126 # erase node 127 self._nodes.pop(key) 128 value.isolate() 129 break 130 131 def nodes(self): 132 """ 133 Get nodes. 134 135 Returns: 136 A list of nodes. 137 """ 138 # If iterating nodes directly without new list, iteration may stuck caused 139 # by node topology being modified during iteration. 140 nodes = [] 141 node = self._head 142 while node is not None: 143 nodes.append(node) 144 node = node.get_next() 145 return nodes 146 147 def get_node(self, node_name: str) -> Optional[Node]: 148 """ 149 Get node of current NodeManager by `node_name`. 150 151 Args: 152 node_name (str): A str represents name of node as key of query. 153 154 Returns: 155 An instance of Node if found else None. 156 """ 157 return self._nodes.get(node_name) 158 159 def append_python_node(self, new_node: Node): 160 """Append python node""" 161 NodeManager.insert_node(self, new_node, self._tail, False) 162 163 def get_head(self): 164 """Get head node of nodes""" 165 return self._head 166 167 def get_tail(self): 168 """Get tail node of nodes""" 169 return self._tail 170 171 def reg_observer(self, observer): 172 """Register observer to monitor code changes.""" 173 self._topo_mgr.reg_observer(observer) 174 for node in self.nodes(): 175 if isinstance(node, NodeManager): 176 node.reg_observer(observer) 177 if node.get_node_type() == NodeType.Tree: 178 node.symbol_tree.reg_observer(observer) 179 180 def get_tree_nodes(self): 181 """Get tree nodes inserted into symbol tree, include nodes later erased by user.""" 182 tree_nodes = [] 183 tree_nodes.extend(self._tree_nodes) 184 for node in self.nodes(): 185 if isinstance(node, NodeManager): 186 tree_nodes.extend(node.get_tree_nodes()) 187 return tree_nodes 188 189 def set_manager_ast(self, node_manager_ast: Union[ast.AST, list]): 190 """Set _node_manager_ast.""" 191 self._node_manager_ast = node_manager_ast 192 193 def get_manager_ast(self): 194 """Get _node_manager_ast.""" 195 return self._node_manager_ast 196 197 def get_input_nodes(self): 198 """Get _inputs""" 199 return self._inputs 200 201 def get_returns(self): 202 """Get _returns""" 203 return self._returns 204 205 def set_manager_name(self, name: str): 206 """Set _manager_name""" 207 self._manager_name = name 208 209 def get_manager_name(self): 210 """Get _manager_name""" 211 return self._manager_name 212 213 def on_update_arg(self, node: Node, arg_idx: int, old_arg: ScopedValue, new_arg: ScopedValue): 214 """ 215 Update node topological when node arg is modified. 216 """ 217 self._topo_mgr.on_update_arg(node, arg_idx, old_arg, new_arg) 218 219 def on_update_arg_by_node(self, dst_node: Node, arg_idx: int, src_node: Node, out_idx: int): 220 """ 221 Update node topological when node arg is modified by another node. 222 """ 223 self._topo_mgr.on_update_arg_by_node(dst_node, arg_idx, src_node, out_idx) 224 225 def dump(self, title="") -> str: 226 """ 227 Dump topological relation. 228 229 title (str): A string as a title will be printed before dumping topological relation. 230 """ 231 try: 232 from tabulate import tabulate # pylint: disable=unused-import,reportMissingModuleSource 233 except ImportError: 234 return "" 235 dump_str = f"\n[{title}]\n" 236 node_specs = [[ 237 n.get_node_type(), 238 n.get_name(), 239 n.get_source_code(), 240 [[key, ((value[0].get_name(), value[1]) if value else ())] 241 for key, value in n.get_arg_providers().items()], 242 [[ 243 key, 244 [(val[0].get_name(), val[1]) if val else () 245 for val in value] if value else [] 246 ] for key, value in n.get_target_users().items()] 247 ] for n in NodeManager.nodes(self)] 248 dump_str += tabulate(node_specs, headers=['node type', 'name', 'codes', 'arg providers', 'target users']) 249 dump_str += '\n' 250 return dump_str 251 252 def get_top_manager(self) -> 'NodeManager': 253 """ 254 Get the top node_manager with type of no-method CallFunction or SymbolTree this 255 node_manager belongs to. 256 """ 257 from .call_function import CallFunction 258 from ..symbol_tree import SymbolTree 259 if isinstance(self, SymbolTree): 260 return self 261 if isinstance(self, CallFunction) and not self.is_method(): 262 return self 263 return self.get_node_manager().get_top_manager() 264 265 def set_manager_node_namer(self, node_namer): 266 """Set manager node namer""" 267 self._manager_node_namer = node_namer 268 269 def _add_node_to_nodes(self, node: Node): 270 """ 271 Add `node` to `_nodes` dict. 272 273 Args: 274 node (Node): A Node to be added into `_nodes`. 275 276 Raises: 277 ValueError: If name of the node is duplicated. 278 """ 279 node_name = node.get_name() 280 if self._nodes.get(node_name) is not None: 281 raise ValueError(f"Duplicated node name: {node_name} in" 282 f"{self.get_name() if isinstance(self, Node) else 'construct'}") 283 self._nodes[node_name] = node 284