• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/context.h"
18 
19 #include <algorithm>
20 #include <cstdint>
21 #include <functional>
22 #include <map>
23 #include <memory>
24 #include <utility>
25 
26 #include "frontend/parallel/device_manager.h"
27 
28 namespace mindspore {
29 namespace parallel {
30 std::map<std::string, Shape> param_shapes;
31 
32 std::vector<std::string> PARALLEL_MODE_LIST = {STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL,
33                                                AUTO_PARALLEL};
34 std::vector<std::string> STRATEGY_SEARCH_MODE_LIST = {DYNAMIC_PROGRAMMING, RECURSIVE_PROGRAMMING};
35 
36 std::vector<std::string> COMMUNI_PARALLEL_MODE_LIST = {ALL_GROUP_PARALLEL, SAME_SERVER_GROUP_PARALLEL,
37                                                        NO_GROUP_PARALLEL};
38 
39 std::shared_ptr<ParallelContext> ParallelContext::inst_context_ = nullptr;
40 
GetInstance()41 std::shared_ptr<ParallelContext> ParallelContext::GetInstance() {
42   if (inst_context_ == nullptr) {
43     inst_context_.reset(new (std::nothrow) ParallelContext());
44   }
45   return inst_context_;
46 }
47 
ParallelContext()48 ParallelContext::ParallelContext() { Reset(); }
49 
Reset()50 void ParallelContext::Reset() {
51   init_param_shape_ = true;
52   gradients_mean_ = false;
53   full_batch_ = false;
54   gradient_fp32_sync_ = true;
55   loss_repeated_mean_ = true;
56   device_num_ = 1;
57   global_rank_ = 0;
58   device_num_is_set_ = false;
59   global_rank_is_set_ = false;
60   parallel_mode_ = STAND_ALONE;
61   parameter_broadcast_ = false;
62   parameter_broadcast_is_set_ = false;
63   enable_all_reduce_fusion_ = false;
64   strategy_ckpt_load_file_ = "";
65   strategy_ckpt_save_file_ = "";
66   enable_parallel_optimizer_ = false;
67   all_reduce_fusion_split_indices_.clear();
68   all_reduce_fusion_split_sizes_.clear();
69   strategy_search_mode_ = DYNAMIC_PROGRAMMING;
70   pipeline_stage_split_num_ = 1;
71   grad_accumulation_step_ = 1;
72   communi_parallel_mode_ = ALL_GROUP_PARALLEL;
73   optimizer_weight_shard_size_ = -1;
74   optimizer_weight_shard_aggregated_save_ = false;
75   sharding_propagation_ = false;
76   enable_all2all_ = false;
77   dataset_strategy_.clear();
78 }
79 
set_device_num(int64_t device_num)80 void ParallelContext::set_device_num(int64_t device_num) {
81   device_num_ = device_num;
82   device_num_is_set_ = true;
83 }
84 
set_global_rank(int64_t global_rank)85 void ParallelContext::set_global_rank(int64_t global_rank) {
86   global_rank_ = global_rank;
87   global_rank_is_set_ = true;
88 }
89 
set_gradients_mean(bool gradients_mean)90 void ParallelContext::set_gradients_mean(bool gradients_mean) { gradients_mean_ = gradients_mean; }
91 
set_full_batch(bool full_batch)92 void ParallelContext::set_full_batch(bool full_batch) { full_batch_ = full_batch; }
93 
set_dataset_strategy(const std::vector<std::vector<int64_t>> & dataset_strategy)94 void ParallelContext::set_dataset_strategy(const std::vector<std::vector<int64_t>> &dataset_strategy) {
95   dataset_strategy_ = dataset_strategy;
96 }
97 
set_grad_accumulation_step(int64_t grad_accumulation_step)98 void ParallelContext::set_grad_accumulation_step(int64_t grad_accumulation_step) {
99   grad_accumulation_step_ = grad_accumulation_step;
100 }
101 
set_gradient_fp32_sync(bool gradient_fp32_sync)102 void ParallelContext::set_gradient_fp32_sync(bool gradient_fp32_sync) { gradient_fp32_sync_ = gradient_fp32_sync; }
103 
set_loss_repeated_mean(bool loss_repeated_mean)104 void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; }
105 
set_pipeline_stage_split_num(const int64_t stage_num)106 void ParallelContext::set_pipeline_stage_split_num(const int64_t stage_num) { pipeline_stage_split_num_ = stage_num; }
107 
set_parallel_mode(const std::string & parallel_mode)108 bool ParallelContext::set_parallel_mode(const std::string &parallel_mode) {
109   auto iter = std::find(PARALLEL_MODE_LIST.begin(), PARALLEL_MODE_LIST.end(), parallel_mode);
110   if (iter == PARALLEL_MODE_LIST.end()) {
111     MS_LOG(INFO) << "Invalid parallel mode:" << parallel_mode;
112     return false;
113   }
114   parallel_mode_ = parallel_mode;
115   return true;
116 }
117 
set_strategy_search_mode(const std::string & strategy_search_mode)118 bool ParallelContext::set_strategy_search_mode(const std::string &strategy_search_mode) {
119   auto iter = std::find(STRATEGY_SEARCH_MODE_LIST.begin(), STRATEGY_SEARCH_MODE_LIST.end(), strategy_search_mode);
120   if (iter == STRATEGY_SEARCH_MODE_LIST.end()) {
121     MS_LOG(INFO) << "Invalid strategy search mode mode: " << strategy_search_mode;
122     return false;
123   }
124   strategy_search_mode_ = strategy_search_mode;
125   return true;
126 }
127 
set_parameter_broadcast(bool parameter_broadcast)128 void ParallelContext::set_parameter_broadcast(bool parameter_broadcast) {
129   parameter_broadcast_ = parameter_broadcast;
130   parameter_broadcast_is_set_ = true;
131 }
132 
set_strategy_ckpt_load_file(const std::string & strategy_ckpt_load_file)133 void ParallelContext::set_strategy_ckpt_load_file(const std::string &strategy_ckpt_load_file) {
134   strategy_ckpt_load_file_ = strategy_ckpt_load_file;
135 }
136 
set_strategy_ckpt_save_file(const std::string & strategy_ckpt_save_file)137 void ParallelContext::set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file) {
138   strategy_ckpt_save_file_ = strategy_ckpt_save_file;
139 }
140 
set_group_ckpt_save_file(const std::string & group_ckpt_save_file)141 void ParallelContext::set_group_ckpt_save_file(const std::string &group_ckpt_save_file) {
142   group_ckpt_save_file_ = group_ckpt_save_file;
143 }
144 
set_optimizer_weight_shard_size(int64_t optimizer_weight_shard_size)145 void ParallelContext::set_optimizer_weight_shard_size(int64_t optimizer_weight_shard_size) {
146   optimizer_weight_shard_size_ = optimizer_weight_shard_size;
147 }
148 
set_optimizer_weight_shard_aggregated_save(bool optimizer_weight_shard_aggregated_save)149 void ParallelContext::set_optimizer_weight_shard_aggregated_save(bool optimizer_weight_shard_aggregated_save) {
150   optimizer_weight_shard_aggregated_save_ = optimizer_weight_shard_aggregated_save;
151 }
152 
SetAllReduceFusionSplitIndices(const std::vector<uint32_t> & indices,const std::string & group)153 void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector<uint32_t> &indices, const std::string &group) {
154   if (!group.empty() && group.find(TypeIdLabel(kNumberTypeFloat)) == std::string::npos &&
155       group.find(TypeIdLabel(kNumberTypeFloat16)) == std::string::npos &&
156       group.find(TypeIdLabel(kNumberTypeFloat32)) == std::string::npos) {
157     all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat)] = indices;
158     all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat16)] = indices;
159     all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat32)] = indices;
160   }
161   all_reduce_fusion_split_indices_[group] = indices;
162 }
163 
GetAllReduceFusionSplitIndices(const std::string & group) const164 std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitIndices(const std::string &group) const {
165   auto iter = all_reduce_fusion_split_indices_.find(group);
166   if (iter != all_reduce_fusion_split_indices_.end()) {
167     return iter->second;
168   }
169   return {};
170 }
171 
SetAllReduceFusionSplitSizes(const std::vector<uint32_t> & sizes,const std::string & group)172 void ParallelContext::SetAllReduceFusionSplitSizes(const std::vector<uint32_t> &sizes, const std::string &group) {
173   if (!group.empty() && group.find(TypeIdLabel(kNumberTypeFloat)) == std::string::npos &&
174       group.find(TypeIdLabel(kNumberTypeFloat16)) == std::string::npos &&
175       group.find(TypeIdLabel(kNumberTypeFloat32)) == std::string::npos) {
176     all_reduce_fusion_split_sizes_[group + TypeIdLabel(kNumberTypeFloat)] = sizes;
177     all_reduce_fusion_split_sizes_[group + TypeIdLabel(kNumberTypeFloat16)] = sizes;
178     all_reduce_fusion_split_sizes_[group + TypeIdLabel(kNumberTypeFloat32)] = sizes;
179   }
180   all_reduce_fusion_split_sizes_[group] = sizes;
181 }
182 
GetAllReduceFusionSplitSizes(const std::string & group) const183 std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitSizes(const std::string &group) const {
184   auto iter = all_reduce_fusion_split_sizes_.find(group);
185   if (iter != all_reduce_fusion_split_sizes_.end()) {
186     return iter->second;
187   }
188   return {};
189 }
190 
set_communi_parallel_mode(const std::string & communi_parallel_mode)191 bool ParallelContext::set_communi_parallel_mode(const std::string &communi_parallel_mode) {
192   auto iter = std::find(COMMUNI_PARALLEL_MODE_LIST.begin(), COMMUNI_PARALLEL_MODE_LIST.end(), communi_parallel_mode);
193   if (iter == COMMUNI_PARALLEL_MODE_LIST.end()) {
194     MS_LOG(INFO) << "Invalid communication parallel mode:" << communi_parallel_mode;
195     return false;
196   }
197 
198   communi_parallel_mode_ = communi_parallel_mode;
199   return true;
200 }
201 
202 // Clear param_shapes before training in auto-parallel or semi-auto-parallel mode
ParallelParameterContextInitShape(const FuncGraphPtr & func_graph)203 void ParallelContext::ParallelParameterContextInitShape(const FuncGraphPtr &func_graph) {
204   MS_EXCEPTION_IF_NULL(func_graph);
205   if (!func_graph->has_flag(AUTO_PARALLEL)) {
206     return;
207   }
208   if (func_graph->has_flag(IS_FIRST_ITERATION)) {
209     param_shapes.clear();
210     init_param_shape_ = true;
211     MS_LOG(INFO) << "Init the parameter shape dict in increment predict with two graph";
212     return;
213   }
214   if (!func_graph->has_flag(TRAINING)) {
215     init_param_shape_ = false;
216     MS_LOG(INFO) << "In parallel evaluation or prediction, may be need to restore the parameter shape";
217     return;
218   }
219 
220   if ((ParallelContext::GetInstance()->grad_accumulation_step() > 1) && !func_graph->has_flag(ACCUMULATION)) {
221     init_param_shape_ = false;
222     MS_LOG(INFO) << "In parallel grad accumulation second graph, need to restore the parameter shape";
223   } else {
224     param_shapes.clear();
225     init_param_shape_ = true;
226     MS_LOG(INFO) << "Init the parameter shape dict";
227   }
228 }
229 
230 // Restore the parameters' shape for evaluation/prediction in auto-parallel or semi-auto-parallel mode
ParallelParameterContextRestoreShape(const FuncGraphPtr & func_graph,const ParameterPtr & param_node,const AbstractBasePtr & ptr)231 void ParallelContext::ParallelParameterContextRestoreShape(const FuncGraphPtr &func_graph,
232                                                            const ParameterPtr &param_node, const AbstractBasePtr &ptr) {
233   MS_EXCEPTION_IF_NULL(func_graph);
234   MS_EXCEPTION_IF_NULL(param_node);
235   MS_EXCEPTION_IF_NULL(ptr);
236   if (!func_graph->has_flag(AUTO_PARALLEL)) {
237     return;
238   }
239 
240   if (init_param_shape_) {
241     return;
242   }
243   auto iter = param_shapes.find(param_node->name());
244   if (iter == param_shapes.end()) {
245     MS_LOG(WARNING) << "Can not found the shape for parameter " << param_node->name();
246     return;
247   }
248   Shape shape = iter->second;
249   std::shared_ptr<abstract::BaseShape> base_shape = std::make_shared<abstract::Shape>(shape);
250   ptr->set_shape(base_shape);
251   MS_LOG(INFO) << "The parameter name is " << param_node->name() << ", the shape is " << shape;
252 }
253 
254 // Clear param_shapes before training in auto-parallel or semi-auto-parallel mode
255 // Checkpoint the parameters' shape for training in auto-parallel or semi-auto-parallel mode
ParallelParameterContextCkptShape(const FuncGraphPtr & func_graph,const ParameterPtr & param_node,const AbstractBasePtr & ptr)256 void ParallelContext::ParallelParameterContextCkptShape(const FuncGraphPtr &func_graph, const ParameterPtr &param_node,
257                                                         const AbstractBasePtr &ptr) {
258   MS_EXCEPTION_IF_NULL(func_graph);
259   MS_EXCEPTION_IF_NULL(param_node);
260   MS_EXCEPTION_IF_NULL(ptr);
261   if (!func_graph->has_flag(AUTO_PARALLEL)) {
262     return;
263   }
264 
265   if (!init_param_shape_) {
266     return;
267   }
268   std::vector<int64_t> shape = dyn_cast<abstract::Shape>(ptr->GetShapeTrack())->shape();
269   auto ret = param_shapes.try_emplace(param_node->name(), shape);
270   if (!ret.second) {
271     MS_LOG(EXCEPTION) << "The shape for parameter name " << param_node->name() << " is existed";
272     return;
273   }
274 
275   MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape;
276 }
277 
set_sharding_propagation(const bool stra_pto)278 void ParallelContext::set_sharding_propagation(const bool stra_pto) { sharding_propagation_ = stra_pto; }
279 
set_enable_all2all(const bool enable)280 void ParallelContext::set_enable_all2all(const bool enable) { enable_all2all_ = enable; }
281 }  // namespace parallel
282 }  // namespace mindspore
283