• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/step_parallel.h"
18 
19 #include <inttypes.h>
20 #include <sys/time.h>
21 #include <algorithm>
22 
23 #include <map>
24 #include <memory>
25 #include <set>
26 #include <string>
27 #include <unordered_map>
28 #include <utility>
29 
30 #include "base/core_ops.h"
31 #include "frontend/operator/ops.h"
32 #include "frontend/optimizer/optimizer.h"
33 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
34 #include "frontend/parallel/context.h"
35 #include "frontend/parallel/device_manager.h"
36 #include "frontend/parallel/dynamic_creator.h"
37 #include "frontend/parallel/graph_util/generate_graph.h"
38 #include "frontend/parallel/graph_util/graph_info.h"
39 #include "frontend/parallel/graph_util/node_info.h"
40 #include "frontend/parallel/graph_util/pipeline_split_utils.h"
41 #include "frontend/parallel/node_check.h"
42 #include "frontend/parallel/parameter_manager.h"
43 #include "frontend/parallel/ops_info/matmul_info.h"
44 #include "ir/param_info.h"
45 #include "ir/tensor.h"
46 #include "utils/trace_base.h"
47 #include "utils/comm_manager.h"
48 #include "utils/ms_context.h"
49 #include "utils/symbolic.h"
50 #include "mindspore/core/utils/parallel_node_check.h"
51 #if ((defined ENABLE_CPU) && (!defined _WIN32))
52 #include "ps/util.h"
53 #include "ps/ps_context.h"
54 #endif
55 
56 using mindspore::tensor::Tensor;
57 
58 namespace mindspore {
59 namespace parallel {
60 static const std::set<std::string> COMMUNICATION_OPS = {ALL_REDUCE, ALL_GATHER, ALL_TO_ALL, REDUCE_SCATTER};
61 static const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS, LOAD, UPDATESTATE};
62 static const std::set<std::string> NO_INPUT_TENSOR_OPS = {UNIFORM_REAL};
63 // g_RefMap, for CNode B input i is a RefKey[Parameter C],
64 // it will be one item in map with key: C, and value: (B, i)
65 std::map<AnfNodePtr, std::pair<AnfNodePtr, int64_t>> g_RefMap;
66 
SetMiniStepOpDoMirrorLabel(std::vector<AnfNodePtr> new_node_input,bool accu_flag)67 void SetMiniStepOpDoMirrorLabel(std::vector<AnfNodePtr> new_node_input, bool accu_flag) {
68   if (new_node_input.empty()) {
69     return;
70   }
71   auto prim_anf_node = new_node_input[0]->cast<ValueNodePtr>();
72   auto prim = GetValueNode<PrimitivePtr>(prim_anf_node);
73   MS_EXCEPTION_IF_NULL(prim);
74 
75   auto attrs = prim->attrs();
76   attrs[DO_MIRROR] = MakeValue<bool>(!accu_flag);
77   prim->SetAttrs(attrs);
78 }
79 
SetAllReduceRecomputeFlag(const std::vector<AnfNodePtr> & new_node_input,const CNodePtr & node)80 void SetAllReduceRecomputeFlag(const std::vector<AnfNodePtr> &new_node_input, const CNodePtr &node) {
81   if (new_node_input.empty()) {
82     return;
83   }
84 
85   auto prim_anf_node = new_node_input[0]->cast<ValueNodePtr>();
86   auto prim = GetValueNode<PrimitivePtr>(prim_anf_node);
87   MS_EXCEPTION_IF_NULL(prim);
88   auto attrs = prim->attrs();
89 
90   auto anf_node = node->input(0)->cast<ValueNodePtr>();
91   auto prim_node = GetValueNode<PrimitivePtr>(anf_node);
92   MS_EXCEPTION_IF_NULL(prim_node);
93   auto node_attrs = prim_node->attrs();
94   if (node_attrs.find(RECOMPUTE_COMM_OP) != node_attrs.end() && !GetValue<bool>(node_attrs[RECOMPUTE_COMM_OP])) {
95     attrs[RECOMPUTE] = MakeValue<bool>(false);
96     prim->SetAttrs(attrs);
97     MS_LOG(INFO) << "Do not recompute the forward communication operator of " << prim_node->ToString();
98   }
99 }
100 
CreateInput(const Operator & op,const AnfNodePtr & node,const std::string & instance_name)101 std::vector<AnfNodePtr> CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name) {
102   MS_EXCEPTION_IF_NULL(node);
103   OperatorArgs arg_forward = op.second;
104   ValuePtr pyop_instance = CreatOpInstance(arg_forward.first, op.first, instance_name);
105   MS_EXCEPTION_IF_NULL(pyop_instance);
106   OperatorParams params = arg_forward.second;
107 
108   std::vector<AnfNodePtr> new_node_input = {NewValueNode(pyop_instance), node};
109   if (!params.empty()) {
110     for (auto &param : params) {
111       AnfNodePtr val = NewValueNode(param.first.second);
112       MS_EXCEPTION_IF_NULL(val);
113       int64_t position = param.second;
114       (void)new_node_input.insert(new_node_input.begin() + position, val);
115     }
116   }
117 
118   // if the op have 'group' attr, set the rank list name for the op
119   SetCommunicationOpGroupLabel(new_node_input);
120   return new_node_input;
121 }
122 
CreateMirrorInput(const FuncGraphPtr & root,const Operator & op,const AnfNodePtr & node,const std::string & instance_name,const std::string & weight_name)123 std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operator &op, const AnfNodePtr &node,
124                                           const std::string &instance_name, const std::string &weight_name) {
125   MS_EXCEPTION_IF_NULL(root);
126   MS_EXCEPTION_IF_NULL(node);
127   MS_EXCEPTION_IF_NULL(root->manager());
128 
129   AnfNodePtr grad_accu = nullptr;
130   std::string op_name = op.first;
131   OperatorArgs arg_forward = op.second;
132 
133   int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
134   int64_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
135 
136   if (grad_accumulation_step > 1 || split_stage_num > 1) {
137     auto parameters = root->parameters();
138     bool find_grad_accu_node = false;
139     for (auto &param : parameters) {
140       if (!ParameterIsCloned(param)) {
141         continue;
142       }
143 
144       auto param_ptr = param->cast<ParameterPtr>();
145       MS_EXCEPTION_IF_NULL(param_ptr);
146       if (param_ptr->name().find(weight_name) != std::string::npos &&
147           param_ptr->name().find(ACCU_GRADS) != std::string::npos) {
148         find_grad_accu_node = true;
149         grad_accu = param;
150         MS_LOG(INFO) << "Find the accumulation grad node: " << param_ptr->name();
151         break;
152       }
153     }
154 
155     if (!find_grad_accu_node) {
156       if (op_name == MIRROR_MINI_STEP_OPERATOR) {
157         op_name = MIRROR_OPERATOR;
158         arg_forward.first.pop_back();
159       } else if (op_name == MINI_STEP_ALL_GATHER || op_name == MIRROR_MICRO_STEP_OPERATOR ||
160                  op_name == MICRO_STEP_ALL_GATHER) {
161         MS_LOG(EXCEPTION) << "You should define `accu_grads` when use " << op_name << " parameter:" << weight_name;
162       }
163     }
164   }
165 
166   ValuePtr pyop_instance = CreatOpInstance(arg_forward.first, op_name, instance_name);
167   MS_EXCEPTION_IF_NULL(pyop_instance);
168   OperatorParams params = arg_forward.second;
169 
170   std::vector<AnfNodePtr> new_node_input;
171   if (op_name == MIRROR_MINI_STEP_OPERATOR || op_name == MINI_STEP_ALL_GATHER ||
172       op_name == MIRROR_MICRO_STEP_OPERATOR || op_name == MICRO_STEP_ALL_GATHER) {
173     new_node_input = {NewValueNode(pyop_instance), node, grad_accu};
174     MS_LOG(INFO) << "Insert the grad accumulation node as the mirror op's input";
175   } else {
176     new_node_input = {NewValueNode(pyop_instance), node};
177   }
178 
179   if (!params.empty()) {
180     for (auto &param : params) {
181       AnfNodePtr val = NewValueNode(param.first.second);
182       MS_EXCEPTION_IF_NULL(val);
183       int64_t position = param.second;
184       (void)new_node_input.insert(new_node_input.begin() + position, val);
185     }
186   }
187 
188   // if the op have 'group' attr, set the rank list name for the op
189   SetCommunicationOpGroupLabel(new_node_input);
190   // gradient accumulation
191   if (grad_accumulation_step > 1) {
192     SetMiniStepOpDoMirrorLabel(new_node_input, root->has_flag(ACCUMULATION));
193   }
194   return new_node_input;
195 }
196 
InsertNode(const Operator & op,const CNodePtr & node,size_t index,const AnfNodePtr & pre_node,const FuncGraphPtr & func_graph,const std::string & instance_name,const std::string & param_name="",const FuncGraphPtr & root=nullptr)197 void InsertNode(const Operator &op, const CNodePtr &node, size_t index, const AnfNodePtr &pre_node,
198                 const FuncGraphPtr &func_graph, const std::string &instance_name, const std::string &param_name = "",
199                 const FuncGraphPtr &root = nullptr) {
200   // insert new node before the node
201   FuncGraphManagerPtr manager = func_graph->manager();
202   MS_EXCEPTION_IF_NULL(manager);
203   ScopePtr scope = node->scope();
204   MS_EXCEPTION_IF_NULL(scope);
205   std::vector<AnfNodePtr> node_input;
206   if (root && !param_name.empty()) {
207     node_input = CreateMirrorInput(root, op, pre_node, instance_name, param_name);
208   } else {
209     node_input = CreateInput(op, pre_node, instance_name);
210   }
211   CNodePtr new_node = func_graph->NewCNode(node_input);
212   MS_EXCEPTION_IF_NULL(new_node);
213   if (instance_name.find(SPLIT_SENS) == std::string::npos) {
214     new_node->set_in_forward_flag(true);  // mark forward flag
215   }
216   auto new_node_value = node_input[0]->cast<ValueNodePtr>();
217   MS_EXCEPTION_IF_NULL(new_node_value);
218   PrimitivePtr new_node_prim = new_node_value->value()->cast<PrimitivePtr>();
219   new_node_prim->set_instance_name(instance_name);
220   new_node_prim->set_attr("keep_value_node_input", MakeValue(true));
221   if (instance_name.find(NOT_RECOMPUTE) != std::string::npos) {
222     new_node_prim->set_attr("recompute", MakeValue(false));
223   }
224   new_node->set_scope(scope);
225   node_input[0]->set_scope(scope);
226   manager->SetEdge(node, SizeToInt(index), new_node);
227   MS_LOG(INFO) << "Insert " << instance_name << " success";
228 }
229 
230 // Replace pre_node with pre_node->op
ReplaceNode(const Operator & op,const AnfNodePtr & pre_node,const FuncGraphPtr & func_graph,const std::string & instance_name,const std::string & param_name="",const FuncGraphPtr & root=nullptr)231 static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, const FuncGraphPtr &func_graph,
232                             const std::string &instance_name, const std::string &param_name = "",
233                             const FuncGraphPtr &root = nullptr) {
234   // insert new node before the node
235   FuncGraphManagerPtr manager = func_graph->manager();
236   MS_EXCEPTION_IF_NULL(manager);
237   ScopePtr scope = pre_node->scope();
238   MS_EXCEPTION_IF_NULL(scope);
239   std::vector<AnfNodePtr> node_input;
240   if (root && !param_name.empty()) {
241     node_input = CreateMirrorInput(root, op, pre_node, instance_name, param_name);
242   } else {
243     node_input = CreateInput(op, pre_node, instance_name);
244   }
245   CNodePtr new_node = func_graph->NewCNode(node_input);
246   MS_EXCEPTION_IF_NULL(new_node);
247   if (instance_name.find(SPLIT_SENS) == std::string::npos) {
248     new_node->set_in_forward_flag(true);  // mark forward flag
249   }
250   auto new_node_prim = GetValueNode<PrimitivePtr>(node_input[0]);
251   new_node_prim->set_instance_name(instance_name);
252   new_node_prim->set_attr("keep_value_node_input", MakeValue(true));
253   if (instance_name.find(NOT_RECOMPUTE) != std::string::npos) {
254     new_node_prim->set_attr("recompute", MakeValue(false));
255   }
256   new_node->set_scope(scope);
257   node_input[0]->set_scope(scope);
258   manager->Replace(pre_node, new_node);
259   MS_LOG(INFO) << "Insert " << instance_name << " success";
260   return new_node;
261 }
262 
ForwardCommunication(OperatorVector forward_op,const CNodePtr & node)263 void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node) {
264   MS_EXCEPTION_IF_NULL(node);
265   // step1:get graph manager distribute_operator
266   FuncGraphPtr func_graph = node->func_graph();
267   MS_EXCEPTION_IF_NULL(func_graph);
268   FuncGraphManagerPtr manager = func_graph->manager();
269   MS_EXCEPTION_IF_NULL(manager);
270   auto uses_set = manager->node_users()[node];
271   CNodePtr node_to_insert = node;
272   for (auto &uses_pair : uses_set) {
273     auto uses_cnode = uses_pair.first->cast<CNodePtr>();
274     MS_EXCEPTION_IF_NULL(uses_cnode);
275     if (!IsValueNode<Primitive>(uses_cnode->input(0))) {
276       break;
277     }
278     PrimitivePtr value_node_prim = GetValueNode<PrimitivePtr>(uses_cnode->input(0));
279     MS_EXCEPTION_IF_NULL(value_node_prim);
280     if (value_node_prim->name() == prim::kTupleGetItem) {
281       if (uses_set.size() > 1) {
282         MS_LOG(EXCEPTION) << "Now only support one output, but got " << uses_set.size();
283       }
284       node_to_insert = uses_cnode;
285     }
286   }
287   MS_EXCEPTION_IF_NULL(node_to_insert);
288   std::reverse(forward_op.begin(), forward_op.end());
289 
290   // step2:traverse op_list and insert node
291   for (size_t index = 0; index < forward_op.size(); ++index) {
292     std::string instance_name_base = FORWARD_OP;
293     std::string instance_name = instance_name_base + "_" + CreateInstanceName(node, index);
294     std::vector<AnfNodePtr> forward_input = CreateInput(forward_op[index], node_to_insert, instance_name);
295     SetAllReduceRecomputeFlag(forward_input, node_to_insert);
296     CNodePtr forward_node = func_graph->NewCNode(forward_input);  // using NewCNode to create anfnode
297     MS_EXCEPTION_IF_NULL(forward_node);
298     ScopePtr scope = node->scope();
299     MS_EXCEPTION_IF_NULL(scope);
300     forward_node->set_scope(scope);
301     forward_node->set_in_forward_flag(true);
302     forward_input[0]->set_scope(scope);
303     (void)manager->Replace(node_to_insert, forward_node);  // using Replace function to insert node
304   }
305 }
306 
InsertMakeTuple(const AnfNodePtr & prev,uint64_t num,const FuncGraphPtr & func_graph)307 CNodePtr InsertMakeTuple(const AnfNodePtr &prev, uint64_t num, const FuncGraphPtr &func_graph) {
308   MS_EXCEPTION_IF_NULL(prev);
309   MS_EXCEPTION_IF_NULL(func_graph);
310   std::vector<AnfNodePtr> make_tuple_inputs;
311   make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
312   for (uint64_t i = 0; i < num; i++) {
313     std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), prev,
314                                                   CreatInt64Imm(UlongToLong(i))};
315     auto tuple_get_item = func_graph->NewCNode(tuple_get_item_inputs);
316     MS_EXCEPTION_IF_NULL(tuple_get_item);
317     make_tuple_inputs.push_back(tuple_get_item);
318   }
319   auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
320   MS_EXCEPTION_IF_NULL(make_tuple);
321   FuncGraphManagerPtr manager = func_graph->manager();
322   MS_EXCEPTION_IF_NULL(manager);
323   (void)manager->Replace(prev, make_tuple);
324   return make_tuple;
325 }
326 
InsertRedistribution(const RedistributionOpListPtr & redistribution_oplist_ptr,const CNodePtr & node,const FuncGraphPtr & func_graph,int64_t pos,const CNodePtr & pre_node)327 void InsertRedistribution(const RedistributionOpListPtr &redistribution_oplist_ptr, const CNodePtr &node,
328                           const FuncGraphPtr &func_graph, int64_t pos, const CNodePtr &pre_node) {
329   MS_EXCEPTION_IF_NULL(node);
330   MS_EXCEPTION_IF_NULL(pre_node);
331   MS_EXCEPTION_IF_NULL(func_graph);
332   FuncGraphManagerPtr manager = func_graph->manager();
333   MS_EXCEPTION_IF_NULL(manager);
334   if ((redistribution_oplist_ptr->first).size() != (redistribution_oplist_ptr->second).size()) {
335     MS_LOG(EXCEPTION) << "size of OperatorVector and OutPutInfoVector must be the same!";
336   }
337   for (size_t index = 0; index < (redistribution_oplist_ptr->first).size(); ++index) {
338     if (pos >= SizeToLong(node->inputs().size())) {
339       MS_LOG(EXCEPTION) << "InsertRedistribution:pos can't be larger than node's inputs'size";
340     }
341     // Create new node
342     AnfNodePtr target_node = node->input(LongToSize(pos));
343     MS_EXCEPTION_IF_NULL(target_node);
344     // Create instance_name
345     auto op = (redistribution_oplist_ptr->first)[index];
346     std::string op_name = (redistribution_oplist_ptr->first)[index].first;
347     std::string instance_name_base = REDISTRIBUTION_OP;
348     std::string instance_name = instance_name_base + "_" + CreateInstanceName(pre_node, index) + op_name;
349     auto prim_out = GetCNodePrimitive(node);
350     auto prim_in = GetCNodePrimitive(pre_node);
351     if (prim_out != nullptr && prim_in != nullptr) {
352       auto prim_out_attr = prim_out->attrs();
353       auto prim_in_attr = prim_in->attrs();
354       if (prim_out_attr.find(RECOMPUTE_COMM_OP) != prim_out_attr.end() &&
355           !GetValue<bool>(prim_out_attr[RECOMPUTE_COMM_OP]) &&
356           prim_in_attr.find(RECOMPUTE_COMM_OP) != prim_in_attr.end() &&
357           !GetValue<bool>(prim_in_attr[RECOMPUTE_COMM_OP]) &&
358           COMMUNICATION_OPS.find(op_name) != COMMUNICATION_OPS.end()) {
359         MS_LOG(INFO) << "The redistribution node would not be recomputed.";
360         instance_name = instance_name + "_" + NOT_RECOMPUTE;
361       }
362     }
363     InsertNode(op, node, LongToSize(pos), target_node, func_graph, instance_name);
364     if ((redistribution_oplist_ptr->second)[index].first) {
365       target_node = node->input(LongToSize(pos));
366       MS_EXCEPTION_IF_NULL(target_node);
367       (void)InsertMakeTuple(target_node, (redistribution_oplist_ptr->second)[index].second, func_graph);
368     }
369   }
370 }
371 
InsertGetTensorSliceOp(const Operator & op,const CNodePtr & node,const FuncGraphPtr & func_graph,int64_t pos,const std::string & instance_name)372 void InsertGetTensorSliceOp(const Operator &op, const CNodePtr &node, const FuncGraphPtr &func_graph, int64_t pos,
373                             const std::string &instance_name) {
374   if (func_graph == nullptr) {
375     MS_LOG(EXCEPTION) << "InsertGetTensorSliceOp: the graph is null, the instance name is " << instance_name;
376   }
377 
378   FuncGraphManagerPtr manager = func_graph->manager();
379   MS_EXCEPTION_IF_NULL(manager);
380   if (pos >= SizeToLong(node->inputs().size())) {
381     MS_LOG(EXCEPTION) << "InsertGetTensorSliceOp: pos can't be larger than node's inputs'size, the instance name is "
382                       << instance_name;
383   }
384   // Create new node
385   AnfNodePtr pre_node = node->input(LongToSize(pos));
386   MS_EXCEPTION_IF_NULL(pre_node);
387   InsertNode(op, node, LongToSize(pos), pre_node, func_graph, instance_name);
388 }
389 
GetTensorInLayout(const CNodePtr & middle_node,const PrimitivePtr & middle_prim,const OperatorInfoPtr & distribute_operator)390 TensorLayout GetTensorInLayout(const CNodePtr &middle_node, const PrimitivePtr &middle_prim,
391                                const OperatorInfoPtr &distribute_operator) {
392   TensorInfo tensorinfo_in;
393   if (middle_prim->name() == prim::kTupleGetItem) {
394     auto value_node = middle_node->input(2)->cast<ValueNodePtr>();
395     MS_EXCEPTION_IF_NULL(value_node);
396     size_t index_s = LongToSize(GetValue<int64_t>(value_node->value()));
397     if (index_s >= distribute_operator->outputs_tensor_info().size()) {
398       MS_LOG(EXCEPTION) << "The index out of range, index: " << index_s
399                         << ", vector size: " << distribute_operator->outputs_tensor_info().size();
400     }
401     tensorinfo_in = distribute_operator->outputs_tensor_info()[index_s];
402   } else {
403     if (distribute_operator->outputs_tensor_info().empty()) {
404       MS_LOG(EXCEPTION) << "The outputs tensor info is empty";
405     }
406     tensorinfo_in = distribute_operator->outputs_tensor_info()[0];
407   }
408   return tensorinfo_in.tensor_layout();
409 }
410 
GetPrimName(const CNodePtr & node)411 std::string GetPrimName(const CNodePtr &node) {
412   auto prim = GetCNodePrimitive(node);
413   MS_EXCEPTION_IF_NULL(prim);
414   return prim->name();
415 }
416 
GetDistributeOperator(const CNodePtr & node)417 OperatorInfoPtr GetDistributeOperator(const CNodePtr &node) {
418   MS_EXCEPTION_IF_NULL(node);
419   if (!IsParallelCareNode(node)) {
420     return nullptr;
421   }
422   OperatorInfoPtr distribute_operator = node->user_data<OperatorInfo>();
423   if (distribute_operator == nullptr) {
424     MS_LOG(EXCEPTION) << "Distribute operator is nullptr, the prim is " << GetPrimName(node);
425   }
426   return distribute_operator;
427 }
428 
Redistribution(const std::pair<AnfNodePtr,int64_t> & node_pair,const OperatorInfoPtr & distribute_operator,const CNodePtr & middle_node,int64_t index,TensorRedistribution tensor_redistribution,const CNodePtr & pre_node)429 void Redistribution(const std::pair<AnfNodePtr, int64_t> &node_pair, const OperatorInfoPtr &distribute_operator,
430                     const CNodePtr &middle_node, int64_t index, TensorRedistribution tensor_redistribution,
431                     const CNodePtr &pre_node) {
432   FuncGraphPtr func_graph = middle_node->func_graph();
433   if (func_graph == nullptr) {
434     MS_LOG(EXCEPTION) << "Redistribution:get graph failed";
435   }
436   CNodePtr next_node = node_pair.first->cast<CNodePtr>();
437   MS_EXCEPTION_IF_NULL(next_node);
438   auto middle_value = middle_node->input(0)->cast<ValueNodePtr>();
439   MS_EXCEPTION_IF_NULL(middle_value);
440   PrimitivePtr middle_prim = middle_value->value()->cast<PrimitivePtr>();
441   MS_EXCEPTION_IF_NULL(middle_prim);
442   OperatorInfoPtr next_distribute_operator = GetDistributeOperator(next_node);
443   if (next_distribute_operator == nullptr) {
444     MS_LOG(EXCEPTION) << "Failure: " << next_node->ToString() << " GetDistributeOperator failed";
445   }
446   RankList dev_list = distribute_operator->stage_device_list();
447   std::string next_prim_name = GetValueNode<PrimitivePtr>(next_node->input(0))->name();
448   MS_LOG(DEBUG) << "Redistribution: middle_prim " << middle_prim->name() << " next_prim " << next_prim_name;
449   MS_LOG(DEBUG) << "Redistribution: middle_node " << middle_node->ToString() << " next_node " << next_node->ToString();
450   // extract tensor layout in and out
451   if (distribute_operator->outputs_tensor_info().empty()) {
452     MS_LOG(WARNING) << "pre_node's tensorinfo_in is empty, operator name is " << distribute_operator->name();
453     return;
454   }
455 
456   if (LongToSize(index - 1) >= next_distribute_operator->inputs_tensor_info().size()) {
457     MS_LOG(WARNING) << "The index is out of range, the index is " << (index - 1) << ", the vector size is "
458                     << next_distribute_operator->inputs_tensor_info().size() << "next operator name is "
459                     << next_distribute_operator->name();
460     return;
461   }
462   TensorInfo tensorinfo_out = next_distribute_operator->inputs_tensor_info()[LongToSize(index - 1)];
463   TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout();
464   TensorLayout tensorlayout_in = GetTensorInLayout(middle_node, middle_prim, distribute_operator);
465   if (IsPrimitiveCNode(middle_node, prim::kPrimReceive)) {
466     tensorlayout_in = *(middle_node->user_data<TensorLayout>());
467   }
468   if (tensor_redistribution.Init(tensorlayout_in, tensorlayout_out, dev_list) == FAILED) {
469     MS_LOG(ERROR) << "Redistribution: middle_prim " << middle_prim->name() << " next_prim : " << next_prim_name;
470     MS_LOG(ERROR) << "Redistribution: middle_node " << middle_node->ToString() << " next_node "
471                   << next_node->ToString();
472     DumpGraph(func_graph, "redistribution_error");
473     MS_LOG(EXCEPTION) << "Failure:tensor_redistribution init failed";
474   }
475   RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList();
476   if (redistribution_oplist_ptr == nullptr) {
477     MS_LOG(EXCEPTION) << "Failure:InferTensorRedistribution failed";
478   }
479   MS_LOG(DEBUG) << "Redistribution size " << redistribution_oplist_ptr->first.size();
480   if (!redistribution_oplist_ptr->first.empty()) {
481     // insert node before next node
482     InsertRedistribution(redistribution_oplist_ptr, next_node, func_graph, node_pair.second, pre_node);
483   }
484 }
485 
StrategyFound(std::unordered_map<std::string,ValuePtr> attrs)486 bool StrategyFound(std::unordered_map<std::string, ValuePtr> attrs) {
487   auto iter = attrs.find(STRATEGY);
488   return !((iter == attrs.end()) || (iter->second->type_name() == NONE));
489 }
490 
HasStrategy(const FuncGraphPtr & root)491 bool HasStrategy(const FuncGraphPtr &root) {
492   AnfNodePtr ret = root->get_return();
493   MS_EXCEPTION_IF_NULL(ret);
494   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
495 
496   for (auto &node : all_nodes) {
497     auto cnode = node->cast<CNodePtr>();
498     if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
499       continue;
500     }
501 
502     ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
503     PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
504     auto attrs = prim->attrs();
505     if (StrategyFound(attrs)) {
506       return true;
507     }
508   }
509 
510   return false;
511 }
512 
IsCommunicationOp(const PrimitivePtr & prim)513 bool IsCommunicationOp(const PrimitivePtr &prim) {
514   MS_EXCEPTION_IF_NULL(prim);
515   return (COMMUNICATION_OPS.find(prim->name()) != COMMUNICATION_OPS.end());
516 }
517 
FindCommunicationOp(const std::vector<AnfNodePtr> & all_nodes)518 bool FindCommunicationOp(const std::vector<AnfNodePtr> &all_nodes) {
519   for (auto &node : all_nodes) {
520     MS_EXCEPTION_IF_NULL(node);
521     if (!node->isa<CNode>()) {
522       continue;
523     }
524     auto cnode = node->cast<CNodePtr>();
525     if (!IsValueNode<Primitive>(cnode->input(0))) {
526       continue;
527     }
528     ValueNodePtr prim_value_node = cnode->input(0)->cast<ValueNodePtr>();
529     MS_EXCEPTION_IF_NULL(prim_value_node);
530     PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_value_node);
531     MS_EXCEPTION_IF_NULL(prim);
532 
533     if (IsCommunicationOp(prim) && cnode->in_forward_flag()) {
534       MS_EXCEPTION_IF_NULL(prim_value_node->scope());
535       MS_LOG(INFO) << "The graph contain communication op: " << prim->name() << ", scope name is "
536                    << prim_value_node->scope()->name();
537       return true;
538     }
539   }
540   return false;
541 }
542 
StepRedistribution(const CNodePtr & node,const OperatorInfoPtr & distribute_operator,const CNodePtr & insert_node,const TensorRedistribution & tensor_redistribution,const CNodePtr & pre_node)543 void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_operator, const CNodePtr &insert_node,
544                         const TensorRedistribution &tensor_redistribution, const CNodePtr &pre_node) {
545   MS_EXCEPTION_IF_NULL(node->func_graph());
546   FuncGraphManagerPtr manager = node->func_graph()->manager();
547   MS_EXCEPTION_IF_NULL(manager);
548   AnfNodeIndexSet node_set = manager->node_users()[node];
549   CNodePtr insert_node_new;
550 
551   if (IsPrimitiveCNode(node, prim::kPrimSend)) {
552     return;
553   }
554   if (AnfNodeIsPrimitive(node, MAKE_TUPLE) || AnfNodeIsPrimitive(node, MAKE_LIST)) {
555     MS_LOG(INFO) << "No need to insert redistribution op between make_tuple node and the next node";
556     return;
557   }
558   if (IsValueNode<Primitive>(node->input(0))) {
559     auto current_value = node->input(0)->cast<ValueNodePtr>();
560     MS_EXCEPTION_IF_NULL(current_value);
561     PrimitivePtr current_prim = current_value->value()->cast<PrimitivePtr>();
562     MS_EXCEPTION_IF_NULL(current_prim);
563     insert_node_new = ((current_prim->name() == prim::kTupleGetItem) ? node : insert_node);
564   } else {
565     insert_node_new = insert_node;
566   }
567   MS_EXCEPTION_IF_NULL(insert_node_new);
568   for (auto &node_pair : node_set) {
569     CNodePtr use_cnode = node_pair.first->cast<CNodePtr>();
570     MS_EXCEPTION_IF_NULL(use_cnode);
571     if (!IsValueNode<Primitive>(use_cnode->input(0))) {
572       StepRedistribution(use_cnode, distribute_operator, insert_node_new, tensor_redistribution, pre_node);
573     } else {
574       ValueNodePtr prim_anf_node = use_cnode->input(0)->cast<ValueNodePtr>();
575       MS_EXCEPTION_IF_NULL(prim_anf_node);
576       PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
577       MS_EXCEPTION_IF_NULL(node_prim);
578       if ((node_prim->name() == DEPEND && node_pair.second != 1) || node_prim->name() == UPDATESTATE) {
579         continue;
580       }
581       if (IsParallelCareNode(use_cnode) && use_cnode->has_user_data<OperatorInfo>()) {
582         Redistribution(node_pair, distribute_operator, insert_node_new, node_pair.second, tensor_redistribution,
583                        pre_node);
584       } else {
585         StepRedistribution(use_cnode, distribute_operator, insert_node_new, tensor_redistribution, pre_node);
586       }
587     }
588   }
589 }
590 
SplitTensor(const AnfNodePtr & node,const CNodePtr & next_node,int64_t index)591 void SplitTensor(const AnfNodePtr &node, const CNodePtr &next_node, int64_t index) {
592   MS_EXCEPTION_IF_NULL(node);
593   MS_EXCEPTION_IF_NULL(next_node);
594   OperatorInfoPtr op_info = next_node->user_data<OperatorInfo>();
595   MS_EXCEPTION_IF_NULL(op_info);
596 
597   // If the shape of tensor is [] or [1], no need to split it.
598   Shapes shapes = GetNodeShape(node);
599   if (shapes.size() != 1) {
600     MS_LOG(EXCEPTION) << "Split tensor for " << op_info->name()
601                       << ": GetNodeShape for tensor_node, output size is not 1";
602   }
603   Shape shape = shapes[0];
604   std::string shape_str = ShapeToString(shape);
605   if (shape.empty() || ((shape.size() == 1) && (shape[0] == 1))) {
606     MS_LOG(INFO) << "Split tensor for " << op_info->name() << ": The shape is " << shape_str
607                  << ", no need to split it.";
608     return;
609   }
610 
611   MS_LOG(INFO) << "Split tensor for " << op_info->name() << ": The shape of tensor is " << shape_str;
612 
613   // extract tensor layout
614   if (LongToSize(index - 1) >= op_info->inputs_tensor_info().size()) {
615     MS_LOG(EXCEPTION) << "The index is out of range, index is  " << (index - 1) << ", vector size is  "
616                       << op_info->inputs_tensor_info().size();
617   }
618   TensorInfo tensor_info = op_info->inputs_tensor_info()[LongToSize(index - 1)];
619   TensorLayout tensor_layout = tensor_info.tensor_layout();
620 
621   // Use _GetTensorSlice operator to split the tensor
622   FuncGraphPtr func_graph = next_node->func_graph();  // only cnode can get the graph
623   MS_EXCEPTION_IF_NULL(func_graph);
624   Operator op = CreateGetTensorSliceOp(tensor_layout);
625   InsertGetTensorSliceOp(op, next_node, func_graph, index, SPLIT_TENSOR);
626   if (!op_info->sub_ops().empty()) {
627     auto sub_ops = op_info->sub_ops();
628     for (size_t i = 0; i < sub_ops.size(); i++) {
629       if (!sub_ops.at(i).empty()) {
630         InsertGetTensorSliceOp(sub_ops.at(i).at(0), next_node, func_graph, index, SUB);
631       }
632     }
633   }
634 }
635 
SplitTensorList(const AnfNodePtr & node,const CNodePtr & next_node,int index)636 void SplitTensorList(const AnfNodePtr &node, const CNodePtr &next_node, int index) {
637   MS_EXCEPTION_IF_NULL(node);
638   MS_EXCEPTION_IF_NULL(next_node);
639   if (next_node->inputs().size() != 2 || index != 1) {
640     MS_LOG(INFO) << next_node->fullname_with_scope() << " Inputs must have only one input, get "
641                  << (next_node->inputs().size() - 1) << " index should be 1, get " << index;
642     return;
643   }
644   OperatorInfoPtr op_info = next_node->user_data<OperatorInfo>();
645   MS_EXCEPTION_IF_NULL(op_info);
646 
647   std::vector<ValuePtr> inputs_values;
648   if (IsValueNode<ValueList>(node)) {
649     inputs_values = node->cast<ValueNodePtr>()->value()->cast<ValueListPtr>()->value();
650   } else {
651     inputs_values = node->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value();
652   }
653   if (inputs_values.size() != op_info->inputs_tensor_info().size()) {
654     MS_LOG(EXCEPTION) << "The inputs size " << inputs_values.size() << ", is not equal to inputs shape size "
655                       << op_info->inputs_tensor_info().size();
656   }
657   std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
658   FuncGraphPtr func_graph = next_node->func_graph();
659   MS_EXCEPTION_IF_NULL(func_graph);
660   FuncGraphManagerPtr manager = func_graph->manager();
661   MS_EXCEPTION_IF_NULL(manager);
662   ScopePtr scope = next_node->scope();
663   MS_EXCEPTION_IF_NULL(scope);
664   for (size_t i = 0; i < inputs_values.size(); ++i) {
665     auto value_ptr = inputs_values[i];
666     auto tensor = value_ptr->cast<tensor::TensorPtr>();
667     MS_EXCEPTION_IF_NULL(tensor);
668     TensorInfo tensor_info = op_info->inputs_tensor_info()[i];
669     TensorLayout tensor_layout = tensor_info.tensor_layout();
670     auto value_node = NewValueNode(value_ptr)->cast<AnfNodePtr>();
671     Operator op = CreateGetTensorSliceOp(tensor_layout);
672     std::vector<AnfNodePtr> node_input = CreateInput(op, value_node, SPLIT_TENSOR);
673     CNodePtr new_node = func_graph->NewCNode(node_input);
674     new_node->set_in_forward_flag(true);
675     auto new_node_value = node_input[0]->cast<ValueNodePtr>();
676     MS_EXCEPTION_IF_NULL(new_node_value);
677     PrimitivePtr new_node_prim = new_node_value->value()->cast<PrimitivePtr>();
678     new_node_prim->set_instance_name(SPLIT_TENSOR);
679     new_node_prim->set_attr("keep_value_node_input", MakeValue(true));
680     new_node->set_scope(scope);
681     node_input[0]->set_scope(scope);
682     make_tuple_inputs.push_back(new_node);
683   }
684   CNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs);
685   manager->Replace(node, make_tuple);
686 }
687 
StepSplitTensor(const AnfNodePtr & node,const FuncGraphManagerPtr & manager)688 void StepSplitTensor(const AnfNodePtr &node, const FuncGraphManagerPtr &manager) {
689   MS_EXCEPTION_IF_NULL(node);
690   MS_EXCEPTION_IF_NULL(manager);
691   AnfNodeIndexSet node_set = manager->node_users()[node];
692   for (auto &node_pair : node_set) {
693     CNodePtr use_cnode = node_pair.first->cast<CNodePtr>();
694     if (use_cnode == nullptr || !IsValueNode<Primitive>(use_cnode->input(0))) {
695       continue;
696     }
697     ValueNodePtr prim_anf_node = use_cnode->input(0)->cast<ValueNodePtr>();
698     MS_EXCEPTION_IF_NULL(prim_anf_node);
699     PrimitivePtr use_cnode_prim = prim_anf_node->value()->cast<PrimitivePtr>();
700     MS_EXCEPTION_IF_NULL(use_cnode_prim);
701     if ((use_cnode_prim->name() == DEPEND && node_pair.second != 1) ||
702         NO_INPUT_TENSOR_OPS.find(use_cnode_prim->name()) != NO_INPUT_TENSOR_OPS.end()) {
703       continue;
704     }
705     if (IsParallelCareNode(use_cnode)) {
706       if (IsValueNode<ValueList>(node) || IsValueNode<ValueTuple>(node)) {
707         SplitTensorList(node, use_cnode, node_pair.second);
708       } else {
709         SplitTensor(node, use_cnode, node_pair.second);
710       }
711     }
712   }
713 }
714 
StepReplaceOp(OperatorVector replace_op,const CNodePtr & node)715 void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) {
716   // step1:get graph manager distribute_operator
717   OperatorInfoPtr distribute_operator = node->user_data<OperatorInfo>();
718   if (distribute_operator == nullptr) {
719     MS_LOG(EXCEPTION) << "Failure:AddNode error since distribute_operator is nullptr";
720   }
721   FuncGraphPtr func_graph = node->func_graph();
722   MS_EXCEPTION_IF_NULL(func_graph);
723   FuncGraphManagerPtr manager = func_graph->manager();
724   if (manager == nullptr) {
725     MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr";
726   }
727   // step2:traverse op_list and insert node
728   std::reverse(replace_op.begin(), replace_op.end());
729   auto replace_op_info = distribute_operator->replace_op_info();
730   std::reverse(replace_op_info.begin(), replace_op_info.end());
731   if (!replace_op_info.empty() && replace_op_info.size() != replace_op.size()) {
732     MS_LOG(EXCEPTION) << "replace_op_info is not empty and size not equal to replace_op!";
733   }
734   bool replace_op_info_flag = !replace_op_info.empty();
735   for (size_t index = 0; index < replace_op.size(); ++index) {
736     std::string instance_name = CreateInstanceName(node, index);
737     std::vector<AnfNodePtr> replace_input;
738     if (index != replace_op.size() - 1) {
739       replace_input = CreateInput(replace_op[index], node, instance_name);
740     } else {
741       replace_input = ReplaceOpInput(replace_op[index], instance_name, node);
742     }
743     CNodePtr replace_node = func_graph->NewCNode(replace_input);
744     MS_EXCEPTION_IF_NULL(replace_node);
745     ScopePtr scope = node->scope();
746     MS_EXCEPTION_IF_NULL(scope);
747     replace_node->set_scope(scope);
748     PrimitivePtr prim = GetValueNode<PrimitivePtr>(replace_node->input(0));
749     PrimitivePtr origin_prim = GetValueNode<PrimitivePtr>(node->input(0));
750     SetUserAttrs(origin_prim->attrs(), prim);
751     auto origin_prim_attrs = origin_prim->attrs();
752     if (origin_prim_attrs.find(RECOMPUTE_COMM_OP) != origin_prim_attrs.end() &&
753         !GetValue<bool>(origin_prim_attrs[RECOMPUTE_COMM_OP]) &&
754         COMMUNICATION_OPS.find(prim->name()) != COMMUNICATION_OPS.end()) {
755       MS_LOG(INFO) << "The redistribution node in reshape would not be recomputed.";
756       prim->set_attr("recompute", MakeValue(false));
757     }
758     if (index == replace_op.size() - 1) {
759       replace_node->set_user_data<OperatorInfo>(node->user_data<OperatorInfo>());
760       replace_node->set_primal_attrs(node->primal_attrs());
761     }
762     replace_node->set_in_forward_flag(true);
763     replace_input[0]->set_scope(scope);
764     if (replace_op_info_flag && replace_op_info[index].first) {
765       auto new_cnode = InsertMakeTuple(replace_node, replace_op_info[index].second, func_graph);
766       new_cnode->set_primal_attrs(node->primal_attrs());
767       (void)manager->Replace(node, new_cnode);  // using Replace function to insert node
768     } else {
769       (void)manager->Replace(node, replace_node);  // using Replace function to insert node
770     }
771   }
772   MS_LOG(INFO) << "Insert ReplaceOp success for " << distribute_operator->name();
773 }
774 
StepReplaceGraph(const ReplaceGraphPtr & replace_graph,const CNodePtr & node)775 void StepReplaceGraph(const ReplaceGraphPtr &replace_graph, const CNodePtr &node) {
776   MS_EXCEPTION_IF_NULL(replace_graph);
777   MS_EXCEPTION_IF_NULL(node);
778   MS_EXCEPTION_IF_NULL(replace_graph->second);
779   FuncGraphPtr func_graph = node->func_graph();
780   MS_EXCEPTION_IF_NULL(func_graph);
781   FuncGraphManagerPtr manager = func_graph->manager();
782   if (manager == nullptr) {
783     MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr";
784   }
785   // Solve the input order
786   // For example input_node:{segment_sum:1, segment_sum:2, gahter:2}
787   // The Original code here will bind the all operations to the first inputs of these operatos
788   // However, the segment_sum operation needs two inputs, To solve this
789   // We maintain a dict to count the times of the same operations,
790   // and bind the inputs according to the times of the op appears.
791   std::unordered_map<AnfNodePtr, int> input_map = {};
792   static int appear_count = 0;
793   for (auto &replace_input : replace_graph->first) {
794     auto pre_node = node->input(LongToSize(replace_input.second));
795 
796     auto it = input_map.find(replace_input.first);
797     if (it != input_map.end()) {
798       appear_count = 1 + it->second;
799     } else {
800       appear_count = 1;
801     }
802     input_map[replace_input.first] = appear_count;
803     manager->SetEdge(replace_input.first, appear_count, pre_node);
804   }
805   //  "(void)manager->Replace(replace_graph->first, pre_node);" can not be called
806   auto replace_output = replace_graph->second->cast<CNodePtr>();
807   MS_EXCEPTION_IF_NULL(replace_output);
808   replace_output->set_primal_attrs(node->primal_attrs());
809   (void)manager->Replace(node, replace_output);
810 }
811 
GetTupleGetItemIndex(const CNodePtr & cnode)812 int64_t GetTupleGetItemIndex(const CNodePtr &cnode) {
813   MS_EXCEPTION_IF_NULL(cnode);
814   if (cnode->inputs().size() != 3) {
815     MS_LOG(EXCEPTION) << cnode->ToString() << " size( " << cnode->inputs().size() << " ) is not 3";
816   }
817 
818   if (!cnode->input(2)->isa<ValueNode>()) {
819     MS_LOG(EXCEPTION) << "The index of tuple getitem is not a value node";
820   }
821 
822   ValuePtr tuple_index_value = GetValueNode(cnode->input(2));
823   MS_EXCEPTION_IF_NULL(tuple_index_value);
824   if (!tuple_index_value->isa<Int64Imm>()) {
825     MS_LOG(EXCEPTION) << "The index of tuple getitem is not int32";
826   }
827   return tuple_index_value->cast<Int64ImmPtr>()->value();
828 }
829 
InsertVirtualDivOp(const VirtualDivOp & virtual_div_op,const CNodePtr & node)830 void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node) {
831   MS_EXCEPTION_IF_NULL(node);
832   size_t node_size = node->inputs().size();
833   FuncGraphPtr func_graph = node->func_graph();
834   MS_EXCEPTION_IF_NULL(func_graph);
835   FuncGraphManagerPtr manager = func_graph->manager();
836   MS_EXCEPTION_IF_NULL(manager);
837 
838   if (IsSomePrimitive(node, DROPOUT_DO_MASK)) {
839     MS_LOG(INFO) << "Handle dropout do mask, only insert the virtual div to input[0]";
840     node_size = 2;
841   }
842 
843   for (size_t index = 1; index < node_size; ++index) {
844     AnfNodePtr input = node->input(index);
845     MS_EXCEPTION_IF_NULL(input);
846     // if it is not a tensor, continue
847     if ((!input->isa<CNode>() && !input->isa<Parameter>()) || HasAbstractMonad(input)) {
848       MS_LOG(INFO) << "insert div op: the index  " << index << "  is not tensor, skip";
849       continue;
850     }
851 
852     for (size_t pos = 0; pos < virtual_div_op.size(); ++pos) {
853       std::string instance_name = CreateInstanceName(node, pos);
854       InsertNode(virtual_div_op[pos], node, index, node->input(index), func_graph, instance_name);
855     }
856     MS_LOG(INFO) << "insert div op for input index  " << index << "  of node";
857   }
858 }
859 
InsertVirtualOutput(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & all_nodes)860 void InsertVirtualOutput(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
861   vector<std::string> last_forward_node_ids;
862   vector<size_t> last_indexs;
863   FindLastNodesUniqueId(root, &last_forward_node_ids, &last_indexs);
864   MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict";
865   for (auto &node : all_nodes) {
866     // here insert virtualoutput node
867     auto cnode = node->cast<CNodePtr>();
868     if (cnode == nullptr) {
869       continue;
870     }
871     auto last_node_iter = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId());
872     if (last_node_iter == last_forward_node_ids.end()) {
873       continue;
874     }
875     for (size_t last_node_index = 0; last_node_index < last_forward_node_ids.size(); ++last_node_index) {
876       if (last_forward_node_ids[last_node_index] != cnode->UniqueId()) {
877         continue;
878       }
879       MS_LOG(INFO) << "find last node: " << cnode->fullname_with_scope() << ", the parallel care node is: "
880                    << cnode->input(last_indexs[last_node_index])->fullname_with_scope();
881       if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
882         FuncGraphManagerPtr manager = cnode->func_graph()->manager();
883         MS_EXCEPTION_IF_NULL(manager);
884         auto node_pair = manager->node_users()[cnode].front();
885         if (!node_pair.first->isa<CNode>()) {
886           MS_LOG(EXCEPTION) << "the output of tuple_get_item is not a cnode";
887         }
888         cnode = node_pair.first->cast<CNodePtr>();
889         last_indexs[last_node_index] = IntToSize(node_pair.second);
890       }
891       auto pre_node = cnode->input(last_indexs[last_node_index]);
892       Shapes shape_outputs = GetNodeShape(pre_node);
893       if (shape_outputs[0].empty()) {
894         continue;
895       }
896       FuncGraphPtr func_graph = node->func_graph();
897       MS_EXCEPTION_IF_NULL(func_graph);
898       OperatorParams params;
899       OperatorAttrs attrs;
900       OperatorArgs args = std::make_pair(attrs, params);
901       Operator op = std::make_pair(VIRTUAL_OUTPUT, args);
902       InsertNode(op, cnode, last_indexs[last_node_index], pre_node, func_graph, VIRTUAL_OUTPUT);
903       auto virtual_output_node = cnode->input(last_indexs[last_node_index]);
904       AbstractBasePtr virtual_output_abstract = pre_node->abstract()->Clone();
905       std::shared_ptr<abstract::BaseShape> virtual_output_shape = std::make_shared<abstract::Shape>(shape_outputs[0]);
906       virtual_output_abstract->set_shape(virtual_output_shape);
907       virtual_output_node->set_abstract(virtual_output_abstract);
908     }
909   }
910 }
911 
FindParameterByValueNode(const AnfNodePtr & node,const FuncGraphPtr & func_graph)912 static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
913   if (IsValueNode<RefKey>(node)) {
914     std::vector<AnfNodePtr> param_v = FindParameterByRefKeyNode(node, func_graph);
915     if (param_v.size() != 1) {
916       MS_LOG(EXCEPTION) << "FindParameterByRefKeyNode failed, return vector size must be 1, real is  "
917                         << param_v.size();
918     }
919     auto param_ptr = param_v[0]->user_data<parallel::TensorLayout>();
920     if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty()) {
921       return std::make_pair(nullptr, true);
922     }
923     return std::make_pair(node, true);
924   }
925   return std::make_pair(nullptr, false);
926 }
927 
FindParameterByParameter(const AnfNodePtr & node)928 static std::pair<AnfNodePtr, bool> FindParameterByParameter(const AnfNodePtr &node) {
929   auto param_ptr = node->user_data<parallel::TensorLayout>();
930   if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty()) {
931     return std::make_pair(nullptr, false);
932   }
933   return std::make_pair(node, false);
934 }
935 
936 // Only used for InsertMirrorOps
FindParameter(const AnfNodePtr & node,const FuncGraphPtr & func_graph)937 std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
938   if (!node->isa<Parameter>() && !node->isa<CNode>() && !node->isa<ValueNode>()) {
939     return std::make_pair(nullptr, false);
940   }
941 
942   if (node->isa<Parameter>()) {
943     return FindParameterByParameter(node);
944   }
945 
946   if (node->isa<ValueNode>()) {
947     return FindParameterByValueNode(node, func_graph);
948   }
949 
950   CNodePtr cnode = node->cast<CNodePtr>();
951   MS_EXCEPTION_IF_NULL(cnode);
952   if (!IsValueNode<Primitive>(cnode->input(0))) {
953     for (size_t index = 0; index < cnode->inputs().size(); ++index) {
954       auto res = FindParameter(cnode->input(index), func_graph);
955       if (!res.first) {
956         continue;
957       }
958       return res;
959     }
960   }
961 
962   // When not fully use opt shard, allgather and mirror would be both inserted.
963   // Skip allgather here and find parameter recursively.
964   if (IsParallelCareNode(cnode) && !IsInAllGatherNodeList(cnode)) {
965     return std::make_pair(nullptr, false);
966   }
967 
968   ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
969   MS_EXCEPTION_IF_NULL(prim_anf_node);
970   for (size_t index = 0; index < cnode->inputs().size(); ++index) {
971     PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
972     MS_EXCEPTION_IF_NULL(prim);
973     if ((prim->name() == DEPEND || prim->name() == LOAD || IsInAllGatherNodeList(cnode)) && index != 1) {
974       continue;
975     }
976     auto res = FindParameter(cnode->input(index), func_graph);
977     if (!res.first) {
978       continue;
979     }
980     return res;
981   }
982   return std::make_pair(nullptr, false);
983 }
984 
985 // only used for FindCNode
SkipTrivialNodesMoveDown(const FuncGraphManagerPtr & manager,CNodePtr node)986 CNodePtr SkipTrivialNodesMoveDown(const FuncGraphManagerPtr &manager, CNodePtr node) {
987   MS_EXCEPTION_IF_NULL(node);
988   while (IsInTrivialNodeList(node) || IsSomePrimitive(node, LOAD)) {
989     node = manager->node_users()[node].begin()->first->cast<CNodePtr>();
990   }
991   return node;
992 }
993 
FindCNode(const AnfNodePtr & anode,const std::string & name,const FuncGraphPtr & func_graph,size_t max_depth)994 std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string &name, const FuncGraphPtr &func_graph,
995                                     size_t max_depth) {
996   MS_EXCEPTION_IF_NULL(anode);
997   MS_EXCEPTION_IF_NULL(anode->func_graph());
998   FuncGraphManagerPtr manager = anode->func_graph()->manager();
999   MS_EXCEPTION_IF_NULL(manager);
1000   if (max_depth > MAX_RECURSIVE_DEPTH) {
1001     MS_LOG(EXCEPTION) << "Recursive call is larger than 100000.";
1002   }
1003   AnfNodeIndexSet node_set = manager->node_users()[anode];
1004   bool result = false;
1005   CNodePtr cnode_return = nullptr;
1006   for (auto &node_pair : node_set) {
1007     CNodePtr use_apply = node_pair.first->cast<CNodePtr>();
1008     if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
1009       continue;
1010     }
1011     if (ParallelContext::GetInstance()->enable_parallel_optimizer()) {
1012       use_apply = SkipTrivialNodesMoveDown(manager, use_apply);
1013     }
1014     ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
1015     MS_EXCEPTION_IF_NULL(prim_anf_node);
1016     PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
1017     MS_EXCEPTION_IF_NULL(node_prim);
1018     if (node_prim->name() == name && node_pair.second == 1) {
1019       if (use_apply->func_graph() == func_graph) {
1020         result = true;
1021         cnode_return = use_apply;
1022         MS_LOG(INFO) << "Find Primitive " << name << " in the same func_graph";
1023         continue;
1024       }
1025       MS_LOG(INFO) << "Find Primitive " << name << " in different func_graph";
1026     }
1027     if (ParallelContext::GetInstance()->enable_parallel_optimizer() && IsInAllGatherNodeList(use_apply)) {
1028       return FindCNode(node_pair.first, name, func_graph, max_depth + 1);
1029     }
1030   }
1031   return std::make_pair(result, cnode_return);
1032 }
1033 
InsertMirrorBeforeCast(const CNodePtr & node,size_t index)1034 bool InsertMirrorBeforeCast(const CNodePtr &node, size_t index) {
1035   // only if gradient_fp32_sync is true, pre node is cast and type is not float32 return true
1036   if (!ParallelContext::GetInstance()->gradient_fp32_sync()) {
1037     return false;
1038   }
1039   auto pre_node = node->input(index);
1040   MS_EXCEPTION_IF_NULL(pre_node);
1041   auto cnode = pre_node->cast<CNodePtr>();
1042   if (cnode == nullptr || !IsValueNode<Primitive>(cnode->input(0))) {
1043     return false;
1044   }
1045   if (ParallelContext::GetInstance()->enable_parallel_optimizer() && IsInAllGatherNodeList(cnode)) {
1046     pre_node = cnode->input(1);
1047   }
1048   if (!IsPrimitiveCNode(pre_node, prim::kPrimCast)) {
1049     return false;
1050   }
1051   auto node_type = pre_node->Type();
1052   MS_EXCEPTION_IF_NULL(node_type);
1053   if (!node_type->isa<mindspore::TensorType>()) {
1054     MS_LOG(EXCEPTION) << "Unknown type.";
1055   }
1056   auto input_element_type = node_type->cast<mindspore::TensorTypePtr>()->element();
1057   MS_EXCEPTION_IF_NULL(input_element_type);
1058   auto type_id = input_element_type->type_id();
1059 
1060   return (type_id != kNumberTypeFloat32);
1061 }
1062 
CheckInsertMirrorOps(const MirrorOps & mirror_ops,const CNodePtr & node,size_t node_size)1063 static bool CheckInsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node, size_t node_size) {
1064   if (IsPrimitiveCNode(node, prim::kPrimSend)) {
1065     return true;
1066   }
1067   if ((node->inputs().size() == 2) && (IsValueNode<ValueSequeue>(node->input(1)))) {
1068     MS_LOG(INFO) << "Input is ValueList, skip it.";
1069     return false;
1070   }
1071 
1072   if ((node->inputs().size() == 2) &&
1073       (AnfNodeIsPrimitive(node->input(1), MAKE_TUPLE) || AnfNodeIsPrimitive(node->input(1), MAKE_LIST))) {
1074     MS_LOG(INFO) << "The mirror for " << GetPrimName(node) << " has handle by make_tuple node";
1075     return false;
1076   }
1077 
1078   if (mirror_ops.size() != node_size - 1) {
1079     MS_LOG(EXCEPTION) << "Mirrorops's size is wrong! mirror_ops size is " << mirror_ops.size() << ", node_size is "
1080                       << (node_size - 1);
1081   }
1082   return true;
1083 }
1084 
1085 // only used for InsertMirrorOps
SkipTrivialNodesMoveUp(CNodePtr node)1086 CNodePtr SkipTrivialNodesMoveUp(CNodePtr node) {
1087   MS_EXCEPTION_IF_NULL(node);
1088   while (!IsSomePrimitive(node, LOAD)) {
1089     if (IsInTrivialNodeList(node) || IsInAllGatherNodeList(node)) {
1090       node = node->input(1)->cast<CNodePtr>();
1091     }
1092   }
1093   auto prev_node = node->input(1)->cast<CNodePtr>();
1094   if (prev_node != nullptr) {
1095     if (IsSomePrimitive(prev_node, DEPEND)) {
1096       auto prev_prev_node = prev_node->input(1)->cast<CNodePtr>();
1097       if (IsSomePrimitive(node, LOAD)) {
1098         node = prev_prev_node;
1099         MS_LOG(INFO) << "Moving to the Load node before Depend node.";
1100       }
1101     }
1102   }
1103   return node;
1104 }
1105 
MirrorOpName()1106 std::string MirrorOpName() {
1107   int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
1108   int64_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
1109   std::string mirror_op_name;
1110   if (grad_accumulation_step > 1) {
1111     mirror_op_name = MIRROR_MINI_STEP_OPERATOR;
1112   } else if (split_stage_num > 1) {
1113     mirror_op_name = MIRROR_MICRO_STEP_OPERATOR;
1114   } else {
1115     mirror_op_name = MIRROR_OPERATOR;
1116   }
1117   return mirror_op_name;
1118 }
1119 
InsertMirrorOps(const FuncGraphPtr & root,const MirrorOps & mirror_ops,const CNodePtr & node)1120 void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, const CNodePtr &node) {
1121   MS_EXCEPTION_IF_NULL(node);
1122   size_t node_size = node->inputs().size();
1123   FuncGraphPtr func_graph = node->func_graph();
1124   MS_EXCEPTION_IF_NULL(func_graph);
1125   FuncGraphManagerPtr manager = func_graph->manager();
1126   MS_EXCEPTION_IF_NULL(manager);
1127   for (auto input : node->inputs()) {
1128     if (HasAbstractMonad(input)) {
1129       node_size--;
1130     }
1131   }
1132 
1133   if (!CheckInsertMirrorOps(mirror_ops, node, node_size)) {
1134     return;
1135   }
1136 
1137   for (size_t index = 1; index < node_size; ++index) {
1138     OperatorVector backward_op = mirror_ops[index - 1];
1139     if (IsPrimitiveCNode(node, prim::kPrimSend)) {
1140       auto param_index = GetValue<int>(node->GetPrimalAttr(PARAM_INDEX));
1141       backward_op = mirror_ops[IntToSize(param_index)];
1142     }
1143     if (backward_op.empty()) {
1144       continue;
1145     }
1146     std::pair<AnfNodePtr, bool> param_node_pair = FindParameter(node->input(index), func_graph);
1147     if (!param_node_pair.first) {
1148       continue;
1149     }
1150 
1151     auto param_ptr = param_node_pair.first->cast<ParameterPtr>();
1152     std::string param_name;
1153     bool is_shared_param = false;
1154     if (param_ptr) {
1155       param_name = param_ptr->name();
1156       if (!param_ptr->param_info() || !param_ptr->param_info()->requires_grad()) {
1157         MS_LOG(INFO) << param_name << " do not need gradient. Skip inserting mirror.";
1158         continue;
1159       }
1160       std::string opt_shard_mirror_group;
1161       if (param_ptr->user_data<TensorLayout>()) {
1162         opt_shard_mirror_group = param_ptr->user_data<TensorLayout>()->opt_shard_mirror_group();
1163         is_shared_param = param_ptr->user_data<TensorLayout>()->is_shared_param();
1164       }
1165       if (!opt_shard_mirror_group.empty()) {
1166         // mirror ops is covered in not fully use opt shard case
1167         backward_op = CreateMirrorOps(opt_shard_mirror_group, static_cast<size_t>(opt_shard_mirror_group[0]));
1168       }
1169     }
1170     // not a RefKey
1171     std::string mirror_op_name = MirrorOpName();
1172     AnfNodePtr pre_node = node->input(index);
1173     if (!param_node_pair.second) {
1174       auto next_cnode = FindCNode(param_node_pair.first, mirror_op_name, func_graph, 0);
1175       // if there is already a MirrorOp in the same graph, use MirrorOp CNode as a input instead
1176       if (next_cnode.first) {
1177         MS_EXCEPTION_IF_NULL(next_cnode.second);
1178         // assume Load is inserted next to parameter
1179         // skip Load moving up and insert mirror next to the parameter
1180         if (pre_node->cast<CNodePtr>()) {
1181           CNodePtr load_node = SkipTrivialNodesMoveUp(node->input(index)->cast<CNodePtr>());
1182           manager->SetEdge(load_node, 1, next_cnode.second);
1183         } else {
1184           manager->SetEdge(node, static_cast<int>(index), next_cnode.second);
1185         }
1186         MS_LOG(INFO) << "Find parameter " << param_name << " for node " << GetPrimName(node->cast<CNodePtr>())
1187                      << " and share the mirror.";
1188         continue;
1189       }
1190     }
1191     // if the parameter found is a RefKey, or no MirrorOp is found in the same graph, insert a new MirrorOp
1192     // only one MirrorOp in backward_op
1193     if (backward_op.size() != 1) {
1194       MS_LOG(EXCEPTION) << "backward_op size must be 1, real is " << backward_op.size();
1195     }
1196     auto op = backward_op[0];
1197     if (pre_node->cast<CNodePtr>() && (InsertMirrorBeforeCast(node, index) || is_shared_param)) {
1198       // assume Load is inserted next to parameter
1199       // skip Load moving up and insert mirror next to the parameter
1200       CNodePtr load_node = SkipTrivialNodesMoveUp(pre_node->cast<CNodePtr>());
1201       InsertNode(op, load_node, 1, load_node->input(1), func_graph, mirror_op_name, param_name, root);
1202       auto comm_op = load_node->input(1)->cast<CNodePtr>();
1203       // add fusion flag
1204       AddCommOpFusionType(comm_op, param_node_pair.first);
1205       MS_LOG(INFO) << "Find parameter " << param_name << " for node " << GetPrimName(node->cast<CNodePtr>())
1206                    << " and insert mirror before Load";
1207       AddCommOpParamFlag(comm_op);
1208       continue;
1209     }
1210     InsertNode(op, node, index, pre_node, func_graph, mirror_op_name, param_name, root);
1211     MS_LOG(INFO) << "Find parameter " << param_name << " for node " << GetPrimName(node->cast<CNodePtr>())
1212                  << " and insert mirror before the node";
1213     auto comm_op = node->input(index)->cast<CNodePtr>();
1214     // add fusion flag
1215     // pipeline mirror would not be set, which should be supported later
1216     AddCommOpFusionType(comm_op, param_node_pair.first);
1217     AddCommOpParamFlag(comm_op);
1218   }
1219 }
1220 
BackwardCommunication(const FuncGraphPtr & root,const OperatorInfoPtr & distribute_operator,const CNodePtr & node,const std::vector<std::pair<CNodePtr,LossNodeInfo>> & sens_loss_pairs)1221 void BackwardCommunication(const FuncGraphPtr &root, const OperatorInfoPtr &distribute_operator, const CNodePtr &node,
1222                            const std::vector<std::pair<CNodePtr, LossNodeInfo>> &sens_loss_pairs) {
1223   MS_EXCEPTION_IF_NULL(distribute_operator);
1224   MS_EXCEPTION_IF_NULL(node);
1225 
1226   if (IsPrimitiveCNode(node, prim::kPrimReceive)) {
1227     return;
1228   }
1229   bool is_loss_cnode =
1230     std::any_of(sens_loss_pairs.begin(), sens_loss_pairs.end(),
1231                 [node](const std::pair<CNodePtr, LossNodeInfo> &element) { return element.second.loss_node == node; });
1232 
1233   MirrorOps mirror_ops = distribute_operator->mirror_ops();
1234   VirtualDivOp virtual_div_op = distribute_operator->virtual_div_op();
1235   // insert mirror op
1236   if (!mirror_ops.empty()) {
1237     MS_LOG(INFO) << "insert mirror op for " << distribute_operator->name();
1238     InsertMirrorOps(root, mirror_ops, node);
1239   }
1240   // insert virtual div op
1241   if (!virtual_div_op.empty() && is_loss_cnode && IsLastStage()) {
1242     MS_LOG(INFO) << "insert virtual div op for " << distribute_operator->name();
1243     InsertVirtualDivOp(virtual_div_op, node);
1244   }
1245 }
1246 
GetDisOpName(const std::string & prim_name)1247 std::string GetDisOpName(const std::string &prim_name) {
1248   std::string op_name = prim_name;
1249   if (!prim_name.empty() && (prim_name[0] == '_')) {
1250     op_name = prim_name.substr(1);
1251   }
1252   return op_name + "Info";
1253 }
1254 
OperatorInstanceByName(const std::string & name,const PrimitiveAttrs & attrs,const std::vector<Shapes> & shape_list)1255 OperatorInfoPtr OperatorInstanceByName(const std::string &name, const PrimitiveAttrs &attrs,
1256                                        const std::vector<Shapes> &shape_list) {
1257   if (shape_list.size() != 2) {
1258     MS_LOG(ERROR) << "The size of shape list is not 2";
1259     return nullptr;
1260   }
1261   if (name.length() == 0) {
1262     MS_LOG(EXCEPTION) << "Length of name is zero!";
1263   }
1264   std::string distribute_opname = GetDisOpName(name);
1265   if (name == GATHERV2) {
1266     distribute_opname = name + "PInfo";
1267     auto data_parallel_iter = attrs.find(DATA_PARALLEL);
1268     if (data_parallel_iter != attrs.end()) {
1269       MS_EXCEPTION_IF_NULL(data_parallel_iter->second);
1270       if (!data_parallel_iter->second->isa<BoolImm>()) {
1271         MS_LOG(EXCEPTION) << ": data_parallel flag's type is not a bool.";
1272       }
1273       bool data_parallel = data_parallel_iter->second->cast<BoolImmPtr>()->value();
1274       if (data_parallel) {
1275         distribute_opname = name + "Info";
1276       }
1277     }
1278   }
1279   OperatorInfoPtr operator_ =
1280     (OperatorInfoPtr)DynCreator::Instance().Create(distribute_opname, shape_list[0], shape_list[1], attrs, TOTAL_OPS);
1281   if (operator_ == nullptr) {
1282     MS_LOG(INFO) << "Create " << name << " failed";
1283     return nullptr;
1284   }
1285   std::string origin_name = operator_->name();
1286   operator_->set_name(origin_name + std::to_string(TOTAL_OPS));
1287   MS_LOG(INFO) << "Successfully created operator " << origin_name;
1288   ++TOTAL_OPS;
1289   return operator_;
1290 }
1291 
OperatorInstance(const PrimitivePtr & prim,const PrimitiveAttrs & attrs,const std::vector<Shapes> & shape_list)1292 OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs,
1293                                  const std::vector<Shapes> &shape_list) {
1294   MS_EXCEPTION_IF_NULL(prim);
1295   OperatorInfoPtr operator_ = OperatorInstanceByName(prim->name(), attrs, shape_list);
1296   if (operator_ == nullptr) {
1297     if (IsInBatchParallelBlackList(prim)) {
1298       MS_LOG(EXCEPTION) << "Operator " << prim->name() << " is not supported yet in auto parallel mode.";
1299     }
1300     MS_LOG(INFO) << "Create " << prim->name() << " failed, use batch parallel";
1301     operator_ = OperatorInstanceByName(BATCH_PARALLEL, attrs, shape_list);
1302     MS_EXCEPTION_IF_NULL(operator_);
1303   }
1304   return operator_;
1305 }
1306 
NewOperatorInstance(const PrimitivePtr & prim,const PrimitiveAttrs & attrs,std::vector<Shapes> shape_list)1307 OperatorInfoPtr NewOperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs,
1308                                     std::vector<Shapes> shape_list) {
1309   OperatorInfoPtr operator_ = OperatorInstance(prim, attrs, shape_list);
1310   for (size_t i = 0; i < shape_list[0].size(); ++i) {
1311     MS_LOG(INFO) << "No:  " << i << "  input's shape: " << ShapeToString(shape_list[0][i]);
1312   }
1313   return operator_;
1314 }
1315 
ExtractStrategy(const ValuePtr & stra)1316 StrategyPtr ExtractStrategy(const ValuePtr &stra) {
1317   ValueTuplePtr var = stra->cast<ValueTuplePtr>();
1318   StrategyPtr strategyPtr;
1319   int64_t stage_id = g_device_manager->stage_id();
1320 
1321   MS_LOG(INFO) << "Extract information: strategy " << stra->ToString();
1322   if (var == nullptr) {
1323     MS_LOG(EXCEPTION) << "Strategy value is nullptr";
1324   }
1325   if (var->size() > 0) {
1326     std::vector<ValuePtr> elements = var->value();
1327     Strategys strategy;
1328     for (uint64_t index = 0; index < elements.size(); ++index) {
1329       Dimensions dim;
1330       if (elements[index]->isa<ValueSequeue>()) {
1331         ValueTuplePtr value_tuple = elements[index]->cast<ValueTuplePtr>();
1332         std::vector<ValuePtr> value_vector = value_tuple->value();
1333         (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(dim),
1334                              [](const ValuePtr &value) { return static_cast<int64_t>(GetValue<int64_t>(value)); });
1335         strategy.push_back(dim);
1336       } else {
1337         MS_LOG(EXCEPTION) << "Failure: Strategy's format is wrong! Need ValueSequence";
1338       }
1339     }
1340     if (strategy.empty()) {
1341       MS_LOG(EXCEPTION) << "ExtractStrategy: failed to extract strategy";
1342     }
1343     strategyPtr = NewStrategy(stage_id, strategy);
1344   }
1345 
1346   return strategyPtr;
1347 }
1348 
GetRefKeyNodeShape(const AnfNodePtr & node,const FuncGraphPtr & func_graph)1349 Shapes GetRefKeyNodeShape(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
1350   MS_EXCEPTION_IF_NULL(node);
1351   MS_EXCEPTION_IF_NULL(func_graph);
1352 
1353   std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(node, func_graph);
1354   if (parameters.size() != 1) {
1355     MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
1356   }
1357 
1358   Shapes input_shapes;
1359   input_shapes = GetNodeShape(parameters[0]);
1360   if (input_shapes.size() != 1) {
1361     MS_LOG(EXCEPTION) << "Get input shape failed";
1362   }
1363 
1364   MS_LOG(INFO) << "The parameter shape is " << ShapeToString(input_shapes[0]);
1365   return input_shapes;
1366 }
1367 
ExtractShape(const CNodePtr & node)1368 std::vector<Shapes> ExtractShape(const CNodePtr &node) {
1369   MS_EXCEPTION_IF_NULL(node);
1370   Shapes shape_inputs, shape_outputs;
1371   std::vector<Shapes> shape_all;
1372   std::vector<AnfNodePtr> all_inputs = node->inputs();
1373 
1374   size_t inputs_size = all_inputs.size();
1375   for (size_t i = 1; i < inputs_size; ++i) {
1376     Shapes input_shapes;
1377     AnfNodePtr input = all_inputs[i];
1378     if (HasAbstractMonad(input)) {
1379       continue;
1380     }
1381     if (IsValueNode<RefKey>(input)) {
1382       auto func_graph = node->func_graph();
1383       MS_EXCEPTION_IF_NULL(func_graph);
1384       std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(input, func_graph);
1385       if (parameters.size() != 1) {
1386         MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
1387       }
1388       std::pair<AnfNodePtr, int64_t> node_pair = std::make_pair(node, SizeToLong(i));
1389       g_RefMap[parameters[0]] = node_pair;
1390       input_shapes = GetRefKeyNodeShape(input, func_graph);
1391     } else if (input->isa<CNode>() || IsValueNode<Tensor>(input) || input->isa<Parameter>() ||
1392                ((IsValueNode<ValueList>(input) || IsValueNode<ValueTuple>(input)) && (inputs_size == 2))) {
1393       input_shapes = GetNodeShape(input);
1394     } else {
1395       continue;
1396     }
1397     if (input_shapes.size() != 1) {
1398       if (inputs_size == 2) {  // like concat
1399         shape_inputs = input_shapes;
1400         break;
1401       } else {
1402         MS_LOG(EXCEPTION) << "ExtractShape: Get input shape failed";
1403       }
1404     }
1405     shape_inputs.push_back(input_shapes[0]);
1406   }
1407   shape_all.push_back(shape_inputs);
1408   // extract out shape
1409   shape_outputs = GetNodeShape(node);
1410   shape_all.push_back(shape_outputs);
1411   return shape_all;
1412 }
1413 
FindParallelCareNode(const AnfNodePtr & node,int32_t recursion_num)1414 std::pair<AnfNodePtr, int64_t> FindParallelCareNode(const AnfNodePtr &node, int32_t recursion_num) {
1415   if (recursion_num >= RECURSION_LIMIT) {
1416     return std::make_pair(nullptr, 0);
1417   }
1418 
1419   MS_EXCEPTION_IF_NULL(node);
1420   FuncGraphPtr func_graph = node->func_graph();
1421   MS_EXCEPTION_IF_NULL(func_graph);
1422   FuncGraphManagerPtr manager = func_graph->manager();
1423   MS_EXCEPTION_IF_NULL(manager);
1424   AnfNodeIndexSet node_set = manager->node_users()[node];
1425   for (auto &node_pair : node_set) {
1426     CNodePtr cnode = node_pair.first->cast<CNodePtr>();
1427     MS_EXCEPTION_IF_NULL(cnode);
1428     if (!IsValueNode<Primitive>(cnode->input(0))) {
1429       continue;
1430     }
1431     ValueNodePtr prim_node_anf = cnode->input(0)->cast<ValueNodePtr>();
1432     MS_EXCEPTION_IF_NULL(prim_node_anf);
1433     PrimitivePtr node_prim = prim_node_anf->value()->cast<PrimitivePtr>();
1434     MS_EXCEPTION_IF_NULL(node_prim);
1435     if ((node_prim->name() == DEPEND && node_pair.second != 1) || IsPrimitiveCNode(cnode, prim::kPrimReceive) ||
1436         IsPrimitiveCNode(cnode, prim::kPrimSend)) {
1437       continue;
1438     }
1439     if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
1440       return node_pair;
1441     } else {
1442       auto tmp_pair = FindParallelCareNode(node_pair.first, recursion_num + 1);
1443       if (tmp_pair.first != nullptr) {
1444         return tmp_pair;
1445       }
1446     }
1447   }
1448   return std::make_pair(nullptr, 0);
1449 }
1450 
FindSubGraph(const FuncGraphPtr & graph,const AnfNodePtr & parameter)1451 std::pair<AnfNodePtr, int64_t> FindSubGraph(const FuncGraphPtr &graph, const AnfNodePtr &parameter) {
1452   MS_EXCEPTION_IF_NULL(graph);
1453   MS_EXCEPTION_IF_NULL(parameter);
1454   FuncGraphManagerPtr manager = graph->manager();
1455   MS_EXCEPTION_IF_NULL(manager);
1456   std::pair<AnfNodePtr, int64_t> prim_anf_node_pair = FindParallelCareNode(parameter, 0);
1457   if (prim_anf_node_pair.first != nullptr) {
1458     return prim_anf_node_pair;
1459   } else {
1460     AnfNodeIndexSet param_sub_set = manager->node_users()[parameter];
1461     for (auto &param_pair : param_sub_set) {
1462       CNodePtr param_cnode = param_pair.first->cast<CNodePtr>();
1463       AnfNodePtr graph_value_node;
1464       if (param_cnode->input(0)->isa<CNode>()) {
1465         graph_value_node = param_cnode->input(0)->cast<CNodePtr>()->input(1);
1466       } else {
1467         graph_value_node = param_cnode->input(0);
1468       }
1469       if (!IsValueNode<FuncGraph>(graph_value_node)) {
1470         continue;
1471       }
1472       FuncGraphPtr graph_sub = GetValueNode<FuncGraphPtr>(graph_value_node);
1473       auto parameters = graph_sub->parameters();
1474       if (LongToSize(param_pair.second - 1) >= parameters.size()) {
1475         MS_LOG(EXCEPTION) << "The index is out of range, index is: " << (param_pair.second - 1) << ", vector size is "
1476                           << parameters.size();
1477       }
1478       std::pair<AnfNodePtr, int64_t> res = FindSubGraph(graph_sub, parameters[LongToSize(param_pair.second - 1)]);
1479       if (res.first != nullptr) {
1480         return res;
1481       }
1482     }
1483   }
1484   return std::make_pair(nullptr, 0);
1485 }
1486 
InsertAllGatherAfterCast(const CNodePtr & cnode)1487 CNodePtr InsertAllGatherAfterCast(const CNodePtr &cnode) {
1488   MS_EXCEPTION_IF_NULL(cnode);
1489   auto graph = cnode->func_graph();
1490   MS_EXCEPTION_IF_NULL(graph);
1491   auto manager = graph->manager();
1492   MS_EXCEPTION_IF_NULL(manager);
1493   // skip Load moving down and assume it only has one node user
1494   CNodePtr res = cnode;
1495   if (IsSomePrimitive(res, LOAD)) {
1496     res = manager->node_users()[cnode].begin()->first->cast<CNodePtr>();
1497   }
1498   // return true only if cnode is Cast from fp32 to fp16
1499   if (!IsSomePrimitive(res, CAST)) {
1500     return nullptr;
1501   }
1502   auto node_type = res->Type();
1503   MS_EXCEPTION_IF_NULL(node_type);
1504   if (!node_type->isa<mindspore::TensorType>()) {
1505     MS_LOG(EXCEPTION) << "Unknown type.";
1506   }
1507   auto input_element_type = node_type->cast<mindspore::TensorTypePtr>()->element();
1508   MS_EXCEPTION_IF_NULL(input_element_type);
1509   auto type_id = input_element_type->type_id();
1510 
1511   if (type_id != kNumberTypeFloat32) {
1512     return res;
1513   } else {
1514     return nullptr;
1515   }
1516 }
1517 
InsertAllGatherOp(const FuncGraphPtr & root,const std::string & group,const std::pair<AnfNodePtr,int> & res,const AnfNodePtr & node,const std::string & op_name,bool is_shared_param)1518 static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group, const std::pair<AnfNodePtr, int> &res,
1519                               const AnfNodePtr &node, const std::string &op_name, bool is_shared_param) {
1520   MS_EXCEPTION_IF_NULL(res.first);
1521   MS_EXCEPTION_IF_NULL(node);
1522   auto cnode = res.first->cast<CNodePtr>();
1523   auto graph = cnode->func_graph();
1524   MS_EXCEPTION_IF_NULL(graph);
1525   auto manager = graph->manager();
1526   MS_EXCEPTION_IF_NULL(manager);
1527   auto cnode_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
1528   MS_EXCEPTION_IF_NULL(cnode_prim);
1529   Operator op;
1530   CNodePtr allgather;
1531   auto param_name = node->cast<ParameterPtr>()->name();
1532   if (op_name == MINI_STEP_ALL_GATHER) {
1533     op = CreateMiniStepAllGatherOp(group);
1534   } else if (op_name == MICRO_STEP_ALL_GATHER) {
1535     op = CreateMicroStepAllGatherOp(group);
1536   } else {
1537     op = CreateAllGatherOp(group);
1538   }
1539   CNodePtr cast_node = InsertAllGatherAfterCast(cnode);
1540   if (!is_shared_param && cast_node) {
1541     allgather = ReplaceNode(op, cast_node, graph, PARALLEL_OPTIMIZER_ALLGATHER_NOT_COMPUTE, param_name, root);
1542     MS_LOG(INFO) << "Parallel optimizer is applied before Cast for " << param_name;
1543   } else {
1544     InsertNode(op, cnode, IntToSize(res.second), node, graph, PARALLEL_OPTIMIZER_ALLGATHER_NOT_COMPUTE, param_name,
1545                root);
1546     allgather = cnode->input(IntToSize(res.second))->cast<CNodePtr>();
1547     MS_LOG(INFO) << "Parallel optimizer is applied before " << GetPrimName(cnode) << " for " << param_name;
1548   }
1549   // add fusion flag
1550   AddCommOpFusionType(allgather, node);
1551   // add gradients mean
1552   AddCommOpMeanFlag(allgather);
1553 }
1554 
ApplyParallelOptOnParam(const FuncGraphPtr & root,const AnfNodePtr & parameter,const std::string & opt_shard_group)1555 static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &parameter,
1556                                     const std::string &opt_shard_group) {
1557   if (opt_shard_group.empty()) {
1558     return;
1559   }
1560   FuncGraphManagerPtr manager = root->manager();
1561   MS_EXCEPTION_IF_NULL(parameter);
1562   MS_EXCEPTION_IF_NULL(manager);
1563   int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
1564   int32_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
1565   std::string op_name;
1566   if (grad_accumulation_step > 1) {
1567     op_name = MINI_STEP_ALL_GATHER;
1568   } else if (split_stage_num > 1) {
1569     op_name = MICRO_STEP_ALL_GATHER;
1570   } else {
1571     op_name = ALL_GATHER;
1572   }
1573   auto param_sub_set = manager->node_users()[parameter];
1574   bool insert_flag = false;
1575   for (auto &param_pair : param_sub_set) {
1576     auto cnode = param_pair.first->cast<CNodePtr>();
1577     MS_EXCEPTION_IF_NULL(cnode);
1578     if (cnode->in_forward_flag() && !IsPrimitiveCNode(cnode, prim::kPrimReceive) &&
1579         !IsPrimitiveCNode(cnode, prim::kPrimDepend)) {
1580       OperatorInfoPtr distribute_operator = cnode->user_data<OperatorInfo>();
1581       if (distribute_operator == nullptr) {
1582         MS_LOG(DEBUG) << "Parallel optimizer: " << GetPrimName(cnode) << " 's OperatorInfoPtr is nullptr";
1583       } else if (IntToSize(param_pair.second - 1) >= distribute_operator->inputs_tensor_info().size()) {
1584         MS_LOG(EXCEPTION) << "The index is out of range, index is  " << (param_pair.second - 1) << ", vector size is  "
1585                           << distribute_operator->inputs_tensor_info().size();
1586       }
1587       if (insert_flag) {
1588         // if there are multiple node users, they share one same allgather
1589         auto next_cnode = FindCNode(parameter, op_name, cnode->func_graph(), 0);
1590         if (next_cnode.first) {
1591           manager->SetEdge(cnode, param_pair.second, next_cnode.second);
1592           MS_LOG(INFO) << "Parallel optimizer is shared between " << parameter->ToString() << " and "
1593                        << GetPrimName(cnode);
1594         } else {
1595           MS_LOG(ERROR) << "Can not find the shared AllGather with multiple node users.";
1596         }
1597       } else {
1598         // insert allgather operator between shard parameter and cnode
1599         auto param_ptr = parameter->cast<ParameterPtr>();
1600         MS_EXCEPTION_IF_NULL(param_ptr);
1601         bool is_shared_param = param_ptr->user_data<TensorLayout>()->is_shared_param();
1602         InsertAllGatherOp(root, opt_shard_group, param_pair, parameter, op_name, is_shared_param);
1603         insert_flag = true;
1604       }
1605     }
1606   }
1607 }
1608 
GetOptShardGroup(const AnfNodePtr & parameter,TensorLayout * const tensor_layout,const OperatorInfoPtr & distribute_operator)1609 static std::string GetOptShardGroup(const AnfNodePtr &parameter, TensorLayout *const tensor_layout,
1610                                     const OperatorInfoPtr &distribute_operator) {
1611   std::string opt_shard_group;
1612   if (!ParameterRequireGrad(parameter)) {
1613     // only trainable parameters need parallel optimizer
1614     MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " is not trainable parameter.";
1615   } else if (parameter->cast<ParameterPtr>()->param_info() &&
1616              !parameter->cast<ParameterPtr>()->param_info()->parallel_optimizer()) {
1617     MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " does not need weight shard.";
1618   } else if (tensor_layout->GenerateOptShardSliceShape() == Status::SUCCESS) {
1619     // get the shard tensor slice shape if the weight is repeated on devices
1620     // and the shape of the first dimension could be divided
1621     // apply parallel optimizer on parameters
1622     // create communication group for allgather operator
1623     std::vector<Group> dev_group;
1624     if (distribute_operator->CreateGroupForOptShard(tensor_layout, &dev_group) == Status::SUCCESS &&
1625         !dev_group.empty()) {
1626       opt_shard_group = dev_group[0].name();
1627       MS_LOG(INFO) << "Parallel optimizer: create group for " << parameter->ToString() << " success.";
1628     } else {
1629       MS_LOG(ERROR) << "Parallel optimizer: create group for " << parameter->ToString() << " failed.";
1630     }
1631   } else {
1632     MS_LOG(WARNING) << "Parallel optimizer: " << parameter->ToString() << "'s distributed shape "
1633                     << tensor_layout->slice_shape().ToString() << " does not satisfy the conditions.";
1634   }
1635   return opt_shard_group;
1636 }
1637 
SetSharedParameterFlag(const FuncGraphPtr & root,const AnfNodePtr & parameter)1638 void SetSharedParameterFlag(const FuncGraphPtr &root, const AnfNodePtr &parameter) {
1639   MS_EXCEPTION_IF_NULL(root);
1640   MS_EXCEPTION_IF_NULL(parameter);
1641   FuncGraphManagerPtr manager = root->manager();
1642   MS_EXCEPTION_IF_NULL(manager);
1643   auto parameter_ptr = parameter->cast<ParameterPtr>();
1644   if (!parameter_ptr) {
1645     MS_LOG(INFO) << parameter->ToString() << " is not a parameter";
1646     return;
1647   }
1648   auto param_sub_set = manager->node_users()[parameter];
1649   int32_t users_count = 0;
1650   for (auto &param_pair : param_sub_set) {
1651     auto cnode = param_pair.first->cast<CNodePtr>();
1652     MS_EXCEPTION_IF_NULL(cnode);
1653     if (cnode->in_forward_flag()) users_count++;
1654   }
1655   if (users_count > 1) {
1656     auto tensor_layout = parameter_ptr->user_data<TensorLayout>();
1657     tensor_layout->set_is_shared_param(true);
1658     MS_LOG(WARNING) << "There are multiple users for " << parameter->ToString()
1659                     << ". Mixed precision optimization is not valid here.";
1660   }
1661 }
1662 
1663 // When this function returns non-empty string, that means parallel optimizer is applied on this parameter.
SetParallelShape(const AnfNodePtr & parameter,const std::pair<AnfNodePtr,int64_t> & res)1664 std::string SetParallelShape(const AnfNodePtr &parameter, const std::pair<AnfNodePtr, int64_t> &res) {
1665   MS_EXCEPTION_IF_NULL(parameter);
1666   AbstractBasePtr abstract = parameter->abstract();
1667   MS_EXCEPTION_IF_NULL(abstract);
1668   MS_LOG(DEBUG) << "SetParallelShape " << parameter->ToString() << " shape " << parameter->Shape()->ToString();
1669   CNodePtr cnode = res.first->cast<CNodePtr>();
1670   MS_EXCEPTION_IF_NULL(cnode);
1671   OperatorInfoPtr distribute_operator = cnode->user_data<OperatorInfo>();
1672   if (distribute_operator == nullptr) {
1673     MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " 's OperatorInfoPtr is nullptr";
1674   }
1675   if (LongToSize(res.second - 1) >= distribute_operator->inputs_tensor_info().size()) {
1676     MS_LOG(EXCEPTION) << "The index is out of range, index is  " << (res.second - 1) << ", vector size is  "
1677                       << distribute_operator->inputs_tensor_info().size();
1678   }
1679   TensorInfo tensorinfo_in = distribute_operator->inputs_tensor_info()[LongToSize(res.second - 1)];
1680   TensorLayout tensor_layout = tensorinfo_in.tensor_layout();
1681   Shape slice_shape = tensor_layout.slice_shape().array();
1682   std::string opt_shard_group;
1683   MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
1684   bool enable_parallel_optimizer = ParallelContext::GetInstance()->enable_parallel_optimizer();
1685   if (enable_parallel_optimizer) {
1686     opt_shard_group = GetOptShardGroup(parameter, &tensor_layout, distribute_operator);
1687   }
1688   if (!opt_shard_group.empty()) {
1689     slice_shape = tensor_layout.opt_shard_slice_shape();
1690   }
1691   MS_LOG(INFO) << "SetParallelShape slice_shape  " << parameter->ToString() << "  shape "
1692                << MakeValue(slice_shape)->ToString() << ", op name is " << distribute_operator->name();
1693   std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
1694   MS_EXCEPTION_IF_NULL(parallel_shape);
1695   // Don't modify it in-place as the pointer of this AbstractValue may used as cache key in StaticAnalysis.
1696   auto cloned_abstract = abstract->Clone();
1697   MS_EXCEPTION_IF_NULL(cloned_abstract);
1698   cloned_abstract->set_shape(parallel_shape);
1699   parameter->set_abstract(cloned_abstract);
1700   ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>();
1701   MS_EXCEPTION_IF_NULL(parameter_ptr);
1702   parameter_ptr->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
1703   return opt_shard_group;
1704 }
1705 
CoverSliceShape(const FuncGraphPtr & root)1706 void CoverSliceShape(const FuncGraphPtr &root) {
1707   MS_EXCEPTION_IF_NULL(root);
1708   auto parameters = root->parameters();
1709   for (auto &parameter : parameters) {
1710     MS_EXCEPTION_IF_NULL(parameter->Shape());
1711     auto iter = g_RefMap.find(parameter);
1712     if (iter != g_RefMap.end()) {
1713       std::string group = SetParallelShape(parameter, g_RefMap[parameter]);
1714       // find all forward nodes that use parameter in graphs and insert allgather if group is not empty
1715       SetSharedParameterFlag(root, parameter);
1716       ApplyParallelOptOnParam(root, parameter, group);
1717       continue;
1718     }
1719     std::pair<AnfNodePtr, int64_t> res = FindSubGraph(root, parameter);
1720     if (res.first == nullptr) {
1721       MS_LOG(INFO) << "Parameter " << parameter->ToString() << " don't need to set parallel shape";
1722     } else {
1723       std::string group = SetParallelShape(parameter, res);
1724       // find all forward nodes that use parameter in graphs and insert allgather if group is not empty
1725       SetSharedParameterFlag(root, parameter);
1726       ApplyParallelOptOnParam(root, parameter, group);
1727       MS_LOG(DEBUG) << "Parameter " << parameter->ToString() << " shape " << parameter->Shape()->ToString();
1728     }
1729   }
1730   g_RefMap.clear();
1731 }
1732 
SetVirtualDatasetStrategy(const CNodePtr & node)1733 void SetVirtualDatasetStrategy(const CNodePtr &node) {
1734   MS_EXCEPTION_IF_NULL(node);
1735   MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
1736   bool full_batch = ParallelContext::GetInstance()->full_batch();
1737 
1738   PrimitivePtr prim = GetValueNode<PrimitivePtr>(node->input(0));
1739   MS_EXCEPTION_IF_NULL(prim);
1740   if (prim->name() == VIRTUAL_DATA_SET || prim->name() == VIRTUAL_OUTPUT) {
1741     CheckGlobalDeviceManager();
1742     auto attrs_temp = prim->attrs();
1743     if (!ParallelContext::GetInstance()->dataset_strategy().empty() && prim->name() == VIRTUAL_DATA_SET) {
1744       std::vector<ValuePtr> elements;
1745       auto dataset_strategy = ParallelContext::GetInstance()->dataset_strategy();
1746       (void)std::transform(dataset_strategy.begin(), dataset_strategy.end(), std::back_inserter(elements),
1747                            [](auto input_stra) { return MakeValue(input_stra); });
1748       ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements);
1749       attrs_temp[STRATEGY] = strategy;
1750       (void)prim->SetAttrs(attrs_temp);
1751       return;
1752     }
1753     int64_t dev_num;
1754     if (full_batch) {
1755       dev_num = 1;
1756     } else {
1757       dev_num = g_device_manager->stage_device_num();
1758     }
1759     if (dev_num == 0) {
1760       MS_LOG(EXCEPTION) << "Device Num must be larger than 0, but got 0.";
1761     }
1762     std::vector<Shapes> shape_list = ExtractShape(node);
1763     if (shape_list.empty()) {
1764       MS_LOG(EXCEPTION) << "Failure:node " << node->ToString() << " failed to extract shape";
1765     }
1766     std::vector<ValuePtr> elements;
1767     for (size_t i = 0; i < shape_list[0].size(); i++) {
1768       if (shape_list[0][i].empty()) {
1769         MS_LOG(EXCEPTION) << "shape_list[ " << i << " ].size() is zero";
1770       }
1771       Dimensions input_strategy;
1772       if (!shape_list[0][i].empty() && shape_list[0][i][0] % dev_num == 0) {
1773         input_strategy.push_back(dev_num);
1774       } else if (!shape_list[0][i].empty()) {
1775         input_strategy.push_back(1);
1776       }
1777       for (size_t j = 1; j < shape_list[0][i].size(); j++) {
1778         input_strategy.push_back(1);
1779       }
1780       elements.push_back(MakeValue(input_strategy));
1781     }
1782     ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements);
1783     attrs_temp[STRATEGY] = strategy;
1784     (void)prim->SetAttrs(attrs_temp);
1785   }
1786 }
1787 
1788 // find previous parallel care node's next node.
FindPreNodes(const AnfNodePtr & node,vector<std::string> * unique_ids,vector<size_t> * indexes,size_t curr_depth)1789 bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids, vector<size_t> *indexes, size_t curr_depth) {
1790   if (curr_depth > MAX_RECURSIVE_DEPTH) {
1791     MS_LOG(WARNING) << "When find the previous node, exceeded the maximum recursion depth: " << MAX_RECURSIVE_DEPTH;
1792     return false;
1793   }
1794   MS_EXCEPTION_IF_NULL(unique_ids);
1795   MS_EXCEPTION_IF_NULL(indexes);
1796   if (!node->isa<CNode>()) {
1797     return false;
1798   }
1799   CNodePtr pre_cnode = node->cast<CNodePtr>();
1800   if (!IsValueNode<Primitive>(pre_cnode->input(0))) {
1801     return false;
1802   }
1803   bool find = false;
1804   for (size_t index = 1; index < pre_cnode->inputs().size(); ++index) {
1805     auto next_node = pre_cnode->inputs()[index];
1806     if (!next_node->isa<CNode>() || next_node->isa<Parameter>()) {
1807       return false;
1808     }
1809     CNodePtr cnode = next_node->cast<CNodePtr>();
1810     if (!IsValueNode<Primitive>(cnode->input(0))) {
1811       return false;
1812     }
1813     ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
1814     PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
1815     if (IsParallelCareNode(cnode) && prim->name() != MAKE_TUPLE && prim->name() != MAKE_LIST) {
1816       unique_ids->push_back(pre_cnode->UniqueId());
1817       indexes->push_back(index);
1818       find = true;
1819       continue;
1820     }
1821     if (FindPreNodes(cnode, unique_ids, indexes, ++curr_depth)) {
1822       find = true;
1823       continue;
1824     }
1825   }
1826   return find;
1827 }
1828 
FindLastNodesUniqueId(const FuncGraphPtr & root,std::vector<std::string> * unique_ids,std::vector<size_t> * indexes)1829 void FindLastNodesUniqueId(const FuncGraphPtr &root, std::vector<std::string> *unique_ids,
1830                            std::vector<size_t> *indexes) {
1831   MS_EXCEPTION_IF_NULL(unique_ids);
1832   CNodePtr cnode = root->get_return();
1833   if (!FindPreNodes(cnode, unique_ids, indexes, 0)) {
1834     MS_LOG(WARNING) << "cannot find the last parallel care node in eval graph";
1835   }
1836 }
1837 
GenerateBatchParallelStrategy(const OperatorInfoPtr operator_,const PrimitivePtr prim)1838 StrategyPtr GenerateBatchParallelStrategy(const OperatorInfoPtr operator_, const PrimitivePtr prim) {
1839   MS_EXCEPTION_IF_NULL(operator_);
1840   MS_EXCEPTION_IF_NULL(prim);
1841   StrategyPtr strategyPtr;
1842   std::shared_ptr<Strategys> strategy_v_ptr = operator_->GenerateBatchStrategies();
1843   MS_EXCEPTION_IF_NULL(strategy_v_ptr);
1844   strategyPtr = NewStrategy(0, *strategy_v_ptr);
1845   std::vector<ValuePtr> elements;
1846   for (size_t i = 0; i < strategy_v_ptr->size(); i++) {
1847     elements.push_back(MakeValue((*strategy_v_ptr)[i]));
1848   }
1849   ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements);
1850   // display the strategy generated by batch parallel
1851   auto attrs = prim->attrs();
1852   attrs[GEN_STRATEGY] = strategy;
1853   (void)prim->SetAttrs(attrs);
1854   MS_LOG(INFO) << "prim " << prim->name() << " batch parallel strategy is " << attrs[GEN_STRATEGY]->ToString();
1855   return strategyPtr;
1856 }
1857 
CheckExtractInfomation(const CNodePtr & cnode)1858 static bool CheckExtractInfomation(const CNodePtr &cnode) {
1859   if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
1860     return false;
1861   }
1862 
1863   ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
1864   PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
1865   if ((prim->name() == MAKE_TUPLE) || (prim->name() == MAKE_LIST) || (prim->name() == RECEIVE)) {
1866     return false;
1867   }
1868 
1869   if (!IsParallelCareNode(cnode)) {
1870     return false;
1871   }
1872   return true;
1873 }
1874 
ExtractInformation(const std::vector<AnfNodePtr> & all_nodes)1875 void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
1876   // load strategy map from checkpoint
1877   StrategyMap stra_map;
1878   if (StrategyCheckpoint::GetInstance().LoadCheckPointOn() &&
1879       (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS)) {
1880     MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
1881   }
1882 
1883   for (auto &node : all_nodes) {
1884     auto cnode = node->cast<CNodePtr>();
1885     if (!CheckExtractInfomation(cnode) || IsPrimitiveCNode(node, prim::kPrimSend)) {
1886       continue;
1887     }
1888 
1889     SetVirtualDatasetStrategy(cnode);
1890     ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
1891     PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
1892 
1893     auto attrs = prim->attrs();
1894     MS_LOG(INFO) << "extract information: node: " << node->ToString() << " prim " << prim->name();
1895 
1896     std::vector<Shapes> shape_list = ExtractShape(cnode);
1897     if (shape_list.empty()) {
1898       MS_LOG(EXCEPTION) << "Failure:node " << node->ToString() << " failed to extract shape";
1899     }
1900     OperatorInfoPtr operator_ = OperatorInstance(prim, attrs, shape_list);
1901     MS_EXCEPTION_IF_NULL(operator_);
1902 
1903     auto &inputs = cnode->inputs();
1904     std::vector<ValuePtr> input_value;
1905     for (size_t index = 1; index < inputs.size(); ++index) {
1906       if (inputs[index]->isa<ValueNode>()) {
1907         input_value.push_back(GetValueNode(inputs[index]));
1908         continue;
1909       }
1910       input_value.emplace_back(nullptr);
1911     }
1912     StrategyPtr strategyPtr = nullptr;
1913     (*operator_).set_input_value(input_value);
1914     (*operator_).set_outputs_dtype(cnode->Type());
1915     (*operator_).set_cnode(cnode);
1916     if (prim->name() == RESHAPE) {
1917       cnode->set_user_data<OperatorInfo>(operator_);
1918       continue;
1919     }
1920     // load strategy checkpoint
1921     // key of strategy map
1922     std::string strategy_key_name = "";
1923     auto param_names = NodeParameterName(cnode, -1, 0);
1924     if (!param_names.empty()) {
1925       strategy_key_name = prim->name() + "_" + param_names[0].first;
1926     }
1927     bool load_strategy_from_ckpt =
1928       StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end();
1929     if ((!StrategyFound(attrs) && !load_strategy_from_ckpt) && !cnode->HasPrimalAttr(STRATEGY)) {
1930       MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name()
1931                    << " is empty, using batch parallel";
1932       strategyPtr = GenerateBatchParallelStrategy(operator_, prim);
1933     } else if (cnode->HasPrimalAttr(STRATEGY)) {
1934       strategyPtr = ExtractStrategy(cnode->GetPrimalAttr(STRATEGY));
1935     } else if (StrategyFound(attrs)) {
1936       strategyPtr = ExtractStrategy(attrs[STRATEGY]);
1937     } else {
1938       strategyPtr = stra_map[strategy_key_name];
1939     }
1940 
1941     MS_EXCEPTION_IF_NULL(strategyPtr);
1942     if (operator_->Init(strategyPtr) == FAILED) {
1943       MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed"
1944                         << " trace: " << trace::DumpSourceLines(cnode);
1945     }
1946     cnode->set_user_data<OperatorInfo>(operator_);
1947   }
1948 }
1949 
GetInputLayoutFromCNode(const std::pair<AnfNodePtr,int64_t> & node_pair)1950 TensorLayout GetInputLayoutFromCNode(const std::pair<AnfNodePtr, int64_t> &node_pair) {
1951   CNodePtr cnode = node_pair.first->cast<CNodePtr>();
1952   MS_EXCEPTION_IF_NULL(cnode);
1953   OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode);
1954   MS_EXCEPTION_IF_NULL(distribute_operator);
1955   int64_t index = node_pair.second;
1956   if (index > SizeToLong(distribute_operator->inputs_tensor_info().size())) {
1957     MS_LOG(EXCEPTION) << "The index is out of range, the node_pair.second is  " << (index - 1)
1958                       << ", the vector size is  " << distribute_operator->inputs_tensor_info().size();
1959   }
1960   TensorInfo tensorinfo_in = distribute_operator->inputs_tensor_info()[LongToSize(index - 1)];
1961   TensorLayout tensorlayout_in = tensorinfo_in.tensor_layout();
1962   return tensorlayout_in;
1963 }
1964 
1965 // if reshape's output connect to several primitive, return the first layout found
FindNextLayout(const CNodePtr & cnode,bool * next_is_reshape)1966 std::shared_ptr<TensorLayout> FindNextLayout(const CNodePtr &cnode, bool *next_is_reshape) {
1967   MS_EXCEPTION_IF_NULL(cnode);
1968   MS_EXCEPTION_IF_NULL(cnode->func_graph());
1969   FuncGraphManagerPtr manager = cnode->func_graph()->manager();
1970   MS_EXCEPTION_IF_NULL(manager);
1971   AnfNodeIndexSet node_set = manager->node_users()[cnode];
1972   for (auto &node_pair : node_set) {
1973     CNodePtr use_apply = node_pair.first->cast<CNodePtr>();
1974     if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
1975       continue;
1976     }
1977     if (IsPrimitiveCNode(use_apply, prim::kPrimReshape)) {
1978       *next_is_reshape = true;
1979       continue;
1980     }
1981     ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
1982     MS_EXCEPTION_IF_NULL(prim_anf_node);
1983     PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
1984     MS_EXCEPTION_IF_NULL(node_prim);
1985     MS_LOG(INFO) << "FindNextLayout prim " << node_prim->name();
1986     if (node_prim->name() == DEPEND && node_pair.second != 1) {
1987       continue;
1988     }
1989     if (IsParallelCareNode(use_apply) && use_apply->has_user_data<OperatorInfo>()) {
1990       MS_LOG(INFO) << "FindNextLayout success prim " << node_prim->name();
1991       *next_is_reshape = false;
1992       auto layout = GetInputLayoutFromCNode(node_pair);
1993       return std::make_shared<TensorLayout>(layout);
1994     }
1995     MS_LOG(DEBUG) << "FindNextLayout failed prim " << node_prim->name() << "  " << IsParallelCareNode(use_apply)
1996                   << "   " << use_apply->has_user_data<OperatorInfo>();
1997 
1998     auto layout_ptr = FindNextLayout(use_apply, next_is_reshape);
1999     if (layout_ptr) {
2000       return layout_ptr;
2001     }
2002   }
2003   MS_LOG(WARNING) << "FindNextLayout return nullptr, if reshape is not the last primitive, there must be some error";
2004   return nullptr;
2005 }
2006 
GetOutputLayoutFromCNode(const CNodePtr & cnode,size_t output_index)2007 std::shared_ptr<TensorLayout> GetOutputLayoutFromCNode(const CNodePtr &cnode, size_t output_index) {
2008   MS_EXCEPTION_IF_NULL(cnode);
2009   OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode);
2010   MS_EXCEPTION_IF_NULL(distribute_operator);
2011   if (distribute_operator->outputs_tensor_info().size() <= output_index) {
2012     MS_LOG(EXCEPTION) << "outputs_tensor_info size is  " << distribute_operator->inputs_tensor_info().size()
2013                       << ", must be greater than output_index  " << output_index;
2014   }
2015   TensorInfo tensorinfo_out = distribute_operator->outputs_tensor_info()[output_index];
2016   TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout();
2017   return std::make_shared<TensorLayout>(tensorlayout_out);
2018 }
2019 
FindPrevParallelCareNodeLayout(const AnfNodePtr & node,size_t output_index)2020 std::shared_ptr<TensorLayout> FindPrevParallelCareNodeLayout(const AnfNodePtr &node, size_t output_index) {
2021   if (!node->isa<CNode>()) {
2022     return nullptr;
2023   }
2024   CNodePtr cnode = node->cast<CNodePtr>();
2025   if (!IsValueNode<Primitive>(cnode->input(0))) {
2026     return nullptr;
2027   }
2028   if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
2029     auto layout_ptr = GetOutputLayoutFromCNode(cnode, output_index);
2030     if (!layout_ptr) {
2031       MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed";
2032     }
2033     return layout_ptr;
2034   }
2035   return nullptr;
2036 }
2037 
FindParameterNextLayout(const AnfNodePtr & node,size_t curr_depth)2038 std::shared_ptr<TensorLayout> FindParameterNextLayout(const AnfNodePtr &node, size_t curr_depth) {
2039   if (curr_depth > MAX_RECURSIVE_DEPTH) {
2040     MS_LOG(WARNING) << "When finding the next tensor layout for the parameter, exceeded the maximum recursion depth: "
2041                     << MAX_RECURSIVE_DEPTH;
2042     return nullptr;
2043   }
2044   FuncGraphManagerPtr manager = node->func_graph()->manager();
2045   MS_EXCEPTION_IF_NULL(manager);
2046   AnfNodeIndexSet node_set = manager->node_users()[node];
2047   for (auto &node_pair : node_set) {
2048     if (IsPrimitiveCNode(node_pair.first, prim::kPrimLoad)) {
2049       auto layout_param = FindParameterNextLayout(node_pair.first, ++curr_depth);
2050       if (!layout_param) {
2051         continue;
2052       }
2053       return layout_param;
2054     }
2055     CNodePtr use_apply = node_pair.first->cast<CNodePtr>();
2056     if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
2057       continue;
2058     }
2059     ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
2060     MS_EXCEPTION_IF_NULL(prim_anf_node);
2061     PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
2062     MS_EXCEPTION_IF_NULL(node_prim);
2063     if ((node_prim->name() == DEPEND && node_pair.second != 1) || node_prim->name() == RESHAPE) {
2064       continue;
2065     }
2066     if (IsParallelCareNode(use_apply) && use_apply->has_user_data<OperatorInfo>()) {
2067       auto layout = GetInputLayoutFromCNode(node_pair);
2068       return std::make_shared<TensorLayout>(layout);
2069     }
2070   }
2071   return nullptr;
2072 }
2073 
CreateParameterLayout(const AnfNodePtr & node)2074 std::shared_ptr<TensorLayout> CreateParameterLayout(const AnfNodePtr &node) {
2075   // Create DataParallel tensor layout for parameter(support WideDeep).
2076   auto next_layout = FindParameterNextLayout(node, 0);
2077   if (next_layout != nullptr) {
2078     return next_layout;
2079   }
2080   CheckGlobalDeviceManager();
2081   int64_t dev_num = g_device_manager->stage_device_num();
2082   TensorLayout input_tensor_layout;
2083   // create input_shape
2084   Shapes inputs_shape = GetNodeShape(node);
2085   Shape input_shape_array = inputs_shape[0];
2086   if (input_shape_array.empty()) {
2087     MS_LOG(EXCEPTION) << "Don't support reshape a scalar parameter.";
2088   }
2089   // create tensor_map
2090   size_t shape_size = input_shape_array.size();
2091   TensorMap input_tensor_map_array(SizeToLong(shape_size) - 1, -1);
2092   input_tensor_map_array.insert(input_tensor_map_array.begin(), 0);
2093   // create dev_matrix
2094   Shape dev_matrix_array = {dev_num};
2095   if (input_tensor_layout.InitFromVector(dev_matrix_array, input_tensor_map_array, input_shape_array) != SUCCESS) {
2096     MS_LOG(EXCEPTION) << "Create tensor layout for parameter failed.";
2097   }
2098   return std::make_shared<TensorLayout>(input_tensor_layout);
2099 }
2100 
InferSensRedistribution(const AnfNodePtr & node,const TensorLayout & loss_layout)2101 RedistributionOpListPtr InferSensRedistribution(const AnfNodePtr &node, const TensorLayout &loss_layout) {
2102   MS_EXCEPTION_IF_NULL(node);
2103   TensorRedistribution tensor_redistribution;
2104   // create stand alone layout:TensorMap:[all -1],dev_matrix:[dev_num].
2105   CheckGlobalDeviceManager();
2106   int64_t dev_num = g_device_manager->stage_device_num();
2107   TensorLayout stand_alone_layout;
2108   Shapes inputs_shape = GetNodeShape(node);
2109   if (inputs_shape.empty()) {
2110     MS_LOG(EXCEPTION) << "InferSensRedistribution failed cause inputs shape is empty.";
2111   }
2112   Shape input_shape_array = inputs_shape[0];
2113   if (input_shape_array.empty()) {
2114     MS_LOG(INFO) << "No need to redistribution for sens.";
2115     return nullptr;
2116   }
2117   // TensorMap
2118   TensorMap stand_alone_tensor_map_array(SizeToLong(input_shape_array.size()), -1);
2119   // Dev_matrix
2120   Shape dev_matrix_array = {dev_num};
2121   if (stand_alone_layout.InitFromVector(dev_matrix_array, stand_alone_tensor_map_array, input_shape_array) == FAILED) {
2122     MS_LOG(EXCEPTION) << "Create tensor layout for Sens failed.";
2123   }
2124 
2125   // Infer Redistribution op list for stand alone and loss layout.
2126   RankList dev_list = g_device_manager->GetDeviceListInThisStage();
2127   if (tensor_redistribution.Init(stand_alone_layout, loss_layout, dev_list) == FAILED) {
2128     MS_LOG(EXCEPTION) << "Redistribution for Sens init failed.";
2129   }
2130   RedistributionOpListPtr sens_redistribution_list = tensor_redistribution.InferTensorRedistributionOperatorList();
2131   MS_EXCEPTION_IF_NULL(sens_redistribution_list);
2132 
2133   return sens_redistribution_list;
2134 }
2135 
FindPrevLayout(const AnfNodePtr & node)2136 std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node) {
2137   if (node->isa<Parameter>()) {
2138     return CreateParameterLayout(node);
2139   }
2140   if (!node->isa<CNode>()) {
2141     return nullptr;
2142   }
2143   CNodePtr cnode = node->cast<CNodePtr>();
2144   if (!IsValueNode<Primitive>(cnode->input(0))) {
2145     return nullptr;
2146   }
2147   if (IsPrimitiveCNode(node, prim::kPrimReceive)) {
2148     return cnode->user_data<TensorLayout>();
2149   }
2150   if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>() &&
2151       !IsPrimitiveCNode(node, prim::kPrimReshape)) {
2152     auto layout_ptr = GetOutputLayoutFromCNode(cnode, 0);
2153     if (!layout_ptr) {
2154       MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed";
2155     }
2156     return layout_ptr;
2157   }
2158   ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
2159   PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
2160   if (prim->name() == prim::kTupleGetItem) {
2161     auto tuple_index = GetTupleGetItemIndex(cnode);
2162     auto layout_ptr = FindPrevParallelCareNodeLayout(cnode->input(1), LongToSize(tuple_index));
2163     if (!layout_ptr) {
2164       MS_LOG(EXCEPTION)
2165         << " Failure:FindPrevLayout failed, tuple_getitem before reshape, but there does not exit a parallel care node "
2166            "before tuple_getitem!";
2167     }
2168     return layout_ptr;
2169   }
2170   for (size_t index = 0; index < cnode->inputs().size(); ++index) {
2171     if (prim->name() == DEPEND && index != 1) {
2172       continue;
2173     }
2174     auto layout_ptr = FindPrevLayout(cnode->inputs()[index]);
2175     if (!layout_ptr) {
2176       continue;
2177     }
2178     return layout_ptr;
2179   }
2180   MS_LOG(WARNING) << "FindPrevLayout return nullptr, if reshape is not the first primitive, there must be some error";
2181   return nullptr;
2182 }
2183 
ReshapeInit(const std::vector<AnfNodePtr> & all_nodes)2184 void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes) {
2185   for (auto &node : all_nodes) {
2186     auto cnode = node->cast<CNodePtr>();
2187     if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
2188       continue;
2189     }
2190     ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
2191     if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) {
2192       continue;
2193     }
2194     PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
2195     MS_EXCEPTION_IF_NULL(prim);
2196     OperatorInfoPtr operator_info = cnode->user_data<OperatorInfo>();
2197     if (operator_info == nullptr) {
2198       MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr";
2199     }
2200     if (prim->name() != RESHAPE) {
2201       continue;
2202     }
2203     auto attrs = prim->attrs();
2204     if (StrategyFound(attrs)) {
2205       MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!";
2206     }
2207     MS_ASSERT(cnode->inputs().size() == 3);
2208     auto prev_layout_ptr = FindPrevLayout(cnode->input(1));
2209     if (prev_layout_ptr) {
2210       auto reshape_info_ptr = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
2211       reshape_info_ptr->SetInputLayout(*prev_layout_ptr);
2212     }
2213     bool is_next_reshape = false;
2214     auto next_layout_ptr = FindNextLayout(cnode, &is_next_reshape);
2215     if (next_layout_ptr) {
2216       auto reshape_info_ptr = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
2217       reshape_info_ptr->SetOutputLayout(*next_layout_ptr);
2218     } else if (is_next_reshape && prev_layout_ptr != nullptr) {
2219       auto reshape_info_ptr = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
2220       reshape_info_ptr->SetOutputLayout(*prev_layout_ptr);
2221     }
2222     if (operator_info->Init(nullptr) == FAILED) {
2223       MS_LOG(EXCEPTION) << "Failure:operator " << prim->ToString() << " init failed";
2224     }
2225   }
2226 }
2227 
HandleDependLoss(const CNodePtr & cnode,size_t curr_depth)2228 CNodePtr HandleDependLoss(const CNodePtr &cnode, size_t curr_depth) {
2229   if (curr_depth > MAX_RECURSIVE_DEPTH) {
2230     MS_LOG(WARNING) << "When handling the loss node of Depend, exceeded the max recursive depth: "
2231                     << MAX_RECURSIVE_DEPTH;
2232     return nullptr;
2233   }
2234   // Handle return->depend->loss
2235   if (IsPrimitiveCNode(cnode, prim::kPrimDepend) ||
2236       (IsPrimitiveCNode(cnode, prim::kPrimCast) && !cnode->has_user_data<OperatorInfo>())) {
2237     auto depend_before = cnode->input(1)->cast<CNodePtr>();
2238     MS_EXCEPTION_IF_NULL(depend_before);
2239     return HandleDependLoss(depend_before, ++curr_depth);
2240   }
2241   return cnode;
2242 }
2243 
FindLossCNode(const FuncGraphPtr & func_graph,size_t max_depth)2244 LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph, size_t max_depth) {
2245   if (max_depth > MAX_RECURSIVE_DEPTH) {
2246     MS_LOG(EXCEPTION) << "Recursive call is larger than 100000.";
2247   }
2248   LossNodeInfo loss_node_info;
2249   MS_EXCEPTION_IF_NULL(func_graph);
2250   CNodePtr return_node = func_graph->get_return();
2251   MS_EXCEPTION_IF_NULL(return_node);
2252   if (return_node->size() < 2) {
2253     MS_LOG(EXCEPTION) << "Failure: " << return_node->DebugString() << " size is smaller than 2";
2254   }
2255   AnfNodePtr pre_node = return_node->input(1);
2256   MS_EXCEPTION_IF_NULL(pre_node);
2257   auto pre_cnode = pre_node->cast<CNodePtr>();
2258   pre_cnode = HandleDependLoss(pre_cnode, 0);
2259   if (pre_cnode->input(0)->isa<CNode>()) {
2260     auto switch_cnode = pre_cnode->input(0)->cast<CNodePtr>();
2261     if (IsPrimitiveCNode(switch_cnode, prim::kPrimSwitch)) {
2262       MS_EXCEPTION_IF_NULL(switch_cnode);
2263       auto switch_graph = GetValueNode<FuncGraphPtr>(switch_cnode->input(2));
2264       return FindLossCNode(switch_graph, max_depth + 1);
2265     }
2266   }
2267   if (pre_cnode == nullptr || !IsValueNode<Primitive>(pre_cnode->input(0))) {
2268     return loss_node_info;
2269   }
2270   if (!IsValueNode<Primitive>(pre_cnode->input(0))) {
2271     MS_LOG(DEBUG) << "pre_cnode:" << pre_cnode->ToString();
2272     return loss_node_info;
2273   }
2274   auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
2275   // notice: the GetNext op has not input
2276   if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) {
2277     MS_LOG(INFO) << "The loss is: " << current_prim->name();
2278     loss_node_info.loss_node = pre_cnode;
2279     return loss_node_info;
2280   }
2281 
2282   // size of common cnode is larger than 1
2283   if (pre_cnode->size() < 2) {
2284     MS_LOG(EXCEPTION) << pre_cnode->ToString() << " size( " << pre_cnode->inputs().size() << " ) is smaller than 2";
2285   }
2286 
2287   // return -> tuple_getitem -> loss
2288   if (current_prim->name() == prim::kTupleGetItem) {
2289     auto tuple_index = GetTupleGetItemIndex(pre_cnode);
2290     AnfNodePtr pre_pre_node = pre_cnode->input(1);
2291     MS_EXCEPTION_IF_NULL(pre_pre_node);
2292 
2293     auto pre_pre_cnode = pre_pre_node->cast<CNodePtr>();
2294     loss_node_info.has_tuple_getitem = true;
2295     loss_node_info.dout_index = tuple_index;
2296     loss_node_info.loss_node = pre_pre_cnode;
2297     return loss_node_info;
2298   }
2299 
2300   // return -> make_tuple
2301   if (current_prim->name() == MAKE_TUPLE) {
2302     MS_LOG(WARNING) << "The loss have make_tuple, it is not supported";
2303     return loss_node_info;
2304   }
2305 
2306   // return -> loss
2307   loss_node_info.loss_node = pre_cnode;
2308   MS_LOG(DEBUG) << "The loss name is " << current_prim->name();
2309   return loss_node_info;
2310 }
2311 
GetLossNodeGradOutputLayout(const LossNodeInfo & node_info)2312 TensorLayouts GetLossNodeGradOutputLayout(const LossNodeInfo &node_info) {
2313   TensorLayouts ret;
2314   auto loss_cnode = node_info.loss_node;
2315   MS_EXCEPTION_IF_NULL(loss_cnode);
2316 
2317   ValueNodePtr prim_anf_node = loss_cnode->input(0)->cast<ValueNodePtr>();
2318   MS_EXCEPTION_IF_NULL(prim_anf_node);
2319   PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
2320   MS_EXCEPTION_IF_NULL(prim);
2321   if (INVALID_LOSS_OPS.find(prim->name()) != INVALID_LOSS_OPS.end()) {
2322     MS_LOG(WARNING) << "The loss name is: " << prim->name() << ", do nothing for split sens now";
2323     return ret;
2324   }
2325 
2326   OperatorInfoPtr operator_info = loss_cnode->user_data<OperatorInfo>();
2327   MS_EXCEPTION_IF_NULL(operator_info);
2328   TensorInfo loss_grad_tensor_info;
2329   size_t op_output_size = operator_info->outputs_tensor_info().size();
2330   MS_LOG(INFO) << "The loss name is " << operator_info->name() << ", the has tuple item is  "
2331                << node_info.has_tuple_getitem << ", the output size is  " << op_output_size << ", the dout_index is  "
2332                << node_info.dout_index;
2333 
2334   if ((op_output_size == 0) || (op_output_size <= LongToSize(node_info.dout_index))) {
2335     MS_LOG(EXCEPTION) << "The index is  " << node_info.dout_index << ", but the size of outputs is  " << op_output_size;
2336   }
2337 
2338   if (!node_info.has_tuple_getitem && (op_output_size > 1)) {
2339     MS_LOG(EXCEPTION) << "Currently, it is not supported that the sens is a tuple.";
2340   }
2341 
2342   loss_grad_tensor_info = operator_info->outputs_tensor_info()[LongToSize(node_info.dout_index)];
2343   ret.push_back(loss_grad_tensor_info.tensor_layout());
2344   return ret;
2345 }
2346 
SplitSens(const CNodePtr & grad_sens_node,const TensorLayout & loss_grad_layout)2347 void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_layout) {
2348   MS_EXCEPTION_IF_NULL(grad_sens_node);
2349   if (grad_sens_node->size() <= 1) {
2350     MS_LOG(EXCEPTION) << "The size of grad sens node is smaller than 2";
2351   }
2352   AnfNodePtr sens_tensor_node = grad_sens_node->input(1);
2353   MS_EXCEPTION_IF_NULL(sens_tensor_node);
2354   Shapes sens_shapes = GetNodeShape(sens_tensor_node);
2355   if (sens_shapes.size() != 1) {
2356     MS_LOG(EXCEPTION) << "GetNodeShape for sens_tensor_node, output size is not 1";
2357   }
2358   // If the shape of sens tensor is [] or [1], no need to split it.
2359   Shape sens_shape = sens_shapes[0];
2360   if (sens_shape.empty() || ((sens_shape.size() == 1) && (sens_shape[0] == 1))) {
2361     if (sens_tensor_node->isa<Parameter>()) {
2362       auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>();
2363       MS_LOG(DEBUG) << "loss layout " << loss_grad_layout.ToString();
2364       sens_tensor_param->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout));
2365     }
2366     MS_LOG(INFO) << "The shape of sens is " << ShapeToString(sens_shape) << ", no need to split sens";
2367     return;
2368   }
2369   auto loss_shape = loss_grad_layout.tensor_shape().array();
2370   if (loss_shape != sens_shape) {
2371     MS_LOG(EXCEPTION) << "The shape of sens is not equal to loss output, it is unsupported now. Sens shape is "
2372                       << ShapeToString(sens_shape) << ", loss shape is " << ShapeToString(loss_shape);
2373   }
2374   MS_LOG(INFO) << "The shape of sens is " << ShapeToString(sens_shape) << ", split it.";
2375 
2376   if (!IsValueNode<Tensor>(sens_tensor_node)) {
2377     if (sens_tensor_node->isa<Parameter>()) {
2378       MS_LOG(DEBUG) << "loss layout " << loss_grad_layout.ToString();
2379       AbstractBasePtr abstract = sens_tensor_node->abstract();
2380       MS_EXCEPTION_IF_NULL(abstract);
2381       auto slice_shape = loss_grad_layout.slice_shape().array();
2382       std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
2383       MS_EXCEPTION_IF_NULL(parallel_shape);
2384       auto cloned_abstract = abstract->Clone();
2385       MS_EXCEPTION_IF_NULL(cloned_abstract);
2386       cloned_abstract->set_shape(parallel_shape);
2387       sens_tensor_node->set_abstract(cloned_abstract);
2388       auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>();
2389       sens_tensor_param->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout));
2390       return;
2391     }
2392     if (sens_tensor_node->isa<CNode>()) {
2393       auto op_list_ptr = InferSensRedistribution(sens_tensor_node, loss_grad_layout);
2394       if (op_list_ptr == nullptr) {
2395         return;
2396       }
2397       auto sens_tensor_cnode = sens_tensor_node->cast<CNodePtr>();
2398       auto func_graph = grad_sens_node->func_graph();
2399       MS_EXCEPTION_IF_NULL(func_graph);
2400       InsertRedistribution(op_list_ptr, grad_sens_node, func_graph, 1, sens_tensor_cnode);
2401       return;
2402     }
2403     MS_LOG(EXCEPTION) << "The type of sens node is not Tensor or Parameter or CNode, it is unsupported now.";
2404   }
2405 
2406   // Use _GetTensorSlice operator to split the sens tensor
2407   FuncGraphPtr func_graph = grad_sens_node->func_graph();  // only cnode can get the graph
2408   MS_EXCEPTION_IF_NULL(func_graph);
2409   Operator op = CreateGetTensorSliceOp(loss_grad_layout);
2410   InsertGetTensorSliceOp(op, grad_sens_node, func_graph, 1, SPLIT_SENS);
2411 }
2412 
InsertForwardOps(const OperatorInfoPtr & distribute_operator,const CNodePtr & cnode)2413 void InsertForwardOps(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) {
2414   MS_EXCEPTION_IF_NULL(distribute_operator);
2415   MS_EXCEPTION_IF_NULL(cnode);
2416   if (IsPrimitiveCNode(cnode, prim::kPrimReceive)) {
2417     return;
2418   }
2419   OperatorVector forward_op = distribute_operator->forward_op();
2420   if (!forward_op.empty()) {
2421     MS_LOG(INFO) << "Insert forward op for " << distribute_operator->name();
2422     ForwardCommunication(forward_op, cnode);
2423   }
2424 }
2425 
StepReplace(const OperatorInfoPtr & distribute_operator,const CNodePtr & cnode)2426 void StepReplace(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) {
2427   MS_EXCEPTION_IF_NULL(distribute_operator);
2428   MS_EXCEPTION_IF_NULL(cnode);
2429   // StepReplaceOp
2430   OperatorVector replace_op = distribute_operator->replace_op();
2431   if (!replace_op.empty()) {
2432     MS_LOG(INFO) << "StepReplaceOp " << cnode->ToString();
2433     StepReplaceOp(replace_op, cnode);
2434   }
2435 
2436   // StepReplaceGraph: after calling StepReplaceGraph, cnode can not be used anymore.
2437   ReplaceGraphPtr replace_graph = distribute_operator->replace_graph(cnode);
2438   if (!replace_op.empty() && replace_graph) {
2439     MS_LOG(EXCEPTION) << "Only one of replace_op or replace_op can be used";
2440   }
2441   if (replace_graph) {
2442     MS_LOG(INFO) << "StepReplaceGraph " << cnode->ToString();
2443     StepReplaceGraph(replace_graph, cnode);
2444   }
2445 }
2446 
FindForwardGraphByRootNodes(const AnfNodeSet & root_all_nodes)2447 std::set<FuncGraphPtr> FindForwardGraphByRootNodes(const AnfNodeSet &root_all_nodes) {
2448   // J->CNode->Graph
2449   std::set<FuncGraphPtr> graph_set;
2450   for (auto &node : root_all_nodes) {
2451     MS_EXCEPTION_IF_NULL(node);
2452     if (!node->isa<CNode>()) {
2453       continue;
2454     }
2455 
2456     auto cnode = node->cast<CNodePtr>();
2457     if ((cnode->size() < 2) || !IsValueNode<Primitive>(cnode->input(0))) {
2458       continue;
2459     }
2460     auto expect_j_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
2461     if (expect_j_prim->name() != J) {
2462       continue;
2463     }
2464     if (IsValueNode<FuncGraph>(cnode->input(1))) {
2465       auto graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
2466       MS_LOG(DEBUG) << "Find the forward graph success";
2467       graph_set.insert(graph);
2468       auto manager = graph->manager();
2469       MS_EXCEPTION_IF_NULL(manager);
2470       auto graph_used = manager->func_graphs_used_total(graph);
2471       for (auto &sub_graph : graph_used) {
2472         graph_set.insert(sub_graph);
2473       }
2474     }
2475   }
2476   return graph_set;
2477 }
2478 
StepSplitSens(const std::pair<CNodePtr,LossNodeInfo> & sens_loss_pair)2479 void StepSplitSens(const std::pair<CNodePtr, LossNodeInfo> &sens_loss_pair) {
2480   CNodePtr sens_node = sens_loss_pair.first;
2481   auto loss_node = sens_loss_pair.second;
2482   auto loss_grad_layout = GetLossNodeGradOutputLayout(loss_node);
2483   if (!loss_grad_layout.empty()) {
2484     SplitSens(sens_node, loss_grad_layout[0]);
2485   }
2486 }
2487 
2488 // Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J)
GetSensLossPairs(const FuncGraphPtr & root)2489 std::vector<std::pair<CNodePtr, LossNodeInfo>> GetSensLossPairs(const FuncGraphPtr &root) {
2490   MS_EXCEPTION_IF_NULL(root);
2491   std::vector<std::pair<CNodePtr, LossNodeInfo>> sens_loss_pairs;
2492   for (auto &node : root->nodes()) {
2493     if (!node->isa<CNode>()) {
2494       continue;
2495     }
2496 
2497     // cnode(sens)-->cnode(tuple_getitem)
2498     auto sens_cnode = node->cast<CNodePtr>();
2499     AnfNodePtr expect_tuple_getitem = sens_cnode->input(0);
2500     MS_EXCEPTION_IF_NULL(expect_tuple_getitem);
2501     if (!expect_tuple_getitem->isa<CNode>()) {
2502       continue;
2503     }
2504 
2505     auto expect_tuple_getitem_cnode = expect_tuple_getitem->cast<CNodePtr>();
2506     if (!IsSomePrimitive(expect_tuple_getitem_cnode, prim::kTupleGetItem)) {
2507       continue;
2508     }
2509 
2510     // cnode(sens)-->cnode(tuple_getitem)-->cnode
2511     AnfNodePtr expect_anonymous = expect_tuple_getitem_cnode->input(1);
2512     MS_EXCEPTION_IF_NULL(expect_anonymous);
2513     if (!expect_anonymous->isa<CNode>()) {
2514       continue;
2515     }
2516 
2517     // cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J)
2518     auto expect_anonymous_cnode = expect_anonymous->cast<CNodePtr>();
2519     AnfNodePtr expect_j = expect_anonymous_cnode->input(0);
2520     MS_EXCEPTION_IF_NULL(expect_j);
2521     if (!expect_j->isa<CNode>()) {
2522       continue;
2523     }
2524     auto expect_j_cnode = expect_j->cast<CNodePtr>();
2525     if (!IsSomePrimitive(expect_j_cnode, J)) {
2526       continue;
2527     }
2528 
2529     if (!IsValueNode<FuncGraph>(expect_j_cnode->input(1))) {
2530       MS_LOG(EXCEPTION) << "Sens can't find the corresponding graph.";
2531     }
2532     auto func_graph = GetValueNode<FuncGraphPtr>(expect_j_cnode->input(1));
2533     auto loss_node_info = FindLossCNode(func_graph, 0);
2534     if (loss_node_info.loss_node == nullptr) {
2535       MS_LOG(WARNING) << "Can not find the loss cnode";
2536       continue;
2537     }
2538     std::pair<CNodePtr, LossNodeInfo> sens_loss_pair = std::make_pair(sens_cnode, loss_node_info);
2539     sens_loss_pairs.push_back(sens_loss_pair);
2540   }
2541   return sens_loss_pairs;
2542 }
2543 
ParallelCommunication(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & all_nodes,const FuncGraphManagerPtr & manager)2544 void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
2545                            const FuncGraphManagerPtr &manager) {
2546   MS_EXCEPTION_IF_NULL(root);
2547   MS_EXCEPTION_IF_NULL(manager);
2548   TensorRedistribution tensor_redistribution;
2549 
2550   std::vector<std::pair<CNodePtr, LossNodeInfo>> sens_loss_pairs = GetSensLossPairs(root);
2551   bool has_backward = !sens_loss_pairs.empty();
2552   // split sens must before inserting the operators.
2553   for (auto &pair : sens_loss_pairs) {
2554     // If the shape of grad-sens tensor is not [] or [1], use get tensor slice to handle it.
2555     // If the type of sens node is not Tensor, it is unsupported now, do nothing default.
2556     if (IsLastStage()) {
2557       StepSplitSens(pair);
2558     }
2559   }
2560 
2561   for (auto &node : all_nodes) {
2562     MS_EXCEPTION_IF_NULL(node);
2563     if (node->isa<CNode>()) {
2564       auto cnode = node->cast<CNodePtr>();
2565       // the make_tuple is parallel care node, but it may have not operator info
2566       if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) {
2567         continue;
2568       }
2569 
2570       OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode);
2571       MS_EXCEPTION_IF_NULL(distribute_operator);
2572 
2573       // skip Send Receive
2574       if (!cnode->HasPrimalAttr(PIPELINE_PARAM)) {
2575         // insert forward ops
2576         InsertForwardOps(distribute_operator, cnode);
2577 
2578         // insert redistribution ops
2579         StepRedistribution(cnode, distribute_operator, cnode, tensor_redistribution, cnode);
2580       }
2581       // insert backward ops
2582       if (has_backward) {
2583         BackwardCommunication(root, distribute_operator, cnode, sens_loss_pairs);
2584       }
2585 
2586       distribute_operator->ReplaceNodeInputOrAttrs();
2587     } else if (IsValueNode<Tensor>(node) || IsValueNode<ValueList>(node) || IsValueNode<ValueTuple>(node)) {
2588       StepSplitTensor(node, manager);
2589     }
2590   }
2591 
2592   for (auto &node : all_nodes) {
2593     MS_EXCEPTION_IF_NULL(node);
2594     if (node->isa<CNode>()) {
2595       auto cnode = node->cast<CNodePtr>();
2596       if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>() || IsSomePrimitive(cnode, RECEIVE) ||
2597           IsSomePrimitive(cnode, SEND)) {
2598         continue;
2599       }
2600 
2601       OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode);
2602       MS_EXCEPTION_IF_NULL(distribute_operator);
2603       // StepReplace
2604       StepReplace(distribute_operator, cnode);
2605     }
2606   }
2607 }
2608 
IsCohesiveNode(const CNodePtr & cnode)2609 bool IsCohesiveNode(const CNodePtr &cnode) {
2610   return IsPrimitiveCNode(cnode, prim::kPrimCast) || IsPrimitiveCNode(cnode, prim::kPrimLoad) ||
2611          IsPrimitiveCNode(cnode, prim::kPrimAllGather) || IsPrimitiveCNode(cnode, prim::kPrimMiniStepAllGather) ||
2612          IsPrimitiveCNode(cnode, prim::kPrimMicroStepAllGather);
2613 }
2614 
NodeParameterName(const CNodePtr & node,int64_t index,size_t curr_depth)2615 ParameterMap NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth) {
2616   if (curr_depth > MAX_RECURSIVE_DEPTH) {
2617     MS_LOG(WARNING) << "When finding the parameters' name of a operator, exceeded the maximum depth: "
2618                     << MAX_RECURSIVE_DEPTH;
2619     return {};
2620   }
2621   std::vector<AnfNodePtr> node_inputs{node->inputs()};
2622   ParameterMap param_names;
2623   for (int64_t i = 0; i < UlongToLong(node_inputs.size()); ++i) {
2624     int64_t idx = index > i ? index : i;
2625     auto input = node_inputs[LongToSize(i)];
2626     if (input->isa<Parameter>()) {
2627       auto input_parameter = input->cast<ParameterPtr>();
2628       if (input_parameter->has_default() && ParameterRequireGrad(input_parameter)) {
2629         (void)param_names.emplace_back(std::make_pair(input_parameter->name(), input_parameter));
2630       }
2631     } else if (input->isa<CNode>()) {
2632       CNodePtr cnode = input->cast<CNodePtr>();
2633       if (!IsValueNode<Primitive>(cnode->input(0))) {
2634         continue;
2635       }
2636       if (IsCohesiveNode(cnode) && cnode->inputs().size() >= 1) {
2637         auto input_param_names = NodeParameterName(cnode, idx, 0);
2638         param_names.insert(param_names.end(), input_param_names.begin(), input_param_names.end());
2639       }
2640     }
2641   }
2642   return param_names;
2643 }
2644 
IsGatherPInfo(const std::string & name)2645 bool IsGatherPInfo(const std::string &name) {
2646   std::vector<std::string> gather_p_info_names = {"GatherPInfo", "SparseGatherV2Info", "EmbeddingLookupInfo"};
2647   for (std::string info_name : gather_p_info_names) {
2648     if (name.find(info_name) != std::string::npos) {
2649       return true;
2650     }
2651   }
2652   return false;
2653 }
2654 
CheckpointStrategy(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr & root)2655 void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
2656   StrategyMap stra_map;
2657   TensorInfoMap tensor_info_map;
2658   ManualShapeMap manual_shape_map;
2659   for (auto &node : all_nodes) {
2660     MS_EXCEPTION_IF_NULL(node);
2661     auto cnode = node->cast<CNodePtr>();
2662     if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
2663       continue;
2664     }
2665     auto param_names = NodeParameterName(cnode, -1, 0);
2666     if (param_names.empty()) {
2667       continue;
2668     }
2669     string param_name = param_names[0].first;
2670     PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
2671     MS_EXCEPTION_IF_NULL(prim);
2672     OperatorInfoPtr operator_info = cnode->user_data<OperatorInfo>();
2673     if (operator_info) {
2674       if (operator_info->name().find(RESHAPEINFO) != std::string::npos) {
2675         continue;
2676       }
2677       std::string stratey_key_name = prim->name() + "_" + param_name;
2678       stra_map[stratey_key_name] = operator_info->strategy();
2679       for (auto param_name_pair : param_names) {
2680         tensor_info_map[param_name_pair.first] = param_name_pair.second->user_data<TensorLayout>();
2681       }
2682       if (IsGatherPInfo(operator_info->name())) {
2683         auto gatherv2_info = std::dynamic_pointer_cast<GatherPInfo>(operator_info);
2684         auto param_split_shapes = gatherv2_info->param_split_shapes();
2685         auto index_offsets = gatherv2_info->index_offsets();
2686         if (param_split_shapes.size() != index_offsets.size()) {
2687           MS_LOG(EXCEPTION) << "In manual split, the param_split_shapes and index_offsets length should be same.";
2688         }
2689         std::vector<std::pair<int64_t, int64_t>> manual_shape;
2690         for (int64_t i = 0; i < UlongToLong(param_split_shapes.size()); ++i) {
2691           manual_shape.emplace_back(std::make_pair(param_split_shapes[LongToSize(i)], index_offsets[LongToSize(i)]));
2692         }
2693         manual_shape_map[param_name] = manual_shape;
2694       }
2695     }
2696   }
2697   for (auto &cloned_parameter_node : root->parameters()) {
2698     MS_EXCEPTION_IF_NULL(cloned_parameter_node);
2699     auto cloned_parameter = cloned_parameter_node->cast<ParameterPtr>();
2700     MS_EXCEPTION_IF_NULL(cloned_parameter);
2701 
2702     if (!ParameterIsCloned(cloned_parameter_node)) {
2703       continue;
2704     }
2705     std::string cloned_param_name = cloned_parameter_node->cast<ParameterPtr>()->name();
2706     auto cloned_param_layout = cloned_parameter_node->user_data<TensorLayout>();
2707     if (cloned_param_layout == nullptr) {
2708       continue;
2709     }
2710     tensor_info_map[cloned_param_name] = cloned_param_layout;
2711   }
2712   if (StrategyCheckpoint::GetInstance().Save(stra_map, tensor_info_map, &manual_shape_map) != SUCCESS) {
2713     MS_LOG(EXCEPTION) << "Save strategy checkpoint failed";
2714   }
2715 }
2716 
SetForwardFlag(const std::vector<AnfNodePtr> & all_nodes)2717 void SetForwardFlag(const std::vector<AnfNodePtr> &all_nodes) {
2718   for (auto &node : all_nodes) {
2719     MS_EXCEPTION_IF_NULL(node);
2720     if (!node->isa<CNode>()) {
2721       continue;
2722     }
2723     auto cnode = node->cast<CNodePtr>();
2724     if (!IsValueNode<Primitive>(cnode->input(0))) {
2725       continue;
2726     }
2727 
2728     // CNode is globally unique.
2729     MS_LOG(DEBUG) << "Set forward flag " << cnode->DebugString() << ".";
2730     cnode->set_in_forward_flag(true);
2731   }
2732 }
2733 
SetForwardFlag(const AnfNodeSet & all_nodes)2734 void SetForwardFlag(const AnfNodeSet &all_nodes) {
2735   for (auto &node : all_nodes) {
2736     MS_EXCEPTION_IF_NULL(node);
2737     if (!node->isa<CNode>()) {
2738       continue;
2739     }
2740     auto cnode = node->cast<CNodePtr>();
2741     if (!IsValueNode<Primitive>(cnode->input(0))) {
2742       continue;
2743     }
2744 
2745     // CNode is globally unique.
2746     cnode->set_in_forward_flag(true);
2747   }
2748 }
2749 
ForwardGraph(const FuncGraphPtr & root)2750 std::set<FuncGraphPtr> ForwardGraph(const FuncGraphPtr &root) {
2751   MS_EXCEPTION_IF_NULL(root);
2752   const auto &all_nodes = root->nodes();
2753   std::set<FuncGraphPtr> graph_set = FindForwardGraphByRootNodes(all_nodes);
2754   return graph_set;
2755 }
2756 
FindRootForwardCNode(const FuncGraphPtr & graph,const AnfNodeSet & all_nodes)2757 std::vector<AnfNodePtr> FindRootForwardCNode(const FuncGraphPtr &graph, const AnfNodeSet &all_nodes) {
2758   MS_EXCEPTION_IF_NULL(graph);
2759   std::vector<AnfNodePtr> root_forward_nodes;
2760   auto loss_cnode = FindLossCNode(graph, 0).loss_node;
2761   if (loss_cnode == nullptr) {
2762     MS_LOG(WARNING) << "Can not find the loss cnode";
2763     return root_forward_nodes;
2764   }
2765 
2766   auto loss_cnode_id = loss_cnode->UniqueIdThroughCopy();
2767   for (auto &node : all_nodes) {
2768     MS_EXCEPTION_IF_NULL(node);
2769     if (!node->isa<CNode>()) {
2770       continue;
2771     }
2772     auto cnode = node->cast<CNodePtr>();
2773     auto root_node_id = node->UniqueIdThroughCopy();
2774     if (loss_cnode_id == root_node_id) {
2775       root_forward_nodes = DeepLinkedGraphSearch(cnode);
2776       break;
2777     }
2778   }
2779   return root_forward_nodes;
2780 }
2781 
InsertShapeOp(const CNodePtr & node,const AnfNodePtr & pre_node,const FuncGraphPtr & root)2782 void InsertShapeOp(const CNodePtr &node, const AnfNodePtr &pre_node, const FuncGraphPtr &root) {
2783   // shape op doesn't have params and attrs.
2784   OperatorParams params;
2785   OperatorAttrs attrs;
2786   auto shape_value = GetValueNode(node->input(2))->cast<ValueSequeuePtr>();
2787   MS_EXCEPTION_IF_NULL(shape_value);
2788   auto shape = shape_value->value();
2789   if (shape.empty()) {
2790     return;
2791   }
2792   OperatorArgs args = std::make_pair(attrs, params);
2793   Operator op = std::make_pair(SHAPE_OP, args);
2794   InsertNode(op, node, 2, pre_node, root, "shape");
2795 }
2796 
FindGrad(const CNodePtr & cnode,size_t curr_depth)2797 static AnfNodePtr FindGrad(const CNodePtr &cnode, size_t curr_depth) {
2798   if (curr_depth > MAX_RECURSIVE_DEPTH) {
2799     MS_LOG(WARNING) << "When finding Grad nodes, exceeded the maximum recursion depth: " << MAX_RECURSIVE_DEPTH;
2800     return nullptr;
2801   }
2802   for (auto &node : cnode->inputs()) {
2803     if (!node->isa<CNode>()) {
2804       continue;
2805     }
2806     if (!IsPrimitiveCNode(node, prim::kPrimEnvGetItem)) {
2807       return FindGrad(node->cast<CNodePtr>(), ++curr_depth);
2808     } else {
2809       return node;
2810     }
2811   }
2812   return nullptr;
2813 }
2814 
HandleRootReshapeAndSaveStrategy(const std::vector<AnfNodePtr> & all_nodes)2815 void HandleRootReshapeAndSaveStrategy(const std::vector<AnfNodePtr> &all_nodes) {
2816   // If root graph has reshape op. Find the corresponding parameter.
2817   // Reshape's shape is the shape of the parameter.
2818   auto executor = pipeline::GraphExecutorPy::GetInstance();
2819   for (auto &node : all_nodes) {
2820     if (!node->isa<CNode>()) {
2821       continue;
2822     }
2823     auto cnode = node->cast<CNodePtr>();
2824     if (!IsValueNode<Primitive>(cnode->input(0)) || cnode == nullptr) {
2825       continue;
2826     }
2827     if (cnode->in_forward_flag()) {
2828       // Save strategy in executor
2829       OperatorInfoPtr op_info = cnode->user_data<OperatorInfo>();
2830       if (op_info) {
2831         auto stra_ptr = op_info->strategy();
2832         if (stra_ptr) {
2833           auto strategy = stra_ptr->GetInputDim();
2834           // fullname with scope should be found in step parallel end ir
2835           executor->SetCNodeStrategy(cnode->fullname_with_scope(), strategy);
2836         }
2837       }
2838       continue;
2839     }
2840 
2841     auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
2842     if (prim->name() != RESHAPE) {
2843       continue;
2844     }
2845     auto root = node->func_graph();
2846     auto grad_node = FindGrad(cnode, 0);
2847     if (grad_node) {
2848       InsertShapeOp(cnode, grad_node, root);
2849     }
2850   }
2851 }
2852 
MarkForwardCNode(const FuncGraphPtr & root)2853 void MarkForwardCNode(const FuncGraphPtr &root) {
2854   MS_EXCEPTION_IF_NULL(root);
2855   auto all_nodes = root->nodes();
2856   auto graph_set = FindForwardGraphByRootNodes(all_nodes);
2857 
2858   if (graph_set.empty()) {
2859     MS_LOG(INFO) << "Can not find the forward graph, so mark the ops in root graph";
2860     SetForwardFlag(all_nodes);
2861   } else {
2862     for (auto &func_graph : graph_set) {
2863       MS_LOG(INFO) << "The sub graph size of root is " << root->func_graphs_used().size();
2864       auto return_node = func_graph->get_return();
2865       MS_EXCEPTION_IF_NULL(return_node);
2866       auto all_dfs_nodes = DeepLinkedGraphSearch(return_node);
2867       SetForwardFlag(all_dfs_nodes);
2868       auto root_forward_nodes = FindRootForwardCNode(func_graph, all_nodes);
2869       if (root_forward_nodes.empty()) {
2870         continue;
2871       }
2872       // Mark forward flag for the nodes in root graph.
2873       SetForwardFlag(root_forward_nodes);
2874     }
2875   }
2876 }
2877 
GetCommInfo()2878 CommInfo GetCommInfo() {
2879   int64_t device_num = ParallelContext::GetInstance()->device_num();
2880   int64_t global_rank = ParallelContext::GetInstance()->global_rank();
2881   auto ms_context = MsContext::GetInstance();
2882   MS_EXCEPTION_IF_NULL(ms_context);
2883   std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
2884   std::string world_group;
2885   std::string communication_backend;
2886   if (backend == kAscendDevice || backend == kDavinciDevice) {
2887     world_group = HCCL_WORLD_GROUP;
2888     communication_backend = HCCL_BACKEND;
2889   } else if (backend == kGPUDevice) {
2890     world_group = NCCL_WORLD_GROUP;
2891     communication_backend = NCCL_BACKEND;
2892   } else {
2893     MS_LOG(EXCEPTION) << "Invalid communication backend: " << backend;
2894   }
2895   uint32_t world_rank_size = 0;
2896   if (!CommManager::GetInstance().GetRankSize(world_group, &world_rank_size)) {
2897     MS_LOG(EXCEPTION) << "Get rank size failed";
2898   }
2899 
2900   if (!ParallelContext::GetInstance()->device_num_is_set()) {
2901     device_num = UintToInt(world_rank_size);
2902     MS_LOG(INFO) << "Get device num from communication model, the device num is  " << device_num;
2903   }
2904 #if defined(ENABLE_GPU)
2905   if (ParallelContext::GetInstance()->device_num_is_set() && backend == kGPUDevice) {
2906     if (world_rank_size != device_num) {
2907       MS_LOG(EXCEPTION) << "The device_num " << device_num
2908                         << " set in the context is not consist with the word group size " << world_rank_size;
2909     }
2910   }
2911 #endif
2912 
2913   uint32_t rank_id = 0;
2914   if (!ParallelContext::GetInstance()->global_rank_is_set()) {
2915     if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) {
2916       MS_LOG(EXCEPTION) << "Get rank id failed";
2917     }
2918     global_rank = UintToInt(rank_id);
2919     MS_LOG(INFO) << "Get global rank from communication model, the global rank is  " << global_rank;
2920   }
2921   CommInfo comm_info{device_num, global_rank, world_group, communication_backend};
2922   return comm_info;
2923 }
2924 
ParallelInit()2925 Status ParallelInit() {
2926   MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
2927   int32_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
2928   std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
2929   if (split_stage_num <= 0) {
2930     MS_LOG(ERROR) << "Invalid stage num " << split_stage_num << ", expected a positive stage number";
2931     return FAILED;
2932   }
2933   auto comm_info = GetCommInfo();
2934   int64_t device_num = comm_info.device_num;
2935   int64_t global_rank = comm_info.global_rank;
2936   if ((device_num <= 0) || (device_num > MAX_DEVICE_NUM)) {
2937     MS_LOG(ERROR) << "Invalid device num " << device_num;
2938     return FAILED;
2939   }
2940 
2941   // the device_num maybe get from communication interface
2942   if (device_num % split_stage_num != 0) {
2943     MS_LOG(ERROR) << "Device num " << device_num << "  can't be divided by stage num " << split_stage_num;
2944     return FAILED;
2945   }
2946 
2947   if ((global_rank < 0) || (global_rank >= device_num)) {
2948     MS_LOG(ERROR) << "Global rank " << global_rank << " is out of range, the device num is " << device_num;
2949     return FAILED;
2950   }
2951 
2952   std::vector<int64_t> stages;
2953   for (int i = 0; i < split_stage_num; i++) {
2954     stages.push_back(device_num / split_stage_num);
2955   }
2956 
2957   if ((split_stage_num > 1) && (parallel_mode != SEMI_AUTO_PARALLEL)) {
2958     MS_LOG(ERROR) << "To enable the pipeline parallel, please set the parallel mode to " << SEMI_AUTO_PARALLEL;
2959     return FAILED;
2960   }
2961 
2962   if (!InitDevice(device_num, global_rank, comm_info.communication_backend, stages)) {
2963     MS_LOG(ERROR) << "Init device failed";
2964     return FAILED;
2965   }
2966 
2967   MS_LOG(INFO) << "The parallel context: dev num: " << device_num << ", global rank: " << global_rank
2968                << ", communication_backend: " << comm_info.communication_backend
2969                << ", gradients_mean: " << ParallelContext::GetInstance()->gradients_mean()
2970                << ", gradient_fp32_sync: " << ParallelContext::GetInstance()->gradient_fp32_sync();
2971 
2972   return SUCCESS;
2973 }
2974 
HandleForwardMakeTupleAndMakeList(const std::vector<AnfNodePtr> & all_nodes)2975 void HandleForwardMakeTupleAndMakeList(const std::vector<AnfNodePtr> &all_nodes) {
2976   for (auto &node : all_nodes) {
2977     if (!AnfNodeIsPrimitive(node, MAKE_TUPLE) && !AnfNodeIsPrimitive(node, MAKE_LIST)) {
2978       continue;
2979     }
2980 
2981     auto cnode = node->cast<CNodePtr>();
2982     MS_EXCEPTION_IF_NULL(cnode);
2983     if (!cnode->in_forward_flag()) {
2984       continue;
2985     }
2986 
2987     FuncGraphManagerPtr manager = cnode->func_graph()->manager();
2988     MS_EXCEPTION_IF_NULL(manager);
2989     std::string op_type = AnfNodeIsPrimitive(node, MAKE_TUPLE) ? MAKE_TUPLE : MAKE_LIST;
2990 
2991     auto &make_tuple_list_user = manager->node_users()[cnode];
2992     if (make_tuple_list_user.size() != 1) {
2993       MS_LOG(EXCEPTION) << "Now the " << op_type << "'s user must be 1, but got " << make_tuple_list_user.size();
2994     }
2995     CNodePtr make_tuple_list_next_cnode = make_tuple_list_user.front().first->cast<CNodePtr>();
2996     MS_EXCEPTION_IF_NULL(make_tuple_list_next_cnode);
2997 
2998     std::string make_tuple__list_user_prim_name = GetPrimName(make_tuple_list_next_cnode);
2999     if (!IsParallelCareNode(make_tuple_list_next_cnode)) {
3000       MS_LOG(INFO) << "The " << op_type << "'s user is " << make_tuple__list_user_prim_name
3001                    << ", no need to set operator info";
3002       continue;
3003     }
3004     if (make_tuple_list_next_cnode->inputs().size() != 2) {
3005       MS_LOG(EXCEPTION) << "Now the " << op_type << "'s user only support 1 input, but got "
3006                         << (make_tuple_list_next_cnode->inputs().size() - 1);
3007     }
3008 
3009     MS_LOG(INFO) << "Set the " << op_type << "'s operator info, and the op name is " << make_tuple__list_user_prim_name;
3010     OperatorInfoPtr op_info = GetDistributeOperator(make_tuple_list_next_cnode);
3011     MS_EXCEPTION_IF_NULL(op_info);
3012     cnode->set_user_data<OperatorInfo>(op_info);
3013   }
3014 }
3015 
CreateGroupsByCkptFile(const std::string & file)3016 bool CreateGroupsByCkptFile(const std::string &file) {
3017   GroupInfoMap group_info_map;
3018   if (StrategyCheckpoint::GetInstance().LoadGroupInfo(file, &group_info_map) != SUCCESS) {
3019     return false;
3020   }
3021 
3022   if (CreateGroups(group_info_map) != SUCCESS) {
3023     return false;
3024   }
3025   MS_LOG(INFO) << "Create groups by checkpoint file success";
3026   return true;
3027 }
3028 
ReorderForPipelineSplit(const FuncGraphPtr & root,const FuncGraphManagerPtr & manager,int64_t pipeline_stages)3029 void ReorderForPipelineSplit(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager, int64_t pipeline_stages) {
3030   if (!root->has_flag(BACKWARD) && pipeline_stages > 1) {
3031     root->set_flag(BACKWARD, true);
3032     if (root->has_flag(TRAINING)) {
3033       Reorder(root);
3034     } else {
3035       ReorderForPredict(root, manager);
3036     }
3037   }
3038 }
3039 
IsInsertVirtualOutput(const FuncGraphPtr & root)3040 bool IsInsertVirtualOutput(const FuncGraphPtr &root) {
3041   MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
3042   auto comm_info = GetCommInfo();
3043   int64_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
3044   int64_t per_stage_device_num = comm_info.device_num / split_stage_num;
3045   int64_t current_stage = comm_info.global_rank / per_stage_device_num;
3046   MS_LOG(INFO) << "The current stage is: " << current_stage;
3047   if (!root->has_flag(TRAINING) && !ParallelContext::GetInstance()->dataset_strategy().empty()) {
3048     MS_LOG(WARNING) << "In eval/predict net, the output parallel strategy would not follow "
3049                        "the input parallel strategy when using context.set_auto_parallel_context(dataset_strategy)"
3050                        " to configure the input strategy.";
3051   }
3052   return (!root->has_flag(TRAINING) && ParallelContext::GetInstance()->dataset_strategy().empty() &&
3053           current_stage == split_stage_num - 1);
3054 }
3055 
StepParallel(const FuncGraphPtr & root,const opt::OptimizerPtr & optimizer)3056 bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) {
3057 #if ((defined ENABLE_CPU) && (!defined _WIN32))
3058   if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) {
3059     return false;
3060   }
3061 #endif
3062   MS_EXCEPTION_IF_NULL(root);
3063   MS_EXCEPTION_IF_NULL(optimizer);
3064   MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
3065   std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
3066   pipeline::ResourceBasePtr res = optimizer->resource();
3067   MS_EXCEPTION_IF_NULL(res);
3068   FuncGraphManagerPtr manager = res->manager();
3069   MS_EXCEPTION_IF_NULL(manager);
3070   auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num();
3071   // assume no change to graph
3072   bool changes = false;
3073   // control whether use model_parallel mode
3074   if (!root->has_flag(AUTO_PARALLEL) || ((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) ||
3075       (root->has_flag(SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY))) {
3076     if (!root->has_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY)) {
3077       if (HasStrategy(root)) {
3078         MS_LOG(INFO) << "Strategies ignored in " << parallel_mode
3079                      << ", set_strategy() only valid in [semi_]auto_parallel.";
3080       }
3081       root->set_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY, true);
3082     }
3083     ReorderForPipelineSplit(root, manager, pipeline_stages);
3084 
3085     return changes;
3086   }
3087 
3088   struct timeval start_time, end_time;
3089   (void)gettimeofday(&start_time, nullptr);
3090 
3091   MS_LOG(INFO) << "Now entering step parallel";
3092   DumpGraph(root, std::string(STEP_PARALLEL_BEGIN));
3093   AnfNodePtr ret = root->get_return();
3094   MS_EXCEPTION_IF_NULL(ret);
3095   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
3096   std::reverse(all_nodes.begin(), all_nodes.end());
3097   if (parallel_mode != AUTO_PARALLEL) {
3098     TOTAL_OPS = 0;
3099     if (pipeline_stages <= 1 && ParallelInit() != SUCCESS) {
3100       MS_LOG(EXCEPTION) << "Parallel init failed";
3101     }
3102 
3103     if (pipeline_stages > 1) {
3104       HandleMicroBatch(all_nodes, manager);
3105       ParameterStartNode(all_nodes, manager);
3106       LastStageEndNode(all_nodes, manager, root);
3107     }
3108 
3109     // mark the forward cnodes, parallel only care these nodes
3110     MarkForwardCNode(root);
3111 
3112     if (FindCommunicationOp(all_nodes)) {
3113       MS_LOG(EXCEPTION) << "The graph contain communication op";
3114     }
3115 
3116     if (IsInsertVirtualOutput(root)) {
3117       InsertVirtualOutput(root, all_nodes);
3118       AnfNodePtr ret_after = root->get_return();
3119       MS_EXCEPTION_IF_NULL(ret_after);
3120       all_nodes = DeepScopedGraphSearch(ret_after);
3121       std::reverse(all_nodes.begin(), all_nodes.end());
3122     }
3123 
3124     // extract shape and strategy, set operator_info
3125     ExtractInformation(all_nodes);
3126     ReshapeInit(all_nodes);
3127   }
3128 
3129   HandleRootReshapeAndSaveStrategy(all_nodes);
3130 
3131   HandleForwardMakeTupleAndMakeList(all_nodes);
3132 
3133   // if the input or parameter has multiple users, check whether its split strategies are consistent.
3134   CheckParameterSplit(all_nodes);
3135 
3136   HandleSymbolicKeyInstance(root, all_nodes);
3137 
3138   // cover Parallel shape
3139   CoverSliceShape(root);
3140 
3141   // handle input is not used
3142   HandleNoUsedParameter(root);
3143 
3144   // set the shape for optimizer's clone tensor
3145   SetClonedTensorShapeForOptimizer(root);
3146 
3147   HandleAdaFactorOpt(root);
3148 
3149   // save strategy as checkpoint for multi-train
3150   if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) {
3151     CheckpointStrategy(all_nodes, root);
3152   }
3153   // ForwardCommunication BackwardCommunication TensorRedistribution
3154   ParallelCommunication(root, all_nodes, manager);
3155 
3156   if (pipeline_stages > 1) {
3157     AddVirtualAssignAdd(root);
3158     HandleReceiveParam(root, all_nodes);
3159   }
3160 
3161   auto group_info = g_device_manager->group_info();
3162   if (StrategyCheckpoint::GetInstance().group_info_save_on() &&
3163       StrategyCheckpoint::GetInstance().SaveGroupInfo(group_info) != SUCCESS) {
3164     MS_LOG(EXCEPTION) << "Save group info failed";
3165   }
3166 
3167   // handle full split parammeters in grad accumulation, do not contain optimizer-sharding's parameter
3168   HandleFullySplitParameters(root);
3169 
3170   DumpGraph(root, std::string(STEP_PARALLEL_END));
3171 
3172   // step parallel only run once
3173   root->set_flag(SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY, true);
3174   res->results()[pipeline::kStepParallelGraph] = root;
3175 
3176   // in auto parallel mode, no need to check if stategies set
3177   root->set_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY, true);
3178 
3179   (void)gettimeofday(&end_time, nullptr);
3180   uint64_t time = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
3181   time += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
3182 
3183   MS_LOG(INFO) << "Now leaving step parallel, used time: " << time << " us";
3184   return changes;
3185 }
3186 
3187 // Needed by rec_parser
ExtractInputsTensorName(const CNodePtr & node)3188 std::vector<std::string> ExtractInputsTensorName(const CNodePtr &node) {
3189   std::vector<std::string> name_inputs;
3190   std::vector<AnfNodePtr> all_inputs = node->inputs();
3191   std::vector<AnfNodePtr> node_inputs{all_inputs.begin() + 1, all_inputs.end()};
3192 
3193   std::string node_id = node->UniqueId();
3194   name_inputs.push_back(node_id);
3195   for (auto &input : node_inputs) {
3196     std::string name = input->UniqueId();
3197     name_inputs.push_back(name);
3198   }
3199 
3200   return name_inputs;
3201 }
3202 }  // namespace parallel
3203 }  // namespace mindspore
3204