1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""SymbolTree nodes topological relationship manager.""" 16from typing import Tuple 17from mindspore import log as logger 18from .node import Node 19from ..api.scoped_value import ScopedValue 20from ..common.observable import Observable 21from ..common.event import Event 22 23 24class TopoManager(Observable): 25 """SymbolTree topological-relationship manager.""" 26 27 @staticmethod 28 def on_update_target(node: Node, index: int, old_target: ScopedValue, new_target: ScopedValue): 29 """ 30 Update node's dicts while updating target of node. 31 32 Args: 33 node (Node): An instance of Node whose target being updated. 34 arg_idx (int): An int indicates which target of node being updated. 35 old_target (ScopedValue): An instance of ScopedValue represents old target. 36 new_target (ScopedValue): An instance of ScopedValue represents new target. 37 """ 38 # Update old_target provider node's target_user dict & old arg's user nodes' arg_providers dict 39 old_provider = TopoManager._get_value_provider(node, old_target) 40 if old_provider: 41 for user in node.get_target_users(index): 42 old_provider[0].append_target_users(old_provider[1], user) 43 user[0].set_arg_providers(user[1], old_provider) 44 else: 45 for user in node.get_target_users(index): 46 user[0].set_arg_providers(user[1], ()) 47 # Update new_target node's target_users dict & new user nodes' arg_providers dict 48 node.get_target_users(index).clear() 49 provider = TopoManager._get_value_provider(node, new_target) 50 if provider: 51 TopoManager._update_target_users_by_node(node, index, provider) 52 else: 53 TopoManager._update_target_users_by_value(node, index, new_target) 54 55 @staticmethod 56 def _update_target_users_by_value(node, index, value: ScopedValue): 57 """ 58 Update node's _target_users by ScopedValue when insert a new node. 59 This function is called when target is not found in previous nodes, which means a new target name is set. 60 """ 61 search_node = node.get_next() 62 while search_node is not None: 63 if search_node.get_normalized_args() is not None: 64 for arg_index, arg in enumerate(search_node.get_normalized_args().values()): 65 if arg == value: 66 node.append_target_users(index, (search_node, arg_index)) 67 search_node.set_arg_providers(arg_index, (node, index)) 68 if search_node.get_targets() is not None: 69 for _, target in enumerate(search_node.get_targets()): 70 if target == value: 71 return 72 search_node = search_node.get_next() 73 return 74 75 @staticmethod 76 def _update_target_users_by_node(node, index, provider: Tuple[Node, int]): 77 """ 78 Update node's _target_users by previous node when insert a new node. 79 This function is called when target is found in previous nodes, which means a repeat target name is set. 80 """ 81 # Args of nodes which are between node and provider should not be changed 82 # [last provider] -> no change args -> [insert node] -> need change args -> [next provider] -> no change args 83 nodes_before_insert = [] 84 search_node = provider[0].get_next() 85 while search_node is not None: 86 nodes_before_insert.append(search_node) 87 if search_node == node: 88 break 89 search_node = search_node.get_next() 90 provider_target_users = provider[0].get_target_users(provider[1]) 91 for user in provider_target_users[:]: # copy list by slice to support remove item during iterating 92 if user[0] not in nodes_before_insert: 93 node.append_target_users(index, user) 94 provider_target_users.remove(user) 95 user[0].set_arg_providers(user[1], (node, index)) 96 97 @staticmethod 98 def _get_value_provider(node, value: ScopedValue): 99 node = node.get_prev() 100 while node is not None: 101 if node.get_targets() is not None: 102 for index, target in enumerate(node.get_targets()): 103 if target == value: 104 return (node, index) 105 node = node.get_prev() 106 return () 107 108 def topo_changed(self): 109 """ 110 The function is executed when an Event.TopologicalChangeEvent event is received. 111 """ 112 self.changed(Event.TopologicalChangeEvent) 113 114 def on_insert_node(self, node: Node): 115 """ 116 Update provider dict and consumer dict while inserting node into SymbolTree and update inputs of node by updated 117 provider dict and consumer dict. 118 119 Args: 120 node (Node): An instance of Node which been inserted into SymbolTree. 121 """ 122 if node.get_normalized_args() is not None: 123 for index, arg in enumerate(node.get_normalized_args().values()): 124 provider = TopoManager._get_value_provider(node, arg) 125 if provider: 126 node.set_arg_providers(index, provider) 127 provider[0].append_target_users(provider[1], (node, index)) 128 if node.get_targets() is not None: 129 for index, target in enumerate(node.get_targets()): 130 provider = TopoManager._get_value_provider(node, target) 131 if provider: 132 TopoManager._update_target_users_by_node(node, index, provider) 133 else: 134 TopoManager._update_target_users_by_value(node, index, target) 135 self.topo_changed() 136 137 def on_erase_node(self, node: Node): 138 """ 139 Update provider dict and consumer dict while erasing node from SymbolTree. 140 141 Args: 142 node (Node): An instance of Node which been erased from SymbolTree. 143 """ 144 prev_providers = {} 145 # Find previous node with same target of current node. 146 for index, target_users in node.get_target_users().items(): 147 if not target_users: 148 continue 149 prev_provider = TopoManager._get_value_provider(node, node.get_targets()[index]) 150 if not prev_provider: 151 logger.warning(f"Node {node.get_name()}'s target {index}({node.get_targets()[index]}) is used in node " 152 f"{target_users[0][0].get_name()}'s arg {target_users[0][1]}, " 153 f"no other node provides this target if node {node.get_name()} is erased.") 154 prev_providers[index] = None 155 else: 156 prev_providers[index] = prev_provider 157 # Update targets topological of nodes 158 for index, prev_provider in prev_providers.items(): 159 for target_user in node.get_target_users(index): 160 if prev_provider is None: 161 target_user[0].get_arg_providers().pop(target_user[1], None) 162 else: 163 prev_provider[0].append_target_users(prev_provider[1], target_user) 164 target_user[0].set_arg_providers(target_user[1], prev_provider) 165 # Update arguments topological of nodes 166 for _, arg_providers in node.get_arg_providers().items(): 167 if not arg_providers: 168 continue 169 provider_target_users = arg_providers[0].get_target_users(arg_providers[1]) 170 for target_user in reversed(provider_target_users): 171 if target_user[0] == node: 172 provider_target_users.remove(target_user) 173 self.topo_changed() 174 175 def on_update_arg(self, node: Node, arg_idx: int, old_arg: ScopedValue, new_arg: ScopedValue): 176 """ 177 Update provider dict and consumer dict while updating argument of node and update inputs of node by updated 178 provider dict and consumer dict. 179 180 Args: 181 node (Node): An instance of Node whose arguments being updated. 182 arg_idx (int): An int indicates which argument of node being updated. 183 old_arg (ScopedValue): An instance of ScopedValue represents original argument. 184 new_arg (ScopedValue): An instance of ScopedValue represents new argument. 185 """ 186 # Update old arg's provider node's target_users. 187 old_provider = TopoManager._get_value_provider(node, old_arg) 188 if old_provider: 189 old_provider_target_users = old_provider[0].get_target_users(old_provider[1]) 190 for target_user in reversed(old_provider_target_users): 191 if target_user[0] == node and target_user[1] == arg_idx: 192 old_provider_target_users.remove(target_user) 193 break 194 # Update new arg's provider node's target_users. 195 provider = TopoManager._get_value_provider(node, new_arg) 196 if provider: 197 provider[0].append_target_users(provider[1], (node, arg_idx)) 198 # Update current node's arg_providers. 199 node.set_arg_providers(arg_idx, provider) 200 self.topo_changed() 201 202 def on_update_arg_by_node(self, dst_node: Node, arg_idx: int, src_node: Node, out_idx: int): 203 """ 204 Update argument of 'dst_node' by another Node. 205 206 Args: 207 dst_node (Node): Node to be modified. 208 arg_idx (int): Indicate which input being modified. 209 src_node (Node): Node as new input. 210 out_idx (int): Indicate which output of 'src_node' as new input of 'dst_node'. 211 """ 212 # Update old arg's provider node's target_users. 213 if arg_idx in dst_node.get_arg_providers().keys(): 214 arg_provider = dst_node.get_arg_providers()[arg_idx] 215 if arg_provider: 216 provider_target_users = arg_provider[0].get_target_users(arg_provider[1]) 217 if (dst_node, arg_idx) in provider_target_users: 218 provider_target_users.remove((dst_node, arg_idx)) 219 # Update new arg's provider node's target_users. 220 src_node.append_target_users(out_idx, (dst_node, arg_idx)) 221 # Update current node's arg_providers. 222 dst_node.set_arg_providers(arg_idx, (src_node, out_idx)) 223 self.topo_changed() 224