• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/allreduce_fusion/allreduce_fusion.h"
18 #include <memory>
19 #include <queue>
20 #include <string>
21 #include <unordered_set>
22 #include "ir/func_graph.h"
23 #include "frontend/parallel/costmodel_context.h"
24 #include "frontend/parallel/graph_util/node_info.h"
25 #include "frontend/parallel/status.h"
26 #include "frontend/parallel/step_parallel.h"
27 #include "utils/log_adapter.h"
28 
29 namespace mindspore {
30 namespace parallel {
FindCNodesWithPara(const AnfNodePtr & para,uint64_t recursive_times=0)31 std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr &para, uint64_t recursive_times = 0) {
32   if (recursive_times > MAX_RECURSIVE_CALL_TIMES) {
33     MS_LOG(EXCEPTION) << "FindCNodesWithPara exceeds max recursive call times! Max recursive call times is "
34                       << MAX_RECURSIVE_CALL_TIMES;
35   }
36   MS_EXCEPTION_IF_NULL(para);
37   MS_EXCEPTION_IF_NULL(para->func_graph());
38   FuncGraphManagerPtr manager = para->func_graph()->manager();
39   MS_EXCEPTION_IF_NULL(manager);
40   auto node_set = manager->node_users()[para];
41   std::unordered_set<CNodePtr> cnode_set;
42   for (auto &node_pair : node_set) {
43     auto cnode = node_pair.first->cast<CNodePtr>();
44     MS_EXCEPTION_IF_NULL(cnode);
45     if (!IsValueNode<Primitive>(cnode->input(0))) {
46       continue;
47     }
48     auto node_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
49     MS_EXCEPTION_IF_NULL(node_prim);
50     if (node_prim->name() == DEPEND && node_pair.second != 1) {
51       continue;
52     }
53     if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
54       (void)cnode_set.emplace(cnode);
55     } else {
56       auto cnode_set_sub = FindCNodesWithPara(node_pair.first, recursive_times + 1);
57       for (auto &cnode_sub : cnode_set_sub) {
58         (void)cnode_set.emplace(cnode_sub);
59       }
60     }
61   }
62   return cnode_set;
63 }
64 
AddNodeToGraph()65 Status AllreduceFusion::AddNodeToGraph() {
66   const auto &parameters = root_graph_->parameters();
67   for (auto &parameter : parameters) {
68     if (!ParameterRequireGrad(parameter)) {
69       continue;
70     }
71     auto cnode_set = FindCNodesWithPara(parameter);
72     if (cnode_set.empty()) {
73       continue;
74     }
75     for (auto &cnode : cnode_set) {
76       MS_LOG(DEBUG) << "AddNode " << cnode->DebugString();
77       if (allreduce_graph_.AddNode(cnode, parameter) != SUCCESS) {
78         MS_LOG(ERROR) << "AddNode failed! cnode: " << cnode->DebugString();
79         return FAILED;
80       }
81     }
82   }
83   return SUCCESS;
84 }
85 
FindCNode(const AnfNodePtr & from,uint64_t recursive_times) const86 CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr &from, uint64_t recursive_times) const {
87   if (recursive_times > MAX_RECURSIVE_CALL_TIMES) {
88     MS_LOG(EXCEPTION) << "FindCNode exceeds max recursive call times! Max recursive call times is "
89                       << MAX_RECURSIVE_CALL_TIMES;
90   }
91   MS_EXCEPTION_IF_NULL(from);
92   std::unordered_map<CNodePtr, double> cnode_dist;
93   if (!from->isa<CNode>()) {
94     return cnode_dist;
95   }
96   auto cnode = from->cast<CNodePtr>();
97   if (!IsValueNode<Primitive>(cnode->input(0))) {
98     return cnode_dist;
99   }
100 
101   auto operator_info = cnode->user_data<OperatorInfo>();
102   MS_LOG(DEBUG) << "cnode " << cnode->ToString() << " IsParallelCareNode: " << IsParallelCareNode(cnode)
103                 << " operator_info: " << (operator_info != nullptr);
104 
105   if (IsParallelCareNode(cnode) && (operator_info != nullptr)) {
106     auto cost = operator_info->GetForwardMemoryCostFromCNode();
107     MS_LOG(DEBUG) << "cnode " << cnode->DebugString() << " cost: " << cost;
108 
109     if (allreduce_graph_.NodeInGraph(cnode)) {
110       cnode_dist[cnode] = cost;
111       return cnode_dist;
112     } else {
113       auto cnode_dist_next = FindNextCNodes(cnode, recursive_times + 1);
114       for (auto &ele_next : cnode_dist_next) {
115         cnode_dist[ele_next.first] = cost + ele_next.second;
116       }
117     }
118   } else {
119     auto cnode_dist_next = FindNextCNodes(cnode);
120     for (auto &ele : cnode_dist_next) {
121       cnode_dist[ele.first] = ele.second;
122     }
123   }
124   return cnode_dist;
125 }
126 
FindNextCNodes(const CNodePtr & from,uint64_t recursive_times) const127 CNodeCostMap AllreduceFusion::FindNextCNodes(const CNodePtr &from, uint64_t recursive_times) const {
128   if (recursive_times > MAX_RECURSIVE_CALL_TIMES) {
129     MS_LOG(EXCEPTION) << "FindNextCNodes exceeds max recursive call times! Max recursive call times is "
130                       << MAX_RECURSIVE_CALL_TIMES;
131   }
132   const auto &from_inputs = from->inputs();
133   std::unordered_map<CNodePtr, double> dist_map;
134   MS_LOG(DEBUG) << "from cnode " << from->DebugString() << " has " << from_inputs.size() << " inputs";
135   for (auto &input_node : from_inputs) {
136     auto cnode_dist = FindCNode(input_node, recursive_times + 1);
137     for (auto &ele : cnode_dist) {
138       (void)dist_map.emplace(ele);
139     }
140   }
141   return dist_map;
142 }
143 
AddEdgeToGraph()144 Status AllreduceFusion::AddEdgeToGraph() {
145   std::unordered_map<CNodePtr, int64_t> cnode_state_map;
146   const auto &cnodes = allreduce_graph_.cnode_set();
147   for (auto &cnode : cnodes) {
148     cnode_state_map[cnode] = 0;
149   }
150   const auto &head_cnode = allreduce_graph_.head_cnode();
151   std::queue<CNodePtr> cnode_queue;
152   cnode_queue.emplace(head_cnode);
153   cnode_state_map[head_cnode] = 1;
154 
155   while (!cnode_queue.empty()) {
156     const auto cur_cnode = cnode_queue.front();
157     cnode_queue.pop();
158     cnode_state_map[cur_cnode] = 2;
159     auto next = FindNextCNodes(cur_cnode);
160     for (auto &ele : next) {
161       auto &cnode = ele.first;
162       auto &dist = ele.second;
163       if (cnode_state_map[cnode] == 0) {
164         cnode_queue.emplace(cnode);
165         cnode_state_map[cnode] = 1;
166       }
167       if (allreduce_graph_.AddEdge(cur_cnode, cnode, dist) != SUCCESS) {
168         MS_LOG(ERROR) << "AddEdge error";
169         return FAILED;
170       }
171       MS_LOG(DEBUG) << "from " << cur_cnode->DebugString() << ", to " << cnode->DebugString() << " dist " << dist;
172     }
173   }
174   return SUCCESS;
175 }
176 
FindMirror(const AnfNodePtr & para,uint64_t recursive_times=0)177 std::vector<CNodePtr> FindMirror(const AnfNodePtr &para, uint64_t recursive_times = 0) {
178   if (recursive_times > MAX_RECURSIVE_CALL_TIMES) {
179     MS_LOG(EXCEPTION) << "FindMirror exceeds max recursive call times! Max recursive call times is "
180                       << MAX_RECURSIVE_CALL_TIMES;
181   }
182   MS_EXCEPTION_IF_NULL(para);
183   MS_EXCEPTION_IF_NULL(para->func_graph());
184   FuncGraphManagerPtr manager = para->func_graph()->manager();
185   MS_EXCEPTION_IF_NULL(manager);
186   AnfNodeIndexSet node_set = manager->node_users()[para];
187   std::vector<CNodePtr> cnode_list;
188   for (auto &node_pair : node_set) {
189     auto cnode = node_pair.first->cast<CNodePtr>();
190     MS_EXCEPTION_IF_NULL(cnode);
191     if (!IsValueNode<Primitive>(cnode->input(0))) {
192       continue;
193     }
194     auto node_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
195     MS_EXCEPTION_IF_NULL(node_prim);
196     if (node_prim->name() == CAST) {
197       auto mirror_cnodes = FindMirror(node_pair.first, recursive_times + 1);
198       if (mirror_cnodes.empty()) {
199         MS_LOG(WARNING) << "mirror node after cast not found";
200         continue;
201       }
202       if (mirror_cnodes.size() > 1) {
203         MS_LOG(EXCEPTION) << "mirror node after cast number is not 1";
204       }
205       cnode_list.emplace_back(mirror_cnodes[0]);
206     }
207     if (node_prim->name() == MIRROR_OPERATOR) {
208       cnode_list.emplace_back(cnode);
209     }
210   }
211   return cnode_list;
212 }
213 
SetMirrorFusion(const CNodePtr & mirror_cnode,int64_t fusion,const std::string & parameter_name)214 void SetMirrorFusion(const CNodePtr &mirror_cnode, int64_t fusion, const std::string &parameter_name) {
215   MS_EXCEPTION_IF_NULL(mirror_cnode);
216   MS_LOG(DEBUG) << "Set Mirror " << mirror_cnode->DebugString() << " fusion " << fusion;
217   auto node_prim = GetValueNode<PrimitivePtr>(mirror_cnode->input(0));
218   auto old_value_ptr = node_prim->GetAttr(FUSION);
219   if (old_value_ptr != nullptr) {
220     if (old_value_ptr->isa<Int64Imm>()) {
221       int64_t old_value = old_value_ptr->cast<Int64ImmPtr>()->value();
222       if (old_value < fusion) {
223         return;
224       }
225     }
226   }
227   (void)node_prim->AddAttr(FUSION, MakeValue(std::make_shared<Int64Imm>(fusion)));
228   (void)node_prim->AddAttr(PARAMETER, MakeValue(std::make_shared<StringImm>(parameter_name)));
229 }
230 
FindMirrorAndSetFusion(const AnfNodePtr & para,int64_t fusion)231 Status FindMirrorAndSetFusion(const AnfNodePtr &para, int64_t fusion) {
232   auto mirror_cnodes = FindMirror(para);
233   if (mirror_cnodes.empty()) {
234     MS_LOG(WARNING) << para->ToString() << " 0 Mirror CNode found.";
235     return SUCCESS;
236   }
237   if (mirror_cnodes.size() > 2) {
238     for (auto &mirror_cnode_1 : mirror_cnodes) {
239       MS_EXCEPTION_IF_NULL(mirror_cnode_1);
240       MS_LOG(INFO) << mirror_cnode_1->DebugString();
241     }
242     MS_EXCEPTION_IF_NULL(para);
243     MS_LOG(ERROR) << para->ToString() << " FindMirror is more than 2. " << mirror_cnodes.size()
244                   << "Mirror CNode found.";
245     return FAILED;
246   }
247   for (auto &mirror_cnode : mirror_cnodes) {
248     auto parameter_name = ParameterName(para);
249     SetMirrorFusion(mirror_cnode, fusion, parameter_name);
250   }
251   return SUCCESS;
252 }
253 
FindMirrorAndSetFusion(const std::vector<AnfNodePtr> & paras,int64_t fusion)254 Status FindMirrorAndSetFusion(const std::vector<AnfNodePtr> &paras, int64_t fusion) {
255   for (auto &param_node : paras) {
256     if (FindMirrorAndSetFusion(param_node, fusion) != SUCCESS) {
257       MS_LOG(ERROR) << "FindMirrorAndSetFusion failed";
258       return FAILED;
259     }
260   }
261   return SUCCESS;
262 }
263 
SetFusion(const std::vector<double> & cost_map)264 Status AllreduceFusion::SetFusion(const std::vector<double> &cost_map) {
265   if (cost_map.size() < 2) {
266     MS_LOG(ERROR) << "cost_map must has at least 2 items, cost_map size is " << cost_map.size();
267     return FAILED;
268   }
269   int64_t fusion = 1;
270   for (auto cost_iter = cost_map.end() - 1; cost_iter != cost_map.begin(); --cost_iter) {
271     auto paras = allreduce_graph_.GetParaByCost(*(cost_iter - 1), *cost_iter);
272     if (FindMirrorAndSetFusion(paras, fusion) != SUCCESS) {
273       MS_LOG(ERROR) << "FindMirrorAndSetFusion failed";
274       return FAILED;
275     }
276     fusion++;
277   }
278   return SUCCESS;
279 }
280 
GenerateCostMap(int64_t fusion_times,double tail_percent) const281 std::vector<double> AllreduceFusion::GenerateCostMap(int64_t fusion_times, double tail_percent) const {
282   double offset = allreduce_graph_.max() * (1 - tail_percent) / (fusion_times - 1);
283   MS_LOG(DEBUG) << "max = " << allreduce_graph_.max() << ", offset = " << offset;
284   std::vector<double> cost_map;
285   double begin = 0;
286   for (auto i = 0; i < fusion_times - 1; i++) {
287     cost_map.push_back(begin);
288     begin += offset;
289   }
290   cost_map.push_back(allreduce_graph_.max() * (1 - tail_percent));
291   cost_map.push_back(allreduce_graph_.max());
292   MS_LOG(DEBUG) << "cost_map = " << cost_map;
293   return cost_map;
294 }
295 
SetFusionByBackwardCompTime()296 Status AllreduceFusion::SetFusionByBackwardCompTime() {
297   auto fusion_times = CostModelContext::GetInstance()->costmodel_allreduce_fusion_times();
298   if (fusion_times < 2) {
299     MS_LOG(INFO) << "'costmodel_allreduce_fusion_times' is " << fusion_times << ". Bypass ProcessAllreduceFusion";
300     return SUCCESS;
301   }
302   auto tail_percent = CostModelContext::GetInstance()->costmodel_allreduce_fusion_tail_percent();
303   if (tail_percent < 0 || tail_percent >= 1) {
304     MS_LOG(INFO) << "'costmodel_allreduce_fusion_tail_percent' is " << tail_percent
305                  << ". Bypass ProcessAllreduceFusion";
306     return SUCCESS;
307   }
308   const auto cost_map = GenerateCostMap(fusion_times, tail_percent);
309   MS_LOG(DEBUG) << "AllreduceGraph GenerateCostMap succeed.";
310   if (SetFusion(cost_map) != SUCCESS) {
311     MS_LOG(ERROR) << "SetFusion failed.";
312     return FAILED;
313   }
314   MS_LOG(DEBUG) << "AllreduceGraph SetFusion succeed.";
315   return SUCCESS;
316 }
317 
GetSetFusionByBackwardCompAndAllreduceTimeParams()318 Status AllreduceFusion::GetSetFusionByBackwardCompAndAllreduceTimeParams() {
319   tail_time_ = CostModelContext::GetInstance()->costmodel_allreduce_fusion_tail_time();
320   if (tail_time_ <= 0) {
321     MS_LOG(INFO) << "'costmodel_allreduce_tail_time' is " << tail_time_ << ". Bypass ProcessAllreduceFusion";
322     return FAILED;
323   }
324   allreduce_inherent_time_ = CostModelContext::GetInstance()->costmodel_allreduce_fusion_allreduce_inherent_time();
325   if (allreduce_inherent_time_ <= 0) {
326     MS_LOG(INFO) << "'costmodel_allreduce_fusion_allreduce_inherent_time' is " << allreduce_inherent_time_
327                  << ". Bypass ProcessAllreduceFusion";
328     return FAILED;
329   }
330   if (tail_time_ <= allreduce_inherent_time_) {
331     MS_LOG(INFO) << "'costmodel_allreduce_tail_time' is " << tail_time_
332                  << "'costmodel_allreduce_fusion_allreduce_inherent_time' is " << allreduce_inherent_time_
333                  << ".tail_time is not more than allreduce_inherent_time. Bypass ProcessAllreduceFusion";
334     return FAILED;
335   }
336   allreduce_bandwidth_ = CostModelContext::GetInstance()->costmodel_allreduce_fusion_allreduce_bandwidth();
337   if (allreduce_bandwidth_ <= 0) {
338     MS_LOG(INFO) << "'costmodel_allreduce_fusion_allreduce_bandwidth' is " << allreduce_bandwidth_
339                  << ". Bypass ProcessAllreduceFusion";
340     return FAILED;
341   }
342   computation_time_parameter_ =
343     CostModelContext::GetInstance()->costmodel_allreduce_fusion_computation_time_parameter();
344   if (computation_time_parameter_ <= 0) {
345     MS_LOG(INFO) << "'costmodel_allreduce_fusion_computation_time_parameter' is " << computation_time_parameter_
346                  << ". Bypass ProcessAllreduceFusion";
347     return FAILED;
348   }
349   return SUCCESS;
350 }
351 
SetFusionByBackwardCompAndAllreduceTime()352 Status AllreduceFusion::SetFusionByBackwardCompAndAllreduceTime() {
353   if (GetSetFusionByBackwardCompAndAllreduceTimeParams() != SUCCESS) {
354     MS_LOG(ERROR) << "GetSetFusionByBackwardCompAndAllreduceTimeParams failed!";
355     return FAILED;
356   }
357   allreduce_graph_.SortArnode();
358   if (allreduce_graph_.RemoveExtraParas() != SUCCESS) {
359     MS_LOG(ERROR) << "RemoveExtraParas failed!";
360     return FAILED;
361   }
362   double para_size = (tail_time_ - allreduce_inherent_time_) / allreduce_bandwidth_;
363   double to_cost = allreduce_graph_.max();
364   int64_t fusion = 1;
365   while (to_cost != 0) {
366     MS_LOG(INFO) << "to_cost: " << to_cost << " para_size: " << para_size;
367     auto node_cost_pair = allreduce_graph_.GetParaByParaSize(to_cost, para_size);
368     MS_LOG(INFO) << "para size: " << node_cost_pair.first.size() << " from_cost: " << node_cost_pair.second;
369     auto paras = node_cost_pair.first;
370     if (FindMirrorAndSetFusion(paras, fusion) != SUCCESS) {
371       MS_LOG(ERROR) << "FindMirrorAndSetFusion failed";
372       return FAILED;
373     }
374     fusion++;
375     para_size = ((to_cost - node_cost_pair.second) * computation_time_parameter_ - allreduce_inherent_time_) /
376                 allreduce_bandwidth_;
377     to_cost = node_cost_pair.second;
378   }
379   MS_LOG(DEBUG) << "AllreduceGraph SetFusionByBackwardCompAndAllreduceTime succeed.";
380   return SUCCESS;
381 }
382 
SetFusionByAlgorithm(int64_t algorithm)383 Status AllreduceFusion::SetFusionByAlgorithm(int64_t algorithm) {
384   if (algorithm == 1) {
385     return SetFusionByBackwardCompTime();
386   }
387   return SetFusionByBackwardCompAndAllreduceTime();
388 }
389 
ProcessAllreduceFusion(const CNodePtr & ret)390 Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr &ret) {
391   if (ret == nullptr) {
392     MS_LOG(ERROR) << "ret is nullptr.";
393     return FAILED;
394   }
395   auto algorithm = CostModelContext::GetInstance()->costmodel_allreduce_fusion_algorithm();
396   if (algorithm < 1 || algorithm > 2) {
397     MS_LOG(INFO) << "'costmodel_allreduce_fusion_algorithm' is " << algorithm << ". Bypass ProcessAllreduceFusion";
398     return SUCCESS;
399   }
400   ret_ = ret;
401   root_graph_ = ret_->func_graph();
402   MS_EXCEPTION_IF_NULL(root_graph_);
403   auto graph_set = ForwardGraph(root_graph_);
404   if (graph_set.size() > 1) {
405     MS_LOG(WARNING) << "AllReduce fusion don't support multiple subgraphs now.";
406     return SUCCESS;
407   }
408   auto forward_graph = *(graph_set.begin());
409   MS_EXCEPTION_IF_NULL(forward_graph);
410   forward_ret_ = forward_graph->get_return();
411   MS_EXCEPTION_IF_NULL(forward_ret_);
412 
413   if (allreduce_graph_.set_head_cnode(forward_ret_) != SUCCESS) {
414     MS_LOG(ERROR) << "AllreduceGraph set_head_cnode failed.";
415     return FAILED;
416   }
417   MS_LOG(DEBUG) << "AllreduceGraph set_head_cnode succeed.";
418   if (AddNodeToGraph() != SUCCESS) {
419     MS_LOG(ERROR) << "AddNodeToGraph failed.";
420     return FAILED;
421   }
422   MS_LOG(DEBUG) << "AllreduceGraph AddNodeToGraph succeed.";
423   if (AddEdgeToGraph() != SUCCESS) {
424     MS_LOG(ERROR) << "AddNodeToGraph failed.";
425     return FAILED;
426   }
427   MS_LOG(DEBUG) << "AllreduceGraph AddEdgeToGraph succeed.";
428   if (SetFusionByAlgorithm(algorithm) != SUCCESS) {
429     MS_LOG(ERROR) << "SetFusionByAlgorithm failed.";
430     return FAILED;
431   }
432   MS_LOG(DEBUG) << "AllreduceGraph SetFusionByAlgorithm succeed.";
433   return SUCCESS;
434 }
435 }  // namespace parallel
436 }  // namespace mindspore
437