• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/auto_parallel/operator_costmodel.h"
18 
19 #include <algorithm>
20 #include <random>
21 #include "frontend/parallel/device_matrix.h"
22 #include "frontend/parallel/tensor_layout/tensor_redistribution.h"
23 
24 namespace mindspore {
25 namespace parallel {
set_is_parameter(const std::vector<bool> & is_parameter)26 void OperatorCost::set_is_parameter(const std::vector<bool> &is_parameter) { is_parameter_ = is_parameter; }
27 
set_is_parameter_involve(const std::vector<bool> & is_parameter_inv)28 void OperatorCost::set_is_parameter_involve(const std::vector<bool> &is_parameter_inv) {
29   is_parameter_involve_ = is_parameter_inv;
30   is_inputs_should_in_memory_ = std::vector<bool>(is_parameter_involve_.size(), false);
31 }
32 
set_output_parameter_involve(int64_t output_para)33 void OperatorCost::set_output_parameter_involve(int64_t output_para) { output_parameter_involve_ = output_para; }
34 
SetInputAndOutputTypeLength(const std::vector<size_t> & input_lengths,const std::vector<size_t> & output_lengths)35 void OperatorCost::SetInputAndOutputTypeLength(const std::vector<size_t> &input_lengths,
36                                                const std::vector<size_t> &output_lengths) {
37   inputs_type_lengths_ = input_lengths;
38   outputs_type_lengths_ = output_lengths;
39 }
40 
set_output_critical(int64_t critical)41 void OperatorCost::set_output_critical(int64_t critical) { is_outputs_critical_ = critical; }
42 
GetMemoryCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs) const43 double OperatorCost::GetMemoryCost(const std::vector<TensorInfo> &inputs,
44                                    const std::vector<TensorInfo> &outputs) const {
45   return GetInputMemoryCost(inputs, outputs) + GetOutputMemoryCost(inputs, outputs);
46 }
47 
GetInputMemoryCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &) const48 double OperatorCost::GetInputMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &) const {
49   double result = 0.0;
50   for (size_t i = 0; i < inputs.size(); ++i) {
51     if (is_inputs_should_in_memory_[i]) {
52       result += ListProduct(inputs[i].slice_shape()) * static_cast<double>(inputs_type_lengths_[i]);
53     }
54   }
55   return result;
56 }
57 
GetOutputMemoryCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> & outputs) const58 double OperatorCost::GetOutputMemoryCost(const std::vector<TensorInfo> &,
59                                          const std::vector<TensorInfo> &outputs) const {
60   double result = 0.0;
61   if (is_output_should_in_memory_) {
62     // When this operator has multiple outputs, they all contributes to the memory.
63     for (size_t i = 0; i < outputs.size(); ++i) {
64       result += ListProduct(outputs[i].slice_shape()) * static_cast<double>(outputs_type_lengths_[i]);
65     }
66   }
67   return result;
68 }
69 
GetMemoryCostForInference(const std::vector<TensorInfo> &,const std::vector<TensorInfo> & outputs) const70 double OperatorCost::GetMemoryCostForInference(const std::vector<TensorInfo> &,
71                                                const std::vector<TensorInfo> &outputs) const {
72   double result = 0.0;
73   if (is_outputs_critical_ == -1) {
74     MS_LOG(EXCEPTION) << "The critical flag is not set.";
75   }
76   if (is_outputs_critical_ == 1) {
77     for (size_t i = 0; i < outputs.size(); ++i) {
78       result += ListProduct(outputs[i].slice_shape()) * static_cast<double>(outputs_type_lengths_[i]);
79     }
80   }
81   return result;
82 }
83 
84 // return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const85 double MatMulCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
86                                       int64_t) const {
87   TensorInfo input0 = inputs[0];
88   TensorInfo output0 = outputs[0];
89   Shape input0_shape = input0.shape();
90   Shape input0_slice_shape = input0.slice_shape();
91   if (input0_shape[input0_shape.size() - 1] == input0_slice_shape[input0_slice_shape.size() - 1]) {
92     // If the reduced dimension has not been partitioned, then there is no communication cost.
93     return 0.0;
94   } else {
95     // Else, the communication cost is the size (number of bytes) of a slice of output tensor.
96     return ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
97   }
98 }
99 
100 // return the per device communication cost in the forward phase.
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const101 double MatMulCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
102                                        int64_t stage_id) const {
103   // In backward phase, the communication cost is incurred only when tensor B is a Parameter and tensor B does not
104   // fully utilize all devices
105   double result = 0.0;
106   if (is_parameter_[1]) {
107     TensorInfo input1 = inputs[1];  // tensor B
108     CheckGlobalDeviceManager();
109     MS_EXCEPTION_IF_NULL(g_device_manager);
110     auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
111 
112     Shape input1_shape = input1.shape();
113     Shape input1_slice_shape = input1.slice_shape();
114     int64_t used_device_num = 1;
115     for (size_t i = 0; i < input1_shape.size(); ++i) {
116       used_device_num *= input1_shape[i] / input1_slice_shape[i];
117     }
118 
119     if (total_device_num != LongToSize(used_device_num)) {
120       result += ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
121     }
122   }
123 
124   return result;
125 }
126 
127 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
128 // this operator uses
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const129 double MatMulCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
130                                              const std::vector<TensorInfo> &outputs, int64_t) const {
131   // In forward phase, the computation cost = slice(A) + slice(B) + (0 or 1) allreduce(slice(C))
132   double result = 0.0;
133   TensorInfo output0 = outputs[0];
134   Shape input0_slice_shape = inputs[0].slice_shape();
135   Shape input1_slice_shape = inputs[1].slice_shape();
136   Shape input0_shape = inputs[0].shape();
137   if (input0_shape[input0_shape.size() - 1] != input0_slice_shape[input0_slice_shape.size() - 1]) {
138     // If the reduced dimension has been partitioned, then there is no communication cost.
139     result += ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
140   }
141   result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
142             ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
143   return result;
144 }
145 
146 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
147 // this operator uses
GetBackwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const148 double MatMulCost::GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
149                                               int64_t stage_id) const {
150   // In backward phase, the computation cost = (0 or 1) allreduce(slice(B))
151   double result = 0.0;
152   if (is_parameter_[1]) {
153     TensorInfo input1 = inputs[1];  // tensor B
154     CheckGlobalDeviceManager();
155     MS_EXCEPTION_IF_NULL(g_device_manager);
156     auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
157 
158     Shape input1_shape = input1.shape();
159     Shape input1_slice_shape = input1.slice_shape();
160     int64_t used_device_num = 1;
161     for (size_t i = 0; i < input1_shape.size(); ++i) {
162       used_device_num *= input1_shape[i] / input1_slice_shape[i];
163     }
164 
165     if (total_device_num != LongToSize(used_device_num)) {
166       result += ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
167     }
168   }
169 
170   return result;
171 }
172 
173 // Not taking account of output
CalculateOutputInMemory()174 void MatMulCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
175 
176 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)177 void MatMulCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
178   if (is_parameter_[0]) {
179     is_inputs_should_in_memory_[0] = true;
180     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
181       is_inputs_should_in_memory_[1] = true;
182     }
183   } else if (is_parameter_involve_[0]) {
184     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
185       is_inputs_should_in_memory_[1] = true;
186     }
187   }
188 
189   if (is_parameter_[1]) {
190     is_inputs_should_in_memory_[1] = true;
191     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
192       is_inputs_should_in_memory_[0] = true;
193     }
194   } else if (is_parameter_involve_[1]) {
195     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
196       is_inputs_should_in_memory_[0] = true;
197     }
198   }
199 }
200 
201 // return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const202 double BatchNormCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
203                                          int64_t) const {
204   TensorInfo input0 = inputs[0];
205   Shape input0_shape = input0.shape();
206   if (input0_shape.size() < 2) {
207     MS_LOG(EXCEPTION) << "The dimension of first input can not be smaller than 2, but got " << input0_shape.size();
208   }
209   Shape input0_slice_shape = input0.slice_shape();
210   if (input0_shape[1] == input0_slice_shape[1]) {
211     // If the 'channel' dimension has not been partitioned, then there is no communication cost.
212     return 0.0;
213   } else {
214     // Else, the communication cost is the size (number of bytes) of a slice of input tensor.
215     return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
216   }
217 }
218 
219 // return the per device communication cost in the forward phase.
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const220 double BatchNormCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
221                                           int64_t stage_id) const {
222   // In backward phase, the communication cost is incurred only when last 4 tensor are Parameter and they does not
223   // fully utilize all devices
224   double result = 0.0, tmp_cost = 0.0;
225 
226   TensorInfo input1 = inputs[1];  // tensor gamma
227   CheckGlobalDeviceManager();
228   MS_EXCEPTION_IF_NULL(g_device_manager);
229   auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
230 
231   Shape input1_shape = input1.shape();
232   Shape input1_slice_shape = input1.slice_shape();
233   int64_t used_device_num = 1;
234   for (size_t i = 0; i < input1_shape.size(); ++i) {
235     used_device_num *= input1_shape[i] / input1_slice_shape[i];
236   }
237 
238   if (total_device_num != LongToSize(used_device_num)) {
239     tmp_cost = ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
240   }
241 
242   for (size_t i = 1; i < is_parameter_.size(); ++i) {
243     if (is_parameter_[i]) {
244       result += tmp_cost;
245     }
246   }
247   return result;
248 }
249 
250 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
251 // this operator uses
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const252 double BatchNormCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
253                                                 int64_t) const {
254   double result = 0.0;
255   for (size_t i = 0; i < inputs.size(); ++i) {
256     result += ListProduct(inputs[i].slice_shape()) * static_cast<double>(inputs_type_lengths_[i]);
257   }
258   return result;
259 }
260 
261 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
262 // this operator uses
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const263 double BatchNormCost::GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
264                                                  int64_t) const {
265   return 0.0;
266 }
267 
268 // Not taking account of output
CalculateOutputInMemory()269 void BatchNormCost::CalculateOutputInMemory() { is_output_should_in_memory_ = true; }
270 
271 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> &)272 void BatchNormCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
273   for (size_t i = 0; i < is_inputs_should_in_memory_.size(); ++i) {
274     is_inputs_should_in_memory_[0] = true;
275   }
276 }
277 
278 // Return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const279 double CastCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const {
280   // ReLU is the element-wise operator, thus it does not need communication in the forward phase
281   return 0.0;
282 }
283 
284 // Return the per device communication cost in the backward phase.
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const285 double CastCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
286                                      int64_t stage_id) const {
287   double result = 0.0;
288   if (is_parameter_[0]) {
289     TensorInfo input1 = inputs[0];
290     MS_EXCEPTION_IF_NULL(g_device_manager);
291     auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
292     Shape input1_shape = input1.shape();
293     Shape input1_slice_shape = input1.slice_shape();
294     int64_t used_device_num = 1;
295     for (size_t i = 0; i < input1_shape.size(); ++i) {
296       used_device_num *= input1_shape[i] / input1_slice_shape[i];
297     }
298     if (total_device_num != LongToSize(used_device_num)) {
299       result = ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
300     }
301   }
302   return result;
303 }
304 
305 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
306 // this operator uses
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const307 double CastCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
308                                            int64_t) const {
309   TensorInfo input0 = inputs[0];
310   Shape input0_slice_shape = input0.slice_shape();
311   return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
312 }
313 
314 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
315 // this operator uses
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const316 double CastCost::GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
317                                             int64_t) const {
318   return 0.0;
319 }
320 
321 // Not taking account of output
CalculateOutputInMemory()322 void CastCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
323 
324 // Not taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> &)325 void CastCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
326   is_inputs_should_in_memory_[0] = is_parameter_[0];
327 }
328 
329 // Taking account of output
CalculateOutputInMemory()330 void SqrtCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
331 
332 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)333 void GeLUCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
334   if (is_parameter_[0]) {
335     is_inputs_should_in_memory_[0] = true;
336   } else if (is_parameter_involve_[0]) {
337     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
338       is_inputs_should_in_memory_[0] = true;
339     }
340   }
341 }
342 
343 // Return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const344 double SoftmaxCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
345                                        int64_t) const {
346   // In the forward phase, the communication cost = 0
347   return 0.0;
348 }
349 
350 // Return the per device communication cost in the backward phase.
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const351 double SoftmaxCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
352                                         int64_t stage_id) const {
353   double result = 0.0;
354   if (is_parameter_[0]) {
355     TensorInfo input1 = inputs[0];
356     MS_EXCEPTION_IF_NULL(g_device_manager);
357     auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
358     Shape input1_shape = input1.shape();
359     Shape input1_slice_shape = input1.slice_shape();
360     int64_t used_device_num = 1;
361     for (size_t i = 0; i < input1_shape.size(); ++i) {
362       used_device_num *= input1_shape[i] / input1_slice_shape[i];
363     }
364     if (total_device_num != LongToSize(used_device_num)) {
365       result = ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
366     }
367   }
368   return result;
369 }
370 
371 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
372 // this operator uses
GetForwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> & outputs,int64_t) const373 double SoftmaxCost::GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &outputs,
374                                               int64_t) const {
375   if (outputs.empty() || outputs_type_lengths_.empty()) {
376     MS_LOG(EXCEPTION) << "The outputs or outputs_type_length is empty";
377   }
378 
379   // use output for Tile operator
380   TensorInfo output_info = outputs[0];
381   Shape output_slice_shape = output_info.slice_shape();
382   return ListProduct(output_slice_shape) * static_cast<double>(outputs_type_lengths_[0]);
383 }
384 
385 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
386 // this operator uses
GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,const std::vector<mindspore::parallel::TensorInfo> &,int64_t) const387 double SoftmaxCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,
388                                                const std::vector<mindspore::parallel::TensorInfo> &, int64_t) const {
389   return 0.0;
390 }
391 
392 // Taking account of output
CalculateOutputInMemory()393 void SoftmaxCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
394 
395 // Not taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> &)396 void SoftmaxCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
397   is_inputs_should_in_memory_[0] = is_parameter_[0];
398 }
399 
400 // Not taking account of output
CalculateOutputInMemory()401 void PackCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
402 
403 // Not taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> &)404 void PackCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
405   is_inputs_should_in_memory_[0] = is_parameter_[0];
406 }
407 
408 // Not taking account of output
CalculateOutputInMemory()409 void TileCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
410 
411 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)412 void TileCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
413   // When calculating 'dx', taking account of 'y'
414   if (is_parameter_[0]) {
415     is_inputs_should_in_memory_[0] = true;
416     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
417       is_inputs_should_in_memory_[1] = true;
418     }
419   } else if (is_parameter_involve_[0]) {
420     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
421       is_inputs_should_in_memory_[1] = true;
422     }
423   }
424 
425   if (!is_inputs_should_in_memory_[1]) {
426     is_inputs_should_in_memory_[1] = is_parameter_[1];
427   }
428 }
429 
430 // Not taking account of output
CalculateOutputInMemory()431 void BroadcastToCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
432 
CalculateInputsInMemory(const std::map<size_t,bool> &)433 void BroadcastToCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
434   is_inputs_should_in_memory_[0] = is_parameter_[0];
435 }
436 
437 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)438 void ReLU6Cost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
439   if (is_parameter_[0]) {
440     is_inputs_should_in_memory_[0] = true;
441   } else if (is_parameter_involve_[0]) {
442     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
443       is_inputs_should_in_memory_[0] = true;
444     }
445   }
446 }
447 
448 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)449 void TransposeCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
450   // When calculating 'dx', taking account of 'y'
451   if (is_parameter_[0]) {
452     is_inputs_should_in_memory_[0] = true;
453     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
454       is_inputs_should_in_memory_[1] = true;
455     }
456   } else if (is_parameter_involve_[0]) {
457     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
458       is_inputs_should_in_memory_[1] = true;
459     }
460   }
461 
462   if (!is_inputs_should_in_memory_[1]) {
463     is_inputs_should_in_memory_[1] = is_parameter_[1];
464   }
465 }
466 
467 // return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<mindspore::parallel::TensorInfo> &,const std::vector<mindspore::parallel::TensorInfo> &,int64_t) const468 double TmpIdentityCost::GetForwardCommCost(const std::vector<mindspore::parallel::TensorInfo> &,
469                                            const std::vector<mindspore::parallel::TensorInfo> &, int64_t) const {
470   // Identity is the element-wise operator, thus it does not need communication in the forward phase
471   return 0.0;
472 }
473 
474 // return the per device communication cost in the backward phase.
GetBackwardCommCost(const std::vector<mindspore::parallel::TensorInfo> &,const std::vector<mindspore::parallel::TensorInfo> &,int64_t) const475 double TmpIdentityCost::GetBackwardCommCost(const std::vector<mindspore::parallel::TensorInfo> &,
476                                             const std::vector<mindspore::parallel::TensorInfo> &, int64_t) const {
477   // Identity is the element-wise operator, thus it does not need communication in the backward phase
478   return 0.0;
479 }
480 
481 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
482 // this operator uses
GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,const std::vector<mindspore::parallel::TensorInfo> &,int64_t) const483 double TmpIdentityCost::GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,
484                                                   const std::vector<mindspore::parallel::TensorInfo> &, int64_t) const {
485   return 0.0;
486 }
487 
488 // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes
489 // this operator uses
GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,const std::vector<mindspore::parallel::TensorInfo> &,int64_t) const490 double TmpIdentityCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,
491                                                    const std::vector<mindspore::parallel::TensorInfo> &,
492                                                    int64_t) const {
493   return 0.0;
494 }
495 
496 // Not taking account of output
CalculateOutputInMemory()497 void TmpIdentityCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
498 
499 // Not taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> &)500 void TmpIdentityCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
501   is_inputs_should_in_memory_[0] = is_parameter_[0];
502 }
503 
GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> & inputs,const std::vector<mindspore::parallel::TensorInfo> &,int64_t) const504 double BatchParallelCost::GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &inputs,
505                                                     const std::vector<mindspore::parallel::TensorInfo> &,
506                                                     int64_t) const {
507   double cost = 0.0;
508   for (size_t i = 0; i < inputs.size(); ++i) {
509     cost += ListProduct(inputs[i].slice_shape()) * static_cast<double>(inputs_type_lengths_[i]);
510   }
511   return cost;
512 }
513 
GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,const std::vector<mindspore::parallel::TensorInfo> &,int64_t) const514 double BatchParallelCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,
515                                                      const std::vector<mindspore::parallel::TensorInfo> &,
516                                                      int64_t) const {
517   return 0.0;
518 }
519 
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const520 double BatchParallelCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
521                                               int64_t stage_id) const {
522   double result = 0.0;
523   CheckGlobalDeviceManager();
524   MS_EXCEPTION_IF_NULL(g_device_manager);
525   auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
526 
527   for (size_t j = 0; j < inputs.size(); ++j) {
528     if (!is_parameter_[j]) {
529       continue;
530     }
531     TensorInfo input_a_tensor_info = inputs[j];
532     Shape input_a_shape = input_a_tensor_info.shape();
533     Shape input_a_slice_shape = input_a_tensor_info.slice_shape();
534     int64_t used_device_num = 1;
535     for (size_t i = 0; i < input_a_shape.size(); ++i) {
536       used_device_num *= input_a_shape[i] / input_a_slice_shape[i];
537     }
538     if (total_device_num != LongToSize(used_device_num)) {
539       result += ListProduct(input_a_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
540     }
541   }
542 
543   return result;
544 }
545 
CalculateOutputInMemory()546 void BatchParallelCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
547 
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)548 void BatchParallelCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
549   if (is_parameter_[0]) {
550     is_inputs_should_in_memory_[0] = true;
551     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
552       is_inputs_should_in_memory_[1] = true;
553     }
554   } else if (is_parameter_involve_[0]) {
555     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
556       is_inputs_should_in_memory_[1] = true;
557     }
558   }
559 
560   if (is_parameter_[1]) {
561     is_inputs_should_in_memory_[1] = true;
562     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
563       is_inputs_should_in_memory_[0] = true;
564     }
565   } else if (is_parameter_involve_[1]) {
566     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
567       is_inputs_should_in_memory_[0] = true;
568     }
569   }
570 }
571 
CalculateOutputInMemory()572 void SparseSoftmaxCrossEntropyWithLogitsCost::CalculateOutputInMemory() {
573   is_output_should_in_memory_ = is_parameter_involve_[0];
574 }
575 
CalculateInputsInMemory(const std::map<size_t,bool> &)576 void SparseSoftmaxCrossEntropyWithLogitsCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
577   is_inputs_should_in_memory_[0] = is_parameter_[0];
578   is_inputs_should_in_memory_[1] = is_parameter_[1];
579 }
580 // return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const581 double PReLUCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const {
582   // prelu does not need communication in the forward phase
583   return 0.0;
584 }
585 
586 // return the per device communication cost in the backward phase.
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const587 double PReLUCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
588                                       int64_t stage_id) const {
589   double result = 0.0;
590   if (is_parameter_[1]) {
591     TensorInfo input1 = inputs[1];
592     CheckGlobalDeviceManager();
593     MS_EXCEPTION_IF_NULL(g_device_manager);
594     auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
595     Shape input1_shape = input1.shape();
596     Shape input1_slice_shape = input1.slice_shape();
597     int64_t used_device_num = 1;
598     for (size_t i = 0; i < input1_shape.size(); ++i) {
599       used_device_num *= input1_shape[i] / input1_slice_shape[i];
600     }
601     if (total_device_num != LongToSize(used_device_num)) {
602       result = ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
603     }
604   }
605   return result;
606 }
607 
608 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
609 // this operator uses
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const610 double PReLUCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
611                                             int64_t) const {
612   // In forward phase, the computation cost = slice(A) + slice(B)
613   Shape input0_slice_shape = inputs[0].slice_shape();
614   Shape input1_slice_shape = inputs[1].slice_shape();
615   double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
616                   ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
617   return result;
618 }
619 
620 // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes
621 // this operator uses
GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> & inputs,const std::vector<mindspore::parallel::TensorInfo> &,int64_t stage_id) const622 double PReLUCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &inputs,
623                                              const std::vector<mindspore::parallel::TensorInfo> &,
624                                              int64_t stage_id) const {
625   // In backward phase, the computation cost = (0 or 1) allreduce(slice(B))
626   double result = 0.0;
627   if (is_parameter_[1]) {
628     TensorInfo input1 = inputs[1];  // tensor B
629     CheckGlobalDeviceManager();
630     MS_EXCEPTION_IF_NULL(g_device_manager);
631     auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
632 
633     Shape input1_shape = input1.shape();
634     Shape input1_slice_shape = input1.slice_shape();
635     int64_t used_device_num = 1;
636     for (size_t i = 0; i < input1_shape.size(); ++i) {
637       used_device_num *= input1_shape[i] / input1_slice_shape[i];
638     }
639 
640     if (total_device_num != LongToSize(used_device_num)) {
641       result += ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
642     }
643   }
644   return result;
645 }
646 
CalculateOutputInMemory()647 void PReLUCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
648 
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)649 void PReLUCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
650   // When calculating 'dx', taking account of both 'x' and 'y';
651   // when calculating 'dy', taking account of both 'x' and 'y'
652   if (is_parameter_involve_[0] || is_parameter_involve_[1]) {
653     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
654       is_inputs_should_in_memory_[0] = true;
655     }
656     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
657       is_inputs_should_in_memory_[1] = true;
658     }
659   }
660 }
661 
662 // return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const663 double OneHotCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const {
664   // onehot does not need communication in the forward phase
665   return 0.0;
666 }
667 
668 // return the per device communication cost in the backward phase.
GetBackwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const669 double OneHotCost::GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
670                                        int64_t) const {
671   // onehot does not need communication in the backward phase
672   return 0.0;
673 }
674 
675 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
676 // this operator uses
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const677 double OneHotCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
678                                              int64_t) const {
679   // In onehot's forward phase, the computation cost = slice(A)
680   Shape input0_slice_shape = inputs[0].slice_shape();
681   return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
682 }
683 
684 // Return the per  device computation cost in the backward phase. The cost is calculated according to the bytes
685 // this operator uses
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const686 double OneHotCost::GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
687                                               int64_t) const {
688   return 0.0;
689 }
690 
691 // Not taking account of output
CalculateOutputInMemory()692 void OneHotCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
693 
694 // Not taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> &)695 void OneHotCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
696   is_inputs_should_in_memory_[0] = is_parameter_[0];
697   is_inputs_should_in_memory_[1] = is_parameter_[1];
698   is_inputs_should_in_memory_[ONEHOT_INPUTS_SIZE - 2] = is_parameter_[ONEHOT_INPUTS_SIZE - 2];
699   is_inputs_should_in_memory_[ONEHOT_INPUTS_SIZE - 1] = is_parameter_[ONEHOT_INPUTS_SIZE - 1];
700 }
701 
702 // return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const703 double SoftmaxCrossEntropyWithLogitsCost::GetForwardCommCost(const std::vector<TensorInfo> &,
704                                                              const std::vector<TensorInfo> &, int64_t) const {
705   // SoftmaxCrossEntropyWithLogitsCost does not need communication in the forward phase
706   return 0.0;
707 }
708 
709 // return the per device communication cost in the backward phase.
GetBackwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const710 double SoftmaxCrossEntropyWithLogitsCost::GetBackwardCommCost(const std::vector<TensorInfo> &,
711                                                               const std::vector<TensorInfo> &, int64_t) const {
712   // SoftmaxCrossEntropyWithLogitsCost does not need communication in the backward phase
713   return 0.0;
714 }
715 
716 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
717 // this operator uses
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const718 double SoftmaxCrossEntropyWithLogitsCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
719                                                                     const std::vector<TensorInfo> &, int64_t) const {
720   // In forward phase, the computation cost = slice(A) + slice(B)
721   Shape input0_slice_shape = inputs[0].slice_shape();
722   Shape input1_slice_shape = inputs[1].slice_shape();
723   double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
724                   ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
725   return result;
726 }
727 
728 // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes
729 // this operator uses
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const730 double SoftmaxCrossEntropyWithLogitsCost::GetBackwardComputationCost(const std::vector<TensorInfo> &,
731                                                                      const std::vector<TensorInfo> &, int64_t) const {
732   return 0.0;
733 }
734 
735 // Taking account of output
CalculateOutputInMemory()736 void SoftmaxCrossEntropyWithLogitsCost::CalculateOutputInMemory() {
737   is_output_should_in_memory_ = is_parameter_involve_[0];
738 }
739 
CalculateInputsInMemory(const std::map<size_t,bool> &)740 void SoftmaxCrossEntropyWithLogitsCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
741   is_inputs_should_in_memory_[0] = is_parameter_[0];
742   is_inputs_should_in_memory_[1] = is_parameter_[1];
743 }
744 
745 // return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id) const746 double ReshapeCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
747                                        int64_t stage_id) const {
748   CheckGlobalDeviceManager();
749   MS_EXCEPTION_IF_NULL(g_device_manager);
750   RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id);
751   TensorRedistribution tensor_redistribution(false, true);
752   if (tensor_redistribution.Init(inputs[0].tensor_layout(), outputs[0].tensor_layout(), dev_list) == FAILED) {
753     MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed.";
754   }
755   if (tensor_redistribution.ComputeCost() == FAILED) {
756     MS_LOG(EXCEPTION) << "Failure: tensor_redistribution ComputeCost failed.";
757   }
758   return (inputs_type_lengths_[0] * tensor_redistribution.comm_cost());
759 }
760 
761 // return the per device communication cost in the backward phase.
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const762 double ReshapeCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
763                                         int64_t stage_id) const {
764   double result = 0.0;
765   if (is_parameter_[0]) {
766     TensorInfo input1 = inputs[0];
767     MS_EXCEPTION_IF_NULL(g_device_manager);
768     auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
769     Shape input1_shape = input1.shape();
770     Shape input1_slice_shape = input1.slice_shape();
771     int64_t used_device_num = 1;
772     for (size_t i = 0; i < input1_shape.size(); ++i) {
773       used_device_num *= input1_shape[i] / input1_slice_shape[i];
774     }
775     if (total_device_num != LongToSize(used_device_num)) {
776       result = ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
777     }
778   }
779   return result;
780 }
781 
782 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
783 // this operator uses
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id) const784 double ReshapeCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
785                                               const std::vector<TensorInfo> &outputs, int64_t stage_id) const {
786   CheckGlobalDeviceManager();
787   MS_EXCEPTION_IF_NULL(g_device_manager);
788   RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id);
789   TensorRedistribution tensor_redistribution(false, true);
790   if (tensor_redistribution.Init(inputs[0].tensor_layout(), outputs[0].tensor_layout(), dev_list) == FAILED) {
791     MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed.";
792   }
793   if (tensor_redistribution.ComputeCost() == FAILED) {
794     MS_LOG(EXCEPTION) << "Failure: tensor_redistribution ComputeCost failed.";
795   }
796   return (inputs_type_lengths_[0] * tensor_redistribution.computation_cost());
797 }
798 
799 // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes
800 // this operator uses
GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,const std::vector<mindspore::parallel::TensorInfo> &,int64_t) const801 double ReshapeCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,
802                                                const std::vector<mindspore::parallel::TensorInfo> &, int64_t) const {
803   return 0.0;
804 }
805 
CalculateOutputInMemory()806 void ReshapeCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
807 
CalculateInputsInMemory(const std::map<size_t,bool> &)808 void ReshapeCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
809   is_inputs_should_in_memory_[0] = is_parameter_[0];
810   is_inputs_should_in_memory_[1] = is_parameter_[1];
811 }
812 
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const813 double SubCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
814                                           int64_t) const {
815   double result = ListProduct(inputs[0].slice_shape()) * static_cast<double>(inputs_type_lengths_[0]) +
816                   ListProduct(inputs[1].slice_shape()) * static_cast<double>(inputs_type_lengths_[1]);
817   return result;
818 }
819 
GetBackwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const820 double SubCost::GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
821                                            int64_t stage_id) const {
822   double result = 0.0;
823   CheckGlobalDeviceManager();
824   MS_EXCEPTION_IF_NULL(g_device_manager);
825   auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
826 
827   if (is_parameter_[0]) {
828     TensorInfo input_a_tensor_info = inputs[0];
829     Shape input_a_shape = input_a_tensor_info.shape();
830     Shape input_a_slice_shape = input_a_tensor_info.slice_shape();
831     int64_t used_device_num = 1;
832     for (size_t i = 0; i < input_a_shape.size(); ++i) {
833       used_device_num *= input_a_shape[i] / input_a_slice_shape[i];
834     }
835 
836     if (total_device_num != LongToSize(used_device_num)) {
837       result += ListProduct(input_a_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
838     }
839   }
840 
841   if (is_parameter_[1]) {
842     TensorInfo input_b_tensor_info = inputs[1];
843     Shape input_b_shape = input_b_tensor_info.shape();
844     Shape input_b_slice_shape = input_b_tensor_info.slice_shape();
845     int64_t used_device_num = 1;
846     for (size_t i = 0; i < input_b_shape.size(); ++i) {
847       used_device_num *= input_b_shape[i] / input_b_slice_shape[i];
848     }
849 
850     if (total_device_num != LongToSize(used_device_num)) {
851       result += ListProduct(input_b_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
852     }
853   }
854   return result;
855 }
856 
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const857 double SubCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
858                                     int64_t stage_id) const {
859   double result = 0.0;
860   CheckGlobalDeviceManager();
861   MS_EXCEPTION_IF_NULL(g_device_manager);
862   auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
863 
864   if (is_parameter_[0]) {
865     TensorInfo input_a_tensor_info = inputs[0];
866     Shape input_a_shape = input_a_tensor_info.shape();
867     Shape input_a_slice_shape = input_a_tensor_info.slice_shape();
868     int64_t used_device_num = 1;
869     for (size_t i = 0; i < input_a_shape.size(); ++i) {
870       used_device_num *= input_a_shape[i] / input_a_slice_shape[i];
871     }
872 
873     if (total_device_num != LongToSize(used_device_num)) {
874       result += ListProduct(input_a_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
875     }
876   }
877 
878   if (is_parameter_[1]) {
879     TensorInfo input_b_tensor_info = inputs[1];
880     Shape input_b_shape = input_b_tensor_info.shape();
881     Shape input_b_slice_shape = input_b_tensor_info.slice_shape();
882     int64_t used_device_num = 1;
883     for (size_t i = 0; i < input_b_shape.size(); ++i) {
884       used_device_num *= input_b_shape[i] / input_b_slice_shape[i];
885     }
886 
887     if (total_device_num != LongToSize(used_device_num)) {
888       result += ListProduct(input_b_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
889     }
890   }
891 
892   return result;
893 }
894 
895 // Not taking account of output
CalculateOutputInMemory()896 void SubCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
897 
898 // Not taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> &)899 void SubCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
900   is_inputs_should_in_memory_[0] = is_parameter_[0];
901   is_inputs_should_in_memory_[1] = is_parameter_[1];
902 }
903 
904 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)905 void MulCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
906   if (is_parameter_[0]) {
907     // 'x' is parameter, so it should be in memory.
908     is_inputs_should_in_memory_[0] = true;
909     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
910       // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
911       is_inputs_should_in_memory_[1] = true;
912     }
913   } else if (is_parameter_involve_[0]) {
914     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
915       is_inputs_should_in_memory_[1] = true;
916     }
917   }
918 
919   if (is_parameter_[1]) {
920     is_inputs_should_in_memory_[1] = true;
921     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
922       is_inputs_should_in_memory_[0] = true;
923     }
924   } else if (is_parameter_involve_[1]) {
925     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
926       is_inputs_should_in_memory_[0] = true;
927     }
928   }
929 }
930 
931 // Taking account of output
CalculateOutputInMemory()932 void DivCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[1]; }
933 
934 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)935 void DivCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
936   // When calculating 'dx', taking account of 'y'
937   if (is_parameter_[0]) {
938     // 'x' is parameter, so it should be in memory.
939     is_inputs_should_in_memory_[0] = true;
940     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
941       // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
942       is_inputs_should_in_memory_[1] = true;
943     }
944   } else if (is_parameter_involve_[0]) {
945     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
946       is_inputs_should_in_memory_[1] = true;
947     }
948   }
949 
950   // When calculating 'dy', taking account of 'y'
951   if (is_parameter_involve_[1]) {
952     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
953       is_inputs_should_in_memory_[1] = true;
954     }
955   }
956 }
957 
958 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)959 void ModCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
960   // When calculating 'dx', not taking account of 'x' and 'y'
961   is_inputs_should_in_memory_[0] = is_parameter_[0];
962   // When calculating 'dy', taking account of 'x' and 'y'
963   if (is_parameter_involve_[1]) {
964     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
965       is_inputs_should_in_memory_[0] = true;
966     }
967     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
968       is_inputs_should_in_memory_[1] = true;
969     }
970   }
971 }
972 
CalculateOutputInMemory()973 void PowCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[1]; }
974 
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)975 void PowCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
976   // When calculating 'dx', taking account of both 'x' and 'power'
977   if (is_parameter_involve_[0]) {
978     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
979       is_inputs_should_in_memory_[0] = true;
980     }
981     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
982       is_inputs_should_in_memory_[1] = true;
983     }
984   }
985   // When calculating 'dpower', taking account of 'x'
986   if (is_parameter_[1]) {
987     is_inputs_should_in_memory_[1] = true;
988     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
989       is_inputs_should_in_memory_[0] = true;
990     }
991   } else if (is_parameter_involve_[1]) {
992     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
993       is_inputs_should_in_memory_[0] = true;
994     }
995   }
996 }
997 
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)998 void AssignCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
999   // When calculating 'dx', taking account of 'x'
1000   if (is_parameter_involve_[0]) {
1001     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
1002       is_inputs_should_in_memory_[0] = true;
1003     }
1004   }
1005   // When calculating 'dy', not taking account of 'x' and 'y'
1006   is_inputs_should_in_memory_[1] = is_parameter_[1];
1007 }
1008 
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1009 void SigmoidCrossEntropyWithLogitsCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1010   // When calculating 'dx', taking account of both 'x' and 'y'
1011   if (is_parameter_involve_[0]) {
1012     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
1013       is_inputs_should_in_memory_[0] = true;
1014     }
1015     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1016       is_inputs_should_in_memory_[1] = true;
1017     }
1018   }
1019   // When calculating 'dy', not taking account of 'x' and 'y'
1020   if (!is_inputs_should_in_memory_[1]) {
1021     is_inputs_should_in_memory_[1] = is_parameter_[1];
1022   }
1023 }
1024 
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1025 void Atan2Cost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1026   // When calculating 'dx', taking account of both 'x' and 'y'; when calculating 'dy', taking account of both 'x' and
1027   // 'y'
1028   if (is_parameter_involve_[0] || is_parameter_involve_[1]) {
1029     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
1030       is_inputs_should_in_memory_[0] = true;
1031     }
1032     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1033       is_inputs_should_in_memory_[1] = true;
1034     }
1035   }
1036 }
1037 
CalculateOutputInMemory()1038 void DivNoNanCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[1]; }
1039 
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1040 void DivNoNanCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1041   // When calculating 'dx', taking account of 'y'
1042   if (is_parameter_[0]) {
1043     // 'x' is parameter, so it should be in memory.
1044     is_inputs_should_in_memory_[0] = true;
1045     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1046       // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
1047       is_inputs_should_in_memory_[1] = true;
1048     }
1049   } else if (is_parameter_involve_[0]) {
1050     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1051       is_inputs_should_in_memory_[1] = true;
1052     }
1053   }
1054 
1055   // When calculating 'dy', taking account of 'y'
1056   if (is_parameter_involve_[1]) {
1057     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1058       is_inputs_should_in_memory_[1] = true;
1059     }
1060   }
1061 }
1062 
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1063 void MaximumCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1064   // When calculating 'dx', taking account of both 'x' and 'y';
1065   // when calculating 'dy', taking account of both 'x' and 'y'
1066   if (is_parameter_involve_[0] || is_parameter_involve_[1]) {
1067     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
1068       is_inputs_should_in_memory_[0] = true;
1069     }
1070     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1071       is_inputs_should_in_memory_[1] = true;
1072     }
1073   }
1074 }
1075 
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1076 void SliceCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1077   // When calculating 'dx', taking account of 'y' and 'z'
1078   if (is_parameter_[0]) {
1079     is_inputs_should_in_memory_[0] = true;
1080     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1081       is_inputs_should_in_memory_[1] = true;
1082     }
1083     if ((prev_output_in_mem.find(SLICE_INPUTS_SIZE - 1) == prev_output_in_mem.end()) ||
1084         (!prev_output_in_mem.at(SLICE_INPUTS_SIZE - 1))) {
1085       is_inputs_should_in_memory_[SLICE_INPUTS_SIZE - 1] = true;
1086     }
1087   } else if (is_parameter_involve_[0]) {
1088     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1089       is_inputs_should_in_memory_[1] = true;
1090     }
1091     if ((prev_output_in_mem.find(SLICE_INPUTS_SIZE - 1) == prev_output_in_mem.end()) ||
1092         (!prev_output_in_mem.at(SLICE_INPUTS_SIZE - 1))) {
1093       is_inputs_should_in_memory_[SLICE_INPUTS_SIZE - 1] = true;
1094     }
1095   }
1096 
1097   if (!is_inputs_should_in_memory_[1]) {
1098     is_inputs_should_in_memory_[1] = is_parameter_[1];
1099   }
1100   if (!is_inputs_should_in_memory_[SLICE_INPUTS_SIZE - 1]) {
1101     is_inputs_should_in_memory_[SLICE_INPUTS_SIZE - 1] = is_parameter_[SLICE_INPUTS_SIZE - 1];
1102   }
1103 }
1104 
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1105 void StridedSliceCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1106   // When calculating 'dx', taking account of 'y', 'z' and 'w'
1107   if (is_parameter_[0]) {
1108     is_inputs_should_in_memory_[0] = true;
1109     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1110       is_inputs_should_in_memory_[1] = true;
1111     }
1112     if ((prev_output_in_mem.find(STRIDED_SLICE_INPUTS_SIZE - 2) == prev_output_in_mem.end()) ||
1113         (!prev_output_in_mem.at(STRIDED_SLICE_INPUTS_SIZE - 2))) {
1114       is_inputs_should_in_memory_[STRIDED_SLICE_INPUTS_SIZE - 2] = true;
1115     }
1116     if ((prev_output_in_mem.find(STRIDED_SLICE_INPUTS_SIZE - 1) == prev_output_in_mem.end()) ||
1117         (!prev_output_in_mem.at(STRIDED_SLICE_INPUTS_SIZE - 1))) {
1118       is_inputs_should_in_memory_[STRIDED_SLICE_INPUTS_SIZE - 1] = true;
1119     }
1120   } else if (is_parameter_involve_[0]) {
1121     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1122       is_inputs_should_in_memory_[1] = true;
1123     }
1124     if ((prev_output_in_mem.find(STRIDED_SLICE_INPUTS_SIZE - 2) == prev_output_in_mem.end()) ||
1125         (!prev_output_in_mem.at(STRIDED_SLICE_INPUTS_SIZE - 2))) {
1126       is_inputs_should_in_memory_[STRIDED_SLICE_INPUTS_SIZE - 2] = true;
1127     }
1128     if ((prev_output_in_mem.find(STRIDED_SLICE_INPUTS_SIZE - 1) == prev_output_in_mem.end()) ||
1129         (!prev_output_in_mem.at(STRIDED_SLICE_INPUTS_SIZE - 1))) {
1130       is_inputs_should_in_memory_[STRIDED_SLICE_INPUTS_SIZE - 1] = true;
1131     }
1132   }
1133 
1134   if (!is_inputs_should_in_memory_[1]) {
1135     is_inputs_should_in_memory_[1] = is_parameter_[1];
1136   }
1137   if (!is_inputs_should_in_memory_[STRIDED_SLICE_INPUTS_SIZE - 2]) {
1138     is_inputs_should_in_memory_[STRIDED_SLICE_INPUTS_SIZE - 2] = is_parameter_[STRIDED_SLICE_INPUTS_SIZE - 2];
1139   }
1140   if (!is_inputs_should_in_memory_[STRIDED_SLICE_INPUTS_SIZE - 1]) {
1141     is_inputs_should_in_memory_[STRIDED_SLICE_INPUTS_SIZE - 1] = is_parameter_[STRIDED_SLICE_INPUTS_SIZE - 1];
1142   }
1143 }
1144 
CalculateOutputInMemory()1145 void DropOutDoMaskCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
1146 
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1147 void DropOutDoMaskCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1148   // When calculating 'dx', taking account of 'y'
1149   if (is_parameter_[0]) {
1150     // 'x' is parameter, so it should be in memory.
1151     is_inputs_should_in_memory_[0] = true;
1152     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1153       // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
1154       is_inputs_should_in_memory_[1] = true;
1155     }
1156   } else if (is_parameter_involve_[0]) {
1157     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1158       is_inputs_should_in_memory_[1] = true;
1159     }
1160   }
1161 
1162   if (!is_inputs_should_in_memory_[1]) {
1163     is_inputs_should_in_memory_[1] = is_parameter_[1];
1164   }
1165   is_inputs_should_in_memory_[DROPOUTDOMASK_INPUTS_SIZE - 1] = is_parameter_[DROPOUTDOMASK_INPUTS_SIZE - 1];
1166 }
1167 
IsDataParallel(const Shape & shape,const Shape & slice_shape,int64_t stage_id)1168 bool IsDataParallel(const Shape &shape, const Shape &slice_shape, int64_t stage_id) {
1169   CheckGlobalDeviceManager();
1170   MS_EXCEPTION_IF_NULL(g_device_manager);
1171   auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
1172   auto strategy0 = shape[0] / slice_shape[0];
1173 
1174   return (total_device_num == LongToSize(strategy0));
1175 }
1176 
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id) const1177 double ReduceSumCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1178                                          int64_t stage_id) const {
1179   double result = 0.0;
1180   TensorInfo input0 = inputs[0];
1181   TensorInfo output0 = outputs[0];
1182   Shape input0_shape = input0.shape();
1183   Shape input0_slice_shape = input0.slice_shape();
1184   if (cross_batch_ && IsDataParallel(input0_shape, input0_slice_shape, stage_id)) {
1185     return result;
1186   }
1187   std::vector<int64_t> dim_list = input0.reduce_dim();
1188   auto pos = std::find_if(dim_list.begin(), dim_list.end(), [input0_shape, input0_slice_shape](int64_t index) {
1189     return input0_shape[LongToSize(index)] != input0_slice_shape[LongToSize(index)];
1190   });
1191   if (pos != dim_list.end()) {
1192     result += ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
1193   }
1194 
1195   return result;
1196 }
1197 
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const1198 double ReduceSumCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1199                                           int64_t stage_id) const {
1200   double result = 0.0;
1201   if (is_parameter_[0]) {
1202     TensorInfo input_tensor_info = inputs[0];
1203     CheckGlobalDeviceManager();
1204     MS_EXCEPTION_IF_NULL(g_device_manager);
1205     auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
1206 
1207     Shape input_shape = input_tensor_info.shape();
1208     Shape input_slice_shape = input_tensor_info.slice_shape();
1209     int64_t used_device_num = 1;
1210     for (size_t i = 0; i < input_shape.size(); ++i) {
1211       used_device_num *= input_shape[i] / input_slice_shape[i];
1212     }
1213 
1214     if (total_device_num != LongToSize(used_device_num)) {
1215       result += ListProduct(input_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1216     }
1217   }
1218 
1219   return result;
1220 }
1221 
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id) const1222 double ReduceSumCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
1223                                                 const std::vector<TensorInfo> &outputs, int64_t stage_id) const {
1224   double result = 0.0;
1225   TensorInfo input0 = inputs[0];
1226   TensorInfo output0 = outputs[0];
1227   std::vector<int64_t> dim_list = input0.reduce_dim();
1228   Shape input0_slice_shape = input0.slice_shape();
1229   Shape input0_shape = input0.shape();
1230   if (!cross_batch_ || !IsDataParallel(input0_shape, input0_slice_shape, stage_id)) {
1231     auto pos = std::find_if(dim_list.begin(), dim_list.end(), [input0_shape, input0_slice_shape](int64_t index) {
1232       return input0_shape[LongToSize(index)] != input0_slice_shape[LongToSize(index)];
1233     });
1234     if (pos != dim_list.end()) {
1235       result += ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
1236     }
1237   }
1238   result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1239 
1240   return result;
1241 }
1242 
1243 // Not taking account of output
CalculateOutputInMemory()1244 void ReduceSumCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
1245 
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1246 void ReduceSumCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1247   // When calculating 'dx', taking account of 'y'
1248   if (is_parameter_[0]) {
1249     // 'x' is parameter, so it should be in memory.
1250     is_inputs_should_in_memory_[0] = true;
1251     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1252       // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
1253       is_inputs_should_in_memory_[1] = true;
1254     }
1255   } else if (is_parameter_involve_[0]) {
1256     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1257       is_inputs_should_in_memory_[1] = true;
1258     }
1259   }
1260 
1261   // Not taking account of 'y'
1262   if (!is_inputs_should_in_memory_[1]) {
1263     is_inputs_should_in_memory_[1] = is_parameter_[1];
1264   }
1265 }
1266 
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id) const1267 double ReduceMeanCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
1268                                                  const std::vector<TensorInfo> &outputs, int64_t stage_id) const {
1269   double result = 0.0;
1270   TensorInfo input0 = inputs[0];
1271   TensorInfo output0 = outputs[0];
1272   std::vector<int64_t> dim_list = input0.reduce_dim();
1273   Shape input0_slice_shape = input0.slice_shape();
1274   Shape input0_shape = input0.shape();
1275   if (!cross_batch_ || !IsDataParallel(input0_shape, input0_slice_shape, stage_id)) {
1276     auto pos = std::find_if(dim_list.begin(), dim_list.end(), [input0_shape, input0_slice_shape](int64_t index) {
1277       return input0_shape[LongToSize(index)] != input0_slice_shape[LongToSize(index)];
1278     });
1279     if (pos != dim_list.end()) {
1280       result += ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]) * 2.0;
1281     }
1282   }
1283   result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1284 
1285   return result;
1286 }
1287 
CalculateOutputInMemory()1288 void ReduceMinCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
1289 
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1290 void ReduceMinCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1291   // When calculating 'dx', taking account of 'y'
1292   if (is_parameter_[0]) {
1293     // 'x' is parameter, so it should be in memory.
1294     is_inputs_should_in_memory_[0] = true;
1295     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1296       // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
1297       is_inputs_should_in_memory_[1] = true;
1298     }
1299   } else if (is_parameter_involve_[0]) {
1300     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1301       is_inputs_should_in_memory_[1] = true;
1302     }
1303   }
1304 
1305   // Not taking account of 'y'
1306   if (!is_inputs_should_in_memory_[1]) {
1307     is_inputs_should_in_memory_[1] = is_parameter_[1];
1308   }
1309 }
1310 
CalculateOutputInMemory()1311 void ArgMaxWithValueCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
1312 
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1313 void ArgMaxWithValueCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1314   // When calculating 'dx', taking account of 'x'
1315   if (is_parameter_[0]) {
1316     is_inputs_should_in_memory_[0] = true;
1317   } else if (is_parameter_involve_[0]) {
1318     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
1319       is_inputs_should_in_memory_[0] = true;
1320     }
1321   }
1322 }
1323 
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const1324 double DropOutCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1325                                               int64_t) const {
1326   if (inputs.empty()) {
1327     return 0.0;
1328   }
1329   TensorInfo input0 = inputs[0];
1330   Shape input0_slice_shape = input0.slice_shape();
1331   return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * DROPOUT_COST_RATE;
1332 }
1333 
1334 // return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const1335 double GatherV2Cost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
1336                                         int64_t) const {
1337   // GatherV2Cost does not need communication in the forward phase
1338   return 0.0;
1339 }
1340 
1341 // return the per device communication cost in the backward phase.
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const1342 double GatherV2Cost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1343                                          int64_t stage_id) const {
1344   double result = 0.0;
1345   CheckGlobalDeviceManager();
1346   MS_EXCEPTION_IF_NULL(g_device_manager);
1347   auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
1348 
1349   for (size_t j = 0; j < inputs.size(); ++j) {
1350     if (!is_parameter_[j]) {
1351       continue;
1352     }
1353     TensorInfo input_a_tensor_info = inputs[j];
1354     Shape input_a_shape = input_a_tensor_info.shape();
1355     Shape input_a_slice_shape = input_a_tensor_info.slice_shape();
1356     int64_t used_device_num = 1;
1357     for (size_t i = 0; i < input_a_shape.size(); ++i) {
1358       used_device_num *= input_a_shape[i] / input_a_slice_shape[i];
1359     }
1360     if (total_device_num != LongToSize(used_device_num)) {
1361       result += ListProduct(input_a_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1362     }
1363   }
1364 
1365   return result;
1366 }
1367 
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const1368 double GatherV2Cost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1369                                                int64_t) const {
1370   // In forward phase, the computation cost = slice(A) + slice(B)
1371   Shape input0_slice_shape = inputs[0].slice_shape();
1372   Shape input1_slice_shape = inputs[1].slice_shape();
1373   double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
1374                   ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
1375   return result;
1376 }
1377 
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const1378 double GatherV2Cost::GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
1379                                                 int64_t) const {
1380   return 0.0;
1381 }
1382 
1383 // Not taking account of output
CalculateOutputInMemory()1384 void GatherV2Cost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
1385 
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1386 void GatherV2Cost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1387   // When calculating 'dx', taking account of 'y' and 'z'
1388   if (is_parameter_[0]) {
1389     // 'x' is parameter, so it should be in memory.
1390     is_inputs_should_in_memory_[0] = true;
1391     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1392       is_inputs_should_in_memory_[1] = true;
1393     }
1394     if ((prev_output_in_mem.find(GATHERV2_INPUTS_SIZE - 1) == prev_output_in_mem.end()) ||
1395         (!prev_output_in_mem.at(GATHERV2_INPUTS_SIZE - 1))) {
1396       is_inputs_should_in_memory_[GATHERV2_INPUTS_SIZE - 1] = true;
1397     }
1398   } else if (is_parameter_involve_[0]) {
1399     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1400       is_inputs_should_in_memory_[1] = true;
1401     }
1402     if ((prev_output_in_mem.find(GATHERV2_INPUTS_SIZE - 1) == prev_output_in_mem.end()) ||
1403         (!prev_output_in_mem.at(GATHERV2_INPUTS_SIZE - 1))) {
1404       is_inputs_should_in_memory_[GATHERV2_INPUTS_SIZE - 1] = true;
1405     }
1406   }
1407 
1408   if (!is_inputs_should_in_memory_[1]) {
1409     is_inputs_should_in_memory_[1] = is_parameter_[1];
1410   }
1411   if (!is_inputs_should_in_memory_[GATHERV2_INPUTS_SIZE - 1]) {
1412     is_inputs_should_in_memory_[GATHERV2_INPUTS_SIZE - 1] = is_parameter_[GATHERV2_INPUTS_SIZE - 1];
1413   }
1414 }
1415 
CalculateOutputInMemory()1416 void GetNextCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
1417 
CalculateInputsInMemory(const std::map<size_t,bool> &)1418 void GetNextCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
1419   if (is_inputs_should_in_memory_.size() == 0) {
1420     return;
1421   }
1422   is_inputs_should_in_memory_[0] = is_parameter_[0];
1423 }
1424 
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const1425 double DSDMatmulCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1426                                                 int64_t) const {
1427   double result = 0.0;
1428   if (inputs_type_lengths_.size() != inputs.size()) {
1429     MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for layer norm cost";
1430   }
1431 
1432   for (size_t index = 0; index < inputs.size(); ++index) {
1433     TensorInfo tensor_info = inputs[index];
1434     Shape slice_shape = tensor_info.slice_shape();
1435     result += ListProduct(slice_shape) * static_cast<double>(inputs_type_lengths_[index]);
1436   }
1437   return result;
1438 }
1439 
CalculateOutputInMemory()1440 void DSDMatmulCost::CalculateOutputInMemory() {
1441   is_output_should_in_memory_ =
1442     (std::find(is_parameter_involve_.cbegin(), is_parameter_involve_.cend(), true) != is_parameter_involve_.cend());
1443 }
1444 
CalculateInputsInMemory(const std::map<size_t,bool> &)1445 void DSDMatmulCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
1446   bool keep_mem =
1447     (std::find(is_parameter_.begin(), is_parameter_.end(), true) != is_parameter_.end()) ||
1448     (std::find(is_parameter_involve_.begin(), is_parameter_involve_.end(), true) != is_parameter_involve_.end());
1449   std::fill(is_inputs_should_in_memory_.begin(), is_inputs_should_in_memory_.end(), keep_mem);
1450 }
1451 
CalculateOutputInMemory()1452 void UniqueCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
1453 
CalculateInputsInMemory(const std::map<size_t,bool> &)1454 void UniqueCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
1455   is_inputs_should_in_memory_[0] = is_parameter_[0];
1456 }
1457 
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const1458 double LayerNormCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1459                                           int64_t stage_id) const {
1460   double result = 0.0;
1461   if (is_parameter_.size() != inputs.size()) {
1462     MS_LOG(EXCEPTION) << "Invalid parameter size " << is_parameter_.size() << " for layer norm cost";
1463   }
1464   if (inputs_type_lengths_.size() != inputs.size()) {
1465     MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for layer norm cost";
1466   }
1467 
1468   MS_EXCEPTION_IF_NULL(g_device_manager);
1469   auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
1470 
1471   for (size_t index = 0; index < inputs.size(); ++index) {
1472     if (is_parameter_[index]) {
1473       TensorInfo tensor_info = inputs[index];
1474       Shape shape = tensor_info.shape();
1475       Shape slice_shape = tensor_info.slice_shape();
1476       int64_t used_device_num = 1;
1477       for (size_t i = 0; i < shape.size(); ++i) {
1478         if (slice_shape[i] == 0) {
1479           MS_LOG(EXCEPTION) << "Invalid slice shape " << ShapeToString(slice_shape);
1480         }
1481         used_device_num *= shape[i] / slice_shape[i];
1482       }
1483       if (total_device_num != LongToSize(used_device_num)) {
1484         result += ListProduct(slice_shape) * static_cast<double>(inputs_type_lengths_[index]);
1485       }
1486     }
1487   }
1488   return result;
1489 }
1490 
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const1491 double LayerNormCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1492                                                 int64_t) const {
1493   double result = 0.0;
1494   if (inputs_type_lengths_.size() != inputs.size()) {
1495     MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for layer norm cost";
1496   }
1497 
1498   for (size_t index = 0; index < inputs.size(); ++index) {
1499     TensorInfo tensor_info = inputs[index];
1500     Shape slice_shape = tensor_info.slice_shape();
1501     result += ListProduct(slice_shape) * static_cast<double>(inputs_type_lengths_[index]);
1502   }
1503   return result;
1504 }
1505 
CalculateOutputInMemory()1506 void LayerNormCost::CalculateOutputInMemory() {
1507   is_output_should_in_memory_ =
1508     is_parameter_involve_[0] || is_parameter_involve_[1] || is_parameter_involve_[LAYERNORM_INPUTS_SIZE - 1];
1509 }
1510 
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1511 void LayerNormCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1512   // When calculating 'dx', taking account of both 'x' and 'y'
1513   // When calculating 'dy', taking account of both 'x' and 'y'
1514   if (is_parameter_involve_[0] || is_parameter_involve_[1]) {
1515     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
1516       is_inputs_should_in_memory_[0] = true;
1517     }
1518     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1519       is_inputs_should_in_memory_[1] = true;
1520     }
1521   }
1522   is_inputs_should_in_memory_[LAYERNORM_INPUTS_SIZE - 1] = is_parameter_[LAYERNORM_INPUTS_SIZE - 1];
1523 }
1524 
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const1525 double UniqueCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const {
1526   return 0.0;
1527 }
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const1528 double UniqueCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1529                                        int64_t stage_id) const {
1530   double result = 0.0;
1531   if (is_parameter_[0]) {
1532     TensorInfo input = inputs[0];
1533     CheckGlobalDeviceManager();
1534     MS_EXCEPTION_IF_NULL(g_device_manager);
1535     auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
1536     Shape input_shape = input.shape();
1537     Shape input_slice_shape = input.slice_shape();
1538     int64_t used_device_num = 1;
1539     for (size_t i = 0; i < input_shape.size(); ++i) {
1540       used_device_num *= input_shape[i] / input_slice_shape[i];
1541     }
1542     if (total_device_num != LongToSize(used_device_num)) {
1543       result = ListProduct(input_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1544     }
1545   }
1546   return result;
1547 }
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const1548 double UniqueCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1549                                              int64_t) const {
1550   // In forward phase, the computation cost = slice(A) + slice(B)
1551   Shape input_slice_shape = inputs[0].slice_shape();
1552   double result = ListProduct(input_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1553   return result;
1554 }
GetBackwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const1555 double UniqueCost::GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1556                                               int64_t stage_id) const {
1557   // In backward phase, the computation cost = (0 or 1) allreduce(slice(B))
1558   double result = 0.0;
1559   if (is_parameter_[0]) {
1560     TensorInfo input = inputs[0];  // tensor B
1561     CheckGlobalDeviceManager();
1562     MS_EXCEPTION_IF_NULL(g_device_manager);
1563     auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
1564 
1565     Shape input_shape = input.shape();
1566     Shape input_slice_shape = input.slice_shape();
1567     int64_t used_device_num = 1;
1568     for (size_t i = 0; i < input_shape.size(); ++i) {
1569       used_device_num *= input_shape[i] / input_slice_shape[i];
1570     }
1571 
1572     if (total_device_num != LongToSize(used_device_num)) {
1573       result += ListProduct(input_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1574     }
1575   }
1576   return result;
1577 }
1578 
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1579 double GatherCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1580                                       int64_t) const {
1581   double result = 0.0;
1582   if (outputs_type_lengths_.size() != outputs.size()) {
1583     MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for gatherv2 cost";
1584   }
1585   // don't split axis
1586   if (strategy_.at(LongToSize(axis_)) == 1) {
1587     return result;
1588   }
1589 
1590   // split axis
1591   auto param_shape = inputs[0].slice_shape();
1592   auto index_shape = inputs[1].slice_shape();
1593   Shape reducescatter_shape = index_shape;
1594   if (param_shape.size() == 2) {
1595     reducescatter_shape.push_back(param_shape.at(LongToSize(1 - axis_)));
1596   }
1597   result += ListProduct(reducescatter_shape) * static_cast<double>(outputs_type_lengths_[0]);
1598   return result;
1599 }
1600 
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const1601 double GatherCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1602                                        int64_t stage_id) const {
1603   double result = 0.0;
1604   CheckGlobalDeviceManager();
1605   MS_EXCEPTION_IF_NULL(g_device_manager);
1606   auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
1607 
1608   for (size_t j = 0; j < inputs.size(); ++j) {
1609     if (!is_parameter_[j]) {
1610       continue;
1611     }
1612     TensorInfo input_a_tensor_info = inputs[j];
1613     Shape input_a_shape = input_a_tensor_info.shape();
1614     Shape input_a_slice_shape = input_a_tensor_info.slice_shape();
1615     int64_t used_device_num = 1;
1616     for (size_t i = 0; i < input_a_shape.size(); ++i) {
1617       used_device_num *= input_a_shape[i] / input_a_slice_shape[i];
1618     }
1619     if (total_device_num != LongToSize(used_device_num)) {
1620       result += ListProduct(input_a_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1621     }
1622   }
1623   return result;
1624 }
1625 
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const1626 double UniformCandidateSamplerCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
1627                                                               const std::vector<TensorInfo> &, int64_t) const {
1628   Shape input0_slice_shape = inputs[0].slice_shape();
1629   if (inputs_type_lengths_.size() != inputs.size()) {
1630     MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size()
1631                       << " for UniformCandidateSampler cost";
1632   }
1633 
1634   double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1635 
1636   return result;
1637 }
1638 
CalculateOutputInMemory()1639 void UniformCandidateSamplerCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
1640 
CalculateInputsInMemory(const std::map<size_t,bool> &)1641 void UniformCandidateSamplerCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
1642   is_inputs_should_in_memory_[0] = is_parameter_[0];
1643 }
1644 
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const1645 double GatherCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1646                                              int64_t) const {
1647   double result = 0.0;
1648   Shape input0_slice_shape = inputs[0].slice_shape();
1649   Shape input1_slice_shape = inputs[1].slice_shape();
1650   if (inputs_type_lengths_.size() != inputs.size()) {
1651     MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for gatherv2 cost";
1652   }
1653   // don't split axis
1654   if (strategy_.at(LongToSize(axis_)) == 1) {
1655     result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
1656               ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
1657   } else {
1658     // split axis
1659     result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * GATHERV2_COST_WEIGHT0 +
1660               ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * GATHERV2_COST_WEIGHT1;
1661   }
1662 
1663   return result;
1664 }
1665 
GetBackwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1666 double GatherCost::GetBackwardComputationCost(const std::vector<TensorInfo> &inputs,
1667                                               const std::vector<TensorInfo> &outputs, int64_t) const {
1668   double result = 0.0;
1669   Shape input1_slice_shape = inputs[1].slice_shape();
1670   Shape output0_slice_shape = outputs[0].slice_shape();
1671   // don't split axis
1672   if (strategy_.at(LongToSize(axis_)) == 1) {
1673     result += ListProduct(output0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1674   } else {
1675     // split axis
1676     result += ListProduct(output0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * GATHERV2_COST_WEIGHT2 +
1677               ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * GATHERV2_COST_WEIGHT3;
1678   }
1679 
1680   return result;
1681 }
1682 
1683 // The forward communication is determined by whether the slice is column split or row split
1684 // The number of segments is actually the shape[0] of the output, which is the cost of the AllReduce
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1685 double UnsortedSegmentSumCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs,
1686                                                   const std::vector<TensorInfo> &outputs, int64_t) const {
1687   TensorInfo input0 = inputs[0];
1688   TensorInfo input1 = inputs[1];
1689   TensorInfo output0 = outputs[0];
1690   Shape input0_shape = input0.shape();
1691   Shape input0_slice_shape = inputs[0].slice_shape();
1692   double result = 0.0;
1693   if (inputs_type_lengths_.size() != inputs.size()) {
1694     MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for UnsortedSegmentSum cost";
1695   }
1696   // If the shape b is not the same as the shape a, we regard it as column slice
1697   for (size_t i = 0; i < input1.shape().size(); ++i) {
1698     if (input0_shape[i] != input0_slice_shape[i]) {
1699       result = ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
1700       return result;
1701     }
1702   }
1703   return result;
1704 }
1705 
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1706 double UnsortedSegmentSumCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs,
1707                                                    const std::vector<TensorInfo> &outputs, int64_t) const {
1708   TensorInfo input0 = inputs[0];
1709   TensorInfo input1 = inputs[1];
1710   TensorInfo output0 = outputs[0];
1711   Shape input0_shape = input0.shape();
1712   Shape input0_slice_shape = inputs[0].slice_shape();
1713   double result = 0.0;
1714   if (inputs_type_lengths_.size() != inputs.size()) {
1715     MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for UnsortedSegmentSum cost";
1716   }
1717   if (is_parameter_[0]) {
1718     // If the forward process has a AllReduce, then the backward also needs one.
1719     for (size_t i = 0; i < input1.shape().size(); ++i) {
1720       if (input0_shape[i] != input0_slice_shape[i]) {
1721         result = ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
1722         return result;
1723       }
1724     }
1725   }
1726   return result;
1727 }
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1728 double UnsortedSegmentSumCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
1729                                                          const std::vector<TensorInfo> &outputs, int64_t) const {
1730   // In forward phase, the computation cost = slice(A) + slice(B)
1731   Shape input0_slice_shape = inputs[0].slice_shape();
1732   Shape input1_slice_shape = inputs[1].slice_shape();
1733   Shape output_slice_shape = outputs[0].slice_shape();
1734   double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
1735                   ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) +
1736                   ListProduct(output_slice_shape) * static_cast<double>(outputs_type_lengths_[0]);
1737   return result;
1738 }
1739 
1740 // Not taking account of output
CalculateOutputInMemory()1741 void UnsortedSegmentSumCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
1742 
1743 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1744 void UnsortedSegmentSumCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1745   // When calculating 'dx', taking account of 'y'
1746   if (is_parameter_[0]) {
1747     is_inputs_should_in_memory_[0] = true;
1748     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1749       is_inputs_should_in_memory_[1] = true;
1750     }
1751   } else if (is_parameter_involve_[0]) {
1752     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1753       is_inputs_should_in_memory_[1] = true;
1754     }
1755   }
1756 
1757   if (!is_inputs_should_in_memory_[1]) {
1758     is_inputs_should_in_memory_[1] = is_parameter_[1];
1759   }
1760   is_inputs_should_in_memory_[UNSORTEDSEGMENTSUM_INPUTS_SIZE - 1] = is_parameter_[UNSORTEDSEGMENTSUM_INPUTS_SIZE - 1];
1761 }
1762 
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1763 double UnsortedSegmentMinCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs,
1764                                                   const std::vector<TensorInfo> &outputs, int64_t) const {
1765   TensorInfo input0 = inputs[0];
1766   TensorInfo input1 = inputs[1];
1767   TensorInfo output0 = outputs[0];
1768   Shape input0_shape = input0.shape();
1769   Shape input0_slice_shape = inputs[0].slice_shape();
1770   double result = 0.0;
1771   if (inputs_type_lengths_.size() != inputs.size()) {
1772     MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size()
1773                       << " for UnsortedSegmentMinCost cost";
1774   }
1775   // If the shape b is not the same as the shape a, we regard it as column slice
1776   // The cost is a AllGather operation, the shape is the same as the output of UnsortedSegmentMin.
1777   for (size_t i = 0; i < input1.shape().size(); ++i) {
1778     if (input0_shape[i] != input0_slice_shape[i]) {
1779       result = ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
1780       return result;
1781     }
1782   }
1783   return result;
1784 }
1785 
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1786 double UnsortedSegmentMinCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs,
1787                                                    const std::vector<TensorInfo> &outputs, int64_t) const {
1788   TensorInfo input0 = inputs[0];
1789   TensorInfo input1 = inputs[1];
1790   TensorInfo output0 = outputs[0];
1791   Shape input0_shape = input0.shape();
1792   Shape input0_slice_shape = inputs[0].slice_shape();
1793   double result = 0.0;
1794   if (inputs_type_lengths_.size() != inputs.size()) {
1795     MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size()
1796                       << " for UnsortedSegmentMinCost cost";
1797   }
1798   if (is_parameter_[0]) {
1799     // If the forward process has a AllGather, then the backward also needs one ReduceScatter.
1800     for (size_t i = 0; i < input1.shape().size(); ++i) {
1801       if (input0_shape[i] != input0_slice_shape[i]) {
1802         result = ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
1803         return result;
1804       }
1805     }
1806   }
1807   return result;
1808 }
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1809 double UnsortedSegmentMinCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
1810                                                          const std::vector<TensorInfo> &outputs, int64_t) const {
1811   // In forward phase, the computation cost = slice(A) + slice(B)
1812   Shape input0_slice_shape = inputs[0].slice_shape();
1813   Shape input1_slice_shape = inputs[1].slice_shape();
1814   Shape output_slice_shape = outputs[0].slice_shape();
1815   // The forward operation is UnsortedSegmentMin + ReudceMin
1816   double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
1817                   ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) +
1818                   ListProduct(output_slice_shape) * static_cast<double>(outputs_type_lengths_[0]) +
1819                   ListProduct(output_slice_shape) * static_cast<double>(outputs_type_lengths_[0]);  // ReduceMin
1820   return result;
1821 }
1822 
1823 // Taking account of output
CalculateOutputInMemory()1824 void UnsortedSegmentMinCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
1825 
1826 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1827 void UnsortedSegmentMinCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1828   // When calculating 'dx', taking account of 'x', 'y' and 'z'
1829   if (is_parameter_involve_[0]) {
1830     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
1831       is_inputs_should_in_memory_[0] = true;
1832     }
1833     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1834       is_inputs_should_in_memory_[1] = true;
1835     }
1836     if ((prev_output_in_mem.find(UNSORTEDSEGMENTSUM_INPUTS_SIZE - 1) == prev_output_in_mem.end()) ||
1837         (!prev_output_in_mem.at(UNSORTEDSEGMENTSUM_INPUTS_SIZE - 1))) {
1838       is_inputs_should_in_memory_[UNSORTEDSEGMENTSUM_INPUTS_SIZE - 1] = true;
1839     }
1840   }
1841   if (!is_inputs_should_in_memory_[1]) {
1842     is_inputs_should_in_memory_[1] = is_parameter_[1];
1843   }
1844   if (!is_inputs_should_in_memory_[UNSORTEDSEGMENTSUM_INPUTS_SIZE - 1]) {
1845     is_inputs_should_in_memory_[UNSORTEDSEGMENTSUM_INPUTS_SIZE - 1] = is_parameter_[UNSORTEDSEGMENTSUM_INPUTS_SIZE - 1];
1846   }
1847 }
1848 
1849 // Not taking account of output
CalculateOutputInMemory()1850 void VirtualDatasetCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
1851 
1852 // Not taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> &)1853 void VirtualDatasetCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
1854   for (size_t i = 0; i < is_inputs_should_in_memory_.size(); ++i) {
1855     is_inputs_should_in_memory_[i] = is_parameter_[i];
1856   }
1857 }
1858 
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const1859 double MatmulDDSCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1860                                                 int64_t) const {
1861   double result = 0.0;
1862   if (inputs_type_lengths_.size() != inputs.size()) {
1863     MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for layer norm cost";
1864   }
1865 
1866   for (size_t index = 0; index < inputs.size(); ++index) {
1867     TensorInfo tensor_info = inputs[index];
1868     Shape slice_shape = tensor_info.slice_shape();
1869     result += ListProduct(slice_shape) * static_cast<double>(inputs_type_lengths_[index]);
1870   }
1871   return result;
1872 }
1873 
1874 // Not taking account of output
CalculateOutputInMemory()1875 void MatmulDDSCost::CalculateOutputInMemory() {
1876   is_output_should_in_memory_ =
1877     (std::find(is_parameter_involve_.cbegin(), is_parameter_involve_.cend(), true) != is_parameter_involve_.cend());
1878 }
1879 
1880 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> &)1881 void MatmulDDSCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
1882   bool keep_mem =
1883     (std::find(is_parameter_.begin(), is_parameter_.end(), true) != is_parameter_.end()) ||
1884     (std::find(is_parameter_involve_.begin(), is_parameter_involve_.end(), true) != is_parameter_involve_.end());
1885   std::fill(is_inputs_should_in_memory_.begin(), is_inputs_should_in_memory_.end(), keep_mem);
1886 }
1887 
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const1888 double ScatterMathOpsCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
1889                                                      const std::vector<TensorInfo> &, int64_t) const {
1890   double result = 0.0;
1891   Shape input0_slice_shape = inputs[0].slice_shape();
1892   Shape input1_slice_shape = inputs[1].slice_shape();
1893   Shape input2_slice_shape = inputs[2].slice_shape();
1894   if (inputs_type_lengths_.size() != inputs.size()) {
1895     MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for gatherv2 cost";
1896   }
1897   // don't split axis
1898   if (!is_split_axis_) {
1899     result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
1900               ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) +
1901               ListProduct(input2_slice_shape) * static_cast<double>(inputs_type_lengths_[2]);
1902   } else {
1903     // split axis
1904     result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * input_coefficient_ +
1905               ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * indices_coefficient_ +
1906               ListProduct(input2_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * updates_coefficient_;
1907   }
1908 
1909   return result;
1910 }
1911 
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const1912 double TensorScatterOpsCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1913                                                  int64_t stage_id) const {
1914   double result = 0.0;
1915   CheckGlobalDeviceManager();
1916   MS_EXCEPTION_IF_NULL(g_device_manager);
1917   auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
1918 
1919   for (size_t j = 0; j < inputs.size(); ++j) {
1920     if (!is_parameter_[j]) {
1921       continue;
1922     }
1923     TensorInfo input_a_tensor_info = inputs[j];
1924     Shape input_a_shape = input_a_tensor_info.shape();
1925     Shape input_a_slice_shape = input_a_tensor_info.slice_shape();
1926     int64_t used_device_num = 1;
1927     for (size_t i = 0; i < input_a_shape.size(); ++i) {
1928       used_device_num *= input_a_shape[i] / input_a_slice_shape[i];
1929     }
1930     if (total_device_num != LongToSize(used_device_num)) {
1931       result += ListProduct(input_a_slice_shape) * static_cast<double>(inputs_type_lengths_[j]);
1932     }
1933   }
1934   return result;
1935 }
1936 
GetBackwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1937 double TensorScatterOpsCost::GetBackwardComputationCost(const std::vector<TensorInfo> &inputs,
1938                                                         const std::vector<TensorInfo> &outputs, int64_t) const {
1939   double result = 0.0;
1940   Shape input0_slice_shape = inputs[0].slice_shape();
1941   Shape input1_slice_shape = inputs[1].slice_shape();
1942   Shape input2_slice_shape = inputs[2].slice_shape();  // equal to output shape
1943   // brop func using 3 times input/out in average.
1944   // don't split axis
1945   if (!is_split_axis_) {
1946     result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
1947               ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) +
1948               ListProduct(input2_slice_shape) * static_cast<double>(inputs_type_lengths_[2]);
1949     result *= 3;
1950 
1951   } else {
1952     // split axis
1953     result +=
1954       ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * (input_coefficient_ + 3) +
1955       ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * (indices_coefficient_ + 3) +
1956       ListProduct(input2_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * (updates_coefficient_ + 3);
1957   }
1958 
1959   return result;
1960 }
1961 
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1962 double CropAndResizeCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs,
1963                                              const std::vector<TensorInfo> &outputs, int64_t) const {
1964   double result = 0.0;
1965   if (outputs_type_lengths_.size() != outputs.size()) {
1966     MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for CropAndResize cost.";
1967   }
1968 
1969   // don't split the batch
1970   if (strategy_[0] == 1) {
1971     return result;
1972   }
1973 
1974   // split batch
1975   auto x_shape = inputs[0].slice_shape();
1976   auto box_shape = inputs[0].slice_shape();
1977   Shape reduce_sum_shape = {box_shape[0], crop_size_[0], crop_size_[1], x_shape[3]};
1978   result += ListProduct(reduce_sum_shape) * static_cast<double>(outputs_type_lengths_[0]);
1979   return result;
1980 }
1981 
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const1982 double CropAndResizeCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1983                                               int64_t stage_id) const {
1984   double result = 0.0;
1985   CheckGlobalDeviceManager();
1986   MS_EXCEPTION_IF_NULL(g_device_manager);
1987   auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
1988 
1989   for (size_t i = 0; i < inputs.size(); ++i) {
1990     if (is_parameter_[i]) {
1991       continue;
1992     }
1993     TensorInfo input_tensor_info = inputs[i];
1994     Shape input_shape = input_tensor_info.shape();
1995     Shape input_slice_shape = input_tensor_info.slice_shape();
1996     int64_t used_device_num = 1;
1997     for (size_t j = 0; j < input_shape.size(); ++j) {
1998       used_device_num *= input_shape[j] / input_slice_shape[j];
1999     }
2000     if (total_device_num != LongToSize(used_device_num)) {
2001       result += ListProduct(input_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
2002     }
2003   }
2004   return result;
2005 }
2006 
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const2007 double CropAndResizeCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
2008                                                     const std::vector<TensorInfo> &, int64_t) const {
2009   if (inputs_type_lengths_.size() != inputs.size()) {
2010     MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for CropAndResize cost.";
2011   }
2012 
2013   Shape input0_slice_shape = inputs.at(0).slice_shape();
2014   Shape input1_slice_shape = inputs.at(1).slice_shape();
2015   Shape input2_slice_shape = inputs.at(2).slice_shape();
2016   double result = 0.0;
2017   // don't split batch
2018   if (strategy_[0] == 1) {
2019     result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
2020               ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) +
2021               ListProduct(input2_slice_shape) * static_cast<double>(inputs_type_lengths_[2]);
2022   } else {
2023     // split batch
2024     result +=
2025       ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * CROP_AND_RESIZE_COST_WEIGHT0 +
2026       ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * CROP_AND_RESIZE_COST_WEIGHT1 +
2027       ListProduct(input2_slice_shape) * static_cast<double>(inputs_type_lengths_[2]) * CROP_AND_RESIZE_COST_WEIGHT2;
2028   }
2029 
2030   return result;
2031 }
2032 
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> & outputs,int64_t) const2033 double CropAndResizeCost::GetBackwardComputationCost(const std::vector<TensorInfo> &,
2034                                                      const std::vector<TensorInfo> &outputs, int64_t) const {
2035   Shape output0_slice_shape = outputs[0].slice_shape();
2036   double result = 0.0;
2037   // don't split batch
2038   if (strategy_.at(0) == 1) {
2039     result += ListProduct(output0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
2040   } else {
2041     // split batch
2042     result +=
2043       ListProduct(output0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * CROP_AND_RESIZE_COST_WEIGHT3;
2044   }
2045 
2046   return result;
2047 }
2048 
2049 // Not taking account of output
CalculateOutputInMemory()2050 void CropAndResizeCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
2051 
2052 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)2053 void CropAndResizeCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
2054   for (size_t i = 0; i < CROP_AND_RESIZE_INPUTS_SIZE; ++i) {
2055     is_inputs_should_in_memory_[i] = is_parameter_[i];
2056   }
2057 
2058   // When calculating 'dx', taking account of 'y' and 'z'
2059   if (is_parameter_[0] || is_parameter_involve_[0]) {
2060     if (prev_output_in_mem.find(1) == prev_output_in_mem.end() || !prev_output_in_mem.at(1)) {
2061       is_inputs_should_in_memory_[1] = true;
2062     }
2063     if (prev_output_in_mem.find(CROP_AND_RESIZE_INPUTS_SIZE - 1) == prev_output_in_mem.end() ||
2064         !prev_output_in_mem.at(CROP_AND_RESIZE_INPUTS_SIZE - 1)) {
2065       is_inputs_should_in_memory_[CROP_AND_RESIZE_INPUTS_SIZE - 1] = true;
2066     }
2067   }
2068 }
2069 
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const2070 double ROIAlignCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
2071                                         int64_t) const {
2072   double result = 0.0;
2073   if (outputs_type_lengths_.size() != outputs.size()) {
2074     MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for CropAndResize cost.";
2075   }
2076 
2077   // don't split the batch
2078   if (strategy_[0] == 1) {
2079     return result;
2080   }
2081 
2082   // split batch
2083   auto features_shape = inputs[0].slice_shape();
2084   auto rois_shape = inputs[0].slice_shape();
2085   Shape reduce_sum_shape = {rois_shape[0], features_shape[1], pooled_shape_[0], pooled_shape_[1]};
2086   result += ListProduct(reduce_sum_shape) * static_cast<double>(outputs_type_lengths_[0]);
2087   return result;
2088 }
2089 
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const2090 double ROIAlignCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
2091                                                int64_t) const {
2092   if (inputs_type_lengths_.size() != inputs.size()) {
2093     MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for CropAndResize cost.";
2094   }
2095 
2096   Shape input0_slice_shape = inputs.at(0).slice_shape();
2097   Shape input1_slice_shape = inputs.at(1).slice_shape();
2098   double result = 0.0;
2099   // don't split batch
2100   if (strategy_[0] == 1) {
2101     result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
2102               ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
2103   } else {
2104     // split batch
2105     result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * ROI_ALIGN_COST_WEIGHT0 +
2106               ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * ROI_ALIGN_COST_WEIGHT1;
2107   }
2108 
2109   return result;
2110 }
2111 
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> & outputs,int64_t) const2112 double ROIAlignCost::GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &outputs,
2113                                                 int64_t) const {
2114   Shape output0_slice_shape = outputs[0].slice_shape();
2115   double result = 0.0;
2116   // don't split batch
2117   if (strategy_.at(0) == 1) {
2118     result += ListProduct(output0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
2119   } else {
2120     // split batch
2121     result += ListProduct(output0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * ROI_ALIGN_COST_WEIGHT2;
2122   }
2123 
2124   return result;
2125 }
2126 
2127 // Taking account of output
CalculateOutputInMemory()2128 void ROIAlignCost::CalculateOutputInMemory() { is_output_should_in_memory_ = true; }
2129 
2130 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)2131 void ROIAlignCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
2132   for (size_t i = 0; i < ROI_ALIGN_INPUTS_SIZE; ++i) {
2133     is_inputs_should_in_memory_[i] = is_parameter_[i];
2134   }
2135 
2136   // When calculating 'dx', taking account of 'y' and 'z'
2137   if (is_parameter_[0] || is_parameter_involve_[0]) {
2138     if (prev_output_in_mem.find(ROI_ALIGN_INPUTS_SIZE - 1) == prev_output_in_mem.end() ||
2139         !prev_output_in_mem.at(ROI_ALIGN_INPUTS_SIZE - 1)) {
2140       is_inputs_should_in_memory_[ROI_ALIGN_INPUTS_SIZE - 1] = true;
2141     }
2142   }
2143 }
2144 }  // namespace parallel
2145 }  // namespace mindspore
2146