• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <string>
18 #include <memory>
19 #include "pipeline/jit/pipeline_split.h"
20 #include "utils/ms_context.h"
21 #include "utils/comm_manager.h"
22 #include "frontend/parallel/context.h"
23 #include "frontend/parallel/pipeline_transformer/pipeline_transformer.h"
24 #include "frontend/parallel/step_parallel.h"
25 
26 namespace mindspore {
27 namespace pipeline {
GetWorldGroup()28 std::string GetWorldGroup() {
29   auto ms_context = MsContext::GetInstance();
30   MS_EXCEPTION_IF_NULL(ms_context);
31   std::string world_group;
32   std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
33   if (backend == kAscendDevice) {
34     world_group = parallel::HCCL_WORLD_GROUP;
35   } else if (backend == kGPUDevice) {
36     world_group = parallel::NCCL_WORLD_GROUP;
37   } else {
38     MS_LOG(EXCEPTION) << "Invalid backend: " << backend;
39   }
40   return world_group;
41 }
42 
GetRank()43 static int64_t GetRank() {
44   auto ms_context = MsContext::GetInstance();
45   MS_EXCEPTION_IF_NULL(ms_context);
46   auto world_group = GetWorldGroup();
47   int64_t global_rank = parallel::ParallelContext::GetInstance()->global_rank();
48   uint32_t rank_id = 0;
49   if (!parallel::ParallelContext::GetInstance()->global_rank_is_set()) {
50     if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) {
51       MS_LOG(EXCEPTION) << "Get rank id failed.";
52     }
53     global_rank = UintToInt(rank_id);
54   }
55   return global_rank;
56 }
57 
InferStage(int64_t rank_id,int64_t stage_num,int64_t device_num)58 static int64_t InferStage(int64_t rank_id, int64_t stage_num, int64_t device_num) {
59   if (stage_num == 0) {
60     MS_LOG(EXCEPTION) << "stage_num is zero";
61   }
62   if (device_num % stage_num != 0) {
63     MS_LOG(EXCEPTION) << "Device_num must be divisible by the stage_num, got device_num: " << device_num
64                       << "stage_num: " << stage_num;
65   }
66   auto per_stage_rank_num = device_num / stage_num;
67   return rank_id / per_stage_rank_num;
68 }
69 
70 // Only auto_parallel and semi_auto_parallel support PipelineSplit
PipelineSplit(const ResourcePtr & res)71 bool PipelineSplit(const ResourcePtr &res) {
72   MS_EXCEPTION_IF_NULL(res);
73   auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
74   if (parallel_mode != parallel::SEMI_AUTO_PARALLEL && parallel_mode != parallel::AUTO_PARALLEL) {
75     MS_LOG(INFO) << "Only auto_parallel and semi_auto_parallel support pipeline split.";
76     return true;
77   }
78   auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num();
79   if (stage_num <= 1) {
80     MS_LOG(INFO) << "stage num is: " << stage_num << ". No need Pipeline split.";
81     return true;
82   }
83   auto manager = res->manager();
84   auto root = res->func_graph();
85   auto global_rank = GetRank();
86   auto world_group = GetWorldGroup();
87   uint32_t world_rank_size = 0;
88   int64_t device_num = 0;
89   if (!parallel::ParallelContext::GetInstance()->device_num_is_set()) {
90     if (!CommManager::GetInstance().GetRankSize(world_group, &world_rank_size)) {
91       MS_LOG(EXCEPTION) << "Get rank size failed";
92     }
93     device_num = UintToInt(world_rank_size);
94     MS_LOG(INFO) << "Get device num from communication model, the device num is  " << device_num;
95   } else {
96     device_num = parallel::ParallelContext::GetInstance()->device_num();
97   }
98   if (device_num < 1) {
99     MS_LOG(EXCEPTION) << "Invalid device num: " << device_num;
100   }
101   if (global_rank < 0) {
102     MS_LOG(EXCEPTION) << "Invalid global rank: " << global_rank;
103   }
104   auto stage = InferStage(global_rank, stage_num, device_num);
105   auto per_stage_rank_num = device_num / stage_num;
106   if (parallel::ParallelInit() != parallel::SUCCESS) {
107     MS_LOG(EXCEPTION) << "parallel init failed.";
108   }
109   auto transformer =
110     std::make_shared<parallel::PipelineTransformer>(manager, stage, root, global_rank, per_stage_rank_num);
111   // step1: Do color graph
112   transformer->Coloring();
113   transformer->MainGraph();
114   // step2: Do color broadcast
115   transformer->BroadCastColoring();
116   transformer->LabelMicroBatch();
117   // step3: Handle shared parameters
118   transformer->ParameterColoring();
119   // step4: Cut Graph
120   transformer->CutGraph();
121   // step5: Handle Sens
122   if (root->has_flag(parallel::TRAINING)) {
123     transformer->CoverSensShape();
124   }
125   // step6: Elim Graph stages and no used parameter
126   transformer->ElimGraphStage();
127   transformer->ElimParameter();
128   return true;
129 }
130 }  // namespace pipeline
131 }  // namespace mindspore
132