1 /** 2 * Copyright 2019-2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_CONTEXT_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_CONTEXT_H_ 19 20 #include <cstdint> 21 #include <map> 22 #include <memory> 23 #include <string> 24 #include <vector> 25 26 #include "abstract/abstract_value.h" 27 #include "frontend/parallel/ops_info/ops_utils.h" 28 #include "frontend/parallel/status.h" 29 #include "ir/anf.h" 30 #include "ir/func_graph.h" 31 #include "utils/convert_utils.h" 32 #include "utils/info.h" 33 #include "pipeline/jit/pipeline.h" 34 35 namespace mindspore { 36 namespace parallel { 37 constexpr char STAND_ALONE[] = "stand_alone"; 38 constexpr char DATA_PARALLEL[] = "data_parallel"; 39 constexpr char HYBRID_PARALLEL[] = "hybrid_parallel"; 40 constexpr char AUTO_PARALLEL[] = "auto_parallel"; 41 constexpr char SEMI_AUTO_PARALLEL[] = "semi_auto_parallel"; 42 43 constexpr char DYNAMIC_PROGRAMMING[] = "dynamic_programming"; 44 constexpr char RECURSIVE_PROGRAMMING[] = "recursive_programming"; 45 46 constexpr char TRAINING[] = "training"; 47 constexpr char ACCUMULATION[] = "accumulation"; 48 49 constexpr char ALL_GROUP_PARALLEL[] = "all_group_parallel"; 50 constexpr char SAME_SERVER_GROUP_PARALLEL[] = "same_server_group_parallel"; 51 constexpr char NO_GROUP_PARALLEL[] = "no_group_parallel"; 52 53 constexpr char IS_FIRST_ITERATION[] = "is_first_iteration"; 54 class ParallelContext { 55 public: 56 ~ParallelContext() = default; 57 ParallelContext(const ParallelContext &) = delete; 58 ParallelContext &operator=(const ParallelContext &) = delete; 59 60 static std::shared_ptr<ParallelContext> GetInstance(); 61 62 void set_gradients_mean(bool gradients_mean); gradients_mean()63 bool gradients_mean() const { return gradients_mean_; } 64 65 void set_full_batch(bool full_batch); full_batch()66 bool full_batch() const { return full_batch_; } 67 68 void set_dataset_strategy(const std::vector<std::vector<int64_t>> &dataset_strategy); dataset_strategy()69 std::vector<std::vector<int64_t>> dataset_strategy() const { return dataset_strategy_; } 70 71 void set_gradient_fp32_sync(bool gradient_fp32_sync); gradient_fp32_sync()72 bool gradient_fp32_sync() const { return gradient_fp32_sync_; } 73 74 void set_loss_repeated_mean(bool loss_repeated_mean); loss_repeated_mean()75 bool loss_repeated_mean() const { return loss_repeated_mean_; } 76 77 void set_device_num(int64_t device_num); device_num()78 int64_t device_num() const { return device_num_; } 79 80 void set_pipeline_stage_split_num(const int64_t stages); pipeline_stage_split_num()81 int64_t pipeline_stage_split_num() const { return pipeline_stage_split_num_; } 82 83 void set_global_rank(int64_t global_rank); global_rank()84 int64_t global_rank() const { return global_rank_; } 85 86 void set_grad_accumulation_step(int64_t grad_accumulation_step); grad_accumulation_step()87 int64_t grad_accumulation_step() const { return grad_accumulation_step_; } 88 89 bool set_parallel_mode(const std::string ¶llel_mode); parallel_mode()90 std::string parallel_mode() const { return parallel_mode_; } 91 92 bool set_strategy_search_mode(const std::string &strategy_search_mode); strategy_search_mode()93 std::string strategy_search_mode() const { return strategy_search_mode_; } 94 95 void set_parameter_broadcast(bool parameter_broadcast); parameter_broadcast()96 bool parameter_broadcast() const { return parameter_broadcast_; } 97 device_num_is_set()98 bool device_num_is_set() const { return device_num_is_set_; } global_rank_is_set()99 bool global_rank_is_set() const { return global_rank_is_set_; } parameter_broadcast_is_set()100 bool parameter_broadcast_is_set() const { return parameter_broadcast_is_set_; } 101 102 void set_optimizer_weight_shard_size(int64_t optimizer_weight_shard_size); optimizer_weight_shard_size()103 int64_t optimizer_weight_shard_size() const { return optimizer_weight_shard_size_; } 104 void set_optimizer_weight_shard_aggregated_save(bool optimizer_weight_shard_aggregated_save); optimizer_weight_shard_aggregated_save()105 bool optimizer_weight_shard_aggregated_save() const { return optimizer_weight_shard_aggregated_save_; } 106 107 void SetAllReduceFusionSplitIndices(const std::vector<uint32_t> &indices, const std::string &group); 108 std::vector<uint32_t> GetAllReduceFusionSplitIndices(const std::string &group) const; 109 void SetAllReduceFusionSplitSizes(const std::vector<uint32_t> &sizes, const std::string &group); 110 std::vector<uint32_t> GetAllReduceFusionSplitSizes(const std::string &group) const; set_enable_all_reduce_fusion(bool enable_all_reduce_fusion)111 void set_enable_all_reduce_fusion(bool enable_all_reduce_fusion) { 112 enable_all_reduce_fusion_ = enable_all_reduce_fusion; 113 } enable_all_reduce_fusion()114 bool enable_all_reduce_fusion() const { return enable_all_reduce_fusion_; } 115 116 void set_strategy_ckpt_load_file(const std::string &strategy_ckpt_load_file); strategy_ckpt_load_file()117 std::string strategy_ckpt_load_file() const { return strategy_ckpt_load_file_; } 118 void set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file); strategy_ckpt_save_file()119 std::string strategy_ckpt_save_file() const { return strategy_ckpt_save_file_; } 120 void set_group_ckpt_save_file(const std::string &group_ckpt_save_file); group_ckpt_save_file()121 std::string group_ckpt_save_file() const { return group_ckpt_save_file_; } 122 set_enable_parallel_optimizer(bool enable_parallel_optimizer)123 void set_enable_parallel_optimizer(bool enable_parallel_optimizer) { 124 enable_parallel_optimizer_ = enable_parallel_optimizer; 125 } enable_parallel_optimizer()126 bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; } 127 128 bool set_communi_parallel_mode(const std::string &communi_parallel_mode); communi_parallel_mode()129 std::string communi_parallel_mode() const { return communi_parallel_mode_; } 130 void set_sharding_propagation(const bool); sharding_propagation()131 bool sharding_propagation() const { return sharding_propagation_; } 132 void set_enable_all2all(const bool); enable_all2all()133 bool enable_all2all() const { return enable_all2all_; } 134 135 void Reset(); 136 void ParallelParameterContextInitShape(const FuncGraphPtr &func_graph); 137 void ParallelParameterContextRestoreShape(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, 138 const AbstractBasePtr &ptr); 139 void ParallelParameterContextCkptShape(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, 140 const AbstractBasePtr &ptr); 141 142 private: 143 ParallelContext(); 144 static std::shared_ptr<ParallelContext> inst_context_; 145 bool gradients_mean_; 146 bool full_batch_; 147 bool gradient_fp32_sync_; 148 bool loss_repeated_mean_; 149 int64_t device_num_; 150 int64_t global_rank_; 151 int64_t grad_accumulation_step_; 152 std::string parallel_mode_; 153 std::string strategy_search_mode_; 154 int64_t pipeline_stage_split_num_; 155 bool parameter_broadcast_; 156 bool device_num_is_set_; 157 bool global_rank_is_set_; 158 bool parameter_broadcast_is_set_; 159 bool enable_all_reduce_fusion_; 160 std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_indices_; 161 std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_sizes_; 162 std::string strategy_ckpt_load_file_; 163 std::string strategy_ckpt_save_file_; 164 std::string group_ckpt_save_file_; 165 bool enable_parallel_optimizer_; 166 bool init_param_shape_; 167 std::string communi_parallel_mode_; 168 int64_t optimizer_weight_shard_size_; 169 bool optimizer_weight_shard_aggregated_save_; 170 // In AUTO_PARALLEL mode, 'sharding_propagation_' = True indicates that sharding-configured operators 171 // will propagate the sharding strategies to other operators with minimum redistribution cost. 172 bool sharding_propagation_; 173 // Enable AllToAll or not. If false, use AllGather and Split. 174 bool enable_all2all_; 175 std::vector<std::vector<int64_t>> dataset_strategy_; 176 }; 177 178 } // namespace parallel 179 } // namespace mindspore 180 181 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_CONTEXT_H_ 182