• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/allreduce_fusion/allreduce_graph.h"
18 #include <algorithm>
19 #include <functional>
20 #include "ir/anf.h"
21 #include "frontend/parallel/allreduce_fusion/allreduce_node.h"
22 #include "frontend/parallel/ops_info/ops_utils.h"
23 #include "utils/log_adapter.h"
24 
25 namespace mindspore {
26 namespace parallel {
AddNode(const CNodePtr & node,const AnfNodePtr & para)27 Status AllreduceGraph::AddNode(const CNodePtr &node, const AnfNodePtr &para) {
28   AllreduceNodePtr arnode;
29   auto cnode_emplace_return = cnode_set_.emplace(node);
30   if (!cnode_emplace_return.second) {
31     MS_LOG(INFO) << "node: " << node->DebugString() << " has already been added!";
32     auto cnode_arnode_pair = cnode_arnode_map_.find(node);
33     if (cnode_arnode_pair == cnode_arnode_map_.end()) {
34       MS_LOG(EXCEPTION) << "node is not in cnode_arnode_map_!";
35     }
36     arnode = cnode_arnode_pair->second;
37   } else {
38     arnode = std::make_shared<AllreduceNode>(AllreduceNode());
39   }
40 
41   MS_EXCEPTION_IF_NULL(arnode);
42   if (arnode->Init(node) != SUCCESS) {
43     MS_LOG(ERROR) << "AllreduceNode Init failed";
44     return FAILED;
45   }
46   if (arnode->AddPara(para) != SUCCESS) {
47     MS_LOG(ERROR) << "AllreduceNode AddPara failed";
48     return FAILED;
49   }
50   cnode_arnode_map_[node] = arnode;
51 
52   auto arnode_emplace_return = arnode_set_.insert(arnode);
53   if (!arnode_emplace_return.second) {
54     MS_LOG(INFO) << "node: " << node->DebugString() << "'s arnode has already been added!";
55   }
56   cnode_emplace_return = para_cnodeset_map_[para].emplace(node);
57   if (!cnode_emplace_return.second) {
58     MS_LOG(INFO) << "node: " << node->DebugString() << " already in para: " << para->fullname_with_scope()
59                  << "'s cnodeset!";
60   }
61   auto para_emplace_return = cnode_paraset_map_[node].emplace(para);
62   if (!para_emplace_return.second) {
63     MS_LOG(INFO) << "para: " << para->fullname_with_scope() << " already in node: " << node->DebugString()
64                  << "'s paraset!";
65   }
66   return SUCCESS;
67 }
68 
AddEdge(const CNodePtr & from,const CNodePtr & to,double dist)69 Status AllreduceGraph::AddEdge(const CNodePtr &from, const CNodePtr &to, double dist) {
70   auto from_arnode_iter = cnode_arnode_map_.find(from);
71   if (from_arnode_iter == cnode_arnode_map_.end()) {
72     MS_LOG(ERROR) << "cnode from: " << from->DebugString() << "has not been added";
73     PrintCNodeSet();
74     return FAILED;
75   }
76   auto to_arnode_iter = cnode_arnode_map_.find(to);
77   if (to_arnode_iter == cnode_arnode_map_.end()) {
78     MS_LOG(ERROR) << "cnode to: " << to->DebugString() << "has not been added";
79     PrintCNodeSet();
80     return FAILED;
81   }
82   auto from_arnode = from_arnode_iter->second;
83   auto to_arnode = to_arnode_iter->second;
84   if (from_arnode->AddNext(to_arnode) != SUCCESS) {
85     MS_LOG(ERROR) << "from_arnode AddNext failed";
86     return FAILED;
87   }
88   if (to_arnode->AddPrev(from_arnode, dist, &max_) != SUCCESS) {
89     MS_LOG(ERROR) << "to_arnode AddPrev failed";
90     return FAILED;
91   }
92   max_ = std::max(max_, to_arnode->depend_feat_size());
93   MS_LOG(DEBUG) << "from " << from->DebugString() << ", to " << to->DebugString();
94   MS_LOG(DEBUG) << "from depend_feat_size: " << from_arnode->depend_feat_size()
95                 << ", to depend_feat_size: " << to_arnode->depend_feat_size();
96   return SUCCESS;
97 }
98 
NodeInGraph(const CNodePtr & node) const99 bool AllreduceGraph::NodeInGraph(const CNodePtr &node) const {
100   auto cnode_iter = cnode_set_.find(node);
101   return !(cnode_iter == cnode_set_.end());
102 }
103 
GetParaByCost(double from,double to)104 std::vector<AnfNodePtr> AllreduceGraph::GetParaByCost(double from, double to) {
105   std::vector<AnfNodePtr> nodes;
106   for (auto &cnode_arnode : cnode_arnode_map_) {
107     MS_LOG(DEBUG) << "cnode: " << cnode_arnode.first->DebugString()
108                   << ", depend_feat_size: " << cnode_arnode.second->depend_feat_size()
109                   << " curr_para_size: " << cnode_arnode.second->curr_para_size();
110     if ((cnode_arnode.second->depend_feat_size() <= to) && (cnode_arnode.second->depend_feat_size() > from)) {
111       (void)nodes.insert(nodes.end(), cnode_paraset_map_[cnode_arnode.first].begin(),
112                          cnode_paraset_map_[cnode_arnode.first].end());
113     }
114   }
115   return nodes;
116 }
117 
GetParaByParaSize(double to,double para_size)118 std::pair<std::vector<AnfNodePtr>, double> AllreduceGraph::GetParaByParaSize(double to, double para_size) {
119   std::vector<AnfNodePtr> nodes;
120   double cur_para_size = 0;
121   double from = to;
122   for (auto &arnode : arnode_vec_) {
123     if ((arnode.depend_feat_size() - max_ <= EPS) && arnode.depend_feat_size() >= to) {
124       continue;
125     }
126     if (para_size > 0 && cur_para_size >= para_size && arnode.depend_feat_size() < from) {
127       return std::make_pair(nodes, from);
128     }
129     (void)nodes.insert(nodes.end(), arnode.paras().begin(), arnode.paras().end());
130     cur_para_size += arnode.curr_para_size();
131     from = arnode.depend_feat_size();
132   }
133   MS_LOG(INFO) << "GetParaByParaSize has reached head node! para_size: " << para_size
134                << " cur_para_size: " << cur_para_size << " from: " << from;
135   return std::make_pair(nodes, from);
136 }
137 
PrintCNodeSet() const138 void AllreduceGraph::PrintCNodeSet() const {
139   MS_LOG(INFO) << "CNodeSet:";
140   for (auto &cnode : cnode_set_) {
141     MS_LOG(INFO) << cnode->DebugString();
142   }
143 }
144 
PrintAllredueGraphInfo() const145 void AllreduceGraph::PrintAllredueGraphInfo() const {
146   MS_LOG(INFO) << "max: " << max_;
147   for (auto &cnode_arnode : cnode_arnode_map_) {
148     MS_LOG(INFO) << "cnode: " << cnode_arnode.first->DebugString();
149     MS_LOG(INFO) << "arnode info: ";
150     cnode_arnode.second->ToString();
151   }
152 }
153 
PrintArnodeVec() const154 void AllreduceGraph::PrintArnodeVec() const {
155   MS_LOG(INFO) << "ArnodeVec:";
156   for (auto &arnode : arnode_vec_) {
157     arnode.ToString();
158   }
159 }
160 
PrintArnodeSet() const161 void AllreduceGraph::PrintArnodeSet() const {
162   MS_LOG(INFO) << "ArnodeSet:";
163   for (auto &arnode : arnode_set_) {
164     arnode->ToString();
165   }
166 }
167 
SortArnode()168 void AllreduceGraph::SortArnode() {
169   arnode_vec_.clear();
170   for (auto &node : arnode_set_) {
171     (void)arnode_vec_.emplace_back(*node);
172   }
173   std::sort(arnode_vec_.begin(), arnode_vec_.end(), std::greater<>());
174 }
175 
RemoveExtraParas()176 Status AllreduceGraph::RemoveExtraParas() {
177   mindspore::HashSet<AnfNodePtr> para_map;
178   for (auto &node : arnode_vec_) {
179     for (auto &para : node.paras()) {
180       auto emplac_result = para_map.emplace(para);
181       if (!emplac_result.second) {
182         MS_LOG(DEBUG) << "parameter: " << para->fullname_with_scope() << "in arnode";
183         if (node.RemovePara(para) != SUCCESS) {
184           MS_LOG(ERROR) << "remove para failed";
185           return FAILED;
186         }
187       }
188     }
189   }
190   return SUCCESS;
191 }
192 
set_head_cnode(const CNodePtr & node)193 Status AllreduceGraph::set_head_cnode(const CNodePtr &node) {
194   auto arnode = std::make_shared<AllreduceNode>(AllreduceNode());
195   if (arnode->Init(node) != SUCCESS) {
196     MS_LOG(ERROR) << "AllreduceNode Init failed";
197   }
198   head_cnode_ = node;
199   cnode_arnode_map_[node] = arnode;
200   auto arnode_emplace_return = arnode_set_.insert(arnode);
201   if (!arnode_emplace_return.second) {
202     MS_LOG(DEBUG) << "node: " << node->DebugString() << "'s arnode has already been added!";
203   }
204   auto cnode_emplace_return = cnode_set_.emplace(node);
205   if (!cnode_emplace_return.second) {
206     MS_LOG(DEBUG) << "node: " << node->DebugString() << " has already been added!";
207   }
208   return SUCCESS;
209 }
210 }  // namespace parallel
211 }  // namespace mindspore
212