• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_PARALLEL_CONTEXT_H_
18 #define MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_PARALLEL_CONTEXT_H_
19 
20 #include <cstdint>
21 #include <map>
22 #include <memory>
23 #include <string>
24 #include <vector>
25 #include "abstract/abstract_value.h"
26 #include "ir/anf.h"
27 #include "ir/func_graph.h"
28 #include "include/common/utils/convert_utils.h"
29 #include "utils/info.h"
30 #include "include/common/visible.h"
31 #include "mindspore/core/symbolic_shape/symbol_info.h"
32 
33 namespace mindspore::parallel {
34 constexpr char kStandalone[] = "stand_alone";
35 constexpr char kDataParallel[] = "data_parallel";
36 constexpr char kHybridParallel[] = "hybrid_parallel";
37 constexpr char kAutoParallel[] = "auto_parallel";
38 constexpr char kSemiAutoParallel[] = "semi_auto_parallel";
39 
40 constexpr char kDynamicProgramming[] = "dynamic_programming";
41 constexpr char kRecursiveProgramming[] = "recursive_programming";
42 constexpr char kShardingPropagation[] = "sharding_propagation";
43 
44 constexpr char kAccumulation[] = "accumulation";
45 
46 constexpr char kAllGroupParallel[] = "all_group_parallel";
47 constexpr char kSameServerGroupParallel[] = "same_server_group_parallel";
48 constexpr char kNoGroupParallel[] = "no_group_parallel";
49 
50 constexpr char kHasShard[] = "has_shard";
51 constexpr char kSharded[] = "sharded";
52 constexpr char kSkipAutoParallelCompile[] = "skip_auto_parallel_compile";
53 constexpr char kKeepInputUnchanged[] = "keep_input_unchanged";
54 
55 constexpr char kPipeline1F1B[] = "1f1b";
56 constexpr char kPipelineGpipe[] = "gpipe";
57 
58 constexpr char kFusionAuto[] = "auto";
59 constexpr char kFusionSize[] = "size";
60 constexpr char kFusionIndex[] = "index";
61 constexpr int64_t kFusionThreshold = 64;
62 constexpr int64_t kDataParallelFusionThreshold = -1;
63 constexpr char kRelatedFusionKey[] = "related_fusion_key";
64 constexpr char kRelatedNodeId[] = "related_node_id";
65 constexpr char FIRST_RECEIVE[] = "first_receive";
66 constexpr char kRelatedCommNodeId[] = "related_comm_node_id";
67 
68 class COMMON_EXPORT ParallelContext {
69  public:
70   static std::shared_ptr<ParallelContext> GetInstance();
71   ~ParallelContext() = default;
72   ParallelContext(const ParallelContext &) = delete;
73   ParallelContext &operator=(const ParallelContext &) = delete;
74 
75   void set_gradients_mean(bool gradients_mean);
gradients_mean()76   bool gradients_mean() const { return gradients_mean_; }
77 
78   void set_full_batch(bool full_batch);
full_batch()79   bool full_batch() const { return full_batch_; }
80 
81   void set_dataset_strategy(const std::vector<std::vector<int64_t>> &dataset_strategy);
dataset_strategy()82   std::vector<std::vector<int64_t>> dataset_strategy() const { return dataset_strategy_; }
83 
84   void set_gradient_fp32_sync(bool gradient_fp32_sync);
gradient_fp32_sync()85   bool gradient_fp32_sync() const { return gradient_fp32_sync_; }
86 
87   void set_loss_repeated_mean(bool loss_repeated_mean);
loss_repeated_mean()88   bool loss_repeated_mean() const { return loss_repeated_mean_; }
89 
90   void set_device_num(int64_t device_num);
device_num()91   int64_t device_num() const { return device_num_; }
92 
93   void set_fusion_threshold_mb(int64_t fusion_threshold);
fusion_threshold_mb()94   int64_t fusion_threshold_mb() const { return fusion_threshold_mb_; }
95 
dp_fusion_threshold_mb()96   int64_t dp_fusion_threshold_mb() const { return dp_fusion_threshold_mb_; }
97 
98   void set_allgather_fusion_threshold_mb(int64_t fusion_threshold);
allgather_fusion_threshold_mb()99   int64_t allgather_fusion_threshold_mb() const { return allgather_fusion_threshold_mb_; }
100 
101   void set_reducescatter_fusion_threshold_mb(int64_t fusion_threshold);
reducescatter_fusion_threshold_mb()102   int64_t reducescatter_fusion_threshold_mb() const { return reducescatter_fusion_threshold_mb_; }
103 
104   bool set_fusion_mode(const std::string &fusion_mode);
get_fusion_mode()105   std::string get_fusion_mode() const { return fusion_mode_; }
106 
107   void set_pipeline_stage_split_num(const int64_t stage_num);
pipeline_stage_split_num()108   int64_t pipeline_stage_split_num() const { return pipeline_stage_split_num_; }
109 
set_pipeline_result_broadcast(const bool flag)110   void set_pipeline_result_broadcast(const bool flag) { pipeline_result_broadcast_ = flag; }
pipeline_result_broadcast()111   bool pipeline_result_broadcast() const { return pipeline_result_broadcast_; }
112   void set_pipeline_interleave(const bool pipeline_interleave);
pipeline_interleave()113   bool pipeline_interleave() const { return pipeline_interleave_; }
114 
115   void set_pipeline_scheduler(const std::string &pipeline_scheduler);
pipeline_scheduler()116   std::string pipeline_scheduler() const { return pipeline_scheduler_; }
117 
118   void set_global_rank(int64_t global_rank);
global_rank()119   int64_t global_rank() const { return global_rank_; }
120 
121   void set_grad_accumulation_step(int64_t grad_accumulation_step);
grad_accumulation_step()122   int64_t grad_accumulation_step() const { return grad_accumulation_step_; }
123 
124   bool set_parallel_mode(const std::string &parallel_mode);
parallel_mode()125   std::string parallel_mode() const { return parallel_mode_; }
126 
127   bool set_strategy_search_mode(const std::string &strategy_search_mode);
strategy_search_mode()128   std::string strategy_search_mode() const { return strategy_search_mode_; }
129 
130   void set_parameter_broadcast(bool parameter_broadcast);
parameter_broadcast()131   bool parameter_broadcast() const { return parameter_broadcast_; }
132 
device_num_is_set()133   bool device_num_is_set() const { return device_num_is_set_; }
global_rank_is_set()134   bool global_rank_is_set() const { return global_rank_is_set_; }
parameter_broadcast_is_set()135   bool parameter_broadcast_is_set() const { return parameter_broadcast_is_set_; }
full_batch_is_set()136   bool full_batch_is_set() const { return full_batch_is_set_; }
137 
138   void set_optimizer_weight_shard_size(int64_t optimizer_weight_shard_size);
optimizer_weight_shard_size()139   int64_t optimizer_weight_shard_size() const { return optimizer_weight_shard_size_; }
140   void set_optimizer_weight_shard_aggregated_save(bool optimizer_weight_shard_aggregated_save);
optimizer_weight_shard_aggregated_save()141   bool optimizer_weight_shard_aggregated_save() const { return optimizer_weight_shard_aggregated_save_; }
142 
143   void SetAllReduceFusionSplitIndices(const std::vector<uint32_t> &indices, const std::string &group);
144   std::vector<uint32_t> GetAllReduceFusionSplitIndices(const std::string &group) const;
145   void SetAllReduceFusionSplitSizes(const std::vector<uint32_t> &sizes, const std::string &group);
146   std::vector<uint32_t> GetAllReduceFusionSplitSizes(const std::string &group) const;
set_enable_all_reduce_fusion(bool enable_all_reduce_fusion)147   void set_enable_all_reduce_fusion(bool enable_all_reduce_fusion) {
148     enable_all_reduce_fusion_ = enable_all_reduce_fusion;
149   }
enable_all_reduce_fusion()150   bool enable_all_reduce_fusion() const { return enable_all_reduce_fusion_; }
set_enable_all_gather_fusion(bool enable_all_gather_fusion)151   void set_enable_all_gather_fusion(bool enable_all_gather_fusion) {
152     enable_all_gather_fusion_ = enable_all_gather_fusion;
153   }
enable_all_gather_fusion()154   bool enable_all_gather_fusion() const { return enable_all_gather_fusion_; }
155 
set_enable_reduce_scatter_fusion(bool enable_reduce_scatter_fusion)156   void set_enable_reduce_scatter_fusion(bool enable_reduce_scatter_fusion) {
157     enable_reduce_scatter_fusion_ = enable_reduce_scatter_fusion;
158   }
enable_reduce_scatter_fusion()159   bool enable_reduce_scatter_fusion() const { return enable_reduce_scatter_fusion_; }
160 
161   void set_ops_strategy_json_config(const std::string &type, const std::string &path, const std::string &mode);
strategy_json_config_file_type()162   std::string strategy_json_config_file_type() const { return strategy_json_config_file_type_; }
strategy_json_config_file_path()163   std::string strategy_json_config_file_path() const { return strategy_json_config_file_path_; }
strategy_json_config_file_mode()164   std::string strategy_json_config_file_mode() const { return strategy_json_config_file_mode_; }
165 
166   void set_strategy_ckpt_load_file(const std::string &strategy_ckpt_load_file);
strategy_ckpt_load_file()167   std::string strategy_ckpt_load_file() const { return strategy_ckpt_load_file_; }
168   void set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file);
strategy_ckpt_save_file()169   std::string strategy_ckpt_save_file() const { return strategy_ckpt_save_file_; }
170   void set_group_ckpt_save_file(const std::string &group_ckpt_save_file);
group_ckpt_save_file()171   std::string group_ckpt_save_file() const { return group_ckpt_save_file_; }
172 
set_enable_parallel_optimizer(bool enable_parallel_optimizer)173   void set_enable_parallel_optimizer(bool enable_parallel_optimizer) {
174     enable_parallel_optimizer_ = enable_parallel_optimizer;
175   }
enable_parallel_optimizer()176   bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; }
177 
set_force_fp32_communication(bool force_fp32_communication)178   void set_force_fp32_communication(bool force_fp32_communication) {
179     force_fp32_communication_ = force_fp32_communication;
180   }
force_fp32_communication()181   bool force_fp32_communication() const { return force_fp32_communication_; }
182 
enable_fold_pipeline()183   bool enable_fold_pipeline() const { return pipeline_segment_split_num_ > 1; }
184 
185   void set_pipeline_segment_split_num(const int64_t segments);
pipeline_segment_split_num()186   int64_t pipeline_segment_split_num() const { return pipeline_segment_split_num_; }
187 
set_hccl_test_available(bool hccl_test_available)188   void set_hccl_test_available(bool hccl_test_available) { hccl_test_available_ = hccl_test_available; }
hccl_test_available()189   bool hccl_test_available() const { return hccl_test_available_; }
set_grad_accumulation_shard(const bool grad_accumulation_shard)190   void set_grad_accumulation_shard(const bool grad_accumulation_shard) {
191     grad_accumulation_shard_ = grad_accumulation_shard;
192   }
grad_accumulation_shard()193   bool grad_accumulation_shard() const { return grad_accumulation_shard_; }
set_parallel_optimizer_threshold(const int64_t parallel_optimizer_threshold)194   void set_parallel_optimizer_threshold(const int64_t parallel_optimizer_threshold) {
195     parallel_optimizer_threshold_ = parallel_optimizer_threshold;
196   }
get_parallel_optimizer_threshold()197   int64_t get_parallel_optimizer_threshold() const { return parallel_optimizer_threshold_; }
198 
199   bool set_communi_parallel_mode(const std::string &communi_parallel_mode);
communi_parallel_mode()200   std::string communi_parallel_mode() const { return communi_parallel_mode_; }
201   void set_enable_all2all(const bool enable);
enable_all2all()202   bool enable_all2all() const { return enable_all2all_; }
set_dataset_repeat_dim_right(const bool dataset_repeat_dim_right)203   void set_dataset_repeat_dim_right(const bool dataset_repeat_dim_right) {
204     dataset_repeat_dim_right_ = dataset_repeat_dim_right;
205   }
dataset_repeat_dim_right()206   bool dataset_repeat_dim_right() const { return dataset_repeat_dim_right_; }
207 
set_direct_split(const bool direct_split)208   void set_direct_split(const bool direct_split) { direct_split_ = direct_split; }
direct_split()209   bool direct_split() const { return direct_split_; }
210 
211   void Reset();
212   void ParallelParameterContextRestoreShape(const FuncGraphPtr &func_graph, const ParameterPtr &param_node,
213                                             const AbstractBasePtr &ptr) const;
214   void set_sharding_propagation(const bool stra_pto);
sharding_propagation()215   bool sharding_propagation() const { return sharding_propagation_; }
216 
217   void set_enable_micro_interleaved(const bool);
enable_micro_interleaved()218   bool enable_micro_interleaved() const { return enable_micro_interleaved_; }
219 
220   void set_enable_fine_grained_micro_interleaved(const bool);
enable_fine_grained_micro_interleaved()221   bool enable_fine_grained_micro_interleaved() const { return enable_fine_grained_micro_interleaved_; }
222 
set_fine_grained_micro_interleaved_size(const int64_t fine_grained_micro_interleaved_size)223   void set_fine_grained_micro_interleaved_size(const int64_t fine_grained_micro_interleaved_size) {
224     fine_grained_micro_interleaved_size_ = fine_grained_micro_interleaved_size;
225   }
fine_grained_micro_interleaved_size()226   int64_t fine_grained_micro_interleaved_size() const { return fine_grained_micro_interleaved_size_; }
227 
228   void set_pipeline_micro_size(const size_t);
pipeline_micro_size()229   size_t pipeline_micro_size() const { return pipeline_micro_size_; }
230 
231   void set_auto_pipeline(const bool);
auto_pipeline()232   bool auto_pipeline() const { return auto_pipeline_; }
233 
234   void set_do_transform(const bool);
do_transform()235   bool do_transform() const { return do_transform_; }
236 
237   void set_stra_file_only_trainable_params(const bool);
stra_file_only_trainable_params()238   bool stra_file_only_trainable_params() const { return stra_file_only_trainable_params_; }
239 
set_symbol_infos(const std::vector<symshape::SymbolInfoList> & symbol_infos)240   void set_symbol_infos(const std::vector<symshape::SymbolInfoList> &symbol_infos) { symbol_infos_ = symbol_infos; }
symbol_infos()241   const std::vector<symshape::SymbolInfoList> &symbol_infos() const { return symbol_infos_; }
242 
243  private:
244   ParallelContext();
245   bool ParallelContextCareGraph(const FuncGraphPtr &func_graph) const;
246 
247   bool gradients_mean_;
248   bool full_batch_;
249   bool full_batch_is_set_;
250   bool gradient_fp32_sync_;
251   bool loss_repeated_mean_;
252   int64_t device_num_;
253   int64_t dp_fusion_threshold_mb_;
254   int64_t fusion_threshold_mb_;
255   int64_t allgather_fusion_threshold_mb_;
256   int64_t reducescatter_fusion_threshold_mb_;  // reducescatter
257   int64_t global_rank_;
258   int64_t grad_accumulation_step_;
259   std::string parallel_mode_;
260   std::string strategy_search_mode_;
261   int64_t pipeline_stage_split_num_;
262   int64_t pipeline_segment_split_num_;
263   bool pipeline_interleave_;
264   std::string pipeline_scheduler_;
265   size_t pipeline_micro_size_;
266   bool auto_pipeline_;
267   bool parameter_broadcast_;
268   bool device_num_is_set_;
269   bool fusion_threshold_is_set_;
270   bool global_rank_is_set_;
271   bool parameter_broadcast_is_set_;
272   bool enable_all_reduce_fusion_;
273   bool enable_all_gather_fusion_;
274   bool enable_reduce_scatter_fusion_;
275 
276   std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_indices_;
277   std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_sizes_;
278   std::string strategy_json_config_file_type_;
279   std::string strategy_json_config_file_path_;
280   std::string strategy_json_config_file_mode_;
281   std::string strategy_ckpt_load_file_;
282   std::string strategy_ckpt_save_file_;
283   std::string group_ckpt_save_file_;
284   bool enable_parallel_optimizer_;
285   bool enable_fold_pipeline_;
286   bool force_fp32_communication_;
287   std::string communi_parallel_mode_;
288   int64_t optimizer_weight_shard_size_;
289   bool optimizer_weight_shard_aggregated_save_;
290   bool grad_accumulation_shard_;
291   int64_t parallel_optimizer_threshold_;
292   // Enable AllToAll or not. If false, use AllGather and Split.
293   bool enable_all2all_;
294   std::vector<std::vector<int64_t>> dataset_strategy_;
295   bool dataset_repeat_dim_right_ = false;
296   bool hccl_test_available_ = false;
297   bool sharding_propagation_;
298   bool enable_micro_interleaved_ = false;
299   bool enable_fine_grained_micro_interleaved_ = false;
300   int64_t fine_grained_micro_interleaved_size_ = -1;
301   bool do_transform_ = false;
302   bool stra_file_only_trainable_params_ = true;
303   std::string fusion_mode_;
304   bool direct_split_ = false;
305   bool pipeline_result_broadcast_ = false;
306   std::vector<symshape::SymbolInfoList> symbol_infos_;
307 };
308 }  // namespace mindspore::parallel
309 #endif  // MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_PARALLEL_CONTEXT_H_
310