• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "include/common/utils/parallel_context.h"
18 
19 #include <algorithm>
20 #include <cstdint>
21 #include <functional>
22 #include <map>
23 #include <memory>
24 
25 namespace mindspore::parallel {
26 namespace {
27 std::vector<std::string> kParallelModeList = {kStandalone, kDataParallel, kHybridParallel, kSemiAutoParallel,
28                                               kAutoParallel};
29 std::vector<std::string> kStrategySearchModeList = {kDynamicProgramming, kRecursiveProgramming, kShardingPropagation};
30 
31 std::vector<std::string> kCommuniParallelModeList = {kAllGroupParallel, kSameServerGroupParallel, kNoGroupParallel};
32 
33 std::vector<std::string> kFusionModeList = {kFusionAuto, kFusionSize, kFusionIndex};
34 }  // namespace
35 
GetInstance()36 std::shared_ptr<ParallelContext> ParallelContext::GetInstance() {
37   static std::shared_ptr<ParallelContext> inst_context_ = std::shared_ptr<ParallelContext>(new ParallelContext());
38   return inst_context_;
39 }
40 
ParallelContext()41 ParallelContext::ParallelContext() { Reset(); }
42 
Reset()43 void ParallelContext::Reset() {
44   gradients_mean_ = false;
45   full_batch_ = false;
46   full_batch_is_set_ = false;
47   gradient_fp32_sync_ = true;
48   loss_repeated_mean_ = true;
49   device_num_ = 1;
50   global_rank_ = 0;
51   device_num_is_set_ = false;
52   global_rank_is_set_ = false;
53   parallel_mode_ = kStandalone;
54   parameter_broadcast_ = false;
55   parameter_broadcast_is_set_ = false;
56   enable_all_reduce_fusion_ = true;
57   enable_all_gather_fusion_ = true;
58   enable_reduce_scatter_fusion_ = true;
59   strategy_json_config_file_type_ = "";
60   strategy_json_config_file_path_ = "";
61   strategy_json_config_file_mode_ = "";
62   strategy_ckpt_load_file_ = "";
63   strategy_ckpt_save_file_ = "";
64   enable_parallel_optimizer_ = false;
65   force_fp32_communication_ = false;
66   all_reduce_fusion_split_indices_.clear();
67   all_reduce_fusion_split_sizes_.clear();
68   strategy_search_mode_ = kRecursiveProgramming;
69   pipeline_stage_split_num_ = 1;
70   pipeline_segment_split_num_ = 1;
71   grad_accumulation_step_ = 1;
72   communi_parallel_mode_ = kAllGroupParallel;
73   optimizer_weight_shard_size_ = -1;
74   optimizer_weight_shard_aggregated_save_ = false;
75   enable_all2all_ = false;
76   grad_accumulation_shard_ = false;
77   parallel_optimizer_threshold_ = -1;
78   sharding_propagation_ = false;
79   dataset_strategy_.clear();
80   dp_fusion_threshold_mb_ = kDataParallelFusionThreshold;
81   fusion_threshold_mb_ = kFusionThreshold;
82   allgather_fusion_threshold_mb_ = kFusionThreshold;
83   reducescatter_fusion_threshold_mb_ = kFusionThreshold;
84   fusion_threshold_is_set_ = true;
85   fusion_mode_ = kFusionAuto;
86   group_ckpt_save_file_ = "";
87   pipeline_micro_size_ = 1;
88   dataset_repeat_dim_right_ = false;
89   hccl_test_available_ = false;
90   enable_micro_interleaved_ = false;
91   enable_fine_grained_micro_interleaved_ = false;
92   do_transform_ = false;
93   direct_split_ = false;
94   pipeline_result_broadcast_ = false;
95   stra_file_only_trainable_params_ = true;
96   pipeline_interleave_ = false;
97   pipeline_scheduler_ = kPipeline1F1B;
98   auto_pipeline_ = false;
99 }
100 
set_device_num(int64_t device_num)101 void ParallelContext::set_device_num(int64_t device_num) {
102   device_num_ = device_num;
103   device_num_is_set_ = true;
104 }
105 
set_fusion_threshold_mb(int64_t fusion_threshold)106 void ParallelContext::set_fusion_threshold_mb(int64_t fusion_threshold) {
107   fusion_threshold_mb_ = fusion_threshold;
108   dp_fusion_threshold_mb_ = fusion_threshold;
109   fusion_threshold_is_set_ = true;
110   enable_all_reduce_fusion_ = true;
111 }
112 
set_allgather_fusion_threshold_mb(int64_t fusion_threshold)113 void ParallelContext::set_allgather_fusion_threshold_mb(int64_t fusion_threshold) {
114   allgather_fusion_threshold_mb_ = fusion_threshold;
115   enable_all_gather_fusion_ = true;
116 }
117 
set_reducescatter_fusion_threshold_mb(int64_t fusion_threshold)118 void ParallelContext::set_reducescatter_fusion_threshold_mb(int64_t fusion_threshold) {
119   reducescatter_fusion_threshold_mb_ = fusion_threshold;
120   enable_reduce_scatter_fusion_ = true;
121 }
122 
set_fusion_mode(const std::string & fusion_mode)123 bool ParallelContext::set_fusion_mode(const std::string &fusion_mode) {
124   auto iter = std::find(kFusionModeList.begin(), kFusionModeList.end(), fusion_mode);
125   if (iter == kFusionModeList.end()) {
126     MS_LOG(INFO) << "Invalid fusion mode:" << fusion_mode;
127     return false;
128   }
129   fusion_mode_ = fusion_mode;
130   return true;
131 }
132 
set_global_rank(int64_t global_rank)133 void ParallelContext::set_global_rank(int64_t global_rank) {
134   global_rank_ = global_rank;
135   global_rank_is_set_ = true;
136 }
137 
set_gradients_mean(bool gradients_mean)138 void ParallelContext::set_gradients_mean(bool gradients_mean) { gradients_mean_ = gradients_mean; }
139 
set_full_batch(bool full_batch)140 void ParallelContext::set_full_batch(bool full_batch) {
141   full_batch_ = full_batch;
142   full_batch_is_set_ = true;
143 }
144 
set_dataset_strategy(const std::vector<std::vector<int64_t>> & dataset_strategy)145 void ParallelContext::set_dataset_strategy(const std::vector<std::vector<int64_t>> &dataset_strategy) {
146   dataset_strategy_ = dataset_strategy;
147 }
148 
set_grad_accumulation_step(int64_t grad_accumulation_step)149 void ParallelContext::set_grad_accumulation_step(int64_t grad_accumulation_step) {
150   grad_accumulation_step_ = grad_accumulation_step;
151 }
152 
set_gradient_fp32_sync(bool gradient_fp32_sync)153 void ParallelContext::set_gradient_fp32_sync(bool gradient_fp32_sync) { gradient_fp32_sync_ = gradient_fp32_sync; }
154 
set_loss_repeated_mean(bool loss_repeated_mean)155 void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; }
156 
set_pipeline_stage_split_num(const int64_t stage_num)157 void ParallelContext::set_pipeline_stage_split_num(const int64_t stage_num) { pipeline_stage_split_num_ = stage_num; }
158 
set_pipeline_interleave(const bool pipeline_interleave)159 void ParallelContext::set_pipeline_interleave(const bool pipeline_interleave) {
160   pipeline_interleave_ = pipeline_interleave;
161 }
162 
set_pipeline_scheduler(const std::string & pipeline_scheduler)163 void ParallelContext::set_pipeline_scheduler(const std::string &pipeline_scheduler) {
164   pipeline_scheduler_ = pipeline_scheduler;
165 }
166 
set_pipeline_segment_split_num(const int64_t segment_num)167 void ParallelContext::set_pipeline_segment_split_num(const int64_t segment_num) {
168   pipeline_segment_split_num_ = segment_num;
169 }
170 
set_parallel_mode(const std::string & parallel_mode)171 bool ParallelContext::set_parallel_mode(const std::string &parallel_mode) {
172   auto iter = std::find(kParallelModeList.begin(), kParallelModeList.end(), parallel_mode);
173   if (iter == kParallelModeList.end()) {
174     MS_LOG(INFO) << "Invalid parallel mode:" << parallel_mode;
175     return false;
176   }
177   parallel_mode_ = parallel_mode;
178   return true;
179 }
180 
set_strategy_search_mode(const std::string & strategy_search_mode)181 bool ParallelContext::set_strategy_search_mode(const std::string &strategy_search_mode) {
182   auto iter = std::find(kStrategySearchModeList.begin(), kStrategySearchModeList.end(), strategy_search_mode);
183   if (iter == kStrategySearchModeList.end()) {
184     MS_LOG(INFO) << "Invalid strategy search mode mode: " << strategy_search_mode;
185     return false;
186   }
187   strategy_search_mode_ = strategy_search_mode;
188   return true;
189 }
190 
set_parameter_broadcast(bool parameter_broadcast)191 void ParallelContext::set_parameter_broadcast(bool parameter_broadcast) {
192   parameter_broadcast_ = parameter_broadcast;
193   parameter_broadcast_is_set_ = true;
194 }
195 
set_ops_strategy_json_config(const std::string & type,const std::string & path,const std::string & mode)196 void ParallelContext::set_ops_strategy_json_config(const std::string &type, const std::string &path,
197                                                    const std::string &mode) {
198   strategy_json_config_file_type_ = type;
199   strategy_json_config_file_path_ = path;
200   strategy_json_config_file_mode_ = mode;
201 }
202 
set_strategy_ckpt_load_file(const std::string & strategy_ckpt_load_file)203 void ParallelContext::set_strategy_ckpt_load_file(const std::string &strategy_ckpt_load_file) {
204   strategy_ckpt_load_file_ = strategy_ckpt_load_file;
205 }
206 
set_strategy_ckpt_save_file(const std::string & strategy_ckpt_save_file)207 void ParallelContext::set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file) {
208   strategy_ckpt_save_file_ = strategy_ckpt_save_file;
209 }
210 
set_group_ckpt_save_file(const std::string & group_ckpt_save_file)211 void ParallelContext::set_group_ckpt_save_file(const std::string &group_ckpt_save_file) {
212   group_ckpt_save_file_ = group_ckpt_save_file;
213 }
214 
set_optimizer_weight_shard_size(int64_t optimizer_weight_shard_size)215 void ParallelContext::set_optimizer_weight_shard_size(int64_t optimizer_weight_shard_size) {
216   optimizer_weight_shard_size_ = optimizer_weight_shard_size;
217 }
218 
set_optimizer_weight_shard_aggregated_save(bool optimizer_weight_shard_aggregated_save)219 void ParallelContext::set_optimizer_weight_shard_aggregated_save(bool optimizer_weight_shard_aggregated_save) {
220   optimizer_weight_shard_aggregated_save_ = optimizer_weight_shard_aggregated_save;
221 }
222 
SetAllReduceFusionSplitIndices(const std::vector<uint32_t> & indices,const std::string & group)223 void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector<uint32_t> &indices, const std::string &group) {
224   if (!group.empty() && group.find(TypeIdLabel(kNumberTypeFloat)) == std::string::npos &&
225       group.find(TypeIdLabel(kNumberTypeFloat16)) == std::string::npos &&
226       group.find(TypeIdLabel(kNumberTypeFloat32)) == std::string::npos) {
227     all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat)] = indices;
228     all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat16)] = indices;
229     all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat32)] = indices;
230   }
231   all_reduce_fusion_split_indices_[group] = indices;
232 }
233 
GetAllReduceFusionSplitIndices(const std::string & group) const234 std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitIndices(const std::string &group) const {
235   auto iter = all_reduce_fusion_split_indices_.find(group);
236   if (iter != all_reduce_fusion_split_indices_.end()) {
237     return iter->second;
238   }
239   return {};
240 }
241 
SetAllReduceFusionSplitSizes(const std::vector<uint32_t> & sizes,const std::string & group)242 void ParallelContext::SetAllReduceFusionSplitSizes(const std::vector<uint32_t> &sizes, const std::string &group) {
243   if (!group.empty() && group.find(TypeIdLabel(kNumberTypeFloat)) == std::string::npos &&
244       group.find(TypeIdLabel(kNumberTypeFloat16)) == std::string::npos &&
245       group.find(TypeIdLabel(kNumberTypeFloat32)) == std::string::npos) {
246     all_reduce_fusion_split_sizes_[group + TypeIdLabel(kNumberTypeFloat)] = sizes;
247     all_reduce_fusion_split_sizes_[group + TypeIdLabel(kNumberTypeFloat16)] = sizes;
248     all_reduce_fusion_split_sizes_[group + TypeIdLabel(kNumberTypeFloat32)] = sizes;
249   }
250   all_reduce_fusion_split_sizes_[group] = sizes;
251 }
252 
GetAllReduceFusionSplitSizes(const std::string & group) const253 std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitSizes(const std::string &group) const {
254   auto iter = all_reduce_fusion_split_sizes_.find(group);
255   if (iter != all_reduce_fusion_split_sizes_.end()) {
256     return iter->second;
257   }
258   return {};
259 }
260 
set_communi_parallel_mode(const std::string & communi_parallel_mode)261 bool ParallelContext::set_communi_parallel_mode(const std::string &communi_parallel_mode) {
262   auto iter = std::find(kCommuniParallelModeList.begin(), kCommuniParallelModeList.end(), communi_parallel_mode);
263   if (iter == kCommuniParallelModeList.end()) {
264     MS_LOG(INFO) << "Invalid communication parallel mode:" << communi_parallel_mode;
265     return false;
266   }
267 
268   communi_parallel_mode_ = communi_parallel_mode;
269   return true;
270 }
271 
272 // Restore the parameters' shape for evaluation/prediction in auto-parallel or semi-auto-parallel mode
ParallelParameterContextRestoreShape(const FuncGraphPtr & func_graph,const ParameterPtr & param_node,const AbstractBasePtr & ptr) const273 void ParallelContext::ParallelParameterContextRestoreShape(const FuncGraphPtr &func_graph,
274                                                            const ParameterPtr &param_node,
275                                                            const AbstractBasePtr &ptr) const {
276   MS_EXCEPTION_IF_NULL(func_graph);
277   MS_EXCEPTION_IF_NULL(param_node);
278   MS_EXCEPTION_IF_NULL(ptr);
279   if (!ParallelContextCareGraph(func_graph)) {
280     return;
281   }
282 
283   auto param_info = param_node->param_info();
284   if (!param_info) {
285     return;
286   }
287   auto shape = param_info->parameter_shape();
288   if (shape.empty()) {
289     MS_LOG(INFO) << "The parameter " << param_node->name() << "'s parameter_shape in param_info is empty";
290     return;
291   }
292   std::shared_ptr<abstract::BaseShape> base_shape = std::make_shared<abstract::Shape>(shape);
293   ptr->set_shape(base_shape);
294   MS_LOG(INFO) << "The parameter name is " << param_node->name() << ", the shape is " << shape;
295 }
296 
ParallelContextCareGraph(const FuncGraphPtr & func_graph) const297 bool ParallelContext::ParallelContextCareGraph(const FuncGraphPtr &func_graph) const {
298   MS_EXCEPTION_IF_NULL(func_graph);
299   if (func_graph->has_flag(kSkipAutoParallelCompile)) {
300     return false;
301   }
302 
303   std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
304   if (parallel_mode != kAutoParallel && parallel_mode != kSemiAutoParallel) {
305     return false;
306   }
307 
308   return true;
309 }
310 
set_enable_all2all(const bool enable)311 void ParallelContext::set_enable_all2all(const bool enable) { enable_all2all_ = enable; }
312 
set_enable_micro_interleaved(const bool enable_micro_interleaved)313 void ParallelContext::set_enable_micro_interleaved(const bool enable_micro_interleaved) {
314   enable_micro_interleaved_ = enable_micro_interleaved;
315 }
316 
set_enable_fine_grained_micro_interleaved(const bool enable_fine_grained_micro_interleaved)317 void ParallelContext::set_enable_fine_grained_micro_interleaved(const bool enable_fine_grained_micro_interleaved) {
318   enable_fine_grained_micro_interleaved_ = enable_fine_grained_micro_interleaved;
319 }
320 
set_pipeline_micro_size(const size_t pipeline_micro_size)321 void ParallelContext::set_pipeline_micro_size(const size_t pipeline_micro_size) {
322   pipeline_micro_size_ = pipeline_micro_size;
323 }
324 
set_auto_pipeline(const bool auto_pipeline)325 void ParallelContext::set_auto_pipeline(const bool auto_pipeline) { auto_pipeline_ = auto_pipeline; }
326 
set_do_transform(const bool do_transform)327 void ParallelContext::set_do_transform(const bool do_transform) { do_transform_ = do_transform; }
328 
set_stra_file_only_trainable_params(const bool stra_file_only_trainable_params)329 void ParallelContext::set_stra_file_only_trainable_params(const bool stra_file_only_trainable_params) {
330   stra_file_only_trainable_params_ = stra_file_only_trainable_params;
331 }
332 
set_sharding_propagation(const bool stra_pto)333 void ParallelContext::set_sharding_propagation(const bool stra_pto) { sharding_propagation_ = stra_pto; }
334 }  // namespace mindspore::parallel
335