• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/step_parallel_utils.h"
18 
19 #include <algorithm>
20 #include <cinttypes>
21 
22 #include <map>
23 #include <memory>
24 #include <queue>
25 #include <set>
26 #include <string>
27 #include <utility>
28 
29 #include "abstract/dshape.h"
30 #include "base/base.h"
31 #include "base/bfloat16.h"
32 #include "frontend/operator/ops.h"
33 #include "frontend/optimizer/optimizer.h"
34 #include "frontend/parallel/device_manager.h"
35 #include "frontend/parallel/dynamic_creator.h"
36 #include "frontend/parallel/graph_util/generate_graph.h"
37 #include "frontend/parallel/graph_util/graph_info.h"
38 #include "frontend/parallel/graph_util/node_info.h"
39 #include "frontend/parallel/graph_util/pipeline_split_utils.h"
40 #include "frontend/parallel/node_check.h"
41 #include "frontend/parallel/parameter_manager.h"
42 #include "frontend/parallel/dynamic_shape/dynamic_shape.h"
43 #include "include/common/utils/comm_manager.h"
44 #include "include/common/utils/parallel_context.h"
45 #include "ir/param_info.h"
46 #include "ir/tensor.h"
47 #include "ops/array_ops.h"
48 #include "ops/framework_ops.h"
49 #include "ops/nn_ops.h"
50 #include "ops/other_ops.h"
51 #include "ops/sequence_ops.h"
52 #include "utils/parallel_node_check.h"
53 #include "utils/hash_map.h"
54 #include "utils/ms_context.h"
55 #include "utils/symbolic.h"
56 #include "utils/trace_base.h"
57 #include "mindspore/core/symbolic_shape/int_symbol.h"
58 
59 namespace mindspore {
60 namespace parallel {
61 using mindspore::tensor::Tensor;
62 size_t TOTAL_OPS = 0;
63 // g_RefMap, for CNode B input i is a RefKey[Parameter C],
64 // it will be one item in map with key: C, and value: (B, i)
65 std::map<AnfNodePtr, std::pair<AnfNodePtr, int64_t>> g_RefMap;
66 
IsDynamicShapeInput(const CNodePtr & node,const AnfNodePtr & input)67 bool IsDynamicShapeInput(const CNodePtr &node, const AnfNodePtr &input) {
68   if (IsSomePrimitiveList(node, CANDIDATE_DYNAMIC_VALUE_OPS) &&
69       (IsPrimitiveCNode(input, prim::kPrimMakeTuple) || IsPrimitiveCNode(input, prim::kPrimShape))) {
70     return true;
71   }
72   if (IsPrimitiveCNode(node, prim::kPrimCast) && IsPrimitiveCNode(input, prim::kPrimTupleGetItem)) {
73     BaseShapePtr base_shape_ptr = node->Shape();
74     if (base_shape_ptr == nullptr) {
75       MS_LOG(EXCEPTION) << "IsDynamicShapeInput: " << node->ToString() << " shape_ptr is nullptr, full name is "
76                         << node->fullname_with_scope();
77     }
78     auto shape_ptr = dyn_cast<abstract::Shape>(base_shape_ptr);
79     MS_EXCEPTION_IF_NULL(shape_ptr);
80     if (shape_ptr->shape().empty()) {
81       return true;
82     }
83   }
84   return false;
85 }
86 
IsSomePrimitive(const CNodePtr & cnode,const std::string & name)87 bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name) {
88   if (!cnode) {
89     return false;
90   }
91   ValueNodePtr anf_node = cnode->input(0)->cast<ValueNodePtr>();
92   if (!anf_node) {
93     return false;
94   }
95   PrimitivePtr prim = anf_node->value()->cast<PrimitivePtr>();
96   if (!prim) {
97     return false;
98   }
99   return (prim->name() == name);
100 }
101 
IsSomePrimitiveList(const CNodePtr & cnode,const std::set<string> & check_list)102 bool IsSomePrimitiveList(const CNodePtr &cnode, const std::set<string> &check_list) {
103   if (!cnode) {
104     return false;
105   }
106   ValueNodePtr anf_node = cnode->input(0)->cast<ValueNodePtr>();
107   if (!anf_node) {
108     return false;
109   }
110   PrimitivePtr prim = anf_node->value()->cast<PrimitivePtr>();
111   if (!prim) {
112     return false;
113   }
114   return std::any_of(check_list.begin(), check_list.end(), [prim](const string &in) { return prim->name() == in; });
115 }
116 
IsIgnoreSplitTensor(const CNodePtr & node,int64_t index)117 bool IsIgnoreSplitTensor(const CNodePtr &node, int64_t index) {
118   if (IsSomePrimitiveList(node, SPLIT_TENSOR_ONLY_FOR_FIRST_INPUT_OPS) && index > 0) {
119     return true;
120   }
121   return false;
122 }
123 
GetPrimName(const CNodePtr & node)124 std::string GetPrimName(const CNodePtr &node) {
125   auto prim = GetCNodePrimitive(node);
126   if (!prim) {
127     return node->DebugString();
128   }
129   return prim->name();
130 }
131 
IsTraining(const FuncGraphManagerPtr & manager)132 bool IsTraining(const FuncGraphManagerPtr &manager) {
133   for (auto &fg : manager->func_graphs()) {
134     if (fg->has_flag(kTraining)) {
135       return true;
136     }
137   }
138   return false;
139 }
140 
HasBackward(const FuncGraphPtr & root)141 bool HasBackward(const FuncGraphPtr &root) {
142   auto nodes = root->nodes();
143   for (auto &node : nodes) {
144     if (IsPrimitiveCNode(node, prim::kPrimJ)) {
145       return true;
146     }
147   }
148   return false;
149 }
150 
GetInputsTensorInfo(const std::pair<AnfNodePtr,int64_t> & param_info)151 TensorInfo GetInputsTensorInfo(const std::pair<AnfNodePtr, int64_t> &param_info) {
152   auto user_cnode = param_info.first->cast<CNodePtr>();
153   MS_EXCEPTION_IF_NULL(user_cnode);
154   auto user_input_index = param_info.second;
155   OperatorInfoPtr op_info = user_cnode->user_data<OperatorInfo>();
156   MS_EXCEPTION_IF_NULL(op_info);
157 
158   TensorInfo tensor_info;
159   if (IsPrimitiveCNode(user_cnode, prim::kPrimSend)) {
160     auto param_index = IntToSize(GetValue<int>(user_cnode->GetPrimalAttr(PARAM_INDEX)));
161     tensor_info = op_info->inputs_tensor_info()[param_index];
162   } else {
163     size_t input_tensor_info_size = op_info->inputs_tensor_info().size();
164     if (SizeToLong(input_tensor_info_size) <= user_input_index - 1) {
165       MS_LOG(EXCEPTION) << op_info->name() << ": the size of inputs tensor info is " << input_tensor_info_size
166                         << ", but the index is " << (user_input_index - 1);
167     }
168     tensor_info = op_info->inputs_tensor_info()[LongToSize(user_input_index - 1)];
169   }
170   return tensor_info;
171 }
172 
IsRealKernelNode(const AnfNodePtr & node)173 static bool IsRealKernelNode(const AnfNodePtr &node) {
174   if (IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad) ||
175       IsPrimitiveCNode(node, prim::kPrimCast) || IsPrimitiveCNode(node, prim::kPrimVirtualDiv) ||
176       IsPrimitiveCNode(node, prim::kPrimReceive) || IsPrimitiveCNode(node, prim::kPrimMicroStepAllGather) ||
177       IsPrimitiveCNode(node, prim::kPrimSend) || IsPrimitiveCNode(node, prim::kPrimInsertGradientOf)) {
178     return false;
179   }
180   return true;
181 }
182 
GetRealKernelNode(const AnfNodePtr & node,int64_t get_item_index,CNodePtr * call_node,bool ignore_get_item)183 std::pair<AnfNodePtr, int64_t> GetRealKernelNode(const AnfNodePtr &node, int64_t get_item_index, CNodePtr *call_node,
184                                                  bool ignore_get_item) {
185   if (!IsRealKernelNode(node)) {
186     return GetRealKernelNode(node->cast<CNodePtr>()->input(1), get_item_index, call_node, ignore_get_item);
187   }
188   if ((IsPrimitiveCNode(node, prim::kPrimTupleGetItem) || IsPrimitiveCNode(node, prim::kPrimInsertGradientOf)) &&
189       ignore_get_item) {
190     auto cnode = node->cast<CNodePtr>();
191     auto cur_get_item_index = LongToInt(GetTupleGetItemIndex(cnode));
192     auto tuple_getitem_input = cnode->input(1);
193     return GetRealKernelNode(tuple_getitem_input, cur_get_item_index, call_node, ignore_get_item);
194   }
195   if (get_item_index != -1 &&
196       (IsPrimitiveCNode(node, prim::kPrimMakeTuple) || IsPrimitiveCNode(node, prim::kPrimInsertGradientOf))) {
197     auto make_tuple_cnode = node->cast<CNodePtr>();
198     auto make_tuple_input = make_tuple_cnode->input(LongToSize(get_item_index + 1));
199     return GetRealKernelNode(make_tuple_input, -1, call_node, ignore_get_item);
200   }
201   if (IsControlFlowNode(node)) {
202     auto switch_cnode = node->cast<CNodePtr>()->input(0)->cast<CNodePtr>();
203     auto fg = GetValueNode<FuncGraphPtr>(switch_cnode->input(3));
204     return GetRealKernelNode(fg->output(), get_item_index, call_node, ignore_get_item);
205   }
206   if (node->isa<CNode>() && IsValueNode<FuncGraph>(node->cast<CNodePtr>()->input(0))) {
207     if (call_node != nullptr && *call_node == nullptr) {
208       *call_node = node->cast<CNodePtr>();
209     }
210     auto cnode = node->cast<CNodePtr>();
211     auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
212     auto output = GetRealKernelNode(graph->output(), get_item_index, call_node, ignore_get_item).first;
213     MS_EXCEPTION_IF_NULL(output);
214     if (output->isa<Parameter>()) {
215       auto param_graph = output->func_graph();
216       auto parameter_list = param_graph->parameters();
217       auto fg_used_map = param_graph->func_graph_cnodes_index();
218       for (auto &cur_fg_use : fg_used_map) {
219         if (cur_fg_use.first->second != 0) {
220           continue;
221         }
222         auto cur_fg = cur_fg_use.first->first->cast<CNodePtr>();
223         auto iter = std::find(parameter_list.begin(), parameter_list.end(), output);
224         auto pos = std::distance(parameter_list.begin(), iter);
225         auto argument = cur_fg->input(pos + 1);
226         return GetRealKernelNode(argument, get_item_index, call_node, ignore_get_item);
227       }
228       return std::make_pair(output, get_item_index);
229     }
230     return std::make_pair(output, get_item_index);
231   }
232   return std::make_pair(node, get_item_index);
233 }
234 
IsWhileGraph(const FuncGraphPtr & cur_fg,const FuncGraphPtr & fg)235 static bool IsWhileGraph(const FuncGraphPtr &cur_fg, const FuncGraphPtr &fg) {
236   auto cur_fg_map = cur_fg->func_graph_cnodes_index();
237   for (auto &cur_fg_use : cur_fg_map) {
238     auto temp_node = cur_fg_use.first->first->cast<CNodePtr>();
239     MS_EXCEPTION_IF_NULL(temp_node);
240     if (temp_node->func_graph() == fg) {
241       return true;
242     }
243   }
244   return false;
245 }
246 
CheckMakeTupleSplit(const AnfNodePtr & node,const FuncGraphManagerPtr & manager)247 AnfNodePtr CheckMakeTupleSplit(const AnfNodePtr &node, const FuncGraphManagerPtr &manager) {
248   auto node_users = manager->node_users()[node];
249   if (node_users.size() == 1) {
250     return node_users.front().first;
251   }
252 
253   bool is_first_tensor_info = true;
254   TensorInfo first_tensor_info;
255   AnfNodePtr first_node;
256   for (auto &node_user : node_users) {
257     auto user_node = node_user.first->cast<CNodePtr>();
258     if (!user_node->has_user_data<OperatorInfo>()) {
259       continue;
260     }
261     auto tensor_info = GetInputsTensorInfo(node_user);
262     if (is_first_tensor_info) {
263       is_first_tensor_info = false;
264       first_tensor_info = tensor_info;
265       first_node = node_user.first;
266       continue;
267     }
268     if (first_tensor_info == tensor_info) {
269       continue;
270     } else {
271       MS_LOG(EXCEPTION) << "The node: " << node->DebugString()
272                         << " has multiple users, but the TensorInfo are different";
273     }
274   }
275   return first_node;
276 }
277 
IsParallelCareNode(const CNodePtr & cnode)278 bool IsParallelCareNode(const CNodePtr &cnode) {
279   MS_EXCEPTION_IF_NULL(cnode);
280   // Not skip Send Receive in pp interleave
281   auto parallel_context = parallel::ParallelContext::GetInstance();
282   MS_EXCEPTION_IF_NULL(parallel_context);
283   auto is_pp_interleave = parallel_context->pipeline_interleave();
284   if (is_pp_interleave && (IsPrimitiveCNode(cnode, prim::kPrimSend) || IsPrimitiveCNode(cnode, prim::kPrimReceive))) {
285     return false;
286   }
287   ValueNodePtr prim_node = cnode->input(0)->cast<ValueNodePtr>();
288   if (prim_node == nullptr) {
289     return false;
290   }
291   PrimitivePtr prim = prim_node->value()->cast<PrimitivePtr>();
292   if (prim == nullptr) {
293     return false;
294   }
295   if (!IsParallelConsiderCNode(cnode)) {
296     MS_LOG(DEBUG) << "Parallel don't care node: " << prim->name();
297     return false;
298   }
299   // get_next is not in the forward graph, we need mark the get_next as the forward node
300   if (prim->name() == GET_NEXT || prim->name() == VIRTUAL_OUTPUT) {
301     return true;
302   }
303   if ((prim->name() == CAST) && !cnode->has_user_data<OperatorInfo>()) {
304     return false;
305   }
306 
307   return cnode->in_forward_flag();
308 }
309 
HasNestedMetaFg(const FuncGraphPtr & func_graph)310 bool HasNestedMetaFg(const FuncGraphPtr &func_graph) {
311   if (!IsPynativeParallel()) {
312     return false;
313   }
314   AnfNodePtr ret = func_graph->get_return();
315   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
316   for (auto &node : all_nodes) {
317     if (IsPrimitiveCNode(node, prim::kPrimJ) || IsPrimitiveCNode(node, prim::kPrimVmap) ||
318         IsPrimitiveCNode(node, prim::kPrimTaylor)) {
319       return true;
320     }
321   }
322   return false;
323 }
324 
IsEmbedShardNode(const FuncGraphPtr & func_graph)325 bool IsEmbedShardNode(const FuncGraphPtr &func_graph) {
326   MS_EXCEPTION_IF_NULL(func_graph);
327   AnfNodePtr ret = func_graph->get_return();
328   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
329   return std::any_of(all_nodes.begin(), all_nodes.end(), [&func_graph](const AnfNodePtr &node) {
330     return IsPrimitiveCNode(node, prim::kPrimShard) && (node->func_graph() == func_graph);
331   });
332 }
333 
GetValueListShape(const AnfNodePtr & node)334 Shapes GetValueListShape(const AnfNodePtr &node) {
335   Shapes shapes;
336   std::vector<ValuePtr> inputs_seq;
337   if (IsValueNode<ValueList>(node)) {
338     inputs_seq = node->cast<ValueNodePtr>()->value()->cast<ValueListPtr>()->value();
339   } else if (IsValueNode<ValueTuple>(node)) {
340     inputs_seq = node->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value();
341   } else {
342     MS_LOG(EXCEPTION) << "node is either ValueList or ValueTuple";
343   }
344   for (auto &ele : inputs_seq) {
345     auto tensor = ele->cast<tensor::TensorPtr>();
346     if (tensor == nullptr) {
347       MS_LOG(WARNING) << "The value node is not a tensor";
348       break;
349     }
350     auto one_shape = tensor->shape();
351     shapes.push_back(one_shape);
352   }
353   return shapes;
354 }
355 
IsControlFlowNode(const AnfNodePtr & node)356 bool IsControlFlowNode(const AnfNodePtr &node) {
357   // Only switch or FuncCall nodes are control flow nodes
358   MS_EXCEPTION_IF_NULL(node);
359   if (!node->isa<CNode>()) {
360     return false;
361   }
362   auto cnode = node->cast<CNodePtr>();
363   MS_EXCEPTION_IF_NULL(cnode);
364   // func node
365   if (cnode->input(0)->isa<CNode>() && IsPrimitiveCNode(cnode->input(0), prim::kPrimSwitch)) {
366     return true;
367   }
368   return false;
369 }
370 
GetTupleGetItemIndex(const CNodePtr & cnode)371 int64_t GetTupleGetItemIndex(const CNodePtr &cnode) {
372   MS_EXCEPTION_IF_NULL(cnode);
373   if (!cnode->input(TUPLE_GETITEM_INDEX_POS)->isa<ValueNode>()) {
374     MS_LOG(EXCEPTION) << "The index of tuple getitem is not a value node";
375   }
376 
377   ValuePtr tuple_index_value = GetValueNode(cnode->input(TUPLE_GETITEM_INDEX_POS));
378   MS_EXCEPTION_IF_NULL(tuple_index_value);
379   if (!tuple_index_value->isa<Int64Imm>()) {
380     MS_LOG(EXCEPTION) << "The index of tuple getitem is not int64";
381   }
382   return tuple_index_value->cast<Int64ImmPtr>()->value();
383 }
384 
IsNoNeedRedistribution(const CNodePtr & use_cnode,int use_index)385 static bool IsNoNeedRedistribution(const CNodePtr &use_cnode, int use_index) {
386   return (IsPrimitiveCNode(use_cnode, prim::kPrimDepend) && use_index != 1) || use_cnode->input(0)->isa<CNode>() ||
387          IsOneOfPrimitiveCNode(use_cnode, {prim::kPrimUpdateState, prim::kPrimSwitch, prim::kPrimShape,
388                                            prim::kPrimTensorShape, prim::kPrimDType});
389 }
390 
FuncGraphNodeUsers(const std::pair<AnfNodePtr,int> & node_pair)391 std::vector<std::pair<AnfNodePtr, int>> FuncGraphNodeUsers(const std::pair<AnfNodePtr, int> &node_pair) {
392   std::vector<std::pair<AnfNodePtr, int>> func_users_vector;
393   if (!node_pair.first->isa<CNode>()) {
394     return func_users_vector;
395   }
396   auto use_cnode = node_pair.first->cast<CNodePtr>();
397   MS_EXCEPTION_IF_NULL(use_cnode);
398   if (IsValueNode<FuncGraph>(use_cnode->input(0))) {
399     auto fg = GetValueNode<FuncGraphPtr>(use_cnode->input(0));
400     auto fg_parameters = fg->parameters();
401     auto param = fg_parameters[IntToSize(node_pair.second - 1)];
402     auto manager = fg->manager();
403     auto param_node_users = manager->node_users()[param];
404     for (const auto &node_user : param_node_users) {
405       auto cnode = node_user.first->cast<CNodePtr>();
406       if (IsValueNode<FuncGraph>(cnode->input(0))) {
407         auto sub_graph_users = FuncGraphNodeUsers(node_user);
408         (void)std::copy(sub_graph_users.begin(), sub_graph_users.end(), std::back_inserter(func_users_vector));
409       } else {
410         func_users_vector.emplace_back(node_user);
411       }
412     }
413   }
414   return func_users_vector;
415 }
416 
RemovePlaceholderIdx(const std::vector<int> & get_item_index)417 std::vector<int> RemovePlaceholderIdx(const std::vector<int> &get_item_index) {
418   std::vector<int> new_get_item_index;
419   std::copy(get_item_index.begin(), get_item_index.end(), std::back_inserter(new_get_item_index));
420   if (new_get_item_index.size() != 1) {
421     // Remove first -1, if there is other index
422     new_get_item_index.erase(new_get_item_index.begin());
423   }
424   return new_get_item_index;
425 }
426 
RedistributionNextNodeInMakeTuple(const CNodePtr & use_cnode,const std::pair<std::shared_ptr<AnfNode>,int> & node_pair,const std::vector<int> & get_item_index,int64_t * make_tuple_index,std::vector<std::pair<std::pair<AnfNodePtr,std::vector<int>>,std::vector<int>>> * next_nodes)427 void RedistributionNextNodeInMakeTuple(
428   const CNodePtr &use_cnode, const std::pair<std::shared_ptr<AnfNode>, int> &node_pair,
429   const std::vector<int> &get_item_index, int64_t *make_tuple_index,
430   std::vector<std::pair<std::pair<AnfNodePtr, std::vector<int>>, std::vector<int>>> *next_nodes) {
431   auto modified_get_item_idx = RemovePlaceholderIdx(get_item_index);
432   std::vector<int> input_index = {node_pair.second};
433   if (*make_tuple_index != -1) {
434     int node_pos = IsSomePrimitiveList(use_cnode, SUPPORT_NEW_SHAPEBASE_OPS) ? node_pair.second : 1;
435     auto real_node = GetRealKernelNode(use_cnode->input(node_pos), -1, nullptr);
436     if (IsPrimitiveCNode(real_node.first, prim::kPrimMakeTuple)) {
437       input_index.push_back(LongToInt((*make_tuple_index) + 1));
438       next_nodes->push_back(std::make_pair(std::make_pair(real_node.first, input_index), modified_get_item_idx));
439       *make_tuple_index = -1;
440       return;
441     }
442   }
443   auto modified_node_pair = std::make_pair(node_pair.first, input_index);
444   next_nodes->push_back(std::make_pair(modified_node_pair, modified_get_item_idx));
445 }
446 
SetAnfNode(const AnfNodePtr & param,std::vector<std::pair<std::pair<AnfNodePtr,std::vector<int>>,std::vector<int>>> * next_nodes)447 void SetAnfNode(const AnfNodePtr &param,
448                 std::vector<std::pair<std::pair<AnfNodePtr, std::vector<int>>, std::vector<int>>> *next_nodes) {
449   for (const auto &next_node : *next_nodes) {
450     next_node.first.first->set_user_data<AnfNode>(FUNC_PARAM, param);
451   }
452 }
453 
RedistributionNextNode(const AnfNodePtr & node,const FuncGraphManagerPtr & manager,const NodeUsersMap & node_users_map,const std::vector<int> & get_item_index,int64_t make_tuple_index,std::vector<std::pair<std::pair<AnfNodePtr,std::vector<int>>,std::vector<int>>> * next_nodes)454 void RedistributionNextNode(
455   const AnfNodePtr &node, const FuncGraphManagerPtr &manager, const NodeUsersMap &node_users_map,
456   const std::vector<int> &get_item_index, int64_t make_tuple_index,
457   std::vector<std::pair<std::pair<AnfNodePtr, std::vector<int>>, std::vector<int>>> *next_nodes) {
458   MS_EXCEPTION_IF_NULL(node);
459   if (node_users_map.count(node) == 0) {
460     return;
461   }
462   auto node_set = node_users_map.at(node);
463   for (auto &node_pair : node_set) {
464     auto use_cnode = node_pair.first->cast<CNodePtr>();
465     MS_EXCEPTION_IF_NULL(use_cnode);
466     if (IsValueNode<FuncGraph>(use_cnode->input(0))) {
467       auto cur_fg = use_cnode->func_graph();
468       auto fg = GetValueNode<FuncGraphPtr>(use_cnode->input(0));
469       MS_EXCEPTION_IF_NULL(fg);
470       if (IsWhileGraph(cur_fg, fg)) {
471         continue;
472       }
473       auto fg_parameters = fg->parameters();
474       auto param = fg_parameters[IntToSize(node_pair.second - 1)];
475       MS_EXCEPTION_IF_NULL(param);
476       if (param->has_user_data<OperatorInfo>()) {
477         std::vector<int> input_index = {node_pair.second};
478         auto modified_node_pair = std::make_pair(node_pair.first, input_index);
479         next_nodes->push_back(std::make_pair(modified_node_pair, RemovePlaceholderIdx(get_item_index)));
480         continue;
481       }
482       RedistributionNextNode(param, manager, node_users_map, get_item_index, make_tuple_index, next_nodes);
483       SetAnfNode(param, next_nodes);
484       continue;
485     }
486     if (IsPrimitiveCNode(use_cnode, prim::kPrimMakeTuple)) {
487       make_tuple_index = node_pair.second - 1;
488       RedistributionNextNode(use_cnode, manager, node_users_map, get_item_index, make_tuple_index, next_nodes);
489       continue;
490     }
491     if (IsPrimitiveCNode(use_cnode, prim::kPrimTupleGetItem) || IsPrimitiveCNode(use_cnode, prim::kPrimListGetItem)) {
492       auto temp = LongToInt(GetTupleGetItemIndex(use_cnode));
493       if (temp != make_tuple_index && make_tuple_index != -1) {
494         continue;
495       }
496       temp = make_tuple_index != -1 ? -1 : temp;
497       std::vector<int> new_get_item_index;
498       std::copy(get_item_index.begin(), get_item_index.end(), std::back_inserter(new_get_item_index));
499       new_get_item_index.push_back(temp);
500       RedistributionNextNode(use_cnode, manager, node_users_map, new_get_item_index, -1, next_nodes);
501       continue;
502     }
503     if (IsPrimitiveCNode(use_cnode, prim::kPrimReturn)) {
504       auto fg = use_cnode->func_graph();
505       auto fg_map = fg->func_graph_cnodes_index();
506       for (auto &fg_use : fg_map) {
507         auto fg_node = fg_use.first->first->cast<CNodePtr>();
508         constexpr int SWITCH_LAST_INPUT_INDEX = 3;
509         if (IsWhileGraph(fg, fg) && fg_use.first->second != SWITCH_LAST_INPUT_INDEX) {
510           continue;
511         }
512         RedistributionNextNode(fg_node, manager, node_users_map, get_item_index, make_tuple_index, next_nodes);
513       }
514     }
515     // depend, auto monad and control flow op don't need to jump over
516     if (IsNoNeedRedistribution(use_cnode, node_pair.second)) {
517       continue;
518     }
519     if (IsParallelCareNode(use_cnode) && use_cnode->has_user_data<OperatorInfo>()) {
520       RedistributionNextNodeInMakeTuple(use_cnode, node_pair, get_item_index, &make_tuple_index, next_nodes);
521       continue;
522     }
523     // search recursively
524     RedistributionNextNode(use_cnode, manager, node_users_map, get_item_index, make_tuple_index, next_nodes);
525   }
526 }
527 
RedistributionPreNode(const CNodePtr & cnode,const FuncGraphManagerPtr & manager,std::vector<AnfNodePtr> * pre_nodes)528 void RedistributionPreNode(const CNodePtr &cnode, const FuncGraphManagerPtr &manager,
529                            std::vector<AnfNodePtr> *pre_nodes) {
530   if (IsValueNode<FuncGraph>(cnode->input(0))) {
531     return;
532   }
533   if (IsControlFlowNode(cnode)) {
534     auto switch_cnode = cnode->input(0)->cast<CNodePtr>();
535     MS_EXCEPTION_IF_NULL(switch_cnode);
536     // extract true branch, false branch is usually also a control flow graph
537     auto fg = GetValueNode<FuncGraphPtr>(switch_cnode->input(2));
538     MS_EXCEPTION_IF_NULL(fg);
539     auto fg_out = fg->output()->cast<CNodePtr>();
540     MS_EXCEPTION_IF_NULL(fg_out);
541     // control flow node, need enter graph to find redistribution pre node.
542     RedistributionPreNode(fg_out, manager, pre_nodes);
543   }
544   if (IsPrimitiveCNode(cnode, prim::kPrimDepend) || IsPrimitiveCNode(cnode, prim::kPrimLoad) ||
545       IsPrimitiveCNode(cnode, prim::kPrimCast) || IsPrimitiveCNode(cnode, prim::kPrimAllReduce)) {
546     auto cnode_input = cnode->input(1)->cast<CNodePtr>();
547     MS_EXCEPTION_IF_NULL(cnode_input);
548     RedistributionPreNode(cnode_input, manager, pre_nodes);
549   }
550   if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
551     pre_nodes->push_back(cnode);
552   }
553 }
554 
GetNodeShape(const AnfNodePtr & node)555 Shapes GetNodeShape(const AnfNodePtr &node) {
556   MS_EXCEPTION_IF_NULL(node);
557   Shapes shapes;
558   if (IsValueNode<ValueList>(node) || IsValueNode<ValueTuple>(node)) {
559     return GetValueListShape(node);
560   }
561   BaseShapePtr base_shape_ptr = node->Shape();
562   if (base_shape_ptr == nullptr && node->isa<ValueNode>()) {
563     auto value_node = node->cast<ValueNodePtr>();
564     MS_EXCEPTION_IF_CHECK_FAIL(value_node->value() != nullptr, "ValueNode has no value.");
565     auto abstract = value_node->value()->ToAbstract();
566     MS_EXCEPTION_IF_CHECK_FAIL(abstract != nullptr, "ValueNode has no Abstract.");
567     node->set_abstract(abstract);
568     base_shape_ptr = node->Shape();
569   }
570   if (node->isa<CNode>() && !IsControlFlowNode(node)) {
571     auto cnode = node->cast<CNodePtr>();
572     if (cnode->input(0)->isa<CNode>()) {
573       if (cnode->size() < 2) {
574         MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " size is smaller than 2";
575       }
576       base_shape_ptr = cnode->input(1)->Shape();
577     }
578   }
579   // If node is Depend, only first input should be used.
580   if (node->isa<CNode>() && IsPrimitiveCNode(node->cast<CNodePtr>(), prim::kPrimDepend)) {
581     auto depend_cnode = node->cast<CNodePtr>();
582     MS_EXCEPTION_IF_NULL(depend_cnode->input(1));
583     return GetNodeShape(depend_cnode->input(1));
584   }
585   if (base_shape_ptr == nullptr) {
586     MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " shape_ptr is nullptr, full name is "
587                       << node->fullname_with_scope();
588   }
589   auto tuple_shape_ptr = dyn_cast<abstract::SequenceShape>(base_shape_ptr);
590   if (tuple_shape_ptr != nullptr) {
591     if (tuple_shape_ptr->size() == 0) {
592       shapes.push_back(Shape{0});
593       return shapes;
594     }
595     auto tuple_shape = tuple_shape_ptr->shape();
596     if (tuple_shape[0]->isa<abstract::NoShape>()) {
597       shapes.push_back(Shape{SizeToLong(tuple_shape_ptr->size())});
598       return shapes;
599     }
600     for (auto &shape : tuple_shape) {
601       auto each_shape = dyn_cast<abstract::Shape>(shape);
602       MS_EXCEPTION_IF_NULL(each_shape);
603       shapes.push_back(each_shape->shape());
604     }
605   } else if (base_shape_ptr->isa<abstract::DynamicSequenceShape>()) {
606     shapes.push_back(Shape{-1});
607   } else if (base_shape_ptr->isa<abstract::Shape>()) {
608     auto shape_ptr = dyn_cast<abstract::Shape>(base_shape_ptr);
609     MS_EXCEPTION_IF_NULL(shape_ptr);
610     shapes.push_back(shape_ptr->shape());
611   } else if (base_shape_ptr->isa<abstract::NoShape>()) {
612     shapes.push_back(Shape{});
613   } else {
614     MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " should be Tuple/List/Tensor/Scalar, but got "
615                       << base_shape_ptr->ToString() << "full name is " << node->fullname_with_scope();
616   }
617   return shapes;
618 }
619 
TransferShapesToNewShapes(const Shapes & shapes,const bool need_create_shape_list)620 NewShapes TransferShapesToNewShapes(const Shapes &shapes, const bool need_create_shape_list) {
621   NewShapes s;
622   if (!need_create_shape_list) {
623     s.emplace_back(std::make_shared<ShapeValue>(shapes[0]));
624   } else {
625     std::vector<ShapeBasePtr> shapes_list;
626     std::transform(shapes.begin(), shapes.end(), std::back_inserter(shapes_list),
627                    [](const auto &shape) { return std::make_shared<ShapeValue>(shape); });
628     s.emplace_back(std::make_shared<ShapeList>(shapes_list));
629   }
630   return s;
631 }
632 
ExtractNewShapeFromShape(const abstract::BaseShapePtr & shape)633 ShapeBasePtr ExtractNewShapeFromShape(const abstract::BaseShapePtr &shape) {
634   ShapeBasePtr out_shape;
635   if (dyn_cast<abstract::Shape>(shape) != nullptr) {
636     auto casted_shape = dyn_cast<abstract::Shape>(shape);
637     std::vector<int64_t> shape_value = casted_shape->shape();
638     out_shape = std::make_shared<ShapeValue>(shape_value);
639   } else if (dyn_cast<abstract::SequenceShape>(shape) != nullptr) {
640     std::vector<ShapeBasePtr> tuple_shape;
641     auto sequence_shape = dyn_cast<abstract::SequenceShape>(shape);
642     std::transform(sequence_shape->shape().begin(), sequence_shape->shape().end(), std::back_inserter(tuple_shape),
643                    ExtractNewShapeFromShape);
644     out_shape = std::make_shared<ShapeList>(tuple_shape);
645   } else {
646     MS_LOG(EXCEPTION) << "each shape in tuple shape is not shape or sequenceshape";
647   }
648   return out_shape;
649 }
650 
GetNodeNewShape(const AnfNodePtr & node)651 NewShapes GetNodeNewShape(const AnfNodePtr &node) {
652   MS_EXCEPTION_IF_NULL(node);
653   NewShapes shapes;
654   if (IsValueNode<ValueList>(node) || IsValueNode<ValueTuple>(node)) {
655     return TransferShapesToNewShapes(GetValueListShape(node), false);
656   }
657   BaseShapePtr base_shape_ptr = node->Shape();
658   if (base_shape_ptr == nullptr && node->isa<ValueNode>()) {
659     auto value_node = node->cast<ValueNodePtr>();
660     MS_EXCEPTION_IF_CHECK_FAIL(value_node->value() != nullptr, "ValueNode has no value.");
661     auto abstract = value_node->value()->ToAbstract();
662     MS_EXCEPTION_IF_CHECK_FAIL(abstract != nullptr, "ValueNode has no Abstract.");
663     node->set_abstract(abstract);
664     base_shape_ptr = node->Shape();
665   }
666   if (node->isa<CNode>() && !IsControlFlowNode(node)) {
667     auto cnode = node->cast<CNodePtr>();
668     if (cnode->input(0)->isa<CNode>()) {
669       if (cnode->size() < kSizeTwo) {
670         MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " size is smaller than 2";
671       }
672       base_shape_ptr = cnode->input(1)->Shape();
673     }
674   }
675   // If node is Depend, only first input should be used.
676   if (node->isa<CNode>() && IsPrimitiveCNode(node->cast<CNodePtr>(), prim::kPrimDepend)) {
677     auto depend_cnode = node->cast<CNodePtr>();
678     MS_EXCEPTION_IF_NULL(depend_cnode->input(1));
679     return GetNodeNewShape(depend_cnode->input(1));
680   }
681   if (base_shape_ptr == nullptr) {
682     MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " shape_ptr is nullptr, full name is "
683                       << node->fullname_with_scope();
684   }
685   auto tuple_shape_ptr = dyn_cast<abstract::SequenceShape>(base_shape_ptr);
686   if (tuple_shape_ptr != nullptr) {
687     if (tuple_shape_ptr->size() == 0) {
688       std::vector<int64_t> shape_value = {0};
689       shapes.emplace_back(std::make_shared<ShapeValue>(shape_value));
690       return shapes;
691     }
692     auto tuple_shape = tuple_shape_ptr->shape();
693     if (tuple_shape[0]->isa<abstract::NoShape>()) {
694       std::vector<int64_t> shape_value = {SizeToLong(tuple_shape_ptr->size())};
695       shapes.emplace_back(std::make_shared<ShapeValue>(shape_value));
696       return shapes;
697     }
698     for (auto &shape : tuple_shape) {
699       auto each_shape = ExtractNewShapeFromShape(shape);
700       shapes.emplace_back(each_shape);
701     }
702   } else if (base_shape_ptr->isa<abstract::DynamicSequenceShape>()) {
703     std::vector<int64_t> shape_value = {-1};
704     shapes.emplace_back(std::make_shared<ShapeValue>(shape_value));
705   } else if (base_shape_ptr->isa<abstract::Shape>()) {
706     auto shape_ptr = dyn_cast<abstract::Shape>(base_shape_ptr);
707     MS_EXCEPTION_IF_NULL(shape_ptr);
708     std::vector<int64_t> shape_value = shape_ptr->shape();
709     shapes.emplace_back(std::make_shared<ShapeValue>(shape_value));
710   } else if (base_shape_ptr->isa<abstract::NoShape>()) {
711     std::vector<int64_t> shape_value = {};
712     shapes.emplace_back(std::make_shared<ShapeValue>(shape_value));
713   } else {
714     MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " should be Tuple/List/Tensor/Scalar, but got "
715                       << base_shape_ptr->ToString() << "full name is " << node->fullname_with_scope();
716   }
717   return shapes;
718 }
719 
FindCommonMirrorGroup(const FuncGraphPtr & root)720 RankList FindCommonMirrorGroup(const FuncGraphPtr &root) {
721   auto parameters = root->parameters();
722   for (auto &parameter : parameters) {
723     auto param_ptr = parameter->cast<ParameterPtr>();
724     MS_EXCEPTION_IF_NULL(param_ptr);
725     if (!(param_ptr->has_default() && ParameterRequireGrad(param_ptr))) {
726       continue;
727     }
728     size_t allow_repeat_num = 1;
729     if (ParallelContext::GetInstance()->enable_parallel_optimizer() &&
730         (!param_ptr->param_info() || param_ptr->param_info()->parallel_optimizer())) {
731       if (ParallelContext::GetInstance()->optimizer_weight_shard_size() == -1) {
732         MS_LOG(INFO) << "The parameter :" << param_ptr->fullname_with_scope()
733                      << " is fully shard by optimizer parallel,"
734                         " thus cannot find common data parallel group for this rank";
735         return {g_device_manager->global_rank()};
736       }
737       allow_repeat_num = size_t(ParallelContext::GetInstance()->optimizer_weight_shard_size());
738     }
739     if (IsFullySplitParameter(param_ptr, allow_repeat_num)) {
740       MS_LOG(INFO) << "The parameter :" << param_ptr->fullname_with_scope()
741                    << " is fully shard, thus cannot find common data parallel group for this rank";
742       return {g_device_manager->global_rank()};
743     }
744   }
745   AnfNodePtr ret = root->get_return();
746   MS_EXCEPTION_IF_NULL(ret);
747   std::vector<int64_t> common_group_list;
748   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
749   bool is_first_group = true;
750   for (auto &node : all_nodes) {
751     if (!IsPrimitiveCNode(node, prim::kPrimMirror) && !IsPrimitiveCNode(node, prim::kPrimMirrorMicroStep) &&
752         !IsPrimitiveCNode(node, prim::kPrimMirrorMiniStep)) {
753       continue;
754     }
755     auto prim = GetCNodePrimitive(node);
756     if (!prim->HasAttr(GROUP)) {
757       MS_LOG(EXCEPTION) << "The mirror operator dose not have group attr : " << node->DebugString();
758     }
759     std::string group_name = GetValue<std::string>(prim->GetAttr(GROUP));
760     std::vector<int64_t> group_list = g_device_manager->FindRankListByHashName(group_name);
761     if (is_first_group) {
762       common_group_list = group_list;
763       is_first_group = false;
764     } else {
765       std::vector<int64_t> new_comm_group_list;
766       (void)std::set_intersection(common_group_list.begin(), common_group_list.end(), group_list.begin(),
767                                   group_list.end(), std::back_inserter(new_comm_group_list));
768       common_group_list = new_comm_group_list;
769     }
770   }
771   MS_LOG(INFO) << "The common mirror group is:" << common_group_list;
772   return common_group_list;
773 }
774 
CreateInstanceName(const CNodePtr & node,size_t index)775 std::string CreateInstanceName(const CNodePtr &node, size_t index) {
776   MS_EXCEPTION_IF_NULL(node);
777   if (!IsValueNode<Primitive>(node->input(0))) {
778     MS_LOG(EXCEPTION) << "CreateInstanceName: " << node->ToString() << " doesn't have primitive";
779   }
780   std::string name_base = node->fullname_with_scope();
781   std::string name = name_base + "_" + std::to_string(index);
782   std::string instance_name = HashInstanceName(name);
783   return instance_name;
784 }
785 
SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input)786 void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) {
787   if (new_node_input.empty()) {
788     return;
789   }
790 
791   auto prim_anf_node = new_node_input[0]->cast<ValueNodePtr>();
792   auto prim = GetValueNode<PrimitivePtr>(prim_anf_node);
793   MS_EXCEPTION_IF_NULL(prim);
794 
795   auto attrs = prim->attrs();
796   auto iter = attrs.find(GROUP);
797   if (iter != attrs.end()) {
798     auto value = iter->second;
799     MS_EXCEPTION_IF_NULL(value);
800     if (value->isa<StringImm>()) {
801       std::string hash_name = value->cast<StringImmPtr>()->value();
802       MS_EXCEPTION_IF_NULL(g_device_manager);
803       std::string rank_list_name = g_device_manager->FindRankListNameByHashName(hash_name);
804       (void)prim->AddAttr(GROUP_RANKS, MakeValue(rank_list_name));
805     }
806   }
807 }
808 
SetStridedSliceSplitStrategy(const std::vector<AnfNodePtr> & all_nodes)809 void SetStridedSliceSplitStrategy(const std::vector<AnfNodePtr> &all_nodes) {
810   for (auto &node : all_nodes) {
811     if (!node->isa<CNode>()) {
812       continue;
813     }
814     auto cnode = node->cast<CNodePtr>();
815     MS_EXCEPTION_IF_NULL(cnode);
816     if (!IsPrimitiveCNode(cnode, prim::kPrimStridedSlice)) {
817       continue;
818     }
819     auto slice_prim = GetCNodePrimitive(cnode);
820     MS_EXCEPTION_IF_NULL(slice_prim);
821     if (slice_prim->HasAttr(FUNC_GRAPH_FLAG_STRIDED_SLICE)) {
822       SetStridedSliceStrategy(cnode);
823     }
824   }
825 }
826 
827 // Check the given tensor, return nullptr if the given type is not an TensorType
CheckTensorType(const TypePtr & node_type)828 bool CheckTensorType(const TypePtr &node_type) {
829   MS_EXCEPTION_IF_NULL(node_type);
830   if (!node_type->isa<mindspore::TensorType>()) {
831     return false;
832   }
833   return true;
834 }
835 
FindReturnUser(const CNodePtr & cnode,const std::vector<AnfNodePtr> & all_nodes,std::pair<std::shared_ptr<AnfNode>,int> * queue_node)836 void FindReturnUser(const CNodePtr &cnode, const std::vector<AnfNodePtr> &all_nodes,
837                     std::pair<std::shared_ptr<AnfNode>, int> *queue_node) {
838   auto graph = cnode->func_graph();
839   auto is_target = [&](const AnfNodePtr &ele) {
840     if (ele->isa<CNode>()) {
841       auto parent_cnode = ele->cast<CNodePtr>();
842       return IsValueNode<FuncGraph>(parent_cnode->input(0)) &&
843              GetValueNode<FuncGraphPtr>(parent_cnode->input(0)) == graph;
844     }
845     return false;
846   };
847   auto it = std::find_if(all_nodes.begin(), all_nodes.end(), is_target);
848   if (it == all_nodes.end()) {
849     return;
850   }
851   *queue_node = {*it, 0};
852 }
853 
AddVisitedNode(std::queue<std::pair<std::shared_ptr<AnfNode>,int>> * visited,const NodeUsersMap & node_users_map,const AnfNodePtr & key_node)854 void AddVisitedNode(std::queue<std::pair<std::shared_ptr<AnfNode>, int>> *visited, const NodeUsersMap &node_users_map,
855                     const AnfNodePtr &key_node) {
856   if (IsPrimitiveCNode(key_node, prim::kPrimReturn)) {
857     return;
858   }
859   auto node_users = node_users_map.at(key_node);
860   for (auto &node_user : node_users) {
861     auto cnode = node_user.first->cast<CNodePtr>();
862     if (!cnode || IsSomePrimitiveList(cnode, {MAKE_TUPLE, UPDATESTATE})) {
863       continue;
864     }
865     if (node_user.first) {
866       visited->push(node_user);
867     }
868   }
869 }
870 
BFSParallelCareNode(const AnfNodePtr & node_ptr,const NodeUsersMap & node_users_map,const int index,const std::vector<AnfNodePtr> & all_nodes)871 std::pair<std::shared_ptr<AnfNode>, int> BFSParallelCareNode(const AnfNodePtr &node_ptr,
872                                                              const NodeUsersMap &node_users_map, const int index,
873                                                              const std::vector<AnfNodePtr> &all_nodes) {
874   std::queue<std::pair<std::shared_ptr<AnfNode>, int>> visited;
875   CNodePtr cnode = nullptr;
876   AnfNodePtr node = nullptr;
877   if (!node_ptr) {
878     return std::make_pair(nullptr, 0);
879   }
880   AddVisitedNode(&visited, node_users_map, node_ptr);
881   while (!visited.empty()) {
882     auto queue_node = visited.front();
883     visited.pop();
884     cnode = queue_node.first->cast<CNodePtr>();
885     if (IsParallelCareNode(cnode) || IsAutoParallelCareNode(cnode)) {
886       return queue_node;
887     } else if (IsValueNode<FuncGraph>(cnode->input(0))) {
888       auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
889       auto params = graph->parameters();
890       auto target_param = params[queue_node.second - 1];
891       auto node_set = node_users_map.at(target_param);
892       for (auto &node_user : node_set) {
893         cnode = node_user.first->cast<CNodePtr>();
894         if (IsParallelCareNode(cnode) || IsAutoParallelCareNode(cnode)) {
895           return node_user;
896         } else if (IsSomePrimitiveList(cnode, {MAKE_TUPLE, UPDATESTATE})) {
897           continue;
898         }
899         visited.push(node_user);
900       }
901     } else {
902       if (IsSomePrimitive(cnode, RETURN)) {
903         FindReturnUser(cnode, all_nodes, &queue_node);
904       } else if (IsSomePrimitive(cnode, kTupleGetItemOpName)) {
905         auto tuple_index = LongToSize(GetValue<int64_t>(GetValueNode(cnode->input(2))));
906         if (tuple_index != IntToSize(index - 1)) {
907           continue;
908         }
909       }
910       AddVisitedNode(&visited, node_users_map, queue_node.first);
911     }
912   }
913   return std::make_pair(nullptr, 0);
914 }
915 
916 // For the weight used by cast and matmul at the same time, like the followings
917 // weight1->mirror->cast1-> matmul1;
918 // weight1->add
919 // we will not insert the cast(FP32->FP16), as it will cause the input of the operator add to be changed to fp16.
GetChildCastNode(const AnfNodePtr & node_ptr,const NodeUsersMap & node_users_map)920 AnfNodePtr GetChildCastNode(const AnfNodePtr &node_ptr, const NodeUsersMap &node_users_map) {
921   std::queue<AnfNodePtr> visited;
922   AnfNodePtr queue_node = nullptr;
923   CNodePtr cnode = nullptr;
924   AnfNodePtr node = nullptr;
925   if (!node_ptr) {
926     return nullptr;
927   }
928   auto users = node_users_map.at(node_ptr);
929   for (auto &node_user : users) {
930     cnode = node_user.first->cast<CNodePtr>();
931     if (!cnode || !cnode->in_forward_flag()) {
932       continue;
933     }
934     if (node_user.first) {
935       visited.push(node_user.first);
936     }
937   }
938   while (!visited.empty()) {
939     queue_node = visited.front();
940     visited.pop();
941     cnode = queue_node->cast<CNodePtr>();
942     // MAKE_TUPLE will not appear after the load in the forward graph
943     if (IsSomePrimitive(cnode, MAKE_TUPLE)) {
944       continue;
945     } else if (IsInAllGatherNodeList(cnode) || IsSomePrimitiveList(cnode, {LOAD, RESHAPE})) {
946       auto node_set = node_users_map.at(queue_node);
947       for (auto &node_user : node_set) {
948         visited.push(node_user.first);
949       }
950     } else if (!IsSomePrimitive(cnode, CAST)) {
951       MS_LOG(INFO) << "The weight's users including the non cast node So "
952                    << "will not insert cast for this parameter " << node_ptr->DebugString();
953       return nullptr;
954     } else if (!node) {
955       node = queue_node;
956     }
957   }
958   return node;
959 }
960 
961 // Given the cnode ptr, find its users until we find the computation node, then return the type of the
962 // computation node. This function is used to find the target type for CreateFP16Cast. Only returns the target type if
963 // it is float16, and the source node is float32. If the situation is not matched, then return the nullptr.
FindChildCastWithFP32ToFP16(const std::pair<AnfNodePtr,int> & res,const NodeUsersMap & node_users_map)964 TypePtr FindChildCastWithFP32ToFP16(const std::pair<AnfNodePtr, int> &res, const NodeUsersMap &node_users_map) {
965   if (ParallelContext::GetInstance()->pipeline_stage_split_num() <= 1) {
966     return nullptr;
967   }
968   auto cnode_ptr = res.first->cast<CNodePtr>();
969   if (!cnode_ptr) {
970     return nullptr;
971   }
972   auto cnode_inputs = cnode_ptr->inputs();
973   if (cnode_inputs.size() < TWO_INPUT_SIZE) {
974     return nullptr;
975   }
976 
977   AnfNodePtr node = nullptr;
978   if (IsValueNode<FuncGraph>(cnode_ptr->input(kIndex0))) {
979     auto graph_sub = GetValueNode<FuncGraphPtr>(cnode_ptr->input(0));
980     auto parameters = graph_sub->parameters();
981     auto parameter_sub = parameters[IntToSize(res.second - 1)];
982     node = GetChildCastNode(parameter_sub, node_users_map);
983   } else {
984     // As we execute the function IsWeightValidUsed when we start to insert the mirror, so the second parameter
985     // is always the parameter.
986     auto weight = cnode_inputs[1];
987     if (!weight->isa<Parameter>()) {
988       return nullptr;
989     }
990     MS_LOG(INFO) << "Start to search the weight params:" << weight->DebugString();
991     node = GetChildCastNode(weight, node_users_map);
992   }
993 
994   if (!node) {
995     return nullptr;
996   }
997   // get the output dtype of the operator
998   auto node_type = node->Type();
999   if (!CheckTensorType(node_type)) {
1000     return nullptr;
1001   }
1002   auto input_element_type = node_type->cast<mindspore::TensorTypePtr>()->element();
1003   MS_EXCEPTION_IF_NULL(input_element_type);
1004   if (!IsPrimitiveCNode(node)) {
1005     return nullptr;
1006   }
1007   auto cast_input_cnode = node->cast<CNodePtr>()->input(kIndex1)->cast<CNodePtr>();
1008   if (!cast_input_cnode) {
1009     return nullptr;
1010   }
1011   auto source_node_type = cast_input_cnode->Type();
1012   if (!CheckTensorType(source_node_type)) {
1013     return nullptr;
1014   }
1015   auto source_element_type = source_node_type->cast<mindspore::TensorTypePtr>()->element();
1016   MS_EXCEPTION_IF_NULL(source_element_type);
1017   // We only add cast operation when the source is fp32 type, and the users is fp16 type.
1018   if ((source_element_type->type_id() == kNumberTypeFloat32 && input_element_type->type_id() == kNumberTypeFloat16) ||
1019       (source_element_type->type_id() == kNumberTypeFloat32 && input_element_type->type_id() == kNumberTypeBFloat16)) {
1020     return input_element_type;
1021   }
1022   return nullptr;
1023 }
1024 
1025 // Create a cast node given the current node and the previous node. The target type of the the cast is from the
1026 // compute_node_type.
1027 // Return the new cast node with pre_node as the inputs.
CreateFP16Cast(const CNodePtr & node,const AnfNodePtr & pre_node,const TypePtr & compute_node_type)1028 AnfNodePtr CreateFP16Cast(const CNodePtr &node, const AnfNodePtr &pre_node, const TypePtr &compute_node_type) {
1029   const char kOpsFunctionModelName[] = "mindspore.ops.functional";
1030   static py::object cast_prim = python_adapter::GetPyFn(kOpsFunctionModelName, "cast");
1031   const auto &adapter = py::cast<PrimitivePyAdapterPtr>(cast_prim);
1032   MS_EXCEPTION_IF_NULL(adapter);
1033   MS_EXCEPTION_IF_NULL(compute_node_type);
1034   auto prim = adapter->attached_primitive();
1035   if (prim == nullptr) {
1036     prim = std::make_shared<PrimitivePy>(cast_prim);
1037   }
1038   // Insert cast.
1039   auto type_node = NewValueNode(compute_node_type);
1040   type_node->set_abstract(compute_node_type->ToAbstract());
1041   auto new_node = node->func_graph()->NewCNode({NewValueNode(prim), pre_node, type_node});
1042   new_node->set_abstract(node->abstract());
1043   new_node->set_scope(node->scope());
1044   new_node->set_in_forward_flag(true);
1045   return new_node;
1046 }
1047 
LabelGenMaskMicro(const FuncGraphPtr & root)1048 void LabelGenMaskMicro(const FuncGraphPtr &root) {
1049   AnfNodePtr ret = root->get_return();
1050   MS_EXCEPTION_IF_NULL(ret);
1051   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
1052   for (auto &node : all_nodes) {
1053     if (IsPrimitiveCNode(node, prim::kPrimDropoutDoMask)) {
1054       auto gen_mask_node = RealInputNode(node->cast<CNodePtr>(), 2);
1055       if (gen_mask_node->isa<CNode>()) {
1056         gen_mask_node->cast<CNodePtr>()->set_primal_attrs(node->cast<CNodePtr>()->primal_attrs());
1057       }
1058     }
1059   }
1060 }
1061 
SetCastForParamNotRecompute(const std::vector<AnfNodePtr> & all_nodes)1062 void SetCastForParamNotRecompute(const std::vector<AnfNodePtr> &all_nodes) {
1063   for (const auto &node : all_nodes) {
1064     if (!IsPrimitiveCNode(node)) {
1065       continue;
1066     }
1067     auto cnode = node->cast<CNodePtr>();
1068     auto cnode_prim = GetCNodePrimitive(cnode);
1069     if (cnode_prim->HasAttr("DISABLE_MERGE_ASSIGN_ADD")) {
1070       cnode->AddPrimalAttr("DISABLE_MERGE_ASSIGN_ADD", cnode_prim->GetAttr("DISABLE_MERGE_ASSIGN_ADD"));
1071     }
1072     if (!IsPrimitiveCNode(node, prim::kPrimCast)) {
1073       continue;
1074     }
1075     auto cast_input = RealInputNode(cnode, 1);
1076     if (cast_input->isa<Parameter>() && cast_input->cast<ParameterPtr>()->has_default()) {
1077       MS_LOG(INFO) << "Cast for parameter no needs recompute to avoid redundant trans_data operator";
1078       PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0)->cast<ValueNodePtr>());
1079       (void)prim->AddAttr("recompute", MakeValue(false));
1080     }
1081   }
1082 }
1083 
GetAttrsFromAnfNode(const std::shared_ptr<AnfNode> & node,const string & key)1084 std::shared_ptr<Value> GetAttrsFromAnfNode(const std::shared_ptr<AnfNode> &node, const string &key) {
1085   if (!node) {
1086     return nullptr;
1087   }
1088   auto cnode = node->cast<CNodePtr>();
1089   auto prim = GetCNodePrimitive(cnode);
1090   if (prim && prim->HasAttr(key)) {
1091     return prim->GetAttr(key);
1092   }
1093   return nullptr;
1094 }
1095 
IsSplittableOperator(const std::string & op_name)1096 bool IsSplittableOperator(const std::string &op_name) {
1097   // clang-format off
1098   static const std::set<std::string> splittable_op =
1099     {MATMUL, TRANSPOSE, GELU, FAST_GELU, TANH, SOFTMAX, SUB, MUL, DIV, RESHAPE, GREATER, LOG_SOFTMAX, ACTIVATION, PRELU,
1100      BATCH_MATMUL_EXT, MATMUL_EXT,
1101      FLOORDIV, L2_NORMALIZE, ADD, MAXPOOL, AVGPOOL, MAXPOOLV2, VIRTUAL_DATA_SET, RELU, ONEHOT, DROPOUT_DO_MASK,
1102      REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING, STACK_EXT,
1103      MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP, STACK,
1104      LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, CONCAT,
1105      STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, TILE, DROPOUT,
1106      SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS,
1107      EMBEDDING_LOOKUP, FUSE_BATCH_NORM_EX, SPLIT, BROADCAST_TO, ABS, ACOSH, ASIN, ASINH, ATAN, ATANH, CEIL, COSH,
1108      EXPM1, LOG1P, SIN, SINH, TAN, RSQRT, INV, RECIPROCAL, ROUND, FLOOR, SIGN, ERF, ERFC, ZEROSLIKE, ONESLIKE,
1109      BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6,
1110      SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM,
1111      UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, SLICE_EXT, SELECT,
1112      GATHERD, UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE, SCATTER_ND_UPDATE, SCATTER_ND_ADD, SCATTER_ND_SUB,
1113      TENSOR_SCATTER_UPDATE, TENSOR_SCATTER_ADD, TENSOR_SCATTER_SUB, TENSOR_SCATTER_MAX, TENSOR_SCATTER_MIN, WKV,
1114      TENSOR_SCATTER_MUL, TENSOR_SCATTER_DIV, VIRTUAL_OUTPUT, CONV2D_BACK_PROP_INPUT, CONV2D_TRANSPOSE, SORT, PAD_V3,
1115      MATMUL_DDS, DSD_MATMUL, UNIFORMREAL, STANDARD_NORMAL, RESIZE_BILINEAR_V2, RESIZE_NEAREST_NEIGHBOR, FAST_GELU, IOU,
1116      BOUNDING_BOX_ENCODE, UNSORTED_SEGMENT_PROD, SQUARE_SUM_ALL, UNIQUE_CONSECUTIVE, SILU, INDEX_SELECT, CLAMP_SCALAR,
1117      RANDOM_CHOICE_WITH_MASK, CROP_AND_RESIZE, ROI_ALIGN, REDUCE_PROD, REDUCE_ANY, REDUCE_ALL, ARGMAX, ARGMIN, ARGMINV2,
1118      RESIZE_NEAREST_NEIGHBOR, CUM_SUM, FAST_GELU, IOU, BOUNDING_BOX_ENCODE, RANDOM_CHOICE_WITH_MASK, CROP_AND_RESIZE,
1119      ROI_ALIGN, IS_FINITE, RINT, HSHRINK, HSIGMOID, MISH, SELU, SOFT_SHRINK, XLOGY, XDIVY, CUM_PROD, BITWISE_AND,
1120      BITWISE_OR, BITWISE_XOR, MUL_NO_NAN, TRUNCATE_DIV, TRUNCATE_MOD, INPLACE_ADD, INPLACE_SUB, INPLACE_UPDATE,
1121      L2_LOSS, LERP, ADDN, CDIST, SQUARED_DIFFERENCE, ERFINV, MASKED_FILL, SPLITV, GAMMA, KLDIV_LOSS, LIN_SPACE,
1122      CHECK_VALID, INVERT, SCATTER_ADD, SCATTER_DIV, SCATTER_MUL, SCATTER_MAX, SCATTER_MIN, SCATTER_SUB, UNIQUE_WITH_PAD,
1123      POPULATION_COUNT, IDENTITY, BESSELI0, BESSELI1, BESSELJ0, BESSELJ1, CUM_MAX, CUM_MIN, HYPOT, IGAMMA, IGAMMAC,
1124      LEFT_SHIFT, RIGHT_SHIFT, NEXT_AFTER, ZETA, REVERSEV2, LGAMMA, TRUNC, BETAINC, GCD, CHOLESKY, CONV3D, MAXPOOL_3D,
1125      AVGPOOL_3D, FILLV2, FAKE_QUANT_PER_LAYER, FAKE_QUANT_PER_CHANNEL, MIN_MAX_UPDATE_PER_LAYER, ASCEND_QUANTV2,
1126      MIN_MAX_UPDATE_PER_CHANNEL, FFN, FLASH_ATTENTION_SCORE, ASCEND_QUANT, ASCEND_DEQUANT, GRID_SAMPLER_2D, ANTI_QUANT,
1127      CONVOLUTION, LIN_SPACE_EXT, ONEHOTEXT};
1128   // clang-format on
1129 
1130   auto iter = splittable_op.find(op_name);
1131   return (iter != splittable_op.end());
1132 }
1133 
IsAutoParallelCareNode(const CNodePtr & cnode)1134 bool IsAutoParallelCareNode(const CNodePtr &cnode) {
1135   MS_EXCEPTION_IF_NULL(cnode);
1136   ValueNodePtr prim_node = cnode->input(0)->cast<ValueNodePtr>();
1137   if (prim_node == nullptr) {
1138     return false;
1139   }
1140   PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_node);
1141   if (prim == nullptr) {
1142     return false;
1143   }
1144   if (IsSomePrimitiveList(cnode, {SEND, RECEIVE, MAKE_TUPLE, MAKE_LIST})) {
1145     return false;
1146   }
1147   bool bool_result = IsParallelCareNode(cnode) && !IsSplittableOperator(prim->name());
1148   if (bool_result) {
1149     MS_LOG(INFO) << "For 'auto_parallel', missing the splitable implementation of OperatorInfo for: " << prim->name()
1150                  << ", default strategy will be assigned. Network training may deteriorate or malfunction";
1151   } else if (prim->name() == CAST) {
1152     if (cnode->fullname_with_scope().find(OPTIMIZER_SUB_STRING) != std::string::npos) {
1153       // Do not care CASTs from optimizer
1154       return false;
1155     }
1156     return cnode->in_forward_flag();
1157   }
1158   return IsParallelCareNode(cnode);
1159 }
1160 
UpdateMicroBatchInterleavedStatus(const std::vector<AnfNodePtr> & all_nodes)1161 void UpdateMicroBatchInterleavedStatus(const std::vector<AnfNodePtr> &all_nodes) {
1162   for (auto &node : all_nodes) {
1163     if (!node->isa<CNode>()) {
1164       continue;
1165     }
1166     auto cnode = node->cast<CNodePtr>();
1167     MS_EXCEPTION_IF_NULL(cnode);
1168     if (!IsPrimitiveCNode(cnode, prim::kPrimStridedSlice)) {
1169       continue;
1170     }
1171     auto slice_prim = GetCNodePrimitive(cnode);
1172     MS_EXCEPTION_IF_NULL(slice_prim);
1173     if (!slice_prim->HasAttr(FUNC_GRAPH_FLAG_STRIDED_SLICE)) {
1174       continue;
1175     }
1176     if (!slice_prim->HasAttr(INTERLEAVED_NUM)) {
1177       continue;
1178     }
1179     if (GetValue<int64_t>(slice_prim->GetAttr(INTERLEAVED_NUM)) == MICRO_INTERLEAVED_SIZE) {
1180       ParallelContext::GetInstance()->set_enable_micro_interleaved(true);
1181       cnode->AddAttr(INTERLEAVED_NUM, slice_prim->GetAttr(INTERLEAVED_NUM));
1182     }
1183   }
1184 }
1185 
GetDisOpName(const std::string & prim_name)1186 std::string GetDisOpName(const std::string &prim_name) {
1187   std::string op_name = prim_name;
1188   if (!prim_name.empty() && (prim_name[0] == '_')) {
1189     op_name = prim_name.substr(1);
1190   }
1191   return op_name + "Info";
1192 }
1193 
OperatorInstanceByName(const std::string & name,const PrimitiveAttrs & attrs,const std::vector<Shapes> & shape_list)1194 OperatorInfoPtr OperatorInstanceByName(const std::string &name, const PrimitiveAttrs &attrs,
1195                                        const std::vector<Shapes> &shape_list) {
1196   if (shape_list.size() != 2) {
1197     MS_LOG(ERROR) << "The size of shape list is not 2";
1198     return nullptr;
1199   }
1200   if (name.length() == 0) {
1201     MS_LOG(EXCEPTION) << "Length of name is zero!";
1202   }
1203 
1204   if (name == "Custom" &&
1205       (attrs.find(KAttrAsLossDivisor) == attrs.end() || attrs.find(KAttrDevMatrixShape) == attrs.end() ||
1206        attrs.find(KAttrInputsTensorMap) == attrs.end() || attrs.find(KAttrOutputsTensorMap) == attrs.end())) {
1207     MS_LOG(WARNING) << "The attr for parallelization settings is not found in the custom op."
1208                     << "To enable auto parallelization, set the attrs including [" << KAttrAsLossDivisor << ", "
1209                     << KAttrDevMatrixShape << ", " << KAttrInputsTensorMap << ", " << KAttrOutputsTensorMap << "]";
1210     return nullptr;
1211   }
1212   std::string distribute_opname = GetDisOpName(name);
1213   OperatorInfoPtr op_info =
1214     (OperatorInfoPtr)DynCreator::Instance().Create(distribute_opname, shape_list[0], shape_list[1], attrs, TOTAL_OPS);
1215   if (op_info == nullptr) {
1216     MS_LOG(INFO) << "Create " << name << " failed";
1217     return nullptr;
1218   }
1219   std::string origin_name = op_info->name();
1220   op_info->set_name(origin_name + std::to_string(TOTAL_OPS));
1221   MS_LOG(INFO) << "Successfully created operator " << origin_name;
1222   ++TOTAL_OPS;
1223   return op_info;
1224 }
1225 
OperatorInstance(const PrimitivePtr & prim,const PrimitiveAttrs & attrs,const std::vector<Shapes> & shape_list)1226 OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs,
1227                                  const std::vector<Shapes> &shape_list) {
1228   MS_EXCEPTION_IF_NULL(prim);
1229   OperatorInfoPtr op_info;
1230   if (prim->HasAttr(SELF_DEFINE_SHARD)) {
1231     auto self_define_shard_attr = prim->GetAttr(SELF_DEFINE_SHARD);
1232     if (self_define_shard_attr->cast_ptr<BoolImm>() == nullptr) {
1233       MS_LOG(EXCEPTION) << "SELF_DEFINE_SHARD attribute is not a bool";
1234     }
1235     if (GetValue<bool>(self_define_shard_attr)) {
1236       op_info = OperatorInstanceByName(SELF_DEFINE_SHARD_OP, attrs, shape_list);
1237       MS_LOG(INFO) << "Operator " << prim->name() << " has self_define_shard attribute. Create SelfDefineShardInfo";
1238       return op_info;
1239     }
1240   }
1241   op_info = OperatorInstanceByName(prim->name(), attrs, shape_list);
1242   if (op_info) {
1243     return op_info;
1244   }
1245   if (IsInBatchParallelBlackList(prim)) {
1246     op_info = OperatorInstanceByName(STAND_ALONE, attrs, shape_list);
1247     prim->AddAttr(STAND_ALONE, MakeValue<bool>(true));
1248     MS_LOG(INFO) << "Operator " << prim->name() << " is not supported yet in auto parallel mode. Use Stand Alone";
1249     return op_info;
1250   }
1251   auto input_shape = shape_list[0];
1252   auto output_shape = shape_list[1];
1253   MS_EXCEPTION_IF_NULL(g_device_manager);
1254   auto device_num = g_device_manager->stage_device_num();
1255   MS_EXCEPTION_IF_ZERO("device_num", device_num);
1256   if (input_shape.empty() || input_shape[0].empty() || input_shape[0][0] % device_num != 0 || output_shape[0].empty() ||
1257       output_shape[0][0] % device_num != 0) {
1258     MS_LOG(INFO) << "Operator " << prim->name() << " use Stand Alone, the input shape is " << input_shape
1259                  << ", the output shape is " << output_shape;
1260     op_info = OperatorInstanceByName(STAND_ALONE, attrs, shape_list);
1261     prim->AddAttr(STAND_ALONE, MakeValue<bool>(true));
1262     return op_info;
1263   }
1264   MS_LOG(INFO) << "Operator " << prim->name() << " use Batch Parallel";
1265   op_info = OperatorInstanceByName(BATCH_PARALLEL, attrs, shape_list);
1266   prim->AddAttr(BATCH_PARALLEL, MakeValue<bool>(true));
1267   return op_info;
1268 }
1269 
GetRefKeyNodeShape(const AnfNodePtr & node,const FuncGraphPtr & func_graph)1270 static Shapes GetRefKeyNodeShape(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
1271   MS_EXCEPTION_IF_NULL(node);
1272   MS_EXCEPTION_IF_NULL(func_graph);
1273 
1274   std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(node, func_graph);
1275   if (parameters.size() != 1) {
1276     MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
1277   }
1278 
1279   Shapes input_shapes = GetNodeShape(parameters[0]);
1280   if (input_shapes.size() != 1) {
1281     MS_LOG(EXCEPTION) << "Get input shape failed";
1282   }
1283 
1284   MS_LOG(INFO) << "The parameter shape is " << ShapeToString(input_shapes[0]);
1285   return input_shapes;
1286 }
1287 
ExtractNewShapeAndSymbol(const CNodePtr & node)1288 std::pair<std::vector<NewShapes>, std::vector<Symbols>> ExtractNewShapeAndSymbol(const CNodePtr &node) {
1289   MS_EXCEPTION_IF_NULL(node);
1290   NewShapes shape_inputs;
1291   NewShapes shape_outputs;
1292   Symbols symbol_inputs;
1293   Symbols symbol_outputs;
1294   std::vector<NewShapes> shape_all;
1295   std::vector<Symbols> symbol_all;
1296   std::vector<AnfNodePtr> all_inputs = node->inputs();
1297   bool need_create_shape_list = false;
1298 
1299   const int min_size = 2;
1300   size_t inputs_size = all_inputs.size();
1301   for (size_t i = 1; i < inputs_size; ++i) {
1302     ShapeBasePtr input_new_shapes;
1303     Shapes input_shapes;
1304     Symbols input_symbols;
1305     AnfNodePtr input = all_inputs[i];
1306     if (HasAbstractMonad(input)) {
1307       continue;
1308     }
1309     if (IsValueNode<RefKey>(input)) {
1310       auto func_graph = node->func_graph();
1311       MS_EXCEPTION_IF_NULL(func_graph);
1312       std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(input, func_graph);
1313       if (parameters.size() != 1) {
1314         MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
1315       }
1316       std::pair<AnfNodePtr, int64_t> node_pair = std::make_pair(node, SizeToLong(i));
1317       g_RefMap[parameters[0]] = node_pair;
1318       MS_LOG(INFO) << "Find parameter by ref key node" << node_pair.first;
1319       input_shapes = GetRefKeyNodeShape(input, func_graph);
1320       input_symbols = StaticShapesToSymbols(input_shapes);  // now the parameter can only be static shape
1321     } else if (input->isa<CNode>() || IsValueNode<Tensor>(input) || input->isa<Parameter>() ||
1322                (IsValueSequence(input) &&
1323                 (inputs_size == min_size || IsSomePrimitiveList(node, INPUT_IS_TUPLE_OR_LIST_OPS)))) {
1324       if (IsDynamicShapeInput(node, input)) {
1325         MS_LOG(INFO) << "may be dynamic shape, no need to get input's shape, the node is " << node->ToString();
1326         continue;
1327       }
1328 
1329       if (IsPrimitiveCNode(input, prim::kPrimShape)) {
1330         input_shapes = GetNodeShape(input->cast<CNodePtr>()->input(1));
1331         input_symbols = GetNodeSymbol(input->cast<CNodePtr>()->input(1));
1332       } else {
1333         input_shapes = GetNodeShape(input);
1334         input_symbols = GetNodeSymbol(input);
1335       }
1336       if ((input->abstract()->isa<abstract::AbstractSequence>() || IsValueSequence(input))) {
1337         need_create_shape_list = true;
1338       }
1339     } else if (IsValueSequence(input)) {
1340       auto temp_input_node = input;
1341       if (IsPrimitiveCNode(input, prim::kPrimShape)) {
1342         temp_input_node = input->cast<CNodePtr>()->input(1);
1343       }
1344       need_create_shape_list = true;
1345       input_shapes = GetNodeShape(temp_input_node);
1346       input_symbols = GetNodeSymbol(temp_input_node);
1347     } else {
1348       continue;
1349     }
1350     // For normal shape
1351     input_new_shapes = TransferShapesToNewShapes(input_shapes, need_create_shape_list)[0];
1352     need_create_shape_list = false;
1353     shape_inputs.emplace_back(input_new_shapes);
1354     symbol_inputs.push_back(input_symbols[0]);
1355   }
1356   shape_all.push_back(shape_inputs);
1357   symbol_all.push_back(symbol_inputs);
1358   // extract out shape
1359   shape_outputs = GetNodeNewShape(node);
1360   symbol_outputs = GetNodeSymbol(node);
1361   shape_all.push_back(shape_outputs);
1362   symbol_all.push_back(symbol_outputs);
1363 
1364   return std::make_pair(shape_all, symbol_all);
1365 }
1366 
ExtractShapeAndSymbol(const CNodePtr & node)1367 std::pair<std::vector<Shapes>, std::vector<Symbols>> ExtractShapeAndSymbol(const CNodePtr &node) {
1368   MS_EXCEPTION_IF_NULL(node);
1369   Shapes shape_inputs;
1370   Shapes shape_outputs;
1371   Symbols symbol_inputs;
1372   Symbols symbol_outputs;
1373   std::vector<Shapes> shape_all;
1374   std::vector<Symbols> symbol_all;
1375   std::vector<AnfNodePtr> all_inputs = node->inputs();
1376 
1377   const int min_size = 2;
1378   size_t inputs_size = all_inputs.size();
1379   for (size_t i = 1; i < inputs_size; ++i) {
1380     Shapes input_shapes;
1381     Symbols input_symbols;
1382     AnfNodePtr input = all_inputs[i];
1383     if (HasAbstractMonad(input)) {
1384       continue;
1385     }
1386     if (IsValueNode<RefKey>(input)) {
1387       auto func_graph = node->func_graph();
1388       MS_EXCEPTION_IF_NULL(func_graph);
1389       std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(input, func_graph);
1390       if (parameters.size() != 1) {
1391         MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
1392       }
1393       std::pair<AnfNodePtr, int64_t> node_pair = std::make_pair(node, SizeToLong(i));
1394       g_RefMap[parameters[0]] = node_pair;
1395       MS_LOG(INFO) << "Find parameter by ref key node" << node_pair.first;
1396       input_shapes = GetRefKeyNodeShape(input, func_graph);
1397       input_symbols = StaticShapesToSymbols(input_shapes);  // now the parameter can only be static shape
1398     } else if (input->isa<CNode>() || IsValueNode<Tensor>(input) || input->isa<Parameter>() ||
1399                (IsValueSequence(input) &&
1400                 (inputs_size == min_size || IsSomePrimitiveList(node, INPUT_IS_TUPLE_OR_LIST_OPS)))) {
1401       if (IsDynamicShapeInput(node, input)) {
1402         MS_LOG(INFO) << "may be dynamic shape, no need to get input's shape, the node is " << node->ToString();
1403         continue;
1404       }
1405 
1406       if (IsPrimitiveCNode(input, prim::kPrimShape)) {
1407         input_shapes = GetNodeShape(input->cast<CNodePtr>()->input(1));
1408         input_symbols = GetNodeSymbol(input->cast<CNodePtr>()->input(1));
1409       } else {
1410         input_shapes = GetNodeShape(input);
1411         input_symbols = GetNodeSymbol(input);
1412       }
1413     } else {
1414       continue;
1415     }
1416     if (input_shapes.size() != 1) {
1417       if (inputs_size == min_size || IsSomePrimitiveList(node, INPUT_IS_TUPLE_OR_LIST_OPS)) {
1418         shape_inputs = input_shapes;
1419         symbol_inputs = input_symbols;
1420         break;
1421       } else {
1422         MS_LOG(EXCEPTION) << "ExtractShape: Get input shape failed";
1423       }
1424     }
1425     shape_inputs.push_back(input_shapes[0]);
1426     symbol_inputs.push_back(input_symbols[0]);
1427   }
1428   shape_all.push_back(shape_inputs);
1429   symbol_all.push_back(symbol_inputs);
1430   // extract out shape
1431   shape_outputs = GetNodeShape(node);
1432   symbol_outputs = GetNodeSymbol(node);
1433   shape_all.push_back(shape_outputs);
1434   symbol_all.push_back(symbol_outputs);
1435 
1436   return std::make_pair(shape_all, symbol_all);
1437 }
1438 
ExtractShape(const CNodePtr & node)1439 std::vector<Shapes> ExtractShape(const CNodePtr &node) {
1440   MS_EXCEPTION_IF_NULL(node);
1441   auto shapes_and_symbols = ExtractShapeAndSymbol(node);
1442   return shapes_and_symbols.first;
1443 }
1444 
ExtractNewShape(const CNodePtr & node)1445 std::vector<NewShapes> ExtractNewShape(const CNodePtr &node) {
1446   MS_EXCEPTION_IF_NULL(node);
1447   auto shapes_and_symbols = ExtractNewShapeAndSymbol(node);
1448   return shapes_and_symbols.first;
1449 }
1450 
ExtractRealDivisor(const CNodePtr & node)1451 std::vector<Shapes> ExtractRealDivisor(const CNodePtr &node) {
1452   MS_EXCEPTION_IF_NULL(node);
1453   auto shapes_and_symbols = ExtractShapeAndSymbol(node);
1454   std::vector<Shapes> shapes = shapes_and_symbols.first;
1455   std::vector<Symbols> symbols = shapes_and_symbols.second;
1456   if (shapes.size() != INPUT_OUTPUT_SYMBOLS_SIZE || symbols.size() != INPUT_OUTPUT_SYMBOLS_SIZE) {
1457     MS_LOG(EXCEPTION) << "the size of shapes or symbols must be " << INPUT_OUTPUT_SYMBOLS_SIZE
1458                       << ", but the size of shapes is " << shapes.size() << ", the size of symbols is "
1459                       << symbols.size();
1460   }
1461 
1462   auto inputs_shape = shapes[0];
1463   auto outputs_shape = shapes[1];
1464   auto inputs_symbol = symbols[0];
1465   auto outputs_symbol = symbols[1];
1466 
1467   Shapes in_divisor_symbols;
1468   Shapes out_divisor_symbols;
1469   MS_LOG(DEBUG) << "the node is " << node->ToString() << ", the divisor of inputs is "
1470                 << DivisorOfSymbolsToString(inputs_symbol) << ", the inputs shape is " << ShapesToString(inputs_shape);
1471   in_divisor_symbols = GetRealDivisorSymbols(inputs_shape, inputs_symbol);
1472   out_divisor_symbols = GetRealDivisorSymbols(outputs_shape, outputs_symbol);
1473 
1474   MS_LOG(DEBUG) << "the node is " << node->ToString() << ", the inputs shape is " << ShapesToString(inputs_shape)
1475                 << ", the inputs divisor is " << ShapesToString(in_divisor_symbols);
1476   MS_LOG(DEBUG) << "the node is " << node->ToString() << ", the outputs shape is " << ShapesToString(outputs_shape)
1477                 << ", the outputs divisor is " << ShapesToString(out_divisor_symbols);
1478   return {in_divisor_symbols, out_divisor_symbols};
1479 }
1480 
GetInputNodeWithFilter(const AnfNodePtr & node,std::function<std::pair<bool,size_t> (const CNodePtr &)> filter)1481 AnfNodePtr GetInputNodeWithFilter(const AnfNodePtr &node,
1482                                   std::function<std::pair<bool, size_t>(const CNodePtr &)> filter) {
1483   std::queue<AnfNodePtr> anf_queue;
1484   anf_queue.push(node);
1485   while (!anf_queue.empty()) {
1486     auto queue_end = anf_queue.front();
1487     anf_queue.pop();
1488     if (!queue_end->isa<CNode>()) {
1489       return queue_end;
1490     }
1491     auto cnode_queue_end = queue_end->cast<CNodePtr>();
1492     auto filter_res = filter(cnode_queue_end);
1493     if (!filter_res.first) {
1494       return queue_end;
1495     }
1496     anf_queue.push(cnode_queue_end->input(filter_res.second));
1497   }
1498   return node;
1499 }
1500 
GetOutputNodesWithFilter(const AnfNodePtr & node,std::function<bool (const AnfNodePtr &)> filter)1501 std::vector<std::pair<AnfNodePtr, int>> GetOutputNodesWithFilter(const AnfNodePtr &node,
1502                                                                  std::function<bool(const AnfNodePtr &)> filter) {
1503   auto func_graph = node->func_graph();
1504   MS_EXCEPTION_IF_NULL(func_graph);
1505   auto manager = func_graph->manager();
1506   MS_EXCEPTION_IF_NULL(manager);
1507   std::vector<std::pair<AnfNodePtr, int>> res;
1508   std::queue<AnfNodePtr> anf_queue;
1509   anf_queue.push(node);
1510   while (!anf_queue.empty()) {
1511     auto queue_end = anf_queue.front();
1512     anf_queue.pop();
1513     auto user_set = manager->node_users()[queue_end];
1514     for (auto &pair : user_set) {
1515       if (filter(pair.first)) {
1516         anf_queue.push(pair.first);
1517         continue;
1518       }
1519       res.push_back(pair);
1520     }
1521   }
1522   return res;
1523 }
1524 
GetOutputNodesSkipDepend(const AnfNodePtr & node)1525 std::vector<std::pair<AnfNodePtr, int>> GetOutputNodesSkipDepend(const AnfNodePtr &node) {
1526   auto func_graph = node->func_graph();
1527   MS_EXCEPTION_IF_NULL(func_graph);
1528   auto manager = func_graph->manager();
1529   MS_EXCEPTION_IF_NULL(manager);
1530   std::vector<std::pair<AnfNodePtr, int>> res;
1531   std::queue<AnfNodePtr> anf_queue;
1532   anf_queue.push(node);
1533   while (!anf_queue.empty()) {
1534     auto queue_end = anf_queue.front();
1535     anf_queue.pop();
1536     auto user_set = manager->node_users()[queue_end];
1537     for (auto &pair : user_set) {
1538       if (IsPrimitiveCNode(pair.first, prim::kPrimDepend)) {
1539         if (pair.second == 1) {
1540           anf_queue.push(pair.first);
1541         }
1542         continue;
1543       }
1544       res.push_back(pair);
1545     }
1546   }
1547   return res;
1548 }
1549 
CanMergeConcatSlice(const std::pair<std::shared_ptr<AnfNode>,int> & pair,const CNodePtr & concat_cnode,const ShapeVector & concat_output_shape_element,int64_t concat_axis)1550 std::pair<bool, size_t> CanMergeConcatSlice(const std::pair<std::shared_ptr<AnfNode>, int> &pair,
1551                                             const CNodePtr &concat_cnode,
1552                                             const ShapeVector &concat_output_shape_element, int64_t concat_axis) {
1553   if (!IsPrimitiveCNode(pair.first, prim::kPrimStridedSlice)) {
1554     return {false, 0};
1555   }
1556   auto slice_cnode = pair.first->cast<CNodePtr>();
1557   MS_LOG(INFO) << "concat slice cnode:" << slice_cnode->fullname_with_scope();
1558   auto begin_value = GetValueNode(slice_cnode->input(2));
1559   auto end_value = GetValueNode(slice_cnode->input(3));
1560   auto strided_value = GetValueNode(slice_cnode->input(4));
1561   if (!begin_value || !end_value || !strided_value) {
1562     return {false, 0};
1563   }
1564   auto begin = GetValue<std::vector<int64_t>>(begin_value);
1565   auto end = GetValue<std::vector<int64_t>>(end_value);
1566   auto strided = GetValue<std::vector<int64_t>>(strided_value);
1567   if (!std::all_of(strided.begin(), strided.end(), [](auto s) { return s == 1; })) {
1568     return {false, 0};
1569   }
1570   if (!IsPrimitiveCNode(concat_cnode->input(1), prim::kPrimMakeTuple)) {
1571     return {false, 0};
1572   }
1573   auto concat_input_node = concat_cnode->input(1)->cast<CNodePtr>();
1574   auto concat_input_size = concat_input_node->size();
1575   bool can_merge = false;
1576   size_t concat_input_index = 0;
1577   for (size_t i = 0; i < begin.size(); ++i) {
1578     int64_t slice_len = (end[i] - begin[i]);
1579     if (i == size_t(concat_axis)) {
1580       int64_t slice_index = begin[i] / slice_len;
1581       if (slice_len == concat_output_shape_element[i] || size_t(slice_index + 1) >= concat_input_size) {
1582         can_merge = false;
1583         break;
1584       }
1585       concat_input_index = size_t(slice_index + 1);
1586       can_merge = true;
1587     } else if (slice_len != concat_output_shape_element[i]) {
1588       can_merge = false;
1589       break;
1590     }
1591   }
1592   return {can_merge, concat_input_index};
1593 }
1594 
UpdateUpdateStateForMergeConcatSlice(const FuncGraphManagerPtr & manager,const std::vector<std::pair<AnfNodePtr,int>> & update_list,const CNodePtr & tuple_get_item_node)1595 void UpdateUpdateStateForMergeConcatSlice(const FuncGraphManagerPtr &manager,
1596                                           const std::vector<std::pair<AnfNodePtr, int>> &update_list,
1597                                           const CNodePtr &tuple_get_item_node) {
1598   for (const auto &ups_pair : update_list) {
1599     manager->SetEdge(ups_pair.first, ups_pair.second, tuple_get_item_node);
1600   }
1601 }
1602 
HandleFuncConcatSlice(const FuncGraphManagerPtr & manager,const std::pair<std::shared_ptr<AnfNode>,int> & pair,const CNodePtr & concat_cnode,const ShapeVector & concat_output_shape_element,int64_t concat_axis)1603 bool HandleFuncConcatSlice(const FuncGraphManagerPtr &manager, const std::pair<std::shared_ptr<AnfNode>, int> &pair,
1604                            const CNodePtr &concat_cnode, const ShapeVector &concat_output_shape_element,
1605                            int64_t concat_axis) {
1606   auto fg = pair.first->func_graph();
1607   auto fg_map = fg->func_graph_cnodes_index();
1608   if (fg_map.size() > 1) {
1609     return false;
1610   }
1611   for (auto &fg_use : fg_map) {
1612     if (!fg_use.first->first->isa<CNode>() || fg_use.first->second > 0) {
1613       continue;
1614     }
1615     auto call_cnode = fg_use.first->first->cast<CNodePtr>();
1616     auto func_users = manager->node_users()[call_cnode];
1617     std::vector<std::pair<AnfNodePtr, int>> update_list;
1618     size_t func_users_size = 0;
1619     std::pair<AnfNodePtr, int> fg_users;
1620     for (auto &cur_fg_users : func_users) {
1621       if (IsPrimitiveCNode(cur_fg_users.first, prim::kPrimUpdateState)) {
1622         update_list.push_back(cur_fg_users);
1623         continue;
1624       }
1625       ++func_users_size;
1626       fg_users = cur_fg_users;
1627     }
1628 
1629     if (func_users_size > 1) {
1630       continue;
1631     }
1632     auto func_node_users = FuncGraphNodeUsers(fg_users);
1633     if (func_node_users.empty()) {
1634       continue;
1635     }
1636     bool have_can_merge = false;
1637     std::vector<std::pair<bool, size_t>> input_index;
1638     for (const auto &new_pair : func_node_users) {
1639       auto can_merge = CanMergeConcatSlice(new_pair, concat_cnode, concat_output_shape_element, concat_axis);
1640       input_index.push_back(can_merge);
1641       if (can_merge.first) {
1642         have_can_merge = true;
1643       }
1644     }
1645     if (!have_can_merge) {
1646       continue;
1647     }
1648     // maketuple->Return
1649     auto concat_input_node = concat_cnode->input(1)->cast<CNodePtr>();
1650     manager->SetEdge(pair.first, pair.second, concat_input_node);
1651     // call -> tuplegetitem -> call
1652     auto user_func_graph = GetValueNode<FuncGraphPtr>(fg_users.first->cast<CNodePtr>()->input(0));
1653     auto user_graph_parameters = user_func_graph->parameters();
1654     auto origin_parameter = user_graph_parameters[fg_users.second - 1];
1655     auto new_user_graph_parameters(user_graph_parameters);
1656     new_user_graph_parameters.erase(new_user_graph_parameters.begin() + fg_users.second - 1);
1657     auto fg_users_inputs_all(fg_users.first->cast<CNodePtr>()->inputs());
1658     fg_users_inputs_all.erase(fg_users_inputs_all.begin() + fg_users.second);
1659     // New concat CNode in user_func_graph
1660     std::vector<AnfNodePtr> new_concat_maketuple_inputs{NewValueNode(prim::kPrimMakeTuple)};
1661     std::vector<AbstractBasePtr> new_maketuple_abstracts;
1662     bool updated_update_state = false;
1663     for (size_t i = 0; i < concat_input_node->size() - 1; ++i) {
1664       std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), call_cnode,
1665                                                     ValuePtrToAnfNodePtr(MakeValue<int64_t>(i))};
1666       auto tuple_get_item_node = call_cnode->func_graph()->NewCNode(tuple_get_item_inputs);
1667       if (!updated_update_state) {
1668         UpdateUpdateStateForMergeConcatSlice(manager, update_list, tuple_get_item_node);
1669         updated_update_state = true;
1670       }
1671       // replace fg_users->inputs(fg_users.second) to a list fg_users->inputs(fg_users.second+i)
1672       fg_users_inputs_all.insert(fg_users_inputs_all.begin() + fg_users.second + i, tuple_get_item_node);
1673       auto new_parameter = user_func_graph->add_parameter();
1674       new_parameter->set_abstract(concat_input_node->input(i + 1)->abstract()->Clone());
1675       new_maketuple_abstracts.push_back(concat_input_node->input(i + 1)->abstract()->Clone());
1676       new_user_graph_parameters.insert(new_user_graph_parameters.begin() + fg_users.second - 1 + i, new_parameter);
1677       new_concat_maketuple_inputs.push_back(new_parameter);
1678     }
1679     user_func_graph->set_parameters(new_user_graph_parameters);
1680     auto user_func_graph_return_cnode = user_func_graph->get_return();
1681     auto return_input_cnode = user_func_graph_return_cnode->input(kIndex1);
1682     auto new_call_cnode = fg_users.first->func_graph()->NewCNode(fg_users_inputs_all);
1683     new_call_cnode->set_abstract(return_input_cnode->abstract()->Clone());
1684     manager->Replace(fg_users.first, new_call_cnode);
1685     // Handle user_func_graph slice cnode
1686     for (size_t j = 0; j < func_node_users.size(); ++j) {
1687       auto new_pair = func_node_users[j];
1688       if (!input_index[j].first) {
1689         auto new_maketuple_cnode = user_func_graph->NewCNode(new_concat_maketuple_inputs);
1690         new_maketuple_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(new_maketuple_abstracts));
1691         auto old_concat_prim = GetCNodePrimitive(concat_cnode);
1692         std::vector<AnfNodePtr> new_concat_inputs{NewValueNode(old_concat_prim->Clone()), new_maketuple_cnode,
1693                                                   NewValueNode(MakeValue<int64_t>(concat_axis))};
1694         auto new_concat = user_func_graph->NewCNode(new_concat_inputs);
1695         new_concat->set_abstract(concat_cnode->abstract()->Clone());
1696         auto new_concat_prim = GetCNodePrimitive(new_concat);
1697         if (new_concat_prim->HasAttr("fine_grained_interleaved_index")) {
1698           new_concat_prim->EraseAttr("fine_grained_interleaved_index");
1699         }
1700         manager->SetEdge(new_pair.first, new_pair.second, new_concat);
1701         continue;
1702       }
1703       manager->Replace(new_pair.first, user_func_graph->parameters()[fg_users.second - 2 + input_index[j].second]);
1704     }
1705   }
1706   return true;
1707 }
1708 
MergeConcatSlice(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphManagerPtr & manager)1709 bool MergeConcatSlice(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager) {
1710   bool merged = false;
1711   for (const auto &node : all_nodes) {
1712     if (!IsPrimitiveCNode(node, prim::kPrimConcat)) {
1713       continue;
1714     }
1715     auto concat_cnode = node->cast<CNodePtr>();
1716     MS_EXCEPTION_IF_NULL(concat_cnode->abstract());
1717     auto concat_output_shape = concat_cnode->abstract()->BuildShape();
1718     MS_EXCEPTION_IF_NULL(concat_output_shape);
1719     MS_EXCEPTION_IF_NULL(concat_output_shape->cast<abstract::ShapePtr>());
1720     auto concat_output_shape_element = concat_output_shape->cast<abstract::ShapePtr>()->shape();
1721     auto axis_value_node = concat_cnode->input(kIndex2);
1722     auto axis_value = GetValueNode(axis_value_node);
1723     auto concat_axis = GetValue<int64_t>(axis_value);
1724     auto next_nodes = GetOutputNodesSkipDepend(node);
1725     for (const auto &pair : next_nodes) {
1726       if (IsPrimitiveCNode(pair.first, prim::kPrimReturn) && next_nodes.size() == 1) {
1727         merged = HandleFuncConcatSlice(manager, pair, concat_cnode, concat_output_shape_element, concat_axis);
1728         continue;
1729       }
1730       auto can_merge = CanMergeConcatSlice(pair, concat_cnode, concat_output_shape_element, concat_axis);
1731       if (!can_merge.first) {
1732         continue;
1733       }
1734       auto concat_input_node = concat_cnode->input(1)->cast<CNodePtr>();
1735       auto concat_real_input_node = concat_input_node->input(can_merge.second);
1736       manager->Replace(pair.first->cast<CNodePtr>(), concat_real_input_node);
1737       merged = true;
1738     }
1739   }
1740   return merged;
1741 }
1742 
NewMicroMirrorPrimByMicroMirror(const FuncGraphPtr & func_graph,const CNodePtr & micro_mirror,const AnfNodePtr & micro_mirror_new_input)1743 AnfNodePtr NewMicroMirrorPrimByMicroMirror(const FuncGraphPtr &func_graph, const CNodePtr &micro_mirror,
1744                                            const AnfNodePtr &micro_mirror_new_input) {
1745   auto prim_origin = GetCNodePrimitive(micro_mirror);
1746   Attr attr0 = std::make_pair(GROUP, prim_origin->GetAttr(GROUP));
1747   Attr attr1 = std::make_pair(DEV_NUM, prim_origin->GetAttr(DEV_NUM));
1748   Attr attr2 = std::make_pair(MEAN_FLAG, prim_origin->GetAttr(MEAN_FLAG));
1749   OperatorAttrs operator_attrs;
1750   operator_attrs.push_back(attr0);
1751   operator_attrs.push_back(attr1);
1752   operator_attrs.push_back(attr2);
1753   ValuePtr pyop_instance = CreateOpInstance(operator_attrs, MIRROR_MICRO_STEP_OPERATOR, prim_origin->instance_name());
1754   MS_EXCEPTION_IF_NULL(pyop_instance);
1755   std::vector<AnfNodePtr> mirror_inputs{NewValueNode(pyop_instance), micro_mirror_new_input,
1756                                         micro_mirror->input(kIndex2)};
1757   auto new_mirror_node = func_graph->NewCNode(mirror_inputs);
1758   auto prim = GetCNodePrimitive(new_mirror_node);
1759   (void)prim->SetAttrs(prim_origin->attrs());
1760   new_mirror_node->set_attrs(micro_mirror->attrs());
1761   new_mirror_node->set_primal_attrs(micro_mirror->primal_attrs());
1762   return new_mirror_node;
1763 }
1764 
AddNodeFusionInfo(const CNodePtr & node,const CNodePtr & comm_node,const std::string & backward_comm_name,const std::string & param_name,int32_t fusion_id)1765 void AddNodeFusionInfo(const CNodePtr &node, const CNodePtr &comm_node, const std::string &backward_comm_name,
1766                        const std::string &param_name, int32_t fusion_id) {
1767   auto comm_id = MakeValue<std::string>(param_name);
1768   comm_node->AddPrimalAttr(kPrimalAttrMirrorUserId, comm_id);
1769   if (GetValueNode<PrimitivePtr>(comm_node->input(0))->HasAttr(GROUP)) {
1770     auto comm_group = GetValue<std::string>(GetValueNode<PrimitivePtr>(comm_node->input(0))->GetAttr(GROUP));
1771     std::string fusion_key = backward_comm_name + "_" + comm_group + "_" + std::to_string(fusion_id);
1772     if (!IsPrimitiveCNode(node, prim::kPrimLoad) && !IsPrimitiveCNode(node, prim::kPrimCast)) {
1773       if (fusion_id > 0) {
1774         node->AddPrimalAttr(kRelatedFusionKey, MakeValue<std::string>(fusion_key));
1775         node->AddPrimalAttr(kRelatedNodeId, MakeValue<std::string>(node->UniqueId()));
1776         node->AddAttr(kRelatedCommNodeId, MakeValue<std::string>(comm_node->UniqueId()));
1777       }
1778       node->AddPrimalAttr(kPrimalAttrMirrorUserId, comm_id);
1779       return;
1780     }
1781     auto next_nodes = GetOutputNodesWithFilter(node, [&](const AnfNodePtr &anode) {
1782       return IsPrimitiveCNode(anode, prim::kPrimLoad) || IsPrimitiveCNode(anode, prim::kPrimCast) ||
1783              IsPrimitiveCNode(anode, prim::kPrimAllGather) || IsPrimitiveCNode(anode, prim::kPrimMirror) ||
1784              IsPrimitiveCNode(anode, prim::kPrimMicroStepAllGather) ||
1785              IsPrimitiveCNode(anode, prim::kPrimMirrorMicroStep) || IsPrimitiveCNode(anode, prim::kPrimMakeTuple);
1786     });
1787     for (auto &pair : next_nodes) {
1788       if (!IsPrimitiveCNode(pair.first)) {
1789         continue;
1790       }
1791       auto next_cnode = pair.first->cast<CNodePtr>();
1792       if (fusion_id > 0) {
1793         next_cnode->AddPrimalAttr(kRelatedFusionKey, MakeValue<std::string>(fusion_key));
1794         next_cnode->AddPrimalAttr(kRelatedNodeId, MakeValue<std::string>(node->UniqueId()));
1795         next_cnode->AddAttr(kRelatedCommNodeId, MakeValue<std::string>(comm_node->UniqueId()));
1796       }
1797       next_cnode->AddPrimalAttr(kPrimalAttrMirrorUserId, comm_id);
1798     }
1799   }
1800 }
1801 
AddNodeMirrorInfo(const CNodePtr & cnode,const std::string & param_name)1802 void AddNodeMirrorInfo(const CNodePtr &cnode, const std::string &param_name) {
1803   auto comm_id = MakeValue<std::string>(param_name);
1804   if (IsParallelCareNode(cnode)) {
1805     cnode->AddPrimalAttr(kPrimalAttrMirrorUserId, comm_id);
1806     return;
1807   }
1808   auto next_nodes = GetOutputNodesWithFilter(cnode, [&](const AnfNodePtr &anode) {
1809     return IsPrimitiveCNode(anode, prim::kPrimLoad) || IsPrimitiveCNode(anode, prim::kPrimCast) ||
1810            IsPrimitiveCNode(anode, prim::kPrimAllGather) || IsPrimitiveCNode(anode, prim::kPrimMirror) ||
1811            IsPrimitiveCNode(anode, prim::kPrimMicroStepAllGather) ||
1812            IsPrimitiveCNode(anode, prim::kPrimMirrorMicroStep) || IsPrimitiveCNode(anode, prim::kPrimMakeTuple);
1813   });
1814   for (auto &pair : next_nodes) {
1815     if (!IsPrimitiveCNode(pair.first)) {
1816       continue;
1817     }
1818     auto next_node = pair.first->cast<CNodePtr>();
1819     next_node->AddPrimalAttr(kPrimalAttrMirrorUserId, comm_id);
1820   }
1821 }
1822 
GetMakeTupleValue(const AnfNodePtr & node)1823 static ValuePtr GetMakeTupleValue(const AnfNodePtr &node) {
1824   auto cnode = node->cast<CNodePtr>();
1825   auto &inputs = cnode->inputs();
1826 
1827   std::vector<int64_t> value_list;
1828   for (size_t index = 1; index < inputs.size(); ++index) {
1829     if (inputs[index]->isa<ValueNode>()) {
1830       auto element = GetValueNode(inputs[index]);
1831       if (element->isa<Int64Imm>()) {
1832         int64_t value = element->cast<Int64ImmPtr>()->value();
1833         value_list.push_back(value);
1834         continue;
1835       }
1836     }
1837     value_list.push_back(-1);  // dynamic shape
1838   }
1839 
1840   MS_LOG(INFO) << "the make tuple value is " << value_list;
1841   return MakeValue(value_list);
1842 }
1843 
HasSupportedValueSequence(const CNodePtr & node)1844 bool HasSupportedValueSequence(const CNodePtr &node) {
1845   const auto &all_inputs = node->inputs();
1846   return std::any_of(all_inputs.begin() + 1, all_inputs.end(), [&node](const AnfNodePtr &input) {
1847     bool is_abs_seq = false;
1848     auto abs = input->abstract();
1849     if (abs != nullptr) {
1850       is_abs_seq = abs->isa<abstract::AbstractSequence>();
1851     }
1852     return (is_abs_seq || IsValueSequence(input)) && IsSomePrimitiveList(node, SUPPORT_NEW_SHAPEBASE_OPS);
1853   });
1854 }
1855 
CreateOperatorInfoForTupleShape(const CNodePtr & cnode)1856 OperatorInfoPtr CreateOperatorInfoForTupleShape(const CNodePtr &cnode) {
1857   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
1858   MS_EXCEPTION_IF_NULL(prim);
1859   MS_LOG(INFO) << prim->name() << ": has value sequence input, enter new shape logic.";
1860   std::pair<std::vector<NewShapes>, std::vector<Symbols>> shapes_and_symbols = ExtractNewShapeAndSymbol(cnode);
1861   auto shape_list = shapes_and_symbols.first;
1862   auto symbol_list = shapes_and_symbols.second;
1863   if (shape_list.size() != INPUT_OUTPUT_SYMBOLS_SIZE || symbol_list.size() != INPUT_OUTPUT_SYMBOLS_SIZE) {
1864     MS_LOG(EXCEPTION) << "the size of shapes or symbols must be " << INPUT_OUTPUT_SYMBOLS_SIZE
1865                       << ", but the size of shapes is " << shape_list.size() << ", the size of symbols is "
1866                       << symbol_list.size();
1867   }
1868   auto attrs = prim->attrs();
1869   std::vector<Shapes> temp_shape_list = {{}, {}};
1870   OperatorInfoPtr op_info = OperatorInstance(prim, attrs, temp_shape_list);
1871   MS_EXCEPTION_IF_NULL(op_info);
1872 
1873   // When the 'inputs' contains numerical values for some operators, these values should be extracted from
1874   // ANF graph
1875   auto &inputs = cnode->inputs();
1876   std::vector<ValuePtr> input_value;
1877   for (size_t index = 1; index < inputs.size(); ++index) {
1878     if (inputs[index]->isa<ValueNode>() || inputs[index]->isa<tensor::Tensor>()) {
1879       (void)input_value.emplace_back(GetValueNode(inputs[index]));
1880       continue;
1881     } else if (IsPrimitiveCNode(inputs[index], prim::kPrimMakeTuple)) {
1882       auto make_tuple_value = GetMakeTupleValue(inputs[index]);
1883       (void)input_value.emplace_back(make_tuple_value);
1884       continue;
1885     } else if (IsPrimitiveCNode(inputs[index], prim::kPrimShape)) {
1886       auto shape_op_cnode = dyn_cast_ptr<CNode>(inputs[index]);
1887       auto dst_shape = GetNodeShape(shape_op_cnode->input(1));
1888       (void)input_value.emplace_back(MakeValue(dst_shape[0]));
1889       MS_LOG(INFO) << "The prim is " << prim->name() << ", the input index is " << index - 1
1890                    << ", is Shape op, dst shape is " << dst_shape;
1891       continue;
1892     }
1893     (void)input_value.emplace_back(nullptr);
1894   }
1895   (*op_info).set_input_value(input_value);
1896   (*op_info).set_outputs_dtype(cnode->Type());
1897   (*op_info).set_cnode(cnode);
1898   (*op_info).set_new_shape(shape_list);
1899   return op_info;
1900 }
1901 
CreateOperatorInfo(const CNodePtr & cnode)1902 OperatorInfoPtr CreateOperatorInfo(const CNodePtr &cnode) {
1903   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
1904   MS_EXCEPTION_IF_NULL(prim);
1905   if (HasSupportedValueSequence(cnode)) {
1906     return CreateOperatorInfoForTupleShape(cnode);
1907   }
1908   std::pair<std::vector<Shapes>, std::vector<Symbols>> shapes_and_symbols = ExtractShapeAndSymbol(cnode);
1909   auto shape_list = shapes_and_symbols.first;
1910   auto symbol_list = shapes_and_symbols.second;
1911   if (shape_list.size() != INPUT_OUTPUT_SYMBOLS_SIZE || symbol_list.size() != INPUT_OUTPUT_SYMBOLS_SIZE) {
1912     MS_LOG(EXCEPTION) << "the size of shapes or symbols must be " << INPUT_OUTPUT_SYMBOLS_SIZE
1913                       << ", but the size of shapes is " << shape_list.size() << ", the size of symbols is "
1914                       << symbol_list.size();
1915   }
1916 
1917   auto attrs = prim->attrs();
1918   OperatorInfoPtr op_info = OperatorInstance(prim, attrs, shape_list);
1919   MS_EXCEPTION_IF_NULL(op_info);
1920   MS_LOG(INFO) << "shape_list.size(): " << shape_list.size();
1921 
1922   // When the 'inputs' contains numerical values for some operators, these values should be extracted from
1923   // ANF graph
1924   auto &inputs = cnode->inputs();
1925   std::vector<ValuePtr> input_value;
1926   for (size_t index = 1; index < inputs.size(); ++index) {
1927     if (inputs[index]->isa<ValueNode>() || inputs[index]->isa<tensor::Tensor>()) {
1928       (void)input_value.emplace_back(GetValueNode(inputs[index]));
1929       continue;
1930     } else if (IsPrimitiveCNode(inputs[index], prim::kPrimMakeTuple)) {
1931       auto make_tuple_value = GetMakeTupleValue(inputs[index]);
1932       (void)input_value.emplace_back(make_tuple_value);
1933       continue;
1934     } else if (IsPrimitiveCNode(inputs[index], prim::kPrimShape)) {
1935       auto shape_op_cnode = dyn_cast_ptr<CNode>(inputs[index]);
1936       auto dst_shape = GetNodeShape(shape_op_cnode->input(1));
1937       (void)input_value.emplace_back(MakeValue(dst_shape[0]));
1938       MS_LOG(INFO) << "The prim is " << prim->name() << ", the input index is " << index - 1
1939                    << ", is Shape op, dst shape is " << dst_shape;
1940       continue;
1941     }
1942     (void)input_value.emplace_back(nullptr);
1943   }
1944 
1945   (*op_info).set_input_value(input_value);
1946   (*op_info).set_outputs_dtype(cnode->Type());
1947   (*op_info).set_cnode(cnode);
1948   if (InDynamicGraph(cnode) && IsDynamicShapesList(shape_list)) {
1949     Shapes in_real_divisors;
1950     Shapes out_real_divisors;
1951     in_real_divisors = GetRealDivisorSymbols(shape_list[INPUT_SYMBOLS_INDEX], symbol_list[INPUT_SYMBOLS_INDEX]);
1952     out_real_divisors = GetRealDivisorSymbols(shape_list[OUTPUT_SYMBOLS_INDEX], symbol_list[OUTPUT_SYMBOLS_INDEX]);
1953     (*op_info).set_dynamic_shape_flag(True);
1954     (*op_info).set_inputs_divisor(in_real_divisors);
1955     (*op_info).set_outputs_divisor(out_real_divisors);
1956     MS_LOG(DEBUG) << (*op_info).name() << ": inputs-shape: " << ShapesToString(shape_list[0])
1957                   << ", inputs_d_symbol: " << ShapesToString(in_real_divisors);
1958     MS_LOG(DEBUG) << (*op_info).name() << ": outputs-shape: " << ShapesToString(shape_list[1])
1959                   << ", outputs_d_symbol: " << ShapesToString(out_real_divisors);
1960   }
1961   return op_info;
1962 }
1963 
ExtendInputArgsAbstractShape(const AbstractBasePtr & args_abstract_item,size_t index)1964 void ExtendInputArgsAbstractShape(const AbstractBasePtr &args_abstract_item, size_t index) {
1965   auto args_abstract_item_shape = args_abstract_item->BuildShape();
1966   auto shape_ptr = dyn_cast<abstract::Shape>(args_abstract_item_shape);
1967   if (shape_ptr == nullptr) {
1968     MS_LOG(WARNING) << "The input " << index << " is not a tensor.";
1969     return;
1970   }
1971   auto shape_value = parallel::ToFullShape(shape_ptr->shape(), index);
1972   auto new_shape_item = std::make_shared<abstract::Shape>(shape_value);
1973   args_abstract_item->set_shape(new_shape_item);
1974 }
1975 
ToFullShape(const ShapeVector & input_shape,size_t index)1976 ShapeVector ToFullShape(const ShapeVector &input_shape, size_t index) {
1977   if (input_shape.empty()) {
1978     return input_shape;
1979   }
1980   MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
1981   if (ParallelContext::GetInstance()->dataset_strategy().empty()) {
1982     auto shape_value = input_shape;
1983     if (!parallel::ParallelContext::GetInstance()->full_batch()) {
1984       auto comm_info = parallel::GetCommInfo();
1985       auto world_rank_size = comm_info.device_num / ParallelContext::GetInstance()->pipeline_stage_split_num();
1986       if (shape_value[0] > 0) {
1987         shape_value[0] = shape_value[0] * SizeToLong(world_rank_size);  // only for static shape
1988       }
1989     }
1990     return shape_value;
1991   }
1992   auto dataset_strategy = ParallelContext::GetInstance()->dataset_strategy();
1993   if (index >= dataset_strategy.size()) {
1994     MS_LOG(EXCEPTION) << "The input shapes size is not equal to dataset strategy size " << dataset_strategy.size();
1995   }
1996   auto dataset_strategy_item = dataset_strategy[index];
1997   if (input_shape.size() != dataset_strategy_item.size()) {
1998     MS_LOG(EXCEPTION) << "The input_shapes[" << index << "]'s size" << input_shape.size()
1999                       << " is not equal to dataset_strategy[" << index << "]'s size " << dataset_strategy_item.size();
2000   }
2001   ShapeVector shape_value;
2002   for (size_t i = 0; i < dataset_strategy_item.size(); ++i) {
2003     if (input_shape[i] > 0) {
2004       shape_value.push_back(input_shape[i] * dataset_strategy_item[i]);
2005     } else {
2006       shape_value.push_back(input_shape[i]);  // dynamic shape, shape is still -1
2007     }
2008   }
2009   return shape_value;
2010 }
2011 
GetCommInfo()2012 CommInfo GetCommInfo() {
2013   int64_t device_num = ParallelContext::GetInstance()->device_num();
2014   int64_t global_rank = ParallelContext::GetInstance()->global_rank();
2015   auto ms_context = MsContext::GetInstance();
2016   MS_EXCEPTION_IF_NULL(ms_context);
2017   std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
2018   std::string world_group;
2019   std::string communication_backend;
2020   if (backend == kAscendDevice || backend == kDavinciDevice) {
2021     world_group = HCCL_WORLD_GROUP;
2022     communication_backend = HCCL_BACKEND;
2023   } else if (backend == kGPUDevice) {
2024     world_group = NCCL_WORLD_GROUP;
2025     communication_backend = NCCL_BACKEND;
2026   } else {
2027     MS_LOG(EXCEPTION) << "Invalid communication backend: " << backend
2028                       << " for semi_auto_parallel/auto_parallel mode,"
2029                          " currently only support Ascend/GPU backend.";
2030   }
2031   uint32_t world_rank_size = 0;
2032   if (!CommManager::GetInstance().GetRankSize(world_group, &world_rank_size)) {
2033     MS_LOG(EXCEPTION) << "Get rank size failed";
2034   }
2035 
2036   if (!ParallelContext::GetInstance()->device_num_is_set()) {
2037     device_num = UintToInt(world_rank_size);
2038     MS_LOG(INFO) << "Get device num from communication model, the device num is  " << device_num;
2039   }
2040 #if (!defined(_WIN32) && !defined(__APPLE__) && !(defined(ENABLE_TESTCASES) || defined(ENABLE_TEST)))
2041   if (ParallelContext::GetInstance()->device_num_is_set() && world_rank_size != device_num &&
2042       !ParallelContext::GetInstance()->hccl_test_available()) {
2043     // hccl_test_available is used when we compile graphs in real ascend card environment, but with hccl_test.
2044     MS_LOG(EXCEPTION) << "The device_num " << device_num << " set in the context is not consist with "
2045                       << world_rank_size << " devices you have"
2046                       << ". Please check your rank_table file(for Ascend) or host file(for GPU).";
2047   }
2048 #endif
2049   uint32_t rank_id = 0;
2050   if (!ParallelContext::GetInstance()->global_rank_is_set()) {
2051     if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) {
2052       MS_LOG(EXCEPTION) << "Get rank id failed";
2053     }
2054     global_rank = UintToInt(rank_id);
2055     ParallelContext::GetInstance()->set_global_rank(global_rank);
2056     MS_LOG(INFO) << "Get global rank from communication model, the global rank is  " << global_rank;
2057   }
2058   CommInfo comm_info{device_num, global_rank, world_group, communication_backend};
2059   return comm_info;
2060 }
2061 
IsPynativeParallel()2062 bool IsPynativeParallel() {
2063   auto parallel_mode = ParallelContext::GetInstance()->parallel_mode();
2064   auto execution_mode = MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE);
2065   return (execution_mode == kPynativeMode) && (parallel_mode == kSemiAutoParallel || parallel_mode == kAutoParallel);
2066 }
2067 
IsAutoParallelCareGraph(const FuncGraphPtr & func_graph)2068 bool IsAutoParallelCareGraph(const FuncGraphPtr &func_graph) {
2069   // compile graph order:
2070   // 1, ParallelParameterContextRestoreShape
2071   // 2, PipelineSplit: insert virtual dataset
2072   // 3, StepAutoParallel
2073   // 4, StepParallel
2074   // if IsParallel() is true, it maybe has some graphs that we now care, so need to check
2075   // 'sharded' or 'has_shard' flag
2076   MS_EXCEPTION_IF_NULL(func_graph);
2077   if (func_graph->has_flag(kSkipAutoParallelCompile)) {
2078     return false;
2079   }
2080 
2081   MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
2082   std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
2083   if (parallel_mode != kAutoParallel && parallel_mode != kSemiAutoParallel) {
2084     return false;
2085   }
2086 
2087   if (IsPynativeParallel() && !func_graph->has_flag(kHasShard) && !(func_graph->has_flag(kSharded))) {
2088     return false;
2089   }
2090   return true;
2091 }
2092 
FindPreNodeCrossFuncGraph(CNodePtr * cnode,int64_t out_index)2093 void FindPreNodeCrossFuncGraph(CNodePtr *cnode, int64_t out_index) {
2094   if (IsValueNode<FuncGraph>((*cnode)->input(0))) {
2095     auto graph = GetValueNode<FuncGraphPtr>((*cnode)->input(0));
2096     auto output = graph->output();
2097     MS_EXCEPTION_IF_NULL(output);
2098     while (IsPrimitiveCNode(output, prim::kPrimDepend)) {
2099       auto output_cnode = output->cast<CNodePtr>();
2100       MS_EXCEPTION_IF_NULL(output_cnode);
2101       output = output_cnode->input(1);
2102     }
2103     while (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
2104       auto make_tuple_cnode = output->cast<CNodePtr>();
2105       output = make_tuple_cnode->input(out_index + 1);
2106     }
2107     *cnode = output->cast<CNodePtr>();
2108   }
2109 }
2110 
FindRealInputByFormalParameter(const CNodePtr & node,const AnfNodePtr & input,const std::vector<AnfNodePtr> & all_nodes)2111 AnfNodePtr FindRealInputByFormalParameter(const CNodePtr &node, const AnfNodePtr &input,
2112                                           const std::vector<AnfNodePtr> &all_nodes) {
2113   auto prev_node = input;
2114   auto graph = node->func_graph();
2115   auto params = graph->parameters();
2116   int64_t param_index = -1;
2117   for (size_t j = 0; j < params.size(); ++j) {
2118     if (params[j] == input) {
2119       param_index = SizeToLong(j);
2120     }
2121   }
2122   if (param_index == -1) {
2123     return prev_node;
2124   }
2125   for (auto &ele : all_nodes) {
2126     if (!ele->isa<CNode>()) {
2127       continue;
2128     }
2129     auto parent_node = ele->cast<CNodePtr>();
2130     if (IsValueNode<FuncGraph>(parent_node->input(0)) && GetValueNode<FuncGraphPtr>(parent_node->input(0)) == graph) {
2131       return parent_node->input(param_index + 1);
2132     }
2133   }
2134   return prev_node;
2135 }
2136 
CrossInterNode(CNodePtr * prev_cnode,ValueNodePtr * prev_prim_anf_node,PrimitivePtr * prev_prim)2137 bool CrossInterNode(CNodePtr *prev_cnode, ValueNodePtr *prev_prim_anf_node, PrimitivePtr *prev_prim) {
2138   if ((*prev_cnode == nullptr) ||
2139       !(IsValueNode<Primitive>((*prev_cnode)->input(0)) || IsValueNode<FuncGraph>((*prev_cnode)->input(0)))) {
2140     return true;
2141   }
2142   if (!IsValueNode<FuncGraph>((*prev_cnode)->input(0))) {
2143     *prev_prim_anf_node = (*prev_cnode)->input(0)->cast<ValueNodePtr>();
2144     *prev_prim = (*prev_prim_anf_node)->value()->cast<PrimitivePtr>();
2145   }
2146   return false;
2147 }
2148 
IsCarePrevCNode(const CNodePtr & prev_cnode,const PrimitivePtr & prev_prim)2149 bool IsCarePrevCNode(const CNodePtr &prev_cnode, const PrimitivePtr &prev_prim) {
2150   return (IsValueNode<FuncGraph>(prev_cnode->input(0))) || (prev_prim->name() == kTupleGetItemOpName) ||
2151          (prev_prim->name() == kDependOpName) || (prev_prim->name() == kMakeListOpName) ||
2152          (prev_prim->name() == kLoadOpName) || (prev_prim->name() == kMakeTupleOpName) ||
2153          (prev_prim->name() == kShapeOpName) || IsAutoParallelCareNode(prev_cnode);
2154 }
2155 
IsCrossedCNode(std::string prev_prim_name)2156 bool IsCrossedCNode(std::string prev_prim_name) {
2157   const std::set<std::string> crossed_cnode_list = {kDependOpName, kLoadOpName, kShapeOpName};
2158   return crossed_cnode_list.find(prev_prim_name) != crossed_cnode_list.end();
2159 }
2160 
2161 // Needed by rec_parser
ExtractInputsTensorName(const CNodePtr & node,const std::vector<AnfNodePtr> & all_nodes)2162 std::vector<std::string> ExtractInputsTensorName(const CNodePtr &node, const std::vector<AnfNodePtr> &all_nodes) {
2163   std::vector<std::string> name_inputs;
2164   std::vector<AnfNodePtr> all_inputs = node->inputs();
2165   std::vector<AnfNodePtr> node_inputs{all_inputs.begin() + 1, all_inputs.end()};
2166 
2167   std::string node_id = node->UniqueId();
2168   name_inputs.push_back(node_id);
2169   for (auto &input : node_inputs) {
2170     AnfNodePtr prev_node = input;
2171     if (input->isa<Parameter>()) {
2172       prev_node = FindRealInputByFormalParameter(node, input, all_nodes);
2173       if (prev_node->UniqueId() == input->UniqueId()) {
2174         name_inputs.push_back(input->UniqueId());
2175         continue;
2176       }
2177     }
2178     auto prev_cnode = prev_node->cast<CNodePtr>();
2179     PrimitivePtr prev_prim;
2180     ValueNodePtr prev_prim_anf_node;
2181 
2182     bool is_cross = CrossInterNode(&prev_cnode, &prev_prim_anf_node, &prev_prim);
2183     if (is_cross) {
2184       name_inputs.push_back(input->UniqueId());
2185       continue;
2186     }
2187 
2188     size_t output_index = 0;
2189     while (IsCarePrevCNode(prev_cnode, prev_prim)) {
2190       if (IsValueNode<FuncGraph>(prev_cnode->input(0))) {
2191         auto graph = GetValueNode<FuncGraphPtr>(prev_cnode->input(0));
2192         auto output = graph->output();
2193         MS_EXCEPTION_IF_NULL(output);
2194         prev_cnode = output->cast<CNodePtr>();
2195         (void)CrossInterNode(&prev_cnode, &prev_prim_anf_node, &prev_prim);
2196       } else if (IsAutoParallelCareNode(prev_cnode)) {
2197         name_inputs.push_back(prev_cnode->UniqueId());
2198         break;
2199       } else if (prev_prim->name() == kTupleGetItemOpName) {
2200         // In this case, 'prev_anf_node' is 'tuple_getitem', the actual precursor node is node before
2201         // this 'tuple_getitem'
2202         output_index = LongToSize(GetValue<int64_t>(GetValueNode(prev_cnode->input(INDEX_TWO))));
2203         prev_node = prev_cnode->input(1);
2204         prev_cnode = prev_node->cast<CNodePtr>();
2205 
2206         if (prev_cnode != nullptr && common::AnfAlgo::GetCNodeName(prev_cnode) == kTupleGetItemOpName) {
2207           continue;
2208         }
2209 
2210         is_cross = CrossInterNode(&prev_cnode, &prev_prim_anf_node, &prev_prim);
2211         if (is_cross) {
2212           name_inputs.push_back(prev_node->UniqueId());
2213           break;
2214         }
2215 
2216         // In dynamic shape scenarios, the situation op1->Shape->TupleGetItem->op2 will occur.
2217         // The incoming operator of op2 should be op1 instead of Shape,
2218         // so the Shape operator is skipped when looking for the incoming operator.
2219         if (prev_prim->name() == kShapeOpName) {
2220           continue;
2221         }
2222 
2223         if (!IsAutoParallelCareNode(prev_cnode) && !IsValueNode<FuncGraph>(prev_cnode->input(0))) {
2224           MS_LOG(EXCEPTION) << "Did not create OperatorInfo for : " << prev_prim->name();
2225         }
2226       } else if (prev_prim->name() == kMakeTupleOpName) {
2227         prev_node = prev_cnode->input(output_index + 1);
2228         prev_cnode = prev_node->cast<CNodePtr>();
2229         output_index = 0;
2230         is_cross = CrossInterNode(&prev_cnode, &prev_prim_anf_node, &prev_prim);
2231         if (is_cross) {
2232           name_inputs.push_back(prev_node->UniqueId());
2233           break;
2234         }
2235       } else if (IsCrossedCNode(prev_prim->name())) {
2236         // In this case, 'prev_anf_node' is 'depend', the actual precursor node is node before
2237         // this 'depend'
2238         prev_node = prev_cnode->input(1);
2239         prev_cnode = prev_node->cast<CNodePtr>();
2240         is_cross = CrossInterNode(&prev_cnode, &prev_prim_anf_node, &prev_prim);
2241         if (is_cross) {
2242           name_inputs.push_back(prev_node->UniqueId());
2243           break;
2244         }
2245       }
2246     }
2247   }
2248 
2249   return name_inputs;
2250 }
2251 
GetDistributeOperator(const CNodePtr & node)2252 OperatorInfoPtr GetDistributeOperator(const CNodePtr &node) {
2253   MS_EXCEPTION_IF_NULL(node);
2254   if (!IsParallelCareNode(node)) {
2255     return nullptr;
2256   }
2257   OperatorInfoPtr distribute_operator = node->user_data<OperatorInfo>();
2258   return distribute_operator;
2259 }
2260 
StrategyFound(const mindspore::HashMap<std::string,ValuePtr> & attrs)2261 bool StrategyFound(const mindspore::HashMap<std::string, ValuePtr> &attrs) {
2262   auto iter = attrs.find(IN_STRATEGY);
2263   return !((iter == attrs.end()) || (iter->second->type_name() == NONE));
2264 }
2265 
AttrFound(const mindspore::HashMap<std::string,ValuePtr> & attrs,const std::string & target)2266 bool AttrFound(const mindspore::HashMap<std::string, ValuePtr> &attrs, const std::string &target) {
2267   auto iter = attrs.find(target);
2268   return !((iter == attrs.end()) || (iter->second->type_name() == NONE));
2269 }
2270 
IsCommunicationOp(const PrimitivePtr & prim)2271 bool IsCommunicationOp(const PrimitivePtr &prim) {
2272   MS_EXCEPTION_IF_NULL(prim);
2273   return (COMMUNICATION_OPS.find(prim->name()) != COMMUNICATION_OPS.end());
2274 }
2275 
ExceptionIfHasCommunicationOp(const std::vector<AnfNodePtr> & all_nodes)2276 void ExceptionIfHasCommunicationOp(const std::vector<AnfNodePtr> &all_nodes) {
2277   for (auto &node : all_nodes) {
2278     MS_EXCEPTION_IF_NULL(node);
2279     if (!node->isa<CNode>()) {
2280       continue;
2281     }
2282     auto cnode = node->cast<CNodePtr>();
2283     if (!IsValueNode<Primitive>(cnode->input(0))) {
2284       continue;
2285     }
2286     ValueNodePtr prim_value_node = cnode->input(0)->cast<ValueNodePtr>();
2287     MS_EXCEPTION_IF_NULL(prim_value_node);
2288     PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_value_node);
2289     MS_EXCEPTION_IF_NULL(prim);
2290 
2291     if (IsCommunicationOp(prim) && cnode->in_forward_flag()) {
2292       MS_EXCEPTION_IF_NULL(prim_value_node->scope());
2293       MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
2294       std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
2295       MS_LOG(EXCEPTION) << "If the parallel mode is semi_auto_parallel or auto_parallel, the graph can not contain "
2296                            "communication op, the parallel mode is "
2297                         << parallel_mode << ", and the graph has communication op : " << prim->name()
2298                         << ", scope name is " << prim_value_node->scope()->name();
2299     }
2300   }
2301 }
2302 
MirrorOpName()2303 std::string MirrorOpName() {
2304   int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
2305   int64_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
2306   std::string mirror_op_name;
2307   if (split_stage_num > 1 || grad_accumulation_step > 1) {
2308     mirror_op_name = MIRROR_MICRO_STEP_OPERATOR;
2309   } else {
2310     mirror_op_name = MIRROR_OPERATOR;
2311   }
2312   return mirror_op_name;
2313 }
2314 
CheckStrategyWithTupleInTuple(const std::vector<ValuePtr> & elements)2315 bool CheckStrategyWithTupleInTuple(const std::vector<ValuePtr> &elements) {
2316   bool has_tuple_in_tuple = false;
2317   for (size_t i = 0; i < elements.size(); ++i) {
2318     if (elements[i]->isa<ValueSequence>()) {
2319       auto value_tuple = elements[i]->cast<ValueTuplePtr>();
2320       std::vector<ValuePtr> value_vector = value_tuple->value();
2321       auto local_tuple_in_tuple = std::any_of(value_vector.begin(), value_vector.end(),
2322                                               [](const ValuePtr &value) { return value->isa<ValueSequence>(); });
2323       has_tuple_in_tuple = has_tuple_in_tuple || local_tuple_in_tuple;
2324     } else {
2325       MS_LOG(EXCEPTION) << "Failure: Strategy's format is wrong! Need ValueSequence";
2326     }
2327   }
2328   MS_LOG(INFO) << "CheckStrategyWithTupleInTuple: has_tuple_in_tuple = " << has_tuple_in_tuple << ".";
2329   return has_tuple_in_tuple;
2330 }
2331 
ExtractDimensions(const ValuePtr & stra)2332 NewDimensions ExtractDimensions(const ValuePtr &stra) {
2333   auto value_tuple = stra->cast<ValueTuplePtr>();
2334   std::vector<ValuePtr> value_vector = value_tuple->value();
2335   bool has_tuple_in_tuple = std::any_of(value_vector.begin(), value_vector.end(),
2336                                         [](const ValuePtr &value) { return value->isa<ValueSequence>(); });
2337   if (has_tuple_in_tuple) {
2338     std::vector<NewDimensions> dim;
2339     (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(dim),
2340                          [](const ValuePtr &value) { return ExtractDimensions(value); });
2341     return std::make_shared<ShapeList>(dim);
2342   }
2343   Dimensions dim;
2344   (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(dim),
2345                        [](const ValuePtr &value) { return static_cast<int64_t>(GetValue<int64_t>(value)); });
2346   return std::make_shared<ShapeValue>(dim);
2347 }
2348 
ExtractNewStrategy(const std::vector<ValuePtr> & elements,const int64_t & stage_id)2349 StrategyPtr ExtractNewStrategy(const std::vector<ValuePtr> &elements, const int64_t &stage_id) {
2350   NewStrategies strategy;
2351   for (uint64_t index = 0; index < elements.size(); ++index) {
2352     if (elements[index]->isa<ValueSequence>()) {
2353       auto dim = ExtractDimensions(elements[index]);
2354       strategy.emplace_back(dim);
2355     } else {
2356       MS_LOG(EXCEPTION) << "Failure: Strategy's format is wrong! Need ValueSequence";
2357     }
2358   }
2359   if (strategy.empty()) {
2360     MS_LOG(EXCEPTION) << "ExtractStrategy: failed to extract strategy";
2361   }
2362   StrategyPtr strategyPtr = NewStrategy(stage_id, strategy);
2363   return strategyPtr;
2364 }
2365 
ExtractStrategy(const ValuePtr & stra)2366 StrategyPtr ExtractStrategy(const ValuePtr &stra) {
2367   if (stra == nullptr) {
2368     return nullptr;
2369   }
2370 
2371   auto var = stra->cast<ValueTuplePtr>();
2372   if (var == nullptr) {
2373     return nullptr;
2374   }
2375 
2376   StrategyPtr strategyPtr;
2377   int64_t stage_id = g_device_manager->stage_id();
2378 
2379   MS_LOG(INFO) << "Extract information: strategy " << stra->ToString();
2380   if (var->size() > 0) {
2381     std::vector<ValuePtr> elements = var->value();
2382     if (CheckStrategyWithTupleInTuple(elements)) {
2383       return ExtractNewStrategy(elements, stage_id);
2384     }
2385     Strategies strategy;
2386     for (uint64_t index = 0; index < elements.size(); ++index) {
2387       Dimensions dim;
2388       if (elements[index]->isa<ValueSequence>()) {
2389         auto value_tuple = elements[index]->cast<ValueTuplePtr>();
2390         std::vector<ValuePtr> value_vector = value_tuple->value();
2391         (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(dim),
2392                              [](const ValuePtr &value) { return static_cast<int64_t>(GetValue<int64_t>(value)); });
2393         strategy.push_back(dim);
2394       } else {
2395         MS_LOG(EXCEPTION) << "Failure: Strategy's format is wrong! Need ValueSequence";
2396       }
2397     }
2398     if (strategy.empty()) {
2399       MS_LOG(EXCEPTION) << "ExtractStrategy: failed to extract strategy";
2400     }
2401     strategyPtr = NewStrategy(stage_id, strategy);
2402   }
2403   return strategyPtr;
2404 }
2405 
GetLayoutFromAttrValue(const ValuePtr & layout_item,std::vector<int64_t> * device_matrix_vector,std::vector<std::vector<int64_t>> * tensor_map_vector,bool * interleaved_parallel)2406 Status GetLayoutFromAttrValue(const ValuePtr &layout_item, std::vector<int64_t> *device_matrix_vector,
2407                               std::vector<std::vector<int64_t>> *tensor_map_vector, bool *interleaved_parallel) {
2408   auto layout_dict_value = layout_item->cast<ValueDictionaryPtr>();
2409   if (!layout_dict_value) {
2410     MS_LOG(ERROR) << "The layout item configured for node is unreasonable";
2411     return FAILED;
2412   }
2413   auto layout_dict = layout_dict_value->value();
2414   ValuePtr device_matrix_value = nullptr;
2415   ValuePtr tensor_map_value = nullptr;
2416   ValuePtr interleaved_parallel_value = nullptr;
2417   for (const auto &value_pair : layout_dict) {
2418     if ((*value_pair.first) == (*MakeValue<std::string>(DEVICE_MATRIX))) {
2419       device_matrix_value = value_pair.second;
2420     }
2421     if ((*value_pair.first) == (*MakeValue<std::string>(TENSOR_MAP))) {
2422       tensor_map_value = value_pair.second;
2423     }
2424     if ((*value_pair.first) == (*MakeValue<std::string>(INTERLEAVED_PARALLEL))) {
2425       interleaved_parallel_value = value_pair.second;
2426     }
2427   }
2428   if (!device_matrix_value || !tensor_map_value || !interleaved_parallel_value) {
2429     MS_LOG(ERROR) << "The layout item configured for node is unreasonable";
2430     return FAILED;
2431   }
2432   *device_matrix_vector = GetValue<std::vector<int64_t>>(device_matrix_value);
2433   *interleaved_parallel = GetValue<bool>(interleaved_parallel_value);
2434   auto tensor_map_value_tuple = tensor_map_value->cast<ValueTuplePtr>();
2435   std::vector<ValuePtr> tensor_map_value_tuple_vector = tensor_map_value_tuple->value();
2436   for (const auto &tensor_map_item : tensor_map_value_tuple_vector) {
2437     if (tensor_map_item->isa<ValueSequence>()) {
2438       auto tensor_map_item_v = GetValue<std::vector<int64_t>>(tensor_map_item);
2439       tensor_map_vector->push_back(tensor_map_item_v);
2440       continue;
2441     }
2442     auto tensor_map_item_i = GetValue<int64_t>(tensor_map_item);
2443     tensor_map_vector->push_back({tensor_map_item_i});
2444   }
2445   return SUCCESS;
2446 }
2447 
ExtractUserConfigLayout(const mindspore::HashMap<std::string,ValuePtr> & prim_attrs,const Shapes & inputs_shape,const Shapes & outputs_shape,std::vector<std::shared_ptr<TensorLayout>> * in_tensor_layouts,std::vector<std::shared_ptr<TensorLayout>> * out_tensor_layouts)2448 Status ExtractUserConfigLayout(const mindspore::HashMap<std::string, ValuePtr> &prim_attrs, const Shapes &inputs_shape,
2449                                const Shapes &outputs_shape,
2450                                std::vector<std::shared_ptr<TensorLayout>> *in_tensor_layouts,
2451                                std::vector<std::shared_ptr<TensorLayout>> *out_tensor_layouts) {
2452   if (prim_attrs.count(IN_LAYOUT) > 0) {
2453     auto layout_value = prim_attrs.at(IN_LAYOUT);
2454     if (!layout_value->isa<ValueSequence>()) {
2455       MS_LOG(ERROR) << "The in_layout configured for node is not a tuple";
2456       return FAILED;
2457     }
2458     auto layout_value_tuple = layout_value->cast<ValueTuplePtr>();
2459     std::vector<ValuePtr> layout_value_vector = layout_value_tuple->value();
2460     if (inputs_shape.size() != layout_value_vector.size()) {
2461       MS_LOG(ERROR) << "The in_layout configured for node is not equal to its input nums";
2462       return FAILED;
2463     }
2464 
2465     for (size_t i = 0; i < layout_value_vector.size(); ++i) {
2466       auto layout_item = layout_value_vector[i];
2467       std::vector<int64_t> device_matrix_vector;
2468       std::vector<std::vector<int64_t>> tensor_map_vector;
2469       bool interleaved_parallel;
2470       if (GetLayoutFromAttrValue(layout_item, &device_matrix_vector, &tensor_map_vector, &interleaved_parallel) !=
2471           SUCCESS) {
2472         return FAILED;
2473       }
2474       auto in_layout = std::make_shared<TensorLayout>();
2475       if (in_layout->InitFromExtendVector(device_matrix_vector, tensor_map_vector, inputs_shape[i],
2476                                           interleaved_parallel) != SUCCESS) {
2477         MS_LOG(ERROR) << "The in_layout configured incorrect, device_matrix:" << device_matrix_vector
2478                       << ", tensor_map:" << tensor_map_vector;
2479         return FAILED;
2480       }
2481       in_tensor_layouts->push_back(in_layout);
2482     }
2483   }
2484   if (prim_attrs.count(OUT_LAYOUT) > 0) {
2485     auto layout_value = prim_attrs.at(OUT_LAYOUT);
2486     if (!layout_value->isa<ValueSequence>()) {
2487       MS_LOG(EXCEPTION) << "The in_layout configured for node is not a tuple";
2488     }
2489     auto layout_value_tuple = layout_value->cast<ValueTuplePtr>();
2490     std::vector<ValuePtr> layout_value_vector = layout_value_tuple->value();
2491     if (outputs_shape.size() != layout_value_vector.size()) {
2492       MS_LOG(EXCEPTION) << "The out_layout configured for node is not equal to its output nums";
2493     }
2494     for (size_t i = 0; i < layout_value_vector.size(); ++i) {
2495       auto layout_item = layout_value_vector[i];
2496       std::vector<int64_t> device_matrix_vector;
2497       std::vector<std::vector<int64_t>> tensor_map_vector;
2498       bool interleaved_parallel;
2499       if (GetLayoutFromAttrValue(layout_item, &device_matrix_vector, &tensor_map_vector, &interleaved_parallel) !=
2500           SUCCESS) {
2501         return FAILED;
2502       }
2503       auto out_layout = std::make_shared<TensorLayout>();
2504       if (out_layout->InitFromExtendVector(device_matrix_vector, tensor_map_vector, outputs_shape[i],
2505                                            interleaved_parallel) != SUCCESS) {
2506         MS_LOG(ERROR) << "The out_layout configured incorrect, device_matrix:" << device_matrix_vector
2507                       << ", tensor_map:" << tensor_map_vector;
2508         return FAILED;
2509       }
2510       out_tensor_layouts->push_back(out_layout);
2511     }
2512   }
2513   return SUCCESS;
2514 }
2515 
IsCohesiveNode(const CNodePtr & cnode)2516 static bool IsCohesiveNode(const CNodePtr &cnode) {
2517   return IsPrimitiveCNode(cnode, prim::kPrimCast) || IsPrimitiveCNode(cnode, prim::kPrimLoad) ||
2518          IsPrimitiveCNode(cnode, prim::kPrimDepend) || IsPrimitiveCNode(cnode, prim::kPrimAllGather) ||
2519          IsPrimitiveCNode(cnode, prim::kPrimMiniStepAllGather) || IsPrimitiveCNode(cnode, prim::kPrimMirrorMicroStep) ||
2520          IsPrimitiveCNode(cnode, prim::kPrimMicroStepAllGather) || IsPrimitiveCNode(cnode, prim::kPrimMirror) ||
2521          IsPrimitiveCNode(cnode, prim::kPrimMirrorMiniStep) || IsPrimitiveCNode(cnode, prim::kPrimVirtualDiv);
2522 }
2523 
NodeParameterName(const CNodePtr & node,int64_t index,size_t curr_depth)2524 ParameterMap NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth) {
2525   if (curr_depth > MAX_RECURSIVE_DEPTH) {
2526     MS_LOG(WARNING) << "When finding the parameters' name of a operator, exceeded the maximum depth: "
2527                     << MAX_RECURSIVE_DEPTH;
2528     return {};
2529   }
2530   bool only_trainable_params = ParallelContext::GetInstance()->stra_file_only_trainable_params();
2531   std::vector<AnfNodePtr> node_inputs{node->inputs()};
2532   ParameterMap param_names;
2533   for (int64_t i = 0; i < UlongToLong(node_inputs.size()); ++i) {
2534     int64_t idx = index > i ? index : i;
2535     auto input = node_inputs[LongToSize(i)];
2536     if (input->isa<Parameter>()) {
2537       auto input_parameter = input->cast<ParameterPtr>();
2538       if (input_parameter->has_default() && (!only_trainable_params || ParameterRequireGrad(input_parameter))) {
2539         (void)param_names.emplace_back(std::make_pair(input_parameter->name(), input_parameter));
2540         continue;
2541       }
2542       auto actual_param_node = RefParameterToActualParameter(input_parameter);
2543       if (!actual_param_node) {
2544         continue;
2545       }
2546       auto actual_param = actual_param_node->cast<ParameterPtr>();
2547       if (!only_trainable_params || ParameterRequireGrad(actual_param)) {
2548         (void)param_names.emplace_back(std::make_pair(actual_param->name(), actual_param));
2549       }
2550     } else if (input->isa<CNode>()) {
2551       CNodePtr cnode = input->cast<CNodePtr>();
2552       if (!IsValueNode<Primitive>(cnode->input(0))) {
2553         continue;
2554       }
2555       if (IsCohesiveNode(cnode) && cnode->size() >= 1) {
2556         auto input_param_names = NodeParameterName(cnode, idx, 0);
2557         (void)param_names.insert(param_names.cend(), input_param_names.cbegin(), input_param_names.cend());
2558       }
2559     }
2560   }
2561   return param_names;
2562 }
2563 
ParallelInit(size_t rank_id,const size_t devices)2564 Status ParallelInit(size_t rank_id, const size_t devices) {
2565   MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
2566   int32_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
2567 
2568   std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
2569   if (split_stage_num <= 0) {
2570     MS_LOG(ERROR) << "The parameter 'split_stage_num' must be a positive number, but got the value : "
2571                   << split_stage_num;
2572     return FAILED;
2573   }
2574   int64_t device_num;
2575   int64_t global_rank;
2576   std::string backend;
2577   if (devices == 0) {
2578     auto comm_info = GetCommInfo();
2579     device_num = comm_info.device_num;
2580     global_rank = comm_info.global_rank;
2581     backend = comm_info.communication_backend;
2582   } else {
2583     device_num = devices;
2584     global_rank = rank_id;
2585     backend = HCCL_BACKEND;
2586   }
2587 
2588   if ((device_num <= 0) || (device_num > MAX_DEVICE_NUM)) {
2589     MS_LOG(ERROR) << "The context configuration parameter 'device_num' must be positive, "
2590                      "but got the value of device_num: "
2591                   << device_num;
2592     return FAILED;
2593   }
2594 
2595   // the device_num maybe get from communication interface
2596   if (device_num % split_stage_num != 0) {
2597     MS_LOG(ERROR) << "The parameter 'device_num' must be divided by 'split_stage_num', but got the device_num : "
2598                   << device_num << "and the split_stage_num : " << split_stage_num;
2599     return FAILED;
2600   }
2601 
2602   int64_t optimizer_weight_shard_size = ParallelContext::GetInstance()->optimizer_weight_shard_size();
2603   if (ParallelContext::GetInstance()->enable_parallel_optimizer() && optimizer_weight_shard_size > 0 &&
2604       device_num < optimizer_weight_shard_size) {
2605     MS_LOG(ERROR) << "When parallel_optimizer is enabled, the optimizer_weight_shard_size "
2606                   << optimizer_weight_shard_size << " should not exceed the device num " << device_num << ".";
2607     return FAILED;
2608   }
2609 
2610   if ((global_rank < 0) || (global_rank >= device_num)) {
2611     MS_LOG(ERROR) << "The parameter 'global_rank' must be  greater than 0 and less equal 'device num', "
2612                      "but got the global_rank : "
2613                   << global_rank << "and the device_num : " << device_num;
2614     return FAILED;
2615   }
2616 
2617   std::vector<int64_t> stages;
2618   for (int i = 0; i < split_stage_num; i++) {
2619     stages.push_back(device_num / split_stage_num);
2620   }
2621 
2622   bool use_rec = (ParallelContext::GetInstance()->strategy_search_mode() == kRecursiveProgramming);
2623   bool use_sp = (ParallelContext::GetInstance()->strategy_search_mode() == kShardingPropagation) ||
2624                 (ParallelContext::GetInstance()->sharding_propagation());
2625   if ((split_stage_num > 1) && (parallel_mode == kAutoParallel) && !(use_sp || use_rec)) {
2626     MS_LOG(ERROR) << "To enable the pipeline parallel, please set the parallel mode to " << kSemiAutoParallel << " or "
2627                   << kAutoParallel << " with " << kShardingPropagation << " or " << kRecursiveProgramming;
2628     return FAILED;
2629   }
2630 
2631   if (!InitDevice(device_num, global_rank, backend, stages)) {
2632     MS_LOG(ERROR) << "Init device failed";
2633     return FAILED;
2634   }
2635 
2636   MS_LOG(INFO) << "The parallel context: device_num: " << device_num << ", global_rank: "
2637                << global_rank
2638                //               << ", communication_backend: " << comm_info.communication_backend
2639                << ", communication_backend: " << HCCL_BACKEND
2640                << ", gradients_mean: " << ParallelContext::GetInstance()->gradients_mean()
2641                << ", gradient_fp32_sync: " << ParallelContext::GetInstance()->gradient_fp32_sync();
2642   return SUCCESS;
2643 }
2644 
2645 // only used for FindCNode
SkipTrivialNodesMoveDown(const FuncGraphManagerPtr & manager,CNodePtr node)2646 static CNodePtr SkipTrivialNodesMoveDown(const FuncGraphManagerPtr &manager, CNodePtr node) {
2647   MS_EXCEPTION_IF_NULL(node);
2648   while (IsInTrivialNodeList(node) || IsSomePrimitive(node, LOAD)) {
2649     node = manager->node_users()[node].begin()->first->cast<CNodePtr>();
2650   }
2651   return node;
2652 }
2653 
FindCNode(const AnfNodePtr & anode,const std::string & name,const FuncGraphPtr & func_graph,size_t max_depth)2654 std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string &name, const FuncGraphPtr &func_graph,
2655                                     size_t max_depth) {
2656   MS_EXCEPTION_IF_NULL(anode);
2657   MS_EXCEPTION_IF_NULL(anode->func_graph());
2658   FuncGraphManagerPtr manager = anode->func_graph()->manager();
2659   MS_EXCEPTION_IF_NULL(manager);
2660   if (max_depth > MAX_RECURSIVE_DEPTH) {
2661     MS_LOG(EXCEPTION) << "Recursive call is larger than 100000.";
2662   }
2663   AnfNodeIndexSet node_set = manager->node_users()[anode];
2664   bool result = false;
2665   CNodePtr cnode_return = nullptr;
2666   for (auto &node_pair : node_set) {
2667     CNodePtr use_apply = node_pair.first->cast<CNodePtr>();
2668     if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
2669       continue;
2670     }
2671     use_apply = SkipTrivialNodesMoveDown(manager, use_apply);
2672     if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
2673       continue;
2674     }
2675     ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
2676     MS_EXCEPTION_IF_NULL(prim_anf_node);
2677     PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
2678     MS_EXCEPTION_IF_NULL(node_prim);
2679     if (node_prim->name() == name && node_pair.second == 1) {
2680       if (use_apply->func_graph() == func_graph) {
2681         result = true;
2682         cnode_return = use_apply;
2683         MS_LOG(INFO) << "Find Primitive " << name << " in the same func_graph";
2684         continue;
2685       }
2686       MS_LOG(INFO) << "Find Primitive " << name << " in different func_graph";
2687     }
2688     if (ParallelContext::GetInstance()->enable_parallel_optimizer() && IsInAllGatherNodeList(use_apply)) {
2689       return FindCNode(node_pair.first, name, func_graph, max_depth + 1);
2690     }
2691   }
2692   return std::make_pair(result, cnode_return);
2693 }
2694 
SetSharedParameterFlag(const FuncGraphPtr & root,const AnfNodePtr & parameter)2695 void SetSharedParameterFlag(const FuncGraphPtr &root, const AnfNodePtr &parameter) {
2696   MS_EXCEPTION_IF_NULL(root);
2697   MS_EXCEPTION_IF_NULL(parameter);
2698   FuncGraphManagerPtr manager = root->manager();
2699   MS_EXCEPTION_IF_NULL(manager);
2700   ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>();
2701   if (parameter_ptr == nullptr) {
2702     MS_LOG(INFO) << parameter->ToString() << ": cast to ptr failed. it may not be a parameter";
2703     return;
2704   }
2705   auto user_set = manager->node_users()[parameter];
2706   int32_t user_count = 0;
2707   for (auto &param_pair : user_set) {
2708     CNodePtr cnode = param_pair.first->cast<CNodePtr>();
2709     MS_EXCEPTION_IF_NULL(cnode);
2710     if (cnode->in_forward_flag()) {
2711       user_count++;
2712     }
2713   }
2714   if (user_count > 1) {
2715     auto tensor_layout = parameter_ptr->user_data<TensorLayout>();
2716     tensor_layout->set_is_shared_param(true);
2717   }
2718 }
2719 
GenerateBatchParallelStrategy(const OperatorInfoPtr operator_,const PrimitivePtr prim)2720 StrategyPtr GenerateBatchParallelStrategy(const OperatorInfoPtr operator_, const PrimitivePtr prim) {
2721   MS_EXCEPTION_IF_NULL(operator_);
2722   MS_EXCEPTION_IF_NULL(prim);
2723   if (!operator_->inputs_shape_new().empty()) {
2724     MS_LOG(EXCEPTION) << "Currently, tuple in tuple input does not support GenerateBatchParallelStrategy, please set "
2725                          "strategy in python side";
2726   }
2727   StrategyPtr strategyPtr;
2728   std::shared_ptr<Strategies> strategy_v_ptr = operator_->GenerateBatchStrategiesWithCheck();
2729   MS_EXCEPTION_IF_NULL(strategy_v_ptr);
2730   auto stage_id = g_device_manager->stage_id();
2731   strategyPtr = NewStrategy(stage_id, *strategy_v_ptr);
2732   std::vector<ValuePtr> elements;
2733   for (size_t i = 0; i < strategy_v_ptr->size(); i++) {
2734     elements.push_back(MakeValue((*strategy_v_ptr)[i]));
2735   }
2736   ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements);
2737   // display the strategy generated by batch parallel
2738   auto attrs = prim->attrs();
2739   attrs[GEN_STRATEGY] = strategy;
2740   (void)prim->SetAttrs(attrs);
2741   MS_LOG(INFO) << "prim " << prim->name() << " batch parallel strategy is " << attrs[GEN_STRATEGY]->ToString();
2742   return strategyPtr;
2743 }
2744 
GenerateStandAloneStrategy(const Shapes & inputs_shape)2745 StrategyPtr GenerateStandAloneStrategy(const Shapes &inputs_shape) {
2746   Strategies strategy_v;
2747   for (size_t i = 0; i != inputs_shape.size(); i++) {
2748     if (inputs_shape[i].empty()) {
2749       MS_LOG(INFO) << "Elements of shapes is empty.";
2750       Dimensions empty_element;
2751       strategy_v.push_back(empty_element);
2752     } else {
2753       Dimensions element(inputs_shape[i].size(), 1);
2754       strategy_v.push_back(element);
2755     }
2756   }
2757   auto stage_id = g_device_manager->stage_id();
2758   auto stra_ptr = NewStrategy(stage_id, strategy_v);
2759   return stra_ptr;
2760 }
2761 
GenerateStra(const ShapeBasePtr & shape)2762 ShapeBasePtr GenerateStra(const ShapeBasePtr &shape) {
2763   ShapeBasePtr out_shape;
2764   if (shape->is_list()) {
2765     std::vector<ShapeBasePtr> list_stra;
2766     for (size_t i = 0; i < shape->size(); ++i) {
2767       auto recursive_stra = GenerateStra(shape->GetElement(SizeToLong(i)));
2768       list_stra.emplace_back(recursive_stra);
2769     }
2770     out_shape = std::make_shared<ShapeList>(list_stra);
2771   } else {
2772     if (shape->empty()) {
2773       MS_LOG(INFO) << "Elements of shapes is empty.";
2774       Dimensions empty_element;
2775       out_shape = std::make_shared<ShapeValue>(empty_element);
2776     } else {
2777       Dimensions element(shape->size(), 1);
2778       out_shape = std::make_shared<ShapeValue>(element);
2779     }
2780   }
2781   return out_shape;
2782 }
2783 
GenerateStandAloneStrategyForNewShapes(const NewShapes & inputs_shape)2784 StrategyPtr GenerateStandAloneStrategyForNewShapes(const NewShapes &inputs_shape) {
2785   NewStrategies strategy_v;
2786   for (size_t i = 0; i != inputs_shape.size(); i++) {
2787     strategy_v.emplace_back(GenerateStra(inputs_shape[i]));
2788   }
2789   auto stage_id = g_device_manager->stage_id();
2790   auto stra_ptr = NewStrategy(stage_id, strategy_v);
2791   return stra_ptr;
2792 }
2793 
IsInsertVirtualOutput(const FuncGraphPtr & root)2794 bool IsInsertVirtualOutput(const FuncGraphPtr &root) {
2795   MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
2796   auto comm_info = GetCommInfo();
2797   int64_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
2798   int64_t per_stage_device_num = comm_info.device_num / split_stage_num;
2799   int64_t current_stage = comm_info.global_rank / per_stage_device_num;
2800   MS_LOG(INFO) << "The current stage is: " << current_stage;
2801   if (!root->has_flag(kTraining) && !ParallelContext::GetInstance()->dataset_strategy().empty()) {
2802     MS_LOG(WARNING) << "In eval/predict net, the output parallel strategy would not follow "
2803                        "the input parallel strategy when using context.set_auto_parallel_context(dataset_strategy)"
2804                        " to configure the input strategy.";
2805   }
2806   return ((!root->has_flag(kTraining) && ParallelContext::GetInstance()->dataset_strategy().empty() &&
2807            current_stage == split_stage_num - 1) ||
2808           IsPynativeParallel());
2809 }
2810 
GetInputLayoutFromCNode(const std::pair<AnfNodePtr,int64_t> & node_pair,const int & make_tuple_index)2811 TensorLayout GetInputLayoutFromCNode(const std::pair<AnfNodePtr, int64_t> &node_pair, const int &make_tuple_index) {
2812   CNodePtr cnode = node_pair.first->cast<CNodePtr>();
2813   MS_EXCEPTION_IF_NULL(cnode);
2814   OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode);
2815   MS_EXCEPTION_IF_NULL(distribute_operator);
2816   int64_t index = node_pair.second;
2817   TensorLayout tensorlayout_in;
2818   if (distribute_operator->inputs_tensor_info_new().empty()) {
2819     if (index > SizeToLong(distribute_operator->inputs_tensor_info().size())) {
2820       MS_LOG(EXCEPTION) << "The index is out of range, the node_pair.second is  " << (index - 1)
2821                         << ", the vector size is  " << distribute_operator->inputs_tensor_info().size();
2822     }
2823     TensorInfo tensorinfo_in = distribute_operator->inputs_tensor_info()[LongToSize(index - 1)];
2824     tensorlayout_in = tensorinfo_in.tensor_layout();
2825   } else {
2826     if (index > SizeToLong(distribute_operator->inputs_tensor_info_new().size())) {
2827       MS_LOG(EXCEPTION) << "The index is out of range, the node_pair.second is  " << (index - 1)
2828                         << ", the vector size is  " << distribute_operator->inputs_tensor_info_new().size();
2829     }
2830     auto tensorinfo_in = distribute_operator->inputs_tensor_info_new()[LongToSize(index - 1)];
2831     if (tensorinfo_in->is_list() && make_tuple_index != -1) {
2832       auto new_tensorinfo_in = tensorinfo_in->GetElement(make_tuple_index - 1);
2833       tensorlayout_in = new_tensorinfo_in->GetValue().tensor_layout();
2834     } else if (!tensorinfo_in->is_list() && make_tuple_index == -1) {
2835       tensorlayout_in = tensorinfo_in->GetValue().tensor_layout();
2836     } else {
2837       MS_LOG(EXCEPTION) << "tensorinfo_in does not match with make_tuple_index: make_tuple_index is "
2838                         << make_tuple_index << ", node is " << node_pair.first->DebugString();
2839     }
2840   }
2841   return tensorlayout_in;
2842 }
2843 
IsCellReuseForwardGraph(const FuncGraphPtr & graph)2844 bool IsCellReuseForwardGraph(const FuncGraphPtr &graph) { return graph->has_flag(FUNC_GRAPH_FLAG_CELL_REUSE); }
2845 
GetCellReuseBackwardGraph(const FuncGraphPtr & forward_graph)2846 FuncGraphPtr GetCellReuseBackwardGraph(const FuncGraphPtr &forward_graph) {
2847   AnfNodePtr node = forward_graph->get_return();
2848   std::vector<std::pair<PrimitivePtr, int64_t>> patterns = {
2849     {prim::kPrimReturn, kIndex1}, {prim::kPrimMakeTuple, kIndex2}, {prim::kPrimPartial, kIndex1}};
2850   for (const auto &pattern : patterns) {
2851     auto cnode = node->cast<CNodePtr>();
2852     if (cnode == nullptr || !IsPrimitiveCNode(cnode, pattern.first)) {
2853       return nullptr;
2854     }
2855     auto prev_node_index = pattern.second;
2856     if (prev_node_index >= SizeToLong(cnode->size())) {
2857       return nullptr;
2858     }
2859     node = cnode->input(prev_node_index);
2860   }
2861   return GetValueNode<FuncGraphPtr>(node);
2862 }
2863 
mirror_group_list(const TensorLayoutPtr & layout)2864 Shape mirror_group_list(const TensorLayoutPtr &layout) {
2865   int64_t rank = g_device_manager->global_rank();
2866   auto stage_dev_list = g_device_manager->GetDeviceListInThisStage();
2867   DeviceMatrix dev_matrix(rank, stage_dev_list, layout->device_arrangement().array());
2868   RankList group_devices;
2869   if (dev_matrix.GetDevicesByTensorMap(layout->tensor_map().array(), &group_devices) != SUCCESS) {
2870     MS_LOG(EXCEPTION) << "For layout:" << layout->ToString() << ", infer mirror failed";
2871   }
2872   return group_devices;
2873 }
2874 
ChangeAllGatherGroup(const CNodePtr & ag_cnode,const RankList & new_group_ranks)2875 void ChangeAllGatherGroup(const CNodePtr &ag_cnode, const RankList &new_group_ranks) {
2876   Group new_group;
2877   if (g_device_manager->CreateGroup(new_group_ranks, &new_group) != SUCCESS) {
2878     MS_LOG(EXCEPTION) << ": Create communication group failed, the rank_list is: " << new_group_ranks;
2879   }
2880   auto ag_prim = GetCNodePrimitive(ag_cnode);
2881   ag_prim->AddAttr(GROUP, MakeValue(new_group.name()));
2882   ag_prim->AddAttr(GROUP_RANKS, MakeValue(g_device_manager->FindRankListNameByHashName(new_group.name())));
2883   ag_prim->AddAttr(RANK_SIZE, MakeValue<int64_t>(new_group_ranks.size()));
2884 }
2885 
InterleavedReplacedConcatNodes(const std::vector<CNodePtr> & ag_vector)2886 std::vector<CNodePtr> InterleavedReplacedConcatNodes(const std::vector<CNodePtr> &ag_vector) {
2887   std::vector<CNodePtr> replace_nodes;
2888   for (const auto &ag : ag_vector) {
2889     auto ag_next_nodes = GetOutputNodesWithFilter(ag, [&](const AnfNodePtr &anode) {
2890       return IsPrimitiveCNode(anode, prim::kPrimSplit) || IsPrimitiveCNode(anode, prim::kPrimTupleGetItem) ||
2891              IsPrimitiveCNode(anode, prim::kPrimMakeTuple);
2892     });
2893     std::set<AnfNodePtr> next_nodes_set;
2894     std::transform(ag_next_nodes.begin(), ag_next_nodes.end(), std::inserter(next_nodes_set, next_nodes_set.begin()),
2895                    [](auto pair) { return pair.first; });
2896     if (!(next_nodes_set.size() == kSizeOne && IsPrimitiveCNode(ag_next_nodes.front().first, prim::kPrimConcat))) {
2897       continue;
2898     }
2899     auto concat_cnode = ag_next_nodes.front().first->cast<CNodePtr>();
2900     auto concat_prim = GetCNodePrimitive(concat_cnode);
2901     if (concat_prim->instance_name().find(REDISTRIBUTION_OP) != std::string::npos) {
2902       replace_nodes.push_back(concat_cnode);
2903     }
2904   }
2905   return replace_nodes;
2906 }
2907 
CreateInterleavedNeedReplaceOpLists(const CNodePtr & virtual_converter_end,const PrimitivePtr & r_prim)2908 std::vector<std::vector<CNodePtr>> CreateInterleavedNeedReplaceOpLists(const CNodePtr &virtual_converter_end,
2909                                                                        const PrimitivePtr &r_prim) {
2910   std::vector<std::vector<CNodePtr>> need_replace_op_lists;
2911   for (size_t j = 1; j < virtual_converter_end->size(); ++j) {
2912     auto current_node = virtual_converter_end->input(j)->cast<CNodePtr>();
2913     MS_EXCEPTION_IF_NULL(current_node);
2914     std::vector<CNodePtr> need_replace_op_list;
2915     while (!IsPrimitiveCNode(current_node, prim::kPrimVirtualConverterBegin)) {
2916       if (IsPrimitiveCNode(current_node, r_prim)) {
2917         need_replace_op_list.push_back(current_node);
2918       }
2919       current_node = current_node->input(kIndex1)->cast<CNodePtr>();
2920       MS_EXCEPTION_IF_NULL(current_node);
2921     }
2922     need_replace_op_lists.push_back(need_replace_op_list);
2923   }
2924   return need_replace_op_lists;
2925 }
2926 
ReplaceInterleavedAllGatherToConcat(const FuncGraphPtr & func_graph,const std::vector<CNodePtr> & ag_vector,const std::vector<std::vector<int64_t>> & new_group_ranks_vector,size_t independent_size)2927 CNodePtr ReplaceInterleavedAllGatherToConcat(const FuncGraphPtr &func_graph, const std::vector<CNodePtr> &ag_vector,
2928                                              const std::vector<std::vector<int64_t>> &new_group_ranks_vector,
2929                                              size_t independent_size) {
2930   std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple->Clone())};
2931   std::transform(ag_vector.begin(), ag_vector.end(), std::back_inserter(make_tuple_inputs),
2932                  [&](auto node) { return independent_size == 1 ? node->input(kIndex1) : node; });
2933   auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
2934   auto replace_nodes = InterleavedReplacedConcatNodes(ag_vector);
2935   bool replace_concat = (!replace_nodes.empty() && independent_size == 1);
2936   AnfNodePtr axis = NewValueNode(MakeValue<int64_t>(0));
2937   if (replace_concat) {
2938     axis = replace_nodes.front()->input(kIndex2);
2939   }
2940   std::vector<AnfNodePtr> concat_inputs = {NewValueNode(prim::kPrimConcat->Clone()), make_tuple, axis};
2941   auto concat = func_graph->NewCNode(concat_inputs);
2942   concat->AddAttr(INTERLEAVED_PARALLEL, MakeValue(true));
2943   auto manager = func_graph->manager();
2944 
2945   for (size_t i = 0; i < ag_vector.size(); ++i) {
2946     auto ag = ag_vector[i];
2947     if (independent_size != 1) {
2948       // set allgather attrs
2949       ChangeAllGatherGroup(ag, new_group_ranks_vector[i]);
2950     }
2951     if (!replace_concat) {
2952       (void)manager->Replace(ag, concat);
2953     }
2954   }
2955   if (!replace_concat) {
2956     return concat;
2957   }
2958   for (size_t i = 0; i < replace_nodes.size(); ++i) {
2959     (void)manager->Replace(replace_nodes[i], concat);
2960   }
2961   return concat;
2962 }
2963 
MergeOpBeforeInterleaveSlice(const FuncGraphPtr & func_graph,const CNodePtr & virtual_converter_end)2964 void MergeOpBeforeInterleaveSlice(const FuncGraphPtr &func_graph, const CNodePtr &virtual_converter_end) {
2965   std::vector<std::vector<CNodePtr>> need_replace_op_lists =
2966     CreateInterleavedNeedReplaceOpLists(virtual_converter_end, prim::kPrimStridedSlice);
2967   auto manager = func_graph->manager();
2968   if (need_replace_op_lists.empty()) {
2969     return;
2970   }
2971   auto col_size = need_replace_op_lists.front().size();
2972   for (size_t i = 0; i < need_replace_op_lists.size(); ++i) {
2973     if (need_replace_op_lists[i].size() != col_size) {
2974       MS_LOG(INTERNAL_EXCEPTION) << "Slice redistribution infer failed.";
2975     }
2976   }
2977   for (size_t col = 0; col < col_size; ++col) {
2978     std::set<std::vector<std::vector<int64_t>>> slice_value_list_set;
2979     for (size_t row = 0; row < need_replace_op_lists.size(); ++row) {
2980       auto slice_cnode = need_replace_op_lists[row][col];
2981       std::vector<std::vector<int64_t>> slice_value_list;
2982       for (size_t i = 2; i < kSizeFive; ++i) {
2983         ValuePtr slice_value = GetValueNode(slice_cnode->input(i));
2984         MS_EXCEPTION_IF_NULL(slice_value);
2985         auto value_vector = GetValue<std::vector<int64_t>>(slice_value);
2986         slice_value_list.push_back(value_vector);
2987       }
2988       slice_value_list_set.insert(slice_value_list);
2989     }
2990     if (slice_value_list_set.size() != need_replace_op_lists.size()) {
2991       continue;
2992     }
2993     // merge nodes before multi slice
2994     auto slice_input = need_replace_op_lists[kIndex0][col]->input(kIndex1);
2995     need_replace_op_lists[kIndex0][col]->AddAttr(INTERLEAVED_PARALLEL, MakeValue(true));
2996     for (size_t row = 1; row < need_replace_op_lists.size(); ++row) {
2997       auto slice_cnode = need_replace_op_lists[row][col];
2998       slice_cnode->AddAttr(INTERLEAVED_PARALLEL, MakeValue(true));
2999       (void)manager->SetEdge(slice_cnode, kIndex1, slice_input);
3000     }
3001   }
3002 }
3003 
ConvertInterleaveAllGatherToConcat(const FuncGraphPtr & func_graph,const CNodePtr & virtual_converter_end,const std::vector<std::vector<std::vector<int64_t>>> & ag_group_ranks_vectors)3004 void ConvertInterleaveAllGatherToConcat(const FuncGraphPtr &func_graph, const CNodePtr &virtual_converter_end,
3005                                         const std::vector<std::vector<std::vector<int64_t>>> &ag_group_ranks_vectors) {
3006   // Change communication rank_list && Create communication group
3007   // Replace AllConcat to Concat
3008   std::vector<std::vector<CNodePtr>> need_replace_op_lists =
3009     CreateInterleavedNeedReplaceOpLists(virtual_converter_end, prim::kPrimAllGather);
3010   MergeOpBeforeInterleaveSlice(func_graph, virtual_converter_end);
3011   if (need_replace_op_lists.size() != ag_group_ranks_vectors.size()) {
3012     MS_LOG(INTERNAL_EXCEPTION) << "AllGather redistribution infer failed.";
3013   }
3014   if (need_replace_op_lists.empty()) {
3015     return;
3016   }
3017   auto col_size = need_replace_op_lists.front().size();
3018   for (size_t i = 0; i < need_replace_op_lists.size(); ++i) {
3019     if (need_replace_op_lists[i].size() != col_size || ag_group_ranks_vectors[i].size() != col_size) {
3020       MS_LOG(INTERNAL_EXCEPTION) << "AllGather redistribution infer failed.";
3021     }
3022   }
3023   auto interleaved_num = ParallelContext::GetInstance()->fine_grained_micro_interleaved_size();
3024   for (size_t col = 0; col < col_size; ++col) {
3025     std::vector<std::vector<int64_t>> new_group_ranks_vector;
3026     std::vector<CNodePtr> ag_vector;
3027     size_t independent_size = 0;
3028     for (size_t row = 0; row < need_replace_op_lists.size(); ++row) {
3029       auto group_ranks = ag_group_ranks_vectors[row][col];
3030       std::vector<int64_t> new_group_ranks;
3031       std::set<int64_t> new_group_ranks_set;
3032       for (const auto &g_rank : group_ranks) {
3033         new_group_ranks_set.insert(int64_t(g_rank / interleaved_num));
3034         new_group_ranks.push_back(int64_t(g_rank / interleaved_num));
3035       }
3036       if (new_group_ranks_set.size() == new_group_ranks.size()) {
3037         // set allgather attrs
3038         ChangeAllGatherGroup(need_replace_op_lists[row][col], new_group_ranks);
3039         continue;
3040       }
3041       std::vector<int64_t> new_group_ranks_no_repeat;
3042       std::copy(new_group_ranks_set.begin(), new_group_ranks_set.end(), std::back_inserter(new_group_ranks_no_repeat));
3043       std::sort(new_group_ranks_no_repeat.begin(), new_group_ranks_no_repeat.end());
3044       new_group_ranks_vector.push_back(new_group_ranks_no_repeat);
3045       if (independent_size > 0 && new_group_ranks_no_repeat.size() != independent_size) {
3046         MS_LOG(INTERNAL_EXCEPTION) << "The concat group in micro interleaved is wrong!";
3047       }
3048       independent_size = new_group_ranks_no_repeat.size();
3049       ag_vector.push_back(need_replace_op_lists[row][col]);
3050     }
3051     if (new_group_ranks_vector.empty()) {
3052       continue;
3053     }
3054 
3055     // Check whether all branch needing be replace
3056     if (new_group_ranks_vector.size() < need_replace_op_lists.size()) {
3057       MS_LOG(INTERNAL_EXCEPTION) << "The concat group in micro interleaved is wrong!";
3058     }
3059 
3060     // replace allgathers to one concat.
3061     auto replaced_concat =
3062       ReplaceInterleavedAllGatherToConcat(func_graph, ag_vector, new_group_ranks_vector, independent_size);
3063     auto manager = func_graph->manager();
3064     auto replaced_concat_users =
3065       GetOutputNodesWithFilter(replaced_concat, [&](const AnfNodePtr &anode) { return false; });
3066     if (replaced_concat_users.size() == kSizeOne) {
3067       continue;
3068     }
3069     if (std::all_of(replaced_concat_users.begin(), replaced_concat_users.end(),
3070                     [](const std::pair<AnfNodePtr, int> &pair) {
3071                       return IsPrimitiveCNode(pair.first, prim::kPrimStridedSlice) &&
3072                              pair.first->cast<CNodePtr>()->HasAttr(INTERLEAVED_PARALLEL);
3073                     })) {
3074       continue;
3075     }
3076     // merge the nodes afer the interleaved parallel concat.
3077     auto virtual_end_input1 = virtual_converter_end->input(kIndex1)->cast<CNodePtr>();
3078     MS_EXCEPTION_IF_NULL(virtual_end_input1);
3079     auto new_virtual_converter_end = CreateVirtualConverterEndNode(func_graph, {virtual_end_input1});
3080 
3081     (void)manager->Replace(virtual_converter_end, new_virtual_converter_end);
3082   }
3083 }
3084 
IsDuplicatedVirtualConverterBegin(const CNodePtr & virtual_converter_begin)3085 bool IsDuplicatedVirtualConverterBegin(const CNodePtr &virtual_converter_begin) {
3086   auto virtual_converter_begin_input = virtual_converter_begin->input(kSizeOne);
3087   if (IsPrimitiveCNode(virtual_converter_begin_input, prim::kPrimVirtualConverterEnd)) {
3088     return false;
3089   }
3090   if (!IsPrimitiveCNode(virtual_converter_begin_input) ||
3091       IsPrimitiveCNode(virtual_converter_begin_input, prim::kPrimUpdateState)) {
3092     return false;
3093   }
3094   auto virtual_converter_begin_input_cnode = virtual_converter_begin_input->cast<CNodePtr>();
3095   if (IsParallelCareNode(virtual_converter_begin_input_cnode)) {
3096     return false;
3097   }
3098   auto virtual_converter_begin_users = GetOutputNodesWithFilter(
3099     virtual_converter_begin, [&](const AnfNodePtr &anode) { return IsPrimitiveCNode(anode, prim::kPrimTupleGetItem); });
3100   if (virtual_converter_begin_users.size() <= kSizeOne) {
3101     return false;
3102   }
3103   std::set<std::vector<std::vector<int64_t>>> slice_value_list_set;
3104   for (const auto &user_pair : virtual_converter_begin_users) {
3105     if (!IsPrimitiveCNode(user_pair.first, prim::kPrimStridedSlice)) {
3106       continue;
3107     }
3108     auto slice = user_pair.first->cast<CNodePtr>();
3109     std::vector<std::vector<int64_t>> slice_value_list;
3110     for (size_t i = 2; i < kSizeFive; ++i) {
3111       ValuePtr slice_value = GetValueNode(slice->input(i));
3112       MS_EXCEPTION_IF_NULL(slice_value);
3113       auto value_vector = GetValue<std::vector<int64_t>>(slice_value);
3114       slice_value_list.push_back(value_vector);
3115     }
3116     slice_value_list_set.insert(slice_value_list);
3117   }
3118   if (slice_value_list_set.size() == virtual_converter_begin_users.size()) {
3119     return false;
3120   }
3121   return true;
3122 }
3123 
GetOrderOfTwoAnode(const std::pair<AnfNodePtr,int> & pair1,const std::pair<AnfNodePtr,int> & pair2)3124 bool GetOrderOfTwoAnode(const std::pair<AnfNodePtr, int> &pair1, const std::pair<AnfNodePtr, int> &pair2) {
3125   int number1 = pair1.second;
3126   int number2 = pair2.second;
3127   auto pair1_input_node = pair1.first->cast<CNodePtr>()->input(pair1.second);
3128   auto pair2_input_node = pair2.first->cast<CNodePtr>()->input(pair2.second);
3129   if (IsPrimitiveCNode(pair1_input_node, prim::kPrimTupleGetItem)) {
3130     number1 = LongToInt(GetTupleGetItemIndex(pair1_input_node->cast<CNodePtr>()));
3131   }
3132   if (IsPrimitiveCNode(pair2_input_node, prim::kPrimTupleGetItem)) {
3133     number2 = LongToInt(GetTupleGetItemIndex(pair2_input_node->cast<CNodePtr>()));
3134   }
3135   return number1 < number2;
3136 }
3137 
DoSplitForNotParallelCareOpsInterleaved(const FuncGraphManagerPtr & manager,const CNodePtr & virtual_converter_begin)3138 std::vector<CNodePtr> DoSplitForNotParallelCareOpsInterleaved(const FuncGraphManagerPtr &manager,
3139                                                               const CNodePtr &virtual_converter_begin) {
3140   auto virtual_converter_begin_input = virtual_converter_begin->input(kSizeOne);
3141   auto virtual_converter_begin_users = GetOutputNodesWithFilter(
3142     virtual_converter_begin, [&](const AnfNodePtr &anode) { return IsPrimitiveCNode(anode, prim::kPrimTupleGetItem); });
3143   std::sort(virtual_converter_begin_users.begin(), virtual_converter_begin_users.end(),
3144             [](const auto &pair1, const auto &pair2) { return GetOrderOfTwoAnode(pair1, pair2); });
3145   auto virtual_converter_begin_input_cnode = virtual_converter_begin_input->cast<CNodePtr>();
3146   std::vector<AnfNodePtr> new_inputs;
3147   std::vector<CNodePtr> new_virtual_converter_begin_vector;
3148   for (size_t i = 1; i < virtual_converter_begin_input_cnode->size(); ++i) {
3149     if (!IsPrimitiveCNode(virtual_converter_begin_input_cnode->input(i)) ||
3150         IsPrimitiveCNode(virtual_converter_begin_input_cnode->input(i), prim::kPrimUpdateState)) {
3151       new_inputs.push_back(virtual_converter_begin_input_cnode->input(i));
3152       continue;
3153     }
3154     auto new_virtual_converter_begin = CreateVirtualConverterBeginNode(
3155       virtual_converter_begin_input_cnode->input(i)->cast<CNodePtr>(), virtual_converter_begin_users.size());
3156     new_inputs.push_back(new_virtual_converter_begin);
3157     new_virtual_converter_begin_vector.push_back(new_virtual_converter_begin);
3158   }
3159 
3160   for (size_t interleveaved_index = 0; interleveaved_index < virtual_converter_begin_users.size();
3161        ++interleveaved_index) {
3162     std::vector<AnfNodePtr> splited_node_inputs = {virtual_converter_begin_input_cnode->input(kIndex0)};
3163     for (size_t i = 0; i < new_inputs.size(); ++i) {
3164       if (!IsPrimitiveCNode(new_inputs[i])) {
3165         splited_node_inputs.push_back(new_inputs[i]);
3166         continue;
3167       }
3168       std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), new_inputs[i],
3169                                                     CreatInt64Imm(UlongToLong(interleveaved_index))};
3170       auto tuple_get_item_cnode = virtual_converter_begin_input_cnode->func_graph()->NewCNode(tuple_get_item_inputs);
3171       splited_node_inputs.push_back(tuple_get_item_cnode);
3172     }
3173     auto splited_node = virtual_converter_begin_input_cnode->func_graph()->NewCNode(splited_node_inputs);
3174     manager->SetEdge(virtual_converter_begin_users[interleveaved_index].first,
3175                      virtual_converter_begin_users[interleveaved_index].second, splited_node);
3176   }
3177   return new_virtual_converter_begin_vector;
3178 }
3179 
SplitNotParallelCareOpsInterleaved(const FuncGraphPtr & root)3180 void SplitNotParallelCareOpsInterleaved(const FuncGraphPtr &root) {
3181   AnfNodePtr ret_after = root->get_return();
3182   MS_EXCEPTION_IF_NULL(ret_after);
3183   auto all_nodes = TopoSort(ret_after, SuccDeeperSimple);
3184   auto manager = root->manager();
3185   auto node_users = manager->node_users();
3186   for (const auto &node : all_nodes) {
3187     if (!IsPrimitiveCNode(node, prim::kPrimVirtualConverterBegin)) {
3188       continue;
3189     }
3190     std::queue<CNodePtr> visited;
3191     visited.push(node->cast<CNodePtr>());
3192     while (!visited.empty()) {
3193       auto virtual_converter_begin = visited.front();
3194       visited.pop();
3195       if (!IsDuplicatedVirtualConverterBegin(virtual_converter_begin)) {
3196         continue;
3197       }
3198       // Need to split the input
3199       auto new_virtual_converter_begins = DoSplitForNotParallelCareOpsInterleaved(manager, virtual_converter_begin);
3200       for (auto &new_virtual_converter_begin : new_virtual_converter_begins) {
3201         visited.push(new_virtual_converter_begin);
3202       }
3203     }
3204   }
3205 }
3206 
EraseVirtualConverter(const FuncGraphPtr & root)3207 void EraseVirtualConverter(const FuncGraphPtr &root) {
3208   AnfNodePtr ret_after = root->get_return();
3209   MS_EXCEPTION_IF_NULL(ret_after);
3210   auto all_nodes = TopoSort(ret_after, SuccDeeperSimple);
3211   auto manager = root->manager();
3212   auto node_users = manager->node_users();
3213   for (const auto &node : all_nodes) {
3214     if (!IsPrimitiveCNode(node, prim::kPrimVirtualConverterBegin)) {
3215       continue;
3216     }
3217     auto virtual_converter_begin = node->cast<CNodePtr>();
3218     if (!IsPrimitiveCNode(virtual_converter_begin->input(kIndex1), prim::kPrimVirtualConverterEnd)) {
3219       MS_LOG(INFO) << "The VirtualConverterBegin input is not VirtualConverterEnd, it is "
3220                    << virtual_converter_begin->input(kIndex1)->fullname_with_scope();
3221       auto virtual_converter_begin_input_node = virtual_converter_begin->input(kIndex1);
3222       for (const auto &v_user_pair : node_users.at(virtual_converter_begin)) {
3223         (void)manager->Replace(v_user_pair.first, virtual_converter_begin_input_node);
3224       }
3225       continue;
3226     }
3227     auto virtual_converter_end = virtual_converter_begin->input(kIndex1)->cast<CNodePtr>();
3228     auto virtual_converter_begin_users = manager->node_users()[virtual_converter_begin];
3229     if (virtual_converter_begin_users.size() != virtual_converter_end->size() - 1) {
3230       MS_LOG(INTERNAL_EXCEPTION)
3231         << "The VirtualConverterBegin users nums is not equal to VirtualConverterEnd inputs nums";
3232     }
3233     for (const auto &node_pair : virtual_converter_begin_users) {
3234       if (!IsPrimitiveCNode(node_pair.first, prim::kPrimTupleGetItem)) {
3235         MS_LOG(INTERNAL_EXCEPTION) << "The VirtualConverterBegin user should be tuple_get_item.";
3236       }
3237       auto tuple_get_item = node_pair.first->cast<CNodePtr>();
3238       auto tuple_get_item_index_value = GetValueNode(tuple_get_item->input(kIndex2));
3239       MS_EXCEPTION_IF_NULL(tuple_get_item_index_value);
3240       auto get_item_index = GetValue<int64_t>(tuple_get_item_index_value);
3241       (void)manager->Replace(tuple_get_item, virtual_converter_end->input(get_item_index + 1));
3242     }
3243   }
3244   AnfNodePtr new_ret_after = root->get_return();
3245   MS_EXCEPTION_IF_NULL(new_ret_after);
3246   auto new_all_nodes = TopoSort(new_ret_after, SuccDeeperSimple);
3247   for (const auto &node : new_all_nodes) {
3248     if (IsPrimitiveCNode(node, prim::kPrimVirtualConverterEnd)) {
3249       auto virtual_converter_end_cnode = node->cast<CNodePtr>();
3250       if (virtual_converter_end_cnode->size() != kSizeTwo) {
3251         MS_LOG(INTERNAL_EXCEPTION) << "The VirtualConverterEnd nums is not equal to VirtualConverterBegin nums.";
3252       }
3253       auto virtual_converter_end_input = virtual_converter_end_cnode->input(kIndex1);
3254       (void)manager->Replace(virtual_converter_end_cnode, virtual_converter_end_input);
3255     }
3256   }
3257 }
3258 
GetSerialNumberString(size_t number)3259 std::string GetSerialNumberString(size_t number) {
3260   std::string suffix = "th";
3261   if (number == kSizeOne) {
3262     suffix = "st";
3263   } else if (number == kSizeTwo) {
3264     suffix = "nd";
3265   } else if (number == kSizeThree) {
3266     suffix = "rd";
3267   }
3268   std::ostringstream oss;
3269   oss << number << suffix;
3270   return oss.str();
3271 }
3272 
3273 // Get single device capacity in Go
GetDeviceCapacity()3274 size_t GetDeviceCapacity() {
3275   auto context = MsContext::GetInstance();
3276   MS_EXCEPTION_IF_NULL(context);
3277   size_t size_from_context;
3278   auto max_device_memory = context->get_param<float>(MS_CTX_MAX_DEVICE_MEMORY);
3279   float total_device_memory = 32.0f;
3280   if (context->ascend_soc_version() == kAscendVersion910b || context->ascend_soc_version() == kAscendVersion910c) {
3281     total_device_memory = 64.0f;
3282   }
3283   if (max_device_memory <= total_device_memory) {
3284     MS_LOG(DEBUG) << "context max_device_memory:" << max_device_memory;
3285     size_from_context = FloatToSize(max_device_memory * kGBToByte);
3286   } else {
3287     auto variable_memory_max_size = context->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE);
3288     if (variable_memory_max_size == "0") {
3289       return 0;
3290     }
3291     MS_LOG(DEBUG) << "context variable_memory_max_size:" << variable_memory_max_size;
3292     auto pos = variable_memory_max_size.find('*');
3293     if (pos == std::string::npos) {
3294       MS_LOG(EXCEPTION) << "Invalid variable_memory_max_size";
3295     }
3296     auto gb_str = variable_memory_max_size.substr(0, pos);
3297     auto gb_var = std::stoull(gb_str);
3298     MS_LOG(DEBUG) << "variable_memory_max_size(GB):" << gb_var;
3299     size_from_context = gb_var * kGBToByte;
3300   }
3301   return size_from_context;
3302 }
3303 
GenerateAbsByOpInfer(const PrimitivePtr & primitive,const AnfNodePtrList & input_list)3304 abstract::AbstractBasePtr GenerateAbsByOpInfer(const PrimitivePtr &primitive, const AnfNodePtrList &input_list) {
3305   MS_EXCEPTION_IF_NULL(primitive);
3306   std::vector<AbstractBasePtr> input_args;
3307   (void)std::for_each(input_list.begin(), input_list.end(),
3308                       [&input_args](const auto &input) { (void)input_args.emplace_back(input->abstract()); });
3309   auto abs_opt = abstract::TryInferAbstract(primitive, input_args);
3310   if (!abs_opt.has_value()) {
3311     MS_LOG(EXCEPTION) << primitive->name() << " infer is not registered.";
3312   }
3313   auto abs = abs_opt.value();
3314   MS_EXCEPTION_IF_NULL(abs);
3315   MS_LOG(DEBUG) << "Abstract for " << primitive->name() << " is " << abs->ToString();
3316   return abs;
3317 }
3318 }  // namespace parallel
3319 }  // namespace mindspore
3320