1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/allreduce_fusion/allreduce_graph.h"
18 #include <algorithm>
19 #include <functional>
20 #include "ir/anf.h"
21 #include "frontend/parallel/allreduce_fusion/allreduce_node.h"
22 #include "frontend/parallel/ops_info/ops_utils.h"
23 #include "utils/log_adapter.h"
24
25 namespace mindspore {
26 namespace parallel {
AddNode(const CNodePtr & node,const AnfNodePtr & para)27 Status AllreduceGraph::AddNode(const CNodePtr &node, const AnfNodePtr ¶) {
28 AllreduceNodePtr arnode;
29 auto cnode_emplace_return = cnode_set_.emplace(node);
30 if (!cnode_emplace_return.second) {
31 MS_LOG(INFO) << "node: " << node->DebugString() << " has already been added!";
32 auto cnode_arnode_pair = cnode_arnode_map_.find(node);
33 if (cnode_arnode_pair == cnode_arnode_map_.end()) {
34 MS_LOG(EXCEPTION) << "node is not in cnode_arnode_map_!";
35 }
36 arnode = cnode_arnode_pair->second;
37 } else {
38 arnode = std::make_shared<AllreduceNode>(AllreduceNode());
39 }
40
41 MS_EXCEPTION_IF_NULL(arnode);
42 if (arnode->Init(node) != SUCCESS) {
43 MS_LOG(ERROR) << "AllreduceNode Init failed";
44 return FAILED;
45 }
46 if (arnode->AddPara(para) != SUCCESS) {
47 MS_LOG(ERROR) << "AllreduceNode AddPara failed";
48 return FAILED;
49 }
50 cnode_arnode_map_[node] = arnode;
51
52 auto arnode_emplace_return = arnode_set_.insert(arnode);
53 if (!arnode_emplace_return.second) {
54 MS_LOG(INFO) << "node: " << node->DebugString() << "'s arnode has already been added!";
55 }
56 cnode_emplace_return = para_cnodeset_map_[para].emplace(node);
57 if (!cnode_emplace_return.second) {
58 MS_LOG(INFO) << "node: " << node->DebugString() << " already in para: " << para->fullname_with_scope()
59 << "'s cnodeset!";
60 }
61 auto para_emplace_return = cnode_paraset_map_[node].emplace(para);
62 if (!para_emplace_return.second) {
63 MS_LOG(INFO) << "para: " << para->fullname_with_scope() << " already in node: " << node->DebugString()
64 << "'s paraset!";
65 }
66 return SUCCESS;
67 }
68
AddEdge(const CNodePtr & from,const CNodePtr & to,double dist)69 Status AllreduceGraph::AddEdge(const CNodePtr &from, const CNodePtr &to, double dist) {
70 auto from_arnode_iter = cnode_arnode_map_.find(from);
71 if (from_arnode_iter == cnode_arnode_map_.end()) {
72 MS_LOG(ERROR) << "cnode from: " << from->DebugString() << "has not been added";
73 PrintCNodeSet();
74 return FAILED;
75 }
76 auto to_arnode_iter = cnode_arnode_map_.find(to);
77 if (to_arnode_iter == cnode_arnode_map_.end()) {
78 MS_LOG(ERROR) << "cnode to: " << to->DebugString() << "has not been added";
79 PrintCNodeSet();
80 return FAILED;
81 }
82 auto from_arnode = from_arnode_iter->second;
83 auto to_arnode = to_arnode_iter->second;
84 if (from_arnode->AddNext(to_arnode) != SUCCESS) {
85 MS_LOG(ERROR) << "from_arnode AddNext failed";
86 return FAILED;
87 }
88 if (to_arnode->AddPrev(from_arnode, dist, &max_) != SUCCESS) {
89 MS_LOG(ERROR) << "to_arnode AddPrev failed";
90 return FAILED;
91 }
92 max_ = std::max(max_, to_arnode->depend_feat_size());
93 MS_LOG(DEBUG) << "from " << from->DebugString() << ", to " << to->DebugString();
94 MS_LOG(DEBUG) << "from depend_feat_size: " << from_arnode->depend_feat_size()
95 << ", to depend_feat_size: " << to_arnode->depend_feat_size();
96 return SUCCESS;
97 }
98
NodeInGraph(const CNodePtr & node) const99 bool AllreduceGraph::NodeInGraph(const CNodePtr &node) const {
100 auto cnode_iter = cnode_set_.find(node);
101 return !(cnode_iter == cnode_set_.end());
102 }
103
GetParaByCost(double from,double to)104 std::vector<AnfNodePtr> AllreduceGraph::GetParaByCost(double from, double to) {
105 std::vector<AnfNodePtr> nodes;
106 for (auto &cnode_arnode : cnode_arnode_map_) {
107 MS_LOG(DEBUG) << "cnode: " << cnode_arnode.first->DebugString()
108 << ", depend_feat_size: " << cnode_arnode.second->depend_feat_size()
109 << " curr_para_size: " << cnode_arnode.second->curr_para_size();
110 if ((cnode_arnode.second->depend_feat_size() <= to) && (cnode_arnode.second->depend_feat_size() > from)) {
111 (void)nodes.insert(nodes.end(), cnode_paraset_map_[cnode_arnode.first].begin(),
112 cnode_paraset_map_[cnode_arnode.first].end());
113 }
114 }
115 return nodes;
116 }
117
GetParaByParaSize(double to,double para_size)118 std::pair<std::vector<AnfNodePtr>, double> AllreduceGraph::GetParaByParaSize(double to, double para_size) {
119 std::vector<AnfNodePtr> nodes;
120 double cur_para_size = 0;
121 double from = to;
122 for (auto &arnode : arnode_vec_) {
123 if ((arnode.depend_feat_size() - max_ <= EPS) && arnode.depend_feat_size() >= to) {
124 continue;
125 }
126 if (para_size > 0 && cur_para_size >= para_size && arnode.depend_feat_size() < from) {
127 return std::make_pair(nodes, from);
128 }
129 (void)nodes.insert(nodes.end(), arnode.paras().begin(), arnode.paras().end());
130 cur_para_size += arnode.curr_para_size();
131 from = arnode.depend_feat_size();
132 }
133 MS_LOG(INFO) << "GetParaByParaSize has reached head node! para_size: " << para_size
134 << " cur_para_size: " << cur_para_size << " from: " << from;
135 return std::make_pair(nodes, from);
136 }
137
PrintCNodeSet() const138 void AllreduceGraph::PrintCNodeSet() const {
139 MS_LOG(INFO) << "CNodeSet:";
140 for (auto &cnode : cnode_set_) {
141 MS_LOG(INFO) << cnode->DebugString();
142 }
143 }
144
PrintAllredueGraphInfo() const145 void AllreduceGraph::PrintAllredueGraphInfo() const {
146 MS_LOG(INFO) << "max: " << max_;
147 for (auto &cnode_arnode : cnode_arnode_map_) {
148 MS_LOG(INFO) << "cnode: " << cnode_arnode.first->DebugString();
149 MS_LOG(INFO) << "arnode info: ";
150 cnode_arnode.second->ToString();
151 }
152 }
153
PrintArnodeVec() const154 void AllreduceGraph::PrintArnodeVec() const {
155 MS_LOG(INFO) << "ArnodeVec:";
156 for (auto &arnode : arnode_vec_) {
157 arnode.ToString();
158 }
159 }
160
PrintArnodeSet() const161 void AllreduceGraph::PrintArnodeSet() const {
162 MS_LOG(INFO) << "ArnodeSet:";
163 for (auto &arnode : arnode_set_) {
164 arnode->ToString();
165 }
166 }
167
SortArnode()168 void AllreduceGraph::SortArnode() {
169 arnode_vec_.clear();
170 for (auto &node : arnode_set_) {
171 (void)arnode_vec_.emplace_back(*node);
172 }
173 std::sort(arnode_vec_.begin(), arnode_vec_.end(), std::greater<>());
174 }
175
RemoveExtraParas()176 Status AllreduceGraph::RemoveExtraParas() {
177 mindspore::HashSet<AnfNodePtr> para_map;
178 for (auto &node : arnode_vec_) {
179 for (auto ¶ : node.paras()) {
180 auto emplac_result = para_map.emplace(para);
181 if (!emplac_result.second) {
182 MS_LOG(DEBUG) << "parameter: " << para->fullname_with_scope() << "in arnode";
183 if (node.RemovePara(para) != SUCCESS) {
184 MS_LOG(ERROR) << "remove para failed";
185 return FAILED;
186 }
187 }
188 }
189 }
190 return SUCCESS;
191 }
192
set_head_cnode(const CNodePtr & node)193 Status AllreduceGraph::set_head_cnode(const CNodePtr &node) {
194 auto arnode = std::make_shared<AllreduceNode>(AllreduceNode());
195 if (arnode->Init(node) != SUCCESS) {
196 MS_LOG(ERROR) << "AllreduceNode Init failed";
197 }
198 head_cnode_ = node;
199 cnode_arnode_map_[node] = arnode;
200 auto arnode_emplace_return = arnode_set_.insert(arnode);
201 if (!arnode_emplace_return.second) {
202 MS_LOG(DEBUG) << "node: " << node->DebugString() << "'s arnode has already been added!";
203 }
204 auto cnode_emplace_return = cnode_set_.emplace(node);
205 if (!cnode_emplace_return.second) {
206 MS_LOG(DEBUG) << "node: " << node->DebugString() << " has already been added!";
207 }
208 return SUCCESS;
209 }
210 } // namespace parallel
211 } // namespace mindspore
212