1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <unordered_map>
18 #include <set>
19 #include <vector>
20 #include <string>
21 #include <utility>
22 #include <algorithm>
23 #include <memory>
24 #include "frontend/parallel/pipeline_transformer/pipeline_transformer.h"
25 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
26 #include "frontend/parallel/ops_info/ops_utils.h"
27 #include "frontend/parallel/group_manager.h"
28 #include "frontend/parallel/context.h"
29 #include "frontend/parallel/step_parallel.h"
30 #include "frontend/parallel/node_check.h"
31 #include "frontend/parallel/graph_util/node_info.h"
32 #include "frontend/parallel/graph_util/pipeline_split_utils.h"
33 #include "frontend/parallel/step_parallel_utils.h"
34 #include "ir/anf.h"
35 #include "ir/graph_utils.h"
36 #include "base/core_ops.h"
37 #include "utils/comm_manager.h"
38 #include "utils/ms_context.h"
39 #include "mindspore/core/utils/parallel_node_check.h"
40
41 namespace mindspore {
42 namespace parallel {
43 std::unordered_map<AnfNodePtr, std::set<int64_t>> parameter_color_map;
44 // map<rank, tag>
45 std::unordered_map<int64_t, int64_t> send_tag_map;
46 std::unordered_map<int64_t, int64_t> recv_tag_map;
47 const std::set<PrimitivePtr> WHITE_LIST = {prim::kPrimTupleGetItem, prim::kPrimMakeTuple, prim::kPrimCast};
48
IsInWhiteList(const CNodePtr & cnode)49 static bool IsInWhiteList(const CNodePtr &cnode) {
50 for (auto &prim : WHITE_LIST) {
51 if (IsPrimitiveCNode(cnode, prim)) {
52 return true;
53 }
54 }
55 return false;
56 }
57
MainGraph()58 void PipelineTransformer::MainGraph() {
59 if (!root_->has_flag(TRAINING)) {
60 main_graph_ = root_;
61 return;
62 }
63 for (auto &fg : manager_->func_graphs()) {
64 for (auto &node : fg->nodes()) {
65 if (IsPrimitiveCNode(node, prim::kPrimVirtualDataset)) {
66 main_graph_ = fg;
67 main_graph_->set_flag(MAIN_GRAPH, true);
68 virtual_dataset_ = node;
69 return;
70 }
71 }
72 }
73 MS_LOG(EXCEPTION) << "Can't find main graph, possible reason is can't find virtual dataset.";
74 }
75
SetMicroBatch(const AnfNodePtr & node,int64_t micro_size)76 ValuePtr PipelineTransformer::SetMicroBatch(const AnfNodePtr &node, int64_t micro_size) {
77 if (!IsPrimitiveCNode(node, prim::kPrimStridedSlice)) {
78 MS_LOG(EXCEPTION) << "Can't find MicroBatch information.";
79 }
80 auto cnode = node->cast<CNodePtr>();
81 auto value = GetValueNode(cnode->input(2));
82 MS_EXCEPTION_IF_NULL(value);
83 auto tuple = GetValue<std::vector<int64_t>>(value);
84 auto input_shape = GetNodeShape(cnode->input(1)).at(0);
85 int64_t micro = tuple.at(0) * micro_size / input_shape.at(0);
86 cnode->AddPrimalAttr(MICRO, MakeValue(micro));
87 cnode->AddPrimalAttr(PIPELINE_BEGIN, MakeValue(micro));
88 return MakeValue(micro);
89 }
90
NeedGrad(const CNodePtr & cnode,const CNodePtr & graph_cnode)91 bool PipelineTransformer::NeedGrad(const CNodePtr &cnode, const CNodePtr &graph_cnode) {
92 for (auto &input : cnode->inputs()) {
93 auto temp = input;
94 while (IsPrimitiveCNode(temp, prim::kPrimLoad) || IsPrimitiveCNode(temp, prim::kPrimCast)) {
95 auto input_cnode = input->cast<CNodePtr>();
96 temp = input_cnode->input(1);
97 }
98 if (temp->isa<Parameter>()) {
99 auto graph = cnode->func_graph();
100 auto parameters = graph->parameters();
101 auto iter = std::find(parameters.begin(), parameters.end(), temp);
102 if (iter == parameters.end() && ParameterRequireGrad(temp)) {
103 return true;
104 }
105 if (iter != parameters.end() && graph != main_graph_) {
106 auto pos = std::distance(parameters.begin(), iter);
107 MS_EXCEPTION_IF_NULL(graph_cnode);
108 auto real_param = graph_cnode->input(LongToSize(pos + 1));
109 if (real_param->isa<Parameter>() && ParameterRequireGrad(real_param)) {
110 return true;
111 }
112 }
113 }
114 }
115 return false;
116 }
117
LabelParameterStart(const FuncGraphPtr & graph,const CNodePtr & graph_cnode)118 bool PipelineTransformer::LabelParameterStart(const FuncGraphPtr &graph, const CNodePtr &graph_cnode) {
119 auto orders = graph->GetOrderedCnodes();
120 for (auto &node : orders) {
121 auto cnode = node->cast<CNodePtr>();
122 MS_EXCEPTION_IF_NULL(cnode);
123 if (cnode->stage() > 0) {
124 continue;
125 }
126 if (IsValueNode<FuncGraph>(cnode->input(0))) {
127 auto sub_graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
128 if (LabelParameterStart(sub_graph, cnode)) {
129 return true;
130 } else {
131 continue;
132 }
133 }
134 if (!IsPipelineCareNode(cnode)) {
135 continue;
136 }
137 if (NeedGrad(cnode, graph_cnode)) {
138 auto prim = GetCNodePrimitive(cnode);
139 (void)prim->AddAttr(PARAMETER_START, MakeValue(0));
140 return true;
141 }
142 }
143 return false;
144 }
145
LabelMicroBatch()146 void PipelineTransformer::LabelMicroBatch() {
147 if (!root_->has_flag(TRAINING)) {
148 return;
149 }
150 MS_EXCEPTION_IF_NULL(main_graph_);
151 if (!LabelParameterStart(main_graph_, nullptr)) {
152 MS_LOG(EXCEPTION) << "Stage 0 should has at least 1 parameter. but got none.";
153 }
154 MS_EXCEPTION_IF_NULL(virtual_dataset_);
155 auto node_user_map = manager_->node_users();
156 auto node_users = node_user_map[virtual_dataset_];
157 for (auto &node_user : node_users) {
158 if (IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) {
159 auto data_users = manager_->node_users()[node_user.first];
160 auto node_first = data_users.front().first;
161 if (!IsPrimitiveCNode(node_first, prim::kPrimStridedSlice)) {
162 data_users.clear();
163 data_users = node_user_map[node_first];
164 }
165 auto micro_size = int64_t(data_users.size());
166 micro_size_ = micro_size;
167 MS_LOG(INFO) << "Micro Size is: " << micro_size;
168 for (auto &data_user : data_users) {
169 auto micro = SetMicroBatch(data_user.first, micro_size);
170 SetStridedSliceStrategy(data_user.first);
171 auto cnode = data_user.first->cast<CNodePtr>();
172 BroadCastMicroBatch(cnode, &node_user_map, micro, 0);
173 }
174 }
175 }
176 }
177
CreateForwardGroup()178 void PipelineTransformer::CreateForwardGroup() {
179 std::vector<int64_t> rank_list;
180 auto rank_id = g_device_manager->global_rank();
181 auto stage_id = g_device_manager->stage_id();
182 auto stage_num = g_device_manager->stage_num();
183 for (int64_t i = 0; i < stage_num; ++i) {
184 rank_list.push_back(rank_id + per_stage_rank_num_ * (i - stage_id));
185 }
186 auto dev_list = g_device_manager->CreateDeviceListByRankList(rank_list);
187 auto g = g_device_manager->CreateGroup(rank_list);
188 auto g_back_name = g.name() + BACKWARD;
189 auto g_back = g_device_manager->CreateGroup(g_back_name, dev_list);
190 group_.push_back(g.name());
191 group_.push_back(g_back.name());
192 }
193
Coloring()194 void PipelineTransformer::Coloring() {
195 auto need_coloring = true;
196 std::set<int64_t> stage_set;
197 while (need_coloring) {
198 need_coloring = false;
199 for (auto &fg : manager_->func_graphs()) {
200 if (fg == root_ && root_->has_flag(TRAINING)) {
201 continue;
202 }
203 auto value_nodes = fg->value_nodes();
204 for (auto &value_pair : value_nodes) {
205 auto node = value_pair.first;
206 if (!IsValueNode<FuncGraph>(node)) {
207 continue;
208 }
209 auto graph = GetValueNode<FuncGraphPtr>(node);
210 if (graph->stage() == -1) {
211 continue;
212 }
213 stage_set.insert(graph->stage());
214 auto node_users = manager_->node_users()[node];
215 for (auto &user_pair : node_users) {
216 auto user_node = user_pair.first->cast<CNodePtr>();
217 user_node->set_stage(graph->stage());
218 auto user_node_graph = user_node->func_graph();
219 if (graph->stage() == stage_ && user_node_graph->stage() == -1) {
220 user_node_graph->set_stage(graph->stage());
221 need_coloring = true;
222 }
223 }
224 }
225 }
226 }
227 MS_EXCEPTION_IF_NULL(g_device_manager);
228 auto stage_num = g_device_manager->stage_num();
229 if (SizeToLong(stage_set.size()) != stage_num) {
230 MS_LOG(EXCEPTION) << "Stage num is " << stage_num << " is not equal to stage used: " << stage_set.size();
231 }
232 }
233
BroadCastColoring()234 void PipelineTransformer::BroadCastColoring() {
235 auto need_coloring = true;
236 while (need_coloring) {
237 need_coloring = false;
238 auto all_nodes = main_graph_->nodes();
239 auto node_users = manager_->node_users();
240 for (auto &node : all_nodes) {
241 if (!node->isa<CNode>() || node->stage() == -1) {
242 continue;
243 }
244 auto stage = node->stage();
245 for (auto &user_pair : node_users[node]) {
246 auto user_node = user_pair.first->cast<CNodePtr>();
247 auto user_node_stage = user_node->stage();
248 if (stage > user_node_stage) {
249 if (IsValueNode<FuncGraph>(user_node->input(0))) {
250 MS_LOG(EXCEPTION) << "The stage setting is incorrect. PreNode's stage:" << stage
251 << " is larger than NextNode's stage:" << user_node_stage;
252 }
253 user_node->set_stage(stage);
254 need_coloring = true;
255 }
256 }
257 }
258 }
259 }
260
IsPipelineCareNode(const CNodePtr & cnode)261 bool PipelineTransformer::IsPipelineCareNode(const CNodePtr &cnode) {
262 MS_EXCEPTION_IF_NULL(cnode);
263 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
264 if (!prim) {
265 return false;
266 }
267 if (IsInWhiteList(cnode)) {
268 return false;
269 }
270 if (IsInParallelBlackList(prim)) {
271 MS_LOG(INFO) << "PipelineSplit don't care node:" << prim->name();
272 return false;
273 }
274 return true;
275 }
276
GraphOutNode(const AnfNodePtr & node,int tuple_index)277 CNodePtr PipelineTransformer::GraphOutNode(const AnfNodePtr &node, int tuple_index) {
278 auto cnode = node->cast<CNodePtr>();
279 MS_EXCEPTION_IF_NULL(cnode);
280 if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
281 return GraphOutNode(cnode->input(1), tuple_index);
282 }
283 if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
284 return cnode->input(IntToSize(tuple_index) + 1)->cast<CNodePtr>();
285 }
286 return cnode;
287 }
288
CreateOpInfo(const CNodePtr & cnode,int tuple_index=0)289 OperatorInfoPtr PipelineTransformer::CreateOpInfo(const CNodePtr &cnode, int tuple_index = 0) {
290 MS_EXCEPTION_IF_NULL(cnode);
291 auto temp_node = cnode;
292 if (IsValueNode<FuncGraph>(cnode->input(0))) {
293 auto output = GetValueNode<FuncGraphPtr>(cnode->input(0))->output();
294 MS_EXCEPTION_IF_NULL(output);
295 temp_node = GraphOutNode(output, tuple_index);
296 }
297 if (!IsPipelineCareNode(temp_node)) {
298 MS_LOG(EXCEPTION) << "Node: " << temp_node->DebugString() << " is not a Pipeline Care Node.";
299 }
300 if (IsPrimitiveCNode(temp_node, prim::kPrimVirtualDataset)) {
301 SetVirtualDatasetStrategy(temp_node);
302 }
303 auto shape_list = ExtractShape(temp_node);
304 if (shape_list.empty()) {
305 MS_LOG(EXCEPTION) << "Node: " << temp_node->DebugString() << " failed to extract shape.";
306 }
307 auto prim = GetValueNode<PrimitivePtr>(temp_node->input(0));
308 MS_EXCEPTION_IF_NULL(prim);
309 if (prim->name() == RESHAPE) {
310 MS_LOG(EXCEPTION) << "Reshape op can't be a border. node:" << temp_node->DebugString();
311 }
312 auto attrs = prim->attrs();
313 auto op_info = OperatorInstance(prim, attrs, shape_list);
314 auto &inputs = temp_node->inputs();
315 std::vector<ValuePtr> input_value;
316 for (size_t index = 1; index < inputs.size(); ++index) {
317 if (inputs[index]->isa<ValueNode>()) {
318 input_value.push_back(GetValueNode(inputs[index]));
319 } else {
320 input_value.emplace_back(nullptr);
321 }
322 }
323 op_info->set_input_value(input_value);
324 op_info->set_outputs_dtype(temp_node->Type());
325 op_info->set_cnode(temp_node);
326 StrategyPtr strategy = nullptr;
327 if (!StrategyFound(attrs)) {
328 strategy = GenerateBatchParallelStrategy(op_info, prim);
329 } else {
330 strategy = ExtractStrategy(attrs[STRATEGY]);
331 }
332 MS_EXCEPTION_IF_NULL(strategy);
333 if (op_info->Init(strategy) == FAILED) {
334 MS_LOG(EXCEPTION) << "operator: " << prim->name() << " init failed.";
335 }
336 return op_info;
337 }
338
GetOpInfo(const AnfNodePtr & node)339 std::pair<OperatorInfoPtr, int> PipelineTransformer::GetOpInfo(const AnfNodePtr &node) {
340 MS_EXCEPTION_IF_NULL(node);
341 auto cnode = node->cast<CNodePtr>();
342 MS_EXCEPTION_IF_NULL(cnode);
343 // Handle Cast and TupleGetitem situation
344 int tensor_info_index = 0;
345 OperatorInfoPtr op_info;
346 if (IsPrimitiveCNode(node, prim::kPrimReceive)) {
347 op_info = node->user_data<OperatorInfo>();
348 } else {
349 if (IsPrimitiveCNode(node, prim::kPrimCast)) {
350 cnode = cnode->input(1)->cast<CNodePtr>();
351 } else if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
352 tensor_info_index = LongToInt(GetTupleGetItemIndex(cnode));
353 cnode = cnode->input(1)->cast<CNodePtr>();
354 }
355 // Create OperatorInfo to get slice_shape for send/recv
356 MS_EXCEPTION_IF_NULL(cnode);
357 op_info = CreateOpInfo(cnode, tensor_info_index);
358 }
359 return std::make_pair(op_info, tensor_info_index);
360 }
361
GetActualOpUsers(const std::pair<AnfNodePtr,int> & node_pair,NodeUsersMap * node_users_map)362 AnfNodeIndexSet PipelineTransformer::GetActualOpUsers(const std::pair<AnfNodePtr, int> &node_pair,
363 NodeUsersMap *node_users_map) {
364 auto temp_node = node_pair.first;
365 auto temp_cnode = temp_node->cast<CNodePtr>();
366 MS_EXCEPTION_IF_NULL(temp_cnode);
367 if (IsValueNode<FuncGraph>(temp_cnode->input(0))) {
368 auto graph = GetValueNode<FuncGraphPtr>(temp_cnode->input(0));
369 auto temp_params = graph->parameters();
370 if (temp_params.size() < IntToSize(node_pair.second)) {
371 MS_LOG(EXCEPTION) << "parameter: " << temp_node->DebugString() << " out of graph:" << graph->ToString()
372 << "'s range.";
373 }
374 temp_node = temp_params[IntToSize(node_pair.second - 1)];
375 }
376 auto temp_users = (*node_users_map)[temp_node];
377 auto node = temp_users.front().first;
378 if (IsPrimitiveCNode(node, prim::kPrimLoad) || IsPrimitiveCNode(node, prim::kPrimCast)) {
379 return GetActualOpUsers(temp_users.front(), node_users_map);
380 }
381 return temp_users;
382 }
383
GetParameterPair(const AnfNodePtr & node)384 std::pair<OperatorInfoPtr, int> PipelineTransformer::GetParameterPair(const AnfNodePtr &node) {
385 MS_EXCEPTION_IF_NULL(node);
386 auto node_users_map = manager_->node_users();
387 auto node_users = node_users_map[node];
388 for (auto &node_user : node_users) {
389 auto load_users = GetActualOpUsers(node_user, &node_users_map);
390 for (auto &user_pair : load_users) {
391 auto user_node = user_pair.first->cast<CNodePtr>();
392 MS_EXCEPTION_IF_NULL(user_node);
393 auto user_node_graph = user_node->func_graph();
394 MS_EXCEPTION_IF_NULL(user_node_graph);
395 if (user_node_graph->stage() == -1) {
396 continue;
397 }
398 auto index = user_pair.second;
399 if (!IsPipelineCareNode(user_node)) {
400 continue;
401 }
402 auto op_info = CreateOpInfo(user_node);
403 return std::make_pair(op_info, index - 1);
404 }
405 }
406 return std::make_pair(nullptr, 0);
407 }
408
HandleSharedParameter()409 std::vector<AnfNodePtr> PipelineTransformer::HandleSharedParameter() {
410 auto parameters = root_->parameters();
411 std::vector<AnfNodePtr> make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)};
412 std::vector<AnfNodePtr> recvs = {};
413 for (auto ¶meter : parameters) {
414 auto parameter_stage = parameter_color_map[parameter];
415 if (parameter_stage.size() <= 1) {
416 continue;
417 }
418 auto users = manager_->node_users()[parameter];
419 for (auto &user : users) {
420 auto node = user.first;
421 auto cnode = node->cast<CNodePtr>();
422 auto graph = node->func_graph();
423 if (IsValueNode<FuncGraph>(cnode->input(0))) {
424 graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
425 }
426 if (graph == root_ || graph->stage() == -1 || !parameter_stage.count(stage_)) {
427 continue;
428 }
429 auto micro = cnode->GetPrimalAttr(MICRO);
430 if (!micro) {
431 MS_LOG(INFO) << "parameter: " << parameter->ToString() << " doesn't have micro batch";
432 micro = MakeValue(int64_t(0));
433 }
434 auto user_stage = node->stage();
435 if (stage_ == *parameter_stage.begin()) {
436 if (graph->stage() == stage_) {
437 continue;
438 }
439 if (Reuse(parameter, user_stage, make_tuple_input, DEST_RANK)) {
440 continue;
441 }
442 auto send_out = InsertSend(parameter, user_stage, stage_, micro);
443 make_tuple_input.push_back(send_out.depend);
444 } else {
445 auto receive = Reuse(parameter, *parameter_stage.begin(), recvs, SRC_RANK);
446 if (receive) {
447 manager_->SetEdge(node, user.second, receive);
448 } else {
449 auto recv = InsertReceive(main_graph_, parameter, node, user.second, stage_, *parameter_stage.begin(), micro,
450 parameter);
451 recvs.push_back(recv);
452 }
453 }
454 }
455 }
456 return make_tuple_input;
457 }
458
ParameterColoring()459 void PipelineTransformer::ParameterColoring() {
460 auto parameters = root_->parameters();
461 for (auto ¶meter : parameters) {
462 auto users = manager_->node_users()[parameter];
463 std::set<int64_t> parameter_stage;
464 for (auto &user : users) {
465 auto node = user.first->cast<CNodePtr>();
466 auto graph = node->func_graph();
467 if (IsValueNode<FuncGraph>(node->input(0))) {
468 graph = GetValueNode<FuncGraphPtr>(node->input(0));
469 }
470 if (graph != root_ && graph->stage() != -1) {
471 parameter_stage.insert(graph->stage());
472 parameter->set_stage(graph->stage());
473 }
474 }
475 auto param_info = parameter->cast<ParameterPtr>()->param_info();
476 if (!param_info) {
477 parameter_color_map[parameter] = parameter_stage;
478 continue;
479 }
480 MS_EXCEPTION_IF_NULL(param_info);
481 auto requires_grad = param_info->requires_grad();
482 if (*parameter_stage.begin() == stage_ && !virtual_param_ && requires_grad) {
483 virtual_param_ = parameter;
484 }
485 parameter_color_map[parameter] = parameter_stage;
486 }
487 }
488
GetShapeType(const AnfNodePtr & node,const Shape & shape)489 static std::pair<ValueListPtr, TypePtr> GetShapeType(const AnfNodePtr &node, const Shape &shape) {
490 TypePtr type;
491 auto cnode = node->cast<CNodePtr>();
492 if (cnode != nullptr && IsValueNode<FuncGraph>(cnode->input(0))) {
493 auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
494 auto graph_output = graph->output();
495 type = graph_output->Type();
496 } else {
497 type = node->Type();
498 }
499 MS_EXCEPTION_IF_NULL(type);
500 std::vector<ValuePtr> element;
501 std::transform(shape.begin(), shape.end(), std::back_inserter(element), [](int elem) { return MakeValue(elem); });
502 auto shape_list = std::make_shared<ValueList>(element);
503 auto tensor_type = type->cast<mindspore::TensorTypePtr>();
504 MS_EXCEPTION_IF_NULL(tensor_type);
505 auto dtype = tensor_type->element();
506 MS_EXCEPTION_IF_NULL(dtype);
507 return std::make_pair(shape_list, dtype);
508 }
509
FindPipelineCareNode(const AnfNodePtr & node)510 AnfNodePtr PipelineTransformer::FindPipelineCareNode(const AnfNodePtr &node) {
511 MS_EXCEPTION_IF_NULL(node);
512 auto cnode = node->cast<CNodePtr>();
513 MS_EXCEPTION_IF_NULL(cnode);
514 int64_t get_item_index = 0;
515 if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
516 get_item_index = LongToInt(GetTupleGetItemIndex(cnode));
517 cnode = cnode->input(1)->cast<CNodePtr>();
518 MS_EXCEPTION_IF_NULL(cnode);
519 }
520 if (IsValueNode<FuncGraph>(cnode->input(0))) {
521 auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
522 auto output = graph->output();
523 MS_EXCEPTION_IF_NULL(output);
524 while (IsPrimitiveCNode(output, prim::kPrimDepend)) {
525 auto output_cnode = output->cast<CNodePtr>();
526 MS_EXCEPTION_IF_NULL(output_cnode);
527 output = output_cnode->input(1);
528 }
529 if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
530 auto make_tuple_cnode = output->cast<CNodePtr>();
531 output = make_tuple_cnode->input(LongToSize(get_item_index + 1));
532 }
533 if (output->isa<Parameter>()) {
534 auto parameters = graph->parameters();
535 auto pos_iter = std::find(parameters.begin(), parameters.end(), output);
536 auto pos = std::distance(parameters.begin(), pos_iter);
537 return FindPipelineCareNode(cnode->input(LongToSize(pos + 1)));
538 }
539 cnode = output->cast<CNodePtr>();
540 MS_EXCEPTION_IF_NULL(cnode);
541 }
542 if (IsPrimitiveCNode(cnode, prim::kPrimDepend)) {
543 return FindPipelineCareNode(cnode->input(1));
544 }
545 if (IsInWhiteList(cnode)) {
546 return cnode->cast<AnfNodePtr>();
547 }
548 if (!IsPipelineCareNode(cnode)) {
549 MS_LOG(EXCEPTION) << "Only PipelineSplit cared node can be a border."
550 << " border node: " << cnode->DebugString();
551 }
552 return cnode->cast<AnfNodePtr>();
553 }
554
InsertSend(const AnfNodePtr & parameter,int64_t user_node_stage,int64_t node_stage,const ValuePtr & value)555 SendAttr PipelineTransformer::InsertSend(const AnfNodePtr ¶meter, int64_t user_node_stage, int64_t node_stage,
556 const ValuePtr &value) {
557 auto dest_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_;
558 int64_t send_tag;
559 if (send_tag_map.find(dest_rank) != send_tag_map.end()) {
560 send_tag = send_tag_map[dest_rank] + 1;
561 send_tag_map[dest_rank] += 1;
562 } else {
563 send_tag = 0;
564 send_tag_map[dest_rank] = 0;
565 }
566 Attr attr_tag = std::make_pair(SR_TAG, MakeValue(send_tag));
567 Attr attr_rank = std::make_pair(DEST_RANK, MakeValue(user_node_stage));
568 Attr attr_group = std::make_pair(GROUP, MakeValue(group_[0]));
569 Attr attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[1]));
570 OperatorAttrs attrs = {attr_tag, attr_rank, attr_group, attr_group_back};
571 auto send_op = CreatOpInstance(attrs, SEND, SEND);
572 auto send_node = NewValueNode(send_op);
573 auto prim = GetValueNode<PrimitivePtr>(send_node);
574 std::pair<OperatorInfoPtr, int> op_info_pair;
575 AnfNodePtr care_node;
576 TensorInfo tensor_info;
577 if (parameter->isa<Parameter>()) {
578 op_info_pair = GetParameterPair(parameter);
579 tensor_info = op_info_pair.first->inputs_tensor_info().at(IntToSize(op_info_pair.second));
580 } else {
581 care_node = FindPipelineCareNode(parameter);
582 if (care_node->isa<Parameter>()) {
583 op_info_pair = GetParameterPair(care_node);
584 tensor_info = op_info_pair.first->inputs_tensor_info().at(IntToSize(op_info_pair.second));
585 } else {
586 op_info_pair = GetOpInfo(care_node);
587 tensor_info = op_info_pair.first->outputs_tensor_info().at(IntToSize(op_info_pair.second));
588 }
589 }
590 auto index = op_info_pair.second;
591 auto op_info = op_info_pair.first;
592 auto slice_shape = tensor_info.slice_shape();
593 auto shape_type_pair = GetShapeType(parameter, slice_shape);
594 prim->set_attr(SHAPE, shape_type_pair.first);
595 prim->set_attr(DTYPE, shape_type_pair.second);
596 std::vector<AnfNodePtr> send_input = {send_node, parameter};
597 auto send = main_graph_->NewCNode(send_input);
598 if (!parameter->isa<Parameter>() && care_node != nullptr && !care_node->isa<Parameter>()) {
599 send->AddPrimalAttr(PIPELINE_END, value);
600 } else {
601 send->AddPrimalAttr(PIPELINE_PARAM, value);
602 send->set_user_data<OperatorInfo>(op_info);
603 send->AddPrimalAttr(PARAM_INDEX, MakeValue(index));
604 }
605 send->AddPrimalAttr(MICRO, value);
606 OperatorAttrs depend_attrs;
607 auto depend_op = CreatOpInstance(depend_attrs, DEPEND, DEPEND);
608 std::vector<AnfNodePtr> depend_input = {NewValueNode(depend_op), parameter, send};
609 auto depend = main_graph_->NewCNode(depend_input);
610 auto abstract = parameter->abstract();
611 if (care_node) {
612 abstract = care_node->abstract();
613 }
614 depend->set_abstract(abstract);
615 send->set_abstract(abstract);
616 SendAttr send_out = {shape_type_pair.first, shape_type_pair.second, depend};
617 return send_out;
618 }
619
InsertReceive(const FuncGraphPtr & graph,const AnfNodePtr & node,const AnfNodePtr & use_node,int index,int64_t user_node_stage,int64_t node_stage,const ValuePtr & value,const AnfNodePtr & graph_param)620 AnfNodePtr PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node,
621 const AnfNodePtr &use_node, int index, int64_t user_node_stage,
622 int64_t node_stage, const ValuePtr &value,
623 const AnfNodePtr &graph_param) {
624 auto src_rank = global_rank_ - (user_node_stage - node_stage) * per_stage_rank_num_;
625 int64_t recv_tag;
626 if (recv_tag_map.find(src_rank) != recv_tag_map.end()) {
627 recv_tag = recv_tag_map[src_rank] + 1;
628 recv_tag_map[src_rank] += 1;
629 } else {
630 recv_tag = 0;
631 recv_tag_map[src_rank] = 0;
632 }
633 Attr attr_tag = std::make_pair(SR_TAG, MakeValue(recv_tag));
634 Attr attr_rank = std::make_pair(SRC_RANK, MakeValue(node_stage));
635 std::pair<OperatorInfoPtr, int> op_info_pair;
636 bool is_param = true;
637 TensorInfo tensor_info;
638 if (node->isa<Parameter>()) {
639 op_info_pair = GetParameterPair(graph_param);
640 tensor_info = op_info_pair.first->inputs_tensor_info().at(IntToSize(op_info_pair.second));
641 } else {
642 auto care_node = FindPipelineCareNode(node);
643 op_info_pair = GetOpInfo(care_node);
644 tensor_info = op_info_pair.first->outputs_tensor_info().at(IntToSize(op_info_pair.second));
645 is_param = false;
646 }
647 auto tensor_layout = tensor_info.tensor_layout();
648 Shape slice_shape = tensor_info.slice_shape();
649 auto shape_type_pair = GetShapeType(node, slice_shape);
650 Attr attr_shape = std::make_pair(SHAPE, shape_type_pair.first);
651 Attr attr_dtype = std::make_pair(DTYPE, shape_type_pair.second);
652 Attr attr_group = std::make_pair(GROUP, MakeValue(group_[0]));
653 Attr attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[1]));
654 OperatorAttrs attrs = {attr_tag, attr_rank, attr_shape, attr_dtype, attr_group, attr_group_back};
655 auto recv_op = CreatOpInstance(attrs, RECEIVE, RECEIVE);
656 std::vector<AnfNodePtr> recv_input;
657 if (node->isa<Parameter>()) {
658 recv_input = {NewValueNode(recv_op), node};
659 } else {
660 recv_input = {NewValueNode(recv_op), virtual_param_};
661 }
662 auto recv = graph->NewCNode(recv_input);
663 if (is_param) {
664 recv->set_user_data<AnfNode>(PIPELINE_PARAM, node);
665 recv->AddPrimalAttr(PIPELINE_PARAM, value);
666 } else {
667 recv->AddPrimalAttr(PIPELINE_BEGIN, value);
668 }
669 recv->AddPrimalAttr(MICRO, value);
670 auto node_abstract = node->abstract();
671 if (node->isa<CNode>()) {
672 auto cnode = node->cast<CNodePtr>();
673 MS_EXCEPTION_IF_NULL(cnode);
674 if (IsValueNode<FuncGraph>(cnode->input(0))) {
675 auto output = GetValueNode<FuncGraphPtr>(cnode->input(0))->output();
676 MS_EXCEPTION_IF_NULL(output);
677 node_abstract = output->abstract();
678 }
679 }
680 MS_EXCEPTION_IF_NULL(node_abstract);
681 recv->set_abstract(node_abstract);
682 if (node->isa<Parameter>()) {
683 BaseShapePtr parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
684 auto abstract_clone = node->abstract()->Clone();
685 MS_EXCEPTION_IF_NULL(abstract_clone);
686 abstract_clone->set_shape(parallel_shape);
687 node->set_abstract(abstract_clone);
688 node->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
689 }
690 recv->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
691 recv->set_user_data<OperatorInfo>(op_info_pair.first);
692
693 manager_->SetEdge(use_node, index, recv);
694 return recv;
695 }
696
Reuse(const AnfNodePtr & node,int64_t stage,const std::vector<AnfNodePtr> & out_input,const std::string & tag)697 AnfNodePtr PipelineTransformer::Reuse(const AnfNodePtr &node, int64_t stage, const std::vector<AnfNodePtr> &out_input,
698 const std::string &tag) {
699 for (auto &input : out_input) {
700 auto cnode = input->cast<CNodePtr>();
701 if (!cnode) {
702 continue;
703 }
704 if (IsPrimitiveCNode(cnode, prim::kPrimDepend)) {
705 cnode = cnode->input(2)->cast<CNodePtr>();
706 }
707 if (cnode->input(1) == node) {
708 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
709 auto dest_rank_send = GetValue<int64_t>(prim->GetAttr(tag));
710 if (dest_rank_send == stage) {
711 return input;
712 }
713 }
714 }
715 return nullptr;
716 }
717
ActualOp(const AnfNodePtr & node)718 AnfNodePtr PipelineTransformer::ActualOp(const AnfNodePtr &node) {
719 // skip some virtual op like:Depend, Load, Cast
720 if (IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimCast) ||
721 IsPrimitiveCNode(node, prim::kPrimLoad)) {
722 auto cnode = node->cast<CNodePtr>();
723 MS_EXCEPTION_IF_NULL(cnode);
724 return ActualOp(cnode->input(1));
725 }
726 return node;
727 }
728
IsParameterGraph(const AnfNodePtr & node)729 bool PipelineTransformer::IsParameterGraph(const AnfNodePtr &node) {
730 // ParameterGraph: graph which return a parameter
731 MS_EXCEPTION_IF_NULL(node);
732 auto temp_node = ActualOp(node);
733 auto cnode = temp_node->cast<CNodePtr>();
734 MS_EXCEPTION_IF_NULL(cnode);
735
736 // parameter_graph->return->graph
737 if (!IsValueNode<FuncGraph>(cnode->input(0))) {
738 return false;
739 }
740 auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
741 MS_EXCEPTION_IF_NULL(graph);
742 auto graph_out = graph->output();
743 MS_EXCEPTION_IF_NULL(graph_out);
744 auto actual_op = ActualOp(graph_out);
745 MS_EXCEPTION_IF_NULL(actual_op);
746 if (actual_op->isa<Parameter>()) {
747 auto parameter_list = graph->parameters();
748 // parameter_graph->parameter->return->graph
749 auto parameter_iter = std::find(parameter_list.begin(), parameter_list.end(), actual_op);
750 if (parameter_iter == parameter_list.end()) {
751 return true;
752 }
753 // parameter->graph->return->graph
754 auto pos = std::distance(parameter_list.begin(), parameter_iter);
755 if (!cnode->input(LongToSize(pos + 1))->isa<Parameter>()) {
756 return false;
757 }
758 return true;
759 }
760 return false;
761 }
762
HandleParameterGraph(const AnfNodePtr & node,const AnfNodePtr & use_node,int64_t stage,int64_t user_stage,const ValuePtr & micro,size_t pos,const std::vector<AnfNodePtr> ops)763 AnfNodePtr PipelineTransformer::HandleParameterGraph(const AnfNodePtr &node, const AnfNodePtr &use_node, int64_t stage,
764 int64_t user_stage, const ValuePtr µ, size_t pos,
765 const std::vector<AnfNodePtr> ops) {
766 MS_EXCEPTION_IF_NULL(node);
767 auto actual_node = ActualOp(node);
768 auto cnode = actual_node->cast<CNodePtr>();
769 MS_EXCEPTION_IF_NULL(cnode);
770 auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
771 MS_EXCEPTION_IF_NULL(graph);
772 AnfNodePtr argument;
773 AnfNodePtr parameter;
774
775 auto graph_out = ActualOp(graph->output());
776 MS_EXCEPTION_IF_NULL(graph_out);
777 auto parameter_list = graph->parameters();
778 auto param_iter = std::find(parameter_list.begin(), parameter_list.end(), graph_out);
779 auto use_cnode = use_node->cast<CNodePtr>();
780 MS_EXCEPTION_IF_NULL(use_cnode);
781 if (!IsValueNode<FuncGraph>(use_cnode->input(0))) {
782 MS_LOG(EXCEPTION) << "Parameter must be used by a graph, but got: " << use_cnode->DebugString();
783 }
784 auto use_graph = GetValueNode<FuncGraphPtr>(use_cnode->input(0));
785 auto use_parameter_list = use_graph->parameters();
786 parameter = use_parameter_list.at(pos - 1);
787 // argument->load->graph
788 if (param_iter == parameter_list.end()) {
789 argument = graph_out;
790 } else {
791 auto param_pos = std::distance(parameter_list.begin(), param_iter);
792 argument = cnode->input(LongToSize(param_pos + 1));
793 }
794
795 // insert receive
796 if (stage_ == user_stage) {
797 auto recv = Reuse(argument, stage, ops, SRC_RANK);
798 if (recv) {
799 manager_->SetEdge(use_node, SizeToInt(pos), recv);
800 return nullptr;
801 }
802 return InsertReceive(main_graph_, argument, use_node, SizeToInt(pos), user_stage, stage, micro, parameter);
803 }
804 // insert send
805 if (Reuse(argument, user_stage, ops, DEST_RANK)) {
806 return nullptr;
807 }
808 auto send_out = InsertSend(argument, user_stage, stage_, micro);
809 send_out.depend->set_user_data<Type>(DTYPE, send_out.type);
810 send_out.depend->set_user_data<ValueList>(SHAPE, send_out.shape);
811 return send_out.depend;
812 }
813
CutBorder(const FuncGraphPtr & graph)814 std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> PipelineTransformer::CutBorder(const FuncGraphPtr &graph) {
815 std::vector<AnfNodePtr> receive_ops;
816 std::vector<AnfNodePtr> send_ops;
817 auto ret = graph->get_return();
818 MS_EXCEPTION_IF_NULL(ret);
819 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
820 std::reverse(all_nodes.begin(), all_nodes.end());
821 auto stage_num = g_device_manager->stage_num();
822 if (root_->has_flag(TRAINING) && (stage_num > micro_size_)) {
823 MS_LOG(EXCEPTION) << "MicroBatch size: " << micro_size_ << " can't less than stage num: " << stage_num;
824 }
825 for (auto &node : all_nodes) {
826 if (!node->isa<CNode>() || node->stage() == -1 || IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
827 continue;
828 }
829 auto node_users = manager_->node_users()[node];
830 AnfNodePtr receive = nullptr;
831 for (auto &user_pair : node_users) {
832 auto user_node = user_pair.first;
833 auto node_stage = node->stage();
834 auto user_node_stage = user_node->stage();
835 if (node_stage != stage_ && user_node_stage != stage_) {
836 continue;
837 }
838 auto micro = user_node->cast<CNodePtr>()->GetPrimalAttr(MICRO);
839 if (!micro) {
840 MS_LOG(INFO) << "Can't find micro_batch information, use micro(0)";
841 micro = MakeValue(int64_t(0));
842 }
843 if (node_stage < user_node_stage) {
844 if (node_stage == stage_) {
845 if (IsParameterGraph(node)) {
846 auto send_depend = HandleParameterGraph(node, user_node, node_stage, user_node_stage, micro,
847 IntToSize(user_pair.second), send_ops);
848 if (!send_depend) {
849 continue;
850 }
851 (void)send_ops.insert(send_ops.begin(), send_depend);
852 continue;
853 }
854 if (Reuse(node, user_node_stage, send_ops, DEST_RANK)) {
855 continue;
856 }
857 auto send_out = InsertSend(node, user_node_stage, node_stage, micro);
858 MS_EXCEPTION_IF_NULL(send_out.depend);
859 send_ops.push_back(send_out.depend);
860 send_out.depend->set_user_data<Type>(DTYPE, send_out.type);
861 send_out.depend->set_user_data<ValueList>(SHAPE, send_out.shape);
862 } else {
863 if (!receive) {
864 if (IsParameterGraph(node)) {
865 receive = HandleParameterGraph(node, user_node, node_stage, user_node_stage, micro,
866 IntToSize(user_pair.second), receive_ops);
867 if (!receive) {
868 continue;
869 }
870 receive_ops.push_back(receive);
871 } else {
872 receive =
873 InsertReceive(graph, node, user_node, user_pair.second, user_node_stage, node_stage, micro, node);
874 receive_ops.push_back(receive);
875 }
876 } else {
877 manager_->SetEdge(user_node, user_pair.second, receive);
878 }
879 }
880 continue;
881 }
882 if (node_stage > user_node_stage) {
883 MS_LOG(EXCEPTION) << "node_stage: " << node_stage
884 << " must be smaller than user_node_stage: " << user_node_stage;
885 }
886 }
887 }
888 return std::make_pair(send_ops, receive_ops);
889 }
890
CutGraph()891 void PipelineTransformer::CutGraph() {
892 std::vector<AnfNodePtr> make_tuple_inputs;
893 CreateForwardGroup();
894 MS_EXCEPTION_IF_NULL(main_graph_);
895 if (make_tuple_inputs.empty()) {
896 make_tuple_inputs = HandleSharedParameter();
897 }
898 auto send_recv_ops = CutBorder(main_graph_);
899 auto send_ops = send_recv_ops.first;
900 if (IsLastStage()) {
901 return;
902 }
903 if (send_ops.empty() && !root_->has_flag(TRAINING)) {
904 return;
905 }
906 (void)make_tuple_inputs.insert(make_tuple_inputs.end(), send_ops.begin(), send_ops.end());
907 if (!send_ops.empty()) {
908 type_ptr_ = send_ops.back()->user_data<Type>(DTYPE);
909 shape_ = send_ops.back()->user_data<ValueList>(SHAPE);
910 }
911 auto make_tuple = main_graph_->NewCNode(make_tuple_inputs);
912 std::vector<AnfNodePtr> out = {NewValueNode(prim::kPrimDepend)};
913 out.push_back(send_ops.back());
914 out.push_back(make_tuple);
915 auto out_node = main_graph_->NewCNode(out);
916 (void)manager_->Replace(main_graph_->output(), out_node);
917 }
918
ElimGraphStage()919 void PipelineTransformer::ElimGraphStage() {
920 for (auto &fg : manager_->func_graphs()) {
921 fg->set_stage(-1);
922 }
923 }
924
FindSensNode()925 std::pair<CNodePtr, FuncGraphPtr> PipelineTransformer::FindSensNode() {
926 std::pair<CNodePtr, FuncGraphPtr> sens_graph_pair;
927 CNodePtr sens_cnode;
928 FuncGraphPtr func_graph;
929 for (auto &node : root_->nodes()) {
930 if (!node->isa<CNode>()) {
931 continue;
932 }
933 sens_cnode = node->cast<CNodePtr>();
934 AnfNodePtr expect_tuple_getitem = sens_cnode->input(0);
935 MS_EXCEPTION_IF_NULL(expect_tuple_getitem);
936 if (!expect_tuple_getitem->isa<CNode>()) {
937 continue;
938 }
939
940 auto expect_tuple_getitem_cnode = expect_tuple_getitem->cast<CNodePtr>();
941 if (!IsPrimitiveCNode(expect_tuple_getitem_cnode, prim::kPrimTupleGetItem)) {
942 continue;
943 }
944 auto expect_anonymous = expect_tuple_getitem_cnode->input(1);
945 if (!expect_anonymous->isa<CNode>()) {
946 continue;
947 }
948 auto expect_anonymous_cnode = expect_anonymous->cast<CNodePtr>();
949 AnfNodePtr expect_j = expect_anonymous_cnode->input(0);
950 if (!expect_j->isa<CNode>()) {
951 continue;
952 }
953 auto expect_j_cnode = expect_j->cast<CNodePtr>();
954 if (!IsPrimitiveCNode(expect_j_cnode, prim::kPrimJ)) {
955 continue;
956 }
957 func_graph = GetValueNode<FuncGraphPtr>(expect_j_cnode->input(1));
958 break;
959 }
960 sens_graph_pair = std::make_pair(sens_cnode, func_graph);
961 return sens_graph_pair;
962 }
963
CoverSensShape()964 void PipelineTransformer::CoverSensShape() {
965 if (IsLastStage()) {
966 return;
967 }
968 auto sens_graph_pair = FindSensNode();
969 auto sens_cnode = sens_graph_pair.first;
970 MS_EXCEPTION_IF_NULL(sens_cnode);
971 OperatorAttrs attrs;
972 auto fill_op = CreatOpInstance(attrs, "Fill", "");
973 MS_EXCEPTION_IF_NULL(type_ptr_);
974 MS_EXCEPTION_IF_NULL(shape_);
975 std::vector<AnfNodePtr> fill_input = {NewValueNode(fill_op), NewValueNode(type_ptr_),
976 NewValueNode(MakeValue(shape_->value())), NewValueNode(0)};
977 auto fill = root_->NewCNode(fill_input);
978 std::vector<AnfNodePtr> new_sens_input = {sens_cnode->input(0), fill};
979 auto new_sens_node = root_->NewCNode(new_sens_input);
980 manager_->Replace(sens_cnode, new_sens_node);
981 }
982
ElimParameter()983 void PipelineTransformer::ElimParameter() {
984 auto parameters = root_->parameters();
985 std::vector<AnfNodePtr> parameter_list;
986 for (auto ¶meter : parameters) {
987 auto param = parameter->cast<ParameterPtr>();
988 MS_EXCEPTION_IF_NULL(param);
989 if (!manager_->node_users()[parameter].empty() || !param->has_default()) {
990 parameter_list.push_back(parameter);
991 }
992 }
993 auto del_num = parameters.size() - parameter_list.size();
994 root_->set_hyper_param_count(root_->hyper_param_count() - del_num);
995 manager_->SetParameters(root_, parameter_list);
996 }
997 } // namespace parallel
998 } // namespace mindspore
999