1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "frontend/parallel/costmodel_context.h" 18 19 #include <memory> 20 21 #include "frontend/parallel/allreduce_fusion/allreduce_fusion.h" 22 #include "utils/ms_context.h" 23 24 namespace mindspore { 25 namespace parallel { 26 std::shared_ptr<CostModelContext> CostModelContext::cm_context_inst_ = nullptr; 27 GetInstance()28std::shared_ptr<CostModelContext> CostModelContext::GetInstance() { 29 if (cm_context_inst_ == nullptr) { 30 MS_LOG(INFO) << "Create costmodel_context"; 31 cm_context_inst_.reset(new (std::nothrow) CostModelContext()); 32 } 33 return cm_context_inst_; 34 } 35 CostModelContext()36CostModelContext::CostModelContext() { 37 ResetCostModel(); 38 ResetAlgoParameters(); 39 } 40 ResetCostModel()41void CostModelContext::ResetCostModel() { 42 device_memory_capacity_ = DEFAULT_DEVICE_MEMORY_CAPACITY; 43 costmodel_alpha_ = DEFAULT_COST_MODEL_ALPHA; 44 costmodel_beta_ = DEFAULT_COST_MODEL_BETA_ASCEND; 45 costmodel_gamma_ = DEFAULT_COST_MODEL_GAMMA; 46 costmodel_communi_threshold_ = DEFAULT_COST_MODEL_COMMUNI_THRESHOLD; 47 costmodel_communi_const_ = DEFAULT_COST_MODEL_COMMUNI_CONST; 48 costmodel_communi_bias_ = DEFAULT_COST_MODEL_COMMUNI_BIAS; 49 is_multi_subgraphs_ = DEFAULT_IS_MULTI_SUBGRAPHS; 50 run_phase_ = TRAINING_PHASE; 51 costmodel_allreduce_fusion_algorithm_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM; 52 costmodel_allreduce_fusion_times_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES; 53 costmodel_allreduce_fusion_tail_percent_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT; 54 costmodel_allreduce_fusion_tail_time_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_TIME; 55 costmodel_allreduce_fusion_allreduce_inherent_time_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_INHERENT_TIME; 56 costmodel_allreduce_fusion_allreduce_bandwidth_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_BANDWIDTH; 57 costmodel_allreduce_fusion_computation_time_parameter_ = 58 DEFAULT_COST_MODEL_ALLREDUCE_FUSION_COMPUTATION_TIME_PARAMETER; 59 dp_algo_single_loop_ = DEFAULT_DP_ALGO_SINGLE_LOOP; 60 rp_matmul_mem_coef_ = DEFAULT_RP_MATMUL_MEM_COEF; 61 } 62 ResetAlgoParameters()63void CostModelContext::ResetAlgoParameters() { 64 costmodel_simplify_cal_ = DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION; 65 tensor_slice_alignment_enable_ = DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE; 66 tensor_slice_alignment_size_ = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE; 67 fully_use_device_ = DEFAULT_FULLY_USE_DEVICES; 68 elementwise_stra_follow_ = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; 69 triangle_star_strategy_overwrite_ = DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE; 70 dp_algo_enable_approxi_ = DEFAULT_DP_ALGO_ENABLE_APPROX; 71 dp_algo_approxi_epsilon_ = DEFAULT_DP_ALGO_APPROX_EPSILON; 72 } 73 PrintCostModel()74void CostModelContext::PrintCostModel() { 75 MS_LOG(INFO) << "device_memory_capacity: " << device_memory_capacity_ << "."; 76 MS_LOG(INFO) << "costmodel_alpha: " << costmodel_alpha_ << "."; 77 MS_LOG(INFO) << "costmodel_beta: " << costmodel_beta_ << "."; 78 MS_LOG(INFO) << "costmodel_gamma: " << costmodel_gamma_ << "."; 79 MS_LOG(INFO) << "costmodel_simplify_cal: " << costmodel_simplify_cal_ << "."; 80 MS_LOG(INFO) << "costmodel_communi_threshold: " << costmodel_communi_threshold_ << "."; 81 MS_LOG(INFO) << "costmodel_communi_const: " << costmodel_communi_const_ << "."; 82 MS_LOG(INFO) << "costmodel_communi_bias: " << costmodel_communi_bias_ << "."; 83 MS_LOG(INFO) << "is_multi_subgraphs: " << is_multi_subgraphs_ << "."; 84 MS_LOG(INFO) << "triangle_star_strategy_overwrite: " << triangle_star_strategy_overwrite_ << "."; 85 MS_LOG(INFO) << "dp_algo_enable_approxi: " << dp_algo_enable_approxi_ << "."; 86 MS_LOG(INFO) << "dp_algo_approxi_epsilon: " << dp_algo_approxi_epsilon_ << "."; 87 MS_LOG(INFO) << "dp_algo_single_loop: " << dp_algo_single_loop_ << "."; 88 MS_LOG(INFO) << "run_phase: " << run_phase_ << "."; 89 MS_LOG(INFO) << "tensor_slice_alignment_enable: " << tensor_slice_alignment_enable_ << "."; 90 MS_LOG(INFO) << "tensor_slice_align_size: " << tensor_slice_alignment_size_ << "."; 91 MS_LOG(INFO) << "fully_use_device: " << fully_use_device_ << "."; 92 MS_LOG(INFO) << "elementwise_stra_follow: " << elementwise_stra_follow_ << "."; 93 MS_LOG(INFO) << "rp_matmul_mem_coef: " << rp_matmul_mem_coef_ << "."; 94 } 95 set_costmodel_context_for_device(const std::string & device_target)96void CostModelContext::set_costmodel_context_for_device(const std::string &device_target) { 97 if (device_target == kGPUDevice) { 98 costmodel_beta_ = DEFAULT_COST_MODEL_BETA_GPU; 99 } 100 } 101 set_dp_algo_approxi_epsilon(double epsilon)102void CostModelContext::set_dp_algo_approxi_epsilon(double epsilon) { 103 if (epsilon <= 0 || epsilon > 1) { 104 MS_LOG(EXCEPTION) << "'epsilon' must be in (0, 1]"; 105 } 106 dp_algo_approxi_epsilon_ = epsilon; 107 } 108 set_rp_matmul_mem_coef(double coef)109void CostModelContext::set_rp_matmul_mem_coef(double coef) { 110 if (coef <= 0) { 111 MS_LOG(EXCEPTION) << "'coef' must be positive"; 112 } 113 rp_matmul_mem_coef_ = coef; 114 } 115 set_dp_algo_enable_approxi(bool approxi)116void CostModelContext::set_dp_algo_enable_approxi(bool approxi) { 117 if (approxi) { 118 MS_LOG(INFO) << "dp_algo_enable_approx: true."; 119 } else { 120 MS_LOG(INFO) << "dp_algo_enable_approx: false."; 121 } 122 dp_algo_enable_approxi_ = approxi; 123 } 124 set_device_memory_capacity(double dm_capacity)125void CostModelContext::set_device_memory_capacity(double dm_capacity) { 126 if (dm_capacity <= 0) { 127 MS_LOG(EXCEPTION) << "'device_memory_capacity' must be positive."; 128 } 129 device_memory_capacity_ = dm_capacity; 130 } 131 set_costmodel_alpha(double cm_alpha)132void CostModelContext::set_costmodel_alpha(double cm_alpha) { 133 if (cm_alpha <= 0) { 134 MS_LOG(EXCEPTION) << "'costmodel_alpha' must be positive."; 135 } 136 costmodel_alpha_ = cm_alpha; 137 } 138 set_costmodel_beta(double cm_beta)139void CostModelContext::set_costmodel_beta(double cm_beta) { 140 if (cm_beta <= 0) { 141 MS_LOG(EXCEPTION) << "'costmodel_beta' must be positive."; 142 } 143 costmodel_beta_ = cm_beta; 144 } 145 set_costmodel_gamma(double cm_gamma)146void CostModelContext::set_costmodel_gamma(double cm_gamma) { 147 if ((cm_gamma < 0) || (cm_gamma > 1)) { 148 MS_LOG(EXCEPTION) << "'costmodel_gamma' must in [0, 1]."; 149 } 150 costmodel_gamma_ = cm_gamma; 151 } 152 set_costmodel_simplify_cal(bool cm_simplify)153void CostModelContext::set_costmodel_simplify_cal(bool cm_simplify) { 154 if (cm_simplify) { 155 MS_LOG(INFO) << "costmodel_simplify_cal: true."; 156 } else { 157 MS_LOG(INFO) << "costmodel_simplify_cal: false."; 158 } 159 costmodel_simplify_cal_ = cm_simplify; 160 } 161 set_costmodel_communi_threshold(double cm_communi_th)162void CostModelContext::set_costmodel_communi_threshold(double cm_communi_th) { 163 if (cm_communi_th < 0) { 164 MS_LOG(EXCEPTION) << "'costmodel_communi_threshold' must be non-zero."; 165 } 166 costmodel_communi_threshold_ = cm_communi_th; 167 } 168 set_costmodel_communi_const(double cm_communi_const)169void CostModelContext::set_costmodel_communi_const(double cm_communi_const) { 170 if (cm_communi_const < 0) { 171 MS_LOG(EXCEPTION) << "'costmodel_communi_const' must be non-zero."; 172 } 173 costmodel_communi_const_ = cm_communi_const; 174 } 175 set_costmodel_communi_bias(double cm_communi_bias)176void CostModelContext::set_costmodel_communi_bias(double cm_communi_bias) { 177 if (cm_communi_bias < 0) { 178 MS_LOG(EXCEPTION) << "'costmodel_communi_bias' must be non-zero."; 179 } 180 costmodel_communi_bias_ = cm_communi_bias; 181 } 182 set_multi_subgraphs(bool multi_graphs)183void CostModelContext::set_multi_subgraphs(bool multi_graphs) { 184 if (multi_graphs) { 185 MS_LOG(INFO) << "multi_subgraphs: true."; 186 } else { 187 MS_LOG(INFO) << "multi_subgraphs: false."; 188 } 189 is_multi_subgraphs_ = multi_graphs; 190 } set_costmodel_allreduce_fusion_algorithm(int64_t algorithm)191void CostModelContext::set_costmodel_allreduce_fusion_algorithm(int64_t algorithm) { 192 costmodel_allreduce_fusion_algorithm_ = algorithm; 193 } 194 set_costmodel_allreduce_fusion_times(int64_t allreduce_fusion_times)195void CostModelContext::set_costmodel_allreduce_fusion_times(int64_t allreduce_fusion_times) { 196 costmodel_allreduce_fusion_times_ = allreduce_fusion_times; 197 } 198 set_costmodel_allreduce_fusion_tail_percent(double tail_percent)199void CostModelContext::set_costmodel_allreduce_fusion_tail_percent(double tail_percent) { 200 costmodel_allreduce_fusion_tail_percent_ = tail_percent; 201 } 202 set_costmodel_allreduce_fusion_tail_time(double tail_time)203void CostModelContext::set_costmodel_allreduce_fusion_tail_time(double tail_time) { 204 costmodel_allreduce_fusion_tail_time_ = tail_time; 205 } 206 set_costmodel_allreduce_fusion_allreduce_inherent_time(double allreduce_inherent_time)207void CostModelContext::set_costmodel_allreduce_fusion_allreduce_inherent_time(double allreduce_inherent_time) { 208 costmodel_allreduce_fusion_allreduce_inherent_time_ = allreduce_inherent_time; 209 } 210 set_costmodel_allreduce_fusion_allreduce_bandwidth(double allreduce_bandwidth)211void CostModelContext::set_costmodel_allreduce_fusion_allreduce_bandwidth(double allreduce_bandwidth) { 212 costmodel_allreduce_fusion_allreduce_bandwidth_ = allreduce_bandwidth; 213 } 214 set_costmodel_allreduce_fusion_computation_time_parameter(double computation_time_parameter)215void CostModelContext::set_costmodel_allreduce_fusion_computation_time_parameter(double computation_time_parameter) { 216 costmodel_allreduce_fusion_computation_time_parameter_ = computation_time_parameter; 217 } 218 set_tensor_slice_alignment_enable(bool ts_align)219void CostModelContext::set_tensor_slice_alignment_enable(bool ts_align) { 220 if (ts_align) { 221 MS_LOG(INFO) << "tensor_slice_align_enable: true."; 222 } else { 223 MS_LOG(INFO) << "tensor_slice_align_enable: false."; 224 } 225 tensor_slice_alignment_enable_ = ts_align; 226 } 227 set_tensor_slice_alignment_size(size_t ts_align_size)228void CostModelContext::set_tensor_slice_alignment_size(size_t ts_align_size) { 229 if (ts_align_size == 0) { 230 MS_LOG(EXCEPTION) << "'tensor_slice_align_size' must be positive."; 231 } 232 tensor_slice_alignment_size_ = ts_align_size; 233 } 234 set_fully_use_device(bool fully_use)235void CostModelContext::set_fully_use_device(bool fully_use) { 236 if (fully_use) { 237 MS_LOG(INFO) << "fully_use_devices: true."; 238 } else { 239 MS_LOG(INFO) << "fully_use_devices: false."; 240 } 241 fully_use_device_ = fully_use; 242 } 243 set_elementwise_stra_follow(bool elementwise_follow)244void CostModelContext::set_elementwise_stra_follow(bool elementwise_follow) { 245 if (elementwise_follow) { 246 MS_LOG(INFO) << "elementwise_op_strategy_follow: true."; 247 } else { 248 MS_LOG(INFO) << "elementwise_op_strategy_follow: false."; 249 } 250 elementwise_stra_follow_ = elementwise_follow; 251 } 252 set_triangle_star_strategy_overwrite(bool overwrite)253void CostModelContext::set_triangle_star_strategy_overwrite(bool overwrite) { 254 if (overwrite) { 255 MS_LOG(INFO) << "triangle_star_strategy_overwrite: true."; 256 } else { 257 MS_LOG(INFO) << "triangle_star_strategy_overwrite: false."; 258 } 259 triangle_star_strategy_overwrite_ = overwrite; 260 } 261 set_run_phase(int64_t phase)262void CostModelContext::set_run_phase(int64_t phase) { 263 if (phase != 0 && phase != 1) { 264 MS_LOG(EXCEPTION) << "'run_phase' must be in {0, 1}"; 265 } 266 run_phase_ = phase; 267 } 268 set_dp_algo_single_loop(bool single_loop)269void CostModelContext::set_dp_algo_single_loop(bool single_loop) { 270 if (single_loop) { 271 MS_LOG(INFO) << "dp_algo_single_loop: true."; 272 } else { 273 MS_LOG(INFO) << "dp_algo_single_loop: false."; 274 } 275 dp_algo_single_loop_ = single_loop; 276 } 277 278 struct CostRegister { CostRegistermindspore::parallel::CostRegister279 CostRegister() { 280 MsContext::device_seter([](const std::string &device_target) { 281 CostModelContext::GetInstance()->set_costmodel_context_for_device(device_target); 282 }); 283 } 284 ~CostRegister() = default; 285 } cost_regsiter; 286 } // namespace parallel 287 } // namespace mindspore 288