• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/pipeline_transformer/pipeline_interleave.h"
18 #include <set>
19 #include <vector>
20 #include <string>
21 #include <utility>
22 #include <algorithm>
23 #include <memory>
24 #include "mindspore/core/ops/sequence_ops.h"
25 #include "mindspore/core/ops/other_ops.h"
26 #include "mindspore/core/ops/nn_ops.h"
27 #include "mindspore/core/ops/array_ops.h"
28 #include "mindspore/core/ops/framework_ops.h"
29 #include "mindspore/core/ops/arithmetic_ops.h"
30 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
31 #include "frontend/parallel/ops_info/ops_utils.h"
32 #include "frontend/parallel/group_manager.h"
33 #include "frontend/parallel/parameter_manager.h"
34 #include "include/common/utils/parallel_context.h"
35 #include "frontend/parallel/step_parallel.h"
36 #include "frontend/parallel/node_check.h"
37 #include "frontend/parallel/graph_util/node_info.h"
38 #include "frontend/parallel/graph_util/graph_info.h"
39 #include "frontend/parallel/graph_util/pipeline_split_utils.h"
40 #include "frontend/parallel/step_parallel_utils.h"
41 #include "frontend/parallel/graph_util/graph_splitter.h"
42 #include "ir/anf.h"
43 #include "ir/graph_utils.h"
44 #include "ir/func_graph_cloner.h"
45 #include "include/common/utils/comm_manager.h"
46 #include "utils/ms_context.h"
47 #include "utils/tensor_construct_utils.h"
48 #include "mindspore/core/utils/parallel_node_check.h"
49 
50 namespace mindspore {
51 namespace parallel {
GetRealAbstract(const AnfNodePtr & node)52 static AbstractBasePtr GetRealAbstract(const AnfNodePtr &node) {
53   if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
54     auto &input = node->cast<CNodePtr>()->input(1);
55     MS_EXCEPTION_IF_NULL(input);
56     return input->abstract();
57   }
58   return node->abstract();
59 }
60 
MainGraph()61 bool PipelineInterleave::MainGraph() {
62   bool find_main_graph = false;
63   for (auto &fg : manager_->func_graphs()) {
64     for (auto &node : fg->nodes()) {
65       if (IsPrimitiveCNode(node, prim::kPrimVirtualDataset)) {
66         main_graph_ = fg;
67         main_graph_->set_flag(MAIN_GRAPH, true);
68         virtual_dataset_ = node;
69         find_main_graph = true;
70         break;
71       }
72     }
73     if (find_main_graph) {
74       break;
75     }
76   }
77   if (!find_main_graph) {
78     MS_LOG(WARNING) << "Can't find main graph, possible reason is can't find virtual dataset.";
79     return false;
80   }
81   auto value_nodes = main_graph_->value_nodes();
82   for (auto value_pair = value_nodes.cbegin(); value_pair != value_nodes.cend(); ++value_pair) {
83     auto node = (*value_pair).first;
84     if (!IsValueNode<FuncGraph>(node)) {
85       continue;
86     }
87     auto graph = GetValueNode<FuncGraphPtr>(node);
88     MS_EXCEPTION_IF_NULL(graph);
89     if (!graph->has_flag(FUNC_GRAPH_FLAG_CELL_REUSE)) {
90       continue;
91     }
92     shared_cell_ = graph;
93     break;
94   }
95   if (!shared_cell_) {
96     MS_LOG(ERROR) << "Pipeline parallel now only support shared_cell.";
97     auto parallel_context = parallel::ParallelContext::GetInstance();
98     MS_EXCEPTION_IF_NULL(parallel_context);
99     auto is_pp_interleave = parallel_context->pipeline_interleave();
100     if (is_pp_interleave) {
101       MS_LOG(EXCEPTION) << "Using pipeline parallel with interleave, should enable lazy_inline.";
102     }
103     return false;
104   }
105   return true;
106 }
107 
CreateSendReceiveGroup()108 void PipelineInterleave::CreateSendReceiveGroup() {
109   MS_EXCEPTION_IF_NULL(g_device_manager);
110   auto rank_list = g_device_manager->GetDeviceListBetweenStage();
111   auto dev_list = g_device_manager->CreateDeviceListByRankList(rank_list);
112   Group forward_send_group;
113   if (g_device_manager->CreateGroup(rank_list, &forward_send_group) != SUCCESS) {
114     MS_LOG(EXCEPTION) << "Create forward Send communication group failed, the rank list is: " << rank_list;
115   }
116   group_.emplace_back(forward_send_group.name());
117 
118   Group backward_send_group;
119   auto backward_send_group_name = forward_send_group.name() + BACKWARD;
120   if (g_device_manager->CreateGroup(backward_send_group_name, dev_list, &backward_send_group) != SUCCESS) {
121     MS_LOG(EXCEPTION) << "Create backward Send communication group failed, the rank list is: " << rank_list;
122   }
123   group_.emplace_back(backward_send_group_name);
124 
125   Group forward_recv_group;
126   auto forward_recv_group_name = forward_send_group.name() + RECEIVE;
127   if (g_device_manager->CreateGroup(forward_recv_group_name, dev_list, &forward_recv_group) != SUCCESS) {
128     MS_LOG(EXCEPTION) << "Create forward Receive communication group failed, the rank list is: " << rank_list;
129   }
130   group_.emplace_back(forward_recv_group_name);
131 
132   Group backward_recv_group;
133   auto backward_recv_group_name = forward_recv_group_name + BACKWARD;
134   if (g_device_manager->CreateGroup(backward_recv_group_name, dev_list, &backward_recv_group) != SUCCESS) {
135     MS_LOG(EXCEPTION) << "Create backward Receive communication group failed, the rank list is: " << rank_list;
136   }
137   group_.emplace_back(backward_recv_group_name);
138 }
139 
SetMicroBatch(const AnfNodePtr & node,int64_t micro_size,size_t batch_axis) const140 ValuePtr PipelineInterleave::SetMicroBatch(const AnfNodePtr &node, int64_t micro_size, size_t batch_axis) const {
141   if (!IsPrimitiveCNode(node, prim::kPrimStridedSlice)) {
142     MS_LOG(EXCEPTION) << "Can't find MicroBatch information.";
143   }
144   auto cnode = node->cast<CNodePtr>();
145 
146   int64_t micro = 0;
147   auto value = GetValueNode(cnode->input(2));
148   if (value != nullptr) {
149     auto tuple = GetValue<std::vector<int64_t>>(value);  // begin
150     auto input_tmp = GetNodeShape(cnode->input(1));
151     auto input_shape = input_tmp.at(0);
152     auto slice_batch_size = input_shape.at(batch_axis);  // betch shape
153     if (slice_batch_size == 0) {
154       MS_LOG(EXCEPTION) << "slice_batch_size should be a positive integer, but got " << slice_batch_size;
155     }
156     micro = tuple.at(batch_axis) * micro_size / slice_batch_size;  // micro-index
157   } else {
158     // dynamic shape
159     // if micro is not 1: stridedslice --> maketuple --> scalarmul --> micro
160     // if micro is 1: stridedslice --> maketuple --> scalarfloordiv
161     if (!IsPrimitiveCNode(cnode->input(2), prim::kPrimMakeTuple)) {
162       MS_LOG(EXCEPTION) << "The begin of stridedslice is not constant value, and not make tuple";
163     }
164     auto make_tuple_cnode = cnode->input(2)->cast<CNodePtr>();
165     if (IsPrimitiveCNode(make_tuple_cnode->input(1), prim::kPrimScalarMul)) {
166       auto scalar_mul_cnode = make_tuple_cnode->input(1)->cast<CNodePtr>();
167       auto mul_value = GetValueNode(scalar_mul_cnode->input(2));
168       micro = GetValue<int64_t>(mul_value);
169     } else if (IsPrimitiveCNode(make_tuple_cnode->input(1), prim::kPrimScalarFloorDiv)) {
170       micro = 1;
171     } else {
172       MS_LOG(EXCEPTION) << "Can not find the micro info, the input op of make tuple is "
173                         << GetCNodePrimitive(make_tuple_cnode->input(1))->name();
174     }
175   }
176 
177   cnode->AddPrimalAttr(MICRO, MakeValue(micro));
178   cnode->AddPrimalAttr(PIPELINE_BEGIN, MakeValue(micro));
179   int64_t seg = 0;
180   cnode->AddPrimalAttr(SEGMENT, MakeValue(seg));
181   return MakeValue(micro);
182 }
183 
Init()184 void PipelineInterleave::Init() {
185   auto ms_context = MsContext::GetInstance();
186   MS_EXCEPTION_IF_NULL(ms_context);
187   world_group_ = GetWorldGroup();
188   uint32_t world_rank_size = 0;
189   global_rank_ = parallel::ParallelContext::GetInstance()->global_rank();
190   uint32_t rank_id = 0;
191   if (!parallel::ParallelContext::GetInstance()->global_rank_is_set()) {
192     if (!CommManager::GetInstance().GetRankID(world_group_, &rank_id)) {
193       MS_LOG(EXCEPTION) << "Get rank id failed.";
194     }
195     global_rank_ = UintToInt(rank_id);
196   }
197   int64_t device_num = 0;
198   auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num();
199   if (!parallel::ParallelContext::GetInstance()->device_num_is_set()) {
200     if (!CommManager::GetInstance().GetRankSize(world_group_, &world_rank_size)) {
201       MS_LOG(EXCEPTION) << "Get rank size failed";
202     }
203     device_num = UintToInt(world_rank_size);
204     MS_LOG(INFO) << "Get device num from communication model, the device num is  " << device_num;
205   } else {
206     device_num = parallel::ParallelContext::GetInstance()->device_num();
207   }
208   per_stage_rank_num_ = device_num / stage_num;
209   return;
210 }
211 
GetBatchAxisForInput(const AnfNodeIndexSet & input_node_users) const212 size_t PipelineInterleave::GetBatchAxisForInput(const AnfNodeIndexSet &input_node_users) const {
213   Shapes inputs_tuple;
214   for (const auto &input_node_user : input_node_users) {
215     auto node = input_node_user.first;
216     if (!IsPrimitiveCNode(node, prim::kPrimStridedSlice)) {
217       return 0;  // simply return 0 when dynamic shape
218     }
219     auto cnode = node->cast<CNodePtr>();
220     auto value = GetValueNode(cnode->input(2));
221     if (value == nullptr) {
222       return 0;  // simply return 0 when dynamic shape
223     }
224     auto tuple = GetValue<std::vector<int64_t>>(value);
225     inputs_tuple.push_back(tuple);
226   }
227   size_t batch_axis = 0;
228   size_t batch_axis_count = 0;
229   size_t input_dim = inputs_tuple.at(0).size();
230   size_t micro_num = inputs_tuple.size();
231   for (size_t axis = 0; axis < input_dim; ++axis) {
232     for (size_t i = 1; i < micro_num; ++i) {
233       if (inputs_tuple[i][axis] != inputs_tuple[i - 1][axis]) {
234         batch_axis = axis;
235         ++batch_axis_count;
236         break;
237       }
238     }
239   }
240   if (batch_axis_count != kSizeOne) {
241     MS_LOG(EXCEPTION)
242       << "For pipeline parallelism, micro_size partitioning of the input along a certain dimension is and "
243       << "is only allowed, but it is found that " << batch_axis_count << " to be partitioned.";
244   }
245   return batch_axis;
246 }
247 
LabelMicroBatch()248 void PipelineInterleave::LabelMicroBatch() {
249   if (!is_train_) {
250     return;
251   }
252   MS_EXCEPTION_IF_NULL(virtual_dataset_);
253   auto node_user_map = manager_->node_users();
254   auto node_users = node_user_map[virtual_dataset_];
255   for (auto &node_user : node_users) {
256     if (IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) {
257       auto data_users = manager_->node_users()[node_user.first];
258       auto node_first = data_users.front().first;
259       if (!IsPrimitiveCNode(node_first, prim::kPrimStridedSlice) && !IsPrimitiveCNode(node_first, prim::kPrimShape)) {
260         data_users.clear();
261         data_users = node_user_map[node_first];
262       }
263       auto micro_size = int64_t(MicroSize(data_users));
264       micro_size_ = micro_size;
265       auto batch_axis = GetBatchAxisForInput(data_users);
266       MS_LOG(INFO) << "For the "
267                    << GetSerialNumberString(
268                         GetValue<int64_t>(GetValueNode(node_user.first->cast<CNodePtr>()->input(kIndex2))))
269                    << "input, batch axis is " << batch_axis << ", micro size is : " << micro_size;
270       for (auto &data_user : data_users) {
271         if (!IsPrimitiveCNode(data_user.first, prim::kPrimStridedSlice)) {
272           continue;
273         }
274         auto micro = SetMicroBatch(data_user.first, micro_size, batch_axis);
275         SetStridedSliceStrategy(data_user.first);
276         auto cnode = data_user.first->cast<CNodePtr>();
277         BroadCastMicroBatch(cnode, &node_user_map, micro, 0);
278       }
279     }
280   }
281 }
282 
LabelGenMaskFusion()283 void PipelineInterleave::LabelGenMaskFusion() {
284   auto fgs = manager_->func_graphs();
285   int64_t fusion_id = 0;
286   for (auto fg = fgs.cbegin(); fg != fgs.cend(); ++fg) {
287     if (*fg == root_ || *fg == main_graph_) {
288       continue;
289     }
290     auto stage = (*fg)->stage();
291     if (stage != -1 && stage != stage_) {
292       continue;
293     }
294     auto nodes = (*fg)->nodes();
295     for (auto node = nodes.cbegin(); node != nodes.cend(); ++node) {
296       if (!IsPrimitiveCNode(*node, prim::kPrimDropoutGenMask) && !IsPrimitiveCNode(*node, prim::kPrimDropoutDoMaskV3) &&
297           !IsPrimitiveCNode(*node, prim::kPrimDropout)) {
298         continue;
299       }
300       auto cnode = (*node)->cast<CNodePtr>();
301       MS_EXCEPTION_IF_NULL(cnode);
302       cnode->AddPrimalAttr(kAttrFusion, MakeValue(fusion_id));
303       fusion_id += 1;
304     }
305   }
306 }
307 
Coloring()308 void PipelineInterleave::Coloring() {
309   auto need_coloring = true;
310   std::set<int64_t> stage_set;
311   if (!IsTraining(manager_)) {
312     is_train_ = false;
313   }
314   while (need_coloring) {
315     need_coloring = false;
316     for (auto &fg : manager_->func_graphs()) {
317       if (fg == root_ && is_train_) {
318         continue;
319       }
320       auto value_nodes = fg->value_nodes();
321       for (auto value_pair = value_nodes.cbegin(); value_pair != value_nodes.cend(); ++value_pair) {
322         auto node = (*value_pair).first;
323         if (!IsValueNode<FuncGraph>(node)) {
324           continue;
325         }
326         auto graph = GetValueNode<FuncGraphPtr>(node);
327         if (graph->stage() == -1) {
328           continue;
329         }
330         (void)stage_set.insert(graph->stage());
331         auto node_users = manager_->node_users()[node];
332         for (auto &user_pair : node_users) {
333           auto user_node = user_pair.first->cast<CNodePtr>();
334           user_node->set_user_data<NodeStageInfo>(std::make_shared<NodeStageInfo>(graph->stage()));
335           auto user_node_graph = user_node->func_graph();
336           if (graph->stage() == stage_ && user_node_graph->stage() == -1) {
337             user_node_graph->set_stage(graph->stage());
338             need_coloring = true;
339           }
340         }
341       }
342     }
343   }
344   MS_EXCEPTION_IF_NULL(g_device_manager);
345   auto stage_num = g_device_manager->stage_num();
346   if (SizeToLong(stage_set.size()) != stage_num) {
347     MS_LOG(EXCEPTION) << "Stage num is " << stage_num << " which is not equal to stage used: " << stage_set.size();
348   }
349 }
350 
BroadCastColoring()351 void PipelineInterleave::BroadCastColoring() {
352   auto need_coloring = true;
353   while (need_coloring) {
354     need_coloring = false;
355     auto all_nodes = shared_cell_->nodes();
356     auto node_users = manager_->node_users();
357     for (auto node = all_nodes.cbegin(); node != all_nodes.cend(); ++node) {
358       auto stage_info = (*node)->user_data<NodeStageInfo>();
359       if (!(*node)->isa<CNode>() || stage_info == nullptr || stage_info->stage() == -1 ||
360           IsPrimitiveCNode(*node, prim::kPrimUpdateState)) {
361         continue;
362       }
363       auto cnode = (*node)->cast<CNodePtr>();
364       auto stage = stage_info->stage();
365       auto chunk = stage_info->chunk();
366       for (auto &user_pair : node_users[*node]) {
367         auto user_node = user_pair.first->cast<CNodePtr>();
368         auto user_stage_info = user_node->user_data<NodeStageInfo>();
369         if (user_stage_info == nullptr) {
370           user_node->set_user_data<NodeStageInfo>(std::make_shared<NodeStageInfo>(stage, chunk));
371           need_coloring = true;
372           user_node->AddPrimalAttr(CHUNK, MakeValue(chunk));
373           user_node->AddPrimalAttr(STAGE, MakeValue(stage));
374           continue;
375         }
376         auto user_node_stage = user_stage_info->stage();
377         auto user_node_chunk = user_stage_info->chunk();
378         if (stage == user_node_stage) {
379           if (chunk > user_node_chunk) {
380             user_stage_info->set_chunk(chunk);
381             need_coloring = true;
382             user_node->AddPrimalAttr(CHUNK, MakeValue(chunk));
383             user_node->AddPrimalAttr(STAGE, MakeValue(user_node_stage));
384             continue;
385           }
386           if (chunk < user_node_chunk) {
387             stage_info->set_chunk(user_node_chunk);
388             chunk = user_node_chunk;
389             need_coloring = true;
390             cnode->AddPrimalAttr(CHUNK, MakeValue(chunk));
391             cnode->AddPrimalAttr(STAGE, MakeValue(user_node_stage));
392             continue;
393           }
394         }
395         if (stage > user_node_stage) {
396           if ((chunk >= user_node_chunk)) {
397             user_stage_info->set_chunk(chunk + 1);
398             need_coloring = true;
399             user_node->AddPrimalAttr(CHUNK, MakeValue(chunk + 1));
400             user_node->AddPrimalAttr(STAGE, MakeValue(user_node_stage));
401             continue;
402           }
403         }
404         if ((stage < user_node_stage) && (chunk > user_node_chunk)) {
405           user_stage_info->set_chunk(chunk);
406           need_coloring = true;
407           user_node->AddPrimalAttr(CHUNK, MakeValue(chunk));
408           user_node->AddPrimalAttr(STAGE, MakeValue(user_node_stage));
409         }
410       }
411     }
412   }
413 }
414 
GetLoadNodeByParam(const AnfNodePtr & param) const415 std::vector<AnfNodePtr> PipelineInterleave::GetLoadNodeByParam(const AnfNodePtr &param) const {
416   std::vector<AnfNodePtr> load_vec = {param};
417   auto node_users = manager_->node_users()[param];
418   for (auto &param_user : node_users) {
419     if (IsPrimitiveCNode(param_user.first, prim::kPrimLoad)) {
420       auto graph = param_user.first->func_graph();
421       // exclude opt graphs
422       if (graph == root_ || (graph->stage() == -1 && graph != main_graph_)) {
423         continue;
424       }
425       (void)load_vec.emplace_back(param_user.first);
426     }
427   }
428   return load_vec;
429 }
430 
GetStageByArgument(const CNodePtr & node,size_t index,const std::vector<AnfNodePtr> & parameters,const NodeUsersMap & node_users_map,std::set<int64_t> * const parameter_stage)431 bool PipelineInterleave::GetStageByArgument(const CNodePtr &node, size_t index,
432                                             const std::vector<AnfNodePtr> &parameters,
433                                             const NodeUsersMap &node_users_map,
434                                             std::set<int64_t> *const parameter_stage) {
435   if (index < 1) {
436     return false;
437   }
438   const auto &input = node->input(0);
439   if (!IsValueNode<FuncGraph>(input)) {
440     return false;
441   }
442   if (GetValueNode<FuncGraphPtr>(input) != shared_cell_) {
443     return false;
444   }
445   auto pos = index - 1;
446   const auto &param = parameters.at(pos);
447   MS_EXCEPTION_IF_NULL(param);
448   auto loads = GetLoadNodeByParam(param);
449   const auto &iter = node_users_map.find(loads.back());
450   if (iter == node_users_map.end()) {
451     return true;
452   }
453   const auto &users = (*iter).second;
454   for (auto &user : users) {
455     auto user_cnode = user.first->cast<CNodePtr>();
456     MS_EXCEPTION_IF_NULL(user_cnode);
457     auto stage_info = user_cnode->user_data<NodeStageInfo>();
458     if (stage_info != nullptr && stage_info->stage() != -1) {
459       (void)((*parameter_stage).insert(stage_info->stage()));
460     } else {
461       auto graph = user_cnode->func_graph();
462       MS_EXCEPTION_IF_NULL(graph);
463       if (graph != root_ && graph != main_graph_ && graph != shared_cell_ && graph->stage() != -1) {
464         (void)((*parameter_stage).insert(graph->stage()));
465       }
466     }
467   }
468   return true;
469 }
470 
ParameterColoring()471 void PipelineInterleave::ParameterColoring() {
472   auto parameters = root_->parameters();
473   auto &node_users_map = manager_->node_users();
474   const auto &share_cell_parameters = shared_cell_->parameters();
475   for (auto &parameter : parameters) {
476     auto loads = GetLoadNodeByParam(parameter);
477     std::set<int64_t> parameter_stage;
478     for (auto &load : loads) {
479       auto load_users = node_users_map[load];
480       for (auto &load_user : load_users) {
481         auto user_cnode = load_user.first->cast<CNodePtr>();
482         MS_EXCEPTION_IF_NULL(user_cnode);
483         if (GetStageByArgument(user_cnode, load_user.second, share_cell_parameters, node_users_map, &parameter_stage)) {
484           continue;
485         }
486         auto stage_info = user_cnode->user_data<NodeStageInfo>();
487         if (stage_info != nullptr && stage_info->stage() != -1) {
488           (void)parameter_stage.insert(stage_info->stage());
489           continue;
490         } else {
491           auto graph = user_cnode->func_graph();
492           MS_EXCEPTION_IF_NULL(graph);
493           if (graph != root_ && graph != main_graph_ && graph != shared_cell_ && graph->stage() != -1) {
494             (void)parameter_stage.insert(graph->stage());
495             continue;
496           }
497         }
498       }
499     }
500     parameter_color_map_[parameter] = parameter_stage;
501   }
502 }
503 
RemoveMonadNode()504 void PipelineInterleave::RemoveMonadNode() {
505   auto all_nodes = DeepScopedGraphSearch(shared_cell_->get_return());
506   auto node_users_map = manager_->node_users();
507   for (auto &node : all_nodes) {
508     if (!IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
509       continue;
510     }
511     auto cnode = node->cast<CNodePtr>();
512     MS_EXCEPTION_IF_NULL(cnode);
513     auto abs = cnode->abstract();
514     MS_EXCEPTION_IF_NULL(abs);
515     auto stage_info = cnode->user_data<NodeStageInfo>();
516     if (stage_info == nullptr) {
517       continue;
518     }
519     auto stage = stage_info->stage();
520     if (stage != stage_ && stage != -1) {
521       auto node_users = node_users_map[node];
522       for (auto &user_node : node_users) {
523         auto monad_node = NewValueNode(kUMonad);
524         if (abs->isa<abstract::AbstractIOMonad>()) {
525           monad_node = NewValueNode(kIOMonad);
526         }
527         manager_->SetEdge(user_node.first, user_node.second, monad_node);
528       }
529     }
530   }
531 }
532 
CreateZeroseOutput(const AnfNodePtr & node,size_t index)533 static tensor::TensorPtr CreateZeroseOutput(const AnfNodePtr &node, size_t index) {
534   auto out_shapes = GetNodeShape(node);
535   auto out_shape_type = GetShapeType(node, out_shapes.at(index), index);
536   auto zero_tensor = TensorConstructUtils::CreateZerosTensor(out_shape_type.second, out_shapes.at(index));
537   return zero_tensor;
538 }
539 
CreateTupleZeroTensor(const FuncGraphPtr & graph,const AnfNodePtr & node,size_t index)540 static AnfNodePtr CreateTupleZeroTensor(const FuncGraphPtr &graph, const AnfNodePtr &node, size_t index) {
541   std::vector<AnfNodePtr> temp_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
542   auto out_shapes = GetNodeShape(node);
543   for (size_t ele = 0; ele < out_shapes.size(); ++ele) {
544     temp_tuple_inputs.emplace_back(NewValueNode(CreateZeroseOutput(node, ele)));
545   }
546   auto temp_tuple = graph->NewCNode(temp_tuple_inputs);
547   return temp_tuple;
548 }
549 
InsertSendReceive(const AnfNodePtr & node,const AnfNodePtr & user_node,int64_t order)550 void PipelineInterleave::InsertSendReceive(const AnfNodePtr &node, const AnfNodePtr &user_node, int64_t order) {
551   auto node_stage_info = node->user_data<NodeStageInfo>();
552   auto user_node_stage_info = user_node->user_data<NodeStageInfo>();
553   auto node_stage = node_stage_info->stage();
554   auto user_stage = user_node_stage_info->stage();
555   Attr attr_tag = std::make_pair(SR_TAG, MakeValue(0));
556   Attr attr_rank = std::make_pair(DEST_RANK, MakeValue(user_stage));
557   Attr attr_group = std::make_pair(GROUP, MakeValue(group_[0]));
558   Attr attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[1]));
559   if (node_stage > user_stage) {
560     attr_group = std::make_pair(GROUP, MakeValue(group_[INDEX_TWO]));
561     attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[INDEX_THREE]));
562   }
563   OperatorAttrs attrs = {attr_tag, attr_rank, attr_group, attr_group_back};
564   auto send_op = CreateOpInstance(attrs, SEND, SEND);
565   auto send_node = NewValueNode(send_op);
566   std::vector<AnfNodePtr> send_input = {send_node, node};
567   auto graph = shared_cell_;
568   auto send = graph->NewCNode(send_input);
569   send->set_user_data<NodeStageInfo>(node_stage_info);
570   send->set_abstract(node->abstract());
571   send->AddPrimalAttr(CHUNK, MakeValue(node_stage_info->chunk()));
572   send->AddPrimalAttr(STAGE, MakeValue(node_stage_info->stage()));
573   send->AddPrimalAttr(ORDER, MakeValue(order));
574 
575   attr_rank = std::make_pair(SRC_RANK, MakeValue(node_stage));
576   auto shape_type_pair = GetShapeType(node, {1}, 0);
577   Attr attr_shape = std::make_pair(SHAPE, shape_type_pair.first);
578   Attr attr_dtype = std::make_pair(DTYPE, shape_type_pair.second);
579   auto send_prim = GetCNodePrimitive(send);
580   send_prim->set_attr(DTYPE, shape_type_pair.second);
581   OperatorAttrs attrs_recv = {attr_tag, attr_rank, attr_shape, attr_dtype, attr_group, attr_group_back};
582   auto recv_op = CreateOpInstance(attrs_recv, RECEIVE, RECEIVE);
583   std::vector<AnfNodePtr> recv_input = {NewValueNode(recv_op), send};
584   auto recv = graph->NewCNode(recv_input);
585   recv->set_abstract(node->abstract());
586   recv->set_user_data<NodeStageInfo>(user_node_stage_info);
587   recv->AddPrimalAttr(CHUNK, MakeValue(user_node_stage_info->chunk()));
588   recv->AddPrimalAttr(STAGE, MakeValue(user_node_stage_info->stage()));
589   recv->AddPrimalAttr(ORDER, MakeValue(order));
590   auto micro = user_node->cast<CNodePtr>()->GetPrimalAttr(MICRO);
591   if (micro != nullptr) {
592     recv->AddPrimalAttr(MICRO, micro);
593   }
594   manager_->Replace(node, recv);
595 }
596 
CutBorderForNode(const FuncGraphPtr & graph,const AnfNodePtr & node,int64_t * order)597 void PipelineInterleave::CutBorderForNode(const FuncGraphPtr &graph, const AnfNodePtr &node, int64_t *order) {
598   auto stage_info = node->user_data<NodeStageInfo>();
599   auto node_users = manager_->node_users()[node];
600   AnfNodePtr receive = nullptr;
601   auto pre_node = GetRealKernelNode(node, -1).first;
602   bool send_param = false;
603   if (pre_node->isa<Parameter>()) {
604     send_param = true;
605   }
606   for (auto &user_pair : node_users) {
607     auto user_node = user_pair.first;
608     auto node_stage = stage_info->stage();
609     auto user_stage_info = user_node->user_data<NodeStageInfo>();
610     if (user_stage_info == nullptr) {
611       continue;
612     }
613     auto user_node_stage = user_stage_info->stage();
614     auto micro = user_node->cast<CNodePtr>()->GetPrimalAttr(MICRO);
615     if (!micro) {
616       MS_LOG(INFO) << "Can't find micro_batch information, use micro(0)";
617       micro = MakeValue(int64_t(0));
618     }
619     if (node_stage != user_node_stage) {
620       InsertSendReceive(node, user_node, *order);
621       (*order) += 1;
622       if (send_param) {
623         parameter_color_map_[pre_node].insert(user_node_stage);
624       }
625     }
626   }
627 }
628 
RedundancyNode(const AnfNodePtr & node,mindspore::HashMap<CNodePtr,std::vector<AnfNodePtr>> * make_tuple_map)629 void PipelineInterleave::RedundancyNode(const AnfNodePtr &node,
630                                         mindspore::HashMap<CNodePtr, std::vector<AnfNodePtr>> *make_tuple_map) {
631   auto node_users = manager_->node_users()[node];
632   for (auto &node_user_pair : node_users) {
633     auto cnode = node_user_pair.first->cast<CNodePtr>();
634     // node->UpdateState, replaced node wiht U.
635     auto fg = cnode->func_graph();
636     MS_EXCEPTION_IF_NULL(fg);
637     if (fg->stage() != -1 && fg != main_graph_) {
638       continue;
639     }
640     if (IsPrimitiveCNode(cnode, prim::kPrimUpdateState)) {
641       auto u_node = NewValueNode(kUMonad);
642       manager_->SetEdge(cnode, node_user_pair.second, u_node);
643       continue;
644     }
645     // node->make_tuple, record with a map, Unified deleted later.
646     if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) {
647       if (fg == main_graph_) {
648         continue;
649       }
650       if (make_tuple_map->find(cnode) == (*make_tuple_map).end()) {
651         (*make_tuple_map)[cnode] = {node};
652       } else {
653         (*make_tuple_map)[cnode].push_back(node);
654       }
655     } else {
656       RedundancyNode(node_user_pair.first, make_tuple_map);
657     }
658   }
659 }
660 
IsRedundancyParameter(const AnfNodePtr & parameter,const std::vector<AnfNodePtr> & non_cloned_parameters)661 bool PipelineInterleave::IsRedundancyParameter(const AnfNodePtr &parameter,
662                                                const std::vector<AnfNodePtr> &non_cloned_parameters) {
663   // RedundancyParameter: other stage's parameters included corresponding cloned parameters.
664   auto param_ptr = parameter->cast<ParameterPtr>();
665   MS_EXCEPTION_IF_NULL(param_ptr);
666   if (!param_ptr->has_default()) {
667     return false;
668   }
669   std::set<int64_t> stage_set;
670   if (!ParameterIsCloned(parameter)) {
671     stage_set = parameter_color_map_.at(parameter);
672   } else {
673     auto parameters = root_->parameters();
674     auto param_name = param_ptr->name();
675     auto non_clone_name = param_name.substr(param_name.find_first_of('.') + 1);
676     for (auto &param : non_cloned_parameters) {
677       auto non_cloned_param = param->cast<ParameterPtr>();
678       if (non_clone_name != non_cloned_param->name()) {
679         continue;
680       }
681       stage_set = parameter_color_map_.at(param);
682       break;
683     }
684   }
685   if (stage_set.empty()) {
686     return false;
687   }
688   return stage_set.count(stage_) == 0;
689 }
690 
ElimParameter()691 void PipelineInterleave::ElimParameter() {
692   auto parameters = root_->parameters();
693   mindspore::HashMap<CNodePtr, std::vector<AnfNodePtr>> make_tuple_map;
694   std::vector<AnfNodePtr> non_cloned_parameters;
695   FreezeGradient();
696   auto node_users_map = manager_->node_users();
697   for (auto &parameter : parameters) {
698     if (ParameterIsCloned(parameter)) {
699       continue;
700     }
701     non_cloned_parameters.push_back(parameter);
702   }
703   for (auto &parameter : parameters) {
704     if (!IsRedundancyParameter(parameter, non_cloned_parameters)) {
705       continue;
706     }
707     MS_LOG(INFO) << "Parameter:" << parameter->DebugString() << " is Redundancy.";
708     RedundancyNode(parameter, &make_tuple_map);
709   }
710   for (auto &temp : make_tuple_map) {
711     auto make_tuple = temp.first;
712     auto fg = make_tuple->func_graph();
713     MS_EXCEPTION_IF_NULL(fg);
714     auto remove_vector = temp.second;
715     if (remove_vector.empty()) {
716       continue;
717     }
718     auto make_tuple_user = node_users_map.at(make_tuple).front().first;
719     auto make_tuple_inputs = make_tuple->inputs();
720     std::vector<AnfNodePtr> new_inputs;
721     for (auto &input : make_tuple_inputs) {
722       if (std::find(remove_vector.begin(), remove_vector.end(), input) == remove_vector.end()) {
723         new_inputs.push_back(input);
724       }
725       if (root_->has_flag(NO_UPDATE) && IsPrimitiveCNode(make_tuple_user, prim::kPrimAddN)) {
726         auto zeros = CreateZeroseOutput(input, 0);
727         new_inputs.push_back(NewValueNode(zeros));
728       }
729     }
730     auto new_make_tuple = fg->NewCNode(new_inputs);
731     (void)manager_->Replace(make_tuple, new_make_tuple);
732   }
733 }
734 
ModifyParameterList()735 void PipelinePostProcess::ModifyParameterList() {
736   auto parameters = root_->parameters();
737   std::vector<AnfNodePtr> parameter_list;
738   for (auto &parameter : parameters) {
739     auto param = parameter->cast<ParameterPtr>();
740     MS_EXCEPTION_IF_NULL(param);
741     if (!manager_->node_users()[parameter].empty() || !param->has_default()) {
742       parameter_list.push_back(parameter);
743     }
744   }
745   auto del_num = parameters.size() - parameter_list.size();
746   root_->set_fv_param_count(root_->fv_param_count() - del_num);
747   manager_->SetParameters(root_, parameter_list);
748 }
749 
CutBorder()750 void PipelineInterleave::CutBorder() {
751   CreateSendReceiveGroup();
752   MS_EXCEPTION_IF_NULL(shared_cell_);
753   auto ret = shared_cell_->get_return();
754   MS_EXCEPTION_IF_NULL(ret);
755   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
756   std::reverse(all_nodes.begin(), all_nodes.end());
757   int64_t order = 0;
758   for (auto &node : all_nodes) {
759     auto stage_info = node->user_data<NodeStageInfo>();
760     if (!node->isa<CNode>() || stage_info == nullptr || stage_info->stage() == -1 ||
761         IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
762       continue;
763     }
764     // Modify for lizard cyclomatic complexity.
765     CutBorderForNode(shared_cell_, node, &order);
766   }
767   RemoveMonadNode();
768 }
769 
GetZeroOutputs(const FuncGraphPtr & graph)770 AnfNodePtr PipelinePostProcess::GetZeroOutputs(const FuncGraphPtr &graph) {
771   auto real_kernel = GetRealKernelNode(graph->output(), -1);
772   AnfNodePtr node = real_kernel.first;
773   MS_EXCEPTION_IF_NULL(node);
774   std::vector<AnfNodePtr> out_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
775   if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
776     auto cnode = node->cast<CNodePtr>();
777     for (size_t i = 1; i < cnode->inputs().size(); ++i) {
778       auto each_out_shapes = GetNodeShape(cnode->input(i));
779       if (each_out_shapes.size() > 1) {
780         auto temp_tuple = CreateTupleZeroTensor(graph, cnode->input(i), each_out_shapes.size());
781         (void)out_tuple_inputs.emplace_back(temp_tuple);
782         continue;
783       }
784       (void)out_tuple_inputs.emplace_back(NewValueNode(CreateZeroseOutput(cnode->input(i), 0)));
785     }
786   }
787   AnfNodePtr zero_outputs;
788   if (out_tuple_inputs.size() > INDEX_ONE) {
789     auto out_tuple = graph->NewCNode(out_tuple_inputs);
790     return out_tuple;
791   } else {
792     auto out_shapes = GetNodeShape(node);
793     AnfNodePtr out_tensor;
794     if (out_shapes.size() > 1 && real_kernel.second == -1) {
795       out_tensor = CreateTupleZeroTensor(graph, node, out_shapes.size());
796     } else {
797       out_tensor = NewValueNode(CreateZeroseOutput(node, 0));
798     }
799     return out_tensor;
800   }
801   return nullptr;
802 }
803 
SetNodeAbstract(const std::vector<AnfNodePtr> & nodes)804 void PipelinePostProcess::SetNodeAbstract(const std::vector<AnfNodePtr> &nodes) {
805   AbstractBasePtr abs;
806   if (nodes.size() == 1) {
807     auto cnode = nodes.front()->cast<CNodePtr>();
808     MS_EXCEPTION_IF_NULL(cnode);
809     abs = GetRealAbstract(cnode->input(INDEX_ONE));
810   } else {
811     AbstractBasePtrList abstract_list;
812     abstract_list.resize(nodes.size());
813     (void)std::transform(nodes.begin(), nodes.end(), abstract_list.begin(), [](const AnfNodePtr &node) {
814       auto cnode = node->cast<CNodePtr>();
815       MS_EXCEPTION_IF_NULL(cnode);
816       return GetRealAbstract(cnode->input(INDEX_ONE));
817     });
818     abs = std::make_shared<abstract::AbstractTuple>(abstract_list);
819   }
820   for (auto &user : shared_cell_users_) {
821     user->set_abstract(abs);
822   }
823 }
824 
ModifySendRecvAttr(const std::vector<AnfNodePtr> & all_nodes)825 void PipelinePostProcess::ModifySendRecvAttr(const std::vector<AnfNodePtr> &all_nodes) {
826   for (auto &node : all_nodes) {
827     if (!IsPrimitiveCNode(node, prim::kPrimSend) && !IsPrimitiveCNode(node, prim::kPrimReceive)) {
828       continue;
829     }
830     auto pre_node_pair = GetRealKernelNode(node, -1);
831     auto pre_node = pre_node_pair.first;
832     auto cnode = node->cast<CNodePtr>();
833     auto prim = GetCNodePrimitive(node);
834     Shape slice_shape;
835     if (pre_node->isa<Parameter>()) {
836       auto base_shape = pre_node->Shape();
837       MS_EXCEPTION_IF_NULL(base_shape);
838       auto shape_ptr = dyn_cast<abstract::Shape>(base_shape);
839       MS_EXCEPTION_IF_NULL(shape_ptr);
840       slice_shape = shape_ptr->shape();
841       cnode->AddPrimalAttr(PIPELINE_PARAM, MakeValue(0));
842       cnode->AddPrimalAttr(MICRO, MakeValue(int64_t(0)));
843       cnode->set_user_data<AnfNode>(INPUT_PARAM, pre_node);
844     } else {
845       auto op_info = pre_node->cast<CNodePtr>()->user_data<OperatorInfo>();
846       MS_EXCEPTION_IF_NULL(op_info);
847       auto tensor_info = op_info->outputs_tensor_info();
848       if (pre_node_pair.second != -1 && tensor_info.size() > 1) {
849         slice_shape = tensor_info.at(pre_node_pair.second).slice_shape();
850       } else {
851         slice_shape = tensor_info.at(0).slice_shape();
852       }
853     }
854     auto abstract = node->abstract();
855     abstract->set_shape(std::make_shared<abstract::Shape>(slice_shape));
856     std::vector<ValuePtr> element;
857     (void)std::transform(slice_shape.begin(), slice_shape.end(), std::back_inserter(element),
858                          [](int elem) { return MakeValue(int64_t(elem)); });
859     auto value = std::make_shared<ValueList>(element);
860     prim->set_attr(SHAPE, value);
861   }
862 }
863 
CalSrTag(int64_t order,int64_t micro,int64_t interleave_index)864 static int64_t CalSrTag(int64_t order, int64_t micro, int64_t interleave_index) {
865   return order * MAX_MICRO_BATCH_NUM * MAX_INTERLEAVE_NUM + interleave_index * MAX_INTERLEAVE_NUM + micro;
866 }
867 
GenNewNodeFromOld(const AnfNodePtr & node,const AnfNodePtr & input,int64_t micro,int64_t index)868 AnfNodePtr PipelinePostProcess::GenNewNodeFromOld(const AnfNodePtr &node, const AnfNodePtr &input, int64_t micro,
869                                                   int64_t index) {
870   const auto &old = node->cast<CNodePtr>();
871   MS_EXCEPTION_IF_NULL(old);
872   auto prim = GetCNodePrimitive(node);
873   auto cloned_prim = prim->Clone();
874   auto attrs = prim->attrs();
875   auto order = GetValue<int64_t>(old->GetPrimalAttr(ORDER));
876   auto sr_tag = CalSrTag(order, micro, index);
877   attrs[SR_TAG] = MakeValue(sr_tag);
878   cloned_prim->SetAttrs(attrs);
879   std::vector<AnfNodePtr> new_node_input = {NewValueNode(cloned_prim), input};
880   auto new_node = main_graph_->NewCNode(new_node_input);
881   new_node->set_abstract(old->abstract());
882   if (old->HasPrimalAttr(PIPELINE_PARAM)) {
883     new_node->AddPrimalAttr(PIPELINE_PARAM, MakeValue(0));
884   }
885   new_node->set_primal_attrs(old->primal_attrs());
886   new_node->AddPrimalAttr(ORDER, MakeValue(sr_tag));
887   return new_node;
888 }
889 
GenerateMainGraphSend(const std::vector<AnfNodePtr> & nodes,const AnfNodePtr & node,const ValuePtr & micro,const ValuePtr & index)890 std::vector<AnfNodePtr> PipelinePostProcess::GenerateMainGraphSend(const std::vector<AnfNodePtr> &nodes,
891                                                                    const AnfNodePtr &node, const ValuePtr &micro,
892                                                                    const ValuePtr &index) {
893   std::vector<AnfNodePtr> sends;
894   auto index_value = GetValue<int64_t>(index);
895   for (size_t i = 0; i < nodes.size(); ++i) {
896     auto send = nodes[i];
897     auto csend = send->cast<CNodePtr>();
898     if (csend->HasPrimalAttr(PIPELINE_PARAM)) {
899       if (csend->HasPrimalAttr("send_once")) {
900         continue;
901       }
902       auto param = csend->cast<CNodePtr>()->user_data<AnfNode>(INPUT_PARAM);
903       csend->AddPrimalAttr("send_once", MakeValue(true));
904       auto new_send = GenNewNodeFromOld(send, param, 0, 0);
905       sends.emplace_back(new_send);
906       continue;
907     }
908     auto micro_value = GetValue<int64_t>(micro);
909     auto send_input = CreateTupleGetItemNode(main_graph_, node, i);
910     auto new_send = GenNewNodeFromOld(send, send_input, micro_value, index_value)->cast<CNodePtr>();
911     new_send->AddPrimalAttr(PIPELINE_END, micro);
912     new_send->AddPrimalAttr(MICRO, micro);
913     sends.emplace_back(new_send);
914   }
915   return sends;
916 }
917 
GenerateMainGraphRecv(const AnfNodePtr & fg_node,const AnfNodePtr & recv)918 AnfNodePtr PipelinePostProcess::GenerateMainGraphRecv(const AnfNodePtr &fg_node, const AnfNodePtr &recv) {
919   auto cuser = fg_node->cast<CNodePtr>();
920   MS_EXCEPTION_IF_NULL(cuser);
921   auto crecv = recv->cast<CNodePtr>();
922   AnfNodePtr new_recv;
923   if (crecv->HasPrimalAttr(PIPELINE_PARAM)) {
924     auto param = crecv->user_data<AnfNode>(INPUT_PARAM);
925     MS_EXCEPTION_IF_NULL(param);
926     new_recv = GenNewNodeFromOld(recv, param, 0, 0);
927   } else {
928     auto index = cuser->GetPrimalAttr(INDEX);
929     MS_EXCEPTION_IF_NULL(index);
930     auto index_value = GetValue<int64_t>(index);
931     new_recv = GenNewNodeFromOld(recv, crecv->input(1), GetValue<int64_t>(cuser->GetPrimalAttr(MICRO)), index_value);
932     new_recv->cast<CNodePtr>()->AddPrimalAttr(PIPELINE_BEGIN, cuser->GetPrimalAttr(MICRO));
933   }
934   new_recv->cast<CNodePtr>()->AddPrimalAttr(MICRO, cuser->GetPrimalAttr(MICRO));
935   manager_->AddEdge(cuser, new_recv);
936   return new_recv;
937 }
938 
Init(const std::vector<AnfNodePtr> & nodes)939 void PipelinePostProcess::Init(const std::vector<AnfNodePtr> &nodes) {
940   for (auto &node : nodes) {
941     if ((IsPrimitiveCNode(node, prim::kPrimSend) || IsPrimitiveCNode(node, prim::kPrimReceive)) &&
942         shared_cell_ == nullptr) {
943       shared_cell_ = node->cast<CNodePtr>()->func_graph();
944     }
945     if (IsPrimitiveCNode(node, prim::kPrimJ)) {
946       auto cnode = node->cast<CNodePtr>();
947       auto graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
948       main_graph_ = graph;
949     }
950     if (!IsPrimitiveCNode(node, prim::kPrimSend) && !IsPrimitiveCNode(node, prim::kPrimReceive)) {
951       continue;
952     }
953     auto cnode = node->cast<CNodePtr>();
954     auto chunk = GetValue<int64_t>(cnode->GetPrimalAttr(CHUNK));
955     chunk_num_ = (chunk + 1) > chunk_num_ ? (chunk + 1) : chunk_num_;
956   }
957   auto value_nodes = main_graph_->value_nodes();
958   for (auto value_pair = value_nodes.cbegin(); value_pair != value_nodes.cend(); ++value_pair) {
959     auto node = (*value_pair).first;
960     if (!IsValueNode<FuncGraph>(node)) {
961       continue;
962     }
963     auto fg = GetValueNode<FuncGraphPtr>(node);
964     if (fg != shared_cell_) {
965       continue;
966     }
967     auto node_users = manager_->node_users()[node];
968     for (auto &node_user : node_users) {
969       auto user = node_user.first;
970       if (user->func_graph() == main_graph_) {
971         shared_cell_users_.emplace_back(user);
972       }
973     }
974     break;
975   }
976 }
977 
GetSendsRecvs(const FuncGraphPtr & fg,int64_t chunk,std::vector<AnfNodePtr> * recvs,std::vector<AnfNodePtr> * sends,std::vector<AnfNodePtr> * temp)978 void PipelinePostProcess::GetSendsRecvs(const FuncGraphPtr &fg, int64_t chunk, std::vector<AnfNodePtr> *recvs,
979                                         std::vector<AnfNodePtr> *sends, std::vector<AnfNodePtr> *temp) {
980   const auto &all_nodes = TopoSort(fg->get_return());
981   for (auto &node : all_nodes) {
982     if (!node->isa<CNode>()) {
983       continue;
984     }
985     auto cnode = node->cast<CNodePtr>();
986     if (!cnode->HasPrimalAttr(STAGE)) {
987       continue;
988     }
989     auto stage_value = cnode->GetPrimalAttr(STAGE);
990     if (stage_value && GetValue<int64_t>(stage_value) != stage_) {
991       continue;
992     }
993     if (IsPrimitiveCNode(cnode, prim::kPrimSend) && GetValue<int64_t>(cnode->GetPrimalAttr(CHUNK)) == chunk) {
994       if (!cnode->HasPrimalAttr(PIPELINE_PARAM)) {
995         temp->emplace_back(cnode->input(INDEX_ONE));
996       }
997       sends->emplace_back(node);
998     }
999     if (IsPrimitiveCNode(cnode, prim::kPrimReceive) && GetValue<int64_t>(cnode->GetPrimalAttr(CHUNK)) == chunk) {
1000       auto prim = GetCNodePrimitive(node);
1001       auto attrs = prim->attrs();
1002       auto zero_tensor = TensorConstructUtils::CreateZerosTensor(attrs[DTYPE]->cast<TypePtr>(), {1});
1003       manager_->SetEdge(node, 1, NewValueNode(zero_tensor));
1004       recvs->emplace_back(node);
1005     }
1006   }
1007   return;
1008 }
1009 
LabelInterleaveIndex()1010 void PipelinePostProcess::LabelInterleaveIndex() {
1011   std::vector<int64_t> micro_visited;
1012   for (auto &usr : shared_cell_users_) {
1013     int64_t index = 0;
1014     auto cusr = usr->cast<CNodePtr>();
1015     MS_EXCEPTION_IF_NULL(cusr);
1016     auto micro = cusr->GetPrimalAttr(MICRO);
1017     MS_EXCEPTION_IF_NULL(micro);
1018     auto micro_value = GetValue<int64_t>(micro);
1019     if (!std::count(micro_visited.begin(), micro_visited.end(), micro_value)) {
1020       micro_visited.emplace_back(micro_value);
1021     } else {
1022       index += 1;
1023     }
1024     cusr->AddPrimalAttr(INDEX, MakeValue(index));
1025   }
1026 }
1027 
PartitionChunkGraph(const FuncGraphPtr & fg,int64_t chunk)1028 std::vector<AnfNodePtr> PipelinePostProcess::PartitionChunkGraph(const FuncGraphPtr &fg, int64_t chunk) {
1029   std::vector<AnfNodePtr> temp;
1030   std::vector<AnfNodePtr> recvs;
1031   std::vector<AnfNodePtr> sends;
1032   GetSendsRecvs(fg, chunk, &recvs, &sends, &temp);
1033   AnfNodePtr out;
1034   if (!temp.empty()) {
1035     out = CreateMakeTupleNode(fg, temp);
1036     manager_->Replace(fg->output(), out);
1037   }
1038 
1039   auto params = fg->parameters();
1040   std::vector<AnfNodePtr> new_params;
1041   auto node_users_map = manager_->node_users();
1042   std::vector<size_t> temp_index;
1043   for (size_t i = 0; i < params.size(); ++i) {
1044     auto param = params.at(i);
1045     if (node_users_map[param].size() == 0) {
1046       temp_index.emplace_back(i + 1);
1047       continue;
1048     }
1049     new_params.emplace_back(param);
1050   }
1051   for (auto &node : recvs) {
1052     auto crecv = node->cast<CNodePtr>();
1053     auto new_shared_cell_param = std::make_shared<Parameter>(fg);
1054     new_shared_cell_param->set_abstract(node->abstract());
1055     new_params.emplace_back(new_shared_cell_param);
1056     manager_->Replace(node, new_shared_cell_param);
1057   }
1058   manager_->SetParameters(fg, new_params);
1059   std::vector<AnfNodePtr> main_graph_sends;
1060   mindspore::HashMap<AnfNodePtr, AnfNodePtr> recv_map;
1061   for (auto &usr : shared_cell_users_) {
1062     auto cusr = usr->cast<CNodePtr>();
1063     std::vector<AnfNodePtr> usr_new_inputs = {NewValueNode(fg)};
1064     for (size_t i = 1; i < cusr->inputs().size(); ++i) {
1065       if (std::find(temp_index.begin(), temp_index.end(), i) == temp_index.end()) {
1066         usr_new_inputs.emplace_back(cusr->input(i));
1067       }
1068     }
1069     auto new_usr = main_graph_->NewCNode(usr_new_inputs);
1070     new_usr->set_primal_attrs(cusr->primal_attrs());
1071     new_usr->AddPrimalAttr(CHUNK, MakeValue(chunk));
1072     if (out != nullptr) {
1073       new_usr->set_abstract(out->abstract());
1074     }
1075     auto micro = cusr->GetPrimalAttr(MICRO);
1076     auto index = cusr->GetPrimalAttr(INDEX);
1077     auto temp_sends = GenerateMainGraphSend(sends, new_usr, micro, index);
1078     if (temp_sends.empty()) {
1079       if (stage_ != stage_num_ - 1) {
1080         MS_LOG(EXCEPTION) << "Some wrong with PipelineParallel.";
1081       }
1082       manager_->Replace(usr, new_usr);
1083     }
1084     main_graph_sends.insert(main_graph_sends.end(), temp_sends.begin(), temp_sends.end());
1085     for (auto &recv : recvs) {
1086       auto crecv = recv->cast<CNodePtr>();
1087       if (crecv->HasPrimalAttr(PIPELINE_PARAM)) {
1088         if (recv_map.find(recv) == recv_map.end()) {
1089           auto temp_recv = GenerateMainGraphRecv(new_usr, recv);
1090           recv_map[recv] = temp_recv;
1091           continue;
1092         }
1093         manager_->AddEdge(new_usr, recv_map[recv]);
1094         continue;
1095       }
1096       (void)GenerateMainGraphRecv(new_usr, recv);
1097     }
1098   }
1099   return main_graph_sends;
1100 }
1101 
GraphPartition(const std::vector<AnfNodePtr> & all_nodes)1102 void PipelinePostProcess::GraphPartition(const std::vector<AnfNodePtr> &all_nodes) {
1103   LabelInterleaveIndex();
1104   std::vector<AnfNodePtr> send_ops;
1105   for (size_t i = 0; i < LongToSize(chunk_num_); ++i) {
1106     auto chunk_fg = shared_cell_;
1107     if (stage_ != stage_num_ - 1 || i != LongToSize(chunk_num_ - 1)) {
1108       chunk_fg = BasicClone(shared_cell_);
1109       chunk_fg->set_flag(FUNC_GRAPH_FLAG_CELL_REUSE, true);
1110       manager_->AddFuncGraph(chunk_fg);
1111     }
1112     auto sends = PartitionChunkGraph(chunk_fg, i);
1113     send_ops.insert(send_ops.begin(), sends.begin(), sends.end());
1114   }
1115   auto make_tuple = CreateMakeTupleNode(main_graph_, send_ops);
1116   auto outputs = GetZeroOutputs(main_graph_);
1117   if (stage_ == stage_num_ - 1) {
1118     outputs = main_graph_->output();
1119   }
1120   std::vector<AnfNodePtr> out = {NewValueNode(prim::kPrimDepend), outputs, make_tuple};
1121   auto out_node = main_graph_->NewCNode(out);
1122   (void)manager_->Replace(main_graph_->output(), out_node);
1123 }
1124 
HandleSendParam()1125 void PipelinePostProcess::HandleSendParam() {
1126   auto parameters = root_->parameters();
1127   auto node_users_map = manager_->node_users();
1128   auto nodes = DeepScopedGraphSearch(root_->get_return());
1129   for (auto &node : nodes) {
1130     if (!IsPrimitiveCNode(node, prim::kPrimSend)) {
1131       continue;
1132     }
1133     auto cnode = node->cast<CNodePtr>();
1134     if (!cnode->HasPrimalAttr(PIPELINE_PARAM)) {
1135       continue;
1136     }
1137     auto param = cnode->input(1);
1138     if (IsPrimitiveCNode(param, prim::kPrimVirtualAssignAdd)) {
1139       param = param->cast<CNodePtr>()->input(1);
1140     }
1141     auto param_ptr = param->cast<ParameterPtr>();
1142     MS_EXCEPTION_IF_NULL(param_ptr);
1143     auto accu_parameter = FindGradAccuParameter(parameters, param_ptr->name());
1144     if (!accu_parameter) {
1145       continue;
1146     }
1147     auto accu_users = node_users_map.at(accu_parameter);
1148     AnfNodePtr share_node = nullptr;
1149     for (auto &user : accu_users) {
1150       auto user_node = user.first;
1151       while (IsSomePrimitiveList(user_node->cast<CNodePtr>(),
1152                                  {prim::kPrimMirrorMicroStep->name(), prim::kPrimMicroStepAllGather->name()})) {
1153         share_node = user_node;
1154         user_node = node_users_map.at(user_node).front().first;
1155       }
1156       if (share_node == nullptr) {
1157         continue;
1158       }
1159       auto base_shape = accu_parameter->Shape();
1160       auto shape_ptr = dyn_cast<abstract::Shape>(base_shape);
1161       auto slice_shape = shape_ptr->shape();
1162       auto prim = GetCNodePrimitive(cnode);
1163       std::vector<ValuePtr> element;
1164       (void)std::transform(slice_shape.begin(), slice_shape.end(), std::back_inserter(element),
1165                            [](int elem) { return MakeValue(int64_t(elem)); });
1166       auto value = std::make_shared<ValueList>(element);
1167       prim->set_attr(SHAPE, value);
1168       manager_->SetEdge(cnode, 1, share_node);
1169       break;
1170     }
1171   }
1172 }
1173 
ElimGraphStage()1174 void PipelinePostProcess::ElimGraphStage() {
1175   for (auto &fg : manager_->func_graphs()) {
1176     fg->set_stage(-1);
1177   }
1178 }
1179 
HasNoUpdateParameter()1180 bool PipelineInterleave::HasNoUpdateParameter() {
1181   auto parameters = root_->parameters();
1182   for (auto &parameter : parameters) {
1183     if (ParameterIsCloned(parameter)) {
1184       continue;
1185     }
1186     auto param_info = parameter->cast<ParameterPtr>()->param_info();
1187     if (!param_info) {
1188       continue;
1189     }
1190     auto stage_set = parameter_color_map_.at(parameter);
1191     auto requires_grad = param_info->requires_grad();
1192     if (requires_grad && stage_set.count(stage_)) {
1193       return false;
1194     }
1195   }
1196   return true;
1197 }
1198 
FreezeGradient()1199 void PipelineInterleave::FreezeGradient() {
1200   auto node_users_map = manager_->node_users();
1201   if (HasNoUpdateParameter() && is_train_) {
1202     root_->set_flag(NO_UPDATE, true);
1203     auto nodes = root_->nodes();
1204     for (auto &node : nodes) {
1205       if (!IsPrimitiveCNode(node, prim::kPrimJ)) {
1206         continue;
1207       }
1208       auto node_users = node_users_map.at(node);
1209       auto grad_users = node_users_map.at(node_users.front().first);
1210       for (auto &grad_user : grad_users) {
1211         auto user_node = grad_user.first->cast<CNodePtr>();
1212         if (!IsPrimitiveCNode(user_node, prim::kPrimTupleGetItem)) {
1213           continue;
1214         }
1215         auto index = GetTupleGetItemIndex(user_node);
1216         if (index != 1) {
1217           continue;
1218         }
1219         auto temp = node_users_map.at(user_node).front().first;
1220         auto out = root_->output();
1221         std::vector<AnfNodePtr> depend_input = {NewValueNode(prim::kPrimDepend), out, temp};
1222         auto new_node = root_->NewCNode(depend_input);
1223         manager_->Replace(out, new_node);
1224         break;
1225       }
1226       break;
1227     }
1228     for (auto &node : nodes) {
1229       if (!IsPrimitiveCNode(node, prim::kPrimNPUGetFloatStatusV2)) {
1230         continue;
1231       }
1232       auto cnode = node->cast<CNodePtr>();
1233       auto out_cnode = root_->output()->cast<CNodePtr>();
1234       auto grads = out_cnode->input(INDEX_TWO);
1235       std::vector<AnfNodePtr> depend_input = {NewValueNode(prim::kPrimDepend), cnode->input(1), grads};
1236       auto new_node = root_->NewCNode(depend_input);
1237       new_node->set_abstract(cnode->input(1)->abstract());
1238       manager_->Replace(cnode->input(1), new_node);
1239       break;
1240     }
1241   }
1242 }
1243 
GetDout(const AnfNodePtr & node,const NodeUsersMap & node_users_map)1244 static AnfNodePtr GetDout(const AnfNodePtr &node, const NodeUsersMap &node_users_map) {
1245   auto node_usrs = node_users_map.at(node);
1246   for (auto &node_user_pair : node_usrs) {
1247     auto node_usr = node_user_pair.first->cast<CNodePtr>();
1248     if (!IsPrimitiveCNode(node_usr, prim::kPrimTupleGetItem)) {
1249       continue;
1250     }
1251     auto index = GetTupleGetItemIndex(node_usr);
1252     if (index != 1) {
1253       continue;
1254     }
1255     auto get_item_usrs = node_users_map.at(node_usr);
1256     if (get_item_usrs.size() != 1) {
1257       MS_LOG(WARNING) << "Get Multi grad usrs. Use first.";
1258     }
1259     return get_item_usrs.begin()->first;
1260   }
1261   return nullptr;
1262 }
1263 
NeedAttach(const FuncGraphManagerPtr & manager)1264 static bool NeedAttach(const FuncGraphManagerPtr &manager) {
1265   std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
1266   if (parallel_mode != kAutoParallel && parallel_mode != kSemiAutoParallel) {
1267     return false;
1268   }
1269   bool cell_reuse = false;
1270   for (auto &fg : manager->func_graphs()) {
1271     if (fg->has_flag(FUNC_GRAPH_FLAG_CELL_REUSE)) {
1272       cell_reuse = true;
1273       break;
1274     }
1275   }
1276   auto stage_num = g_device_manager->stage_num();
1277   if (!cell_reuse || stage_num <= 1) {
1278     return false;
1279   }
1280   return true;
1281 }
1282 
IsolatedNodeAttach(const FuncGraphPtr & root,const opt::OptimizerPtr & optimizer)1283 bool IsolatedNodeAttach(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) {
1284   if (root->has_flag(HAS_ATTACHED)) {
1285     return false;
1286   }
1287   root->set_flag(HAS_ATTACHED, true);
1288   auto manager = root->manager();
1289   if (!NeedAttach(manager)) {
1290     return false;
1291   }
1292   auto ret_after = root->get_return();
1293   MS_EXCEPTION_IF_NULL(ret_after);
1294   auto all_nodes = DeepScopedGraphSearch(ret_after);
1295   const auto &node_users_map = manager->node_users();
1296   std::vector<AnfNodePtr> make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)};
1297   FuncGraphPtr main_graph;
1298   FuncGraphPtr grad_graph;
1299   for (auto &node : all_nodes) {
1300     if (!node->isa<CNode>()) {
1301       continue;
1302     }
1303     auto cnode = node->cast<CNodePtr>();
1304     if (!IsValueNode<FuncGraph>(cnode->input(0))) {
1305       continue;
1306     }
1307     auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
1308     auto sub_graph_output = graph->output();
1309     if (!IsPrimitiveCNode(sub_graph_output, prim::kPrimMakeTuple)) {
1310       continue;
1311     }
1312     auto csub_graph_output = sub_graph_output->cast<CNodePtr>();
1313     if (!IsPrimitiveCNode(csub_graph_output->input(1), prim::kPrimReceive)) {
1314       continue;
1315     }
1316     auto call_node_input = cnode->input(1);
1317     if (!IsValueNode<tensor::Tensor>(call_node_input)) {
1318       continue;
1319     }
1320     auto call_node_users = node_users_map.at(node);
1321     if (call_node_users.size() != 1) {
1322       continue;
1323     }
1324     auto usr_node = call_node_users.begin()->first;
1325     if (!IsPrimitiveCNode(usr_node, prim::kPrimTupleGetItem)) {
1326       continue;
1327     }
1328     auto get_item_usrs = node_users_map.at(usr_node);
1329     std::vector<AnfNodePtr> addn_input = {NewValueNode(prim::kPrimAddN)};
1330     main_graph = node->func_graph();
1331     for (auto &get_item_usr_pair : get_item_usrs) {
1332       auto get_item_usr = get_item_usr_pair.first;
1333       auto grad_node = GetDout(get_item_usr, node_users_map);
1334       if (grad_graph == nullptr) {
1335         grad_graph = grad_node->func_graph();
1336       } else {
1337         if (grad_graph != grad_node->func_graph()) {
1338           MS_LOG(EXCEPTION) << "Got Wrong Grad graph when attached Receive's grad, Maybe don't use lazy inline.";
1339         }
1340       }
1341       std::vector<AnfNodePtr> new_get_item_input = {NewValueNode(prim::kPrimTupleGetItem), grad_node,
1342                                                     NewValueNode(MakeValue(SizeToLong(get_item_usr_pair.second)))};
1343       auto new_get_item = grad_graph->NewCNode(new_get_item_input);
1344       addn_input.emplace_back(new_get_item);
1345     }
1346     AnfNodePtr temp;
1347     if (addn_input.size() > SIZE_TWO) {
1348       temp = grad_graph->NewCNode(addn_input);
1349     } else {
1350       temp = addn_input.at(1);
1351     }
1352     std::vector<AnfNodePtr> send_grad_fn_input = {NewValueNode(prim::kPrimTupleGetItem), node,
1353                                                   NewValueNode(MakeValue(int64_t(1)))};
1354     auto send_grad_fn = main_graph->NewCNode(send_grad_fn_input);
1355     auto call_grad_node = grad_graph->NewCNode({send_grad_fn, temp});
1356     std::vector<AnfNodePtr> call_grad_get_item_input = {NewValueNode(prim::kPrimTupleGetItem), call_grad_node,
1357                                                         NewValueNode(MakeValue(int64_t(1)))};
1358     auto call_grad_get_item = grad_graph->NewCNode(call_grad_get_item_input);
1359     make_tuple_input.emplace_back(call_grad_get_item);
1360   }
1361   if (make_tuple_input.size() <= 1) {
1362     return false;
1363   }
1364   auto make_tuple = grad_graph->NewCNode(make_tuple_input);
1365   if (root->has_flag(NO_UPDATE)) {
1366     manager->Replace(grad_graph->output(), make_tuple);
1367     return true;
1368   }
1369   std::vector<AnfNodePtr> attach_node_input = {NewValueNode(prim::kPrimDepend), grad_graph->output(), make_tuple};
1370   auto attach_node = grad_graph->NewCNode(attach_node_input);
1371   manager->Replace(grad_graph->output(), attach_node);
1372   return true;
1373 }
1374 }  // namespace parallel
1375 }  // namespace mindspore
1376