1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/auto_parallel/operator_costmodel.h"
18
19 #include <algorithm>
20 #include <random>
21 #include "frontend/parallel/device_matrix.h"
22 #include "frontend/parallel/tensor_layout/tensor_redistribution.h"
23
24 namespace mindspore {
25 namespace parallel {
set_is_parameter(const std::vector<bool> & is_parameter)26 void OperatorCost::set_is_parameter(const std::vector<bool> &is_parameter) { is_parameter_ = is_parameter; }
27
set_is_parameter_involve(const std::vector<bool> & is_parameter_inv)28 void OperatorCost::set_is_parameter_involve(const std::vector<bool> &is_parameter_inv) {
29 is_parameter_involve_ = is_parameter_inv;
30 is_inputs_should_in_memory_ = std::vector<bool>(is_parameter_involve_.size(), false);
31 }
32
set_output_parameter_involve(int64_t output_para)33 void OperatorCost::set_output_parameter_involve(int64_t output_para) { output_parameter_involve_ = output_para; }
34
SetInputAndOutputTypeLength(const std::vector<size_t> & input_lengths,const std::vector<size_t> & output_lengths)35 void OperatorCost::SetInputAndOutputTypeLength(const std::vector<size_t> &input_lengths,
36 const std::vector<size_t> &output_lengths) {
37 inputs_type_lengths_ = input_lengths;
38 outputs_type_lengths_ = output_lengths;
39 }
40
set_output_critical(int64_t critical)41 void OperatorCost::set_output_critical(int64_t critical) { is_outputs_critical_ = critical; }
42
GetMemoryCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs) const43 double OperatorCost::GetMemoryCost(const std::vector<TensorInfo> &inputs,
44 const std::vector<TensorInfo> &outputs) const {
45 return GetInputMemoryCost(inputs, outputs) + GetOutputMemoryCost(inputs, outputs);
46 }
47
GetInputMemoryCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &) const48 double OperatorCost::GetInputMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &) const {
49 double result = 0.0;
50 for (size_t i = 0; i < inputs.size(); ++i) {
51 if (is_inputs_should_in_memory_[i]) {
52 result += ListProduct(inputs[i].slice_shape()) * static_cast<double>(inputs_type_lengths_[i]);
53 }
54 }
55 return result;
56 }
57
GetOutputMemoryCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> & outputs) const58 double OperatorCost::GetOutputMemoryCost(const std::vector<TensorInfo> &,
59 const std::vector<TensorInfo> &outputs) const {
60 double result = 0.0;
61 if (is_output_should_in_memory_) {
62 // When this operator has multiple outputs, they all contributes to the memory.
63 for (size_t i = 0; i < outputs.size(); ++i) {
64 result += ListProduct(outputs[i].slice_shape()) * static_cast<double>(outputs_type_lengths_[i]);
65 }
66 }
67 return result;
68 }
69
GetMemoryCostForInference(const std::vector<TensorInfo> &,const std::vector<TensorInfo> & outputs) const70 double OperatorCost::GetMemoryCostForInference(const std::vector<TensorInfo> &,
71 const std::vector<TensorInfo> &outputs) const {
72 double result = 0.0;
73 if (is_outputs_critical_ == -1) {
74 MS_LOG(EXCEPTION) << "The critical flag is not set.";
75 }
76 if (is_outputs_critical_ == 1) {
77 for (size_t i = 0; i < outputs.size(); ++i) {
78 result += ListProduct(outputs[i].slice_shape()) * static_cast<double>(outputs_type_lengths_[i]);
79 }
80 }
81 return result;
82 }
83
84 // return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const85 double MatMulCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
86 int64_t) const {
87 TensorInfo input0 = inputs[0];
88 TensorInfo output0 = outputs[0];
89 Shape input0_shape = input0.shape();
90 Shape input0_slice_shape = input0.slice_shape();
91 if (input0_shape[input0_shape.size() - 1] == input0_slice_shape[input0_slice_shape.size() - 1]) {
92 // If the reduced dimension has not been partitioned, then there is no communication cost.
93 return 0.0;
94 } else {
95 // Else, the communication cost is the size (number of bytes) of a slice of output tensor.
96 return ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
97 }
98 }
99
100 // return the per device communication cost in the forward phase.
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const101 double MatMulCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
102 int64_t stage_id) const {
103 // In backward phase, the communication cost is incurred only when tensor B is a Parameter and tensor B does not
104 // fully utilize all devices
105 double result = 0.0;
106 if (is_parameter_[1]) {
107 TensorInfo input1 = inputs[1]; // tensor B
108 CheckGlobalDeviceManager();
109 MS_EXCEPTION_IF_NULL(g_device_manager);
110 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
111
112 Shape input1_shape = input1.shape();
113 Shape input1_slice_shape = input1.slice_shape();
114 int64_t used_device_num = 1;
115 for (size_t i = 0; i < input1_shape.size(); ++i) {
116 used_device_num *= input1_shape[i] / input1_slice_shape[i];
117 }
118
119 if (total_device_num != LongToSize(used_device_num)) {
120 result += ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
121 }
122 }
123
124 return result;
125 }
126
127 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
128 // this operator uses
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const129 double MatMulCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
130 const std::vector<TensorInfo> &outputs, int64_t) const {
131 // In forward phase, the computation cost = slice(A) + slice(B) + (0 or 1) allreduce(slice(C))
132 double result = 0.0;
133 TensorInfo output0 = outputs[0];
134 Shape input0_slice_shape = inputs[0].slice_shape();
135 Shape input1_slice_shape = inputs[1].slice_shape();
136 Shape input0_shape = inputs[0].shape();
137 if (input0_shape[input0_shape.size() - 1] != input0_slice_shape[input0_slice_shape.size() - 1]) {
138 // If the reduced dimension has been partitioned, then there is no communication cost.
139 result += ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
140 }
141 result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
142 ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
143 return result;
144 }
145
146 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
147 // this operator uses
GetBackwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const148 double MatMulCost::GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
149 int64_t stage_id) const {
150 // In backward phase, the computation cost = (0 or 1) allreduce(slice(B))
151 double result = 0.0;
152 if (is_parameter_[1]) {
153 TensorInfo input1 = inputs[1]; // tensor B
154 CheckGlobalDeviceManager();
155 MS_EXCEPTION_IF_NULL(g_device_manager);
156 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
157
158 Shape input1_shape = input1.shape();
159 Shape input1_slice_shape = input1.slice_shape();
160 int64_t used_device_num = 1;
161 for (size_t i = 0; i < input1_shape.size(); ++i) {
162 used_device_num *= input1_shape[i] / input1_slice_shape[i];
163 }
164
165 if (total_device_num != LongToSize(used_device_num)) {
166 result += ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
167 }
168 }
169
170 return result;
171 }
172
173 // Not taking account of output
CalculateOutputInMemory()174 void MatMulCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
175
176 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)177 void MatMulCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
178 if (is_parameter_[0]) {
179 is_inputs_should_in_memory_[0] = true;
180 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
181 is_inputs_should_in_memory_[1] = true;
182 }
183 } else if (is_parameter_involve_[0]) {
184 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
185 is_inputs_should_in_memory_[1] = true;
186 }
187 }
188
189 if (is_parameter_[1]) {
190 is_inputs_should_in_memory_[1] = true;
191 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
192 is_inputs_should_in_memory_[0] = true;
193 }
194 } else if (is_parameter_involve_[1]) {
195 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
196 is_inputs_should_in_memory_[0] = true;
197 }
198 }
199 }
200
201 // return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const202 double BatchNormCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
203 int64_t) const {
204 TensorInfo input0 = inputs[0];
205 Shape input0_shape = input0.shape();
206 if (input0_shape.size() < 2) {
207 MS_LOG(EXCEPTION) << "The dimension of first input can not be smaller than 2, but got " << input0_shape.size();
208 }
209 Shape input0_slice_shape = input0.slice_shape();
210 if (input0_shape[1] == input0_slice_shape[1]) {
211 // If the 'channel' dimension has not been partitioned, then there is no communication cost.
212 return 0.0;
213 } else {
214 // Else, the communication cost is the size (number of bytes) of a slice of input tensor.
215 return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
216 }
217 }
218
219 // return the per device communication cost in the forward phase.
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const220 double BatchNormCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
221 int64_t stage_id) const {
222 // In backward phase, the communication cost is incurred only when last 4 tensor are Parameter and they does not
223 // fully utilize all devices
224 double result = 0.0, tmp_cost = 0.0;
225
226 TensorInfo input1 = inputs[1]; // tensor gamma
227 CheckGlobalDeviceManager();
228 MS_EXCEPTION_IF_NULL(g_device_manager);
229 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
230
231 Shape input1_shape = input1.shape();
232 Shape input1_slice_shape = input1.slice_shape();
233 int64_t used_device_num = 1;
234 for (size_t i = 0; i < input1_shape.size(); ++i) {
235 used_device_num *= input1_shape[i] / input1_slice_shape[i];
236 }
237
238 if (total_device_num != LongToSize(used_device_num)) {
239 tmp_cost = ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
240 }
241
242 for (size_t i = 1; i < is_parameter_.size(); ++i) {
243 if (is_parameter_[i]) {
244 result += tmp_cost;
245 }
246 }
247 return result;
248 }
249
250 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
251 // this operator uses
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const252 double BatchNormCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
253 int64_t) const {
254 double result = 0.0;
255 for (size_t i = 0; i < inputs.size(); ++i) {
256 result += ListProduct(inputs[i].slice_shape()) * static_cast<double>(inputs_type_lengths_[i]);
257 }
258 return result;
259 }
260
261 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
262 // this operator uses
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const263 double BatchNormCost::GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
264 int64_t) const {
265 return 0.0;
266 }
267
268 // Not taking account of output
CalculateOutputInMemory()269 void BatchNormCost::CalculateOutputInMemory() { is_output_should_in_memory_ = true; }
270
271 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> &)272 void BatchNormCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
273 for (size_t i = 0; i < is_inputs_should_in_memory_.size(); ++i) {
274 is_inputs_should_in_memory_[0] = true;
275 }
276 }
277
278 // Return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const279 double CastCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const {
280 // ReLU is the element-wise operator, thus it does not need communication in the forward phase
281 return 0.0;
282 }
283
284 // Return the per device communication cost in the backward phase.
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const285 double CastCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
286 int64_t stage_id) const {
287 double result = 0.0;
288 if (is_parameter_[0]) {
289 TensorInfo input1 = inputs[0];
290 MS_EXCEPTION_IF_NULL(g_device_manager);
291 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
292 Shape input1_shape = input1.shape();
293 Shape input1_slice_shape = input1.slice_shape();
294 int64_t used_device_num = 1;
295 for (size_t i = 0; i < input1_shape.size(); ++i) {
296 used_device_num *= input1_shape[i] / input1_slice_shape[i];
297 }
298 if (total_device_num != LongToSize(used_device_num)) {
299 result = ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
300 }
301 }
302 return result;
303 }
304
305 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
306 // this operator uses
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const307 double CastCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
308 int64_t) const {
309 TensorInfo input0 = inputs[0];
310 Shape input0_slice_shape = input0.slice_shape();
311 return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
312 }
313
314 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
315 // this operator uses
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const316 double CastCost::GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
317 int64_t) const {
318 return 0.0;
319 }
320
321 // Not taking account of output
CalculateOutputInMemory()322 void CastCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
323
324 // Not taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> &)325 void CastCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
326 is_inputs_should_in_memory_[0] = is_parameter_[0];
327 }
328
329 // Taking account of output
CalculateOutputInMemory()330 void SqrtCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
331
332 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)333 void GeLUCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
334 if (is_parameter_[0]) {
335 is_inputs_should_in_memory_[0] = true;
336 } else if (is_parameter_involve_[0]) {
337 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
338 is_inputs_should_in_memory_[0] = true;
339 }
340 }
341 }
342
343 // Return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const344 double SoftmaxCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
345 int64_t) const {
346 // In the forward phase, the communication cost = 0
347 return 0.0;
348 }
349
350 // Return the per device communication cost in the backward phase.
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const351 double SoftmaxCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
352 int64_t stage_id) const {
353 double result = 0.0;
354 if (is_parameter_[0]) {
355 TensorInfo input1 = inputs[0];
356 MS_EXCEPTION_IF_NULL(g_device_manager);
357 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
358 Shape input1_shape = input1.shape();
359 Shape input1_slice_shape = input1.slice_shape();
360 int64_t used_device_num = 1;
361 for (size_t i = 0; i < input1_shape.size(); ++i) {
362 used_device_num *= input1_shape[i] / input1_slice_shape[i];
363 }
364 if (total_device_num != LongToSize(used_device_num)) {
365 result = ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
366 }
367 }
368 return result;
369 }
370
371 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
372 // this operator uses
GetForwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> & outputs,int64_t) const373 double SoftmaxCost::GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &outputs,
374 int64_t) const {
375 if (outputs.empty() || outputs_type_lengths_.empty()) {
376 MS_LOG(EXCEPTION) << "The outputs or outputs_type_length is empty";
377 }
378
379 // use output for Tile operator
380 TensorInfo output_info = outputs[0];
381 Shape output_slice_shape = output_info.slice_shape();
382 return ListProduct(output_slice_shape) * static_cast<double>(outputs_type_lengths_[0]);
383 }
384
385 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
386 // this operator uses
GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,const std::vector<mindspore::parallel::TensorInfo> &,int64_t) const387 double SoftmaxCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,
388 const std::vector<mindspore::parallel::TensorInfo> &, int64_t) const {
389 return 0.0;
390 }
391
392 // Taking account of output
CalculateOutputInMemory()393 void SoftmaxCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
394
395 // Not taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> &)396 void SoftmaxCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
397 is_inputs_should_in_memory_[0] = is_parameter_[0];
398 }
399
400 // Not taking account of output
CalculateOutputInMemory()401 void PackCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
402
403 // Not taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> &)404 void PackCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
405 is_inputs_should_in_memory_[0] = is_parameter_[0];
406 }
407
408 // Not taking account of output
CalculateOutputInMemory()409 void TileCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
410
411 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)412 void TileCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
413 // When calculating 'dx', taking account of 'y'
414 if (is_parameter_[0]) {
415 is_inputs_should_in_memory_[0] = true;
416 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
417 is_inputs_should_in_memory_[1] = true;
418 }
419 } else if (is_parameter_involve_[0]) {
420 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
421 is_inputs_should_in_memory_[1] = true;
422 }
423 }
424
425 if (!is_inputs_should_in_memory_[1]) {
426 is_inputs_should_in_memory_[1] = is_parameter_[1];
427 }
428 }
429
430 // Not taking account of output
CalculateOutputInMemory()431 void BroadcastToCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
432
CalculateInputsInMemory(const std::map<size_t,bool> &)433 void BroadcastToCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
434 is_inputs_should_in_memory_[0] = is_parameter_[0];
435 }
436
437 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)438 void ReLU6Cost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
439 if (is_parameter_[0]) {
440 is_inputs_should_in_memory_[0] = true;
441 } else if (is_parameter_involve_[0]) {
442 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
443 is_inputs_should_in_memory_[0] = true;
444 }
445 }
446 }
447
448 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)449 void TransposeCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
450 // When calculating 'dx', taking account of 'y'
451 if (is_parameter_[0]) {
452 is_inputs_should_in_memory_[0] = true;
453 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
454 is_inputs_should_in_memory_[1] = true;
455 }
456 } else if (is_parameter_involve_[0]) {
457 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
458 is_inputs_should_in_memory_[1] = true;
459 }
460 }
461
462 if (!is_inputs_should_in_memory_[1]) {
463 is_inputs_should_in_memory_[1] = is_parameter_[1];
464 }
465 }
466
467 // return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<mindspore::parallel::TensorInfo> &,const std::vector<mindspore::parallel::TensorInfo> &,int64_t) const468 double TmpIdentityCost::GetForwardCommCost(const std::vector<mindspore::parallel::TensorInfo> &,
469 const std::vector<mindspore::parallel::TensorInfo> &, int64_t) const {
470 // Identity is the element-wise operator, thus it does not need communication in the forward phase
471 return 0.0;
472 }
473
474 // return the per device communication cost in the backward phase.
GetBackwardCommCost(const std::vector<mindspore::parallel::TensorInfo> &,const std::vector<mindspore::parallel::TensorInfo> &,int64_t) const475 double TmpIdentityCost::GetBackwardCommCost(const std::vector<mindspore::parallel::TensorInfo> &,
476 const std::vector<mindspore::parallel::TensorInfo> &, int64_t) const {
477 // Identity is the element-wise operator, thus it does not need communication in the backward phase
478 return 0.0;
479 }
480
481 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
482 // this operator uses
GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,const std::vector<mindspore::parallel::TensorInfo> &,int64_t) const483 double TmpIdentityCost::GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,
484 const std::vector<mindspore::parallel::TensorInfo> &, int64_t) const {
485 return 0.0;
486 }
487
488 // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes
489 // this operator uses
GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,const std::vector<mindspore::parallel::TensorInfo> &,int64_t) const490 double TmpIdentityCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,
491 const std::vector<mindspore::parallel::TensorInfo> &,
492 int64_t) const {
493 return 0.0;
494 }
495
496 // Not taking account of output
CalculateOutputInMemory()497 void TmpIdentityCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
498
499 // Not taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> &)500 void TmpIdentityCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
501 is_inputs_should_in_memory_[0] = is_parameter_[0];
502 }
503
GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> & inputs,const std::vector<mindspore::parallel::TensorInfo> &,int64_t) const504 double BatchParallelCost::GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &inputs,
505 const std::vector<mindspore::parallel::TensorInfo> &,
506 int64_t) const {
507 double cost = 0.0;
508 for (size_t i = 0; i < inputs.size(); ++i) {
509 cost += ListProduct(inputs[i].slice_shape()) * static_cast<double>(inputs_type_lengths_[i]);
510 }
511 return cost;
512 }
513
GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,const std::vector<mindspore::parallel::TensorInfo> &,int64_t) const514 double BatchParallelCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,
515 const std::vector<mindspore::parallel::TensorInfo> &,
516 int64_t) const {
517 return 0.0;
518 }
519
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const520 double BatchParallelCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
521 int64_t stage_id) const {
522 double result = 0.0;
523 CheckGlobalDeviceManager();
524 MS_EXCEPTION_IF_NULL(g_device_manager);
525 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
526
527 for (size_t j = 0; j < inputs.size(); ++j) {
528 if (!is_parameter_[j]) {
529 continue;
530 }
531 TensorInfo input_a_tensor_info = inputs[j];
532 Shape input_a_shape = input_a_tensor_info.shape();
533 Shape input_a_slice_shape = input_a_tensor_info.slice_shape();
534 int64_t used_device_num = 1;
535 for (size_t i = 0; i < input_a_shape.size(); ++i) {
536 used_device_num *= input_a_shape[i] / input_a_slice_shape[i];
537 }
538 if (total_device_num != LongToSize(used_device_num)) {
539 result += ListProduct(input_a_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
540 }
541 }
542
543 return result;
544 }
545
CalculateOutputInMemory()546 void BatchParallelCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
547
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)548 void BatchParallelCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
549 if (is_parameter_[0]) {
550 is_inputs_should_in_memory_[0] = true;
551 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
552 is_inputs_should_in_memory_[1] = true;
553 }
554 } else if (is_parameter_involve_[0]) {
555 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
556 is_inputs_should_in_memory_[1] = true;
557 }
558 }
559
560 if (is_parameter_[1]) {
561 is_inputs_should_in_memory_[1] = true;
562 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
563 is_inputs_should_in_memory_[0] = true;
564 }
565 } else if (is_parameter_involve_[1]) {
566 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
567 is_inputs_should_in_memory_[0] = true;
568 }
569 }
570 }
571
CalculateOutputInMemory()572 void SparseSoftmaxCrossEntropyWithLogitsCost::CalculateOutputInMemory() {
573 is_output_should_in_memory_ = is_parameter_involve_[0];
574 }
575
CalculateInputsInMemory(const std::map<size_t,bool> &)576 void SparseSoftmaxCrossEntropyWithLogitsCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
577 is_inputs_should_in_memory_[0] = is_parameter_[0];
578 is_inputs_should_in_memory_[1] = is_parameter_[1];
579 }
580 // return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const581 double PReLUCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const {
582 // prelu does not need communication in the forward phase
583 return 0.0;
584 }
585
586 // return the per device communication cost in the backward phase.
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const587 double PReLUCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
588 int64_t stage_id) const {
589 double result = 0.0;
590 if (is_parameter_[1]) {
591 TensorInfo input1 = inputs[1];
592 CheckGlobalDeviceManager();
593 MS_EXCEPTION_IF_NULL(g_device_manager);
594 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
595 Shape input1_shape = input1.shape();
596 Shape input1_slice_shape = input1.slice_shape();
597 int64_t used_device_num = 1;
598 for (size_t i = 0; i < input1_shape.size(); ++i) {
599 used_device_num *= input1_shape[i] / input1_slice_shape[i];
600 }
601 if (total_device_num != LongToSize(used_device_num)) {
602 result = ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
603 }
604 }
605 return result;
606 }
607
608 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
609 // this operator uses
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const610 double PReLUCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
611 int64_t) const {
612 // In forward phase, the computation cost = slice(A) + slice(B)
613 Shape input0_slice_shape = inputs[0].slice_shape();
614 Shape input1_slice_shape = inputs[1].slice_shape();
615 double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
616 ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
617 return result;
618 }
619
620 // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes
621 // this operator uses
GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> & inputs,const std::vector<mindspore::parallel::TensorInfo> &,int64_t stage_id) const622 double PReLUCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &inputs,
623 const std::vector<mindspore::parallel::TensorInfo> &,
624 int64_t stage_id) const {
625 // In backward phase, the computation cost = (0 or 1) allreduce(slice(B))
626 double result = 0.0;
627 if (is_parameter_[1]) {
628 TensorInfo input1 = inputs[1]; // tensor B
629 CheckGlobalDeviceManager();
630 MS_EXCEPTION_IF_NULL(g_device_manager);
631 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
632
633 Shape input1_shape = input1.shape();
634 Shape input1_slice_shape = input1.slice_shape();
635 int64_t used_device_num = 1;
636 for (size_t i = 0; i < input1_shape.size(); ++i) {
637 used_device_num *= input1_shape[i] / input1_slice_shape[i];
638 }
639
640 if (total_device_num != LongToSize(used_device_num)) {
641 result += ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
642 }
643 }
644 return result;
645 }
646
CalculateOutputInMemory()647 void PReLUCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
648
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)649 void PReLUCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
650 // When calculating 'dx', taking account of both 'x' and 'y';
651 // when calculating 'dy', taking account of both 'x' and 'y'
652 if (is_parameter_involve_[0] || is_parameter_involve_[1]) {
653 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
654 is_inputs_should_in_memory_[0] = true;
655 }
656 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
657 is_inputs_should_in_memory_[1] = true;
658 }
659 }
660 }
661
662 // return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const663 double OneHotCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const {
664 // onehot does not need communication in the forward phase
665 return 0.0;
666 }
667
668 // return the per device communication cost in the backward phase.
GetBackwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const669 double OneHotCost::GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
670 int64_t) const {
671 // onehot does not need communication in the backward phase
672 return 0.0;
673 }
674
675 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
676 // this operator uses
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const677 double OneHotCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
678 int64_t) const {
679 // In onehot's forward phase, the computation cost = slice(A)
680 Shape input0_slice_shape = inputs[0].slice_shape();
681 return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
682 }
683
684 // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes
685 // this operator uses
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const686 double OneHotCost::GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
687 int64_t) const {
688 return 0.0;
689 }
690
691 // Not taking account of output
CalculateOutputInMemory()692 void OneHotCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
693
694 // Not taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> &)695 void OneHotCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
696 is_inputs_should_in_memory_[0] = is_parameter_[0];
697 is_inputs_should_in_memory_[1] = is_parameter_[1];
698 is_inputs_should_in_memory_[ONEHOT_INPUTS_SIZE - 2] = is_parameter_[ONEHOT_INPUTS_SIZE - 2];
699 is_inputs_should_in_memory_[ONEHOT_INPUTS_SIZE - 1] = is_parameter_[ONEHOT_INPUTS_SIZE - 1];
700 }
701
702 // return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const703 double SoftmaxCrossEntropyWithLogitsCost::GetForwardCommCost(const std::vector<TensorInfo> &,
704 const std::vector<TensorInfo> &, int64_t) const {
705 // SoftmaxCrossEntropyWithLogitsCost does not need communication in the forward phase
706 return 0.0;
707 }
708
709 // return the per device communication cost in the backward phase.
GetBackwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const710 double SoftmaxCrossEntropyWithLogitsCost::GetBackwardCommCost(const std::vector<TensorInfo> &,
711 const std::vector<TensorInfo> &, int64_t) const {
712 // SoftmaxCrossEntropyWithLogitsCost does not need communication in the backward phase
713 return 0.0;
714 }
715
716 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
717 // this operator uses
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const718 double SoftmaxCrossEntropyWithLogitsCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
719 const std::vector<TensorInfo> &, int64_t) const {
720 // In forward phase, the computation cost = slice(A) + slice(B)
721 Shape input0_slice_shape = inputs[0].slice_shape();
722 Shape input1_slice_shape = inputs[1].slice_shape();
723 double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
724 ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
725 return result;
726 }
727
728 // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes
729 // this operator uses
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const730 double SoftmaxCrossEntropyWithLogitsCost::GetBackwardComputationCost(const std::vector<TensorInfo> &,
731 const std::vector<TensorInfo> &, int64_t) const {
732 return 0.0;
733 }
734
735 // Taking account of output
CalculateOutputInMemory()736 void SoftmaxCrossEntropyWithLogitsCost::CalculateOutputInMemory() {
737 is_output_should_in_memory_ = is_parameter_involve_[0];
738 }
739
CalculateInputsInMemory(const std::map<size_t,bool> &)740 void SoftmaxCrossEntropyWithLogitsCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
741 is_inputs_should_in_memory_[0] = is_parameter_[0];
742 is_inputs_should_in_memory_[1] = is_parameter_[1];
743 }
744
745 // return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id) const746 double ReshapeCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
747 int64_t stage_id) const {
748 CheckGlobalDeviceManager();
749 MS_EXCEPTION_IF_NULL(g_device_manager);
750 RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id);
751 TensorRedistribution tensor_redistribution(false, true);
752 if (tensor_redistribution.Init(inputs[0].tensor_layout(), outputs[0].tensor_layout(), dev_list) == FAILED) {
753 MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed.";
754 }
755 if (tensor_redistribution.ComputeCost() == FAILED) {
756 MS_LOG(EXCEPTION) << "Failure: tensor_redistribution ComputeCost failed.";
757 }
758 return (inputs_type_lengths_[0] * tensor_redistribution.comm_cost());
759 }
760
761 // return the per device communication cost in the backward phase.
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const762 double ReshapeCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
763 int64_t stage_id) const {
764 double result = 0.0;
765 if (is_parameter_[0]) {
766 TensorInfo input1 = inputs[0];
767 MS_EXCEPTION_IF_NULL(g_device_manager);
768 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
769 Shape input1_shape = input1.shape();
770 Shape input1_slice_shape = input1.slice_shape();
771 int64_t used_device_num = 1;
772 for (size_t i = 0; i < input1_shape.size(); ++i) {
773 used_device_num *= input1_shape[i] / input1_slice_shape[i];
774 }
775 if (total_device_num != LongToSize(used_device_num)) {
776 result = ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
777 }
778 }
779 return result;
780 }
781
782 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
783 // this operator uses
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id) const784 double ReshapeCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
785 const std::vector<TensorInfo> &outputs, int64_t stage_id) const {
786 CheckGlobalDeviceManager();
787 MS_EXCEPTION_IF_NULL(g_device_manager);
788 RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id);
789 TensorRedistribution tensor_redistribution(false, true);
790 if (tensor_redistribution.Init(inputs[0].tensor_layout(), outputs[0].tensor_layout(), dev_list) == FAILED) {
791 MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed.";
792 }
793 if (tensor_redistribution.ComputeCost() == FAILED) {
794 MS_LOG(EXCEPTION) << "Failure: tensor_redistribution ComputeCost failed.";
795 }
796 return (inputs_type_lengths_[0] * tensor_redistribution.computation_cost());
797 }
798
799 // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes
800 // this operator uses
GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,const std::vector<mindspore::parallel::TensorInfo> &,int64_t) const801 double ReshapeCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,
802 const std::vector<mindspore::parallel::TensorInfo> &, int64_t) const {
803 return 0.0;
804 }
805
CalculateOutputInMemory()806 void ReshapeCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
807
CalculateInputsInMemory(const std::map<size_t,bool> &)808 void ReshapeCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
809 is_inputs_should_in_memory_[0] = is_parameter_[0];
810 is_inputs_should_in_memory_[1] = is_parameter_[1];
811 }
812
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const813 double SubCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
814 int64_t) const {
815 double result = ListProduct(inputs[0].slice_shape()) * static_cast<double>(inputs_type_lengths_[0]) +
816 ListProduct(inputs[1].slice_shape()) * static_cast<double>(inputs_type_lengths_[1]);
817 return result;
818 }
819
GetBackwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const820 double SubCost::GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
821 int64_t stage_id) const {
822 double result = 0.0;
823 CheckGlobalDeviceManager();
824 MS_EXCEPTION_IF_NULL(g_device_manager);
825 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
826
827 if (is_parameter_[0]) {
828 TensorInfo input_a_tensor_info = inputs[0];
829 Shape input_a_shape = input_a_tensor_info.shape();
830 Shape input_a_slice_shape = input_a_tensor_info.slice_shape();
831 int64_t used_device_num = 1;
832 for (size_t i = 0; i < input_a_shape.size(); ++i) {
833 used_device_num *= input_a_shape[i] / input_a_slice_shape[i];
834 }
835
836 if (total_device_num != LongToSize(used_device_num)) {
837 result += ListProduct(input_a_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
838 }
839 }
840
841 if (is_parameter_[1]) {
842 TensorInfo input_b_tensor_info = inputs[1];
843 Shape input_b_shape = input_b_tensor_info.shape();
844 Shape input_b_slice_shape = input_b_tensor_info.slice_shape();
845 int64_t used_device_num = 1;
846 for (size_t i = 0; i < input_b_shape.size(); ++i) {
847 used_device_num *= input_b_shape[i] / input_b_slice_shape[i];
848 }
849
850 if (total_device_num != LongToSize(used_device_num)) {
851 result += ListProduct(input_b_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
852 }
853 }
854 return result;
855 }
856
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const857 double SubCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
858 int64_t stage_id) const {
859 double result = 0.0;
860 CheckGlobalDeviceManager();
861 MS_EXCEPTION_IF_NULL(g_device_manager);
862 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
863
864 if (is_parameter_[0]) {
865 TensorInfo input_a_tensor_info = inputs[0];
866 Shape input_a_shape = input_a_tensor_info.shape();
867 Shape input_a_slice_shape = input_a_tensor_info.slice_shape();
868 int64_t used_device_num = 1;
869 for (size_t i = 0; i < input_a_shape.size(); ++i) {
870 used_device_num *= input_a_shape[i] / input_a_slice_shape[i];
871 }
872
873 if (total_device_num != LongToSize(used_device_num)) {
874 result += ListProduct(input_a_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
875 }
876 }
877
878 if (is_parameter_[1]) {
879 TensorInfo input_b_tensor_info = inputs[1];
880 Shape input_b_shape = input_b_tensor_info.shape();
881 Shape input_b_slice_shape = input_b_tensor_info.slice_shape();
882 int64_t used_device_num = 1;
883 for (size_t i = 0; i < input_b_shape.size(); ++i) {
884 used_device_num *= input_b_shape[i] / input_b_slice_shape[i];
885 }
886
887 if (total_device_num != LongToSize(used_device_num)) {
888 result += ListProduct(input_b_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
889 }
890 }
891
892 return result;
893 }
894
895 // Not taking account of output
CalculateOutputInMemory()896 void SubCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
897
898 // Not taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> &)899 void SubCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
900 is_inputs_should_in_memory_[0] = is_parameter_[0];
901 is_inputs_should_in_memory_[1] = is_parameter_[1];
902 }
903
904 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)905 void MulCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
906 if (is_parameter_[0]) {
907 // 'x' is parameter, so it should be in memory.
908 is_inputs_should_in_memory_[0] = true;
909 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
910 // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
911 is_inputs_should_in_memory_[1] = true;
912 }
913 } else if (is_parameter_involve_[0]) {
914 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
915 is_inputs_should_in_memory_[1] = true;
916 }
917 }
918
919 if (is_parameter_[1]) {
920 is_inputs_should_in_memory_[1] = true;
921 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
922 is_inputs_should_in_memory_[0] = true;
923 }
924 } else if (is_parameter_involve_[1]) {
925 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
926 is_inputs_should_in_memory_[0] = true;
927 }
928 }
929 }
930
931 // Taking account of output
CalculateOutputInMemory()932 void DivCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[1]; }
933
934 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)935 void DivCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
936 // When calculating 'dx', taking account of 'y'
937 if (is_parameter_[0]) {
938 // 'x' is parameter, so it should be in memory.
939 is_inputs_should_in_memory_[0] = true;
940 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
941 // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
942 is_inputs_should_in_memory_[1] = true;
943 }
944 } else if (is_parameter_involve_[0]) {
945 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
946 is_inputs_should_in_memory_[1] = true;
947 }
948 }
949
950 // When calculating 'dy', taking account of 'y'
951 if (is_parameter_involve_[1]) {
952 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
953 is_inputs_should_in_memory_[1] = true;
954 }
955 }
956 }
957
958 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)959 void ModCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
960 // When calculating 'dx', not taking account of 'x' and 'y'
961 is_inputs_should_in_memory_[0] = is_parameter_[0];
962 // When calculating 'dy', taking account of 'x' and 'y'
963 if (is_parameter_involve_[1]) {
964 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
965 is_inputs_should_in_memory_[0] = true;
966 }
967 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
968 is_inputs_should_in_memory_[1] = true;
969 }
970 }
971 }
972
CalculateOutputInMemory()973 void PowCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[1]; }
974
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)975 void PowCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
976 // When calculating 'dx', taking account of both 'x' and 'power'
977 if (is_parameter_involve_[0]) {
978 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
979 is_inputs_should_in_memory_[0] = true;
980 }
981 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
982 is_inputs_should_in_memory_[1] = true;
983 }
984 }
985 // When calculating 'dpower', taking account of 'x'
986 if (is_parameter_[1]) {
987 is_inputs_should_in_memory_[1] = true;
988 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
989 is_inputs_should_in_memory_[0] = true;
990 }
991 } else if (is_parameter_involve_[1]) {
992 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
993 is_inputs_should_in_memory_[0] = true;
994 }
995 }
996 }
997
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)998 void AssignCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
999 // When calculating 'dx', taking account of 'x'
1000 if (is_parameter_involve_[0]) {
1001 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
1002 is_inputs_should_in_memory_[0] = true;
1003 }
1004 }
1005 // When calculating 'dy', not taking account of 'x' and 'y'
1006 is_inputs_should_in_memory_[1] = is_parameter_[1];
1007 }
1008
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1009 void SigmoidCrossEntropyWithLogitsCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1010 // When calculating 'dx', taking account of both 'x' and 'y'
1011 if (is_parameter_involve_[0]) {
1012 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
1013 is_inputs_should_in_memory_[0] = true;
1014 }
1015 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1016 is_inputs_should_in_memory_[1] = true;
1017 }
1018 }
1019 // When calculating 'dy', not taking account of 'x' and 'y'
1020 if (!is_inputs_should_in_memory_[1]) {
1021 is_inputs_should_in_memory_[1] = is_parameter_[1];
1022 }
1023 }
1024
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1025 void Atan2Cost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1026 // When calculating 'dx', taking account of both 'x' and 'y'; when calculating 'dy', taking account of both 'x' and
1027 // 'y'
1028 if (is_parameter_involve_[0] || is_parameter_involve_[1]) {
1029 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
1030 is_inputs_should_in_memory_[0] = true;
1031 }
1032 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1033 is_inputs_should_in_memory_[1] = true;
1034 }
1035 }
1036 }
1037
CalculateOutputInMemory()1038 void DivNoNanCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[1]; }
1039
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1040 void DivNoNanCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1041 // When calculating 'dx', taking account of 'y'
1042 if (is_parameter_[0]) {
1043 // 'x' is parameter, so it should be in memory.
1044 is_inputs_should_in_memory_[0] = true;
1045 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1046 // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
1047 is_inputs_should_in_memory_[1] = true;
1048 }
1049 } else if (is_parameter_involve_[0]) {
1050 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1051 is_inputs_should_in_memory_[1] = true;
1052 }
1053 }
1054
1055 // When calculating 'dy', taking account of 'y'
1056 if (is_parameter_involve_[1]) {
1057 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1058 is_inputs_should_in_memory_[1] = true;
1059 }
1060 }
1061 }
1062
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1063 void MaximumCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1064 // When calculating 'dx', taking account of both 'x' and 'y';
1065 // when calculating 'dy', taking account of both 'x' and 'y'
1066 if (is_parameter_involve_[0] || is_parameter_involve_[1]) {
1067 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
1068 is_inputs_should_in_memory_[0] = true;
1069 }
1070 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1071 is_inputs_should_in_memory_[1] = true;
1072 }
1073 }
1074 }
1075
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1076 void SliceCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1077 // When calculating 'dx', taking account of 'y' and 'z'
1078 if (is_parameter_[0]) {
1079 is_inputs_should_in_memory_[0] = true;
1080 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1081 is_inputs_should_in_memory_[1] = true;
1082 }
1083 if ((prev_output_in_mem.find(SLICE_INPUTS_SIZE - 1) == prev_output_in_mem.end()) ||
1084 (!prev_output_in_mem.at(SLICE_INPUTS_SIZE - 1))) {
1085 is_inputs_should_in_memory_[SLICE_INPUTS_SIZE - 1] = true;
1086 }
1087 } else if (is_parameter_involve_[0]) {
1088 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1089 is_inputs_should_in_memory_[1] = true;
1090 }
1091 if ((prev_output_in_mem.find(SLICE_INPUTS_SIZE - 1) == prev_output_in_mem.end()) ||
1092 (!prev_output_in_mem.at(SLICE_INPUTS_SIZE - 1))) {
1093 is_inputs_should_in_memory_[SLICE_INPUTS_SIZE - 1] = true;
1094 }
1095 }
1096
1097 if (!is_inputs_should_in_memory_[1]) {
1098 is_inputs_should_in_memory_[1] = is_parameter_[1];
1099 }
1100 if (!is_inputs_should_in_memory_[SLICE_INPUTS_SIZE - 1]) {
1101 is_inputs_should_in_memory_[SLICE_INPUTS_SIZE - 1] = is_parameter_[SLICE_INPUTS_SIZE - 1];
1102 }
1103 }
1104
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1105 void StridedSliceCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1106 // When calculating 'dx', taking account of 'y', 'z' and 'w'
1107 if (is_parameter_[0]) {
1108 is_inputs_should_in_memory_[0] = true;
1109 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1110 is_inputs_should_in_memory_[1] = true;
1111 }
1112 if ((prev_output_in_mem.find(STRIDED_SLICE_INPUTS_SIZE - 2) == prev_output_in_mem.end()) ||
1113 (!prev_output_in_mem.at(STRIDED_SLICE_INPUTS_SIZE - 2))) {
1114 is_inputs_should_in_memory_[STRIDED_SLICE_INPUTS_SIZE - 2] = true;
1115 }
1116 if ((prev_output_in_mem.find(STRIDED_SLICE_INPUTS_SIZE - 1) == prev_output_in_mem.end()) ||
1117 (!prev_output_in_mem.at(STRIDED_SLICE_INPUTS_SIZE - 1))) {
1118 is_inputs_should_in_memory_[STRIDED_SLICE_INPUTS_SIZE - 1] = true;
1119 }
1120 } else if (is_parameter_involve_[0]) {
1121 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1122 is_inputs_should_in_memory_[1] = true;
1123 }
1124 if ((prev_output_in_mem.find(STRIDED_SLICE_INPUTS_SIZE - 2) == prev_output_in_mem.end()) ||
1125 (!prev_output_in_mem.at(STRIDED_SLICE_INPUTS_SIZE - 2))) {
1126 is_inputs_should_in_memory_[STRIDED_SLICE_INPUTS_SIZE - 2] = true;
1127 }
1128 if ((prev_output_in_mem.find(STRIDED_SLICE_INPUTS_SIZE - 1) == prev_output_in_mem.end()) ||
1129 (!prev_output_in_mem.at(STRIDED_SLICE_INPUTS_SIZE - 1))) {
1130 is_inputs_should_in_memory_[STRIDED_SLICE_INPUTS_SIZE - 1] = true;
1131 }
1132 }
1133
1134 if (!is_inputs_should_in_memory_[1]) {
1135 is_inputs_should_in_memory_[1] = is_parameter_[1];
1136 }
1137 if (!is_inputs_should_in_memory_[STRIDED_SLICE_INPUTS_SIZE - 2]) {
1138 is_inputs_should_in_memory_[STRIDED_SLICE_INPUTS_SIZE - 2] = is_parameter_[STRIDED_SLICE_INPUTS_SIZE - 2];
1139 }
1140 if (!is_inputs_should_in_memory_[STRIDED_SLICE_INPUTS_SIZE - 1]) {
1141 is_inputs_should_in_memory_[STRIDED_SLICE_INPUTS_SIZE - 1] = is_parameter_[STRIDED_SLICE_INPUTS_SIZE - 1];
1142 }
1143 }
1144
CalculateOutputInMemory()1145 void DropOutDoMaskCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
1146
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1147 void DropOutDoMaskCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1148 // When calculating 'dx', taking account of 'y'
1149 if (is_parameter_[0]) {
1150 // 'x' is parameter, so it should be in memory.
1151 is_inputs_should_in_memory_[0] = true;
1152 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1153 // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
1154 is_inputs_should_in_memory_[1] = true;
1155 }
1156 } else if (is_parameter_involve_[0]) {
1157 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1158 is_inputs_should_in_memory_[1] = true;
1159 }
1160 }
1161
1162 if (!is_inputs_should_in_memory_[1]) {
1163 is_inputs_should_in_memory_[1] = is_parameter_[1];
1164 }
1165 is_inputs_should_in_memory_[DROPOUTDOMASK_INPUTS_SIZE - 1] = is_parameter_[DROPOUTDOMASK_INPUTS_SIZE - 1];
1166 }
1167
IsDataParallel(const Shape & shape,const Shape & slice_shape,int64_t stage_id)1168 bool IsDataParallel(const Shape &shape, const Shape &slice_shape, int64_t stage_id) {
1169 CheckGlobalDeviceManager();
1170 MS_EXCEPTION_IF_NULL(g_device_manager);
1171 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
1172 auto strategy0 = shape[0] / slice_shape[0];
1173
1174 return (total_device_num == LongToSize(strategy0));
1175 }
1176
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id) const1177 double ReduceSumCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1178 int64_t stage_id) const {
1179 double result = 0.0;
1180 TensorInfo input0 = inputs[0];
1181 TensorInfo output0 = outputs[0];
1182 Shape input0_shape = input0.shape();
1183 Shape input0_slice_shape = input0.slice_shape();
1184 if (cross_batch_ && IsDataParallel(input0_shape, input0_slice_shape, stage_id)) {
1185 return result;
1186 }
1187 std::vector<int64_t> dim_list = input0.reduce_dim();
1188 auto pos = std::find_if(dim_list.begin(), dim_list.end(), [input0_shape, input0_slice_shape](int64_t index) {
1189 return input0_shape[LongToSize(index)] != input0_slice_shape[LongToSize(index)];
1190 });
1191 if (pos != dim_list.end()) {
1192 result += ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
1193 }
1194
1195 return result;
1196 }
1197
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const1198 double ReduceSumCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1199 int64_t stage_id) const {
1200 double result = 0.0;
1201 if (is_parameter_[0]) {
1202 TensorInfo input_tensor_info = inputs[0];
1203 CheckGlobalDeviceManager();
1204 MS_EXCEPTION_IF_NULL(g_device_manager);
1205 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
1206
1207 Shape input_shape = input_tensor_info.shape();
1208 Shape input_slice_shape = input_tensor_info.slice_shape();
1209 int64_t used_device_num = 1;
1210 for (size_t i = 0; i < input_shape.size(); ++i) {
1211 used_device_num *= input_shape[i] / input_slice_shape[i];
1212 }
1213
1214 if (total_device_num != LongToSize(used_device_num)) {
1215 result += ListProduct(input_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1216 }
1217 }
1218
1219 return result;
1220 }
1221
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id) const1222 double ReduceSumCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
1223 const std::vector<TensorInfo> &outputs, int64_t stage_id) const {
1224 double result = 0.0;
1225 TensorInfo input0 = inputs[0];
1226 TensorInfo output0 = outputs[0];
1227 std::vector<int64_t> dim_list = input0.reduce_dim();
1228 Shape input0_slice_shape = input0.slice_shape();
1229 Shape input0_shape = input0.shape();
1230 if (!cross_batch_ || !IsDataParallel(input0_shape, input0_slice_shape, stage_id)) {
1231 auto pos = std::find_if(dim_list.begin(), dim_list.end(), [input0_shape, input0_slice_shape](int64_t index) {
1232 return input0_shape[LongToSize(index)] != input0_slice_shape[LongToSize(index)];
1233 });
1234 if (pos != dim_list.end()) {
1235 result += ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
1236 }
1237 }
1238 result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1239
1240 return result;
1241 }
1242
1243 // Not taking account of output
CalculateOutputInMemory()1244 void ReduceSumCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
1245
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1246 void ReduceSumCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1247 // When calculating 'dx', taking account of 'y'
1248 if (is_parameter_[0]) {
1249 // 'x' is parameter, so it should be in memory.
1250 is_inputs_should_in_memory_[0] = true;
1251 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1252 // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
1253 is_inputs_should_in_memory_[1] = true;
1254 }
1255 } else if (is_parameter_involve_[0]) {
1256 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1257 is_inputs_should_in_memory_[1] = true;
1258 }
1259 }
1260
1261 // Not taking account of 'y'
1262 if (!is_inputs_should_in_memory_[1]) {
1263 is_inputs_should_in_memory_[1] = is_parameter_[1];
1264 }
1265 }
1266
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id) const1267 double ReduceMeanCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
1268 const std::vector<TensorInfo> &outputs, int64_t stage_id) const {
1269 double result = 0.0;
1270 TensorInfo input0 = inputs[0];
1271 TensorInfo output0 = outputs[0];
1272 std::vector<int64_t> dim_list = input0.reduce_dim();
1273 Shape input0_slice_shape = input0.slice_shape();
1274 Shape input0_shape = input0.shape();
1275 if (!cross_batch_ || !IsDataParallel(input0_shape, input0_slice_shape, stage_id)) {
1276 auto pos = std::find_if(dim_list.begin(), dim_list.end(), [input0_shape, input0_slice_shape](int64_t index) {
1277 return input0_shape[LongToSize(index)] != input0_slice_shape[LongToSize(index)];
1278 });
1279 if (pos != dim_list.end()) {
1280 result += ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]) * 2.0;
1281 }
1282 }
1283 result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1284
1285 return result;
1286 }
1287
CalculateOutputInMemory()1288 void ReduceMinCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
1289
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1290 void ReduceMinCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1291 // When calculating 'dx', taking account of 'y'
1292 if (is_parameter_[0]) {
1293 // 'x' is parameter, so it should be in memory.
1294 is_inputs_should_in_memory_[0] = true;
1295 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1296 // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
1297 is_inputs_should_in_memory_[1] = true;
1298 }
1299 } else if (is_parameter_involve_[0]) {
1300 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1301 is_inputs_should_in_memory_[1] = true;
1302 }
1303 }
1304
1305 // Not taking account of 'y'
1306 if (!is_inputs_should_in_memory_[1]) {
1307 is_inputs_should_in_memory_[1] = is_parameter_[1];
1308 }
1309 }
1310
CalculateOutputInMemory()1311 void ArgMaxWithValueCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
1312
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1313 void ArgMaxWithValueCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1314 // When calculating 'dx', taking account of 'x'
1315 if (is_parameter_[0]) {
1316 is_inputs_should_in_memory_[0] = true;
1317 } else if (is_parameter_involve_[0]) {
1318 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
1319 is_inputs_should_in_memory_[0] = true;
1320 }
1321 }
1322 }
1323
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const1324 double DropOutCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1325 int64_t) const {
1326 if (inputs.empty()) {
1327 return 0.0;
1328 }
1329 TensorInfo input0 = inputs[0];
1330 Shape input0_slice_shape = input0.slice_shape();
1331 return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * DROPOUT_COST_RATE;
1332 }
1333
1334 // return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const1335 double GatherV2Cost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
1336 int64_t) const {
1337 // GatherV2Cost does not need communication in the forward phase
1338 return 0.0;
1339 }
1340
1341 // return the per device communication cost in the backward phase.
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const1342 double GatherV2Cost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1343 int64_t stage_id) const {
1344 double result = 0.0;
1345 CheckGlobalDeviceManager();
1346 MS_EXCEPTION_IF_NULL(g_device_manager);
1347 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
1348
1349 for (size_t j = 0; j < inputs.size(); ++j) {
1350 if (!is_parameter_[j]) {
1351 continue;
1352 }
1353 TensorInfo input_a_tensor_info = inputs[j];
1354 Shape input_a_shape = input_a_tensor_info.shape();
1355 Shape input_a_slice_shape = input_a_tensor_info.slice_shape();
1356 int64_t used_device_num = 1;
1357 for (size_t i = 0; i < input_a_shape.size(); ++i) {
1358 used_device_num *= input_a_shape[i] / input_a_slice_shape[i];
1359 }
1360 if (total_device_num != LongToSize(used_device_num)) {
1361 result += ListProduct(input_a_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1362 }
1363 }
1364
1365 return result;
1366 }
1367
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const1368 double GatherV2Cost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1369 int64_t) const {
1370 // In forward phase, the computation cost = slice(A) + slice(B)
1371 Shape input0_slice_shape = inputs[0].slice_shape();
1372 Shape input1_slice_shape = inputs[1].slice_shape();
1373 double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
1374 ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
1375 return result;
1376 }
1377
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const1378 double GatherV2Cost::GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
1379 int64_t) const {
1380 return 0.0;
1381 }
1382
1383 // Not taking account of output
CalculateOutputInMemory()1384 void GatherV2Cost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
1385
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1386 void GatherV2Cost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1387 // When calculating 'dx', taking account of 'y' and 'z'
1388 if (is_parameter_[0]) {
1389 // 'x' is parameter, so it should be in memory.
1390 is_inputs_should_in_memory_[0] = true;
1391 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1392 is_inputs_should_in_memory_[1] = true;
1393 }
1394 if ((prev_output_in_mem.find(GATHERV2_INPUTS_SIZE - 1) == prev_output_in_mem.end()) ||
1395 (!prev_output_in_mem.at(GATHERV2_INPUTS_SIZE - 1))) {
1396 is_inputs_should_in_memory_[GATHERV2_INPUTS_SIZE - 1] = true;
1397 }
1398 } else if (is_parameter_involve_[0]) {
1399 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1400 is_inputs_should_in_memory_[1] = true;
1401 }
1402 if ((prev_output_in_mem.find(GATHERV2_INPUTS_SIZE - 1) == prev_output_in_mem.end()) ||
1403 (!prev_output_in_mem.at(GATHERV2_INPUTS_SIZE - 1))) {
1404 is_inputs_should_in_memory_[GATHERV2_INPUTS_SIZE - 1] = true;
1405 }
1406 }
1407
1408 if (!is_inputs_should_in_memory_[1]) {
1409 is_inputs_should_in_memory_[1] = is_parameter_[1];
1410 }
1411 if (!is_inputs_should_in_memory_[GATHERV2_INPUTS_SIZE - 1]) {
1412 is_inputs_should_in_memory_[GATHERV2_INPUTS_SIZE - 1] = is_parameter_[GATHERV2_INPUTS_SIZE - 1];
1413 }
1414 }
1415
CalculateOutputInMemory()1416 void GetNextCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
1417
CalculateInputsInMemory(const std::map<size_t,bool> &)1418 void GetNextCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
1419 if (is_inputs_should_in_memory_.size() == 0) {
1420 return;
1421 }
1422 is_inputs_should_in_memory_[0] = is_parameter_[0];
1423 }
1424
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const1425 double DSDMatmulCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1426 int64_t) const {
1427 double result = 0.0;
1428 if (inputs_type_lengths_.size() != inputs.size()) {
1429 MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for layer norm cost";
1430 }
1431
1432 for (size_t index = 0; index < inputs.size(); ++index) {
1433 TensorInfo tensor_info = inputs[index];
1434 Shape slice_shape = tensor_info.slice_shape();
1435 result += ListProduct(slice_shape) * static_cast<double>(inputs_type_lengths_[index]);
1436 }
1437 return result;
1438 }
1439
CalculateOutputInMemory()1440 void DSDMatmulCost::CalculateOutputInMemory() {
1441 is_output_should_in_memory_ =
1442 (std::find(is_parameter_involve_.cbegin(), is_parameter_involve_.cend(), true) != is_parameter_involve_.cend());
1443 }
1444
CalculateInputsInMemory(const std::map<size_t,bool> &)1445 void DSDMatmulCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
1446 bool keep_mem =
1447 (std::find(is_parameter_.begin(), is_parameter_.end(), true) != is_parameter_.end()) ||
1448 (std::find(is_parameter_involve_.begin(), is_parameter_involve_.end(), true) != is_parameter_involve_.end());
1449 std::fill(is_inputs_should_in_memory_.begin(), is_inputs_should_in_memory_.end(), keep_mem);
1450 }
1451
CalculateOutputInMemory()1452 void UniqueCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
1453
CalculateInputsInMemory(const std::map<size_t,bool> &)1454 void UniqueCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
1455 is_inputs_should_in_memory_[0] = is_parameter_[0];
1456 }
1457
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const1458 double LayerNormCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1459 int64_t stage_id) const {
1460 double result = 0.0;
1461 if (is_parameter_.size() != inputs.size()) {
1462 MS_LOG(EXCEPTION) << "Invalid parameter size " << is_parameter_.size() << " for layer norm cost";
1463 }
1464 if (inputs_type_lengths_.size() != inputs.size()) {
1465 MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for layer norm cost";
1466 }
1467
1468 MS_EXCEPTION_IF_NULL(g_device_manager);
1469 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
1470
1471 for (size_t index = 0; index < inputs.size(); ++index) {
1472 if (is_parameter_[index]) {
1473 TensorInfo tensor_info = inputs[index];
1474 Shape shape = tensor_info.shape();
1475 Shape slice_shape = tensor_info.slice_shape();
1476 int64_t used_device_num = 1;
1477 for (size_t i = 0; i < shape.size(); ++i) {
1478 if (slice_shape[i] == 0) {
1479 MS_LOG(EXCEPTION) << "Invalid slice shape " << ShapeToString(slice_shape);
1480 }
1481 used_device_num *= shape[i] / slice_shape[i];
1482 }
1483 if (total_device_num != LongToSize(used_device_num)) {
1484 result += ListProduct(slice_shape) * static_cast<double>(inputs_type_lengths_[index]);
1485 }
1486 }
1487 }
1488 return result;
1489 }
1490
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const1491 double LayerNormCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1492 int64_t) const {
1493 double result = 0.0;
1494 if (inputs_type_lengths_.size() != inputs.size()) {
1495 MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for layer norm cost";
1496 }
1497
1498 for (size_t index = 0; index < inputs.size(); ++index) {
1499 TensorInfo tensor_info = inputs[index];
1500 Shape slice_shape = tensor_info.slice_shape();
1501 result += ListProduct(slice_shape) * static_cast<double>(inputs_type_lengths_[index]);
1502 }
1503 return result;
1504 }
1505
CalculateOutputInMemory()1506 void LayerNormCost::CalculateOutputInMemory() {
1507 is_output_should_in_memory_ =
1508 is_parameter_involve_[0] || is_parameter_involve_[1] || is_parameter_involve_[LAYERNORM_INPUTS_SIZE - 1];
1509 }
1510
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1511 void LayerNormCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1512 // When calculating 'dx', taking account of both 'x' and 'y'
1513 // When calculating 'dy', taking account of both 'x' and 'y'
1514 if (is_parameter_involve_[0] || is_parameter_involve_[1]) {
1515 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
1516 is_inputs_should_in_memory_[0] = true;
1517 }
1518 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1519 is_inputs_should_in_memory_[1] = true;
1520 }
1521 }
1522 is_inputs_should_in_memory_[LAYERNORM_INPUTS_SIZE - 1] = is_parameter_[LAYERNORM_INPUTS_SIZE - 1];
1523 }
1524
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const1525 double UniqueCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const {
1526 return 0.0;
1527 }
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const1528 double UniqueCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1529 int64_t stage_id) const {
1530 double result = 0.0;
1531 if (is_parameter_[0]) {
1532 TensorInfo input = inputs[0];
1533 CheckGlobalDeviceManager();
1534 MS_EXCEPTION_IF_NULL(g_device_manager);
1535 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
1536 Shape input_shape = input.shape();
1537 Shape input_slice_shape = input.slice_shape();
1538 int64_t used_device_num = 1;
1539 for (size_t i = 0; i < input_shape.size(); ++i) {
1540 used_device_num *= input_shape[i] / input_slice_shape[i];
1541 }
1542 if (total_device_num != LongToSize(used_device_num)) {
1543 result = ListProduct(input_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1544 }
1545 }
1546 return result;
1547 }
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const1548 double UniqueCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1549 int64_t) const {
1550 // In forward phase, the computation cost = slice(A) + slice(B)
1551 Shape input_slice_shape = inputs[0].slice_shape();
1552 double result = ListProduct(input_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1553 return result;
1554 }
GetBackwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const1555 double UniqueCost::GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1556 int64_t stage_id) const {
1557 // In backward phase, the computation cost = (0 or 1) allreduce(slice(B))
1558 double result = 0.0;
1559 if (is_parameter_[0]) {
1560 TensorInfo input = inputs[0]; // tensor B
1561 CheckGlobalDeviceManager();
1562 MS_EXCEPTION_IF_NULL(g_device_manager);
1563 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
1564
1565 Shape input_shape = input.shape();
1566 Shape input_slice_shape = input.slice_shape();
1567 int64_t used_device_num = 1;
1568 for (size_t i = 0; i < input_shape.size(); ++i) {
1569 used_device_num *= input_shape[i] / input_slice_shape[i];
1570 }
1571
1572 if (total_device_num != LongToSize(used_device_num)) {
1573 result += ListProduct(input_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1574 }
1575 }
1576 return result;
1577 }
1578
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1579 double GatherCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1580 int64_t) const {
1581 double result = 0.0;
1582 if (outputs_type_lengths_.size() != outputs.size()) {
1583 MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for gatherv2 cost";
1584 }
1585 // don't split axis
1586 if (strategy_.at(LongToSize(axis_)) == 1) {
1587 return result;
1588 }
1589
1590 // split axis
1591 auto param_shape = inputs[0].slice_shape();
1592 auto index_shape = inputs[1].slice_shape();
1593 Shape reducescatter_shape = index_shape;
1594 if (param_shape.size() == 2) {
1595 reducescatter_shape.push_back(param_shape.at(LongToSize(1 - axis_)));
1596 }
1597 result += ListProduct(reducescatter_shape) * static_cast<double>(outputs_type_lengths_[0]);
1598 return result;
1599 }
1600
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const1601 double GatherCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1602 int64_t stage_id) const {
1603 double result = 0.0;
1604 CheckGlobalDeviceManager();
1605 MS_EXCEPTION_IF_NULL(g_device_manager);
1606 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
1607
1608 for (size_t j = 0; j < inputs.size(); ++j) {
1609 if (!is_parameter_[j]) {
1610 continue;
1611 }
1612 TensorInfo input_a_tensor_info = inputs[j];
1613 Shape input_a_shape = input_a_tensor_info.shape();
1614 Shape input_a_slice_shape = input_a_tensor_info.slice_shape();
1615 int64_t used_device_num = 1;
1616 for (size_t i = 0; i < input_a_shape.size(); ++i) {
1617 used_device_num *= input_a_shape[i] / input_a_slice_shape[i];
1618 }
1619 if (total_device_num != LongToSize(used_device_num)) {
1620 result += ListProduct(input_a_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1621 }
1622 }
1623 return result;
1624 }
1625
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const1626 double UniformCandidateSamplerCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
1627 const std::vector<TensorInfo> &, int64_t) const {
1628 Shape input0_slice_shape = inputs[0].slice_shape();
1629 if (inputs_type_lengths_.size() != inputs.size()) {
1630 MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size()
1631 << " for UniformCandidateSampler cost";
1632 }
1633
1634 double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1635
1636 return result;
1637 }
1638
CalculateOutputInMemory()1639 void UniformCandidateSamplerCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
1640
CalculateInputsInMemory(const std::map<size_t,bool> &)1641 void UniformCandidateSamplerCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
1642 is_inputs_should_in_memory_[0] = is_parameter_[0];
1643 }
1644
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const1645 double GatherCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1646 int64_t) const {
1647 double result = 0.0;
1648 Shape input0_slice_shape = inputs[0].slice_shape();
1649 Shape input1_slice_shape = inputs[1].slice_shape();
1650 if (inputs_type_lengths_.size() != inputs.size()) {
1651 MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for gatherv2 cost";
1652 }
1653 // don't split axis
1654 if (strategy_.at(LongToSize(axis_)) == 1) {
1655 result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
1656 ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
1657 } else {
1658 // split axis
1659 result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * GATHERV2_COST_WEIGHT0 +
1660 ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * GATHERV2_COST_WEIGHT1;
1661 }
1662
1663 return result;
1664 }
1665
GetBackwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1666 double GatherCost::GetBackwardComputationCost(const std::vector<TensorInfo> &inputs,
1667 const std::vector<TensorInfo> &outputs, int64_t) const {
1668 double result = 0.0;
1669 Shape input1_slice_shape = inputs[1].slice_shape();
1670 Shape output0_slice_shape = outputs[0].slice_shape();
1671 // don't split axis
1672 if (strategy_.at(LongToSize(axis_)) == 1) {
1673 result += ListProduct(output0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1674 } else {
1675 // split axis
1676 result += ListProduct(output0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * GATHERV2_COST_WEIGHT2 +
1677 ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * GATHERV2_COST_WEIGHT3;
1678 }
1679
1680 return result;
1681 }
1682
1683 // The forward communication is determined by whether the slice is column split or row split
1684 // The number of segments is actually the shape[0] of the output, which is the cost of the AllReduce
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1685 double UnsortedSegmentSumCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs,
1686 const std::vector<TensorInfo> &outputs, int64_t) const {
1687 TensorInfo input0 = inputs[0];
1688 TensorInfo input1 = inputs[1];
1689 TensorInfo output0 = outputs[0];
1690 Shape input0_shape = input0.shape();
1691 Shape input0_slice_shape = inputs[0].slice_shape();
1692 double result = 0.0;
1693 if (inputs_type_lengths_.size() != inputs.size()) {
1694 MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for UnsortedSegmentSum cost";
1695 }
1696 // If the shape b is not the same as the shape a, we regard it as column slice
1697 for (size_t i = 0; i < input1.shape().size(); ++i) {
1698 if (input0_shape[i] != input0_slice_shape[i]) {
1699 result = ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
1700 return result;
1701 }
1702 }
1703 return result;
1704 }
1705
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1706 double UnsortedSegmentSumCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs,
1707 const std::vector<TensorInfo> &outputs, int64_t) const {
1708 TensorInfo input0 = inputs[0];
1709 TensorInfo input1 = inputs[1];
1710 TensorInfo output0 = outputs[0];
1711 Shape input0_shape = input0.shape();
1712 Shape input0_slice_shape = inputs[0].slice_shape();
1713 double result = 0.0;
1714 if (inputs_type_lengths_.size() != inputs.size()) {
1715 MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for UnsortedSegmentSum cost";
1716 }
1717 if (is_parameter_[0]) {
1718 // If the forward process has a AllReduce, then the backward also needs one.
1719 for (size_t i = 0; i < input1.shape().size(); ++i) {
1720 if (input0_shape[i] != input0_slice_shape[i]) {
1721 result = ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
1722 return result;
1723 }
1724 }
1725 }
1726 return result;
1727 }
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1728 double UnsortedSegmentSumCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
1729 const std::vector<TensorInfo> &outputs, int64_t) const {
1730 // In forward phase, the computation cost = slice(A) + slice(B)
1731 Shape input0_slice_shape = inputs[0].slice_shape();
1732 Shape input1_slice_shape = inputs[1].slice_shape();
1733 Shape output_slice_shape = outputs[0].slice_shape();
1734 double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
1735 ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) +
1736 ListProduct(output_slice_shape) * static_cast<double>(outputs_type_lengths_[0]);
1737 return result;
1738 }
1739
1740 // Not taking account of output
CalculateOutputInMemory()1741 void UnsortedSegmentSumCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
1742
1743 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1744 void UnsortedSegmentSumCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1745 // When calculating 'dx', taking account of 'y'
1746 if (is_parameter_[0]) {
1747 is_inputs_should_in_memory_[0] = true;
1748 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1749 is_inputs_should_in_memory_[1] = true;
1750 }
1751 } else if (is_parameter_involve_[0]) {
1752 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1753 is_inputs_should_in_memory_[1] = true;
1754 }
1755 }
1756
1757 if (!is_inputs_should_in_memory_[1]) {
1758 is_inputs_should_in_memory_[1] = is_parameter_[1];
1759 }
1760 is_inputs_should_in_memory_[UNSORTEDSEGMENTSUM_INPUTS_SIZE - 1] = is_parameter_[UNSORTEDSEGMENTSUM_INPUTS_SIZE - 1];
1761 }
1762
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1763 double UnsortedSegmentMinCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs,
1764 const std::vector<TensorInfo> &outputs, int64_t) const {
1765 TensorInfo input0 = inputs[0];
1766 TensorInfo input1 = inputs[1];
1767 TensorInfo output0 = outputs[0];
1768 Shape input0_shape = input0.shape();
1769 Shape input0_slice_shape = inputs[0].slice_shape();
1770 double result = 0.0;
1771 if (inputs_type_lengths_.size() != inputs.size()) {
1772 MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size()
1773 << " for UnsortedSegmentMinCost cost";
1774 }
1775 // If the shape b is not the same as the shape a, we regard it as column slice
1776 // The cost is a AllGather operation, the shape is the same as the output of UnsortedSegmentMin.
1777 for (size_t i = 0; i < input1.shape().size(); ++i) {
1778 if (input0_shape[i] != input0_slice_shape[i]) {
1779 result = ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
1780 return result;
1781 }
1782 }
1783 return result;
1784 }
1785
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1786 double UnsortedSegmentMinCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs,
1787 const std::vector<TensorInfo> &outputs, int64_t) const {
1788 TensorInfo input0 = inputs[0];
1789 TensorInfo input1 = inputs[1];
1790 TensorInfo output0 = outputs[0];
1791 Shape input0_shape = input0.shape();
1792 Shape input0_slice_shape = inputs[0].slice_shape();
1793 double result = 0.0;
1794 if (inputs_type_lengths_.size() != inputs.size()) {
1795 MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size()
1796 << " for UnsortedSegmentMinCost cost";
1797 }
1798 if (is_parameter_[0]) {
1799 // If the forward process has a AllGather, then the backward also needs one ReduceScatter.
1800 for (size_t i = 0; i < input1.shape().size(); ++i) {
1801 if (input0_shape[i] != input0_slice_shape[i]) {
1802 result = ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
1803 return result;
1804 }
1805 }
1806 }
1807 return result;
1808 }
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1809 double UnsortedSegmentMinCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
1810 const std::vector<TensorInfo> &outputs, int64_t) const {
1811 // In forward phase, the computation cost = slice(A) + slice(B)
1812 Shape input0_slice_shape = inputs[0].slice_shape();
1813 Shape input1_slice_shape = inputs[1].slice_shape();
1814 Shape output_slice_shape = outputs[0].slice_shape();
1815 // The forward operation is UnsortedSegmentMin + ReudceMin
1816 double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
1817 ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) +
1818 ListProduct(output_slice_shape) * static_cast<double>(outputs_type_lengths_[0]) +
1819 ListProduct(output_slice_shape) * static_cast<double>(outputs_type_lengths_[0]); // ReduceMin
1820 return result;
1821 }
1822
1823 // Taking account of output
CalculateOutputInMemory()1824 void UnsortedSegmentMinCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
1825
1826 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1827 void UnsortedSegmentMinCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1828 // When calculating 'dx', taking account of 'x', 'y' and 'z'
1829 if (is_parameter_involve_[0]) {
1830 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
1831 is_inputs_should_in_memory_[0] = true;
1832 }
1833 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1834 is_inputs_should_in_memory_[1] = true;
1835 }
1836 if ((prev_output_in_mem.find(UNSORTEDSEGMENTSUM_INPUTS_SIZE - 1) == prev_output_in_mem.end()) ||
1837 (!prev_output_in_mem.at(UNSORTEDSEGMENTSUM_INPUTS_SIZE - 1))) {
1838 is_inputs_should_in_memory_[UNSORTEDSEGMENTSUM_INPUTS_SIZE - 1] = true;
1839 }
1840 }
1841 if (!is_inputs_should_in_memory_[1]) {
1842 is_inputs_should_in_memory_[1] = is_parameter_[1];
1843 }
1844 if (!is_inputs_should_in_memory_[UNSORTEDSEGMENTSUM_INPUTS_SIZE - 1]) {
1845 is_inputs_should_in_memory_[UNSORTEDSEGMENTSUM_INPUTS_SIZE - 1] = is_parameter_[UNSORTEDSEGMENTSUM_INPUTS_SIZE - 1];
1846 }
1847 }
1848
1849 // Not taking account of output
CalculateOutputInMemory()1850 void VirtualDatasetCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
1851
1852 // Not taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> &)1853 void VirtualDatasetCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
1854 for (size_t i = 0; i < is_inputs_should_in_memory_.size(); ++i) {
1855 is_inputs_should_in_memory_[i] = is_parameter_[i];
1856 }
1857 }
1858
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const1859 double MatmulDDSCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1860 int64_t) const {
1861 double result = 0.0;
1862 if (inputs_type_lengths_.size() != inputs.size()) {
1863 MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for layer norm cost";
1864 }
1865
1866 for (size_t index = 0; index < inputs.size(); ++index) {
1867 TensorInfo tensor_info = inputs[index];
1868 Shape slice_shape = tensor_info.slice_shape();
1869 result += ListProduct(slice_shape) * static_cast<double>(inputs_type_lengths_[index]);
1870 }
1871 return result;
1872 }
1873
1874 // Not taking account of output
CalculateOutputInMemory()1875 void MatmulDDSCost::CalculateOutputInMemory() {
1876 is_output_should_in_memory_ =
1877 (std::find(is_parameter_involve_.cbegin(), is_parameter_involve_.cend(), true) != is_parameter_involve_.cend());
1878 }
1879
1880 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> &)1881 void MatmulDDSCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
1882 bool keep_mem =
1883 (std::find(is_parameter_.begin(), is_parameter_.end(), true) != is_parameter_.end()) ||
1884 (std::find(is_parameter_involve_.begin(), is_parameter_involve_.end(), true) != is_parameter_involve_.end());
1885 std::fill(is_inputs_should_in_memory_.begin(), is_inputs_should_in_memory_.end(), keep_mem);
1886 }
1887
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const1888 double ScatterMathOpsCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
1889 const std::vector<TensorInfo> &, int64_t) const {
1890 double result = 0.0;
1891 Shape input0_slice_shape = inputs[0].slice_shape();
1892 Shape input1_slice_shape = inputs[1].slice_shape();
1893 Shape input2_slice_shape = inputs[2].slice_shape();
1894 if (inputs_type_lengths_.size() != inputs.size()) {
1895 MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for gatherv2 cost";
1896 }
1897 // don't split axis
1898 if (!is_split_axis_) {
1899 result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
1900 ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) +
1901 ListProduct(input2_slice_shape) * static_cast<double>(inputs_type_lengths_[2]);
1902 } else {
1903 // split axis
1904 result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * input_coefficient_ +
1905 ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * indices_coefficient_ +
1906 ListProduct(input2_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * updates_coefficient_;
1907 }
1908
1909 return result;
1910 }
1911
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const1912 double TensorScatterOpsCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1913 int64_t stage_id) const {
1914 double result = 0.0;
1915 CheckGlobalDeviceManager();
1916 MS_EXCEPTION_IF_NULL(g_device_manager);
1917 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
1918
1919 for (size_t j = 0; j < inputs.size(); ++j) {
1920 if (!is_parameter_[j]) {
1921 continue;
1922 }
1923 TensorInfo input_a_tensor_info = inputs[j];
1924 Shape input_a_shape = input_a_tensor_info.shape();
1925 Shape input_a_slice_shape = input_a_tensor_info.slice_shape();
1926 int64_t used_device_num = 1;
1927 for (size_t i = 0; i < input_a_shape.size(); ++i) {
1928 used_device_num *= input_a_shape[i] / input_a_slice_shape[i];
1929 }
1930 if (total_device_num != LongToSize(used_device_num)) {
1931 result += ListProduct(input_a_slice_shape) * static_cast<double>(inputs_type_lengths_[j]);
1932 }
1933 }
1934 return result;
1935 }
1936
GetBackwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1937 double TensorScatterOpsCost::GetBackwardComputationCost(const std::vector<TensorInfo> &inputs,
1938 const std::vector<TensorInfo> &outputs, int64_t) const {
1939 double result = 0.0;
1940 Shape input0_slice_shape = inputs[0].slice_shape();
1941 Shape input1_slice_shape = inputs[1].slice_shape();
1942 Shape input2_slice_shape = inputs[2].slice_shape(); // equal to output shape
1943 // brop func using 3 times input/out in average.
1944 // don't split axis
1945 if (!is_split_axis_) {
1946 result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
1947 ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) +
1948 ListProduct(input2_slice_shape) * static_cast<double>(inputs_type_lengths_[2]);
1949 result *= 3;
1950
1951 } else {
1952 // split axis
1953 result +=
1954 ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * (input_coefficient_ + 3) +
1955 ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * (indices_coefficient_ + 3) +
1956 ListProduct(input2_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * (updates_coefficient_ + 3);
1957 }
1958
1959 return result;
1960 }
1961
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1962 double CropAndResizeCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs,
1963 const std::vector<TensorInfo> &outputs, int64_t) const {
1964 double result = 0.0;
1965 if (outputs_type_lengths_.size() != outputs.size()) {
1966 MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for CropAndResize cost.";
1967 }
1968
1969 // don't split the batch
1970 if (strategy_[0] == 1) {
1971 return result;
1972 }
1973
1974 // split batch
1975 auto x_shape = inputs[0].slice_shape();
1976 auto box_shape = inputs[0].slice_shape();
1977 Shape reduce_sum_shape = {box_shape[0], crop_size_[0], crop_size_[1], x_shape[3]};
1978 result += ListProduct(reduce_sum_shape) * static_cast<double>(outputs_type_lengths_[0]);
1979 return result;
1980 }
1981
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const1982 double CropAndResizeCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1983 int64_t stage_id) const {
1984 double result = 0.0;
1985 CheckGlobalDeviceManager();
1986 MS_EXCEPTION_IF_NULL(g_device_manager);
1987 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
1988
1989 for (size_t i = 0; i < inputs.size(); ++i) {
1990 if (is_parameter_[i]) {
1991 continue;
1992 }
1993 TensorInfo input_tensor_info = inputs[i];
1994 Shape input_shape = input_tensor_info.shape();
1995 Shape input_slice_shape = input_tensor_info.slice_shape();
1996 int64_t used_device_num = 1;
1997 for (size_t j = 0; j < input_shape.size(); ++j) {
1998 used_device_num *= input_shape[j] / input_slice_shape[j];
1999 }
2000 if (total_device_num != LongToSize(used_device_num)) {
2001 result += ListProduct(input_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
2002 }
2003 }
2004 return result;
2005 }
2006
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const2007 double CropAndResizeCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
2008 const std::vector<TensorInfo> &, int64_t) const {
2009 if (inputs_type_lengths_.size() != inputs.size()) {
2010 MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for CropAndResize cost.";
2011 }
2012
2013 Shape input0_slice_shape = inputs.at(0).slice_shape();
2014 Shape input1_slice_shape = inputs.at(1).slice_shape();
2015 Shape input2_slice_shape = inputs.at(2).slice_shape();
2016 double result = 0.0;
2017 // don't split batch
2018 if (strategy_[0] == 1) {
2019 result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
2020 ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) +
2021 ListProduct(input2_slice_shape) * static_cast<double>(inputs_type_lengths_[2]);
2022 } else {
2023 // split batch
2024 result +=
2025 ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * CROP_AND_RESIZE_COST_WEIGHT0 +
2026 ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * CROP_AND_RESIZE_COST_WEIGHT1 +
2027 ListProduct(input2_slice_shape) * static_cast<double>(inputs_type_lengths_[2]) * CROP_AND_RESIZE_COST_WEIGHT2;
2028 }
2029
2030 return result;
2031 }
2032
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> & outputs,int64_t) const2033 double CropAndResizeCost::GetBackwardComputationCost(const std::vector<TensorInfo> &,
2034 const std::vector<TensorInfo> &outputs, int64_t) const {
2035 Shape output0_slice_shape = outputs[0].slice_shape();
2036 double result = 0.0;
2037 // don't split batch
2038 if (strategy_.at(0) == 1) {
2039 result += ListProduct(output0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
2040 } else {
2041 // split batch
2042 result +=
2043 ListProduct(output0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * CROP_AND_RESIZE_COST_WEIGHT3;
2044 }
2045
2046 return result;
2047 }
2048
2049 // Not taking account of output
CalculateOutputInMemory()2050 void CropAndResizeCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
2051
2052 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)2053 void CropAndResizeCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
2054 for (size_t i = 0; i < CROP_AND_RESIZE_INPUTS_SIZE; ++i) {
2055 is_inputs_should_in_memory_[i] = is_parameter_[i];
2056 }
2057
2058 // When calculating 'dx', taking account of 'y' and 'z'
2059 if (is_parameter_[0] || is_parameter_involve_[0]) {
2060 if (prev_output_in_mem.find(1) == prev_output_in_mem.end() || !prev_output_in_mem.at(1)) {
2061 is_inputs_should_in_memory_[1] = true;
2062 }
2063 if (prev_output_in_mem.find(CROP_AND_RESIZE_INPUTS_SIZE - 1) == prev_output_in_mem.end() ||
2064 !prev_output_in_mem.at(CROP_AND_RESIZE_INPUTS_SIZE - 1)) {
2065 is_inputs_should_in_memory_[CROP_AND_RESIZE_INPUTS_SIZE - 1] = true;
2066 }
2067 }
2068 }
2069
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const2070 double ROIAlignCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
2071 int64_t) const {
2072 double result = 0.0;
2073 if (outputs_type_lengths_.size() != outputs.size()) {
2074 MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for CropAndResize cost.";
2075 }
2076
2077 // don't split the batch
2078 if (strategy_[0] == 1) {
2079 return result;
2080 }
2081
2082 // split batch
2083 auto features_shape = inputs[0].slice_shape();
2084 auto rois_shape = inputs[0].slice_shape();
2085 Shape reduce_sum_shape = {rois_shape[0], features_shape[1], pooled_shape_[0], pooled_shape_[1]};
2086 result += ListProduct(reduce_sum_shape) * static_cast<double>(outputs_type_lengths_[0]);
2087 return result;
2088 }
2089
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const2090 double ROIAlignCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
2091 int64_t) const {
2092 if (inputs_type_lengths_.size() != inputs.size()) {
2093 MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for CropAndResize cost.";
2094 }
2095
2096 Shape input0_slice_shape = inputs.at(0).slice_shape();
2097 Shape input1_slice_shape = inputs.at(1).slice_shape();
2098 double result = 0.0;
2099 // don't split batch
2100 if (strategy_[0] == 1) {
2101 result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
2102 ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
2103 } else {
2104 // split batch
2105 result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * ROI_ALIGN_COST_WEIGHT0 +
2106 ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * ROI_ALIGN_COST_WEIGHT1;
2107 }
2108
2109 return result;
2110 }
2111
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> & outputs,int64_t) const2112 double ROIAlignCost::GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &outputs,
2113 int64_t) const {
2114 Shape output0_slice_shape = outputs[0].slice_shape();
2115 double result = 0.0;
2116 // don't split batch
2117 if (strategy_.at(0) == 1) {
2118 result += ListProduct(output0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
2119 } else {
2120 // split batch
2121 result += ListProduct(output0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * ROI_ALIGN_COST_WEIGHT2;
2122 }
2123
2124 return result;
2125 }
2126
2127 // Taking account of output
CalculateOutputInMemory()2128 void ROIAlignCost::CalculateOutputInMemory() { is_output_should_in_memory_ = true; }
2129
2130 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)2131 void ROIAlignCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
2132 for (size_t i = 0; i < ROI_ALIGN_INPUTS_SIZE; ++i) {
2133 is_inputs_should_in_memory_[i] = is_parameter_[i];
2134 }
2135
2136 // When calculating 'dx', taking account of 'y' and 'z'
2137 if (is_parameter_[0] || is_parameter_involve_[0]) {
2138 if (prev_output_in_mem.find(ROI_ALIGN_INPUTS_SIZE - 1) == prev_output_in_mem.end() ||
2139 !prev_output_in_mem.at(ROI_ALIGN_INPUTS_SIZE - 1)) {
2140 is_inputs_should_in_memory_[ROI_ALIGN_INPUTS_SIZE - 1] = true;
2141 }
2142 }
2143 }
2144 } // namespace parallel
2145 } // namespace mindspore
2146