1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/context.h"
18
19 #include <algorithm>
20 #include <cstdint>
21 #include <functional>
22 #include <map>
23 #include <memory>
24 #include <utility>
25
26 #include "frontend/parallel/device_manager.h"
27
28 namespace mindspore {
29 namespace parallel {
30 std::map<std::string, Shape> param_shapes;
31
32 std::vector<std::string> PARALLEL_MODE_LIST = {STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL,
33 AUTO_PARALLEL};
34 std::vector<std::string> STRATEGY_SEARCH_MODE_LIST = {DYNAMIC_PROGRAMMING, RECURSIVE_PROGRAMMING};
35
36 std::vector<std::string> COMMUNI_PARALLEL_MODE_LIST = {ALL_GROUP_PARALLEL, SAME_SERVER_GROUP_PARALLEL,
37 NO_GROUP_PARALLEL};
38
39 std::shared_ptr<ParallelContext> ParallelContext::inst_context_ = nullptr;
40
GetInstance()41 std::shared_ptr<ParallelContext> ParallelContext::GetInstance() {
42 if (inst_context_ == nullptr) {
43 inst_context_.reset(new (std::nothrow) ParallelContext());
44 }
45 return inst_context_;
46 }
47
ParallelContext()48 ParallelContext::ParallelContext() { Reset(); }
49
Reset()50 void ParallelContext::Reset() {
51 init_param_shape_ = true;
52 gradients_mean_ = false;
53 full_batch_ = false;
54 gradient_fp32_sync_ = true;
55 loss_repeated_mean_ = true;
56 device_num_ = 1;
57 global_rank_ = 0;
58 device_num_is_set_ = false;
59 global_rank_is_set_ = false;
60 parallel_mode_ = STAND_ALONE;
61 parameter_broadcast_ = false;
62 parameter_broadcast_is_set_ = false;
63 enable_all_reduce_fusion_ = false;
64 strategy_ckpt_load_file_ = "";
65 strategy_ckpt_save_file_ = "";
66 enable_parallel_optimizer_ = false;
67 all_reduce_fusion_split_indices_.clear();
68 all_reduce_fusion_split_sizes_.clear();
69 strategy_search_mode_ = DYNAMIC_PROGRAMMING;
70 pipeline_stage_split_num_ = 1;
71 grad_accumulation_step_ = 1;
72 communi_parallel_mode_ = ALL_GROUP_PARALLEL;
73 optimizer_weight_shard_size_ = -1;
74 optimizer_weight_shard_aggregated_save_ = false;
75 sharding_propagation_ = false;
76 enable_all2all_ = false;
77 dataset_strategy_.clear();
78 }
79
set_device_num(int64_t device_num)80 void ParallelContext::set_device_num(int64_t device_num) {
81 device_num_ = device_num;
82 device_num_is_set_ = true;
83 }
84
set_global_rank(int64_t global_rank)85 void ParallelContext::set_global_rank(int64_t global_rank) {
86 global_rank_ = global_rank;
87 global_rank_is_set_ = true;
88 }
89
set_gradients_mean(bool gradients_mean)90 void ParallelContext::set_gradients_mean(bool gradients_mean) { gradients_mean_ = gradients_mean; }
91
set_full_batch(bool full_batch)92 void ParallelContext::set_full_batch(bool full_batch) { full_batch_ = full_batch; }
93
set_dataset_strategy(const std::vector<std::vector<int64_t>> & dataset_strategy)94 void ParallelContext::set_dataset_strategy(const std::vector<std::vector<int64_t>> &dataset_strategy) {
95 dataset_strategy_ = dataset_strategy;
96 }
97
set_grad_accumulation_step(int64_t grad_accumulation_step)98 void ParallelContext::set_grad_accumulation_step(int64_t grad_accumulation_step) {
99 grad_accumulation_step_ = grad_accumulation_step;
100 }
101
set_gradient_fp32_sync(bool gradient_fp32_sync)102 void ParallelContext::set_gradient_fp32_sync(bool gradient_fp32_sync) { gradient_fp32_sync_ = gradient_fp32_sync; }
103
set_loss_repeated_mean(bool loss_repeated_mean)104 void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; }
105
set_pipeline_stage_split_num(const int64_t stage_num)106 void ParallelContext::set_pipeline_stage_split_num(const int64_t stage_num) { pipeline_stage_split_num_ = stage_num; }
107
set_parallel_mode(const std::string & parallel_mode)108 bool ParallelContext::set_parallel_mode(const std::string ¶llel_mode) {
109 auto iter = std::find(PARALLEL_MODE_LIST.begin(), PARALLEL_MODE_LIST.end(), parallel_mode);
110 if (iter == PARALLEL_MODE_LIST.end()) {
111 MS_LOG(INFO) << "Invalid parallel mode:" << parallel_mode;
112 return false;
113 }
114 parallel_mode_ = parallel_mode;
115 return true;
116 }
117
set_strategy_search_mode(const std::string & strategy_search_mode)118 bool ParallelContext::set_strategy_search_mode(const std::string &strategy_search_mode) {
119 auto iter = std::find(STRATEGY_SEARCH_MODE_LIST.begin(), STRATEGY_SEARCH_MODE_LIST.end(), strategy_search_mode);
120 if (iter == STRATEGY_SEARCH_MODE_LIST.end()) {
121 MS_LOG(INFO) << "Invalid strategy search mode mode: " << strategy_search_mode;
122 return false;
123 }
124 strategy_search_mode_ = strategy_search_mode;
125 return true;
126 }
127
set_parameter_broadcast(bool parameter_broadcast)128 void ParallelContext::set_parameter_broadcast(bool parameter_broadcast) {
129 parameter_broadcast_ = parameter_broadcast;
130 parameter_broadcast_is_set_ = true;
131 }
132
set_strategy_ckpt_load_file(const std::string & strategy_ckpt_load_file)133 void ParallelContext::set_strategy_ckpt_load_file(const std::string &strategy_ckpt_load_file) {
134 strategy_ckpt_load_file_ = strategy_ckpt_load_file;
135 }
136
set_strategy_ckpt_save_file(const std::string & strategy_ckpt_save_file)137 void ParallelContext::set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file) {
138 strategy_ckpt_save_file_ = strategy_ckpt_save_file;
139 }
140
set_group_ckpt_save_file(const std::string & group_ckpt_save_file)141 void ParallelContext::set_group_ckpt_save_file(const std::string &group_ckpt_save_file) {
142 group_ckpt_save_file_ = group_ckpt_save_file;
143 }
144
set_optimizer_weight_shard_size(int64_t optimizer_weight_shard_size)145 void ParallelContext::set_optimizer_weight_shard_size(int64_t optimizer_weight_shard_size) {
146 optimizer_weight_shard_size_ = optimizer_weight_shard_size;
147 }
148
set_optimizer_weight_shard_aggregated_save(bool optimizer_weight_shard_aggregated_save)149 void ParallelContext::set_optimizer_weight_shard_aggregated_save(bool optimizer_weight_shard_aggregated_save) {
150 optimizer_weight_shard_aggregated_save_ = optimizer_weight_shard_aggregated_save;
151 }
152
SetAllReduceFusionSplitIndices(const std::vector<uint32_t> & indices,const std::string & group)153 void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector<uint32_t> &indices, const std::string &group) {
154 if (!group.empty() && group.find(TypeIdLabel(kNumberTypeFloat)) == std::string::npos &&
155 group.find(TypeIdLabel(kNumberTypeFloat16)) == std::string::npos &&
156 group.find(TypeIdLabel(kNumberTypeFloat32)) == std::string::npos) {
157 all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat)] = indices;
158 all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat16)] = indices;
159 all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat32)] = indices;
160 }
161 all_reduce_fusion_split_indices_[group] = indices;
162 }
163
GetAllReduceFusionSplitIndices(const std::string & group) const164 std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitIndices(const std::string &group) const {
165 auto iter = all_reduce_fusion_split_indices_.find(group);
166 if (iter != all_reduce_fusion_split_indices_.end()) {
167 return iter->second;
168 }
169 return {};
170 }
171
SetAllReduceFusionSplitSizes(const std::vector<uint32_t> & sizes,const std::string & group)172 void ParallelContext::SetAllReduceFusionSplitSizes(const std::vector<uint32_t> &sizes, const std::string &group) {
173 if (!group.empty() && group.find(TypeIdLabel(kNumberTypeFloat)) == std::string::npos &&
174 group.find(TypeIdLabel(kNumberTypeFloat16)) == std::string::npos &&
175 group.find(TypeIdLabel(kNumberTypeFloat32)) == std::string::npos) {
176 all_reduce_fusion_split_sizes_[group + TypeIdLabel(kNumberTypeFloat)] = sizes;
177 all_reduce_fusion_split_sizes_[group + TypeIdLabel(kNumberTypeFloat16)] = sizes;
178 all_reduce_fusion_split_sizes_[group + TypeIdLabel(kNumberTypeFloat32)] = sizes;
179 }
180 all_reduce_fusion_split_sizes_[group] = sizes;
181 }
182
GetAllReduceFusionSplitSizes(const std::string & group) const183 std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitSizes(const std::string &group) const {
184 auto iter = all_reduce_fusion_split_sizes_.find(group);
185 if (iter != all_reduce_fusion_split_sizes_.end()) {
186 return iter->second;
187 }
188 return {};
189 }
190
set_communi_parallel_mode(const std::string & communi_parallel_mode)191 bool ParallelContext::set_communi_parallel_mode(const std::string &communi_parallel_mode) {
192 auto iter = std::find(COMMUNI_PARALLEL_MODE_LIST.begin(), COMMUNI_PARALLEL_MODE_LIST.end(), communi_parallel_mode);
193 if (iter == COMMUNI_PARALLEL_MODE_LIST.end()) {
194 MS_LOG(INFO) << "Invalid communication parallel mode:" << communi_parallel_mode;
195 return false;
196 }
197
198 communi_parallel_mode_ = communi_parallel_mode;
199 return true;
200 }
201
202 // Clear param_shapes before training in auto-parallel or semi-auto-parallel mode
ParallelParameterContextInitShape(const FuncGraphPtr & func_graph)203 void ParallelContext::ParallelParameterContextInitShape(const FuncGraphPtr &func_graph) {
204 MS_EXCEPTION_IF_NULL(func_graph);
205 if (!func_graph->has_flag(AUTO_PARALLEL)) {
206 return;
207 }
208 if (func_graph->has_flag(IS_FIRST_ITERATION)) {
209 param_shapes.clear();
210 init_param_shape_ = true;
211 MS_LOG(INFO) << "Init the parameter shape dict in increment predict with two graph";
212 return;
213 }
214 if (!func_graph->has_flag(TRAINING)) {
215 init_param_shape_ = false;
216 MS_LOG(INFO) << "In parallel evaluation or prediction, may be need to restore the parameter shape";
217 return;
218 }
219
220 if ((ParallelContext::GetInstance()->grad_accumulation_step() > 1) && !func_graph->has_flag(ACCUMULATION)) {
221 init_param_shape_ = false;
222 MS_LOG(INFO) << "In parallel grad accumulation second graph, need to restore the parameter shape";
223 } else {
224 param_shapes.clear();
225 init_param_shape_ = true;
226 MS_LOG(INFO) << "Init the parameter shape dict";
227 }
228 }
229
230 // Restore the parameters' shape for evaluation/prediction in auto-parallel or semi-auto-parallel mode
ParallelParameterContextRestoreShape(const FuncGraphPtr & func_graph,const ParameterPtr & param_node,const AbstractBasePtr & ptr)231 void ParallelContext::ParallelParameterContextRestoreShape(const FuncGraphPtr &func_graph,
232 const ParameterPtr ¶m_node, const AbstractBasePtr &ptr) {
233 MS_EXCEPTION_IF_NULL(func_graph);
234 MS_EXCEPTION_IF_NULL(param_node);
235 MS_EXCEPTION_IF_NULL(ptr);
236 if (!func_graph->has_flag(AUTO_PARALLEL)) {
237 return;
238 }
239
240 if (init_param_shape_) {
241 return;
242 }
243 auto iter = param_shapes.find(param_node->name());
244 if (iter == param_shapes.end()) {
245 MS_LOG(WARNING) << "Can not found the shape for parameter " << param_node->name();
246 return;
247 }
248 Shape shape = iter->second;
249 std::shared_ptr<abstract::BaseShape> base_shape = std::make_shared<abstract::Shape>(shape);
250 ptr->set_shape(base_shape);
251 MS_LOG(INFO) << "The parameter name is " << param_node->name() << ", the shape is " << shape;
252 }
253
254 // Clear param_shapes before training in auto-parallel or semi-auto-parallel mode
255 // Checkpoint the parameters' shape for training in auto-parallel or semi-auto-parallel mode
ParallelParameterContextCkptShape(const FuncGraphPtr & func_graph,const ParameterPtr & param_node,const AbstractBasePtr & ptr)256 void ParallelContext::ParallelParameterContextCkptShape(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node,
257 const AbstractBasePtr &ptr) {
258 MS_EXCEPTION_IF_NULL(func_graph);
259 MS_EXCEPTION_IF_NULL(param_node);
260 MS_EXCEPTION_IF_NULL(ptr);
261 if (!func_graph->has_flag(AUTO_PARALLEL)) {
262 return;
263 }
264
265 if (!init_param_shape_) {
266 return;
267 }
268 std::vector<int64_t> shape = dyn_cast<abstract::Shape>(ptr->GetShapeTrack())->shape();
269 auto ret = param_shapes.try_emplace(param_node->name(), shape);
270 if (!ret.second) {
271 MS_LOG(EXCEPTION) << "The shape for parameter name " << param_node->name() << " is existed";
272 return;
273 }
274
275 MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape;
276 }
277
set_sharding_propagation(const bool stra_pto)278 void ParallelContext::set_sharding_propagation(const bool stra_pto) { sharding_propagation_ = stra_pto; }
279
set_enable_all2all(const bool enable)280 void ParallelContext::set_enable_all2all(const bool enable) { enable_all2all_ = enable; }
281 } // namespace parallel
282 } // namespace mindspore
283