1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/allreduce_fusion/allreduce_graph.h"
18 #include <algorithm>
19 #include <functional>
20 #include "ir/anf.h"
21 #include "frontend/parallel/allreduce_fusion/allreduce_node.h"
22 #include "frontend/parallel/ops_info/ops_utils.h"
23 #include "utils/log_adapter.h"
24
25 namespace mindspore {
26 namespace parallel {
AddNode(const CNodePtr & node,const AnfNodePtr & para)27 Status AllreduceGraph::AddNode(const CNodePtr &node, const AnfNodePtr ¶) {
28 AllreduceNodePtr arnode;
29 auto cnode_emplace_return = cnode_set_.emplace(node);
30 if (!cnode_emplace_return.second) {
31 MS_LOG(INFO) << "node: " << node->DebugString() << " has already been added!";
32 auto cnode_arnode_pair = cnode_arnode_map_.find(node);
33 if (cnode_arnode_pair == cnode_arnode_map_.end()) {
34 MS_LOG(EXCEPTION) << "node is not in cnode_arnode_map_!";
35 }
36 arnode = cnode_arnode_pair->second;
37 } else {
38 arnode = std::make_shared<AllreduceNode>(AllreduceNode());
39 }
40
41 if (arnode->Init(node) != SUCCESS) {
42 MS_LOG(ERROR) << "AllreduceNode Init failed";
43 return FAILED;
44 }
45 if (arnode->AddPara(para) != SUCCESS) {
46 MS_LOG(ERROR) << "AllreduceNode AddPara failed";
47 return FAILED;
48 }
49 cnode_arnode_map_[node] = arnode;
50
51 auto arnode_emplace_return = arnode_set_.insert(arnode);
52 if (!arnode_emplace_return.second) {
53 MS_LOG(INFO) << "node: " << node->DebugString() << "'s arnode has already been added!";
54 }
55 cnode_emplace_return = para_cnodeset_map_[para].emplace(node);
56 if (!cnode_emplace_return.second) {
57 MS_LOG(INFO) << "node: " << node->DebugString() << " already in para: " << para->fullname_with_scope()
58 << "'s cnodeset!";
59 }
60 auto para_emplace_return = cnode_paraset_map_[node].emplace(para);
61 if (!para_emplace_return.second) {
62 MS_LOG(INFO) << "para: " << para->fullname_with_scope() << " already in node: " << node->DebugString()
63 << "'s paraset!";
64 }
65 return SUCCESS;
66 }
67
AddEdge(const CNodePtr & from,const CNodePtr & to,double dist)68 Status AllreduceGraph::AddEdge(const CNodePtr &from, const CNodePtr &to, double dist) {
69 auto from_arnode_iter = cnode_arnode_map_.find(from);
70 if (from_arnode_iter == cnode_arnode_map_.end()) {
71 MS_LOG(ERROR) << "cnode from: " << from->DebugString() << "has not been added";
72 PrintCNodeSet();
73 return FAILED;
74 }
75 auto to_arnode_iter = cnode_arnode_map_.find(to);
76 if (to_arnode_iter == cnode_arnode_map_.end()) {
77 MS_LOG(ERROR) << "cnode to: " << to->DebugString() << "has not been added";
78 PrintCNodeSet();
79 return FAILED;
80 }
81 auto from_arnode = from_arnode_iter->second;
82 auto to_arnode = to_arnode_iter->second;
83 if (from_arnode->AddNext(to_arnode) != SUCCESS) {
84 MS_LOG(ERROR) << "from_arnode AddNext failed";
85 return FAILED;
86 }
87 if (to_arnode->AddPrev(from_arnode, dist, &max_) != SUCCESS) {
88 MS_LOG(ERROR) << "to_arnode AddPrev failed";
89 return FAILED;
90 }
91 max_ = std::max(max_, to_arnode->depend_feat_size());
92 MS_LOG(DEBUG) << "from " << from->DebugString() << ", to " << to->DebugString();
93 MS_LOG(DEBUG) << "from depend_feat_size: " << from_arnode->depend_feat_size()
94 << ", to depend_feat_size: " << to_arnode->depend_feat_size();
95 return SUCCESS;
96 }
97
NodeInGraph(const CNodePtr & node) const98 bool AllreduceGraph::NodeInGraph(const CNodePtr &node) const {
99 auto cnode_iter = cnode_set_.find(node);
100 return !(cnode_iter == cnode_set_.end());
101 }
102
GetParaByCost(double from,double to)103 std::vector<AnfNodePtr> AllreduceGraph::GetParaByCost(double from, double to) {
104 std::vector<AnfNodePtr> nodes;
105 for (auto &cnode_arnode : cnode_arnode_map_) {
106 MS_LOG(DEBUG) << "cnode: " << cnode_arnode.first->DebugString()
107 << ", depend_feat_size: " << cnode_arnode.second->depend_feat_size()
108 << " curr_para_size: " << cnode_arnode.second->curr_para_size();
109 if ((cnode_arnode.second->depend_feat_size() <= to) && (cnode_arnode.second->depend_feat_size() > from)) {
110 (void)nodes.insert(nodes.end(), cnode_paraset_map_[cnode_arnode.first].begin(),
111 cnode_paraset_map_[cnode_arnode.first].end());
112 }
113 }
114 return nodes;
115 }
116
GetParaByParaSize(double to,double para_size)117 std::pair<std::vector<AnfNodePtr>, double> AllreduceGraph::GetParaByParaSize(double to, double para_size) {
118 std::vector<AnfNodePtr> nodes;
119 double cur_para_size = 0;
120 double from = to;
121 for (auto &arnode : arnode_vec_) {
122 if ((arnode.depend_feat_size() - max_ <= EPS) && arnode.depend_feat_size() >= to) {
123 continue;
124 }
125 if (para_size > 0 && cur_para_size >= para_size && arnode.depend_feat_size() < from) {
126 return std::make_pair(nodes, from);
127 }
128 (void)nodes.insert(nodes.end(), arnode.paras().begin(), arnode.paras().end());
129 cur_para_size += arnode.curr_para_size();
130 from = arnode.depend_feat_size();
131 }
132 MS_LOG(INFO) << "GetParaByParaSize has reached head node! para_size: " << para_size
133 << " cur_para_size: " << cur_para_size << " from: " << from;
134 return std::make_pair(nodes, from);
135 }
136
PrintCNodeSet() const137 void AllreduceGraph::PrintCNodeSet() const {
138 MS_LOG(INFO) << "CNodeSet:";
139 for (auto &cnode : cnode_set_) {
140 MS_LOG(INFO) << cnode->DebugString();
141 }
142 }
143
PrintAllredueGraphInfo() const144 void AllreduceGraph::PrintAllredueGraphInfo() const {
145 MS_LOG(INFO) << "max: " << max_;
146 for (auto &cnode_arnode : cnode_arnode_map_) {
147 MS_LOG(INFO) << "cnode: " << cnode_arnode.first->DebugString();
148 MS_LOG(INFO) << "arnode info: ";
149 cnode_arnode.second->ToString();
150 }
151 }
152
PrintArnodeVec() const153 void AllreduceGraph::PrintArnodeVec() const {
154 MS_LOG(INFO) << "ArnodeVec:";
155 for (auto &arnode : arnode_vec_) {
156 arnode.ToString();
157 }
158 }
159
PrintArnodeSet() const160 void AllreduceGraph::PrintArnodeSet() const {
161 MS_LOG(INFO) << "ArnodeSet:";
162 for (auto &arnode : arnode_set_) {
163 arnode->ToString();
164 }
165 }
166
SortArnode()167 void AllreduceGraph::SortArnode() {
168 arnode_vec_.clear();
169 for (auto &node : arnode_set_) {
170 arnode_vec_.emplace_back(*node);
171 }
172 std::sort(arnode_vec_.begin(), arnode_vec_.end(), std::greater<>());
173 }
174
RemoveExtraParas()175 Status AllreduceGraph::RemoveExtraParas() {
176 std::unordered_set<AnfNodePtr> para_map;
177 for (auto &node : arnode_vec_) {
178 for (auto ¶ : node.paras()) {
179 auto emplac_result = para_map.emplace(para);
180 if (!emplac_result.second) {
181 MS_LOG(DEBUG) << "parameter: " << para->fullname_with_scope() << "in arnode";
182 if (node.RemovePara(para) != SUCCESS) {
183 MS_LOG(ERROR) << "remove para failed";
184 return FAILED;
185 }
186 }
187 }
188 }
189 return SUCCESS;
190 }
191
set_head_cnode(const CNodePtr & node)192 Status AllreduceGraph::set_head_cnode(const CNodePtr &node) {
193 auto arnode = std::make_shared<AllreduceNode>(AllreduceNode());
194 if (arnode->Init(node) != SUCCESS) {
195 MS_LOG(ERROR) << "AllreduceNode Init failed";
196 }
197 head_cnode_ = node;
198 cnode_arnode_map_[node] = arnode;
199 auto arnode_emplace_return = arnode_set_.insert(arnode);
200 if (!arnode_emplace_return.second) {
201 MS_LOG(WARNING) << "node: " << node->DebugString() << "'s arnode has already been added!";
202 }
203 auto cnode_emplace_return = cnode_set_.emplace(node);
204 if (!cnode_emplace_return.second) {
205 MS_LOG(WARNING) << "node: " << node->DebugString() << " has already been added!";
206 }
207 return SUCCESS;
208 }
209 } // namespace parallel
210 } // namespace mindspore
211