• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/costmodel_context.h"
18 
19 #include <memory>
20 
21 #include "frontend/parallel/allreduce_fusion/allreduce_fusion.h"
22 #include "utils/ms_context.h"
23 
24 namespace mindspore {
25 namespace parallel {
26 std::shared_ptr<CostModelContext> CostModelContext::cm_context_inst_ = nullptr;
27 
GetInstance()28 std::shared_ptr<CostModelContext> CostModelContext::GetInstance() {
29   if (cm_context_inst_ == nullptr) {
30     MS_LOG(INFO) << "Create costmodel_context";
31     cm_context_inst_.reset(new (std::nothrow) CostModelContext());
32   }
33   return cm_context_inst_;
34 }
35 
CostModelContext()36 CostModelContext::CostModelContext() {
37   ResetCostModel();
38   ResetAlgoParameters();
39 }
40 
ResetCostModel()41 void CostModelContext::ResetCostModel() {
42   device_memory_capacity_ = DEFAULT_DEVICE_MEMORY_CAPACITY;
43   costmodel_alpha_ = DEFAULT_COST_MODEL_ALPHA;
44   costmodel_beta_ = DEFAULT_COST_MODEL_BETA_ASCEND;
45   costmodel_gamma_ = DEFAULT_COST_MODEL_GAMMA;
46   costmodel_communi_threshold_ = DEFAULT_COST_MODEL_COMMUNI_THRESHOLD;
47   costmodel_communi_const_ = DEFAULT_COST_MODEL_COMMUNI_CONST;
48   costmodel_communi_bias_ = DEFAULT_COST_MODEL_COMMUNI_BIAS;
49   is_multi_subgraphs_ = DEFAULT_IS_MULTI_SUBGRAPHS;
50   run_phase_ = TRAINING_PHASE;
51   costmodel_allreduce_fusion_algorithm_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM;
52   costmodel_allreduce_fusion_times_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES;
53   costmodel_allreduce_fusion_tail_percent_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT;
54   costmodel_allreduce_fusion_tail_time_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_TIME;
55   costmodel_allreduce_fusion_allreduce_inherent_time_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_INHERENT_TIME;
56   costmodel_allreduce_fusion_allreduce_bandwidth_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_BANDWIDTH;
57   costmodel_allreduce_fusion_computation_time_parameter_ =
58     DEFAULT_COST_MODEL_ALLREDUCE_FUSION_COMPUTATION_TIME_PARAMETER;
59   dp_algo_single_loop_ = DEFAULT_DP_ALGO_SINGLE_LOOP;
60   rp_matmul_mem_coef_ = DEFAULT_RP_MATMUL_MEM_COEF;
61 }
62 
ResetAlgoParameters()63 void CostModelContext::ResetAlgoParameters() {
64   costmodel_simplify_cal_ = DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION;
65   tensor_slice_alignment_enable_ = DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE;
66   tensor_slice_alignment_size_ = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE;
67   fully_use_device_ = DEFAULT_FULLY_USE_DEVICES;
68   elementwise_stra_follow_ = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW;
69   triangle_star_strategy_overwrite_ = DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE;
70   dp_algo_enable_approxi_ = DEFAULT_DP_ALGO_ENABLE_APPROX;
71   dp_algo_approxi_epsilon_ = DEFAULT_DP_ALGO_APPROX_EPSILON;
72 }
73 
PrintCostModel()74 void CostModelContext::PrintCostModel() {
75   MS_LOG(INFO) << "device_memory_capacity: " << device_memory_capacity_ << ".";
76   MS_LOG(INFO) << "costmodel_alpha: " << costmodel_alpha_ << ".";
77   MS_LOG(INFO) << "costmodel_beta: " << costmodel_beta_ << ".";
78   MS_LOG(INFO) << "costmodel_gamma: " << costmodel_gamma_ << ".";
79   MS_LOG(INFO) << "costmodel_simplify_cal: " << costmodel_simplify_cal_ << ".";
80   MS_LOG(INFO) << "costmodel_communi_threshold: " << costmodel_communi_threshold_ << ".";
81   MS_LOG(INFO) << "costmodel_communi_const: " << costmodel_communi_const_ << ".";
82   MS_LOG(INFO) << "costmodel_communi_bias: " << costmodel_communi_bias_ << ".";
83   MS_LOG(INFO) << "is_multi_subgraphs: " << is_multi_subgraphs_ << ".";
84   MS_LOG(INFO) << "triangle_star_strategy_overwrite: " << triangle_star_strategy_overwrite_ << ".";
85   MS_LOG(INFO) << "dp_algo_enable_approxi: " << dp_algo_enable_approxi_ << ".";
86   MS_LOG(INFO) << "dp_algo_approxi_epsilon: " << dp_algo_approxi_epsilon_ << ".";
87   MS_LOG(INFO) << "dp_algo_single_loop: " << dp_algo_single_loop_ << ".";
88   MS_LOG(INFO) << "run_phase: " << run_phase_ << ".";
89   MS_LOG(INFO) << "tensor_slice_alignment_enable: " << tensor_slice_alignment_enable_ << ".";
90   MS_LOG(INFO) << "tensor_slice_align_size: " << tensor_slice_alignment_size_ << ".";
91   MS_LOG(INFO) << "fully_use_device: " << fully_use_device_ << ".";
92   MS_LOG(INFO) << "elementwise_stra_follow: " << elementwise_stra_follow_ << ".";
93   MS_LOG(INFO) << "rp_matmul_mem_coef: " << rp_matmul_mem_coef_ << ".";
94 }
95 
set_costmodel_context_for_device(const std::string & device_target)96 void CostModelContext::set_costmodel_context_for_device(const std::string &device_target) {
97   if (device_target == kGPUDevice) {
98     costmodel_beta_ = DEFAULT_COST_MODEL_BETA_GPU;
99   }
100 }
101 
set_dp_algo_approxi_epsilon(double epsilon)102 void CostModelContext::set_dp_algo_approxi_epsilon(double epsilon) {
103   if (epsilon <= 0 || epsilon > 1) {
104     MS_LOG(EXCEPTION) << "'epsilon' must be in (0, 1]";
105   }
106   dp_algo_approxi_epsilon_ = epsilon;
107 }
108 
set_rp_matmul_mem_coef(double coef)109 void CostModelContext::set_rp_matmul_mem_coef(double coef) {
110   if (coef <= 0) {
111     MS_LOG(EXCEPTION) << "'coef' must be positive";
112   }
113   rp_matmul_mem_coef_ = coef;
114 }
115 
set_dp_algo_enable_approxi(bool approxi)116 void CostModelContext::set_dp_algo_enable_approxi(bool approxi) {
117   if (approxi) {
118     MS_LOG(INFO) << "dp_algo_enable_approx: true.";
119   } else {
120     MS_LOG(INFO) << "dp_algo_enable_approx: false.";
121   }
122   dp_algo_enable_approxi_ = approxi;
123 }
124 
set_device_memory_capacity(double dm_capacity)125 void CostModelContext::set_device_memory_capacity(double dm_capacity) {
126   if (dm_capacity <= 0) {
127     MS_LOG(EXCEPTION) << "'device_memory_capacity' must be positive.";
128   }
129   device_memory_capacity_ = dm_capacity;
130 }
131 
set_costmodel_alpha(double cm_alpha)132 void CostModelContext::set_costmodel_alpha(double cm_alpha) {
133   if (cm_alpha <= 0) {
134     MS_LOG(EXCEPTION) << "'costmodel_alpha' must be positive.";
135   }
136   costmodel_alpha_ = cm_alpha;
137 }
138 
set_costmodel_beta(double cm_beta)139 void CostModelContext::set_costmodel_beta(double cm_beta) {
140   if (cm_beta <= 0) {
141     MS_LOG(EXCEPTION) << "'costmodel_beta' must be positive.";
142   }
143   costmodel_beta_ = cm_beta;
144 }
145 
set_costmodel_gamma(double cm_gamma)146 void CostModelContext::set_costmodel_gamma(double cm_gamma) {
147   if ((cm_gamma < 0) || (cm_gamma > 1)) {
148     MS_LOG(EXCEPTION) << "'costmodel_gamma' must in [0, 1].";
149   }
150   costmodel_gamma_ = cm_gamma;
151 }
152 
set_costmodel_simplify_cal(bool cm_simplify)153 void CostModelContext::set_costmodel_simplify_cal(bool cm_simplify) {
154   if (cm_simplify) {
155     MS_LOG(INFO) << "costmodel_simplify_cal: true.";
156   } else {
157     MS_LOG(INFO) << "costmodel_simplify_cal: false.";
158   }
159   costmodel_simplify_cal_ = cm_simplify;
160 }
161 
set_costmodel_communi_threshold(double cm_communi_th)162 void CostModelContext::set_costmodel_communi_threshold(double cm_communi_th) {
163   if (cm_communi_th < 0) {
164     MS_LOG(EXCEPTION) << "'costmodel_communi_threshold' must be non-zero.";
165   }
166   costmodel_communi_threshold_ = cm_communi_th;
167 }
168 
set_costmodel_communi_const(double cm_communi_const)169 void CostModelContext::set_costmodel_communi_const(double cm_communi_const) {
170   if (cm_communi_const < 0) {
171     MS_LOG(EXCEPTION) << "'costmodel_communi_const' must be non-zero.";
172   }
173   costmodel_communi_const_ = cm_communi_const;
174 }
175 
set_costmodel_communi_bias(double cm_communi_bias)176 void CostModelContext::set_costmodel_communi_bias(double cm_communi_bias) {
177   if (cm_communi_bias < 0) {
178     MS_LOG(EXCEPTION) << "'costmodel_communi_bias' must be non-zero.";
179   }
180   costmodel_communi_bias_ = cm_communi_bias;
181 }
182 
set_multi_subgraphs(bool multi_graphs)183 void CostModelContext::set_multi_subgraphs(bool multi_graphs) {
184   if (multi_graphs) {
185     MS_LOG(INFO) << "multi_subgraphs: true.";
186   } else {
187     MS_LOG(INFO) << "multi_subgraphs: false.";
188   }
189   is_multi_subgraphs_ = multi_graphs;
190 }
set_costmodel_allreduce_fusion_algorithm(int64_t algorithm)191 void CostModelContext::set_costmodel_allreduce_fusion_algorithm(int64_t algorithm) {
192   costmodel_allreduce_fusion_algorithm_ = algorithm;
193 }
194 
set_costmodel_allreduce_fusion_times(int64_t allreduce_fusion_times)195 void CostModelContext::set_costmodel_allreduce_fusion_times(int64_t allreduce_fusion_times) {
196   costmodel_allreduce_fusion_times_ = allreduce_fusion_times;
197 }
198 
set_costmodel_allreduce_fusion_tail_percent(double tail_percent)199 void CostModelContext::set_costmodel_allreduce_fusion_tail_percent(double tail_percent) {
200   costmodel_allreduce_fusion_tail_percent_ = tail_percent;
201 }
202 
set_costmodel_allreduce_fusion_tail_time(double tail_time)203 void CostModelContext::set_costmodel_allreduce_fusion_tail_time(double tail_time) {
204   costmodel_allreduce_fusion_tail_time_ = tail_time;
205 }
206 
set_costmodel_allreduce_fusion_allreduce_inherent_time(double allreduce_inherent_time)207 void CostModelContext::set_costmodel_allreduce_fusion_allreduce_inherent_time(double allreduce_inherent_time) {
208   costmodel_allreduce_fusion_allreduce_inherent_time_ = allreduce_inherent_time;
209 }
210 
set_costmodel_allreduce_fusion_allreduce_bandwidth(double allreduce_bandwidth)211 void CostModelContext::set_costmodel_allreduce_fusion_allreduce_bandwidth(double allreduce_bandwidth) {
212   costmodel_allreduce_fusion_allreduce_bandwidth_ = allreduce_bandwidth;
213 }
214 
set_costmodel_allreduce_fusion_computation_time_parameter(double computation_time_parameter)215 void CostModelContext::set_costmodel_allreduce_fusion_computation_time_parameter(double computation_time_parameter) {
216   costmodel_allreduce_fusion_computation_time_parameter_ = computation_time_parameter;
217 }
218 
set_tensor_slice_alignment_enable(bool ts_align)219 void CostModelContext::set_tensor_slice_alignment_enable(bool ts_align) {
220   if (ts_align) {
221     MS_LOG(INFO) << "tensor_slice_align_enable: true.";
222   } else {
223     MS_LOG(INFO) << "tensor_slice_align_enable: false.";
224   }
225   tensor_slice_alignment_enable_ = ts_align;
226 }
227 
set_tensor_slice_alignment_size(size_t ts_align_size)228 void CostModelContext::set_tensor_slice_alignment_size(size_t ts_align_size) {
229   if (ts_align_size == 0) {
230     MS_LOG(EXCEPTION) << "'tensor_slice_align_size' must be positive.";
231   }
232   tensor_slice_alignment_size_ = ts_align_size;
233 }
234 
set_fully_use_device(bool fully_use)235 void CostModelContext::set_fully_use_device(bool fully_use) {
236   if (fully_use) {
237     MS_LOG(INFO) << "fully_use_devices: true.";
238   } else {
239     MS_LOG(INFO) << "fully_use_devices: false.";
240   }
241   fully_use_device_ = fully_use;
242 }
243 
set_elementwise_stra_follow(bool elementwise_follow)244 void CostModelContext::set_elementwise_stra_follow(bool elementwise_follow) {
245   if (elementwise_follow) {
246     MS_LOG(INFO) << "elementwise_op_strategy_follow: true.";
247   } else {
248     MS_LOG(INFO) << "elementwise_op_strategy_follow: false.";
249   }
250   elementwise_stra_follow_ = elementwise_follow;
251 }
252 
set_triangle_star_strategy_overwrite(bool overwrite)253 void CostModelContext::set_triangle_star_strategy_overwrite(bool overwrite) {
254   if (overwrite) {
255     MS_LOG(INFO) << "triangle_star_strategy_overwrite: true.";
256   } else {
257     MS_LOG(INFO) << "triangle_star_strategy_overwrite: false.";
258   }
259   triangle_star_strategy_overwrite_ = overwrite;
260 }
261 
set_run_phase(int64_t phase)262 void CostModelContext::set_run_phase(int64_t phase) {
263   if (phase != 0 && phase != 1) {
264     MS_LOG(EXCEPTION) << "'run_phase' must be in {0, 1}";
265   }
266   run_phase_ = phase;
267 }
268 
set_dp_algo_single_loop(bool single_loop)269 void CostModelContext::set_dp_algo_single_loop(bool single_loop) {
270   if (single_loop) {
271     MS_LOG(INFO) << "dp_algo_single_loop: true.";
272   } else {
273     MS_LOG(INFO) << "dp_algo_single_loop: false.";
274   }
275   dp_algo_single_loop_ = single_loop;
276 }
277 
278 struct CostRegister {
CostRegistermindspore::parallel::CostRegister279   CostRegister() {
280     MsContext::device_seter([](const std::string &device_target) {
281       CostModelContext::GetInstance()->set_costmodel_context_for_device(device_target);
282     });
283   }
284   ~CostRegister() = default;
285 } cost_regsiter;
286 }  // namespace parallel
287 }  // namespace mindspore
288