• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""SymbolTree nodes topological relationship manager."""
16from typing import Tuple
17from mindspore import log as logger
18from .node import Node
19from ..api.scoped_value import ScopedValue
20from ..common.observable import Observable
21from ..common.event import Event
22
23
24class TopoManager(Observable):
25    """SymbolTree topological-relationship manager."""
26
27    @staticmethod
28    def on_update_target(node: Node, index: int, old_target: ScopedValue, new_target: ScopedValue):
29        """
30        Update node's dicts while updating target of node.
31
32        Args:
33            node (Node): An instance of Node whose target being updated.
34            arg_idx (int): An int indicates which target of node being updated.
35            old_target (ScopedValue): An instance of ScopedValue represents old target.
36            new_target (ScopedValue): An instance of ScopedValue represents new target.
37        """
38        # Update old_target provider node's target_user dict & old arg's user nodes' arg_providers dict
39        old_provider = TopoManager._get_value_provider(node, old_target)
40        if old_provider:
41            for user in node.get_target_users(index):
42                old_provider[0].append_target_users(old_provider[1], user)
43                user[0].set_arg_providers(user[1], old_provider)
44        else:
45            for user in node.get_target_users(index):
46                user[0].set_arg_providers(user[1], ())
47        # Update new_target node's target_users dict & new user nodes' arg_providers dict
48        node.get_target_users(index).clear()
49        provider = TopoManager._get_value_provider(node, new_target)
50        if provider:
51            TopoManager._update_target_users_by_node(node, index, provider)
52        else:
53            TopoManager._update_target_users_by_value(node, index, new_target)
54
55    @staticmethod
56    def _update_target_users_by_value(node, index, value: ScopedValue):
57        """
58        Update node's _target_users by ScopedValue when insert a new node.
59        This function is called when target is not found in previous nodes, which means a new target name is set.
60        """
61        search_node = node.get_next()
62        while search_node is not None:
63            if search_node.get_normalized_args() is not None:
64                for arg_index, arg in enumerate(search_node.get_normalized_args().values()):
65                    if arg == value:
66                        node.append_target_users(index, (search_node, arg_index))
67                        search_node.set_arg_providers(arg_index, (node, index))
68            if search_node.get_targets() is not None:
69                for _, target in enumerate(search_node.get_targets()):
70                    if target == value:
71                        return
72            search_node = search_node.get_next()
73        return
74
75    @staticmethod
76    def _update_target_users_by_node(node, index, provider: Tuple[Node, int]):
77        """
78        Update node's _target_users by previous node when insert a new node.
79        This function is called when target is found in previous nodes, which means a repeat target name is set.
80        """
81        # Args of nodes which are between node and provider should not be changed
82        # [last provider] -> no change args -> [insert node] -> need change args -> [next provider] -> no change args
83        nodes_before_insert = []
84        search_node = provider[0].get_next()
85        while search_node is not None:
86            nodes_before_insert.append(search_node)
87            if search_node == node:
88                break
89            search_node = search_node.get_next()
90        provider_target_users = provider[0].get_target_users(provider[1])
91        for user in provider_target_users[:]: # copy list by slice to support remove item during iterating
92            if user[0] not in nodes_before_insert:
93                node.append_target_users(index, user)
94                provider_target_users.remove(user)
95                user[0].set_arg_providers(user[1], (node, index))
96
97    @staticmethod
98    def _get_value_provider(node, value: ScopedValue):
99        node = node.get_prev()
100        while node is not None:
101            if node.get_targets() is not None:
102                for index, target in enumerate(node.get_targets()):
103                    if target == value:
104                        return (node, index)
105            node = node.get_prev()
106        return ()
107
108    def topo_changed(self):
109        """
110        The function is executed when an Event.TopologicalChangeEvent event is received.
111        """
112        self.changed(Event.TopologicalChangeEvent)
113
114    def on_insert_node(self, node: Node):
115        """
116        Update provider dict and consumer dict while inserting node into SymbolTree and update inputs of node by updated
117        provider dict and consumer dict.
118
119        Args:
120            node (Node): An instance of Node which been inserted into SymbolTree.
121        """
122        if node.get_normalized_args() is not None:
123            for index, arg in enumerate(node.get_normalized_args().values()):
124                provider = TopoManager._get_value_provider(node, arg)
125                if provider:
126                    node.set_arg_providers(index, provider)
127                    provider[0].append_target_users(provider[1], (node, index))
128        if node.get_targets() is not None:
129            for index, target in enumerate(node.get_targets()):
130                provider = TopoManager._get_value_provider(node, target)
131                if provider:
132                    TopoManager._update_target_users_by_node(node, index, provider)
133                else:
134                    TopoManager._update_target_users_by_value(node, index, target)
135        self.topo_changed()
136
137    def on_erase_node(self, node: Node):
138        """
139        Update provider dict and consumer dict while erasing node from SymbolTree.
140
141        Args:
142            node (Node): An instance of Node which been erased from SymbolTree.
143        """
144        prev_providers = {}
145        # Find previous node with same target of current node.
146        for index, target_users in node.get_target_users().items():
147            if not target_users:
148                continue
149            prev_provider = TopoManager._get_value_provider(node, node.get_targets()[index])
150            if not prev_provider:
151                logger.warning(f"Node {node.get_name()}'s target {index}({node.get_targets()[index]}) is used in node "
152                               f"{target_users[0][0].get_name()}'s arg {target_users[0][1]}, "
153                               f"no other node provides this target if node {node.get_name()} is erased.")
154                prev_providers[index] = None
155            else:
156                prev_providers[index] = prev_provider
157        # Update targets topological of nodes
158        for index, prev_provider in prev_providers.items():
159            for target_user in node.get_target_users(index):
160                if prev_provider is None:
161                    target_user[0].get_arg_providers().pop(target_user[1], None)
162                else:
163                    prev_provider[0].append_target_users(prev_provider[1], target_user)
164                    target_user[0].set_arg_providers(target_user[1], prev_provider)
165        # Update arguments topological of nodes
166        for _, arg_providers in node.get_arg_providers().items():
167            if not arg_providers:
168                continue
169            provider_target_users = arg_providers[0].get_target_users(arg_providers[1])
170            for target_user in reversed(provider_target_users):
171                if target_user[0] == node:
172                    provider_target_users.remove(target_user)
173        self.topo_changed()
174
175    def on_update_arg(self, node: Node, arg_idx: int, old_arg: ScopedValue, new_arg: ScopedValue):
176        """
177        Update provider dict and consumer dict while updating argument of node and update inputs of node by updated
178        provider dict and consumer dict.
179
180        Args:
181            node (Node): An instance of Node whose arguments being updated.
182            arg_idx (int): An int indicates which argument of node being updated.
183            old_arg (ScopedValue): An instance of ScopedValue represents original argument.
184            new_arg (ScopedValue): An instance of ScopedValue represents new argument.
185        """
186        # Update old arg's provider node's target_users.
187        old_provider = TopoManager._get_value_provider(node, old_arg)
188        if old_provider:
189            old_provider_target_users = old_provider[0].get_target_users(old_provider[1])
190            for target_user in reversed(old_provider_target_users):
191                if target_user[0] == node and target_user[1] == arg_idx:
192                    old_provider_target_users.remove(target_user)
193                    break
194        # Update new arg's provider node's target_users.
195        provider = TopoManager._get_value_provider(node, new_arg)
196        if provider:
197            provider[0].append_target_users(provider[1], (node, arg_idx))
198        # Update current node's arg_providers.
199        node.set_arg_providers(arg_idx, provider)
200        self.topo_changed()
201
202    def on_update_arg_by_node(self, dst_node: Node, arg_idx: int, src_node: Node, out_idx: int):
203        """
204        Update argument of 'dst_node' by another Node.
205
206        Args:
207            dst_node (Node): Node to be modified.
208            arg_idx (int): Indicate which input being modified.
209            src_node (Node): Node as new input.
210            out_idx (int): Indicate which output of 'src_node' as new input of 'dst_node'.
211        """
212        # Update old arg's provider node's target_users.
213        if arg_idx in dst_node.get_arg_providers().keys():
214            arg_provider = dst_node.get_arg_providers()[arg_idx]
215            if arg_provider:
216                provider_target_users = arg_provider[0].get_target_users(arg_provider[1])
217                if (dst_node, arg_idx) in provider_target_users:
218                    provider_target_users.remove((dst_node, arg_idx))
219        # Update new arg's provider node's target_users.
220        src_node.append_target_users(out_idx, (dst_node, arg_idx))
221        # Update current node's arg_providers.
222        dst_node.set_arg_providers(arg_idx, (src_node, out_idx))
223        self.topo_changed()
224