• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/step_parallel.h"
18 
19 #include <cinttypes>
20 #include <algorithm>
21 #include <chrono>
22 #include <map>
23 #include <unordered_map>
24 #include <memory>
25 #include <set>
26 #include <string>
27 #include <queue>
28 #include "mindspore/core/ops/sequence_ops.h"
29 #include "mindspore/core/ops/other_ops.h"
30 #include "mindspore/core/ops/array_ops.h"
31 #include "mindspore/core/ops/structure_ops.h"
32 #include "mindspore/core/ops/framework_ops.h"
33 #include "utils/hash_map.h"
34 #include "frontend/operator/ops.h"
35 #include "frontend/optimizer/optimizer.h"
36 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
37 #include "include/common/utils/parallel_context.h"
38 #include "frontend/parallel/device_manager.h"
39 #include "frontend/parallel/ops_info/gather_info.h"
40 #include "frontend/parallel/ops_info/reshape_info.h"
41 #include "frontend/parallel/graph_util/generate_graph.h"
42 #include "frontend/parallel/graph_util/graph_info.h"
43 #include "frontend/parallel/graph_util/node_info.h"
44 #include "frontend/parallel/graph_util/graph_utils.h"
45 #include "frontend/parallel/tensor_layout/prime_generator.h"
46 #include "frontend/parallel/graph_util/pipeline_split_utils.h"
47 #include "frontend/parallel/graph_util/fold_pipeline_split_utils.h"
48 #include "frontend/parallel/pipeline_transformer/pipeline_interleave.h"
49 #include "frontend/parallel/graph_util/grad_accumulation_utils.h"
50 #include "frontend/parallel/node_check.h"
51 #include "frontend/parallel/silent_check/silent_check.h"
52 #include "frontend/parallel/parameter_manager.h"
53 #include "frontend/parallel/ops_info/matmul_info.h"
54 #include "frontend/parallel/dynamic_shape/dynamic_shape.h"
55 #include "frontend/parallel/tensor_layout/tensor_transform.h"
56 #include "ir/param_info.h"
57 #include "ir/tensor.h"
58 #include "utils/trace_base.h"
59 #include "include/common/utils/comm_manager.h"
60 #include "utils/ms_context.h"
61 #include "utils/symbolic.h"
62 #include "mindspore/core/utils/parallel_node_check.h"
63 #include "frontend/parallel/parallel_optimizer/opt_param_mgr.h"
64 #include "mindspore/core/ops/conv_pool_ops.h"
65 #include "mindspore/core/ops/nn_ops.h"
66 #include "mindspore/core/ops/ops_func_impl/flash_attention_score.h"
67 
68 #if defined(__linux__) && defined(WITH_BACKEND)
69 #include "include/backend/distributed/ps/util.h"
70 #include "include/backend/distributed/ps/ps_context.h"
71 #endif
72 
73 using mindspore::tensor::Tensor;
74 
75 namespace mindspore {
76 namespace parallel {
77 static const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS, LOAD, UPDATESTATE};
78 static const std::set<std::string> NO_INPUT_TENSOR_OPS = {UNIFORM_REAL, STANDARD_NORMAL};
79 const uint32_t MAX_BFS_DEPTH = 7;
80 const char kSilentCheckEnvEnable[] = "1";
81 
SetAllReduceRecomputeFlag(const std::vector<AnfNodePtr> & new_node_input,const CNodePtr & node)82 static void SetAllReduceRecomputeFlag(const std::vector<AnfNodePtr> &new_node_input, const CNodePtr &node) {
83   if (new_node_input.empty()) {
84     return;
85   }
86 
87   auto prim_anf_node = new_node_input[0]->cast<ValueNodePtr>();
88   auto prim = GetValueNode<PrimitivePtr>(prim_anf_node);
89   MS_EXCEPTION_IF_NULL(prim);
90   auto attrs = prim->attrs();
91 
92   auto anf_node = node->input(0)->cast<ValueNodePtr>();
93   auto prim_node = GetValueNode<PrimitivePtr>(anf_node);
94   MS_EXCEPTION_IF_NULL(prim_node);
95   auto node_attrs = prim_node->attrs();
96   if (node_attrs.find(RECOMPUTE_COMM_OP) != node_attrs.end() && !GetValue<bool>(node_attrs[RECOMPUTE_COMM_OP])) {
97     attrs[RECOMPUTE] = MakeValue<bool>(false);
98     (void)prim->SetAttrs(attrs);
99     MS_LOG(INFO) << "Do not recompute the forward communication operator of " << prim_node->ToString();
100   }
101 }
102 
103 // Replace pre_node with pre_node->op
ReplaceNode(const Operator & op,const AnfNodePtr & pre_node,const FuncGraphPtr & func_graph,const std::string & instance_name,const std::string & param_name="",const FuncGraphPtr & root=nullptr)104 static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, const FuncGraphPtr &func_graph,
105                             const std::string &instance_name, const std::string &param_name = "",
106                             const FuncGraphPtr &root = nullptr) {
107   // insert new node before the node
108   FuncGraphManagerPtr manager = func_graph->manager();
109   MS_EXCEPTION_IF_NULL(manager);
110   ScopePtr scope = pre_node->scope();
111   MS_EXCEPTION_IF_NULL(scope);
112   std::vector<AnfNodePtr> node_input;
113   if (root && !param_name.empty()) {
114     node_input = CreateMirrorInput(root, op, pre_node, instance_name, param_name);
115   } else {
116     node_input = CreateInput(op, pre_node, instance_name);
117   }
118   CNodePtr new_node = func_graph->NewCNode(node_input);
119   MS_EXCEPTION_IF_NULL(new_node);
120   if (instance_name.find(SPLIT_SENS) == std::string::npos) {
121     new_node->set_in_forward_flag(true);  // mark forward flag
122   }
123   auto new_node_prim = GetValueNode<PrimitivePtr>(node_input[0]);
124   new_node_prim->set_instance_name(instance_name);
125   new_node_prim->set_attr("keep_value_node_input", MakeValue(true));
126   if (instance_name.find(NOT_RECOMPUTE) != std::string::npos) {
127     new_node_prim->set_attr("recompute", MakeValue(false));
128   } else if (instance_name.find(RECOMPUTE) != std::string::npos) {
129     new_node_prim->set_attr("recompute", MakeValue(true));
130   }
131   new_node->set_scope(scope);
132   node_input[0]->set_scope(scope);
133   (void)manager->Replace(pre_node, new_node);
134   MS_LOG(INFO) << "Insert " << instance_name << " success";
135   return new_node;
136 }
137 
ForwardCommunicationForMultiOut(OperatorVector forward_op,const CNodePtr & node)138 void ForwardCommunicationForMultiOut(OperatorVector forward_op, const CNodePtr &node) {
139   MS_EXCEPTION_IF_NULL(node);
140   // step1:get graph manager distribute_operator
141   FuncGraphPtr func_graph = node->func_graph();
142   MS_EXCEPTION_IF_NULL(func_graph);
143   FuncGraphManagerPtr manager = func_graph->manager();
144   MS_EXCEPTION_IF_NULL(manager);
145   auto uses_set = manager->node_users()[node];
146   // For GMM, its out always be tuplegetitem, so we need to find the real user of GMM
147   std::vector<CNodePtr> node_to_insert = {};
148   for (auto &uses_pair : uses_set) {
149     auto uses_cnode = uses_pair.first->cast<CNodePtr>();
150     MS_EXCEPTION_IF_NULL(uses_cnode);
151     if (!IsValueNode<Primitive>(uses_cnode->input(0))) {
152       break;
153     }
154     PrimitivePtr value_node_prim = GetValueNode<PrimitivePtr>(uses_cnode->input(0));
155     MS_EXCEPTION_IF_NULL(value_node_prim);
156     if (value_node_prim->name() == prim::kPrimTupleGetItem->name()) {
157       node_to_insert.push_back(uses_cnode);
158     }
159   }
160   if (node_to_insert.empty()) {
161     MS_LOG(ERROR) << "The output of " << node->DebugString()
162                   << "does not have a tuplegetitem node. Forward communication can not be inserted, the correctness of "
163                      "current op can not be ensured.";
164     return;
165   }
166   std::reverse(forward_op.begin(), forward_op.end());
167 
168   // step2:traverse op_list and insert node
169   for (size_t index = 0; index < forward_op.size(); ++index) {
170     std::string instance_name_base = FORWARD_OP;
171     std::string instance_name = instance_name_base + "_" + CreateInstanceName(node, index);
172     std::vector<AnfNodePtr> forward_input = CreateInput(forward_op[index], node_to_insert[index], instance_name);
173     SetAllReduceRecomputeFlag(forward_input, node_to_insert[index]);
174     CNodePtr forward_node = func_graph->NewCNode(forward_input);  // using NewCNode to create anfnode
175     MS_EXCEPTION_IF_NULL(forward_node);
176     ScopePtr scope = node->scope();
177     MS_EXCEPTION_IF_NULL(scope);
178     forward_node->set_scope(scope);
179     forward_node->set_in_forward_flag(true);
180     forward_node->AddPrimalAttr(kPrimalAttrForwardCommNodeUniqueId, MakeValue<std::string>(forward_node->UniqueId()));
181     if (node_to_insert[index]->HasPrimalAttr(MICRO)) {
182       forward_node->AddPrimalAttr(MICRO, node_to_insert[index]->GetPrimalAttr(MICRO));
183     }
184     forward_input[0]->set_scope(scope);
185     (void)manager->Replace(node_to_insert[index], forward_node);  // using Replace function to insert node
186   }
187 }
188 
ForwardCommunication(OperatorVector forward_op,const CNodePtr & node)189 void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node) {
190   if (dyn_cast<abstract::SequenceShape>(node->Shape()) != nullptr) {
191     // For Ops like GMM has multiple output
192     MS_LOG(INFO) << "The input node " << node->DebugString()
193                  << " has multiple output, enter ForwardCommunicationForMultiOut";
194     ForwardCommunicationForMultiOut(forward_op, node);
195     return;
196   }
197   MS_EXCEPTION_IF_NULL(node);
198   // step1:get graph manager distribute_operator
199   FuncGraphPtr func_graph = node->func_graph();
200   MS_EXCEPTION_IF_NULL(func_graph);
201   FuncGraphManagerPtr manager = func_graph->manager();
202   MS_EXCEPTION_IF_NULL(manager);
203   auto uses_set = manager->node_users()[node];
204   CNodePtr node_to_insert = node;
205   for (auto &uses_pair : uses_set) {
206     auto uses_cnode = uses_pair.first->cast<CNodePtr>();
207     MS_EXCEPTION_IF_NULL(uses_cnode);
208     if (!IsValueNode<Primitive>(uses_cnode->input(0))) {
209       break;
210     }
211     PrimitivePtr value_node_prim = GetValueNode<PrimitivePtr>(uses_cnode->input(0));
212     MS_EXCEPTION_IF_NULL(value_node_prim);
213     if (value_node_prim->name() == prim::kPrimTupleGetItem->name()) {
214       if (uses_set.size() > 1) {
215         MS_LOG(EXCEPTION) << "Now only support one output, but got " << uses_set.size();
216       }
217       node_to_insert = uses_cnode;
218     }
219   }
220   MS_EXCEPTION_IF_NULL(node_to_insert);
221   std::reverse(forward_op.begin(), forward_op.end());
222 
223   // step2:traverse op_list and insert node
224   for (size_t index = 0; index < forward_op.size(); ++index) {
225     std::string instance_name_base = FORWARD_OP;
226     std::string instance_name = instance_name_base + "_" + CreateInstanceName(node, index);
227     std::vector<AnfNodePtr> forward_input = CreateInput(forward_op[index], node_to_insert, instance_name);
228     SetAllReduceRecomputeFlag(forward_input, node_to_insert);
229     CNodePtr forward_node = func_graph->NewCNode(forward_input);  // using NewCNode to create anfnode
230     MS_EXCEPTION_IF_NULL(forward_node);
231     ScopePtr scope = node->scope();
232     MS_EXCEPTION_IF_NULL(scope);
233     forward_node->set_scope(scope);
234     forward_node->set_in_forward_flag(true);
235     forward_node->AddPrimalAttr(kPrimalAttrForwardCommNodeUniqueId, MakeValue<std::string>(forward_node->UniqueId()));
236     if (node_to_insert->HasPrimalAttr(MICRO)) {
237       forward_node->AddPrimalAttr(MICRO, node_to_insert->GetPrimalAttr(MICRO));
238     }
239     forward_input[0]->set_scope(scope);
240     (void)manager->Replace(node_to_insert, forward_node);  // using Replace function to insert node
241   }
242 }
243 
InsertMakeTuple(const AnfNodePtr & prev,uint64_t num,const FuncGraphPtr & func_graph)244 static CNodePtr InsertMakeTuple(const AnfNodePtr &prev, uint64_t num, const FuncGraphPtr &func_graph) {
245   MS_EXCEPTION_IF_NULL(prev);
246   MS_EXCEPTION_IF_NULL(func_graph);
247   ScopeGuard scope_guard(prev->scope());
248   std::vector<AnfNodePtr> make_tuple_inputs;
249   make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
250   for (uint64_t i = 0; i < num; i++) {
251     std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), prev,
252                                                   CreatInt64Imm(UlongToLong(i))};
253     auto tuple_get_item = func_graph->NewCNode(tuple_get_item_inputs);
254     MS_EXCEPTION_IF_NULL(tuple_get_item);
255     make_tuple_inputs.push_back(tuple_get_item);
256   }
257   auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
258   MS_EXCEPTION_IF_NULL(make_tuple);
259   FuncGraphManagerPtr manager = func_graph->manager();
260   MS_EXCEPTION_IF_NULL(manager);
261   (void)manager->Replace(prev, make_tuple);
262   return make_tuple;
263 }
264 
InsertRedistribution(const RedistributionOpListPtr & redistribution_oplist_ptr,const CNodePtr & node,const FuncGraphPtr & func_graph,int64_t pos,const CNodePtr & pre_node,const TensorRedistributionPtr & tensor_redistribution)265 static void InsertRedistribution(const RedistributionOpListPtr &redistribution_oplist_ptr, const CNodePtr &node,
266                                  const FuncGraphPtr &func_graph, int64_t pos, const CNodePtr &pre_node,
267                                  const TensorRedistributionPtr &tensor_redistribution) {
268   MS_EXCEPTION_IF_NULL(node);
269   MS_EXCEPTION_IF_NULL(pre_node);
270   MS_EXCEPTION_IF_NULL(func_graph);
271   FuncGraphManagerPtr manager = func_graph->manager();
272   MS_EXCEPTION_IF_NULL(manager);
273   if ((redistribution_oplist_ptr->first).size() != (redistribution_oplist_ptr->second).size()) {
274     MS_LOG(EXCEPTION) << "size of OperatorVector and OutPutInfoVector must be the same!";
275   }
276 
277   for (size_t index = 0; index < (redistribution_oplist_ptr->first).size(); ++index) {
278     if (pos >= SizeToLong(node->size())) {
279       MS_LOG(EXCEPTION) << "InsertRedistribution:pos can't be larger than node's inputs'size";
280     }
281     // Create new node
282     AnfNodePtr target_node = node->input(LongToSize(pos));
283     MS_EXCEPTION_IF_NULL(target_node);
284     // Create instance_name
285     auto op = (redistribution_oplist_ptr->first)[index];
286     std::string op_name = (redistribution_oplist_ptr->first)[index].first;
287     std::string instance_name_base = REDISTRIBUTION_OP;
288     std::string instance_name = instance_name_base + "_" + CreateInstanceName(pre_node, index) + op_name;
289     auto prim_out = GetCNodePrimitive(node);
290     auto prim_in = GetCNodePrimitive(pre_node);
291     if (prim_out != nullptr && prim_in != nullptr) {
292       auto prim_out_attr = prim_out->attrs();
293       auto prim_in_attr = prim_in->attrs();
294       std::string recompute_str = "";
295       if (prim_out_attr.find(RECOMPUTE_COMM_OP) != prim_out_attr.end()) {
296         recompute_str = GetValue<bool>(prim_out_attr[RECOMPUTE_COMM_OP]) ? RECOMPUTE : NOT_RECOMPUTE;
297       }
298       if (recompute_str.empty() && prim_in_attr.find(RECOMPUTE_COMM_OP) != prim_in_attr.end()) {
299         recompute_str = GetValue<bool>(prim_in_attr[RECOMPUTE_COMM_OP]) ? RECOMPUTE : NOT_RECOMPUTE;
300       }
301       instance_name = instance_name + "_" + recompute_str;
302     }
303     InsertNode(op, node, LongToSize(pos), target_node, func_graph, instance_name, "", nullptr, tensor_redistribution);
304     if ((redistribution_oplist_ptr->second)[index].first) {
305       target_node = node->input(LongToSize(pos));
306       MS_EXCEPTION_IF_NULL(target_node);
307       (void)InsertMakeTuple(target_node, (redistribution_oplist_ptr->second)[index].second, func_graph);
308     }
309   }
310 }
311 
InsertGetTensorSliceOp(const Operator & op,const CNodePtr & node,const FuncGraphPtr & func_graph,int64_t pos,const std::string & instance_name)312 static void InsertGetTensorSliceOp(const Operator &op, const CNodePtr &node, const FuncGraphPtr &func_graph,
313                                    int64_t pos, const std::string &instance_name) {
314   if (func_graph == nullptr) {
315     MS_LOG(EXCEPTION) << "InsertGetTensorSliceOp: the graph is null, the instance name is " << instance_name;
316   }
317 
318   FuncGraphManagerPtr manager = func_graph->manager();
319   MS_EXCEPTION_IF_NULL(manager);
320   if (pos >= SizeToLong(node->size())) {
321     MS_LOG(EXCEPTION) << "InsertGetTensorSliceOp: pos can't be larger than node's inputs'size, the instance name is "
322                       << instance_name;
323   }
324   // Create new node
325   AnfNodePtr pre_node = node->input(LongToSize(pos));
326   MS_EXCEPTION_IF_NULL(pre_node);
327   InsertNode(op, node, LongToSize(pos), pre_node, func_graph, instance_name);
328 }
329 
GetTensorInLayoutForNewShape(const AnfNodePtr & pre_node,std::vector<int> get_item_index)330 TensorLayout GetTensorInLayoutForNewShape(const AnfNodePtr &pre_node, std::vector<int> get_item_index) {
331   TensorLayout tensorinfo_in_layout;
332   auto pre_cnode = pre_node->cast<CNodePtr>();
333   MS_EXCEPTION_IF_NULL(pre_cnode);
334   auto distribute_operator = GetDistributeOperator(pre_cnode);
335   MS_EXCEPTION_IF_NULL(distribute_operator);
336   TensorInfoBasePtr tensorinfo_in;
337   auto tensor_info_pos = get_item_index.front();
338   get_item_index.erase(get_item_index.begin());
339   if (tensor_info_pos != -1) {
340     if (tensor_info_pos >= SizeToInt(distribute_operator->outputs_tensor_info_new().size())) {
341       MS_LOG(EXCEPTION) << "The index out of range. Node: " << pre_node->DebugString() << " index: " << tensor_info_pos
342                         << " outputs_tensor_info's size: " << distribute_operator->outputs_tensor_info().size();
343     }
344     tensorinfo_in = distribute_operator->outputs_tensor_info_new()[IntToSize(tensor_info_pos)];
345   } else {
346     tensorinfo_in = distribute_operator->outputs_tensor_info_new()[0];
347   }
348   for (const auto &index : get_item_index) {
349     tensorinfo_in = tensorinfo_in->GetElement(IntToLong(index));
350   }
351   tensorinfo_in_layout = tensorinfo_in->GetValue().tensor_layout();
352   return tensorinfo_in_layout;
353 }
354 
GetTensorInLayout(const AnfNodePtr & pre_node,std::vector<int> get_item_index)355 TensorLayout GetTensorInLayout(const AnfNodePtr &pre_node, std::vector<int> get_item_index) {
356   TensorLayout tensorinfo_in_layout;
357   auto pre_cnode = pre_node->cast<CNodePtr>();
358   MS_EXCEPTION_IF_NULL(pre_cnode);
359   auto distribute_operator = GetDistributeOperator(pre_cnode);
360   if (!distribute_operator->outputs_tensor_info_new().empty()) {
361     return GetTensorInLayoutForNewShape(pre_node, get_item_index);
362   }
363   MS_EXCEPTION_IF_NULL(distribute_operator);
364   if (get_item_index.size() != 1) {
365     // If does not have outputes_tensor_info_new, the outputs only have one tensor info
366     // thus the get item index must only have one value
367     MS_LOG(EXCEPTION) << "The get_item_index size is not 1, the size is " << get_item_index.size();
368   }
369   if (get_item_index[get_item_index.size() - 1] != -1) {
370     if (get_item_index[get_item_index.size() - 1] >= SizeToInt(distribute_operator->outputs_tensor_info().size())) {
371       MS_LOG(EXCEPTION) << "The index out of range. Node: " << pre_node->DebugString() << " index: " << get_item_index
372                         << " outputs_tensor_info's size: " << distribute_operator->outputs_tensor_info().size();
373     }
374     auto tensorinfo_in =
375       distribute_operator->outputs_tensor_info()[IntToSize(get_item_index[get_item_index.size() - 1])];
376     tensorinfo_in_layout = tensorinfo_in.tensor_layout();
377   } else {
378     if (distribute_operator->outputs_tensor_info().empty()) {
379       MS_LOG(EXCEPTION) << "The outputs tensor info is empty. Node:" << pre_node->DebugString();
380     }
381     auto tensorinfo_in = distribute_operator->outputs_tensor_info()[0];
382     tensorinfo_in_layout = tensorinfo_in.tensor_layout();
383   }
384   return tensorinfo_in_layout;
385 }
386 
ObtainOutputTensorLayout(const OperatorInfoPtr & next_distribute_operator,const std::pair<AnfNodePtr,std::vector<int>> & node_pair,const CNodePtr & next_cnode,const bool & using_func_param_op_info,TensorLayout * tensorlayout_out)387 Status ObtainOutputTensorLayout(const OperatorInfoPtr &next_distribute_operator,
388                                 const std::pair<AnfNodePtr, std::vector<int>> &node_pair, const CNodePtr &next_cnode,
389                                 const bool &using_func_param_op_info, TensorLayout *tensorlayout_out) {
390   bool next_dist_op_has_tuple = !next_distribute_operator->inputs_tensor_info_new().empty();
391   if (next_dist_op_has_tuple) {
392     auto next_inputs_tensor_info = using_func_param_op_info ? next_distribute_operator->outputs_tensor_info_new()
393                                                             : next_distribute_operator->inputs_tensor_info_new();
394     auto it = std::find_if(node_pair.second.begin(), node_pair.second.end(), [&](const auto &input_idx) {
395       return LongToSize(input_idx - 1) >= next_inputs_tensor_info.size();
396     });
397     if (it != node_pair.second.end()) {
398       MS_LOG(INFO) << "The index is out of range, the index is " << (*it - 1) << ", the vector size is "
399                    << next_inputs_tensor_info.size() << ", next node is " << next_cnode->DebugString();
400       return FAILED;
401     }
402     auto tensorinfo_out_ptr = next_inputs_tensor_info[LongToSize(node_pair.second[0] - 1)];
403     if (tensorinfo_out_ptr->is_list()) {
404       for (size_t i = 1; i < node_pair.second.size(); ++i) {
405         tensorinfo_out_ptr = tensorinfo_out_ptr->GetElement(LongToSize(node_pair.second[i] - 1));
406       }
407     }
408     TensorInfo tensorinfo_out = tensorinfo_out_ptr->GetValue();
409     *tensorlayout_out = tensorinfo_out.tensor_layout();
410     return SUCCESS;
411   }
412   auto next_inputs_tensor_info = using_func_param_op_info ? next_distribute_operator->outputs_tensor_info()
413                                                           : next_distribute_operator->inputs_tensor_info();
414   size_t out_layout_index = LongToSize(node_pair.second[node_pair.second.size() - 1] - 1);
415   if (out_layout_index >= next_inputs_tensor_info.size()) {
416     MS_LOG(INFO) << "The index is out of range, the index is " << out_layout_index << ", the vector size is "
417                  << next_inputs_tensor_info.size() << ", next node is " << next_cnode->DebugString();
418     return FAILED;
419   }
420   TensorInfo tensorinfo_out = next_inputs_tensor_info[out_layout_index];
421   *tensorlayout_out = tensorinfo_out.tensor_layout();
422   return SUCCESS;
423 }
424 
InsertRedistributionForMicroInterleaved(const TensorRedistributionPtr & tensor_redistribution,const std::pair<AnfNodePtr,int64_t> & node_pair,const FuncGraphPtr & func_graph,const CNodePtr & attr_cnode,const CNodePtr & real_pre_node)425 void InsertRedistributionForMicroInterleaved(const TensorRedistributionPtr &tensor_redistribution,
426                                              const std::pair<AnfNodePtr, int64_t> &node_pair,
427                                              const FuncGraphPtr &func_graph, const CNodePtr &attr_cnode,
428                                              const CNodePtr &real_pre_node) {
429   auto redistribution_oplist_ptr_vector = tensor_redistribution->InferTensorRedistributionOperatorVirtualGraphs();
430   auto next_cnode = node_pair.first->cast<CNodePtr>();
431   MS_EXCEPTION_IF_NULL(next_cnode);
432   auto next_cnode_index = node_pair.second;
433   // create VirtualConverterBeginNode
434   MS_EXCEPTION_IF_NULL(real_pre_node);
435   auto virtual_converter_begin =
436     CreateVirtualConverterBeginNode(real_pre_node, redistribution_oplist_ptr_vector.size());
437   std::vector<CNodePtr> tuple_get_item_vector;
438   for (size_t i = 0; i < redistribution_oplist_ptr_vector.size(); ++i) {
439     if (redistribution_oplist_ptr_vector[i]->first.empty()) {
440       return;
441     }
442     // create tuple_get_item
443     std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), virtual_converter_begin,
444                                                   CreatInt64Imm(UlongToLong(i))};
445     auto tuple_get_item_cnode = func_graph->NewCNode(tuple_get_item_inputs);
446     tuple_get_item_vector.push_back(tuple_get_item_cnode);
447   }
448   // create VirtualConverterEndNode
449   auto virtual_converter_end = CreateVirtualConverterEndNode(func_graph, tuple_get_item_vector);
450   auto manager = func_graph->manager();
451   (void)manager->SetEdge(next_cnode, next_cnode_index, virtual_converter_end);
452   // add recompute_comm_op attrs
453   auto prim_out = GetCNodePrimitive(next_cnode);
454   if (prim_out != nullptr && prim_out->HasAttr(RECOMPUTE_COMM_OP)) {
455     auto out_recompute_comm_op_attr = prim_out->GetAttr(RECOMPUTE_COMM_OP);
456     auto virtual_converter_end_prim = GetCNodePrimitive(virtual_converter_end);
457     virtual_converter_end_prim->AddAttr(RECOMPUTE_COMM_OP, out_recompute_comm_op_attr);
458   }
459   std::vector<std::vector<std::vector<int64_t>>> ag_group_ranks_vectors;
460 
461   for (size_t i = 0; i < redistribution_oplist_ptr_vector.size(); ++i) {
462     auto redistribution_oplist_ptr = redistribution_oplist_ptr_vector[i];
463     if (!tensor_redistribution->IsAssembledStaticShape()) {
464       redistribution_oplist_ptr = TensorTransform::GetInstance()->OptimizeTensorRedistributionOperatorList(
465         redistribution_oplist_ptr, tensor_redistribution->input_shape());
466     }
467     // Get allgather group_ranks attr in redistribution_oplist_ptr
468     std::vector<std::vector<int64_t>> ag_group_ranks_vector;
469     for (size_t findex = 0; findex < (redistribution_oplist_ptr->first).size(); ++findex) {
470       // Create instance_name
471       auto index = (redistribution_oplist_ptr->first).size() - 1 - findex;
472       auto op = (redistribution_oplist_ptr->first)[index];
473       std::string op_name = (redistribution_oplist_ptr->first)[index].first;
474       if (op_name == ALL_GATHER) {
475         auto group_ranks_attr = (redistribution_oplist_ptr->first)[index].second.first[1].second;
476         auto group_ranks = GetValue<std::vector<int64_t>>(group_ranks_attr);
477         ag_group_ranks_vector.push_back(group_ranks);
478       }
479     }
480     ag_group_ranks_vectors.push_back(ag_group_ranks_vector);
481     InsertRedistribution(redistribution_oplist_ptr, virtual_converter_end, func_graph, i + 1, attr_cnode,
482                          tensor_redistribution);
483   }
484   ConvertInterleaveAllGatherToConcat(func_graph, virtual_converter_end, ag_group_ranks_vectors);
485 }
486 
Redistribution(const std::pair<AnfNodePtr,std::vector<int>> & node_pair,const AnfNodePtr & pre_node,const std::vector<int> & get_item_index)487 static void Redistribution(const std::pair<AnfNodePtr, std::vector<int>> &node_pair, const AnfNodePtr &pre_node,
488                            const std::vector<int> &get_item_index) {
489   MS_LOG(DEBUG) << "Do Redistribution for " << node_pair.first->fullname_with_scope();
490   auto next_cnode = node_pair.first->cast<CNodePtr>();
491   MS_EXCEPTION_IF_NULL(next_cnode);
492   auto func_graph = next_cnode->func_graph();
493   MS_EXCEPTION_IF_NULL(func_graph);
494   auto pre_cnode = pre_node->cast<CNodePtr>();
495   MS_EXCEPTION_IF_NULL(pre_cnode);
496   auto distribute_operator = GetDistributeOperator(pre_cnode);
497   MS_EXCEPTION_IF_NULL(distribute_operator);
498   auto dev_list = distribute_operator->stage_device_list();
499   OperatorInfoPtr next_distribute_operator;
500   bool using_func_param_op_info = false;
501   if (IsValueNode<FuncGraph>(next_cnode->input(0))) {
502     auto fg = GetValueNode<FuncGraphPtr>(next_cnode->input(0));
503     auto fg_parameters = fg->parameters();
504     auto param = fg_parameters[IntToSize(node_pair.second[node_pair.second.size() - 1] - 1)];
505     if (param->has_user_data<OperatorInfo>()) {
506       MS_LOG(INFO) << "Func call node:" << next_cnode->DebugString() << " has operator info.";
507       next_distribute_operator = param->user_data<OperatorInfo>();
508       using_func_param_op_info = true;
509     } else {
510       next_distribute_operator = GetDistributeOperator(next_cnode);
511     }
512   } else {
513     next_distribute_operator = GetDistributeOperator(next_cnode);
514   }
515   MS_LOG(DEBUG) << "Redistribution for pre_node: " << pre_cnode->DebugString()
516                 << " next_node: " << next_cnode->DebugString();
517   MS_EXCEPTION_IF_NULL(next_distribute_operator);
518 
519   auto tensor_redistribution = next_distribute_operator->CreateTensorRedistribution();
520   tensor_redistribution->SetPreAndNextCNode(pre_cnode, next_cnode);
521   MS_LOG(DEBUG) << "Redistribution for pre_node: " << pre_cnode->DebugString()
522                 << "next_node: " << next_cnode->DebugString();
523 
524   // extract tensor layout in and out
525   if (distribute_operator->outputs_tensor_info().empty() && distribute_operator->outputs_tensor_info_new().empty()) {
526     MS_LOG(WARNING) << "pre_node's tensorinfo_in is empty, operator name is " << distribute_operator->name();
527     return;
528   }
529   TensorLayout tensorlayout_out;
530   auto status = ObtainOutputTensorLayout(next_distribute_operator, node_pair, next_cnode, using_func_param_op_info,
531                                          &tensorlayout_out);
532   if (status != SUCCESS) {
533     return;
534   }
535   TensorLayout tensorlayout_in = GetTensorInLayout(pre_node, get_item_index);
536   if (IsPrimitiveCNode(pre_node, prim::kPrimReceive)) {
537     tensorlayout_in = *(pre_node->user_data<TensorLayout>());
538   }
539 
540   if (tensor_redistribution->Init(tensorlayout_in, tensorlayout_out, dev_list) == FAILED) {
541     MS_LOG(ERROR) << "Redistribution: pre_node " << pre_cnode->DebugString() << " next_node "
542                   << next_cnode->DebugString();
543     DumpGraph(func_graph, "redistribution_error");
544     MS_LOG(EXCEPTION) << "Failure:tensor_redistribution init failed";
545   }
546   if (tensorlayout_in.GetVirtualRank().size() > 1 || tensorlayout_out.GetVirtualRank().size() > 1) {
547     auto real_pre_node = next_cnode->input(node_pair.second[node_pair.second.size() - 1])->cast<CNodePtr>();
548     InsertRedistributionForMicroInterleaved(tensor_redistribution,
549                                             {node_pair.first, node_pair.second[node_pair.second.size() - 1]},
550                                             func_graph, pre_cnode, real_pre_node);
551     return;
552   }
553   RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution->InferTensorRedistributionOperatorList();
554   if (redistribution_oplist_ptr == nullptr) {
555     MS_LOG(INTERNAL_EXCEPTION) << "Infer tensor redistribution failed.";
556   }
557   if (!tensor_redistribution->IsAssembledStaticShape()) {
558     redistribution_oplist_ptr = TensorTransform::GetInstance()->OptimizeTensorRedistributionOperatorList(
559       redistribution_oplist_ptr, tensor_redistribution->input_shape());
560   }
561 
562   if (redistribution_oplist_ptr == nullptr) {
563     MS_LOG(EXCEPTION) << "Failure:InferTensorRedistribution failed";
564   }
565   MS_LOG(DEBUG) << "Redistribution size " << redistribution_oplist_ptr->first.size();
566   if (!redistribution_oplist_ptr->first.empty()) {
567     // the last one is the pos of node in maketuple
568     tensor_redistribution->CreateAssembledDynamicMapping(next_cnode, pre_cnode, func_graph,
569                                                          node_pair.second[node_pair.second.size() - 1]);
570     // insert node before next node
571     InsertRedistribution(redistribution_oplist_ptr, next_cnode, func_graph,
572                          node_pair.second[node_pair.second.size() - 1], pre_cnode, tensor_redistribution);
573   }
574   // Rollback to dynamic shape.
575   if (tensor_redistribution->IsAssembledStaticShape() &&
576       tensor_redistribution->ResetLayoutTransfer() != Status::SUCCESS) {
577     MS_LOG(WARNING) << "Failed to reset layout transfer.";
578   }
579 }
580 
StepRedistribution(const CNodePtr & cnode,const NodeUsersMap & node_users_map)581 static void StepRedistribution(const CNodePtr &cnode, const NodeUsersMap &node_users_map) {
582   MS_LOG(DEBUG) << "Do StepRedistribution for " << cnode->fullname_with_scope();
583   MS_EXCEPTION_IF_NULL(cnode->func_graph());
584   FuncGraphManagerPtr manager = cnode->func_graph()->manager();
585   MS_EXCEPTION_IF_NULL(manager);
586   // In pipeline parallel mode, redistribution is inserted after receive, not send.
587   if (IsPrimitiveCNode(cnode, prim::kPrimSend) || IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) ||
588       IsPrimitiveCNode(cnode, prim::kPrimMakeList)) {
589     return;
590   }
591   // Find Redistribution next_nodes
592   // next_node.first.second = (pos in next node input(don't need to -1), pos in tuple(need to -1))
593   std::vector<std::pair<std::pair<AnfNodePtr, std::vector<int>>, std::vector<int>>> next_nodes;
594   RedistributionNextNode(cnode, manager, node_users_map, {-1}, -1, &next_nodes);
595   if (next_nodes.empty()) {
596     return;
597   }
598 
599   // Find Redistribution pre_nodes
600   std::vector<AnfNodePtr> pre_nodes;
601   RedistributionPreNode(cnode, manager, &pre_nodes);
602   if (pre_nodes.size() > 1) {
603     MS_LOG(EXCEPTION) << " Don't support Redistribution has multiple pre_node.";
604   }
605 
606   // Insert Redistribution nodes between pre_nodes and next_nodes
607   for (auto &pre_node : pre_nodes) {
608     for (auto &next_node : next_nodes) {
609       MS_LOG(INFO) << "===========Do Redistribution start============" << std::endl
610                    << pre_node->fullname_with_scope() << "->" << next_node.first.first->fullname_with_scope() << "("
611                    << next_node.first.second << ")";
612       Redistribution(next_node.first, pre_node, next_node.second);
613       MS_LOG(INFO) << "===========Do Redistribution end  ============";
614     }
615     for (const auto &next_node : next_nodes) {
616       if (!next_node.first.first->has_user_data(FUNC_PARAM)) {
617         continue;
618       }
619       if (pre_node->func_graph() == next_node.first.first->func_graph()) {
620         continue;
621       }
622       auto param = next_node.first.first->user_data<AnfNode>(FUNC_PARAM);
623       auto distribute_operator = GetDistributeOperator(pre_node->cast<CNodePtr>());
624       param->set_user_data<OperatorInfo>(distribute_operator);
625       break;
626     }
627   }
628 }
629 
SplitTensor(const AnfNodePtr & node,const CNodePtr & next_node,int64_t index)630 static void SplitTensor(const AnfNodePtr &node, const CNodePtr &next_node, int64_t index) {
631   MS_EXCEPTION_IF_NULL(node);
632   MS_EXCEPTION_IF_NULL(next_node);
633   OperatorInfoPtr op_info = next_node->user_data<OperatorInfo>();
634   if (!op_info) {
635     return;
636   }
637 
638   if (op_info->name().find(FILLV2) != std::string::npos) {
639     MS_LOG(INFO) << "FillV2 operator info no need to split tensor";
640     return;
641   }
642 
643   if (op_info->name().find(STAND_ALONE) != std::string::npos) {
644     MS_LOG(INFO) << "Stand alone operator info no need to split tensor";
645     return;
646   }
647 
648   // If the shape of tensor is [] or [1], no need to split it.
649   Shapes shapes = GetNodeShape(node);
650   if (shapes.size() != 1) {
651     MS_LOG(EXCEPTION) << "Split tensor for " << op_info->name()
652                       << ": GetNodeShape for tensor_node, output size is not 1";
653   }
654   Shape shape = shapes[0];
655   std::string shape_str = ShapeToString(shape);
656   if (shape.empty() || ((shape.size() == 1) && (shape[0] == 1))) {
657     MS_LOG(INFO) << "Split tensor for " << op_info->name() << ": The shape is " << shape_str
658                  << ", no need to split it.";
659     return;
660   }
661 
662   MS_LOG(INFO) << "Split tensor for " << op_info->name() << ": The shape of tensor is " << shape_str;
663 
664   // extract tensor layout
665   TensorLayout tensor_layout;
666   auto inputs_info_size = op_info->inputs_tensor_info_new().empty() ? op_info->inputs_tensor_info().size()
667                                                                     : op_info->inputs_tensor_info_new().size();
668   if (LongToSize(index - 1) >= inputs_info_size) {
669     if (IsIgnoreSplitTensor(next_node, index - 1)) {
670       MS_LOG(INFO) << op_info->name() << ": no need to split tensor for index " << (index - 1);
671       return;
672     }
673     MS_LOG(EXCEPTION) << op_info->name() << ": The index is out of range, index is  " << (index - 1)
674                       << ", vector size is  " << inputs_info_size;
675   }
676   if (op_info->inputs_tensor_info_new().empty()) {
677     TensorInfo tensor_info = op_info->inputs_tensor_info()[LongToSize(index - 1)];
678     tensor_layout = tensor_info.tensor_layout();
679   } else {
680     auto tensor_info = op_info->inputs_tensor_info_new()[LongToSize(index - 1)];
681     tensor_layout = tensor_info->GetValue().tensor_layout();
682   }
683 
684   // Use _GetTensorSlice operator to split the tensor
685   FuncGraphPtr func_graph = next_node->func_graph();  // only cnode can get the graph
686   MS_EXCEPTION_IF_NULL(func_graph);
687   Operator op = CreateGetTensorSliceOp(tensor_layout);
688   InsertGetTensorSliceOp(op, next_node, func_graph, index, SPLIT_TENSOR);
689   if (!op_info->sub_ops().empty()) {
690     auto sub_ops = op_info->sub_ops();
691     for (size_t i = 0; i < sub_ops.size(); i++) {
692       if (!sub_ops.at(i).empty()) {
693         InsertGetTensorSliceOp(sub_ops.at(i).at(0), next_node, func_graph, index, SUB);
694       }
695     }
696   }
697 }
698 
SplitTensorList(const AnfNodePtr & node,const CNodePtr & next_node,int index)699 static void SplitTensorList(const AnfNodePtr &node, const CNodePtr &next_node, int index) {
700   MS_EXCEPTION_IF_NULL(node);
701   MS_EXCEPTION_IF_NULL(next_node);
702   if (((next_node->size() != kSizeTwo) && !IsSomePrimitiveList(next_node, SUPPORT_NEW_SHAPEBASE_OPS)) || index != 1) {
703     MS_LOG(INFO) << next_node->fullname_with_scope() << " Inputs must have only one input, get "
704                  << (next_node->size() - 1) << " index should be 1, get " << index;
705     return;
706   }
707   OperatorInfoPtr op_info = next_node->user_data<OperatorInfo>();
708   MS_EXCEPTION_IF_NULL(op_info);
709 
710   std::vector<ValuePtr> inputs_values;
711   if (IsValueNode<ValueList>(node)) {
712     inputs_values = node->cast<ValueNodePtr>()->value()->cast<ValueListPtr>()->value();
713   } else {
714     inputs_values = node->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value();
715   }
716   std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
717   FuncGraphPtr func_graph = next_node->func_graph();
718   MS_EXCEPTION_IF_NULL(func_graph);
719   FuncGraphManagerPtr manager = func_graph->manager();
720   MS_EXCEPTION_IF_NULL(manager);
721   if (op_info->inputs_tensor_info_new().empty()) {
722     if (inputs_values.size() != op_info->inputs_tensor_info().size()) {
723       MS_LOG(EXCEPTION) << "The inputs size " << inputs_values.size() << ", is not equal to inputs shape size "
724                         << op_info->inputs_tensor_info().size();
725     }
726     ScopePtr scope = next_node->scope();
727     MS_EXCEPTION_IF_NULL(scope);
728     for (size_t i = 0; i < inputs_values.size(); ++i) {
729       auto value_ptr = inputs_values[i];
730       auto tensor = value_ptr->cast<tensor::TensorPtr>();
731       MS_EXCEPTION_IF_NULL(tensor);
732       TensorInfo tensor_info = op_info->inputs_tensor_info()[i];
733       TensorLayout tensor_layout = tensor_info.tensor_layout();
734       auto value_node = NewValueNode(value_ptr)->cast<AnfNodePtr>();
735       Operator op = CreateGetTensorSliceOp(tensor_layout);
736       std::vector<AnfNodePtr> node_input = CreateInput(op, value_node, SPLIT_TENSOR);
737       CNodePtr new_node = func_graph->NewCNode(node_input);
738       new_node->set_in_forward_flag(true);
739       auto new_node_value = node_input[0]->cast<ValueNodePtr>();
740       MS_EXCEPTION_IF_NULL(new_node_value);
741       PrimitivePtr new_node_prim = new_node_value->value()->cast<PrimitivePtr>();
742       new_node_prim->set_instance_name(SPLIT_TENSOR);
743       new_node_prim->set_attr("keep_value_node_input", MakeValue(true));
744       new_node->set_scope(scope);
745       node_input[0]->set_scope(scope);
746       make_tuple_inputs.push_back(new_node);
747     }
748   } else {
749     if (inputs_values.size() != op_info->inputs_tensor_info_new()[index - 1]->size()) {
750       MS_LOG(EXCEPTION) << "The inputs size " << inputs_values.size() << ", is not equal to inputs shape size "
751                         << op_info->inputs_tensor_info_new()[index - 1]->size();
752     }
753     auto corresponding_tensor_info = op_info->inputs_tensor_info_new()[index - 1];
754     ScopePtr scope = next_node->scope();
755     MS_EXCEPTION_IF_NULL(scope);
756     for (size_t i = 0; i < inputs_values.size(); ++i) {
757       auto value_ptr = inputs_values[i];
758       auto tensor = value_ptr->cast<tensor::TensorPtr>();
759       MS_EXCEPTION_IF_NULL(tensor);
760       TensorInfo tensor_info = corresponding_tensor_info->GetElement(SizeToLong(i))->GetValue();
761       TensorLayout tensor_layout = tensor_info.tensor_layout();
762       auto value_node = NewValueNode(value_ptr)->cast<AnfNodePtr>();
763       Operator op = CreateGetTensorSliceOp(tensor_layout);
764       std::vector<AnfNodePtr> node_input = CreateInput(op, value_node, SPLIT_TENSOR);
765       CNodePtr new_node = func_graph->NewCNode(node_input);
766       new_node->set_in_forward_flag(true);
767       auto new_node_value = node_input[0]->cast<ValueNodePtr>();
768       MS_EXCEPTION_IF_NULL(new_node_value);
769       PrimitivePtr new_node_prim = new_node_value->value()->cast<PrimitivePtr>();
770       new_node_prim->set_instance_name(SPLIT_TENSOR);
771       new_node_prim->set_attr("keep_value_node_input", MakeValue(true));
772       new_node->set_scope(scope);
773       node_input[0]->set_scope(scope);
774       make_tuple_inputs.push_back(new_node);
775     }
776   }
777   CNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs);
778   (void)manager->Replace(node, make_tuple);
779 }
780 
StepSplitTensor(const AnfNodePtr & node,const FuncGraphManagerPtr & manager)781 static void StepSplitTensor(const AnfNodePtr &node, const FuncGraphManagerPtr &manager) {
782   MS_EXCEPTION_IF_NULL(node);
783   MS_EXCEPTION_IF_NULL(manager);
784   AnfNodeIndexSet node_set = manager->node_users()[node];
785   for (auto &node_pair : node_set) {
786     CNodePtr use_cnode = node_pair.first->cast<CNodePtr>();
787     if (use_cnode == nullptr || !IsValueNode<Primitive>(use_cnode->input(0))) {
788       continue;
789     }
790     ValueNodePtr prim_anf_node = use_cnode->input(0)->cast<ValueNodePtr>();
791     MS_EXCEPTION_IF_NULL(prim_anf_node);
792     PrimitivePtr use_cnode_prim = prim_anf_node->value()->cast<PrimitivePtr>();
793     MS_EXCEPTION_IF_NULL(use_cnode_prim);
794     if ((use_cnode_prim->name() == DEPEND && node_pair.second != 1) ||
795         NO_INPUT_TENSOR_OPS.find(use_cnode_prim->name()) != NO_INPUT_TENSOR_OPS.end()) {
796       continue;
797     }
798     if (IsParallelCareNode(use_cnode)) {
799       if (IsPrimitiveCNode(use_cnode, prim::kPrimReceive)) {
800         continue;
801       }
802       if (IsValueNode<ValueList>(node) || IsValueNode<ValueTuple>(node)) {
803         SplitTensorList(node, use_cnode, node_pair.second);
804       } else {
805         SplitTensor(node, use_cnode, node_pair.second);
806       }
807     }
808   }
809 }
810 
StepReplaceOp(OperatorVector replace_op,const CNodePtr & node)811 static void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) {
812   MS_LOG(INFO) << "Start StepReplaceOp for " << node->fullname_with_scope();
813   // step1:get graph manager distribute_operator
814   OperatorInfoPtr distribute_operator = node->user_data<OperatorInfo>();
815   if (distribute_operator == nullptr) {
816     MS_LOG(EXCEPTION) << "Failure:AddNode error since distribute_operator is nullptr";
817   }
818   FuncGraphPtr func_graph = node->func_graph();
819   MS_EXCEPTION_IF_NULL(func_graph);
820   FuncGraphManagerPtr manager = func_graph->manager();
821   if (manager == nullptr) {
822     MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr";
823   }
824 
825   // When reshape(bool), insert cast in the begin and end of op_list to avoid AllGather(bool).
826   auto reshape_type_str = node->abstract()->BuildType()->ToString();
827   auto replace_op_info = distribute_operator->replace_op_info();
828   if (IsPrimitiveCNode(node, prim::kPrimReshape) && reshape_type_str.find(BOOL) != std::string::npos) {
829     auto cast_int = CreateCastOp(kInt32);
830     auto cast_bool = CreateCastOp(kBool);
831     (void)replace_op.insert(replace_op.cbegin(), cast_int);
832     (void)replace_op.insert(replace_op.cend(), cast_bool);
833     (void)replace_op_info.insert(replace_op_info.cbegin(), {false, 1});
834     (void)replace_op_info.insert(replace_op_info.cend(), {false, 1});
835   }
836 
837   // step2:traverse op_list and insert node
838   std::reverse(replace_op.begin(), replace_op.end());
839   std::reverse(replace_op_info.begin(), replace_op_info.end());
840   if (!replace_op_info.empty() && replace_op_info.size() != replace_op.size()) {
841     MS_LOG(EXCEPTION) << "replace_op_info is not empty and size not equal to replace_op!";
842   }
843   bool replace_op_info_flag = !replace_op_info.empty();
844   for (size_t index = 0; index < replace_op.size(); ++index) {
845     std::string instance_name = CreateInstanceName(node, index);
846     std::string full_inst_name = std::string(REDISTRIBUTION_OP) + "_" + instance_name;
847     std::vector<AnfNodePtr> replace_input;
848     if (index != replace_op.size() - 1) {
849       replace_input = CreateInput(replace_op[index], node, full_inst_name, node);
850     } else {
851       replace_input = ReplaceOpInput(replace_op[index], full_inst_name, node);
852     }
853     CNodePtr replace_node = func_graph->NewCNode(replace_input);
854     MS_EXCEPTION_IF_NULL(replace_node);
855     ScopePtr scope = node->scope();
856     MS_EXCEPTION_IF_NULL(scope);
857     replace_node->set_scope(scope);
858     PrimitivePtr prim = GetValueNode<PrimitivePtr>(replace_node->input(0));
859     PrimitivePtr origin_prim = GetValueNode<PrimitivePtr>(node->input(0));
860     SetUserAttrs(origin_prim->attrs(), prim);
861     auto origin_prim_attrs = origin_prim->attrs();
862     if (origin_prim_attrs.find(RECOMPUTE_COMM_OP) != origin_prim_attrs.end()) {
863       auto do_recompute = GetValue<bool>(origin_prim_attrs[RECOMPUTE_COMM_OP]);
864       MS_LOG(INFO) << "The redistribution node in reshape would not be recomputed.";
865       prim->set_attr(RECOMPUTE, MakeValue(do_recompute));
866     }
867     if (prim->name() == GET_NEXT && origin_prim_attrs.find(SYMBOLS) != origin_prim_attrs.end()) {
868       prim->set_attr(SYMBOLS, origin_prim_attrs[SYMBOLS]);
869     }
870     if (index == replace_op.size() - 1) {
871       replace_node->set_user_data<OperatorInfo>(node->user_data<OperatorInfo>());
872       replace_node->set_primal_attrs(node->primal_attrs());
873     }
874     replace_node->AddPrimalAttr(kPrimalAttrForwardCommNodeUniqueId, MakeValue<std::string>(replace_node->UniqueId()));
875     if (node->HasPrimalAttr(MICRO)) {
876       replace_node->AddPrimalAttr(MICRO, node->GetPrimalAttr(MICRO));
877     }
878     replace_node->set_in_forward_flag(true);
879     replace_input[0]->set_scope(scope);
880     if (replace_op_info_flag && replace_op_info[index].first) {
881       auto new_cnode = InsertMakeTuple(replace_node, replace_op_info[index].second, func_graph);
882       new_cnode->set_primal_attrs(node->primal_attrs());
883       (void)manager->Replace(node, new_cnode);  // using Replace function to insert node
884     } else {
885       (void)manager->Replace(node, replace_node);  // using Replace function to insert node
886     }
887   }
888   MS_LOG(INFO) << "Insert ReplaceOp success for " << distribute_operator->name();
889 }
890 
StepReplaceGraph(const ReplaceGraphPtr & replace_graph,const CNodePtr & node,const OperatorInfoPtr & op_info)891 static void StepReplaceGraph(const ReplaceGraphPtr &replace_graph, const CNodePtr &node,
892                              const OperatorInfoPtr &op_info) {
893   MS_EXCEPTION_IF_NULL(replace_graph);
894   MS_EXCEPTION_IF_NULL(node);
895   MS_EXCEPTION_IF_NULL(replace_graph->second);
896   FuncGraphPtr func_graph = node->func_graph();
897   MS_EXCEPTION_IF_NULL(func_graph);
898   FuncGraphManagerPtr manager = func_graph->manager();
899   if (manager == nullptr) {
900     MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr";
901   }
902   // Solve the input order
903   // For example input_node:{segment_sum:1, segment_sum:2, gather:2}
904   // The Original code here will bind the all operations to the first inputs of these operators
905   // However, the segment_sum operation needs two inputs, To solve this
906   // We maintain a dict to count the times of the same operations,
907   // and bind the inputs according to the times of the op appears.
908   mindspore::HashMap<AnfNodePtr, int> input_map = {};
909   static int appear_count = 0;
910   for (auto &replace_input : replace_graph->first) {
911     auto pre_node = node->input(LongToSize(replace_input.second));
912 
913     auto it = input_map.find(replace_input.first);
914     if (it != input_map.end()) {
915       appear_count = 1 + it->second;
916     } else {
917       appear_count = 1;
918     }
919     auto replace_input_cnode = replace_input.first->cast<CNodePtr>();
920     replace_input_cnode->set_user_data<OperatorInfo>(op_info);
921     size_t inputs_size = replace_input_cnode->size();
922     while (IntToSize(appear_count) < inputs_size && replace_input_cnode->input(appear_count)->func_graph() != nullptr) {
923       ++appear_count;
924     }
925     if (IntToSize(appear_count) >= inputs_size) {
926       MS_LOG(EXCEPTION) << "No replaceable virtual_input_node";
927     }
928     input_map[replace_input.first] = appear_count;
929     replace_input_cnode->set_in_forward_flag(true);
930     manager->SetEdge(replace_input.first, appear_count, pre_node);
931   }
932   //  "(void)manager->Replace(replace_graph->first, pre_node);" can not be called
933   auto replace_output = replace_graph->second->cast<CNodePtr>();
934   MS_EXCEPTION_IF_NULL(replace_output);
935   replace_output->set_in_forward_flag(true);
936   replace_output->set_primal_attrs(node->primal_attrs());
937   (void)manager->Replace(node, replace_output);
938 }
939 
InsertVirtualDivOp(const VirtualDivOp & virtual_div_op,const CNodePtr & node)940 static void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node) {
941   MS_EXCEPTION_IF_NULL(node);
942   size_t node_size = node->size();
943   FuncGraphPtr func_graph = node->func_graph();
944   MS_EXCEPTION_IF_NULL(func_graph);
945   FuncGraphManagerPtr manager = func_graph->manager();
946   MS_EXCEPTION_IF_NULL(manager);
947 
948   if (IsSomePrimitive(node, DROPOUT_DO_MASK)) {
949     MS_LOG(INFO) << "Handle dropout do mask, only insert the virtual div to input[0]";
950     node_size = 2;
951   }
952 
953   for (size_t index = 1; index < node_size; ++index) {
954     AnfNodePtr input = node->input(index);
955     MS_EXCEPTION_IF_NULL(input);
956     // if it is not a tensor, continue
957     if ((!input->isa<CNode>() && !input->isa<Parameter>()) || HasAbstractMonad(input)) {
958       MS_LOG(INFO) << "insert div op: the index  " << index << "  is not tensor, skip";
959       continue;
960     }
961 
962     for (size_t pos = 0; pos < virtual_div_op.size(); ++pos) {
963       std::string instance_name = CreateInstanceName(node, pos);
964       InsertNode(virtual_div_op[pos], node, index, node->input(index), func_graph, instance_name);
965     }
966     MS_LOG(INFO) << "insert div op for input index  " << index << "  of node";
967   }
968 }
969 
InsertRealDivOpToNodeInput(const CNodePtr & node,int64_t scale,const string & instance_name)970 static void InsertRealDivOpToNodeInput(const CNodePtr &node, int64_t scale, const string &instance_name) {
971   MS_EXCEPTION_IF_NULL(node);
972   if (scale == 0) {
973     MS_LOG(EXCEPTION) << "Find the scale value is 0, you should check the mirror operators's group size.";
974   }
975   size_t node_size = node->size();
976   FuncGraphPtr func_graph = node->func_graph();
977   MS_EXCEPTION_IF_NULL(func_graph);
978   // instance the real div operator
979   Operator div_op = CreateDivOp(LongToFloat(scale));
980 
981   // Insert it as the input of the node
982   for (size_t index = 1; index < node_size; ++index) {
983     AnfNodePtr input = node->input(index);
984     MS_EXCEPTION_IF_NULL(input);
985     // if it is not a tensor, continue
986     if ((!input->isa<CNode>() && !input->isa<Parameter>()) || HasAbstractMonad(input)) {
987       continue;
988     }
989     InsertNode(div_op, node, index, node->input(index), func_graph, instance_name);
990   }
991 }
992 
InsertAllReduceToNodeInput(const CNodePtr & node,const std::string & group,const std::string & instance_name)993 static void InsertAllReduceToNodeInput(const CNodePtr &node, const std::string &group,
994                                        const std::string &instance_name) {
995   MS_EXCEPTION_IF_NULL(node);
996   size_t node_size = node->size();
997   FuncGraphPtr func_graph = node->func_graph();
998   MS_EXCEPTION_IF_NULL(func_graph);
999   // instance the real div operator
1000   CheckGlobalDeviceManager();
1001   Operator allreduce_op = CreateAllReduceOp(REDUCE_OP_SUM, group);
1002 
1003   // Insert it as the input of the node
1004   for (size_t index = 1; index < node_size; ++index) {
1005     AnfNodePtr input = node->input(index);
1006     MS_EXCEPTION_IF_NULL(input);
1007     // if it is not a tensor, continue
1008     if ((!input->isa<CNode>() && !input->isa<Parameter>()) || HasAbstractMonad(input)) {
1009       continue;
1010     }
1011 
1012     InsertNode(allreduce_op, node, index, node->input(index), func_graph, instance_name);
1013   }
1014 }
1015 
PynativeParallelGraph(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & all_nodes)1016 static FuncGraphPtr PynativeParallelGraph(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
1017   FuncGraphPtr real_graph = root;
1018   for (auto &node : all_nodes) {
1019     if (!node->isa<CNode>()) {
1020       continue;
1021     }
1022     auto cnode = node->cast<CNodePtr>();
1023     if (!IsValueNode<Primitive>(cnode->input(0))) {
1024       continue;
1025     }
1026     auto expect_shard_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
1027     if (expect_shard_prim->name() != SHARD) {
1028       continue;
1029     }
1030     real_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
1031   }
1032   return real_graph;
1033 }
1034 
1035 // find previous parallel care node's next node.
FindPreNodes(const AnfNodePtr & node,std::vector<std::string> * unique_ids,std::vector<size_t> * indexes,size_t curr_depth)1036 static bool FindPreNodes(const AnfNodePtr &node, std::vector<std::string> *unique_ids, std::vector<size_t> *indexes,
1037                          size_t curr_depth) {
1038   if (curr_depth > MAX_RECURSIVE_DEPTH) {
1039     MS_LOG(WARNING) << "When find the previous node, exceeded the maximum recursion depth: " << MAX_RECURSIVE_DEPTH;
1040     return false;
1041   }
1042   MS_EXCEPTION_IF_NULL(unique_ids);
1043   MS_EXCEPTION_IF_NULL(indexes);
1044   if (!node->isa<CNode>()) {
1045     return false;
1046   }
1047   CNodePtr pre_cnode = node->cast<CNodePtr>();
1048   if (!IsValueNode<Primitive>(pre_cnode->input(0))) {
1049     return false;
1050   }
1051   bool find = false;
1052   for (size_t index = 1; index < pre_cnode->size(); ++index) {
1053     if (IsPrimitiveCNode(pre_cnode, prim::kPrimDepend) && index > 1) {
1054       // For Depend, only the first input will be output.
1055       break;
1056     }
1057     auto next_node = pre_cnode->inputs()[index];
1058     if (!next_node->isa<CNode>() || next_node->isa<Parameter>()) {
1059       return false;
1060     }
1061     CNodePtr cnode = next_node->cast<CNodePtr>();
1062     if (!IsValueNode<Primitive>(cnode->input(0))) {
1063       return false;
1064     }
1065     if (IsParallelCareNode(cnode) && !IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) &&
1066         !IsPrimitiveCNode(cnode, prim::kPrimMakeList)) {
1067       unique_ids->push_back(pre_cnode->UniqueId());
1068       indexes->push_back(index);
1069       find = true;
1070       continue;
1071     }
1072     if (FindPreNodes(cnode, unique_ids, indexes, ++curr_depth)) {
1073       find = true;
1074     }
1075   }
1076   return find;
1077 }
1078 
InsertVirtualOutput(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & all_nodes)1079 void InsertVirtualOutput(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
1080   auto real_graph = PynativeParallelGraph(root, all_nodes);
1081   auto out_pair = GetRealKernelNode(real_graph->output(), -1, nullptr, false);
1082   auto out_node = out_pair.first;
1083   MS_EXCEPTION_IF_NULL(out_node);
1084   OperatorParams params;
1085   OperatorAttrs attrs;
1086   OperatorArgs args = std::make_pair(attrs, params);
1087   Operator op = std::make_pair(VIRTUAL_OUTPUT, args);
1088   if (IsPrimitiveCNode(out_node, prim::kPrimMakeTuple)) {
1089     auto tuple = out_node->cast<CNodePtr>();
1090     MS_EXCEPTION_IF_NULL(tuple);
1091     for (size_t i = 1; i < tuple->size(); ++i) {
1092       auto cur_input = tuple->input(i);
1093       Shapes shape_outputs = GetNodeShape(cur_input);
1094       if (shape_outputs[0].empty()) {
1095         continue;
1096       }
1097       InsertNode(op, tuple, i, cur_input, tuple->func_graph(), VIRTUAL_OUTPUT);
1098       auto virtual_output_abstract = cur_input->abstract()->Clone();
1099       std::shared_ptr<abstract::BaseShape> virtual_output_shape = std::make_shared<abstract::Shape>(shape_outputs[0]);
1100       virtual_output_abstract->set_shape(virtual_output_shape);
1101       auto virtual_output_node = tuple->input(i);
1102       virtual_output_node->set_abstract(virtual_output_abstract);
1103     }
1104   } else {
1105     Shapes shape_outputs = GetNodeShape(out_node);
1106     if (shape_outputs[0].empty() || out_node->isa<Parameter>()) {
1107       return;
1108     }
1109     auto node_input = CreateInput(op, out_node, VIRTUAL_OUTPUT);
1110     auto cur_graph = out_node->cast<CNodePtr>()->func_graph();
1111     MS_EXCEPTION_IF_NULL(cur_graph);
1112     auto new_node = cur_graph->NewCNode(node_input);
1113     auto manager = cur_graph->manager();
1114     (void)manager->Replace(out_node, new_node);
1115     auto virtual_output_abstract = out_node->abstract()->Clone();
1116     std::shared_ptr<abstract::BaseShape> virtual_output_shape = std::make_shared<abstract::Shape>(shape_outputs[0]);
1117     virtual_output_abstract->set_shape(virtual_output_shape);
1118     new_node->set_abstract(virtual_output_abstract);
1119   }
1120 }
1121 
InsertMirrorBeforeCast(const CNodePtr & node,size_t index)1122 bool InsertMirrorBeforeCast(const CNodePtr &node, size_t index) {
1123   // only if gradient_fp32_sync is true, pre node is cast and type is not float32 return true
1124   bool is_gradient_fp32_sync = ParallelContext::GetInstance()->gradient_fp32_sync();
1125   auto pre_node = node->input(index);
1126   MS_EXCEPTION_IF_NULL(pre_node);
1127   auto cnode = pre_node->cast<CNodePtr>();
1128   if (cnode == nullptr || !IsValueNode<Primitive>(cnode->input(0))) {
1129     return false;
1130   }
1131   if (ParallelContext::GetInstance()->enable_parallel_optimizer() && IsInAllGatherNodeList(cnode)) {
1132     pre_node = cnode->input(1);
1133   }
1134   if (!IsPrimitiveCNode(pre_node, prim::kPrimCast)) {
1135     return false;
1136   }
1137   auto node_type = pre_node->Type();
1138   MS_EXCEPTION_IF_NULL(node_type);
1139   if (!node_type->isa<mindspore::TensorType>()) {
1140     MS_LOG(EXCEPTION) << "Unknown type.";
1141   }
1142   auto input_element_type = node_type->cast<mindspore::TensorTypePtr>()->element();
1143   MS_EXCEPTION_IF_NULL(input_element_type);
1144   auto type_id = input_element_type->type_id();
1145   if (!is_gradient_fp32_sync && type_id != kNumberTypeFloat32) {
1146     return false;
1147   }
1148 
1149   return true;
1150 }
1151 
CheckInsertMirrorOps(const MirrorOps & mirror_ops,const CNodePtr & node)1152 static bool CheckInsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
1153   if (IsPrimitiveCNode(node, prim::kPrimSend)) {
1154     return true;
1155   }
1156   constexpr size_t kSingleArgCNodeSize = 2;
1157   if ((node->size() == kSingleArgCNodeSize || IsSomePrimitiveList(node, INPUT_IS_TUPLE_OR_LIST_OPS)) &&
1158       (IsValueNode<ValueSequence>(node->input(1)))) {
1159     MS_LOG(INFO) << "Input is ValueList, skip it.";
1160     return false;
1161   }
1162 
1163   if ((node->size() == kSingleArgCNodeSize || IsSomePrimitiveList(node, INPUT_IS_TUPLE_OR_LIST_OPS)) &&
1164       (AnfNodeIsPrimitive(node->input(1), MAKE_TUPLE) || AnfNodeIsPrimitive(node->input(1), MAKE_LIST))) {
1165     MS_LOG(INFO) << "The mirror for " << GetPrimName(node) << " has handle by make_tuple node";
1166     return false;
1167   }
1168   return true;
1169 }
1170 
1171 // only used for InsertMirrorOps
SkipTrivialNodesMoveUp(CNodePtr node)1172 static CNodePtr SkipTrivialNodesMoveUp(CNodePtr node) {
1173   MS_EXCEPTION_IF_NULL(node);
1174   while (True) {
1175     if (IsPrimitiveCNode(node, prim::kPrimLoad) || IsInTrivialNodeList(node) || IsInAllGatherNodeList(node)) {
1176       if (IsPrimitiveCNode(node->input(1), prim::kPrimMicroStepAllGather)) {
1177         return node;
1178       }
1179       if (node->input(1)->isa<Parameter>()) {
1180         return node;
1181       }
1182       node = node->input(1)->cast<CNodePtr>();
1183     } else {
1184       MS_LOG(EXCEPTION) << "The node " << node->fullname_with_scope()
1185                         << " is a abnormal node in inserting mirror node.";
1186     }
1187   }
1188 }
1189 
CreateMirrorForParam(const ParameterPtr param_ptr,OperatorVector * backward_op,bool * is_shared_param)1190 static void CreateMirrorForParam(const ParameterPtr param_ptr, OperatorVector *backward_op, bool *is_shared_param) {
1191   std::string opt_shard_mirror_group;
1192   if (param_ptr->user_data<TensorLayout>()) {
1193     opt_shard_mirror_group = param_ptr->user_data<TensorLayout>()->opt_shard_mirror_group();
1194     *is_shared_param = param_ptr->user_data<TensorLayout>()->is_shared_param();
1195   }
1196   if (!opt_shard_mirror_group.empty()) {
1197     // mirror ops is covered in not fully use opt shard case
1198     uint32_t group_rank_size = 0;
1199     if (!CommManager::GetInstance().GetRankSize(opt_shard_mirror_group, &group_rank_size)) {
1200       MS_LOG(EXCEPTION) << "Got the group size from the group " << opt_shard_mirror_group << " failed";
1201     }
1202     *backward_op = CreateMirrorOps(opt_shard_mirror_group, static_cast<size_t>(group_rank_size));
1203   }
1204 }
1205 
DoInsertMirrorOps(const FuncGraphPtr & root,const MirrorOps & mirror_ops,const CNodePtr & node)1206 static void DoInsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, const CNodePtr &node) {
1207   FuncGraphPtr func_graph = node->func_graph();
1208   MS_EXCEPTION_IF_NULL(func_graph);
1209   FuncGraphManagerPtr manager = func_graph->manager();
1210   MS_EXCEPTION_IF_NULL(manager);
1211   auto mirror_size = mirror_ops.size();
1212   if (IsPrimitiveCNode(node, prim::kPrimSend)) {
1213     mirror_size = 1;
1214   }
1215 
1216   for (size_t index = 1; index <= mirror_size; ++index) {
1217     OperatorVector backward_op = mirror_ops[index - 1];
1218     if (IsPrimitiveCNode(node, prim::kPrimSend)) {
1219       auto param_index = GetValue<int>(node->GetPrimalAttr(PARAM_INDEX));
1220       backward_op = mirror_ops[IntToSize(param_index)];
1221     }
1222     if (backward_op.empty()) {
1223       continue;
1224     }
1225     std::pair<AnfNodePtr, bool> param_node_pair = FindParameter(node->input(index), func_graph);
1226     if (!param_node_pair.first) {
1227       continue;
1228     }
1229 
1230     auto param_ptr = param_node_pair.first->cast<ParameterPtr>();
1231     std::string param_name;
1232     bool is_shared_param = false;
1233     if (param_ptr) {
1234       param_name = param_ptr->name();
1235       if (!param_ptr->param_info() || !param_ptr->param_info()->requires_grad()) {
1236         MS_LOG(INFO) << param_name << " do not need gradient. Skip inserting mirror.";
1237         continue;
1238       }
1239       CreateMirrorForParam(param_ptr, &backward_op, &is_shared_param);
1240     }
1241     // not a RefKey
1242     std::string mirror_op_name = MirrorOpName();
1243     AnfNodePtr pre_node = node->input(index);
1244     if (!param_node_pair.second) {
1245       auto next_cnode = FindCNode(param_node_pair.first, mirror_op_name, func_graph, 0);
1246       // if there is already a MirrorOp in the same graph, use MirrorOp CNode as a input instead
1247       if (next_cnode.first) {
1248         MS_EXCEPTION_IF_NULL(next_cnode.second);
1249         // assume Load is inserted next to parameter
1250         // skip Load moving up and insert mirror next to the parameter
1251         if (pre_node->cast<CNodePtr>()) {
1252           CNodePtr load_node = SkipTrivialNodesMoveUp(node->input(index)->cast<CNodePtr>());
1253           manager->SetEdge(load_node, 1, next_cnode.second);
1254         } else {
1255           manager->SetEdge(node, static_cast<int>(index), next_cnode.second);
1256         }
1257         MS_LOG(INFO) << "Find parameter " << param_name << " for node " << GetPrimName(node->cast<CNodePtr>())
1258                      << " and share the mirror.";
1259         AddNodeMirrorInfo(node->cast<CNodePtr>(), param_name);
1260         continue;
1261       }
1262     }
1263     // if the parameter found is a RefKey, or no MirrorOp is found in the same graph, insert a new MirrorOp
1264     // only one MirrorOp in backward_op
1265     if (backward_op.size() != 1) {
1266       MS_LOG(EXCEPTION) << "backward_op size must be 1, real is " << backward_op.size();
1267     }
1268     auto op = backward_op[0];
1269     if (pre_node->cast<CNodePtr>() && (InsertMirrorBeforeCast(node, index) || is_shared_param ||
1270                                        IsPrimitiveCNode(pre_node, prim::kPrimMirrorSilentCheck))) {
1271       // assume Load is inserted next to parameter
1272       // skip Load moving up and insert mirror next to the parameter
1273       CNodePtr load_node = SkipTrivialNodesMoveUp(pre_node->cast<CNodePtr>());
1274       InsertNode(op, load_node, 1, load_node->input(1), func_graph, mirror_op_name, param_name, root);
1275       auto comm_op = load_node->input(1)->cast<CNodePtr>();
1276       // add fusion flag
1277       auto fusion_id = AddCommOpFusionType(comm_op, param_node_pair.first);
1278       MS_LOG(INFO) << "Find parameter " << param_name << " for node " << GetPrimName(node->cast<CNodePtr>())
1279                    << " and insert mirror before Load";
1280       AddCommOpParamFlag(comm_op);
1281       AddNodeFusionInfo(node, comm_op, "all_reduce", param_name, fusion_id);
1282       continue;
1283     }
1284     InsertNode(op, node, index, pre_node, func_graph, mirror_op_name, param_name, root);
1285     MS_LOG(INFO) << "Find parameter " << param_name << " for node " << GetPrimName(node->cast<CNodePtr>())
1286                  << " and insert mirror before the node";
1287     auto comm_op = node->input(index)->cast<CNodePtr>();
1288     // add fusion flag
1289     // pipeline mirror would not be set, which should be supported later
1290     auto fusion_id = AddCommOpFusionType(comm_op, param_node_pair.first);
1291     AddCommOpParamFlag(comm_op);
1292     AddNodeFusionInfo(node, comm_op, "all_reduce", param_name, fusion_id);
1293   }
1294 }
1295 
InsertMirrorOps(const FuncGraphPtr & root,const MirrorOps & mirror_ops,const CNodePtr & node)1296 static void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, const CNodePtr &node) {
1297   MS_EXCEPTION_IF_NULL(node);
1298   if (!CheckInsertMirrorOps(mirror_ops, node)) {
1299     return;
1300   }
1301 
1302   DoInsertMirrorOps(root, mirror_ops, node);
1303 }
1304 
BackwardCommunication(const FuncGraphPtr & root,const OperatorInfoPtr & distribute_operator,const CNodePtr & node,const std::vector<std::pair<CNodePtr,LossNodeInfo>> & sens_loss_pairs)1305 static void BackwardCommunication(const FuncGraphPtr &root, const OperatorInfoPtr &distribute_operator,
1306                                   const CNodePtr &node,
1307                                   const std::vector<std::pair<CNodePtr, LossNodeInfo>> &sens_loss_pairs) {
1308   MS_EXCEPTION_IF_NULL(distribute_operator);
1309   MS_EXCEPTION_IF_NULL(node);
1310 
1311   if (IsPrimitiveCNode(node, prim::kPrimReceive)) {
1312     return;
1313   }
1314   bool is_loss_cnode =
1315     std::any_of(sens_loss_pairs.begin(), sens_loss_pairs.end(),
1316                 [node](const std::pair<CNodePtr, LossNodeInfo> &element) { return element.second.loss_node == node; });
1317 
1318   MirrorOps mirror_ops = distribute_operator->mirror_ops();
1319   VirtualDivOp virtual_div_op = distribute_operator->virtual_div_op();
1320   // insert mirror op
1321   if (!mirror_ops.empty()) {
1322     MS_LOG(INFO) << "insert mirror op for " << distribute_operator->name();
1323     InsertMirrorOps(root, mirror_ops, node);
1324   }
1325   // insert virtual div op
1326   if (!virtual_div_op.empty() && is_loss_cnode && IsLastStage()) {
1327     MS_LOG(INFO) << "insert virtual div op for " << distribute_operator->name();
1328     InsertVirtualDivOp(virtual_div_op, node);
1329   }
1330 }
1331 
FindParallelCareNode(const AnfNodePtr & node,int32_t recursion_num)1332 static std::pair<AnfNodePtr, int64_t> FindParallelCareNode(const AnfNodePtr &node, int32_t recursion_num) {
1333   if (recursion_num >= RECURSION_LIMIT) {
1334     return std::make_pair(nullptr, 0);
1335   }
1336 
1337   MS_EXCEPTION_IF_NULL(node);
1338   FuncGraphPtr func_graph = node->func_graph();
1339   MS_EXCEPTION_IF_NULL(func_graph);
1340   FuncGraphManagerPtr manager = func_graph->manager();
1341   MS_EXCEPTION_IF_NULL(manager);
1342   AnfNodeIndexSet node_set = manager->node_users()[node];
1343   for (auto &node_pair : node_set) {
1344     CNodePtr cnode = node_pair.first->cast<CNodePtr>();
1345     MS_EXCEPTION_IF_NULL(cnode);
1346     if (!IsValueNode<Primitive>(cnode->input(0))) {
1347       continue;
1348     }
1349     if (IsPrimitiveCNode(cnode, prim::kPrimMirrorSilentCheck) && node_pair.second != 1) {
1350       continue;
1351     }
1352     ValueNodePtr prim_node_anf = cnode->input(0)->cast<ValueNodePtr>();
1353     MS_EXCEPTION_IF_NULL(prim_node_anf);
1354     PrimitivePtr node_prim = prim_node_anf->value()->cast<PrimitivePtr>();
1355     MS_EXCEPTION_IF_NULL(node_prim);
1356     if ((node_prim->name() == DEPEND && node_pair.second != 1) || IsPrimitiveCNode(cnode, prim::kPrimReceive) ||
1357         IsPrimitiveCNode(cnode, prim::kPrimSend)) {
1358       continue;
1359     }
1360     if (node_prim->name() == UPDATESTATE && node_pair.second > 0) {
1361       continue;
1362     }
1363     if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
1364       return node_pair;
1365     } else {
1366       auto tmp_pair = FindParallelCareNode(node_pair.first, recursion_num + 1);
1367       if (tmp_pair.first != nullptr) {
1368         return tmp_pair;
1369       }
1370     }
1371   }
1372   return std::make_pair(nullptr, 0);
1373 }
1374 
FindSubGraph(const FuncGraphPtr & graph,const AnfNodePtr & parameter)1375 static std::pair<AnfNodePtr, int64_t> FindSubGraph(const FuncGraphPtr &graph, const AnfNodePtr &parameter) {
1376   MS_EXCEPTION_IF_NULL(graph);
1377   MS_EXCEPTION_IF_NULL(parameter);
1378   FuncGraphManagerPtr manager = graph->manager();
1379   MS_EXCEPTION_IF_NULL(manager);
1380   std::pair<AnfNodePtr, int64_t> prim_anf_node_pair = FindParallelCareNode(parameter, 0);
1381   if (prim_anf_node_pair.first != nullptr) {
1382     return prim_anf_node_pair;
1383   } else {
1384     AnfNodeIndexSet param_sub_set = manager->node_users()[parameter];
1385     for (auto &param_pair : param_sub_set) {
1386       CNodePtr param_cnode = param_pair.first->cast<CNodePtr>();
1387       AnfNodePtr graph_value_node;
1388       if (param_cnode->input(0)->isa<CNode>()) {
1389         graph_value_node = param_cnode->input(0)->cast<CNodePtr>()->input(1);
1390       } else {
1391         graph_value_node = param_cnode->input(0);
1392       }
1393       if (!IsValueNode<FuncGraph>(graph_value_node)) {
1394         continue;
1395       }
1396       FuncGraphPtr graph_sub = GetValueNode<FuncGraphPtr>(graph_value_node);
1397       auto parameters = graph_sub->parameters();
1398       if (LongToSize(param_pair.second - 1) >= parameters.size()) {
1399         MS_LOG(EXCEPTION) << "The index is out of range, index is: " << (param_pair.second - 1) << ", vector size is "
1400                           << parameters.size();
1401       }
1402       std::pair<AnfNodePtr, int64_t> res = FindSubGraph(graph_sub, parameters[LongToSize(param_pair.second - 1)]);
1403       if (res.first != nullptr) {
1404         return res;
1405       }
1406     }
1407   }
1408   return std::make_pair(nullptr, 0);
1409 }
1410 
InsertAllGatherAfterCast(const std::pair<AnfNodePtr,int> & node_pair)1411 static CNodePtr InsertAllGatherAfterCast(const std::pair<AnfNodePtr, int> &node_pair) {
1412   if (ParallelContext::GetInstance()->pipeline_stage_split_num() <= 1) {
1413     return nullptr;
1414   }
1415   auto cnode = node_pair.first->cast<CNodePtr>();
1416   MS_EXCEPTION_IF_NULL(cnode);
1417   auto graph = cnode->func_graph();
1418   MS_EXCEPTION_IF_NULL(graph);
1419   auto manager = graph->manager();
1420   MS_EXCEPTION_IF_NULL(manager);
1421   // skip Load moving down and assume it only has one node user
1422   CNodePtr res = cnode;
1423   if (IsSomePrimitive(res, LOAD)) {
1424     res = manager->node_users()[cnode].begin()->first->cast<CNodePtr>();
1425   }
1426   // return true only if cnode is Cast from fp32 to fp16
1427   if (!IsSomePrimitive(res, CAST)) {
1428     return nullptr;
1429   }
1430   auto node_type = res->Type();
1431   MS_EXCEPTION_IF_NULL(node_type);
1432   if (!node_type->isa<mindspore::TensorType>()) {
1433     MS_LOG(EXCEPTION) << "Unknown type.";
1434   }
1435   auto input_element_type = node_type->cast<mindspore::TensorTypePtr>()->element();
1436   MS_EXCEPTION_IF_NULL(input_element_type);
1437   auto type_id = input_element_type->type_id();
1438 
1439   if (type_id != kNumberTypeFloat32) {
1440     return res;
1441   } else {
1442     return nullptr;
1443   }
1444 }
1445 
AddAllGatherAttrs(const CNodePtr & allgather,const CNodePtr & cnode,const AnfNodePtr & node,const std::string & op_name,bool add_accu,bool is_with_mirror,bool grad_accumulation_shard)1446 void AddAllGatherAttrs(const CNodePtr &allgather, const CNodePtr &cnode, const AnfNodePtr &node,
1447                        const std::string &op_name, bool add_accu, bool is_with_mirror, bool grad_accumulation_shard) {
1448   // add fusion flag
1449   auto fusion_id = AddCommOpFusionType(allgather, node);
1450   auto param_ptr = node->cast<ParameterPtr>();
1451   auto param_name = param_ptr->name();
1452   AddNodeFusionInfo(cnode, allgather, "reduce_scatter", param_name, fusion_id);
1453   // add gradients mean
1454   AddCommOpMeanFlag(allgather);
1455   AddCNodePrimAttr(allgather, "with_mirror_operator", MakeValue<bool>(is_with_mirror));
1456   if (op_name == MICRO_STEP_ALL_GATHER) {
1457     // When grad_accumulation_shard is enabled, the ReduceScatter is inserted at each micro step
1458     // so no need to do backward for the micro_step_allgather
1459     AddCNodePrimAttr(allgather, DO_MIRROR, MakeValue<bool>(!grad_accumulation_shard));
1460   } else if (op_name == MINI_STEP_ALL_GATHER) {
1461     // We need to manually set the add_accu to be false if it's father node is MirrorMiniStep
1462     AddCNodePrimAttr(allgather, ADD_ACCU, MakeValue<bool>(!add_accu && !is_with_mirror));
1463     AddCNodePrimAttr(allgather, DO_MIRROR, MakeValue<bool>(!grad_accumulation_shard || !add_accu));
1464   }
1465 }
1466 
InsertAllGatherOp(const FuncGraphPtr & root,const std::string & group,const std::pair<AnfNodePtr,int> & res,const AnfNodePtr & node,const std::string & op_name,bool is_shared_param)1467 static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group, const std::pair<AnfNodePtr, int> &res,
1468                               const AnfNodePtr &node, const std::string &op_name, bool is_shared_param) {
1469   MS_EXCEPTION_IF_NULL(res.first);
1470   MS_EXCEPTION_IF_NULL(node);
1471   bool grad_accumulation_shard = ParallelContext::GetInstance()->grad_accumulation_shard();
1472   auto cnode = res.first->cast<CNodePtr>();
1473   auto graph = cnode->func_graph();
1474   MS_EXCEPTION_IF_NULL(graph);
1475   auto manager = graph->manager();
1476   MS_EXCEPTION_IF_NULL(manager);
1477   Operator op;
1478   CNodePtr allgather;
1479   auto param_name = node->cast<ParameterPtr>()->name();
1480   if (op_name == MICRO_STEP_ALL_GATHER) {
1481     op = CreateMicroStepAllGatherOp(group);
1482   } else {
1483     op = CreateAllGatherOp(group);
1484   }
1485   CNodePtr cast_node = InsertAllGatherAfterCast(res);
1486   auto param_ptr = node->cast<ParameterPtr>();
1487   MS_EXCEPTION_IF_NULL(param_ptr);
1488   bool is_with_mirror = false;
1489   if (param_ptr->user_data<TensorLayout>()) {
1490     auto opt_shard_mirror_group = param_ptr->user_data<TensorLayout>()->opt_shard_mirror_group();
1491     is_with_mirror = !opt_shard_mirror_group.empty();
1492     if (!param_ptr->param_info()->parallel_optimizer()) {
1493       auto mirror_group = mirror_group_list(param_ptr->user_data<TensorLayout>());
1494       is_with_mirror = mirror_group.size() > 1;
1495     }
1496   }
1497   if (!is_shared_param && cast_node) {
1498     allgather = ReplaceNode(op, cast_node, graph, PARALLEL_OPTIMIZER_ALLGATHER_NOT_COMPUTE, param_name, root);
1499     MS_LOG(INFO) << "Parallel optimizer is applied before Cast for " << param_name;
1500   } else {
1501     auto pre_node = node;
1502     AnfNodePtr pre_node_ = node;
1503     auto &node_user_map = manager->node_users();
1504     TypePtr next_node_dtype = FindChildCastWithFP32ToFP16(res, node_user_map);
1505     if (next_node_dtype) {
1506       MS_LOG(INFO) << "Inserting Cast from float32 to float16 for node " << node->fullname_with_scope() << " for saving"
1507                    << " communication.";
1508       pre_node_ = CreateFP16Cast(cnode, pre_node, next_node_dtype);
1509     }
1510     InsertNode(op, cnode, IntToSize(res.second), pre_node_, graph, PARALLEL_OPTIMIZER_ALLGATHER_NOT_COMPUTE, param_name,
1511                root);
1512     allgather = cnode->input(IntToSize(res.second))->cast<CNodePtr>();
1513     MS_LOG(INFO) << "Parallel optimizer is applied before " << cnode->DebugString() << " for " << param_name;
1514   }
1515   bool add_accu = root->has_flag(kAccumulation);
1516   AddAllGatherAttrs(allgather, cnode, node, op_name, add_accu, is_with_mirror, grad_accumulation_shard);
1517 }
1518 
IsForwardCNode(const CNodePtr & cnode)1519 bool IsForwardCNode(const CNodePtr &cnode) {
1520   if (cnode->in_forward_flag()) {
1521     return true;
1522   }
1523   if (cnode->input(0) && IsValueNode<FuncGraph>(cnode->input(0))) {
1524     auto func_graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
1525     auto orders = func_graph->GetOrderedCnodes();
1526     return std::any_of(orders.begin(), orders.end(), [](const auto &c_node) { return c_node->in_forward_flag(); });
1527   }
1528   return false;
1529 }
1530 
InsertParallelOpt(const FuncGraphPtr & root,const AnfNodePtr & parameter,const std::string & opt_shard_group,const std::string & op_name)1531 void InsertParallelOpt(const FuncGraphPtr &root, const AnfNodePtr &parameter, const std::string &opt_shard_group,
1532                        const std::string &op_name) {
1533   // insert all gather
1534   FuncGraphManagerPtr manager = root->manager();
1535   MS_EXCEPTION_IF_NULL(manager);
1536   auto param_sub_set = manager->node_users()[parameter];
1537   bool insert_flag = false;
1538   for (auto &param_pair : param_sub_set) {
1539     auto cnode = param_pair.first->cast<CNodePtr>();
1540     MS_EXCEPTION_IF_NULL(cnode);
1541     if (IsForwardCNode(cnode) && !IsPrimitiveCNode(cnode, prim::kPrimReceive) &&
1542         !(IsPrimitiveCNode(cnode, prim::kPrimDepend) && param_pair.second == INDEX_TWO)) {
1543       if (insert_flag) {
1544         // if there are multiple node users, they share one same allgather
1545         auto next_cnode = FindCNode(parameter, op_name, cnode->func_graph(), 0);
1546         if (next_cnode.first) {
1547           manager->SetEdge(cnode, param_pair.second, next_cnode.second);
1548           auto param_ptr = parameter->cast<ParameterPtr>();
1549           MS_EXCEPTION_IF_NULL(param_ptr);
1550           AddNodeMirrorInfo(cnode, param_ptr->name());
1551           MS_LOG(INFO) << "Parallel optimizer is shared between " << parameter->ToString() << " and "
1552                        << GetPrimName(cnode);
1553         } else {
1554           MS_LOG(ERROR) << "Can not find the shared AllGather with multiple node users.";
1555         }
1556       } else {
1557         // insert allgather operator between shard parameter and cnode
1558         auto param_ptr = parameter->cast<ParameterPtr>();
1559         MS_EXCEPTION_IF_NULL(param_ptr);
1560         bool is_shared_param = param_ptr->user_data<TensorLayout>()->is_shared_param();
1561         InsertAllGatherOp(root, opt_shard_group, param_pair, parameter, op_name, is_shared_param);
1562         insert_flag = true;
1563       }
1564     }
1565   }
1566 }
1567 
ApplyParallelOptOnParam(const FuncGraphPtr & root,const AnfNodePtr & parameter,const std::string & opt_shard_group)1568 static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &parameter,
1569                                     const std::string &opt_shard_group) {
1570   auto enable_opt_shard = ParallelContext::GetInstance()->enable_parallel_optimizer();
1571   if (!enable_opt_shard) {
1572     return;
1573   }
1574   MS_EXCEPTION_IF_NULL(parameter);
1575   if (ParameterIsCloned(parameter)) {
1576     return;
1577   }
1578 
1579   int32_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
1580   if (opt_shard_group.empty() &&
1581       (split_stage_num <= 1 || !ParameterRequireGrad(parameter) || !root->has_flag(kTraining))) {
1582     return;
1583   }
1584 
1585   // set all gather type
1586   int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
1587   std::string op_name = ALL_GATHER;
1588   if (root->has_flag(kTraining)) {
1589     if ((grad_accumulation_step > 1 || split_stage_num > 1) && ParameterRequireGrad(parameter)) {
1590       op_name = MICRO_STEP_ALL_GATHER;
1591     }
1592   }
1593 
1594   // insert all gather
1595   InsertParallelOpt(root, parameter, opt_shard_group, op_name);
1596 }
1597 
1598 // When this function returns non-empty string, that means parallel optimizer is applied on this parameter.
SetParallelShape(const AnfNodePtr & parameter,const std::pair<AnfNodePtr,int64_t> & res,const FuncGraphPtr & root,const int & idx)1599 static std::string SetParallelShape(const AnfNodePtr &parameter, const std::pair<AnfNodePtr, int64_t> &res,
1600                                     const FuncGraphPtr &root, const int &idx) {
1601   // check null for param and cnode
1602   MS_EXCEPTION_IF_NULL(parameter);
1603   auto param_shape = parameter->Shape();
1604 
1605   MS_EXCEPTION_IF_NULL(param_shape);
1606 
1607   CNodePtr cnode = res.first->cast<CNodePtr>();
1608   MS_EXCEPTION_IF_NULL(cnode);
1609 
1610   // get slice_shape
1611   OperatorInfoPtr distribute_operator = cnode->user_data<OperatorInfo>();
1612   if (distribute_operator == nullptr) {
1613     MS_LOG(EXCEPTION) << "node " << cnode->ToString() << " 's distribute_operator is nullptr";
1614   }
1615   TensorLayout tensor_layout;
1616   if (distribute_operator->inputs_tensor_info_new().empty()) {
1617     if (LongToSize(res.second - 1) >= distribute_operator->inputs_tensor_info().size()) {
1618       MS_LOG(EXCEPTION) << "The parameter index is not in inputs_tensor_info. index = " << (res.second - 1)
1619                         << ", inputs_tensor_info size = " << distribute_operator->inputs_tensor_info().size();
1620     }
1621     TensorInfo tensorinfo_in = distribute_operator->inputs_tensor_info()[LongToSize(res.second - 1)];
1622     tensor_layout = tensorinfo_in.tensor_layout();
1623   } else {
1624     TensorInfoBasePtr tensorinfo_in;
1625     if (idx == -1) {
1626       tensorinfo_in = distribute_operator->inputs_tensor_info_new()[LongToSize(res.second - 1)];
1627     } else {
1628       // idx != -1, input is maketuple
1629       tensorinfo_in = distribute_operator->inputs_tensor_info_new()[LongToSize(idx)];
1630     }
1631     if (tensorinfo_in->is_list()) {
1632       if (idx == -1) {
1633         MS_LOG(EXCEPTION) << "The input of " << distribute_operator->name() << " is a list, but idx is -1.";
1634       }
1635       tensor_layout = tensorinfo_in->GetElement(res.second - 1)->GetValue().tensor_layout();
1636     } else {
1637       tensor_layout = tensorinfo_in->GetValue().tensor_layout();
1638     }
1639   }
1640   Shape slice_shape = tensor_layout.base_slice_shape().array();
1641 
1642   // generate shard group
1643   std::string opt_shard_group;
1644   MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
1645   bool enable_parallel_optimizer = ParallelContext::GetInstance()->enable_parallel_optimizer();
1646   if (enable_parallel_optimizer) {
1647     std::unique_ptr<OptParamMgr> apOptParamMgr = createOptParamMgr(root);
1648     opt_shard_group = apOptParamMgr->ShardOptGroup(parameter, &tensor_layout, distribute_operator);
1649     // set the shape of parameter to sliced shape
1650     if (!opt_shard_group.empty()) {
1651       slice_shape = tensor_layout.opt_shard_slice_shape();
1652     }
1653     MS_LOG(INFO) << "the shape of " << parameter->ToString() << "(original: " << param_shape->ToString() << ")"
1654                  << " will be sliced into " << MakeValue(slice_shape)->ToString() << " in op "
1655                  << distribute_operator->name();
1656   }
1657 
1658   AbstractBasePtr abstract = parameter->abstract();
1659   if (abstract == nullptr) {
1660     MS_LOG(EXCEPTION) << "parameter " << parameter->ToString() << ": abstract is nullptr";
1661   }
1662 
1663   AbstractBasePtr cloned_abstract = abstract->Clone();
1664   if (cloned_abstract == nullptr) {
1665     MS_LOG(EXCEPTION) << "parameter " << parameter->ToString() << ": abstract clone failed";
1666   }
1667 
1668   cloned_abstract->set_shape(std::make_shared<abstract::Shape>(slice_shape));
1669   parameter->set_abstract(cloned_abstract);
1670   ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>();
1671   MS_EXCEPTION_IF_NULL(parameter_ptr);
1672   if (tensor_layout.IsInterleavedParallel()) {
1673     MS_LOG(EXCEPTION) << "parameter " << parameter->ToString() << " can not set to interleaved parallel";
1674   }
1675   parameter_ptr->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
1676   if (ParallelContext::GetInstance()->direct_split() && parameter_ptr->has_default()) {
1677     auto layout = parameter_ptr->user_data<TensorLayout>();
1678     MS_LOG(INFO) << "parameter: " << parameter->ToString() << parameter->Shape()->ToString()
1679                  << "parameter_ptr->default_param()" << parameter_ptr->default_param() << "LAYOUT"
1680                  << layout->ToString();
1681     SliceTensorObj(parameter_ptr, layout);
1682   }
1683   return opt_shard_group;
1684 }
1685 
ObtainActualInputIdxForSupportedOps(const AnfNodeIndexSet & node_set)1686 int ObtainActualInputIdxForSupportedOps(const AnfNodeIndexSet &node_set) {
1687   int idx = 0;
1688   for (const auto &node_pair : node_set) {
1689     auto use_cnode = node_pair.first->cast<CNodePtr>();
1690     if (IsSomePrimitiveList(use_cnode, SUPPORT_NEW_SHAPEBASE_OPS)) {
1691       idx = node_pair.second;
1692     }
1693   }
1694   return idx;
1695 }
1696 
CoverSliceShape(const FuncGraphPtr & root)1697 static void CoverSliceShape(const FuncGraphPtr &root) {
1698   MS_EXCEPTION_IF_NULL(root);
1699   auto parameters = root->parameters();
1700   FuncGraphManagerPtr manager = root->manager();
1701   MS_EXCEPTION_IF_NULL(manager);
1702   const auto &node_users_map = manager->node_users();
1703   for (auto &parameter : parameters) {
1704     MS_EXCEPTION_IF_NULL(parameter->Shape());
1705     auto iter = g_RefMap.find(parameter);
1706     if (iter != g_RefMap.cend()) {
1707       auto node_set = node_users_map.at(g_RefMap[parameter].first);
1708       auto idx = ObtainActualInputIdxForSupportedOps(node_set);
1709       std::string group = SetParallelShape(parameter, g_RefMap[parameter], root, idx - 1);
1710       // find all forward nodes that use parameter in graphs and insert allgather if group is not empty
1711       SetSharedParameterFlag(root, parameter);
1712       ApplyParallelOptOnParam(root, parameter, group);
1713       continue;
1714     }
1715 
1716     std::pair<AnfNodePtr, int64_t> res = FindSubGraph(root, parameter);
1717     if (res.first == nullptr) {
1718       MS_LOG(INFO) << "Parameter " << parameter->ToString() << " is not in graph, thus no need to set parallel shape";
1719       if (parameter->has_user_data<TensorLayout>()) {
1720         auto param_abstract = parameter->abstract()->Clone();
1721         auto tensor_layout = parameter->user_data<TensorLayout>();
1722         Shape slice_shape = tensor_layout->base_slice_shape().array();
1723         param_abstract->set_shape(std::make_shared<abstract::Shape>(slice_shape));
1724         parameter->set_abstract(param_abstract);
1725       }
1726     } else {
1727       auto node_set = node_users_map.at(res.first);
1728       auto idx = ObtainActualInputIdxForSupportedOps(node_set);
1729       std::string group = SetParallelShape(parameter, res, root, idx - 1);
1730       // find all forward nodes that use parameter in graphs and insert allgather if group is not empty
1731       SetSharedParameterFlag(root, parameter);
1732       ApplyParallelOptOnParam(root, parameter, group);
1733       MS_LOG(DEBUG) << "Parameter " << parameter->ToString() << " shape " << parameter->Shape()->ToString();
1734     }
1735   }
1736   g_RefMap.clear();
1737 }
1738 
PreProcessActualSeqLenInputForFlashAttentionScore(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & all_nodes)1739 static void PreProcessActualSeqLenInputForFlashAttentionScore(const FuncGraphPtr &root,
1740                                                               const std::vector<AnfNodePtr> &all_nodes) {
1741   auto manager = root->manager();
1742   MS_EXCEPTION_IF_NULL(manager);
1743   for (auto node : all_nodes) {
1744     if (IsPrimitiveCNode(node, prim::kPrimFlashAttentionScore)) {
1745       auto fa_cnode = node->cast<CNodePtr>();
1746       MS_EXCEPTION_IF_NULL(fa_cnode);
1747       auto fa_inputs = fa_cnode->inputs();
1748       for (size_t index = ops::kFlashAttentionScoreInputActualSeqQlenIndex;
1749            index <= ops::kFlashAttentionScoreInputActualSeqKVlenIndex; ++index) {
1750         auto input = fa_inputs.at(index + 1);
1751         if (IsValueNode<None>(input)) {
1752           continue;
1753         }
1754         // Transfer Tuple to Tensor
1755         if (IsPrimitiveCNode(input, prim::kPrimTensorToTuple)) {
1756           // Eliminate TensorToTuple
1757           manager->SetEdge(fa_cnode, index + 1, input->cast<CNodePtr>()->input(kIndex1));
1758           MS_LOG(DEBUG) << "Eliminate TensorToTuple for " << fa_cnode->fullname_with_scope() << ", index is "
1759                         << index + 1;
1760         } else {
1761           auto dtype = NewValueNode(MakeValue<int64_t>(kInt64->type_id()));
1762           dtype->set_abstract(abstract::FromValue((int64_t)(kInt64->type_id())));
1763           auto tuple_to_tensor_cnode =
1764             fa_cnode->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleToTensor), input, dtype});
1765           auto abs = GenerateAbsByOpInfer(GetCNodePrimitive(tuple_to_tensor_cnode), {input, dtype});
1766           tuple_to_tensor_cnode->set_abstract(abs);
1767           manager->SetEdge(fa_cnode, index + 1, tuple_to_tensor_cnode);
1768           MS_LOG(DEBUG) << "Insert TupleToTensor for " << fa_cnode->fullname_with_scope() << ", index is " << index + 1;
1769         }
1770       }
1771     }
1772   }
1773 }
1774 
PostProcessActualSeqLenInputForFlashAttentionScore(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & all_nodes)1775 static void PostProcessActualSeqLenInputForFlashAttentionScore(const FuncGraphPtr &root,
1776                                                                const std::vector<AnfNodePtr> &all_nodes) {
1777   auto manager = root->manager();
1778   MS_EXCEPTION_IF_NULL(manager);
1779   for (auto node : all_nodes) {
1780     if (IsPrimitiveCNode(node, prim::kPrimFlashAttentionScore)) {
1781       auto fa_cnode = node->cast<CNodePtr>();
1782       MS_EXCEPTION_IF_NULL(fa_cnode);
1783       auto fa_inputs = fa_cnode->inputs();
1784       for (size_t index = ops::kFlashAttentionScoreInputActualSeqQlenIndex;
1785            index <= ops::kFlashAttentionScoreInputActualSeqKVlenIndex; ++index) {
1786         auto input = fa_inputs.at(index + 1);
1787         auto input_abs = input->abstract();
1788         if (IsValueNode<None>(input)) {
1789           continue;
1790         }
1791 
1792         if (IsPrimitiveCNode(input, prim::kPrimTupleToTensor)) {
1793           // Eliminate TupleToTensor
1794           manager->SetEdge(fa_cnode, index + 1, input->cast<CNodePtr>()->input(kIndex1));
1795           MS_LOG(DEBUG) << "Eliminate TensorToTuple for " << fa_cnode->fullname_with_scope() << ", index is "
1796                         << index + 1;
1797         } else {
1798           // Transfer Tensor to Tuple
1799           auto tensor_to_tuple_cnode =
1800             fa_cnode->func_graph()->NewCNode({NewValueNode(prim::kPrimTensorToTuple), input});
1801           manager->SetEdge(fa_cnode, index + 1, tensor_to_tuple_cnode);
1802           MS_LOG(DEBUG) << "Insert TensorToTuple for " << fa_cnode->fullname_with_scope() << ", index is " << index + 1;
1803         }
1804       }
1805     }
1806   }
1807 }
1808 
ObtainStrategyForNewShapes(const ShapeBasePtr & shape,const int64_t & dev_num)1809 ValuePtr ObtainStrategyForNewShapes(const ShapeBasePtr &shape, const int64_t &dev_num) {
1810   ValuePtr stra_value_ptr;
1811   if (shape->is_list()) {
1812     std::vector<ValuePtr> elements;
1813     for (size_t i = 0; i < shape->size(); ++i) {
1814       auto value_stra = ObtainStrategyForNewShapes(shape->GetElement(SizeToLong(i)), dev_num);
1815       elements.emplace_back(value_stra);
1816     }
1817     stra_value_ptr = std::make_shared<ValueTuple>(elements);
1818   } else {
1819     Dimensions stra;
1820     stra.push_back(dev_num);
1821     for (size_t j = 1; j < shape->size(); ++j) {
1822       stra.push_back(1);
1823     }
1824     stra_value_ptr = MakeValue(stra);
1825   }
1826   return stra_value_ptr;
1827 }
1828 
ObtainElementsForStrategyNewShape(const std::vector<NewShapes> & new_shape_list,const int64_t & dev_num,std::vector<ValuePtr> * elements)1829 void ObtainElementsForStrategyNewShape(const std::vector<NewShapes> &new_shape_list, const int64_t &dev_num,
1830                                        std::vector<ValuePtr> *elements) {
1831   for (size_t i = 0; i < new_shape_list[0].size(); i++) {
1832     if (new_shape_list[0][i]->empty()) {
1833       (void)elements->emplace_back(MakeValue(Dimensions()));
1834       continue;
1835     }
1836     auto input_strategy = ObtainStrategyForNewShapes(new_shape_list[0][i], dev_num);
1837     (void)elements->emplace_back(MakeValue(input_strategy));
1838   }
1839 }
1840 
ObtainElementsForStrategy(const std::vector<Shapes> & shape_list,const int64_t & dev_num,std::vector<ValuePtr> * elements)1841 void ObtainElementsForStrategy(const std::vector<Shapes> &shape_list, const int64_t &dev_num,
1842                                std::vector<ValuePtr> *elements) {
1843   for (size_t i = 0; i < shape_list[0].size(); i++) {
1844     if (shape_list[0][i].empty()) {
1845       (void)elements->emplace_back(MakeValue(Dimensions()));
1846       continue;
1847     }
1848     Dimensions input_strategy;
1849     input_strategy.push_back(dev_num);
1850     if (shape_list[0][i][0] > 0 && shape_list[0][i][0] % dev_num != 0) {
1851       MS_LOG(EXCEPTION) << "The shapes of dataset is " << shape_list[0]
1852                         << ", the batch dim can not be evenly div by dev_num " << dev_num;
1853     }
1854     for (size_t j = 1; j < shape_list[0][i].size(); j++) {
1855       input_strategy.push_back(1);
1856     }
1857     (void)elements->emplace_back(MakeValue(input_strategy));
1858   }
1859 }
1860 
ObtainShape(const CNodePtr & node)1861 std::pair<std::vector<Shapes>, std::vector<NewShapes>> ObtainShape(const CNodePtr &node) {
1862   std::vector<Shapes> shape_list;
1863   std::vector<NewShapes> new_shape_list;
1864   if (HasSupportedValueSequence(node)) {
1865     new_shape_list = ExtractNewShape(node);
1866   } else {
1867     shape_list = ExtractShape(node);
1868   }
1869   return std::make_pair(shape_list, new_shape_list);
1870 }
1871 
SetVirtualDatasetStrategy(const CNodePtr & node)1872 void SetVirtualDatasetStrategy(const CNodePtr &node) {
1873   MS_EXCEPTION_IF_NULL(node);
1874   MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
1875   bool full_batch = ParallelContext::GetInstance()->full_batch();
1876 
1877   PrimitivePtr prim = GetValueNode<PrimitivePtr>(node->input(0));
1878   MS_EXCEPTION_IF_NULL(prim);
1879   if (prim->name() == VIRTUAL_DATA_SET || prim->name() == VIRTUAL_OUTPUT) {
1880     CheckGlobalDeviceManager();
1881     auto attrs_temp = prim->attrs();
1882     if (!ParallelContext::GetInstance()->dataset_strategy().empty() && prim->name() == VIRTUAL_DATA_SET) {
1883       std::vector<ValuePtr> elements;
1884       auto dataset_strategy = ParallelContext::GetInstance()->dataset_strategy();
1885       (void)std::transform(dataset_strategy.begin(), dataset_strategy.end(), std::back_inserter(elements),
1886                            [](auto input_stra) { return MakeValue(input_stra); });
1887       ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements);
1888       attrs_temp[IN_STRATEGY] = strategy;
1889       (void)prim->SetAttrs(attrs_temp);
1890       if (prim->HasAttr(REPEAT_DIM_DIRECT) && GetValue<std::string>(prim->GetAttr(REPEAT_DIM_DIRECT)) == RIGHT) {
1891         ParallelContext::GetInstance()->set_dataset_repeat_dim_right(true);
1892         MS_LOG(INFO) << "dataset repeat dim is right";
1893       }
1894       return;
1895     }
1896     int64_t dev_num;
1897     if (full_batch) {
1898       dev_num = 1;
1899     } else {
1900       dev_num = g_device_manager->stage_device_num();
1901     }
1902     if (dev_num == 0) {
1903       MS_LOG(EXCEPTION) << "Device Num must be larger than 0, but got 0.";
1904     }
1905     std::vector<Shapes> shape_list;
1906     std::vector<NewShapes> new_shape_list;
1907     if (InDynamicGraph(node)) {
1908       shape_list = ExtractRealDivisor(node);
1909       MS_LOG(INFO) << "The node is in dynamic shape graph, the real divisor is " << ShapesToString(shape_list[0]);
1910     } else {
1911       std::tie(shape_list, new_shape_list) = ObtainShape(node);
1912     }
1913     if (shape_list.empty() && new_shape_list.empty()) {
1914       MS_LOG(EXCEPTION) << "Failure:node " << node->ToString() << " failed to extract shape";
1915     }
1916     std::vector<ValuePtr> elements;
1917     if (new_shape_list.empty()) {
1918       ObtainElementsForStrategy(shape_list, dev_num, &elements);
1919     } else {
1920       ObtainElementsForStrategyNewShape(new_shape_list, dev_num, &elements);
1921     }
1922     ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements);
1923     attrs_temp[IN_STRATEGY] = strategy;
1924     (void)prim->SetAttrs(attrs_temp);
1925   }
1926 }
1927 
CheckExtractInformation(const CNodePtr & cnode)1928 static bool CheckExtractInformation(const CNodePtr &cnode) {
1929   if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
1930     return false;
1931   }
1932 
1933   ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
1934   PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
1935   if ((prim->name() == MAKE_TUPLE) || (prim->name() == MAKE_LIST) || (prim->name() == RECEIVE)) {
1936     return false;
1937   }
1938 
1939   return IsParallelCareNode(cnode);
1940 }
1941 
GenerateStandAloneStra(const OperatorInfoPtr & op_info)1942 StrategyPtr GenerateStandAloneStra(const OperatorInfoPtr &op_info) {
1943   StrategyPtr in_strategy;
1944   if (op_info->inputs_shape_new().empty()) {
1945     in_strategy = GenerateStandAloneStrategy(op_info->inputs_shape());
1946   } else {
1947     in_strategy = GenerateStandAloneStrategyForNewShapes(op_info->inputs_shape_new());
1948   }
1949   return in_strategy;
1950 }
1951 
CheckStrategyAndShape(const StrategyPtr & in_strategy,const OperatorInfoPtr & op_info)1952 void CheckStrategyAndShape(const StrategyPtr &in_strategy, const OperatorInfoPtr &op_info) {
1953   MS_EXCEPTION_IF_NULL(in_strategy);
1954   auto has_tuple_stra = in_strategy->HasTupleInTupleStrategy();
1955   auto has_new_shape = !op_info->inputs_shape_new().empty();
1956   if (has_tuple_stra != has_new_shape) {
1957     MS_LOG(EXCEPTION)
1958       << "One of the strategy or input shape have tuple in tuple input, but the other does not; in_strategy is "
1959       << has_tuple_stra << ", input shape is " << has_new_shape;
1960   }
1961 }
1962 
ExtractStrategyAndInit(const CNodePtr & cnode,const PrimitivePtr & prim,const OperatorInfoPtr & op_info)1963 static void ExtractStrategyAndInit(const CNodePtr &cnode, const PrimitivePtr &prim, const OperatorInfoPtr &op_info) {
1964   StrategyPtr in_strategy = nullptr, out_strategy = nullptr;
1965   auto attrs = prim->attrs();
1966 
1967   // load strategy map from checkpoint
1968   StrategyMap stra_map;
1969   if (StrategyCheckpoint::GetInstance().LoadCheckPointOn() &&
1970       (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS)) {
1971     MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
1972   }
1973 
1974   std::string strategy_key_name = "";
1975   auto param_names = NodeParameterName(cnode, -1, 0);
1976   if (!param_names.empty()) {
1977     strategy_key_name = prim->name() + "_" + param_names[0].first;
1978   }
1979   std::vector<std::shared_ptr<TensorLayout>> in_tensor_layouts;
1980   std::vector<std::shared_ptr<TensorLayout>> out_tensor_layouts;
1981   if (ExtractUserConfigLayout(attrs, op_info->inputs_shape(), op_info->outputs_shape(), &in_tensor_layouts,
1982                               &out_tensor_layouts) != SUCCESS) {
1983     MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " extract configured layout failed"
1984                       << trace::DumpSourceLines(cnode);
1985   }
1986   if (in_tensor_layouts.empty() && out_tensor_layouts.empty()) {
1987     bool load_strategy_from_ckpt =
1988       StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end();
1989     if (!prim->HasAttr(STAND_ALONE)) {
1990       if (((!StrategyFound(attrs) && !load_strategy_from_ckpt) && !cnode->HasPrimalAttr(IN_STRATEGY)) ||
1991           prim->HasAttr(BATCH_PARALLEL)) {
1992         MS_LOG(INFO) << "ExtractInformation: the strategy of node " << cnode->ToString() << " prim " << prim->name()
1993                      << " is empty, using batch parallel";
1994         in_strategy = GenerateBatchParallelStrategy(op_info, prim);
1995       } else if (cnode->HasPrimalAttr(IN_STRATEGY)) {
1996         in_strategy = ExtractStrategy(cnode->GetPrimalAttr(IN_STRATEGY));
1997         out_strategy = ExtractStrategy(cnode->GetPrimalAttr(OUT_STRATEGY));
1998       } else if (StrategyFound(attrs)) {
1999         in_strategy = ExtractStrategy(attrs[IN_STRATEGY]);
2000         out_strategy = ExtractStrategy(attrs[OUT_STRATEGY]);
2001       } else {
2002         in_strategy = stra_map[strategy_key_name];
2003       }
2004     } else {
2005       in_strategy = GenerateStandAloneStra(op_info);
2006     }
2007     CheckStrategyAndShape(in_strategy, op_info);
2008   }
2009   if (op_info->Init(in_strategy, out_strategy, in_tensor_layouts, out_tensor_layouts) == FAILED) {
2010     MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed" << trace::DumpSourceLines(cnode);
2011   }
2012 }
2013 
ExtractInformation(const std::vector<AnfNodePtr> & all_nodes)2014 void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
2015   SetStridedSliceSplitStrategy(all_nodes);
2016   for (auto &node : all_nodes) {
2017     auto cnode = node->cast<CNodePtr>();
2018     if (!CheckExtractInformation(cnode) || IsPrimitiveCNode(node, prim::kPrimSend)) {
2019       continue;
2020     }
2021 
2022     SetVirtualDatasetStrategy(cnode);
2023     ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
2024     PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
2025 
2026     OperatorInfoPtr operator_ = CreateOperatorInfo(cnode);
2027     MS_EXCEPTION_IF_NULL(operator_);
2028 
2029     if (prim->name() == RESHAPE) {
2030       cnode->set_user_data<OperatorInfo>(operator_);
2031       continue;
2032     }
2033 
2034     ExtractStrategyAndInit(cnode, prim, operator_);
2035     cnode->set_user_data<OperatorInfo>(operator_);
2036   }
2037 }
2038 
2039 // if reshape's output connect to several primitive, return the first layout found
FindNextLayout(const AnfNodePtr & cnode,bool * next_is_reshape,mindspore::HashSet<AnfNodePtr> * visit,int make_tuple_index,int tuple_get_index,const std::shared_ptr<TensorLayout> & pre_layout)2040 static std::shared_ptr<TensorLayout> FindNextLayout(const AnfNodePtr &cnode, bool *next_is_reshape,
2041                                                     mindspore::HashSet<AnfNodePtr> *visit, int make_tuple_index,
2042                                                     int tuple_get_index,
2043                                                     const std::shared_ptr<TensorLayout> &pre_layout) {
2044   MS_EXCEPTION_IF_NULL(cnode);
2045   MS_EXCEPTION_IF_NULL(next_is_reshape);
2046   MS_EXCEPTION_IF_NULL(visit);
2047   MS_EXCEPTION_IF_NULL(cnode->func_graph());
2048   FuncGraphManagerPtr manager = cnode->func_graph()->manager();
2049   MS_EXCEPTION_IF_NULL(manager);
2050   AnfNodeIndexSet node_set = manager->node_users()[cnode];
2051   for (auto &node_pair : node_set) {
2052     auto use_apply = node_pair.first->cast<CNodePtr>();
2053     if (visit->find(use_apply) != visit->end()) {
2054       continue;
2055     }
2056     (void)(visit->insert(use_apply));
2057 
2058     if (IsPrimitiveCNode(use_apply, prim::kPrimPrint) || IsPrimitiveCNode(use_apply, prim::kPrimTensorDump)) {
2059       return pre_layout;
2060     }
2061 
2062     if (IsValueNode<FuncGraph>(use_apply->input(0))) {
2063       auto fg = GetValueNode<FuncGraphPtr>(use_apply->input(0));
2064       MS_EXCEPTION_IF_NULL(fg);
2065       auto fg_parameters = fg->parameters();
2066       auto param = fg_parameters[IntToSize(node_pair.second - 1)];
2067       auto next_layout = FindNextLayout(param, next_is_reshape, visit, make_tuple_index, tuple_get_index, pre_layout);
2068       if (next_layout != nullptr) {
2069         return next_layout;
2070       }
2071     }
2072 
2073     if (IsPrimitiveCNode(use_apply, prim::kPrimReturn)) {
2074       auto fg = use_apply->func_graph();
2075       auto fg_map = fg->func_graph_cnodes_index();
2076       for (auto &fg_use : fg_map) {
2077         auto fg_node = fg_use.first->first->cast<CNodePtr>();
2078         MS_EXCEPTION_IF_NULL(fg_node);
2079         auto next_layout =
2080           FindNextLayout(fg_node, next_is_reshape, visit, make_tuple_index, tuple_get_index, pre_layout);
2081         if (next_layout != nullptr) {
2082           return next_layout;
2083         }
2084       }
2085     }
2086 
2087     if (IsPrimitiveCNode(use_apply, prim::kPrimTupleGetItem)) {
2088       auto temp = LongToInt(GetTupleGetItemIndex(use_apply));
2089       if (temp != make_tuple_index - 1 && make_tuple_index > 0) {
2090         continue;
2091       }
2092       temp = make_tuple_index > 0 ? -1 : temp;
2093       auto next_layout = FindNextLayout(use_apply, next_is_reshape, visit, temp, -1, pre_layout);
2094       if (next_layout != nullptr) {
2095         return next_layout;
2096       }
2097     }
2098 
2099     if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
2100       continue;
2101     }
2102     if (IsPrimitiveCNode(use_apply, prim::kPrimReshape)) {
2103       *next_is_reshape = true;
2104       continue;
2105     }
2106     if (IsOneOfPrimitiveCNode(use_apply, {prim::kPrimDepend, prim::kPrimUpdateState}) && node_pair.second != 1) {
2107       continue;
2108     }
2109     if (IsPrimitiveCNode(use_apply, prim::kPrimMakeTuple)) {
2110       make_tuple_index = node_pair.second;
2111       auto next_layout =
2112         FindNextLayout(use_apply, next_is_reshape, visit, make_tuple_index, tuple_get_index, pre_layout);
2113       if (next_layout != nullptr) {
2114         return next_layout;
2115       }
2116     }
2117     if (IsParallelCareNode(use_apply) && use_apply->has_user_data<OperatorInfo>() &&
2118         IsSomePrimitiveList(use_apply, SUPPORT_NEW_SHAPEBASE_OPS)) {
2119       MS_LOG(INFO) << "FindNextLayout success node " << use_apply->DebugString() << ", in support new shapebase ops";
2120       *next_is_reshape = false;
2121       auto layout = GetInputLayoutFromCNode(node_pair, make_tuple_index);
2122       return std::make_shared<TensorLayout>(layout);
2123     }
2124     if (IsParallelCareNode(use_apply) && use_apply->has_user_data<OperatorInfo>()) {
2125       if (make_tuple_index > 0) {
2126         node_pair.second = make_tuple_index;
2127       }
2128       MS_LOG(INFO) << "FindNextLayout success node " << use_apply->DebugString();
2129       *next_is_reshape = false;
2130       auto layout = GetInputLayoutFromCNode(node_pair, -1);
2131       return std::make_shared<TensorLayout>(layout);
2132     }
2133     MS_LOG(DEBUG) << "FindNextLayout failed node " << use_apply->DebugString() << "  " << IsParallelCareNode(use_apply)
2134                   << "   " << use_apply->has_user_data<OperatorInfo>();
2135 
2136     auto layout_ptr = FindNextLayout(use_apply, next_is_reshape, visit, make_tuple_index, tuple_get_index, pre_layout);
2137     if (layout_ptr) {
2138       return layout_ptr;
2139     }
2140   }
2141   return nullptr;
2142 }
2143 
GetOutputLayoutFromCNode(const CNodePtr & cnode,size_t output_index)2144 static std::shared_ptr<TensorLayout> GetOutputLayoutFromCNode(const CNodePtr &cnode, size_t output_index) {
2145   MS_EXCEPTION_IF_NULL(cnode);
2146   OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode);
2147   MS_EXCEPTION_IF_NULL(distribute_operator);
2148   TensorLayout tensorlayout_out;
2149   if (distribute_operator->outputs_tensor_info_new().empty()) {
2150     if (distribute_operator->outputs_tensor_info().size() <= output_index) {
2151       MS_LOG(EXCEPTION) << "outputs_tensor_info size is  " << distribute_operator->outputs_tensor_info().size()
2152                         << ", must be greater than output_index  " << output_index;
2153     }
2154     TensorInfo tensorinfo_out = distribute_operator->outputs_tensor_info()[output_index];
2155     tensorlayout_out = tensorinfo_out.tensor_layout();
2156   } else {
2157     if (distribute_operator->outputs_tensor_info_new().size() <= output_index) {
2158       MS_LOG(EXCEPTION) << "outputs_tensor_info size is  " << distribute_operator->outputs_tensor_info_new().size()
2159                         << ", must be greater than output_index  " << output_index;
2160     }
2161     auto tensorinfo_out = distribute_operator->outputs_tensor_info_new()[output_index];
2162     if (tensorinfo_out->is_list()) {
2163       MS_LOG(EXCEPTION) << "For " << cnode->DebugString() << ": the " << output_index
2164                         << " out tensorinfo is a list, which does not support yet";
2165     }
2166     tensorlayout_out = tensorinfo_out->GetValue().tensor_layout();
2167   }
2168   return std::make_shared<TensorLayout>(tensorlayout_out);
2169 }
2170 
FindPrevParallelCareNodeLayout(const AnfNodePtr & node,size_t output_index)2171 static std::shared_ptr<TensorLayout> FindPrevParallelCareNodeLayout(const AnfNodePtr &node, size_t output_index) {
2172   if (!node->isa<CNode>()) {
2173     return nullptr;
2174   }
2175   CNodePtr cnode = node->cast<CNodePtr>();
2176   if (!IsValueNode<Primitive>(cnode->input(0))) {
2177     return nullptr;
2178   }
2179   if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
2180     auto layout_ptr = GetOutputLayoutFromCNode(cnode, output_index);
2181     if (!layout_ptr) {
2182       MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed";
2183     }
2184     return layout_ptr;
2185   }
2186   return nullptr;
2187 }
2188 
InferSensRedistribution(const AnfNodePtr & node,const TensorLayout & loss_layout)2189 static RedistributionOpListPtr InferSensRedistribution(const AnfNodePtr &node, const TensorLayout &loss_layout) {
2190   MS_EXCEPTION_IF_NULL(node);
2191   TensorRedistribution tensor_redistribution;
2192   // create stand alone layout:TensorMap:[all -1],dev_matrix:[dev_num].
2193   CheckGlobalDeviceManager();
2194   int64_t dev_num = g_device_manager->stage_device_num();
2195   TensorLayout stand_alone_layout;
2196   Shapes inputs_shape = GetNodeShape(node);
2197   if (inputs_shape.empty()) {
2198     MS_LOG(EXCEPTION) << "InferSensRedistribution failed cause inputs shape is empty.";
2199   }
2200   Shape input_shape_array = inputs_shape[0];
2201   if (input_shape_array.empty()) {
2202     MS_LOG(INFO) << "No need to redistribution for sens.";
2203     return nullptr;
2204   }
2205   // TensorMap
2206   TensorMap stand_alone_tensor_map_array(SizeToLong(input_shape_array.size()), -1);
2207   // Dev_matrix
2208   Shape dev_matrix_array = {dev_num};
2209   if (stand_alone_layout.InitFromVector(dev_matrix_array, stand_alone_tensor_map_array, input_shape_array) == FAILED) {
2210     MS_LOG(EXCEPTION) << "Create tensor layout for Sens failed.";
2211   }
2212 
2213   // Infer Redistribution op list for stand alone and loss layout.
2214   RankList dev_list = g_device_manager->GetDeviceListInThisStage();
2215   if (tensor_redistribution.Init(stand_alone_layout, loss_layout, dev_list) == FAILED) {
2216     MS_LOG(EXCEPTION) << "Redistribution for Sens init failed.";
2217   }
2218   RedistributionOpListPtr sens_redistribution_list = tensor_redistribution.InferTensorRedistributionOperatorList();
2219   MS_EXCEPTION_IF_NULL(sens_redistribution_list);
2220 
2221   return sens_redistribution_list;
2222 }
2223 
2224 // reshape1 ---> depend ---> call @sub_graph(x, y, z)
2225 // sub_graph(x, y, z): reshape2(y)
2226 // find the reshape1 through y
RefParameterToActualNode(const AnfNodePtr & node)2227 static AnfNodePtr RefParameterToActualNode(const AnfNodePtr &node) {
2228   if (!node->isa<Parameter>()) {
2229     return nullptr;
2230   }
2231   auto node_param_ptr = node->cast<ParameterPtr>();
2232   if (node_param_ptr->has_default()) {
2233     return node;
2234   }
2235   auto sub_func_graph = node_param_ptr->func_graph();
2236   auto call_cnodes_map = sub_func_graph->func_graph_cnodes_index();
2237   auto sub_graph_parameters = sub_func_graph->parameters();
2238   auto curr_param_iter = std::find(sub_graph_parameters.begin(), sub_graph_parameters.end(), node);
2239   if (curr_param_iter == sub_graph_parameters.end()) {
2240     MS_LOG(EXCEPTION) << "Cannot find param " << node_param_ptr->DebugString() << " in current sub_graph";
2241   }
2242   size_t curr_param_index = static_cast<size_t>(curr_param_iter - sub_graph_parameters.begin());
2243   for (const auto &node_pair : call_cnodes_map) {
2244     if (!node_pair.first->first->isa<CNode>() || node_pair.first->second > 0) {
2245       continue;
2246     }
2247     auto cnode = node_pair.first->first->cast<CNodePtr>();
2248     auto cnode_input = cnode->input(curr_param_index + 1);
2249     auto pre_cnode = GetInputNodeWithFilter(cnode_input, [&](const CNodePtr &cnode) {
2250       bool filter = IsPrimitiveCNode(cnode, prim::kPrimCast) || IsPrimitiveCNode(cnode, prim::kPrimLoad) ||
2251                     IsPrimitiveCNode(cnode, prim::kPrimDepend);
2252       return std::make_pair(filter, 1);
2253     });
2254     if (pre_cnode) {
2255       return pre_cnode;
2256     }
2257   }
2258   return nullptr;
2259 }
2260 
IsCommonOp(const AnfNodePtr & node)2261 static bool IsCommonOp(const AnfNodePtr &node) {
2262   CNodePtr cnode = node->cast<CNodePtr>();
2263   bool is_comm_op =
2264     IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>() && !IsPrimitiveCNode(node, prim::kPrimReshape);
2265   return is_comm_op;
2266 }
2267 
FindPrevLayout(const AnfNodePtr & node,bool * is_input_param)2268 static std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node, bool *is_input_param) {
2269   if (node->isa<Parameter>()) {
2270     auto node_param_ptr = node->cast<ParameterPtr>();
2271     if (node_param_ptr->has_default()) {
2272       // Only when the real input of Reshape is a parameter that the strategy of Reshape will be assigned to this
2273       // parameter.
2274       *is_input_param = true;
2275       return CreateParameterLayout(node);
2276     }
2277 
2278     // the node is parameter of sub-graph
2279     auto actual_node = RefParameterToActualNode(node);
2280     if (actual_node) {
2281       return FindPrevLayout(actual_node, is_input_param);
2282     }
2283     return nullptr;
2284   }
2285   if (!node->isa<CNode>()) {
2286     return nullptr;
2287   }
2288   CNodePtr cnode = node->cast<CNodePtr>();
2289   if (IsValueNode<FuncGraph>(cnode->input(0))) {
2290     auto fg = GetValueNode<FuncGraphPtr>(cnode->input(0));
2291     auto pre_node = GetRealKernelNode(fg->output(), -1, nullptr).first;
2292     if (!pre_node) {
2293       return nullptr;
2294     }
2295     return FindPrevLayout(pre_node, is_input_param);
2296   }
2297   if (!IsValueNode<Primitive>(cnode->input(0))) {
2298     return nullptr;
2299   }
2300   if (IsPrimitiveCNode(node, prim::kPrimReceive)) {
2301     return cnode->user_data<TensorLayout>();
2302   }
2303   if (IsCommonOp(node)) {
2304     auto layout_ptr = GetOutputLayoutFromCNode(cnode, 0);
2305     if (!layout_ptr) {
2306       MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed";
2307     }
2308     return layout_ptr;
2309   }
2310   ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
2311   PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
2312   if (prim->name() == prim::kPrimTupleGetItem->name()) {
2313     auto tuple_index = GetTupleGetItemIndex(cnode);
2314     auto tuple_getitem_input = cnode->input(1)->cast<CNodePtr>();
2315     if (IsValueNode<FuncGraph>(tuple_getitem_input->input(0))) {
2316       auto fg = GetValueNode<FuncGraphPtr>(tuple_getitem_input->input(0));
2317       auto pre_node = GetRealKernelNode(fg->output(), tuple_index, nullptr).first;
2318       if (!pre_node) {
2319         return nullptr;
2320       }
2321       return FindPrevLayout(pre_node, is_input_param);
2322     }
2323     auto layout_ptr = FindPrevParallelCareNodeLayout(cnode->input(1), LongToSize(tuple_index));
2324     if (!layout_ptr) {
2325       MS_LOG(EXCEPTION) << " Failure:FindPrevLayout failed, tuple_getitem before reshape, but there does not exit a "
2326                            "parallel care node "
2327                            "before tuple_getitem!";
2328     }
2329     return layout_ptr;
2330   }
2331   for (size_t index = 0; index < cnode->size(); ++index) {
2332     if (prim->name() == DEPEND && index != 1) {
2333       continue;
2334     }
2335     auto layout_ptr = FindPrevLayout(cnode->inputs()[index], is_input_param);
2336     if (!layout_ptr) {
2337       continue;
2338     }
2339     return layout_ptr;
2340   }
2341   return nullptr;
2342 }
2343 
ReshapeInit(const std::vector<AnfNodePtr> & all_nodes)2344 static void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes) {
2345   MS_LOG(DEBUG) << "=============Do ReshapeInit start=============";
2346   for (auto &node : all_nodes) {
2347     auto cnode = node->cast<CNodePtr>();
2348     if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
2349       continue;
2350     }
2351     ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
2352     if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) {
2353       continue;
2354     }
2355     PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
2356     MS_EXCEPTION_IF_NULL(prim);
2357     OperatorInfoPtr operator_info = cnode->user_data<OperatorInfo>();
2358     if (operator_info == nullptr) {
2359       MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr";
2360     }
2361     if (prim->name() != RESHAPE) {
2362       continue;
2363     }
2364 
2365     bool is_input_param = false;
2366     auto prev_layout_ptr = FindPrevLayout(cnode->input(1), &is_input_param);
2367     if (prev_layout_ptr) {
2368       auto reshape_info_ptr = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
2369       reshape_info_ptr->SetInputLayout(*prev_layout_ptr);
2370     } else {
2371       MS_LOG(WARNING)
2372         << "FindPrevLayout return nullptr, if reshape is not the first primitive, there must be some error";
2373     }
2374     auto attrs = prim->attrs();
2375     if (StrategyFound(attrs) && !is_input_param) {
2376       MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!";
2377     }
2378     MS_ASSERT(cnode->size() == RESHAPE_INPUT_SIZE);
2379 
2380     bool is_next_reshape = false;
2381     mindspore::HashSet<AnfNodePtr> visit;
2382     auto next_layout_ptr = FindNextLayout(cnode, &is_next_reshape, &visit, -1, -1, prev_layout_ptr);
2383     if (next_layout_ptr == nullptr) {
2384       std::string is_reshape = is_next_reshape ? "true" : "false";
2385       MS_LOG(WARNING) << "FindNextLayout for " << cnode->fullname_with_scope()
2386                       << " return nullptr, and is_next_reshape is " << is_next_reshape
2387                       << ". If reshape is not the last primitive, there must be some error.";
2388     }
2389     if (next_layout_ptr) {
2390       auto reshape_info_ptr = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
2391       reshape_info_ptr->SetOutputLayout(*next_layout_ptr);
2392     } else if (is_next_reshape && prev_layout_ptr != nullptr) {
2393       auto reshape_info_ptr = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
2394       reshape_info_ptr->SetOutputLayout(*prev_layout_ptr);
2395     }
2396     if (operator_info->Init(nullptr, nullptr) == FAILED) {
2397       MS_LOG(EXCEPTION) << "Failure:operator " << prim->ToString() << " init failed";
2398     }
2399   }
2400   MS_LOG(DEBUG) << "=============Do ReshapeInit end=============";
2401 }
2402 
HandleDependLoss(const CNodePtr & cnode,size_t curr_depth)2403 static CNodePtr HandleDependLoss(const CNodePtr &cnode, size_t curr_depth) {
2404   if (curr_depth > MAX_RECURSIVE_DEPTH) {
2405     MS_LOG(WARNING) << "When handling the loss node of Depend, exceeded the max recursive depth: "
2406                     << MAX_RECURSIVE_DEPTH;
2407     return nullptr;
2408   }
2409   // Handle return->depend->loss
2410   if (IsPrimitiveCNode(cnode, prim::kPrimDepend) ||
2411       (IsPrimitiveCNode(cnode, prim::kPrimCast) && !cnode->has_user_data<OperatorInfo>())) {
2412     auto depend_before = cnode->input(1)->cast<CNodePtr>();
2413     MS_EXCEPTION_IF_NULL(depend_before);
2414     return HandleDependLoss(depend_before, ++curr_depth);
2415   }
2416   return cnode;
2417 }
2418 
FindLossCNode(const FuncGraphPtr & func_graph)2419 static LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph) {
2420   LossNodeInfo loss_node_info;
2421   MS_EXCEPTION_IF_NULL(func_graph);
2422   CNodePtr return_node = func_graph->get_return();
2423   MS_EXCEPTION_IF_NULL(return_node);
2424   if (return_node->size() < 2) {
2425     MS_LOG(EXCEPTION) << "Failure: " << return_node->DebugString() << " size is smaller than 2";
2426   }
2427   auto pre_node_pair = GetRealKernelNode(return_node->input(1), -1, nullptr);
2428   auto pre_node = pre_node_pair.first;
2429   MS_EXCEPTION_IF_NULL(pre_node);
2430   auto pre_cnode = pre_node->cast<CNodePtr>();
2431 
2432   if (pre_cnode == nullptr || !IsValueNode<Primitive>(pre_cnode->input(0))) {
2433     return loss_node_info;
2434   }
2435   if (!IsValueNode<Primitive>(pre_cnode->input(0))) {
2436     MS_LOG(DEBUG) << "pre_cnode:" << pre_cnode->ToString();
2437     return loss_node_info;
2438   }
2439   auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
2440   // notice: the GetNext op has not input
2441   if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) {
2442     MS_LOG(INFO) << "The loss is: " << current_prim->name();
2443     loss_node_info.loss_node = pre_cnode;
2444     return loss_node_info;
2445   }
2446 
2447   // return -> tuple_getitem -> loss
2448   if (pre_node_pair.second != -1) {
2449     loss_node_info.has_tuple_getitem = true;
2450     loss_node_info.dout_index = pre_node_pair.second;
2451     loss_node_info.loss_node = pre_cnode;
2452     return loss_node_info;
2453   }
2454 
2455   // return -> make_tuple
2456   if (current_prim->name() == MAKE_TUPLE) {
2457     return loss_node_info;
2458   }
2459 
2460   // return -> loss
2461   loss_node_info.loss_node = pre_cnode;
2462   MS_LOG(DEBUG) << "The loss name is " << current_prim->name();
2463   return loss_node_info;
2464 }
2465 
GetLossNodeGradOutputLayout(const LossNodeInfo & node_info)2466 static TensorLayouts GetLossNodeGradOutputLayout(const LossNodeInfo &node_info) {
2467   TensorLayouts ret;
2468   auto loss_cnode = node_info.loss_node;
2469   MS_EXCEPTION_IF_NULL(loss_cnode);
2470 
2471   ValueNodePtr prim_anf_node = loss_cnode->input(0)->cast<ValueNodePtr>();
2472   MS_EXCEPTION_IF_NULL(prim_anf_node);
2473   PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
2474   MS_EXCEPTION_IF_NULL(prim);
2475   if (INVALID_LOSS_OPS.find(prim->name()) != INVALID_LOSS_OPS.end()) {
2476     MS_LOG(WARNING) << "The loss name is: " << prim->name() << ", do nothing for split sens now";
2477     return ret;
2478   }
2479 
2480   OperatorInfoPtr operator_info = loss_cnode->user_data<OperatorInfo>();
2481   if (!operator_info) {
2482     return ret;
2483   }
2484   MS_EXCEPTION_IF_NULL(operator_info);
2485   TensorInfo loss_grad_tensor_info;
2486   size_t op_output_size = operator_info->outputs_tensor_info().size();
2487   MS_LOG(INFO) << "The loss name is " << operator_info->name() << ", the has tuple item is  "
2488                << node_info.has_tuple_getitem << ", the output size is  " << op_output_size << ", the dout_index is  "
2489                << node_info.dout_index;
2490 
2491   if ((op_output_size == 0) || (op_output_size <= LongToSize(node_info.dout_index))) {
2492     MS_LOG(EXCEPTION) << "The index is  " << node_info.dout_index << ", but the size of outputs is  " << op_output_size;
2493   }
2494 
2495   if (!node_info.has_tuple_getitem && (op_output_size > 1)) {
2496     MS_LOG(EXCEPTION) << "Currently, it is not supported that the sens is a tuple.";
2497   }
2498 
2499   loss_grad_tensor_info = operator_info->outputs_tensor_info()[LongToSize(node_info.dout_index)];
2500   ret.push_back(loss_grad_tensor_info.tensor_layout());
2501   return ret;
2502 }
2503 
SplitSens(const CNodePtr & grad_sens_node,const TensorLayout & loss_grad_layout)2504 static void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_layout) {
2505   MS_EXCEPTION_IF_NULL(grad_sens_node);
2506   if (grad_sens_node->size() <= 1) {
2507     MS_LOG(EXCEPTION) << "The size of grad sens node is smaller than 2";
2508   }
2509   AnfNodePtr sens_tensor_node = grad_sens_node->input(1);
2510   MS_EXCEPTION_IF_NULL(sens_tensor_node);
2511   Shapes sens_shapes = GetNodeShape(sens_tensor_node);
2512   if (sens_shapes.size() != 1) {
2513     MS_LOG(EXCEPTION) << "GetNodeShape for sens_tensor_node, output size is not 1";
2514   }
2515   // If the shape of sens tensor is [] or [1], no need to split it.
2516   Shape sens_shape = sens_shapes[0];
2517   if (sens_shape.empty() || ((sens_shape.size() == 1) && (sens_shape[0] == 1))) {
2518     if (sens_tensor_node->isa<Parameter>()) {
2519       auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>();
2520       MS_LOG(DEBUG) << "loss layout " << loss_grad_layout.ToString();
2521       sens_tensor_param->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout));
2522     }
2523     MS_LOG(INFO) << "The shape of sens is " << ShapeToString(sens_shape) << ", no need to split sens";
2524     return;
2525   }
2526   auto loss_shape = loss_grad_layout.tensor_shape().array();
2527   auto loss_tensor_map = loss_grad_layout.tensor_map_before();
2528   bool multi_split = std::any_of(loss_tensor_map.begin(), loss_tensor_map.end(),
2529                                  [](const auto &tensor_map) { return tensor_map.size() != 1; });
2530   if ((loss_shape != sens_shape) && !multi_split) {
2531     MS_LOG(EXCEPTION) << "The shape of sens is not equal to loss output, it is unsupported now. Sens shape is "
2532                       << ShapeToString(sens_shape) << ", loss shape is " << ShapeToString(loss_shape);
2533   }
2534   MS_LOG(INFO) << "The shape of sens is " << ShapeToString(sens_shape) << ", split it.";
2535 
2536   if (!IsValueNode<Tensor>(sens_tensor_node)) {
2537     if (sens_tensor_node->isa<Parameter>()) {
2538       MS_LOG(DEBUG) << "loss layout " << loss_grad_layout.ToString();
2539       AbstractBasePtr abstract = sens_tensor_node->abstract();
2540       MS_EXCEPTION_IF_NULL(abstract);
2541       auto slice_shape = loss_grad_layout.slice_shape().array();
2542       std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
2543       MS_EXCEPTION_IF_NULL(parallel_shape);
2544       auto cloned_abstract = abstract->Clone();
2545       MS_EXCEPTION_IF_NULL(cloned_abstract);
2546       cloned_abstract->set_shape(parallel_shape);
2547       sens_tensor_node->set_abstract(cloned_abstract);
2548       auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>();
2549       sens_tensor_param->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout));
2550       return;
2551     }
2552     bool is_dynamic = InDynamicGraph(sens_tensor_node->cast<CNodePtr>());
2553     if (sens_tensor_node->isa<CNode>() && !is_dynamic) {
2554       auto op_list_ptr = InferSensRedistribution(sens_tensor_node, loss_grad_layout);
2555       if (op_list_ptr == nullptr) {
2556         return;
2557       }
2558       auto sens_tensor_cnode = sens_tensor_node->cast<CNodePtr>();
2559       auto func_graph = grad_sens_node->func_graph();
2560       MS_EXCEPTION_IF_NULL(func_graph);
2561       TensorRedistributionPtr tensor_redistribution = std::make_shared<TensorRedistribution>();
2562       InsertRedistribution(op_list_ptr, grad_sens_node, func_graph, 1, sens_tensor_cnode, tensor_redistribution);
2563       return;
2564     }
2565     if (is_dynamic) {
2566       return;
2567     }
2568     MS_LOG(EXCEPTION) << "The type of sens node is not Tensor or Parameter or CNode, it is unsupported now.";
2569   }
2570 
2571   // Use _GetTensorSlice operator to split the sens tensor
2572   FuncGraphPtr func_graph = grad_sens_node->func_graph();  // only cnode can get the graph
2573   MS_EXCEPTION_IF_NULL(func_graph);
2574   Operator op = CreateGetTensorSliceOp(loss_grad_layout);
2575   InsertGetTensorSliceOp(op, grad_sens_node, func_graph, 1, SPLIT_SENS);
2576 }
2577 
InsertForwardOps(const OperatorInfoPtr & distribute_operator,const CNodePtr & cnode)2578 static void InsertForwardOps(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) {
2579   MS_EXCEPTION_IF_NULL(distribute_operator);
2580   MS_EXCEPTION_IF_NULL(cnode);
2581   if (IsPrimitiveCNode(cnode, prim::kPrimReceive)) {
2582     return;
2583   }
2584   OperatorVector forward_op = distribute_operator->forward_op();
2585   // for gmm, its make tuple will inherit its op info,
2586   // which will lead to insert allreduce for maketuple.
2587   if (!forward_op.empty() && !IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) {
2588     MS_LOG(INFO) << "Insert forward op for " << distribute_operator->name();
2589     ForwardCommunication(forward_op, cnode);
2590   }
2591 }
2592 
StepReplace(const std::vector<AnfNodePtr> & all_nodes)2593 static void StepReplace(const std::vector<AnfNodePtr> &all_nodes) {
2594   for (auto &node : all_nodes) {
2595     MS_EXCEPTION_IF_NULL(node);
2596     if (node->isa<CNode>()) {
2597       auto cnode = node->cast<CNodePtr>();
2598       if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>() || IsSomePrimitive(cnode, RECEIVE) ||
2599           IsSomePrimitive(cnode, SEND)) {
2600         continue;
2601       }
2602 
2603       OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode);
2604       // StepReplace
2605       MS_EXCEPTION_IF_NULL(distribute_operator);
2606       auto replace_op = distribute_operator->replace_op();
2607       if (!replace_op.empty()) {
2608         MS_LOG(INFO) << "StepReplaceOp " << cnode->ToString();
2609         StepReplaceOp(replace_op, cnode);
2610       }
2611 
2612       // StepReplaceGraph: after calling StepReplaceGraph, cnode can not be used anymore.
2613       auto replace_graph = distribute_operator->replace_graph(cnode);
2614       if (!replace_op.empty() && replace_graph) {
2615         MS_LOG(EXCEPTION) << "Only one of replace_op or replace_op can be used";
2616       }
2617       if (replace_graph) {
2618         MS_LOG(INFO) << "StepReplaceGraph " << cnode->ToString();
2619         StepReplaceGraph(replace_graph, cnode, distribute_operator);
2620       }
2621       if (distribute_operator->name().find(RESHAPEINFO) != std::string::npos) {
2622         auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(distribute_operator);
2623         if (!reshape_info->InterleavedParallel()) {
2624           continue;
2625         }
2626         auto reshape_redis = reshape_info->ReshapeRedistribution();
2627         InsertRedistributionForMicroInterleaved(reshape_redis, {cnode, 1}, cnode->func_graph(), cnode,
2628                                                 cnode->input(kIndex1)->cast<CNodePtr>());
2629         if (!IsPrimitiveCNode(cnode->input(kIndex1), prim::kPrimVirtualConverterEnd)) {
2630           continue;
2631         }
2632         auto virtual_converter_end = cnode->input(kIndex1)->cast<CNodePtr>();
2633         auto func_graph = cnode->func_graph();
2634         MS_EXCEPTION_IF_NULL(func_graph);
2635         auto manager = func_graph->manager();
2636         MS_EXCEPTION_IF_NULL(manager);
2637         manager->Replace(cnode, virtual_converter_end);
2638       }
2639     }
2640   }
2641 }
2642 
StepSplitSens(const std::pair<CNodePtr,LossNodeInfo> & sens_loss_pair)2643 static void StepSplitSens(const std::pair<CNodePtr, LossNodeInfo> &sens_loss_pair) {
2644   CNodePtr sens_node = sens_loss_pair.first;
2645   auto loss_node = sens_loss_pair.second;
2646   auto loss_grad_layout = GetLossNodeGradOutputLayout(loss_node);
2647   if (!loss_grad_layout.empty()) {
2648     SplitSens(sens_node, loss_grad_layout[0]);
2649   }
2650 }
2651 
2652 // Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J)
GetSensLossPairs(const FuncGraphPtr & root)2653 static std::vector<std::pair<CNodePtr, LossNodeInfo>> GetSensLossPairs(const FuncGraphPtr &root) {
2654   MS_EXCEPTION_IF_NULL(root);
2655   std::vector<std::pair<CNodePtr, LossNodeInfo>> sens_loss_pairs;
2656   for (auto &node : root->nodes()) {
2657     if (!node->isa<CNode>()) {
2658       continue;
2659     }
2660 
2661     // cnode(sens)-->cnode(tuple_getitem)
2662     auto sens_cnode = node->cast<CNodePtr>();
2663     AnfNodePtr expect_tuple_getitem = sens_cnode->input(0);
2664     MS_EXCEPTION_IF_NULL(expect_tuple_getitem);
2665     if (!expect_tuple_getitem->isa<CNode>()) {
2666       continue;
2667     }
2668 
2669     auto expect_tuple_getitem_cnode = expect_tuple_getitem->cast<CNodePtr>();
2670     if (!IsSomePrimitive(expect_tuple_getitem_cnode, prim::kPrimTupleGetItem->name())) {
2671       continue;
2672     }
2673 
2674     // cnode(sens)-->cnode(tuple_getitem)-->cnode
2675     AnfNodePtr expect_anonymous = expect_tuple_getitem_cnode->input(1);
2676     MS_EXCEPTION_IF_NULL(expect_anonymous);
2677     if (!expect_anonymous->isa<CNode>()) {
2678       continue;
2679     }
2680 
2681     // cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J)
2682     auto expect_anonymous_cnode = expect_anonymous->cast<CNodePtr>();
2683     AnfNodePtr expect_j = expect_anonymous_cnode->input(0);
2684     MS_EXCEPTION_IF_NULL(expect_j);
2685     if (!expect_j->isa<CNode>()) {
2686       continue;
2687     }
2688     auto expect_j_cnode = expect_j->cast<CNodePtr>();
2689     if (!IsSomePrimitive(expect_j_cnode, J)) {
2690       continue;
2691     }
2692 
2693     if (!IsValueNode<FuncGraph>(expect_j_cnode->input(1))) {
2694       MS_LOG(EXCEPTION) << "Sens can't find the corresponding graph.";
2695     }
2696     auto func_graph = GetValueNode<FuncGraphPtr>(expect_j_cnode->input(1));
2697     auto loss_node_info = FindLossCNode(func_graph);
2698     if (loss_node_info.loss_node == nullptr) {
2699       MS_LOG(WARNING) << "Can not find the loss cnode";
2700       continue;
2701     }
2702     std::pair<CNodePtr, LossNodeInfo> sens_loss_pair = std::make_pair(sens_cnode, loss_node_info);
2703     sens_loss_pairs.push_back(sens_loss_pair);
2704   }
2705   return sens_loss_pairs;
2706 }
2707 
HandleSens(const std::vector<std::pair<CNodePtr,LossNodeInfo>> & sens_loss_pairs)2708 static void HandleSens(const std::vector<std::pair<CNodePtr, LossNodeInfo>> &sens_loss_pairs) {
2709   // split sens must before inserting the operators.
2710   for (auto &pair : sens_loss_pairs) {
2711     // If the shape of grad-sens tensor is not [] or [1], use get tensor slice to handle it.
2712     // If the type of sens node is not Tensor, it is unsupported now, do nothing default.
2713     if (IsLastStage()) {
2714       StepSplitSens(pair);
2715     }
2716   }
2717   return;
2718 }
2719 
ParallelCommunication(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & all_nodes,const FuncGraphManagerPtr & manager)2720 static void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
2721                                   const FuncGraphManagerPtr &manager) {
2722   MS_EXCEPTION_IF_NULL(root);
2723   MS_EXCEPTION_IF_NULL(manager);
2724 
2725   std::vector<std::pair<CNodePtr, LossNodeInfo>> sens_loss_pairs = GetSensLossPairs(root);
2726   auto has_backward = HasBackward(root);
2727   // split sens must before inserting the operators.
2728   HandleSens(sens_loss_pairs);
2729 
2730   const auto &node_users_map = manager->node_users();
2731   for (auto &node : all_nodes) {
2732     MS_EXCEPTION_IF_NULL(node);
2733     if (node->isa<CNode>()) {
2734       auto cnode = node->cast<CNodePtr>();
2735       if (IsValueNode<FuncGraph>(cnode->input(0))) {
2736         StepRedistribution(cnode, node_users_map);
2737         continue;
2738       }
2739       // the make_tuple is parallel care node, but it may have not operator info
2740       if ((!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) && !IsControlFlowNode(cnode)) {
2741         continue;
2742       }
2743       OperatorInfoPtr distribute_operator = nullptr;
2744       if (!IsControlFlowNode(cnode)) {
2745         distribute_operator = GetDistributeOperator(cnode);
2746         MS_EXCEPTION_IF_NULL(distribute_operator);
2747       }
2748 
2749       // skip Send Receive
2750       auto parallel_context = parallel::ParallelContext::GetInstance();
2751       MS_EXCEPTION_IF_NULL(parallel_context);
2752       auto is_pp_interleave = parallel_context->pipeline_interleave();
2753       if (!cnode->HasPrimalAttr(PIPELINE_PARAM) || is_pp_interleave) {
2754         // insert forward ops
2755         if (!IsControlFlowNode(cnode)) {
2756           InsertForwardOps(distribute_operator, cnode);
2757         }
2758 
2759         // insert redistribution ops
2760         StepRedistribution(cnode, node_users_map);
2761       }
2762       // insert backward ops
2763       if (!IsControlFlowNode(cnode) && (has_backward || IsPynativeParallel())) {
2764         BackwardCommunication(root, distribute_operator, cnode, sens_loss_pairs);
2765       }
2766       if (!IsControlFlowNode(cnode)) {
2767         distribute_operator->ReplaceNodeInputOrAttrs();
2768       }
2769     } else if (IsValueNode<Tensor>(node) || IsValueNode<ValueList>(node) || IsValueNode<ValueTuple>(node)) {
2770       StepSplitTensor(node, manager);
2771     }
2772   }
2773   // StepReplace
2774   StepReplace(all_nodes);
2775 }
2776 
IsGatherInfo(const std::string & name)2777 static bool IsGatherInfo(const std::string &name) {
2778   std::vector<std::string> gather_info_names = {"GatherInfo", "SparseGatherV2Info", "EmbeddingLookupInfo"};
2779   for (std::string info_name : gather_info_names) {
2780     if (name.find(info_name) != std::string::npos) {
2781       return true;
2782     }
2783   }
2784   return false;
2785 }
2786 
AssignStrategyMap(const StrategyPtr & stra,const std::string & strategy_key_name,StrategyMap * stra_map)2787 void AssignStrategyMap(const StrategyPtr &stra, const std::string &strategy_key_name, StrategyMap *stra_map) {
2788   if (stra) {
2789     (*stra_map)[strategy_key_name] = stra;
2790   } else {
2791     Strategies new_stra_v;
2792     StrategyPtr new_stra = std::make_shared<Strategy>(g_device_manager->stage_id(), new_stra_v);
2793     (*stra_map)[strategy_key_name] = new_stra;
2794   }
2795 }
2796 
AssignManualShapeMapForGather(const OperatorInfoPtr & operator_info,const std::string & param_name,ManualShapeMap * manual_shape_map)2797 void AssignManualShapeMapForGather(const OperatorInfoPtr &operator_info, const std::string &param_name,
2798                                    ManualShapeMap *manual_shape_map) {
2799   if (IsGatherInfo(operator_info->name())) {
2800     auto gather_info = std::dynamic_pointer_cast<GatherInfo>(operator_info);
2801     auto param_split_shapes = gather_info->param_split_shapes();
2802     auto index_offsets = gather_info->index_offsets();
2803     if (param_split_shapes.size() != index_offsets.size()) {
2804       MS_LOG(EXCEPTION) << "In manual split, the param_split_shapes and index_offsets length should be same.";
2805     }
2806     std::vector<std::pair<int64_t, int64_t>> manual_shape;
2807     for (int64_t i = 0; i < UlongToLong(param_split_shapes.size()); ++i) {
2808       (void)manual_shape.emplace_back(std::make_pair(param_split_shapes[LongToSize(i)], index_offsets[LongToSize(i)]));
2809     }
2810     (*manual_shape_map)[param_name] = manual_shape;
2811   }
2812 }
2813 
CheckpointStrategy(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr & root)2814 static void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
2815   if (!StrategyCheckpoint::GetInstance().SaveCheckPointOn()) {
2816     return;
2817   }
2818 
2819   StrategyMap stra_map;
2820   TensorInfoMap tensor_info_map;
2821   ManualShapeMap manual_shape_map;
2822   for (auto &node : all_nodes) {
2823     MS_EXCEPTION_IF_NULL(node);
2824     auto cnode = node->cast<CNodePtr>();
2825     if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
2826       continue;
2827     }
2828     auto param_names = NodeParameterName(cnode, -1, 0);
2829     if (param_names.empty()) {
2830       continue;
2831     }
2832     string param_name = param_names[0].first;
2833     PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
2834     MS_EXCEPTION_IF_NULL(prim);
2835     OperatorInfoPtr operator_info = cnode->user_data<OperatorInfo>();
2836     if (operator_info) {
2837       std::string strategy_key_name = prim->name() + "_" + param_name;
2838       StrategyPtr stra;
2839       if (operator_info->name().find(RESHAPEINFO) != std::string::npos) {
2840         auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
2841         stra = reshape_info->get_input_shard_strategy();
2842         if (stra == nullptr) {
2843           MS_LOG(INFO) << "Reshape has not input strategy, Skipped";
2844           continue;
2845         }
2846       } else {
2847         stra = operator_info->strategy();
2848       }
2849       AssignStrategyMap(stra, strategy_key_name, &stra_map);
2850 
2851       for (auto param_name_pair : param_names) {
2852         tensor_info_map[param_name_pair.first] = param_name_pair.second->user_data<TensorLayout>();
2853       }
2854       AssignManualShapeMapForGather(operator_info, param_name, &manual_shape_map);
2855     }
2856   }
2857   for (auto &cloned_parameter_node : root->parameters()) {
2858     MS_EXCEPTION_IF_NULL(cloned_parameter_node);
2859     auto cloned_parameter = cloned_parameter_node->cast<ParameterPtr>();
2860     MS_EXCEPTION_IF_NULL(cloned_parameter);
2861 
2862     if (!ParameterIsCloned(cloned_parameter_node) && !IsStrategySaved(cloned_parameter_node)) {
2863       continue;
2864     }
2865     std::string cloned_param_name = cloned_parameter_node->cast<ParameterPtr>()->name();
2866     auto cloned_param_layout = cloned_parameter_node->user_data<TensorLayout>();
2867     if (cloned_param_layout == nullptr) {
2868       continue;
2869     }
2870     tensor_info_map[cloned_param_name] = cloned_param_layout;
2871   }
2872   if (StrategyCheckpoint::GetInstance().Save(stra_map, tensor_info_map, manual_shape_map) != SUCCESS) {
2873     MS_LOG(EXCEPTION) << "Save strategy checkpoint failed";
2874   }
2875 }
2876 
SetForwardFlag(const std::vector<AnfNodePtr> & all_nodes)2877 static void SetForwardFlag(const std::vector<AnfNodePtr> &all_nodes) {
2878   for (auto &node : all_nodes) {
2879     MS_EXCEPTION_IF_NULL(node);
2880     if (!node->isa<CNode>()) {
2881       continue;
2882     }
2883     auto cnode = node->cast<CNodePtr>();
2884     if (!IsValueNode<Primitive>(cnode->input(0))) {
2885       continue;
2886     }
2887 
2888     // CNode is globally unique.
2889     MS_LOG(DEBUG) << "Set forward flag " << cnode->DebugString() << ".";
2890     cnode->set_in_forward_flag(true);
2891   }
2892 }
2893 
SetForwardFlag(const AnfNodeSet & all_nodes)2894 static void SetForwardFlag(const AnfNodeSet &all_nodes) {
2895   for (auto &node : all_nodes) {
2896     MS_EXCEPTION_IF_NULL(node);
2897     if (!node->isa<CNode>()) {
2898       continue;
2899     }
2900     auto cnode = node->cast<CNodePtr>();
2901     if (!IsValueNode<Primitive>(cnode->input(0))) {
2902       continue;
2903     }
2904 
2905     // CNode is globally unique.
2906     cnode->set_in_forward_flag(true);
2907   }
2908 }
2909 
ForwardGraph(const FuncGraphPtr & root)2910 std::set<FuncGraphPtr> ForwardGraph(const FuncGraphPtr &root) {
2911   MS_EXCEPTION_IF_NULL(root);
2912   auto ret = root->get_return();
2913   MS_EXCEPTION_IF_NULL(ret);
2914   auto all_nodes = TopoSort(ret, SuccDeeperSimple);
2915   std::set<FuncGraphPtr> graph_set = FindForwardGraphByRootNodes(all_nodes);
2916   return graph_set;
2917 }
2918 
FindRootForwardCNode(const FuncGraphPtr & graph,const std::vector<AnfNodePtr> & all_nodes)2919 static std::vector<AnfNodePtr> FindRootForwardCNode(const FuncGraphPtr &graph,
2920                                                     const std::vector<AnfNodePtr> &all_nodes) {
2921   MS_EXCEPTION_IF_NULL(graph);
2922   std::vector<AnfNodePtr> root_forward_nodes;
2923   auto loss_cnode = FindLossCNode(graph).loss_node;
2924   if (loss_cnode == nullptr) {
2925     return root_forward_nodes;
2926   }
2927 
2928   auto loss_cnode_id = loss_cnode->UniqueIdThroughCopy();
2929   for (auto &node : all_nodes) {
2930     MS_EXCEPTION_IF_NULL(node);
2931     if (!node->isa<CNode>()) {
2932       continue;
2933     }
2934     auto cnode = node->cast<CNodePtr>();
2935     auto root_node_id = node->UniqueIdThroughCopy();
2936     if (loss_cnode_id == root_node_id) {
2937       root_forward_nodes = DeepLinkedGraphSearch(cnode);
2938       break;
2939     }
2940   }
2941   return root_forward_nodes;
2942 }
2943 
InsertShapeOp(const CNodePtr & node,const AnfNodePtr & pre_node,const FuncGraphPtr & root)2944 static void InsertShapeOp(const CNodePtr &node, const AnfNodePtr &pre_node, const FuncGraphPtr &root) {
2945   // shape op doesn't have params and attrs.
2946   OperatorParams params;
2947   OperatorAttrs attrs;
2948   auto shape_value = GetValueNode(node->input(2))->cast<ValueSequencePtr>();
2949   MS_EXCEPTION_IF_NULL(shape_value);
2950   auto shape = shape_value->value();
2951   if (shape.empty()) {
2952     return;
2953   }
2954   OperatorArgs args = std::make_pair(attrs, params);
2955   Operator op = std::make_pair(SHAPE_OP, args);
2956   InsertNode(op, node, 2, pre_node, root, "shape");
2957 }
2958 
FindGrad(const CNodePtr & cnode,size_t curr_depth)2959 static AnfNodePtr FindGrad(const CNodePtr &cnode, size_t curr_depth) {
2960   if (curr_depth > MAX_RECURSIVE_DEPTH) {
2961     MS_LOG(WARNING) << "When finding Grad nodes, exceeded the maximum recursion depth: " << MAX_RECURSIVE_DEPTH;
2962     return nullptr;
2963   }
2964   for (auto &node : cnode->inputs()) {
2965     if (!node->isa<CNode>()) {
2966       continue;
2967     }
2968     if (!IsPrimitiveCNode(node, prim::kPrimEnvironGet)) {
2969       return FindGrad(node->cast<CNodePtr>(), ++curr_depth);
2970     } else {
2971       return node;
2972     }
2973   }
2974   return nullptr;
2975 }
2976 
HandleRootReshapeAndSaveStrategy(const std::vector<AnfNodePtr> & all_nodes)2977 static void HandleRootReshapeAndSaveStrategy(const std::vector<AnfNodePtr> &all_nodes) {
2978   // If root graph has reshape op. Find the corresponding parameter.
2979   // Reshape's shape is the shape of the parameter.
2980   auto executor = pipeline::GraphExecutorPy::GetInstance();
2981   for (auto &node : all_nodes) {
2982     if (!node->isa<CNode>()) {
2983       continue;
2984     }
2985     auto cnode = node->cast<CNodePtr>();
2986     if (!IsValueNode<Primitive>(cnode->input(0)) || cnode == nullptr) {
2987       continue;
2988     }
2989     if (cnode->in_forward_flag()) {
2990       // Save strategy in executor
2991       OperatorInfoPtr op_info = cnode->user_data<OperatorInfo>();
2992       if (op_info) {
2993         auto stra_ptr = op_info->strategy();
2994         if (stra_ptr) {
2995           auto strategy = stra_ptr->GetInputDim();
2996           // fullname with scope should be found in step parallel end ir
2997           executor->SetCNodeStrategy(cnode->fullname_with_scope(), strategy);
2998         }
2999       }
3000       continue;
3001     }
3002 
3003     auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
3004     if (prim->name() != RESHAPE) {
3005       continue;
3006     }
3007 
3008     Shape origin_dst_shape = GetValue<std::vector<int64_t>>(cnode->input(2)->cast<ValueNodePtr>()->value());
3009     if (origin_dst_shape.size() == 1 && origin_dst_shape[0] == -1) {
3010       continue;
3011     }
3012     auto root = node->func_graph();
3013     auto grad_node = FindGrad(cnode, 0);
3014     if (grad_node) {
3015       InsertShapeOp(cnode, grad_node, root);
3016     }
3017   }
3018 }
3019 
MarkForwardCNode(const FuncGraphPtr & root)3020 void MarkForwardCNode(const FuncGraphPtr &root) {
3021   MS_EXCEPTION_IF_NULL(root);
3022   auto ret = root->get_return();
3023   MS_EXCEPTION_IF_NULL(ret);
3024   auto all_nodes = TopoSort(ret, SuccDeeperSimple);
3025   auto graph_set = FindForwardGraphByRootNodes(all_nodes);
3026 
3027   if (graph_set.empty()) {
3028     MS_LOG(INFO) << "Can not find the forward graph, so mark the ops in root graph";
3029     auto fgs = root->manager()->func_graphs();
3030     for (auto fg = fgs.cbegin(); fg != fgs.cend(); ++fg) {
3031       SetForwardFlag((*fg)->nodes());
3032     }
3033   } else {
3034     for (auto func_graph = graph_set.cbegin(); func_graph != graph_set.cend(); ++func_graph) {
3035       MS_LOG(INFO) << "The sub graph size of root is " << root->func_graphs_used().size();
3036       auto return_node = (*func_graph)->get_return();
3037       MS_EXCEPTION_IF_NULL(return_node);
3038       auto all_dfs_nodes = DeepLinkedGraphSearch(return_node);
3039       SetForwardFlag(all_dfs_nodes);
3040       auto root_forward_nodes = FindRootForwardCNode(*func_graph, all_nodes);
3041       if (root_forward_nodes.empty()) {
3042         continue;
3043       }
3044       // Mark forward flag for the nodes in root graph.
3045       SetForwardFlag(root_forward_nodes);
3046     }
3047   }
3048 }
3049 
set_make_list_for_ifa(CNodePtr make_list,const CNodePtr & next_node)3050 OperatorInfoPtr set_make_list_for_ifa(CNodePtr make_list, const CNodePtr &next_node) {
3051   ValueNodePtr anf_node = next_node->input(0)->cast<ValueNodePtr>();
3052   if (!anf_node) {
3053     return nullptr;
3054   }
3055   PrimitivePtr prim = anf_node->value()->cast<PrimitivePtr>();
3056   if (!prim) {
3057     return nullptr;
3058   }
3059   if (prim->name() != INCRE_FLASH_ATTENTION) {
3060     return nullptr;
3061   }
3062 
3063   int kv_index = 1;
3064   OperatorInfoPtr operator_make_list = CreateOperatorInfo(make_list);
3065   auto make_list_prim = GetValueNode<PrimitivePtr>(make_list->input(0));
3066   if (make_list_prim->HasAttr(STAND_ALONE)) {
3067     (void)make_list_prim->DelAttr(STAND_ALONE);
3068   }
3069   OperatorInfoPtr next_operator = next_node->user_data<OperatorInfo>();
3070   StrategyPtr next_node_strategy = next_operator->strategy();
3071   Strategies key_value_strategies;
3072   Dimensions key_value_dim = next_node_strategy->GetInputDim().at(kv_index);
3073   key_value_strategies.push_back(key_value_dim);
3074   auto make_list_stage = next_node_strategy->GetInputStage();
3075   auto make_list_new_in_stra = NewStrategy(make_list_stage, key_value_strategies);
3076   operator_make_list->set_strategy(make_list_new_in_stra);
3077 
3078   std::vector<TensorInfo> kv_in_tensor_info(1, next_operator->inputs_tensor_info()[kv_index]);
3079   operator_make_list->set_inputs_tensor_info(kv_in_tensor_info);
3080   return operator_make_list;
3081 }
3082 
HandleForwardMakeTupleAndMakeList(const std::vector<AnfNodePtr> & all_nodes)3083 static void HandleForwardMakeTupleAndMakeList(const std::vector<AnfNodePtr> &all_nodes) {
3084   for (auto &node : all_nodes) {
3085     if (!AnfNodeIsPrimitive(node, MAKE_TUPLE) && !AnfNodeIsPrimitive(node, MAKE_LIST)) {
3086       continue;
3087     }
3088 
3089     auto cnode = node->cast<CNodePtr>();
3090     MS_EXCEPTION_IF_NULL(cnode);
3091     if (!cnode->in_forward_flag()) {
3092       continue;
3093     }
3094 
3095     FuncGraphManagerPtr manager = cnode->func_graph()->manager();
3096     MS_EXCEPTION_IF_NULL(manager);
3097 
3098     // MakeTuple has multiple users, each user's TensorInfo must be same.
3099     auto make_tuple_list_next_node = CheckMakeTupleSplit(node, manager);
3100     if (make_tuple_list_next_node == nullptr) {
3101       continue;
3102     }
3103     auto make_tuple_list_next_cnode = make_tuple_list_next_node->cast<CNodePtr>();
3104     MS_EXCEPTION_IF_NULL(make_tuple_list_next_cnode);
3105     if (!IsSomePrimitiveList(make_tuple_list_next_cnode, INPUT_IS_TUPLE_OR_LIST_OPS)) {
3106       continue;
3107     }
3108 
3109     OperatorInfoPtr op_info = set_make_list_for_ifa(cnode, make_tuple_list_next_cnode);
3110     if (op_info == nullptr) {
3111       op_info = GetDistributeOperator(make_tuple_list_next_cnode);
3112     }
3113     MS_EXCEPTION_IF_NULL(op_info);
3114     cnode->set_user_data<OperatorInfo>(op_info);
3115   }
3116 }
3117 
CreateGroupsByCkptFile(const std::string & file)3118 bool CreateGroupsByCkptFile(const std::string &file) {
3119   GroupInfoMap group_info_map;
3120   if (StrategyCheckpoint::GetInstance().LoadGroupInfo(file, &group_info_map) != SUCCESS) {
3121     return false;
3122   }
3123 
3124   if (CreateGroups(group_info_map) != SUCCESS) {
3125     return false;
3126   }
3127   MS_LOG(INFO) << "Create groups by checkpoint file success";
3128   return true;
3129 }
3130 
ReorderForPipelineSplit(const FuncGraphPtr & root,const FuncGraphManagerPtr & manager,int64_t pipeline_stages)3131 static void ReorderForPipelineSplit(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager,
3132                                     int64_t pipeline_stages) {
3133   auto parallel_context = parallel::ParallelContext::GetInstance();
3134   MS_EXCEPTION_IF_NULL(parallel_context);
3135   auto is_pp_interleave = parallel_context->pipeline_interleave();
3136   if (is_pp_interleave) {
3137     return;
3138   }
3139   if (!root->has_flag(kSkipAutoParallelCompile) && !root->has_flag(BACKWARD) && pipeline_stages > 1) {
3140     root->set_flag(BACKWARD, true);
3141     if (IsTraining(manager)) {
3142       if (parallel_context->enable_fold_pipeline()) {
3143         MS_LOG(INFO) << "Begin Fold Pipeline Reorder. ";
3144         FoldPipelineReorder(root);
3145       } else {
3146         Reorder(root);
3147       }
3148     } else {
3149       ReorderForPredict(root, manager);
3150     }
3151   }
3152 }
3153 
ReorderForGradAccumulation(const FuncGraphPtr & root,const FuncGraphManagerPtr & manager)3154 static void ReorderForGradAccumulation(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
3155   if (!root->has_flag(kSkipAutoParallelCompile) && !root->has_flag(BACKWARD) &&
3156       ParallelContext::GetInstance()->grad_accumulation_step() > 1) {
3157     root->set_flag(BACKWARD, true);
3158     auto context = MsContext::GetInstance();
3159     MS_EXCEPTION_IF_NULL(context);
3160     const auto cell_reuse = context->CellReuseLevel() != CellReuseLevel::kNoCellReuse;
3161     DumpGraph(root, "before_reorder");
3162     if (IsTraining(manager)) {
3163       if (cell_reuse) {
3164         TagMicroBatchBpEndInCellShare(root, manager);
3165       }
3166       std::unordered_map<int64_t, std::vector<CNodePtr>> forward_start;
3167       std::unordered_map<int64_t, std::vector<CNodePtr>> backward_end;
3168       ExtractMicroBatchBorderNodes(root, &forward_start, &backward_end);
3169       ReorderGradAccumulation(root, forward_start, backward_end);
3170       DumpGraph(root, "after_reorder");
3171     } else {
3172       MS_LOG(EXCEPTION) << "Current not support predict with grad_accu";
3173     }
3174   }
3175 }
3176 
HandleDataParallel()3177 static void HandleDataParallel() {
3178   std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
3179   if (parallel_mode == kDataParallel) {
3180     auto group_info_save_path = common::GetEnv("GROUP_INFO_FILE");
3181     if (!group_info_save_path.empty()) {
3182       std::vector<std::pair<std::string, std::vector<uint32_t>>> group_info;
3183       int64_t device_num = GetCommInfo().device_num;
3184       RankList comm_group;
3185       for (size_t i = 0; i < size_t(device_num); ++i) {
3186         comm_group.push_back(i);
3187       }
3188       ParallelContext::GetInstance()->set_group_ckpt_save_file(group_info_save_path);
3189       if (StrategyCheckpoint::GetInstance().SaveGroupInfo(group_info, comm_group) != SUCCESS) {
3190         MS_LOG(EXCEPTION) << "Save group info failed";
3191       }
3192     }
3193   }
3194 }
3195 
MicroBatchPreProcess(const FuncGraphPtr & root,const FuncGraphManagerPtr & manager,const std::vector<AnfNodePtr> & all_nodes)3196 static void MicroBatchPreProcess(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager,
3197                                  const std::vector<AnfNodePtr> &all_nodes) {
3198   auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num();
3199   if (pipeline_stages > 1) {
3200     HandleMicroBatch(all_nodes, manager);
3201     ParameterStartNode(all_nodes, manager);
3202     LastStageEndNode(all_nodes, manager, root);
3203     return;
3204   }
3205   TagMicroBatchStart(manager, all_nodes);
3206   TagMicroBatchEnd(manager, all_nodes);
3207   auto context = MsContext::GetInstance();
3208   MS_EXCEPTION_IF_NULL(context);
3209   const auto no_cell_reuse = context->CellReuseLevel() == CellReuseLevel::kNoCellReuse;
3210   bool enable_grad_accu = ParallelContext::GetInstance()->grad_accumulation_step() > 1;
3211   if (no_cell_reuse && enable_grad_accu) {
3212     TagMicroBatchBpEndPrim(root);
3213     TagMicroBatchBpEnd(root);
3214   }
3215 }
3216 
MicroBatchPostProcess(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & all_nodes)3217 static void MicroBatchPostProcess(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
3218   auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num();
3219   if (pipeline_stages > 1) {
3220     AddVirtualAssignAdd(root);
3221     HandleReceiveParam(root);
3222     LabelGenMaskMicro(root);
3223     return;
3224   }
3225   if (ParallelContext::GetInstance()->grad_accumulation_step() > 1) {
3226     AddVirtualAssignAdd(root);
3227     LabelGenMaskMicro(root);
3228   }
3229 }
3230 
InsertAllReduceForNormValue(const AnfNodePtr & res_node)3231 static void InsertAllReduceForNormValue(const AnfNodePtr &res_node) {
3232   auto cnode = res_node->cast<CNodePtr>();
3233   auto graphs = res_node->func_graph();
3234   MS_EXCEPTION_IF_NULL(graphs);
3235   auto manager = graphs->manager();
3236   MS_EXCEPTION_IF_NULL(manager);
3237   auto &node_user_map = manager->node_users();
3238   if (!IsSomePrimitive(cnode, EXPAND_DIMS)) {
3239     MS_LOG(ERROR) << "Expected the operator expand_dims, but found the " << GetPrimName(cnode)
3240                   << "This may cause the calculation of the global norm incorrect";
3241     return;
3242   }
3243   auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num();
3244   auto find_node = res_node;
3245   uint32_t limits = 0;
3246   while (!IsSomePrimitive(find_node->cast<CNodePtr>(), SQRT) && limits < MAX_BFS_DEPTH) {
3247     auto users = node_user_map.at(find_node);
3248     if (users.empty()) {
3249       return;
3250     }
3251     find_node = users.front().first;
3252     ++limits;
3253   }
3254   if (!find_node || !IsSomePrimitive(find_node->cast<CNodePtr>(), SQRT)) {
3255     return;
3256   }
3257   auto anf_node = find_node->cast<CNodePtr>();
3258   if (anf_node->size() > 1 && IsSomePrimitive(anf_node->input(1)->cast<CNodePtr>(), ALL_REDUCE)) {
3259     return;
3260   }
3261   auto sqrt_node = find_node;
3262   auto cur_stage_rank_list = g_device_manager->GetDeviceListInThisStage();
3263   Group cur_stage_device_list;
3264   if (g_device_manager->CreateGroup(cur_stage_rank_list, &cur_stage_device_list) != SUCCESS) {
3265     MS_LOG(EXCEPTION) << "Create the communication group for allreduce in calculating global norm failed, "
3266                          "the rank_list is: "
3267                       << cur_stage_rank_list;
3268   }
3269   InsertAllReduceToNodeInput(sqrt_node->cast<CNodePtr>(), cur_stage_device_list.name(), PARALLEL_GLOBALNORM);
3270   MS_LOG(INFO) << "Insert the AllReduce for global norm value in stages succeed.";
3271   if (pipeline_stages > 1) {
3272     MS_LOG(INFO) << "Insert the AllReduce for global norm value between stages succeed.";
3273     auto ranks_between_stages = g_device_manager->GetDeviceListBetweenStage();
3274     Group group_between_stages;
3275     if (g_device_manager->CreateGroup(ranks_between_stages, &group_between_stages) != SUCCESS) {
3276       MS_LOG(EXCEPTION) << "Create the communication group for allreduce in calculating global norm "
3277                            "with pipeline parallel failed, the rank_list is: "
3278                         << cur_stage_rank_list;
3279     }
3280     InsertAllReduceToNodeInput(sqrt_node->cast<CNodePtr>(), group_between_stages.name(), PARALLEL_GLOBALNORM_BETWEEN);
3281   }
3282 }
3283 
FindExpandDimsWIthGradScale(const AnfNodePtr & node_ptr,const NodeUsersMap & node_users_map,uint32_t limits)3284 static AnfNodePtr FindExpandDimsWIthGradScale(const AnfNodePtr &node_ptr, const NodeUsersMap &node_users_map,
3285                                               uint32_t limits) {
3286   std::queue<AnfNodePtr> visited;
3287   AnfNodePtr queue_node = nullptr;
3288   CNodePtr cnode = nullptr;
3289   AnfNodePtr last_node = nullptr;
3290   uint32_t depth = 0;
3291   if (!node_ptr) {
3292     return nullptr;
3293   }
3294   visited.push(node_ptr);
3295   while (!visited.empty()) {
3296     queue_node = visited.front();
3297     visited.pop();
3298     cnode = queue_node->cast<CNodePtr>();
3299     // MAKE_TUPLE will not appear after the load in the forward graph
3300     if (IsSomePrimitive(cnode, EXPAND_DIMS)) {
3301       auto value = GetAttrsFromAnfNode(queue_node, GRAD_SCALE);
3302       if (!value || !GetValue<bool>(value)) {
3303         continue;
3304       }
3305       return queue_node;
3306     }
3307     if (!IsSomePrimitiveList(
3308           cnode, {ENVIRONGET, MUL, SQUARE, REDUCE_SUM, EXPAND_DIMS, DEPEND, CAST, REF_TO_EMBED, EMBED, LOAD})) {
3309       continue;
3310     }
3311     auto node_set = node_users_map.at(queue_node);
3312     for (auto &node_user : node_set) {
3313       visited.push(node_user.first);
3314     }
3315     if (!last_node || last_node == queue_node) {
3316       if (++depth == limits) {
3317         break;
3318       }
3319       last_node = visited.back();
3320     }
3321   }
3322   return nullptr;
3323 }
3324 
InsertDivAndAllReduceForNorm(const NodeUsersMap & node_user_map,const AnfNodePtr & parameter,uint32_t dev_num)3325 static void InsertDivAndAllReduceForNorm(const NodeUsersMap &node_user_map, const AnfNodePtr &parameter,
3326                                          uint32_t dev_num) {
3327   auto params_user_set = node_user_map.at(parameter);
3328   for (auto &param_pair : params_user_set) {
3329     auto cnode = param_pair.first->cast<CNodePtr>();
3330     MS_EXCEPTION_IF_NULL(cnode);
3331     if (cnode->in_forward_flag()) {
3332       continue;
3333     }
3334     constexpr size_t bfs_depth = 10;
3335     auto expand_dims_node = FindExpandDimsWIthGradScale(cnode, node_user_map, bfs_depth);
3336     if (!expand_dims_node) {
3337       continue;
3338     }
3339     auto value = GetAttrsFromAnfNode(expand_dims_node, GRAD_SCALE);
3340     if (!value || !GetValue<bool>(value)) {
3341       continue;
3342     }
3343     if (dev_num > 0) {
3344       InsertRealDivOpToNodeInput(expand_dims_node->cast<CNodePtr>(), dev_num, PARALLEL_GLOBALNORM_DIV);
3345       MS_LOG(INFO) << "Insert the realdiv with " << dev_num << " for the parameter " << parameter->fullname_with_scope()
3346                    << " succeed!";
3347     }
3348     // If already inserted allreduce, the pattern will not be matched and thus no allreduce will be inserted.
3349     InsertAllReduceForNormValue(expand_dims_node);
3350   }
3351 }
3352 
GetMirrorOp(const NodeUsersMap & node_user_map,const AnfNodePtr & parameter)3353 static AnfNodePtr GetMirrorOp(const NodeUsersMap &node_user_map, const AnfNodePtr &parameter) {
3354   auto params_user_set = node_user_map.at(parameter);
3355   for (auto &param_pair : params_user_set) {
3356     auto cnode = param_pair.first->cast<CNodePtr>();
3357     std::vector<AnfNodePtr> candidate = {cnode};
3358     if (!cnode->in_forward_flag()) {
3359       continue;
3360     }
3361     while (IsInTrivialNodeList(cnode) || IsSomePrimitive(cnode, LOAD) ||
3362            IsPrimitiveCNode(cnode, prim::kPrimMicroStepAllGather) || IsPrimitiveCNode(cnode, prim::kPrimAllGather)) {
3363       auto load_users = node_user_map.at(cnode);
3364       cnode = node_user_map.at(cnode).front().first->cast<CNodePtr>();
3365       MS_EXCEPTION_IF_NULL(cnode);
3366       (void)std::transform(load_users.begin(), load_users.end(), std::back_inserter(candidate),
3367                            [](const auto &v) { return v.first; });
3368     }
3369     for (auto &node : candidate) {
3370       auto local_cnode = node->cast<CNodePtr>();
3371       if (!IsPrimitiveCNode(local_cnode, prim::kPrimMirror) &&
3372           !IsPrimitiveCNode(local_cnode, prim::kPrimMirrorMicroStep) &&
3373           !IsPrimitiveCNode(local_cnode, prim::kPrimMirrorMiniStep)) {
3374         continue;
3375       }
3376       return node;
3377     }
3378   }
3379   return nullptr;
3380 }
3381 
HandleGlobalNormScale(const FuncGraphPtr & root,const FuncGraphManagerPtr & manager)3382 static void HandleGlobalNormScale(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
3383   auto parameters = root->parameters();
3384   const auto &node_user_map = manager->node_users();
3385   MS_LOG(INFO) << "Start to process the global norm";
3386 
3387   for (auto &parameter : parameters) {
3388     int64_t dev_num = 0;
3389     if (!ParameterRequireGrad(parameter)) {
3390       continue;
3391     }
3392     auto mirror_node = GetMirrorOp(node_user_map, parameter);
3393     auto device_num_ptr = GetAttrsFromAnfNode(mirror_node, DEV_NUM);
3394     if (device_num_ptr && device_num_ptr->isa<Int64Imm>()) {
3395       dev_num = GetValue<int64_t>(device_num_ptr);
3396     }
3397     InsertDivAndAllReduceForNorm(node_user_map, parameter, LongToUint(dev_num));
3398   }
3399 }
3400 
MoveMicroMirrorOutCallFunc(const FuncGraphPtr & root)3401 static void MoveMicroMirrorOutCallFunc(const FuncGraphPtr &root) {
3402   AnfNodePtr ret_after = root->get_return();
3403   MS_EXCEPTION_IF_NULL(ret_after);
3404   auto all_nodes = TopoSort(ret_after, SuccDeeperSimple);
3405   auto manager = root->manager();
3406   for (const auto &node : all_nodes) {
3407     if (!IsPrimitiveCNode(node, prim::kPrimMirrorMicroStep)) {
3408       continue;
3409     }
3410     auto micro_mirror = node->cast<CNodePtr>();
3411     auto param_anf_node = GetInputNodeWithFilter(micro_mirror, [&](const CNodePtr &cnode) {
3412       bool filter = IsPrimitiveCNode(cnode, prim::kPrimMirrorMicroStep) || IsPrimitiveCNode(cnode, prim::kPrimLoad) ||
3413                     IsPrimitiveCNode(cnode, prim::kPrimDepend);
3414       return std::make_pair(filter, 1);
3415     });
3416     if (!param_anf_node->isa<Parameter>()) {
3417       continue;
3418     }
3419     auto param = param_anf_node->cast<ParameterPtr>();
3420     if (param->has_default()) {
3421       continue;
3422     }
3423     auto sub_func_graph = param_anf_node->func_graph();
3424     auto call_cnodes_map = sub_func_graph->func_graph_cnodes_index();
3425     auto sub_graph_parameters = sub_func_graph->parameters();
3426     auto curr_param_iter = std::find(sub_graph_parameters.begin(), sub_graph_parameters.end(), param_anf_node);
3427     if (curr_param_iter == sub_graph_parameters.end()) {
3428       MS_LOG(EXCEPTION) << "Cannot find param " << param_anf_node->DebugString() << " in current sub_graph";
3429     }
3430     size_t curr_param_index = static_cast<size_t>(curr_param_iter - sub_graph_parameters.begin());
3431     AnfNodePtr call_nodes_common_param_input = nullptr;
3432     FuncGraphPtr call_nodes_func_graph = nullptr;
3433     for (const auto &node_pair : call_cnodes_map) {
3434       if (!node_pair.first->first->isa<CNode>() || node_pair.first->second > 0) {
3435         continue;
3436       }
3437       auto cnode = node_pair.first->first->cast<CNodePtr>();
3438       call_nodes_func_graph = cnode->func_graph();
3439       auto cnode_input = cnode->input(curr_param_index + 1);
3440       if (!call_nodes_common_param_input) {
3441         call_nodes_common_param_input = cnode_input;
3442       }
3443       if (call_nodes_common_param_input != cnode_input) {
3444         call_nodes_common_param_input = nullptr;
3445         break;
3446       }
3447     }
3448     if (!call_nodes_common_param_input || !call_nodes_func_graph) {
3449       continue;
3450     }
3451     // Insert new MicroMirror in root func
3452     if (!IsPrimitiveCNode(call_nodes_common_param_input, prim::kPrimMirrorMicroStep)) {
3453       auto new_mirror_node =
3454         NewMicroMirrorPrimByMicroMirror(call_nodes_func_graph, micro_mirror, call_nodes_common_param_input);
3455       for (const auto &node_pair : call_cnodes_map) {
3456         if (!node_pair.first->first->isa<CNode>() || node_pair.first->second > 0) {
3457           continue;
3458         }
3459         manager->SetEdge(node_pair.first->first, curr_param_index + 1, new_mirror_node);
3460       }
3461     }
3462 
3463     // Remove MicroMirror in call_func
3464     (void)manager->Replace(micro_mirror, micro_mirror->input(kIndex1));
3465   }
3466 }
3467 
MergeMicroMirrorForSharedParameter(const FuncGraphPtr & root)3468 static void MergeMicroMirrorForSharedParameter(const FuncGraphPtr &root) {
3469   AnfNodePtr ret_after = root->get_return();
3470   MS_EXCEPTION_IF_NULL(ret_after);
3471   auto all_nodes = TopoSort(ret_after, SuccDeeperSimple);
3472   auto manager = root->manager();
3473   std::unordered_map<ParameterPtr, std::vector<CNodePtr>> param_mirror_map;
3474   for (const auto &node : all_nodes) {
3475     if (!IsPrimitiveCNode(node, prim::kPrimMirrorMicroStep)) {
3476       continue;
3477     }
3478     auto micro_mirror = node->cast<CNodePtr>();
3479     auto param_anf_node = GetInputNodeWithFilter(micro_mirror, [&](const CNodePtr &cnode) {
3480       bool filter = IsPrimitiveCNode(cnode, prim::kPrimMirrorMicroStep) || IsPrimitiveCNode(cnode, prim::kPrimLoad) ||
3481                     IsPrimitiveCNode(cnode, prim::kPrimDepend) ||
3482                     IsPrimitiveCNode(cnode, prim::kPrimMicroStepAllGather);
3483       return std::make_pair(filter, 1);
3484     });
3485     if (!param_anf_node->isa<Parameter>()) {
3486       continue;
3487     }
3488     auto param = param_anf_node->cast<ParameterPtr>();
3489     param_mirror_map[param].push_back(micro_mirror);
3490   }
3491   for (const auto &parm_pair : param_mirror_map) {
3492     if (parm_pair.second.size() <= 1) {
3493       continue;
3494     }
3495     MS_LOG(INFO) << "Parameter " << parm_pair.first->name() << " still has multi mirror user, merge those mirror.";
3496     auto mirror0 = parm_pair.second.front();
3497     for (size_t i = 1; i < parm_pair.second.size(); ++i) {
3498       (void)manager->Replace(parm_pair.second[i], mirror0);
3499     }
3500   }
3501 }
3502 
BroadcastMultiOutputs(const FuncGraphPtr & root,const FuncGraphManagerPtr & manager,const Group & group)3503 static void BroadcastMultiOutputs(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager, const Group &group) {
3504   auto output = root->get_return()->input(1)->cast<CNodePtr>();
3505   auto output_abstract = output->abstract();
3506   MS_EXCEPTION_IF_NULL(output_abstract);
3507   auto abstract_tuple = output_abstract->cast<abstract::AbstractTuplePtr>();
3508   MS_EXCEPTION_IF_NULL(abstract_tuple);
3509   auto abstract_list = abstract_tuple->elements();
3510 
3511   AnfNodePtrList make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)};
3512   for (size_t i = 0; i < abstract_list.size(); i++) {
3513     auto abstract = abstract_list[i];
3514     MS_EXCEPTION_IF_NULL(abstract);
3515 
3516     // TupleGetItem
3517     auto idx = NewValueNode(SizeToLong(i));
3518     CNodePtr tuple_getitem = root->NewCNode({NewValueNode(prim::kPrimTupleGetItem), output, idx});
3519     MS_EXCEPTION_IF_NULL(tuple_getitem);
3520     tuple_getitem->set_abstract(abstract);
3521 
3522     // Depend: prevent disorder and CSE
3523     if (i > 0) {
3524       tuple_getitem = root->NewCNode({NewValueNode(prim::kPrimDepend), tuple_getitem, make_tuple_input[i]});
3525       MS_EXCEPTION_IF_NULL(tuple_getitem);
3526       tuple_getitem->set_abstract(abstract);
3527     }
3528 
3529     // Allreduce
3530     CNodePtr allreduce = root->NewCNode({NewValueNode(prim::kPrimAllReduce), tuple_getitem});
3531     MS_EXCEPTION_IF_NULL(allreduce);
3532     allreduce->set_abstract(abstract);
3533     common::AnfAlgo::SetNodeAttr(OP, MakeValue(REDUCE_OP_SUM), allreduce);
3534     common::AnfAlgo::SetNodeAttr(GROUP, MakeValue(group.name()), allreduce);
3535     // Disable GE allreduce fusion.
3536     common::AnfAlgo::SetNodeAttr(FUSION, MakeValue(static_cast<int64_t>(0)), allreduce);
3537 
3538     make_tuple_input.push_back(allreduce);
3539   }
3540 
3541   CNodePtr make_tuple_node = root->NewCNode(make_tuple_input);
3542   MS_EXCEPTION_IF_NULL(make_tuple_node);
3543   make_tuple_node->set_abstract(abstract_tuple);
3544   (void)manager->Replace(output, make_tuple_node);
3545 }
3546 
BroadcastLastResult(const FuncGraphPtr & root,const FuncGraphManagerPtr & manager)3547 static void BroadcastLastResult(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
3548   auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num();
3549   auto pipeline_result_broadcast = parallel::ParallelContext::GetInstance()->pipeline_result_broadcast();
3550   if (IsTraining(manager) || stage_num <= 1 || pipeline_result_broadcast == false) {
3551     return;
3552   }
3553 
3554   std::vector<int64_t> rank_list = g_device_manager->GetDeviceListBetweenStage();
3555   Group group;
3556   if (g_device_manager->CreateGroup(rank_list, &group) != SUCCESS) {
3557     MS_LOG(EXCEPTION) << "Create communication group between all pipeline stages failed, the rank_list is: "
3558                       << rank_list;
3559   }
3560 
3561   auto return_node = root->get_return();
3562   const auto &abstract = return_node->abstract();
3563   if (abstract->isa<abstract::AbstractTuple>()) {
3564     return BroadcastMultiOutputs(root, manager, group);
3565   }
3566 
3567   InsertAllReduceToNodeInput(return_node, group.name(), PARALLEL_RESULT_BROADCAST);
3568   return_node->input(1)->set_abstract(abstract);
3569 }
3570 
RecordFlopsOriginShape(const FuncGraphManagerPtr & mng)3571 static void RecordFlopsOriginShape(const FuncGraphManagerPtr &mng) {
3572   for (const auto &each_graph : mng->func_graphs()) {
3573     std::list<CNodePtr> graph_orders = each_graph->GetOrderedCnodes();
3574     std::vector<CNodePtr> origin_nodes_topological(graph_orders.cbegin(), graph_orders.cend());
3575     for (const auto &node : origin_nodes_topological) {
3576       if (IsPrimitiveCNode(node, prim::kPrimConv2D) || IsPrimitiveCNode(node, prim::kPrimBatchMatMul) ||
3577           IsPrimitiveCNode(node, prim::kPrimMatMul)) {
3578         node->AddPrimalAttr(kAttrOriginOutputShape, MakeValue(node->abstract()->GetShapeTrack()->GetShapeVector()));
3579         node->AddPrimalAttr(
3580           kAttrOriginInputShapes,
3581           MakeValue<std::vector<ShapeVector>>({node->input(kIndex1)->abstract()->GetShapeTrack()->GetShapeVector(),
3582                                                node->input(kIndex2)->abstract()->GetShapeTrack()->GetShapeVector()}));
3583       } else if (IsPrimitiveCNode(node, prim::kPrimFlashAttentionScore)) {
3584         node->AddPrimalAttr(
3585           kAttrOriginInputShapes,
3586           MakeValue<std::vector<ShapeVector>>({node->input(kIndex1)->abstract()->GetShapeTrack()->GetShapeVector(),
3587                                                node->input(kIndex2)->abstract()->GetShapeTrack()->GetShapeVector()}));
3588       }
3589     }
3590   }
3591 }
3592 
IsVirtualDatasetDynamicShape(const FuncGraphPtr & func_graph)3593 bool IsVirtualDatasetDynamicShape(const FuncGraphPtr &func_graph) {
3594   MS_EXCEPTION_IF_NULL(func_graph);
3595   auto all_nodes = TopoSort(func_graph->get_return());
3596   for (const auto &node : all_nodes) {
3597     if (!node->isa<CNode>()) {
3598       continue;
3599     }
3600     auto cnode = node->cast<CNodePtr>();
3601     auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
3602     if (prim == nullptr) {
3603       continue;
3604     }
3605     MS_EXCEPTION_IF_NULL(prim);
3606     if (prim->name() == VIRTUAL_DATA_SET) {
3607       MS_LOG(INFO) << "VIRTUAL_DATA_SET: " << cnode->DebugString();
3608       for (size_t i = 1; i < cnode->inputs().size(); ++i) {
3609         auto input_node = cnode->input(i);
3610         auto base_shape = input_node->Shape();
3611         MS_EXCEPTION_IF_NULL(base_shape);
3612         std::vector<int64_t> shape_vec = base_shape->GetShapeVector();
3613         MS_LOG(INFO) << "VIRTUAL_DATA_SET: " << node->fullname_with_scope() << ", shape:" << shape_vec;
3614         if (std::find(shape_vec.begin(), shape_vec.end(), -1) != shape_vec.end()) {
3615           return true;
3616         }
3617       }
3618     }
3619   }
3620   return false;
3621 }
3622 
HandleSilentCheck(const FuncGraphPtr & root,const FuncGraphManagerPtr & mng)3623 static void HandleSilentCheck(const FuncGraphPtr &root, const FuncGraphManagerPtr &mng) {
3624   auto env = common::GetEnv(NPU_ASD_ENABLE);
3625   if (env != kSilentCheckEnvEnable) {
3626     return;
3627   }
3628   auto sdc = std::make_shared<SilentCheck>(root, mng);
3629   if (sdc == nullptr) {
3630     MS_LOG(EXCEPTION) << "The silent check env got nullptr;";
3631   }
3632   sdc->GetLossScale();
3633   sdc->ModifySilentCheckOps();
3634 }
3635 
ParallelPartProcess(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr & root,const FuncGraphManagerPtr & manager)3636 static void ParallelPartProcess(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root,
3637                                 const FuncGraphManagerPtr &manager) {
3638   ReshapeInit(all_nodes);
3639 
3640   SetCastForParamNotRecompute(all_nodes);
3641 
3642   HandleRootReshapeAndSaveStrategy(all_nodes);
3643 
3644   HandleForwardMakeTupleAndMakeList(all_nodes);
3645 
3646   // if the input or parameter has multiple users, check whether its split strategies are consistent.
3647   CheckParameterSplit(all_nodes);
3648 
3649   HandleSymbolicKeyInstance(root, all_nodes);
3650 
3651   // cover Parallel shape
3652   CoverSliceShape(root);
3653 
3654   // handle input is not used
3655   HandleNoUsedParameter(root);
3656 
3657   // set the shape for optimizer's clone tensor
3658   SetClonedTensorShapeForOptimizer(root);
3659 
3660   HandleCameAndAdaFactorOpt(root, all_nodes, manager);
3661 
3662   InsertUniformRealForTaggedNodes(manager, all_nodes);
3663 
3664   auto adasum_param_tensor_layout_map = AdaSumParamTensorLayout(root);
3665   bool is_apply_adasum = HandleAdaSum(root, all_nodes, &adasum_param_tensor_layout_map);
3666 
3667   if (MergeEntireShapeForDynamic(root) != Status::SUCCESS) {
3668     MS_LOG(EXCEPTION) << "Merge entire shape for dynamic shape failed.";
3669   }
3670 
3671   auto parallel_context = parallel::ParallelContext::GetInstance();
3672   MS_EXCEPTION_IF_NULL(parallel_context);
3673   auto is_pp_interleave = parallel_context->pipeline_interleave();
3674   std::shared_ptr<PipelinePostProcess> pipeline_processor;
3675   auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num();
3676   if (pipeline_stages > 1 && is_pp_interleave) {
3677     pipeline_processor =
3678       std::make_shared<PipelinePostProcess>(manager, g_device_manager->stage_id(), pipeline_stages, root);
3679     pipeline_processor->Init(all_nodes);
3680     pipeline_processor->ModifySendRecvAttr(all_nodes);
3681   }
3682   // ForwardCommunication BackwardCommunication TensorRedistribution
3683   ParallelCommunication(root, all_nodes, manager);
3684   SplitNotParallelCareOpsInterleaved(root);
3685   EraseVirtualConverter(root);
3686   if (is_apply_adasum) {
3687     HandleMirrorInAdaSum(root, &adasum_param_tensor_layout_map);
3688   }
3689 
3690   if (pipeline_stages > 1 && is_pp_interleave) {
3691     MS_EXCEPTION_IF_NULL(pipeline_processor);
3692     pipeline_processor->GraphPartition(all_nodes);
3693     pipeline_processor->ElimGraphStage();
3694     pipeline_processor->ModifyParameterList();
3695   }
3696 
3697   // save strategy as checkpoint for multi-train
3698   auto all_nodes_after_pp = TopoSort(root->get_return(), SuccDeeperSimple);
3699   if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) {
3700     CheckpointStrategy(all_nodes_after_pp, root);
3701   }
3702   auto comm_group = FindCommonMirrorGroup(root);
3703   StrategyCheckpoint::GetInstance().set_common_mirror_group(comm_group);
3704   MoveMicroMirrorOutCallFunc(root);
3705   HandleGlobalNormScale(root, manager);
3706   if (pipeline_stages > 1 && is_pp_interleave) {
3707     pipeline_processor->HandleSendParam();
3708     MarkForwardCNode(root);
3709   }
3710   MergeMicroMirrorForSharedParameter(root);
3711   // Insert TensorToTuple for FlashAttentionScore if input actual_seq_len is tensor
3712   PostProcessActualSeqLenInputForFlashAttentionScore(root, all_nodes);
3713   return;
3714 }
3715 
StepParallel(const FuncGraphPtr & root,const opt::OptimizerPtr & optimizer)3716 bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) {
3717 #if defined(__linux__) && defined(WITH_BACKEND)
3718   if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) {
3719     return false;
3720   }
3721 #endif
3722   MS_EXCEPTION_IF_NULL(root);
3723   MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
3724   std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
3725   HandleDataParallel();
3726   FuncGraphManagerPtr manager;
3727   pipeline::ResourceBasePtr res;
3728   if (optimizer == nullptr) {
3729     manager = root->manager();
3730     res = std::make_shared<pipeline::Resource>();
3731     res->set_manager(manager);
3732   } else {
3733     res = optimizer->resource();
3734     MS_EXCEPTION_IF_NULL(res);
3735     manager = res->manager();
3736   }
3737 
3738   MS_EXCEPTION_IF_NULL(manager);
3739   auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num();
3740   if (IsTraining(manager)) {
3741     root->set_flag(kTraining, true);
3742   }
3743   // assume no change to graph
3744   bool changes = false;
3745   // control whether use model_parallel mode
3746   if (!IsAutoParallelCareGraph(root) || (root->has_flag(SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY)) || HasNestedMetaFg(root)) {
3747     if (!root->has_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY)) {
3748       MS_LOG(INFO) << "Strategies would be ignored in " << parallel_mode
3749                    << ", shard() only valid in [semi_]auto_parallel.";
3750       root->set_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY, true);
3751     }
3752     ReorderForPipelineSplit(root, manager, pipeline_stages);
3753     ReorderForGradAccumulation(root, manager);
3754     return changes;
3755   }
3756 
3757   MSLogTime msTime;
3758   msTime.Start();
3759   DumpGraph(root, std::string(STEP_PARALLEL_BEGIN));
3760   RecordFlopsOriginShape(manager);
3761   AnfNodePtr ret = root->get_return();
3762   MS_EXCEPTION_IF_NULL(ret);
3763   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
3764   std::reverse(all_nodes.begin(), all_nodes.end());
3765   bool merged = MergeConcatSlice(all_nodes, manager);
3766   if (merged) {
3767     all_nodes = TopoSort(ret, SuccDeeperSimple);
3768   }
3769   if (pipeline_stages <= 1 && parallel_mode != kAutoParallel && ParallelInit() != SUCCESS) {
3770     MS_LOG(EXCEPTION) << "Parallel init failed";
3771   }
3772 
3773   // Insert TupleToTensor for FA if actual_seq_len input is tuple type.
3774   PreProcessActualSeqLenInputForFlashAttentionScore(root, all_nodes);
3775 
3776   MicroBatchPreProcess(root, manager, all_nodes);
3777   // mark the forward cnodes, parallel only care these nodes
3778   MarkForwardCNode(root);
3779   HandleSilentCheck(root, manager);
3780   // tag dynamic shape graph
3781   TagDynamicShapeFuncGraph(root);
3782   UpdateMicroBatchInterleavedStatus(all_nodes);
3783   if (parallel_mode != kAutoParallel) {
3784     TOTAL_OPS = 0;
3785     ExceptionIfHasCommunicationOp(all_nodes);
3786 
3787     if (IsInsertVirtualOutput(root)) {
3788       InsertVirtualOutput(root, all_nodes);
3789       AnfNodePtr ret_after = root->get_return();
3790       MS_EXCEPTION_IF_NULL(ret_after);
3791       all_nodes = TopoSort(ret_after, SuccDeeperSimple);
3792     }
3793 
3794     // extract shape and strategy, set operator_info
3795     ExtractInformation(all_nodes);
3796   }
3797 
3798   ParallelPartProcess(all_nodes, root, manager);
3799   BroadcastLastResult(root, manager);
3800   MicroBatchPostProcess(root, all_nodes);
3801   UpdateParamSymbolicShape(root);
3802   DumpGraph(root, std::string(STEP_PARALLEL_END));
3803 
3804   // step parallel only run once
3805   root->set_flag(SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY, true);
3806   // Keep all func graph for parallel before save result.
3807   SetReserved(root);
3808   res->SetResult(pipeline::kStepParallelGraph, root);
3809 
3810   // in auto parallel mode, no need to check if strategies set
3811   root->set_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY, true);
3812 
3813   msTime.End();
3814   uint64_t time = msTime.GetRunTimeUS();
3815   MS_LOG(INFO) << "Now leaving step parallel, used time: " << time << " us";
3816   return changes;
3817 }
3818 }  // namespace parallel
3819 }  // namespace mindspore
3820