1 /**
2 * Copyright 2021-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/step_parallel_utils.h"
18
19 #include <algorithm>
20 #include <cinttypes>
21
22 #include <map>
23 #include <memory>
24 #include <queue>
25 #include <set>
26 #include <string>
27 #include <utility>
28
29 #include "abstract/dshape.h"
30 #include "base/base.h"
31 #include "base/bfloat16.h"
32 #include "frontend/operator/ops.h"
33 #include "frontend/optimizer/optimizer.h"
34 #include "frontend/parallel/device_manager.h"
35 #include "frontend/parallel/dynamic_creator.h"
36 #include "frontend/parallel/graph_util/generate_graph.h"
37 #include "frontend/parallel/graph_util/graph_info.h"
38 #include "frontend/parallel/graph_util/node_info.h"
39 #include "frontend/parallel/graph_util/pipeline_split_utils.h"
40 #include "frontend/parallel/node_check.h"
41 #include "frontend/parallel/parameter_manager.h"
42 #include "frontend/parallel/dynamic_shape/dynamic_shape.h"
43 #include "include/common/utils/comm_manager.h"
44 #include "include/common/utils/parallel_context.h"
45 #include "ir/param_info.h"
46 #include "ir/tensor.h"
47 #include "ops/array_ops.h"
48 #include "ops/framework_ops.h"
49 #include "ops/nn_ops.h"
50 #include "ops/other_ops.h"
51 #include "ops/sequence_ops.h"
52 #include "utils/parallel_node_check.h"
53 #include "utils/hash_map.h"
54 #include "utils/ms_context.h"
55 #include "utils/symbolic.h"
56 #include "utils/trace_base.h"
57 #include "mindspore/core/symbolic_shape/int_symbol.h"
58
59 namespace mindspore {
60 namespace parallel {
61 using mindspore::tensor::Tensor;
62 size_t TOTAL_OPS = 0;
63 // g_RefMap, for CNode B input i is a RefKey[Parameter C],
64 // it will be one item in map with key: C, and value: (B, i)
65 std::map<AnfNodePtr, std::pair<AnfNodePtr, int64_t>> g_RefMap;
66
IsDynamicShapeInput(const CNodePtr & node,const AnfNodePtr & input)67 bool IsDynamicShapeInput(const CNodePtr &node, const AnfNodePtr &input) {
68 if (IsSomePrimitiveList(node, CANDIDATE_DYNAMIC_VALUE_OPS) &&
69 (IsPrimitiveCNode(input, prim::kPrimMakeTuple) || IsPrimitiveCNode(input, prim::kPrimShape))) {
70 return true;
71 }
72 if (IsPrimitiveCNode(node, prim::kPrimCast) && IsPrimitiveCNode(input, prim::kPrimTupleGetItem)) {
73 BaseShapePtr base_shape_ptr = node->Shape();
74 if (base_shape_ptr == nullptr) {
75 MS_LOG(EXCEPTION) << "IsDynamicShapeInput: " << node->ToString() << " shape_ptr is nullptr, full name is "
76 << node->fullname_with_scope();
77 }
78 auto shape_ptr = dyn_cast<abstract::Shape>(base_shape_ptr);
79 MS_EXCEPTION_IF_NULL(shape_ptr);
80 if (shape_ptr->shape().empty()) {
81 return true;
82 }
83 }
84 return false;
85 }
86
IsSomePrimitive(const CNodePtr & cnode,const std::string & name)87 bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name) {
88 if (!cnode) {
89 return false;
90 }
91 ValueNodePtr anf_node = cnode->input(0)->cast<ValueNodePtr>();
92 if (!anf_node) {
93 return false;
94 }
95 PrimitivePtr prim = anf_node->value()->cast<PrimitivePtr>();
96 if (!prim) {
97 return false;
98 }
99 return (prim->name() == name);
100 }
101
IsSomePrimitiveList(const CNodePtr & cnode,const std::set<string> & check_list)102 bool IsSomePrimitiveList(const CNodePtr &cnode, const std::set<string> &check_list) {
103 if (!cnode) {
104 return false;
105 }
106 ValueNodePtr anf_node = cnode->input(0)->cast<ValueNodePtr>();
107 if (!anf_node) {
108 return false;
109 }
110 PrimitivePtr prim = anf_node->value()->cast<PrimitivePtr>();
111 if (!prim) {
112 return false;
113 }
114 return std::any_of(check_list.begin(), check_list.end(), [prim](const string &in) { return prim->name() == in; });
115 }
116
IsIgnoreSplitTensor(const CNodePtr & node,int64_t index)117 bool IsIgnoreSplitTensor(const CNodePtr &node, int64_t index) {
118 if (IsSomePrimitiveList(node, SPLIT_TENSOR_ONLY_FOR_FIRST_INPUT_OPS) && index > 0) {
119 return true;
120 }
121 return false;
122 }
123
GetPrimName(const CNodePtr & node)124 std::string GetPrimName(const CNodePtr &node) {
125 auto prim = GetCNodePrimitive(node);
126 if (!prim) {
127 return node->DebugString();
128 }
129 return prim->name();
130 }
131
IsTraining(const FuncGraphManagerPtr & manager)132 bool IsTraining(const FuncGraphManagerPtr &manager) {
133 for (auto &fg : manager->func_graphs()) {
134 if (fg->has_flag(kTraining)) {
135 return true;
136 }
137 }
138 return false;
139 }
140
HasBackward(const FuncGraphPtr & root)141 bool HasBackward(const FuncGraphPtr &root) {
142 auto nodes = root->nodes();
143 for (auto &node : nodes) {
144 if (IsPrimitiveCNode(node, prim::kPrimJ)) {
145 return true;
146 }
147 }
148 return false;
149 }
150
GetInputsTensorInfo(const std::pair<AnfNodePtr,int64_t> & param_info)151 TensorInfo GetInputsTensorInfo(const std::pair<AnfNodePtr, int64_t> ¶m_info) {
152 auto user_cnode = param_info.first->cast<CNodePtr>();
153 MS_EXCEPTION_IF_NULL(user_cnode);
154 auto user_input_index = param_info.second;
155 OperatorInfoPtr op_info = user_cnode->user_data<OperatorInfo>();
156 MS_EXCEPTION_IF_NULL(op_info);
157
158 TensorInfo tensor_info;
159 if (IsPrimitiveCNode(user_cnode, prim::kPrimSend)) {
160 auto param_index = IntToSize(GetValue<int>(user_cnode->GetPrimalAttr(PARAM_INDEX)));
161 tensor_info = op_info->inputs_tensor_info()[param_index];
162 } else {
163 size_t input_tensor_info_size = op_info->inputs_tensor_info().size();
164 if (SizeToLong(input_tensor_info_size) <= user_input_index - 1) {
165 MS_LOG(EXCEPTION) << op_info->name() << ": the size of inputs tensor info is " << input_tensor_info_size
166 << ", but the index is " << (user_input_index - 1);
167 }
168 tensor_info = op_info->inputs_tensor_info()[LongToSize(user_input_index - 1)];
169 }
170 return tensor_info;
171 }
172
IsRealKernelNode(const AnfNodePtr & node)173 static bool IsRealKernelNode(const AnfNodePtr &node) {
174 if (IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad) ||
175 IsPrimitiveCNode(node, prim::kPrimCast) || IsPrimitiveCNode(node, prim::kPrimVirtualDiv) ||
176 IsPrimitiveCNode(node, prim::kPrimReceive) || IsPrimitiveCNode(node, prim::kPrimMicroStepAllGather) ||
177 IsPrimitiveCNode(node, prim::kPrimSend) || IsPrimitiveCNode(node, prim::kPrimInsertGradientOf)) {
178 return false;
179 }
180 return true;
181 }
182
GetRealKernelNode(const AnfNodePtr & node,int64_t get_item_index,CNodePtr * call_node,bool ignore_get_item)183 std::pair<AnfNodePtr, int64_t> GetRealKernelNode(const AnfNodePtr &node, int64_t get_item_index, CNodePtr *call_node,
184 bool ignore_get_item) {
185 if (!IsRealKernelNode(node)) {
186 return GetRealKernelNode(node->cast<CNodePtr>()->input(1), get_item_index, call_node, ignore_get_item);
187 }
188 if ((IsPrimitiveCNode(node, prim::kPrimTupleGetItem) || IsPrimitiveCNode(node, prim::kPrimInsertGradientOf)) &&
189 ignore_get_item) {
190 auto cnode = node->cast<CNodePtr>();
191 auto cur_get_item_index = LongToInt(GetTupleGetItemIndex(cnode));
192 auto tuple_getitem_input = cnode->input(1);
193 return GetRealKernelNode(tuple_getitem_input, cur_get_item_index, call_node, ignore_get_item);
194 }
195 if (get_item_index != -1 &&
196 (IsPrimitiveCNode(node, prim::kPrimMakeTuple) || IsPrimitiveCNode(node, prim::kPrimInsertGradientOf))) {
197 auto make_tuple_cnode = node->cast<CNodePtr>();
198 auto make_tuple_input = make_tuple_cnode->input(LongToSize(get_item_index + 1));
199 return GetRealKernelNode(make_tuple_input, -1, call_node, ignore_get_item);
200 }
201 if (IsControlFlowNode(node)) {
202 auto switch_cnode = node->cast<CNodePtr>()->input(0)->cast<CNodePtr>();
203 auto fg = GetValueNode<FuncGraphPtr>(switch_cnode->input(3));
204 return GetRealKernelNode(fg->output(), get_item_index, call_node, ignore_get_item);
205 }
206 if (node->isa<CNode>() && IsValueNode<FuncGraph>(node->cast<CNodePtr>()->input(0))) {
207 if (call_node != nullptr && *call_node == nullptr) {
208 *call_node = node->cast<CNodePtr>();
209 }
210 auto cnode = node->cast<CNodePtr>();
211 auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
212 auto output = GetRealKernelNode(graph->output(), get_item_index, call_node, ignore_get_item).first;
213 MS_EXCEPTION_IF_NULL(output);
214 if (output->isa<Parameter>()) {
215 auto param_graph = output->func_graph();
216 auto parameter_list = param_graph->parameters();
217 auto fg_used_map = param_graph->func_graph_cnodes_index();
218 for (auto &cur_fg_use : fg_used_map) {
219 if (cur_fg_use.first->second != 0) {
220 continue;
221 }
222 auto cur_fg = cur_fg_use.first->first->cast<CNodePtr>();
223 auto iter = std::find(parameter_list.begin(), parameter_list.end(), output);
224 auto pos = std::distance(parameter_list.begin(), iter);
225 auto argument = cur_fg->input(pos + 1);
226 return GetRealKernelNode(argument, get_item_index, call_node, ignore_get_item);
227 }
228 return std::make_pair(output, get_item_index);
229 }
230 return std::make_pair(output, get_item_index);
231 }
232 return std::make_pair(node, get_item_index);
233 }
234
IsWhileGraph(const FuncGraphPtr & cur_fg,const FuncGraphPtr & fg)235 static bool IsWhileGraph(const FuncGraphPtr &cur_fg, const FuncGraphPtr &fg) {
236 auto cur_fg_map = cur_fg->func_graph_cnodes_index();
237 for (auto &cur_fg_use : cur_fg_map) {
238 auto temp_node = cur_fg_use.first->first->cast<CNodePtr>();
239 MS_EXCEPTION_IF_NULL(temp_node);
240 if (temp_node->func_graph() == fg) {
241 return true;
242 }
243 }
244 return false;
245 }
246
CheckMakeTupleSplit(const AnfNodePtr & node,const FuncGraphManagerPtr & manager)247 AnfNodePtr CheckMakeTupleSplit(const AnfNodePtr &node, const FuncGraphManagerPtr &manager) {
248 auto node_users = manager->node_users()[node];
249 if (node_users.size() == 1) {
250 return node_users.front().first;
251 }
252
253 bool is_first_tensor_info = true;
254 TensorInfo first_tensor_info;
255 AnfNodePtr first_node;
256 for (auto &node_user : node_users) {
257 auto user_node = node_user.first->cast<CNodePtr>();
258 if (!user_node->has_user_data<OperatorInfo>()) {
259 continue;
260 }
261 auto tensor_info = GetInputsTensorInfo(node_user);
262 if (is_first_tensor_info) {
263 is_first_tensor_info = false;
264 first_tensor_info = tensor_info;
265 first_node = node_user.first;
266 continue;
267 }
268 if (first_tensor_info == tensor_info) {
269 continue;
270 } else {
271 MS_LOG(EXCEPTION) << "The node: " << node->DebugString()
272 << " has multiple users, but the TensorInfo are different";
273 }
274 }
275 return first_node;
276 }
277
IsParallelCareNode(const CNodePtr & cnode)278 bool IsParallelCareNode(const CNodePtr &cnode) {
279 MS_EXCEPTION_IF_NULL(cnode);
280 // Not skip Send Receive in pp interleave
281 auto parallel_context = parallel::ParallelContext::GetInstance();
282 MS_EXCEPTION_IF_NULL(parallel_context);
283 auto is_pp_interleave = parallel_context->pipeline_interleave();
284 if (is_pp_interleave && (IsPrimitiveCNode(cnode, prim::kPrimSend) || IsPrimitiveCNode(cnode, prim::kPrimReceive))) {
285 return false;
286 }
287 ValueNodePtr prim_node = cnode->input(0)->cast<ValueNodePtr>();
288 if (prim_node == nullptr) {
289 return false;
290 }
291 PrimitivePtr prim = prim_node->value()->cast<PrimitivePtr>();
292 if (prim == nullptr) {
293 return false;
294 }
295 if (!IsParallelConsiderCNode(cnode)) {
296 MS_LOG(DEBUG) << "Parallel don't care node: " << prim->name();
297 return false;
298 }
299 // get_next is not in the forward graph, we need mark the get_next as the forward node
300 if (prim->name() == GET_NEXT || prim->name() == VIRTUAL_OUTPUT) {
301 return true;
302 }
303 if ((prim->name() == CAST) && !cnode->has_user_data<OperatorInfo>()) {
304 return false;
305 }
306
307 return cnode->in_forward_flag();
308 }
309
HasNestedMetaFg(const FuncGraphPtr & func_graph)310 bool HasNestedMetaFg(const FuncGraphPtr &func_graph) {
311 if (!IsPynativeParallel()) {
312 return false;
313 }
314 AnfNodePtr ret = func_graph->get_return();
315 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
316 for (auto &node : all_nodes) {
317 if (IsPrimitiveCNode(node, prim::kPrimJ) || IsPrimitiveCNode(node, prim::kPrimVmap) ||
318 IsPrimitiveCNode(node, prim::kPrimTaylor)) {
319 return true;
320 }
321 }
322 return false;
323 }
324
IsEmbedShardNode(const FuncGraphPtr & func_graph)325 bool IsEmbedShardNode(const FuncGraphPtr &func_graph) {
326 MS_EXCEPTION_IF_NULL(func_graph);
327 AnfNodePtr ret = func_graph->get_return();
328 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
329 return std::any_of(all_nodes.begin(), all_nodes.end(), [&func_graph](const AnfNodePtr &node) {
330 return IsPrimitiveCNode(node, prim::kPrimShard) && (node->func_graph() == func_graph);
331 });
332 }
333
GetValueListShape(const AnfNodePtr & node)334 Shapes GetValueListShape(const AnfNodePtr &node) {
335 Shapes shapes;
336 std::vector<ValuePtr> inputs_seq;
337 if (IsValueNode<ValueList>(node)) {
338 inputs_seq = node->cast<ValueNodePtr>()->value()->cast<ValueListPtr>()->value();
339 } else if (IsValueNode<ValueTuple>(node)) {
340 inputs_seq = node->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value();
341 } else {
342 MS_LOG(EXCEPTION) << "node is either ValueList or ValueTuple";
343 }
344 for (auto &ele : inputs_seq) {
345 auto tensor = ele->cast<tensor::TensorPtr>();
346 if (tensor == nullptr) {
347 MS_LOG(WARNING) << "The value node is not a tensor";
348 break;
349 }
350 auto one_shape = tensor->shape();
351 shapes.push_back(one_shape);
352 }
353 return shapes;
354 }
355
IsControlFlowNode(const AnfNodePtr & node)356 bool IsControlFlowNode(const AnfNodePtr &node) {
357 // Only switch or FuncCall nodes are control flow nodes
358 MS_EXCEPTION_IF_NULL(node);
359 if (!node->isa<CNode>()) {
360 return false;
361 }
362 auto cnode = node->cast<CNodePtr>();
363 MS_EXCEPTION_IF_NULL(cnode);
364 // func node
365 if (cnode->input(0)->isa<CNode>() && IsPrimitiveCNode(cnode->input(0), prim::kPrimSwitch)) {
366 return true;
367 }
368 return false;
369 }
370
GetTupleGetItemIndex(const CNodePtr & cnode)371 int64_t GetTupleGetItemIndex(const CNodePtr &cnode) {
372 MS_EXCEPTION_IF_NULL(cnode);
373 if (!cnode->input(TUPLE_GETITEM_INDEX_POS)->isa<ValueNode>()) {
374 MS_LOG(EXCEPTION) << "The index of tuple getitem is not a value node";
375 }
376
377 ValuePtr tuple_index_value = GetValueNode(cnode->input(TUPLE_GETITEM_INDEX_POS));
378 MS_EXCEPTION_IF_NULL(tuple_index_value);
379 if (!tuple_index_value->isa<Int64Imm>()) {
380 MS_LOG(EXCEPTION) << "The index of tuple getitem is not int64";
381 }
382 return tuple_index_value->cast<Int64ImmPtr>()->value();
383 }
384
IsNoNeedRedistribution(const CNodePtr & use_cnode,int use_index)385 static bool IsNoNeedRedistribution(const CNodePtr &use_cnode, int use_index) {
386 return (IsPrimitiveCNode(use_cnode, prim::kPrimDepend) && use_index != 1) || use_cnode->input(0)->isa<CNode>() ||
387 IsOneOfPrimitiveCNode(use_cnode, {prim::kPrimUpdateState, prim::kPrimSwitch, prim::kPrimShape,
388 prim::kPrimTensorShape, prim::kPrimDType});
389 }
390
FuncGraphNodeUsers(const std::pair<AnfNodePtr,int> & node_pair)391 std::vector<std::pair<AnfNodePtr, int>> FuncGraphNodeUsers(const std::pair<AnfNodePtr, int> &node_pair) {
392 std::vector<std::pair<AnfNodePtr, int>> func_users_vector;
393 if (!node_pair.first->isa<CNode>()) {
394 return func_users_vector;
395 }
396 auto use_cnode = node_pair.first->cast<CNodePtr>();
397 MS_EXCEPTION_IF_NULL(use_cnode);
398 if (IsValueNode<FuncGraph>(use_cnode->input(0))) {
399 auto fg = GetValueNode<FuncGraphPtr>(use_cnode->input(0));
400 auto fg_parameters = fg->parameters();
401 auto param = fg_parameters[IntToSize(node_pair.second - 1)];
402 auto manager = fg->manager();
403 auto param_node_users = manager->node_users()[param];
404 for (const auto &node_user : param_node_users) {
405 auto cnode = node_user.first->cast<CNodePtr>();
406 if (IsValueNode<FuncGraph>(cnode->input(0))) {
407 auto sub_graph_users = FuncGraphNodeUsers(node_user);
408 (void)std::copy(sub_graph_users.begin(), sub_graph_users.end(), std::back_inserter(func_users_vector));
409 } else {
410 func_users_vector.emplace_back(node_user);
411 }
412 }
413 }
414 return func_users_vector;
415 }
416
RemovePlaceholderIdx(const std::vector<int> & get_item_index)417 std::vector<int> RemovePlaceholderIdx(const std::vector<int> &get_item_index) {
418 std::vector<int> new_get_item_index;
419 std::copy(get_item_index.begin(), get_item_index.end(), std::back_inserter(new_get_item_index));
420 if (new_get_item_index.size() != 1) {
421 // Remove first -1, if there is other index
422 new_get_item_index.erase(new_get_item_index.begin());
423 }
424 return new_get_item_index;
425 }
426
RedistributionNextNodeInMakeTuple(const CNodePtr & use_cnode,const std::pair<std::shared_ptr<AnfNode>,int> & node_pair,const std::vector<int> & get_item_index,int64_t * make_tuple_index,std::vector<std::pair<std::pair<AnfNodePtr,std::vector<int>>,std::vector<int>>> * next_nodes)427 void RedistributionNextNodeInMakeTuple(
428 const CNodePtr &use_cnode, const std::pair<std::shared_ptr<AnfNode>, int> &node_pair,
429 const std::vector<int> &get_item_index, int64_t *make_tuple_index,
430 std::vector<std::pair<std::pair<AnfNodePtr, std::vector<int>>, std::vector<int>>> *next_nodes) {
431 auto modified_get_item_idx = RemovePlaceholderIdx(get_item_index);
432 std::vector<int> input_index = {node_pair.second};
433 if (*make_tuple_index != -1) {
434 int node_pos = IsSomePrimitiveList(use_cnode, SUPPORT_NEW_SHAPEBASE_OPS) ? node_pair.second : 1;
435 auto real_node = GetRealKernelNode(use_cnode->input(node_pos), -1, nullptr);
436 if (IsPrimitiveCNode(real_node.first, prim::kPrimMakeTuple)) {
437 input_index.push_back(LongToInt((*make_tuple_index) + 1));
438 next_nodes->push_back(std::make_pair(std::make_pair(real_node.first, input_index), modified_get_item_idx));
439 *make_tuple_index = -1;
440 return;
441 }
442 }
443 auto modified_node_pair = std::make_pair(node_pair.first, input_index);
444 next_nodes->push_back(std::make_pair(modified_node_pair, modified_get_item_idx));
445 }
446
SetAnfNode(const AnfNodePtr & param,std::vector<std::pair<std::pair<AnfNodePtr,std::vector<int>>,std::vector<int>>> * next_nodes)447 void SetAnfNode(const AnfNodePtr ¶m,
448 std::vector<std::pair<std::pair<AnfNodePtr, std::vector<int>>, std::vector<int>>> *next_nodes) {
449 for (const auto &next_node : *next_nodes) {
450 next_node.first.first->set_user_data<AnfNode>(FUNC_PARAM, param);
451 }
452 }
453
RedistributionNextNode(const AnfNodePtr & node,const FuncGraphManagerPtr & manager,const NodeUsersMap & node_users_map,const std::vector<int> & get_item_index,int64_t make_tuple_index,std::vector<std::pair<std::pair<AnfNodePtr,std::vector<int>>,std::vector<int>>> * next_nodes)454 void RedistributionNextNode(
455 const AnfNodePtr &node, const FuncGraphManagerPtr &manager, const NodeUsersMap &node_users_map,
456 const std::vector<int> &get_item_index, int64_t make_tuple_index,
457 std::vector<std::pair<std::pair<AnfNodePtr, std::vector<int>>, std::vector<int>>> *next_nodes) {
458 MS_EXCEPTION_IF_NULL(node);
459 if (node_users_map.count(node) == 0) {
460 return;
461 }
462 auto node_set = node_users_map.at(node);
463 for (auto &node_pair : node_set) {
464 auto use_cnode = node_pair.first->cast<CNodePtr>();
465 MS_EXCEPTION_IF_NULL(use_cnode);
466 if (IsValueNode<FuncGraph>(use_cnode->input(0))) {
467 auto cur_fg = use_cnode->func_graph();
468 auto fg = GetValueNode<FuncGraphPtr>(use_cnode->input(0));
469 MS_EXCEPTION_IF_NULL(fg);
470 if (IsWhileGraph(cur_fg, fg)) {
471 continue;
472 }
473 auto fg_parameters = fg->parameters();
474 auto param = fg_parameters[IntToSize(node_pair.second - 1)];
475 MS_EXCEPTION_IF_NULL(param);
476 if (param->has_user_data<OperatorInfo>()) {
477 std::vector<int> input_index = {node_pair.second};
478 auto modified_node_pair = std::make_pair(node_pair.first, input_index);
479 next_nodes->push_back(std::make_pair(modified_node_pair, RemovePlaceholderIdx(get_item_index)));
480 continue;
481 }
482 RedistributionNextNode(param, manager, node_users_map, get_item_index, make_tuple_index, next_nodes);
483 SetAnfNode(param, next_nodes);
484 continue;
485 }
486 if (IsPrimitiveCNode(use_cnode, prim::kPrimMakeTuple)) {
487 make_tuple_index = node_pair.second - 1;
488 RedistributionNextNode(use_cnode, manager, node_users_map, get_item_index, make_tuple_index, next_nodes);
489 continue;
490 }
491 if (IsPrimitiveCNode(use_cnode, prim::kPrimTupleGetItem) || IsPrimitiveCNode(use_cnode, prim::kPrimListGetItem)) {
492 auto temp = LongToInt(GetTupleGetItemIndex(use_cnode));
493 if (temp != make_tuple_index && make_tuple_index != -1) {
494 continue;
495 }
496 temp = make_tuple_index != -1 ? -1 : temp;
497 std::vector<int> new_get_item_index;
498 std::copy(get_item_index.begin(), get_item_index.end(), std::back_inserter(new_get_item_index));
499 new_get_item_index.push_back(temp);
500 RedistributionNextNode(use_cnode, manager, node_users_map, new_get_item_index, -1, next_nodes);
501 continue;
502 }
503 if (IsPrimitiveCNode(use_cnode, prim::kPrimReturn)) {
504 auto fg = use_cnode->func_graph();
505 auto fg_map = fg->func_graph_cnodes_index();
506 for (auto &fg_use : fg_map) {
507 auto fg_node = fg_use.first->first->cast<CNodePtr>();
508 constexpr int SWITCH_LAST_INPUT_INDEX = 3;
509 if (IsWhileGraph(fg, fg) && fg_use.first->second != SWITCH_LAST_INPUT_INDEX) {
510 continue;
511 }
512 RedistributionNextNode(fg_node, manager, node_users_map, get_item_index, make_tuple_index, next_nodes);
513 }
514 }
515 // depend, auto monad and control flow op don't need to jump over
516 if (IsNoNeedRedistribution(use_cnode, node_pair.second)) {
517 continue;
518 }
519 if (IsParallelCareNode(use_cnode) && use_cnode->has_user_data<OperatorInfo>()) {
520 RedistributionNextNodeInMakeTuple(use_cnode, node_pair, get_item_index, &make_tuple_index, next_nodes);
521 continue;
522 }
523 // search recursively
524 RedistributionNextNode(use_cnode, manager, node_users_map, get_item_index, make_tuple_index, next_nodes);
525 }
526 }
527
RedistributionPreNode(const CNodePtr & cnode,const FuncGraphManagerPtr & manager,std::vector<AnfNodePtr> * pre_nodes)528 void RedistributionPreNode(const CNodePtr &cnode, const FuncGraphManagerPtr &manager,
529 std::vector<AnfNodePtr> *pre_nodes) {
530 if (IsValueNode<FuncGraph>(cnode->input(0))) {
531 return;
532 }
533 if (IsControlFlowNode(cnode)) {
534 auto switch_cnode = cnode->input(0)->cast<CNodePtr>();
535 MS_EXCEPTION_IF_NULL(switch_cnode);
536 // extract true branch, false branch is usually also a control flow graph
537 auto fg = GetValueNode<FuncGraphPtr>(switch_cnode->input(2));
538 MS_EXCEPTION_IF_NULL(fg);
539 auto fg_out = fg->output()->cast<CNodePtr>();
540 MS_EXCEPTION_IF_NULL(fg_out);
541 // control flow node, need enter graph to find redistribution pre node.
542 RedistributionPreNode(fg_out, manager, pre_nodes);
543 }
544 if (IsPrimitiveCNode(cnode, prim::kPrimDepend) || IsPrimitiveCNode(cnode, prim::kPrimLoad) ||
545 IsPrimitiveCNode(cnode, prim::kPrimCast) || IsPrimitiveCNode(cnode, prim::kPrimAllReduce)) {
546 auto cnode_input = cnode->input(1)->cast<CNodePtr>();
547 MS_EXCEPTION_IF_NULL(cnode_input);
548 RedistributionPreNode(cnode_input, manager, pre_nodes);
549 }
550 if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
551 pre_nodes->push_back(cnode);
552 }
553 }
554
GetNodeShape(const AnfNodePtr & node)555 Shapes GetNodeShape(const AnfNodePtr &node) {
556 MS_EXCEPTION_IF_NULL(node);
557 Shapes shapes;
558 if (IsValueNode<ValueList>(node) || IsValueNode<ValueTuple>(node)) {
559 return GetValueListShape(node);
560 }
561 BaseShapePtr base_shape_ptr = node->Shape();
562 if (base_shape_ptr == nullptr && node->isa<ValueNode>()) {
563 auto value_node = node->cast<ValueNodePtr>();
564 MS_EXCEPTION_IF_CHECK_FAIL(value_node->value() != nullptr, "ValueNode has no value.");
565 auto abstract = value_node->value()->ToAbstract();
566 MS_EXCEPTION_IF_CHECK_FAIL(abstract != nullptr, "ValueNode has no Abstract.");
567 node->set_abstract(abstract);
568 base_shape_ptr = node->Shape();
569 }
570 if (node->isa<CNode>() && !IsControlFlowNode(node)) {
571 auto cnode = node->cast<CNodePtr>();
572 if (cnode->input(0)->isa<CNode>()) {
573 if (cnode->size() < 2) {
574 MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " size is smaller than 2";
575 }
576 base_shape_ptr = cnode->input(1)->Shape();
577 }
578 }
579 // If node is Depend, only first input should be used.
580 if (node->isa<CNode>() && IsPrimitiveCNode(node->cast<CNodePtr>(), prim::kPrimDepend)) {
581 auto depend_cnode = node->cast<CNodePtr>();
582 MS_EXCEPTION_IF_NULL(depend_cnode->input(1));
583 return GetNodeShape(depend_cnode->input(1));
584 }
585 if (base_shape_ptr == nullptr) {
586 MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " shape_ptr is nullptr, full name is "
587 << node->fullname_with_scope();
588 }
589 auto tuple_shape_ptr = dyn_cast<abstract::SequenceShape>(base_shape_ptr);
590 if (tuple_shape_ptr != nullptr) {
591 if (tuple_shape_ptr->size() == 0) {
592 shapes.push_back(Shape{0});
593 return shapes;
594 }
595 auto tuple_shape = tuple_shape_ptr->shape();
596 if (tuple_shape[0]->isa<abstract::NoShape>()) {
597 shapes.push_back(Shape{SizeToLong(tuple_shape_ptr->size())});
598 return shapes;
599 }
600 for (auto &shape : tuple_shape) {
601 auto each_shape = dyn_cast<abstract::Shape>(shape);
602 MS_EXCEPTION_IF_NULL(each_shape);
603 shapes.push_back(each_shape->shape());
604 }
605 } else if (base_shape_ptr->isa<abstract::DynamicSequenceShape>()) {
606 shapes.push_back(Shape{-1});
607 } else if (base_shape_ptr->isa<abstract::Shape>()) {
608 auto shape_ptr = dyn_cast<abstract::Shape>(base_shape_ptr);
609 MS_EXCEPTION_IF_NULL(shape_ptr);
610 shapes.push_back(shape_ptr->shape());
611 } else if (base_shape_ptr->isa<abstract::NoShape>()) {
612 shapes.push_back(Shape{});
613 } else {
614 MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " should be Tuple/List/Tensor/Scalar, but got "
615 << base_shape_ptr->ToString() << "full name is " << node->fullname_with_scope();
616 }
617 return shapes;
618 }
619
TransferShapesToNewShapes(const Shapes & shapes,const bool need_create_shape_list)620 NewShapes TransferShapesToNewShapes(const Shapes &shapes, const bool need_create_shape_list) {
621 NewShapes s;
622 if (!need_create_shape_list) {
623 s.emplace_back(std::make_shared<ShapeValue>(shapes[0]));
624 } else {
625 std::vector<ShapeBasePtr> shapes_list;
626 std::transform(shapes.begin(), shapes.end(), std::back_inserter(shapes_list),
627 [](const auto &shape) { return std::make_shared<ShapeValue>(shape); });
628 s.emplace_back(std::make_shared<ShapeList>(shapes_list));
629 }
630 return s;
631 }
632
ExtractNewShapeFromShape(const abstract::BaseShapePtr & shape)633 ShapeBasePtr ExtractNewShapeFromShape(const abstract::BaseShapePtr &shape) {
634 ShapeBasePtr out_shape;
635 if (dyn_cast<abstract::Shape>(shape) != nullptr) {
636 auto casted_shape = dyn_cast<abstract::Shape>(shape);
637 std::vector<int64_t> shape_value = casted_shape->shape();
638 out_shape = std::make_shared<ShapeValue>(shape_value);
639 } else if (dyn_cast<abstract::SequenceShape>(shape) != nullptr) {
640 std::vector<ShapeBasePtr> tuple_shape;
641 auto sequence_shape = dyn_cast<abstract::SequenceShape>(shape);
642 std::transform(sequence_shape->shape().begin(), sequence_shape->shape().end(), std::back_inserter(tuple_shape),
643 ExtractNewShapeFromShape);
644 out_shape = std::make_shared<ShapeList>(tuple_shape);
645 } else {
646 MS_LOG(EXCEPTION) << "each shape in tuple shape is not shape or sequenceshape";
647 }
648 return out_shape;
649 }
650
GetNodeNewShape(const AnfNodePtr & node)651 NewShapes GetNodeNewShape(const AnfNodePtr &node) {
652 MS_EXCEPTION_IF_NULL(node);
653 NewShapes shapes;
654 if (IsValueNode<ValueList>(node) || IsValueNode<ValueTuple>(node)) {
655 return TransferShapesToNewShapes(GetValueListShape(node), false);
656 }
657 BaseShapePtr base_shape_ptr = node->Shape();
658 if (base_shape_ptr == nullptr && node->isa<ValueNode>()) {
659 auto value_node = node->cast<ValueNodePtr>();
660 MS_EXCEPTION_IF_CHECK_FAIL(value_node->value() != nullptr, "ValueNode has no value.");
661 auto abstract = value_node->value()->ToAbstract();
662 MS_EXCEPTION_IF_CHECK_FAIL(abstract != nullptr, "ValueNode has no Abstract.");
663 node->set_abstract(abstract);
664 base_shape_ptr = node->Shape();
665 }
666 if (node->isa<CNode>() && !IsControlFlowNode(node)) {
667 auto cnode = node->cast<CNodePtr>();
668 if (cnode->input(0)->isa<CNode>()) {
669 if (cnode->size() < kSizeTwo) {
670 MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " size is smaller than 2";
671 }
672 base_shape_ptr = cnode->input(1)->Shape();
673 }
674 }
675 // If node is Depend, only first input should be used.
676 if (node->isa<CNode>() && IsPrimitiveCNode(node->cast<CNodePtr>(), prim::kPrimDepend)) {
677 auto depend_cnode = node->cast<CNodePtr>();
678 MS_EXCEPTION_IF_NULL(depend_cnode->input(1));
679 return GetNodeNewShape(depend_cnode->input(1));
680 }
681 if (base_shape_ptr == nullptr) {
682 MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " shape_ptr is nullptr, full name is "
683 << node->fullname_with_scope();
684 }
685 auto tuple_shape_ptr = dyn_cast<abstract::SequenceShape>(base_shape_ptr);
686 if (tuple_shape_ptr != nullptr) {
687 if (tuple_shape_ptr->size() == 0) {
688 std::vector<int64_t> shape_value = {0};
689 shapes.emplace_back(std::make_shared<ShapeValue>(shape_value));
690 return shapes;
691 }
692 auto tuple_shape = tuple_shape_ptr->shape();
693 if (tuple_shape[0]->isa<abstract::NoShape>()) {
694 std::vector<int64_t> shape_value = {SizeToLong(tuple_shape_ptr->size())};
695 shapes.emplace_back(std::make_shared<ShapeValue>(shape_value));
696 return shapes;
697 }
698 for (auto &shape : tuple_shape) {
699 auto each_shape = ExtractNewShapeFromShape(shape);
700 shapes.emplace_back(each_shape);
701 }
702 } else if (base_shape_ptr->isa<abstract::DynamicSequenceShape>()) {
703 std::vector<int64_t> shape_value = {-1};
704 shapes.emplace_back(std::make_shared<ShapeValue>(shape_value));
705 } else if (base_shape_ptr->isa<abstract::Shape>()) {
706 auto shape_ptr = dyn_cast<abstract::Shape>(base_shape_ptr);
707 MS_EXCEPTION_IF_NULL(shape_ptr);
708 std::vector<int64_t> shape_value = shape_ptr->shape();
709 shapes.emplace_back(std::make_shared<ShapeValue>(shape_value));
710 } else if (base_shape_ptr->isa<abstract::NoShape>()) {
711 std::vector<int64_t> shape_value = {};
712 shapes.emplace_back(std::make_shared<ShapeValue>(shape_value));
713 } else {
714 MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " should be Tuple/List/Tensor/Scalar, but got "
715 << base_shape_ptr->ToString() << "full name is " << node->fullname_with_scope();
716 }
717 return shapes;
718 }
719
FindCommonMirrorGroup(const FuncGraphPtr & root)720 RankList FindCommonMirrorGroup(const FuncGraphPtr &root) {
721 auto parameters = root->parameters();
722 for (auto ¶meter : parameters) {
723 auto param_ptr = parameter->cast<ParameterPtr>();
724 MS_EXCEPTION_IF_NULL(param_ptr);
725 if (!(param_ptr->has_default() && ParameterRequireGrad(param_ptr))) {
726 continue;
727 }
728 size_t allow_repeat_num = 1;
729 if (ParallelContext::GetInstance()->enable_parallel_optimizer() &&
730 (!param_ptr->param_info() || param_ptr->param_info()->parallel_optimizer())) {
731 if (ParallelContext::GetInstance()->optimizer_weight_shard_size() == -1) {
732 MS_LOG(INFO) << "The parameter :" << param_ptr->fullname_with_scope()
733 << " is fully shard by optimizer parallel,"
734 " thus cannot find common data parallel group for this rank";
735 return {g_device_manager->global_rank()};
736 }
737 allow_repeat_num = size_t(ParallelContext::GetInstance()->optimizer_weight_shard_size());
738 }
739 if (IsFullySplitParameter(param_ptr, allow_repeat_num)) {
740 MS_LOG(INFO) << "The parameter :" << param_ptr->fullname_with_scope()
741 << " is fully shard, thus cannot find common data parallel group for this rank";
742 return {g_device_manager->global_rank()};
743 }
744 }
745 AnfNodePtr ret = root->get_return();
746 MS_EXCEPTION_IF_NULL(ret);
747 std::vector<int64_t> common_group_list;
748 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
749 bool is_first_group = true;
750 for (auto &node : all_nodes) {
751 if (!IsPrimitiveCNode(node, prim::kPrimMirror) && !IsPrimitiveCNode(node, prim::kPrimMirrorMicroStep) &&
752 !IsPrimitiveCNode(node, prim::kPrimMirrorMiniStep)) {
753 continue;
754 }
755 auto prim = GetCNodePrimitive(node);
756 if (!prim->HasAttr(GROUP)) {
757 MS_LOG(EXCEPTION) << "The mirror operator dose not have group attr : " << node->DebugString();
758 }
759 std::string group_name = GetValue<std::string>(prim->GetAttr(GROUP));
760 std::vector<int64_t> group_list = g_device_manager->FindRankListByHashName(group_name);
761 if (is_first_group) {
762 common_group_list = group_list;
763 is_first_group = false;
764 } else {
765 std::vector<int64_t> new_comm_group_list;
766 (void)std::set_intersection(common_group_list.begin(), common_group_list.end(), group_list.begin(),
767 group_list.end(), std::back_inserter(new_comm_group_list));
768 common_group_list = new_comm_group_list;
769 }
770 }
771 MS_LOG(INFO) << "The common mirror group is:" << common_group_list;
772 return common_group_list;
773 }
774
CreateInstanceName(const CNodePtr & node,size_t index)775 std::string CreateInstanceName(const CNodePtr &node, size_t index) {
776 MS_EXCEPTION_IF_NULL(node);
777 if (!IsValueNode<Primitive>(node->input(0))) {
778 MS_LOG(EXCEPTION) << "CreateInstanceName: " << node->ToString() << " doesn't have primitive";
779 }
780 std::string name_base = node->fullname_with_scope();
781 std::string name = name_base + "_" + std::to_string(index);
782 std::string instance_name = HashInstanceName(name);
783 return instance_name;
784 }
785
SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input)786 void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) {
787 if (new_node_input.empty()) {
788 return;
789 }
790
791 auto prim_anf_node = new_node_input[0]->cast<ValueNodePtr>();
792 auto prim = GetValueNode<PrimitivePtr>(prim_anf_node);
793 MS_EXCEPTION_IF_NULL(prim);
794
795 auto attrs = prim->attrs();
796 auto iter = attrs.find(GROUP);
797 if (iter != attrs.end()) {
798 auto value = iter->second;
799 MS_EXCEPTION_IF_NULL(value);
800 if (value->isa<StringImm>()) {
801 std::string hash_name = value->cast<StringImmPtr>()->value();
802 MS_EXCEPTION_IF_NULL(g_device_manager);
803 std::string rank_list_name = g_device_manager->FindRankListNameByHashName(hash_name);
804 (void)prim->AddAttr(GROUP_RANKS, MakeValue(rank_list_name));
805 }
806 }
807 }
808
SetStridedSliceSplitStrategy(const std::vector<AnfNodePtr> & all_nodes)809 void SetStridedSliceSplitStrategy(const std::vector<AnfNodePtr> &all_nodes) {
810 for (auto &node : all_nodes) {
811 if (!node->isa<CNode>()) {
812 continue;
813 }
814 auto cnode = node->cast<CNodePtr>();
815 MS_EXCEPTION_IF_NULL(cnode);
816 if (!IsPrimitiveCNode(cnode, prim::kPrimStridedSlice)) {
817 continue;
818 }
819 auto slice_prim = GetCNodePrimitive(cnode);
820 MS_EXCEPTION_IF_NULL(slice_prim);
821 if (slice_prim->HasAttr(FUNC_GRAPH_FLAG_STRIDED_SLICE)) {
822 SetStridedSliceStrategy(cnode);
823 }
824 }
825 }
826
827 // Check the given tensor, return nullptr if the given type is not an TensorType
CheckTensorType(const TypePtr & node_type)828 bool CheckTensorType(const TypePtr &node_type) {
829 MS_EXCEPTION_IF_NULL(node_type);
830 if (!node_type->isa<mindspore::TensorType>()) {
831 return false;
832 }
833 return true;
834 }
835
FindReturnUser(const CNodePtr & cnode,const std::vector<AnfNodePtr> & all_nodes,std::pair<std::shared_ptr<AnfNode>,int> * queue_node)836 void FindReturnUser(const CNodePtr &cnode, const std::vector<AnfNodePtr> &all_nodes,
837 std::pair<std::shared_ptr<AnfNode>, int> *queue_node) {
838 auto graph = cnode->func_graph();
839 auto is_target = [&](const AnfNodePtr &ele) {
840 if (ele->isa<CNode>()) {
841 auto parent_cnode = ele->cast<CNodePtr>();
842 return IsValueNode<FuncGraph>(parent_cnode->input(0)) &&
843 GetValueNode<FuncGraphPtr>(parent_cnode->input(0)) == graph;
844 }
845 return false;
846 };
847 auto it = std::find_if(all_nodes.begin(), all_nodes.end(), is_target);
848 if (it == all_nodes.end()) {
849 return;
850 }
851 *queue_node = {*it, 0};
852 }
853
AddVisitedNode(std::queue<std::pair<std::shared_ptr<AnfNode>,int>> * visited,const NodeUsersMap & node_users_map,const AnfNodePtr & key_node)854 void AddVisitedNode(std::queue<std::pair<std::shared_ptr<AnfNode>, int>> *visited, const NodeUsersMap &node_users_map,
855 const AnfNodePtr &key_node) {
856 if (IsPrimitiveCNode(key_node, prim::kPrimReturn)) {
857 return;
858 }
859 auto node_users = node_users_map.at(key_node);
860 for (auto &node_user : node_users) {
861 auto cnode = node_user.first->cast<CNodePtr>();
862 if (!cnode || IsSomePrimitiveList(cnode, {MAKE_TUPLE, UPDATESTATE})) {
863 continue;
864 }
865 if (node_user.first) {
866 visited->push(node_user);
867 }
868 }
869 }
870
BFSParallelCareNode(const AnfNodePtr & node_ptr,const NodeUsersMap & node_users_map,const int index,const std::vector<AnfNodePtr> & all_nodes)871 std::pair<std::shared_ptr<AnfNode>, int> BFSParallelCareNode(const AnfNodePtr &node_ptr,
872 const NodeUsersMap &node_users_map, const int index,
873 const std::vector<AnfNodePtr> &all_nodes) {
874 std::queue<std::pair<std::shared_ptr<AnfNode>, int>> visited;
875 CNodePtr cnode = nullptr;
876 AnfNodePtr node = nullptr;
877 if (!node_ptr) {
878 return std::make_pair(nullptr, 0);
879 }
880 AddVisitedNode(&visited, node_users_map, node_ptr);
881 while (!visited.empty()) {
882 auto queue_node = visited.front();
883 visited.pop();
884 cnode = queue_node.first->cast<CNodePtr>();
885 if (IsParallelCareNode(cnode) || IsAutoParallelCareNode(cnode)) {
886 return queue_node;
887 } else if (IsValueNode<FuncGraph>(cnode->input(0))) {
888 auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
889 auto params = graph->parameters();
890 auto target_param = params[queue_node.second - 1];
891 auto node_set = node_users_map.at(target_param);
892 for (auto &node_user : node_set) {
893 cnode = node_user.first->cast<CNodePtr>();
894 if (IsParallelCareNode(cnode) || IsAutoParallelCareNode(cnode)) {
895 return node_user;
896 } else if (IsSomePrimitiveList(cnode, {MAKE_TUPLE, UPDATESTATE})) {
897 continue;
898 }
899 visited.push(node_user);
900 }
901 } else {
902 if (IsSomePrimitive(cnode, RETURN)) {
903 FindReturnUser(cnode, all_nodes, &queue_node);
904 } else if (IsSomePrimitive(cnode, kTupleGetItemOpName)) {
905 auto tuple_index = LongToSize(GetValue<int64_t>(GetValueNode(cnode->input(2))));
906 if (tuple_index != IntToSize(index - 1)) {
907 continue;
908 }
909 }
910 AddVisitedNode(&visited, node_users_map, queue_node.first);
911 }
912 }
913 return std::make_pair(nullptr, 0);
914 }
915
916 // For the weight used by cast and matmul at the same time, like the followings
917 // weight1->mirror->cast1-> matmul1;
918 // weight1->add
919 // we will not insert the cast(FP32->FP16), as it will cause the input of the operator add to be changed to fp16.
GetChildCastNode(const AnfNodePtr & node_ptr,const NodeUsersMap & node_users_map)920 AnfNodePtr GetChildCastNode(const AnfNodePtr &node_ptr, const NodeUsersMap &node_users_map) {
921 std::queue<AnfNodePtr> visited;
922 AnfNodePtr queue_node = nullptr;
923 CNodePtr cnode = nullptr;
924 AnfNodePtr node = nullptr;
925 if (!node_ptr) {
926 return nullptr;
927 }
928 auto users = node_users_map.at(node_ptr);
929 for (auto &node_user : users) {
930 cnode = node_user.first->cast<CNodePtr>();
931 if (!cnode || !cnode->in_forward_flag()) {
932 continue;
933 }
934 if (node_user.first) {
935 visited.push(node_user.first);
936 }
937 }
938 while (!visited.empty()) {
939 queue_node = visited.front();
940 visited.pop();
941 cnode = queue_node->cast<CNodePtr>();
942 // MAKE_TUPLE will not appear after the load in the forward graph
943 if (IsSomePrimitive(cnode, MAKE_TUPLE)) {
944 continue;
945 } else if (IsInAllGatherNodeList(cnode) || IsSomePrimitiveList(cnode, {LOAD, RESHAPE})) {
946 auto node_set = node_users_map.at(queue_node);
947 for (auto &node_user : node_set) {
948 visited.push(node_user.first);
949 }
950 } else if (!IsSomePrimitive(cnode, CAST)) {
951 MS_LOG(INFO) << "The weight's users including the non cast node So "
952 << "will not insert cast for this parameter " << node_ptr->DebugString();
953 return nullptr;
954 } else if (!node) {
955 node = queue_node;
956 }
957 }
958 return node;
959 }
960
961 // Given the cnode ptr, find its users until we find the computation node, then return the type of the
962 // computation node. This function is used to find the target type for CreateFP16Cast. Only returns the target type if
963 // it is float16, and the source node is float32. If the situation is not matched, then return the nullptr.
FindChildCastWithFP32ToFP16(const std::pair<AnfNodePtr,int> & res,const NodeUsersMap & node_users_map)964 TypePtr FindChildCastWithFP32ToFP16(const std::pair<AnfNodePtr, int> &res, const NodeUsersMap &node_users_map) {
965 if (ParallelContext::GetInstance()->pipeline_stage_split_num() <= 1) {
966 return nullptr;
967 }
968 auto cnode_ptr = res.first->cast<CNodePtr>();
969 if (!cnode_ptr) {
970 return nullptr;
971 }
972 auto cnode_inputs = cnode_ptr->inputs();
973 if (cnode_inputs.size() < TWO_INPUT_SIZE) {
974 return nullptr;
975 }
976
977 AnfNodePtr node = nullptr;
978 if (IsValueNode<FuncGraph>(cnode_ptr->input(kIndex0))) {
979 auto graph_sub = GetValueNode<FuncGraphPtr>(cnode_ptr->input(0));
980 auto parameters = graph_sub->parameters();
981 auto parameter_sub = parameters[IntToSize(res.second - 1)];
982 node = GetChildCastNode(parameter_sub, node_users_map);
983 } else {
984 // As we execute the function IsWeightValidUsed when we start to insert the mirror, so the second parameter
985 // is always the parameter.
986 auto weight = cnode_inputs[1];
987 if (!weight->isa<Parameter>()) {
988 return nullptr;
989 }
990 MS_LOG(INFO) << "Start to search the weight params:" << weight->DebugString();
991 node = GetChildCastNode(weight, node_users_map);
992 }
993
994 if (!node) {
995 return nullptr;
996 }
997 // get the output dtype of the operator
998 auto node_type = node->Type();
999 if (!CheckTensorType(node_type)) {
1000 return nullptr;
1001 }
1002 auto input_element_type = node_type->cast<mindspore::TensorTypePtr>()->element();
1003 MS_EXCEPTION_IF_NULL(input_element_type);
1004 if (!IsPrimitiveCNode(node)) {
1005 return nullptr;
1006 }
1007 auto cast_input_cnode = node->cast<CNodePtr>()->input(kIndex1)->cast<CNodePtr>();
1008 if (!cast_input_cnode) {
1009 return nullptr;
1010 }
1011 auto source_node_type = cast_input_cnode->Type();
1012 if (!CheckTensorType(source_node_type)) {
1013 return nullptr;
1014 }
1015 auto source_element_type = source_node_type->cast<mindspore::TensorTypePtr>()->element();
1016 MS_EXCEPTION_IF_NULL(source_element_type);
1017 // We only add cast operation when the source is fp32 type, and the users is fp16 type.
1018 if ((source_element_type->type_id() == kNumberTypeFloat32 && input_element_type->type_id() == kNumberTypeFloat16) ||
1019 (source_element_type->type_id() == kNumberTypeFloat32 && input_element_type->type_id() == kNumberTypeBFloat16)) {
1020 return input_element_type;
1021 }
1022 return nullptr;
1023 }
1024
1025 // Create a cast node given the current node and the previous node. The target type of the the cast is from the
1026 // compute_node_type.
1027 // Return the new cast node with pre_node as the inputs.
CreateFP16Cast(const CNodePtr & node,const AnfNodePtr & pre_node,const TypePtr & compute_node_type)1028 AnfNodePtr CreateFP16Cast(const CNodePtr &node, const AnfNodePtr &pre_node, const TypePtr &compute_node_type) {
1029 const char kOpsFunctionModelName[] = "mindspore.ops.functional";
1030 static py::object cast_prim = python_adapter::GetPyFn(kOpsFunctionModelName, "cast");
1031 const auto &adapter = py::cast<PrimitivePyAdapterPtr>(cast_prim);
1032 MS_EXCEPTION_IF_NULL(adapter);
1033 MS_EXCEPTION_IF_NULL(compute_node_type);
1034 auto prim = adapter->attached_primitive();
1035 if (prim == nullptr) {
1036 prim = std::make_shared<PrimitivePy>(cast_prim);
1037 }
1038 // Insert cast.
1039 auto type_node = NewValueNode(compute_node_type);
1040 type_node->set_abstract(compute_node_type->ToAbstract());
1041 auto new_node = node->func_graph()->NewCNode({NewValueNode(prim), pre_node, type_node});
1042 new_node->set_abstract(node->abstract());
1043 new_node->set_scope(node->scope());
1044 new_node->set_in_forward_flag(true);
1045 return new_node;
1046 }
1047
LabelGenMaskMicro(const FuncGraphPtr & root)1048 void LabelGenMaskMicro(const FuncGraphPtr &root) {
1049 AnfNodePtr ret = root->get_return();
1050 MS_EXCEPTION_IF_NULL(ret);
1051 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
1052 for (auto &node : all_nodes) {
1053 if (IsPrimitiveCNode(node, prim::kPrimDropoutDoMask)) {
1054 auto gen_mask_node = RealInputNode(node->cast<CNodePtr>(), 2);
1055 if (gen_mask_node->isa<CNode>()) {
1056 gen_mask_node->cast<CNodePtr>()->set_primal_attrs(node->cast<CNodePtr>()->primal_attrs());
1057 }
1058 }
1059 }
1060 }
1061
SetCastForParamNotRecompute(const std::vector<AnfNodePtr> & all_nodes)1062 void SetCastForParamNotRecompute(const std::vector<AnfNodePtr> &all_nodes) {
1063 for (const auto &node : all_nodes) {
1064 if (!IsPrimitiveCNode(node)) {
1065 continue;
1066 }
1067 auto cnode = node->cast<CNodePtr>();
1068 auto cnode_prim = GetCNodePrimitive(cnode);
1069 if (cnode_prim->HasAttr("DISABLE_MERGE_ASSIGN_ADD")) {
1070 cnode->AddPrimalAttr("DISABLE_MERGE_ASSIGN_ADD", cnode_prim->GetAttr("DISABLE_MERGE_ASSIGN_ADD"));
1071 }
1072 if (!IsPrimitiveCNode(node, prim::kPrimCast)) {
1073 continue;
1074 }
1075 auto cast_input = RealInputNode(cnode, 1);
1076 if (cast_input->isa<Parameter>() && cast_input->cast<ParameterPtr>()->has_default()) {
1077 MS_LOG(INFO) << "Cast for parameter no needs recompute to avoid redundant trans_data operator";
1078 PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0)->cast<ValueNodePtr>());
1079 (void)prim->AddAttr("recompute", MakeValue(false));
1080 }
1081 }
1082 }
1083
GetAttrsFromAnfNode(const std::shared_ptr<AnfNode> & node,const string & key)1084 std::shared_ptr<Value> GetAttrsFromAnfNode(const std::shared_ptr<AnfNode> &node, const string &key) {
1085 if (!node) {
1086 return nullptr;
1087 }
1088 auto cnode = node->cast<CNodePtr>();
1089 auto prim = GetCNodePrimitive(cnode);
1090 if (prim && prim->HasAttr(key)) {
1091 return prim->GetAttr(key);
1092 }
1093 return nullptr;
1094 }
1095
IsSplittableOperator(const std::string & op_name)1096 bool IsSplittableOperator(const std::string &op_name) {
1097 // clang-format off
1098 static const std::set<std::string> splittable_op =
1099 {MATMUL, TRANSPOSE, GELU, FAST_GELU, TANH, SOFTMAX, SUB, MUL, DIV, RESHAPE, GREATER, LOG_SOFTMAX, ACTIVATION, PRELU,
1100 BATCH_MATMUL_EXT, MATMUL_EXT,
1101 FLOORDIV, L2_NORMALIZE, ADD, MAXPOOL, AVGPOOL, MAXPOOLV2, VIRTUAL_DATA_SET, RELU, ONEHOT, DROPOUT_DO_MASK,
1102 REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING, STACK_EXT,
1103 MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP, STACK,
1104 LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, CONCAT,
1105 STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, TILE, DROPOUT,
1106 SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS,
1107 EMBEDDING_LOOKUP, FUSE_BATCH_NORM_EX, SPLIT, BROADCAST_TO, ABS, ACOSH, ASIN, ASINH, ATAN, ATANH, CEIL, COSH,
1108 EXPM1, LOG1P, SIN, SINH, TAN, RSQRT, INV, RECIPROCAL, ROUND, FLOOR, SIGN, ERF, ERFC, ZEROSLIKE, ONESLIKE,
1109 BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6,
1110 SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM,
1111 UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, SLICE_EXT, SELECT,
1112 GATHERD, UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE, SCATTER_ND_UPDATE, SCATTER_ND_ADD, SCATTER_ND_SUB,
1113 TENSOR_SCATTER_UPDATE, TENSOR_SCATTER_ADD, TENSOR_SCATTER_SUB, TENSOR_SCATTER_MAX, TENSOR_SCATTER_MIN, WKV,
1114 TENSOR_SCATTER_MUL, TENSOR_SCATTER_DIV, VIRTUAL_OUTPUT, CONV2D_BACK_PROP_INPUT, CONV2D_TRANSPOSE, SORT, PAD_V3,
1115 MATMUL_DDS, DSD_MATMUL, UNIFORMREAL, STANDARD_NORMAL, RESIZE_BILINEAR_V2, RESIZE_NEAREST_NEIGHBOR, FAST_GELU, IOU,
1116 BOUNDING_BOX_ENCODE, UNSORTED_SEGMENT_PROD, SQUARE_SUM_ALL, UNIQUE_CONSECUTIVE, SILU, INDEX_SELECT, CLAMP_SCALAR,
1117 RANDOM_CHOICE_WITH_MASK, CROP_AND_RESIZE, ROI_ALIGN, REDUCE_PROD, REDUCE_ANY, REDUCE_ALL, ARGMAX, ARGMIN, ARGMINV2,
1118 RESIZE_NEAREST_NEIGHBOR, CUM_SUM, FAST_GELU, IOU, BOUNDING_BOX_ENCODE, RANDOM_CHOICE_WITH_MASK, CROP_AND_RESIZE,
1119 ROI_ALIGN, IS_FINITE, RINT, HSHRINK, HSIGMOID, MISH, SELU, SOFT_SHRINK, XLOGY, XDIVY, CUM_PROD, BITWISE_AND,
1120 BITWISE_OR, BITWISE_XOR, MUL_NO_NAN, TRUNCATE_DIV, TRUNCATE_MOD, INPLACE_ADD, INPLACE_SUB, INPLACE_UPDATE,
1121 L2_LOSS, LERP, ADDN, CDIST, SQUARED_DIFFERENCE, ERFINV, MASKED_FILL, SPLITV, GAMMA, KLDIV_LOSS, LIN_SPACE,
1122 CHECK_VALID, INVERT, SCATTER_ADD, SCATTER_DIV, SCATTER_MUL, SCATTER_MAX, SCATTER_MIN, SCATTER_SUB, UNIQUE_WITH_PAD,
1123 POPULATION_COUNT, IDENTITY, BESSELI0, BESSELI1, BESSELJ0, BESSELJ1, CUM_MAX, CUM_MIN, HYPOT, IGAMMA, IGAMMAC,
1124 LEFT_SHIFT, RIGHT_SHIFT, NEXT_AFTER, ZETA, REVERSEV2, LGAMMA, TRUNC, BETAINC, GCD, CHOLESKY, CONV3D, MAXPOOL_3D,
1125 AVGPOOL_3D, FILLV2, FAKE_QUANT_PER_LAYER, FAKE_QUANT_PER_CHANNEL, MIN_MAX_UPDATE_PER_LAYER, ASCEND_QUANTV2,
1126 MIN_MAX_UPDATE_PER_CHANNEL, FFN, FLASH_ATTENTION_SCORE, ASCEND_QUANT, ASCEND_DEQUANT, GRID_SAMPLER_2D, ANTI_QUANT,
1127 CONVOLUTION, LIN_SPACE_EXT, ONEHOTEXT};
1128 // clang-format on
1129
1130 auto iter = splittable_op.find(op_name);
1131 return (iter != splittable_op.end());
1132 }
1133
IsAutoParallelCareNode(const CNodePtr & cnode)1134 bool IsAutoParallelCareNode(const CNodePtr &cnode) {
1135 MS_EXCEPTION_IF_NULL(cnode);
1136 ValueNodePtr prim_node = cnode->input(0)->cast<ValueNodePtr>();
1137 if (prim_node == nullptr) {
1138 return false;
1139 }
1140 PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_node);
1141 if (prim == nullptr) {
1142 return false;
1143 }
1144 if (IsSomePrimitiveList(cnode, {SEND, RECEIVE, MAKE_TUPLE, MAKE_LIST})) {
1145 return false;
1146 }
1147 bool bool_result = IsParallelCareNode(cnode) && !IsSplittableOperator(prim->name());
1148 if (bool_result) {
1149 MS_LOG(INFO) << "For 'auto_parallel', missing the splitable implementation of OperatorInfo for: " << prim->name()
1150 << ", default strategy will be assigned. Network training may deteriorate or malfunction";
1151 } else if (prim->name() == CAST) {
1152 if (cnode->fullname_with_scope().find(OPTIMIZER_SUB_STRING) != std::string::npos) {
1153 // Do not care CASTs from optimizer
1154 return false;
1155 }
1156 return cnode->in_forward_flag();
1157 }
1158 return IsParallelCareNode(cnode);
1159 }
1160
UpdateMicroBatchInterleavedStatus(const std::vector<AnfNodePtr> & all_nodes)1161 void UpdateMicroBatchInterleavedStatus(const std::vector<AnfNodePtr> &all_nodes) {
1162 for (auto &node : all_nodes) {
1163 if (!node->isa<CNode>()) {
1164 continue;
1165 }
1166 auto cnode = node->cast<CNodePtr>();
1167 MS_EXCEPTION_IF_NULL(cnode);
1168 if (!IsPrimitiveCNode(cnode, prim::kPrimStridedSlice)) {
1169 continue;
1170 }
1171 auto slice_prim = GetCNodePrimitive(cnode);
1172 MS_EXCEPTION_IF_NULL(slice_prim);
1173 if (!slice_prim->HasAttr(FUNC_GRAPH_FLAG_STRIDED_SLICE)) {
1174 continue;
1175 }
1176 if (!slice_prim->HasAttr(INTERLEAVED_NUM)) {
1177 continue;
1178 }
1179 if (GetValue<int64_t>(slice_prim->GetAttr(INTERLEAVED_NUM)) == MICRO_INTERLEAVED_SIZE) {
1180 ParallelContext::GetInstance()->set_enable_micro_interleaved(true);
1181 cnode->AddAttr(INTERLEAVED_NUM, slice_prim->GetAttr(INTERLEAVED_NUM));
1182 }
1183 }
1184 }
1185
GetDisOpName(const std::string & prim_name)1186 std::string GetDisOpName(const std::string &prim_name) {
1187 std::string op_name = prim_name;
1188 if (!prim_name.empty() && (prim_name[0] == '_')) {
1189 op_name = prim_name.substr(1);
1190 }
1191 return op_name + "Info";
1192 }
1193
OperatorInstanceByName(const std::string & name,const PrimitiveAttrs & attrs,const std::vector<Shapes> & shape_list)1194 OperatorInfoPtr OperatorInstanceByName(const std::string &name, const PrimitiveAttrs &attrs,
1195 const std::vector<Shapes> &shape_list) {
1196 if (shape_list.size() != 2) {
1197 MS_LOG(ERROR) << "The size of shape list is not 2";
1198 return nullptr;
1199 }
1200 if (name.length() == 0) {
1201 MS_LOG(EXCEPTION) << "Length of name is zero!";
1202 }
1203
1204 if (name == "Custom" &&
1205 (attrs.find(KAttrAsLossDivisor) == attrs.end() || attrs.find(KAttrDevMatrixShape) == attrs.end() ||
1206 attrs.find(KAttrInputsTensorMap) == attrs.end() || attrs.find(KAttrOutputsTensorMap) == attrs.end())) {
1207 MS_LOG(WARNING) << "The attr for parallelization settings is not found in the custom op."
1208 << "To enable auto parallelization, set the attrs including [" << KAttrAsLossDivisor << ", "
1209 << KAttrDevMatrixShape << ", " << KAttrInputsTensorMap << ", " << KAttrOutputsTensorMap << "]";
1210 return nullptr;
1211 }
1212 std::string distribute_opname = GetDisOpName(name);
1213 OperatorInfoPtr op_info =
1214 (OperatorInfoPtr)DynCreator::Instance().Create(distribute_opname, shape_list[0], shape_list[1], attrs, TOTAL_OPS);
1215 if (op_info == nullptr) {
1216 MS_LOG(INFO) << "Create " << name << " failed";
1217 return nullptr;
1218 }
1219 std::string origin_name = op_info->name();
1220 op_info->set_name(origin_name + std::to_string(TOTAL_OPS));
1221 MS_LOG(INFO) << "Successfully created operator " << origin_name;
1222 ++TOTAL_OPS;
1223 return op_info;
1224 }
1225
OperatorInstance(const PrimitivePtr & prim,const PrimitiveAttrs & attrs,const std::vector<Shapes> & shape_list)1226 OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs,
1227 const std::vector<Shapes> &shape_list) {
1228 MS_EXCEPTION_IF_NULL(prim);
1229 OperatorInfoPtr op_info;
1230 if (prim->HasAttr(SELF_DEFINE_SHARD)) {
1231 auto self_define_shard_attr = prim->GetAttr(SELF_DEFINE_SHARD);
1232 if (self_define_shard_attr->cast_ptr<BoolImm>() == nullptr) {
1233 MS_LOG(EXCEPTION) << "SELF_DEFINE_SHARD attribute is not a bool";
1234 }
1235 if (GetValue<bool>(self_define_shard_attr)) {
1236 op_info = OperatorInstanceByName(SELF_DEFINE_SHARD_OP, attrs, shape_list);
1237 MS_LOG(INFO) << "Operator " << prim->name() << " has self_define_shard attribute. Create SelfDefineShardInfo";
1238 return op_info;
1239 }
1240 }
1241 op_info = OperatorInstanceByName(prim->name(), attrs, shape_list);
1242 if (op_info) {
1243 return op_info;
1244 }
1245 if (IsInBatchParallelBlackList(prim)) {
1246 op_info = OperatorInstanceByName(STAND_ALONE, attrs, shape_list);
1247 prim->AddAttr(STAND_ALONE, MakeValue<bool>(true));
1248 MS_LOG(INFO) << "Operator " << prim->name() << " is not supported yet in auto parallel mode. Use Stand Alone";
1249 return op_info;
1250 }
1251 auto input_shape = shape_list[0];
1252 auto output_shape = shape_list[1];
1253 MS_EXCEPTION_IF_NULL(g_device_manager);
1254 auto device_num = g_device_manager->stage_device_num();
1255 MS_EXCEPTION_IF_ZERO("device_num", device_num);
1256 if (input_shape.empty() || input_shape[0].empty() || input_shape[0][0] % device_num != 0 || output_shape[0].empty() ||
1257 output_shape[0][0] % device_num != 0) {
1258 MS_LOG(INFO) << "Operator " << prim->name() << " use Stand Alone, the input shape is " << input_shape
1259 << ", the output shape is " << output_shape;
1260 op_info = OperatorInstanceByName(STAND_ALONE, attrs, shape_list);
1261 prim->AddAttr(STAND_ALONE, MakeValue<bool>(true));
1262 return op_info;
1263 }
1264 MS_LOG(INFO) << "Operator " << prim->name() << " use Batch Parallel";
1265 op_info = OperatorInstanceByName(BATCH_PARALLEL, attrs, shape_list);
1266 prim->AddAttr(BATCH_PARALLEL, MakeValue<bool>(true));
1267 return op_info;
1268 }
1269
GetRefKeyNodeShape(const AnfNodePtr & node,const FuncGraphPtr & func_graph)1270 static Shapes GetRefKeyNodeShape(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
1271 MS_EXCEPTION_IF_NULL(node);
1272 MS_EXCEPTION_IF_NULL(func_graph);
1273
1274 std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(node, func_graph);
1275 if (parameters.size() != 1) {
1276 MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
1277 }
1278
1279 Shapes input_shapes = GetNodeShape(parameters[0]);
1280 if (input_shapes.size() != 1) {
1281 MS_LOG(EXCEPTION) << "Get input shape failed";
1282 }
1283
1284 MS_LOG(INFO) << "The parameter shape is " << ShapeToString(input_shapes[0]);
1285 return input_shapes;
1286 }
1287
ExtractNewShapeAndSymbol(const CNodePtr & node)1288 std::pair<std::vector<NewShapes>, std::vector<Symbols>> ExtractNewShapeAndSymbol(const CNodePtr &node) {
1289 MS_EXCEPTION_IF_NULL(node);
1290 NewShapes shape_inputs;
1291 NewShapes shape_outputs;
1292 Symbols symbol_inputs;
1293 Symbols symbol_outputs;
1294 std::vector<NewShapes> shape_all;
1295 std::vector<Symbols> symbol_all;
1296 std::vector<AnfNodePtr> all_inputs = node->inputs();
1297 bool need_create_shape_list = false;
1298
1299 const int min_size = 2;
1300 size_t inputs_size = all_inputs.size();
1301 for (size_t i = 1; i < inputs_size; ++i) {
1302 ShapeBasePtr input_new_shapes;
1303 Shapes input_shapes;
1304 Symbols input_symbols;
1305 AnfNodePtr input = all_inputs[i];
1306 if (HasAbstractMonad(input)) {
1307 continue;
1308 }
1309 if (IsValueNode<RefKey>(input)) {
1310 auto func_graph = node->func_graph();
1311 MS_EXCEPTION_IF_NULL(func_graph);
1312 std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(input, func_graph);
1313 if (parameters.size() != 1) {
1314 MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
1315 }
1316 std::pair<AnfNodePtr, int64_t> node_pair = std::make_pair(node, SizeToLong(i));
1317 g_RefMap[parameters[0]] = node_pair;
1318 MS_LOG(INFO) << "Find parameter by ref key node" << node_pair.first;
1319 input_shapes = GetRefKeyNodeShape(input, func_graph);
1320 input_symbols = StaticShapesToSymbols(input_shapes); // now the parameter can only be static shape
1321 } else if (input->isa<CNode>() || IsValueNode<Tensor>(input) || input->isa<Parameter>() ||
1322 (IsValueSequence(input) &&
1323 (inputs_size == min_size || IsSomePrimitiveList(node, INPUT_IS_TUPLE_OR_LIST_OPS)))) {
1324 if (IsDynamicShapeInput(node, input)) {
1325 MS_LOG(INFO) << "may be dynamic shape, no need to get input's shape, the node is " << node->ToString();
1326 continue;
1327 }
1328
1329 if (IsPrimitiveCNode(input, prim::kPrimShape)) {
1330 input_shapes = GetNodeShape(input->cast<CNodePtr>()->input(1));
1331 input_symbols = GetNodeSymbol(input->cast<CNodePtr>()->input(1));
1332 } else {
1333 input_shapes = GetNodeShape(input);
1334 input_symbols = GetNodeSymbol(input);
1335 }
1336 if ((input->abstract()->isa<abstract::AbstractSequence>() || IsValueSequence(input))) {
1337 need_create_shape_list = true;
1338 }
1339 } else if (IsValueSequence(input)) {
1340 auto temp_input_node = input;
1341 if (IsPrimitiveCNode(input, prim::kPrimShape)) {
1342 temp_input_node = input->cast<CNodePtr>()->input(1);
1343 }
1344 need_create_shape_list = true;
1345 input_shapes = GetNodeShape(temp_input_node);
1346 input_symbols = GetNodeSymbol(temp_input_node);
1347 } else {
1348 continue;
1349 }
1350 // For normal shape
1351 input_new_shapes = TransferShapesToNewShapes(input_shapes, need_create_shape_list)[0];
1352 need_create_shape_list = false;
1353 shape_inputs.emplace_back(input_new_shapes);
1354 symbol_inputs.push_back(input_symbols[0]);
1355 }
1356 shape_all.push_back(shape_inputs);
1357 symbol_all.push_back(symbol_inputs);
1358 // extract out shape
1359 shape_outputs = GetNodeNewShape(node);
1360 symbol_outputs = GetNodeSymbol(node);
1361 shape_all.push_back(shape_outputs);
1362 symbol_all.push_back(symbol_outputs);
1363
1364 return std::make_pair(shape_all, symbol_all);
1365 }
1366
ExtractShapeAndSymbol(const CNodePtr & node)1367 std::pair<std::vector<Shapes>, std::vector<Symbols>> ExtractShapeAndSymbol(const CNodePtr &node) {
1368 MS_EXCEPTION_IF_NULL(node);
1369 Shapes shape_inputs;
1370 Shapes shape_outputs;
1371 Symbols symbol_inputs;
1372 Symbols symbol_outputs;
1373 std::vector<Shapes> shape_all;
1374 std::vector<Symbols> symbol_all;
1375 std::vector<AnfNodePtr> all_inputs = node->inputs();
1376
1377 const int min_size = 2;
1378 size_t inputs_size = all_inputs.size();
1379 for (size_t i = 1; i < inputs_size; ++i) {
1380 Shapes input_shapes;
1381 Symbols input_symbols;
1382 AnfNodePtr input = all_inputs[i];
1383 if (HasAbstractMonad(input)) {
1384 continue;
1385 }
1386 if (IsValueNode<RefKey>(input)) {
1387 auto func_graph = node->func_graph();
1388 MS_EXCEPTION_IF_NULL(func_graph);
1389 std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(input, func_graph);
1390 if (parameters.size() != 1) {
1391 MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
1392 }
1393 std::pair<AnfNodePtr, int64_t> node_pair = std::make_pair(node, SizeToLong(i));
1394 g_RefMap[parameters[0]] = node_pair;
1395 MS_LOG(INFO) << "Find parameter by ref key node" << node_pair.first;
1396 input_shapes = GetRefKeyNodeShape(input, func_graph);
1397 input_symbols = StaticShapesToSymbols(input_shapes); // now the parameter can only be static shape
1398 } else if (input->isa<CNode>() || IsValueNode<Tensor>(input) || input->isa<Parameter>() ||
1399 (IsValueSequence(input) &&
1400 (inputs_size == min_size || IsSomePrimitiveList(node, INPUT_IS_TUPLE_OR_LIST_OPS)))) {
1401 if (IsDynamicShapeInput(node, input)) {
1402 MS_LOG(INFO) << "may be dynamic shape, no need to get input's shape, the node is " << node->ToString();
1403 continue;
1404 }
1405
1406 if (IsPrimitiveCNode(input, prim::kPrimShape)) {
1407 input_shapes = GetNodeShape(input->cast<CNodePtr>()->input(1));
1408 input_symbols = GetNodeSymbol(input->cast<CNodePtr>()->input(1));
1409 } else {
1410 input_shapes = GetNodeShape(input);
1411 input_symbols = GetNodeSymbol(input);
1412 }
1413 } else {
1414 continue;
1415 }
1416 if (input_shapes.size() != 1) {
1417 if (inputs_size == min_size || IsSomePrimitiveList(node, INPUT_IS_TUPLE_OR_LIST_OPS)) {
1418 shape_inputs = input_shapes;
1419 symbol_inputs = input_symbols;
1420 break;
1421 } else {
1422 MS_LOG(EXCEPTION) << "ExtractShape: Get input shape failed";
1423 }
1424 }
1425 shape_inputs.push_back(input_shapes[0]);
1426 symbol_inputs.push_back(input_symbols[0]);
1427 }
1428 shape_all.push_back(shape_inputs);
1429 symbol_all.push_back(symbol_inputs);
1430 // extract out shape
1431 shape_outputs = GetNodeShape(node);
1432 symbol_outputs = GetNodeSymbol(node);
1433 shape_all.push_back(shape_outputs);
1434 symbol_all.push_back(symbol_outputs);
1435
1436 return std::make_pair(shape_all, symbol_all);
1437 }
1438
ExtractShape(const CNodePtr & node)1439 std::vector<Shapes> ExtractShape(const CNodePtr &node) {
1440 MS_EXCEPTION_IF_NULL(node);
1441 auto shapes_and_symbols = ExtractShapeAndSymbol(node);
1442 return shapes_and_symbols.first;
1443 }
1444
ExtractNewShape(const CNodePtr & node)1445 std::vector<NewShapes> ExtractNewShape(const CNodePtr &node) {
1446 MS_EXCEPTION_IF_NULL(node);
1447 auto shapes_and_symbols = ExtractNewShapeAndSymbol(node);
1448 return shapes_and_symbols.first;
1449 }
1450
ExtractRealDivisor(const CNodePtr & node)1451 std::vector<Shapes> ExtractRealDivisor(const CNodePtr &node) {
1452 MS_EXCEPTION_IF_NULL(node);
1453 auto shapes_and_symbols = ExtractShapeAndSymbol(node);
1454 std::vector<Shapes> shapes = shapes_and_symbols.first;
1455 std::vector<Symbols> symbols = shapes_and_symbols.second;
1456 if (shapes.size() != INPUT_OUTPUT_SYMBOLS_SIZE || symbols.size() != INPUT_OUTPUT_SYMBOLS_SIZE) {
1457 MS_LOG(EXCEPTION) << "the size of shapes or symbols must be " << INPUT_OUTPUT_SYMBOLS_SIZE
1458 << ", but the size of shapes is " << shapes.size() << ", the size of symbols is "
1459 << symbols.size();
1460 }
1461
1462 auto inputs_shape = shapes[0];
1463 auto outputs_shape = shapes[1];
1464 auto inputs_symbol = symbols[0];
1465 auto outputs_symbol = symbols[1];
1466
1467 Shapes in_divisor_symbols;
1468 Shapes out_divisor_symbols;
1469 MS_LOG(DEBUG) << "the node is " << node->ToString() << ", the divisor of inputs is "
1470 << DivisorOfSymbolsToString(inputs_symbol) << ", the inputs shape is " << ShapesToString(inputs_shape);
1471 in_divisor_symbols = GetRealDivisorSymbols(inputs_shape, inputs_symbol);
1472 out_divisor_symbols = GetRealDivisorSymbols(outputs_shape, outputs_symbol);
1473
1474 MS_LOG(DEBUG) << "the node is " << node->ToString() << ", the inputs shape is " << ShapesToString(inputs_shape)
1475 << ", the inputs divisor is " << ShapesToString(in_divisor_symbols);
1476 MS_LOG(DEBUG) << "the node is " << node->ToString() << ", the outputs shape is " << ShapesToString(outputs_shape)
1477 << ", the outputs divisor is " << ShapesToString(out_divisor_symbols);
1478 return {in_divisor_symbols, out_divisor_symbols};
1479 }
1480
GetInputNodeWithFilter(const AnfNodePtr & node,std::function<std::pair<bool,size_t> (const CNodePtr &)> filter)1481 AnfNodePtr GetInputNodeWithFilter(const AnfNodePtr &node,
1482 std::function<std::pair<bool, size_t>(const CNodePtr &)> filter) {
1483 std::queue<AnfNodePtr> anf_queue;
1484 anf_queue.push(node);
1485 while (!anf_queue.empty()) {
1486 auto queue_end = anf_queue.front();
1487 anf_queue.pop();
1488 if (!queue_end->isa<CNode>()) {
1489 return queue_end;
1490 }
1491 auto cnode_queue_end = queue_end->cast<CNodePtr>();
1492 auto filter_res = filter(cnode_queue_end);
1493 if (!filter_res.first) {
1494 return queue_end;
1495 }
1496 anf_queue.push(cnode_queue_end->input(filter_res.second));
1497 }
1498 return node;
1499 }
1500
GetOutputNodesWithFilter(const AnfNodePtr & node,std::function<bool (const AnfNodePtr &)> filter)1501 std::vector<std::pair<AnfNodePtr, int>> GetOutputNodesWithFilter(const AnfNodePtr &node,
1502 std::function<bool(const AnfNodePtr &)> filter) {
1503 auto func_graph = node->func_graph();
1504 MS_EXCEPTION_IF_NULL(func_graph);
1505 auto manager = func_graph->manager();
1506 MS_EXCEPTION_IF_NULL(manager);
1507 std::vector<std::pair<AnfNodePtr, int>> res;
1508 std::queue<AnfNodePtr> anf_queue;
1509 anf_queue.push(node);
1510 while (!anf_queue.empty()) {
1511 auto queue_end = anf_queue.front();
1512 anf_queue.pop();
1513 auto user_set = manager->node_users()[queue_end];
1514 for (auto &pair : user_set) {
1515 if (filter(pair.first)) {
1516 anf_queue.push(pair.first);
1517 continue;
1518 }
1519 res.push_back(pair);
1520 }
1521 }
1522 return res;
1523 }
1524
GetOutputNodesSkipDepend(const AnfNodePtr & node)1525 std::vector<std::pair<AnfNodePtr, int>> GetOutputNodesSkipDepend(const AnfNodePtr &node) {
1526 auto func_graph = node->func_graph();
1527 MS_EXCEPTION_IF_NULL(func_graph);
1528 auto manager = func_graph->manager();
1529 MS_EXCEPTION_IF_NULL(manager);
1530 std::vector<std::pair<AnfNodePtr, int>> res;
1531 std::queue<AnfNodePtr> anf_queue;
1532 anf_queue.push(node);
1533 while (!anf_queue.empty()) {
1534 auto queue_end = anf_queue.front();
1535 anf_queue.pop();
1536 auto user_set = manager->node_users()[queue_end];
1537 for (auto &pair : user_set) {
1538 if (IsPrimitiveCNode(pair.first, prim::kPrimDepend)) {
1539 if (pair.second == 1) {
1540 anf_queue.push(pair.first);
1541 }
1542 continue;
1543 }
1544 res.push_back(pair);
1545 }
1546 }
1547 return res;
1548 }
1549
CanMergeConcatSlice(const std::pair<std::shared_ptr<AnfNode>,int> & pair,const CNodePtr & concat_cnode,const ShapeVector & concat_output_shape_element,int64_t concat_axis)1550 std::pair<bool, size_t> CanMergeConcatSlice(const std::pair<std::shared_ptr<AnfNode>, int> &pair,
1551 const CNodePtr &concat_cnode,
1552 const ShapeVector &concat_output_shape_element, int64_t concat_axis) {
1553 if (!IsPrimitiveCNode(pair.first, prim::kPrimStridedSlice)) {
1554 return {false, 0};
1555 }
1556 auto slice_cnode = pair.first->cast<CNodePtr>();
1557 MS_LOG(INFO) << "concat slice cnode:" << slice_cnode->fullname_with_scope();
1558 auto begin_value = GetValueNode(slice_cnode->input(2));
1559 auto end_value = GetValueNode(slice_cnode->input(3));
1560 auto strided_value = GetValueNode(slice_cnode->input(4));
1561 if (!begin_value || !end_value || !strided_value) {
1562 return {false, 0};
1563 }
1564 auto begin = GetValue<std::vector<int64_t>>(begin_value);
1565 auto end = GetValue<std::vector<int64_t>>(end_value);
1566 auto strided = GetValue<std::vector<int64_t>>(strided_value);
1567 if (!std::all_of(strided.begin(), strided.end(), [](auto s) { return s == 1; })) {
1568 return {false, 0};
1569 }
1570 if (!IsPrimitiveCNode(concat_cnode->input(1), prim::kPrimMakeTuple)) {
1571 return {false, 0};
1572 }
1573 auto concat_input_node = concat_cnode->input(1)->cast<CNodePtr>();
1574 auto concat_input_size = concat_input_node->size();
1575 bool can_merge = false;
1576 size_t concat_input_index = 0;
1577 for (size_t i = 0; i < begin.size(); ++i) {
1578 int64_t slice_len = (end[i] - begin[i]);
1579 if (i == size_t(concat_axis)) {
1580 int64_t slice_index = begin[i] / slice_len;
1581 if (slice_len == concat_output_shape_element[i] || size_t(slice_index + 1) >= concat_input_size) {
1582 can_merge = false;
1583 break;
1584 }
1585 concat_input_index = size_t(slice_index + 1);
1586 can_merge = true;
1587 } else if (slice_len != concat_output_shape_element[i]) {
1588 can_merge = false;
1589 break;
1590 }
1591 }
1592 return {can_merge, concat_input_index};
1593 }
1594
UpdateUpdateStateForMergeConcatSlice(const FuncGraphManagerPtr & manager,const std::vector<std::pair<AnfNodePtr,int>> & update_list,const CNodePtr & tuple_get_item_node)1595 void UpdateUpdateStateForMergeConcatSlice(const FuncGraphManagerPtr &manager,
1596 const std::vector<std::pair<AnfNodePtr, int>> &update_list,
1597 const CNodePtr &tuple_get_item_node) {
1598 for (const auto &ups_pair : update_list) {
1599 manager->SetEdge(ups_pair.first, ups_pair.second, tuple_get_item_node);
1600 }
1601 }
1602
HandleFuncConcatSlice(const FuncGraphManagerPtr & manager,const std::pair<std::shared_ptr<AnfNode>,int> & pair,const CNodePtr & concat_cnode,const ShapeVector & concat_output_shape_element,int64_t concat_axis)1603 bool HandleFuncConcatSlice(const FuncGraphManagerPtr &manager, const std::pair<std::shared_ptr<AnfNode>, int> &pair,
1604 const CNodePtr &concat_cnode, const ShapeVector &concat_output_shape_element,
1605 int64_t concat_axis) {
1606 auto fg = pair.first->func_graph();
1607 auto fg_map = fg->func_graph_cnodes_index();
1608 if (fg_map.size() > 1) {
1609 return false;
1610 }
1611 for (auto &fg_use : fg_map) {
1612 if (!fg_use.first->first->isa<CNode>() || fg_use.first->second > 0) {
1613 continue;
1614 }
1615 auto call_cnode = fg_use.first->first->cast<CNodePtr>();
1616 auto func_users = manager->node_users()[call_cnode];
1617 std::vector<std::pair<AnfNodePtr, int>> update_list;
1618 size_t func_users_size = 0;
1619 std::pair<AnfNodePtr, int> fg_users;
1620 for (auto &cur_fg_users : func_users) {
1621 if (IsPrimitiveCNode(cur_fg_users.first, prim::kPrimUpdateState)) {
1622 update_list.push_back(cur_fg_users);
1623 continue;
1624 }
1625 ++func_users_size;
1626 fg_users = cur_fg_users;
1627 }
1628
1629 if (func_users_size > 1) {
1630 continue;
1631 }
1632 auto func_node_users = FuncGraphNodeUsers(fg_users);
1633 if (func_node_users.empty()) {
1634 continue;
1635 }
1636 bool have_can_merge = false;
1637 std::vector<std::pair<bool, size_t>> input_index;
1638 for (const auto &new_pair : func_node_users) {
1639 auto can_merge = CanMergeConcatSlice(new_pair, concat_cnode, concat_output_shape_element, concat_axis);
1640 input_index.push_back(can_merge);
1641 if (can_merge.first) {
1642 have_can_merge = true;
1643 }
1644 }
1645 if (!have_can_merge) {
1646 continue;
1647 }
1648 // maketuple->Return
1649 auto concat_input_node = concat_cnode->input(1)->cast<CNodePtr>();
1650 manager->SetEdge(pair.first, pair.second, concat_input_node);
1651 // call -> tuplegetitem -> call
1652 auto user_func_graph = GetValueNode<FuncGraphPtr>(fg_users.first->cast<CNodePtr>()->input(0));
1653 auto user_graph_parameters = user_func_graph->parameters();
1654 auto origin_parameter = user_graph_parameters[fg_users.second - 1];
1655 auto new_user_graph_parameters(user_graph_parameters);
1656 new_user_graph_parameters.erase(new_user_graph_parameters.begin() + fg_users.second - 1);
1657 auto fg_users_inputs_all(fg_users.first->cast<CNodePtr>()->inputs());
1658 fg_users_inputs_all.erase(fg_users_inputs_all.begin() + fg_users.second);
1659 // New concat CNode in user_func_graph
1660 std::vector<AnfNodePtr> new_concat_maketuple_inputs{NewValueNode(prim::kPrimMakeTuple)};
1661 std::vector<AbstractBasePtr> new_maketuple_abstracts;
1662 bool updated_update_state = false;
1663 for (size_t i = 0; i < concat_input_node->size() - 1; ++i) {
1664 std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), call_cnode,
1665 ValuePtrToAnfNodePtr(MakeValue<int64_t>(i))};
1666 auto tuple_get_item_node = call_cnode->func_graph()->NewCNode(tuple_get_item_inputs);
1667 if (!updated_update_state) {
1668 UpdateUpdateStateForMergeConcatSlice(manager, update_list, tuple_get_item_node);
1669 updated_update_state = true;
1670 }
1671 // replace fg_users->inputs(fg_users.second) to a list fg_users->inputs(fg_users.second+i)
1672 fg_users_inputs_all.insert(fg_users_inputs_all.begin() + fg_users.second + i, tuple_get_item_node);
1673 auto new_parameter = user_func_graph->add_parameter();
1674 new_parameter->set_abstract(concat_input_node->input(i + 1)->abstract()->Clone());
1675 new_maketuple_abstracts.push_back(concat_input_node->input(i + 1)->abstract()->Clone());
1676 new_user_graph_parameters.insert(new_user_graph_parameters.begin() + fg_users.second - 1 + i, new_parameter);
1677 new_concat_maketuple_inputs.push_back(new_parameter);
1678 }
1679 user_func_graph->set_parameters(new_user_graph_parameters);
1680 auto user_func_graph_return_cnode = user_func_graph->get_return();
1681 auto return_input_cnode = user_func_graph_return_cnode->input(kIndex1);
1682 auto new_call_cnode = fg_users.first->func_graph()->NewCNode(fg_users_inputs_all);
1683 new_call_cnode->set_abstract(return_input_cnode->abstract()->Clone());
1684 manager->Replace(fg_users.first, new_call_cnode);
1685 // Handle user_func_graph slice cnode
1686 for (size_t j = 0; j < func_node_users.size(); ++j) {
1687 auto new_pair = func_node_users[j];
1688 if (!input_index[j].first) {
1689 auto new_maketuple_cnode = user_func_graph->NewCNode(new_concat_maketuple_inputs);
1690 new_maketuple_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(new_maketuple_abstracts));
1691 auto old_concat_prim = GetCNodePrimitive(concat_cnode);
1692 std::vector<AnfNodePtr> new_concat_inputs{NewValueNode(old_concat_prim->Clone()), new_maketuple_cnode,
1693 NewValueNode(MakeValue<int64_t>(concat_axis))};
1694 auto new_concat = user_func_graph->NewCNode(new_concat_inputs);
1695 new_concat->set_abstract(concat_cnode->abstract()->Clone());
1696 auto new_concat_prim = GetCNodePrimitive(new_concat);
1697 if (new_concat_prim->HasAttr("fine_grained_interleaved_index")) {
1698 new_concat_prim->EraseAttr("fine_grained_interleaved_index");
1699 }
1700 manager->SetEdge(new_pair.first, new_pair.second, new_concat);
1701 continue;
1702 }
1703 manager->Replace(new_pair.first, user_func_graph->parameters()[fg_users.second - 2 + input_index[j].second]);
1704 }
1705 }
1706 return true;
1707 }
1708
MergeConcatSlice(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphManagerPtr & manager)1709 bool MergeConcatSlice(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager) {
1710 bool merged = false;
1711 for (const auto &node : all_nodes) {
1712 if (!IsPrimitiveCNode(node, prim::kPrimConcat)) {
1713 continue;
1714 }
1715 auto concat_cnode = node->cast<CNodePtr>();
1716 MS_EXCEPTION_IF_NULL(concat_cnode->abstract());
1717 auto concat_output_shape = concat_cnode->abstract()->BuildShape();
1718 MS_EXCEPTION_IF_NULL(concat_output_shape);
1719 MS_EXCEPTION_IF_NULL(concat_output_shape->cast<abstract::ShapePtr>());
1720 auto concat_output_shape_element = concat_output_shape->cast<abstract::ShapePtr>()->shape();
1721 auto axis_value_node = concat_cnode->input(kIndex2);
1722 auto axis_value = GetValueNode(axis_value_node);
1723 auto concat_axis = GetValue<int64_t>(axis_value);
1724 auto next_nodes = GetOutputNodesSkipDepend(node);
1725 for (const auto &pair : next_nodes) {
1726 if (IsPrimitiveCNode(pair.first, prim::kPrimReturn) && next_nodes.size() == 1) {
1727 merged = HandleFuncConcatSlice(manager, pair, concat_cnode, concat_output_shape_element, concat_axis);
1728 continue;
1729 }
1730 auto can_merge = CanMergeConcatSlice(pair, concat_cnode, concat_output_shape_element, concat_axis);
1731 if (!can_merge.first) {
1732 continue;
1733 }
1734 auto concat_input_node = concat_cnode->input(1)->cast<CNodePtr>();
1735 auto concat_real_input_node = concat_input_node->input(can_merge.second);
1736 manager->Replace(pair.first->cast<CNodePtr>(), concat_real_input_node);
1737 merged = true;
1738 }
1739 }
1740 return merged;
1741 }
1742
NewMicroMirrorPrimByMicroMirror(const FuncGraphPtr & func_graph,const CNodePtr & micro_mirror,const AnfNodePtr & micro_mirror_new_input)1743 AnfNodePtr NewMicroMirrorPrimByMicroMirror(const FuncGraphPtr &func_graph, const CNodePtr µ_mirror,
1744 const AnfNodePtr µ_mirror_new_input) {
1745 auto prim_origin = GetCNodePrimitive(micro_mirror);
1746 Attr attr0 = std::make_pair(GROUP, prim_origin->GetAttr(GROUP));
1747 Attr attr1 = std::make_pair(DEV_NUM, prim_origin->GetAttr(DEV_NUM));
1748 Attr attr2 = std::make_pair(MEAN_FLAG, prim_origin->GetAttr(MEAN_FLAG));
1749 OperatorAttrs operator_attrs;
1750 operator_attrs.push_back(attr0);
1751 operator_attrs.push_back(attr1);
1752 operator_attrs.push_back(attr2);
1753 ValuePtr pyop_instance = CreateOpInstance(operator_attrs, MIRROR_MICRO_STEP_OPERATOR, prim_origin->instance_name());
1754 MS_EXCEPTION_IF_NULL(pyop_instance);
1755 std::vector<AnfNodePtr> mirror_inputs{NewValueNode(pyop_instance), micro_mirror_new_input,
1756 micro_mirror->input(kIndex2)};
1757 auto new_mirror_node = func_graph->NewCNode(mirror_inputs);
1758 auto prim = GetCNodePrimitive(new_mirror_node);
1759 (void)prim->SetAttrs(prim_origin->attrs());
1760 new_mirror_node->set_attrs(micro_mirror->attrs());
1761 new_mirror_node->set_primal_attrs(micro_mirror->primal_attrs());
1762 return new_mirror_node;
1763 }
1764
AddNodeFusionInfo(const CNodePtr & node,const CNodePtr & comm_node,const std::string & backward_comm_name,const std::string & param_name,int32_t fusion_id)1765 void AddNodeFusionInfo(const CNodePtr &node, const CNodePtr &comm_node, const std::string &backward_comm_name,
1766 const std::string ¶m_name, int32_t fusion_id) {
1767 auto comm_id = MakeValue<std::string>(param_name);
1768 comm_node->AddPrimalAttr(kPrimalAttrMirrorUserId, comm_id);
1769 if (GetValueNode<PrimitivePtr>(comm_node->input(0))->HasAttr(GROUP)) {
1770 auto comm_group = GetValue<std::string>(GetValueNode<PrimitivePtr>(comm_node->input(0))->GetAttr(GROUP));
1771 std::string fusion_key = backward_comm_name + "_" + comm_group + "_" + std::to_string(fusion_id);
1772 if (!IsPrimitiveCNode(node, prim::kPrimLoad) && !IsPrimitiveCNode(node, prim::kPrimCast)) {
1773 if (fusion_id > 0) {
1774 node->AddPrimalAttr(kRelatedFusionKey, MakeValue<std::string>(fusion_key));
1775 node->AddPrimalAttr(kRelatedNodeId, MakeValue<std::string>(node->UniqueId()));
1776 node->AddAttr(kRelatedCommNodeId, MakeValue<std::string>(comm_node->UniqueId()));
1777 }
1778 node->AddPrimalAttr(kPrimalAttrMirrorUserId, comm_id);
1779 return;
1780 }
1781 auto next_nodes = GetOutputNodesWithFilter(node, [&](const AnfNodePtr &anode) {
1782 return IsPrimitiveCNode(anode, prim::kPrimLoad) || IsPrimitiveCNode(anode, prim::kPrimCast) ||
1783 IsPrimitiveCNode(anode, prim::kPrimAllGather) || IsPrimitiveCNode(anode, prim::kPrimMirror) ||
1784 IsPrimitiveCNode(anode, prim::kPrimMicroStepAllGather) ||
1785 IsPrimitiveCNode(anode, prim::kPrimMirrorMicroStep) || IsPrimitiveCNode(anode, prim::kPrimMakeTuple);
1786 });
1787 for (auto &pair : next_nodes) {
1788 if (!IsPrimitiveCNode(pair.first)) {
1789 continue;
1790 }
1791 auto next_cnode = pair.first->cast<CNodePtr>();
1792 if (fusion_id > 0) {
1793 next_cnode->AddPrimalAttr(kRelatedFusionKey, MakeValue<std::string>(fusion_key));
1794 next_cnode->AddPrimalAttr(kRelatedNodeId, MakeValue<std::string>(node->UniqueId()));
1795 next_cnode->AddAttr(kRelatedCommNodeId, MakeValue<std::string>(comm_node->UniqueId()));
1796 }
1797 next_cnode->AddPrimalAttr(kPrimalAttrMirrorUserId, comm_id);
1798 }
1799 }
1800 }
1801
AddNodeMirrorInfo(const CNodePtr & cnode,const std::string & param_name)1802 void AddNodeMirrorInfo(const CNodePtr &cnode, const std::string ¶m_name) {
1803 auto comm_id = MakeValue<std::string>(param_name);
1804 if (IsParallelCareNode(cnode)) {
1805 cnode->AddPrimalAttr(kPrimalAttrMirrorUserId, comm_id);
1806 return;
1807 }
1808 auto next_nodes = GetOutputNodesWithFilter(cnode, [&](const AnfNodePtr &anode) {
1809 return IsPrimitiveCNode(anode, prim::kPrimLoad) || IsPrimitiveCNode(anode, prim::kPrimCast) ||
1810 IsPrimitiveCNode(anode, prim::kPrimAllGather) || IsPrimitiveCNode(anode, prim::kPrimMirror) ||
1811 IsPrimitiveCNode(anode, prim::kPrimMicroStepAllGather) ||
1812 IsPrimitiveCNode(anode, prim::kPrimMirrorMicroStep) || IsPrimitiveCNode(anode, prim::kPrimMakeTuple);
1813 });
1814 for (auto &pair : next_nodes) {
1815 if (!IsPrimitiveCNode(pair.first)) {
1816 continue;
1817 }
1818 auto next_node = pair.first->cast<CNodePtr>();
1819 next_node->AddPrimalAttr(kPrimalAttrMirrorUserId, comm_id);
1820 }
1821 }
1822
GetMakeTupleValue(const AnfNodePtr & node)1823 static ValuePtr GetMakeTupleValue(const AnfNodePtr &node) {
1824 auto cnode = node->cast<CNodePtr>();
1825 auto &inputs = cnode->inputs();
1826
1827 std::vector<int64_t> value_list;
1828 for (size_t index = 1; index < inputs.size(); ++index) {
1829 if (inputs[index]->isa<ValueNode>()) {
1830 auto element = GetValueNode(inputs[index]);
1831 if (element->isa<Int64Imm>()) {
1832 int64_t value = element->cast<Int64ImmPtr>()->value();
1833 value_list.push_back(value);
1834 continue;
1835 }
1836 }
1837 value_list.push_back(-1); // dynamic shape
1838 }
1839
1840 MS_LOG(INFO) << "the make tuple value is " << value_list;
1841 return MakeValue(value_list);
1842 }
1843
HasSupportedValueSequence(const CNodePtr & node)1844 bool HasSupportedValueSequence(const CNodePtr &node) {
1845 const auto &all_inputs = node->inputs();
1846 return std::any_of(all_inputs.begin() + 1, all_inputs.end(), [&node](const AnfNodePtr &input) {
1847 bool is_abs_seq = false;
1848 auto abs = input->abstract();
1849 if (abs != nullptr) {
1850 is_abs_seq = abs->isa<abstract::AbstractSequence>();
1851 }
1852 return (is_abs_seq || IsValueSequence(input)) && IsSomePrimitiveList(node, SUPPORT_NEW_SHAPEBASE_OPS);
1853 });
1854 }
1855
CreateOperatorInfoForTupleShape(const CNodePtr & cnode)1856 OperatorInfoPtr CreateOperatorInfoForTupleShape(const CNodePtr &cnode) {
1857 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
1858 MS_EXCEPTION_IF_NULL(prim);
1859 MS_LOG(INFO) << prim->name() << ": has value sequence input, enter new shape logic.";
1860 std::pair<std::vector<NewShapes>, std::vector<Symbols>> shapes_and_symbols = ExtractNewShapeAndSymbol(cnode);
1861 auto shape_list = shapes_and_symbols.first;
1862 auto symbol_list = shapes_and_symbols.second;
1863 if (shape_list.size() != INPUT_OUTPUT_SYMBOLS_SIZE || symbol_list.size() != INPUT_OUTPUT_SYMBOLS_SIZE) {
1864 MS_LOG(EXCEPTION) << "the size of shapes or symbols must be " << INPUT_OUTPUT_SYMBOLS_SIZE
1865 << ", but the size of shapes is " << shape_list.size() << ", the size of symbols is "
1866 << symbol_list.size();
1867 }
1868 auto attrs = prim->attrs();
1869 std::vector<Shapes> temp_shape_list = {{}, {}};
1870 OperatorInfoPtr op_info = OperatorInstance(prim, attrs, temp_shape_list);
1871 MS_EXCEPTION_IF_NULL(op_info);
1872
1873 // When the 'inputs' contains numerical values for some operators, these values should be extracted from
1874 // ANF graph
1875 auto &inputs = cnode->inputs();
1876 std::vector<ValuePtr> input_value;
1877 for (size_t index = 1; index < inputs.size(); ++index) {
1878 if (inputs[index]->isa<ValueNode>() || inputs[index]->isa<tensor::Tensor>()) {
1879 (void)input_value.emplace_back(GetValueNode(inputs[index]));
1880 continue;
1881 } else if (IsPrimitiveCNode(inputs[index], prim::kPrimMakeTuple)) {
1882 auto make_tuple_value = GetMakeTupleValue(inputs[index]);
1883 (void)input_value.emplace_back(make_tuple_value);
1884 continue;
1885 } else if (IsPrimitiveCNode(inputs[index], prim::kPrimShape)) {
1886 auto shape_op_cnode = dyn_cast_ptr<CNode>(inputs[index]);
1887 auto dst_shape = GetNodeShape(shape_op_cnode->input(1));
1888 (void)input_value.emplace_back(MakeValue(dst_shape[0]));
1889 MS_LOG(INFO) << "The prim is " << prim->name() << ", the input index is " << index - 1
1890 << ", is Shape op, dst shape is " << dst_shape;
1891 continue;
1892 }
1893 (void)input_value.emplace_back(nullptr);
1894 }
1895 (*op_info).set_input_value(input_value);
1896 (*op_info).set_outputs_dtype(cnode->Type());
1897 (*op_info).set_cnode(cnode);
1898 (*op_info).set_new_shape(shape_list);
1899 return op_info;
1900 }
1901
CreateOperatorInfo(const CNodePtr & cnode)1902 OperatorInfoPtr CreateOperatorInfo(const CNodePtr &cnode) {
1903 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
1904 MS_EXCEPTION_IF_NULL(prim);
1905 if (HasSupportedValueSequence(cnode)) {
1906 return CreateOperatorInfoForTupleShape(cnode);
1907 }
1908 std::pair<std::vector<Shapes>, std::vector<Symbols>> shapes_and_symbols = ExtractShapeAndSymbol(cnode);
1909 auto shape_list = shapes_and_symbols.first;
1910 auto symbol_list = shapes_and_symbols.second;
1911 if (shape_list.size() != INPUT_OUTPUT_SYMBOLS_SIZE || symbol_list.size() != INPUT_OUTPUT_SYMBOLS_SIZE) {
1912 MS_LOG(EXCEPTION) << "the size of shapes or symbols must be " << INPUT_OUTPUT_SYMBOLS_SIZE
1913 << ", but the size of shapes is " << shape_list.size() << ", the size of symbols is "
1914 << symbol_list.size();
1915 }
1916
1917 auto attrs = prim->attrs();
1918 OperatorInfoPtr op_info = OperatorInstance(prim, attrs, shape_list);
1919 MS_EXCEPTION_IF_NULL(op_info);
1920 MS_LOG(INFO) << "shape_list.size(): " << shape_list.size();
1921
1922 // When the 'inputs' contains numerical values for some operators, these values should be extracted from
1923 // ANF graph
1924 auto &inputs = cnode->inputs();
1925 std::vector<ValuePtr> input_value;
1926 for (size_t index = 1; index < inputs.size(); ++index) {
1927 if (inputs[index]->isa<ValueNode>() || inputs[index]->isa<tensor::Tensor>()) {
1928 (void)input_value.emplace_back(GetValueNode(inputs[index]));
1929 continue;
1930 } else if (IsPrimitiveCNode(inputs[index], prim::kPrimMakeTuple)) {
1931 auto make_tuple_value = GetMakeTupleValue(inputs[index]);
1932 (void)input_value.emplace_back(make_tuple_value);
1933 continue;
1934 } else if (IsPrimitiveCNode(inputs[index], prim::kPrimShape)) {
1935 auto shape_op_cnode = dyn_cast_ptr<CNode>(inputs[index]);
1936 auto dst_shape = GetNodeShape(shape_op_cnode->input(1));
1937 (void)input_value.emplace_back(MakeValue(dst_shape[0]));
1938 MS_LOG(INFO) << "The prim is " << prim->name() << ", the input index is " << index - 1
1939 << ", is Shape op, dst shape is " << dst_shape;
1940 continue;
1941 }
1942 (void)input_value.emplace_back(nullptr);
1943 }
1944
1945 (*op_info).set_input_value(input_value);
1946 (*op_info).set_outputs_dtype(cnode->Type());
1947 (*op_info).set_cnode(cnode);
1948 if (InDynamicGraph(cnode) && IsDynamicShapesList(shape_list)) {
1949 Shapes in_real_divisors;
1950 Shapes out_real_divisors;
1951 in_real_divisors = GetRealDivisorSymbols(shape_list[INPUT_SYMBOLS_INDEX], symbol_list[INPUT_SYMBOLS_INDEX]);
1952 out_real_divisors = GetRealDivisorSymbols(shape_list[OUTPUT_SYMBOLS_INDEX], symbol_list[OUTPUT_SYMBOLS_INDEX]);
1953 (*op_info).set_dynamic_shape_flag(True);
1954 (*op_info).set_inputs_divisor(in_real_divisors);
1955 (*op_info).set_outputs_divisor(out_real_divisors);
1956 MS_LOG(DEBUG) << (*op_info).name() << ": inputs-shape: " << ShapesToString(shape_list[0])
1957 << ", inputs_d_symbol: " << ShapesToString(in_real_divisors);
1958 MS_LOG(DEBUG) << (*op_info).name() << ": outputs-shape: " << ShapesToString(shape_list[1])
1959 << ", outputs_d_symbol: " << ShapesToString(out_real_divisors);
1960 }
1961 return op_info;
1962 }
1963
ExtendInputArgsAbstractShape(const AbstractBasePtr & args_abstract_item,size_t index)1964 void ExtendInputArgsAbstractShape(const AbstractBasePtr &args_abstract_item, size_t index) {
1965 auto args_abstract_item_shape = args_abstract_item->BuildShape();
1966 auto shape_ptr = dyn_cast<abstract::Shape>(args_abstract_item_shape);
1967 if (shape_ptr == nullptr) {
1968 MS_LOG(WARNING) << "The input " << index << " is not a tensor.";
1969 return;
1970 }
1971 auto shape_value = parallel::ToFullShape(shape_ptr->shape(), index);
1972 auto new_shape_item = std::make_shared<abstract::Shape>(shape_value);
1973 args_abstract_item->set_shape(new_shape_item);
1974 }
1975
ToFullShape(const ShapeVector & input_shape,size_t index)1976 ShapeVector ToFullShape(const ShapeVector &input_shape, size_t index) {
1977 if (input_shape.empty()) {
1978 return input_shape;
1979 }
1980 MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
1981 if (ParallelContext::GetInstance()->dataset_strategy().empty()) {
1982 auto shape_value = input_shape;
1983 if (!parallel::ParallelContext::GetInstance()->full_batch()) {
1984 auto comm_info = parallel::GetCommInfo();
1985 auto world_rank_size = comm_info.device_num / ParallelContext::GetInstance()->pipeline_stage_split_num();
1986 if (shape_value[0] > 0) {
1987 shape_value[0] = shape_value[0] * SizeToLong(world_rank_size); // only for static shape
1988 }
1989 }
1990 return shape_value;
1991 }
1992 auto dataset_strategy = ParallelContext::GetInstance()->dataset_strategy();
1993 if (index >= dataset_strategy.size()) {
1994 MS_LOG(EXCEPTION) << "The input shapes size is not equal to dataset strategy size " << dataset_strategy.size();
1995 }
1996 auto dataset_strategy_item = dataset_strategy[index];
1997 if (input_shape.size() != dataset_strategy_item.size()) {
1998 MS_LOG(EXCEPTION) << "The input_shapes[" << index << "]'s size" << input_shape.size()
1999 << " is not equal to dataset_strategy[" << index << "]'s size " << dataset_strategy_item.size();
2000 }
2001 ShapeVector shape_value;
2002 for (size_t i = 0; i < dataset_strategy_item.size(); ++i) {
2003 if (input_shape[i] > 0) {
2004 shape_value.push_back(input_shape[i] * dataset_strategy_item[i]);
2005 } else {
2006 shape_value.push_back(input_shape[i]); // dynamic shape, shape is still -1
2007 }
2008 }
2009 return shape_value;
2010 }
2011
GetCommInfo()2012 CommInfo GetCommInfo() {
2013 int64_t device_num = ParallelContext::GetInstance()->device_num();
2014 int64_t global_rank = ParallelContext::GetInstance()->global_rank();
2015 auto ms_context = MsContext::GetInstance();
2016 MS_EXCEPTION_IF_NULL(ms_context);
2017 std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
2018 std::string world_group;
2019 std::string communication_backend;
2020 if (backend == kAscendDevice || backend == kDavinciDevice) {
2021 world_group = HCCL_WORLD_GROUP;
2022 communication_backend = HCCL_BACKEND;
2023 } else if (backend == kGPUDevice) {
2024 world_group = NCCL_WORLD_GROUP;
2025 communication_backend = NCCL_BACKEND;
2026 } else {
2027 MS_LOG(EXCEPTION) << "Invalid communication backend: " << backend
2028 << " for semi_auto_parallel/auto_parallel mode,"
2029 " currently only support Ascend/GPU backend.";
2030 }
2031 uint32_t world_rank_size = 0;
2032 if (!CommManager::GetInstance().GetRankSize(world_group, &world_rank_size)) {
2033 MS_LOG(EXCEPTION) << "Get rank size failed";
2034 }
2035
2036 if (!ParallelContext::GetInstance()->device_num_is_set()) {
2037 device_num = UintToInt(world_rank_size);
2038 MS_LOG(INFO) << "Get device num from communication model, the device num is " << device_num;
2039 }
2040 #if (!defined(_WIN32) && !defined(__APPLE__) && !(defined(ENABLE_TESTCASES) || defined(ENABLE_TEST)))
2041 if (ParallelContext::GetInstance()->device_num_is_set() && world_rank_size != device_num &&
2042 !ParallelContext::GetInstance()->hccl_test_available()) {
2043 // hccl_test_available is used when we compile graphs in real ascend card environment, but with hccl_test.
2044 MS_LOG(EXCEPTION) << "The device_num " << device_num << " set in the context is not consist with "
2045 << world_rank_size << " devices you have"
2046 << ". Please check your rank_table file(for Ascend) or host file(for GPU).";
2047 }
2048 #endif
2049 uint32_t rank_id = 0;
2050 if (!ParallelContext::GetInstance()->global_rank_is_set()) {
2051 if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) {
2052 MS_LOG(EXCEPTION) << "Get rank id failed";
2053 }
2054 global_rank = UintToInt(rank_id);
2055 ParallelContext::GetInstance()->set_global_rank(global_rank);
2056 MS_LOG(INFO) << "Get global rank from communication model, the global rank is " << global_rank;
2057 }
2058 CommInfo comm_info{device_num, global_rank, world_group, communication_backend};
2059 return comm_info;
2060 }
2061
IsPynativeParallel()2062 bool IsPynativeParallel() {
2063 auto parallel_mode = ParallelContext::GetInstance()->parallel_mode();
2064 auto execution_mode = MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE);
2065 return (execution_mode == kPynativeMode) && (parallel_mode == kSemiAutoParallel || parallel_mode == kAutoParallel);
2066 }
2067
IsAutoParallelCareGraph(const FuncGraphPtr & func_graph)2068 bool IsAutoParallelCareGraph(const FuncGraphPtr &func_graph) {
2069 // compile graph order:
2070 // 1, ParallelParameterContextRestoreShape
2071 // 2, PipelineSplit: insert virtual dataset
2072 // 3, StepAutoParallel
2073 // 4, StepParallel
2074 // if IsParallel() is true, it maybe has some graphs that we now care, so need to check
2075 // 'sharded' or 'has_shard' flag
2076 MS_EXCEPTION_IF_NULL(func_graph);
2077 if (func_graph->has_flag(kSkipAutoParallelCompile)) {
2078 return false;
2079 }
2080
2081 MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
2082 std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
2083 if (parallel_mode != kAutoParallel && parallel_mode != kSemiAutoParallel) {
2084 return false;
2085 }
2086
2087 if (IsPynativeParallel() && !func_graph->has_flag(kHasShard) && !(func_graph->has_flag(kSharded))) {
2088 return false;
2089 }
2090 return true;
2091 }
2092
FindPreNodeCrossFuncGraph(CNodePtr * cnode,int64_t out_index)2093 void FindPreNodeCrossFuncGraph(CNodePtr *cnode, int64_t out_index) {
2094 if (IsValueNode<FuncGraph>((*cnode)->input(0))) {
2095 auto graph = GetValueNode<FuncGraphPtr>((*cnode)->input(0));
2096 auto output = graph->output();
2097 MS_EXCEPTION_IF_NULL(output);
2098 while (IsPrimitiveCNode(output, prim::kPrimDepend)) {
2099 auto output_cnode = output->cast<CNodePtr>();
2100 MS_EXCEPTION_IF_NULL(output_cnode);
2101 output = output_cnode->input(1);
2102 }
2103 while (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
2104 auto make_tuple_cnode = output->cast<CNodePtr>();
2105 output = make_tuple_cnode->input(out_index + 1);
2106 }
2107 *cnode = output->cast<CNodePtr>();
2108 }
2109 }
2110
FindRealInputByFormalParameter(const CNodePtr & node,const AnfNodePtr & input,const std::vector<AnfNodePtr> & all_nodes)2111 AnfNodePtr FindRealInputByFormalParameter(const CNodePtr &node, const AnfNodePtr &input,
2112 const std::vector<AnfNodePtr> &all_nodes) {
2113 auto prev_node = input;
2114 auto graph = node->func_graph();
2115 auto params = graph->parameters();
2116 int64_t param_index = -1;
2117 for (size_t j = 0; j < params.size(); ++j) {
2118 if (params[j] == input) {
2119 param_index = SizeToLong(j);
2120 }
2121 }
2122 if (param_index == -1) {
2123 return prev_node;
2124 }
2125 for (auto &ele : all_nodes) {
2126 if (!ele->isa<CNode>()) {
2127 continue;
2128 }
2129 auto parent_node = ele->cast<CNodePtr>();
2130 if (IsValueNode<FuncGraph>(parent_node->input(0)) && GetValueNode<FuncGraphPtr>(parent_node->input(0)) == graph) {
2131 return parent_node->input(param_index + 1);
2132 }
2133 }
2134 return prev_node;
2135 }
2136
CrossInterNode(CNodePtr * prev_cnode,ValueNodePtr * prev_prim_anf_node,PrimitivePtr * prev_prim)2137 bool CrossInterNode(CNodePtr *prev_cnode, ValueNodePtr *prev_prim_anf_node, PrimitivePtr *prev_prim) {
2138 if ((*prev_cnode == nullptr) ||
2139 !(IsValueNode<Primitive>((*prev_cnode)->input(0)) || IsValueNode<FuncGraph>((*prev_cnode)->input(0)))) {
2140 return true;
2141 }
2142 if (!IsValueNode<FuncGraph>((*prev_cnode)->input(0))) {
2143 *prev_prim_anf_node = (*prev_cnode)->input(0)->cast<ValueNodePtr>();
2144 *prev_prim = (*prev_prim_anf_node)->value()->cast<PrimitivePtr>();
2145 }
2146 return false;
2147 }
2148
IsCarePrevCNode(const CNodePtr & prev_cnode,const PrimitivePtr & prev_prim)2149 bool IsCarePrevCNode(const CNodePtr &prev_cnode, const PrimitivePtr &prev_prim) {
2150 return (IsValueNode<FuncGraph>(prev_cnode->input(0))) || (prev_prim->name() == kTupleGetItemOpName) ||
2151 (prev_prim->name() == kDependOpName) || (prev_prim->name() == kMakeListOpName) ||
2152 (prev_prim->name() == kLoadOpName) || (prev_prim->name() == kMakeTupleOpName) ||
2153 (prev_prim->name() == kShapeOpName) || IsAutoParallelCareNode(prev_cnode);
2154 }
2155
IsCrossedCNode(std::string prev_prim_name)2156 bool IsCrossedCNode(std::string prev_prim_name) {
2157 const std::set<std::string> crossed_cnode_list = {kDependOpName, kLoadOpName, kShapeOpName};
2158 return crossed_cnode_list.find(prev_prim_name) != crossed_cnode_list.end();
2159 }
2160
2161 // Needed by rec_parser
ExtractInputsTensorName(const CNodePtr & node,const std::vector<AnfNodePtr> & all_nodes)2162 std::vector<std::string> ExtractInputsTensorName(const CNodePtr &node, const std::vector<AnfNodePtr> &all_nodes) {
2163 std::vector<std::string> name_inputs;
2164 std::vector<AnfNodePtr> all_inputs = node->inputs();
2165 std::vector<AnfNodePtr> node_inputs{all_inputs.begin() + 1, all_inputs.end()};
2166
2167 std::string node_id = node->UniqueId();
2168 name_inputs.push_back(node_id);
2169 for (auto &input : node_inputs) {
2170 AnfNodePtr prev_node = input;
2171 if (input->isa<Parameter>()) {
2172 prev_node = FindRealInputByFormalParameter(node, input, all_nodes);
2173 if (prev_node->UniqueId() == input->UniqueId()) {
2174 name_inputs.push_back(input->UniqueId());
2175 continue;
2176 }
2177 }
2178 auto prev_cnode = prev_node->cast<CNodePtr>();
2179 PrimitivePtr prev_prim;
2180 ValueNodePtr prev_prim_anf_node;
2181
2182 bool is_cross = CrossInterNode(&prev_cnode, &prev_prim_anf_node, &prev_prim);
2183 if (is_cross) {
2184 name_inputs.push_back(input->UniqueId());
2185 continue;
2186 }
2187
2188 size_t output_index = 0;
2189 while (IsCarePrevCNode(prev_cnode, prev_prim)) {
2190 if (IsValueNode<FuncGraph>(prev_cnode->input(0))) {
2191 auto graph = GetValueNode<FuncGraphPtr>(prev_cnode->input(0));
2192 auto output = graph->output();
2193 MS_EXCEPTION_IF_NULL(output);
2194 prev_cnode = output->cast<CNodePtr>();
2195 (void)CrossInterNode(&prev_cnode, &prev_prim_anf_node, &prev_prim);
2196 } else if (IsAutoParallelCareNode(prev_cnode)) {
2197 name_inputs.push_back(prev_cnode->UniqueId());
2198 break;
2199 } else if (prev_prim->name() == kTupleGetItemOpName) {
2200 // In this case, 'prev_anf_node' is 'tuple_getitem', the actual precursor node is node before
2201 // this 'tuple_getitem'
2202 output_index = LongToSize(GetValue<int64_t>(GetValueNode(prev_cnode->input(INDEX_TWO))));
2203 prev_node = prev_cnode->input(1);
2204 prev_cnode = prev_node->cast<CNodePtr>();
2205
2206 if (prev_cnode != nullptr && common::AnfAlgo::GetCNodeName(prev_cnode) == kTupleGetItemOpName) {
2207 continue;
2208 }
2209
2210 is_cross = CrossInterNode(&prev_cnode, &prev_prim_anf_node, &prev_prim);
2211 if (is_cross) {
2212 name_inputs.push_back(prev_node->UniqueId());
2213 break;
2214 }
2215
2216 // In dynamic shape scenarios, the situation op1->Shape->TupleGetItem->op2 will occur.
2217 // The incoming operator of op2 should be op1 instead of Shape,
2218 // so the Shape operator is skipped when looking for the incoming operator.
2219 if (prev_prim->name() == kShapeOpName) {
2220 continue;
2221 }
2222
2223 if (!IsAutoParallelCareNode(prev_cnode) && !IsValueNode<FuncGraph>(prev_cnode->input(0))) {
2224 MS_LOG(EXCEPTION) << "Did not create OperatorInfo for : " << prev_prim->name();
2225 }
2226 } else if (prev_prim->name() == kMakeTupleOpName) {
2227 prev_node = prev_cnode->input(output_index + 1);
2228 prev_cnode = prev_node->cast<CNodePtr>();
2229 output_index = 0;
2230 is_cross = CrossInterNode(&prev_cnode, &prev_prim_anf_node, &prev_prim);
2231 if (is_cross) {
2232 name_inputs.push_back(prev_node->UniqueId());
2233 break;
2234 }
2235 } else if (IsCrossedCNode(prev_prim->name())) {
2236 // In this case, 'prev_anf_node' is 'depend', the actual precursor node is node before
2237 // this 'depend'
2238 prev_node = prev_cnode->input(1);
2239 prev_cnode = prev_node->cast<CNodePtr>();
2240 is_cross = CrossInterNode(&prev_cnode, &prev_prim_anf_node, &prev_prim);
2241 if (is_cross) {
2242 name_inputs.push_back(prev_node->UniqueId());
2243 break;
2244 }
2245 }
2246 }
2247 }
2248
2249 return name_inputs;
2250 }
2251
GetDistributeOperator(const CNodePtr & node)2252 OperatorInfoPtr GetDistributeOperator(const CNodePtr &node) {
2253 MS_EXCEPTION_IF_NULL(node);
2254 if (!IsParallelCareNode(node)) {
2255 return nullptr;
2256 }
2257 OperatorInfoPtr distribute_operator = node->user_data<OperatorInfo>();
2258 return distribute_operator;
2259 }
2260
StrategyFound(const mindspore::HashMap<std::string,ValuePtr> & attrs)2261 bool StrategyFound(const mindspore::HashMap<std::string, ValuePtr> &attrs) {
2262 auto iter = attrs.find(IN_STRATEGY);
2263 return !((iter == attrs.end()) || (iter->second->type_name() == NONE));
2264 }
2265
AttrFound(const mindspore::HashMap<std::string,ValuePtr> & attrs,const std::string & target)2266 bool AttrFound(const mindspore::HashMap<std::string, ValuePtr> &attrs, const std::string &target) {
2267 auto iter = attrs.find(target);
2268 return !((iter == attrs.end()) || (iter->second->type_name() == NONE));
2269 }
2270
IsCommunicationOp(const PrimitivePtr & prim)2271 bool IsCommunicationOp(const PrimitivePtr &prim) {
2272 MS_EXCEPTION_IF_NULL(prim);
2273 return (COMMUNICATION_OPS.find(prim->name()) != COMMUNICATION_OPS.end());
2274 }
2275
ExceptionIfHasCommunicationOp(const std::vector<AnfNodePtr> & all_nodes)2276 void ExceptionIfHasCommunicationOp(const std::vector<AnfNodePtr> &all_nodes) {
2277 for (auto &node : all_nodes) {
2278 MS_EXCEPTION_IF_NULL(node);
2279 if (!node->isa<CNode>()) {
2280 continue;
2281 }
2282 auto cnode = node->cast<CNodePtr>();
2283 if (!IsValueNode<Primitive>(cnode->input(0))) {
2284 continue;
2285 }
2286 ValueNodePtr prim_value_node = cnode->input(0)->cast<ValueNodePtr>();
2287 MS_EXCEPTION_IF_NULL(prim_value_node);
2288 PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_value_node);
2289 MS_EXCEPTION_IF_NULL(prim);
2290
2291 if (IsCommunicationOp(prim) && cnode->in_forward_flag()) {
2292 MS_EXCEPTION_IF_NULL(prim_value_node->scope());
2293 MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
2294 std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
2295 MS_LOG(EXCEPTION) << "If the parallel mode is semi_auto_parallel or auto_parallel, the graph can not contain "
2296 "communication op, the parallel mode is "
2297 << parallel_mode << ", and the graph has communication op : " << prim->name()
2298 << ", scope name is " << prim_value_node->scope()->name();
2299 }
2300 }
2301 }
2302
MirrorOpName()2303 std::string MirrorOpName() {
2304 int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
2305 int64_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
2306 std::string mirror_op_name;
2307 if (split_stage_num > 1 || grad_accumulation_step > 1) {
2308 mirror_op_name = MIRROR_MICRO_STEP_OPERATOR;
2309 } else {
2310 mirror_op_name = MIRROR_OPERATOR;
2311 }
2312 return mirror_op_name;
2313 }
2314
CheckStrategyWithTupleInTuple(const std::vector<ValuePtr> & elements)2315 bool CheckStrategyWithTupleInTuple(const std::vector<ValuePtr> &elements) {
2316 bool has_tuple_in_tuple = false;
2317 for (size_t i = 0; i < elements.size(); ++i) {
2318 if (elements[i]->isa<ValueSequence>()) {
2319 auto value_tuple = elements[i]->cast<ValueTuplePtr>();
2320 std::vector<ValuePtr> value_vector = value_tuple->value();
2321 auto local_tuple_in_tuple = std::any_of(value_vector.begin(), value_vector.end(),
2322 [](const ValuePtr &value) { return value->isa<ValueSequence>(); });
2323 has_tuple_in_tuple = has_tuple_in_tuple || local_tuple_in_tuple;
2324 } else {
2325 MS_LOG(EXCEPTION) << "Failure: Strategy's format is wrong! Need ValueSequence";
2326 }
2327 }
2328 MS_LOG(INFO) << "CheckStrategyWithTupleInTuple: has_tuple_in_tuple = " << has_tuple_in_tuple << ".";
2329 return has_tuple_in_tuple;
2330 }
2331
ExtractDimensions(const ValuePtr & stra)2332 NewDimensions ExtractDimensions(const ValuePtr &stra) {
2333 auto value_tuple = stra->cast<ValueTuplePtr>();
2334 std::vector<ValuePtr> value_vector = value_tuple->value();
2335 bool has_tuple_in_tuple = std::any_of(value_vector.begin(), value_vector.end(),
2336 [](const ValuePtr &value) { return value->isa<ValueSequence>(); });
2337 if (has_tuple_in_tuple) {
2338 std::vector<NewDimensions> dim;
2339 (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(dim),
2340 [](const ValuePtr &value) { return ExtractDimensions(value); });
2341 return std::make_shared<ShapeList>(dim);
2342 }
2343 Dimensions dim;
2344 (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(dim),
2345 [](const ValuePtr &value) { return static_cast<int64_t>(GetValue<int64_t>(value)); });
2346 return std::make_shared<ShapeValue>(dim);
2347 }
2348
ExtractNewStrategy(const std::vector<ValuePtr> & elements,const int64_t & stage_id)2349 StrategyPtr ExtractNewStrategy(const std::vector<ValuePtr> &elements, const int64_t &stage_id) {
2350 NewStrategies strategy;
2351 for (uint64_t index = 0; index < elements.size(); ++index) {
2352 if (elements[index]->isa<ValueSequence>()) {
2353 auto dim = ExtractDimensions(elements[index]);
2354 strategy.emplace_back(dim);
2355 } else {
2356 MS_LOG(EXCEPTION) << "Failure: Strategy's format is wrong! Need ValueSequence";
2357 }
2358 }
2359 if (strategy.empty()) {
2360 MS_LOG(EXCEPTION) << "ExtractStrategy: failed to extract strategy";
2361 }
2362 StrategyPtr strategyPtr = NewStrategy(stage_id, strategy);
2363 return strategyPtr;
2364 }
2365
ExtractStrategy(const ValuePtr & stra)2366 StrategyPtr ExtractStrategy(const ValuePtr &stra) {
2367 if (stra == nullptr) {
2368 return nullptr;
2369 }
2370
2371 auto var = stra->cast<ValueTuplePtr>();
2372 if (var == nullptr) {
2373 return nullptr;
2374 }
2375
2376 StrategyPtr strategyPtr;
2377 int64_t stage_id = g_device_manager->stage_id();
2378
2379 MS_LOG(INFO) << "Extract information: strategy " << stra->ToString();
2380 if (var->size() > 0) {
2381 std::vector<ValuePtr> elements = var->value();
2382 if (CheckStrategyWithTupleInTuple(elements)) {
2383 return ExtractNewStrategy(elements, stage_id);
2384 }
2385 Strategies strategy;
2386 for (uint64_t index = 0; index < elements.size(); ++index) {
2387 Dimensions dim;
2388 if (elements[index]->isa<ValueSequence>()) {
2389 auto value_tuple = elements[index]->cast<ValueTuplePtr>();
2390 std::vector<ValuePtr> value_vector = value_tuple->value();
2391 (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(dim),
2392 [](const ValuePtr &value) { return static_cast<int64_t>(GetValue<int64_t>(value)); });
2393 strategy.push_back(dim);
2394 } else {
2395 MS_LOG(EXCEPTION) << "Failure: Strategy's format is wrong! Need ValueSequence";
2396 }
2397 }
2398 if (strategy.empty()) {
2399 MS_LOG(EXCEPTION) << "ExtractStrategy: failed to extract strategy";
2400 }
2401 strategyPtr = NewStrategy(stage_id, strategy);
2402 }
2403 return strategyPtr;
2404 }
2405
GetLayoutFromAttrValue(const ValuePtr & layout_item,std::vector<int64_t> * device_matrix_vector,std::vector<std::vector<int64_t>> * tensor_map_vector,bool * interleaved_parallel)2406 Status GetLayoutFromAttrValue(const ValuePtr &layout_item, std::vector<int64_t> *device_matrix_vector,
2407 std::vector<std::vector<int64_t>> *tensor_map_vector, bool *interleaved_parallel) {
2408 auto layout_dict_value = layout_item->cast<ValueDictionaryPtr>();
2409 if (!layout_dict_value) {
2410 MS_LOG(ERROR) << "The layout item configured for node is unreasonable";
2411 return FAILED;
2412 }
2413 auto layout_dict = layout_dict_value->value();
2414 ValuePtr device_matrix_value = nullptr;
2415 ValuePtr tensor_map_value = nullptr;
2416 ValuePtr interleaved_parallel_value = nullptr;
2417 for (const auto &value_pair : layout_dict) {
2418 if ((*value_pair.first) == (*MakeValue<std::string>(DEVICE_MATRIX))) {
2419 device_matrix_value = value_pair.second;
2420 }
2421 if ((*value_pair.first) == (*MakeValue<std::string>(TENSOR_MAP))) {
2422 tensor_map_value = value_pair.second;
2423 }
2424 if ((*value_pair.first) == (*MakeValue<std::string>(INTERLEAVED_PARALLEL))) {
2425 interleaved_parallel_value = value_pair.second;
2426 }
2427 }
2428 if (!device_matrix_value || !tensor_map_value || !interleaved_parallel_value) {
2429 MS_LOG(ERROR) << "The layout item configured for node is unreasonable";
2430 return FAILED;
2431 }
2432 *device_matrix_vector = GetValue<std::vector<int64_t>>(device_matrix_value);
2433 *interleaved_parallel = GetValue<bool>(interleaved_parallel_value);
2434 auto tensor_map_value_tuple = tensor_map_value->cast<ValueTuplePtr>();
2435 std::vector<ValuePtr> tensor_map_value_tuple_vector = tensor_map_value_tuple->value();
2436 for (const auto &tensor_map_item : tensor_map_value_tuple_vector) {
2437 if (tensor_map_item->isa<ValueSequence>()) {
2438 auto tensor_map_item_v = GetValue<std::vector<int64_t>>(tensor_map_item);
2439 tensor_map_vector->push_back(tensor_map_item_v);
2440 continue;
2441 }
2442 auto tensor_map_item_i = GetValue<int64_t>(tensor_map_item);
2443 tensor_map_vector->push_back({tensor_map_item_i});
2444 }
2445 return SUCCESS;
2446 }
2447
ExtractUserConfigLayout(const mindspore::HashMap<std::string,ValuePtr> & prim_attrs,const Shapes & inputs_shape,const Shapes & outputs_shape,std::vector<std::shared_ptr<TensorLayout>> * in_tensor_layouts,std::vector<std::shared_ptr<TensorLayout>> * out_tensor_layouts)2448 Status ExtractUserConfigLayout(const mindspore::HashMap<std::string, ValuePtr> &prim_attrs, const Shapes &inputs_shape,
2449 const Shapes &outputs_shape,
2450 std::vector<std::shared_ptr<TensorLayout>> *in_tensor_layouts,
2451 std::vector<std::shared_ptr<TensorLayout>> *out_tensor_layouts) {
2452 if (prim_attrs.count(IN_LAYOUT) > 0) {
2453 auto layout_value = prim_attrs.at(IN_LAYOUT);
2454 if (!layout_value->isa<ValueSequence>()) {
2455 MS_LOG(ERROR) << "The in_layout configured for node is not a tuple";
2456 return FAILED;
2457 }
2458 auto layout_value_tuple = layout_value->cast<ValueTuplePtr>();
2459 std::vector<ValuePtr> layout_value_vector = layout_value_tuple->value();
2460 if (inputs_shape.size() != layout_value_vector.size()) {
2461 MS_LOG(ERROR) << "The in_layout configured for node is not equal to its input nums";
2462 return FAILED;
2463 }
2464
2465 for (size_t i = 0; i < layout_value_vector.size(); ++i) {
2466 auto layout_item = layout_value_vector[i];
2467 std::vector<int64_t> device_matrix_vector;
2468 std::vector<std::vector<int64_t>> tensor_map_vector;
2469 bool interleaved_parallel;
2470 if (GetLayoutFromAttrValue(layout_item, &device_matrix_vector, &tensor_map_vector, &interleaved_parallel) !=
2471 SUCCESS) {
2472 return FAILED;
2473 }
2474 auto in_layout = std::make_shared<TensorLayout>();
2475 if (in_layout->InitFromExtendVector(device_matrix_vector, tensor_map_vector, inputs_shape[i],
2476 interleaved_parallel) != SUCCESS) {
2477 MS_LOG(ERROR) << "The in_layout configured incorrect, device_matrix:" << device_matrix_vector
2478 << ", tensor_map:" << tensor_map_vector;
2479 return FAILED;
2480 }
2481 in_tensor_layouts->push_back(in_layout);
2482 }
2483 }
2484 if (prim_attrs.count(OUT_LAYOUT) > 0) {
2485 auto layout_value = prim_attrs.at(OUT_LAYOUT);
2486 if (!layout_value->isa<ValueSequence>()) {
2487 MS_LOG(EXCEPTION) << "The in_layout configured for node is not a tuple";
2488 }
2489 auto layout_value_tuple = layout_value->cast<ValueTuplePtr>();
2490 std::vector<ValuePtr> layout_value_vector = layout_value_tuple->value();
2491 if (outputs_shape.size() != layout_value_vector.size()) {
2492 MS_LOG(EXCEPTION) << "The out_layout configured for node is not equal to its output nums";
2493 }
2494 for (size_t i = 0; i < layout_value_vector.size(); ++i) {
2495 auto layout_item = layout_value_vector[i];
2496 std::vector<int64_t> device_matrix_vector;
2497 std::vector<std::vector<int64_t>> tensor_map_vector;
2498 bool interleaved_parallel;
2499 if (GetLayoutFromAttrValue(layout_item, &device_matrix_vector, &tensor_map_vector, &interleaved_parallel) !=
2500 SUCCESS) {
2501 return FAILED;
2502 }
2503 auto out_layout = std::make_shared<TensorLayout>();
2504 if (out_layout->InitFromExtendVector(device_matrix_vector, tensor_map_vector, outputs_shape[i],
2505 interleaved_parallel) != SUCCESS) {
2506 MS_LOG(ERROR) << "The out_layout configured incorrect, device_matrix:" << device_matrix_vector
2507 << ", tensor_map:" << tensor_map_vector;
2508 return FAILED;
2509 }
2510 out_tensor_layouts->push_back(out_layout);
2511 }
2512 }
2513 return SUCCESS;
2514 }
2515
IsCohesiveNode(const CNodePtr & cnode)2516 static bool IsCohesiveNode(const CNodePtr &cnode) {
2517 return IsPrimitiveCNode(cnode, prim::kPrimCast) || IsPrimitiveCNode(cnode, prim::kPrimLoad) ||
2518 IsPrimitiveCNode(cnode, prim::kPrimDepend) || IsPrimitiveCNode(cnode, prim::kPrimAllGather) ||
2519 IsPrimitiveCNode(cnode, prim::kPrimMiniStepAllGather) || IsPrimitiveCNode(cnode, prim::kPrimMirrorMicroStep) ||
2520 IsPrimitiveCNode(cnode, prim::kPrimMicroStepAllGather) || IsPrimitiveCNode(cnode, prim::kPrimMirror) ||
2521 IsPrimitiveCNode(cnode, prim::kPrimMirrorMiniStep) || IsPrimitiveCNode(cnode, prim::kPrimVirtualDiv);
2522 }
2523
NodeParameterName(const CNodePtr & node,int64_t index,size_t curr_depth)2524 ParameterMap NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth) {
2525 if (curr_depth > MAX_RECURSIVE_DEPTH) {
2526 MS_LOG(WARNING) << "When finding the parameters' name of a operator, exceeded the maximum depth: "
2527 << MAX_RECURSIVE_DEPTH;
2528 return {};
2529 }
2530 bool only_trainable_params = ParallelContext::GetInstance()->stra_file_only_trainable_params();
2531 std::vector<AnfNodePtr> node_inputs{node->inputs()};
2532 ParameterMap param_names;
2533 for (int64_t i = 0; i < UlongToLong(node_inputs.size()); ++i) {
2534 int64_t idx = index > i ? index : i;
2535 auto input = node_inputs[LongToSize(i)];
2536 if (input->isa<Parameter>()) {
2537 auto input_parameter = input->cast<ParameterPtr>();
2538 if (input_parameter->has_default() && (!only_trainable_params || ParameterRequireGrad(input_parameter))) {
2539 (void)param_names.emplace_back(std::make_pair(input_parameter->name(), input_parameter));
2540 continue;
2541 }
2542 auto actual_param_node = RefParameterToActualParameter(input_parameter);
2543 if (!actual_param_node) {
2544 continue;
2545 }
2546 auto actual_param = actual_param_node->cast<ParameterPtr>();
2547 if (!only_trainable_params || ParameterRequireGrad(actual_param)) {
2548 (void)param_names.emplace_back(std::make_pair(actual_param->name(), actual_param));
2549 }
2550 } else if (input->isa<CNode>()) {
2551 CNodePtr cnode = input->cast<CNodePtr>();
2552 if (!IsValueNode<Primitive>(cnode->input(0))) {
2553 continue;
2554 }
2555 if (IsCohesiveNode(cnode) && cnode->size() >= 1) {
2556 auto input_param_names = NodeParameterName(cnode, idx, 0);
2557 (void)param_names.insert(param_names.cend(), input_param_names.cbegin(), input_param_names.cend());
2558 }
2559 }
2560 }
2561 return param_names;
2562 }
2563
ParallelInit(size_t rank_id,const size_t devices)2564 Status ParallelInit(size_t rank_id, const size_t devices) {
2565 MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
2566 int32_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
2567
2568 std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
2569 if (split_stage_num <= 0) {
2570 MS_LOG(ERROR) << "The parameter 'split_stage_num' must be a positive number, but got the value : "
2571 << split_stage_num;
2572 return FAILED;
2573 }
2574 int64_t device_num;
2575 int64_t global_rank;
2576 std::string backend;
2577 if (devices == 0) {
2578 auto comm_info = GetCommInfo();
2579 device_num = comm_info.device_num;
2580 global_rank = comm_info.global_rank;
2581 backend = comm_info.communication_backend;
2582 } else {
2583 device_num = devices;
2584 global_rank = rank_id;
2585 backend = HCCL_BACKEND;
2586 }
2587
2588 if ((device_num <= 0) || (device_num > MAX_DEVICE_NUM)) {
2589 MS_LOG(ERROR) << "The context configuration parameter 'device_num' must be positive, "
2590 "but got the value of device_num: "
2591 << device_num;
2592 return FAILED;
2593 }
2594
2595 // the device_num maybe get from communication interface
2596 if (device_num % split_stage_num != 0) {
2597 MS_LOG(ERROR) << "The parameter 'device_num' must be divided by 'split_stage_num', but got the device_num : "
2598 << device_num << "and the split_stage_num : " << split_stage_num;
2599 return FAILED;
2600 }
2601
2602 int64_t optimizer_weight_shard_size = ParallelContext::GetInstance()->optimizer_weight_shard_size();
2603 if (ParallelContext::GetInstance()->enable_parallel_optimizer() && optimizer_weight_shard_size > 0 &&
2604 device_num < optimizer_weight_shard_size) {
2605 MS_LOG(ERROR) << "When parallel_optimizer is enabled, the optimizer_weight_shard_size "
2606 << optimizer_weight_shard_size << " should not exceed the device num " << device_num << ".";
2607 return FAILED;
2608 }
2609
2610 if ((global_rank < 0) || (global_rank >= device_num)) {
2611 MS_LOG(ERROR) << "The parameter 'global_rank' must be greater than 0 and less equal 'device num', "
2612 "but got the global_rank : "
2613 << global_rank << "and the device_num : " << device_num;
2614 return FAILED;
2615 }
2616
2617 std::vector<int64_t> stages;
2618 for (int i = 0; i < split_stage_num; i++) {
2619 stages.push_back(device_num / split_stage_num);
2620 }
2621
2622 bool use_rec = (ParallelContext::GetInstance()->strategy_search_mode() == kRecursiveProgramming);
2623 bool use_sp = (ParallelContext::GetInstance()->strategy_search_mode() == kShardingPropagation) ||
2624 (ParallelContext::GetInstance()->sharding_propagation());
2625 if ((split_stage_num > 1) && (parallel_mode == kAutoParallel) && !(use_sp || use_rec)) {
2626 MS_LOG(ERROR) << "To enable the pipeline parallel, please set the parallel mode to " << kSemiAutoParallel << " or "
2627 << kAutoParallel << " with " << kShardingPropagation << " or " << kRecursiveProgramming;
2628 return FAILED;
2629 }
2630
2631 if (!InitDevice(device_num, global_rank, backend, stages)) {
2632 MS_LOG(ERROR) << "Init device failed";
2633 return FAILED;
2634 }
2635
2636 MS_LOG(INFO) << "The parallel context: device_num: " << device_num << ", global_rank: "
2637 << global_rank
2638 // << ", communication_backend: " << comm_info.communication_backend
2639 << ", communication_backend: " << HCCL_BACKEND
2640 << ", gradients_mean: " << ParallelContext::GetInstance()->gradients_mean()
2641 << ", gradient_fp32_sync: " << ParallelContext::GetInstance()->gradient_fp32_sync();
2642 return SUCCESS;
2643 }
2644
2645 // only used for FindCNode
SkipTrivialNodesMoveDown(const FuncGraphManagerPtr & manager,CNodePtr node)2646 static CNodePtr SkipTrivialNodesMoveDown(const FuncGraphManagerPtr &manager, CNodePtr node) {
2647 MS_EXCEPTION_IF_NULL(node);
2648 while (IsInTrivialNodeList(node) || IsSomePrimitive(node, LOAD)) {
2649 node = manager->node_users()[node].begin()->first->cast<CNodePtr>();
2650 }
2651 return node;
2652 }
2653
FindCNode(const AnfNodePtr & anode,const std::string & name,const FuncGraphPtr & func_graph,size_t max_depth)2654 std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string &name, const FuncGraphPtr &func_graph,
2655 size_t max_depth) {
2656 MS_EXCEPTION_IF_NULL(anode);
2657 MS_EXCEPTION_IF_NULL(anode->func_graph());
2658 FuncGraphManagerPtr manager = anode->func_graph()->manager();
2659 MS_EXCEPTION_IF_NULL(manager);
2660 if (max_depth > MAX_RECURSIVE_DEPTH) {
2661 MS_LOG(EXCEPTION) << "Recursive call is larger than 100000.";
2662 }
2663 AnfNodeIndexSet node_set = manager->node_users()[anode];
2664 bool result = false;
2665 CNodePtr cnode_return = nullptr;
2666 for (auto &node_pair : node_set) {
2667 CNodePtr use_apply = node_pair.first->cast<CNodePtr>();
2668 if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
2669 continue;
2670 }
2671 use_apply = SkipTrivialNodesMoveDown(manager, use_apply);
2672 if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
2673 continue;
2674 }
2675 ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
2676 MS_EXCEPTION_IF_NULL(prim_anf_node);
2677 PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
2678 MS_EXCEPTION_IF_NULL(node_prim);
2679 if (node_prim->name() == name && node_pair.second == 1) {
2680 if (use_apply->func_graph() == func_graph) {
2681 result = true;
2682 cnode_return = use_apply;
2683 MS_LOG(INFO) << "Find Primitive " << name << " in the same func_graph";
2684 continue;
2685 }
2686 MS_LOG(INFO) << "Find Primitive " << name << " in different func_graph";
2687 }
2688 if (ParallelContext::GetInstance()->enable_parallel_optimizer() && IsInAllGatherNodeList(use_apply)) {
2689 return FindCNode(node_pair.first, name, func_graph, max_depth + 1);
2690 }
2691 }
2692 return std::make_pair(result, cnode_return);
2693 }
2694
SetSharedParameterFlag(const FuncGraphPtr & root,const AnfNodePtr & parameter)2695 void SetSharedParameterFlag(const FuncGraphPtr &root, const AnfNodePtr ¶meter) {
2696 MS_EXCEPTION_IF_NULL(root);
2697 MS_EXCEPTION_IF_NULL(parameter);
2698 FuncGraphManagerPtr manager = root->manager();
2699 MS_EXCEPTION_IF_NULL(manager);
2700 ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>();
2701 if (parameter_ptr == nullptr) {
2702 MS_LOG(INFO) << parameter->ToString() << ": cast to ptr failed. it may not be a parameter";
2703 return;
2704 }
2705 auto user_set = manager->node_users()[parameter];
2706 int32_t user_count = 0;
2707 for (auto ¶m_pair : user_set) {
2708 CNodePtr cnode = param_pair.first->cast<CNodePtr>();
2709 MS_EXCEPTION_IF_NULL(cnode);
2710 if (cnode->in_forward_flag()) {
2711 user_count++;
2712 }
2713 }
2714 if (user_count > 1) {
2715 auto tensor_layout = parameter_ptr->user_data<TensorLayout>();
2716 tensor_layout->set_is_shared_param(true);
2717 }
2718 }
2719
GenerateBatchParallelStrategy(const OperatorInfoPtr operator_,const PrimitivePtr prim)2720 StrategyPtr GenerateBatchParallelStrategy(const OperatorInfoPtr operator_, const PrimitivePtr prim) {
2721 MS_EXCEPTION_IF_NULL(operator_);
2722 MS_EXCEPTION_IF_NULL(prim);
2723 if (!operator_->inputs_shape_new().empty()) {
2724 MS_LOG(EXCEPTION) << "Currently, tuple in tuple input does not support GenerateBatchParallelStrategy, please set "
2725 "strategy in python side";
2726 }
2727 StrategyPtr strategyPtr;
2728 std::shared_ptr<Strategies> strategy_v_ptr = operator_->GenerateBatchStrategiesWithCheck();
2729 MS_EXCEPTION_IF_NULL(strategy_v_ptr);
2730 auto stage_id = g_device_manager->stage_id();
2731 strategyPtr = NewStrategy(stage_id, *strategy_v_ptr);
2732 std::vector<ValuePtr> elements;
2733 for (size_t i = 0; i < strategy_v_ptr->size(); i++) {
2734 elements.push_back(MakeValue((*strategy_v_ptr)[i]));
2735 }
2736 ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements);
2737 // display the strategy generated by batch parallel
2738 auto attrs = prim->attrs();
2739 attrs[GEN_STRATEGY] = strategy;
2740 (void)prim->SetAttrs(attrs);
2741 MS_LOG(INFO) << "prim " << prim->name() << " batch parallel strategy is " << attrs[GEN_STRATEGY]->ToString();
2742 return strategyPtr;
2743 }
2744
GenerateStandAloneStrategy(const Shapes & inputs_shape)2745 StrategyPtr GenerateStandAloneStrategy(const Shapes &inputs_shape) {
2746 Strategies strategy_v;
2747 for (size_t i = 0; i != inputs_shape.size(); i++) {
2748 if (inputs_shape[i].empty()) {
2749 MS_LOG(INFO) << "Elements of shapes is empty.";
2750 Dimensions empty_element;
2751 strategy_v.push_back(empty_element);
2752 } else {
2753 Dimensions element(inputs_shape[i].size(), 1);
2754 strategy_v.push_back(element);
2755 }
2756 }
2757 auto stage_id = g_device_manager->stage_id();
2758 auto stra_ptr = NewStrategy(stage_id, strategy_v);
2759 return stra_ptr;
2760 }
2761
GenerateStra(const ShapeBasePtr & shape)2762 ShapeBasePtr GenerateStra(const ShapeBasePtr &shape) {
2763 ShapeBasePtr out_shape;
2764 if (shape->is_list()) {
2765 std::vector<ShapeBasePtr> list_stra;
2766 for (size_t i = 0; i < shape->size(); ++i) {
2767 auto recursive_stra = GenerateStra(shape->GetElement(SizeToLong(i)));
2768 list_stra.emplace_back(recursive_stra);
2769 }
2770 out_shape = std::make_shared<ShapeList>(list_stra);
2771 } else {
2772 if (shape->empty()) {
2773 MS_LOG(INFO) << "Elements of shapes is empty.";
2774 Dimensions empty_element;
2775 out_shape = std::make_shared<ShapeValue>(empty_element);
2776 } else {
2777 Dimensions element(shape->size(), 1);
2778 out_shape = std::make_shared<ShapeValue>(element);
2779 }
2780 }
2781 return out_shape;
2782 }
2783
GenerateStandAloneStrategyForNewShapes(const NewShapes & inputs_shape)2784 StrategyPtr GenerateStandAloneStrategyForNewShapes(const NewShapes &inputs_shape) {
2785 NewStrategies strategy_v;
2786 for (size_t i = 0; i != inputs_shape.size(); i++) {
2787 strategy_v.emplace_back(GenerateStra(inputs_shape[i]));
2788 }
2789 auto stage_id = g_device_manager->stage_id();
2790 auto stra_ptr = NewStrategy(stage_id, strategy_v);
2791 return stra_ptr;
2792 }
2793
IsInsertVirtualOutput(const FuncGraphPtr & root)2794 bool IsInsertVirtualOutput(const FuncGraphPtr &root) {
2795 MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
2796 auto comm_info = GetCommInfo();
2797 int64_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
2798 int64_t per_stage_device_num = comm_info.device_num / split_stage_num;
2799 int64_t current_stage = comm_info.global_rank / per_stage_device_num;
2800 MS_LOG(INFO) << "The current stage is: " << current_stage;
2801 if (!root->has_flag(kTraining) && !ParallelContext::GetInstance()->dataset_strategy().empty()) {
2802 MS_LOG(WARNING) << "In eval/predict net, the output parallel strategy would not follow "
2803 "the input parallel strategy when using context.set_auto_parallel_context(dataset_strategy)"
2804 " to configure the input strategy.";
2805 }
2806 return ((!root->has_flag(kTraining) && ParallelContext::GetInstance()->dataset_strategy().empty() &&
2807 current_stage == split_stage_num - 1) ||
2808 IsPynativeParallel());
2809 }
2810
GetInputLayoutFromCNode(const std::pair<AnfNodePtr,int64_t> & node_pair,const int & make_tuple_index)2811 TensorLayout GetInputLayoutFromCNode(const std::pair<AnfNodePtr, int64_t> &node_pair, const int &make_tuple_index) {
2812 CNodePtr cnode = node_pair.first->cast<CNodePtr>();
2813 MS_EXCEPTION_IF_NULL(cnode);
2814 OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode);
2815 MS_EXCEPTION_IF_NULL(distribute_operator);
2816 int64_t index = node_pair.second;
2817 TensorLayout tensorlayout_in;
2818 if (distribute_operator->inputs_tensor_info_new().empty()) {
2819 if (index > SizeToLong(distribute_operator->inputs_tensor_info().size())) {
2820 MS_LOG(EXCEPTION) << "The index is out of range, the node_pair.second is " << (index - 1)
2821 << ", the vector size is " << distribute_operator->inputs_tensor_info().size();
2822 }
2823 TensorInfo tensorinfo_in = distribute_operator->inputs_tensor_info()[LongToSize(index - 1)];
2824 tensorlayout_in = tensorinfo_in.tensor_layout();
2825 } else {
2826 if (index > SizeToLong(distribute_operator->inputs_tensor_info_new().size())) {
2827 MS_LOG(EXCEPTION) << "The index is out of range, the node_pair.second is " << (index - 1)
2828 << ", the vector size is " << distribute_operator->inputs_tensor_info_new().size();
2829 }
2830 auto tensorinfo_in = distribute_operator->inputs_tensor_info_new()[LongToSize(index - 1)];
2831 if (tensorinfo_in->is_list() && make_tuple_index != -1) {
2832 auto new_tensorinfo_in = tensorinfo_in->GetElement(make_tuple_index - 1);
2833 tensorlayout_in = new_tensorinfo_in->GetValue().tensor_layout();
2834 } else if (!tensorinfo_in->is_list() && make_tuple_index == -1) {
2835 tensorlayout_in = tensorinfo_in->GetValue().tensor_layout();
2836 } else {
2837 MS_LOG(EXCEPTION) << "tensorinfo_in does not match with make_tuple_index: make_tuple_index is "
2838 << make_tuple_index << ", node is " << node_pair.first->DebugString();
2839 }
2840 }
2841 return tensorlayout_in;
2842 }
2843
IsCellReuseForwardGraph(const FuncGraphPtr & graph)2844 bool IsCellReuseForwardGraph(const FuncGraphPtr &graph) { return graph->has_flag(FUNC_GRAPH_FLAG_CELL_REUSE); }
2845
GetCellReuseBackwardGraph(const FuncGraphPtr & forward_graph)2846 FuncGraphPtr GetCellReuseBackwardGraph(const FuncGraphPtr &forward_graph) {
2847 AnfNodePtr node = forward_graph->get_return();
2848 std::vector<std::pair<PrimitivePtr, int64_t>> patterns = {
2849 {prim::kPrimReturn, kIndex1}, {prim::kPrimMakeTuple, kIndex2}, {prim::kPrimPartial, kIndex1}};
2850 for (const auto &pattern : patterns) {
2851 auto cnode = node->cast<CNodePtr>();
2852 if (cnode == nullptr || !IsPrimitiveCNode(cnode, pattern.first)) {
2853 return nullptr;
2854 }
2855 auto prev_node_index = pattern.second;
2856 if (prev_node_index >= SizeToLong(cnode->size())) {
2857 return nullptr;
2858 }
2859 node = cnode->input(prev_node_index);
2860 }
2861 return GetValueNode<FuncGraphPtr>(node);
2862 }
2863
mirror_group_list(const TensorLayoutPtr & layout)2864 Shape mirror_group_list(const TensorLayoutPtr &layout) {
2865 int64_t rank = g_device_manager->global_rank();
2866 auto stage_dev_list = g_device_manager->GetDeviceListInThisStage();
2867 DeviceMatrix dev_matrix(rank, stage_dev_list, layout->device_arrangement().array());
2868 RankList group_devices;
2869 if (dev_matrix.GetDevicesByTensorMap(layout->tensor_map().array(), &group_devices) != SUCCESS) {
2870 MS_LOG(EXCEPTION) << "For layout:" << layout->ToString() << ", infer mirror failed";
2871 }
2872 return group_devices;
2873 }
2874
ChangeAllGatherGroup(const CNodePtr & ag_cnode,const RankList & new_group_ranks)2875 void ChangeAllGatherGroup(const CNodePtr &ag_cnode, const RankList &new_group_ranks) {
2876 Group new_group;
2877 if (g_device_manager->CreateGroup(new_group_ranks, &new_group) != SUCCESS) {
2878 MS_LOG(EXCEPTION) << ": Create communication group failed, the rank_list is: " << new_group_ranks;
2879 }
2880 auto ag_prim = GetCNodePrimitive(ag_cnode);
2881 ag_prim->AddAttr(GROUP, MakeValue(new_group.name()));
2882 ag_prim->AddAttr(GROUP_RANKS, MakeValue(g_device_manager->FindRankListNameByHashName(new_group.name())));
2883 ag_prim->AddAttr(RANK_SIZE, MakeValue<int64_t>(new_group_ranks.size()));
2884 }
2885
InterleavedReplacedConcatNodes(const std::vector<CNodePtr> & ag_vector)2886 std::vector<CNodePtr> InterleavedReplacedConcatNodes(const std::vector<CNodePtr> &ag_vector) {
2887 std::vector<CNodePtr> replace_nodes;
2888 for (const auto &ag : ag_vector) {
2889 auto ag_next_nodes = GetOutputNodesWithFilter(ag, [&](const AnfNodePtr &anode) {
2890 return IsPrimitiveCNode(anode, prim::kPrimSplit) || IsPrimitiveCNode(anode, prim::kPrimTupleGetItem) ||
2891 IsPrimitiveCNode(anode, prim::kPrimMakeTuple);
2892 });
2893 std::set<AnfNodePtr> next_nodes_set;
2894 std::transform(ag_next_nodes.begin(), ag_next_nodes.end(), std::inserter(next_nodes_set, next_nodes_set.begin()),
2895 [](auto pair) { return pair.first; });
2896 if (!(next_nodes_set.size() == kSizeOne && IsPrimitiveCNode(ag_next_nodes.front().first, prim::kPrimConcat))) {
2897 continue;
2898 }
2899 auto concat_cnode = ag_next_nodes.front().first->cast<CNodePtr>();
2900 auto concat_prim = GetCNodePrimitive(concat_cnode);
2901 if (concat_prim->instance_name().find(REDISTRIBUTION_OP) != std::string::npos) {
2902 replace_nodes.push_back(concat_cnode);
2903 }
2904 }
2905 return replace_nodes;
2906 }
2907
CreateInterleavedNeedReplaceOpLists(const CNodePtr & virtual_converter_end,const PrimitivePtr & r_prim)2908 std::vector<std::vector<CNodePtr>> CreateInterleavedNeedReplaceOpLists(const CNodePtr &virtual_converter_end,
2909 const PrimitivePtr &r_prim) {
2910 std::vector<std::vector<CNodePtr>> need_replace_op_lists;
2911 for (size_t j = 1; j < virtual_converter_end->size(); ++j) {
2912 auto current_node = virtual_converter_end->input(j)->cast<CNodePtr>();
2913 MS_EXCEPTION_IF_NULL(current_node);
2914 std::vector<CNodePtr> need_replace_op_list;
2915 while (!IsPrimitiveCNode(current_node, prim::kPrimVirtualConverterBegin)) {
2916 if (IsPrimitiveCNode(current_node, r_prim)) {
2917 need_replace_op_list.push_back(current_node);
2918 }
2919 current_node = current_node->input(kIndex1)->cast<CNodePtr>();
2920 MS_EXCEPTION_IF_NULL(current_node);
2921 }
2922 need_replace_op_lists.push_back(need_replace_op_list);
2923 }
2924 return need_replace_op_lists;
2925 }
2926
ReplaceInterleavedAllGatherToConcat(const FuncGraphPtr & func_graph,const std::vector<CNodePtr> & ag_vector,const std::vector<std::vector<int64_t>> & new_group_ranks_vector,size_t independent_size)2927 CNodePtr ReplaceInterleavedAllGatherToConcat(const FuncGraphPtr &func_graph, const std::vector<CNodePtr> &ag_vector,
2928 const std::vector<std::vector<int64_t>> &new_group_ranks_vector,
2929 size_t independent_size) {
2930 std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple->Clone())};
2931 std::transform(ag_vector.begin(), ag_vector.end(), std::back_inserter(make_tuple_inputs),
2932 [&](auto node) { return independent_size == 1 ? node->input(kIndex1) : node; });
2933 auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
2934 auto replace_nodes = InterleavedReplacedConcatNodes(ag_vector);
2935 bool replace_concat = (!replace_nodes.empty() && independent_size == 1);
2936 AnfNodePtr axis = NewValueNode(MakeValue<int64_t>(0));
2937 if (replace_concat) {
2938 axis = replace_nodes.front()->input(kIndex2);
2939 }
2940 std::vector<AnfNodePtr> concat_inputs = {NewValueNode(prim::kPrimConcat->Clone()), make_tuple, axis};
2941 auto concat = func_graph->NewCNode(concat_inputs);
2942 concat->AddAttr(INTERLEAVED_PARALLEL, MakeValue(true));
2943 auto manager = func_graph->manager();
2944
2945 for (size_t i = 0; i < ag_vector.size(); ++i) {
2946 auto ag = ag_vector[i];
2947 if (independent_size != 1) {
2948 // set allgather attrs
2949 ChangeAllGatherGroup(ag, new_group_ranks_vector[i]);
2950 }
2951 if (!replace_concat) {
2952 (void)manager->Replace(ag, concat);
2953 }
2954 }
2955 if (!replace_concat) {
2956 return concat;
2957 }
2958 for (size_t i = 0; i < replace_nodes.size(); ++i) {
2959 (void)manager->Replace(replace_nodes[i], concat);
2960 }
2961 return concat;
2962 }
2963
MergeOpBeforeInterleaveSlice(const FuncGraphPtr & func_graph,const CNodePtr & virtual_converter_end)2964 void MergeOpBeforeInterleaveSlice(const FuncGraphPtr &func_graph, const CNodePtr &virtual_converter_end) {
2965 std::vector<std::vector<CNodePtr>> need_replace_op_lists =
2966 CreateInterleavedNeedReplaceOpLists(virtual_converter_end, prim::kPrimStridedSlice);
2967 auto manager = func_graph->manager();
2968 if (need_replace_op_lists.empty()) {
2969 return;
2970 }
2971 auto col_size = need_replace_op_lists.front().size();
2972 for (size_t i = 0; i < need_replace_op_lists.size(); ++i) {
2973 if (need_replace_op_lists[i].size() != col_size) {
2974 MS_LOG(INTERNAL_EXCEPTION) << "Slice redistribution infer failed.";
2975 }
2976 }
2977 for (size_t col = 0; col < col_size; ++col) {
2978 std::set<std::vector<std::vector<int64_t>>> slice_value_list_set;
2979 for (size_t row = 0; row < need_replace_op_lists.size(); ++row) {
2980 auto slice_cnode = need_replace_op_lists[row][col];
2981 std::vector<std::vector<int64_t>> slice_value_list;
2982 for (size_t i = 2; i < kSizeFive; ++i) {
2983 ValuePtr slice_value = GetValueNode(slice_cnode->input(i));
2984 MS_EXCEPTION_IF_NULL(slice_value);
2985 auto value_vector = GetValue<std::vector<int64_t>>(slice_value);
2986 slice_value_list.push_back(value_vector);
2987 }
2988 slice_value_list_set.insert(slice_value_list);
2989 }
2990 if (slice_value_list_set.size() != need_replace_op_lists.size()) {
2991 continue;
2992 }
2993 // merge nodes before multi slice
2994 auto slice_input = need_replace_op_lists[kIndex0][col]->input(kIndex1);
2995 need_replace_op_lists[kIndex0][col]->AddAttr(INTERLEAVED_PARALLEL, MakeValue(true));
2996 for (size_t row = 1; row < need_replace_op_lists.size(); ++row) {
2997 auto slice_cnode = need_replace_op_lists[row][col];
2998 slice_cnode->AddAttr(INTERLEAVED_PARALLEL, MakeValue(true));
2999 (void)manager->SetEdge(slice_cnode, kIndex1, slice_input);
3000 }
3001 }
3002 }
3003
ConvertInterleaveAllGatherToConcat(const FuncGraphPtr & func_graph,const CNodePtr & virtual_converter_end,const std::vector<std::vector<std::vector<int64_t>>> & ag_group_ranks_vectors)3004 void ConvertInterleaveAllGatherToConcat(const FuncGraphPtr &func_graph, const CNodePtr &virtual_converter_end,
3005 const std::vector<std::vector<std::vector<int64_t>>> &ag_group_ranks_vectors) {
3006 // Change communication rank_list && Create communication group
3007 // Replace AllConcat to Concat
3008 std::vector<std::vector<CNodePtr>> need_replace_op_lists =
3009 CreateInterleavedNeedReplaceOpLists(virtual_converter_end, prim::kPrimAllGather);
3010 MergeOpBeforeInterleaveSlice(func_graph, virtual_converter_end);
3011 if (need_replace_op_lists.size() != ag_group_ranks_vectors.size()) {
3012 MS_LOG(INTERNAL_EXCEPTION) << "AllGather redistribution infer failed.";
3013 }
3014 if (need_replace_op_lists.empty()) {
3015 return;
3016 }
3017 auto col_size = need_replace_op_lists.front().size();
3018 for (size_t i = 0; i < need_replace_op_lists.size(); ++i) {
3019 if (need_replace_op_lists[i].size() != col_size || ag_group_ranks_vectors[i].size() != col_size) {
3020 MS_LOG(INTERNAL_EXCEPTION) << "AllGather redistribution infer failed.";
3021 }
3022 }
3023 auto interleaved_num = ParallelContext::GetInstance()->fine_grained_micro_interleaved_size();
3024 for (size_t col = 0; col < col_size; ++col) {
3025 std::vector<std::vector<int64_t>> new_group_ranks_vector;
3026 std::vector<CNodePtr> ag_vector;
3027 size_t independent_size = 0;
3028 for (size_t row = 0; row < need_replace_op_lists.size(); ++row) {
3029 auto group_ranks = ag_group_ranks_vectors[row][col];
3030 std::vector<int64_t> new_group_ranks;
3031 std::set<int64_t> new_group_ranks_set;
3032 for (const auto &g_rank : group_ranks) {
3033 new_group_ranks_set.insert(int64_t(g_rank / interleaved_num));
3034 new_group_ranks.push_back(int64_t(g_rank / interleaved_num));
3035 }
3036 if (new_group_ranks_set.size() == new_group_ranks.size()) {
3037 // set allgather attrs
3038 ChangeAllGatherGroup(need_replace_op_lists[row][col], new_group_ranks);
3039 continue;
3040 }
3041 std::vector<int64_t> new_group_ranks_no_repeat;
3042 std::copy(new_group_ranks_set.begin(), new_group_ranks_set.end(), std::back_inserter(new_group_ranks_no_repeat));
3043 std::sort(new_group_ranks_no_repeat.begin(), new_group_ranks_no_repeat.end());
3044 new_group_ranks_vector.push_back(new_group_ranks_no_repeat);
3045 if (independent_size > 0 && new_group_ranks_no_repeat.size() != independent_size) {
3046 MS_LOG(INTERNAL_EXCEPTION) << "The concat group in micro interleaved is wrong!";
3047 }
3048 independent_size = new_group_ranks_no_repeat.size();
3049 ag_vector.push_back(need_replace_op_lists[row][col]);
3050 }
3051 if (new_group_ranks_vector.empty()) {
3052 continue;
3053 }
3054
3055 // Check whether all branch needing be replace
3056 if (new_group_ranks_vector.size() < need_replace_op_lists.size()) {
3057 MS_LOG(INTERNAL_EXCEPTION) << "The concat group in micro interleaved is wrong!";
3058 }
3059
3060 // replace allgathers to one concat.
3061 auto replaced_concat =
3062 ReplaceInterleavedAllGatherToConcat(func_graph, ag_vector, new_group_ranks_vector, independent_size);
3063 auto manager = func_graph->manager();
3064 auto replaced_concat_users =
3065 GetOutputNodesWithFilter(replaced_concat, [&](const AnfNodePtr &anode) { return false; });
3066 if (replaced_concat_users.size() == kSizeOne) {
3067 continue;
3068 }
3069 if (std::all_of(replaced_concat_users.begin(), replaced_concat_users.end(),
3070 [](const std::pair<AnfNodePtr, int> &pair) {
3071 return IsPrimitiveCNode(pair.first, prim::kPrimStridedSlice) &&
3072 pair.first->cast<CNodePtr>()->HasAttr(INTERLEAVED_PARALLEL);
3073 })) {
3074 continue;
3075 }
3076 // merge the nodes afer the interleaved parallel concat.
3077 auto virtual_end_input1 = virtual_converter_end->input(kIndex1)->cast<CNodePtr>();
3078 MS_EXCEPTION_IF_NULL(virtual_end_input1);
3079 auto new_virtual_converter_end = CreateVirtualConverterEndNode(func_graph, {virtual_end_input1});
3080
3081 (void)manager->Replace(virtual_converter_end, new_virtual_converter_end);
3082 }
3083 }
3084
IsDuplicatedVirtualConverterBegin(const CNodePtr & virtual_converter_begin)3085 bool IsDuplicatedVirtualConverterBegin(const CNodePtr &virtual_converter_begin) {
3086 auto virtual_converter_begin_input = virtual_converter_begin->input(kSizeOne);
3087 if (IsPrimitiveCNode(virtual_converter_begin_input, prim::kPrimVirtualConverterEnd)) {
3088 return false;
3089 }
3090 if (!IsPrimitiveCNode(virtual_converter_begin_input) ||
3091 IsPrimitiveCNode(virtual_converter_begin_input, prim::kPrimUpdateState)) {
3092 return false;
3093 }
3094 auto virtual_converter_begin_input_cnode = virtual_converter_begin_input->cast<CNodePtr>();
3095 if (IsParallelCareNode(virtual_converter_begin_input_cnode)) {
3096 return false;
3097 }
3098 auto virtual_converter_begin_users = GetOutputNodesWithFilter(
3099 virtual_converter_begin, [&](const AnfNodePtr &anode) { return IsPrimitiveCNode(anode, prim::kPrimTupleGetItem); });
3100 if (virtual_converter_begin_users.size() <= kSizeOne) {
3101 return false;
3102 }
3103 std::set<std::vector<std::vector<int64_t>>> slice_value_list_set;
3104 for (const auto &user_pair : virtual_converter_begin_users) {
3105 if (!IsPrimitiveCNode(user_pair.first, prim::kPrimStridedSlice)) {
3106 continue;
3107 }
3108 auto slice = user_pair.first->cast<CNodePtr>();
3109 std::vector<std::vector<int64_t>> slice_value_list;
3110 for (size_t i = 2; i < kSizeFive; ++i) {
3111 ValuePtr slice_value = GetValueNode(slice->input(i));
3112 MS_EXCEPTION_IF_NULL(slice_value);
3113 auto value_vector = GetValue<std::vector<int64_t>>(slice_value);
3114 slice_value_list.push_back(value_vector);
3115 }
3116 slice_value_list_set.insert(slice_value_list);
3117 }
3118 if (slice_value_list_set.size() == virtual_converter_begin_users.size()) {
3119 return false;
3120 }
3121 return true;
3122 }
3123
GetOrderOfTwoAnode(const std::pair<AnfNodePtr,int> & pair1,const std::pair<AnfNodePtr,int> & pair2)3124 bool GetOrderOfTwoAnode(const std::pair<AnfNodePtr, int> &pair1, const std::pair<AnfNodePtr, int> &pair2) {
3125 int number1 = pair1.second;
3126 int number2 = pair2.second;
3127 auto pair1_input_node = pair1.first->cast<CNodePtr>()->input(pair1.second);
3128 auto pair2_input_node = pair2.first->cast<CNodePtr>()->input(pair2.second);
3129 if (IsPrimitiveCNode(pair1_input_node, prim::kPrimTupleGetItem)) {
3130 number1 = LongToInt(GetTupleGetItemIndex(pair1_input_node->cast<CNodePtr>()));
3131 }
3132 if (IsPrimitiveCNode(pair2_input_node, prim::kPrimTupleGetItem)) {
3133 number2 = LongToInt(GetTupleGetItemIndex(pair2_input_node->cast<CNodePtr>()));
3134 }
3135 return number1 < number2;
3136 }
3137
DoSplitForNotParallelCareOpsInterleaved(const FuncGraphManagerPtr & manager,const CNodePtr & virtual_converter_begin)3138 std::vector<CNodePtr> DoSplitForNotParallelCareOpsInterleaved(const FuncGraphManagerPtr &manager,
3139 const CNodePtr &virtual_converter_begin) {
3140 auto virtual_converter_begin_input = virtual_converter_begin->input(kSizeOne);
3141 auto virtual_converter_begin_users = GetOutputNodesWithFilter(
3142 virtual_converter_begin, [&](const AnfNodePtr &anode) { return IsPrimitiveCNode(anode, prim::kPrimTupleGetItem); });
3143 std::sort(virtual_converter_begin_users.begin(), virtual_converter_begin_users.end(),
3144 [](const auto &pair1, const auto &pair2) { return GetOrderOfTwoAnode(pair1, pair2); });
3145 auto virtual_converter_begin_input_cnode = virtual_converter_begin_input->cast<CNodePtr>();
3146 std::vector<AnfNodePtr> new_inputs;
3147 std::vector<CNodePtr> new_virtual_converter_begin_vector;
3148 for (size_t i = 1; i < virtual_converter_begin_input_cnode->size(); ++i) {
3149 if (!IsPrimitiveCNode(virtual_converter_begin_input_cnode->input(i)) ||
3150 IsPrimitiveCNode(virtual_converter_begin_input_cnode->input(i), prim::kPrimUpdateState)) {
3151 new_inputs.push_back(virtual_converter_begin_input_cnode->input(i));
3152 continue;
3153 }
3154 auto new_virtual_converter_begin = CreateVirtualConverterBeginNode(
3155 virtual_converter_begin_input_cnode->input(i)->cast<CNodePtr>(), virtual_converter_begin_users.size());
3156 new_inputs.push_back(new_virtual_converter_begin);
3157 new_virtual_converter_begin_vector.push_back(new_virtual_converter_begin);
3158 }
3159
3160 for (size_t interleveaved_index = 0; interleveaved_index < virtual_converter_begin_users.size();
3161 ++interleveaved_index) {
3162 std::vector<AnfNodePtr> splited_node_inputs = {virtual_converter_begin_input_cnode->input(kIndex0)};
3163 for (size_t i = 0; i < new_inputs.size(); ++i) {
3164 if (!IsPrimitiveCNode(new_inputs[i])) {
3165 splited_node_inputs.push_back(new_inputs[i]);
3166 continue;
3167 }
3168 std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), new_inputs[i],
3169 CreatInt64Imm(UlongToLong(interleveaved_index))};
3170 auto tuple_get_item_cnode = virtual_converter_begin_input_cnode->func_graph()->NewCNode(tuple_get_item_inputs);
3171 splited_node_inputs.push_back(tuple_get_item_cnode);
3172 }
3173 auto splited_node = virtual_converter_begin_input_cnode->func_graph()->NewCNode(splited_node_inputs);
3174 manager->SetEdge(virtual_converter_begin_users[interleveaved_index].first,
3175 virtual_converter_begin_users[interleveaved_index].second, splited_node);
3176 }
3177 return new_virtual_converter_begin_vector;
3178 }
3179
SplitNotParallelCareOpsInterleaved(const FuncGraphPtr & root)3180 void SplitNotParallelCareOpsInterleaved(const FuncGraphPtr &root) {
3181 AnfNodePtr ret_after = root->get_return();
3182 MS_EXCEPTION_IF_NULL(ret_after);
3183 auto all_nodes = TopoSort(ret_after, SuccDeeperSimple);
3184 auto manager = root->manager();
3185 auto node_users = manager->node_users();
3186 for (const auto &node : all_nodes) {
3187 if (!IsPrimitiveCNode(node, prim::kPrimVirtualConverterBegin)) {
3188 continue;
3189 }
3190 std::queue<CNodePtr> visited;
3191 visited.push(node->cast<CNodePtr>());
3192 while (!visited.empty()) {
3193 auto virtual_converter_begin = visited.front();
3194 visited.pop();
3195 if (!IsDuplicatedVirtualConverterBegin(virtual_converter_begin)) {
3196 continue;
3197 }
3198 // Need to split the input
3199 auto new_virtual_converter_begins = DoSplitForNotParallelCareOpsInterleaved(manager, virtual_converter_begin);
3200 for (auto &new_virtual_converter_begin : new_virtual_converter_begins) {
3201 visited.push(new_virtual_converter_begin);
3202 }
3203 }
3204 }
3205 }
3206
EraseVirtualConverter(const FuncGraphPtr & root)3207 void EraseVirtualConverter(const FuncGraphPtr &root) {
3208 AnfNodePtr ret_after = root->get_return();
3209 MS_EXCEPTION_IF_NULL(ret_after);
3210 auto all_nodes = TopoSort(ret_after, SuccDeeperSimple);
3211 auto manager = root->manager();
3212 auto node_users = manager->node_users();
3213 for (const auto &node : all_nodes) {
3214 if (!IsPrimitiveCNode(node, prim::kPrimVirtualConverterBegin)) {
3215 continue;
3216 }
3217 auto virtual_converter_begin = node->cast<CNodePtr>();
3218 if (!IsPrimitiveCNode(virtual_converter_begin->input(kIndex1), prim::kPrimVirtualConverterEnd)) {
3219 MS_LOG(INFO) << "The VirtualConverterBegin input is not VirtualConverterEnd, it is "
3220 << virtual_converter_begin->input(kIndex1)->fullname_with_scope();
3221 auto virtual_converter_begin_input_node = virtual_converter_begin->input(kIndex1);
3222 for (const auto &v_user_pair : node_users.at(virtual_converter_begin)) {
3223 (void)manager->Replace(v_user_pair.first, virtual_converter_begin_input_node);
3224 }
3225 continue;
3226 }
3227 auto virtual_converter_end = virtual_converter_begin->input(kIndex1)->cast<CNodePtr>();
3228 auto virtual_converter_begin_users = manager->node_users()[virtual_converter_begin];
3229 if (virtual_converter_begin_users.size() != virtual_converter_end->size() - 1) {
3230 MS_LOG(INTERNAL_EXCEPTION)
3231 << "The VirtualConverterBegin users nums is not equal to VirtualConverterEnd inputs nums";
3232 }
3233 for (const auto &node_pair : virtual_converter_begin_users) {
3234 if (!IsPrimitiveCNode(node_pair.first, prim::kPrimTupleGetItem)) {
3235 MS_LOG(INTERNAL_EXCEPTION) << "The VirtualConverterBegin user should be tuple_get_item.";
3236 }
3237 auto tuple_get_item = node_pair.first->cast<CNodePtr>();
3238 auto tuple_get_item_index_value = GetValueNode(tuple_get_item->input(kIndex2));
3239 MS_EXCEPTION_IF_NULL(tuple_get_item_index_value);
3240 auto get_item_index = GetValue<int64_t>(tuple_get_item_index_value);
3241 (void)manager->Replace(tuple_get_item, virtual_converter_end->input(get_item_index + 1));
3242 }
3243 }
3244 AnfNodePtr new_ret_after = root->get_return();
3245 MS_EXCEPTION_IF_NULL(new_ret_after);
3246 auto new_all_nodes = TopoSort(new_ret_after, SuccDeeperSimple);
3247 for (const auto &node : new_all_nodes) {
3248 if (IsPrimitiveCNode(node, prim::kPrimVirtualConverterEnd)) {
3249 auto virtual_converter_end_cnode = node->cast<CNodePtr>();
3250 if (virtual_converter_end_cnode->size() != kSizeTwo) {
3251 MS_LOG(INTERNAL_EXCEPTION) << "The VirtualConverterEnd nums is not equal to VirtualConverterBegin nums.";
3252 }
3253 auto virtual_converter_end_input = virtual_converter_end_cnode->input(kIndex1);
3254 (void)manager->Replace(virtual_converter_end_cnode, virtual_converter_end_input);
3255 }
3256 }
3257 }
3258
GetSerialNumberString(size_t number)3259 std::string GetSerialNumberString(size_t number) {
3260 std::string suffix = "th";
3261 if (number == kSizeOne) {
3262 suffix = "st";
3263 } else if (number == kSizeTwo) {
3264 suffix = "nd";
3265 } else if (number == kSizeThree) {
3266 suffix = "rd";
3267 }
3268 std::ostringstream oss;
3269 oss << number << suffix;
3270 return oss.str();
3271 }
3272
3273 // Get single device capacity in Go
GetDeviceCapacity()3274 size_t GetDeviceCapacity() {
3275 auto context = MsContext::GetInstance();
3276 MS_EXCEPTION_IF_NULL(context);
3277 size_t size_from_context;
3278 auto max_device_memory = context->get_param<float>(MS_CTX_MAX_DEVICE_MEMORY);
3279 float total_device_memory = 32.0f;
3280 if (context->ascend_soc_version() == kAscendVersion910b || context->ascend_soc_version() == kAscendVersion910c) {
3281 total_device_memory = 64.0f;
3282 }
3283 if (max_device_memory <= total_device_memory) {
3284 MS_LOG(DEBUG) << "context max_device_memory:" << max_device_memory;
3285 size_from_context = FloatToSize(max_device_memory * kGBToByte);
3286 } else {
3287 auto variable_memory_max_size = context->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE);
3288 if (variable_memory_max_size == "0") {
3289 return 0;
3290 }
3291 MS_LOG(DEBUG) << "context variable_memory_max_size:" << variable_memory_max_size;
3292 auto pos = variable_memory_max_size.find('*');
3293 if (pos == std::string::npos) {
3294 MS_LOG(EXCEPTION) << "Invalid variable_memory_max_size";
3295 }
3296 auto gb_str = variable_memory_max_size.substr(0, pos);
3297 auto gb_var = std::stoull(gb_str);
3298 MS_LOG(DEBUG) << "variable_memory_max_size(GB):" << gb_var;
3299 size_from_context = gb_var * kGBToByte;
3300 }
3301 return size_from_context;
3302 }
3303
GenerateAbsByOpInfer(const PrimitivePtr & primitive,const AnfNodePtrList & input_list)3304 abstract::AbstractBasePtr GenerateAbsByOpInfer(const PrimitivePtr &primitive, const AnfNodePtrList &input_list) {
3305 MS_EXCEPTION_IF_NULL(primitive);
3306 std::vector<AbstractBasePtr> input_args;
3307 (void)std::for_each(input_list.begin(), input_list.end(),
3308 [&input_args](const auto &input) { (void)input_args.emplace_back(input->abstract()); });
3309 auto abs_opt = abstract::TryInferAbstract(primitive, input_args);
3310 if (!abs_opt.has_value()) {
3311 MS_LOG(EXCEPTION) << primitive->name() << " infer is not registered.";
3312 }
3313 auto abs = abs_opt.value();
3314 MS_EXCEPTION_IF_NULL(abs);
3315 MS_LOG(DEBUG) << "Abstract for " << primitive->name() << " is " << abs->ToString();
3316 return abs;
3317 }
3318 } // namespace parallel
3319 } // namespace mindspore
3320