• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""SymbolTree nodes manager."""
16from typing import Optional, Union
17import ast
18from .node import Node
19from .node_topological_manager import TopoManager
20from ..api.node_type import NodeType
21from ..api.scoped_value import ScopedValue
22
23
24class NodeManager:
25    """
26    NodeManager saves nodes and manager nodes' topological relationship.
27    """
28    def __init__(self):
29        """Initializer of NodeManager"""
30        self._topo_mgr = TopoManager()
31        self._nodes: {str, Node} = {}
32        self._manager_node_namer = None
33        # record all tree nodes, which is used when generating codes
34        self._tree_nodes: [Node] = []
35        # head node is always point to the first node of nodes
36        self._head = None
37        # tail node is always point to the last node of nodes
38        self._tail = None
39        # nodes of Input type
40        self._inputs: [Node] = []
41        # nodes of Output type
42        self._returns: [Node] = []
43        # ast of node manager
44        # SymbolTree    -> ast.FunctionDef
45        # CallFunction  -> ast.FunctionDef
46        # ControlFlow   -> list
47        # CellContainer -> ast.Assign
48        self._node_manager_ast: Union[ast.AST, list] = None
49        # name of manager
50        self._manager_name = "OriginNodeManager"
51
52    @property
53    def node_list(self):
54        """ Get node list. """
55        nodes = []
56        node = self._head
57        while node is not None:
58            nodes.append(node)
59            node = node.get_next()
60        return nodes
61
62    @property
63    def node_count(self):
64        """Number of nodes."""
65        node_num = 0
66        node = self._head
67        while node is not None:
68            node_num = node_num + 1
69            node = node.get_next()
70        return node_num
71
72    def insert_node(self, new_node: Node, base_node: Node, before_node: bool):
73        """
74        Insert a node before or after base_node.
75
76        Args:
77            new_node (Node): Node to be inserted.
78            base_node (Node): New node will be inserted before or after base_node.
79            before_node (bool): Indicate whether new node is inserted before base_node.
80        """
81        # update node name
82        new_node_name = self._manager_node_namer.get_name(new_node)
83        new_node.set_name(new_node_name)
84        if isinstance(new_node, NodeManager):
85            new_node.set_manager_name(new_node_name)
86        # insert node to list table
87        if base_node is None:
88            if self._nodes:
89                raise ValueError("base_node cannot be None when node inserted is not the first node.")
90            self._head = new_node
91            self._tail = new_node
92        elif before_node:
93            base_node.insert_before(new_node)
94            if self._head == base_node:
95                self._head = new_node
96        else:
97            base_node.insert_after(new_node)
98            if self._tail == base_node:
99                self._tail = new_node
100        self._add_node_to_nodes(new_node)
101        self._topo_mgr.on_insert_node(new_node)
102        new_node.set_node_manager(self)
103        # record Input nodes, Output nodes and tree nodes
104        if new_node.get_node_type() == NodeType.Output:
105            self._returns.append(new_node)
106        elif new_node.get_node_type() == NodeType.Input:
107            self._inputs.append(new_node)
108        elif new_node.get_node_type() == NodeType.Tree:
109            self._tree_nodes.append(new_node)
110
111    def erase_node(self, node: Node):
112        """
113        Erase a node from nodes.
114
115        Args:
116            node (Node): Node to be erased.
117        """
118        self._topo_mgr.on_erase_node(node)
119        for key, value in self._nodes.items():
120            if id(value) == id(node):
121                # update self._head and self._tail
122                if self._head == node:
123                    self._head = node.get_next()
124                if self._tail == node:
125                    self._tail = node.get_prev()
126                # erase node
127                self._nodes.pop(key)
128                value.isolate()
129                break
130
131    def nodes(self):
132        """
133        Get nodes.
134
135        Returns:
136            A list of nodes.
137        """
138        # If iterating nodes directly without new list, iteration may stuck caused
139        # by node topology being modified during iteration.
140        nodes = []
141        node = self._head
142        while node is not None:
143            nodes.append(node)
144            node = node.get_next()
145        return nodes
146
147    def get_node(self, node_name: str) -> Optional[Node]:
148        """
149        Get node of current NodeManager by `node_name`.
150
151        Args:
152            node_name (str): A str represents name of node as key of query.
153
154        Returns:
155            An instance of Node if found else None.
156        """
157        return self._nodes.get(node_name)
158
159    def append_python_node(self, new_node: Node):
160        """Append python node"""
161        NodeManager.insert_node(self, new_node, self._tail, False)
162
163    def get_head(self):
164        """Get head node of nodes"""
165        return self._head
166
167    def get_tail(self):
168        """Get tail node of nodes"""
169        return self._tail
170
171    def reg_observer(self, observer):
172        """Register observer to monitor code changes."""
173        self._topo_mgr.reg_observer(observer)
174        for node in self.nodes():
175            if isinstance(node, NodeManager):
176                node.reg_observer(observer)
177            if node.get_node_type() == NodeType.Tree:
178                node.symbol_tree.reg_observer(observer)
179
180    def get_tree_nodes(self):
181        """Get tree nodes inserted into symbol tree, include nodes later erased by user."""
182        tree_nodes = []
183        tree_nodes.extend(self._tree_nodes)
184        for node in self.nodes():
185            if isinstance(node, NodeManager):
186                tree_nodes.extend(node.get_tree_nodes())
187        return tree_nodes
188
189    def set_manager_ast(self, node_manager_ast: Union[ast.AST, list]):
190        """Set _node_manager_ast."""
191        self._node_manager_ast = node_manager_ast
192
193    def get_manager_ast(self):
194        """Get _node_manager_ast."""
195        return self._node_manager_ast
196
197    def get_input_nodes(self):
198        """Get _inputs"""
199        return self._inputs
200
201    def get_returns(self):
202        """Get _returns"""
203        return self._returns
204
205    def set_manager_name(self, name: str):
206        """Set _manager_name"""
207        self._manager_name = name
208
209    def get_manager_name(self):
210        """Get _manager_name"""
211        return self._manager_name
212
213    def on_update_arg(self, node: Node, arg_idx: int, old_arg: ScopedValue, new_arg: ScopedValue):
214        """
215        Update node topological when node arg is modified.
216        """
217        self._topo_mgr.on_update_arg(node, arg_idx, old_arg, new_arg)
218
219    def on_update_arg_by_node(self, dst_node: Node, arg_idx: int, src_node: Node, out_idx: int):
220        """
221        Update node topological when node arg is modified by another node.
222        """
223        self._topo_mgr.on_update_arg_by_node(dst_node, arg_idx, src_node, out_idx)
224
225    def dump(self, title="") -> str:
226        """
227        Dump topological relation.
228
229        title (str): A string as a title will be printed before dumping topological relation.
230        """
231        try:
232            from tabulate import tabulate # pylint: disable=unused-import,reportMissingModuleSource
233        except ImportError:
234            return ""
235        dump_str = f"\n[{title}]\n"
236        node_specs = [[
237            n.get_node_type(),
238            n.get_name(),
239            n.get_source_code(),
240            [[key, ((value[0].get_name(), value[1]) if value else ())]
241             for key, value in n.get_arg_providers().items()],
242            [[
243                key,
244                [(val[0].get_name(), val[1]) if val else ()
245                 for val in value] if value else []
246            ] for key, value in n.get_target_users().items()]
247        ] for n in NodeManager.nodes(self)]
248        dump_str += tabulate(node_specs, headers=['node type', 'name', 'codes', 'arg providers', 'target users'])
249        dump_str += '\n'
250        return dump_str
251
252    def get_top_manager(self) -> 'NodeManager':
253        """
254        Get the top node_manager with type of no-method CallFunction or SymbolTree this
255        node_manager belongs to.
256        """
257        from .call_function import CallFunction
258        from ..symbol_tree import SymbolTree
259        if isinstance(self, SymbolTree):
260            return self
261        if isinstance(self, CallFunction) and not self.is_method():
262            return self
263        return self.get_node_manager().get_top_manager()
264
265    def set_manager_node_namer(self, node_namer):
266        """Set manager node namer"""
267        self._manager_node_namer = node_namer
268
269    def _add_node_to_nodes(self, node: Node):
270        """
271        Add `node` to `_nodes` dict.
272
273        Args:
274            node (Node): A Node to be added into `_nodes`.
275
276        Raises:
277            ValueError: If name of the node is duplicated.
278        """
279        node_name = node.get_name()
280        if self._nodes.get(node_name) is not None:
281            raise ValueError(f"Duplicated node name: {node_name} in"
282                             f"{self.get_name() if isinstance(self, Node) else 'construct'}")
283        self._nodes[node_name] = node
284