1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/pipeline_transformer/pipeline_transformer.h"
18 #include <set>
19 #include <vector>
20 #include <string>
21 #include <utility>
22 #include <algorithm>
23 #include <memory>
24 #include "base/base.h"
25 #include "mindspore/core/ops/sequence_ops.h"
26 #include "mindspore/core/ops/other_ops.h"
27 #include "mindspore/core/ops/nn_ops.h"
28 #include "mindspore/core/ops/array_ops.h"
29 #include "mindspore/core/ops/framework_ops.h"
30 #include "mindspore/core/ops/arithmetic_ops.h"
31 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
32 #include "frontend/parallel/ops_info/ops_utils.h"
33 #include "frontend/parallel/group_manager.h"
34 #include "frontend/parallel/parameter_manager.h"
35 #include "include/common/utils/parallel_context.h"
36 #include "frontend/parallel/step_parallel.h"
37 #include "frontend/parallel/node_check.h"
38 #include "frontend/parallel/graph_util/node_info.h"
39 #include "frontend/parallel/graph_util/graph_info.h"
40 #include "frontend/parallel/graph_util/pipeline_split_utils.h"
41 #include "frontend/parallel/step_parallel_utils.h"
42 #include "frontend/parallel/graph_util/graph_splitter.h"
43 #include "frontend/parallel/tensor_layout/shared_parameter.h"
44 #include "ir/anf.h"
45 #include "ir/graph_utils.h"
46 #include "include/common/utils/comm_manager.h"
47 #include "utils/ms_context.h"
48 #include "utils/tensor_construct_utils.h"
49 #include "mindspore/core/utils/parallel_node_check.h"
50 #include "include/common/debug/anf_ir_dump.h"
51
52 namespace mindspore {
53 namespace parallel {
54 namespace {
SetMakeTupleAbstract(const CNodePtr & node)55 void SetMakeTupleAbstract(const CNodePtr &node) {
56 if (!IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
57 return;
58 }
59
60 AbstractBasePtrList abstract_list;
61 for (size_t i = 1; i < node->inputs().size(); i++) {
62 abstract_list.emplace_back(node->input(i)->abstract());
63 }
64 auto abs = std::make_shared<abstract::AbstractTuple>(abstract_list);
65 node->set_abstract(abs);
66 }
67 } // namespace
68
69 mindspore::HashMap<int64_t, int64_t> send_tag_map;
70 mindspore::HashMap<int64_t, int64_t> recv_tag_map;
71 const std::set<PrimitivePtr> WHITE_LIST = {prim::kPrimTupleGetItem, prim::kPrimMakeTuple, prim::kPrimCast};
72
IsInWhiteList(const CNodePtr & cnode)73 bool IsInWhiteList(const CNodePtr &cnode) {
74 for (auto prim = WHITE_LIST.cbegin(); prim != WHITE_LIST.cend(); ++prim) {
75 if (IsPrimitiveCNode(cnode, *prim)) {
76 return true;
77 }
78 }
79 return false;
80 }
81
GetRealAbstract(const AnfNodePtr & node)82 static AbstractBasePtr GetRealAbstract(const AnfNodePtr &node) {
83 if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
84 auto &input = node->cast<CNodePtr>()->input(1);
85 MS_EXCEPTION_IF_NULL(input);
86 return input->abstract();
87 }
88 return node->abstract();
89 }
90
FindNodeGraph(const CNodePtr & cnode)91 FuncGraphPtr FindNodeGraph(const CNodePtr &cnode) {
92 auto graph = cnode->func_graph();
93 if (IsValueNode<FuncGraph>(cnode->input(0))) {
94 graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
95 }
96 return graph;
97 }
98
UpdateParameterSharedInfo(const AnfNodePtr & node,const AnfNodePtr & communcate_op,bool is_send)99 void PipelineTransformer::UpdateParameterSharedInfo(const AnfNodePtr &node, const AnfNodePtr &communcate_op,
100 bool is_send) {
101 MS_EXCEPTION_IF_NULL(node);
102 MS_EXCEPTION_IF_NULL(communcate_op);
103
104 if (!node->isa<Parameter>()) {
105 return;
106 }
107 auto root_param = node;
108 if (node->func_graph() != root_) {
109 root_param = GetArgumentsByParameter(node);
110 MS_EXCEPTION_IF_NULL(root_param);
111 }
112
113 // get communication info from cnode.
114 auto prim = GetCNodePrimitive(communcate_op);
115 MS_EXCEPTION_IF_NULL(prim);
116
117 auto sr_tag_attr = prim->GetAttr(SR_TAG);
118 MS_EXCEPTION_IF_NULL(sr_tag_attr);
119 auto sr_tag = GetValue<int64_t>(sr_tag_attr);
120 auto peer_rank_attr = is_send ? prim->GetAttr(DEST_RANK) : prim->GetAttr(SRC_RANK);
121 MS_EXCEPTION_IF_NULL(peer_rank_attr);
122 auto peer_rank = GetValue<int64_t>(peer_rank_attr);
123 auto group_attr = prim->GetAttr(GROUP);
124 MS_EXCEPTION_IF_NULL(group_attr);
125 auto group = GetValue<std::string>(group_attr);
126
127 // Use global rank since local group may not exist after loading checkpoint.
128 auto rank_list = g_device_manager->FindRankListByHashName(group);
129 peer_rank = rank_list.at(peer_rank);
130
131 // update tensor layout.
132 auto param = root_param->cast<ParameterPtr>();
133 MS_EXCEPTION_IF_NULL(param);
134 auto shared_parameters = std::make_shared<SharedParameter>(true, is_send, peer_rank, sr_tag);
135 param->set_user_data<SharedParameter>(shared_parameters);
136 }
137
GetTensorInfo(const std::pair<OperatorInfoPtr,int> & op_info_pair,bool is_param)138 TensorInfo PipelineTransformer::GetTensorInfo(const std::pair<OperatorInfoPtr, int> &op_info_pair, bool is_param) {
139 if (is_param) {
140 auto inputs_tensor_info = op_info_pair.first->inputs_tensor_info();
141 return inputs_tensor_info.at(IntToSize(op_info_pair.second));
142 } else {
143 auto outputs_tensor_info = op_info_pair.first->outputs_tensor_info();
144 return outputs_tensor_info.at(IntToSize(op_info_pair.second));
145 }
146 }
147
SeparateParamBorder(const std::vector<AnfNodePtr> & nodes,bool send,std::vector<AnfNodePtr> * const params,std::vector<AnfNodePtr> * const borders)148 static void SeparateParamBorder(const std::vector<AnfNodePtr> &nodes, bool send, std::vector<AnfNodePtr> *const params,
149 std::vector<AnfNodePtr> *const borders) {
150 std::vector<AnfNodePtr> real_comm_ops;
151 if (send) {
152 (void)std::transform(nodes.begin(), nodes.end(), std::back_inserter(real_comm_ops), [](const AnfNodePtr &n) {
153 const auto &cnode = n->cast<CNodePtr>();
154 MS_EXCEPTION_IF_NULL(cnode);
155 if (cnode->inputs().size() <= INDEX_TWO) {
156 return cnode;
157 }
158 const auto &real = cnode->input(INDEX_TWO)->cast<CNodePtr>();
159 MS_EXCEPTION_IF_NULL(real);
160 return real;
161 });
162 } else {
163 real_comm_ops = nodes;
164 }
165 for (auto &node : real_comm_ops) {
166 const auto &cnode = node->cast<CNodePtr>();
167 MS_EXCEPTION_IF_NULL(cnode);
168 if (cnode->HasPrimalAttr(PIPELINE_PARAM)) {
169 (*params).push_back(node);
170 } else {
171 (*borders).push_back(node);
172 }
173 }
174 }
175
MainGraph()176 bool PipelineTransformer::MainGraph() {
177 bool find_main_graph = false;
178 for (auto &fg : manager_->func_graphs()) {
179 for (auto &node : fg->nodes()) {
180 if (IsPrimitiveCNode(node, prim::kPrimVirtualDataset)) {
181 main_graph_ = fg;
182 main_graph_->set_flag(MAIN_GRAPH, true);
183 virtual_dataset_ = node;
184 find_main_graph = true;
185 break;
186 }
187 }
188 if (find_main_graph) {
189 break;
190 }
191 }
192 if (!find_main_graph) {
193 MS_LOG(WARNING) << "Can't find main graph, possible reason is can't find virtual dataset.";
194 return false;
195 }
196 for (auto &fg : manager_->func_graphs()) {
197 if (fg->has_flag(FUNC_GRAPH_FLAG_CELL_REUSE)) {
198 shared_cell_ = fg;
199 break;
200 }
201 }
202 if (!shared_cell_) {
203 return true;
204 }
205 auto value_nodes = main_graph_->value_nodes();
206 mindspore::CompactSet<AnfNodePtr> shared_cell_nodes;
207 for (auto value_pair = value_nodes.cbegin(); value_pair != value_nodes.cend(); ++value_pair) {
208 auto node = (*value_pair).first;
209 if (!IsValueNode<FuncGraph>(node)) {
210 continue;
211 }
212 auto graph = GetValueNode<FuncGraphPtr>(node);
213 MS_EXCEPTION_IF_NULL(graph);
214 if (graph == shared_cell_) {
215 (void)(shared_cell_nodes.insert(node));
216 }
217 }
218 if (shared_cell_nodes.empty()) {
219 return true;
220 }
221 for (auto node : shared_cell_nodes) {
222 auto node_users = manager_->node_users()[node];
223 for (auto &node_user : node_users) {
224 auto user = node_user.first;
225 if (user->func_graph() == main_graph_) {
226 if (std::find(shared_cell_users_.begin(), shared_cell_users_.end(), user) == shared_cell_users_.end()) {
227 shared_cell_users_.push_back(user);
228 }
229 }
230 }
231 }
232 MS_LOG(INFO) << "Enable micro-fold, the folded cell is " << shared_cell_->ToString();
233 enable_share_cell_ = true;
234 return true;
235 }
236
SetMicroBatch(const AnfNodePtr & node,int64_t micro_size,size_t batch_axis) const237 ValuePtr PipelineTransformer::SetMicroBatch(const AnfNodePtr &node, int64_t micro_size, size_t batch_axis) const {
238 if (!IsPrimitiveCNode(node, prim::kPrimStridedSlice)) {
239 MS_LOG(EXCEPTION) << "Can't find MicroBatch information.";
240 }
241 auto cnode = node->cast<CNodePtr>();
242
243 int64_t micro = 0;
244 auto value = GetValueNode(cnode->input(2));
245 if (value != nullptr) {
246 auto tuple = GetValue<std::vector<int64_t>>(value); // begin
247 auto input_tmp = GetNodeShape(cnode->input(1));
248 auto input_shape = input_tmp.at(0);
249 auto slice_batch_size = input_shape.at(batch_axis); // betch shape
250 if (slice_batch_size == 0) {
251 MS_LOG(EXCEPTION) << "slice_batch_size should be a positive integer, but got " << slice_batch_size;
252 }
253 micro = tuple.at(batch_axis) * micro_size / slice_batch_size; // micro-index
254 } else {
255 // dynamic shape
256 // if micro is not 1: stridedslice --> maketuple --> scalarmul --> micro
257 // if micro is 1: stridedslice --> maketuple --> scalarfloordiv
258 if (!IsPrimitiveCNode(cnode->input(2), prim::kPrimMakeTuple)) {
259 MS_LOG(EXCEPTION) << "the begin of stridedslice is not constant value, and not make tuple";
260 }
261 auto make_tuple_cnode = cnode->input(2)->cast<CNodePtr>();
262
263 if (IsPrimitiveCNode(make_tuple_cnode->input(1), prim::kPrimScalarMul)) {
264 auto scalar_mul_cnode = make_tuple_cnode->input(1)->cast<CNodePtr>();
265 auto mul_value = GetValueNode(scalar_mul_cnode->input(2));
266 micro = GetValue<int64_t>(mul_value);
267 } else if (IsPrimitiveCNode(make_tuple_cnode->input(1), prim::kPrimScalarFloorDiv)) {
268 micro = 1;
269 } else {
270 MS_LOG(EXCEPTION) << "can not find the micro info, the input op of make tuple is "
271 << GetCNodePrimitive(make_tuple_cnode->input(1))->name();
272 }
273 }
274
275 cnode->AddPrimalAttr(MICRO, MakeValue(micro));
276 cnode->AddPrimalAttr(PIPELINE_BEGIN, MakeValue(micro));
277 int64_t seg = 0;
278 cnode->AddPrimalAttr(SEGMENT, MakeValue(seg));
279 return MakeValue(micro);
280 }
281
GetArgumentsByParameter(const AnfNodePtr & parameter)282 AnfNodePtr PipelineTransformer::GetArgumentsByParameter(const AnfNodePtr ¶meter) {
283 auto fg = parameter->func_graph();
284 if (fg == root_) {
285 return parameter;
286 }
287 auto parameters = fg->parameters();
288 auto iter = std::find(parameters.begin(), parameters.end(), parameter);
289 if (iter != parameters.end()) {
290 auto pos = std::distance(parameters.begin(), iter);
291 auto fg_used_map = fg->func_graph_cnodes_index();
292 for (auto &cur_fg_use : fg_used_map) {
293 if (cur_fg_use.first->second != 0) {
294 continue;
295 }
296 auto cur_fg = cur_fg_use.first->first->cast<CNodePtr>();
297 auto argument = cur_fg->input(pos + 1);
298 if (argument->isa<Parameter>()) {
299 return GetArgumentsByParameter(argument);
300 }
301 }
302 }
303 return nullptr;
304 }
305
NeedGrad(const CNodePtr & cnode)306 bool PipelineTransformer::NeedGrad(const CNodePtr &cnode) {
307 for (auto &input : cnode->inputs()) {
308 auto temp = input;
309 while (IsPrimitiveCNode(temp, prim::kPrimLoad) || IsPrimitiveCNode(temp, prim::kPrimCast) ||
310 IsPrimitiveCNode(temp, prim::kPrimDepend)) {
311 auto input_cnode = temp->cast<CNodePtr>();
312 MS_EXCEPTION_IF_NULL(input_cnode);
313 temp = input_cnode->input(1);
314 }
315 if (temp->isa<Parameter>()) {
316 auto argument = GetArgumentsByParameter(temp);
317 if (!argument || !GetRealKernelNode(argument, -1, nullptr).first->isa<Parameter>()) {
318 continue;
319 }
320 if (ParameterRequireGrad(argument)) {
321 return true;
322 }
323 }
324 }
325 return false;
326 }
327
LabelParameterStart(const FuncGraphPtr & graph)328 bool PipelineTransformer::LabelParameterStart(const FuncGraphPtr &graph) {
329 auto orders = graph->GetOrderedCnodes();
330 for (auto node = orders.cbegin(); node != orders.cend(); ++node) {
331 auto cnode = (*node)->cast<CNodePtr>();
332 MS_EXCEPTION_IF_NULL(cnode);
333 auto stage_info = cnode->user_data<NodeStageInfo>();
334 if (stage_info == nullptr || stage_info->stage() != 0) {
335 continue;
336 }
337 if (IsValueNode<FuncGraph>(cnode->input(0))) {
338 auto sub_graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
339 if (LabelParameterStart(sub_graph)) {
340 return true;
341 } else {
342 continue;
343 }
344 }
345 if (!IsPipelineCareNode(cnode)) {
346 continue;
347 }
348 if (NeedGrad(cnode)) {
349 auto prim = GetCNodePrimitive(cnode);
350 if (enable_share_cell_) {
351 (void)prim->AddAttr(PARAMETER_START_SHARE_CELL, MakeValue(0));
352 } else {
353 (void)prim->AddAttr(PARAMETER_START, MakeValue(0));
354 }
355 return true;
356 }
357 }
358 return false;
359 }
360
GetBatchAxisForInput(const AnfNodeIndexSet & input_node_users) const361 size_t PipelineTransformer::GetBatchAxisForInput(const AnfNodeIndexSet &input_node_users) const {
362 Shapes inputs_tuple;
363 for (const auto &input_node_user : input_node_users) {
364 auto node = input_node_user.first;
365 if (!IsPrimitiveCNode(node, prim::kPrimStridedSlice)) {
366 return 0; // simply return 0 when dynamic shape
367 }
368 auto cnode = node->cast<CNodePtr>();
369 auto value = GetValueNode(cnode->input(2));
370 if (value == nullptr) {
371 return 0; // simply return 0 when dynamic shape
372 }
373 auto tuple = GetValue<std::vector<int64_t>>(value);
374 inputs_tuple.push_back(tuple);
375 }
376 size_t batch_axis = 0;
377 size_t batch_axis_count = 0;
378 size_t input_dim = inputs_tuple.at(0).size();
379 size_t micro_num = inputs_tuple.size();
380 for (size_t axis = 0; axis < input_dim; ++axis) {
381 for (size_t i = 1; i < micro_num; ++i) {
382 if (inputs_tuple[i][axis] != inputs_tuple[i - 1][axis]) {
383 batch_axis = axis;
384 ++batch_axis_count;
385 break;
386 }
387 }
388 }
389 if (is_train_ && batch_axis_count != kSizeOne) {
390 MS_LOG(EXCEPTION)
391 << "For pipeline parallelism, micro_size partitioning of the input along a certain dimension is and "
392 << "is only allowed, but it is found that " << batch_axis_count << " to be partitioned.";
393 }
394 return batch_axis;
395 }
396
MicroSize(const AnfNodeIndexSet & input_node_users)397 size_t MicroSize(const AnfNodeIndexSet &input_node_users) {
398 size_t micro_size = 0;
399 for (const auto &input_node_user : input_node_users) {
400 auto node = input_node_user.first;
401 if (IsPrimitiveCNode(node, prim::kPrimStridedSlice)) {
402 micro_size++;
403 }
404 }
405
406 return micro_size;
407 }
408
LabelMicroBatch()409 void PipelineTransformer::LabelMicroBatch() {
410 auto graph = enable_share_cell_ ? shared_cell_ : main_graph_;
411 MS_EXCEPTION_IF_NULL(graph);
412 if (!LabelParameterStart(graph)) {
413 MS_LOG(EXCEPTION) << "Stage 0 should has at least 1 parameter. but got none. "
414 << "One possible cause is that the @lazy_inline decorator is misplaced.";
415 }
416 MS_EXCEPTION_IF_NULL(virtual_dataset_);
417 auto node_user_map = manager_->node_users();
418 auto node_users = node_user_map[virtual_dataset_];
419 auto stage_num = g_device_manager->stage_num();
420 for (auto &node_user : node_users) {
421 if (IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) {
422 auto data_users = manager_->node_users()[node_user.first];
423 auto node_first = data_users.front().first;
424 if (!IsPrimitiveCNode(node_first, prim::kPrimStridedSlice) && !IsPrimitiveCNode(node_first, prim::kPrimShape)) {
425 data_users.clear();
426 data_users = node_user_map[node_first];
427 }
428 auto micro_size = int64_t(MicroSize(data_users));
429 if (is_train_ && micro_size < stage_num) {
430 MS_LOG(EXCEPTION) << "The size of micro_batch must be greater than or equal to stage_num. But got the size of "
431 << "micro_batch is " << micro_size << " and the stage_num is " << stage_num;
432 }
433 micro_size_ = micro_size;
434 auto batch_axis = GetBatchAxisForInput(data_users);
435 MS_LOG(INFO) << "For the "
436 << GetSerialNumberString(
437 GetValue<int64_t>(GetValueNode(node_user.first->cast<CNodePtr>()->input(kIndex2))))
438 << "input, batch axis is " << batch_axis << ", micro size is : " << micro_size;
439 for (auto &data_user : data_users) {
440 if (!IsPrimitiveCNode(data_user.first, prim::kPrimStridedSlice)) {
441 continue;
442 }
443 auto micro = SetMicroBatch(data_user.first, micro_size, batch_axis);
444 SetStridedSliceStrategy(data_user.first);
445 auto cnode = data_user.first->cast<CNodePtr>();
446 BroadCastMicroBatch(cnode, &node_user_map, micro, 0);
447 }
448 }
449 }
450 }
451
LabelGenMaskFusion()452 void PipelineTransformer::LabelGenMaskFusion() {
453 auto fgs = manager_->func_graphs();
454 int64_t fusion_id = 0;
455 for (auto fg = fgs.cbegin(); fg != fgs.cend(); ++fg) {
456 if (*fg == root_ || *fg == main_graph_) {
457 continue;
458 }
459 auto stage = (*fg)->stage();
460 if (stage != -1 && stage != stage_) {
461 continue;
462 }
463 auto nodes = (*fg)->nodes();
464 for (auto node = nodes.cbegin(); node != nodes.cend(); ++node) {
465 if (!IsPrimitiveCNode(*node, prim::kPrimDropoutGenMask) && !IsPrimitiveCNode(*node, prim::kPrimDropoutDoMaskV3) &&
466 !IsPrimitiveCNode(*node, prim::kPrimDropout)) {
467 continue;
468 }
469 auto cnode = (*node)->cast<CNodePtr>();
470 MS_EXCEPTION_IF_NULL(cnode);
471 cnode->AddPrimalAttr(kAttrFusion, MakeValue(fusion_id));
472 fusion_id += 1;
473 }
474 }
475 }
476
Coloring()477 void PipelineTransformer::Coloring() {
478 auto need_coloring = true;
479 std::set<int64_t> stage_set;
480 if (!IsTraining(manager_)) {
481 is_train_ = false;
482 }
483 while (need_coloring) {
484 need_coloring = false;
485 for (auto &fg : manager_->func_graphs()) {
486 if (fg == root_ && is_train_) {
487 continue;
488 }
489 auto value_nodes = fg->value_nodes();
490 for (auto value_pair = value_nodes.cbegin(); value_pair != value_nodes.cend(); ++value_pair) {
491 auto node = (*value_pair).first;
492 if (!IsValueNode<FuncGraph>(node)) {
493 continue;
494 }
495 auto graph = GetValueNode<FuncGraphPtr>(node);
496 if (graph->stage() == -1) {
497 continue;
498 }
499 (void)stage_set.insert(graph->stage());
500 auto node_users = manager_->node_users()[node];
501 for (auto &user_pair : node_users) {
502 auto user_node = user_pair.first->cast<CNodePtr>();
503 user_node->set_user_data<NodeStageInfo>(std::make_shared<NodeStageInfo>(graph->stage()));
504 auto user_node_graph = user_node->func_graph();
505 if (graph->stage() == stage_ && user_node_graph->stage() == -1) {
506 user_node_graph->set_stage(graph->stage());
507 need_coloring = true;
508 }
509 }
510 }
511 }
512 }
513 MS_EXCEPTION_IF_NULL(g_device_manager);
514 auto stage_num = g_device_manager->stage_num();
515 if (SizeToLong(stage_set.size()) != stage_num) {
516 MS_LOG(EXCEPTION) << "Stage num is " << stage_num << " is not equal to stage used: " << stage_set.size();
517 }
518 }
519
BroadCastColoring()520 void PipelineTransformer::BroadCastColoring() {
521 auto need_coloring = true;
522 while (need_coloring) {
523 need_coloring = false;
524 auto all_nodes = enable_share_cell_ ? shared_cell_->nodes() : main_graph_->nodes();
525 auto node_users = manager_->node_users();
526 for (auto node = all_nodes.cbegin(); node != all_nodes.cend(); ++node) {
527 auto stage_info = (*node)->user_data<NodeStageInfo>();
528 if (!(*node)->isa<CNode>() || stage_info == nullptr || stage_info->stage() == -1 ||
529 IsPrimitiveCNode(*node, prim::kPrimUpdateState)) {
530 continue;
531 }
532 auto stage = stage_info->stage();
533 for (auto &user_pair : node_users[*node]) {
534 auto user_node = user_pair.first->cast<CNodePtr>();
535 auto user_stage_info = user_node->user_data<NodeStageInfo>();
536 if (user_stage_info == nullptr) {
537 user_node->set_user_data<NodeStageInfo>(std::make_shared<NodeStageInfo>(stage));
538 need_coloring = true;
539 continue;
540 }
541 auto user_node_stage = user_stage_info->stage();
542 if (stage > user_node_stage) {
543 if (IsValueNode<FuncGraph>(user_node->input(0))) {
544 MS_LOG(EXCEPTION) << "The stage setting is incorrect. PreNode's stage:" << stage
545 << " is larger than NextNode's stage:" << user_node_stage;
546 }
547 user_node->set_user_data<NodeStageInfo>(std::make_shared<NodeStageInfo>(stage));
548 need_coloring = true;
549 }
550 }
551 }
552 }
553 for (auto &fg : manager_->func_graphs()) {
554 auto stage = fg->stage();
555 if (stage < 0) {
556 continue;
557 }
558 if (fg == root_ || fg == main_graph_ || fg == shared_cell_) {
559 continue;
560 }
561 auto all_nodes = fg->nodes();
562 for (auto node : all_nodes) {
563 if (node->user_data<NodeStageInfo>() != nullptr) {
564 continue;
565 }
566 node->set_user_data<NodeStageInfo>(std::make_shared<NodeStageInfo>(stage));
567 }
568 }
569 }
570
GetLoadNodeByParam(const AnfNodePtr & param) const571 std::vector<AnfNodePtr> PipelineTransformer::GetLoadNodeByParam(const AnfNodePtr ¶m) const {
572 std::vector<AnfNodePtr> load_vec = {param};
573 auto node_users = manager_->node_users()[param];
574 for (auto ¶m_user : node_users) {
575 if (IsPrimitiveCNode(param_user.first, prim::kPrimLoad)) {
576 auto graph = param_user.first->func_graph();
577 // exclude opt graphs
578 if (graph == root_ || (graph->stage() == -1 && graph != main_graph_)) {
579 continue;
580 }
581 (void)load_vec.emplace_back(param_user.first);
582 }
583 }
584 return load_vec;
585 }
586
IsPipelineCareNode(const CNodePtr & cnode) const587 bool PipelineTransformer::IsPipelineCareNode(const CNodePtr &cnode) const {
588 MS_EXCEPTION_IF_NULL(cnode);
589 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
590 if (!prim) {
591 return false;
592 }
593 if (IsInWhiteList(cnode)) {
594 return false;
595 }
596 if (!IsParallelConsiderCNode(cnode)) {
597 MS_LOG(INFO) << "PipelineSplit don't care node:" << prim->name();
598 return false;
599 }
600 return true;
601 }
602
GraphOutNode(const AnfNodePtr & node,int tuple_index)603 CNodePtr PipelineTransformer::GraphOutNode(const AnfNodePtr &node, int tuple_index) {
604 auto cnode = node->cast<CNodePtr>();
605 MS_EXCEPTION_IF_NULL(cnode);
606 if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
607 return GraphOutNode(cnode->input(1), tuple_index);
608 }
609 if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
610 return cnode->input(IntToSize(tuple_index) + 1)->cast<CNodePtr>();
611 }
612 return cnode;
613 }
614
CreateOpInfo(const CNodePtr & cnode,int tuple_index=0)615 OperatorInfoPtr PipelineTransformer::CreateOpInfo(const CNodePtr &cnode, int tuple_index = 0) {
616 MS_EXCEPTION_IF_NULL(cnode);
617 auto temp_node = cnode;
618 if (IsValueNode<FuncGraph>(cnode->input(0))) {
619 auto output = GetValueNode<FuncGraphPtr>(cnode->input(0))->output();
620 MS_EXCEPTION_IF_NULL(output);
621 temp_node = GraphOutNode(output, tuple_index);
622 }
623 if (!IsPipelineCareNode(temp_node)) {
624 MS_LOG(EXCEPTION) << "Node: " << temp_node->DebugString() << " is not a Pipeline Care Node.";
625 }
626 if (IsPrimitiveCNode(temp_node, prim::kPrimVirtualDataset)) {
627 SetVirtualDatasetStrategy(temp_node);
628 }
629
630 auto prim = GetValueNode<PrimitivePtr>(temp_node->input(0));
631 MS_EXCEPTION_IF_NULL(prim);
632 if (prim->name() == RESHAPE) {
633 MS_LOG(EXCEPTION) << "Reshape op can't be a border. node:" << temp_node->DebugString();
634 }
635 auto attrs = prim->attrs();
636 auto op_info = CreateOperatorInfo(temp_node);
637
638 StrategyPtr in_strategy = nullptr, out_strategy = nullptr;
639 if (!StrategyFound(attrs)) {
640 in_strategy = GenerateBatchParallelStrategy(op_info, prim);
641 } else {
642 in_strategy = ExtractStrategy(attrs[IN_STRATEGY]);
643 out_strategy = ExtractStrategy(attrs[OUT_STRATEGY]);
644 }
645 MS_EXCEPTION_IF_NULL(in_strategy);
646 if (op_info->Init(in_strategy, out_strategy) == FAILED) {
647 MS_LOG(EXCEPTION) << "operator: " << prim->name() << " init failed.";
648 }
649 return op_info;
650 }
651
GetOpInfo(const AnfNodePtr & node)652 std::pair<OperatorInfoPtr, int> PipelineTransformer::GetOpInfo(const AnfNodePtr &node) {
653 MS_EXCEPTION_IF_NULL(node);
654 auto cnode = node->cast<CNodePtr>();
655 MS_EXCEPTION_IF_NULL(cnode);
656 // Handle Cast and TupleGetitem situation
657 int tensor_info_index = 0;
658 OperatorInfoPtr op_info;
659 if (IsPrimitiveCNode(node, prim::kPrimReceive)) {
660 op_info = node->user_data<OperatorInfo>();
661 } else {
662 if (IsPrimitiveCNode(node, prim::kPrimCast)) {
663 cnode = cnode->input(1)->cast<CNodePtr>();
664 } else if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
665 tensor_info_index = LongToInt(GetTupleGetItemIndex(cnode));
666 cnode = cnode->input(1)->cast<CNodePtr>();
667 }
668 // Create OperatorInfo to get slice_shape for send/recv
669 MS_EXCEPTION_IF_NULL(cnode);
670 if (cnode->has_user_data<OperatorInfo>()) {
671 op_info = cnode->user_data<OperatorInfo>();
672 } else {
673 op_info = CreateOpInfo(cnode, tensor_info_index);
674 }
675 }
676 return std::make_pair(op_info, tensor_info_index);
677 }
678
GetActualOpUsers(const AnfNodePtr & node,NodeUsersMap * node_users_map)679 AnfNodeIndexSet GetActualOpUsers(const AnfNodePtr &node, NodeUsersMap *node_users_map) {
680 AnfNodeIndexSet users;
681 auto user_pairs = (*node_users_map)[node];
682 for (const auto &user_pair : user_pairs) {
683 const auto user = user_pair.first;
684 const auto &cuser = user->cast<CNodePtr>();
685 MS_EXCEPTION_IF_NULL(cuser);
686 const auto &input = cuser->input(0);
687 MS_EXCEPTION_IF_NULL(input);
688 AnfNodePtr temp_node = nullptr;
689 if (IsValueNode<FuncGraph>(input)) {
690 auto graph = GetValueNode<FuncGraphPtr>(input);
691 MS_EXCEPTION_IF_NULL(graph);
692 auto temp_params = graph->parameters();
693 auto index = user_pair.second;
694 if (temp_params.size() < IntToSize(index)) {
695 MS_LOG(EXCEPTION) << "parameter: " << temp_node->DebugString() << " out of graph: " << graph->ToString()
696 << "'s range.";
697 }
698 temp_node = temp_params[IntToSize(index - 1)];
699 } else if (IsPrimitiveCNode(cuser, prim::kPrimLoad) || IsPrimitiveCNode(cuser, prim::kPrimCast) ||
700 IsPrimitiveCNode(cuser, prim::kPrimMirrorSilentCheck)) {
701 temp_node = cuser;
702 }
703 if (temp_node) {
704 const auto &temp_users = GetActualOpUsers(temp_node, node_users_map);
705 (void)(users.insert(temp_users.begin(), temp_users.end()));
706 } else {
707 (void)(users.insert(user_pair));
708 }
709 }
710 return users;
711 }
712
GetParameterPair(const AnfNodePtr & node)713 std::pair<OperatorInfoPtr, int> PipelineTransformer::GetParameterPair(const AnfNodePtr &node) {
714 MS_EXCEPTION_IF_NULL(node);
715 auto node_users_map = manager_->node_users();
716 const auto &node_users = GetActualOpUsers(node, &node_users_map);
717 for (auto &node_user : node_users) {
718 auto user = node_user.first->cast<CNodePtr>();
719 MS_EXCEPTION_IF_NULL(user);
720 auto user_graph = user->func_graph();
721 MS_EXCEPTION_IF_NULL(user_graph);
722 if (user_graph->stage() == -1) {
723 continue;
724 }
725 auto index = node_user.second;
726 if (!IsPipelineCareNode(user)) {
727 continue;
728 }
729 OperatorInfoPtr op_info;
730 if (user->has_user_data<OperatorInfo>()) {
731 op_info = user->user_data<OperatorInfo>();
732 } else {
733 op_info = CreateOpInfo(user);
734 }
735 return std::make_pair(op_info, index - 1);
736 }
737 return std::make_pair(nullptr, 0);
738 }
739
GetParameterLoadUsers(const AnfNodePtr & node,const NodeUsersMap & node_users_map) const740 AnfNodeIndexSet PipelineTransformer::GetParameterLoadUsers(const AnfNodePtr &node,
741 const NodeUsersMap &node_users_map) const {
742 AnfNodeIndexSet users;
743 if (node_users_map.find(node) == node_users_map.end()) {
744 return users;
745 }
746 auto loads = GetLoadNodeByParam(node);
747 for (auto &load : loads) {
748 auto iter = node_users_map.find(load);
749 if (iter == node_users_map.end()) {
750 continue;
751 }
752 const auto &temp_users = iter->second;
753 for (const auto &user : temp_users) {
754 auto cuser = user.first->cast<CNodePtr>();
755 MS_EXCEPTION_IF_NULL(cuser);
756 const auto &input = cuser->input(0);
757 MS_EXCEPTION_IF_NULL(input);
758 if (enable_share_cell_ && IsValueNode<FuncGraph>(input) && GetValueNode<FuncGraphPtr>(input) == shared_cell_) {
759 auto index = user.second;
760 auto pos = index - 1;
761 const auto &share_cell_params = shared_cell_->parameters();
762 const auto ¶m = share_cell_params.at(pos);
763 const auto ¶m_iter = node_users_map.find(param);
764 if (param_iter == node_users_map.end()) {
765 continue;
766 }
767 const auto ¶m_users = param_iter->second;
768 users.insert(param_users.begin(), param_users.end());
769 } else {
770 users.insert(user);
771 }
772 }
773 }
774 return users;
775 }
776
HandleSharedParameter()777 std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> PipelineTransformer::HandleSharedParameter() {
778 auto parameters = root_->parameters();
779 std::vector<AnfNodePtr> sends = {};
780 std::vector<AnfNodePtr> recvs = {};
781 for (auto ¶meter : parameters) {
782 auto parameter_stage = parameter_color_map_[parameter];
783 if (parameter_stage.size() <= 1) {
784 continue;
785 }
786 const auto &node_users_map = manager_->node_users();
787 auto users = GetParameterLoadUsers(parameter, node_users_map);
788 for (auto &user : users) {
789 if (!is_train_ && !enable_share_cell_) {
790 continue;
791 }
792 auto node = user.first;
793 auto cnode = node->cast<CNodePtr>();
794 auto graph = FindNodeGraph(cnode);
795 if (graph == root_ || graph->stage() == -1 || parameter_stage.count(stage_) == 0) {
796 continue;
797 }
798 auto micro = cnode->GetPrimalAttr(MICRO);
799 if (!micro) {
800 MS_LOG(INFO) << "parameter: " << parameter->ToString() << " doesn't have micro batch";
801 micro = MakeValue(int64_t(0));
802 }
803 if (stage_ == *(parameter_stage.begin())) {
804 auto user_stage = graph->stage();
805 auto stage_info = node->user_data<NodeStageInfo>();
806 if (stage_info) {
807 user_stage = stage_info->stage();
808 }
809 if (graph->stage() == stage_ || user_stage == -1) {
810 continue;
811 }
812 if (Reuse(parameter, user_stage, sends, DEST_RANK)) {
813 continue;
814 }
815 auto send_out = InsertSend(parameter, user_stage, stage_, micro);
816 sends.push_back(send_out.depend);
817 } else {
818 auto receive = Reuse(parameter, *parameter_stage.begin(), recvs, SRC_RANK);
819 if (receive) {
820 manager_->SetEdge(node, user.second, receive);
821 } else {
822 AnfNodePtr recv;
823 auto fg = enable_share_cell_ ? shared_cell_ : main_graph_;
824 recv = InsertReceive(fg, parameter, node, user.second, stage_, *parameter_stage.begin(), micro, parameter);
825 (void)(recvs.push_back(recv));
826 }
827 }
828 }
829 }
830 return std::make_pair(sends, recvs);
831 }
832
FillParameterStage(const CNodePtr & node,std::set<int64_t> * const parameter_stage)833 void PipelineTransformer::FillParameterStage(const CNodePtr &node, std::set<int64_t> *const parameter_stage) {
834 auto stage_info = node->user_data<NodeStageInfo>();
835 if (stage_info != nullptr && stage_info->stage() != -1) {
836 (void)(parameter_stage->insert(stage_info->stage()));
837 } else {
838 auto graph = node->func_graph();
839 MS_EXCEPTION_IF_NULL(graph);
840 if (graph != root_ && graph != main_graph_ && graph != shared_cell_ && graph->stage() != -1) {
841 (void)(parameter_stage->insert(graph->stage()));
842 }
843 }
844 }
845
GetStageByArgument(const CNodePtr & node,size_t index,const std::vector<AnfNodePtr> & parameters,const NodeUsersMap & node_users_map,std::set<int64_t> * const parameter_stage)846 bool PipelineTransformer::GetStageByArgument(const CNodePtr &node, size_t index,
847 const std::vector<AnfNodePtr> ¶meters,
848 const NodeUsersMap &node_users_map,
849 std::set<int64_t> *const parameter_stage) {
850 if (!enable_share_cell_) {
851 return false;
852 }
853 if (index < 1) {
854 return false;
855 }
856 const auto &input = node->input(0);
857 if (!IsValueNode<FuncGraph>(input)) {
858 FillParameterStage(node, parameter_stage);
859 return true;
860 }
861 if (GetValueNode<FuncGraphPtr>(input) != shared_cell_) {
862 return false;
863 }
864 auto pos = index - 1;
865 const auto ¶m = parameters.at(pos);
866 MS_EXCEPTION_IF_NULL(param);
867 auto loads = GetLoadNodeByParam(param);
868 for (auto &load : loads) {
869 const auto &iter = node_users_map.find(load);
870 if (iter == node_users_map.end()) {
871 continue;
872 }
873 const auto &users = (*iter).second;
874 for (auto &user : users) {
875 auto user_cnode = user.first->cast<CNodePtr>();
876 MS_EXCEPTION_IF_NULL(user_cnode);
877 FillParameterStage(user_cnode, parameter_stage);
878 }
879 }
880 return true;
881 }
882
ParameterColoring()883 void PipelineTransformer::ParameterColoring() {
884 auto parameters = root_->parameters();
885 auto &node_users_map = manager_->node_users();
886 const auto &share_cell_parameters = shared_cell_->parameters();
887 for (auto ¶meter : parameters) {
888 auto loads = GetLoadNodeByParam(parameter);
889 std::set<int64_t> parameter_stage;
890 for (auto &load : loads) {
891 auto load_users = node_users_map[load];
892 for (auto &load_user : load_users) {
893 auto user_cnode = load_user.first->cast<CNodePtr>();
894 MS_EXCEPTION_IF_NULL(user_cnode);
895 if (GetStageByArgument(user_cnode, load_user.second, share_cell_parameters, node_users_map, ¶meter_stage)) {
896 continue;
897 }
898 FillParameterStage(user_cnode, ¶meter_stage);
899 }
900 }
901 auto param_info = parameter->cast<ParameterPtr>()->param_info();
902 if (!param_info) {
903 parameter_color_map_[parameter] = parameter_stage;
904 continue;
905 }
906 MS_EXCEPTION_IF_NULL(param_info);
907 auto requires_grad = param_info->requires_grad();
908 if (!parameter_stage.empty() && *parameter_stage.begin() == stage_ && !virtual_param_ && requires_grad) {
909 virtual_param_ = parameter;
910 }
911 parameter_color_map_[parameter] = parameter_stage;
912 }
913 }
914
RemoveMonadNode()915 void PipelineTransformer::RemoveMonadNode() {
916 auto all_nodes = DeepScopedGraphSearch(main_graph_->get_return());
917 auto node_users_map = manager_->node_users();
918 for (auto &node : all_nodes) {
919 if (!IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
920 continue;
921 }
922 auto cnode = node->cast<CNodePtr>();
923 MS_EXCEPTION_IF_NULL(cnode);
924 auto abs = cnode->abstract();
925 MS_EXCEPTION_IF_NULL(abs);
926 auto stage_info = cnode->user_data<NodeStageInfo>();
927 if (stage_info == nullptr) {
928 continue;
929 }
930 auto stage = stage_info->stage();
931 if (stage != stage_ && stage != -1) {
932 auto node_users = node_users_map[node];
933 for (auto &user_node : node_users) {
934 auto monad_node = NewValueNode(kUMonad);
935 if (abs->isa<abstract::AbstractIOMonad>()) {
936 monad_node = NewValueNode(kIOMonad);
937 }
938 manager_->SetEdge(user_node.first, user_node.second, monad_node);
939 }
940 }
941 }
942 }
943
GetShapeValue(const Shape & shape)944 static ValueListPtr GetShapeValue(const Shape &shape) {
945 std::vector<ValuePtr> element;
946 (void)std::transform(shape.begin(), shape.end(), std::back_inserter(element),
947 [](int elem) { return MakeValue(elem); });
948 return std::make_shared<ValueList>(element);
949 }
950
GetShapeType(const AnfNodePtr & node,const Shape & shape,size_t index)951 std::pair<ValueListPtr, TypePtr> GetShapeType(const AnfNodePtr &node, const Shape &shape, size_t index) {
952 TypePtr type;
953 auto cnode = node->cast<CNodePtr>();
954 if (cnode != nullptr && IsValueNode<FuncGraph>(cnode->input(0))) {
955 auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
956 auto graph_output = graph->output();
957 type = graph_output->Type();
958 } else {
959 if (node->isa<CNode>() && IsPrimitiveCNode(node->cast<CNodePtr>(), prim::kPrimDepend)) {
960 type = cnode->input(1)->Type();
961 } else {
962 type = node->Type();
963 }
964 }
965 MS_EXCEPTION_IF_NULL(type);
966
967 TensorTypePtr tensor_type;
968 if (type->isa<mindspore::TensorType>()) {
969 tensor_type = type->cast<mindspore::TensorTypePtr>();
970 } else if (type->isa<Tuple>()) {
971 auto tuple_type = type->cast<TuplePtr>();
972 MS_EXCEPTION_IF_NULL(tuple_type);
973 tensor_type = tuple_type->elements().at(index)->cast<TensorTypePtr>();
974 }
975 MS_EXCEPTION_IF_NULL(tensor_type);
976 auto dtype = tensor_type->element();
977 MS_EXCEPTION_IF_NULL(dtype);
978 auto shape_list = GetShapeValue(shape);
979 return std::make_pair(shape_list, dtype);
980 }
981
FindPipelineCareNode(const AnfNodePtr & node) const982 AnfNodePtr PipelineTransformer::FindPipelineCareNode(const AnfNodePtr &node) const {
983 MS_EXCEPTION_IF_NULL(node);
984 auto real_node = GetRealKernelNode(node, -1).first;
985 if (!real_node->isa<CNode>()) {
986 return real_node;
987 }
988 auto cnode = real_node->cast<CNodePtr>();
989 MS_EXCEPTION_IF_NULL(cnode);
990 if (IsInWhiteList(cnode)) {
991 return cnode->cast<AnfNodePtr>();
992 }
993 if (!IsPipelineCareNode(cnode)) {
994 MS_LOG(EXCEPTION) << "Only PipelineSplit cared node can be a border."
995 << " border node: " << cnode->DebugString();
996 }
997 return cnode->cast<AnfNodePtr>();
998 }
999
InsertSend(const AnfNodePtr & parameter,int64_t user_node_stage,int64_t node_stage,const ValuePtr & value)1000 SendAttr PipelineTransformer::InsertSend(const AnfNodePtr ¶meter, int64_t user_node_stage, int64_t node_stage,
1001 const ValuePtr &value) {
1002 auto dest_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_;
1003 int64_t send_tag = send_tag_map[dest_rank];
1004 send_tag_map[dest_rank]++;
1005 Attr attr_tag = std::make_pair(SR_TAG, MakeValue(send_tag));
1006 Attr attr_rank = std::make_pair(DEST_RANK, MakeValue(dest_rank));
1007 Attr attr_group = std::make_pair(GROUP, MakeValue(world_group_));
1008 Attr attr_group_back = std::make_pair(GROUP_BACK, MakeValue(world_group_));
1009 OperatorAttrs attrs = {attr_tag, attr_rank, attr_group, attr_group_back};
1010 AnfNodePtr care_node;
1011 bool is_param = true;
1012 auto op_info_pair = GetOpInfoPair(parameter, parameter, &care_node, &is_param);
1013 MS_EXCEPTION_IF_NULL(op_info_pair.first);
1014 auto tensor_info = GetTensorInfo(op_info_pair, is_param);
1015 auto index = op_info_pair.second;
1016 auto op_info = op_info_pair.first;
1017 auto slice_shape = tensor_info.slice_shape();
1018 auto shape_type_pair = GetShapeType(parameter, slice_shape, 0);
1019 auto graph = enable_share_cell_ ? shared_cell_ : main_graph_;
1020 CNodePtr send = CreateCNodeByInputsAndAttr(graph, SEND, SEND, AnfNodePtrList{parameter}, attrs);
1021 auto prim = GetCNodePrimitive(send);
1022 prim->set_attr(SHAPE, shape_type_pair.first);
1023 prim->set_attr(DTYPE, shape_type_pair.second);
1024
1025 if (!is_param) {
1026 send->AddPrimalAttr(PIPELINE_END, value);
1027 } else {
1028 send->AddPrimalAttr(PIPELINE_PARAM, value);
1029 send->set_user_data<OperatorInfo>(op_info);
1030 send->AddPrimalAttr(PARAM_INDEX, MakeValue(index));
1031 auto param = care_node ? care_node : parameter;
1032 send->set_user_data<AnfNode>(INPUT_PARAM, param);
1033 }
1034 send->AddPrimalAttr(MICRO, value);
1035 send->AddPrimalAttr(DEST_RANK, MakeValue(user_node_stage));
1036 auto abstract = parameter->abstract();
1037 if (care_node) {
1038 abstract = care_node->abstract();
1039 }
1040 send->set_abstract(abstract);
1041 SendAttr send_out = {shape_type_pair.first, shape_type_pair.second, send};
1042
1043 // for FetchSends
1044 send->set_user_data<int64_t>(DEST_RANK, std::make_shared<int64_t>(dest_rank));
1045 send->set_user_data<int64_t>(USER_NODE_STAGE, std::make_shared<int64_t>(user_node_stage));
1046 return send_out;
1047 }
1048
InsertReceive(const FuncGraphPtr & graph,const AnfNodePtr & node,const AnfNodePtr & use_node,int index,int64_t user_node_stage,int64_t node_stage,const ValuePtr & value,const AnfNodePtr & graph_param)1049 AnfNodePtr PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node,
1050 const AnfNodePtr &use_node, int index, int64_t user_node_stage,
1051 int64_t node_stage, const ValuePtr &value,
1052 const AnfNodePtr &graph_param) {
1053 auto src_rank = global_rank_ - (user_node_stage - node_stage) * per_stage_rank_num_;
1054 int64_t recv_tag = recv_tag_map[src_rank];
1055 recv_tag_map[src_rank]++;
1056 Attr attr_tag = std::make_pair(SR_TAG, MakeValue(recv_tag));
1057 Attr attr_rank = std::make_pair(SRC_RANK, MakeValue(src_rank));
1058 bool is_param = true;
1059 AnfNodePtr care_node;
1060 auto op_info_pair = GetOpInfoPair(node, graph_param, &care_node, &is_param);
1061 auto tensor_info = GetTensorInfo(op_info_pair, is_param);
1062 auto tensor_layout = tensor_info.tensor_layout();
1063 Shape slice_shape = tensor_info.slice_shape();
1064 auto shape_type_pair = GetShapeType(node, slice_shape, 0);
1065 Attr attr_shape = std::make_pair(SHAPE, shape_type_pair.first);
1066 Attr attr_dtype = std::make_pair(DTYPE, shape_type_pair.second);
1067 Attr attr_group = std::make_pair(GROUP, MakeValue(world_group_));
1068 Attr attr_group_back = std::make_pair(GROUP_BACK, MakeValue(world_group_));
1069 OperatorAttrs attrs = {attr_tag, attr_rank, attr_shape, attr_dtype, attr_group, attr_group_back};
1070 std::vector<AnfNodePtr> recv_input;
1071 if (node->isa<Parameter>()) {
1072 recv_input = {node};
1073 } else {
1074 recv_input = {virtual_param_};
1075 if (enable_share_cell_ || !is_train_) {
1076 auto recv_tensor = TensorConstructUtils::CreateZerosTensor(kFloat16, {1});
1077 recv_input = {NewValueNode(recv_tensor)};
1078 } else {
1079 if (virtual_param_ == nullptr) {
1080 MS_LOG(EXCEPTION)
1081 << "For Pipeline Parallel, each stage must have at least one parameter that needs to be trained, but stage: "
1082 << stage_ << " has none.";
1083 }
1084 }
1085 }
1086 auto recv = CreateCNodeByInputsAndAttr(graph, RECEIVE, RECEIVE, recv_input, attrs);
1087 if (is_param) {
1088 recv->set_user_data<AnfNode>(PIPELINE_PARAM, node);
1089 recv->AddPrimalAttr(PIPELINE_PARAM, value);
1090 auto param = care_node ? care_node : node;
1091 recv->set_user_data<AnfNode>(INPUT_PARAM, param);
1092 } else {
1093 recv->AddPrimalAttr(PIPELINE_BEGIN, value);
1094 }
1095 recv->AddPrimalAttr(MICRO, value);
1096 recv->AddPrimalAttr(SRC_RANK, MakeValue(node_stage));
1097 auto node_abstract = node->abstract();
1098 if (node->isa<CNode>()) {
1099 auto cnode = node->cast<CNodePtr>();
1100 MS_EXCEPTION_IF_NULL(cnode);
1101 if (IsValueNode<FuncGraph>(cnode->input(0))) {
1102 auto output = GetValueNode<FuncGraphPtr>(cnode->input(0))->output();
1103 MS_EXCEPTION_IF_NULL(output);
1104 node_abstract = output->abstract();
1105 }
1106 }
1107 MS_EXCEPTION_IF_NULL(node_abstract);
1108 recv->set_abstract(node_abstract);
1109 if (node->isa<Parameter>()) {
1110 BaseShapePtr parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
1111 auto abstract_clone = node->abstract()->Clone();
1112 MS_EXCEPTION_IF_NULL(abstract_clone);
1113 abstract_clone->set_shape(parallel_shape);
1114 node->set_abstract(abstract_clone);
1115 node->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
1116 auto actual_param = RefParameterToActualParameter(node);
1117 if (actual_param) {
1118 actual_param->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
1119 auto actual_param_abstract = actual_param->abstract()->Clone();
1120 actual_param_abstract->set_shape(parallel_shape);
1121 actual_param->set_abstract(actual_param_abstract);
1122 }
1123 }
1124 recv->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
1125 recv->set_user_data<OperatorInfo>(op_info_pair.first);
1126
1127 // for FetchRecvs
1128 recv->set_user_data<int64_t>(SRC_RANK, std::make_shared<int64_t>(src_rank));
1129 recv->set_user_data<int64_t>(NODE_STAGE, std::make_shared<int64_t>(node_stage));
1130 recv->set_user_data<Type>(SLICE_DTYPE, shape_type_pair.second);
1131 recv->set_user_data<Shape>(SLICE_SHAPE, std::make_shared<Shape>(slice_shape));
1132
1133 manager_->SetEdge(use_node, index, recv);
1134 return recv;
1135 }
1136
Reuse(const AnfNodePtr & node,int64_t stage,const std::vector<AnfNodePtr> & out_input,const std::string & tag) const1137 AnfNodePtr PipelineTransformer::Reuse(const AnfNodePtr &node, int64_t stage, const std::vector<AnfNodePtr> &out_input,
1138 const std::string &tag) const {
1139 for (auto &input : out_input) {
1140 auto cnode = input->cast<CNodePtr>();
1141 if (!cnode) {
1142 continue;
1143 }
1144 if (IsPrimitiveCNode(cnode, prim::kPrimDepend)) {
1145 cnode = cnode->input(2)->cast<CNodePtr>();
1146 }
1147 if (cnode->input(1) == node) {
1148 auto dest_rank_send = GetValue<int64_t>(cnode->GetPrimalAttr(tag));
1149 if (dest_rank_send == stage) {
1150 return input;
1151 }
1152 }
1153 }
1154 return nullptr;
1155 }
1156
ActualOp(const AnfNodePtr & node)1157 AnfNodePtr PipelineTransformer::ActualOp(const AnfNodePtr &node) {
1158 // skip some virtual op like:Depend, Load, Cast
1159 if (IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimCast) ||
1160 IsPrimitiveCNode(node, prim::kPrimLoad)) {
1161 auto cnode = node->cast<CNodePtr>();
1162 MS_EXCEPTION_IF_NULL(cnode);
1163 return ActualOp(cnode->input(1));
1164 }
1165 return node;
1166 }
1167
IsParameterGraph(const AnfNodePtr & node) const1168 bool PipelineTransformer::IsParameterGraph(const AnfNodePtr &node) const {
1169 // ParameterGraph: graph which return a parameter
1170 MS_EXCEPTION_IF_NULL(node);
1171 CNodePtr call_node = nullptr;
1172 auto real_kernel = GetRealKernelNode(node, -1, &call_node).first;
1173 if (call_node != nullptr && real_kernel->isa<Parameter>()) {
1174 return true;
1175 }
1176 return false;
1177 }
1178
HandleParameterGraph(const AnfNodePtr & node,const AnfNodePtr & use_node,int64_t stage,int64_t user_stage,const ValuePtr & micro,size_t pos,const std::vector<AnfNodePtr> & ops)1179 AnfNodePtr PipelineTransformer::HandleParameterGraph(const AnfNodePtr &node, const AnfNodePtr &use_node, int64_t stage,
1180 int64_t user_stage, const ValuePtr µ, size_t pos,
1181 const std::vector<AnfNodePtr> &ops) {
1182 CNodePtr call_node = nullptr;
1183 auto argument = GetRealKernelNode(node, -1, &call_node).first;
1184
1185 auto use_cnode = use_node->cast<CNodePtr>();
1186 MS_EXCEPTION_IF_NULL(use_cnode);
1187 if (!IsValueNode<FuncGraph>(use_cnode->input(0))) {
1188 MS_LOG(EXCEPTION) << "Parameter must be used by a graph, but got: " << use_cnode->DebugString();
1189 }
1190 auto use_graph = GetValueNode<FuncGraphPtr>(use_cnode->input(0));
1191 auto use_parameter_list = use_graph->parameters();
1192 auto parameter = use_parameter_list.at(pos - 1);
1193 // insert receive
1194 if (stage_ == user_stage) {
1195 auto recv = Reuse(argument, stage, ops, SRC_RANK);
1196 if (recv) {
1197 manager_->SetEdge(use_node, SizeToInt(pos), recv);
1198 return nullptr;
1199 }
1200 auto root_param = argument;
1201 if (argument->isa<Parameter>() && argument->func_graph() != root_) {
1202 root_param = GetArgumentsByParameter(argument);
1203 }
1204 (void)parameter_color_map_[root_param].insert(user_stage);
1205 auto graph = enable_share_cell_ ? shared_cell_ : main_graph_;
1206 auto recv_node = InsertReceive(graph, argument, use_node, SizeToInt(pos), user_stage, stage, micro, parameter);
1207 UpdateParameterSharedInfo(root_param, recv_node, false);
1208 return recv_node;
1209 }
1210 // insert send
1211 if (Reuse(argument, user_stage, ops, DEST_RANK)) {
1212 return nullptr;
1213 }
1214 auto send_out = InsertSend(argument, user_stage, stage_, micro);
1215 send_out.depend->set_user_data<Type>(DTYPE, send_out.type);
1216 send_out.depend->set_user_data<ValueList>(SHAPE, send_out.shape);
1217 UpdateParameterSharedInfo(argument, send_out.depend, true);
1218 return send_out.depend;
1219 }
1220
CutBorderForNode(const FuncGraphPtr & graph,const AnfNodePtr & node,std::vector<AnfNodePtr> * send_ops,std::vector<AnfNodePtr> * receive_ops)1221 void PipelineTransformer::CutBorderForNode(const FuncGraphPtr &graph, const AnfNodePtr &node,
1222 std::vector<AnfNodePtr> *send_ops, std::vector<AnfNodePtr> *receive_ops) {
1223 auto stage_info = node->user_data<NodeStageInfo>();
1224 auto node_users = manager_->node_users()[node];
1225 AnfNodePtr receive = nullptr;
1226 for (auto &user_pair : node_users) {
1227 auto user_node = user_pair.first;
1228 auto node_stage = stage_info->stage();
1229 auto user_stage_info = user_node->user_data<NodeStageInfo>();
1230 if (user_stage_info == nullptr) {
1231 continue;
1232 }
1233 auto user_node_stage = user_stage_info->stage();
1234 if (node_stage != stage_ && user_node_stage != stage_) {
1235 continue;
1236 }
1237 auto micro = user_node->cast<CNodePtr>()->GetPrimalAttr(MICRO);
1238 if (!micro) {
1239 MS_LOG(INFO) << "Can't find micro_batch information, use micro(0)";
1240 micro = MakeValue(int64_t(0));
1241 }
1242 if (node_stage < user_node_stage) {
1243 if (node_stage == stage_) {
1244 if (IsParameterGraph(node)) {
1245 if (!is_train_ && !enable_share_cell_) {
1246 continue;
1247 }
1248 auto send_depend = HandleParameterGraph(node, user_node, node_stage, user_node_stage, micro,
1249 IntToSize(user_pair.second), *send_ops);
1250 if (!send_depend) {
1251 continue;
1252 }
1253 (void)send_ops->insert(send_ops->cbegin(), send_depend);
1254 continue;
1255 }
1256 if (Reuse(node, user_node_stage, *send_ops, DEST_RANK)) {
1257 continue;
1258 }
1259 auto send_out = InsertSend(node, user_node_stage, node_stage, micro);
1260 MS_EXCEPTION_IF_NULL(send_out.depend);
1261 send_ops->push_back(send_out.depend);
1262 send_out.depend->set_user_data<Type>(DTYPE, send_out.type);
1263 send_out.depend->set_user_data<ValueList>(SHAPE, send_out.shape);
1264 } else {
1265 if (!receive) {
1266 if (IsParameterGraph(node)) {
1267 if (!is_train_ && !enable_share_cell_) {
1268 continue;
1269 }
1270 receive = HandleParameterGraph(node, user_node, node_stage, user_node_stage, micro,
1271 IntToSize(user_pair.second), *receive_ops);
1272 if (!receive) {
1273 continue;
1274 }
1275 receive_ops->push_back(receive);
1276 } else {
1277 receive = InsertReceive(graph, node, user_node, user_pair.second, user_node_stage, node_stage, micro, node);
1278 receive_ops->push_back(receive);
1279 }
1280 } else {
1281 manager_->SetEdge(user_node, user_pair.second, receive);
1282 }
1283 }
1284 continue;
1285 }
1286 if (node_stage > user_node_stage) {
1287 MS_LOG(EXCEPTION) << "node_stage: " << node_stage << " must be smaller than user_node_stage: " << user_node_stage;
1288 }
1289 }
1290 }
1291
CutBorder(const FuncGraphPtr & graph)1292 std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> PipelineTransformer::CutBorder(const FuncGraphPtr &graph) {
1293 std::vector<AnfNodePtr> send_ops;
1294 std::vector<AnfNodePtr> receive_ops;
1295 auto ret = graph->get_return();
1296 MS_EXCEPTION_IF_NULL(ret);
1297 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
1298 std::reverse(all_nodes.begin(), all_nodes.end());
1299 for (auto &node : all_nodes) {
1300 auto stage_info = node->user_data<NodeStageInfo>();
1301 if (!node->isa<CNode>() || stage_info == nullptr || stage_info->stage() == -1 ||
1302 IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
1303 continue;
1304 }
1305 // Modify for lizard cyclomatic complexity.
1306 CutBorderForNode(graph, node, &send_ops, &receive_ops);
1307 }
1308 RemoveMonadNode();
1309 return std::make_pair(send_ops, receive_ops);
1310 }
1311
CreateZeroseOutput(const AnfNodePtr & node,size_t index)1312 AnfNodePtr PipelineTransformer::CreateZeroseOutput(const AnfNodePtr &node, size_t index) {
1313 auto out_shapes = GetNodeShape(node);
1314 if (out_shapes.size() <= index) {
1315 MS_LOG(EXCEPTION) << "the index is out of range, the size of output_shapes is " << out_shapes.size()
1316 << ", but the index is " << index;
1317 }
1318 auto out_shape = out_shapes.at(index);
1319 if (std::count(out_shape.cbegin(), out_shape.cend(), DYNAMIC_DIM_VAL) > 0) {
1320 MS_LOG(EXCEPTION) << "it is not supported that loss is not a scalar in dynamic shape and pipeline parallel "
1321 "scenarios, the output shape is "
1322 << out_shape;
1323 }
1324
1325 // Modify output dimension when enable data parallel since only the last stage enable VirtualOutput redistribution.
1326 bool full_batch = ParallelContext::GetInstance()->full_batch();
1327 int64_t dev_num = full_batch ? 1 : g_device_manager->stage_device_num();
1328 if (dev_num == 0) {
1329 MS_LOG(EXCEPTION) << "Device num must be larger than 0, but get 0.";
1330 }
1331
1332 if (!is_train_ && !out_shape.empty() && out_shape[0] % dev_num == 0) {
1333 out_shape[0] /= dev_num;
1334 }
1335
1336 auto out_shape_type = GetShapeType(node, out_shape, index);
1337 auto zero_tensor = TensorConstructUtils::CreateZerosTensor(out_shape_type.second, out_shape);
1338 MS_EXCEPTION_IF_NULL(zero_tensor);
1339
1340 auto value_node = NewValueNode(zero_tensor);
1341 MS_EXCEPTION_IF_NULL(value_node);
1342
1343 // Build abstract from node to prevent confusion between Scalar and 0D-Tensor.
1344 auto abs = node->abstract()->Clone();
1345 MS_EXCEPTION_IF_NULL(abs);
1346 if (abs->isa<abstract::AbstractSequence>()) {
1347 auto elements = abs->cast<abstract::AbstractSequencePtr>()->elements();
1348 abs = elements.at(index)->Clone();
1349 MS_EXCEPTION_IF_NULL(abs);
1350 }
1351
1352 abs->set_shape(std::make_shared<abstract::Shape>(out_shape));
1353 value_node->set_abstract(abs);
1354 return value_node;
1355 }
1356
GetZeroOutputs(const FuncGraphPtr & graph)1357 AnfNodePtr PipelineTransformer::GetZeroOutputs(const FuncGraphPtr &graph) {
1358 // first: out node second: getitem index
1359 auto real_kernel = GetRealKernelNode(graph->output(), -1);
1360 auto real_out = real_kernel.first;
1361 MS_EXCEPTION_IF_NULL(real_out);
1362 std::vector<AnfNodePtr> out_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
1363 if (IsPrimitiveCNode(real_out, prim::kPrimMakeTuple)) {
1364 auto real_out_cnode = real_out->cast<CNodePtr>();
1365 for (size_t i = 1; i < real_out_cnode->size(); ++i) {
1366 auto each_out_shapes = GetNodeShape(real_out_cnode->input(i));
1367 // In case: tuple's input is also a tuple
1368 if (each_out_shapes.size() > 1) {
1369 auto temp_tuple = CreateTupleZeroTensor(real_out_cnode->input(i), each_out_shapes.size());
1370 (void)out_tuple_inputs.emplace_back(temp_tuple);
1371 continue;
1372 }
1373 (void)out_tuple_inputs.emplace_back(CreateZeroseOutput(real_out_cnode->input(i), 0));
1374 }
1375 }
1376 if (out_tuple_inputs.size() > INDEX_ONE) {
1377 auto out_tuple = main_graph_->NewCNode(out_tuple_inputs);
1378 SetMakeTupleAbstract(out_tuple);
1379 return out_tuple;
1380 } else {
1381 auto real_out_shapes = GetNodeShape(real_out);
1382 AnfNodePtr out_tensor;
1383 // In case: op has multioutput
1384 if (real_out_shapes.size() > 1 && real_kernel.second == -1) {
1385 out_tensor = CreateTupleZeroTensor(real_out, real_out_shapes.size());
1386 } else {
1387 out_tensor = CreateZeroseOutput(real_out, 0);
1388 }
1389 return out_tensor;
1390 }
1391 return nullptr;
1392 }
1393
GetOpInfoPair(const AnfNodePtr & node,const AnfNodePtr & graph_param,AnfNodePtr * care_node,bool * is_param)1394 std::pair<OperatorInfoPtr, int> PipelineTransformer::GetOpInfoPair(const AnfNodePtr &node,
1395 const AnfNodePtr &graph_param, AnfNodePtr *care_node,
1396 bool *is_param) {
1397 if (node->isa<Parameter>()) {
1398 return GetParameterPair(graph_param);
1399 } else {
1400 *care_node = FindPipelineCareNode(node);
1401 if ((*care_node)->isa<Parameter>()) {
1402 return GetParameterPair(*care_node);
1403 } else {
1404 *is_param = false;
1405 return GetOpInfo(*care_node);
1406 }
1407 }
1408 }
1409
SetNodeAbstract(const std::vector<AnfNodePtr> & nodes)1410 void PipelineTransformer::SetNodeAbstract(const std::vector<AnfNodePtr> &nodes) {
1411 AbstractBasePtr abs;
1412 if (nodes.size() == 1) {
1413 auto cnode = nodes.front()->cast<CNodePtr>();
1414 MS_EXCEPTION_IF_NULL(cnode);
1415 abs = GetRealAbstract(cnode->input(INDEX_ONE));
1416 } else {
1417 AbstractBasePtrList abstract_list;
1418 abstract_list.resize(nodes.size());
1419 (void)std::transform(nodes.begin(), nodes.end(), abstract_list.begin(), [](const AnfNodePtr &node) {
1420 auto cnode = node->cast<CNodePtr>();
1421 MS_EXCEPTION_IF_NULL(cnode);
1422 return GetRealAbstract(cnode->input(INDEX_ONE));
1423 });
1424 abs = std::make_shared<abstract::AbstractTuple>(abstract_list);
1425 }
1426 for (auto &user : shared_cell_users_) {
1427 user->set_abstract(abs);
1428 }
1429 }
1430
GenNewSendFromOld(const AnfNodePtr & node,const AnfNodePtr & input,const ValuePtr & value)1431 AnfNodePtr PipelineTransformer::GenNewSendFromOld(const AnfNodePtr &node, const AnfNodePtr &input,
1432 const ValuePtr &value) {
1433 const auto &old = node->cast<CNodePtr>();
1434 MS_EXCEPTION_IF_NULL(old);
1435 auto old_is_pipeline_param = old->HasPrimalAttr(PIPELINE_PARAM);
1436 auto dest_rank_ptr = old->user_data<int64_t>(DEST_RANK);
1437 MS_EXCEPTION_IF_NULL(dest_rank_ptr);
1438 auto dest_rank = *dest_rank_ptr;
1439 auto send_tag = send_tag_map[dest_rank];
1440 send_tag_map[dest_rank]++;
1441 Attr attr_tag = std::make_pair(SR_TAG, MakeValue(send_tag));
1442 Attr attr_rank = std::make_pair(DEST_RANK, MakeValue(dest_rank));
1443 Attr attr_group = std::make_pair(GROUP, MakeValue(world_group_));
1444 Attr attr_group_back = std::make_pair(GROUP_BACK, MakeValue(world_group_));
1445 OperatorAttrs attrs = {attr_tag, attr_rank, attr_group, attr_group_back};
1446 std::vector<AnfNodePtr> send_input{input};
1447 auto send = CreateCNodeByInputsAndAttr(main_graph_, SEND, SEND, send_input, attrs);
1448 AnfNodePtr care_node;
1449 bool is_param = true;
1450 auto op_info_pair = GetOpInfoPair(input, input, &care_node, &is_param);
1451 auto tensor_info = GetTensorInfo(op_info_pair, is_param);
1452 auto op_info = op_info_pair.first;
1453 auto index = op_info_pair.second;
1454 auto slice_shape = tensor_info.slice_shape();
1455 auto shape_type_pair = GetShapeType(input, slice_shape, 0);
1456 auto prim = GetCNodePrimitive(send);
1457 prim->set_attr(SHAPE, shape_type_pair.first);
1458 prim->set_attr(DTYPE, shape_type_pair.second);
1459 if (!is_param) {
1460 if (old_is_pipeline_param) {
1461 MS_LOG(EXCEPTION) << "The old send is pipeline_param, but new send is not pipeline_param.";
1462 }
1463 send->AddPrimalAttr(PIPELINE_END, value);
1464 } else {
1465 if (!old_is_pipeline_param) {
1466 MS_LOG(EXCEPTION) << "The old send is not pipeline_param, but new send is pipeline_param.";
1467 }
1468 send->AddPrimalAttr(PARAM_INDEX, MakeValue(index));
1469 send->AddPrimalAttr(PIPELINE_PARAM, value);
1470 send->set_user_data<OperatorInfo>(op_info);
1471 }
1472 send->AddPrimalAttr(MICRO, value);
1473 auto abstract = input->abstract();
1474 if (care_node) {
1475 abstract = care_node->abstract();
1476 }
1477 send->set_abstract(abstract);
1478 return send;
1479 }
1480
FetchSend(const AnfNodePtr & node,bool pipeline_param,bool single_pipeline_end,size_t end_index)1481 std::vector<AnfNodePtr> PipelineTransformer::FetchSend(const AnfNodePtr &node, bool pipeline_param,
1482 bool single_pipeline_end, size_t end_index) {
1483 std::vector<AnfNodePtr> depends;
1484 AnfNodePtr send_input;
1485 if (pipeline_param) {
1486 auto param = node->user_data<AnfNode>(INPUT_PARAM);
1487 MS_EXCEPTION_IF_NULL(param);
1488 auto params = shared_cell_->parameters();
1489 auto iter = std::find(params.begin(), params.end(), param);
1490 if (iter != params.end()) {
1491 auto input_pos = std::distance(params.begin(), iter) + 1;
1492 auto &front = shared_cell_users_.front();
1493 MS_EXCEPTION_IF_NULL(front);
1494 const auto &user = front->cast<CNodePtr>();
1495 MS_EXCEPTION_IF_NULL(user);
1496 send_input = user->input(input_pos);
1497 } else {
1498 const auto &cnode = node->cast<CNodePtr>();
1499 MS_EXCEPTION_IF_NULL(cnode);
1500 send_input = cnode->input(INDEX_ONE);
1501 }
1502 MS_EXCEPTION_IF_NULL(send_input);
1503 auto value = MakeValue(int64_t(0));
1504 (void)(depends.emplace_back(GenNewSendFromOld(node, send_input, value)));
1505 return depends;
1506 }
1507 for (auto &user : shared_cell_users_) {
1508 auto cuser = user->cast<CNodePtr>();
1509 MS_EXCEPTION_IF_NULL(cuser);
1510 auto value = shared_cell_users_.size() > 1 ? cuser->GetPrimalAttr(MICRO) : MakeValue(int64_t(0));
1511 MS_EXCEPTION_IF_NULL(value);
1512 send_input = single_pipeline_end ? user : CreateTupleGetItemNode(main_graph_, user, end_index);
1513 (void)(depends.emplace_back(GenNewSendFromOld(node, send_input, value)));
1514 }
1515 return depends;
1516 }
1517
HandleGraphOutputs(const std::vector<AnfNodePtr> & nodes)1518 void PipelineTransformer::HandleGraphOutputs(const std::vector<AnfNodePtr> &nodes) {
1519 std::vector<AnfNodePtr> pipeline_params;
1520 std::vector<AnfNodePtr> pipeline_ends;
1521 SeparateParamBorder(nodes, true, &pipeline_params, &pipeline_ends);
1522 std::vector<AnfNodePtr> sends;
1523 SetNodeAbstract(pipeline_ends);
1524
1525 // Create root graph output before modify subgraph(shared cell).
1526 // This process order is crucial when the output of subgraph is directly used as root graph.
1527 auto zero_outputs = GetZeroOutputs(main_graph_);
1528
1529 size_t ends_size = pipeline_ends.size();
1530 bool single_pipeline_end = ends_size == 1;
1531 if (single_pipeline_end) {
1532 auto &depend = pipeline_ends.front();
1533 const auto &cdepend = depend->cast<CNodePtr>();
1534 MS_EXCEPTION_IF_NULL(cdepend);
1535 (void)manager_->Replace(shared_cell_->output(), cdepend->input(INDEX_ONE));
1536 } else {
1537 std::vector<AnfNodePtr> rets;
1538 (void)std::transform(pipeline_ends.begin(), pipeline_ends.end(), std::back_inserter(rets),
1539 [](const AnfNodePtr &depend) {
1540 const auto &cdepend = depend->cast<CNodePtr>();
1541 MS_EXCEPTION_IF_NULL(cdepend);
1542 return cdepend->input(INDEX_ONE);
1543 });
1544 auto out = CreateMakeTupleNode(shared_cell_, rets);
1545 (void)manager_->Replace(shared_cell_->output(), out);
1546 }
1547 for (auto &node : pipeline_params) {
1548 auto params = FetchSend(node, true, false, 0);
1549 if (is_train_) {
1550 (void)std::copy(params.begin(), params.end(), std::back_inserter(sends));
1551 }
1552 }
1553 for (size_t i = 0; i < ends_size; i++) {
1554 auto node = pipeline_ends[i];
1555 auto ends = FetchSend(node, false, single_pipeline_end, i);
1556 (void)std::copy(ends.begin(), ends.end(), std::back_inserter(sends));
1557 }
1558 auto make_tuple = CreateMakeTupleNode(main_graph_, sends);
1559 std::vector<AnfNodePtr> out = {NewValueNode(prim::kPrimDepend), zero_outputs, make_tuple};
1560 auto out_node = main_graph_->NewCNode(out);
1561 out_node->set_abstract(zero_outputs->abstract());
1562 (void)manager_->Replace(main_graph_->output(), out_node);
1563 }
1564
GenNewRecvFromOld(const AnfNodePtr & node,const AnfNodePtr & input,const ValuePtr & value)1565 AnfNodePtr PipelineTransformer::GenNewRecvFromOld(const AnfNodePtr &node, const AnfNodePtr &input,
1566 const ValuePtr &value) {
1567 auto cnode = node->cast<CNodePtr>();
1568 MS_EXCEPTION_IF_NULL(cnode);
1569 auto src_rank_ptr = cnode->user_data<int64_t>(SRC_RANK);
1570 MS_EXCEPTION_IF_NULL(src_rank_ptr);
1571 auto src_rank = *src_rank_ptr;
1572 auto recv_tag = recv_tag_map[src_rank];
1573 recv_tag_map[src_rank]++;
1574 auto dtype = node->user_data<Type>(SLICE_DTYPE);
1575 auto slice_shape = *(cnode->user_data<Shape>(SLICE_SHAPE));
1576 auto shape = GetShapeValue(slice_shape);
1577 Attr attr_tag = std::make_pair(SR_TAG, MakeValue(recv_tag));
1578 Attr attr_rank = std::make_pair(SRC_RANK, MakeValue(src_rank));
1579 Attr attr_shape = std::make_pair(SHAPE, shape);
1580 Attr attr_dtype = std::make_pair(DTYPE, dtype);
1581 Attr attr_group = std::make_pair(GROUP, MakeValue(world_group_));
1582 Attr attr_group_back = std::make_pair(GROUP_BACK, MakeValue(world_group_));
1583 OperatorAttrs attrs = {attr_tag, attr_rank, attr_shape, attr_dtype, attr_group, attr_group_back};
1584
1585 std::vector<AnfNodePtr> recv_input = {input};
1586 auto recv = CreateCNodeByInputsAndAttr(main_graph_, RECEIVE, RECEIVE, recv_input, attrs);
1587 auto tensor_layout = node->user_data<TensorLayout>();
1588 if (cnode->HasPrimalAttr(PIPELINE_PARAM)) {
1589 auto abstract_clone = node->abstract()->Clone();
1590 MS_EXCEPTION_IF_NULL(abstract_clone);
1591 recv->set_user_data<AnfNode>(PIPELINE_PARAM, recv_input[INDEX_ZERO]);
1592 recv->AddPrimalAttr(PIPELINE_PARAM, value);
1593 recv_input[INDEX_ZERO]->set_abstract(abstract_clone);
1594 recv_input[INDEX_ZERO]->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(*tensor_layout));
1595 } else {
1596 recv->AddPrimalAttr(PIPELINE_BEGIN, value);
1597 }
1598 auto abstract_clone = node->abstract()->Clone();
1599 MS_EXCEPTION_IF_NULL(abstract_clone);
1600 recv->set_abstract(abstract_clone);
1601
1602 recv->AddPrimalAttr(MICRO, value);
1603 recv->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(*tensor_layout));
1604 recv->set_user_data<OperatorInfo>(node->user_data<OperatorInfo>());
1605 return recv;
1606 }
1607
FetchRecv(const AnfNodePtr & node,bool pipeline_param)1608 std::vector<AnfNodePtr> PipelineTransformer::FetchRecv(const AnfNodePtr &node, bool pipeline_param) {
1609 std::vector<AnfNodePtr> recvs;
1610 AnfNodePtr recv_input;
1611 AnfNodePtr recv;
1612 if (pipeline_param) {
1613 auto value = MakeValue(int64_t(0));
1614 auto param = node->user_data<AnfNode>(INPUT_PARAM);
1615 MS_EXCEPTION_IF_NULL(param);
1616 auto &front = shared_cell_users_.front();
1617 MS_EXCEPTION_IF_NULL(front);
1618 const auto &user = front->cast<CNodePtr>();
1619 MS_EXCEPTION_IF_NULL(user);
1620 auto params = shared_cell_->parameters();
1621 auto user_inputs = user->inputs();
1622 auto iter = std::find(user_inputs.begin(), user_inputs.end(), param);
1623 if (iter != user_inputs.end()) {
1624 auto input_pos = std::distance(user_inputs.begin(), iter);
1625 auto argu = params.at(input_pos - 1);
1626 manager_->SetEdge(node, 1, argu);
1627 node->set_user_data<AnfNode>(INPUT_PARAM, argu);
1628 recv_input = user->input(input_pos);
1629 recv = GenNewRecvFromOld(node, recv_input, value);
1630 for (auto &share_user : shared_cell_users_) {
1631 if (is_train_) {
1632 manager_->SetEdge(share_user, input_pos, recv);
1633 } else {
1634 manager_->SetEdge(share_user, input_pos, recv_input);
1635 }
1636 }
1637 node->set_user_data<bool>(ORIGIN_INPUT_IS_PARAM, std::make_shared<bool>(true));
1638 } else {
1639 const auto &cnode = node->cast<CNodePtr>();
1640 MS_EXCEPTION_IF_NULL(cnode);
1641 recv_input = cnode->input(INDEX_ONE);
1642 recv = GenNewRecvFromOld(node, recv_input, value);
1643 }
1644 (void)(recvs.emplace_back(recv));
1645 return recvs;
1646 }
1647 for (auto &user : shared_cell_users_) {
1648 auto cuser = user->cast<CNodePtr>();
1649 MS_EXCEPTION_IF_NULL(cuser);
1650 auto value = shared_cell_users_.size() > 1 ? cuser->GetPrimalAttr(MICRO) : MakeValue(int64_t(0));
1651 MS_EXCEPTION_IF_NULL(value);
1652 if (enable_share_cell_ || !is_train_) {
1653 auto recv_tensor = TensorConstructUtils::CreateZerosTensor(kFloat16, {1});
1654 recv = GenNewRecvFromOld(node, NewValueNode(recv_tensor), value);
1655 } else {
1656 recv = GenNewRecvFromOld(node, virtual_param_, value);
1657 }
1658 (void)(recvs.emplace_back(recv));
1659 }
1660 return recvs;
1661 }
1662
ResetSharedCellParamAndArgu(const std::vector<std::vector<AnfNodePtr>> & pipeline_begins_fetched,const std::vector<AnfNodePtr> & newly_added_params,const std::vector<AnfNodePtr> & reserved_inputs)1663 void PipelineTransformer::ResetSharedCellParamAndArgu(
1664 const std::vector<std::vector<AnfNodePtr>> &pipeline_begins_fetched,
1665 const std::vector<AnfNodePtr> &newly_added_params, const std::vector<AnfNodePtr> &reserved_inputs) {
1666 // set shared_cell_ parameters, and call_input
1667 auto params = shared_cell_->parameters();
1668 auto ret = shared_cell_->get_return();
1669 MS_EXCEPTION_IF_NULL(ret);
1670 std::vector<AnfNodePtr> searched_params;
1671 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
1672 for (auto &node : all_nodes) {
1673 if (node->isa<Parameter>()) {
1674 searched_params.push_back(node);
1675 }
1676 }
1677 std::set<size_t> reserved_param_index;
1678 std::vector<AnfNodePtr> new_params;
1679 std::vector<AnfNodePtr> monad_params;
1680 // set shared_cell_ parameters
1681 for (size_t i = 0; i < params.size(); i++) {
1682 auto param = params[i];
1683 if (std::find(searched_params.begin(), searched_params.end(), param) == searched_params.end()) {
1684 continue;
1685 }
1686 if (HasAbstractMonad(param)) {
1687 monad_params.push_back(param);
1688 } else {
1689 new_params.push_back(param);
1690 }
1691 (void)(reserved_param_index.insert(i));
1692 }
1693 (void)(new_params.insert(new_params.end(), newly_added_params.begin(), newly_added_params.end()));
1694 (void)(new_params.insert(new_params.end(), monad_params.begin(), monad_params.end()));
1695 MS_LOG(DEBUG) << "The shared cell origin params size is " << params.size() << ", new params size is "
1696 << new_params.size();
1697 manager_->SetParameters(shared_cell_, new_params);
1698 shared_cell_->set_fv_param_count(new_params.size());
1699 // set call inputs
1700 size_t user_index = 0;
1701 for (auto &user : shared_cell_users_) {
1702 auto cuser = user->cast<CNodePtr>();
1703 MS_EXCEPTION_IF_NULL(cuser);
1704 const auto &old_inputs = cuser->inputs();
1705 std::vector<AnfNodePtr> new_inputs{old_inputs.front()};
1706 std::vector<AnfNodePtr> monad_inputs;
1707 for (size_t i = 1; i < old_inputs.size(); i++) {
1708 if (reserved_param_index.find(i - 1) == reserved_param_index.end()) {
1709 continue;
1710 }
1711 auto old_input = old_inputs[i];
1712 if (HasAbstractMonad(old_input)) {
1713 monad_inputs.push_back(old_input);
1714 } else {
1715 new_inputs.push_back(old_input);
1716 }
1717 }
1718 auto newly_added_inputs = reserved_inputs;
1719 auto begins = pipeline_begins_fetched.at(user_index);
1720 (void)(newly_added_inputs.insert(newly_added_inputs.end(), begins.begin(), begins.end()));
1721 (void)(newly_added_inputs.insert(newly_added_inputs.end(), monad_inputs.begin(), monad_inputs.end()));
1722 (void)(new_inputs.insert(new_inputs.end(), newly_added_inputs.begin(), newly_added_inputs.end()));
1723 auto new_call = main_graph_->NewCNode(new_inputs);
1724 new_call->set_attrs(cuser->attrs());
1725 new_call->set_primal_attrs(cuser->primal_attrs());
1726 new_call->set_abstract(cuser->abstract());
1727 (void)manager_->Replace(user, new_call);
1728 user_index++;
1729 }
1730 }
1731
HandleGraphInputs(const std::vector<AnfNodePtr> & recv_ops)1732 void PipelineTransformer::HandleGraphInputs(const std::vector<AnfNodePtr> &recv_ops) {
1733 std::vector<AnfNodePtr> pipeline_params;
1734 std::vector<AnfNodePtr> pipeline_begins;
1735 SeparateParamBorder(recv_ops, false, &pipeline_params, &pipeline_begins);
1736
1737 // reserved inputs
1738 std::vector<AnfNodePtr> reserved_inputs;
1739 // pipeline_param whose input is a parameter
1740 std::vector<AnfNodePtr> pipeline_params_with_param_input;
1741 std::vector<AnfNodePtr> need_link_to_new_param;
1742
1743 for (auto &node : pipeline_params) {
1744 auto recvs = FetchRecv(node, true);
1745 auto cnode = node->cast<CNodePtr>();
1746 MS_EXCEPTION_IF_NULL(cnode);
1747 if (cnode->has_user_data(ORIGIN_INPUT_IS_PARAM)) {
1748 pipeline_params_with_param_input.push_back(node);
1749 } else {
1750 (void)(reserved_inputs.insert(reserved_inputs.end(), recvs.begin(), recvs.end()));
1751 need_link_to_new_param.push_back(node);
1752 }
1753 }
1754 (void)(need_link_to_new_param.insert(need_link_to_new_param.end(), pipeline_begins.begin(), pipeline_begins.end()));
1755
1756 size_t begin_size = pipeline_begins.size();
1757 // The 0th dimension corresponds to shared_cell users
1758 // The first dimension corresponds to recvs
1759 // user0: recv0_0, recv0_1
1760 // user1: recv1_0, recv1_1
1761 size_t shared_cell_users_size = shared_cell_users_.size();
1762 std::vector<std::vector<AnfNodePtr>> pipeline_begins_fetched(shared_cell_users_size, std::vector<AnfNodePtr>());
1763 for (size_t i = 0; i < begin_size; i++) {
1764 auto node = pipeline_begins[i];
1765 auto begins = FetchRecv(node, false);
1766 for (size_t j = 0; j < shared_cell_users_size; j++) {
1767 pipeline_begins_fetched[j].push_back(begins.at(j));
1768 }
1769 }
1770 auto &node_users_map = manager_->node_users();
1771 // relink pipeline_param_with_param_input's users to its input
1772 for (const auto ¶m : pipeline_params_with_param_input) {
1773 const auto &users = node_users_map[param];
1774 auto input = param->user_data<AnfNode>(INPUT_PARAM);
1775 MS_EXCEPTION_IF_NULL(input);
1776 for (const auto &user : users) {
1777 manager_->SetEdge(user.first, user.second, input);
1778 }
1779 }
1780
1781 std::vector<AnfNodePtr> newly_added_params;
1782 // relink pipeline_param_without_param_input and pipeline_begins's users to new parameter
1783 for (const auto &node : need_link_to_new_param) {
1784 auto param = std::make_shared<Parameter>(shared_cell_);
1785 param->set_abstract(node->abstract()->Clone());
1786 newly_added_params.push_back(param);
1787 const auto &users = node_users_map[node];
1788 for (const auto &user : users) {
1789 manager_->SetEdge(user.first, user.second, param);
1790 }
1791 }
1792 ResetSharedCellParamAndArgu(pipeline_begins_fetched, newly_added_params, reserved_inputs);
1793 }
1794
CreateTupleZeroTensor(const AnfNodePtr & node,size_t index)1795 AnfNodePtr PipelineTransformer::CreateTupleZeroTensor(const AnfNodePtr &node, size_t index) {
1796 std::vector<AnfNodePtr> temp_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
1797 auto out_shapes = GetNodeShape(node);
1798 for (size_t ele = 0; ele < out_shapes.size(); ++ele) {
1799 temp_tuple_inputs.emplace_back(CreateZeroseOutput(node, ele));
1800 }
1801 auto temp_tuple = main_graph_->NewCNode(temp_tuple_inputs);
1802 SetMakeTupleAbstract(temp_tuple);
1803 return temp_tuple;
1804 }
1805
CutGraph()1806 void PipelineTransformer::CutGraph() {
1807 world_group_ = GetWorldGroup();
1808 auto send_recv_shared_param = HandleSharedParameter();
1809 auto graph = enable_share_cell_ ? shared_cell_ : main_graph_;
1810 MS_EXCEPTION_IF_NULL(graph);
1811 auto send_recv_cut_border = CutBorder(graph);
1812 std::vector<AnfNodePtr> send_ops;
1813
1814 (void)(send_ops.insert(send_ops.end(), send_recv_shared_param.first.begin(), send_recv_shared_param.first.end()));
1815 (void)(send_ops.insert(send_ops.end(), send_recv_cut_border.first.begin(), send_recv_cut_border.first.end()));
1816 if (IsLastStage() && !enable_share_cell_) {
1817 return;
1818 }
1819 if (!send_ops.empty()) {
1820 type_ptr_ = send_ops.back()->user_data<Type>(DTYPE);
1821 shape_ = send_ops.back()->user_data<ValueList>(SHAPE);
1822 }
1823 if (!enable_share_cell_) {
1824 auto make_tuple = CreateMakeTupleNode(main_graph_, send_ops);
1825 auto zero_outputs = GetZeroOutputs(main_graph_);
1826 std::vector<AnfNodePtr> out = {NewValueNode(prim::kPrimDepend), zero_outputs, make_tuple};
1827 auto out_node = main_graph_->NewCNode(out);
1828 (void)manager_->Replace(main_graph_->output(), out_node);
1829 return;
1830 }
1831 if (!IsLastStage()) {
1832 HandleGraphOutputs(send_ops);
1833 }
1834 std::vector<AnfNodePtr> recv_ops;
1835
1836 (void)(recv_ops.insert(recv_ops.end(), send_recv_shared_param.second.begin(), send_recv_shared_param.second.end()));
1837 (void)(recv_ops.insert(recv_ops.end(), send_recv_cut_border.second.begin(), send_recv_cut_border.second.end()));
1838 HandleGraphInputs(recv_ops);
1839 }
1840
ElimGraphStage()1841 void PipelineTransformer::ElimGraphStage() {
1842 for (auto &fg : manager_->func_graphs()) {
1843 fg->set_stage(-1);
1844 fg->set_segment(-1);
1845 }
1846 }
1847
RedundancyNode(const AnfNodePtr & node,mindspore::HashMap<CNodePtr,std::vector<AnfNodePtr>> * make_tuple_map)1848 void PipelineTransformer::RedundancyNode(const AnfNodePtr &node,
1849 mindspore::HashMap<CNodePtr, std::vector<AnfNodePtr>> *make_tuple_map) {
1850 auto node_users = manager_->node_users()[node];
1851 for (auto &node_user_pair : node_users) {
1852 auto cnode = node_user_pair.first->cast<CNodePtr>();
1853 // node->UpdateState, replaced node wiht U.
1854 auto fg = cnode->func_graph();
1855 MS_EXCEPTION_IF_NULL(fg);
1856 if (fg->stage() != -1 && fg != main_graph_) {
1857 continue;
1858 }
1859 if (IsPrimitiveCNode(cnode, prim::kPrimUpdateState)) {
1860 auto abs = cnode->abstract();
1861 MS_EXCEPTION_IF_NULL(abs);
1862 auto monad_node = NewValueNode(kUMonad);
1863 if (abs->isa<abstract::AbstractIOMonad>()) {
1864 monad_node = NewValueNode(kIOMonad);
1865 }
1866 manager_->SetEdge(cnode, node_user_pair.second, monad_node);
1867 continue;
1868 }
1869 // node->make_tuple, record with a map, Unified deleted later.
1870 if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) || IsPrimitiveCNode(cnode, prim::kPrimMakeList)) {
1871 if (make_tuple_map->find(cnode) == (*make_tuple_map).end()) {
1872 (*make_tuple_map)[cnode] = {node};
1873 } else {
1874 (*make_tuple_map)[cnode].push_back(node);
1875 }
1876 } else {
1877 RedundancyNode(node_user_pair.first, make_tuple_map);
1878 }
1879 }
1880 }
1881
IsRedundancyParameter(const AnfNodePtr & parameter,const std::vector<AnfNodePtr> & non_cloned_parameters)1882 bool PipelineTransformer::IsRedundancyParameter(const AnfNodePtr ¶meter,
1883 const std::vector<AnfNodePtr> &non_cloned_parameters) {
1884 // RedundancyParameter: other stage's parameters included corresponding cloned parameters.
1885 auto param_ptr = parameter->cast<ParameterPtr>();
1886 MS_EXCEPTION_IF_NULL(param_ptr);
1887 if (!param_ptr->has_default()) {
1888 return false;
1889 }
1890 std::set<int64_t> stage_set;
1891 if (!ParameterIsCloned(parameter)) {
1892 stage_set = parameter_color_map_.at(parameter);
1893 } else {
1894 auto parameters = root_->parameters();
1895 auto param_name = param_ptr->name();
1896 auto non_clone_name = param_name.substr(param_name.find_first_of('.') + 1);
1897 for (auto ¶m : non_cloned_parameters) {
1898 auto non_cloned_param = param->cast<ParameterPtr>();
1899 if (non_clone_name != non_cloned_param->name()) {
1900 continue;
1901 }
1902 stage_set = parameter_color_map_.at(param);
1903 break;
1904 }
1905 }
1906 if (stage_set.empty()) {
1907 return false;
1908 }
1909 return stage_set.count(stage_) == 0;
1910 }
1911
HasNoUpdateParameter()1912 bool PipelineTransformer::HasNoUpdateParameter() {
1913 auto parameters = root_->parameters();
1914 for (auto ¶meter : parameters) {
1915 if (ParameterIsCloned(parameter)) {
1916 continue;
1917 }
1918 auto param_info = parameter->cast<ParameterPtr>()->param_info();
1919 if (!param_info) {
1920 continue;
1921 }
1922 auto stage_set = parameter_color_map_.at(parameter);
1923 auto requires_grad = param_info->requires_grad();
1924 if (requires_grad && stage_set.count(stage_)) {
1925 return false;
1926 }
1927 }
1928 return true;
1929 }
1930
FreezeGradient()1931 void PipelineTransformer::FreezeGradient() {
1932 auto node_users_map = manager_->node_users();
1933 if (HasNoUpdateParameter() && is_train_) {
1934 root_->set_flag(NO_UPDATE, true);
1935 auto nodes = root_->nodes();
1936 for (auto &node : nodes) {
1937 if (!IsPrimitiveCNode(node, prim::kPrimJ)) {
1938 continue;
1939 }
1940 auto node_users = node_users_map.at(node);
1941 auto grad_users = node_users_map.at(node_users.front().first);
1942 for (auto &grad_user : grad_users) {
1943 auto user_node = grad_user.first->cast<CNodePtr>();
1944 if (!IsPrimitiveCNode(user_node, prim::kPrimTupleGetItem)) {
1945 continue;
1946 }
1947 auto index = GetTupleGetItemIndex(user_node);
1948 if (index != 1) {
1949 continue;
1950 }
1951 auto temp = node_users_map.at(user_node).front().first;
1952 auto out = root_->output();
1953 std::vector<AnfNodePtr> depend_input = {NewValueNode(prim::kPrimDepend), out, temp};
1954 auto new_node = root_->NewCNode(depend_input);
1955 manager_->Replace(out, new_node);
1956 break;
1957 }
1958 break;
1959 }
1960 for (auto &node : nodes) {
1961 if (!IsPrimitiveCNode(node, prim::kPrimNPUGetFloatStatusV2)) {
1962 continue;
1963 }
1964 auto cnode = node->cast<CNodePtr>();
1965 auto out_cnode = root_->output()->cast<CNodePtr>();
1966 auto grads = out_cnode->input(INDEX_TWO);
1967 std::vector<AnfNodePtr> depend_input = {NewValueNode(prim::kPrimDepend), cnode->input(1), grads};
1968 auto new_node = root_->NewCNode(depend_input);
1969 new_node->set_abstract(cnode->input(1)->abstract());
1970 manager_->Replace(cnode->input(1), new_node);
1971 break;
1972 }
1973 }
1974 }
1975
ElimParameter()1976 void PipelineTransformer::ElimParameter() {
1977 auto parameters = root_->parameters();
1978 mindspore::HashMap<CNodePtr, std::vector<AnfNodePtr>> make_tuple_map;
1979 std::vector<AnfNodePtr> non_cloned_parameters;
1980 FreezeGradient();
1981 auto node_users_map = manager_->node_users();
1982 for (auto ¶meter : parameters) {
1983 if (ParameterIsCloned(parameter)) {
1984 continue;
1985 }
1986 non_cloned_parameters.push_back(parameter);
1987 }
1988 for (auto ¶meter : parameters) {
1989 if (!IsRedundancyParameter(parameter, non_cloned_parameters)) {
1990 continue;
1991 }
1992 MS_LOG(INFO) << "Parameter:" << parameter->DebugString() << " is Redundancy.";
1993 RedundancyNode(parameter, &make_tuple_map);
1994 }
1995 for (auto &temp : make_tuple_map) {
1996 auto make_tuple = temp.first;
1997 auto fg = make_tuple->func_graph();
1998 MS_EXCEPTION_IF_NULL(fg);
1999 auto remove_vector = temp.second;
2000 if (remove_vector.empty()) {
2001 continue;
2002 }
2003 auto make_tuple_user = node_users_map.at(make_tuple).front().first;
2004 auto make_tuple_inputs = make_tuple->inputs();
2005 std::vector<AnfNodePtr> new_inputs;
2006 for (auto &input : make_tuple_inputs) {
2007 if (std::find(remove_vector.begin(), remove_vector.end(), input) == remove_vector.end()) {
2008 new_inputs.push_back(input);
2009 continue;
2010 }
2011 if (root_->has_flag(NO_UPDATE) && IsPrimitiveCNode(make_tuple_user, prim::kPrimAddN)) {
2012 new_inputs.push_back(CreateZeroseOutput(input, 0));
2013 }
2014 }
2015 auto new_make_tuple = fg->NewCNode(new_inputs);
2016 (void)manager_->Replace(make_tuple, new_make_tuple);
2017 }
2018 }
2019
ModifyParameterList()2020 void PipelineTransformer::ModifyParameterList() {
2021 ElimParameter();
2022 auto parameters = root_->parameters();
2023 std::vector<AnfNodePtr> parameter_list;
2024 for (auto ¶meter : parameters) {
2025 auto param = parameter->cast<ParameterPtr>();
2026 MS_EXCEPTION_IF_NULL(param);
2027 if (!manager_->node_users()[parameter].empty() || !param->has_default()) {
2028 parameter_list.push_back(parameter);
2029 }
2030 }
2031 auto del_num = parameters.size() - parameter_list.size();
2032 root_->set_fv_param_count(root_->fv_param_count() - del_num);
2033 manager_->SetParameters(root_, parameter_list);
2034 }
2035 } // namespace parallel
2036 } // namespace mindspore
2037