• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/pipeline_transformer/pipeline_transformer.h"
18 #include <set>
19 #include <vector>
20 #include <string>
21 #include <utility>
22 #include <algorithm>
23 #include <memory>
24 #include "base/base.h"
25 #include "mindspore/core/ops/sequence_ops.h"
26 #include "mindspore/core/ops/other_ops.h"
27 #include "mindspore/core/ops/nn_ops.h"
28 #include "mindspore/core/ops/array_ops.h"
29 #include "mindspore/core/ops/framework_ops.h"
30 #include "mindspore/core/ops/arithmetic_ops.h"
31 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
32 #include "frontend/parallel/ops_info/ops_utils.h"
33 #include "frontend/parallel/group_manager.h"
34 #include "frontend/parallel/parameter_manager.h"
35 #include "include/common/utils/parallel_context.h"
36 #include "frontend/parallel/step_parallel.h"
37 #include "frontend/parallel/node_check.h"
38 #include "frontend/parallel/graph_util/node_info.h"
39 #include "frontend/parallel/graph_util/graph_info.h"
40 #include "frontend/parallel/graph_util/pipeline_split_utils.h"
41 #include "frontend/parallel/step_parallel_utils.h"
42 #include "frontend/parallel/graph_util/graph_splitter.h"
43 #include "frontend/parallel/tensor_layout/shared_parameter.h"
44 #include "ir/anf.h"
45 #include "ir/graph_utils.h"
46 #include "include/common/utils/comm_manager.h"
47 #include "utils/ms_context.h"
48 #include "utils/tensor_construct_utils.h"
49 #include "mindspore/core/utils/parallel_node_check.h"
50 #include "include/common/debug/anf_ir_dump.h"
51 
52 namespace mindspore {
53 namespace parallel {
54 namespace {
SetMakeTupleAbstract(const CNodePtr & node)55 void SetMakeTupleAbstract(const CNodePtr &node) {
56   if (!IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
57     return;
58   }
59 
60   AbstractBasePtrList abstract_list;
61   for (size_t i = 1; i < node->inputs().size(); i++) {
62     abstract_list.emplace_back(node->input(i)->abstract());
63   }
64   auto abs = std::make_shared<abstract::AbstractTuple>(abstract_list);
65   node->set_abstract(abs);
66 }
67 }  // namespace
68 
69 mindspore::HashMap<int64_t, int64_t> send_tag_map;
70 mindspore::HashMap<int64_t, int64_t> recv_tag_map;
71 const std::set<PrimitivePtr> WHITE_LIST = {prim::kPrimTupleGetItem, prim::kPrimMakeTuple, prim::kPrimCast};
72 
IsInWhiteList(const CNodePtr & cnode)73 bool IsInWhiteList(const CNodePtr &cnode) {
74   for (auto prim = WHITE_LIST.cbegin(); prim != WHITE_LIST.cend(); ++prim) {
75     if (IsPrimitiveCNode(cnode, *prim)) {
76       return true;
77     }
78   }
79   return false;
80 }
81 
GetRealAbstract(const AnfNodePtr & node)82 static AbstractBasePtr GetRealAbstract(const AnfNodePtr &node) {
83   if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
84     auto &input = node->cast<CNodePtr>()->input(1);
85     MS_EXCEPTION_IF_NULL(input);
86     return input->abstract();
87   }
88   return node->abstract();
89 }
90 
FindNodeGraph(const CNodePtr & cnode)91 FuncGraphPtr FindNodeGraph(const CNodePtr &cnode) {
92   auto graph = cnode->func_graph();
93   if (IsValueNode<FuncGraph>(cnode->input(0))) {
94     graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
95   }
96   return graph;
97 }
98 
UpdateParameterSharedInfo(const AnfNodePtr & node,const AnfNodePtr & communcate_op,bool is_send)99 void PipelineTransformer::UpdateParameterSharedInfo(const AnfNodePtr &node, const AnfNodePtr &communcate_op,
100                                                     bool is_send) {
101   MS_EXCEPTION_IF_NULL(node);
102   MS_EXCEPTION_IF_NULL(communcate_op);
103 
104   if (!node->isa<Parameter>()) {
105     return;
106   }
107   auto root_param = node;
108   if (node->func_graph() != root_) {
109     root_param = GetArgumentsByParameter(node);
110     MS_EXCEPTION_IF_NULL(root_param);
111   }
112 
113   // get communication info from cnode.
114   auto prim = GetCNodePrimitive(communcate_op);
115   MS_EXCEPTION_IF_NULL(prim);
116 
117   auto sr_tag_attr = prim->GetAttr(SR_TAG);
118   MS_EXCEPTION_IF_NULL(sr_tag_attr);
119   auto sr_tag = GetValue<int64_t>(sr_tag_attr);
120   auto peer_rank_attr = is_send ? prim->GetAttr(DEST_RANK) : prim->GetAttr(SRC_RANK);
121   MS_EXCEPTION_IF_NULL(peer_rank_attr);
122   auto peer_rank = GetValue<int64_t>(peer_rank_attr);
123   auto group_attr = prim->GetAttr(GROUP);
124   MS_EXCEPTION_IF_NULL(group_attr);
125   auto group = GetValue<std::string>(group_attr);
126 
127   // Use global rank since local group may not exist after loading checkpoint.
128   auto rank_list = g_device_manager->FindRankListByHashName(group);
129   peer_rank = rank_list.at(peer_rank);
130 
131   // update tensor layout.
132   auto param = root_param->cast<ParameterPtr>();
133   MS_EXCEPTION_IF_NULL(param);
134   auto shared_parameters = std::make_shared<SharedParameter>(true, is_send, peer_rank, sr_tag);
135   param->set_user_data<SharedParameter>(shared_parameters);
136 }
137 
GetTensorInfo(const std::pair<OperatorInfoPtr,int> & op_info_pair,bool is_param)138 TensorInfo PipelineTransformer::GetTensorInfo(const std::pair<OperatorInfoPtr, int> &op_info_pair, bool is_param) {
139   if (is_param) {
140     auto inputs_tensor_info = op_info_pair.first->inputs_tensor_info();
141     return inputs_tensor_info.at(IntToSize(op_info_pair.second));
142   } else {
143     auto outputs_tensor_info = op_info_pair.first->outputs_tensor_info();
144     return outputs_tensor_info.at(IntToSize(op_info_pair.second));
145   }
146 }
147 
SeparateParamBorder(const std::vector<AnfNodePtr> & nodes,bool send,std::vector<AnfNodePtr> * const params,std::vector<AnfNodePtr> * const borders)148 static void SeparateParamBorder(const std::vector<AnfNodePtr> &nodes, bool send, std::vector<AnfNodePtr> *const params,
149                                 std::vector<AnfNodePtr> *const borders) {
150   std::vector<AnfNodePtr> real_comm_ops;
151   if (send) {
152     (void)std::transform(nodes.begin(), nodes.end(), std::back_inserter(real_comm_ops), [](const AnfNodePtr &n) {
153       const auto &cnode = n->cast<CNodePtr>();
154       MS_EXCEPTION_IF_NULL(cnode);
155       if (cnode->inputs().size() <= INDEX_TWO) {
156         return cnode;
157       }
158       const auto &real = cnode->input(INDEX_TWO)->cast<CNodePtr>();
159       MS_EXCEPTION_IF_NULL(real);
160       return real;
161     });
162   } else {
163     real_comm_ops = nodes;
164   }
165   for (auto &node : real_comm_ops) {
166     const auto &cnode = node->cast<CNodePtr>();
167     MS_EXCEPTION_IF_NULL(cnode);
168     if (cnode->HasPrimalAttr(PIPELINE_PARAM)) {
169       (*params).push_back(node);
170     } else {
171       (*borders).push_back(node);
172     }
173   }
174 }
175 
MainGraph()176 bool PipelineTransformer::MainGraph() {
177   bool find_main_graph = false;
178   for (auto &fg : manager_->func_graphs()) {
179     for (auto &node : fg->nodes()) {
180       if (IsPrimitiveCNode(node, prim::kPrimVirtualDataset)) {
181         main_graph_ = fg;
182         main_graph_->set_flag(MAIN_GRAPH, true);
183         virtual_dataset_ = node;
184         find_main_graph = true;
185         break;
186       }
187     }
188     if (find_main_graph) {
189       break;
190     }
191   }
192   if (!find_main_graph) {
193     MS_LOG(WARNING) << "Can't find main graph, possible reason is can't find virtual dataset.";
194     return false;
195   }
196   for (auto &fg : manager_->func_graphs()) {
197     if (fg->has_flag(FUNC_GRAPH_FLAG_CELL_REUSE)) {
198       shared_cell_ = fg;
199       break;
200     }
201   }
202   if (!shared_cell_) {
203     return true;
204   }
205   auto value_nodes = main_graph_->value_nodes();
206   mindspore::CompactSet<AnfNodePtr> shared_cell_nodes;
207   for (auto value_pair = value_nodes.cbegin(); value_pair != value_nodes.cend(); ++value_pair) {
208     auto node = (*value_pair).first;
209     if (!IsValueNode<FuncGraph>(node)) {
210       continue;
211     }
212     auto graph = GetValueNode<FuncGraphPtr>(node);
213     MS_EXCEPTION_IF_NULL(graph);
214     if (graph == shared_cell_) {
215       (void)(shared_cell_nodes.insert(node));
216     }
217   }
218   if (shared_cell_nodes.empty()) {
219     return true;
220   }
221   for (auto node : shared_cell_nodes) {
222     auto node_users = manager_->node_users()[node];
223     for (auto &node_user : node_users) {
224       auto user = node_user.first;
225       if (user->func_graph() == main_graph_) {
226         if (std::find(shared_cell_users_.begin(), shared_cell_users_.end(), user) == shared_cell_users_.end()) {
227           shared_cell_users_.push_back(user);
228         }
229       }
230     }
231   }
232   MS_LOG(INFO) << "Enable micro-fold, the folded cell is " << shared_cell_->ToString();
233   enable_share_cell_ = true;
234   return true;
235 }
236 
SetMicroBatch(const AnfNodePtr & node,int64_t micro_size,size_t batch_axis) const237 ValuePtr PipelineTransformer::SetMicroBatch(const AnfNodePtr &node, int64_t micro_size, size_t batch_axis) const {
238   if (!IsPrimitiveCNode(node, prim::kPrimStridedSlice)) {
239     MS_LOG(EXCEPTION) << "Can't find MicroBatch information.";
240   }
241   auto cnode = node->cast<CNodePtr>();
242 
243   int64_t micro = 0;
244   auto value = GetValueNode(cnode->input(2));
245   if (value != nullptr) {
246     auto tuple = GetValue<std::vector<int64_t>>(value);  // begin
247     auto input_tmp = GetNodeShape(cnode->input(1));
248     auto input_shape = input_tmp.at(0);
249     auto slice_batch_size = input_shape.at(batch_axis);  // betch shape
250     if (slice_batch_size == 0) {
251       MS_LOG(EXCEPTION) << "slice_batch_size should be a positive integer, but got " << slice_batch_size;
252     }
253     micro = tuple.at(batch_axis) * micro_size / slice_batch_size;  // micro-index
254   } else {
255     // dynamic shape
256     // if micro is not 1: stridedslice --> maketuple --> scalarmul --> micro
257     // if micro is 1: stridedslice --> maketuple --> scalarfloordiv
258     if (!IsPrimitiveCNode(cnode->input(2), prim::kPrimMakeTuple)) {
259       MS_LOG(EXCEPTION) << "the begin of stridedslice is not constant value, and not make tuple";
260     }
261     auto make_tuple_cnode = cnode->input(2)->cast<CNodePtr>();
262 
263     if (IsPrimitiveCNode(make_tuple_cnode->input(1), prim::kPrimScalarMul)) {
264       auto scalar_mul_cnode = make_tuple_cnode->input(1)->cast<CNodePtr>();
265       auto mul_value = GetValueNode(scalar_mul_cnode->input(2));
266       micro = GetValue<int64_t>(mul_value);
267     } else if (IsPrimitiveCNode(make_tuple_cnode->input(1), prim::kPrimScalarFloorDiv)) {
268       micro = 1;
269     } else {
270       MS_LOG(EXCEPTION) << "can not find the micro info, the input op of make tuple is "
271                         << GetCNodePrimitive(make_tuple_cnode->input(1))->name();
272     }
273   }
274 
275   cnode->AddPrimalAttr(MICRO, MakeValue(micro));
276   cnode->AddPrimalAttr(PIPELINE_BEGIN, MakeValue(micro));
277   int64_t seg = 0;
278   cnode->AddPrimalAttr(SEGMENT, MakeValue(seg));
279   return MakeValue(micro);
280 }
281 
GetArgumentsByParameter(const AnfNodePtr & parameter)282 AnfNodePtr PipelineTransformer::GetArgumentsByParameter(const AnfNodePtr &parameter) {
283   auto fg = parameter->func_graph();
284   if (fg == root_) {
285     return parameter;
286   }
287   auto parameters = fg->parameters();
288   auto iter = std::find(parameters.begin(), parameters.end(), parameter);
289   if (iter != parameters.end()) {
290     auto pos = std::distance(parameters.begin(), iter);
291     auto fg_used_map = fg->func_graph_cnodes_index();
292     for (auto &cur_fg_use : fg_used_map) {
293       if (cur_fg_use.first->second != 0) {
294         continue;
295       }
296       auto cur_fg = cur_fg_use.first->first->cast<CNodePtr>();
297       auto argument = cur_fg->input(pos + 1);
298       if (argument->isa<Parameter>()) {
299         return GetArgumentsByParameter(argument);
300       }
301     }
302   }
303   return nullptr;
304 }
305 
NeedGrad(const CNodePtr & cnode)306 bool PipelineTransformer::NeedGrad(const CNodePtr &cnode) {
307   for (auto &input : cnode->inputs()) {
308     auto temp = input;
309     while (IsPrimitiveCNode(temp, prim::kPrimLoad) || IsPrimitiveCNode(temp, prim::kPrimCast) ||
310            IsPrimitiveCNode(temp, prim::kPrimDepend)) {
311       auto input_cnode = temp->cast<CNodePtr>();
312       MS_EXCEPTION_IF_NULL(input_cnode);
313       temp = input_cnode->input(1);
314     }
315     if (temp->isa<Parameter>()) {
316       auto argument = GetArgumentsByParameter(temp);
317       if (!argument || !GetRealKernelNode(argument, -1, nullptr).first->isa<Parameter>()) {
318         continue;
319       }
320       if (ParameterRequireGrad(argument)) {
321         return true;
322       }
323     }
324   }
325   return false;
326 }
327 
LabelParameterStart(const FuncGraphPtr & graph)328 bool PipelineTransformer::LabelParameterStart(const FuncGraphPtr &graph) {
329   auto orders = graph->GetOrderedCnodes();
330   for (auto node = orders.cbegin(); node != orders.cend(); ++node) {
331     auto cnode = (*node)->cast<CNodePtr>();
332     MS_EXCEPTION_IF_NULL(cnode);
333     auto stage_info = cnode->user_data<NodeStageInfo>();
334     if (stage_info == nullptr || stage_info->stage() != 0) {
335       continue;
336     }
337     if (IsValueNode<FuncGraph>(cnode->input(0))) {
338       auto sub_graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
339       if (LabelParameterStart(sub_graph)) {
340         return true;
341       } else {
342         continue;
343       }
344     }
345     if (!IsPipelineCareNode(cnode)) {
346       continue;
347     }
348     if (NeedGrad(cnode)) {
349       auto prim = GetCNodePrimitive(cnode);
350       if (enable_share_cell_) {
351         (void)prim->AddAttr(PARAMETER_START_SHARE_CELL, MakeValue(0));
352       } else {
353         (void)prim->AddAttr(PARAMETER_START, MakeValue(0));
354       }
355       return true;
356     }
357   }
358   return false;
359 }
360 
GetBatchAxisForInput(const AnfNodeIndexSet & input_node_users) const361 size_t PipelineTransformer::GetBatchAxisForInput(const AnfNodeIndexSet &input_node_users) const {
362   Shapes inputs_tuple;
363   for (const auto &input_node_user : input_node_users) {
364     auto node = input_node_user.first;
365     if (!IsPrimitiveCNode(node, prim::kPrimStridedSlice)) {
366       return 0;  // simply return 0 when dynamic shape
367     }
368     auto cnode = node->cast<CNodePtr>();
369     auto value = GetValueNode(cnode->input(2));
370     if (value == nullptr) {
371       return 0;  // simply return 0 when dynamic shape
372     }
373     auto tuple = GetValue<std::vector<int64_t>>(value);
374     inputs_tuple.push_back(tuple);
375   }
376   size_t batch_axis = 0;
377   size_t batch_axis_count = 0;
378   size_t input_dim = inputs_tuple.at(0).size();
379   size_t micro_num = inputs_tuple.size();
380   for (size_t axis = 0; axis < input_dim; ++axis) {
381     for (size_t i = 1; i < micro_num; ++i) {
382       if (inputs_tuple[i][axis] != inputs_tuple[i - 1][axis]) {
383         batch_axis = axis;
384         ++batch_axis_count;
385         break;
386       }
387     }
388   }
389   if (is_train_ && batch_axis_count != kSizeOne) {
390     MS_LOG(EXCEPTION)
391       << "For pipeline parallelism, micro_size partitioning of the input along a certain dimension is and "
392       << "is only allowed, but it is found that " << batch_axis_count << " to be partitioned.";
393   }
394   return batch_axis;
395 }
396 
MicroSize(const AnfNodeIndexSet & input_node_users)397 size_t MicroSize(const AnfNodeIndexSet &input_node_users) {
398   size_t micro_size = 0;
399   for (const auto &input_node_user : input_node_users) {
400     auto node = input_node_user.first;
401     if (IsPrimitiveCNode(node, prim::kPrimStridedSlice)) {
402       micro_size++;
403     }
404   }
405 
406   return micro_size;
407 }
408 
LabelMicroBatch()409 void PipelineTransformer::LabelMicroBatch() {
410   auto graph = enable_share_cell_ ? shared_cell_ : main_graph_;
411   MS_EXCEPTION_IF_NULL(graph);
412   if (!LabelParameterStart(graph)) {
413     MS_LOG(EXCEPTION) << "Stage 0 should has at least 1 parameter. but got none. "
414                       << "One possible cause is that the @lazy_inline decorator is misplaced.";
415   }
416   MS_EXCEPTION_IF_NULL(virtual_dataset_);
417   auto node_user_map = manager_->node_users();
418   auto node_users = node_user_map[virtual_dataset_];
419   auto stage_num = g_device_manager->stage_num();
420   for (auto &node_user : node_users) {
421     if (IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) {
422       auto data_users = manager_->node_users()[node_user.first];
423       auto node_first = data_users.front().first;
424       if (!IsPrimitiveCNode(node_first, prim::kPrimStridedSlice) && !IsPrimitiveCNode(node_first, prim::kPrimShape)) {
425         data_users.clear();
426         data_users = node_user_map[node_first];
427       }
428       auto micro_size = int64_t(MicroSize(data_users));
429       if (is_train_ && micro_size < stage_num) {
430         MS_LOG(EXCEPTION) << "The size of micro_batch must be greater than or equal to stage_num. But got the size of "
431                           << "micro_batch is " << micro_size << " and the stage_num is " << stage_num;
432       }
433       micro_size_ = micro_size;
434       auto batch_axis = GetBatchAxisForInput(data_users);
435       MS_LOG(INFO) << "For the "
436                    << GetSerialNumberString(
437                         GetValue<int64_t>(GetValueNode(node_user.first->cast<CNodePtr>()->input(kIndex2))))
438                    << "input, batch axis is " << batch_axis << ", micro size is : " << micro_size;
439       for (auto &data_user : data_users) {
440         if (!IsPrimitiveCNode(data_user.first, prim::kPrimStridedSlice)) {
441           continue;
442         }
443         auto micro = SetMicroBatch(data_user.first, micro_size, batch_axis);
444         SetStridedSliceStrategy(data_user.first);
445         auto cnode = data_user.first->cast<CNodePtr>();
446         BroadCastMicroBatch(cnode, &node_user_map, micro, 0);
447       }
448     }
449   }
450 }
451 
LabelGenMaskFusion()452 void PipelineTransformer::LabelGenMaskFusion() {
453   auto fgs = manager_->func_graphs();
454   int64_t fusion_id = 0;
455   for (auto fg = fgs.cbegin(); fg != fgs.cend(); ++fg) {
456     if (*fg == root_ || *fg == main_graph_) {
457       continue;
458     }
459     auto stage = (*fg)->stage();
460     if (stage != -1 && stage != stage_) {
461       continue;
462     }
463     auto nodes = (*fg)->nodes();
464     for (auto node = nodes.cbegin(); node != nodes.cend(); ++node) {
465       if (!IsPrimitiveCNode(*node, prim::kPrimDropoutGenMask) && !IsPrimitiveCNode(*node, prim::kPrimDropoutDoMaskV3) &&
466           !IsPrimitiveCNode(*node, prim::kPrimDropout)) {
467         continue;
468       }
469       auto cnode = (*node)->cast<CNodePtr>();
470       MS_EXCEPTION_IF_NULL(cnode);
471       cnode->AddPrimalAttr(kAttrFusion, MakeValue(fusion_id));
472       fusion_id += 1;
473     }
474   }
475 }
476 
Coloring()477 void PipelineTransformer::Coloring() {
478   auto need_coloring = true;
479   std::set<int64_t> stage_set;
480   if (!IsTraining(manager_)) {
481     is_train_ = false;
482   }
483   while (need_coloring) {
484     need_coloring = false;
485     for (auto &fg : manager_->func_graphs()) {
486       if (fg == root_ && is_train_) {
487         continue;
488       }
489       auto value_nodes = fg->value_nodes();
490       for (auto value_pair = value_nodes.cbegin(); value_pair != value_nodes.cend(); ++value_pair) {
491         auto node = (*value_pair).first;
492         if (!IsValueNode<FuncGraph>(node)) {
493           continue;
494         }
495         auto graph = GetValueNode<FuncGraphPtr>(node);
496         if (graph->stage() == -1) {
497           continue;
498         }
499         (void)stage_set.insert(graph->stage());
500         auto node_users = manager_->node_users()[node];
501         for (auto &user_pair : node_users) {
502           auto user_node = user_pair.first->cast<CNodePtr>();
503           user_node->set_user_data<NodeStageInfo>(std::make_shared<NodeStageInfo>(graph->stage()));
504           auto user_node_graph = user_node->func_graph();
505           if (graph->stage() == stage_ && user_node_graph->stage() == -1) {
506             user_node_graph->set_stage(graph->stage());
507             need_coloring = true;
508           }
509         }
510       }
511     }
512   }
513   MS_EXCEPTION_IF_NULL(g_device_manager);
514   auto stage_num = g_device_manager->stage_num();
515   if (SizeToLong(stage_set.size()) != stage_num) {
516     MS_LOG(EXCEPTION) << "Stage num is " << stage_num << " is not equal to stage used: " << stage_set.size();
517   }
518 }
519 
BroadCastColoring()520 void PipelineTransformer::BroadCastColoring() {
521   auto need_coloring = true;
522   while (need_coloring) {
523     need_coloring = false;
524     auto all_nodes = enable_share_cell_ ? shared_cell_->nodes() : main_graph_->nodes();
525     auto node_users = manager_->node_users();
526     for (auto node = all_nodes.cbegin(); node != all_nodes.cend(); ++node) {
527       auto stage_info = (*node)->user_data<NodeStageInfo>();
528       if (!(*node)->isa<CNode>() || stage_info == nullptr || stage_info->stage() == -1 ||
529           IsPrimitiveCNode(*node, prim::kPrimUpdateState)) {
530         continue;
531       }
532       auto stage = stage_info->stage();
533       for (auto &user_pair : node_users[*node]) {
534         auto user_node = user_pair.first->cast<CNodePtr>();
535         auto user_stage_info = user_node->user_data<NodeStageInfo>();
536         if (user_stage_info == nullptr) {
537           user_node->set_user_data<NodeStageInfo>(std::make_shared<NodeStageInfo>(stage));
538           need_coloring = true;
539           continue;
540         }
541         auto user_node_stage = user_stage_info->stage();
542         if (stage > user_node_stage) {
543           if (IsValueNode<FuncGraph>(user_node->input(0))) {
544             MS_LOG(EXCEPTION) << "The stage setting is incorrect. PreNode's stage:" << stage
545                               << " is larger than NextNode's stage:" << user_node_stage;
546           }
547           user_node->set_user_data<NodeStageInfo>(std::make_shared<NodeStageInfo>(stage));
548           need_coloring = true;
549         }
550       }
551     }
552   }
553   for (auto &fg : manager_->func_graphs()) {
554     auto stage = fg->stage();
555     if (stage < 0) {
556       continue;
557     }
558     if (fg == root_ || fg == main_graph_ || fg == shared_cell_) {
559       continue;
560     }
561     auto all_nodes = fg->nodes();
562     for (auto node : all_nodes) {
563       if (node->user_data<NodeStageInfo>() != nullptr) {
564         continue;
565       }
566       node->set_user_data<NodeStageInfo>(std::make_shared<NodeStageInfo>(stage));
567     }
568   }
569 }
570 
GetLoadNodeByParam(const AnfNodePtr & param) const571 std::vector<AnfNodePtr> PipelineTransformer::GetLoadNodeByParam(const AnfNodePtr &param) const {
572   std::vector<AnfNodePtr> load_vec = {param};
573   auto node_users = manager_->node_users()[param];
574   for (auto &param_user : node_users) {
575     if (IsPrimitiveCNode(param_user.first, prim::kPrimLoad)) {
576       auto graph = param_user.first->func_graph();
577       // exclude opt graphs
578       if (graph == root_ || (graph->stage() == -1 && graph != main_graph_)) {
579         continue;
580       }
581       (void)load_vec.emplace_back(param_user.first);
582     }
583   }
584   return load_vec;
585 }
586 
IsPipelineCareNode(const CNodePtr & cnode) const587 bool PipelineTransformer::IsPipelineCareNode(const CNodePtr &cnode) const {
588   MS_EXCEPTION_IF_NULL(cnode);
589   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
590   if (!prim) {
591     return false;
592   }
593   if (IsInWhiteList(cnode)) {
594     return false;
595   }
596   if (!IsParallelConsiderCNode(cnode)) {
597     MS_LOG(INFO) << "PipelineSplit don't care node:" << prim->name();
598     return false;
599   }
600   return true;
601 }
602 
GraphOutNode(const AnfNodePtr & node,int tuple_index)603 CNodePtr PipelineTransformer::GraphOutNode(const AnfNodePtr &node, int tuple_index) {
604   auto cnode = node->cast<CNodePtr>();
605   MS_EXCEPTION_IF_NULL(cnode);
606   if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
607     return GraphOutNode(cnode->input(1), tuple_index);
608   }
609   if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
610     return cnode->input(IntToSize(tuple_index) + 1)->cast<CNodePtr>();
611   }
612   return cnode;
613 }
614 
CreateOpInfo(const CNodePtr & cnode,int tuple_index=0)615 OperatorInfoPtr PipelineTransformer::CreateOpInfo(const CNodePtr &cnode, int tuple_index = 0) {
616   MS_EXCEPTION_IF_NULL(cnode);
617   auto temp_node = cnode;
618   if (IsValueNode<FuncGraph>(cnode->input(0))) {
619     auto output = GetValueNode<FuncGraphPtr>(cnode->input(0))->output();
620     MS_EXCEPTION_IF_NULL(output);
621     temp_node = GraphOutNode(output, tuple_index);
622   }
623   if (!IsPipelineCareNode(temp_node)) {
624     MS_LOG(EXCEPTION) << "Node: " << temp_node->DebugString() << " is not a Pipeline Care Node.";
625   }
626   if (IsPrimitiveCNode(temp_node, prim::kPrimVirtualDataset)) {
627     SetVirtualDatasetStrategy(temp_node);
628   }
629 
630   auto prim = GetValueNode<PrimitivePtr>(temp_node->input(0));
631   MS_EXCEPTION_IF_NULL(prim);
632   if (prim->name() == RESHAPE) {
633     MS_LOG(EXCEPTION) << "Reshape op can't be a border. node:" << temp_node->DebugString();
634   }
635   auto attrs = prim->attrs();
636   auto op_info = CreateOperatorInfo(temp_node);
637 
638   StrategyPtr in_strategy = nullptr, out_strategy = nullptr;
639   if (!StrategyFound(attrs)) {
640     in_strategy = GenerateBatchParallelStrategy(op_info, prim);
641   } else {
642     in_strategy = ExtractStrategy(attrs[IN_STRATEGY]);
643     out_strategy = ExtractStrategy(attrs[OUT_STRATEGY]);
644   }
645   MS_EXCEPTION_IF_NULL(in_strategy);
646   if (op_info->Init(in_strategy, out_strategy) == FAILED) {
647     MS_LOG(EXCEPTION) << "operator: " << prim->name() << " init failed.";
648   }
649   return op_info;
650 }
651 
GetOpInfo(const AnfNodePtr & node)652 std::pair<OperatorInfoPtr, int> PipelineTransformer::GetOpInfo(const AnfNodePtr &node) {
653   MS_EXCEPTION_IF_NULL(node);
654   auto cnode = node->cast<CNodePtr>();
655   MS_EXCEPTION_IF_NULL(cnode);
656   // Handle Cast and TupleGetitem situation
657   int tensor_info_index = 0;
658   OperatorInfoPtr op_info;
659   if (IsPrimitiveCNode(node, prim::kPrimReceive)) {
660     op_info = node->user_data<OperatorInfo>();
661   } else {
662     if (IsPrimitiveCNode(node, prim::kPrimCast)) {
663       cnode = cnode->input(1)->cast<CNodePtr>();
664     } else if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
665       tensor_info_index = LongToInt(GetTupleGetItemIndex(cnode));
666       cnode = cnode->input(1)->cast<CNodePtr>();
667     }
668     // Create OperatorInfo to get slice_shape for send/recv
669     MS_EXCEPTION_IF_NULL(cnode);
670     if (cnode->has_user_data<OperatorInfo>()) {
671       op_info = cnode->user_data<OperatorInfo>();
672     } else {
673       op_info = CreateOpInfo(cnode, tensor_info_index);
674     }
675   }
676   return std::make_pair(op_info, tensor_info_index);
677 }
678 
GetActualOpUsers(const AnfNodePtr & node,NodeUsersMap * node_users_map)679 AnfNodeIndexSet GetActualOpUsers(const AnfNodePtr &node, NodeUsersMap *node_users_map) {
680   AnfNodeIndexSet users;
681   auto user_pairs = (*node_users_map)[node];
682   for (const auto &user_pair : user_pairs) {
683     const auto user = user_pair.first;
684     const auto &cuser = user->cast<CNodePtr>();
685     MS_EXCEPTION_IF_NULL(cuser);
686     const auto &input = cuser->input(0);
687     MS_EXCEPTION_IF_NULL(input);
688     AnfNodePtr temp_node = nullptr;
689     if (IsValueNode<FuncGraph>(input)) {
690       auto graph = GetValueNode<FuncGraphPtr>(input);
691       MS_EXCEPTION_IF_NULL(graph);
692       auto temp_params = graph->parameters();
693       auto index = user_pair.second;
694       if (temp_params.size() < IntToSize(index)) {
695         MS_LOG(EXCEPTION) << "parameter: " << temp_node->DebugString() << " out of graph: " << graph->ToString()
696                           << "'s range.";
697       }
698       temp_node = temp_params[IntToSize(index - 1)];
699     } else if (IsPrimitiveCNode(cuser, prim::kPrimLoad) || IsPrimitiveCNode(cuser, prim::kPrimCast) ||
700                IsPrimitiveCNode(cuser, prim::kPrimMirrorSilentCheck)) {
701       temp_node = cuser;
702     }
703     if (temp_node) {
704       const auto &temp_users = GetActualOpUsers(temp_node, node_users_map);
705       (void)(users.insert(temp_users.begin(), temp_users.end()));
706     } else {
707       (void)(users.insert(user_pair));
708     }
709   }
710   return users;
711 }
712 
GetParameterPair(const AnfNodePtr & node)713 std::pair<OperatorInfoPtr, int> PipelineTransformer::GetParameterPair(const AnfNodePtr &node) {
714   MS_EXCEPTION_IF_NULL(node);
715   auto node_users_map = manager_->node_users();
716   const auto &node_users = GetActualOpUsers(node, &node_users_map);
717   for (auto &node_user : node_users) {
718     auto user = node_user.first->cast<CNodePtr>();
719     MS_EXCEPTION_IF_NULL(user);
720     auto user_graph = user->func_graph();
721     MS_EXCEPTION_IF_NULL(user_graph);
722     if (user_graph->stage() == -1) {
723       continue;
724     }
725     auto index = node_user.second;
726     if (!IsPipelineCareNode(user)) {
727       continue;
728     }
729     OperatorInfoPtr op_info;
730     if (user->has_user_data<OperatorInfo>()) {
731       op_info = user->user_data<OperatorInfo>();
732     } else {
733       op_info = CreateOpInfo(user);
734     }
735     return std::make_pair(op_info, index - 1);
736   }
737   return std::make_pair(nullptr, 0);
738 }
739 
GetParameterLoadUsers(const AnfNodePtr & node,const NodeUsersMap & node_users_map) const740 AnfNodeIndexSet PipelineTransformer::GetParameterLoadUsers(const AnfNodePtr &node,
741                                                            const NodeUsersMap &node_users_map) const {
742   AnfNodeIndexSet users;
743   if (node_users_map.find(node) == node_users_map.end()) {
744     return users;
745   }
746   auto loads = GetLoadNodeByParam(node);
747   for (auto &load : loads) {
748     auto iter = node_users_map.find(load);
749     if (iter == node_users_map.end()) {
750       continue;
751     }
752     const auto &temp_users = iter->second;
753     for (const auto &user : temp_users) {
754       auto cuser = user.first->cast<CNodePtr>();
755       MS_EXCEPTION_IF_NULL(cuser);
756       const auto &input = cuser->input(0);
757       MS_EXCEPTION_IF_NULL(input);
758       if (enable_share_cell_ && IsValueNode<FuncGraph>(input) && GetValueNode<FuncGraphPtr>(input) == shared_cell_) {
759         auto index = user.second;
760         auto pos = index - 1;
761         const auto &share_cell_params = shared_cell_->parameters();
762         const auto &param = share_cell_params.at(pos);
763         const auto &param_iter = node_users_map.find(param);
764         if (param_iter == node_users_map.end()) {
765           continue;
766         }
767         const auto &param_users = param_iter->second;
768         users.insert(param_users.begin(), param_users.end());
769       } else {
770         users.insert(user);
771       }
772     }
773   }
774   return users;
775 }
776 
HandleSharedParameter()777 std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> PipelineTransformer::HandleSharedParameter() {
778   auto parameters = root_->parameters();
779   std::vector<AnfNodePtr> sends = {};
780   std::vector<AnfNodePtr> recvs = {};
781   for (auto &parameter : parameters) {
782     auto parameter_stage = parameter_color_map_[parameter];
783     if (parameter_stage.size() <= 1) {
784       continue;
785     }
786     const auto &node_users_map = manager_->node_users();
787     auto users = GetParameterLoadUsers(parameter, node_users_map);
788     for (auto &user : users) {
789       if (!is_train_ && !enable_share_cell_) {
790         continue;
791       }
792       auto node = user.first;
793       auto cnode = node->cast<CNodePtr>();
794       auto graph = FindNodeGraph(cnode);
795       if (graph == root_ || graph->stage() == -1 || parameter_stage.count(stage_) == 0) {
796         continue;
797       }
798       auto micro = cnode->GetPrimalAttr(MICRO);
799       if (!micro) {
800         MS_LOG(INFO) << "parameter: " << parameter->ToString() << " doesn't have micro batch";
801         micro = MakeValue(int64_t(0));
802       }
803       if (stage_ == *(parameter_stage.begin())) {
804         auto user_stage = graph->stage();
805         auto stage_info = node->user_data<NodeStageInfo>();
806         if (stage_info) {
807           user_stage = stage_info->stage();
808         }
809         if (graph->stage() == stage_ || user_stage == -1) {
810           continue;
811         }
812         if (Reuse(parameter, user_stage, sends, DEST_RANK)) {
813           continue;
814         }
815         auto send_out = InsertSend(parameter, user_stage, stage_, micro);
816         sends.push_back(send_out.depend);
817       } else {
818         auto receive = Reuse(parameter, *parameter_stage.begin(), recvs, SRC_RANK);
819         if (receive) {
820           manager_->SetEdge(node, user.second, receive);
821         } else {
822           AnfNodePtr recv;
823           auto fg = enable_share_cell_ ? shared_cell_ : main_graph_;
824           recv = InsertReceive(fg, parameter, node, user.second, stage_, *parameter_stage.begin(), micro, parameter);
825           (void)(recvs.push_back(recv));
826         }
827       }
828     }
829   }
830   return std::make_pair(sends, recvs);
831 }
832 
FillParameterStage(const CNodePtr & node,std::set<int64_t> * const parameter_stage)833 void PipelineTransformer::FillParameterStage(const CNodePtr &node, std::set<int64_t> *const parameter_stage) {
834   auto stage_info = node->user_data<NodeStageInfo>();
835   if (stage_info != nullptr && stage_info->stage() != -1) {
836     (void)(parameter_stage->insert(stage_info->stage()));
837   } else {
838     auto graph = node->func_graph();
839     MS_EXCEPTION_IF_NULL(graph);
840     if (graph != root_ && graph != main_graph_ && graph != shared_cell_ && graph->stage() != -1) {
841       (void)(parameter_stage->insert(graph->stage()));
842     }
843   }
844 }
845 
GetStageByArgument(const CNodePtr & node,size_t index,const std::vector<AnfNodePtr> & parameters,const NodeUsersMap & node_users_map,std::set<int64_t> * const parameter_stage)846 bool PipelineTransformer::GetStageByArgument(const CNodePtr &node, size_t index,
847                                              const std::vector<AnfNodePtr> &parameters,
848                                              const NodeUsersMap &node_users_map,
849                                              std::set<int64_t> *const parameter_stage) {
850   if (!enable_share_cell_) {
851     return false;
852   }
853   if (index < 1) {
854     return false;
855   }
856   const auto &input = node->input(0);
857   if (!IsValueNode<FuncGraph>(input)) {
858     FillParameterStage(node, parameter_stage);
859     return true;
860   }
861   if (GetValueNode<FuncGraphPtr>(input) != shared_cell_) {
862     return false;
863   }
864   auto pos = index - 1;
865   const auto &param = parameters.at(pos);
866   MS_EXCEPTION_IF_NULL(param);
867   auto loads = GetLoadNodeByParam(param);
868   for (auto &load : loads) {
869     const auto &iter = node_users_map.find(load);
870     if (iter == node_users_map.end()) {
871       continue;
872     }
873     const auto &users = (*iter).second;
874     for (auto &user : users) {
875       auto user_cnode = user.first->cast<CNodePtr>();
876       MS_EXCEPTION_IF_NULL(user_cnode);
877       FillParameterStage(user_cnode, parameter_stage);
878     }
879   }
880   return true;
881 }
882 
ParameterColoring()883 void PipelineTransformer::ParameterColoring() {
884   auto parameters = root_->parameters();
885   auto &node_users_map = manager_->node_users();
886   const auto &share_cell_parameters = shared_cell_->parameters();
887   for (auto &parameter : parameters) {
888     auto loads = GetLoadNodeByParam(parameter);
889     std::set<int64_t> parameter_stage;
890     for (auto &load : loads) {
891       auto load_users = node_users_map[load];
892       for (auto &load_user : load_users) {
893         auto user_cnode = load_user.first->cast<CNodePtr>();
894         MS_EXCEPTION_IF_NULL(user_cnode);
895         if (GetStageByArgument(user_cnode, load_user.second, share_cell_parameters, node_users_map, &parameter_stage)) {
896           continue;
897         }
898         FillParameterStage(user_cnode, &parameter_stage);
899       }
900     }
901     auto param_info = parameter->cast<ParameterPtr>()->param_info();
902     if (!param_info) {
903       parameter_color_map_[parameter] = parameter_stage;
904       continue;
905     }
906     MS_EXCEPTION_IF_NULL(param_info);
907     auto requires_grad = param_info->requires_grad();
908     if (!parameter_stage.empty() && *parameter_stage.begin() == stage_ && !virtual_param_ && requires_grad) {
909       virtual_param_ = parameter;
910     }
911     parameter_color_map_[parameter] = parameter_stage;
912   }
913 }
914 
RemoveMonadNode()915 void PipelineTransformer::RemoveMonadNode() {
916   auto all_nodes = DeepScopedGraphSearch(main_graph_->get_return());
917   auto node_users_map = manager_->node_users();
918   for (auto &node : all_nodes) {
919     if (!IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
920       continue;
921     }
922     auto cnode = node->cast<CNodePtr>();
923     MS_EXCEPTION_IF_NULL(cnode);
924     auto abs = cnode->abstract();
925     MS_EXCEPTION_IF_NULL(abs);
926     auto stage_info = cnode->user_data<NodeStageInfo>();
927     if (stage_info == nullptr) {
928       continue;
929     }
930     auto stage = stage_info->stage();
931     if (stage != stage_ && stage != -1) {
932       auto node_users = node_users_map[node];
933       for (auto &user_node : node_users) {
934         auto monad_node = NewValueNode(kUMonad);
935         if (abs->isa<abstract::AbstractIOMonad>()) {
936           monad_node = NewValueNode(kIOMonad);
937         }
938         manager_->SetEdge(user_node.first, user_node.second, monad_node);
939       }
940     }
941   }
942 }
943 
GetShapeValue(const Shape & shape)944 static ValueListPtr GetShapeValue(const Shape &shape) {
945   std::vector<ValuePtr> element;
946   (void)std::transform(shape.begin(), shape.end(), std::back_inserter(element),
947                        [](int elem) { return MakeValue(elem); });
948   return std::make_shared<ValueList>(element);
949 }
950 
GetShapeType(const AnfNodePtr & node,const Shape & shape,size_t index)951 std::pair<ValueListPtr, TypePtr> GetShapeType(const AnfNodePtr &node, const Shape &shape, size_t index) {
952   TypePtr type;
953   auto cnode = node->cast<CNodePtr>();
954   if (cnode != nullptr && IsValueNode<FuncGraph>(cnode->input(0))) {
955     auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
956     auto graph_output = graph->output();
957     type = graph_output->Type();
958   } else {
959     if (node->isa<CNode>() && IsPrimitiveCNode(node->cast<CNodePtr>(), prim::kPrimDepend)) {
960       type = cnode->input(1)->Type();
961     } else {
962       type = node->Type();
963     }
964   }
965   MS_EXCEPTION_IF_NULL(type);
966 
967   TensorTypePtr tensor_type;
968   if (type->isa<mindspore::TensorType>()) {
969     tensor_type = type->cast<mindspore::TensorTypePtr>();
970   } else if (type->isa<Tuple>()) {
971     auto tuple_type = type->cast<TuplePtr>();
972     MS_EXCEPTION_IF_NULL(tuple_type);
973     tensor_type = tuple_type->elements().at(index)->cast<TensorTypePtr>();
974   }
975   MS_EXCEPTION_IF_NULL(tensor_type);
976   auto dtype = tensor_type->element();
977   MS_EXCEPTION_IF_NULL(dtype);
978   auto shape_list = GetShapeValue(shape);
979   return std::make_pair(shape_list, dtype);
980 }
981 
FindPipelineCareNode(const AnfNodePtr & node) const982 AnfNodePtr PipelineTransformer::FindPipelineCareNode(const AnfNodePtr &node) const {
983   MS_EXCEPTION_IF_NULL(node);
984   auto real_node = GetRealKernelNode(node, -1).first;
985   if (!real_node->isa<CNode>()) {
986     return real_node;
987   }
988   auto cnode = real_node->cast<CNodePtr>();
989   MS_EXCEPTION_IF_NULL(cnode);
990   if (IsInWhiteList(cnode)) {
991     return cnode->cast<AnfNodePtr>();
992   }
993   if (!IsPipelineCareNode(cnode)) {
994     MS_LOG(EXCEPTION) << "Only PipelineSplit cared node can be a border."
995                       << " border node: " << cnode->DebugString();
996   }
997   return cnode->cast<AnfNodePtr>();
998 }
999 
InsertSend(const AnfNodePtr & parameter,int64_t user_node_stage,int64_t node_stage,const ValuePtr & value)1000 SendAttr PipelineTransformer::InsertSend(const AnfNodePtr &parameter, int64_t user_node_stage, int64_t node_stage,
1001                                          const ValuePtr &value) {
1002   auto dest_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_;
1003   int64_t send_tag = send_tag_map[dest_rank];
1004   send_tag_map[dest_rank]++;
1005   Attr attr_tag = std::make_pair(SR_TAG, MakeValue(send_tag));
1006   Attr attr_rank = std::make_pair(DEST_RANK, MakeValue(dest_rank));
1007   Attr attr_group = std::make_pair(GROUP, MakeValue(world_group_));
1008   Attr attr_group_back = std::make_pair(GROUP_BACK, MakeValue(world_group_));
1009   OperatorAttrs attrs = {attr_tag, attr_rank, attr_group, attr_group_back};
1010   AnfNodePtr care_node;
1011   bool is_param = true;
1012   auto op_info_pair = GetOpInfoPair(parameter, parameter, &care_node, &is_param);
1013   MS_EXCEPTION_IF_NULL(op_info_pair.first);
1014   auto tensor_info = GetTensorInfo(op_info_pair, is_param);
1015   auto index = op_info_pair.second;
1016   auto op_info = op_info_pair.first;
1017   auto slice_shape = tensor_info.slice_shape();
1018   auto shape_type_pair = GetShapeType(parameter, slice_shape, 0);
1019   auto graph = enable_share_cell_ ? shared_cell_ : main_graph_;
1020   CNodePtr send = CreateCNodeByInputsAndAttr(graph, SEND, SEND, AnfNodePtrList{parameter}, attrs);
1021   auto prim = GetCNodePrimitive(send);
1022   prim->set_attr(SHAPE, shape_type_pair.first);
1023   prim->set_attr(DTYPE, shape_type_pair.second);
1024 
1025   if (!is_param) {
1026     send->AddPrimalAttr(PIPELINE_END, value);
1027   } else {
1028     send->AddPrimalAttr(PIPELINE_PARAM, value);
1029     send->set_user_data<OperatorInfo>(op_info);
1030     send->AddPrimalAttr(PARAM_INDEX, MakeValue(index));
1031     auto param = care_node ? care_node : parameter;
1032     send->set_user_data<AnfNode>(INPUT_PARAM, param);
1033   }
1034   send->AddPrimalAttr(MICRO, value);
1035   send->AddPrimalAttr(DEST_RANK, MakeValue(user_node_stage));
1036   auto abstract = parameter->abstract();
1037   if (care_node) {
1038     abstract = care_node->abstract();
1039   }
1040   send->set_abstract(abstract);
1041   SendAttr send_out = {shape_type_pair.first, shape_type_pair.second, send};
1042 
1043   // for FetchSends
1044   send->set_user_data<int64_t>(DEST_RANK, std::make_shared<int64_t>(dest_rank));
1045   send->set_user_data<int64_t>(USER_NODE_STAGE, std::make_shared<int64_t>(user_node_stage));
1046   return send_out;
1047 }
1048 
InsertReceive(const FuncGraphPtr & graph,const AnfNodePtr & node,const AnfNodePtr & use_node,int index,int64_t user_node_stage,int64_t node_stage,const ValuePtr & value,const AnfNodePtr & graph_param)1049 AnfNodePtr PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node,
1050                                               const AnfNodePtr &use_node, int index, int64_t user_node_stage,
1051                                               int64_t node_stage, const ValuePtr &value,
1052                                               const AnfNodePtr &graph_param) {
1053   auto src_rank = global_rank_ - (user_node_stage - node_stage) * per_stage_rank_num_;
1054   int64_t recv_tag = recv_tag_map[src_rank];
1055   recv_tag_map[src_rank]++;
1056   Attr attr_tag = std::make_pair(SR_TAG, MakeValue(recv_tag));
1057   Attr attr_rank = std::make_pair(SRC_RANK, MakeValue(src_rank));
1058   bool is_param = true;
1059   AnfNodePtr care_node;
1060   auto op_info_pair = GetOpInfoPair(node, graph_param, &care_node, &is_param);
1061   auto tensor_info = GetTensorInfo(op_info_pair, is_param);
1062   auto tensor_layout = tensor_info.tensor_layout();
1063   Shape slice_shape = tensor_info.slice_shape();
1064   auto shape_type_pair = GetShapeType(node, slice_shape, 0);
1065   Attr attr_shape = std::make_pair(SHAPE, shape_type_pair.first);
1066   Attr attr_dtype = std::make_pair(DTYPE, shape_type_pair.second);
1067   Attr attr_group = std::make_pair(GROUP, MakeValue(world_group_));
1068   Attr attr_group_back = std::make_pair(GROUP_BACK, MakeValue(world_group_));
1069   OperatorAttrs attrs = {attr_tag, attr_rank, attr_shape, attr_dtype, attr_group, attr_group_back};
1070   std::vector<AnfNodePtr> recv_input;
1071   if (node->isa<Parameter>()) {
1072     recv_input = {node};
1073   } else {
1074     recv_input = {virtual_param_};
1075     if (enable_share_cell_ || !is_train_) {
1076       auto recv_tensor = TensorConstructUtils::CreateZerosTensor(kFloat16, {1});
1077       recv_input = {NewValueNode(recv_tensor)};
1078     } else {
1079       if (virtual_param_ == nullptr) {
1080         MS_LOG(EXCEPTION)
1081           << "For Pipeline Parallel, each stage must have at least one parameter that needs to be trained, but stage: "
1082           << stage_ << " has none.";
1083       }
1084     }
1085   }
1086   auto recv = CreateCNodeByInputsAndAttr(graph, RECEIVE, RECEIVE, recv_input, attrs);
1087   if (is_param) {
1088     recv->set_user_data<AnfNode>(PIPELINE_PARAM, node);
1089     recv->AddPrimalAttr(PIPELINE_PARAM, value);
1090     auto param = care_node ? care_node : node;
1091     recv->set_user_data<AnfNode>(INPUT_PARAM, param);
1092   } else {
1093     recv->AddPrimalAttr(PIPELINE_BEGIN, value);
1094   }
1095   recv->AddPrimalAttr(MICRO, value);
1096   recv->AddPrimalAttr(SRC_RANK, MakeValue(node_stage));
1097   auto node_abstract = node->abstract();
1098   if (node->isa<CNode>()) {
1099     auto cnode = node->cast<CNodePtr>();
1100     MS_EXCEPTION_IF_NULL(cnode);
1101     if (IsValueNode<FuncGraph>(cnode->input(0))) {
1102       auto output = GetValueNode<FuncGraphPtr>(cnode->input(0))->output();
1103       MS_EXCEPTION_IF_NULL(output);
1104       node_abstract = output->abstract();
1105     }
1106   }
1107   MS_EXCEPTION_IF_NULL(node_abstract);
1108   recv->set_abstract(node_abstract);
1109   if (node->isa<Parameter>()) {
1110     BaseShapePtr parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
1111     auto abstract_clone = node->abstract()->Clone();
1112     MS_EXCEPTION_IF_NULL(abstract_clone);
1113     abstract_clone->set_shape(parallel_shape);
1114     node->set_abstract(abstract_clone);
1115     node->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
1116     auto actual_param = RefParameterToActualParameter(node);
1117     if (actual_param) {
1118       actual_param->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
1119       auto actual_param_abstract = actual_param->abstract()->Clone();
1120       actual_param_abstract->set_shape(parallel_shape);
1121       actual_param->set_abstract(actual_param_abstract);
1122     }
1123   }
1124   recv->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
1125   recv->set_user_data<OperatorInfo>(op_info_pair.first);
1126 
1127   // for FetchRecvs
1128   recv->set_user_data<int64_t>(SRC_RANK, std::make_shared<int64_t>(src_rank));
1129   recv->set_user_data<int64_t>(NODE_STAGE, std::make_shared<int64_t>(node_stage));
1130   recv->set_user_data<Type>(SLICE_DTYPE, shape_type_pair.second);
1131   recv->set_user_data<Shape>(SLICE_SHAPE, std::make_shared<Shape>(slice_shape));
1132 
1133   manager_->SetEdge(use_node, index, recv);
1134   return recv;
1135 }
1136 
Reuse(const AnfNodePtr & node,int64_t stage,const std::vector<AnfNodePtr> & out_input,const std::string & tag) const1137 AnfNodePtr PipelineTransformer::Reuse(const AnfNodePtr &node, int64_t stage, const std::vector<AnfNodePtr> &out_input,
1138                                       const std::string &tag) const {
1139   for (auto &input : out_input) {
1140     auto cnode = input->cast<CNodePtr>();
1141     if (!cnode) {
1142       continue;
1143     }
1144     if (IsPrimitiveCNode(cnode, prim::kPrimDepend)) {
1145       cnode = cnode->input(2)->cast<CNodePtr>();
1146     }
1147     if (cnode->input(1) == node) {
1148       auto dest_rank_send = GetValue<int64_t>(cnode->GetPrimalAttr(tag));
1149       if (dest_rank_send == stage) {
1150         return input;
1151       }
1152     }
1153   }
1154   return nullptr;
1155 }
1156 
ActualOp(const AnfNodePtr & node)1157 AnfNodePtr PipelineTransformer::ActualOp(const AnfNodePtr &node) {
1158   // skip some virtual op like:Depend, Load, Cast
1159   if (IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimCast) ||
1160       IsPrimitiveCNode(node, prim::kPrimLoad)) {
1161     auto cnode = node->cast<CNodePtr>();
1162     MS_EXCEPTION_IF_NULL(cnode);
1163     return ActualOp(cnode->input(1));
1164   }
1165   return node;
1166 }
1167 
IsParameterGraph(const AnfNodePtr & node) const1168 bool PipelineTransformer::IsParameterGraph(const AnfNodePtr &node) const {
1169   // ParameterGraph: graph which return a parameter
1170   MS_EXCEPTION_IF_NULL(node);
1171   CNodePtr call_node = nullptr;
1172   auto real_kernel = GetRealKernelNode(node, -1, &call_node).first;
1173   if (call_node != nullptr && real_kernel->isa<Parameter>()) {
1174     return true;
1175   }
1176   return false;
1177 }
1178 
HandleParameterGraph(const AnfNodePtr & node,const AnfNodePtr & use_node,int64_t stage,int64_t user_stage,const ValuePtr & micro,size_t pos,const std::vector<AnfNodePtr> & ops)1179 AnfNodePtr PipelineTransformer::HandleParameterGraph(const AnfNodePtr &node, const AnfNodePtr &use_node, int64_t stage,
1180                                                      int64_t user_stage, const ValuePtr &micro, size_t pos,
1181                                                      const std::vector<AnfNodePtr> &ops) {
1182   CNodePtr call_node = nullptr;
1183   auto argument = GetRealKernelNode(node, -1, &call_node).first;
1184 
1185   auto use_cnode = use_node->cast<CNodePtr>();
1186   MS_EXCEPTION_IF_NULL(use_cnode);
1187   if (!IsValueNode<FuncGraph>(use_cnode->input(0))) {
1188     MS_LOG(EXCEPTION) << "Parameter must be used by a graph, but got: " << use_cnode->DebugString();
1189   }
1190   auto use_graph = GetValueNode<FuncGraphPtr>(use_cnode->input(0));
1191   auto use_parameter_list = use_graph->parameters();
1192   auto parameter = use_parameter_list.at(pos - 1);
1193   // insert receive
1194   if (stage_ == user_stage) {
1195     auto recv = Reuse(argument, stage, ops, SRC_RANK);
1196     if (recv) {
1197       manager_->SetEdge(use_node, SizeToInt(pos), recv);
1198       return nullptr;
1199     }
1200     auto root_param = argument;
1201     if (argument->isa<Parameter>() && argument->func_graph() != root_) {
1202       root_param = GetArgumentsByParameter(argument);
1203     }
1204     (void)parameter_color_map_[root_param].insert(user_stage);
1205     auto graph = enable_share_cell_ ? shared_cell_ : main_graph_;
1206     auto recv_node = InsertReceive(graph, argument, use_node, SizeToInt(pos), user_stage, stage, micro, parameter);
1207     UpdateParameterSharedInfo(root_param, recv_node, false);
1208     return recv_node;
1209   }
1210   // insert send
1211   if (Reuse(argument, user_stage, ops, DEST_RANK)) {
1212     return nullptr;
1213   }
1214   auto send_out = InsertSend(argument, user_stage, stage_, micro);
1215   send_out.depend->set_user_data<Type>(DTYPE, send_out.type);
1216   send_out.depend->set_user_data<ValueList>(SHAPE, send_out.shape);
1217   UpdateParameterSharedInfo(argument, send_out.depend, true);
1218   return send_out.depend;
1219 }
1220 
CutBorderForNode(const FuncGraphPtr & graph,const AnfNodePtr & node,std::vector<AnfNodePtr> * send_ops,std::vector<AnfNodePtr> * receive_ops)1221 void PipelineTransformer::CutBorderForNode(const FuncGraphPtr &graph, const AnfNodePtr &node,
1222                                            std::vector<AnfNodePtr> *send_ops, std::vector<AnfNodePtr> *receive_ops) {
1223   auto stage_info = node->user_data<NodeStageInfo>();
1224   auto node_users = manager_->node_users()[node];
1225   AnfNodePtr receive = nullptr;
1226   for (auto &user_pair : node_users) {
1227     auto user_node = user_pair.first;
1228     auto node_stage = stage_info->stage();
1229     auto user_stage_info = user_node->user_data<NodeStageInfo>();
1230     if (user_stage_info == nullptr) {
1231       continue;
1232     }
1233     auto user_node_stage = user_stage_info->stage();
1234     if (node_stage != stage_ && user_node_stage != stage_) {
1235       continue;
1236     }
1237     auto micro = user_node->cast<CNodePtr>()->GetPrimalAttr(MICRO);
1238     if (!micro) {
1239       MS_LOG(INFO) << "Can't find micro_batch information, use micro(0)";
1240       micro = MakeValue(int64_t(0));
1241     }
1242     if (node_stage < user_node_stage) {
1243       if (node_stage == stage_) {
1244         if (IsParameterGraph(node)) {
1245           if (!is_train_ && !enable_share_cell_) {
1246             continue;
1247           }
1248           auto send_depend = HandleParameterGraph(node, user_node, node_stage, user_node_stage, micro,
1249                                                   IntToSize(user_pair.second), *send_ops);
1250           if (!send_depend) {
1251             continue;
1252           }
1253           (void)send_ops->insert(send_ops->cbegin(), send_depend);
1254           continue;
1255         }
1256         if (Reuse(node, user_node_stage, *send_ops, DEST_RANK)) {
1257           continue;
1258         }
1259         auto send_out = InsertSend(node, user_node_stage, node_stage, micro);
1260         MS_EXCEPTION_IF_NULL(send_out.depend);
1261         send_ops->push_back(send_out.depend);
1262         send_out.depend->set_user_data<Type>(DTYPE, send_out.type);
1263         send_out.depend->set_user_data<ValueList>(SHAPE, send_out.shape);
1264       } else {
1265         if (!receive) {
1266           if (IsParameterGraph(node)) {
1267             if (!is_train_ && !enable_share_cell_) {
1268               continue;
1269             }
1270             receive = HandleParameterGraph(node, user_node, node_stage, user_node_stage, micro,
1271                                            IntToSize(user_pair.second), *receive_ops);
1272             if (!receive) {
1273               continue;
1274             }
1275             receive_ops->push_back(receive);
1276           } else {
1277             receive = InsertReceive(graph, node, user_node, user_pair.second, user_node_stage, node_stage, micro, node);
1278             receive_ops->push_back(receive);
1279           }
1280         } else {
1281           manager_->SetEdge(user_node, user_pair.second, receive);
1282         }
1283       }
1284       continue;
1285     }
1286     if (node_stage > user_node_stage) {
1287       MS_LOG(EXCEPTION) << "node_stage: " << node_stage << " must be smaller than user_node_stage: " << user_node_stage;
1288     }
1289   }
1290 }
1291 
CutBorder(const FuncGraphPtr & graph)1292 std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> PipelineTransformer::CutBorder(const FuncGraphPtr &graph) {
1293   std::vector<AnfNodePtr> send_ops;
1294   std::vector<AnfNodePtr> receive_ops;
1295   auto ret = graph->get_return();
1296   MS_EXCEPTION_IF_NULL(ret);
1297   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
1298   std::reverse(all_nodes.begin(), all_nodes.end());
1299   for (auto &node : all_nodes) {
1300     auto stage_info = node->user_data<NodeStageInfo>();
1301     if (!node->isa<CNode>() || stage_info == nullptr || stage_info->stage() == -1 ||
1302         IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
1303       continue;
1304     }
1305     // Modify for lizard cyclomatic complexity.
1306     CutBorderForNode(graph, node, &send_ops, &receive_ops);
1307   }
1308   RemoveMonadNode();
1309   return std::make_pair(send_ops, receive_ops);
1310 }
1311 
CreateZeroseOutput(const AnfNodePtr & node,size_t index)1312 AnfNodePtr PipelineTransformer::CreateZeroseOutput(const AnfNodePtr &node, size_t index) {
1313   auto out_shapes = GetNodeShape(node);
1314   if (out_shapes.size() <= index) {
1315     MS_LOG(EXCEPTION) << "the index is out of range, the size of output_shapes is " << out_shapes.size()
1316                       << ", but the index is " << index;
1317   }
1318   auto out_shape = out_shapes.at(index);
1319   if (std::count(out_shape.cbegin(), out_shape.cend(), DYNAMIC_DIM_VAL) > 0) {
1320     MS_LOG(EXCEPTION) << "it is not supported that loss is not a scalar in dynamic shape and pipeline parallel "
1321                          "scenarios, the output shape is "
1322                       << out_shape;
1323   }
1324 
1325   // Modify output dimension when enable data parallel since only the last stage enable VirtualOutput redistribution.
1326   bool full_batch = ParallelContext::GetInstance()->full_batch();
1327   int64_t dev_num = full_batch ? 1 : g_device_manager->stage_device_num();
1328   if (dev_num == 0) {
1329     MS_LOG(EXCEPTION) << "Device num must be larger than 0, but get 0.";
1330   }
1331 
1332   if (!is_train_ && !out_shape.empty() && out_shape[0] % dev_num == 0) {
1333     out_shape[0] /= dev_num;
1334   }
1335 
1336   auto out_shape_type = GetShapeType(node, out_shape, index);
1337   auto zero_tensor = TensorConstructUtils::CreateZerosTensor(out_shape_type.second, out_shape);
1338   MS_EXCEPTION_IF_NULL(zero_tensor);
1339 
1340   auto value_node = NewValueNode(zero_tensor);
1341   MS_EXCEPTION_IF_NULL(value_node);
1342 
1343   // Build abstract from node to prevent confusion between Scalar and 0D-Tensor.
1344   auto abs = node->abstract()->Clone();
1345   MS_EXCEPTION_IF_NULL(abs);
1346   if (abs->isa<abstract::AbstractSequence>()) {
1347     auto elements = abs->cast<abstract::AbstractSequencePtr>()->elements();
1348     abs = elements.at(index)->Clone();
1349     MS_EXCEPTION_IF_NULL(abs);
1350   }
1351 
1352   abs->set_shape(std::make_shared<abstract::Shape>(out_shape));
1353   value_node->set_abstract(abs);
1354   return value_node;
1355 }
1356 
GetZeroOutputs(const FuncGraphPtr & graph)1357 AnfNodePtr PipelineTransformer::GetZeroOutputs(const FuncGraphPtr &graph) {
1358   // first: out node  second: getitem index
1359   auto real_kernel = GetRealKernelNode(graph->output(), -1);
1360   auto real_out = real_kernel.first;
1361   MS_EXCEPTION_IF_NULL(real_out);
1362   std::vector<AnfNodePtr> out_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
1363   if (IsPrimitiveCNode(real_out, prim::kPrimMakeTuple)) {
1364     auto real_out_cnode = real_out->cast<CNodePtr>();
1365     for (size_t i = 1; i < real_out_cnode->size(); ++i) {
1366       auto each_out_shapes = GetNodeShape(real_out_cnode->input(i));
1367       // In case: tuple's input is also a tuple
1368       if (each_out_shapes.size() > 1) {
1369         auto temp_tuple = CreateTupleZeroTensor(real_out_cnode->input(i), each_out_shapes.size());
1370         (void)out_tuple_inputs.emplace_back(temp_tuple);
1371         continue;
1372       }
1373       (void)out_tuple_inputs.emplace_back(CreateZeroseOutput(real_out_cnode->input(i), 0));
1374     }
1375   }
1376   if (out_tuple_inputs.size() > INDEX_ONE) {
1377     auto out_tuple = main_graph_->NewCNode(out_tuple_inputs);
1378     SetMakeTupleAbstract(out_tuple);
1379     return out_tuple;
1380   } else {
1381     auto real_out_shapes = GetNodeShape(real_out);
1382     AnfNodePtr out_tensor;
1383     // In case: op has multioutput
1384     if (real_out_shapes.size() > 1 && real_kernel.second == -1) {
1385       out_tensor = CreateTupleZeroTensor(real_out, real_out_shapes.size());
1386     } else {
1387       out_tensor = CreateZeroseOutput(real_out, 0);
1388     }
1389     return out_tensor;
1390   }
1391   return nullptr;
1392 }
1393 
GetOpInfoPair(const AnfNodePtr & node,const AnfNodePtr & graph_param,AnfNodePtr * care_node,bool * is_param)1394 std::pair<OperatorInfoPtr, int> PipelineTransformer::GetOpInfoPair(const AnfNodePtr &node,
1395                                                                    const AnfNodePtr &graph_param, AnfNodePtr *care_node,
1396                                                                    bool *is_param) {
1397   if (node->isa<Parameter>()) {
1398     return GetParameterPair(graph_param);
1399   } else {
1400     *care_node = FindPipelineCareNode(node);
1401     if ((*care_node)->isa<Parameter>()) {
1402       return GetParameterPair(*care_node);
1403     } else {
1404       *is_param = false;
1405       return GetOpInfo(*care_node);
1406     }
1407   }
1408 }
1409 
SetNodeAbstract(const std::vector<AnfNodePtr> & nodes)1410 void PipelineTransformer::SetNodeAbstract(const std::vector<AnfNodePtr> &nodes) {
1411   AbstractBasePtr abs;
1412   if (nodes.size() == 1) {
1413     auto cnode = nodes.front()->cast<CNodePtr>();
1414     MS_EXCEPTION_IF_NULL(cnode);
1415     abs = GetRealAbstract(cnode->input(INDEX_ONE));
1416   } else {
1417     AbstractBasePtrList abstract_list;
1418     abstract_list.resize(nodes.size());
1419     (void)std::transform(nodes.begin(), nodes.end(), abstract_list.begin(), [](const AnfNodePtr &node) {
1420       auto cnode = node->cast<CNodePtr>();
1421       MS_EXCEPTION_IF_NULL(cnode);
1422       return GetRealAbstract(cnode->input(INDEX_ONE));
1423     });
1424     abs = std::make_shared<abstract::AbstractTuple>(abstract_list);
1425   }
1426   for (auto &user : shared_cell_users_) {
1427     user->set_abstract(abs);
1428   }
1429 }
1430 
GenNewSendFromOld(const AnfNodePtr & node,const AnfNodePtr & input,const ValuePtr & value)1431 AnfNodePtr PipelineTransformer::GenNewSendFromOld(const AnfNodePtr &node, const AnfNodePtr &input,
1432                                                   const ValuePtr &value) {
1433   const auto &old = node->cast<CNodePtr>();
1434   MS_EXCEPTION_IF_NULL(old);
1435   auto old_is_pipeline_param = old->HasPrimalAttr(PIPELINE_PARAM);
1436   auto dest_rank_ptr = old->user_data<int64_t>(DEST_RANK);
1437   MS_EXCEPTION_IF_NULL(dest_rank_ptr);
1438   auto dest_rank = *dest_rank_ptr;
1439   auto send_tag = send_tag_map[dest_rank];
1440   send_tag_map[dest_rank]++;
1441   Attr attr_tag = std::make_pair(SR_TAG, MakeValue(send_tag));
1442   Attr attr_rank = std::make_pair(DEST_RANK, MakeValue(dest_rank));
1443   Attr attr_group = std::make_pair(GROUP, MakeValue(world_group_));
1444   Attr attr_group_back = std::make_pair(GROUP_BACK, MakeValue(world_group_));
1445   OperatorAttrs attrs = {attr_tag, attr_rank, attr_group, attr_group_back};
1446   std::vector<AnfNodePtr> send_input{input};
1447   auto send = CreateCNodeByInputsAndAttr(main_graph_, SEND, SEND, send_input, attrs);
1448   AnfNodePtr care_node;
1449   bool is_param = true;
1450   auto op_info_pair = GetOpInfoPair(input, input, &care_node, &is_param);
1451   auto tensor_info = GetTensorInfo(op_info_pair, is_param);
1452   auto op_info = op_info_pair.first;
1453   auto index = op_info_pair.second;
1454   auto slice_shape = tensor_info.slice_shape();
1455   auto shape_type_pair = GetShapeType(input, slice_shape, 0);
1456   auto prim = GetCNodePrimitive(send);
1457   prim->set_attr(SHAPE, shape_type_pair.first);
1458   prim->set_attr(DTYPE, shape_type_pair.second);
1459   if (!is_param) {
1460     if (old_is_pipeline_param) {
1461       MS_LOG(EXCEPTION) << "The old send is pipeline_param, but new send is not pipeline_param.";
1462     }
1463     send->AddPrimalAttr(PIPELINE_END, value);
1464   } else {
1465     if (!old_is_pipeline_param) {
1466       MS_LOG(EXCEPTION) << "The old send is not pipeline_param, but new send is pipeline_param.";
1467     }
1468     send->AddPrimalAttr(PARAM_INDEX, MakeValue(index));
1469     send->AddPrimalAttr(PIPELINE_PARAM, value);
1470     send->set_user_data<OperatorInfo>(op_info);
1471   }
1472   send->AddPrimalAttr(MICRO, value);
1473   auto abstract = input->abstract();
1474   if (care_node) {
1475     abstract = care_node->abstract();
1476   }
1477   send->set_abstract(abstract);
1478   return send;
1479 }
1480 
FetchSend(const AnfNodePtr & node,bool pipeline_param,bool single_pipeline_end,size_t end_index)1481 std::vector<AnfNodePtr> PipelineTransformer::FetchSend(const AnfNodePtr &node, bool pipeline_param,
1482                                                        bool single_pipeline_end, size_t end_index) {
1483   std::vector<AnfNodePtr> depends;
1484   AnfNodePtr send_input;
1485   if (pipeline_param) {
1486     auto param = node->user_data<AnfNode>(INPUT_PARAM);
1487     MS_EXCEPTION_IF_NULL(param);
1488     auto params = shared_cell_->parameters();
1489     auto iter = std::find(params.begin(), params.end(), param);
1490     if (iter != params.end()) {
1491       auto input_pos = std::distance(params.begin(), iter) + 1;
1492       auto &front = shared_cell_users_.front();
1493       MS_EXCEPTION_IF_NULL(front);
1494       const auto &user = front->cast<CNodePtr>();
1495       MS_EXCEPTION_IF_NULL(user);
1496       send_input = user->input(input_pos);
1497     } else {
1498       const auto &cnode = node->cast<CNodePtr>();
1499       MS_EXCEPTION_IF_NULL(cnode);
1500       send_input = cnode->input(INDEX_ONE);
1501     }
1502     MS_EXCEPTION_IF_NULL(send_input);
1503     auto value = MakeValue(int64_t(0));
1504     (void)(depends.emplace_back(GenNewSendFromOld(node, send_input, value)));
1505     return depends;
1506   }
1507   for (auto &user : shared_cell_users_) {
1508     auto cuser = user->cast<CNodePtr>();
1509     MS_EXCEPTION_IF_NULL(cuser);
1510     auto value = shared_cell_users_.size() > 1 ? cuser->GetPrimalAttr(MICRO) : MakeValue(int64_t(0));
1511     MS_EXCEPTION_IF_NULL(value);
1512     send_input = single_pipeline_end ? user : CreateTupleGetItemNode(main_graph_, user, end_index);
1513     (void)(depends.emplace_back(GenNewSendFromOld(node, send_input, value)));
1514   }
1515   return depends;
1516 }
1517 
HandleGraphOutputs(const std::vector<AnfNodePtr> & nodes)1518 void PipelineTransformer::HandleGraphOutputs(const std::vector<AnfNodePtr> &nodes) {
1519   std::vector<AnfNodePtr> pipeline_params;
1520   std::vector<AnfNodePtr> pipeline_ends;
1521   SeparateParamBorder(nodes, true, &pipeline_params, &pipeline_ends);
1522   std::vector<AnfNodePtr> sends;
1523   SetNodeAbstract(pipeline_ends);
1524 
1525   // Create root graph output before modify subgraph(shared cell).
1526   // This process order is crucial when the output of subgraph is directly used as root graph.
1527   auto zero_outputs = GetZeroOutputs(main_graph_);
1528 
1529   size_t ends_size = pipeline_ends.size();
1530   bool single_pipeline_end = ends_size == 1;
1531   if (single_pipeline_end) {
1532     auto &depend = pipeline_ends.front();
1533     const auto &cdepend = depend->cast<CNodePtr>();
1534     MS_EXCEPTION_IF_NULL(cdepend);
1535     (void)manager_->Replace(shared_cell_->output(), cdepend->input(INDEX_ONE));
1536   } else {
1537     std::vector<AnfNodePtr> rets;
1538     (void)std::transform(pipeline_ends.begin(), pipeline_ends.end(), std::back_inserter(rets),
1539                          [](const AnfNodePtr &depend) {
1540                            const auto &cdepend = depend->cast<CNodePtr>();
1541                            MS_EXCEPTION_IF_NULL(cdepend);
1542                            return cdepend->input(INDEX_ONE);
1543                          });
1544     auto out = CreateMakeTupleNode(shared_cell_, rets);
1545     (void)manager_->Replace(shared_cell_->output(), out);
1546   }
1547   for (auto &node : pipeline_params) {
1548     auto params = FetchSend(node, true, false, 0);
1549     if (is_train_) {
1550       (void)std::copy(params.begin(), params.end(), std::back_inserter(sends));
1551     }
1552   }
1553   for (size_t i = 0; i < ends_size; i++) {
1554     auto node = pipeline_ends[i];
1555     auto ends = FetchSend(node, false, single_pipeline_end, i);
1556     (void)std::copy(ends.begin(), ends.end(), std::back_inserter(sends));
1557   }
1558   auto make_tuple = CreateMakeTupleNode(main_graph_, sends);
1559   std::vector<AnfNodePtr> out = {NewValueNode(prim::kPrimDepend), zero_outputs, make_tuple};
1560   auto out_node = main_graph_->NewCNode(out);
1561   out_node->set_abstract(zero_outputs->abstract());
1562   (void)manager_->Replace(main_graph_->output(), out_node);
1563 }
1564 
GenNewRecvFromOld(const AnfNodePtr & node,const AnfNodePtr & input,const ValuePtr & value)1565 AnfNodePtr PipelineTransformer::GenNewRecvFromOld(const AnfNodePtr &node, const AnfNodePtr &input,
1566                                                   const ValuePtr &value) {
1567   auto cnode = node->cast<CNodePtr>();
1568   MS_EXCEPTION_IF_NULL(cnode);
1569   auto src_rank_ptr = cnode->user_data<int64_t>(SRC_RANK);
1570   MS_EXCEPTION_IF_NULL(src_rank_ptr);
1571   auto src_rank = *src_rank_ptr;
1572   auto recv_tag = recv_tag_map[src_rank];
1573   recv_tag_map[src_rank]++;
1574   auto dtype = node->user_data<Type>(SLICE_DTYPE);
1575   auto slice_shape = *(cnode->user_data<Shape>(SLICE_SHAPE));
1576   auto shape = GetShapeValue(slice_shape);
1577   Attr attr_tag = std::make_pair(SR_TAG, MakeValue(recv_tag));
1578   Attr attr_rank = std::make_pair(SRC_RANK, MakeValue(src_rank));
1579   Attr attr_shape = std::make_pair(SHAPE, shape);
1580   Attr attr_dtype = std::make_pair(DTYPE, dtype);
1581   Attr attr_group = std::make_pair(GROUP, MakeValue(world_group_));
1582   Attr attr_group_back = std::make_pair(GROUP_BACK, MakeValue(world_group_));
1583   OperatorAttrs attrs = {attr_tag, attr_rank, attr_shape, attr_dtype, attr_group, attr_group_back};
1584 
1585   std::vector<AnfNodePtr> recv_input = {input};
1586   auto recv = CreateCNodeByInputsAndAttr(main_graph_, RECEIVE, RECEIVE, recv_input, attrs);
1587   auto tensor_layout = node->user_data<TensorLayout>();
1588   if (cnode->HasPrimalAttr(PIPELINE_PARAM)) {
1589     auto abstract_clone = node->abstract()->Clone();
1590     MS_EXCEPTION_IF_NULL(abstract_clone);
1591     recv->set_user_data<AnfNode>(PIPELINE_PARAM, recv_input[INDEX_ZERO]);
1592     recv->AddPrimalAttr(PIPELINE_PARAM, value);
1593     recv_input[INDEX_ZERO]->set_abstract(abstract_clone);
1594     recv_input[INDEX_ZERO]->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(*tensor_layout));
1595   } else {
1596     recv->AddPrimalAttr(PIPELINE_BEGIN, value);
1597   }
1598   auto abstract_clone = node->abstract()->Clone();
1599   MS_EXCEPTION_IF_NULL(abstract_clone);
1600   recv->set_abstract(abstract_clone);
1601 
1602   recv->AddPrimalAttr(MICRO, value);
1603   recv->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(*tensor_layout));
1604   recv->set_user_data<OperatorInfo>(node->user_data<OperatorInfo>());
1605   return recv;
1606 }
1607 
FetchRecv(const AnfNodePtr & node,bool pipeline_param)1608 std::vector<AnfNodePtr> PipelineTransformer::FetchRecv(const AnfNodePtr &node, bool pipeline_param) {
1609   std::vector<AnfNodePtr> recvs;
1610   AnfNodePtr recv_input;
1611   AnfNodePtr recv;
1612   if (pipeline_param) {
1613     auto value = MakeValue(int64_t(0));
1614     auto param = node->user_data<AnfNode>(INPUT_PARAM);
1615     MS_EXCEPTION_IF_NULL(param);
1616     auto &front = shared_cell_users_.front();
1617     MS_EXCEPTION_IF_NULL(front);
1618     const auto &user = front->cast<CNodePtr>();
1619     MS_EXCEPTION_IF_NULL(user);
1620     auto params = shared_cell_->parameters();
1621     auto user_inputs = user->inputs();
1622     auto iter = std::find(user_inputs.begin(), user_inputs.end(), param);
1623     if (iter != user_inputs.end()) {
1624       auto input_pos = std::distance(user_inputs.begin(), iter);
1625       auto argu = params.at(input_pos - 1);
1626       manager_->SetEdge(node, 1, argu);
1627       node->set_user_data<AnfNode>(INPUT_PARAM, argu);
1628       recv_input = user->input(input_pos);
1629       recv = GenNewRecvFromOld(node, recv_input, value);
1630       for (auto &share_user : shared_cell_users_) {
1631         if (is_train_) {
1632           manager_->SetEdge(share_user, input_pos, recv);
1633         } else {
1634           manager_->SetEdge(share_user, input_pos, recv_input);
1635         }
1636       }
1637       node->set_user_data<bool>(ORIGIN_INPUT_IS_PARAM, std::make_shared<bool>(true));
1638     } else {
1639       const auto &cnode = node->cast<CNodePtr>();
1640       MS_EXCEPTION_IF_NULL(cnode);
1641       recv_input = cnode->input(INDEX_ONE);
1642       recv = GenNewRecvFromOld(node, recv_input, value);
1643     }
1644     (void)(recvs.emplace_back(recv));
1645     return recvs;
1646   }
1647   for (auto &user : shared_cell_users_) {
1648     auto cuser = user->cast<CNodePtr>();
1649     MS_EXCEPTION_IF_NULL(cuser);
1650     auto value = shared_cell_users_.size() > 1 ? cuser->GetPrimalAttr(MICRO) : MakeValue(int64_t(0));
1651     MS_EXCEPTION_IF_NULL(value);
1652     if (enable_share_cell_ || !is_train_) {
1653       auto recv_tensor = TensorConstructUtils::CreateZerosTensor(kFloat16, {1});
1654       recv = GenNewRecvFromOld(node, NewValueNode(recv_tensor), value);
1655     } else {
1656       recv = GenNewRecvFromOld(node, virtual_param_, value);
1657     }
1658     (void)(recvs.emplace_back(recv));
1659   }
1660   return recvs;
1661 }
1662 
ResetSharedCellParamAndArgu(const std::vector<std::vector<AnfNodePtr>> & pipeline_begins_fetched,const std::vector<AnfNodePtr> & newly_added_params,const std::vector<AnfNodePtr> & reserved_inputs)1663 void PipelineTransformer::ResetSharedCellParamAndArgu(
1664   const std::vector<std::vector<AnfNodePtr>> &pipeline_begins_fetched,
1665   const std::vector<AnfNodePtr> &newly_added_params, const std::vector<AnfNodePtr> &reserved_inputs) {
1666   // set shared_cell_ parameters, and call_input
1667   auto params = shared_cell_->parameters();
1668   auto ret = shared_cell_->get_return();
1669   MS_EXCEPTION_IF_NULL(ret);
1670   std::vector<AnfNodePtr> searched_params;
1671   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
1672   for (auto &node : all_nodes) {
1673     if (node->isa<Parameter>()) {
1674       searched_params.push_back(node);
1675     }
1676   }
1677   std::set<size_t> reserved_param_index;
1678   std::vector<AnfNodePtr> new_params;
1679   std::vector<AnfNodePtr> monad_params;
1680   // set shared_cell_ parameters
1681   for (size_t i = 0; i < params.size(); i++) {
1682     auto param = params[i];
1683     if (std::find(searched_params.begin(), searched_params.end(), param) == searched_params.end()) {
1684       continue;
1685     }
1686     if (HasAbstractMonad(param)) {
1687       monad_params.push_back(param);
1688     } else {
1689       new_params.push_back(param);
1690     }
1691     (void)(reserved_param_index.insert(i));
1692   }
1693   (void)(new_params.insert(new_params.end(), newly_added_params.begin(), newly_added_params.end()));
1694   (void)(new_params.insert(new_params.end(), monad_params.begin(), monad_params.end()));
1695   MS_LOG(DEBUG) << "The shared cell origin params size is " << params.size() << ", new params size is "
1696                 << new_params.size();
1697   manager_->SetParameters(shared_cell_, new_params);
1698   shared_cell_->set_fv_param_count(new_params.size());
1699   // set call inputs
1700   size_t user_index = 0;
1701   for (auto &user : shared_cell_users_) {
1702     auto cuser = user->cast<CNodePtr>();
1703     MS_EXCEPTION_IF_NULL(cuser);
1704     const auto &old_inputs = cuser->inputs();
1705     std::vector<AnfNodePtr> new_inputs{old_inputs.front()};
1706     std::vector<AnfNodePtr> monad_inputs;
1707     for (size_t i = 1; i < old_inputs.size(); i++) {
1708       if (reserved_param_index.find(i - 1) == reserved_param_index.end()) {
1709         continue;
1710       }
1711       auto old_input = old_inputs[i];
1712       if (HasAbstractMonad(old_input)) {
1713         monad_inputs.push_back(old_input);
1714       } else {
1715         new_inputs.push_back(old_input);
1716       }
1717     }
1718     auto newly_added_inputs = reserved_inputs;
1719     auto begins = pipeline_begins_fetched.at(user_index);
1720     (void)(newly_added_inputs.insert(newly_added_inputs.end(), begins.begin(), begins.end()));
1721     (void)(newly_added_inputs.insert(newly_added_inputs.end(), monad_inputs.begin(), monad_inputs.end()));
1722     (void)(new_inputs.insert(new_inputs.end(), newly_added_inputs.begin(), newly_added_inputs.end()));
1723     auto new_call = main_graph_->NewCNode(new_inputs);
1724     new_call->set_attrs(cuser->attrs());
1725     new_call->set_primal_attrs(cuser->primal_attrs());
1726     new_call->set_abstract(cuser->abstract());
1727     (void)manager_->Replace(user, new_call);
1728     user_index++;
1729   }
1730 }
1731 
HandleGraphInputs(const std::vector<AnfNodePtr> & recv_ops)1732 void PipelineTransformer::HandleGraphInputs(const std::vector<AnfNodePtr> &recv_ops) {
1733   std::vector<AnfNodePtr> pipeline_params;
1734   std::vector<AnfNodePtr> pipeline_begins;
1735   SeparateParamBorder(recv_ops, false, &pipeline_params, &pipeline_begins);
1736 
1737   // reserved inputs
1738   std::vector<AnfNodePtr> reserved_inputs;
1739   // pipeline_param whose input is a parameter
1740   std::vector<AnfNodePtr> pipeline_params_with_param_input;
1741   std::vector<AnfNodePtr> need_link_to_new_param;
1742 
1743   for (auto &node : pipeline_params) {
1744     auto recvs = FetchRecv(node, true);
1745     auto cnode = node->cast<CNodePtr>();
1746     MS_EXCEPTION_IF_NULL(cnode);
1747     if (cnode->has_user_data(ORIGIN_INPUT_IS_PARAM)) {
1748       pipeline_params_with_param_input.push_back(node);
1749     } else {
1750       (void)(reserved_inputs.insert(reserved_inputs.end(), recvs.begin(), recvs.end()));
1751       need_link_to_new_param.push_back(node);
1752     }
1753   }
1754   (void)(need_link_to_new_param.insert(need_link_to_new_param.end(), pipeline_begins.begin(), pipeline_begins.end()));
1755 
1756   size_t begin_size = pipeline_begins.size();
1757   // The 0th dimension corresponds to shared_cell users
1758   // The first dimension corresponds to recvs
1759   // user0: recv0_0, recv0_1
1760   // user1: recv1_0, recv1_1
1761   size_t shared_cell_users_size = shared_cell_users_.size();
1762   std::vector<std::vector<AnfNodePtr>> pipeline_begins_fetched(shared_cell_users_size, std::vector<AnfNodePtr>());
1763   for (size_t i = 0; i < begin_size; i++) {
1764     auto node = pipeline_begins[i];
1765     auto begins = FetchRecv(node, false);
1766     for (size_t j = 0; j < shared_cell_users_size; j++) {
1767       pipeline_begins_fetched[j].push_back(begins.at(j));
1768     }
1769   }
1770   auto &node_users_map = manager_->node_users();
1771   // relink pipeline_param_with_param_input's users to its input
1772   for (const auto &param : pipeline_params_with_param_input) {
1773     const auto &users = node_users_map[param];
1774     auto input = param->user_data<AnfNode>(INPUT_PARAM);
1775     MS_EXCEPTION_IF_NULL(input);
1776     for (const auto &user : users) {
1777       manager_->SetEdge(user.first, user.second, input);
1778     }
1779   }
1780 
1781   std::vector<AnfNodePtr> newly_added_params;
1782   // relink pipeline_param_without_param_input and pipeline_begins's users to new parameter
1783   for (const auto &node : need_link_to_new_param) {
1784     auto param = std::make_shared<Parameter>(shared_cell_);
1785     param->set_abstract(node->abstract()->Clone());
1786     newly_added_params.push_back(param);
1787     const auto &users = node_users_map[node];
1788     for (const auto &user : users) {
1789       manager_->SetEdge(user.first, user.second, param);
1790     }
1791   }
1792   ResetSharedCellParamAndArgu(pipeline_begins_fetched, newly_added_params, reserved_inputs);
1793 }
1794 
CreateTupleZeroTensor(const AnfNodePtr & node,size_t index)1795 AnfNodePtr PipelineTransformer::CreateTupleZeroTensor(const AnfNodePtr &node, size_t index) {
1796   std::vector<AnfNodePtr> temp_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
1797   auto out_shapes = GetNodeShape(node);
1798   for (size_t ele = 0; ele < out_shapes.size(); ++ele) {
1799     temp_tuple_inputs.emplace_back(CreateZeroseOutput(node, ele));
1800   }
1801   auto temp_tuple = main_graph_->NewCNode(temp_tuple_inputs);
1802   SetMakeTupleAbstract(temp_tuple);
1803   return temp_tuple;
1804 }
1805 
CutGraph()1806 void PipelineTransformer::CutGraph() {
1807   world_group_ = GetWorldGroup();
1808   auto send_recv_shared_param = HandleSharedParameter();
1809   auto graph = enable_share_cell_ ? shared_cell_ : main_graph_;
1810   MS_EXCEPTION_IF_NULL(graph);
1811   auto send_recv_cut_border = CutBorder(graph);
1812   std::vector<AnfNodePtr> send_ops;
1813 
1814   (void)(send_ops.insert(send_ops.end(), send_recv_shared_param.first.begin(), send_recv_shared_param.first.end()));
1815   (void)(send_ops.insert(send_ops.end(), send_recv_cut_border.first.begin(), send_recv_cut_border.first.end()));
1816   if (IsLastStage() && !enable_share_cell_) {
1817     return;
1818   }
1819   if (!send_ops.empty()) {
1820     type_ptr_ = send_ops.back()->user_data<Type>(DTYPE);
1821     shape_ = send_ops.back()->user_data<ValueList>(SHAPE);
1822   }
1823   if (!enable_share_cell_) {
1824     auto make_tuple = CreateMakeTupleNode(main_graph_, send_ops);
1825     auto zero_outputs = GetZeroOutputs(main_graph_);
1826     std::vector<AnfNodePtr> out = {NewValueNode(prim::kPrimDepend), zero_outputs, make_tuple};
1827     auto out_node = main_graph_->NewCNode(out);
1828     (void)manager_->Replace(main_graph_->output(), out_node);
1829     return;
1830   }
1831   if (!IsLastStage()) {
1832     HandleGraphOutputs(send_ops);
1833   }
1834   std::vector<AnfNodePtr> recv_ops;
1835 
1836   (void)(recv_ops.insert(recv_ops.end(), send_recv_shared_param.second.begin(), send_recv_shared_param.second.end()));
1837   (void)(recv_ops.insert(recv_ops.end(), send_recv_cut_border.second.begin(), send_recv_cut_border.second.end()));
1838   HandleGraphInputs(recv_ops);
1839 }
1840 
ElimGraphStage()1841 void PipelineTransformer::ElimGraphStage() {
1842   for (auto &fg : manager_->func_graphs()) {
1843     fg->set_stage(-1);
1844     fg->set_segment(-1);
1845   }
1846 }
1847 
RedundancyNode(const AnfNodePtr & node,mindspore::HashMap<CNodePtr,std::vector<AnfNodePtr>> * make_tuple_map)1848 void PipelineTransformer::RedundancyNode(const AnfNodePtr &node,
1849                                          mindspore::HashMap<CNodePtr, std::vector<AnfNodePtr>> *make_tuple_map) {
1850   auto node_users = manager_->node_users()[node];
1851   for (auto &node_user_pair : node_users) {
1852     auto cnode = node_user_pair.first->cast<CNodePtr>();
1853     // node->UpdateState, replaced node wiht U.
1854     auto fg = cnode->func_graph();
1855     MS_EXCEPTION_IF_NULL(fg);
1856     if (fg->stage() != -1 && fg != main_graph_) {
1857       continue;
1858     }
1859     if (IsPrimitiveCNode(cnode, prim::kPrimUpdateState)) {
1860       auto abs = cnode->abstract();
1861       MS_EXCEPTION_IF_NULL(abs);
1862       auto monad_node = NewValueNode(kUMonad);
1863       if (abs->isa<abstract::AbstractIOMonad>()) {
1864         monad_node = NewValueNode(kIOMonad);
1865       }
1866       manager_->SetEdge(cnode, node_user_pair.second, monad_node);
1867       continue;
1868     }
1869     // node->make_tuple, record with a map, Unified deleted later.
1870     if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) || IsPrimitiveCNode(cnode, prim::kPrimMakeList)) {
1871       if (make_tuple_map->find(cnode) == (*make_tuple_map).end()) {
1872         (*make_tuple_map)[cnode] = {node};
1873       } else {
1874         (*make_tuple_map)[cnode].push_back(node);
1875       }
1876     } else {
1877       RedundancyNode(node_user_pair.first, make_tuple_map);
1878     }
1879   }
1880 }
1881 
IsRedundancyParameter(const AnfNodePtr & parameter,const std::vector<AnfNodePtr> & non_cloned_parameters)1882 bool PipelineTransformer::IsRedundancyParameter(const AnfNodePtr &parameter,
1883                                                 const std::vector<AnfNodePtr> &non_cloned_parameters) {
1884   // RedundancyParameter: other stage's parameters included corresponding cloned parameters.
1885   auto param_ptr = parameter->cast<ParameterPtr>();
1886   MS_EXCEPTION_IF_NULL(param_ptr);
1887   if (!param_ptr->has_default()) {
1888     return false;
1889   }
1890   std::set<int64_t> stage_set;
1891   if (!ParameterIsCloned(parameter)) {
1892     stage_set = parameter_color_map_.at(parameter);
1893   } else {
1894     auto parameters = root_->parameters();
1895     auto param_name = param_ptr->name();
1896     auto non_clone_name = param_name.substr(param_name.find_first_of('.') + 1);
1897     for (auto &param : non_cloned_parameters) {
1898       auto non_cloned_param = param->cast<ParameterPtr>();
1899       if (non_clone_name != non_cloned_param->name()) {
1900         continue;
1901       }
1902       stage_set = parameter_color_map_.at(param);
1903       break;
1904     }
1905   }
1906   if (stage_set.empty()) {
1907     return false;
1908   }
1909   return stage_set.count(stage_) == 0;
1910 }
1911 
HasNoUpdateParameter()1912 bool PipelineTransformer::HasNoUpdateParameter() {
1913   auto parameters = root_->parameters();
1914   for (auto &parameter : parameters) {
1915     if (ParameterIsCloned(parameter)) {
1916       continue;
1917     }
1918     auto param_info = parameter->cast<ParameterPtr>()->param_info();
1919     if (!param_info) {
1920       continue;
1921     }
1922     auto stage_set = parameter_color_map_.at(parameter);
1923     auto requires_grad = param_info->requires_grad();
1924     if (requires_grad && stage_set.count(stage_)) {
1925       return false;
1926     }
1927   }
1928   return true;
1929 }
1930 
FreezeGradient()1931 void PipelineTransformer::FreezeGradient() {
1932   auto node_users_map = manager_->node_users();
1933   if (HasNoUpdateParameter() && is_train_) {
1934     root_->set_flag(NO_UPDATE, true);
1935     auto nodes = root_->nodes();
1936     for (auto &node : nodes) {
1937       if (!IsPrimitiveCNode(node, prim::kPrimJ)) {
1938         continue;
1939       }
1940       auto node_users = node_users_map.at(node);
1941       auto grad_users = node_users_map.at(node_users.front().first);
1942       for (auto &grad_user : grad_users) {
1943         auto user_node = grad_user.first->cast<CNodePtr>();
1944         if (!IsPrimitiveCNode(user_node, prim::kPrimTupleGetItem)) {
1945           continue;
1946         }
1947         auto index = GetTupleGetItemIndex(user_node);
1948         if (index != 1) {
1949           continue;
1950         }
1951         auto temp = node_users_map.at(user_node).front().first;
1952         auto out = root_->output();
1953         std::vector<AnfNodePtr> depend_input = {NewValueNode(prim::kPrimDepend), out, temp};
1954         auto new_node = root_->NewCNode(depend_input);
1955         manager_->Replace(out, new_node);
1956         break;
1957       }
1958       break;
1959     }
1960     for (auto &node : nodes) {
1961       if (!IsPrimitiveCNode(node, prim::kPrimNPUGetFloatStatusV2)) {
1962         continue;
1963       }
1964       auto cnode = node->cast<CNodePtr>();
1965       auto out_cnode = root_->output()->cast<CNodePtr>();
1966       auto grads = out_cnode->input(INDEX_TWO);
1967       std::vector<AnfNodePtr> depend_input = {NewValueNode(prim::kPrimDepend), cnode->input(1), grads};
1968       auto new_node = root_->NewCNode(depend_input);
1969       new_node->set_abstract(cnode->input(1)->abstract());
1970       manager_->Replace(cnode->input(1), new_node);
1971       break;
1972     }
1973   }
1974 }
1975 
ElimParameter()1976 void PipelineTransformer::ElimParameter() {
1977   auto parameters = root_->parameters();
1978   mindspore::HashMap<CNodePtr, std::vector<AnfNodePtr>> make_tuple_map;
1979   std::vector<AnfNodePtr> non_cloned_parameters;
1980   FreezeGradient();
1981   auto node_users_map = manager_->node_users();
1982   for (auto &parameter : parameters) {
1983     if (ParameterIsCloned(parameter)) {
1984       continue;
1985     }
1986     non_cloned_parameters.push_back(parameter);
1987   }
1988   for (auto &parameter : parameters) {
1989     if (!IsRedundancyParameter(parameter, non_cloned_parameters)) {
1990       continue;
1991     }
1992     MS_LOG(INFO) << "Parameter:" << parameter->DebugString() << " is Redundancy.";
1993     RedundancyNode(parameter, &make_tuple_map);
1994   }
1995   for (auto &temp : make_tuple_map) {
1996     auto make_tuple = temp.first;
1997     auto fg = make_tuple->func_graph();
1998     MS_EXCEPTION_IF_NULL(fg);
1999     auto remove_vector = temp.second;
2000     if (remove_vector.empty()) {
2001       continue;
2002     }
2003     auto make_tuple_user = node_users_map.at(make_tuple).front().first;
2004     auto make_tuple_inputs = make_tuple->inputs();
2005     std::vector<AnfNodePtr> new_inputs;
2006     for (auto &input : make_tuple_inputs) {
2007       if (std::find(remove_vector.begin(), remove_vector.end(), input) == remove_vector.end()) {
2008         new_inputs.push_back(input);
2009         continue;
2010       }
2011       if (root_->has_flag(NO_UPDATE) && IsPrimitiveCNode(make_tuple_user, prim::kPrimAddN)) {
2012         new_inputs.push_back(CreateZeroseOutput(input, 0));
2013       }
2014     }
2015     auto new_make_tuple = fg->NewCNode(new_inputs);
2016     (void)manager_->Replace(make_tuple, new_make_tuple);
2017   }
2018 }
2019 
ModifyParameterList()2020 void PipelineTransformer::ModifyParameterList() {
2021   ElimParameter();
2022   auto parameters = root_->parameters();
2023   std::vector<AnfNodePtr> parameter_list;
2024   for (auto &parameter : parameters) {
2025     auto param = parameter->cast<ParameterPtr>();
2026     MS_EXCEPTION_IF_NULL(param);
2027     if (!manager_->node_users()[parameter].empty() || !param->has_default()) {
2028       parameter_list.push_back(parameter);
2029     }
2030   }
2031   auto del_num = parameters.size() - parameter_list.size();
2032   root_->set_fv_param_count(root_->fv_param_count() - del_num);
2033   manager_->SetParameters(root_, parameter_list);
2034 }
2035 }  // namespace parallel
2036 }  // namespace mindspore
2037