1 /**
2 * Copyright 2019-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "include/common/utils/parallel_context.h"
18
19 #include <algorithm>
20 #include <cstdint>
21 #include <functional>
22 #include <map>
23 #include <memory>
24
25 namespace mindspore::parallel {
26 namespace {
27 std::vector<std::string> kParallelModeList = {kStandalone, kDataParallel, kHybridParallel, kSemiAutoParallel,
28 kAutoParallel};
29 std::vector<std::string> kStrategySearchModeList = {kDynamicProgramming, kRecursiveProgramming, kShardingPropagation};
30
31 std::vector<std::string> kCommuniParallelModeList = {kAllGroupParallel, kSameServerGroupParallel, kNoGroupParallel};
32
33 std::vector<std::string> kFusionModeList = {kFusionAuto, kFusionSize, kFusionIndex};
34 } // namespace
35
GetInstance()36 std::shared_ptr<ParallelContext> ParallelContext::GetInstance() {
37 static std::shared_ptr<ParallelContext> inst_context_ = std::shared_ptr<ParallelContext>(new ParallelContext());
38 return inst_context_;
39 }
40
ParallelContext()41 ParallelContext::ParallelContext() { Reset(); }
42
Reset()43 void ParallelContext::Reset() {
44 gradients_mean_ = false;
45 full_batch_ = false;
46 full_batch_is_set_ = false;
47 gradient_fp32_sync_ = true;
48 loss_repeated_mean_ = true;
49 device_num_ = 1;
50 global_rank_ = 0;
51 device_num_is_set_ = false;
52 global_rank_is_set_ = false;
53 parallel_mode_ = kStandalone;
54 parameter_broadcast_ = false;
55 parameter_broadcast_is_set_ = false;
56 enable_all_reduce_fusion_ = true;
57 enable_all_gather_fusion_ = true;
58 enable_reduce_scatter_fusion_ = true;
59 strategy_json_config_file_type_ = "";
60 strategy_json_config_file_path_ = "";
61 strategy_json_config_file_mode_ = "";
62 strategy_ckpt_load_file_ = "";
63 strategy_ckpt_save_file_ = "";
64 enable_parallel_optimizer_ = false;
65 force_fp32_communication_ = false;
66 all_reduce_fusion_split_indices_.clear();
67 all_reduce_fusion_split_sizes_.clear();
68 strategy_search_mode_ = kRecursiveProgramming;
69 pipeline_stage_split_num_ = 1;
70 pipeline_segment_split_num_ = 1;
71 grad_accumulation_step_ = 1;
72 communi_parallel_mode_ = kAllGroupParallel;
73 optimizer_weight_shard_size_ = -1;
74 optimizer_weight_shard_aggregated_save_ = false;
75 enable_all2all_ = false;
76 grad_accumulation_shard_ = false;
77 parallel_optimizer_threshold_ = -1;
78 sharding_propagation_ = false;
79 dataset_strategy_.clear();
80 dp_fusion_threshold_mb_ = kDataParallelFusionThreshold;
81 fusion_threshold_mb_ = kFusionThreshold;
82 allgather_fusion_threshold_mb_ = kFusionThreshold;
83 reducescatter_fusion_threshold_mb_ = kFusionThreshold;
84 fusion_threshold_is_set_ = true;
85 fusion_mode_ = kFusionAuto;
86 group_ckpt_save_file_ = "";
87 pipeline_micro_size_ = 1;
88 dataset_repeat_dim_right_ = false;
89 hccl_test_available_ = false;
90 enable_micro_interleaved_ = false;
91 enable_fine_grained_micro_interleaved_ = false;
92 do_transform_ = false;
93 direct_split_ = false;
94 pipeline_result_broadcast_ = false;
95 stra_file_only_trainable_params_ = true;
96 pipeline_interleave_ = false;
97 pipeline_scheduler_ = kPipeline1F1B;
98 auto_pipeline_ = false;
99 }
100
set_device_num(int64_t device_num)101 void ParallelContext::set_device_num(int64_t device_num) {
102 device_num_ = device_num;
103 device_num_is_set_ = true;
104 }
105
set_fusion_threshold_mb(int64_t fusion_threshold)106 void ParallelContext::set_fusion_threshold_mb(int64_t fusion_threshold) {
107 fusion_threshold_mb_ = fusion_threshold;
108 dp_fusion_threshold_mb_ = fusion_threshold;
109 fusion_threshold_is_set_ = true;
110 enable_all_reduce_fusion_ = true;
111 }
112
set_allgather_fusion_threshold_mb(int64_t fusion_threshold)113 void ParallelContext::set_allgather_fusion_threshold_mb(int64_t fusion_threshold) {
114 allgather_fusion_threshold_mb_ = fusion_threshold;
115 enable_all_gather_fusion_ = true;
116 }
117
set_reducescatter_fusion_threshold_mb(int64_t fusion_threshold)118 void ParallelContext::set_reducescatter_fusion_threshold_mb(int64_t fusion_threshold) {
119 reducescatter_fusion_threshold_mb_ = fusion_threshold;
120 enable_reduce_scatter_fusion_ = true;
121 }
122
set_fusion_mode(const std::string & fusion_mode)123 bool ParallelContext::set_fusion_mode(const std::string &fusion_mode) {
124 auto iter = std::find(kFusionModeList.begin(), kFusionModeList.end(), fusion_mode);
125 if (iter == kFusionModeList.end()) {
126 MS_LOG(INFO) << "Invalid fusion mode:" << fusion_mode;
127 return false;
128 }
129 fusion_mode_ = fusion_mode;
130 return true;
131 }
132
set_global_rank(int64_t global_rank)133 void ParallelContext::set_global_rank(int64_t global_rank) {
134 global_rank_ = global_rank;
135 global_rank_is_set_ = true;
136 }
137
set_gradients_mean(bool gradients_mean)138 void ParallelContext::set_gradients_mean(bool gradients_mean) { gradients_mean_ = gradients_mean; }
139
set_full_batch(bool full_batch)140 void ParallelContext::set_full_batch(bool full_batch) {
141 full_batch_ = full_batch;
142 full_batch_is_set_ = true;
143 }
144
set_dataset_strategy(const std::vector<std::vector<int64_t>> & dataset_strategy)145 void ParallelContext::set_dataset_strategy(const std::vector<std::vector<int64_t>> &dataset_strategy) {
146 dataset_strategy_ = dataset_strategy;
147 }
148
set_grad_accumulation_step(int64_t grad_accumulation_step)149 void ParallelContext::set_grad_accumulation_step(int64_t grad_accumulation_step) {
150 grad_accumulation_step_ = grad_accumulation_step;
151 }
152
set_gradient_fp32_sync(bool gradient_fp32_sync)153 void ParallelContext::set_gradient_fp32_sync(bool gradient_fp32_sync) { gradient_fp32_sync_ = gradient_fp32_sync; }
154
set_loss_repeated_mean(bool loss_repeated_mean)155 void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; }
156
set_pipeline_stage_split_num(const int64_t stage_num)157 void ParallelContext::set_pipeline_stage_split_num(const int64_t stage_num) { pipeline_stage_split_num_ = stage_num; }
158
set_pipeline_interleave(const bool pipeline_interleave)159 void ParallelContext::set_pipeline_interleave(const bool pipeline_interleave) {
160 pipeline_interleave_ = pipeline_interleave;
161 }
162
set_pipeline_scheduler(const std::string & pipeline_scheduler)163 void ParallelContext::set_pipeline_scheduler(const std::string &pipeline_scheduler) {
164 pipeline_scheduler_ = pipeline_scheduler;
165 }
166
set_pipeline_segment_split_num(const int64_t segment_num)167 void ParallelContext::set_pipeline_segment_split_num(const int64_t segment_num) {
168 pipeline_segment_split_num_ = segment_num;
169 }
170
set_parallel_mode(const std::string & parallel_mode)171 bool ParallelContext::set_parallel_mode(const std::string ¶llel_mode) {
172 auto iter = std::find(kParallelModeList.begin(), kParallelModeList.end(), parallel_mode);
173 if (iter == kParallelModeList.end()) {
174 MS_LOG(INFO) << "Invalid parallel mode:" << parallel_mode;
175 return false;
176 }
177 parallel_mode_ = parallel_mode;
178 return true;
179 }
180
set_strategy_search_mode(const std::string & strategy_search_mode)181 bool ParallelContext::set_strategy_search_mode(const std::string &strategy_search_mode) {
182 auto iter = std::find(kStrategySearchModeList.begin(), kStrategySearchModeList.end(), strategy_search_mode);
183 if (iter == kStrategySearchModeList.end()) {
184 MS_LOG(INFO) << "Invalid strategy search mode mode: " << strategy_search_mode;
185 return false;
186 }
187 strategy_search_mode_ = strategy_search_mode;
188 return true;
189 }
190
set_parameter_broadcast(bool parameter_broadcast)191 void ParallelContext::set_parameter_broadcast(bool parameter_broadcast) {
192 parameter_broadcast_ = parameter_broadcast;
193 parameter_broadcast_is_set_ = true;
194 }
195
set_ops_strategy_json_config(const std::string & type,const std::string & path,const std::string & mode)196 void ParallelContext::set_ops_strategy_json_config(const std::string &type, const std::string &path,
197 const std::string &mode) {
198 strategy_json_config_file_type_ = type;
199 strategy_json_config_file_path_ = path;
200 strategy_json_config_file_mode_ = mode;
201 }
202
set_strategy_ckpt_load_file(const std::string & strategy_ckpt_load_file)203 void ParallelContext::set_strategy_ckpt_load_file(const std::string &strategy_ckpt_load_file) {
204 strategy_ckpt_load_file_ = strategy_ckpt_load_file;
205 }
206
set_strategy_ckpt_save_file(const std::string & strategy_ckpt_save_file)207 void ParallelContext::set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file) {
208 strategy_ckpt_save_file_ = strategy_ckpt_save_file;
209 }
210
set_group_ckpt_save_file(const std::string & group_ckpt_save_file)211 void ParallelContext::set_group_ckpt_save_file(const std::string &group_ckpt_save_file) {
212 group_ckpt_save_file_ = group_ckpt_save_file;
213 }
214
set_optimizer_weight_shard_size(int64_t optimizer_weight_shard_size)215 void ParallelContext::set_optimizer_weight_shard_size(int64_t optimizer_weight_shard_size) {
216 optimizer_weight_shard_size_ = optimizer_weight_shard_size;
217 }
218
set_optimizer_weight_shard_aggregated_save(bool optimizer_weight_shard_aggregated_save)219 void ParallelContext::set_optimizer_weight_shard_aggregated_save(bool optimizer_weight_shard_aggregated_save) {
220 optimizer_weight_shard_aggregated_save_ = optimizer_weight_shard_aggregated_save;
221 }
222
SetAllReduceFusionSplitIndices(const std::vector<uint32_t> & indices,const std::string & group)223 void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector<uint32_t> &indices, const std::string &group) {
224 if (!group.empty() && group.find(TypeIdLabel(kNumberTypeFloat)) == std::string::npos &&
225 group.find(TypeIdLabel(kNumberTypeFloat16)) == std::string::npos &&
226 group.find(TypeIdLabel(kNumberTypeFloat32)) == std::string::npos) {
227 all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat)] = indices;
228 all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat16)] = indices;
229 all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat32)] = indices;
230 }
231 all_reduce_fusion_split_indices_[group] = indices;
232 }
233
GetAllReduceFusionSplitIndices(const std::string & group) const234 std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitIndices(const std::string &group) const {
235 auto iter = all_reduce_fusion_split_indices_.find(group);
236 if (iter != all_reduce_fusion_split_indices_.end()) {
237 return iter->second;
238 }
239 return {};
240 }
241
SetAllReduceFusionSplitSizes(const std::vector<uint32_t> & sizes,const std::string & group)242 void ParallelContext::SetAllReduceFusionSplitSizes(const std::vector<uint32_t> &sizes, const std::string &group) {
243 if (!group.empty() && group.find(TypeIdLabel(kNumberTypeFloat)) == std::string::npos &&
244 group.find(TypeIdLabel(kNumberTypeFloat16)) == std::string::npos &&
245 group.find(TypeIdLabel(kNumberTypeFloat32)) == std::string::npos) {
246 all_reduce_fusion_split_sizes_[group + TypeIdLabel(kNumberTypeFloat)] = sizes;
247 all_reduce_fusion_split_sizes_[group + TypeIdLabel(kNumberTypeFloat16)] = sizes;
248 all_reduce_fusion_split_sizes_[group + TypeIdLabel(kNumberTypeFloat32)] = sizes;
249 }
250 all_reduce_fusion_split_sizes_[group] = sizes;
251 }
252
GetAllReduceFusionSplitSizes(const std::string & group) const253 std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitSizes(const std::string &group) const {
254 auto iter = all_reduce_fusion_split_sizes_.find(group);
255 if (iter != all_reduce_fusion_split_sizes_.end()) {
256 return iter->second;
257 }
258 return {};
259 }
260
set_communi_parallel_mode(const std::string & communi_parallel_mode)261 bool ParallelContext::set_communi_parallel_mode(const std::string &communi_parallel_mode) {
262 auto iter = std::find(kCommuniParallelModeList.begin(), kCommuniParallelModeList.end(), communi_parallel_mode);
263 if (iter == kCommuniParallelModeList.end()) {
264 MS_LOG(INFO) << "Invalid communication parallel mode:" << communi_parallel_mode;
265 return false;
266 }
267
268 communi_parallel_mode_ = communi_parallel_mode;
269 return true;
270 }
271
272 // Restore the parameters' shape for evaluation/prediction in auto-parallel or semi-auto-parallel mode
ParallelParameterContextRestoreShape(const FuncGraphPtr & func_graph,const ParameterPtr & param_node,const AbstractBasePtr & ptr) const273 void ParallelContext::ParallelParameterContextRestoreShape(const FuncGraphPtr &func_graph,
274 const ParameterPtr ¶m_node,
275 const AbstractBasePtr &ptr) const {
276 MS_EXCEPTION_IF_NULL(func_graph);
277 MS_EXCEPTION_IF_NULL(param_node);
278 MS_EXCEPTION_IF_NULL(ptr);
279 if (!ParallelContextCareGraph(func_graph)) {
280 return;
281 }
282
283 auto param_info = param_node->param_info();
284 if (!param_info) {
285 return;
286 }
287 auto shape = param_info->parameter_shape();
288 if (shape.empty()) {
289 MS_LOG(INFO) << "The parameter " << param_node->name() << "'s parameter_shape in param_info is empty";
290 return;
291 }
292 std::shared_ptr<abstract::BaseShape> base_shape = std::make_shared<abstract::Shape>(shape);
293 ptr->set_shape(base_shape);
294 MS_LOG(INFO) << "The parameter name is " << param_node->name() << ", the shape is " << shape;
295 }
296
ParallelContextCareGraph(const FuncGraphPtr & func_graph) const297 bool ParallelContext::ParallelContextCareGraph(const FuncGraphPtr &func_graph) const {
298 MS_EXCEPTION_IF_NULL(func_graph);
299 if (func_graph->has_flag(kSkipAutoParallelCompile)) {
300 return false;
301 }
302
303 std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
304 if (parallel_mode != kAutoParallel && parallel_mode != kSemiAutoParallel) {
305 return false;
306 }
307
308 return true;
309 }
310
set_enable_all2all(const bool enable)311 void ParallelContext::set_enable_all2all(const bool enable) { enable_all2all_ = enable; }
312
set_enable_micro_interleaved(const bool enable_micro_interleaved)313 void ParallelContext::set_enable_micro_interleaved(const bool enable_micro_interleaved) {
314 enable_micro_interleaved_ = enable_micro_interleaved;
315 }
316
set_enable_fine_grained_micro_interleaved(const bool enable_fine_grained_micro_interleaved)317 void ParallelContext::set_enable_fine_grained_micro_interleaved(const bool enable_fine_grained_micro_interleaved) {
318 enable_fine_grained_micro_interleaved_ = enable_fine_grained_micro_interleaved;
319 }
320
set_pipeline_micro_size(const size_t pipeline_micro_size)321 void ParallelContext::set_pipeline_micro_size(const size_t pipeline_micro_size) {
322 pipeline_micro_size_ = pipeline_micro_size;
323 }
324
set_auto_pipeline(const bool auto_pipeline)325 void ParallelContext::set_auto_pipeline(const bool auto_pipeline) { auto_pipeline_ = auto_pipeline; }
326
set_do_transform(const bool do_transform)327 void ParallelContext::set_do_transform(const bool do_transform) { do_transform_ = do_transform; }
328
set_stra_file_only_trainable_params(const bool stra_file_only_trainable_params)329 void ParallelContext::set_stra_file_only_trainable_params(const bool stra_file_only_trainable_params) {
330 stra_file_only_trainable_params_ = stra_file_only_trainable_params;
331 }
332
set_sharding_propagation(const bool stra_pto)333 void ParallelContext::set_sharding_propagation(const bool stra_pto) { sharding_propagation_ = stra_pto; }
334 } // namespace mindspore::parallel
335