1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "frontend/parallel/costmodel_context.h" 18 19 #include <memory> 20 21 #include "frontend/parallel/allreduce_fusion/allreduce_fusion.h" 22 #include "utils/ms_context.h" 23 24 namespace mindspore { 25 namespace parallel { 26 std::shared_ptr<CostModelContext> CostModelContext::cm_context_inst_ = nullptr; 27 GetInstance()28std::shared_ptr<CostModelContext> CostModelContext::GetInstance() { 29 if (cm_context_inst_ == nullptr) { 30 MS_LOG(INFO) << "Create costmodel_context"; 31 cm_context_inst_.reset(new (std::nothrow) CostModelContext()); 32 } 33 return cm_context_inst_; 34 } 35 CostModelContext()36CostModelContext::CostModelContext() { 37 ResetCostModel(); 38 ResetAlgoParameters(); 39 } 40 ResetCostModel()41void CostModelContext::ResetCostModel() { 42 device_memory_capacity_ = DEFAULT_DEVICE_MEMORY_CAPACITY; 43 costmodel_alpha_ = DEFAULT_COST_MODEL_ALPHA; 44 costmodel_beta_ = DEFAULT_COST_MODEL_BETA_ASCEND; 45 costmodel_gamma_ = DEFAULT_COST_MODEL_GAMMA; 46 costmodel_communi_threshold_ = DEFAULT_COST_MODEL_COMMUNI_THRESHOLD; 47 costmodel_communi_const_ = DEFAULT_COST_MODEL_COMMUNI_CONST; 48 costmodel_communi_bias_ = DEFAULT_COST_MODEL_COMMUNI_BIAS; 49 is_multi_subgraphs_ = DEFAULT_IS_MULTI_SUBGRAPHS; 50 run_phase_ = TRAINING_PHASE; 51 costmodel_allreduce_fusion_algorithm_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM; 52 costmodel_allreduce_fusion_times_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES; 53 costmodel_allreduce_fusion_tail_percent_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT; 54 costmodel_allreduce_fusion_tail_time_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_TIME; 55 costmodel_allreduce_fusion_allreduce_inherent_time_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_INHERENT_TIME; 56 costmodel_allreduce_fusion_allreduce_bandwidth_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_BANDWIDTH; 57 costmodel_allreduce_fusion_computation_time_parameter_ = 58 DEFAULT_COST_MODEL_ALLREDUCE_FUSION_COMPUTATION_TIME_PARAMETER; 59 dp_algo_single_loop_ = DEFAULT_DP_ALGO_SINGLE_LOOP; 60 } 61 ResetAlgoParameters()62void CostModelContext::ResetAlgoParameters() { 63 costmodel_simplify_cal_ = DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION; 64 tensor_slice_alignment_enable_ = DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE; 65 tensor_slice_alignment_size_ = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE; 66 fully_use_device_ = DEFAULT_FULLY_USE_DEVICES; 67 elementwise_stra_follow_ = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; 68 triangle_star_strategy_overwrite_ = DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE; 69 dp_algo_enable_approxi_ = DEFAULT_DP_ALGO_ENABLE_APPROX; 70 dp_algo_approxi_epsilon_ = DEFAULT_DP_ALGO_APPROX_EPSILON; 71 } 72 PrintCostModel()73void CostModelContext::PrintCostModel() { 74 MS_LOG(INFO) << "device_memory_capacity: " << device_memory_capacity_ << "."; 75 MS_LOG(INFO) << "costmodel_alpha: " << costmodel_alpha_ << "."; 76 MS_LOG(INFO) << "costmodel_beta: " << costmodel_beta_ << "."; 77 MS_LOG(INFO) << "costmodel_gamma: " << costmodel_gamma_ << "."; 78 MS_LOG(INFO) << "costmodel_simplify_cal: " << costmodel_simplify_cal_ << "."; 79 MS_LOG(INFO) << "costmodel_communi_threshold: " << costmodel_communi_threshold_ << "."; 80 MS_LOG(INFO) << "costmodel_communi_const: " << costmodel_communi_const_ << "."; 81 MS_LOG(INFO) << "costmodel_communi_bias: " << costmodel_communi_bias_ << "."; 82 MS_LOG(INFO) << "is_multi_subgraphs: " << is_multi_subgraphs_ << "."; 83 MS_LOG(INFO) << "triangle_star_strategy_overwrite: " << triangle_star_strategy_overwrite_ << "."; 84 MS_LOG(INFO) << "dp_algo_enable_approxi: " << dp_algo_enable_approxi_ << "."; 85 MS_LOG(INFO) << "dp_algo_approxi_epsilon: " << dp_algo_approxi_epsilon_ << "."; 86 MS_LOG(INFO) << "dp_algo_single_loop: " << dp_algo_single_loop_ << "."; 87 MS_LOG(INFO) << "run_phase: " << run_phase_ << "."; 88 MS_LOG(INFO) << "tensor_slice_alignment_enable: " << tensor_slice_alignment_enable_ << "."; 89 MS_LOG(INFO) << "tensor_slice_align_size: " << tensor_slice_alignment_size_ << "."; 90 MS_LOG(INFO) << "fully_use_device: " << fully_use_device_ << "."; 91 MS_LOG(INFO) << "elementwise_stra_follow: " << elementwise_stra_follow_ << "."; 92 } 93 set_costmodel_context_for_device(const std::string & device_target)94void CostModelContext::set_costmodel_context_for_device(const std::string &device_target) { 95 if (device_target == kGPUDevice) { 96 costmodel_beta_ = DEFAULT_COST_MODEL_BETA_GPU; 97 } 98 } 99 set_dp_algo_approxi_epsilon(double epsilon)100void CostModelContext::set_dp_algo_approxi_epsilon(double epsilon) { 101 if (epsilon <= 0 || epsilon > 1) { 102 MS_LOG(EXCEPTION) << "'epsilon' must be in (0, 1]"; 103 } 104 dp_algo_approxi_epsilon_ = epsilon; 105 } 106 set_dp_algo_enable_approxi(bool approxi)107void CostModelContext::set_dp_algo_enable_approxi(bool approxi) { 108 if (approxi) { 109 MS_LOG(INFO) << "dp_algo_enable_approx: true."; 110 } else { 111 MS_LOG(INFO) << "dp_algo_enable_approx: false."; 112 } 113 dp_algo_enable_approxi_ = approxi; 114 } 115 set_device_memory_capacity(double dm_capacity)116void CostModelContext::set_device_memory_capacity(double dm_capacity) { 117 if (dm_capacity <= 0) { 118 MS_LOG(EXCEPTION) << "'device_memory_capacity' must be positive."; 119 } 120 device_memory_capacity_ = dm_capacity; 121 } 122 set_costmodel_alpha(double cm_alpha)123void CostModelContext::set_costmodel_alpha(double cm_alpha) { 124 if (cm_alpha <= 0) { 125 MS_LOG(EXCEPTION) << "'costmodel_alpha' must be positive."; 126 } 127 costmodel_alpha_ = cm_alpha; 128 } 129 set_costmodel_beta(double cm_beta)130void CostModelContext::set_costmodel_beta(double cm_beta) { 131 if (cm_beta <= 0) { 132 MS_LOG(EXCEPTION) << "'costmodel_beta' must be positive."; 133 } 134 costmodel_beta_ = cm_beta; 135 } 136 set_costmodel_gamma(double cm_gamma)137void CostModelContext::set_costmodel_gamma(double cm_gamma) { 138 if ((cm_gamma < 0) || (cm_gamma > 1)) { 139 MS_LOG(EXCEPTION) << "'costmodel_gamma' must in [0, 1]."; 140 } 141 costmodel_gamma_ = cm_gamma; 142 } 143 set_costmodel_simplify_cal(bool cm_simplify)144void CostModelContext::set_costmodel_simplify_cal(bool cm_simplify) { 145 if (cm_simplify) { 146 MS_LOG(INFO) << "costmodel_simplify_cal: true."; 147 } else { 148 MS_LOG(INFO) << "costmodel_simplify_cal: false."; 149 } 150 costmodel_simplify_cal_ = cm_simplify; 151 } 152 set_costmodel_communi_threshold(double cm_communi_th)153void CostModelContext::set_costmodel_communi_threshold(double cm_communi_th) { 154 if (cm_communi_th < 0) { 155 MS_LOG(EXCEPTION) << "'costmodel_communi_threshold' must be non-zero."; 156 } 157 costmodel_communi_threshold_ = cm_communi_th; 158 } 159 set_costmodel_communi_const(double cm_communi_const)160void CostModelContext::set_costmodel_communi_const(double cm_communi_const) { 161 if (cm_communi_const < 0) { 162 MS_LOG(EXCEPTION) << "'costmodel_communi_const' must be non-zero."; 163 } 164 costmodel_communi_const_ = cm_communi_const; 165 } 166 set_costmodel_communi_bias(double cm_communi_bias)167void CostModelContext::set_costmodel_communi_bias(double cm_communi_bias) { 168 if (cm_communi_bias < 0) { 169 MS_LOG(EXCEPTION) << "'costmodel_communi_bias' must be non-zero."; 170 } 171 costmodel_communi_bias_ = cm_communi_bias; 172 } 173 set_multi_subgraphs(bool multi_graphs)174void CostModelContext::set_multi_subgraphs(bool multi_graphs) { 175 if (multi_graphs) { 176 MS_LOG(INFO) << "multi_subgraphs: true."; 177 } else { 178 MS_LOG(INFO) << "multi_subgraphs: false."; 179 } 180 is_multi_subgraphs_ = multi_graphs; 181 } set_costmodel_allreduce_fusion_algorithm(int64_t algorithm)182void CostModelContext::set_costmodel_allreduce_fusion_algorithm(int64_t algorithm) { 183 costmodel_allreduce_fusion_algorithm_ = algorithm; 184 } 185 set_costmodel_allreduce_fusion_times(int64_t allreduce_fusion_times)186void CostModelContext::set_costmodel_allreduce_fusion_times(int64_t allreduce_fusion_times) { 187 costmodel_allreduce_fusion_times_ = allreduce_fusion_times; 188 } 189 set_costmodel_allreduce_fusion_tail_percent(double tail_percent)190void CostModelContext::set_costmodel_allreduce_fusion_tail_percent(double tail_percent) { 191 costmodel_allreduce_fusion_tail_percent_ = tail_percent; 192 } 193 set_costmodel_allreduce_fusion_tail_time(double tail_time)194void CostModelContext::set_costmodel_allreduce_fusion_tail_time(double tail_time) { 195 costmodel_allreduce_fusion_tail_time_ = tail_time; 196 } 197 set_costmodel_allreduce_fusion_allreduce_inherent_time(double allreduce_inherent_time)198void CostModelContext::set_costmodel_allreduce_fusion_allreduce_inherent_time(double allreduce_inherent_time) { 199 costmodel_allreduce_fusion_allreduce_inherent_time_ = allreduce_inherent_time; 200 } 201 set_costmodel_allreduce_fusion_allreduce_bandwidth(double allreduce_bandwidth)202void CostModelContext::set_costmodel_allreduce_fusion_allreduce_bandwidth(double allreduce_bandwidth) { 203 costmodel_allreduce_fusion_allreduce_bandwidth_ = allreduce_bandwidth; 204 } 205 set_costmodel_allreduce_fusion_computation_time_parameter(double computation_time_parameter)206void CostModelContext::set_costmodel_allreduce_fusion_computation_time_parameter(double computation_time_parameter) { 207 costmodel_allreduce_fusion_computation_time_parameter_ = computation_time_parameter; 208 } 209 set_tensor_slice_alignment_enable(bool ts_align)210void CostModelContext::set_tensor_slice_alignment_enable(bool ts_align) { 211 if (ts_align) { 212 MS_LOG(INFO) << "tensor_slice_align_enable: true."; 213 } else { 214 MS_LOG(INFO) << "tensor_slice_align_enable: false."; 215 } 216 tensor_slice_alignment_enable_ = ts_align; 217 } 218 set_tensor_slice_alignment_size(size_t ts_align_size)219void CostModelContext::set_tensor_slice_alignment_size(size_t ts_align_size) { 220 if (ts_align_size == 0) { 221 MS_LOG(EXCEPTION) << "'tensor_slice_align_size' must be positive."; 222 } 223 tensor_slice_alignment_size_ = ts_align_size; 224 } 225 set_fully_use_device(bool fully_use)226void CostModelContext::set_fully_use_device(bool fully_use) { 227 if (fully_use) { 228 MS_LOG(INFO) << "fully_use_devices: true."; 229 } else { 230 MS_LOG(INFO) << "fully_use_devices: false."; 231 } 232 fully_use_device_ = fully_use; 233 } 234 set_elementwise_stra_follow(bool elementwise_follow)235void CostModelContext::set_elementwise_stra_follow(bool elementwise_follow) { 236 if (elementwise_follow) { 237 MS_LOG(INFO) << "elementwise_op_strategy_follow: true."; 238 } else { 239 MS_LOG(INFO) << "elementwise_op_strategy_follow: false."; 240 } 241 elementwise_stra_follow_ = elementwise_follow; 242 } 243 set_triangle_star_strategy_overwrite(bool overwrite)244void CostModelContext::set_triangle_star_strategy_overwrite(bool overwrite) { 245 if (overwrite) { 246 MS_LOG(INFO) << "triangle_star_strategy_overwrite: true."; 247 } else { 248 MS_LOG(INFO) << "triangle_star_strategy_overwrite: false."; 249 } 250 triangle_star_strategy_overwrite_ = overwrite; 251 } 252 set_run_phase(int64_t phase)253void CostModelContext::set_run_phase(int64_t phase) { 254 if (phase != 0 && phase != 1) { 255 MS_LOG(EXCEPTION) << "'run_phase' must be in {0, 1}"; 256 } 257 run_phase_ = phase; 258 } 259 set_dp_algo_single_loop(bool single_loop)260void CostModelContext::set_dp_algo_single_loop(bool single_loop) { 261 if (single_loop) { 262 MS_LOG(INFO) << "dp_algo_single_loop: true."; 263 } else { 264 MS_LOG(INFO) << "dp_algo_single_loop: false."; 265 } 266 dp_algo_single_loop_ = single_loop; 267 } 268 269 struct CostRegister { CostRegistermindspore::parallel::CostRegister270 CostRegister() { 271 MsContext::device_seter([](const std::string &device_target) { 272 CostModelContext::GetInstance()->set_costmodel_context_for_device(device_target); 273 }); 274 } 275 ~CostRegister() = default; 276 } cost_regsiter; 277 } // namespace parallel 278 } // namespace mindspore 279