• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <iterator>
18 #include <memory>
19 #include <list>
20 #include <set>
21 #include <algorithm>
22 #include "frontend/parallel/graph_util/pipeline_split_utils.h"
23 #include "frontend/parallel/graph_util/generate_graph.h"
24 #include "base/core_ops.h"
25 #include "ir/value.h"
26 #include "frontend/parallel/ops_info/ops_utils.h"
27 #include "frontend/parallel/device_manager.h"
28 #include "frontend/parallel/context.h"
29 #include "frontend/parallel/step_parallel.h"
30 #include "frontend/parallel/graph_util/node_info.h"
31 #include "utils/parallel_node_check.h"
32 
33 namespace mindspore {
34 namespace parallel {
35 const std::set<PrimitivePtr> END_NODE_BLACK_LIST = {
36   prim::kPrimDepend,    prim::kPrimTupleGetItem, prim::kPrimAdd,    prim::kPrimSoftmaxCrossEntropyWithLogits,
37   prim::kPrimMakeTuple, prim::kPrimUpdateState,  prim::kPrimReshape};
38 
IsInEndNodeBlackList(const CNodePtr & cnode)39 static bool IsInEndNodeBlackList(const CNodePtr &cnode) {
40   MS_EXCEPTION_IF_NULL(cnode);
41   if (!IsValueNode<Primitive>(cnode->input(0))) {
42     return true;
43   }
44   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
45   if (IsInParallelBlackList(prim)) {
46     return true;
47   }
48   for (auto &prim_node : END_NODE_BLACK_LIST) {
49     if (IsPrimitiveCNode(cnode, prim_node)) {
50       return true;
51     }
52   }
53   return false;
54 }
55 
FindAccuGrad(const CNodePtr & cnode)56 AnfNodePtr FindAccuGrad(const CNodePtr &cnode) {
57   auto pre_node = cnode->input(1);
58   size_t depth = 0;
59   while (true) {
60     if (depth > MAX_RECURSIVE_DEPTH) {
61       return nullptr;
62     }
63     depth += 1;
64     if (pre_node->isa<Parameter>()) {
65       return pre_node;
66     } else {
67       if (pre_node->isa<CNode>()) {
68         auto pre_cnode = pre_node->cast<CNodePtr>();
69         pre_node = pre_cnode->input(1);
70       } else {
71         return nullptr;
72       }
73     }
74   }
75   return nullptr;
76 }
77 
IsLastStage()78 bool IsLastStage() {
79   MS_EXCEPTION_IF_NULL(g_device_manager);
80   auto stage_num = g_device_manager->stage_num();
81   auto stage_id = g_device_manager->stage_id();
82   return ((stage_num - 1) == stage_id);
83 }
84 
SetStridedSliceStrategy(const AnfNodePtr & node)85 void SetStridedSliceStrategy(const AnfNodePtr &node) {
86   MS_EXCEPTION_IF_NULL(node);
87   if (!IsPrimitiveCNode(node, prim::kPrimStridedSlice)) {
88     return;
89   }
90   auto cnode = node->cast<CNodePtr>();
91   MS_EXCEPTION_IF_NULL(cnode);
92   std::vector<Shapes> shape_list = ExtractShape(cnode);
93   if (shape_list.empty()) {
94     MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " failed to extract shape";
95   }
96   std::vector<ValuePtr> elements;
97   for (size_t i = 0; i < shape_list[0].size(); i++) {
98     if (shape_list[0][i].empty()) {
99       MS_LOG(EXCEPTION) << "shape_list[ " << i << " ].size() is zero";
100     }
101     Dimensions input_strategy;
102     for (size_t j = 0; j < shape_list[0][i].size(); j++) {
103       input_strategy.push_back(1);
104     }
105     elements.push_back(MakeValue(input_strategy));
106   }
107   ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements);
108   cnode->AddPrimalAttr(STRATEGY, strategy);
109 }
110 
InsertVirtualAssignAdd(const std::pair<AnfNodePtr,int> & node_user,const FuncGraphManagerPtr & manager,const AnfNodePtr & accu_parameter)111 void InsertVirtualAssignAdd(const std::pair<AnfNodePtr, int> &node_user, const FuncGraphManagerPtr &manager,
112                             const AnfNodePtr &accu_parameter) {
113   auto cnode = node_user.first->cast<CNodePtr>();
114   if (IsPrimitiveCNode(cnode, prim::kPrimReceive) || !cnode->in_forward_flag()) {
115     return;
116   }
117   MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
118   bool enable_parallel_optimizer = ParallelContext::GetInstance()->enable_parallel_optimizer();
119   if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && enable_parallel_optimizer) {
120     return;
121   }
122   auto prim = GetCNodePrimitive(cnode);
123   if (prim == nullptr) {
124     MS_LOG(WARNING) << cnode->DebugString() << " can not insert _VirtualAssignAdd.";
125     return;
126   }
127   OperatorAttrs attrs;
128   auto py_instance = CreatOpInstance(attrs, VIRTUAL_ASSIGN_ADD, VIRTUAL_ASSIGN_ADD);
129   auto value_node = NewValueNode(py_instance);
130   std::vector<AnfNodePtr> virtual_node_input = {value_node, cnode->input(IntToSize(node_user.second)), accu_parameter};
131   auto graph = cnode->func_graph();
132   auto virtual_node = graph->NewCNode(virtual_node_input);
133   manager->SetEdge(cnode, node_user.second, virtual_node);
134 }
135 
InsertVirtualAccuGrad(const AnfNodePtr & recv,const FuncGraphManagerPtr & manager,const AnfNodePtr & param)136 void InsertVirtualAccuGrad(const AnfNodePtr &recv, const FuncGraphManagerPtr &manager, const AnfNodePtr &param) {
137   auto cnode = recv->cast<CNodePtr>();
138   MS_EXCEPTION_IF_NULL(cnode);
139   OperatorAttrs attrs;
140   auto py_instance = CreatOpInstance(attrs, VIRTUAL_ACCU_GRAD, VIRTUAL_ACCU_GRAD);
141   auto value_node = NewValueNode(py_instance);
142   std::vector<AnfNodePtr> virtual_node_input = {value_node, recv, param};
143   auto graph = cnode->func_graph();
144   MS_EXCEPTION_IF_NULL(graph);
145   auto virtual_node = graph->NewCNode(virtual_node_input);
146   (void)manager->Replace(recv, virtual_node);
147 }
148 
FindGradAccuParameter(const std::vector<AnfNodePtr> & parameters,const std::string & name)149 AnfNodePtr FindGradAccuParameter(const std::vector<AnfNodePtr> &parameters, const std::string &name) {
150   for (auto &parameter : parameters) {
151     auto param_ptr = parameter->cast<ParameterPtr>();
152     MS_EXCEPTION_IF_NULL(param_ptr);
153     if (param_ptr->name() == name) {
154       continue;
155     }
156     auto expect_name = "accu_grads." + name;
157     if (param_ptr->name() == expect_name) {
158       return parameter;
159     }
160   }
161   return nullptr;
162 }
163 
HandleReceiveParam(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & all_nodes)164 void HandleReceiveParam(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
165   auto parameters = root->parameters();
166   auto node_users_map = root->manager()->node_users();
167   for (auto &node : all_nodes) {
168     if (!IsPrimitiveCNode(node, prim::kPrimReceive)) {
169       continue;
170     }
171     auto cnode = node->cast<CNodePtr>();
172     if (!cnode->HasPrimalAttr(PIPELINE_PARAM)) {
173       continue;
174     }
175     auto parameter_ptr = cnode->input(1)->cast<ParameterPtr>();
176     MS_EXCEPTION_IF_NULL(parameter_ptr);
177     auto accu_parameter = FindGradAccuParameter(parameters, parameter_ptr->name());
178     if (!accu_parameter) {
179       continue;
180     }
181     auto node_users = node_users_map[node];
182     for (auto &temp_user : node_users) {
183       auto temp_node = temp_user.first;
184       // Micro virtual operator might be inserted after cast
185       if (IsPrimitiveCNode(temp_node, prim::kPrimCast)) {
186         temp_node = node_users_map[temp_node].begin()->first;
187       }
188       if (IsPrimitiveCNode(temp_node, prim::kPrimMirrorMicroStep) ||
189           IsPrimitiveCNode(temp_node, prim::kPrimMicroStepAllGather)) {
190         auto node_set = node_users_map[temp_node];
191         for (auto &node_user : node_set) {
192           InsertVirtualAssignAdd(node_user, root->manager(), accu_parameter);
193         }
194       } else {
195         InsertVirtualAssignAdd(temp_user, root->manager(), accu_parameter);
196       }
197     }
198     InsertVirtualAccuGrad(node, root->manager(), accu_parameter);
199   }
200 }
201 
AddVirtualAssignAdd(const FuncGraphPtr & root)202 void AddVirtualAssignAdd(const FuncGraphPtr &root) {
203   auto parameters = root->parameters();
204   auto node_users_map = root->manager()->node_users();
205   for (auto &parameter : parameters) {
206     auto parameter_ptr = parameter->cast<ParameterPtr>();
207     auto accu_parameter = FindGradAccuParameter(parameters, parameter_ptr->name());
208     if (!accu_parameter) {
209       continue;
210     }
211     auto node_users = node_users_map[parameter];
212     for (auto &temp_user : node_users) {
213       auto temp_node = temp_user.first;
214       // Micro virtual operator might be inserted after cast
215       if (IsPrimitiveCNode(temp_node, prim::kPrimCast)) {
216         temp_node = node_users_map[temp_node].begin()->first;
217       }
218       if (IsPrimitiveCNode(temp_node, prim::kPrimMirrorMicroStep) ||
219           IsPrimitiveCNode(temp_node, prim::kPrimMicroStepAllGather)) {
220         auto node_set = node_users_map[temp_node];
221         for (auto &node_user : node_set) {
222           InsertVirtualAssignAdd(node_user, root->manager(), accu_parameter);
223         }
224       } else {
225         InsertVirtualAssignAdd(temp_user, root->manager(), accu_parameter);
226       }
227     }
228   }
229 }
230 
CompFunc(const AnfNodePtr & node1,const AnfNodePtr & node2)231 bool CompFunc(const AnfNodePtr &node1, const AnfNodePtr &node2) {
232   MS_EXCEPTION_IF_NULL(node1);
233   MS_EXCEPTION_IF_NULL(node2);
234   auto cnode1 = node1->cast<CNodePtr>();
235   auto cnode2 = node2->cast<CNodePtr>();
236   MS_EXCEPTION_IF_NULL(cnode1);
237   MS_EXCEPTION_IF_NULL(cnode2);
238   auto micro1 = cnode1->GetPrimalAttr(MICRO);
239   auto micro2 = cnode2->GetPrimalAttr(MICRO);
240   MS_EXCEPTION_IF_NULL(micro1);
241   MS_EXCEPTION_IF_NULL(micro2);
242   auto micro1_value = GetValue<int64_t>(micro1);
243   auto micro2_value = GetValue<int64_t>(micro2);
244   if (micro1_value == micro2_value) {
245     auto prim1 = GetCNodePrimitive(cnode1);
246     auto prim2 = GetCNodePrimitive(cnode2);
247     MS_EXCEPTION_IF_NULL(prim1);
248     MS_EXCEPTION_IF_NULL(prim2);
249     auto rank_tag1 = prim1->GetAttr(SRC_RANK);
250     auto rank_tag2 = prim2->GetAttr(SRC_RANK);
251     if (rank_tag1 == nullptr) {
252       rank_tag1 = prim1->GetAttr(DEST_RANK);
253     }
254     if (rank_tag2 == nullptr) {
255       rank_tag2 = prim2->GetAttr(DEST_RANK);
256     }
257     if (!rank_tag1 || !rank_tag2) {
258       return false;
259     }
260     auto rank1_value = GetValue<int64_t>(rank_tag1);
261     auto rank2_value = GetValue<int64_t>(rank_tag2);
262     if (rank1_value == rank2_value) {
263       auto sr_tag1 = prim1->GetAttr(SR_TAG);
264       auto sr_tag2 = prim2->GetAttr(SR_TAG);
265       MS_EXCEPTION_IF_NULL(sr_tag1);
266       MS_EXCEPTION_IF_NULL(sr_tag2);
267       auto sr1_value = GetValue<int64_t>(sr_tag1);
268       auto sr2_value = GetValue<int64_t>(sr_tag2);
269       return sr1_value < sr2_value;
270     }
271     return rank1_value < rank2_value;
272   }
273   return micro1_value < micro2_value;
274 }
275 
InsertDepend(const AnfNodePtr & prior_node,const AnfNodePtr & post_node,const FuncGraphManagerPtr & manager,const FuncGraphPtr & root)276 void InsertDepend(const AnfNodePtr &prior_node, const AnfNodePtr &post_node, const FuncGraphManagerPtr &manager,
277                   const FuncGraphPtr &root) {
278   MS_EXCEPTION_IF_NULL(prior_node);
279   MS_EXCEPTION_IF_NULL(post_node);
280   auto post_cnode = post_node->cast<CNodePtr>();
281   MS_EXCEPTION_IF_NULL(post_cnode);
282   std::vector<AnfNodePtr> depend_input = {NewValueNode(prim::kPrimDepend), post_cnode->input(1), prior_node};
283   auto depend_node = root->NewCNode(depend_input);
284   manager->SetEdge(post_node, 1, depend_node);
285 }
286 
ReorderForForward(const std::vector<AnfNodePtr> & forward_start,const std::vector<AnfNodePtr> & forward_end,const FuncGraphPtr & root)287 void ReorderForForward(const std::vector<AnfNodePtr> &forward_start, const std::vector<AnfNodePtr> &forward_end,
288                        const FuncGraphPtr &root) {
289   MS_EXCEPTION_IF_NULL(g_device_manager);
290   MS_EXCEPTION_IF_NULL(root);
291   auto manager = root->manager();
292   MS_EXCEPTION_IF_NULL(manager);
293   auto stage_num = g_device_manager->stage_num();
294   auto stage_id = g_device_manager->stage_id();
295   for (size_t i = 1; i < LongToSize(stage_num - stage_id); ++i) {
296     auto prior_node = forward_end[i - 1];
297     auto post_node = forward_start[i];
298     InsertDepend(prior_node, post_node, manager, root);
299   }
300 }
301 
ReorderForBackward(const PipelinePair & forward_start_pair,const PipelinePair & forward_end_pair,const PipelinePair & backward_start_pair,const PipelinePair & backward_end_pair,const PipelinePair & forward_end_before_pair,const FuncGraphPtr & root)302 void ReorderForBackward(const PipelinePair &forward_start_pair, const PipelinePair &forward_end_pair,
303                         const PipelinePair &backward_start_pair, const PipelinePair &backward_end_pair,
304                         const PipelinePair &forward_end_before_pair, const FuncGraphPtr &root) {
305   MS_EXCEPTION_IF_NULL(g_device_manager);
306   MS_EXCEPTION_IF_NULL(root);
307   auto manager = root->manager();
308   MS_EXCEPTION_IF_NULL(manager);
309   auto stage_num = g_device_manager->stage_num();
310   auto stage_id = g_device_manager->stage_id();
311   for (size_t i = LongToSize(stage_num - stage_id); i < (forward_start_pair.first.size()); ++i) {
312     auto prior_node1 = forward_end_before_pair.second[i];
313     auto post_node1 = backward_start_pair.first[LongToSize(SizeToLong(i) - stage_num + stage_id + 1)];
314     InsertDepend(prior_node1, post_node1, manager, root);
315     auto prior_node2 = backward_end_pair.second[LongToSize(SizeToLong(i) - stage_num + stage_id)];
316     auto post_node2 = forward_start_pair.first[i];
317     InsertDepend(prior_node2, post_node2, manager, root);
318   }
319   for (size_t i = LongToSize(stage_num - stage_id); i < (forward_start_pair.first.size() + 1); ++i) {
320     if (!IsLastStage()) {
321       auto prior_node3 = backward_start_pair.second[LongToSize(SizeToLong(i) - stage_num + stage_id)];
322       auto post_node3 = forward_end_pair.first[i - 1];
323       InsertDepend(prior_node3, post_node3, manager, root);
324       auto prior_node4 = forward_end_pair.second[i - 1];
325       auto post_node4 = backward_end_pair.first[LongToSize(SizeToLong(i) - stage_num + stage_id)];
326       InsertDepend(prior_node4, post_node4, manager, root);
327     }
328   }
329   for (size_t j = LongToSize(SizeToLong(backward_start_pair.first.size()) - stage_num + stage_id + 1);
330        j < backward_start_pair.first.size(); ++j) {
331     auto prior_node5 = backward_end_pair.second[j - 1];
332     auto post_node5 = backward_start_pair.first[j];
333     InsertDepend(prior_node5, post_node5, manager, root);
334   }
335   if (!IsLastStage()) {
336     auto prior_node6 = forward_end_before_pair.second[LongToSize(stage_num - 1 - stage_id)];
337     auto post_node6 = backward_start_pair.first[0];
338     InsertDepend(prior_node6, post_node6, manager, root);
339   }
340 }
341 
ReorderForParams(const std::vector<AnfNodePtr> & backward_params,const std::vector<AnfNodePtr> & forward_params,const std::vector<AnfNodePtr> & allreduce_params,const PipelinePair & forward_params_pair,const PipelinePair & backward_params_pair,const std::vector<AnfNodePtr> & backward_end,const PipelinePair & forward_start_pair,const FuncGraphPtr & root)342 void ReorderForParams(const std::vector<AnfNodePtr> &backward_params, const std::vector<AnfNodePtr> &forward_params,
343                       const std::vector<AnfNodePtr> &allreduce_params, const PipelinePair &forward_params_pair,
344                       const PipelinePair &backward_params_pair, const std::vector<AnfNodePtr> &backward_end,
345                       const PipelinePair &forward_start_pair, const FuncGraphPtr &root) {
346   auto manager = root->manager();
347   MS_EXCEPTION_IF_NULL(manager);
348   if (!forward_params.empty()) {
349     auto prior_node = forward_params_pair.second[0];
350     auto post_node = forward_start_pair.first[0];
351     InsertDepend(prior_node, post_node, manager, root);
352   }
353   if (!backward_params.empty()) {
354     if (!allreduce_params.empty()) {
355       for (auto &node : allreduce_params) {
356         auto post_node1 = backward_params_pair.first[0];
357         InsertDepend(node, post_node1, manager, root);
358       }
359     }
360     auto prior_node2 = backward_end.back();
361     auto post_node2 = backward_params[0];
362     InsertDepend(prior_node2, post_node2, manager, root);
363   }
364 }
365 
GetMicroBatch(const AnfNodePtr & node)366 int64_t GetMicroBatch(const AnfNodePtr &node) {
367   MS_EXCEPTION_IF_NULL(node);
368   auto cnode = node->cast<CNodePtr>();
369   MS_EXCEPTION_IF_NULL(cnode);
370   auto micro_value = cnode->GetPrimalAttr(MICRO);
371   MS_EXCEPTION_IF_NULL(micro_value);
372   return GetValue<int64_t>(micro_value);
373 }
374 
Deduplicate(const std::vector<AnfNodePtr> & node_vector,const FuncGraphPtr & root,int64_t micro_max)375 PipelinePair Deduplicate(const std::vector<AnfNodePtr> &node_vector, const FuncGraphPtr &root, int64_t micro_max) {
376   std::vector<AnfNodePtr> temp_vec;
377   std::vector<AnfNodePtr> out_vec_begin;
378   std::vector<AnfNodePtr> out_vec_end;
379   auto manager = root->manager();
380   for (int64_t i = 0; i <= micro_max; ++i) {
381     temp_vec.clear();
382     if (!root->has_flag(TRAINING)) {
383       temp_vec = node_vector;
384     } else {
385       for (auto &node : node_vector) {
386         auto node_micro = GetMicroBatch(node);
387         if (node_micro == i) {
388           temp_vec.push_back(node);
389         }
390       }
391     }
392     if (temp_vec.size() <= 1) {
393       MS_LOG(INFO) << "No Duplicate MicroBatch.";
394       continue;
395     }
396     std::sort(temp_vec.begin(), temp_vec.end(), CompFunc);
397     for (size_t j = 0; j < temp_vec.size() - 1; ++j) {
398       auto prior_node = temp_vec[j];
399       auto post_node = temp_vec[j + 1];
400       InsertDepend(prior_node, post_node, manager, root);
401     }
402     if (!temp_vec.empty()) {
403       out_vec_begin.push_back(temp_vec.front());
404       out_vec_end.push_back(temp_vec.back());
405     }
406   }
407   if (out_vec_begin.empty()) {
408     return std::make_pair(node_vector, node_vector);
409   }
410   return std::make_pair(out_vec_begin, out_vec_end);
411 }
412 
BroadCastMicroBatch(const CNodePtr & node,NodeUsersMap * node_users_map,const ValuePtr & value,size_t max_depth)413 void BroadCastMicroBatch(const CNodePtr &node, NodeUsersMap *node_users_map, const ValuePtr &value, size_t max_depth) {
414   auto node_users = (*node_users_map)[node];
415   if (max_depth > MAX_RECURSIVE_DEPTH) {
416     MS_LOG(EXCEPTION) << "Recursive call is larger than 100000.";
417   }
418   for (auto &node_pair : node_users) {
419     auto user_node = node_pair.first->cast<CNodePtr>();
420     if (user_node->HasPrimalAttr(MICRO)) {
421       continue;
422     }
423     user_node->AddPrimalAttr(MICRO, value);
424     BroadCastMicroBatch(user_node, node_users_map, value, max_depth + 1);
425   }
426 }
427 
BroadCastNeedGrad(const AnfNodePtr & node,NodeUsersMap * node_user_map,const FuncGraphPtr & root)428 void BroadCastNeedGrad(const AnfNodePtr &node, NodeUsersMap *node_user_map, const FuncGraphPtr &root) {
429   auto node_users = (*node_user_map)[node];
430   for (auto &node_user : node_users) {
431     auto cnode = node_user.first->cast<CNodePtr>();
432     MS_EXCEPTION_IF_NULL(cnode);
433     if (cnode->HasPrimalAttr(NEED_GRAD)) {
434       continue;
435     }
436     if (cnode->func_graph() == root) {
437       continue;
438     }
439     cnode->AddPrimalAttr(NEED_GRAD, MakeValue(1));
440     BroadCastNeedGrad(cnode, node_user_map, root);
441   }
442 }
443 
444 // Label node that need backpropagation
LabelNeedGrad(const FuncGraphManagerPtr & manager,const FuncGraphPtr & root)445 void LabelNeedGrad(const FuncGraphManagerPtr &manager, const FuncGraphPtr &root) {
446   auto parameters = root->parameters();
447   auto node_user_map = manager->node_users();
448   for (auto &parameter : parameters) {
449     if (!ParameterRequireGrad(parameter)) {
450       continue;
451     }
452     auto param_ptr = parameter->cast<ParameterPtr>();
453     MS_EXCEPTION_IF_NULL(param_ptr);
454     if (param_ptr->name().find(ACCU_GRADS) != std::string::npos) {
455       continue;
456     }
457     BroadCastNeedGrad(parameter, &node_user_map, root);
458   }
459 }
460 
GetPreNode(const AnfNodePtr & node)461 AnfNodePtr GetPreNode(const AnfNodePtr &node) {
462   auto cnode = node->cast<CNodePtr>();
463   MS_EXCEPTION_IF_NULL(cnode);
464   std::vector<AnfNodePtr> node_queue = {node};
465   while (!node_queue.empty()) {
466     auto cur_node = (*node_queue.begin())->cast<CNodePtr>();
467     if (!cur_node) {
468       (void)node_queue.erase(node_queue.begin());
469       continue;
470     }
471     (void)node_queue.erase(node_queue.begin());
472     if (!IsInEndNodeBlackList(cur_node) && cur_node->HasPrimalAttr(NEED_GRAD)) {
473       MS_LOG(INFO) << "Pipeline End node: " << cur_node->DebugString();
474       return cur_node;
475     }
476     (void)node_queue.insert(node_queue.end(), cur_node->inputs().begin() + 1, cur_node->inputs().end());
477   }
478   MS_LOG(EXCEPTION) << "Get Pipeline End node failed.";
479 }
480 
LastStageEndNode(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphManagerPtr & manager,const FuncGraphPtr & root)481 void LastStageEndNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager,
482                       const FuncGraphPtr &root) {
483   if (!IsLastStage()) {
484     return;
485   }
486   LabelNeedGrad(manager, root);
487   for (auto &node : all_nodes) {
488     if (!node->isa<CNode>()) {
489       continue;
490     }
491     auto cnode = node->cast<CNodePtr>();
492     if (!cnode->HasPrimalAttr(MICRO)) {
493       continue;
494     }
495     auto prim = GetCNodePrimitive(node);
496     if (prim && prim->HasAttr(PIPELINE_END)) {
497       for (auto &temp_node : cnode->inputs()) {
498         if (!temp_node->isa<CNode>()) {
499           continue;
500         }
501         auto temp_prim = GetCNodePrimitive(temp_node);
502         if (!temp_prim || temp_prim->HasAttr(PIPELINE_END)) {
503           continue;
504         }
505         auto end_node = GetPreNode(temp_node);
506         MS_EXCEPTION_IF_NULL(end_node);
507         auto end_cnode = end_node->cast<CNodePtr>();
508         MS_EXCEPTION_IF_NULL(end_cnode);
509         auto end_prim = GetCNodePrimitive(end_node);
510         OperatorAttrs attrs_;
511         auto op = CreatOpInstance(attrs_, end_prim->name(), "");
512         auto value_node = NewValueNode(op);
513         auto new_prim = GetValueNode(value_node)->cast<PrimitivePtr>();
514         (void)new_prim->SetAttrs(end_prim->attrs());
515         manager->SetEdge(end_node, 0, value_node);
516         end_cnode->AddPrimalAttr(PIPELINE_END, end_cnode->GetPrimalAttr(MICRO));
517       }
518     }
519   }
520 }
521 
Micro(const CNodePtr & cnode,NodeUsersMap * node_users_map,size_t max_depth)522 ValuePtr Micro(const CNodePtr &cnode, NodeUsersMap *node_users_map, size_t max_depth) {
523   if (max_depth > MAX_RECURSIVE_DEPTH) {
524     MS_LOG(EXCEPTION) << "Recursive call is larger than 100000.";
525   }
526   if (cnode->HasPrimalAttr(MICRO)) {
527     return cnode->GetPrimalAttr(MICRO);
528   }
529   auto node_users = (*node_users_map)[cnode];
530   for (auto &node_pair : node_users) {
531     auto user_node = node_pair.first->cast<CNodePtr>();
532     auto micro = Micro(user_node, node_users_map, max_depth + 1);
533     if (micro) {
534       return micro;
535     }
536   }
537   return nullptr;
538 }
539 
ParameterStartNode(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphManagerPtr & manager)540 void ParameterStartNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager) {
541   auto node_users_map = manager->node_users();
542   for (auto &node : all_nodes) {
543     if (!node->isa<CNode>()) {
544       continue;
545     }
546     auto cnode = node->cast<CNodePtr>();
547     auto prim = GetCNodePrimitive(node);
548     if (prim && prim->HasAttr(PARAMETER_START)) {
549       auto micro = Micro(cnode, &node_users_map, 0);
550       cnode->AddPrimalAttr(MICRO, micro);
551       cnode->AddPrimalAttr(PARAMETER_START, micro);
552     }
553   }
554 }
555 
HandleMicroBatch(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphManagerPtr & manager)556 void HandleMicroBatch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager) {
557   auto node_users_map = manager->node_users();
558   for (auto &node : all_nodes) {
559     if (!node->isa<CNode>()) {
560       continue;
561     }
562     auto cnode = node->cast<CNodePtr>();
563     if (!cnode->HasPrimalAttr(MICRO)) {
564       continue;
565     }
566     auto micro = cnode->GetPrimalAttr(MICRO);
567     MS_EXCEPTION_IF_NULL(micro);
568     BroadCastMicroBatch(cnode, &node_users_map, micro, 0);
569   }
570 }
571 
GetActualOp(const AnfNodePtr & node)572 AnfNodePtr GetActualOp(const AnfNodePtr &node) {
573   if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
574     auto cnode = node->cast<CNodePtr>();
575     return cnode->input(1);
576   }
577   return node;
578 }
579 
GetBorderNode(std::vector<AnfNodePtr> * forward_start,std::vector<AnfNodePtr> * forward_end,std::vector<AnfNodePtr> * backward_start,std::vector<AnfNodePtr> * backward_end,std::vector<AnfNodePtr> * forward_params,std::vector<AnfNodePtr> * backward_params,std::vector<AnfNodePtr> * allreduce_params,const FuncGraphPtr & root)580 void GetBorderNode(std::vector<AnfNodePtr> *forward_start, std::vector<AnfNodePtr> *forward_end,
581                    std::vector<AnfNodePtr> *backward_start, std::vector<AnfNodePtr> *backward_end,
582                    std::vector<AnfNodePtr> *forward_params, std::vector<AnfNodePtr> *backward_params,
583                    std::vector<AnfNodePtr> *allreduce_params, const FuncGraphPtr &root) {
584   std::list<ValuePtr> name_list = {};
585   auto stage_id = g_device_manager->stage_id();
586   for (auto &node : root->nodes()) {
587     if (!node->isa<CNode>()) {
588       continue;
589     }
590     if (IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimZerosLike)) {
591       continue;
592     }
593     auto prim = GetCNodePrimitive(node);
594     auto cnode = node->cast<CNodePtr>();
595     if (cnode->HasPrimalAttr(kPrimalAttrForwardNodeName)) {
596       auto forward_node_name = cnode->GetPrimalAttr(kPrimalAttrForwardNodeName);
597       if (std::find(name_list.begin(), name_list.end(), forward_node_name) != name_list.end()) {
598         continue;
599       }
600       name_list.push_back(forward_node_name);
601       if (cnode->HasPrimalAttr(PIPELINE_END)) {
602         backward_start->push_back(node);
603       }
604       if (cnode->HasPrimalAttr(PIPELINE_BEGIN)) {
605         backward_end->push_back(node);
606       }
607       if (cnode->HasPrimalAttr(PARAMETER_START)) {
608         backward_end->push_back(node);
609       }
610       if (cnode->HasPrimalAttr(PIPELINE_PARAM)) {
611         backward_params->push_back(node);
612       }
613       if (prim->HasAttr(PARAMETER_MICRO)) {
614         allreduce_params->push_back(node);
615       }
616     } else {
617       if (cnode->HasPrimalAttr(PIPELINE_BEGIN)) {
618         if (stage_id != 0 && IsPrimitiveCNode(node, prim::kPrimStridedSlice)) {
619           continue;
620         }
621         forward_start->push_back(node);
622       }
623       if (cnode->HasPrimalAttr(PIPELINE_END)) {
624         forward_end->push_back(node);
625       }
626       if (cnode->HasPrimalAttr(PIPELINE_PARAM)) {
627         forward_params->push_back(node);
628       }
629     }
630   }
631   std::sort((*backward_start).begin(), (*backward_start).end(), CompFunc);
632   std::sort((*backward_end).begin(), (*backward_end).end(), CompFunc);
633   std::sort((*forward_start).begin(), (*forward_start).end(), CompFunc);
634   std::sort((*forward_end).begin(), (*forward_end).end(), CompFunc);
635   std::sort((*backward_params).begin(), (*backward_params).end(), CompFunc);
636   std::sort((*forward_params).begin(), (*forward_params).end(), CompFunc);
637 }
638 
CheckBorderNode(const PipelinePair & forward_start_pair,const PipelinePair & forward_end_pair,const PipelinePair & backward_start_pair,const PipelinePair & backward_end_pair,size_t micro_size)639 void CheckBorderNode(const PipelinePair &forward_start_pair, const PipelinePair &forward_end_pair,
640                      const PipelinePair &backward_start_pair, const PipelinePair &backward_end_pair,
641                      size_t micro_size) {
642   micro_size = micro_size + 1;
643   if (forward_start_pair.first.size() != micro_size) {
644     MS_LOG(EXCEPTION) << "forward_node's size:" << forward_start_pair.first.size()
645                       << "is not equal to micro size:" << micro_size;
646   }
647   if (forward_end_pair.first.size() != micro_size) {
648     MS_LOG(EXCEPTION) << "forward_node's size:" << forward_end_pair.first.size()
649                       << "is not equal to micro size:" << micro_size;
650   }
651   if (backward_start_pair.first.size() != micro_size) {
652     MS_LOG(EXCEPTION) << "backward_node's size:" << backward_start_pair.first.size()
653                       << "is not equal to micro size:" << micro_size;
654   }
655   if (backward_end_pair.first.size() != micro_size) {
656     MS_LOG(EXCEPTION) << "backward_node's size:" << backward_end_pair.first.size()
657                       << "is not equal to micro size:" << micro_size;
658   }
659 }
660 
Reorder(const FuncGraphPtr & root)661 void Reorder(const FuncGraphPtr &root) {
662   std::vector<AnfNodePtr> forward_start;
663   std::vector<AnfNodePtr> forward_end;
664   std::vector<AnfNodePtr> forward_params;
665   std::vector<AnfNodePtr> backward_start;
666   std::vector<AnfNodePtr> backward_end;
667   std::vector<AnfNodePtr> backward_params;
668   std::vector<AnfNodePtr> allreduce_params;
669   GetBorderNode(&forward_start, &forward_end, &backward_start, &backward_end, &forward_params, &backward_params,
670                 &allreduce_params, root);
671   int64_t micro_max = 0;
672   if (root->has_flag(TRAINING)) {
673     auto forward_end_cnode = forward_end.back()->cast<CNodePtr>();
674     auto micro_size = forward_end_cnode->GetPrimalAttr(MICRO);
675     MS_EXCEPTION_IF_NULL(micro_size);
676     micro_max = GetValue<int64_t>(micro_size);
677   }
678   auto backward_start_pair = Deduplicate(backward_start, root, micro_max);
679   auto backward_end_pair = Deduplicate(backward_end, root, micro_max);
680   auto forward_start_pair = Deduplicate(forward_start, root, micro_max);
681   auto forward_end_pair = Deduplicate(forward_end, root, micro_max);
682   auto forward_params_pair = Deduplicate(forward_params, root, micro_max);
683   auto backward_params_pair = Deduplicate(backward_params, root, micro_max);
684   CheckBorderNode(forward_start_pair, forward_end_pair, backward_start_pair, backward_end_pair, LongToSize(micro_max));
685   PipelinePair forward_end_before_pair;
686   if (!IsLastStage()) {
687     for (auto &node : forward_end_pair.first) {
688       auto cnode = node->cast<CNodePtr>();
689       auto temp_node = GetActualOp(cnode->input(1));
690       MS_EXCEPTION_IF_NULL(temp_node);
691       forward_end_before_pair.first.push_back(temp_node);
692     }
693     for (auto &node : forward_end_pair.second) {
694       auto cnode = node->cast<CNodePtr>();
695       auto temp_node = GetActualOp(cnode->input(1));
696       MS_EXCEPTION_IF_NULL(temp_node);
697       forward_end_before_pair.second.push_back(temp_node);
698     }
699   } else {
700     forward_end_before_pair = forward_end_pair;
701   }
702   ReorderForForward(forward_start_pair.first, forward_end_pair.second, root);
703   ReorderForBackward(forward_start_pair, forward_end_pair, backward_start_pair, backward_end_pair,
704                      forward_end_before_pair, root);
705   ReorderForParams(backward_params, forward_params, allreduce_params, forward_params_pair, backward_params_pair,
706                    backward_end, forward_start_pair, root);
707 }
708 
ReorderForPredict(const FuncGraphPtr & root,const FuncGraphManagerPtr & manager)709 void ReorderForPredict(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
710   std::vector<AnfNodePtr> forward_end;
711   std::vector<AnfNodePtr> forward_start;
712   std::vector<AnfNodePtr> forward_params;
713   for (auto &node : root->nodes()) {
714     if (!node->isa<CNode>()) {
715       continue;
716     }
717     auto cnode = node->cast<CNodePtr>();
718     if (cnode->HasPrimalAttr(PIPELINE_BEGIN)) {
719       forward_start.push_back(node);
720     }
721     if (cnode->HasPrimalAttr(PIPELINE_END)) {
722       forward_end.push_back(node);
723     }
724     if (cnode->HasPrimalAttr(PIPELINE_PARAM)) {
725       forward_params.push_back(node);
726     }
727   }
728   std::sort(forward_start.begin(), forward_start.end(), CompFunc);
729   std::sort(forward_end.begin(), forward_end.end(), CompFunc);
730   std::sort(forward_params.begin(), forward_params.end(), CompFunc);
731   auto forward_start_pair = Deduplicate(forward_start, root, 0);
732   auto forward_end_pair = Deduplicate(forward_end, root, 0);
733   auto forward_params_pair = Deduplicate(forward_params, root, 0);
734   if (!forward_end.empty() && !forward_params.empty()) {
735     InsertDepend(forward_params_pair.second[0], forward_end_pair.first[0], manager, root);
736   }
737   if (!forward_start.empty() && !forward_params.empty()) {
738     InsertDepend(forward_params_pair.second[0], forward_start_pair.first[0], manager, root);
739   }
740 }
741 }  // namespace parallel
742 }  // namespace mindspore
743