1 /** 2 * Copyright 2019-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_PARALLEL_CONTEXT_H_ 18 #define MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_PARALLEL_CONTEXT_H_ 19 20 #include <cstdint> 21 #include <map> 22 #include <memory> 23 #include <string> 24 #include <vector> 25 #include "abstract/abstract_value.h" 26 #include "ir/anf.h" 27 #include "ir/func_graph.h" 28 #include "include/common/utils/convert_utils.h" 29 #include "utils/info.h" 30 #include "include/common/visible.h" 31 #include "mindspore/core/symbolic_shape/symbol_info.h" 32 33 namespace mindspore::parallel { 34 constexpr char kStandalone[] = "stand_alone"; 35 constexpr char kDataParallel[] = "data_parallel"; 36 constexpr char kHybridParallel[] = "hybrid_parallel"; 37 constexpr char kAutoParallel[] = "auto_parallel"; 38 constexpr char kSemiAutoParallel[] = "semi_auto_parallel"; 39 40 constexpr char kDynamicProgramming[] = "dynamic_programming"; 41 constexpr char kRecursiveProgramming[] = "recursive_programming"; 42 constexpr char kShardingPropagation[] = "sharding_propagation"; 43 44 constexpr char kAccumulation[] = "accumulation"; 45 46 constexpr char kAllGroupParallel[] = "all_group_parallel"; 47 constexpr char kSameServerGroupParallel[] = "same_server_group_parallel"; 48 constexpr char kNoGroupParallel[] = "no_group_parallel"; 49 50 constexpr char kHasShard[] = "has_shard"; 51 constexpr char kSharded[] = "sharded"; 52 constexpr char kSkipAutoParallelCompile[] = "skip_auto_parallel_compile"; 53 constexpr char kKeepInputUnchanged[] = "keep_input_unchanged"; 54 55 constexpr char kPipeline1F1B[] = "1f1b"; 56 constexpr char kPipelineGpipe[] = "gpipe"; 57 58 constexpr char kFusionAuto[] = "auto"; 59 constexpr char kFusionSize[] = "size"; 60 constexpr char kFusionIndex[] = "index"; 61 constexpr int64_t kFusionThreshold = 64; 62 constexpr int64_t kDataParallelFusionThreshold = -1; 63 constexpr char kRelatedFusionKey[] = "related_fusion_key"; 64 constexpr char kRelatedNodeId[] = "related_node_id"; 65 constexpr char FIRST_RECEIVE[] = "first_receive"; 66 constexpr char kRelatedCommNodeId[] = "related_comm_node_id"; 67 68 class COMMON_EXPORT ParallelContext { 69 public: 70 static std::shared_ptr<ParallelContext> GetInstance(); 71 ~ParallelContext() = default; 72 ParallelContext(const ParallelContext &) = delete; 73 ParallelContext &operator=(const ParallelContext &) = delete; 74 75 void set_gradients_mean(bool gradients_mean); gradients_mean()76 bool gradients_mean() const { return gradients_mean_; } 77 78 void set_full_batch(bool full_batch); full_batch()79 bool full_batch() const { return full_batch_; } 80 81 void set_dataset_strategy(const std::vector<std::vector<int64_t>> &dataset_strategy); dataset_strategy()82 std::vector<std::vector<int64_t>> dataset_strategy() const { return dataset_strategy_; } 83 84 void set_gradient_fp32_sync(bool gradient_fp32_sync); gradient_fp32_sync()85 bool gradient_fp32_sync() const { return gradient_fp32_sync_; } 86 87 void set_loss_repeated_mean(bool loss_repeated_mean); loss_repeated_mean()88 bool loss_repeated_mean() const { return loss_repeated_mean_; } 89 90 void set_device_num(int64_t device_num); device_num()91 int64_t device_num() const { return device_num_; } 92 93 void set_fusion_threshold_mb(int64_t fusion_threshold); fusion_threshold_mb()94 int64_t fusion_threshold_mb() const { return fusion_threshold_mb_; } 95 dp_fusion_threshold_mb()96 int64_t dp_fusion_threshold_mb() const { return dp_fusion_threshold_mb_; } 97 98 void set_allgather_fusion_threshold_mb(int64_t fusion_threshold); allgather_fusion_threshold_mb()99 int64_t allgather_fusion_threshold_mb() const { return allgather_fusion_threshold_mb_; } 100 101 void set_reducescatter_fusion_threshold_mb(int64_t fusion_threshold); reducescatter_fusion_threshold_mb()102 int64_t reducescatter_fusion_threshold_mb() const { return reducescatter_fusion_threshold_mb_; } 103 104 bool set_fusion_mode(const std::string &fusion_mode); get_fusion_mode()105 std::string get_fusion_mode() const { return fusion_mode_; } 106 107 void set_pipeline_stage_split_num(const int64_t stage_num); pipeline_stage_split_num()108 int64_t pipeline_stage_split_num() const { return pipeline_stage_split_num_; } 109 set_pipeline_result_broadcast(const bool flag)110 void set_pipeline_result_broadcast(const bool flag) { pipeline_result_broadcast_ = flag; } pipeline_result_broadcast()111 bool pipeline_result_broadcast() const { return pipeline_result_broadcast_; } 112 void set_pipeline_interleave(const bool pipeline_interleave); pipeline_interleave()113 bool pipeline_interleave() const { return pipeline_interleave_; } 114 115 void set_pipeline_scheduler(const std::string &pipeline_scheduler); pipeline_scheduler()116 std::string pipeline_scheduler() const { return pipeline_scheduler_; } 117 118 void set_global_rank(int64_t global_rank); global_rank()119 int64_t global_rank() const { return global_rank_; } 120 121 void set_grad_accumulation_step(int64_t grad_accumulation_step); grad_accumulation_step()122 int64_t grad_accumulation_step() const { return grad_accumulation_step_; } 123 124 bool set_parallel_mode(const std::string ¶llel_mode); parallel_mode()125 std::string parallel_mode() const { return parallel_mode_; } 126 127 bool set_strategy_search_mode(const std::string &strategy_search_mode); strategy_search_mode()128 std::string strategy_search_mode() const { return strategy_search_mode_; } 129 130 void set_parameter_broadcast(bool parameter_broadcast); parameter_broadcast()131 bool parameter_broadcast() const { return parameter_broadcast_; } 132 device_num_is_set()133 bool device_num_is_set() const { return device_num_is_set_; } global_rank_is_set()134 bool global_rank_is_set() const { return global_rank_is_set_; } parameter_broadcast_is_set()135 bool parameter_broadcast_is_set() const { return parameter_broadcast_is_set_; } full_batch_is_set()136 bool full_batch_is_set() const { return full_batch_is_set_; } 137 138 void set_optimizer_weight_shard_size(int64_t optimizer_weight_shard_size); optimizer_weight_shard_size()139 int64_t optimizer_weight_shard_size() const { return optimizer_weight_shard_size_; } 140 void set_optimizer_weight_shard_aggregated_save(bool optimizer_weight_shard_aggregated_save); optimizer_weight_shard_aggregated_save()141 bool optimizer_weight_shard_aggregated_save() const { return optimizer_weight_shard_aggregated_save_; } 142 143 void SetAllReduceFusionSplitIndices(const std::vector<uint32_t> &indices, const std::string &group); 144 std::vector<uint32_t> GetAllReduceFusionSplitIndices(const std::string &group) const; 145 void SetAllReduceFusionSplitSizes(const std::vector<uint32_t> &sizes, const std::string &group); 146 std::vector<uint32_t> GetAllReduceFusionSplitSizes(const std::string &group) const; set_enable_all_reduce_fusion(bool enable_all_reduce_fusion)147 void set_enable_all_reduce_fusion(bool enable_all_reduce_fusion) { 148 enable_all_reduce_fusion_ = enable_all_reduce_fusion; 149 } enable_all_reduce_fusion()150 bool enable_all_reduce_fusion() const { return enable_all_reduce_fusion_; } set_enable_all_gather_fusion(bool enable_all_gather_fusion)151 void set_enable_all_gather_fusion(bool enable_all_gather_fusion) { 152 enable_all_gather_fusion_ = enable_all_gather_fusion; 153 } enable_all_gather_fusion()154 bool enable_all_gather_fusion() const { return enable_all_gather_fusion_; } 155 set_enable_reduce_scatter_fusion(bool enable_reduce_scatter_fusion)156 void set_enable_reduce_scatter_fusion(bool enable_reduce_scatter_fusion) { 157 enable_reduce_scatter_fusion_ = enable_reduce_scatter_fusion; 158 } enable_reduce_scatter_fusion()159 bool enable_reduce_scatter_fusion() const { return enable_reduce_scatter_fusion_; } 160 161 void set_ops_strategy_json_config(const std::string &type, const std::string &path, const std::string &mode); strategy_json_config_file_type()162 std::string strategy_json_config_file_type() const { return strategy_json_config_file_type_; } strategy_json_config_file_path()163 std::string strategy_json_config_file_path() const { return strategy_json_config_file_path_; } strategy_json_config_file_mode()164 std::string strategy_json_config_file_mode() const { return strategy_json_config_file_mode_; } 165 166 void set_strategy_ckpt_load_file(const std::string &strategy_ckpt_load_file); strategy_ckpt_load_file()167 std::string strategy_ckpt_load_file() const { return strategy_ckpt_load_file_; } 168 void set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file); strategy_ckpt_save_file()169 std::string strategy_ckpt_save_file() const { return strategy_ckpt_save_file_; } 170 void set_group_ckpt_save_file(const std::string &group_ckpt_save_file); group_ckpt_save_file()171 std::string group_ckpt_save_file() const { return group_ckpt_save_file_; } 172 set_enable_parallel_optimizer(bool enable_parallel_optimizer)173 void set_enable_parallel_optimizer(bool enable_parallel_optimizer) { 174 enable_parallel_optimizer_ = enable_parallel_optimizer; 175 } enable_parallel_optimizer()176 bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; } 177 set_force_fp32_communication(bool force_fp32_communication)178 void set_force_fp32_communication(bool force_fp32_communication) { 179 force_fp32_communication_ = force_fp32_communication; 180 } force_fp32_communication()181 bool force_fp32_communication() const { return force_fp32_communication_; } 182 enable_fold_pipeline()183 bool enable_fold_pipeline() const { return pipeline_segment_split_num_ > 1; } 184 185 void set_pipeline_segment_split_num(const int64_t segments); pipeline_segment_split_num()186 int64_t pipeline_segment_split_num() const { return pipeline_segment_split_num_; } 187 set_hccl_test_available(bool hccl_test_available)188 void set_hccl_test_available(bool hccl_test_available) { hccl_test_available_ = hccl_test_available; } hccl_test_available()189 bool hccl_test_available() const { return hccl_test_available_; } set_grad_accumulation_shard(const bool grad_accumulation_shard)190 void set_grad_accumulation_shard(const bool grad_accumulation_shard) { 191 grad_accumulation_shard_ = grad_accumulation_shard; 192 } grad_accumulation_shard()193 bool grad_accumulation_shard() const { return grad_accumulation_shard_; } set_parallel_optimizer_threshold(const int64_t parallel_optimizer_threshold)194 void set_parallel_optimizer_threshold(const int64_t parallel_optimizer_threshold) { 195 parallel_optimizer_threshold_ = parallel_optimizer_threshold; 196 } get_parallel_optimizer_threshold()197 int64_t get_parallel_optimizer_threshold() const { return parallel_optimizer_threshold_; } 198 199 bool set_communi_parallel_mode(const std::string &communi_parallel_mode); communi_parallel_mode()200 std::string communi_parallel_mode() const { return communi_parallel_mode_; } 201 void set_enable_all2all(const bool enable); enable_all2all()202 bool enable_all2all() const { return enable_all2all_; } set_dataset_repeat_dim_right(const bool dataset_repeat_dim_right)203 void set_dataset_repeat_dim_right(const bool dataset_repeat_dim_right) { 204 dataset_repeat_dim_right_ = dataset_repeat_dim_right; 205 } dataset_repeat_dim_right()206 bool dataset_repeat_dim_right() const { return dataset_repeat_dim_right_; } 207 set_direct_split(const bool direct_split)208 void set_direct_split(const bool direct_split) { direct_split_ = direct_split; } direct_split()209 bool direct_split() const { return direct_split_; } 210 211 void Reset(); 212 void ParallelParameterContextRestoreShape(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, 213 const AbstractBasePtr &ptr) const; 214 void set_sharding_propagation(const bool stra_pto); sharding_propagation()215 bool sharding_propagation() const { return sharding_propagation_; } 216 217 void set_enable_micro_interleaved(const bool); enable_micro_interleaved()218 bool enable_micro_interleaved() const { return enable_micro_interleaved_; } 219 220 void set_enable_fine_grained_micro_interleaved(const bool); enable_fine_grained_micro_interleaved()221 bool enable_fine_grained_micro_interleaved() const { return enable_fine_grained_micro_interleaved_; } 222 set_fine_grained_micro_interleaved_size(const int64_t fine_grained_micro_interleaved_size)223 void set_fine_grained_micro_interleaved_size(const int64_t fine_grained_micro_interleaved_size) { 224 fine_grained_micro_interleaved_size_ = fine_grained_micro_interleaved_size; 225 } fine_grained_micro_interleaved_size()226 int64_t fine_grained_micro_interleaved_size() const { return fine_grained_micro_interleaved_size_; } 227 228 void set_pipeline_micro_size(const size_t); pipeline_micro_size()229 size_t pipeline_micro_size() const { return pipeline_micro_size_; } 230 231 void set_auto_pipeline(const bool); auto_pipeline()232 bool auto_pipeline() const { return auto_pipeline_; } 233 234 void set_do_transform(const bool); do_transform()235 bool do_transform() const { return do_transform_; } 236 237 void set_stra_file_only_trainable_params(const bool); stra_file_only_trainable_params()238 bool stra_file_only_trainable_params() const { return stra_file_only_trainable_params_; } 239 set_symbol_infos(const std::vector<symshape::SymbolInfoList> & symbol_infos)240 void set_symbol_infos(const std::vector<symshape::SymbolInfoList> &symbol_infos) { symbol_infos_ = symbol_infos; } symbol_infos()241 const std::vector<symshape::SymbolInfoList> &symbol_infos() const { return symbol_infos_; } 242 243 private: 244 ParallelContext(); 245 bool ParallelContextCareGraph(const FuncGraphPtr &func_graph) const; 246 247 bool gradients_mean_; 248 bool full_batch_; 249 bool full_batch_is_set_; 250 bool gradient_fp32_sync_; 251 bool loss_repeated_mean_; 252 int64_t device_num_; 253 int64_t dp_fusion_threshold_mb_; 254 int64_t fusion_threshold_mb_; 255 int64_t allgather_fusion_threshold_mb_; 256 int64_t reducescatter_fusion_threshold_mb_; // reducescatter 257 int64_t global_rank_; 258 int64_t grad_accumulation_step_; 259 std::string parallel_mode_; 260 std::string strategy_search_mode_; 261 int64_t pipeline_stage_split_num_; 262 int64_t pipeline_segment_split_num_; 263 bool pipeline_interleave_; 264 std::string pipeline_scheduler_; 265 size_t pipeline_micro_size_; 266 bool auto_pipeline_; 267 bool parameter_broadcast_; 268 bool device_num_is_set_; 269 bool fusion_threshold_is_set_; 270 bool global_rank_is_set_; 271 bool parameter_broadcast_is_set_; 272 bool enable_all_reduce_fusion_; 273 bool enable_all_gather_fusion_; 274 bool enable_reduce_scatter_fusion_; 275 276 std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_indices_; 277 std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_sizes_; 278 std::string strategy_json_config_file_type_; 279 std::string strategy_json_config_file_path_; 280 std::string strategy_json_config_file_mode_; 281 std::string strategy_ckpt_load_file_; 282 std::string strategy_ckpt_save_file_; 283 std::string group_ckpt_save_file_; 284 bool enable_parallel_optimizer_; 285 bool enable_fold_pipeline_; 286 bool force_fp32_communication_; 287 std::string communi_parallel_mode_; 288 int64_t optimizer_weight_shard_size_; 289 bool optimizer_weight_shard_aggregated_save_; 290 bool grad_accumulation_shard_; 291 int64_t parallel_optimizer_threshold_; 292 // Enable AllToAll or not. If false, use AllGather and Split. 293 bool enable_all2all_; 294 std::vector<std::vector<int64_t>> dataset_strategy_; 295 bool dataset_repeat_dim_right_ = false; 296 bool hccl_test_available_ = false; 297 bool sharding_propagation_; 298 bool enable_micro_interleaved_ = false; 299 bool enable_fine_grained_micro_interleaved_ = false; 300 int64_t fine_grained_micro_interleaved_size_ = -1; 301 bool do_transform_ = false; 302 bool stra_file_only_trainable_params_ = true; 303 std::string fusion_mode_; 304 bool direct_split_ = false; 305 bool pipeline_result_broadcast_ = false; 306 std::vector<symshape::SymbolInfoList> symbol_infos_; 307 }; 308 } // namespace mindspore::parallel 309 #endif // MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_PARALLEL_CONTEXT_H_ 310