1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <string>
18 #include <memory>
19 #include "pipeline/jit/pipeline_split.h"
20 #include "utils/ms_context.h"
21 #include "utils/comm_manager.h"
22 #include "frontend/parallel/context.h"
23 #include "frontend/parallel/pipeline_transformer/pipeline_transformer.h"
24 #include "frontend/parallel/step_parallel.h"
25
26 namespace mindspore {
27 namespace pipeline {
GetWorldGroup()28 std::string GetWorldGroup() {
29 auto ms_context = MsContext::GetInstance();
30 MS_EXCEPTION_IF_NULL(ms_context);
31 std::string world_group;
32 std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
33 if (backend == kAscendDevice) {
34 world_group = parallel::HCCL_WORLD_GROUP;
35 } else if (backend == kGPUDevice) {
36 world_group = parallel::NCCL_WORLD_GROUP;
37 } else {
38 MS_LOG(EXCEPTION) << "Invalid backend: " << backend;
39 }
40 return world_group;
41 }
42
GetRank()43 static int64_t GetRank() {
44 auto ms_context = MsContext::GetInstance();
45 MS_EXCEPTION_IF_NULL(ms_context);
46 auto world_group = GetWorldGroup();
47 int64_t global_rank = parallel::ParallelContext::GetInstance()->global_rank();
48 uint32_t rank_id = 0;
49 if (!parallel::ParallelContext::GetInstance()->global_rank_is_set()) {
50 if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) {
51 MS_LOG(EXCEPTION) << "Get rank id failed.";
52 }
53 global_rank = UintToInt(rank_id);
54 }
55 return global_rank;
56 }
57
InferStage(int64_t rank_id,int64_t stage_num,int64_t device_num)58 static int64_t InferStage(int64_t rank_id, int64_t stage_num, int64_t device_num) {
59 if (stage_num == 0) {
60 MS_LOG(EXCEPTION) << "stage_num is zero";
61 }
62 if (device_num % stage_num != 0) {
63 MS_LOG(EXCEPTION) << "Device_num must be divisible by the stage_num, got device_num: " << device_num
64 << "stage_num: " << stage_num;
65 }
66 auto per_stage_rank_num = device_num / stage_num;
67 return rank_id / per_stage_rank_num;
68 }
69
70 // Only auto_parallel and semi_auto_parallel support PipelineSplit
PipelineSplit(const ResourcePtr & res)71 bool PipelineSplit(const ResourcePtr &res) {
72 MS_EXCEPTION_IF_NULL(res);
73 auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
74 if (parallel_mode != parallel::SEMI_AUTO_PARALLEL && parallel_mode != parallel::AUTO_PARALLEL) {
75 MS_LOG(INFO) << "Only auto_parallel and semi_auto_parallel support pipeline split.";
76 return true;
77 }
78 auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num();
79 if (stage_num <= 1) {
80 MS_LOG(INFO) << "stage num is: " << stage_num << ". No need Pipeline split.";
81 return true;
82 }
83 auto manager = res->manager();
84 auto root = res->func_graph();
85 auto global_rank = GetRank();
86 auto world_group = GetWorldGroup();
87 uint32_t world_rank_size = 0;
88 int64_t device_num = 0;
89 if (!parallel::ParallelContext::GetInstance()->device_num_is_set()) {
90 if (!CommManager::GetInstance().GetRankSize(world_group, &world_rank_size)) {
91 MS_LOG(EXCEPTION) << "Get rank size failed";
92 }
93 device_num = UintToInt(world_rank_size);
94 MS_LOG(INFO) << "Get device num from communication model, the device num is " << device_num;
95 } else {
96 device_num = parallel::ParallelContext::GetInstance()->device_num();
97 }
98 if (device_num < 1) {
99 MS_LOG(EXCEPTION) << "Invalid device num: " << device_num;
100 }
101 if (global_rank < 0) {
102 MS_LOG(EXCEPTION) << "Invalid global rank: " << global_rank;
103 }
104 auto stage = InferStage(global_rank, stage_num, device_num);
105 auto per_stage_rank_num = device_num / stage_num;
106 if (parallel::ParallelInit() != parallel::SUCCESS) {
107 MS_LOG(EXCEPTION) << "parallel init failed.";
108 }
109 auto transformer =
110 std::make_shared<parallel::PipelineTransformer>(manager, stage, root, global_rank, per_stage_rank_num);
111 // step1: Do color graph
112 transformer->Coloring();
113 transformer->MainGraph();
114 // step2: Do color broadcast
115 transformer->BroadCastColoring();
116 transformer->LabelMicroBatch();
117 // step3: Handle shared parameters
118 transformer->ParameterColoring();
119 // step4: Cut Graph
120 transformer->CutGraph();
121 // step5: Handle Sens
122 if (root->has_flag(parallel::TRAINING)) {
123 transformer->CoverSensShape();
124 }
125 // step6: Elim Graph stages and no used parameter
126 transformer->ElimGraphStage();
127 transformer->ElimParameter();
128 return true;
129 }
130 } // namespace pipeline
131 } // namespace mindspore
132