• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/auto_parallel/operator_costmodel.h"
18 
19 #include <algorithm>
20 #include <random>
21 #include "frontend/parallel/device_matrix.h"
22 #include "frontend/parallel/tensor_layout/tensor_redistribution.h"
23 
24 namespace mindspore {
25 namespace parallel {
set_is_parameter(const std::vector<bool> & is_parameter)26 void OperatorCost::set_is_parameter(const std::vector<bool> &is_parameter) { is_parameter_ = is_parameter; }
27 
set_is_parameter_involve(const std::vector<bool> & is_parameter_inv)28 void OperatorCost::set_is_parameter_involve(const std::vector<bool> &is_parameter_inv) {
29   is_parameter_involve_ = is_parameter_inv;
30   is_inputs_should_in_memory_ = std::vector<bool>(is_parameter_involve_.size(), false);
31 }
32 
set_output_parameter_involve(int64_t output_para)33 void OperatorCost::set_output_parameter_involve(int64_t output_para) { output_parameter_involve_ = output_para; }
34 
SetInputAndOutputTypeLength(const std::vector<size_t> & input_lengths,const std::vector<size_t> & output_lengths)35 void OperatorCost::SetInputAndOutputTypeLength(const std::vector<size_t> &input_lengths,
36                                                const std::vector<size_t> &output_lengths) {
37   inputs_type_lengths_ = input_lengths;
38   outputs_type_lengths_ = output_lengths;
39 }
40 
set_output_critical(int64_t critical)41 void OperatorCost::set_output_critical(int64_t critical) { is_outputs_critical_ = critical; }
42 
GetMemoryCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs) const43 double OperatorCost::GetMemoryCost(const std::vector<TensorInfo> &inputs,
44                                    const std::vector<TensorInfo> &outputs) const {
45   return GetInputMemoryCost(inputs, outputs) + GetOutputMemoryCost(inputs, outputs);
46 }
47 
GetInputMemoryCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &) const48 double OperatorCost::GetInputMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &) const {
49   double result = 0.0;
50   for (size_t i = 0; i < inputs.size(); ++i) {
51     if (is_inputs_should_in_memory_[i]) {
52       result += ListProduct(inputs[i].slice_shape()) * static_cast<double>(inputs_type_lengths_[i]);
53     }
54   }
55   return result;
56 }
57 
GetOutputMemoryCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> & outputs) const58 double OperatorCost::GetOutputMemoryCost(const std::vector<TensorInfo> &,
59                                          const std::vector<TensorInfo> &outputs) const {
60   double result = 0.0;
61   if (is_output_should_in_memory_) {
62     // When this operator has multiple outputs, they all contributes to the memory.
63     for (size_t i = 0; i < outputs.size(); ++i) {
64       result += ListProduct(outputs[i].slice_shape()) * static_cast<double>(outputs_type_lengths_[i]);
65     }
66   }
67   return result;
68 }
69 
GetMemoryCostForInference(const std::vector<TensorInfo> &,const std::vector<TensorInfo> & outputs) const70 double OperatorCost::GetMemoryCostForInference(const std::vector<TensorInfo> &,
71                                                const std::vector<TensorInfo> &outputs) const {
72   double result = 0.0;
73   if (is_outputs_critical_ == -1) {
74     MS_LOG(EXCEPTION) << "The critical flag is not set.";
75   }
76   if (is_outputs_critical_ == 1) {
77     for (size_t i = 0; i < outputs.size(); ++i) {
78       result += ListProduct(outputs[i].slice_shape()) * static_cast<double>(outputs_type_lengths_[i]);
79     }
80   }
81   return result;
82 }
83 
84 // return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const85 double MatMulCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
86                                       int64_t) const {
87   TensorInfo input0 = inputs[0];
88   TensorInfo output0 = outputs[0];
89   Shape input0_shape = input0.shape();
90   Shape input0_slice_shape = input0.slice_shape();
91   if (input0_shape[input0_shape.size() - 1] == input0_slice_shape[input0_slice_shape.size() - 1]) {
92     // If the reduced dimension has not been partitioned, then there is no communication cost.
93     return 0.0;
94   } else {
95     // Else, the communication cost is the size (number of bytes) of a slice of output tensor.
96     return ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
97   }
98 }
99 
100 // return the per device communication cost in the forward phase.
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const101 double MatMulCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
102                                        int64_t stage_id) const {
103   // In backward phase, the communication cost is incurred only when tensor B is a Parameter and tensor B does not
104   // fully utilize all devices
105   double result = 0.0;
106   if (is_parameter_[1]) {
107     TensorInfo input1 = inputs[1];  // tensor B
108     CheckGlobalDeviceManager();
109     MS_EXCEPTION_IF_NULL(g_device_manager);
110     auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
111 
112     Shape input1_shape = input1.shape();
113     Shape input1_slice_shape = input1.slice_shape();
114     int64_t used_device_num = 1;
115     for (size_t i = 0; i < input1_shape.size(); ++i) {
116       used_device_num *= input1_shape[i] / input1_slice_shape[i];
117     }
118 
119     if (total_device_num != LongToSize(used_device_num))
120       result += ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
121   }
122 
123   return result;
124 }
125 
126 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
127 // this operator uses
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const128 double MatMulCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
129                                              const std::vector<TensorInfo> &outputs, int64_t) const {
130   // In forward phase, the computation cost = slice(A) + slice(B) + (0 or 1) allreduce(slice(C))
131   double result = 0.0;
132   TensorInfo output0 = outputs[0];
133   Shape input0_slice_shape = inputs[0].slice_shape();
134   Shape input1_slice_shape = inputs[1].slice_shape();
135   Shape input0_shape = inputs[0].shape();
136   if (input0_shape[input0_shape.size() - 1] != input0_slice_shape[input0_slice_shape.size() - 1]) {
137     // If the reduced dimension has been partitioned, then there is no communication cost.
138     result += ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
139   }
140   result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
141             ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
142   return result;
143 }
144 
145 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
146 // this operator uses
GetBackwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const147 double MatMulCost::GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
148                                               int64_t stage_id) const {
149   // In backward phase, the computation cost = (0 or 1) allreduce(slice(B))
150   double result = 0.0;
151   if (is_parameter_[1]) {
152     TensorInfo input1 = inputs[1];  // tensor B
153     CheckGlobalDeviceManager();
154     MS_EXCEPTION_IF_NULL(g_device_manager);
155     auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
156 
157     Shape input1_shape = input1.shape();
158     Shape input1_slice_shape = input1.slice_shape();
159     int64_t used_device_num = 1;
160     for (size_t i = 0; i < input1_shape.size(); ++i) {
161       used_device_num *= input1_shape[i] / input1_slice_shape[i];
162     }
163 
164     if (total_device_num != LongToSize(used_device_num))
165       result += ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
166   }
167 
168   return result;
169 }
170 
171 // Not taking account of output
CalculateOutputInMemory()172 void MatMulCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
173 
174 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)175 void MatMulCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
176   if (is_parameter_[0]) {
177     is_inputs_should_in_memory_[0] = true;
178     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
179       is_inputs_should_in_memory_[1] = true;
180     }
181   } else if (is_parameter_involve_[0]) {
182     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
183       is_inputs_should_in_memory_[1] = true;
184     }
185   }
186 
187   if (is_parameter_[1]) {
188     is_inputs_should_in_memory_[1] = true;
189     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
190       is_inputs_should_in_memory_[0] = true;
191     }
192   } else if (is_parameter_involve_[1]) {
193     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
194       is_inputs_should_in_memory_[0] = true;
195     }
196   }
197 }
198 
199 // Return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const200 double CastCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const {
201   // ReLU is the element-wise operator, thus it does not need communication in the forward phase
202   return 0.0;
203 }
204 
205 // Return the per device communication cost in the backward phase.
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const206 double CastCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
207                                      int64_t stage_id) const {
208   double result = 0.0;
209   if (is_parameter_[0]) {
210     TensorInfo input1 = inputs[0];
211     MS_EXCEPTION_IF_NULL(g_device_manager);
212     auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
213     Shape input1_shape = input1.shape();
214     Shape input1_slice_shape = input1.slice_shape();
215     int64_t used_device_num = 1;
216     for (size_t i = 0; i < input1_shape.size(); ++i) {
217       used_device_num *= input1_shape[i] / input1_slice_shape[i];
218     }
219     if (total_device_num != LongToSize(used_device_num)) {
220       result = ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
221     }
222   }
223   return result;
224 }
225 
226 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
227 // this operator uses
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const228 double CastCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
229                                            int64_t) const {
230   TensorInfo input0 = inputs[0];
231   Shape input0_slice_shape = input0.slice_shape();
232   return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
233 }
234 
235 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
236 // this operator uses
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const237 double CastCost::GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
238                                             int64_t) const {
239   return 0.0;
240 }
241 
242 // Not taking account of output
CalculateOutputInMemory()243 void CastCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
244 
245 // Not taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> &)246 void CastCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
247   is_inputs_should_in_memory_[0] = is_parameter_[0];
248 }
249 
250 // Taking account of output
CalculateOutputInMemory()251 void SqrtCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
252 
253 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)254 void GeLUCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
255   if (is_parameter_[0]) {
256     is_inputs_should_in_memory_[0] = true;
257   } else if (is_parameter_involve_[0]) {
258     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
259       is_inputs_should_in_memory_[0] = true;
260     }
261   }
262 }
263 
264 // Return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const265 double SoftmaxCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
266                                        int64_t) const {
267   // In the forward phase, the communication cost = 0
268   return 0.0;
269 }
270 
271 // Return the per device communication cost in the backward phase.
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const272 double SoftmaxCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
273                                         int64_t stage_id) const {
274   double result = 0.0;
275   if (is_parameter_[0]) {
276     TensorInfo input1 = inputs[0];
277     MS_EXCEPTION_IF_NULL(g_device_manager);
278     auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
279     Shape input1_shape = input1.shape();
280     Shape input1_slice_shape = input1.slice_shape();
281     int64_t used_device_num = 1;
282     for (size_t i = 0; i < input1_shape.size(); ++i) {
283       used_device_num *= input1_shape[i] / input1_slice_shape[i];
284     }
285     if (total_device_num != LongToSize(used_device_num)) {
286       result = ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
287     }
288   }
289   return result;
290 }
291 
292 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
293 // this operator uses
GetForwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> & outputs,int64_t) const294 double SoftmaxCost::GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &outputs,
295                                               int64_t) const {
296   if (outputs.empty() || outputs_type_lengths_.empty()) {
297     MS_LOG(EXCEPTION) << "The outputs or outputs_type_length is empty";
298   }
299 
300   // use output for Tile operator
301   TensorInfo output_info = outputs[0];
302   Shape output_slice_shape = output_info.slice_shape();
303   return ListProduct(output_slice_shape) * static_cast<double>(outputs_type_lengths_[0]);
304 }
305 
306 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
307 // this operator uses
GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,const std::vector<mindspore::parallel::TensorInfo> &,int64_t) const308 double SoftmaxCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,
309                                                const std::vector<mindspore::parallel::TensorInfo> &, int64_t) const {
310   return 0.0;
311 }
312 
313 // Taking account of output
CalculateOutputInMemory()314 void SoftmaxCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
315 
316 // Not taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> &)317 void SoftmaxCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
318   is_inputs_should_in_memory_[0] = is_parameter_[0];
319 }
320 
321 // Not taking account of output
CalculateOutputInMemory()322 void PackCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
323 
324 // Not taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> &)325 void PackCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
326   is_inputs_should_in_memory_[0] = is_parameter_[0];
327 }
328 
329 // Not taking account of output
CalculateOutputInMemory()330 void TileCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
331 
332 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)333 void TileCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
334   // When calculating 'dx', taking account of 'y'
335   if (is_parameter_[0]) {
336     is_inputs_should_in_memory_[0] = true;
337     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
338       is_inputs_should_in_memory_[1] = true;
339     }
340   } else if (is_parameter_involve_[0]) {
341     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
342       is_inputs_should_in_memory_[1] = true;
343     }
344   }
345 
346   if (!is_inputs_should_in_memory_[1]) {
347     is_inputs_should_in_memory_[1] = is_parameter_[1];
348   }
349 }
350 
351 // Not taking account of output
CalculateOutputInMemory()352 void BroadcastToCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
353 
CalculateInputsInMemory(const std::map<size_t,bool> &)354 void BroadcastToCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
355   is_inputs_should_in_memory_[0] = is_parameter_[0];
356 }
357 
358 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)359 void ReLU6Cost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
360   if (is_parameter_[0]) {
361     is_inputs_should_in_memory_[0] = true;
362   } else if (is_parameter_involve_[0]) {
363     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
364       is_inputs_should_in_memory_[0] = true;
365     }
366   }
367 }
368 
369 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)370 void TransposeCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
371   // When calculating 'dx', taking account of 'y'
372   if (is_parameter_[0]) {
373     is_inputs_should_in_memory_[0] = true;
374     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
375       is_inputs_should_in_memory_[1] = true;
376     }
377   } else if (is_parameter_involve_[0]) {
378     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
379       is_inputs_should_in_memory_[1] = true;
380     }
381   }
382 
383   if (!is_inputs_should_in_memory_[1]) {
384     is_inputs_should_in_memory_[1] = is_parameter_[1];
385   }
386 }
387 
388 // return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<mindspore::parallel::TensorInfo> &,const std::vector<mindspore::parallel::TensorInfo> &,int64_t) const389 double TmpIdentityCost::GetForwardCommCost(const std::vector<mindspore::parallel::TensorInfo> &,
390                                            const std::vector<mindspore::parallel::TensorInfo> &, int64_t) const {
391   // Identity is the element-wise operator, thus it does not need communication in the forward phase
392   return 0.0;
393 }
394 
395 // return the per device communication cost in the backward phase.
GetBackwardCommCost(const std::vector<mindspore::parallel::TensorInfo> &,const std::vector<mindspore::parallel::TensorInfo> &,int64_t) const396 double TmpIdentityCost::GetBackwardCommCost(const std::vector<mindspore::parallel::TensorInfo> &,
397                                             const std::vector<mindspore::parallel::TensorInfo> &, int64_t) const {
398   // Identity is the element-wise operator, thus it does not need communication in the backward phase
399   return 0.0;
400 }
401 
402 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
403 // this operator uses
GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,const std::vector<mindspore::parallel::TensorInfo> &,int64_t) const404 double TmpIdentityCost::GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,
405                                                   const std::vector<mindspore::parallel::TensorInfo> &, int64_t) const {
406   return 0.0;
407 }
408 
409 // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes
410 // this operator uses
GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,const std::vector<mindspore::parallel::TensorInfo> &,int64_t) const411 double TmpIdentityCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,
412                                                    const std::vector<mindspore::parallel::TensorInfo> &,
413                                                    int64_t) const {
414   return 0.0;
415 }
416 
417 // Not taking account of output
CalculateOutputInMemory()418 void TmpIdentityCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
419 
420 // Not taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> &)421 void TmpIdentityCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
422   is_inputs_should_in_memory_[0] = is_parameter_[0];
423 }
424 
GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> & inputs,const std::vector<mindspore::parallel::TensorInfo> &,int64_t) const425 double BatchParallelCost::GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &inputs,
426                                                     const std::vector<mindspore::parallel::TensorInfo> &,
427                                                     int64_t) const {
428   double cost = 0.0;
429   for (size_t i = 0; i < inputs.size(); ++i) {
430     cost += ListProduct(inputs[i].slice_shape()) * static_cast<double>(inputs_type_lengths_[i]);
431   }
432   return cost;
433 }
434 
GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,const std::vector<mindspore::parallel::TensorInfo> &,int64_t) const435 double BatchParallelCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,
436                                                      const std::vector<mindspore::parallel::TensorInfo> &,
437                                                      int64_t) const {
438   return 0.0;
439 }
440 
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const441 double BatchParallelCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
442                                               int64_t stage_id) const {
443   double result = 0.0;
444   CheckGlobalDeviceManager();
445   MS_EXCEPTION_IF_NULL(g_device_manager);
446   auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
447 
448   for (size_t j = 0; j < inputs.size(); ++j) {
449     if (!is_parameter_[j]) {
450       continue;
451     }
452     TensorInfo input_a_tensor_info = inputs[j];
453     Shape input_a_shape = input_a_tensor_info.shape();
454     Shape input_a_slice_shape = input_a_tensor_info.slice_shape();
455     int64_t used_device_num = 1;
456     for (size_t i = 0; i < input_a_shape.size(); ++i) {
457       used_device_num *= input_a_shape[i] / input_a_slice_shape[i];
458     }
459     if (total_device_num != LongToSize(used_device_num)) {
460       result += ListProduct(input_a_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
461     }
462   }
463 
464   return result;
465 }
466 
CalculateOutputInMemory()467 void BatchParallelCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
468 
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)469 void BatchParallelCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
470   if (is_parameter_[0]) {
471     is_inputs_should_in_memory_[0] = true;
472     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
473       is_inputs_should_in_memory_[1] = true;
474     }
475   } else if (is_parameter_involve_[0]) {
476     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
477       is_inputs_should_in_memory_[1] = true;
478     }
479   }
480 
481   if (is_parameter_[1]) {
482     is_inputs_should_in_memory_[1] = true;
483     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
484       is_inputs_should_in_memory_[0] = true;
485     }
486   } else if (is_parameter_involve_[1]) {
487     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
488       is_inputs_should_in_memory_[0] = true;
489     }
490   }
491 }
492 
CalculateOutputInMemory()493 void SparseSoftmaxCrossEntropyWithLogitsCost::CalculateOutputInMemory() {
494   is_output_should_in_memory_ = is_parameter_involve_[0];
495 }
496 
CalculateInputsInMemory(const std::map<size_t,bool> &)497 void SparseSoftmaxCrossEntropyWithLogitsCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
498   is_inputs_should_in_memory_[0] = is_parameter_[0];
499   is_inputs_should_in_memory_[1] = is_parameter_[1];
500 }
501 // return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const502 double PReLUCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const {
503   // prelu does not need communication in the forward phase
504   return 0.0;
505 }
506 
507 // return the per device communication cost in the backward phase.
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const508 double PReLUCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
509                                       int64_t stage_id) const {
510   double result = 0.0;
511   if (is_parameter_[1]) {
512     TensorInfo input1 = inputs[1];
513     CheckGlobalDeviceManager();
514     MS_EXCEPTION_IF_NULL(g_device_manager);
515     auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
516     Shape input1_shape = input1.shape();
517     Shape input1_slice_shape = input1.slice_shape();
518     int64_t used_device_num = 1;
519     for (size_t i = 0; i < input1_shape.size(); ++i) {
520       used_device_num *= input1_shape[i] / input1_slice_shape[i];
521     }
522     if (total_device_num != LongToSize(used_device_num)) {
523       result = ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
524     }
525   }
526   return result;
527 }
528 
529 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
530 // this operator uses
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const531 double PReLUCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
532                                             int64_t) const {
533   // In forward phase, the computation cost = slice(A) + slice(B)
534   Shape input0_slice_shape = inputs[0].slice_shape();
535   Shape input1_slice_shape = inputs[1].slice_shape();
536   double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
537                   ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
538   return result;
539 }
540 
541 // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes
542 // this operator uses
GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> & inputs,const std::vector<mindspore::parallel::TensorInfo> &,int64_t stage_id) const543 double PReLUCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &inputs,
544                                              const std::vector<mindspore::parallel::TensorInfo> &,
545                                              int64_t stage_id) const {
546   // In backward phase, the computation cost = (0 or 1) allreduce(slice(B))
547   double result = 0.0;
548   if (is_parameter_[1]) {
549     TensorInfo input1 = inputs[1];  // tensor B
550     CheckGlobalDeviceManager();
551     MS_EXCEPTION_IF_NULL(g_device_manager);
552     auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
553 
554     Shape input1_shape = input1.shape();
555     Shape input1_slice_shape = input1.slice_shape();
556     int64_t used_device_num = 1;
557     for (size_t i = 0; i < input1_shape.size(); ++i) {
558       used_device_num *= input1_shape[i] / input1_slice_shape[i];
559     }
560 
561     if (total_device_num != LongToSize(used_device_num)) {
562       result += ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
563     }
564   }
565   return result;
566 }
567 
CalculateOutputInMemory()568 void PReLUCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
569 
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)570 void PReLUCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
571   // When calculating 'dx', taking account of both 'x' and 'y';
572   // when calculating 'dy', taking account of both 'x' and 'y'
573   if (is_parameter_involve_[0] || is_parameter_involve_[1]) {
574     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
575       is_inputs_should_in_memory_[0] = true;
576     }
577     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
578       is_inputs_should_in_memory_[1] = true;
579     }
580   }
581 }
582 
583 // return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const584 double OneHotCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const {
585   // onehot does not need communication in the forward phase
586   return 0.0;
587 }
588 
589 // return the per device communication cost in the backward phase.
GetBackwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const590 double OneHotCost::GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
591                                        int64_t) const {
592   // onehot does not need communication in the backward phase
593   return 0.0;
594 }
595 
596 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
597 // this operator uses
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const598 double OneHotCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
599                                              int64_t) const {
600   // In onehot's forward phase, the computation cost = slice(A)
601   Shape input0_slice_shape = inputs[0].slice_shape();
602   return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
603 }
604 
605 // Return the per  device computation cost in the backward phase. The cost is calculated according to the bytes
606 // this operator uses
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const607 double OneHotCost::GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
608                                               int64_t) const {
609   return 0.0;
610 }
611 
612 // Not taking account of output
CalculateOutputInMemory()613 void OneHotCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
614 
615 // Not taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> &)616 void OneHotCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
617   is_inputs_should_in_memory_[0] = is_parameter_[0];
618   is_inputs_should_in_memory_[1] = is_parameter_[1];
619   is_inputs_should_in_memory_[ONEHOT_INPUTS_SIZE - 2] = is_parameter_[ONEHOT_INPUTS_SIZE - 2];
620   is_inputs_should_in_memory_[ONEHOT_INPUTS_SIZE - 1] = is_parameter_[ONEHOT_INPUTS_SIZE - 1];
621 }
622 
623 // return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const624 double SoftmaxCrossEntropyWithLogitsCost::GetForwardCommCost(const std::vector<TensorInfo> &,
625                                                              const std::vector<TensorInfo> &, int64_t) const {
626   // SoftmaxCrossEntropyWithLogitsCost does not need communication in the forward phase
627   return 0.0;
628 }
629 
630 // return the per device communication cost in the backward phase.
GetBackwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const631 double SoftmaxCrossEntropyWithLogitsCost::GetBackwardCommCost(const std::vector<TensorInfo> &,
632                                                               const std::vector<TensorInfo> &, int64_t) const {
633   // SoftmaxCrossEntropyWithLogitsCost does not need communication in the backward phase
634   return 0.0;
635 }
636 
637 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
638 // this operator uses
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const639 double SoftmaxCrossEntropyWithLogitsCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
640                                                                     const std::vector<TensorInfo> &, int64_t) const {
641   // In forward phase, the computation cost = slice(A) + slice(B)
642   Shape input0_slice_shape = inputs[0].slice_shape();
643   Shape input1_slice_shape = inputs[1].slice_shape();
644   double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
645                   ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
646   return result;
647 }
648 
649 // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes
650 // this operator uses
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const651 double SoftmaxCrossEntropyWithLogitsCost::GetBackwardComputationCost(const std::vector<TensorInfo> &,
652                                                                      const std::vector<TensorInfo> &, int64_t) const {
653   return 0.0;
654 }
655 
656 // Taking account of output
CalculateOutputInMemory()657 void SoftmaxCrossEntropyWithLogitsCost::CalculateOutputInMemory() {
658   is_output_should_in_memory_ = is_parameter_involve_[0];
659 }
660 
CalculateInputsInMemory(const std::map<size_t,bool> &)661 void SoftmaxCrossEntropyWithLogitsCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
662   is_inputs_should_in_memory_[0] = is_parameter_[0];
663   is_inputs_should_in_memory_[1] = is_parameter_[1];
664 }
665 
666 // return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id) const667 double ReshapeCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
668                                        int64_t stage_id) const {
669   CheckGlobalDeviceManager();
670   MS_EXCEPTION_IF_NULL(g_device_manager);
671   RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id);
672   TensorRedistribution tensor_redistribution(false, true);
673   if (tensor_redistribution.Init(inputs[0].tensor_layout(), outputs[0].tensor_layout(), dev_list) == FAILED) {
674     MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed.";
675   }
676   if (tensor_redistribution.ComputeCost() == FAILED) {
677     MS_LOG(EXCEPTION) << "Failure: tensor_redistribution ComputeCost failed.";
678   }
679   return (inputs_type_lengths_[0] * tensor_redistribution.comm_cost());
680 }
681 
682 // return the per device communication cost in the backward phase.
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const683 double ReshapeCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
684                                         int64_t stage_id) const {
685   double result = 0.0;
686   if (is_parameter_[0]) {
687     TensorInfo input1 = inputs[0];
688     MS_EXCEPTION_IF_NULL(g_device_manager);
689     auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
690     Shape input1_shape = input1.shape();
691     Shape input1_slice_shape = input1.slice_shape();
692     int64_t used_device_num = 1;
693     for (size_t i = 0; i < input1_shape.size(); ++i) {
694       used_device_num *= input1_shape[i] / input1_slice_shape[i];
695     }
696     if (total_device_num != LongToSize(used_device_num)) {
697       result = ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
698     }
699   }
700   return result;
701 }
702 
703 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
704 // this operator uses
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id) const705 double ReshapeCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
706                                               const std::vector<TensorInfo> &outputs, int64_t stage_id) const {
707   CheckGlobalDeviceManager();
708   MS_EXCEPTION_IF_NULL(g_device_manager);
709   RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id);
710   TensorRedistribution tensor_redistribution(false, true);
711   if (tensor_redistribution.Init(inputs[0].tensor_layout(), outputs[0].tensor_layout(), dev_list) == FAILED) {
712     MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed.";
713   }
714   if (tensor_redistribution.ComputeCost() == FAILED) {
715     MS_LOG(EXCEPTION) << "Failure: tensor_redistribution ComputeCost failed.";
716   }
717   return (inputs_type_lengths_[0] * tensor_redistribution.computation_cost());
718 }
719 
720 // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes
721 // this operator uses
GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,const std::vector<mindspore::parallel::TensorInfo> &,int64_t) const722 double ReshapeCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,
723                                                const std::vector<mindspore::parallel::TensorInfo> &, int64_t) const {
724   return 0.0;
725 }
726 
CalculateOutputInMemory()727 void ReshapeCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
728 
CalculateInputsInMemory(const std::map<size_t,bool> &)729 void ReshapeCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
730   is_inputs_should_in_memory_[0] = is_parameter_[0];
731   is_inputs_should_in_memory_[1] = is_parameter_[1];
732 }
733 
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const734 double SubCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
735                                           int64_t) const {
736   double result = ListProduct(inputs[0].slice_shape()) * static_cast<double>(inputs_type_lengths_[0]) +
737                   ListProduct(inputs[1].slice_shape()) * static_cast<double>(inputs_type_lengths_[1]);
738   return result;
739 }
740 
GetBackwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const741 double SubCost::GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
742                                            int64_t stage_id) const {
743   double result = 0.0;
744   CheckGlobalDeviceManager();
745   MS_EXCEPTION_IF_NULL(g_device_manager);
746   auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
747 
748   if (is_parameter_[0]) {
749     TensorInfo input_a_tensor_info = inputs[0];
750     Shape input_a_shape = input_a_tensor_info.shape();
751     Shape input_a_slice_shape = input_a_tensor_info.slice_shape();
752     int64_t used_device_num = 1;
753     for (size_t i = 0; i < input_a_shape.size(); ++i) {
754       used_device_num *= input_a_shape[i] / input_a_slice_shape[i];
755     }
756 
757     if (total_device_num != LongToSize(used_device_num))
758       result += ListProduct(input_a_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
759   }
760 
761   if (is_parameter_[1]) {
762     TensorInfo input_b_tensor_info = inputs[1];
763     Shape input_b_shape = input_b_tensor_info.shape();
764     Shape input_b_slice_shape = input_b_tensor_info.slice_shape();
765     int64_t used_device_num = 1;
766     for (size_t i = 0; i < input_b_shape.size(); ++i) {
767       used_device_num *= input_b_shape[i] / input_b_slice_shape[i];
768     }
769 
770     if (total_device_num != LongToSize(used_device_num))
771       result += ListProduct(input_b_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
772   }
773   return result;
774 }
775 
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const776 double SubCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
777                                     int64_t stage_id) const {
778   double result = 0.0;
779   CheckGlobalDeviceManager();
780   MS_EXCEPTION_IF_NULL(g_device_manager);
781   auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
782 
783   if (is_parameter_[0]) {
784     TensorInfo input_a_tensor_info = inputs[0];
785     Shape input_a_shape = input_a_tensor_info.shape();
786     Shape input_a_slice_shape = input_a_tensor_info.slice_shape();
787     int64_t used_device_num = 1;
788     for (size_t i = 0; i < input_a_shape.size(); ++i) {
789       used_device_num *= input_a_shape[i] / input_a_slice_shape[i];
790     }
791 
792     if (total_device_num != LongToSize(used_device_num))
793       result += ListProduct(input_a_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
794   }
795 
796   if (is_parameter_[1]) {
797     TensorInfo input_b_tensor_info = inputs[1];
798     Shape input_b_shape = input_b_tensor_info.shape();
799     Shape input_b_slice_shape = input_b_tensor_info.slice_shape();
800     int64_t used_device_num = 1;
801     for (size_t i = 0; i < input_b_shape.size(); ++i) {
802       used_device_num *= input_b_shape[i] / input_b_slice_shape[i];
803     }
804 
805     if (total_device_num != LongToSize(used_device_num))
806       result += ListProduct(input_b_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
807   }
808 
809   return result;
810 }
811 
812 // Not taking account of output
CalculateOutputInMemory()813 void SubCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
814 
815 // Not taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> &)816 void SubCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
817   is_inputs_should_in_memory_[0] = is_parameter_[0];
818   is_inputs_should_in_memory_[1] = is_parameter_[1];
819 }
820 
821 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)822 void MulCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
823   if (is_parameter_[0]) {
824     // 'x' is parameter, so it should be in memory.
825     is_inputs_should_in_memory_[0] = true;
826     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
827       // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
828       is_inputs_should_in_memory_[1] = true;
829     }
830   } else if (is_parameter_involve_[0]) {
831     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
832       is_inputs_should_in_memory_[1] = true;
833     }
834   }
835 
836   if (is_parameter_[1]) {
837     is_inputs_should_in_memory_[1] = true;
838     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
839       is_inputs_should_in_memory_[0] = true;
840     }
841   } else if (is_parameter_involve_[1]) {
842     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
843       is_inputs_should_in_memory_[0] = true;
844     }
845   }
846 }
847 
848 // Taking account of output
CalculateOutputInMemory()849 void DivCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[1]; }
850 
851 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)852 void DivCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
853   // When calculating 'dx', taking account of 'y'
854   if (is_parameter_[0]) {
855     // 'x' is parameter, so it should be in memory.
856     is_inputs_should_in_memory_[0] = true;
857     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
858       // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
859       is_inputs_should_in_memory_[1] = true;
860     }
861   } else if (is_parameter_involve_[0]) {
862     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
863       is_inputs_should_in_memory_[1] = true;
864     }
865   }
866 
867   // When calculating 'dy', taking account of 'y'
868   if (is_parameter_involve_[1]) {
869     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
870       is_inputs_should_in_memory_[1] = true;
871     }
872   }
873 }
874 
875 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)876 void ModCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
877   // When calculating 'dx', not taking account of 'x' and 'y'
878   is_inputs_should_in_memory_[0] = is_parameter_[0];
879   // When calculating 'dy', taking account of 'x' and 'y'
880   if (is_parameter_involve_[1]) {
881     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
882       is_inputs_should_in_memory_[0] = true;
883     }
884     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
885       is_inputs_should_in_memory_[1] = true;
886     }
887   }
888 }
889 
CalculateOutputInMemory()890 void PowCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[1]; }
891 
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)892 void PowCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
893   // When calculating 'dx', taking account of both 'x' and 'power'
894   if (is_parameter_involve_[0]) {
895     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
896       is_inputs_should_in_memory_[0] = true;
897     }
898     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
899       is_inputs_should_in_memory_[1] = true;
900     }
901   }
902   // When calculating 'dpower', taking account of 'x'
903   if (is_parameter_[1]) {
904     is_inputs_should_in_memory_[1] = true;
905     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
906       is_inputs_should_in_memory_[0] = true;
907     }
908   } else if (is_parameter_involve_[1]) {
909     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
910       is_inputs_should_in_memory_[0] = true;
911     }
912   }
913 }
914 
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)915 void AssignCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
916   // When calculating 'dx', taking account of 'x'
917   if (is_parameter_involve_[0]) {
918     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
919       is_inputs_should_in_memory_[0] = true;
920     }
921   }
922   // When calculating 'dy', not taking account of 'x' and 'y'
923   is_inputs_should_in_memory_[1] = is_parameter_[1];
924 }
925 
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)926 void SigmoidCrossEntropyWithLogitsCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
927   // When calculating 'dx', taking account of both 'x' and 'y'
928   if (is_parameter_involve_[0]) {
929     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
930       is_inputs_should_in_memory_[0] = true;
931     }
932     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
933       is_inputs_should_in_memory_[1] = true;
934     }
935   }
936   // When calculating 'dy', not taking account of 'x' and 'y'
937   if (!is_inputs_should_in_memory_[1]) {
938     is_inputs_should_in_memory_[1] = is_parameter_[1];
939   }
940 }
941 
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)942 void Atan2Cost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
943   // When calculating 'dx', taking account of both 'x' and 'y'; when calculating 'dy', taking account of both 'x' and
944   // 'y'
945   if (is_parameter_involve_[0] || is_parameter_involve_[1]) {
946     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
947       is_inputs_should_in_memory_[0] = true;
948     }
949     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
950       is_inputs_should_in_memory_[1] = true;
951     }
952   }
953 }
954 
CalculateOutputInMemory()955 void DivNoNanCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[1]; }
956 
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)957 void DivNoNanCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
958   // When calculating 'dx', taking account of 'y'
959   if (is_parameter_[0]) {
960     // 'x' is parameter, so it should be in memory.
961     is_inputs_should_in_memory_[0] = true;
962     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
963       // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
964       is_inputs_should_in_memory_[1] = true;
965     }
966   } else if (is_parameter_involve_[0]) {
967     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
968       is_inputs_should_in_memory_[1] = true;
969     }
970   }
971 
972   // When calculating 'dy', taking account of 'y'
973   if (is_parameter_involve_[1]) {
974     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
975       is_inputs_should_in_memory_[1] = true;
976     }
977   }
978 }
979 
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)980 void MaximumCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
981   // When calculating 'dx', taking account of both 'x' and 'y';
982   // when calculating 'dy', taking account of both 'x' and 'y'
983   if (is_parameter_involve_[0] || is_parameter_involve_[1]) {
984     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
985       is_inputs_should_in_memory_[0] = true;
986     }
987     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
988       is_inputs_should_in_memory_[1] = true;
989     }
990   }
991 }
992 
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)993 void SliceCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
994   // When calculating 'dx', taking account of 'y' and 'z'
995   if (is_parameter_[0]) {
996     is_inputs_should_in_memory_[0] = true;
997     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
998       is_inputs_should_in_memory_[1] = true;
999     }
1000     if ((prev_output_in_mem.find(SLICE_INPUTS_SIZE - 1) == prev_output_in_mem.end()) ||
1001         (!prev_output_in_mem.at(SLICE_INPUTS_SIZE - 1))) {
1002       is_inputs_should_in_memory_[SLICE_INPUTS_SIZE - 1] = true;
1003     }
1004   } else if (is_parameter_involve_[0]) {
1005     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1006       is_inputs_should_in_memory_[1] = true;
1007     }
1008     if ((prev_output_in_mem.find(SLICE_INPUTS_SIZE - 1) == prev_output_in_mem.end()) ||
1009         (!prev_output_in_mem.at(SLICE_INPUTS_SIZE - 1))) {
1010       is_inputs_should_in_memory_[SLICE_INPUTS_SIZE - 1] = true;
1011     }
1012   }
1013 
1014   if (!is_inputs_should_in_memory_[1]) {
1015     is_inputs_should_in_memory_[1] = is_parameter_[1];
1016   }
1017   if (!is_inputs_should_in_memory_[SLICE_INPUTS_SIZE - 1]) {
1018     is_inputs_should_in_memory_[SLICE_INPUTS_SIZE - 1] = is_parameter_[SLICE_INPUTS_SIZE - 1];
1019   }
1020 }
1021 
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1022 void StridedSliceCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1023   // When calculating 'dx', taking account of 'y', 'z' and 'w'
1024   if (is_parameter_[0]) {
1025     is_inputs_should_in_memory_[0] = true;
1026     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1027       is_inputs_should_in_memory_[1] = true;
1028     }
1029     if ((prev_output_in_mem.find(STRIDED_SLICE_INPUTS_SIZE - 2) == prev_output_in_mem.end()) ||
1030         (!prev_output_in_mem.at(STRIDED_SLICE_INPUTS_SIZE - 2))) {
1031       is_inputs_should_in_memory_[STRIDED_SLICE_INPUTS_SIZE - 2] = true;
1032     }
1033     if ((prev_output_in_mem.find(STRIDED_SLICE_INPUTS_SIZE - 1) == prev_output_in_mem.end()) ||
1034         (!prev_output_in_mem.at(STRIDED_SLICE_INPUTS_SIZE - 1))) {
1035       is_inputs_should_in_memory_[STRIDED_SLICE_INPUTS_SIZE - 1] = true;
1036     }
1037   } else if (is_parameter_involve_[0]) {
1038     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1039       is_inputs_should_in_memory_[1] = true;
1040     }
1041     if ((prev_output_in_mem.find(STRIDED_SLICE_INPUTS_SIZE - 2) == prev_output_in_mem.end()) ||
1042         (!prev_output_in_mem.at(STRIDED_SLICE_INPUTS_SIZE - 2))) {
1043       is_inputs_should_in_memory_[STRIDED_SLICE_INPUTS_SIZE - 2] = true;
1044     }
1045     if ((prev_output_in_mem.find(STRIDED_SLICE_INPUTS_SIZE - 1) == prev_output_in_mem.end()) ||
1046         (!prev_output_in_mem.at(STRIDED_SLICE_INPUTS_SIZE - 1))) {
1047       is_inputs_should_in_memory_[STRIDED_SLICE_INPUTS_SIZE - 1] = true;
1048     }
1049   }
1050 
1051   if (!is_inputs_should_in_memory_[1]) {
1052     is_inputs_should_in_memory_[1] = is_parameter_[1];
1053   }
1054   if (!is_inputs_should_in_memory_[STRIDED_SLICE_INPUTS_SIZE - 2]) {
1055     is_inputs_should_in_memory_[STRIDED_SLICE_INPUTS_SIZE - 2] = is_parameter_[STRIDED_SLICE_INPUTS_SIZE - 2];
1056   }
1057   if (!is_inputs_should_in_memory_[STRIDED_SLICE_INPUTS_SIZE - 1]) {
1058     is_inputs_should_in_memory_[STRIDED_SLICE_INPUTS_SIZE - 1] = is_parameter_[STRIDED_SLICE_INPUTS_SIZE - 1];
1059   }
1060 }
1061 
CalculateOutputInMemory()1062 void DropOutDoMaskCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
1063 
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1064 void DropOutDoMaskCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1065   // When calculating 'dx', taking account of 'y'
1066   if (is_parameter_[0]) {
1067     // 'x' is parameter, so it should be in memory.
1068     is_inputs_should_in_memory_[0] = true;
1069     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1070       // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
1071       is_inputs_should_in_memory_[1] = true;
1072     }
1073   } else if (is_parameter_involve_[0]) {
1074     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1075       is_inputs_should_in_memory_[1] = true;
1076     }
1077   }
1078 
1079   if (!is_inputs_should_in_memory_[1]) {
1080     is_inputs_should_in_memory_[1] = is_parameter_[1];
1081   }
1082   is_inputs_should_in_memory_[DROPOUTDOMASK_INPUTS_SIZE - 1] = is_parameter_[DROPOUTDOMASK_INPUTS_SIZE - 1];
1083 }
1084 
IsDataParallel(const Shape & shape,const Shape & slice_shape,int64_t stage_id)1085 bool IsDataParallel(const Shape &shape, const Shape &slice_shape, int64_t stage_id) {
1086   CheckGlobalDeviceManager();
1087   MS_EXCEPTION_IF_NULL(g_device_manager);
1088   auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
1089   auto strategy0 = shape[0] / slice_shape[0];
1090 
1091   return (total_device_num == LongToSize(strategy0));
1092 }
1093 
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id) const1094 double ReduceSumCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1095                                          int64_t stage_id) const {
1096   double result = 0.0;
1097   TensorInfo input0 = inputs[0];
1098   TensorInfo output0 = outputs[0];
1099   Shape input0_shape = input0.shape();
1100   Shape input0_slice_shape = input0.slice_shape();
1101   if (cross_batch_ && IsDataParallel(input0_shape, input0_slice_shape, stage_id)) {
1102     return result;
1103   }
1104   std::vector<int64_t> dim_list = input0.reduce_dim();
1105   auto pos = std::find_if(dim_list.begin(), dim_list.end(), [input0_shape, input0_slice_shape](int64_t index) {
1106     return input0_shape[LongToSize(index)] != input0_slice_shape[LongToSize(index)];
1107   });
1108   if (pos != dim_list.end()) {
1109     result += ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
1110   }
1111 
1112   return result;
1113 }
1114 
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const1115 double ReduceSumCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1116                                           int64_t stage_id) const {
1117   double result = 0.0;
1118   if (is_parameter_[0]) {
1119     TensorInfo input_tensor_info = inputs[0];
1120     CheckGlobalDeviceManager();
1121     MS_EXCEPTION_IF_NULL(g_device_manager);
1122     auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
1123 
1124     Shape input_shape = input_tensor_info.shape();
1125     Shape input_slice_shape = input_tensor_info.slice_shape();
1126     int64_t used_device_num = 1;
1127     for (size_t i = 0; i < input_shape.size(); ++i) {
1128       used_device_num *= input_shape[i] / input_slice_shape[i];
1129     }
1130 
1131     if (total_device_num != LongToSize(used_device_num))
1132       result += ListProduct(input_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1133   }
1134 
1135   return result;
1136 }
1137 
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id) const1138 double ReduceSumCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
1139                                                 const std::vector<TensorInfo> &outputs, int64_t stage_id) const {
1140   double result = 0.0;
1141   TensorInfo input0 = inputs[0];
1142   TensorInfo output0 = outputs[0];
1143   std::vector<int64_t> dim_list = input0.reduce_dim();
1144   Shape input0_slice_shape = input0.slice_shape();
1145   Shape input0_shape = input0.shape();
1146   if (!cross_batch_ || !IsDataParallel(input0_shape, input0_slice_shape, stage_id)) {
1147     auto pos = std::find_if(dim_list.begin(), dim_list.end(), [input0_shape, input0_slice_shape](int64_t index) {
1148       return input0_shape[LongToSize(index)] != input0_slice_shape[LongToSize(index)];
1149     });
1150     if (pos != dim_list.end()) {
1151       result += ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
1152     }
1153   }
1154   result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1155 
1156   return result;
1157 }
1158 
1159 // Not taking account of output
CalculateOutputInMemory()1160 void ReduceSumCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
1161 
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1162 void ReduceSumCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1163   // When calculating 'dx', taking account of 'y'
1164   if (is_parameter_[0]) {
1165     // 'x' is parameter, so it should be in memory.
1166     is_inputs_should_in_memory_[0] = true;
1167     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1168       // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
1169       is_inputs_should_in_memory_[1] = true;
1170     }
1171   } else if (is_parameter_involve_[0]) {
1172     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1173       is_inputs_should_in_memory_[1] = true;
1174     }
1175   }
1176 
1177   // Not taking account of 'y'
1178   if (!is_inputs_should_in_memory_[1]) {
1179     is_inputs_should_in_memory_[1] = is_parameter_[1];
1180   }
1181 }
1182 
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id) const1183 double ReduceMeanCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
1184                                                  const std::vector<TensorInfo> &outputs, int64_t stage_id) const {
1185   double result = 0.0;
1186   TensorInfo input0 = inputs[0];
1187   TensorInfo output0 = outputs[0];
1188   std::vector<int64_t> dim_list = input0.reduce_dim();
1189   Shape input0_slice_shape = input0.slice_shape();
1190   Shape input0_shape = input0.shape();
1191   if (!cross_batch_ || !IsDataParallel(input0_shape, input0_slice_shape, stage_id)) {
1192     auto pos = std::find_if(dim_list.begin(), dim_list.end(), [input0_shape, input0_slice_shape](int64_t index) {
1193       return input0_shape[LongToSize(index)] != input0_slice_shape[LongToSize(index)];
1194     });
1195     if (pos != dim_list.end()) {
1196       result += ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]) * 2.0;
1197     }
1198   }
1199   result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1200 
1201   return result;
1202 }
1203 
CalculateOutputInMemory()1204 void ReduceMinCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
1205 
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1206 void ReduceMinCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1207   // When calculating 'dx', taking account of 'y'
1208   if (is_parameter_[0]) {
1209     // 'x' is parameter, so it should be in memory.
1210     is_inputs_should_in_memory_[0] = true;
1211     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1212       // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
1213       is_inputs_should_in_memory_[1] = true;
1214     }
1215   } else if (is_parameter_involve_[0]) {
1216     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1217       is_inputs_should_in_memory_[1] = true;
1218     }
1219   }
1220 
1221   // Not taking account of 'y'
1222   if (!is_inputs_should_in_memory_[1]) {
1223     is_inputs_should_in_memory_[1] = is_parameter_[1];
1224   }
1225 }
1226 
CalculateOutputInMemory()1227 void ArgMaxWithValueCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
1228 
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1229 void ArgMaxWithValueCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1230   // When calculating 'dx', taking account of 'x'
1231   if (is_parameter_[0]) {
1232     is_inputs_should_in_memory_[0] = true;
1233   } else if (is_parameter_involve_[0]) {
1234     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
1235       is_inputs_should_in_memory_[0] = true;
1236     }
1237   }
1238 }
1239 
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const1240 double DropOutCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1241                                               int64_t) const {
1242   if (inputs.empty()) {
1243     return 0.0;
1244   }
1245   TensorInfo input0 = inputs[0];
1246   Shape input0_slice_shape = input0.slice_shape();
1247   return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * DROPOUT_COST_RATE;
1248 }
1249 
1250 // return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const1251 double GatherV2Cost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
1252                                         int64_t) const {
1253   // GatherV2Cost does not need communication in the forward phase
1254   return 0.0;
1255 }
1256 
1257 // return the per device communication cost in the backward phase.
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const1258 double GatherV2Cost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1259                                          int64_t stage_id) const {
1260   double result = 0.0;
1261   CheckGlobalDeviceManager();
1262   MS_EXCEPTION_IF_NULL(g_device_manager);
1263   auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
1264 
1265   for (size_t j = 0; j < inputs.size(); ++j) {
1266     if (!is_parameter_[j]) {
1267       continue;
1268     }
1269     TensorInfo input_a_tensor_info = inputs[j];
1270     Shape input_a_shape = input_a_tensor_info.shape();
1271     Shape input_a_slice_shape = input_a_tensor_info.slice_shape();
1272     int64_t used_device_num = 1;
1273     for (size_t i = 0; i < input_a_shape.size(); ++i) {
1274       used_device_num *= input_a_shape[i] / input_a_slice_shape[i];
1275     }
1276     if (total_device_num != LongToSize(used_device_num)) {
1277       result += ListProduct(input_a_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1278     }
1279   }
1280 
1281   return result;
1282 }
1283 
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const1284 double GatherV2Cost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1285                                                int64_t) const {
1286   // In forward phase, the computation cost = slice(A) + slice(B)
1287   Shape input0_slice_shape = inputs[0].slice_shape();
1288   Shape input1_slice_shape = inputs[1].slice_shape();
1289   double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
1290                   ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
1291   return result;
1292 }
1293 
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const1294 double GatherV2Cost::GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
1295                                                 int64_t) const {
1296   return 0.0;
1297 }
1298 
1299 // Not taking account of output
CalculateOutputInMemory()1300 void GatherV2Cost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
1301 
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1302 void GatherV2Cost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1303   // When calculating 'dx', taking account of 'y' and 'z'
1304   if (is_parameter_[0]) {
1305     // 'x' is parameter, so it should be in memory.
1306     is_inputs_should_in_memory_[0] = true;
1307     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1308       is_inputs_should_in_memory_[1] = true;
1309     }
1310     if ((prev_output_in_mem.find(GATHERV2_INPUTS_SIZE - 1) == prev_output_in_mem.end()) ||
1311         (!prev_output_in_mem.at(GATHERV2_INPUTS_SIZE - 1))) {
1312       is_inputs_should_in_memory_[GATHERV2_INPUTS_SIZE - 1] = true;
1313     }
1314   } else if (is_parameter_involve_[0]) {
1315     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1316       is_inputs_should_in_memory_[1] = true;
1317     }
1318     if ((prev_output_in_mem.find(GATHERV2_INPUTS_SIZE - 1) == prev_output_in_mem.end()) ||
1319         (!prev_output_in_mem.at(GATHERV2_INPUTS_SIZE - 1))) {
1320       is_inputs_should_in_memory_[GATHERV2_INPUTS_SIZE - 1] = true;
1321     }
1322   }
1323 
1324   if (!is_inputs_should_in_memory_[1]) {
1325     is_inputs_should_in_memory_[1] = is_parameter_[1];
1326   }
1327   if (!is_inputs_should_in_memory_[GATHERV2_INPUTS_SIZE - 1]) {
1328     is_inputs_should_in_memory_[GATHERV2_INPUTS_SIZE - 1] = is_parameter_[GATHERV2_INPUTS_SIZE - 1];
1329   }
1330 }
1331 
CalculateOutputInMemory()1332 void GetNextCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
1333 
CalculateInputsInMemory(const std::map<size_t,bool> &)1334 void GetNextCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
1335   if (is_inputs_should_in_memory_.size() == 0) {
1336     return;
1337   }
1338   is_inputs_should_in_memory_[0] = is_parameter_[0];
1339 }
1340 
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const1341 double DSDMatmulCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1342                                                 int64_t) const {
1343   double result = 0.0;
1344   if (inputs_type_lengths_.size() != inputs.size()) {
1345     MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for layer norm cost";
1346   }
1347 
1348   for (size_t index = 0; index < inputs.size(); ++index) {
1349     TensorInfo tensor_info = inputs[index];
1350     Shape slice_shape = tensor_info.slice_shape();
1351     result += ListProduct(slice_shape) * static_cast<double>(inputs_type_lengths_[index]);
1352   }
1353   return result;
1354 }
1355 
CalculateOutputInMemory()1356 void DSDMatmulCost::CalculateOutputInMemory() {
1357   is_output_should_in_memory_ =
1358     (std::find(is_parameter_involve_.begin(), is_parameter_involve_.end(), true) != is_parameter_involve_.end());
1359 }
1360 
CalculateInputsInMemory(const std::map<size_t,bool> &)1361 void DSDMatmulCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
1362   bool keep_mem =
1363     (std::find(is_parameter_.begin(), is_parameter_.end(), true) != is_parameter_.end()) ||
1364     (std::find(is_parameter_involve_.begin(), is_parameter_involve_.end(), true) != is_parameter_involve_.end());
1365   std::fill(is_inputs_should_in_memory_.begin(), is_inputs_should_in_memory_.end(), keep_mem);
1366 }
1367 
CalculateOutputInMemory()1368 void UniqueCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
1369 
CalculateInputsInMemory(const std::map<size_t,bool> &)1370 void UniqueCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
1371   is_inputs_should_in_memory_[0] = is_parameter_[0];
1372 }
1373 
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const1374 double LayerNormCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1375                                           int64_t stage_id) const {
1376   double result = 0.0;
1377   if (is_parameter_.size() != inputs.size()) {
1378     MS_LOG(EXCEPTION) << "Invalid parameter size " << is_parameter_.size() << " for layer norm cost";
1379   }
1380   if (inputs_type_lengths_.size() != inputs.size()) {
1381     MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for layer norm cost";
1382   }
1383 
1384   MS_EXCEPTION_IF_NULL(g_device_manager);
1385   auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
1386 
1387   for (size_t index = 0; index < inputs.size(); ++index) {
1388     if (is_parameter_[index]) {
1389       TensorInfo tensor_info = inputs[index];
1390       Shape shape = tensor_info.shape();
1391       Shape slice_shape = tensor_info.slice_shape();
1392       int64_t used_device_num = 1;
1393       for (size_t i = 0; i < shape.size(); ++i) {
1394         if (slice_shape[i] == 0) {
1395           MS_LOG(EXCEPTION) << "Invalid slice shape " << ShapeToString(slice_shape);
1396         }
1397         used_device_num *= shape[i] / slice_shape[i];
1398       }
1399       if (total_device_num != LongToSize(used_device_num)) {
1400         result += ListProduct(slice_shape) * static_cast<double>(inputs_type_lengths_[index]);
1401       }
1402     }
1403   }
1404   return result;
1405 }
1406 
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const1407 double LayerNormCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1408                                                 int64_t) const {
1409   double result = 0.0;
1410   if (inputs_type_lengths_.size() != inputs.size()) {
1411     MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for layer norm cost";
1412   }
1413 
1414   for (size_t index = 0; index < inputs.size(); ++index) {
1415     TensorInfo tensor_info = inputs[index];
1416     Shape slice_shape = tensor_info.slice_shape();
1417     result += ListProduct(slice_shape) * static_cast<double>(inputs_type_lengths_[index]);
1418   }
1419   return result;
1420 }
1421 
CalculateOutputInMemory()1422 void LayerNormCost::CalculateOutputInMemory() {
1423   is_output_should_in_memory_ =
1424     is_parameter_involve_[0] || is_parameter_involve_[1] || is_parameter_involve_[LAYERNORM_INPUTS_SIZE - 1];
1425 }
1426 
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1427 void LayerNormCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1428   // When calculating 'dx', taking account of both 'x' and 'y'
1429   // When calculating 'dy', taking account of both 'x' and 'y'
1430   if (is_parameter_involve_[0] || is_parameter_involve_[1]) {
1431     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
1432       is_inputs_should_in_memory_[0] = true;
1433     }
1434     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1435       is_inputs_should_in_memory_[1] = true;
1436     }
1437   }
1438   is_inputs_should_in_memory_[LAYERNORM_INPUTS_SIZE - 1] = is_parameter_[LAYERNORM_INPUTS_SIZE - 1];
1439 }
1440 
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const1441 double UniqueCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const {
1442   return 0.0;
1443 }
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const1444 double UniqueCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1445                                        int64_t stage_id) const {
1446   double result = 0.0;
1447   if (is_parameter_[0]) {
1448     TensorInfo input = inputs[0];
1449     CheckGlobalDeviceManager();
1450     MS_EXCEPTION_IF_NULL(g_device_manager);
1451     auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
1452     Shape input_shape = input.shape();
1453     Shape input_slice_shape = input.slice_shape();
1454     int64_t used_device_num = 1;
1455     for (size_t i = 0; i < input_shape.size(); ++i) {
1456       used_device_num *= input_shape[i] / input_slice_shape[i];
1457     }
1458     if (total_device_num != LongToSize(used_device_num)) {
1459       result = ListProduct(input_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1460     }
1461   }
1462   return result;
1463 }
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const1464 double UniqueCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1465                                              int64_t) const {
1466   // In forward phase, the computation cost = slice(A) + slice(B)
1467   Shape input_slice_shape = inputs[0].slice_shape();
1468   double result = ListProduct(input_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1469   return result;
1470 }
GetBackwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const1471 double UniqueCost::GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1472                                               int64_t stage_id) const {
1473   // In backward phase, the computation cost = (0 or 1) allreduce(slice(B))
1474   double result = 0.0;
1475   if (is_parameter_[0]) {
1476     TensorInfo input = inputs[0];  // tensor B
1477     CheckGlobalDeviceManager();
1478     MS_EXCEPTION_IF_NULL(g_device_manager);
1479     auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
1480 
1481     Shape input_shape = input.shape();
1482     Shape input_slice_shape = input.slice_shape();
1483     int64_t used_device_num = 1;
1484     for (size_t i = 0; i < input_shape.size(); ++i) {
1485       used_device_num *= input_shape[i] / input_slice_shape[i];
1486     }
1487 
1488     if (total_device_num != LongToSize(used_device_num)) {
1489       result += ListProduct(input_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1490     }
1491   }
1492   return result;
1493 }
1494 
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1495 double GatherV2PCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1496                                          int64_t) const {
1497   double result = 0.0;
1498   if (outputs_type_lengths_.size() != outputs.size()) {
1499     MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for gatherv2 cost";
1500   }
1501   // don't split axis
1502   if (strategy_.at(LongToSize(axis_)) == 1) {
1503     return result;
1504   }
1505 
1506   // split axis
1507   auto param_shape = inputs[0].slice_shape();
1508   auto index_shape = inputs[1].slice_shape();
1509   Shape reducescatter_shape = index_shape;
1510   if (param_shape.size() == 2) {
1511     reducescatter_shape.push_back(param_shape.at(LongToSize(1 - axis_)));
1512   }
1513   result += ListProduct(reducescatter_shape) * static_cast<double>(outputs_type_lengths_[0]);
1514   return result;
1515 }
1516 
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const1517 double GatherV2PCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1518                                           int64_t stage_id) const {
1519   double result = 0.0;
1520   CheckGlobalDeviceManager();
1521   MS_EXCEPTION_IF_NULL(g_device_manager);
1522   auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
1523 
1524   for (size_t j = 0; j < inputs.size(); ++j) {
1525     if (!is_parameter_[j]) {
1526       continue;
1527     }
1528     TensorInfo input_a_tensor_info = inputs[j];
1529     Shape input_a_shape = input_a_tensor_info.shape();
1530     Shape input_a_slice_shape = input_a_tensor_info.slice_shape();
1531     int64_t used_device_num = 1;
1532     for (size_t i = 0; i < input_a_shape.size(); ++i) {
1533       used_device_num *= input_a_shape[i] / input_a_slice_shape[i];
1534     }
1535     if (total_device_num != LongToSize(used_device_num)) {
1536       result += ListProduct(input_a_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1537     }
1538   }
1539   return result;
1540 }
1541 
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const1542 double UniformCandidateSamplerCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
1543                                                               const std::vector<TensorInfo> &, int64_t) const {
1544   Shape input0_slice_shape = inputs[0].slice_shape();
1545   if (inputs_type_lengths_.size() != inputs.size()) {
1546     MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size()
1547                       << " for UniformCandidateSampler cost";
1548   }
1549 
1550   double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1551 
1552   return result;
1553 }
1554 
CalculateOutputInMemory()1555 void UniformCandidateSamplerCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
1556 
CalculateInputsInMemory(const std::map<size_t,bool> &)1557 void UniformCandidateSamplerCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
1558   is_inputs_should_in_memory_[0] = is_parameter_[0];
1559 }
1560 
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const1561 double GatherV2PCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1562                                                 int64_t) const {
1563   double result = 0.0;
1564   Shape input0_slice_shape = inputs[0].slice_shape();
1565   Shape input1_slice_shape = inputs[1].slice_shape();
1566   if (inputs_type_lengths_.size() != inputs.size()) {
1567     MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for gatherv2 cost";
1568   }
1569   // don't split axis
1570   if (strategy_.at(LongToSize(axis_)) == 1) {
1571     result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
1572               ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
1573   } else {
1574     // split axis
1575     result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * GATHERV2_COST_WEIGHT0 +
1576               ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * GATHERV2_COST_WEIGHT1;
1577   }
1578 
1579   return result;
1580 }
1581 
GetBackwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1582 double GatherV2PCost::GetBackwardComputationCost(const std::vector<TensorInfo> &inputs,
1583                                                  const std::vector<TensorInfo> &outputs, int64_t) const {
1584   double result = 0.0;
1585   Shape input1_slice_shape = inputs[1].slice_shape();
1586   Shape output0_slice_shape = outputs[0].slice_shape();
1587   // don't split axis
1588   if (strategy_.at(LongToSize(axis_)) == 1) {
1589     result += ListProduct(output0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1590   } else {
1591     // split axis
1592     result += ListProduct(output0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * GATHERV2_COST_WEIGHT2 +
1593               ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * GATHERV2_COST_WEIGHT3;
1594   }
1595 
1596   return result;
1597 }
1598 
1599 // The forward communication is determined by whether the slice is column split or row split
1600 // The number of segments is actually the shape[0] of the output, which is the cost of the AllReduce
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1601 double UnsortedSegmentSumCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs,
1602                                                   const std::vector<TensorInfo> &outputs, int64_t) const {
1603   TensorInfo input0 = inputs[0];
1604   TensorInfo input1 = inputs[1];
1605   TensorInfo output0 = outputs[0];
1606   Shape input0_shape = input0.shape();
1607   Shape input0_slice_shape = inputs[0].slice_shape();
1608   double result = 0.0;
1609   if (inputs_type_lengths_.size() != inputs.size()) {
1610     MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for UnsortedSegmentSum cost";
1611   }
1612   // If the shape b is not the same as the shape a, we regard it as column slice
1613   for (size_t i = 0; i < input1.shape().size(); ++i) {
1614     if (input0_shape[i] != input0_slice_shape[i]) {
1615       result = ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
1616       return result;
1617     }
1618   }
1619   return result;
1620 }
1621 
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1622 double UnsortedSegmentSumCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs,
1623                                                    const std::vector<TensorInfo> &outputs, int64_t) const {
1624   TensorInfo input0 = inputs[0];
1625   TensorInfo input1 = inputs[1];
1626   TensorInfo output0 = outputs[0];
1627   Shape input0_shape = input0.shape();
1628   Shape input0_slice_shape = inputs[0].slice_shape();
1629   double result = 0.0;
1630   if (inputs_type_lengths_.size() != inputs.size()) {
1631     MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for UnsortedSegmentSum cost";
1632   }
1633   if (is_parameter_[0]) {
1634     // If the forward process has a AllReduce, then the backward also needs one.
1635     for (size_t i = 0; i < input1.shape().size(); ++i) {
1636       if (input0_shape[i] != input0_slice_shape[i]) {
1637         result = ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
1638         return result;
1639       }
1640     }
1641   }
1642   return result;
1643 }
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1644 double UnsortedSegmentSumCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
1645                                                          const std::vector<TensorInfo> &outputs, int64_t) const {
1646   // In forward phase, the computation cost = slice(A) + slice(B)
1647   Shape input0_slice_shape = inputs[0].slice_shape();
1648   Shape input1_slice_shape = inputs[1].slice_shape();
1649   Shape output_slice_shape = outputs[0].slice_shape();
1650   double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
1651                   ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) +
1652                   ListProduct(output_slice_shape) * static_cast<double>(outputs_type_lengths_[0]);
1653   return result;
1654 }
1655 
1656 // Not taking account of output
CalculateOutputInMemory()1657 void UnsortedSegmentSumCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
1658 
1659 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1660 void UnsortedSegmentSumCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1661   // When calculating 'dx', taking account of 'y'
1662   if (is_parameter_[0]) {
1663     is_inputs_should_in_memory_[0] = true;
1664     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1665       is_inputs_should_in_memory_[1] = true;
1666     }
1667   } else if (is_parameter_involve_[0]) {
1668     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1669       is_inputs_should_in_memory_[1] = true;
1670     }
1671   }
1672 
1673   if (!is_inputs_should_in_memory_[1]) {
1674     is_inputs_should_in_memory_[1] = is_parameter_[1];
1675   }
1676   is_inputs_should_in_memory_[UNSORTEDSEGMENTSUM_INPUTS_SIZE - 1] = is_parameter_[UNSORTEDSEGMENTSUM_INPUTS_SIZE - 1];
1677 }
1678 
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1679 double UnsortedSegmentMinCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs,
1680                                                   const std::vector<TensorInfo> &outputs, int64_t) const {
1681   TensorInfo input0 = inputs[0];
1682   TensorInfo input1 = inputs[1];
1683   TensorInfo output0 = outputs[0];
1684   Shape input0_shape = input0.shape();
1685   Shape input0_slice_shape = inputs[0].slice_shape();
1686   double result = 0.0;
1687   if (inputs_type_lengths_.size() != inputs.size()) {
1688     MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size()
1689                       << " for UnsortedSegmentMinCost cost";
1690   }
1691   // If the shape b is not the same as the shape a, we regard it as column slice
1692   // The cost is a AllGather operation, the shape is the same as the output of UnsortedSegmentMin.
1693   for (size_t i = 0; i < input1.shape().size(); ++i) {
1694     if (input0_shape[i] != input0_slice_shape[i]) {
1695       result = ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
1696       return result;
1697     }
1698   }
1699   return result;
1700 }
1701 
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1702 double UnsortedSegmentMinCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs,
1703                                                    const std::vector<TensorInfo> &outputs, int64_t) const {
1704   TensorInfo input0 = inputs[0];
1705   TensorInfo input1 = inputs[1];
1706   TensorInfo output0 = outputs[0];
1707   Shape input0_shape = input0.shape();
1708   Shape input0_slice_shape = inputs[0].slice_shape();
1709   double result = 0.0;
1710   if (inputs_type_lengths_.size() != inputs.size()) {
1711     MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size()
1712                       << " for UnsortedSegmentMinCost cost";
1713   }
1714   if (is_parameter_[0]) {
1715     // If the forward process has a AllGather, then the backward also needs one ReduceScatter.
1716     for (size_t i = 0; i < input1.shape().size(); ++i) {
1717       if (input0_shape[i] != input0_slice_shape[i]) {
1718         result = ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
1719         return result;
1720       }
1721     }
1722   }
1723   return result;
1724 }
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1725 double UnsortedSegmentMinCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
1726                                                          const std::vector<TensorInfo> &outputs, int64_t) const {
1727   // In forward phase, the computation cost = slice(A) + slice(B)
1728   Shape input0_slice_shape = inputs[0].slice_shape();
1729   Shape input1_slice_shape = inputs[1].slice_shape();
1730   Shape output_slice_shape = outputs[0].slice_shape();
1731   // The forward operation is UnsortedSegmentMin + ReudceMin
1732   double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
1733                   ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) +
1734                   ListProduct(output_slice_shape) * static_cast<double>(outputs_type_lengths_[0]) +
1735                   ListProduct(output_slice_shape) * static_cast<double>(outputs_type_lengths_[0]);  // ReduceMin
1736   return result;
1737 }
1738 
1739 // Taking account of output
CalculateOutputInMemory()1740 void UnsortedSegmentMinCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
1741 
1742 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1743 void UnsortedSegmentMinCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1744   // When calculating 'dx', taking account of 'x', 'y' and 'z'
1745   if (is_parameter_involve_[0]) {
1746     if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
1747       is_inputs_should_in_memory_[0] = true;
1748     }
1749     if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1750       is_inputs_should_in_memory_[1] = true;
1751     }
1752     if ((prev_output_in_mem.find(UNSORTEDSEGMENTSUM_INPUTS_SIZE - 1) == prev_output_in_mem.end()) ||
1753         (!prev_output_in_mem.at(UNSORTEDSEGMENTSUM_INPUTS_SIZE - 1))) {
1754       is_inputs_should_in_memory_[UNSORTEDSEGMENTSUM_INPUTS_SIZE - 1] = true;
1755     }
1756   }
1757   if (!is_inputs_should_in_memory_[1]) {
1758     is_inputs_should_in_memory_[1] = is_parameter_[1];
1759   }
1760   if (!is_inputs_should_in_memory_[UNSORTEDSEGMENTSUM_INPUTS_SIZE - 1]) {
1761     is_inputs_should_in_memory_[UNSORTEDSEGMENTSUM_INPUTS_SIZE - 1] = is_parameter_[UNSORTEDSEGMENTSUM_INPUTS_SIZE - 1];
1762   }
1763 }
1764 
1765 // Not taking account of output
CalculateOutputInMemory()1766 void VirtualDatasetCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
1767 
1768 // Not taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> &)1769 void VirtualDatasetCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
1770   for (size_t i = 0; i < is_inputs_should_in_memory_.size(); ++i) {
1771     is_inputs_should_in_memory_[i] = is_parameter_[i];
1772   }
1773 }
1774 
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const1775 double MatmulDDSCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1776                                                 int64_t) const {
1777   double result = 0.0;
1778   if (inputs_type_lengths_.size() != inputs.size()) {
1779     MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for layer norm cost";
1780   }
1781 
1782   for (size_t index = 0; index < inputs.size(); ++index) {
1783     TensorInfo tensor_info = inputs[index];
1784     Shape slice_shape = tensor_info.slice_shape();
1785     result += ListProduct(slice_shape) * static_cast<double>(inputs_type_lengths_[index]);
1786   }
1787   return result;
1788 }
1789 
1790 // Not taking account of output
CalculateOutputInMemory()1791 void MatmulDDSCost::CalculateOutputInMemory() {
1792   is_output_should_in_memory_ =
1793     (std::find(is_parameter_involve_.begin(), is_parameter_involve_.end(), true) != is_parameter_involve_.end());
1794 }
1795 
1796 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> &)1797 void MatmulDDSCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
1798   bool keep_mem =
1799     (std::find(is_parameter_.begin(), is_parameter_.end(), true) != is_parameter_.end()) ||
1800     (std::find(is_parameter_involve_.begin(), is_parameter_involve_.end(), true) != is_parameter_involve_.end());
1801   std::fill(is_inputs_should_in_memory_.begin(), is_inputs_should_in_memory_.end(), keep_mem);
1802 }
1803 }  // namespace parallel
1804 }  // namespace mindspore
1805