1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/pipeline_transformer/pipeline_interleave.h"
18 #include <set>
19 #include <vector>
20 #include <string>
21 #include <utility>
22 #include <algorithm>
23 #include <memory>
24 #include "mindspore/core/ops/sequence_ops.h"
25 #include "mindspore/core/ops/other_ops.h"
26 #include "mindspore/core/ops/nn_ops.h"
27 #include "mindspore/core/ops/array_ops.h"
28 #include "mindspore/core/ops/framework_ops.h"
29 #include "mindspore/core/ops/arithmetic_ops.h"
30 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
31 #include "frontend/parallel/ops_info/ops_utils.h"
32 #include "frontend/parallel/group_manager.h"
33 #include "frontend/parallel/parameter_manager.h"
34 #include "include/common/utils/parallel_context.h"
35 #include "frontend/parallel/step_parallel.h"
36 #include "frontend/parallel/node_check.h"
37 #include "frontend/parallel/graph_util/node_info.h"
38 #include "frontend/parallel/graph_util/graph_info.h"
39 #include "frontend/parallel/graph_util/pipeline_split_utils.h"
40 #include "frontend/parallel/step_parallel_utils.h"
41 #include "frontend/parallel/graph_util/graph_splitter.h"
42 #include "ir/anf.h"
43 #include "ir/graph_utils.h"
44 #include "ir/func_graph_cloner.h"
45 #include "include/common/utils/comm_manager.h"
46 #include "utils/ms_context.h"
47 #include "utils/tensor_construct_utils.h"
48 #include "mindspore/core/utils/parallel_node_check.h"
49
50 namespace mindspore {
51 namespace parallel {
GetRealAbstract(const AnfNodePtr & node)52 static AbstractBasePtr GetRealAbstract(const AnfNodePtr &node) {
53 if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
54 auto &input = node->cast<CNodePtr>()->input(1);
55 MS_EXCEPTION_IF_NULL(input);
56 return input->abstract();
57 }
58 return node->abstract();
59 }
60
MainGraph()61 bool PipelineInterleave::MainGraph() {
62 bool find_main_graph = false;
63 for (auto &fg : manager_->func_graphs()) {
64 for (auto &node : fg->nodes()) {
65 if (IsPrimitiveCNode(node, prim::kPrimVirtualDataset)) {
66 main_graph_ = fg;
67 main_graph_->set_flag(MAIN_GRAPH, true);
68 virtual_dataset_ = node;
69 find_main_graph = true;
70 break;
71 }
72 }
73 if (find_main_graph) {
74 break;
75 }
76 }
77 if (!find_main_graph) {
78 MS_LOG(WARNING) << "Can't find main graph, possible reason is can't find virtual dataset.";
79 return false;
80 }
81 auto value_nodes = main_graph_->value_nodes();
82 for (auto value_pair = value_nodes.cbegin(); value_pair != value_nodes.cend(); ++value_pair) {
83 auto node = (*value_pair).first;
84 if (!IsValueNode<FuncGraph>(node)) {
85 continue;
86 }
87 auto graph = GetValueNode<FuncGraphPtr>(node);
88 MS_EXCEPTION_IF_NULL(graph);
89 if (!graph->has_flag(FUNC_GRAPH_FLAG_CELL_REUSE)) {
90 continue;
91 }
92 shared_cell_ = graph;
93 break;
94 }
95 if (!shared_cell_) {
96 MS_LOG(ERROR) << "Pipeline parallel now only support shared_cell.";
97 auto parallel_context = parallel::ParallelContext::GetInstance();
98 MS_EXCEPTION_IF_NULL(parallel_context);
99 auto is_pp_interleave = parallel_context->pipeline_interleave();
100 if (is_pp_interleave) {
101 MS_LOG(EXCEPTION) << "Using pipeline parallel with interleave, should enable lazy_inline.";
102 }
103 return false;
104 }
105 return true;
106 }
107
CreateSendReceiveGroup()108 void PipelineInterleave::CreateSendReceiveGroup() {
109 MS_EXCEPTION_IF_NULL(g_device_manager);
110 auto rank_list = g_device_manager->GetDeviceListBetweenStage();
111 auto dev_list = g_device_manager->CreateDeviceListByRankList(rank_list);
112 Group forward_send_group;
113 if (g_device_manager->CreateGroup(rank_list, &forward_send_group) != SUCCESS) {
114 MS_LOG(EXCEPTION) << "Create forward Send communication group failed, the rank list is: " << rank_list;
115 }
116 group_.emplace_back(forward_send_group.name());
117
118 Group backward_send_group;
119 auto backward_send_group_name = forward_send_group.name() + BACKWARD;
120 if (g_device_manager->CreateGroup(backward_send_group_name, dev_list, &backward_send_group) != SUCCESS) {
121 MS_LOG(EXCEPTION) << "Create backward Send communication group failed, the rank list is: " << rank_list;
122 }
123 group_.emplace_back(backward_send_group_name);
124
125 Group forward_recv_group;
126 auto forward_recv_group_name = forward_send_group.name() + RECEIVE;
127 if (g_device_manager->CreateGroup(forward_recv_group_name, dev_list, &forward_recv_group) != SUCCESS) {
128 MS_LOG(EXCEPTION) << "Create forward Receive communication group failed, the rank list is: " << rank_list;
129 }
130 group_.emplace_back(forward_recv_group_name);
131
132 Group backward_recv_group;
133 auto backward_recv_group_name = forward_recv_group_name + BACKWARD;
134 if (g_device_manager->CreateGroup(backward_recv_group_name, dev_list, &backward_recv_group) != SUCCESS) {
135 MS_LOG(EXCEPTION) << "Create backward Receive communication group failed, the rank list is: " << rank_list;
136 }
137 group_.emplace_back(backward_recv_group_name);
138 }
139
SetMicroBatch(const AnfNodePtr & node,int64_t micro_size,size_t batch_axis) const140 ValuePtr PipelineInterleave::SetMicroBatch(const AnfNodePtr &node, int64_t micro_size, size_t batch_axis) const {
141 if (!IsPrimitiveCNode(node, prim::kPrimStridedSlice)) {
142 MS_LOG(EXCEPTION) << "Can't find MicroBatch information.";
143 }
144 auto cnode = node->cast<CNodePtr>();
145
146 int64_t micro = 0;
147 auto value = GetValueNode(cnode->input(2));
148 if (value != nullptr) {
149 auto tuple = GetValue<std::vector<int64_t>>(value); // begin
150 auto input_tmp = GetNodeShape(cnode->input(1));
151 auto input_shape = input_tmp.at(0);
152 auto slice_batch_size = input_shape.at(batch_axis); // betch shape
153 if (slice_batch_size == 0) {
154 MS_LOG(EXCEPTION) << "slice_batch_size should be a positive integer, but got " << slice_batch_size;
155 }
156 micro = tuple.at(batch_axis) * micro_size / slice_batch_size; // micro-index
157 } else {
158 // dynamic shape
159 // if micro is not 1: stridedslice --> maketuple --> scalarmul --> micro
160 // if micro is 1: stridedslice --> maketuple --> scalarfloordiv
161 if (!IsPrimitiveCNode(cnode->input(2), prim::kPrimMakeTuple)) {
162 MS_LOG(EXCEPTION) << "The begin of stridedslice is not constant value, and not make tuple";
163 }
164 auto make_tuple_cnode = cnode->input(2)->cast<CNodePtr>();
165 if (IsPrimitiveCNode(make_tuple_cnode->input(1), prim::kPrimScalarMul)) {
166 auto scalar_mul_cnode = make_tuple_cnode->input(1)->cast<CNodePtr>();
167 auto mul_value = GetValueNode(scalar_mul_cnode->input(2));
168 micro = GetValue<int64_t>(mul_value);
169 } else if (IsPrimitiveCNode(make_tuple_cnode->input(1), prim::kPrimScalarFloorDiv)) {
170 micro = 1;
171 } else {
172 MS_LOG(EXCEPTION) << "Can not find the micro info, the input op of make tuple is "
173 << GetCNodePrimitive(make_tuple_cnode->input(1))->name();
174 }
175 }
176
177 cnode->AddPrimalAttr(MICRO, MakeValue(micro));
178 cnode->AddPrimalAttr(PIPELINE_BEGIN, MakeValue(micro));
179 int64_t seg = 0;
180 cnode->AddPrimalAttr(SEGMENT, MakeValue(seg));
181 return MakeValue(micro);
182 }
183
Init()184 void PipelineInterleave::Init() {
185 auto ms_context = MsContext::GetInstance();
186 MS_EXCEPTION_IF_NULL(ms_context);
187 world_group_ = GetWorldGroup();
188 uint32_t world_rank_size = 0;
189 global_rank_ = parallel::ParallelContext::GetInstance()->global_rank();
190 uint32_t rank_id = 0;
191 if (!parallel::ParallelContext::GetInstance()->global_rank_is_set()) {
192 if (!CommManager::GetInstance().GetRankID(world_group_, &rank_id)) {
193 MS_LOG(EXCEPTION) << "Get rank id failed.";
194 }
195 global_rank_ = UintToInt(rank_id);
196 }
197 int64_t device_num = 0;
198 auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num();
199 if (!parallel::ParallelContext::GetInstance()->device_num_is_set()) {
200 if (!CommManager::GetInstance().GetRankSize(world_group_, &world_rank_size)) {
201 MS_LOG(EXCEPTION) << "Get rank size failed";
202 }
203 device_num = UintToInt(world_rank_size);
204 MS_LOG(INFO) << "Get device num from communication model, the device num is " << device_num;
205 } else {
206 device_num = parallel::ParallelContext::GetInstance()->device_num();
207 }
208 per_stage_rank_num_ = device_num / stage_num;
209 return;
210 }
211
GetBatchAxisForInput(const AnfNodeIndexSet & input_node_users) const212 size_t PipelineInterleave::GetBatchAxisForInput(const AnfNodeIndexSet &input_node_users) const {
213 Shapes inputs_tuple;
214 for (const auto &input_node_user : input_node_users) {
215 auto node = input_node_user.first;
216 if (!IsPrimitiveCNode(node, prim::kPrimStridedSlice)) {
217 return 0; // simply return 0 when dynamic shape
218 }
219 auto cnode = node->cast<CNodePtr>();
220 auto value = GetValueNode(cnode->input(2));
221 if (value == nullptr) {
222 return 0; // simply return 0 when dynamic shape
223 }
224 auto tuple = GetValue<std::vector<int64_t>>(value);
225 inputs_tuple.push_back(tuple);
226 }
227 size_t batch_axis = 0;
228 size_t batch_axis_count = 0;
229 size_t input_dim = inputs_tuple.at(0).size();
230 size_t micro_num = inputs_tuple.size();
231 for (size_t axis = 0; axis < input_dim; ++axis) {
232 for (size_t i = 1; i < micro_num; ++i) {
233 if (inputs_tuple[i][axis] != inputs_tuple[i - 1][axis]) {
234 batch_axis = axis;
235 ++batch_axis_count;
236 break;
237 }
238 }
239 }
240 if (batch_axis_count != kSizeOne) {
241 MS_LOG(EXCEPTION)
242 << "For pipeline parallelism, micro_size partitioning of the input along a certain dimension is and "
243 << "is only allowed, but it is found that " << batch_axis_count << " to be partitioned.";
244 }
245 return batch_axis;
246 }
247
LabelMicroBatch()248 void PipelineInterleave::LabelMicroBatch() {
249 if (!is_train_) {
250 return;
251 }
252 MS_EXCEPTION_IF_NULL(virtual_dataset_);
253 auto node_user_map = manager_->node_users();
254 auto node_users = node_user_map[virtual_dataset_];
255 for (auto &node_user : node_users) {
256 if (IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) {
257 auto data_users = manager_->node_users()[node_user.first];
258 auto node_first = data_users.front().first;
259 if (!IsPrimitiveCNode(node_first, prim::kPrimStridedSlice) && !IsPrimitiveCNode(node_first, prim::kPrimShape)) {
260 data_users.clear();
261 data_users = node_user_map[node_first];
262 }
263 auto micro_size = int64_t(MicroSize(data_users));
264 micro_size_ = micro_size;
265 auto batch_axis = GetBatchAxisForInput(data_users);
266 MS_LOG(INFO) << "For the "
267 << GetSerialNumberString(
268 GetValue<int64_t>(GetValueNode(node_user.first->cast<CNodePtr>()->input(kIndex2))))
269 << "input, batch axis is " << batch_axis << ", micro size is : " << micro_size;
270 for (auto &data_user : data_users) {
271 if (!IsPrimitiveCNode(data_user.first, prim::kPrimStridedSlice)) {
272 continue;
273 }
274 auto micro = SetMicroBatch(data_user.first, micro_size, batch_axis);
275 SetStridedSliceStrategy(data_user.first);
276 auto cnode = data_user.first->cast<CNodePtr>();
277 BroadCastMicroBatch(cnode, &node_user_map, micro, 0);
278 }
279 }
280 }
281 }
282
LabelGenMaskFusion()283 void PipelineInterleave::LabelGenMaskFusion() {
284 auto fgs = manager_->func_graphs();
285 int64_t fusion_id = 0;
286 for (auto fg = fgs.cbegin(); fg != fgs.cend(); ++fg) {
287 if (*fg == root_ || *fg == main_graph_) {
288 continue;
289 }
290 auto stage = (*fg)->stage();
291 if (stage != -1 && stage != stage_) {
292 continue;
293 }
294 auto nodes = (*fg)->nodes();
295 for (auto node = nodes.cbegin(); node != nodes.cend(); ++node) {
296 if (!IsPrimitiveCNode(*node, prim::kPrimDropoutGenMask) && !IsPrimitiveCNode(*node, prim::kPrimDropoutDoMaskV3) &&
297 !IsPrimitiveCNode(*node, prim::kPrimDropout)) {
298 continue;
299 }
300 auto cnode = (*node)->cast<CNodePtr>();
301 MS_EXCEPTION_IF_NULL(cnode);
302 cnode->AddPrimalAttr(kAttrFusion, MakeValue(fusion_id));
303 fusion_id += 1;
304 }
305 }
306 }
307
Coloring()308 void PipelineInterleave::Coloring() {
309 auto need_coloring = true;
310 std::set<int64_t> stage_set;
311 if (!IsTraining(manager_)) {
312 is_train_ = false;
313 }
314 while (need_coloring) {
315 need_coloring = false;
316 for (auto &fg : manager_->func_graphs()) {
317 if (fg == root_ && is_train_) {
318 continue;
319 }
320 auto value_nodes = fg->value_nodes();
321 for (auto value_pair = value_nodes.cbegin(); value_pair != value_nodes.cend(); ++value_pair) {
322 auto node = (*value_pair).first;
323 if (!IsValueNode<FuncGraph>(node)) {
324 continue;
325 }
326 auto graph = GetValueNode<FuncGraphPtr>(node);
327 if (graph->stage() == -1) {
328 continue;
329 }
330 (void)stage_set.insert(graph->stage());
331 auto node_users = manager_->node_users()[node];
332 for (auto &user_pair : node_users) {
333 auto user_node = user_pair.first->cast<CNodePtr>();
334 user_node->set_user_data<NodeStageInfo>(std::make_shared<NodeStageInfo>(graph->stage()));
335 auto user_node_graph = user_node->func_graph();
336 if (graph->stage() == stage_ && user_node_graph->stage() == -1) {
337 user_node_graph->set_stage(graph->stage());
338 need_coloring = true;
339 }
340 }
341 }
342 }
343 }
344 MS_EXCEPTION_IF_NULL(g_device_manager);
345 auto stage_num = g_device_manager->stage_num();
346 if (SizeToLong(stage_set.size()) != stage_num) {
347 MS_LOG(EXCEPTION) << "Stage num is " << stage_num << " which is not equal to stage used: " << stage_set.size();
348 }
349 }
350
BroadCastColoring()351 void PipelineInterleave::BroadCastColoring() {
352 auto need_coloring = true;
353 while (need_coloring) {
354 need_coloring = false;
355 auto all_nodes = shared_cell_->nodes();
356 auto node_users = manager_->node_users();
357 for (auto node = all_nodes.cbegin(); node != all_nodes.cend(); ++node) {
358 auto stage_info = (*node)->user_data<NodeStageInfo>();
359 if (!(*node)->isa<CNode>() || stage_info == nullptr || stage_info->stage() == -1 ||
360 IsPrimitiveCNode(*node, prim::kPrimUpdateState)) {
361 continue;
362 }
363 auto cnode = (*node)->cast<CNodePtr>();
364 auto stage = stage_info->stage();
365 auto chunk = stage_info->chunk();
366 for (auto &user_pair : node_users[*node]) {
367 auto user_node = user_pair.first->cast<CNodePtr>();
368 auto user_stage_info = user_node->user_data<NodeStageInfo>();
369 if (user_stage_info == nullptr) {
370 user_node->set_user_data<NodeStageInfo>(std::make_shared<NodeStageInfo>(stage, chunk));
371 need_coloring = true;
372 user_node->AddPrimalAttr(CHUNK, MakeValue(chunk));
373 user_node->AddPrimalAttr(STAGE, MakeValue(stage));
374 continue;
375 }
376 auto user_node_stage = user_stage_info->stage();
377 auto user_node_chunk = user_stage_info->chunk();
378 if (stage == user_node_stage) {
379 if (chunk > user_node_chunk) {
380 user_stage_info->set_chunk(chunk);
381 need_coloring = true;
382 user_node->AddPrimalAttr(CHUNK, MakeValue(chunk));
383 user_node->AddPrimalAttr(STAGE, MakeValue(user_node_stage));
384 continue;
385 }
386 if (chunk < user_node_chunk) {
387 stage_info->set_chunk(user_node_chunk);
388 chunk = user_node_chunk;
389 need_coloring = true;
390 cnode->AddPrimalAttr(CHUNK, MakeValue(chunk));
391 cnode->AddPrimalAttr(STAGE, MakeValue(user_node_stage));
392 continue;
393 }
394 }
395 if (stage > user_node_stage) {
396 if ((chunk >= user_node_chunk)) {
397 user_stage_info->set_chunk(chunk + 1);
398 need_coloring = true;
399 user_node->AddPrimalAttr(CHUNK, MakeValue(chunk + 1));
400 user_node->AddPrimalAttr(STAGE, MakeValue(user_node_stage));
401 continue;
402 }
403 }
404 if ((stage < user_node_stage) && (chunk > user_node_chunk)) {
405 user_stage_info->set_chunk(chunk);
406 need_coloring = true;
407 user_node->AddPrimalAttr(CHUNK, MakeValue(chunk));
408 user_node->AddPrimalAttr(STAGE, MakeValue(user_node_stage));
409 }
410 }
411 }
412 }
413 }
414
GetLoadNodeByParam(const AnfNodePtr & param) const415 std::vector<AnfNodePtr> PipelineInterleave::GetLoadNodeByParam(const AnfNodePtr ¶m) const {
416 std::vector<AnfNodePtr> load_vec = {param};
417 auto node_users = manager_->node_users()[param];
418 for (auto ¶m_user : node_users) {
419 if (IsPrimitiveCNode(param_user.first, prim::kPrimLoad)) {
420 auto graph = param_user.first->func_graph();
421 // exclude opt graphs
422 if (graph == root_ || (graph->stage() == -1 && graph != main_graph_)) {
423 continue;
424 }
425 (void)load_vec.emplace_back(param_user.first);
426 }
427 }
428 return load_vec;
429 }
430
GetStageByArgument(const CNodePtr & node,size_t index,const std::vector<AnfNodePtr> & parameters,const NodeUsersMap & node_users_map,std::set<int64_t> * const parameter_stage)431 bool PipelineInterleave::GetStageByArgument(const CNodePtr &node, size_t index,
432 const std::vector<AnfNodePtr> ¶meters,
433 const NodeUsersMap &node_users_map,
434 std::set<int64_t> *const parameter_stage) {
435 if (index < 1) {
436 return false;
437 }
438 const auto &input = node->input(0);
439 if (!IsValueNode<FuncGraph>(input)) {
440 return false;
441 }
442 if (GetValueNode<FuncGraphPtr>(input) != shared_cell_) {
443 return false;
444 }
445 auto pos = index - 1;
446 const auto ¶m = parameters.at(pos);
447 MS_EXCEPTION_IF_NULL(param);
448 auto loads = GetLoadNodeByParam(param);
449 const auto &iter = node_users_map.find(loads.back());
450 if (iter == node_users_map.end()) {
451 return true;
452 }
453 const auto &users = (*iter).second;
454 for (auto &user : users) {
455 auto user_cnode = user.first->cast<CNodePtr>();
456 MS_EXCEPTION_IF_NULL(user_cnode);
457 auto stage_info = user_cnode->user_data<NodeStageInfo>();
458 if (stage_info != nullptr && stage_info->stage() != -1) {
459 (void)((*parameter_stage).insert(stage_info->stage()));
460 } else {
461 auto graph = user_cnode->func_graph();
462 MS_EXCEPTION_IF_NULL(graph);
463 if (graph != root_ && graph != main_graph_ && graph != shared_cell_ && graph->stage() != -1) {
464 (void)((*parameter_stage).insert(graph->stage()));
465 }
466 }
467 }
468 return true;
469 }
470
ParameterColoring()471 void PipelineInterleave::ParameterColoring() {
472 auto parameters = root_->parameters();
473 auto &node_users_map = manager_->node_users();
474 const auto &share_cell_parameters = shared_cell_->parameters();
475 for (auto ¶meter : parameters) {
476 auto loads = GetLoadNodeByParam(parameter);
477 std::set<int64_t> parameter_stage;
478 for (auto &load : loads) {
479 auto load_users = node_users_map[load];
480 for (auto &load_user : load_users) {
481 auto user_cnode = load_user.first->cast<CNodePtr>();
482 MS_EXCEPTION_IF_NULL(user_cnode);
483 if (GetStageByArgument(user_cnode, load_user.second, share_cell_parameters, node_users_map, ¶meter_stage)) {
484 continue;
485 }
486 auto stage_info = user_cnode->user_data<NodeStageInfo>();
487 if (stage_info != nullptr && stage_info->stage() != -1) {
488 (void)parameter_stage.insert(stage_info->stage());
489 continue;
490 } else {
491 auto graph = user_cnode->func_graph();
492 MS_EXCEPTION_IF_NULL(graph);
493 if (graph != root_ && graph != main_graph_ && graph != shared_cell_ && graph->stage() != -1) {
494 (void)parameter_stage.insert(graph->stage());
495 continue;
496 }
497 }
498 }
499 }
500 parameter_color_map_[parameter] = parameter_stage;
501 }
502 }
503
RemoveMonadNode()504 void PipelineInterleave::RemoveMonadNode() {
505 auto all_nodes = DeepScopedGraphSearch(shared_cell_->get_return());
506 auto node_users_map = manager_->node_users();
507 for (auto &node : all_nodes) {
508 if (!IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
509 continue;
510 }
511 auto cnode = node->cast<CNodePtr>();
512 MS_EXCEPTION_IF_NULL(cnode);
513 auto abs = cnode->abstract();
514 MS_EXCEPTION_IF_NULL(abs);
515 auto stage_info = cnode->user_data<NodeStageInfo>();
516 if (stage_info == nullptr) {
517 continue;
518 }
519 auto stage = stage_info->stage();
520 if (stage != stage_ && stage != -1) {
521 auto node_users = node_users_map[node];
522 for (auto &user_node : node_users) {
523 auto monad_node = NewValueNode(kUMonad);
524 if (abs->isa<abstract::AbstractIOMonad>()) {
525 monad_node = NewValueNode(kIOMonad);
526 }
527 manager_->SetEdge(user_node.first, user_node.second, monad_node);
528 }
529 }
530 }
531 }
532
CreateZeroseOutput(const AnfNodePtr & node,size_t index)533 static tensor::TensorPtr CreateZeroseOutput(const AnfNodePtr &node, size_t index) {
534 auto out_shapes = GetNodeShape(node);
535 auto out_shape_type = GetShapeType(node, out_shapes.at(index), index);
536 auto zero_tensor = TensorConstructUtils::CreateZerosTensor(out_shape_type.second, out_shapes.at(index));
537 return zero_tensor;
538 }
539
CreateTupleZeroTensor(const FuncGraphPtr & graph,const AnfNodePtr & node,size_t index)540 static AnfNodePtr CreateTupleZeroTensor(const FuncGraphPtr &graph, const AnfNodePtr &node, size_t index) {
541 std::vector<AnfNodePtr> temp_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
542 auto out_shapes = GetNodeShape(node);
543 for (size_t ele = 0; ele < out_shapes.size(); ++ele) {
544 temp_tuple_inputs.emplace_back(NewValueNode(CreateZeroseOutput(node, ele)));
545 }
546 auto temp_tuple = graph->NewCNode(temp_tuple_inputs);
547 return temp_tuple;
548 }
549
InsertSendReceive(const AnfNodePtr & node,const AnfNodePtr & user_node,int64_t order)550 void PipelineInterleave::InsertSendReceive(const AnfNodePtr &node, const AnfNodePtr &user_node, int64_t order) {
551 auto node_stage_info = node->user_data<NodeStageInfo>();
552 auto user_node_stage_info = user_node->user_data<NodeStageInfo>();
553 auto node_stage = node_stage_info->stage();
554 auto user_stage = user_node_stage_info->stage();
555 Attr attr_tag = std::make_pair(SR_TAG, MakeValue(0));
556 Attr attr_rank = std::make_pair(DEST_RANK, MakeValue(user_stage));
557 Attr attr_group = std::make_pair(GROUP, MakeValue(group_[0]));
558 Attr attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[1]));
559 if (node_stage > user_stage) {
560 attr_group = std::make_pair(GROUP, MakeValue(group_[INDEX_TWO]));
561 attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[INDEX_THREE]));
562 }
563 OperatorAttrs attrs = {attr_tag, attr_rank, attr_group, attr_group_back};
564 auto send_op = CreateOpInstance(attrs, SEND, SEND);
565 auto send_node = NewValueNode(send_op);
566 std::vector<AnfNodePtr> send_input = {send_node, node};
567 auto graph = shared_cell_;
568 auto send = graph->NewCNode(send_input);
569 send->set_user_data<NodeStageInfo>(node_stage_info);
570 send->set_abstract(node->abstract());
571 send->AddPrimalAttr(CHUNK, MakeValue(node_stage_info->chunk()));
572 send->AddPrimalAttr(STAGE, MakeValue(node_stage_info->stage()));
573 send->AddPrimalAttr(ORDER, MakeValue(order));
574
575 attr_rank = std::make_pair(SRC_RANK, MakeValue(node_stage));
576 auto shape_type_pair = GetShapeType(node, {1}, 0);
577 Attr attr_shape = std::make_pair(SHAPE, shape_type_pair.first);
578 Attr attr_dtype = std::make_pair(DTYPE, shape_type_pair.second);
579 auto send_prim = GetCNodePrimitive(send);
580 send_prim->set_attr(DTYPE, shape_type_pair.second);
581 OperatorAttrs attrs_recv = {attr_tag, attr_rank, attr_shape, attr_dtype, attr_group, attr_group_back};
582 auto recv_op = CreateOpInstance(attrs_recv, RECEIVE, RECEIVE);
583 std::vector<AnfNodePtr> recv_input = {NewValueNode(recv_op), send};
584 auto recv = graph->NewCNode(recv_input);
585 recv->set_abstract(node->abstract());
586 recv->set_user_data<NodeStageInfo>(user_node_stage_info);
587 recv->AddPrimalAttr(CHUNK, MakeValue(user_node_stage_info->chunk()));
588 recv->AddPrimalAttr(STAGE, MakeValue(user_node_stage_info->stage()));
589 recv->AddPrimalAttr(ORDER, MakeValue(order));
590 auto micro = user_node->cast<CNodePtr>()->GetPrimalAttr(MICRO);
591 if (micro != nullptr) {
592 recv->AddPrimalAttr(MICRO, micro);
593 }
594 manager_->Replace(node, recv);
595 }
596
CutBorderForNode(const FuncGraphPtr & graph,const AnfNodePtr & node,int64_t * order)597 void PipelineInterleave::CutBorderForNode(const FuncGraphPtr &graph, const AnfNodePtr &node, int64_t *order) {
598 auto stage_info = node->user_data<NodeStageInfo>();
599 auto node_users = manager_->node_users()[node];
600 AnfNodePtr receive = nullptr;
601 auto pre_node = GetRealKernelNode(node, -1).first;
602 bool send_param = false;
603 if (pre_node->isa<Parameter>()) {
604 send_param = true;
605 }
606 for (auto &user_pair : node_users) {
607 auto user_node = user_pair.first;
608 auto node_stage = stage_info->stage();
609 auto user_stage_info = user_node->user_data<NodeStageInfo>();
610 if (user_stage_info == nullptr) {
611 continue;
612 }
613 auto user_node_stage = user_stage_info->stage();
614 auto micro = user_node->cast<CNodePtr>()->GetPrimalAttr(MICRO);
615 if (!micro) {
616 MS_LOG(INFO) << "Can't find micro_batch information, use micro(0)";
617 micro = MakeValue(int64_t(0));
618 }
619 if (node_stage != user_node_stage) {
620 InsertSendReceive(node, user_node, *order);
621 (*order) += 1;
622 if (send_param) {
623 parameter_color_map_[pre_node].insert(user_node_stage);
624 }
625 }
626 }
627 }
628
RedundancyNode(const AnfNodePtr & node,mindspore::HashMap<CNodePtr,std::vector<AnfNodePtr>> * make_tuple_map)629 void PipelineInterleave::RedundancyNode(const AnfNodePtr &node,
630 mindspore::HashMap<CNodePtr, std::vector<AnfNodePtr>> *make_tuple_map) {
631 auto node_users = manager_->node_users()[node];
632 for (auto &node_user_pair : node_users) {
633 auto cnode = node_user_pair.first->cast<CNodePtr>();
634 // node->UpdateState, replaced node wiht U.
635 auto fg = cnode->func_graph();
636 MS_EXCEPTION_IF_NULL(fg);
637 if (fg->stage() != -1 && fg != main_graph_) {
638 continue;
639 }
640 if (IsPrimitiveCNode(cnode, prim::kPrimUpdateState)) {
641 auto u_node = NewValueNode(kUMonad);
642 manager_->SetEdge(cnode, node_user_pair.second, u_node);
643 continue;
644 }
645 // node->make_tuple, record with a map, Unified deleted later.
646 if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) {
647 if (fg == main_graph_) {
648 continue;
649 }
650 if (make_tuple_map->find(cnode) == (*make_tuple_map).end()) {
651 (*make_tuple_map)[cnode] = {node};
652 } else {
653 (*make_tuple_map)[cnode].push_back(node);
654 }
655 } else {
656 RedundancyNode(node_user_pair.first, make_tuple_map);
657 }
658 }
659 }
660
IsRedundancyParameter(const AnfNodePtr & parameter,const std::vector<AnfNodePtr> & non_cloned_parameters)661 bool PipelineInterleave::IsRedundancyParameter(const AnfNodePtr ¶meter,
662 const std::vector<AnfNodePtr> &non_cloned_parameters) {
663 // RedundancyParameter: other stage's parameters included corresponding cloned parameters.
664 auto param_ptr = parameter->cast<ParameterPtr>();
665 MS_EXCEPTION_IF_NULL(param_ptr);
666 if (!param_ptr->has_default()) {
667 return false;
668 }
669 std::set<int64_t> stage_set;
670 if (!ParameterIsCloned(parameter)) {
671 stage_set = parameter_color_map_.at(parameter);
672 } else {
673 auto parameters = root_->parameters();
674 auto param_name = param_ptr->name();
675 auto non_clone_name = param_name.substr(param_name.find_first_of('.') + 1);
676 for (auto ¶m : non_cloned_parameters) {
677 auto non_cloned_param = param->cast<ParameterPtr>();
678 if (non_clone_name != non_cloned_param->name()) {
679 continue;
680 }
681 stage_set = parameter_color_map_.at(param);
682 break;
683 }
684 }
685 if (stage_set.empty()) {
686 return false;
687 }
688 return stage_set.count(stage_) == 0;
689 }
690
ElimParameter()691 void PipelineInterleave::ElimParameter() {
692 auto parameters = root_->parameters();
693 mindspore::HashMap<CNodePtr, std::vector<AnfNodePtr>> make_tuple_map;
694 std::vector<AnfNodePtr> non_cloned_parameters;
695 FreezeGradient();
696 auto node_users_map = manager_->node_users();
697 for (auto ¶meter : parameters) {
698 if (ParameterIsCloned(parameter)) {
699 continue;
700 }
701 non_cloned_parameters.push_back(parameter);
702 }
703 for (auto ¶meter : parameters) {
704 if (!IsRedundancyParameter(parameter, non_cloned_parameters)) {
705 continue;
706 }
707 MS_LOG(INFO) << "Parameter:" << parameter->DebugString() << " is Redundancy.";
708 RedundancyNode(parameter, &make_tuple_map);
709 }
710 for (auto &temp : make_tuple_map) {
711 auto make_tuple = temp.first;
712 auto fg = make_tuple->func_graph();
713 MS_EXCEPTION_IF_NULL(fg);
714 auto remove_vector = temp.second;
715 if (remove_vector.empty()) {
716 continue;
717 }
718 auto make_tuple_user = node_users_map.at(make_tuple).front().first;
719 auto make_tuple_inputs = make_tuple->inputs();
720 std::vector<AnfNodePtr> new_inputs;
721 for (auto &input : make_tuple_inputs) {
722 if (std::find(remove_vector.begin(), remove_vector.end(), input) == remove_vector.end()) {
723 new_inputs.push_back(input);
724 }
725 if (root_->has_flag(NO_UPDATE) && IsPrimitiveCNode(make_tuple_user, prim::kPrimAddN)) {
726 auto zeros = CreateZeroseOutput(input, 0);
727 new_inputs.push_back(NewValueNode(zeros));
728 }
729 }
730 auto new_make_tuple = fg->NewCNode(new_inputs);
731 (void)manager_->Replace(make_tuple, new_make_tuple);
732 }
733 }
734
ModifyParameterList()735 void PipelinePostProcess::ModifyParameterList() {
736 auto parameters = root_->parameters();
737 std::vector<AnfNodePtr> parameter_list;
738 for (auto ¶meter : parameters) {
739 auto param = parameter->cast<ParameterPtr>();
740 MS_EXCEPTION_IF_NULL(param);
741 if (!manager_->node_users()[parameter].empty() || !param->has_default()) {
742 parameter_list.push_back(parameter);
743 }
744 }
745 auto del_num = parameters.size() - parameter_list.size();
746 root_->set_fv_param_count(root_->fv_param_count() - del_num);
747 manager_->SetParameters(root_, parameter_list);
748 }
749
CutBorder()750 void PipelineInterleave::CutBorder() {
751 CreateSendReceiveGroup();
752 MS_EXCEPTION_IF_NULL(shared_cell_);
753 auto ret = shared_cell_->get_return();
754 MS_EXCEPTION_IF_NULL(ret);
755 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
756 std::reverse(all_nodes.begin(), all_nodes.end());
757 int64_t order = 0;
758 for (auto &node : all_nodes) {
759 auto stage_info = node->user_data<NodeStageInfo>();
760 if (!node->isa<CNode>() || stage_info == nullptr || stage_info->stage() == -1 ||
761 IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
762 continue;
763 }
764 // Modify for lizard cyclomatic complexity.
765 CutBorderForNode(shared_cell_, node, &order);
766 }
767 RemoveMonadNode();
768 }
769
GetZeroOutputs(const FuncGraphPtr & graph)770 AnfNodePtr PipelinePostProcess::GetZeroOutputs(const FuncGraphPtr &graph) {
771 auto real_kernel = GetRealKernelNode(graph->output(), -1);
772 AnfNodePtr node = real_kernel.first;
773 MS_EXCEPTION_IF_NULL(node);
774 std::vector<AnfNodePtr> out_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
775 if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
776 auto cnode = node->cast<CNodePtr>();
777 for (size_t i = 1; i < cnode->inputs().size(); ++i) {
778 auto each_out_shapes = GetNodeShape(cnode->input(i));
779 if (each_out_shapes.size() > 1) {
780 auto temp_tuple = CreateTupleZeroTensor(graph, cnode->input(i), each_out_shapes.size());
781 (void)out_tuple_inputs.emplace_back(temp_tuple);
782 continue;
783 }
784 (void)out_tuple_inputs.emplace_back(NewValueNode(CreateZeroseOutput(cnode->input(i), 0)));
785 }
786 }
787 AnfNodePtr zero_outputs;
788 if (out_tuple_inputs.size() > INDEX_ONE) {
789 auto out_tuple = graph->NewCNode(out_tuple_inputs);
790 return out_tuple;
791 } else {
792 auto out_shapes = GetNodeShape(node);
793 AnfNodePtr out_tensor;
794 if (out_shapes.size() > 1 && real_kernel.second == -1) {
795 out_tensor = CreateTupleZeroTensor(graph, node, out_shapes.size());
796 } else {
797 out_tensor = NewValueNode(CreateZeroseOutput(node, 0));
798 }
799 return out_tensor;
800 }
801 return nullptr;
802 }
803
SetNodeAbstract(const std::vector<AnfNodePtr> & nodes)804 void PipelinePostProcess::SetNodeAbstract(const std::vector<AnfNodePtr> &nodes) {
805 AbstractBasePtr abs;
806 if (nodes.size() == 1) {
807 auto cnode = nodes.front()->cast<CNodePtr>();
808 MS_EXCEPTION_IF_NULL(cnode);
809 abs = GetRealAbstract(cnode->input(INDEX_ONE));
810 } else {
811 AbstractBasePtrList abstract_list;
812 abstract_list.resize(nodes.size());
813 (void)std::transform(nodes.begin(), nodes.end(), abstract_list.begin(), [](const AnfNodePtr &node) {
814 auto cnode = node->cast<CNodePtr>();
815 MS_EXCEPTION_IF_NULL(cnode);
816 return GetRealAbstract(cnode->input(INDEX_ONE));
817 });
818 abs = std::make_shared<abstract::AbstractTuple>(abstract_list);
819 }
820 for (auto &user : shared_cell_users_) {
821 user->set_abstract(abs);
822 }
823 }
824
ModifySendRecvAttr(const std::vector<AnfNodePtr> & all_nodes)825 void PipelinePostProcess::ModifySendRecvAttr(const std::vector<AnfNodePtr> &all_nodes) {
826 for (auto &node : all_nodes) {
827 if (!IsPrimitiveCNode(node, prim::kPrimSend) && !IsPrimitiveCNode(node, prim::kPrimReceive)) {
828 continue;
829 }
830 auto pre_node_pair = GetRealKernelNode(node, -1);
831 auto pre_node = pre_node_pair.first;
832 auto cnode = node->cast<CNodePtr>();
833 auto prim = GetCNodePrimitive(node);
834 Shape slice_shape;
835 if (pre_node->isa<Parameter>()) {
836 auto base_shape = pre_node->Shape();
837 MS_EXCEPTION_IF_NULL(base_shape);
838 auto shape_ptr = dyn_cast<abstract::Shape>(base_shape);
839 MS_EXCEPTION_IF_NULL(shape_ptr);
840 slice_shape = shape_ptr->shape();
841 cnode->AddPrimalAttr(PIPELINE_PARAM, MakeValue(0));
842 cnode->AddPrimalAttr(MICRO, MakeValue(int64_t(0)));
843 cnode->set_user_data<AnfNode>(INPUT_PARAM, pre_node);
844 } else {
845 auto op_info = pre_node->cast<CNodePtr>()->user_data<OperatorInfo>();
846 MS_EXCEPTION_IF_NULL(op_info);
847 auto tensor_info = op_info->outputs_tensor_info();
848 if (pre_node_pair.second != -1 && tensor_info.size() > 1) {
849 slice_shape = tensor_info.at(pre_node_pair.second).slice_shape();
850 } else {
851 slice_shape = tensor_info.at(0).slice_shape();
852 }
853 }
854 auto abstract = node->abstract();
855 abstract->set_shape(std::make_shared<abstract::Shape>(slice_shape));
856 std::vector<ValuePtr> element;
857 (void)std::transform(slice_shape.begin(), slice_shape.end(), std::back_inserter(element),
858 [](int elem) { return MakeValue(int64_t(elem)); });
859 auto value = std::make_shared<ValueList>(element);
860 prim->set_attr(SHAPE, value);
861 }
862 }
863
CalSrTag(int64_t order,int64_t micro,int64_t interleave_index)864 static int64_t CalSrTag(int64_t order, int64_t micro, int64_t interleave_index) {
865 return order * MAX_MICRO_BATCH_NUM * MAX_INTERLEAVE_NUM + interleave_index * MAX_INTERLEAVE_NUM + micro;
866 }
867
GenNewNodeFromOld(const AnfNodePtr & node,const AnfNodePtr & input,int64_t micro,int64_t index)868 AnfNodePtr PipelinePostProcess::GenNewNodeFromOld(const AnfNodePtr &node, const AnfNodePtr &input, int64_t micro,
869 int64_t index) {
870 const auto &old = node->cast<CNodePtr>();
871 MS_EXCEPTION_IF_NULL(old);
872 auto prim = GetCNodePrimitive(node);
873 auto cloned_prim = prim->Clone();
874 auto attrs = prim->attrs();
875 auto order = GetValue<int64_t>(old->GetPrimalAttr(ORDER));
876 auto sr_tag = CalSrTag(order, micro, index);
877 attrs[SR_TAG] = MakeValue(sr_tag);
878 cloned_prim->SetAttrs(attrs);
879 std::vector<AnfNodePtr> new_node_input = {NewValueNode(cloned_prim), input};
880 auto new_node = main_graph_->NewCNode(new_node_input);
881 new_node->set_abstract(old->abstract());
882 if (old->HasPrimalAttr(PIPELINE_PARAM)) {
883 new_node->AddPrimalAttr(PIPELINE_PARAM, MakeValue(0));
884 }
885 new_node->set_primal_attrs(old->primal_attrs());
886 new_node->AddPrimalAttr(ORDER, MakeValue(sr_tag));
887 return new_node;
888 }
889
GenerateMainGraphSend(const std::vector<AnfNodePtr> & nodes,const AnfNodePtr & node,const ValuePtr & micro,const ValuePtr & index)890 std::vector<AnfNodePtr> PipelinePostProcess::GenerateMainGraphSend(const std::vector<AnfNodePtr> &nodes,
891 const AnfNodePtr &node, const ValuePtr µ,
892 const ValuePtr &index) {
893 std::vector<AnfNodePtr> sends;
894 auto index_value = GetValue<int64_t>(index);
895 for (size_t i = 0; i < nodes.size(); ++i) {
896 auto send = nodes[i];
897 auto csend = send->cast<CNodePtr>();
898 if (csend->HasPrimalAttr(PIPELINE_PARAM)) {
899 if (csend->HasPrimalAttr("send_once")) {
900 continue;
901 }
902 auto param = csend->cast<CNodePtr>()->user_data<AnfNode>(INPUT_PARAM);
903 csend->AddPrimalAttr("send_once", MakeValue(true));
904 auto new_send = GenNewNodeFromOld(send, param, 0, 0);
905 sends.emplace_back(new_send);
906 continue;
907 }
908 auto micro_value = GetValue<int64_t>(micro);
909 auto send_input = CreateTupleGetItemNode(main_graph_, node, i);
910 auto new_send = GenNewNodeFromOld(send, send_input, micro_value, index_value)->cast<CNodePtr>();
911 new_send->AddPrimalAttr(PIPELINE_END, micro);
912 new_send->AddPrimalAttr(MICRO, micro);
913 sends.emplace_back(new_send);
914 }
915 return sends;
916 }
917
GenerateMainGraphRecv(const AnfNodePtr & fg_node,const AnfNodePtr & recv)918 AnfNodePtr PipelinePostProcess::GenerateMainGraphRecv(const AnfNodePtr &fg_node, const AnfNodePtr &recv) {
919 auto cuser = fg_node->cast<CNodePtr>();
920 MS_EXCEPTION_IF_NULL(cuser);
921 auto crecv = recv->cast<CNodePtr>();
922 AnfNodePtr new_recv;
923 if (crecv->HasPrimalAttr(PIPELINE_PARAM)) {
924 auto param = crecv->user_data<AnfNode>(INPUT_PARAM);
925 MS_EXCEPTION_IF_NULL(param);
926 new_recv = GenNewNodeFromOld(recv, param, 0, 0);
927 } else {
928 auto index = cuser->GetPrimalAttr(INDEX);
929 MS_EXCEPTION_IF_NULL(index);
930 auto index_value = GetValue<int64_t>(index);
931 new_recv = GenNewNodeFromOld(recv, crecv->input(1), GetValue<int64_t>(cuser->GetPrimalAttr(MICRO)), index_value);
932 new_recv->cast<CNodePtr>()->AddPrimalAttr(PIPELINE_BEGIN, cuser->GetPrimalAttr(MICRO));
933 }
934 new_recv->cast<CNodePtr>()->AddPrimalAttr(MICRO, cuser->GetPrimalAttr(MICRO));
935 manager_->AddEdge(cuser, new_recv);
936 return new_recv;
937 }
938
Init(const std::vector<AnfNodePtr> & nodes)939 void PipelinePostProcess::Init(const std::vector<AnfNodePtr> &nodes) {
940 for (auto &node : nodes) {
941 if ((IsPrimitiveCNode(node, prim::kPrimSend) || IsPrimitiveCNode(node, prim::kPrimReceive)) &&
942 shared_cell_ == nullptr) {
943 shared_cell_ = node->cast<CNodePtr>()->func_graph();
944 }
945 if (IsPrimitiveCNode(node, prim::kPrimJ)) {
946 auto cnode = node->cast<CNodePtr>();
947 auto graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
948 main_graph_ = graph;
949 }
950 if (!IsPrimitiveCNode(node, prim::kPrimSend) && !IsPrimitiveCNode(node, prim::kPrimReceive)) {
951 continue;
952 }
953 auto cnode = node->cast<CNodePtr>();
954 auto chunk = GetValue<int64_t>(cnode->GetPrimalAttr(CHUNK));
955 chunk_num_ = (chunk + 1) > chunk_num_ ? (chunk + 1) : chunk_num_;
956 }
957 auto value_nodes = main_graph_->value_nodes();
958 for (auto value_pair = value_nodes.cbegin(); value_pair != value_nodes.cend(); ++value_pair) {
959 auto node = (*value_pair).first;
960 if (!IsValueNode<FuncGraph>(node)) {
961 continue;
962 }
963 auto fg = GetValueNode<FuncGraphPtr>(node);
964 if (fg != shared_cell_) {
965 continue;
966 }
967 auto node_users = manager_->node_users()[node];
968 for (auto &node_user : node_users) {
969 auto user = node_user.first;
970 if (user->func_graph() == main_graph_) {
971 shared_cell_users_.emplace_back(user);
972 }
973 }
974 break;
975 }
976 }
977
GetSendsRecvs(const FuncGraphPtr & fg,int64_t chunk,std::vector<AnfNodePtr> * recvs,std::vector<AnfNodePtr> * sends,std::vector<AnfNodePtr> * temp)978 void PipelinePostProcess::GetSendsRecvs(const FuncGraphPtr &fg, int64_t chunk, std::vector<AnfNodePtr> *recvs,
979 std::vector<AnfNodePtr> *sends, std::vector<AnfNodePtr> *temp) {
980 const auto &all_nodes = TopoSort(fg->get_return());
981 for (auto &node : all_nodes) {
982 if (!node->isa<CNode>()) {
983 continue;
984 }
985 auto cnode = node->cast<CNodePtr>();
986 if (!cnode->HasPrimalAttr(STAGE)) {
987 continue;
988 }
989 auto stage_value = cnode->GetPrimalAttr(STAGE);
990 if (stage_value && GetValue<int64_t>(stage_value) != stage_) {
991 continue;
992 }
993 if (IsPrimitiveCNode(cnode, prim::kPrimSend) && GetValue<int64_t>(cnode->GetPrimalAttr(CHUNK)) == chunk) {
994 if (!cnode->HasPrimalAttr(PIPELINE_PARAM)) {
995 temp->emplace_back(cnode->input(INDEX_ONE));
996 }
997 sends->emplace_back(node);
998 }
999 if (IsPrimitiveCNode(cnode, prim::kPrimReceive) && GetValue<int64_t>(cnode->GetPrimalAttr(CHUNK)) == chunk) {
1000 auto prim = GetCNodePrimitive(node);
1001 auto attrs = prim->attrs();
1002 auto zero_tensor = TensorConstructUtils::CreateZerosTensor(attrs[DTYPE]->cast<TypePtr>(), {1});
1003 manager_->SetEdge(node, 1, NewValueNode(zero_tensor));
1004 recvs->emplace_back(node);
1005 }
1006 }
1007 return;
1008 }
1009
LabelInterleaveIndex()1010 void PipelinePostProcess::LabelInterleaveIndex() {
1011 std::vector<int64_t> micro_visited;
1012 for (auto &usr : shared_cell_users_) {
1013 int64_t index = 0;
1014 auto cusr = usr->cast<CNodePtr>();
1015 MS_EXCEPTION_IF_NULL(cusr);
1016 auto micro = cusr->GetPrimalAttr(MICRO);
1017 MS_EXCEPTION_IF_NULL(micro);
1018 auto micro_value = GetValue<int64_t>(micro);
1019 if (!std::count(micro_visited.begin(), micro_visited.end(), micro_value)) {
1020 micro_visited.emplace_back(micro_value);
1021 } else {
1022 index += 1;
1023 }
1024 cusr->AddPrimalAttr(INDEX, MakeValue(index));
1025 }
1026 }
1027
PartitionChunkGraph(const FuncGraphPtr & fg,int64_t chunk)1028 std::vector<AnfNodePtr> PipelinePostProcess::PartitionChunkGraph(const FuncGraphPtr &fg, int64_t chunk) {
1029 std::vector<AnfNodePtr> temp;
1030 std::vector<AnfNodePtr> recvs;
1031 std::vector<AnfNodePtr> sends;
1032 GetSendsRecvs(fg, chunk, &recvs, &sends, &temp);
1033 AnfNodePtr out;
1034 if (!temp.empty()) {
1035 out = CreateMakeTupleNode(fg, temp);
1036 manager_->Replace(fg->output(), out);
1037 }
1038
1039 auto params = fg->parameters();
1040 std::vector<AnfNodePtr> new_params;
1041 auto node_users_map = manager_->node_users();
1042 std::vector<size_t> temp_index;
1043 for (size_t i = 0; i < params.size(); ++i) {
1044 auto param = params.at(i);
1045 if (node_users_map[param].size() == 0) {
1046 temp_index.emplace_back(i + 1);
1047 continue;
1048 }
1049 new_params.emplace_back(param);
1050 }
1051 for (auto &node : recvs) {
1052 auto crecv = node->cast<CNodePtr>();
1053 auto new_shared_cell_param = std::make_shared<Parameter>(fg);
1054 new_shared_cell_param->set_abstract(node->abstract());
1055 new_params.emplace_back(new_shared_cell_param);
1056 manager_->Replace(node, new_shared_cell_param);
1057 }
1058 manager_->SetParameters(fg, new_params);
1059 std::vector<AnfNodePtr> main_graph_sends;
1060 mindspore::HashMap<AnfNodePtr, AnfNodePtr> recv_map;
1061 for (auto &usr : shared_cell_users_) {
1062 auto cusr = usr->cast<CNodePtr>();
1063 std::vector<AnfNodePtr> usr_new_inputs = {NewValueNode(fg)};
1064 for (size_t i = 1; i < cusr->inputs().size(); ++i) {
1065 if (std::find(temp_index.begin(), temp_index.end(), i) == temp_index.end()) {
1066 usr_new_inputs.emplace_back(cusr->input(i));
1067 }
1068 }
1069 auto new_usr = main_graph_->NewCNode(usr_new_inputs);
1070 new_usr->set_primal_attrs(cusr->primal_attrs());
1071 new_usr->AddPrimalAttr(CHUNK, MakeValue(chunk));
1072 if (out != nullptr) {
1073 new_usr->set_abstract(out->abstract());
1074 }
1075 auto micro = cusr->GetPrimalAttr(MICRO);
1076 auto index = cusr->GetPrimalAttr(INDEX);
1077 auto temp_sends = GenerateMainGraphSend(sends, new_usr, micro, index);
1078 if (temp_sends.empty()) {
1079 if (stage_ != stage_num_ - 1) {
1080 MS_LOG(EXCEPTION) << "Some wrong with PipelineParallel.";
1081 }
1082 manager_->Replace(usr, new_usr);
1083 }
1084 main_graph_sends.insert(main_graph_sends.end(), temp_sends.begin(), temp_sends.end());
1085 for (auto &recv : recvs) {
1086 auto crecv = recv->cast<CNodePtr>();
1087 if (crecv->HasPrimalAttr(PIPELINE_PARAM)) {
1088 if (recv_map.find(recv) == recv_map.end()) {
1089 auto temp_recv = GenerateMainGraphRecv(new_usr, recv);
1090 recv_map[recv] = temp_recv;
1091 continue;
1092 }
1093 manager_->AddEdge(new_usr, recv_map[recv]);
1094 continue;
1095 }
1096 (void)GenerateMainGraphRecv(new_usr, recv);
1097 }
1098 }
1099 return main_graph_sends;
1100 }
1101
GraphPartition(const std::vector<AnfNodePtr> & all_nodes)1102 void PipelinePostProcess::GraphPartition(const std::vector<AnfNodePtr> &all_nodes) {
1103 LabelInterleaveIndex();
1104 std::vector<AnfNodePtr> send_ops;
1105 for (size_t i = 0; i < LongToSize(chunk_num_); ++i) {
1106 auto chunk_fg = shared_cell_;
1107 if (stage_ != stage_num_ - 1 || i != LongToSize(chunk_num_ - 1)) {
1108 chunk_fg = BasicClone(shared_cell_);
1109 chunk_fg->set_flag(FUNC_GRAPH_FLAG_CELL_REUSE, true);
1110 manager_->AddFuncGraph(chunk_fg);
1111 }
1112 auto sends = PartitionChunkGraph(chunk_fg, i);
1113 send_ops.insert(send_ops.begin(), sends.begin(), sends.end());
1114 }
1115 auto make_tuple = CreateMakeTupleNode(main_graph_, send_ops);
1116 auto outputs = GetZeroOutputs(main_graph_);
1117 if (stage_ == stage_num_ - 1) {
1118 outputs = main_graph_->output();
1119 }
1120 std::vector<AnfNodePtr> out = {NewValueNode(prim::kPrimDepend), outputs, make_tuple};
1121 auto out_node = main_graph_->NewCNode(out);
1122 (void)manager_->Replace(main_graph_->output(), out_node);
1123 }
1124
HandleSendParam()1125 void PipelinePostProcess::HandleSendParam() {
1126 auto parameters = root_->parameters();
1127 auto node_users_map = manager_->node_users();
1128 auto nodes = DeepScopedGraphSearch(root_->get_return());
1129 for (auto &node : nodes) {
1130 if (!IsPrimitiveCNode(node, prim::kPrimSend)) {
1131 continue;
1132 }
1133 auto cnode = node->cast<CNodePtr>();
1134 if (!cnode->HasPrimalAttr(PIPELINE_PARAM)) {
1135 continue;
1136 }
1137 auto param = cnode->input(1);
1138 if (IsPrimitiveCNode(param, prim::kPrimVirtualAssignAdd)) {
1139 param = param->cast<CNodePtr>()->input(1);
1140 }
1141 auto param_ptr = param->cast<ParameterPtr>();
1142 MS_EXCEPTION_IF_NULL(param_ptr);
1143 auto accu_parameter = FindGradAccuParameter(parameters, param_ptr->name());
1144 if (!accu_parameter) {
1145 continue;
1146 }
1147 auto accu_users = node_users_map.at(accu_parameter);
1148 AnfNodePtr share_node = nullptr;
1149 for (auto &user : accu_users) {
1150 auto user_node = user.first;
1151 while (IsSomePrimitiveList(user_node->cast<CNodePtr>(),
1152 {prim::kPrimMirrorMicroStep->name(), prim::kPrimMicroStepAllGather->name()})) {
1153 share_node = user_node;
1154 user_node = node_users_map.at(user_node).front().first;
1155 }
1156 if (share_node == nullptr) {
1157 continue;
1158 }
1159 auto base_shape = accu_parameter->Shape();
1160 auto shape_ptr = dyn_cast<abstract::Shape>(base_shape);
1161 auto slice_shape = shape_ptr->shape();
1162 auto prim = GetCNodePrimitive(cnode);
1163 std::vector<ValuePtr> element;
1164 (void)std::transform(slice_shape.begin(), slice_shape.end(), std::back_inserter(element),
1165 [](int elem) { return MakeValue(int64_t(elem)); });
1166 auto value = std::make_shared<ValueList>(element);
1167 prim->set_attr(SHAPE, value);
1168 manager_->SetEdge(cnode, 1, share_node);
1169 break;
1170 }
1171 }
1172 }
1173
ElimGraphStage()1174 void PipelinePostProcess::ElimGraphStage() {
1175 for (auto &fg : manager_->func_graphs()) {
1176 fg->set_stage(-1);
1177 }
1178 }
1179
HasNoUpdateParameter()1180 bool PipelineInterleave::HasNoUpdateParameter() {
1181 auto parameters = root_->parameters();
1182 for (auto ¶meter : parameters) {
1183 if (ParameterIsCloned(parameter)) {
1184 continue;
1185 }
1186 auto param_info = parameter->cast<ParameterPtr>()->param_info();
1187 if (!param_info) {
1188 continue;
1189 }
1190 auto stage_set = parameter_color_map_.at(parameter);
1191 auto requires_grad = param_info->requires_grad();
1192 if (requires_grad && stage_set.count(stage_)) {
1193 return false;
1194 }
1195 }
1196 return true;
1197 }
1198
FreezeGradient()1199 void PipelineInterleave::FreezeGradient() {
1200 auto node_users_map = manager_->node_users();
1201 if (HasNoUpdateParameter() && is_train_) {
1202 root_->set_flag(NO_UPDATE, true);
1203 auto nodes = root_->nodes();
1204 for (auto &node : nodes) {
1205 if (!IsPrimitiveCNode(node, prim::kPrimJ)) {
1206 continue;
1207 }
1208 auto node_users = node_users_map.at(node);
1209 auto grad_users = node_users_map.at(node_users.front().first);
1210 for (auto &grad_user : grad_users) {
1211 auto user_node = grad_user.first->cast<CNodePtr>();
1212 if (!IsPrimitiveCNode(user_node, prim::kPrimTupleGetItem)) {
1213 continue;
1214 }
1215 auto index = GetTupleGetItemIndex(user_node);
1216 if (index != 1) {
1217 continue;
1218 }
1219 auto temp = node_users_map.at(user_node).front().first;
1220 auto out = root_->output();
1221 std::vector<AnfNodePtr> depend_input = {NewValueNode(prim::kPrimDepend), out, temp};
1222 auto new_node = root_->NewCNode(depend_input);
1223 manager_->Replace(out, new_node);
1224 break;
1225 }
1226 break;
1227 }
1228 for (auto &node : nodes) {
1229 if (!IsPrimitiveCNode(node, prim::kPrimNPUGetFloatStatusV2)) {
1230 continue;
1231 }
1232 auto cnode = node->cast<CNodePtr>();
1233 auto out_cnode = root_->output()->cast<CNodePtr>();
1234 auto grads = out_cnode->input(INDEX_TWO);
1235 std::vector<AnfNodePtr> depend_input = {NewValueNode(prim::kPrimDepend), cnode->input(1), grads};
1236 auto new_node = root_->NewCNode(depend_input);
1237 new_node->set_abstract(cnode->input(1)->abstract());
1238 manager_->Replace(cnode->input(1), new_node);
1239 break;
1240 }
1241 }
1242 }
1243
GetDout(const AnfNodePtr & node,const NodeUsersMap & node_users_map)1244 static AnfNodePtr GetDout(const AnfNodePtr &node, const NodeUsersMap &node_users_map) {
1245 auto node_usrs = node_users_map.at(node);
1246 for (auto &node_user_pair : node_usrs) {
1247 auto node_usr = node_user_pair.first->cast<CNodePtr>();
1248 if (!IsPrimitiveCNode(node_usr, prim::kPrimTupleGetItem)) {
1249 continue;
1250 }
1251 auto index = GetTupleGetItemIndex(node_usr);
1252 if (index != 1) {
1253 continue;
1254 }
1255 auto get_item_usrs = node_users_map.at(node_usr);
1256 if (get_item_usrs.size() != 1) {
1257 MS_LOG(WARNING) << "Get Multi grad usrs. Use first.";
1258 }
1259 return get_item_usrs.begin()->first;
1260 }
1261 return nullptr;
1262 }
1263
NeedAttach(const FuncGraphManagerPtr & manager)1264 static bool NeedAttach(const FuncGraphManagerPtr &manager) {
1265 std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
1266 if (parallel_mode != kAutoParallel && parallel_mode != kSemiAutoParallel) {
1267 return false;
1268 }
1269 bool cell_reuse = false;
1270 for (auto &fg : manager->func_graphs()) {
1271 if (fg->has_flag(FUNC_GRAPH_FLAG_CELL_REUSE)) {
1272 cell_reuse = true;
1273 break;
1274 }
1275 }
1276 auto stage_num = g_device_manager->stage_num();
1277 if (!cell_reuse || stage_num <= 1) {
1278 return false;
1279 }
1280 return true;
1281 }
1282
IsolatedNodeAttach(const FuncGraphPtr & root,const opt::OptimizerPtr & optimizer)1283 bool IsolatedNodeAttach(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) {
1284 if (root->has_flag(HAS_ATTACHED)) {
1285 return false;
1286 }
1287 root->set_flag(HAS_ATTACHED, true);
1288 auto manager = root->manager();
1289 if (!NeedAttach(manager)) {
1290 return false;
1291 }
1292 auto ret_after = root->get_return();
1293 MS_EXCEPTION_IF_NULL(ret_after);
1294 auto all_nodes = DeepScopedGraphSearch(ret_after);
1295 const auto &node_users_map = manager->node_users();
1296 std::vector<AnfNodePtr> make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)};
1297 FuncGraphPtr main_graph;
1298 FuncGraphPtr grad_graph;
1299 for (auto &node : all_nodes) {
1300 if (!node->isa<CNode>()) {
1301 continue;
1302 }
1303 auto cnode = node->cast<CNodePtr>();
1304 if (!IsValueNode<FuncGraph>(cnode->input(0))) {
1305 continue;
1306 }
1307 auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
1308 auto sub_graph_output = graph->output();
1309 if (!IsPrimitiveCNode(sub_graph_output, prim::kPrimMakeTuple)) {
1310 continue;
1311 }
1312 auto csub_graph_output = sub_graph_output->cast<CNodePtr>();
1313 if (!IsPrimitiveCNode(csub_graph_output->input(1), prim::kPrimReceive)) {
1314 continue;
1315 }
1316 auto call_node_input = cnode->input(1);
1317 if (!IsValueNode<tensor::Tensor>(call_node_input)) {
1318 continue;
1319 }
1320 auto call_node_users = node_users_map.at(node);
1321 if (call_node_users.size() != 1) {
1322 continue;
1323 }
1324 auto usr_node = call_node_users.begin()->first;
1325 if (!IsPrimitiveCNode(usr_node, prim::kPrimTupleGetItem)) {
1326 continue;
1327 }
1328 auto get_item_usrs = node_users_map.at(usr_node);
1329 std::vector<AnfNodePtr> addn_input = {NewValueNode(prim::kPrimAddN)};
1330 main_graph = node->func_graph();
1331 for (auto &get_item_usr_pair : get_item_usrs) {
1332 auto get_item_usr = get_item_usr_pair.first;
1333 auto grad_node = GetDout(get_item_usr, node_users_map);
1334 if (grad_graph == nullptr) {
1335 grad_graph = grad_node->func_graph();
1336 } else {
1337 if (grad_graph != grad_node->func_graph()) {
1338 MS_LOG(EXCEPTION) << "Got Wrong Grad graph when attached Receive's grad, Maybe don't use lazy inline.";
1339 }
1340 }
1341 std::vector<AnfNodePtr> new_get_item_input = {NewValueNode(prim::kPrimTupleGetItem), grad_node,
1342 NewValueNode(MakeValue(SizeToLong(get_item_usr_pair.second)))};
1343 auto new_get_item = grad_graph->NewCNode(new_get_item_input);
1344 addn_input.emplace_back(new_get_item);
1345 }
1346 AnfNodePtr temp;
1347 if (addn_input.size() > SIZE_TWO) {
1348 temp = grad_graph->NewCNode(addn_input);
1349 } else {
1350 temp = addn_input.at(1);
1351 }
1352 std::vector<AnfNodePtr> send_grad_fn_input = {NewValueNode(prim::kPrimTupleGetItem), node,
1353 NewValueNode(MakeValue(int64_t(1)))};
1354 auto send_grad_fn = main_graph->NewCNode(send_grad_fn_input);
1355 auto call_grad_node = grad_graph->NewCNode({send_grad_fn, temp});
1356 std::vector<AnfNodePtr> call_grad_get_item_input = {NewValueNode(prim::kPrimTupleGetItem), call_grad_node,
1357 NewValueNode(MakeValue(int64_t(1)))};
1358 auto call_grad_get_item = grad_graph->NewCNode(call_grad_get_item_input);
1359 make_tuple_input.emplace_back(call_grad_get_item);
1360 }
1361 if (make_tuple_input.size() <= 1) {
1362 return false;
1363 }
1364 auto make_tuple = grad_graph->NewCNode(make_tuple_input);
1365 if (root->has_flag(NO_UPDATE)) {
1366 manager->Replace(grad_graph->output(), make_tuple);
1367 return true;
1368 }
1369 std::vector<AnfNodePtr> attach_node_input = {NewValueNode(prim::kPrimDepend), grad_graph->output(), make_tuple};
1370 auto attach_node = grad_graph->NewCNode(attach_node_input);
1371 manager->Replace(grad_graph->output(), attach_node);
1372 return true;
1373 }
1374 } // namespace parallel
1375 } // namespace mindspore
1376