1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/allreduce_fusion/allreduce_fusion.h"
18 #include <memory>
19 #include <queue>
20 #include <string>
21 #include <unordered_set>
22 #include "ir/func_graph.h"
23 #include "frontend/parallel/costmodel_context.h"
24 #include "frontend/parallel/graph_util/node_info.h"
25 #include "frontend/parallel/status.h"
26 #include "frontend/parallel/step_parallel.h"
27 #include "utils/log_adapter.h"
28
29 namespace mindspore {
30 namespace parallel {
FindCNodesWithPara(const AnfNodePtr & para,uint64_t recursive_times=0)31 std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr ¶, uint64_t recursive_times = 0) {
32 if (recursive_times > MAX_RECURSIVE_CALL_TIMES) {
33 MS_LOG(EXCEPTION) << "FindCNodesWithPara exceeds max recursive call times! Max recursive call times is "
34 << MAX_RECURSIVE_CALL_TIMES;
35 }
36 MS_EXCEPTION_IF_NULL(para);
37 MS_EXCEPTION_IF_NULL(para->func_graph());
38 FuncGraphManagerPtr manager = para->func_graph()->manager();
39 MS_EXCEPTION_IF_NULL(manager);
40 auto node_set = manager->node_users()[para];
41 std::unordered_set<CNodePtr> cnode_set;
42 for (auto &node_pair : node_set) {
43 auto cnode = node_pair.first->cast<CNodePtr>();
44 MS_EXCEPTION_IF_NULL(cnode);
45 if (!IsValueNode<Primitive>(cnode->input(0))) {
46 continue;
47 }
48 auto node_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
49 MS_EXCEPTION_IF_NULL(node_prim);
50 if (node_prim->name() == DEPEND && node_pair.second != 1) {
51 continue;
52 }
53 if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
54 (void)cnode_set.emplace(cnode);
55 } else {
56 auto cnode_set_sub = FindCNodesWithPara(node_pair.first, recursive_times + 1);
57 for (auto &cnode_sub : cnode_set_sub) {
58 (void)cnode_set.emplace(cnode_sub);
59 }
60 }
61 }
62 return cnode_set;
63 }
64
AddNodeToGraph()65 Status AllreduceFusion::AddNodeToGraph() {
66 const auto ¶meters = root_graph_->parameters();
67 for (auto ¶meter : parameters) {
68 if (!ParameterRequireGrad(parameter)) {
69 continue;
70 }
71 auto cnode_set = FindCNodesWithPara(parameter);
72 if (cnode_set.empty()) {
73 continue;
74 }
75 for (auto &cnode : cnode_set) {
76 MS_LOG(DEBUG) << "AddNode " << cnode->DebugString();
77 if (allreduce_graph_.AddNode(cnode, parameter) != SUCCESS) {
78 MS_LOG(ERROR) << "AddNode failed! cnode: " << cnode->DebugString();
79 return FAILED;
80 }
81 }
82 }
83 return SUCCESS;
84 }
85
FindCNode(const AnfNodePtr & from,uint64_t recursive_times) const86 CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr &from, uint64_t recursive_times) const {
87 if (recursive_times > MAX_RECURSIVE_CALL_TIMES) {
88 MS_LOG(EXCEPTION) << "FindCNode exceeds max recursive call times! Max recursive call times is "
89 << MAX_RECURSIVE_CALL_TIMES;
90 }
91 MS_EXCEPTION_IF_NULL(from);
92 std::unordered_map<CNodePtr, double> cnode_dist;
93 if (!from->isa<CNode>()) {
94 return cnode_dist;
95 }
96 auto cnode = from->cast<CNodePtr>();
97 if (!IsValueNode<Primitive>(cnode->input(0))) {
98 return cnode_dist;
99 }
100
101 auto operator_info = cnode->user_data<OperatorInfo>();
102 MS_LOG(DEBUG) << "cnode " << cnode->ToString() << " IsParallelCareNode: " << IsParallelCareNode(cnode)
103 << " operator_info: " << (operator_info != nullptr);
104
105 if (IsParallelCareNode(cnode) && (operator_info != nullptr)) {
106 auto cost = operator_info->GetForwardMemoryCostFromCNode();
107 MS_LOG(DEBUG) << "cnode " << cnode->DebugString() << " cost: " << cost;
108
109 if (allreduce_graph_.NodeInGraph(cnode)) {
110 cnode_dist[cnode] = cost;
111 return cnode_dist;
112 } else {
113 auto cnode_dist_next = FindNextCNodes(cnode, recursive_times + 1);
114 for (auto &ele_next : cnode_dist_next) {
115 cnode_dist[ele_next.first] = cost + ele_next.second;
116 }
117 }
118 } else {
119 auto cnode_dist_next = FindNextCNodes(cnode);
120 for (auto &ele : cnode_dist_next) {
121 cnode_dist[ele.first] = ele.second;
122 }
123 }
124 return cnode_dist;
125 }
126
FindNextCNodes(const CNodePtr & from,uint64_t recursive_times) const127 CNodeCostMap AllreduceFusion::FindNextCNodes(const CNodePtr &from, uint64_t recursive_times) const {
128 if (recursive_times > MAX_RECURSIVE_CALL_TIMES) {
129 MS_LOG(EXCEPTION) << "FindNextCNodes exceeds max recursive call times! Max recursive call times is "
130 << MAX_RECURSIVE_CALL_TIMES;
131 }
132 const auto &from_inputs = from->inputs();
133 std::unordered_map<CNodePtr, double> dist_map;
134 MS_LOG(DEBUG) << "from cnode " << from->DebugString() << " has " << from_inputs.size() << " inputs";
135 for (auto &input_node : from_inputs) {
136 auto cnode_dist = FindCNode(input_node, recursive_times + 1);
137 for (auto &ele : cnode_dist) {
138 (void)dist_map.emplace(ele);
139 }
140 }
141 return dist_map;
142 }
143
AddEdgeToGraph()144 Status AllreduceFusion::AddEdgeToGraph() {
145 std::unordered_map<CNodePtr, int64_t> cnode_state_map;
146 const auto &cnodes = allreduce_graph_.cnode_set();
147 for (auto &cnode : cnodes) {
148 cnode_state_map[cnode] = 0;
149 }
150 const auto &head_cnode = allreduce_graph_.head_cnode();
151 std::queue<CNodePtr> cnode_queue;
152 cnode_queue.emplace(head_cnode);
153 cnode_state_map[head_cnode] = 1;
154
155 while (!cnode_queue.empty()) {
156 const auto cur_cnode = cnode_queue.front();
157 cnode_queue.pop();
158 cnode_state_map[cur_cnode] = 2;
159 auto next = FindNextCNodes(cur_cnode);
160 for (auto &ele : next) {
161 auto &cnode = ele.first;
162 auto &dist = ele.second;
163 if (cnode_state_map[cnode] == 0) {
164 cnode_queue.emplace(cnode);
165 cnode_state_map[cnode] = 1;
166 }
167 if (allreduce_graph_.AddEdge(cur_cnode, cnode, dist) != SUCCESS) {
168 MS_LOG(ERROR) << "AddEdge error";
169 return FAILED;
170 }
171 MS_LOG(DEBUG) << "from " << cur_cnode->DebugString() << ", to " << cnode->DebugString() << " dist " << dist;
172 }
173 }
174 return SUCCESS;
175 }
176
FindMirror(const AnfNodePtr & para,uint64_t recursive_times=0)177 std::vector<CNodePtr> FindMirror(const AnfNodePtr ¶, uint64_t recursive_times = 0) {
178 if (recursive_times > MAX_RECURSIVE_CALL_TIMES) {
179 MS_LOG(EXCEPTION) << "FindMirror exceeds max recursive call times! Max recursive call times is "
180 << MAX_RECURSIVE_CALL_TIMES;
181 }
182 MS_EXCEPTION_IF_NULL(para);
183 MS_EXCEPTION_IF_NULL(para->func_graph());
184 FuncGraphManagerPtr manager = para->func_graph()->manager();
185 MS_EXCEPTION_IF_NULL(manager);
186 AnfNodeIndexSet node_set = manager->node_users()[para];
187 std::vector<CNodePtr> cnode_list;
188 for (auto &node_pair : node_set) {
189 auto cnode = node_pair.first->cast<CNodePtr>();
190 MS_EXCEPTION_IF_NULL(cnode);
191 if (!IsValueNode<Primitive>(cnode->input(0))) {
192 continue;
193 }
194 auto node_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
195 MS_EXCEPTION_IF_NULL(node_prim);
196 if (node_prim->name() == CAST) {
197 auto mirror_cnodes = FindMirror(node_pair.first, recursive_times + 1);
198 if (mirror_cnodes.empty()) {
199 MS_LOG(WARNING) << "mirror node after cast not found";
200 continue;
201 }
202 if (mirror_cnodes.size() > 1) {
203 MS_LOG(EXCEPTION) << "mirror node after cast number is not 1";
204 }
205 cnode_list.emplace_back(mirror_cnodes[0]);
206 }
207 if (node_prim->name() == MIRROR_OPERATOR) {
208 cnode_list.emplace_back(cnode);
209 }
210 }
211 return cnode_list;
212 }
213
SetMirrorFusion(const CNodePtr & mirror_cnode,int64_t fusion,const std::string & parameter_name)214 void SetMirrorFusion(const CNodePtr &mirror_cnode, int64_t fusion, const std::string ¶meter_name) {
215 MS_EXCEPTION_IF_NULL(mirror_cnode);
216 MS_LOG(DEBUG) << "Set Mirror " << mirror_cnode->DebugString() << " fusion " << fusion;
217 auto node_prim = GetValueNode<PrimitivePtr>(mirror_cnode->input(0));
218 auto old_value_ptr = node_prim->GetAttr(FUSION);
219 if (old_value_ptr != nullptr) {
220 if (old_value_ptr->isa<Int64Imm>()) {
221 int64_t old_value = old_value_ptr->cast<Int64ImmPtr>()->value();
222 if (old_value < fusion) {
223 return;
224 }
225 }
226 }
227 (void)node_prim->AddAttr(FUSION, MakeValue(std::make_shared<Int64Imm>(fusion)));
228 (void)node_prim->AddAttr(PARAMETER, MakeValue(std::make_shared<StringImm>(parameter_name)));
229 }
230
FindMirrorAndSetFusion(const AnfNodePtr & para,int64_t fusion)231 Status FindMirrorAndSetFusion(const AnfNodePtr ¶, int64_t fusion) {
232 auto mirror_cnodes = FindMirror(para);
233 if (mirror_cnodes.empty()) {
234 MS_LOG(WARNING) << para->ToString() << " 0 Mirror CNode found.";
235 return SUCCESS;
236 }
237 if (mirror_cnodes.size() > 2) {
238 for (auto &mirror_cnode_1 : mirror_cnodes) {
239 MS_EXCEPTION_IF_NULL(mirror_cnode_1);
240 MS_LOG(INFO) << mirror_cnode_1->DebugString();
241 }
242 MS_EXCEPTION_IF_NULL(para);
243 MS_LOG(ERROR) << para->ToString() << " FindMirror is more than 2. " << mirror_cnodes.size()
244 << "Mirror CNode found.";
245 return FAILED;
246 }
247 for (auto &mirror_cnode : mirror_cnodes) {
248 auto parameter_name = ParameterName(para);
249 SetMirrorFusion(mirror_cnode, fusion, parameter_name);
250 }
251 return SUCCESS;
252 }
253
FindMirrorAndSetFusion(const std::vector<AnfNodePtr> & paras,int64_t fusion)254 Status FindMirrorAndSetFusion(const std::vector<AnfNodePtr> ¶s, int64_t fusion) {
255 for (auto ¶m_node : paras) {
256 if (FindMirrorAndSetFusion(param_node, fusion) != SUCCESS) {
257 MS_LOG(ERROR) << "FindMirrorAndSetFusion failed";
258 return FAILED;
259 }
260 }
261 return SUCCESS;
262 }
263
SetFusion(const std::vector<double> & cost_map)264 Status AllreduceFusion::SetFusion(const std::vector<double> &cost_map) {
265 if (cost_map.size() < 2) {
266 MS_LOG(ERROR) << "cost_map must has at least 2 items, cost_map size is " << cost_map.size();
267 return FAILED;
268 }
269 int64_t fusion = 1;
270 for (auto cost_iter = cost_map.end() - 1; cost_iter != cost_map.begin(); --cost_iter) {
271 auto paras = allreduce_graph_.GetParaByCost(*(cost_iter - 1), *cost_iter);
272 if (FindMirrorAndSetFusion(paras, fusion) != SUCCESS) {
273 MS_LOG(ERROR) << "FindMirrorAndSetFusion failed";
274 return FAILED;
275 }
276 fusion++;
277 }
278 return SUCCESS;
279 }
280
GenerateCostMap(int64_t fusion_times,double tail_percent) const281 std::vector<double> AllreduceFusion::GenerateCostMap(int64_t fusion_times, double tail_percent) const {
282 double offset = allreduce_graph_.max() * (1 - tail_percent) / (fusion_times - 1);
283 MS_LOG(DEBUG) << "max = " << allreduce_graph_.max() << ", offset = " << offset;
284 std::vector<double> cost_map;
285 double begin = 0;
286 for (auto i = 0; i < fusion_times - 1; i++) {
287 cost_map.push_back(begin);
288 begin += offset;
289 }
290 cost_map.push_back(allreduce_graph_.max() * (1 - tail_percent));
291 cost_map.push_back(allreduce_graph_.max());
292 MS_LOG(DEBUG) << "cost_map = " << cost_map;
293 return cost_map;
294 }
295
SetFusionByBackwardCompTime()296 Status AllreduceFusion::SetFusionByBackwardCompTime() {
297 auto fusion_times = CostModelContext::GetInstance()->costmodel_allreduce_fusion_times();
298 if (fusion_times < 2) {
299 MS_LOG(INFO) << "'costmodel_allreduce_fusion_times' is " << fusion_times << ". Bypass ProcessAllreduceFusion";
300 return SUCCESS;
301 }
302 auto tail_percent = CostModelContext::GetInstance()->costmodel_allreduce_fusion_tail_percent();
303 if (tail_percent < 0 || tail_percent >= 1) {
304 MS_LOG(INFO) << "'costmodel_allreduce_fusion_tail_percent' is " << tail_percent
305 << ". Bypass ProcessAllreduceFusion";
306 return SUCCESS;
307 }
308 const auto cost_map = GenerateCostMap(fusion_times, tail_percent);
309 MS_LOG(DEBUG) << "AllreduceGraph GenerateCostMap succeed.";
310 if (SetFusion(cost_map) != SUCCESS) {
311 MS_LOG(ERROR) << "SetFusion failed.";
312 return FAILED;
313 }
314 MS_LOG(DEBUG) << "AllreduceGraph SetFusion succeed.";
315 return SUCCESS;
316 }
317
GetSetFusionByBackwardCompAndAllreduceTimeParams()318 Status AllreduceFusion::GetSetFusionByBackwardCompAndAllreduceTimeParams() {
319 tail_time_ = CostModelContext::GetInstance()->costmodel_allreduce_fusion_tail_time();
320 if (tail_time_ <= 0) {
321 MS_LOG(INFO) << "'costmodel_allreduce_tail_time' is " << tail_time_ << ". Bypass ProcessAllreduceFusion";
322 return FAILED;
323 }
324 allreduce_inherent_time_ = CostModelContext::GetInstance()->costmodel_allreduce_fusion_allreduce_inherent_time();
325 if (allreduce_inherent_time_ <= 0) {
326 MS_LOG(INFO) << "'costmodel_allreduce_fusion_allreduce_inherent_time' is " << allreduce_inherent_time_
327 << ". Bypass ProcessAllreduceFusion";
328 return FAILED;
329 }
330 if (tail_time_ <= allreduce_inherent_time_) {
331 MS_LOG(INFO) << "'costmodel_allreduce_tail_time' is " << tail_time_
332 << "'costmodel_allreduce_fusion_allreduce_inherent_time' is " << allreduce_inherent_time_
333 << ".tail_time is not more than allreduce_inherent_time. Bypass ProcessAllreduceFusion";
334 return FAILED;
335 }
336 allreduce_bandwidth_ = CostModelContext::GetInstance()->costmodel_allreduce_fusion_allreduce_bandwidth();
337 if (allreduce_bandwidth_ <= 0) {
338 MS_LOG(INFO) << "'costmodel_allreduce_fusion_allreduce_bandwidth' is " << allreduce_bandwidth_
339 << ". Bypass ProcessAllreduceFusion";
340 return FAILED;
341 }
342 computation_time_parameter_ =
343 CostModelContext::GetInstance()->costmodel_allreduce_fusion_computation_time_parameter();
344 if (computation_time_parameter_ <= 0) {
345 MS_LOG(INFO) << "'costmodel_allreduce_fusion_computation_time_parameter' is " << computation_time_parameter_
346 << ". Bypass ProcessAllreduceFusion";
347 return FAILED;
348 }
349 return SUCCESS;
350 }
351
SetFusionByBackwardCompAndAllreduceTime()352 Status AllreduceFusion::SetFusionByBackwardCompAndAllreduceTime() {
353 if (GetSetFusionByBackwardCompAndAllreduceTimeParams() != SUCCESS) {
354 MS_LOG(ERROR) << "GetSetFusionByBackwardCompAndAllreduceTimeParams failed!";
355 return FAILED;
356 }
357 allreduce_graph_.SortArnode();
358 if (allreduce_graph_.RemoveExtraParas() != SUCCESS) {
359 MS_LOG(ERROR) << "RemoveExtraParas failed!";
360 return FAILED;
361 }
362 double para_size = (tail_time_ - allreduce_inherent_time_) / allreduce_bandwidth_;
363 double to_cost = allreduce_graph_.max();
364 int64_t fusion = 1;
365 while (to_cost != 0) {
366 MS_LOG(INFO) << "to_cost: " << to_cost << " para_size: " << para_size;
367 auto node_cost_pair = allreduce_graph_.GetParaByParaSize(to_cost, para_size);
368 MS_LOG(INFO) << "para size: " << node_cost_pair.first.size() << " from_cost: " << node_cost_pair.second;
369 auto paras = node_cost_pair.first;
370 if (FindMirrorAndSetFusion(paras, fusion) != SUCCESS) {
371 MS_LOG(ERROR) << "FindMirrorAndSetFusion failed";
372 return FAILED;
373 }
374 fusion++;
375 para_size = ((to_cost - node_cost_pair.second) * computation_time_parameter_ - allreduce_inherent_time_) /
376 allreduce_bandwidth_;
377 to_cost = node_cost_pair.second;
378 }
379 MS_LOG(DEBUG) << "AllreduceGraph SetFusionByBackwardCompAndAllreduceTime succeed.";
380 return SUCCESS;
381 }
382
SetFusionByAlgorithm(int64_t algorithm)383 Status AllreduceFusion::SetFusionByAlgorithm(int64_t algorithm) {
384 if (algorithm == 1) {
385 return SetFusionByBackwardCompTime();
386 }
387 return SetFusionByBackwardCompAndAllreduceTime();
388 }
389
ProcessAllreduceFusion(const CNodePtr & ret)390 Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr &ret) {
391 if (ret == nullptr) {
392 MS_LOG(ERROR) << "ret is nullptr.";
393 return FAILED;
394 }
395 auto algorithm = CostModelContext::GetInstance()->costmodel_allreduce_fusion_algorithm();
396 if (algorithm < 1 || algorithm > 2) {
397 MS_LOG(INFO) << "'costmodel_allreduce_fusion_algorithm' is " << algorithm << ". Bypass ProcessAllreduceFusion";
398 return SUCCESS;
399 }
400 ret_ = ret;
401 root_graph_ = ret_->func_graph();
402 MS_EXCEPTION_IF_NULL(root_graph_);
403 auto graph_set = ForwardGraph(root_graph_);
404 if (graph_set.size() > 1) {
405 MS_LOG(WARNING) << "AllReduce fusion don't support multiple subgraphs now.";
406 return SUCCESS;
407 }
408 auto forward_graph = *(graph_set.begin());
409 MS_EXCEPTION_IF_NULL(forward_graph);
410 forward_ret_ = forward_graph->get_return();
411 MS_EXCEPTION_IF_NULL(forward_ret_);
412
413 if (allreduce_graph_.set_head_cnode(forward_ret_) != SUCCESS) {
414 MS_LOG(ERROR) << "AllreduceGraph set_head_cnode failed.";
415 return FAILED;
416 }
417 MS_LOG(DEBUG) << "AllreduceGraph set_head_cnode succeed.";
418 if (AddNodeToGraph() != SUCCESS) {
419 MS_LOG(ERROR) << "AddNodeToGraph failed.";
420 return FAILED;
421 }
422 MS_LOG(DEBUG) << "AllreduceGraph AddNodeToGraph succeed.";
423 if (AddEdgeToGraph() != SUCCESS) {
424 MS_LOG(ERROR) << "AddNodeToGraph failed.";
425 return FAILED;
426 }
427 MS_LOG(DEBUG) << "AllreduceGraph AddEdgeToGraph succeed.";
428 if (SetFusionByAlgorithm(algorithm) != SUCCESS) {
429 MS_LOG(ERROR) << "SetFusionByAlgorithm failed.";
430 return FAILED;
431 }
432 MS_LOG(DEBUG) << "AllreduceGraph SetFusionByAlgorithm succeed.";
433 return SUCCESS;
434 }
435 } // namespace parallel
436 } // namespace mindspore
437