• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <unordered_map>
18 #include <set>
19 #include <vector>
20 #include <string>
21 #include <utility>
22 #include <algorithm>
23 #include <memory>
24 #include "frontend/parallel/pipeline_transformer/pipeline_transformer.h"
25 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
26 #include "frontend/parallel/ops_info/ops_utils.h"
27 #include "frontend/parallel/group_manager.h"
28 #include "frontend/parallel/context.h"
29 #include "frontend/parallel/step_parallel.h"
30 #include "frontend/parallel/node_check.h"
31 #include "frontend/parallel/graph_util/node_info.h"
32 #include "frontend/parallel/graph_util/pipeline_split_utils.h"
33 #include "frontend/parallel/step_parallel_utils.h"
34 #include "ir/anf.h"
35 #include "ir/graph_utils.h"
36 #include "base/core_ops.h"
37 #include "utils/comm_manager.h"
38 #include "utils/ms_context.h"
39 #include "mindspore/core/utils/parallel_node_check.h"
40 
41 namespace mindspore {
42 namespace parallel {
43 std::unordered_map<AnfNodePtr, std::set<int64_t>> parameter_color_map;
44 // map<rank, tag>
45 std::unordered_map<int64_t, int64_t> send_tag_map;
46 std::unordered_map<int64_t, int64_t> recv_tag_map;
47 const std::set<PrimitivePtr> WHITE_LIST = {prim::kPrimTupleGetItem, prim::kPrimMakeTuple, prim::kPrimCast};
48 
IsInWhiteList(const CNodePtr & cnode)49 static bool IsInWhiteList(const CNodePtr &cnode) {
50   for (auto &prim : WHITE_LIST) {
51     if (IsPrimitiveCNode(cnode, prim)) {
52       return true;
53     }
54   }
55   return false;
56 }
57 
MainGraph()58 void PipelineTransformer::MainGraph() {
59   if (!root_->has_flag(TRAINING)) {
60     main_graph_ = root_;
61     return;
62   }
63   for (auto &fg : manager_->func_graphs()) {
64     for (auto &node : fg->nodes()) {
65       if (IsPrimitiveCNode(node, prim::kPrimVirtualDataset)) {
66         main_graph_ = fg;
67         main_graph_->set_flag(MAIN_GRAPH, true);
68         virtual_dataset_ = node;
69         return;
70       }
71     }
72   }
73   MS_LOG(EXCEPTION) << "Can't find main graph, possible reason is can't find virtual dataset.";
74 }
75 
SetMicroBatch(const AnfNodePtr & node,int64_t micro_size)76 ValuePtr PipelineTransformer::SetMicroBatch(const AnfNodePtr &node, int64_t micro_size) {
77   if (!IsPrimitiveCNode(node, prim::kPrimStridedSlice)) {
78     MS_LOG(EXCEPTION) << "Can't find MicroBatch information.";
79   }
80   auto cnode = node->cast<CNodePtr>();
81   auto value = GetValueNode(cnode->input(2));
82   MS_EXCEPTION_IF_NULL(value);
83   auto tuple = GetValue<std::vector<int64_t>>(value);
84   auto input_shape = GetNodeShape(cnode->input(1)).at(0);
85   int64_t micro = tuple.at(0) * micro_size / input_shape.at(0);
86   cnode->AddPrimalAttr(MICRO, MakeValue(micro));
87   cnode->AddPrimalAttr(PIPELINE_BEGIN, MakeValue(micro));
88   return MakeValue(micro);
89 }
90 
NeedGrad(const CNodePtr & cnode,const CNodePtr & graph_cnode)91 bool PipelineTransformer::NeedGrad(const CNodePtr &cnode, const CNodePtr &graph_cnode) {
92   for (auto &input : cnode->inputs()) {
93     auto temp = input;
94     while (IsPrimitiveCNode(temp, prim::kPrimLoad) || IsPrimitiveCNode(temp, prim::kPrimCast)) {
95       auto input_cnode = input->cast<CNodePtr>();
96       temp = input_cnode->input(1);
97     }
98     if (temp->isa<Parameter>()) {
99       auto graph = cnode->func_graph();
100       auto parameters = graph->parameters();
101       auto iter = std::find(parameters.begin(), parameters.end(), temp);
102       if (iter == parameters.end() && ParameterRequireGrad(temp)) {
103         return true;
104       }
105       if (iter != parameters.end() && graph != main_graph_) {
106         auto pos = std::distance(parameters.begin(), iter);
107         MS_EXCEPTION_IF_NULL(graph_cnode);
108         auto real_param = graph_cnode->input(LongToSize(pos + 1));
109         if (real_param->isa<Parameter>() && ParameterRequireGrad(real_param)) {
110           return true;
111         }
112       }
113     }
114   }
115   return false;
116 }
117 
LabelParameterStart(const FuncGraphPtr & graph,const CNodePtr & graph_cnode)118 bool PipelineTransformer::LabelParameterStart(const FuncGraphPtr &graph, const CNodePtr &graph_cnode) {
119   auto orders = graph->GetOrderedCnodes();
120   for (auto &node : orders) {
121     auto cnode = node->cast<CNodePtr>();
122     MS_EXCEPTION_IF_NULL(cnode);
123     if (cnode->stage() > 0) {
124       continue;
125     }
126     if (IsValueNode<FuncGraph>(cnode->input(0))) {
127       auto sub_graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
128       if (LabelParameterStart(sub_graph, cnode)) {
129         return true;
130       } else {
131         continue;
132       }
133     }
134     if (!IsPipelineCareNode(cnode)) {
135       continue;
136     }
137     if (NeedGrad(cnode, graph_cnode)) {
138       auto prim = GetCNodePrimitive(cnode);
139       (void)prim->AddAttr(PARAMETER_START, MakeValue(0));
140       return true;
141     }
142   }
143   return false;
144 }
145 
LabelMicroBatch()146 void PipelineTransformer::LabelMicroBatch() {
147   if (!root_->has_flag(TRAINING)) {
148     return;
149   }
150   MS_EXCEPTION_IF_NULL(main_graph_);
151   if (!LabelParameterStart(main_graph_, nullptr)) {
152     MS_LOG(EXCEPTION) << "Stage 0 should has at least 1 parameter. but got none.";
153   }
154   MS_EXCEPTION_IF_NULL(virtual_dataset_);
155   auto node_user_map = manager_->node_users();
156   auto node_users = node_user_map[virtual_dataset_];
157   for (auto &node_user : node_users) {
158     if (IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) {
159       auto data_users = manager_->node_users()[node_user.first];
160       auto node_first = data_users.front().first;
161       if (!IsPrimitiveCNode(node_first, prim::kPrimStridedSlice)) {
162         data_users.clear();
163         data_users = node_user_map[node_first];
164       }
165       auto micro_size = int64_t(data_users.size());
166       micro_size_ = micro_size;
167       MS_LOG(INFO) << "Micro Size is: " << micro_size;
168       for (auto &data_user : data_users) {
169         auto micro = SetMicroBatch(data_user.first, micro_size);
170         SetStridedSliceStrategy(data_user.first);
171         auto cnode = data_user.first->cast<CNodePtr>();
172         BroadCastMicroBatch(cnode, &node_user_map, micro, 0);
173       }
174     }
175   }
176 }
177 
CreateForwardGroup()178 void PipelineTransformer::CreateForwardGroup() {
179   std::vector<int64_t> rank_list;
180   auto rank_id = g_device_manager->global_rank();
181   auto stage_id = g_device_manager->stage_id();
182   auto stage_num = g_device_manager->stage_num();
183   for (int64_t i = 0; i < stage_num; ++i) {
184     rank_list.push_back(rank_id + per_stage_rank_num_ * (i - stage_id));
185   }
186   auto dev_list = g_device_manager->CreateDeviceListByRankList(rank_list);
187   auto g = g_device_manager->CreateGroup(rank_list);
188   auto g_back_name = g.name() + BACKWARD;
189   auto g_back = g_device_manager->CreateGroup(g_back_name, dev_list);
190   group_.push_back(g.name());
191   group_.push_back(g_back.name());
192 }
193 
Coloring()194 void PipelineTransformer::Coloring() {
195   auto need_coloring = true;
196   std::set<int64_t> stage_set;
197   while (need_coloring) {
198     need_coloring = false;
199     for (auto &fg : manager_->func_graphs()) {
200       if (fg == root_ && root_->has_flag(TRAINING)) {
201         continue;
202       }
203       auto value_nodes = fg->value_nodes();
204       for (auto &value_pair : value_nodes) {
205         auto node = value_pair.first;
206         if (!IsValueNode<FuncGraph>(node)) {
207           continue;
208         }
209         auto graph = GetValueNode<FuncGraphPtr>(node);
210         if (graph->stage() == -1) {
211           continue;
212         }
213         stage_set.insert(graph->stage());
214         auto node_users = manager_->node_users()[node];
215         for (auto &user_pair : node_users) {
216           auto user_node = user_pair.first->cast<CNodePtr>();
217           user_node->set_stage(graph->stage());
218           auto user_node_graph = user_node->func_graph();
219           if (graph->stage() == stage_ && user_node_graph->stage() == -1) {
220             user_node_graph->set_stage(graph->stage());
221             need_coloring = true;
222           }
223         }
224       }
225     }
226   }
227   MS_EXCEPTION_IF_NULL(g_device_manager);
228   auto stage_num = g_device_manager->stage_num();
229   if (SizeToLong(stage_set.size()) != stage_num) {
230     MS_LOG(EXCEPTION) << "Stage num is " << stage_num << " is not equal to stage used: " << stage_set.size();
231   }
232 }
233 
BroadCastColoring()234 void PipelineTransformer::BroadCastColoring() {
235   auto need_coloring = true;
236   while (need_coloring) {
237     need_coloring = false;
238     auto all_nodes = main_graph_->nodes();
239     auto node_users = manager_->node_users();
240     for (auto &node : all_nodes) {
241       if (!node->isa<CNode>() || node->stage() == -1) {
242         continue;
243       }
244       auto stage = node->stage();
245       for (auto &user_pair : node_users[node]) {
246         auto user_node = user_pair.first->cast<CNodePtr>();
247         auto user_node_stage = user_node->stage();
248         if (stage > user_node_stage) {
249           if (IsValueNode<FuncGraph>(user_node->input(0))) {
250             MS_LOG(EXCEPTION) << "The stage setting is incorrect. PreNode's stage:" << stage
251                               << " is larger than NextNode's stage:" << user_node_stage;
252           }
253           user_node->set_stage(stage);
254           need_coloring = true;
255         }
256       }
257     }
258   }
259 }
260 
IsPipelineCareNode(const CNodePtr & cnode)261 bool PipelineTransformer::IsPipelineCareNode(const CNodePtr &cnode) {
262   MS_EXCEPTION_IF_NULL(cnode);
263   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
264   if (!prim) {
265     return false;
266   }
267   if (IsInWhiteList(cnode)) {
268     return false;
269   }
270   if (IsInParallelBlackList(prim)) {
271     MS_LOG(INFO) << "PipelineSplit don't care node:" << prim->name();
272     return false;
273   }
274   return true;
275 }
276 
GraphOutNode(const AnfNodePtr & node,int tuple_index)277 CNodePtr PipelineTransformer::GraphOutNode(const AnfNodePtr &node, int tuple_index) {
278   auto cnode = node->cast<CNodePtr>();
279   MS_EXCEPTION_IF_NULL(cnode);
280   if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
281     return GraphOutNode(cnode->input(1), tuple_index);
282   }
283   if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
284     return cnode->input(IntToSize(tuple_index) + 1)->cast<CNodePtr>();
285   }
286   return cnode;
287 }
288 
CreateOpInfo(const CNodePtr & cnode,int tuple_index=0)289 OperatorInfoPtr PipelineTransformer::CreateOpInfo(const CNodePtr &cnode, int tuple_index = 0) {
290   MS_EXCEPTION_IF_NULL(cnode);
291   auto temp_node = cnode;
292   if (IsValueNode<FuncGraph>(cnode->input(0))) {
293     auto output = GetValueNode<FuncGraphPtr>(cnode->input(0))->output();
294     MS_EXCEPTION_IF_NULL(output);
295     temp_node = GraphOutNode(output, tuple_index);
296   }
297   if (!IsPipelineCareNode(temp_node)) {
298     MS_LOG(EXCEPTION) << "Node: " << temp_node->DebugString() << " is not a Pipeline Care Node.";
299   }
300   if (IsPrimitiveCNode(temp_node, prim::kPrimVirtualDataset)) {
301     SetVirtualDatasetStrategy(temp_node);
302   }
303   auto shape_list = ExtractShape(temp_node);
304   if (shape_list.empty()) {
305     MS_LOG(EXCEPTION) << "Node: " << temp_node->DebugString() << " failed to extract shape.";
306   }
307   auto prim = GetValueNode<PrimitivePtr>(temp_node->input(0));
308   MS_EXCEPTION_IF_NULL(prim);
309   if (prim->name() == RESHAPE) {
310     MS_LOG(EXCEPTION) << "Reshape op can't be a border. node:" << temp_node->DebugString();
311   }
312   auto attrs = prim->attrs();
313   auto op_info = OperatorInstance(prim, attrs, shape_list);
314   auto &inputs = temp_node->inputs();
315   std::vector<ValuePtr> input_value;
316   for (size_t index = 1; index < inputs.size(); ++index) {
317     if (inputs[index]->isa<ValueNode>()) {
318       input_value.push_back(GetValueNode(inputs[index]));
319     } else {
320       input_value.emplace_back(nullptr);
321     }
322   }
323   op_info->set_input_value(input_value);
324   op_info->set_outputs_dtype(temp_node->Type());
325   op_info->set_cnode(temp_node);
326   StrategyPtr strategy = nullptr;
327   if (!StrategyFound(attrs)) {
328     strategy = GenerateBatchParallelStrategy(op_info, prim);
329   } else {
330     strategy = ExtractStrategy(attrs[STRATEGY]);
331   }
332   MS_EXCEPTION_IF_NULL(strategy);
333   if (op_info->Init(strategy) == FAILED) {
334     MS_LOG(EXCEPTION) << "operator: " << prim->name() << " init failed.";
335   }
336   return op_info;
337 }
338 
GetOpInfo(const AnfNodePtr & node)339 std::pair<OperatorInfoPtr, int> PipelineTransformer::GetOpInfo(const AnfNodePtr &node) {
340   MS_EXCEPTION_IF_NULL(node);
341   auto cnode = node->cast<CNodePtr>();
342   MS_EXCEPTION_IF_NULL(cnode);
343   // Handle Cast and TupleGetitem situation
344   int tensor_info_index = 0;
345   OperatorInfoPtr op_info;
346   if (IsPrimitiveCNode(node, prim::kPrimReceive)) {
347     op_info = node->user_data<OperatorInfo>();
348   } else {
349     if (IsPrimitiveCNode(node, prim::kPrimCast)) {
350       cnode = cnode->input(1)->cast<CNodePtr>();
351     } else if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
352       tensor_info_index = LongToInt(GetTupleGetItemIndex(cnode));
353       cnode = cnode->input(1)->cast<CNodePtr>();
354     }
355     // Create OperatorInfo to get slice_shape for send/recv
356     MS_EXCEPTION_IF_NULL(cnode);
357     op_info = CreateOpInfo(cnode, tensor_info_index);
358   }
359   return std::make_pair(op_info, tensor_info_index);
360 }
361 
GetActualOpUsers(const std::pair<AnfNodePtr,int> & node_pair,NodeUsersMap * node_users_map)362 AnfNodeIndexSet PipelineTransformer::GetActualOpUsers(const std::pair<AnfNodePtr, int> &node_pair,
363                                                       NodeUsersMap *node_users_map) {
364   auto temp_node = node_pair.first;
365   auto temp_cnode = temp_node->cast<CNodePtr>();
366   MS_EXCEPTION_IF_NULL(temp_cnode);
367   if (IsValueNode<FuncGraph>(temp_cnode->input(0))) {
368     auto graph = GetValueNode<FuncGraphPtr>(temp_cnode->input(0));
369     auto temp_params = graph->parameters();
370     if (temp_params.size() < IntToSize(node_pair.second)) {
371       MS_LOG(EXCEPTION) << "parameter: " << temp_node->DebugString() << " out of graph:" << graph->ToString()
372                         << "'s range.";
373     }
374     temp_node = temp_params[IntToSize(node_pair.second - 1)];
375   }
376   auto temp_users = (*node_users_map)[temp_node];
377   auto node = temp_users.front().first;
378   if (IsPrimitiveCNode(node, prim::kPrimLoad) || IsPrimitiveCNode(node, prim::kPrimCast)) {
379     return GetActualOpUsers(temp_users.front(), node_users_map);
380   }
381   return temp_users;
382 }
383 
GetParameterPair(const AnfNodePtr & node)384 std::pair<OperatorInfoPtr, int> PipelineTransformer::GetParameterPair(const AnfNodePtr &node) {
385   MS_EXCEPTION_IF_NULL(node);
386   auto node_users_map = manager_->node_users();
387   auto node_users = node_users_map[node];
388   for (auto &node_user : node_users) {
389     auto load_users = GetActualOpUsers(node_user, &node_users_map);
390     for (auto &user_pair : load_users) {
391       auto user_node = user_pair.first->cast<CNodePtr>();
392       MS_EXCEPTION_IF_NULL(user_node);
393       auto user_node_graph = user_node->func_graph();
394       MS_EXCEPTION_IF_NULL(user_node_graph);
395       if (user_node_graph->stage() == -1) {
396         continue;
397       }
398       auto index = user_pair.second;
399       if (!IsPipelineCareNode(user_node)) {
400         continue;
401       }
402       auto op_info = CreateOpInfo(user_node);
403       return std::make_pair(op_info, index - 1);
404     }
405   }
406   return std::make_pair(nullptr, 0);
407 }
408 
HandleSharedParameter()409 std::vector<AnfNodePtr> PipelineTransformer::HandleSharedParameter() {
410   auto parameters = root_->parameters();
411   std::vector<AnfNodePtr> make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)};
412   std::vector<AnfNodePtr> recvs = {};
413   for (auto &parameter : parameters) {
414     auto parameter_stage = parameter_color_map[parameter];
415     if (parameter_stage.size() <= 1) {
416       continue;
417     }
418     auto users = manager_->node_users()[parameter];
419     for (auto &user : users) {
420       auto node = user.first;
421       auto cnode = node->cast<CNodePtr>();
422       auto graph = node->func_graph();
423       if (IsValueNode<FuncGraph>(cnode->input(0))) {
424         graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
425       }
426       if (graph == root_ || graph->stage() == -1 || !parameter_stage.count(stage_)) {
427         continue;
428       }
429       auto micro = cnode->GetPrimalAttr(MICRO);
430       if (!micro) {
431         MS_LOG(INFO) << "parameter: " << parameter->ToString() << " doesn't have micro batch";
432         micro = MakeValue(int64_t(0));
433       }
434       auto user_stage = node->stage();
435       if (stage_ == *parameter_stage.begin()) {
436         if (graph->stage() == stage_) {
437           continue;
438         }
439         if (Reuse(parameter, user_stage, make_tuple_input, DEST_RANK)) {
440           continue;
441         }
442         auto send_out = InsertSend(parameter, user_stage, stage_, micro);
443         make_tuple_input.push_back(send_out.depend);
444       } else {
445         auto receive = Reuse(parameter, *parameter_stage.begin(), recvs, SRC_RANK);
446         if (receive) {
447           manager_->SetEdge(node, user.second, receive);
448         } else {
449           auto recv = InsertReceive(main_graph_, parameter, node, user.second, stage_, *parameter_stage.begin(), micro,
450                                     parameter);
451           recvs.push_back(recv);
452         }
453       }
454     }
455   }
456   return make_tuple_input;
457 }
458 
ParameterColoring()459 void PipelineTransformer::ParameterColoring() {
460   auto parameters = root_->parameters();
461   for (auto &parameter : parameters) {
462     auto users = manager_->node_users()[parameter];
463     std::set<int64_t> parameter_stage;
464     for (auto &user : users) {
465       auto node = user.first->cast<CNodePtr>();
466       auto graph = node->func_graph();
467       if (IsValueNode<FuncGraph>(node->input(0))) {
468         graph = GetValueNode<FuncGraphPtr>(node->input(0));
469       }
470       if (graph != root_ && graph->stage() != -1) {
471         parameter_stage.insert(graph->stage());
472         parameter->set_stage(graph->stage());
473       }
474     }
475     auto param_info = parameter->cast<ParameterPtr>()->param_info();
476     if (!param_info) {
477       parameter_color_map[parameter] = parameter_stage;
478       continue;
479     }
480     MS_EXCEPTION_IF_NULL(param_info);
481     auto requires_grad = param_info->requires_grad();
482     if (*parameter_stage.begin() == stage_ && !virtual_param_ && requires_grad) {
483       virtual_param_ = parameter;
484     }
485     parameter_color_map[parameter] = parameter_stage;
486   }
487 }
488 
GetShapeType(const AnfNodePtr & node,const Shape & shape)489 static std::pair<ValueListPtr, TypePtr> GetShapeType(const AnfNodePtr &node, const Shape &shape) {
490   TypePtr type;
491   auto cnode = node->cast<CNodePtr>();
492   if (cnode != nullptr && IsValueNode<FuncGraph>(cnode->input(0))) {
493     auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
494     auto graph_output = graph->output();
495     type = graph_output->Type();
496   } else {
497     type = node->Type();
498   }
499   MS_EXCEPTION_IF_NULL(type);
500   std::vector<ValuePtr> element;
501   std::transform(shape.begin(), shape.end(), std::back_inserter(element), [](int elem) { return MakeValue(elem); });
502   auto shape_list = std::make_shared<ValueList>(element);
503   auto tensor_type = type->cast<mindspore::TensorTypePtr>();
504   MS_EXCEPTION_IF_NULL(tensor_type);
505   auto dtype = tensor_type->element();
506   MS_EXCEPTION_IF_NULL(dtype);
507   return std::make_pair(shape_list, dtype);
508 }
509 
FindPipelineCareNode(const AnfNodePtr & node)510 AnfNodePtr PipelineTransformer::FindPipelineCareNode(const AnfNodePtr &node) {
511   MS_EXCEPTION_IF_NULL(node);
512   auto cnode = node->cast<CNodePtr>();
513   MS_EXCEPTION_IF_NULL(cnode);
514   int64_t get_item_index = 0;
515   if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
516     get_item_index = LongToInt(GetTupleGetItemIndex(cnode));
517     cnode = cnode->input(1)->cast<CNodePtr>();
518     MS_EXCEPTION_IF_NULL(cnode);
519   }
520   if (IsValueNode<FuncGraph>(cnode->input(0))) {
521     auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
522     auto output = graph->output();
523     MS_EXCEPTION_IF_NULL(output);
524     while (IsPrimitiveCNode(output, prim::kPrimDepend)) {
525       auto output_cnode = output->cast<CNodePtr>();
526       MS_EXCEPTION_IF_NULL(output_cnode);
527       output = output_cnode->input(1);
528     }
529     if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
530       auto make_tuple_cnode = output->cast<CNodePtr>();
531       output = make_tuple_cnode->input(LongToSize(get_item_index + 1));
532     }
533     if (output->isa<Parameter>()) {
534       auto parameters = graph->parameters();
535       auto pos_iter = std::find(parameters.begin(), parameters.end(), output);
536       auto pos = std::distance(parameters.begin(), pos_iter);
537       return FindPipelineCareNode(cnode->input(LongToSize(pos + 1)));
538     }
539     cnode = output->cast<CNodePtr>();
540     MS_EXCEPTION_IF_NULL(cnode);
541   }
542   if (IsPrimitiveCNode(cnode, prim::kPrimDepend)) {
543     return FindPipelineCareNode(cnode->input(1));
544   }
545   if (IsInWhiteList(cnode)) {
546     return cnode->cast<AnfNodePtr>();
547   }
548   if (!IsPipelineCareNode(cnode)) {
549     MS_LOG(EXCEPTION) << "Only PipelineSplit cared node can be a border."
550                       << " border node: " << cnode->DebugString();
551   }
552   return cnode->cast<AnfNodePtr>();
553 }
554 
InsertSend(const AnfNodePtr & parameter,int64_t user_node_stage,int64_t node_stage,const ValuePtr & value)555 SendAttr PipelineTransformer::InsertSend(const AnfNodePtr &parameter, int64_t user_node_stage, int64_t node_stage,
556                                          const ValuePtr &value) {
557   auto dest_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_;
558   int64_t send_tag;
559   if (send_tag_map.find(dest_rank) != send_tag_map.end()) {
560     send_tag = send_tag_map[dest_rank] + 1;
561     send_tag_map[dest_rank] += 1;
562   } else {
563     send_tag = 0;
564     send_tag_map[dest_rank] = 0;
565   }
566   Attr attr_tag = std::make_pair(SR_TAG, MakeValue(send_tag));
567   Attr attr_rank = std::make_pair(DEST_RANK, MakeValue(user_node_stage));
568   Attr attr_group = std::make_pair(GROUP, MakeValue(group_[0]));
569   Attr attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[1]));
570   OperatorAttrs attrs = {attr_tag, attr_rank, attr_group, attr_group_back};
571   auto send_op = CreatOpInstance(attrs, SEND, SEND);
572   auto send_node = NewValueNode(send_op);
573   auto prim = GetValueNode<PrimitivePtr>(send_node);
574   std::pair<OperatorInfoPtr, int> op_info_pair;
575   AnfNodePtr care_node;
576   TensorInfo tensor_info;
577   if (parameter->isa<Parameter>()) {
578     op_info_pair = GetParameterPair(parameter);
579     tensor_info = op_info_pair.first->inputs_tensor_info().at(IntToSize(op_info_pair.second));
580   } else {
581     care_node = FindPipelineCareNode(parameter);
582     if (care_node->isa<Parameter>()) {
583       op_info_pair = GetParameterPair(care_node);
584       tensor_info = op_info_pair.first->inputs_tensor_info().at(IntToSize(op_info_pair.second));
585     } else {
586       op_info_pair = GetOpInfo(care_node);
587       tensor_info = op_info_pair.first->outputs_tensor_info().at(IntToSize(op_info_pair.second));
588     }
589   }
590   auto index = op_info_pair.second;
591   auto op_info = op_info_pair.first;
592   auto slice_shape = tensor_info.slice_shape();
593   auto shape_type_pair = GetShapeType(parameter, slice_shape);
594   prim->set_attr(SHAPE, shape_type_pair.first);
595   prim->set_attr(DTYPE, shape_type_pair.second);
596   std::vector<AnfNodePtr> send_input = {send_node, parameter};
597   auto send = main_graph_->NewCNode(send_input);
598   if (!parameter->isa<Parameter>() && care_node != nullptr && !care_node->isa<Parameter>()) {
599     send->AddPrimalAttr(PIPELINE_END, value);
600   } else {
601     send->AddPrimalAttr(PIPELINE_PARAM, value);
602     send->set_user_data<OperatorInfo>(op_info);
603     send->AddPrimalAttr(PARAM_INDEX, MakeValue(index));
604   }
605   send->AddPrimalAttr(MICRO, value);
606   OperatorAttrs depend_attrs;
607   auto depend_op = CreatOpInstance(depend_attrs, DEPEND, DEPEND);
608   std::vector<AnfNodePtr> depend_input = {NewValueNode(depend_op), parameter, send};
609   auto depend = main_graph_->NewCNode(depend_input);
610   auto abstract = parameter->abstract();
611   if (care_node) {
612     abstract = care_node->abstract();
613   }
614   depend->set_abstract(abstract);
615   send->set_abstract(abstract);
616   SendAttr send_out = {shape_type_pair.first, shape_type_pair.second, depend};
617   return send_out;
618 }
619 
InsertReceive(const FuncGraphPtr & graph,const AnfNodePtr & node,const AnfNodePtr & use_node,int index,int64_t user_node_stage,int64_t node_stage,const ValuePtr & value,const AnfNodePtr & graph_param)620 AnfNodePtr PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node,
621                                               const AnfNodePtr &use_node, int index, int64_t user_node_stage,
622                                               int64_t node_stage, const ValuePtr &value,
623                                               const AnfNodePtr &graph_param) {
624   auto src_rank = global_rank_ - (user_node_stage - node_stage) * per_stage_rank_num_;
625   int64_t recv_tag;
626   if (recv_tag_map.find(src_rank) != recv_tag_map.end()) {
627     recv_tag = recv_tag_map[src_rank] + 1;
628     recv_tag_map[src_rank] += 1;
629   } else {
630     recv_tag = 0;
631     recv_tag_map[src_rank] = 0;
632   }
633   Attr attr_tag = std::make_pair(SR_TAG, MakeValue(recv_tag));
634   Attr attr_rank = std::make_pair(SRC_RANK, MakeValue(node_stage));
635   std::pair<OperatorInfoPtr, int> op_info_pair;
636   bool is_param = true;
637   TensorInfo tensor_info;
638   if (node->isa<Parameter>()) {
639     op_info_pair = GetParameterPair(graph_param);
640     tensor_info = op_info_pair.first->inputs_tensor_info().at(IntToSize(op_info_pair.second));
641   } else {
642     auto care_node = FindPipelineCareNode(node);
643     op_info_pair = GetOpInfo(care_node);
644     tensor_info = op_info_pair.first->outputs_tensor_info().at(IntToSize(op_info_pair.second));
645     is_param = false;
646   }
647   auto tensor_layout = tensor_info.tensor_layout();
648   Shape slice_shape = tensor_info.slice_shape();
649   auto shape_type_pair = GetShapeType(node, slice_shape);
650   Attr attr_shape = std::make_pair(SHAPE, shape_type_pair.first);
651   Attr attr_dtype = std::make_pair(DTYPE, shape_type_pair.second);
652   Attr attr_group = std::make_pair(GROUP, MakeValue(group_[0]));
653   Attr attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[1]));
654   OperatorAttrs attrs = {attr_tag, attr_rank, attr_shape, attr_dtype, attr_group, attr_group_back};
655   auto recv_op = CreatOpInstance(attrs, RECEIVE, RECEIVE);
656   std::vector<AnfNodePtr> recv_input;
657   if (node->isa<Parameter>()) {
658     recv_input = {NewValueNode(recv_op), node};
659   } else {
660     recv_input = {NewValueNode(recv_op), virtual_param_};
661   }
662   auto recv = graph->NewCNode(recv_input);
663   if (is_param) {
664     recv->set_user_data<AnfNode>(PIPELINE_PARAM, node);
665     recv->AddPrimalAttr(PIPELINE_PARAM, value);
666   } else {
667     recv->AddPrimalAttr(PIPELINE_BEGIN, value);
668   }
669   recv->AddPrimalAttr(MICRO, value);
670   auto node_abstract = node->abstract();
671   if (node->isa<CNode>()) {
672     auto cnode = node->cast<CNodePtr>();
673     MS_EXCEPTION_IF_NULL(cnode);
674     if (IsValueNode<FuncGraph>(cnode->input(0))) {
675       auto output = GetValueNode<FuncGraphPtr>(cnode->input(0))->output();
676       MS_EXCEPTION_IF_NULL(output);
677       node_abstract = output->abstract();
678     }
679   }
680   MS_EXCEPTION_IF_NULL(node_abstract);
681   recv->set_abstract(node_abstract);
682   if (node->isa<Parameter>()) {
683     BaseShapePtr parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
684     auto abstract_clone = node->abstract()->Clone();
685     MS_EXCEPTION_IF_NULL(abstract_clone);
686     abstract_clone->set_shape(parallel_shape);
687     node->set_abstract(abstract_clone);
688     node->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
689   }
690   recv->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
691   recv->set_user_data<OperatorInfo>(op_info_pair.first);
692 
693   manager_->SetEdge(use_node, index, recv);
694   return recv;
695 }
696 
Reuse(const AnfNodePtr & node,int64_t stage,const std::vector<AnfNodePtr> & out_input,const std::string & tag)697 AnfNodePtr PipelineTransformer::Reuse(const AnfNodePtr &node, int64_t stage, const std::vector<AnfNodePtr> &out_input,
698                                       const std::string &tag) {
699   for (auto &input : out_input) {
700     auto cnode = input->cast<CNodePtr>();
701     if (!cnode) {
702       continue;
703     }
704     if (IsPrimitiveCNode(cnode, prim::kPrimDepend)) {
705       cnode = cnode->input(2)->cast<CNodePtr>();
706     }
707     if (cnode->input(1) == node) {
708       auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
709       auto dest_rank_send = GetValue<int64_t>(prim->GetAttr(tag));
710       if (dest_rank_send == stage) {
711         return input;
712       }
713     }
714   }
715   return nullptr;
716 }
717 
ActualOp(const AnfNodePtr & node)718 AnfNodePtr PipelineTransformer::ActualOp(const AnfNodePtr &node) {
719   // skip some virtual op like:Depend, Load, Cast
720   if (IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimCast) ||
721       IsPrimitiveCNode(node, prim::kPrimLoad)) {
722     auto cnode = node->cast<CNodePtr>();
723     MS_EXCEPTION_IF_NULL(cnode);
724     return ActualOp(cnode->input(1));
725   }
726   return node;
727 }
728 
IsParameterGraph(const AnfNodePtr & node)729 bool PipelineTransformer::IsParameterGraph(const AnfNodePtr &node) {
730   // ParameterGraph: graph which return a parameter
731   MS_EXCEPTION_IF_NULL(node);
732   auto temp_node = ActualOp(node);
733   auto cnode = temp_node->cast<CNodePtr>();
734   MS_EXCEPTION_IF_NULL(cnode);
735 
736   // parameter_graph->return->graph
737   if (!IsValueNode<FuncGraph>(cnode->input(0))) {
738     return false;
739   }
740   auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
741   MS_EXCEPTION_IF_NULL(graph);
742   auto graph_out = graph->output();
743   MS_EXCEPTION_IF_NULL(graph_out);
744   auto actual_op = ActualOp(graph_out);
745   MS_EXCEPTION_IF_NULL(actual_op);
746   if (actual_op->isa<Parameter>()) {
747     auto parameter_list = graph->parameters();
748     // parameter_graph->parameter->return->graph
749     auto parameter_iter = std::find(parameter_list.begin(), parameter_list.end(), actual_op);
750     if (parameter_iter == parameter_list.end()) {
751       return true;
752     }
753     // parameter->graph->return->graph
754     auto pos = std::distance(parameter_list.begin(), parameter_iter);
755     if (!cnode->input(LongToSize(pos + 1))->isa<Parameter>()) {
756       return false;
757     }
758     return true;
759   }
760   return false;
761 }
762 
HandleParameterGraph(const AnfNodePtr & node,const AnfNodePtr & use_node,int64_t stage,int64_t user_stage,const ValuePtr & micro,size_t pos,const std::vector<AnfNodePtr> ops)763 AnfNodePtr PipelineTransformer::HandleParameterGraph(const AnfNodePtr &node, const AnfNodePtr &use_node, int64_t stage,
764                                                      int64_t user_stage, const ValuePtr &micro, size_t pos,
765                                                      const std::vector<AnfNodePtr> ops) {
766   MS_EXCEPTION_IF_NULL(node);
767   auto actual_node = ActualOp(node);
768   auto cnode = actual_node->cast<CNodePtr>();
769   MS_EXCEPTION_IF_NULL(cnode);
770   auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
771   MS_EXCEPTION_IF_NULL(graph);
772   AnfNodePtr argument;
773   AnfNodePtr parameter;
774 
775   auto graph_out = ActualOp(graph->output());
776   MS_EXCEPTION_IF_NULL(graph_out);
777   auto parameter_list = graph->parameters();
778   auto param_iter = std::find(parameter_list.begin(), parameter_list.end(), graph_out);
779   auto use_cnode = use_node->cast<CNodePtr>();
780   MS_EXCEPTION_IF_NULL(use_cnode);
781   if (!IsValueNode<FuncGraph>(use_cnode->input(0))) {
782     MS_LOG(EXCEPTION) << "Parameter must be used by a graph, but got: " << use_cnode->DebugString();
783   }
784   auto use_graph = GetValueNode<FuncGraphPtr>(use_cnode->input(0));
785   auto use_parameter_list = use_graph->parameters();
786   parameter = use_parameter_list.at(pos - 1);
787   // argument->load->graph
788   if (param_iter == parameter_list.end()) {
789     argument = graph_out;
790   } else {
791     auto param_pos = std::distance(parameter_list.begin(), param_iter);
792     argument = cnode->input(LongToSize(param_pos + 1));
793   }
794 
795   // insert receive
796   if (stage_ == user_stage) {
797     auto recv = Reuse(argument, stage, ops, SRC_RANK);
798     if (recv) {
799       manager_->SetEdge(use_node, SizeToInt(pos), recv);
800       return nullptr;
801     }
802     return InsertReceive(main_graph_, argument, use_node, SizeToInt(pos), user_stage, stage, micro, parameter);
803   }
804   // insert send
805   if (Reuse(argument, user_stage, ops, DEST_RANK)) {
806     return nullptr;
807   }
808   auto send_out = InsertSend(argument, user_stage, stage_, micro);
809   send_out.depend->set_user_data<Type>(DTYPE, send_out.type);
810   send_out.depend->set_user_data<ValueList>(SHAPE, send_out.shape);
811   return send_out.depend;
812 }
813 
CutBorder(const FuncGraphPtr & graph)814 std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> PipelineTransformer::CutBorder(const FuncGraphPtr &graph) {
815   std::vector<AnfNodePtr> receive_ops;
816   std::vector<AnfNodePtr> send_ops;
817   auto ret = graph->get_return();
818   MS_EXCEPTION_IF_NULL(ret);
819   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
820   std::reverse(all_nodes.begin(), all_nodes.end());
821   auto stage_num = g_device_manager->stage_num();
822   if (root_->has_flag(TRAINING) && (stage_num > micro_size_)) {
823     MS_LOG(EXCEPTION) << "MicroBatch size: " << micro_size_ << " can't less than stage num: " << stage_num;
824   }
825   for (auto &node : all_nodes) {
826     if (!node->isa<CNode>() || node->stage() == -1 || IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
827       continue;
828     }
829     auto node_users = manager_->node_users()[node];
830     AnfNodePtr receive = nullptr;
831     for (auto &user_pair : node_users) {
832       auto user_node = user_pair.first;
833       auto node_stage = node->stage();
834       auto user_node_stage = user_node->stage();
835       if (node_stage != stage_ && user_node_stage != stage_) {
836         continue;
837       }
838       auto micro = user_node->cast<CNodePtr>()->GetPrimalAttr(MICRO);
839       if (!micro) {
840         MS_LOG(INFO) << "Can't find micro_batch information, use micro(0)";
841         micro = MakeValue(int64_t(0));
842       }
843       if (node_stage < user_node_stage) {
844         if (node_stage == stage_) {
845           if (IsParameterGraph(node)) {
846             auto send_depend = HandleParameterGraph(node, user_node, node_stage, user_node_stage, micro,
847                                                     IntToSize(user_pair.second), send_ops);
848             if (!send_depend) {
849               continue;
850             }
851             (void)send_ops.insert(send_ops.begin(), send_depend);
852             continue;
853           }
854           if (Reuse(node, user_node_stage, send_ops, DEST_RANK)) {
855             continue;
856           }
857           auto send_out = InsertSend(node, user_node_stage, node_stage, micro);
858           MS_EXCEPTION_IF_NULL(send_out.depend);
859           send_ops.push_back(send_out.depend);
860           send_out.depend->set_user_data<Type>(DTYPE, send_out.type);
861           send_out.depend->set_user_data<ValueList>(SHAPE, send_out.shape);
862         } else {
863           if (!receive) {
864             if (IsParameterGraph(node)) {
865               receive = HandleParameterGraph(node, user_node, node_stage, user_node_stage, micro,
866                                              IntToSize(user_pair.second), receive_ops);
867               if (!receive) {
868                 continue;
869               }
870               receive_ops.push_back(receive);
871             } else {
872               receive =
873                 InsertReceive(graph, node, user_node, user_pair.second, user_node_stage, node_stage, micro, node);
874               receive_ops.push_back(receive);
875             }
876           } else {
877             manager_->SetEdge(user_node, user_pair.second, receive);
878           }
879         }
880         continue;
881       }
882       if (node_stage > user_node_stage) {
883         MS_LOG(EXCEPTION) << "node_stage: " << node_stage
884                           << " must be smaller than user_node_stage: " << user_node_stage;
885       }
886     }
887   }
888   return std::make_pair(send_ops, receive_ops);
889 }
890 
CutGraph()891 void PipelineTransformer::CutGraph() {
892   std::vector<AnfNodePtr> make_tuple_inputs;
893   CreateForwardGroup();
894   MS_EXCEPTION_IF_NULL(main_graph_);
895   if (make_tuple_inputs.empty()) {
896     make_tuple_inputs = HandleSharedParameter();
897   }
898   auto send_recv_ops = CutBorder(main_graph_);
899   auto send_ops = send_recv_ops.first;
900   if (IsLastStage()) {
901     return;
902   }
903   if (send_ops.empty() && !root_->has_flag(TRAINING)) {
904     return;
905   }
906   (void)make_tuple_inputs.insert(make_tuple_inputs.end(), send_ops.begin(), send_ops.end());
907   if (!send_ops.empty()) {
908     type_ptr_ = send_ops.back()->user_data<Type>(DTYPE);
909     shape_ = send_ops.back()->user_data<ValueList>(SHAPE);
910   }
911   auto make_tuple = main_graph_->NewCNode(make_tuple_inputs);
912   std::vector<AnfNodePtr> out = {NewValueNode(prim::kPrimDepend)};
913   out.push_back(send_ops.back());
914   out.push_back(make_tuple);
915   auto out_node = main_graph_->NewCNode(out);
916   (void)manager_->Replace(main_graph_->output(), out_node);
917 }
918 
ElimGraphStage()919 void PipelineTransformer::ElimGraphStage() {
920   for (auto &fg : manager_->func_graphs()) {
921     fg->set_stage(-1);
922   }
923 }
924 
FindSensNode()925 std::pair<CNodePtr, FuncGraphPtr> PipelineTransformer::FindSensNode() {
926   std::pair<CNodePtr, FuncGraphPtr> sens_graph_pair;
927   CNodePtr sens_cnode;
928   FuncGraphPtr func_graph;
929   for (auto &node : root_->nodes()) {
930     if (!node->isa<CNode>()) {
931       continue;
932     }
933     sens_cnode = node->cast<CNodePtr>();
934     AnfNodePtr expect_tuple_getitem = sens_cnode->input(0);
935     MS_EXCEPTION_IF_NULL(expect_tuple_getitem);
936     if (!expect_tuple_getitem->isa<CNode>()) {
937       continue;
938     }
939 
940     auto expect_tuple_getitem_cnode = expect_tuple_getitem->cast<CNodePtr>();
941     if (!IsPrimitiveCNode(expect_tuple_getitem_cnode, prim::kPrimTupleGetItem)) {
942       continue;
943     }
944     auto expect_anonymous = expect_tuple_getitem_cnode->input(1);
945     if (!expect_anonymous->isa<CNode>()) {
946       continue;
947     }
948     auto expect_anonymous_cnode = expect_anonymous->cast<CNodePtr>();
949     AnfNodePtr expect_j = expect_anonymous_cnode->input(0);
950     if (!expect_j->isa<CNode>()) {
951       continue;
952     }
953     auto expect_j_cnode = expect_j->cast<CNodePtr>();
954     if (!IsPrimitiveCNode(expect_j_cnode, prim::kPrimJ)) {
955       continue;
956     }
957     func_graph = GetValueNode<FuncGraphPtr>(expect_j_cnode->input(1));
958     break;
959   }
960   sens_graph_pair = std::make_pair(sens_cnode, func_graph);
961   return sens_graph_pair;
962 }
963 
CoverSensShape()964 void PipelineTransformer::CoverSensShape() {
965   if (IsLastStage()) {
966     return;
967   }
968   auto sens_graph_pair = FindSensNode();
969   auto sens_cnode = sens_graph_pair.first;
970   MS_EXCEPTION_IF_NULL(sens_cnode);
971   OperatorAttrs attrs;
972   auto fill_op = CreatOpInstance(attrs, "Fill", "");
973   MS_EXCEPTION_IF_NULL(type_ptr_);
974   MS_EXCEPTION_IF_NULL(shape_);
975   std::vector<AnfNodePtr> fill_input = {NewValueNode(fill_op), NewValueNode(type_ptr_),
976                                         NewValueNode(MakeValue(shape_->value())), NewValueNode(0)};
977   auto fill = root_->NewCNode(fill_input);
978   std::vector<AnfNodePtr> new_sens_input = {sens_cnode->input(0), fill};
979   auto new_sens_node = root_->NewCNode(new_sens_input);
980   manager_->Replace(sens_cnode, new_sens_node);
981 }
982 
ElimParameter()983 void PipelineTransformer::ElimParameter() {
984   auto parameters = root_->parameters();
985   std::vector<AnfNodePtr> parameter_list;
986   for (auto &parameter : parameters) {
987     auto param = parameter->cast<ParameterPtr>();
988     MS_EXCEPTION_IF_NULL(param);
989     if (!manager_->node_users()[parameter].empty() || !param->has_default()) {
990       parameter_list.push_back(parameter);
991     }
992   }
993   auto del_num = parameters.size() - parameter_list.size();
994   root_->set_hyper_param_count(root_->hyper_param_count() - del_num);
995   manager_->SetParameters(root_, parameter_list);
996 }
997 }  // namespace parallel
998 }  // namespace mindspore
999