1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/auto_parallel/operator_costmodel.h"
18
19 #include <algorithm>
20 #include <random>
21 #include "frontend/parallel/device_matrix.h"
22 #include "frontend/parallel/tensor_layout/tensor_redistribution.h"
23
24 namespace mindspore {
25 namespace parallel {
set_is_parameter(const std::vector<bool> & is_parameter)26 void OperatorCost::set_is_parameter(const std::vector<bool> &is_parameter) { is_parameter_ = is_parameter; }
27
set_is_parameter_involve(const std::vector<bool> & is_parameter_inv)28 void OperatorCost::set_is_parameter_involve(const std::vector<bool> &is_parameter_inv) {
29 is_parameter_involve_ = is_parameter_inv;
30 is_inputs_should_in_memory_ = std::vector<bool>(is_parameter_involve_.size(), false);
31 }
32
set_output_parameter_involve(int64_t output_para)33 void OperatorCost::set_output_parameter_involve(int64_t output_para) { output_parameter_involve_ = output_para; }
34
SetInputAndOutputTypeLength(const std::vector<size_t> & input_lengths,const std::vector<size_t> & output_lengths)35 void OperatorCost::SetInputAndOutputTypeLength(const std::vector<size_t> &input_lengths,
36 const std::vector<size_t> &output_lengths) {
37 inputs_type_lengths_ = input_lengths;
38 outputs_type_lengths_ = output_lengths;
39 }
40
set_output_critical(int64_t critical)41 void OperatorCost::set_output_critical(int64_t critical) { is_outputs_critical_ = critical; }
42
GetMemoryCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs) const43 double OperatorCost::GetMemoryCost(const std::vector<TensorInfo> &inputs,
44 const std::vector<TensorInfo> &outputs) const {
45 return GetInputMemoryCost(inputs, outputs) + GetOutputMemoryCost(inputs, outputs);
46 }
47
GetInputMemoryCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &) const48 double OperatorCost::GetInputMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &) const {
49 double result = 0.0;
50 for (size_t i = 0; i < inputs.size(); ++i) {
51 if (is_inputs_should_in_memory_[i]) {
52 result += ListProduct(inputs[i].slice_shape()) * static_cast<double>(inputs_type_lengths_[i]);
53 }
54 }
55 return result;
56 }
57
GetOutputMemoryCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> & outputs) const58 double OperatorCost::GetOutputMemoryCost(const std::vector<TensorInfo> &,
59 const std::vector<TensorInfo> &outputs) const {
60 double result = 0.0;
61 if (is_output_should_in_memory_) {
62 // When this operator has multiple outputs, they all contributes to the memory.
63 for (size_t i = 0; i < outputs.size(); ++i) {
64 result += ListProduct(outputs[i].slice_shape()) * static_cast<double>(outputs_type_lengths_[i]);
65 }
66 }
67 return result;
68 }
69
GetMemoryCostForInference(const std::vector<TensorInfo> &,const std::vector<TensorInfo> & outputs) const70 double OperatorCost::GetMemoryCostForInference(const std::vector<TensorInfo> &,
71 const std::vector<TensorInfo> &outputs) const {
72 double result = 0.0;
73 if (is_outputs_critical_ == -1) {
74 MS_LOG(EXCEPTION) << "The critical flag is not set.";
75 }
76 if (is_outputs_critical_ == 1) {
77 for (size_t i = 0; i < outputs.size(); ++i) {
78 result += ListProduct(outputs[i].slice_shape()) * static_cast<double>(outputs_type_lengths_[i]);
79 }
80 }
81 return result;
82 }
83
84 // return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const85 double MatMulCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
86 int64_t) const {
87 TensorInfo input0 = inputs[0];
88 TensorInfo output0 = outputs[0];
89 Shape input0_shape = input0.shape();
90 Shape input0_slice_shape = input0.slice_shape();
91 if (input0_shape[input0_shape.size() - 1] == input0_slice_shape[input0_slice_shape.size() - 1]) {
92 // If the reduced dimension has not been partitioned, then there is no communication cost.
93 return 0.0;
94 } else {
95 // Else, the communication cost is the size (number of bytes) of a slice of output tensor.
96 return ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
97 }
98 }
99
100 // return the per device communication cost in the forward phase.
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const101 double MatMulCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
102 int64_t stage_id) const {
103 // In backward phase, the communication cost is incurred only when tensor B is a Parameter and tensor B does not
104 // fully utilize all devices
105 double result = 0.0;
106 if (is_parameter_[1]) {
107 TensorInfo input1 = inputs[1]; // tensor B
108 CheckGlobalDeviceManager();
109 MS_EXCEPTION_IF_NULL(g_device_manager);
110 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
111
112 Shape input1_shape = input1.shape();
113 Shape input1_slice_shape = input1.slice_shape();
114 int64_t used_device_num = 1;
115 for (size_t i = 0; i < input1_shape.size(); ++i) {
116 used_device_num *= input1_shape[i] / input1_slice_shape[i];
117 }
118
119 if (total_device_num != LongToSize(used_device_num))
120 result += ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
121 }
122
123 return result;
124 }
125
126 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
127 // this operator uses
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const128 double MatMulCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
129 const std::vector<TensorInfo> &outputs, int64_t) const {
130 // In forward phase, the computation cost = slice(A) + slice(B) + (0 or 1) allreduce(slice(C))
131 double result = 0.0;
132 TensorInfo output0 = outputs[0];
133 Shape input0_slice_shape = inputs[0].slice_shape();
134 Shape input1_slice_shape = inputs[1].slice_shape();
135 Shape input0_shape = inputs[0].shape();
136 if (input0_shape[input0_shape.size() - 1] != input0_slice_shape[input0_slice_shape.size() - 1]) {
137 // If the reduced dimension has been partitioned, then there is no communication cost.
138 result += ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
139 }
140 result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
141 ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
142 return result;
143 }
144
145 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
146 // this operator uses
GetBackwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const147 double MatMulCost::GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
148 int64_t stage_id) const {
149 // In backward phase, the computation cost = (0 or 1) allreduce(slice(B))
150 double result = 0.0;
151 if (is_parameter_[1]) {
152 TensorInfo input1 = inputs[1]; // tensor B
153 CheckGlobalDeviceManager();
154 MS_EXCEPTION_IF_NULL(g_device_manager);
155 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
156
157 Shape input1_shape = input1.shape();
158 Shape input1_slice_shape = input1.slice_shape();
159 int64_t used_device_num = 1;
160 for (size_t i = 0; i < input1_shape.size(); ++i) {
161 used_device_num *= input1_shape[i] / input1_slice_shape[i];
162 }
163
164 if (total_device_num != LongToSize(used_device_num))
165 result += ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
166 }
167
168 return result;
169 }
170
171 // Not taking account of output
CalculateOutputInMemory()172 void MatMulCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
173
174 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)175 void MatMulCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
176 if (is_parameter_[0]) {
177 is_inputs_should_in_memory_[0] = true;
178 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
179 is_inputs_should_in_memory_[1] = true;
180 }
181 } else if (is_parameter_involve_[0]) {
182 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
183 is_inputs_should_in_memory_[1] = true;
184 }
185 }
186
187 if (is_parameter_[1]) {
188 is_inputs_should_in_memory_[1] = true;
189 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
190 is_inputs_should_in_memory_[0] = true;
191 }
192 } else if (is_parameter_involve_[1]) {
193 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
194 is_inputs_should_in_memory_[0] = true;
195 }
196 }
197 }
198
199 // Return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const200 double CastCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const {
201 // ReLU is the element-wise operator, thus it does not need communication in the forward phase
202 return 0.0;
203 }
204
205 // Return the per device communication cost in the backward phase.
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const206 double CastCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
207 int64_t stage_id) const {
208 double result = 0.0;
209 if (is_parameter_[0]) {
210 TensorInfo input1 = inputs[0];
211 MS_EXCEPTION_IF_NULL(g_device_manager);
212 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
213 Shape input1_shape = input1.shape();
214 Shape input1_slice_shape = input1.slice_shape();
215 int64_t used_device_num = 1;
216 for (size_t i = 0; i < input1_shape.size(); ++i) {
217 used_device_num *= input1_shape[i] / input1_slice_shape[i];
218 }
219 if (total_device_num != LongToSize(used_device_num)) {
220 result = ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
221 }
222 }
223 return result;
224 }
225
226 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
227 // this operator uses
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const228 double CastCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
229 int64_t) const {
230 TensorInfo input0 = inputs[0];
231 Shape input0_slice_shape = input0.slice_shape();
232 return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
233 }
234
235 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
236 // this operator uses
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const237 double CastCost::GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
238 int64_t) const {
239 return 0.0;
240 }
241
242 // Not taking account of output
CalculateOutputInMemory()243 void CastCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
244
245 // Not taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> &)246 void CastCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
247 is_inputs_should_in_memory_[0] = is_parameter_[0];
248 }
249
250 // Taking account of output
CalculateOutputInMemory()251 void SqrtCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
252
253 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)254 void GeLUCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
255 if (is_parameter_[0]) {
256 is_inputs_should_in_memory_[0] = true;
257 } else if (is_parameter_involve_[0]) {
258 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
259 is_inputs_should_in_memory_[0] = true;
260 }
261 }
262 }
263
264 // Return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const265 double SoftmaxCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
266 int64_t) const {
267 // In the forward phase, the communication cost = 0
268 return 0.0;
269 }
270
271 // Return the per device communication cost in the backward phase.
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const272 double SoftmaxCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
273 int64_t stage_id) const {
274 double result = 0.0;
275 if (is_parameter_[0]) {
276 TensorInfo input1 = inputs[0];
277 MS_EXCEPTION_IF_NULL(g_device_manager);
278 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
279 Shape input1_shape = input1.shape();
280 Shape input1_slice_shape = input1.slice_shape();
281 int64_t used_device_num = 1;
282 for (size_t i = 0; i < input1_shape.size(); ++i) {
283 used_device_num *= input1_shape[i] / input1_slice_shape[i];
284 }
285 if (total_device_num != LongToSize(used_device_num)) {
286 result = ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
287 }
288 }
289 return result;
290 }
291
292 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
293 // this operator uses
GetForwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> & outputs,int64_t) const294 double SoftmaxCost::GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &outputs,
295 int64_t) const {
296 if (outputs.empty() || outputs_type_lengths_.empty()) {
297 MS_LOG(EXCEPTION) << "The outputs or outputs_type_length is empty";
298 }
299
300 // use output for Tile operator
301 TensorInfo output_info = outputs[0];
302 Shape output_slice_shape = output_info.slice_shape();
303 return ListProduct(output_slice_shape) * static_cast<double>(outputs_type_lengths_[0]);
304 }
305
306 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
307 // this operator uses
GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,const std::vector<mindspore::parallel::TensorInfo> &,int64_t) const308 double SoftmaxCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,
309 const std::vector<mindspore::parallel::TensorInfo> &, int64_t) const {
310 return 0.0;
311 }
312
313 // Taking account of output
CalculateOutputInMemory()314 void SoftmaxCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
315
316 // Not taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> &)317 void SoftmaxCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
318 is_inputs_should_in_memory_[0] = is_parameter_[0];
319 }
320
321 // Not taking account of output
CalculateOutputInMemory()322 void PackCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
323
324 // Not taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> &)325 void PackCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
326 is_inputs_should_in_memory_[0] = is_parameter_[0];
327 }
328
329 // Not taking account of output
CalculateOutputInMemory()330 void TileCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
331
332 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)333 void TileCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
334 // When calculating 'dx', taking account of 'y'
335 if (is_parameter_[0]) {
336 is_inputs_should_in_memory_[0] = true;
337 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
338 is_inputs_should_in_memory_[1] = true;
339 }
340 } else if (is_parameter_involve_[0]) {
341 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
342 is_inputs_should_in_memory_[1] = true;
343 }
344 }
345
346 if (!is_inputs_should_in_memory_[1]) {
347 is_inputs_should_in_memory_[1] = is_parameter_[1];
348 }
349 }
350
351 // Not taking account of output
CalculateOutputInMemory()352 void BroadcastToCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
353
CalculateInputsInMemory(const std::map<size_t,bool> &)354 void BroadcastToCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
355 is_inputs_should_in_memory_[0] = is_parameter_[0];
356 }
357
358 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)359 void ReLU6Cost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
360 if (is_parameter_[0]) {
361 is_inputs_should_in_memory_[0] = true;
362 } else if (is_parameter_involve_[0]) {
363 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
364 is_inputs_should_in_memory_[0] = true;
365 }
366 }
367 }
368
369 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)370 void TransposeCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
371 // When calculating 'dx', taking account of 'y'
372 if (is_parameter_[0]) {
373 is_inputs_should_in_memory_[0] = true;
374 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
375 is_inputs_should_in_memory_[1] = true;
376 }
377 } else if (is_parameter_involve_[0]) {
378 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
379 is_inputs_should_in_memory_[1] = true;
380 }
381 }
382
383 if (!is_inputs_should_in_memory_[1]) {
384 is_inputs_should_in_memory_[1] = is_parameter_[1];
385 }
386 }
387
388 // return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<mindspore::parallel::TensorInfo> &,const std::vector<mindspore::parallel::TensorInfo> &,int64_t) const389 double TmpIdentityCost::GetForwardCommCost(const std::vector<mindspore::parallel::TensorInfo> &,
390 const std::vector<mindspore::parallel::TensorInfo> &, int64_t) const {
391 // Identity is the element-wise operator, thus it does not need communication in the forward phase
392 return 0.0;
393 }
394
395 // return the per device communication cost in the backward phase.
GetBackwardCommCost(const std::vector<mindspore::parallel::TensorInfo> &,const std::vector<mindspore::parallel::TensorInfo> &,int64_t) const396 double TmpIdentityCost::GetBackwardCommCost(const std::vector<mindspore::parallel::TensorInfo> &,
397 const std::vector<mindspore::parallel::TensorInfo> &, int64_t) const {
398 // Identity is the element-wise operator, thus it does not need communication in the backward phase
399 return 0.0;
400 }
401
402 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
403 // this operator uses
GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,const std::vector<mindspore::parallel::TensorInfo> &,int64_t) const404 double TmpIdentityCost::GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,
405 const std::vector<mindspore::parallel::TensorInfo> &, int64_t) const {
406 return 0.0;
407 }
408
409 // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes
410 // this operator uses
GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,const std::vector<mindspore::parallel::TensorInfo> &,int64_t) const411 double TmpIdentityCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,
412 const std::vector<mindspore::parallel::TensorInfo> &,
413 int64_t) const {
414 return 0.0;
415 }
416
417 // Not taking account of output
CalculateOutputInMemory()418 void TmpIdentityCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
419
420 // Not taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> &)421 void TmpIdentityCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
422 is_inputs_should_in_memory_[0] = is_parameter_[0];
423 }
424
GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> & inputs,const std::vector<mindspore::parallel::TensorInfo> &,int64_t) const425 double BatchParallelCost::GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &inputs,
426 const std::vector<mindspore::parallel::TensorInfo> &,
427 int64_t) const {
428 double cost = 0.0;
429 for (size_t i = 0; i < inputs.size(); ++i) {
430 cost += ListProduct(inputs[i].slice_shape()) * static_cast<double>(inputs_type_lengths_[i]);
431 }
432 return cost;
433 }
434
GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,const std::vector<mindspore::parallel::TensorInfo> &,int64_t) const435 double BatchParallelCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,
436 const std::vector<mindspore::parallel::TensorInfo> &,
437 int64_t) const {
438 return 0.0;
439 }
440
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const441 double BatchParallelCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
442 int64_t stage_id) const {
443 double result = 0.0;
444 CheckGlobalDeviceManager();
445 MS_EXCEPTION_IF_NULL(g_device_manager);
446 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
447
448 for (size_t j = 0; j < inputs.size(); ++j) {
449 if (!is_parameter_[j]) {
450 continue;
451 }
452 TensorInfo input_a_tensor_info = inputs[j];
453 Shape input_a_shape = input_a_tensor_info.shape();
454 Shape input_a_slice_shape = input_a_tensor_info.slice_shape();
455 int64_t used_device_num = 1;
456 for (size_t i = 0; i < input_a_shape.size(); ++i) {
457 used_device_num *= input_a_shape[i] / input_a_slice_shape[i];
458 }
459 if (total_device_num != LongToSize(used_device_num)) {
460 result += ListProduct(input_a_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
461 }
462 }
463
464 return result;
465 }
466
CalculateOutputInMemory()467 void BatchParallelCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
468
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)469 void BatchParallelCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
470 if (is_parameter_[0]) {
471 is_inputs_should_in_memory_[0] = true;
472 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
473 is_inputs_should_in_memory_[1] = true;
474 }
475 } else if (is_parameter_involve_[0]) {
476 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
477 is_inputs_should_in_memory_[1] = true;
478 }
479 }
480
481 if (is_parameter_[1]) {
482 is_inputs_should_in_memory_[1] = true;
483 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
484 is_inputs_should_in_memory_[0] = true;
485 }
486 } else if (is_parameter_involve_[1]) {
487 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
488 is_inputs_should_in_memory_[0] = true;
489 }
490 }
491 }
492
CalculateOutputInMemory()493 void SparseSoftmaxCrossEntropyWithLogitsCost::CalculateOutputInMemory() {
494 is_output_should_in_memory_ = is_parameter_involve_[0];
495 }
496
CalculateInputsInMemory(const std::map<size_t,bool> &)497 void SparseSoftmaxCrossEntropyWithLogitsCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
498 is_inputs_should_in_memory_[0] = is_parameter_[0];
499 is_inputs_should_in_memory_[1] = is_parameter_[1];
500 }
501 // return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const502 double PReLUCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const {
503 // prelu does not need communication in the forward phase
504 return 0.0;
505 }
506
507 // return the per device communication cost in the backward phase.
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const508 double PReLUCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
509 int64_t stage_id) const {
510 double result = 0.0;
511 if (is_parameter_[1]) {
512 TensorInfo input1 = inputs[1];
513 CheckGlobalDeviceManager();
514 MS_EXCEPTION_IF_NULL(g_device_manager);
515 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
516 Shape input1_shape = input1.shape();
517 Shape input1_slice_shape = input1.slice_shape();
518 int64_t used_device_num = 1;
519 for (size_t i = 0; i < input1_shape.size(); ++i) {
520 used_device_num *= input1_shape[i] / input1_slice_shape[i];
521 }
522 if (total_device_num != LongToSize(used_device_num)) {
523 result = ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
524 }
525 }
526 return result;
527 }
528
529 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
530 // this operator uses
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const531 double PReLUCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
532 int64_t) const {
533 // In forward phase, the computation cost = slice(A) + slice(B)
534 Shape input0_slice_shape = inputs[0].slice_shape();
535 Shape input1_slice_shape = inputs[1].slice_shape();
536 double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
537 ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
538 return result;
539 }
540
541 // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes
542 // this operator uses
GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> & inputs,const std::vector<mindspore::parallel::TensorInfo> &,int64_t stage_id) const543 double PReLUCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &inputs,
544 const std::vector<mindspore::parallel::TensorInfo> &,
545 int64_t stage_id) const {
546 // In backward phase, the computation cost = (0 or 1) allreduce(slice(B))
547 double result = 0.0;
548 if (is_parameter_[1]) {
549 TensorInfo input1 = inputs[1]; // tensor B
550 CheckGlobalDeviceManager();
551 MS_EXCEPTION_IF_NULL(g_device_manager);
552 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
553
554 Shape input1_shape = input1.shape();
555 Shape input1_slice_shape = input1.slice_shape();
556 int64_t used_device_num = 1;
557 for (size_t i = 0; i < input1_shape.size(); ++i) {
558 used_device_num *= input1_shape[i] / input1_slice_shape[i];
559 }
560
561 if (total_device_num != LongToSize(used_device_num)) {
562 result += ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
563 }
564 }
565 return result;
566 }
567
CalculateOutputInMemory()568 void PReLUCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
569
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)570 void PReLUCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
571 // When calculating 'dx', taking account of both 'x' and 'y';
572 // when calculating 'dy', taking account of both 'x' and 'y'
573 if (is_parameter_involve_[0] || is_parameter_involve_[1]) {
574 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
575 is_inputs_should_in_memory_[0] = true;
576 }
577 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
578 is_inputs_should_in_memory_[1] = true;
579 }
580 }
581 }
582
583 // return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const584 double OneHotCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const {
585 // onehot does not need communication in the forward phase
586 return 0.0;
587 }
588
589 // return the per device communication cost in the backward phase.
GetBackwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const590 double OneHotCost::GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
591 int64_t) const {
592 // onehot does not need communication in the backward phase
593 return 0.0;
594 }
595
596 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
597 // this operator uses
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const598 double OneHotCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
599 int64_t) const {
600 // In onehot's forward phase, the computation cost = slice(A)
601 Shape input0_slice_shape = inputs[0].slice_shape();
602 return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
603 }
604
605 // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes
606 // this operator uses
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const607 double OneHotCost::GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
608 int64_t) const {
609 return 0.0;
610 }
611
612 // Not taking account of output
CalculateOutputInMemory()613 void OneHotCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
614
615 // Not taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> &)616 void OneHotCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
617 is_inputs_should_in_memory_[0] = is_parameter_[0];
618 is_inputs_should_in_memory_[1] = is_parameter_[1];
619 is_inputs_should_in_memory_[ONEHOT_INPUTS_SIZE - 2] = is_parameter_[ONEHOT_INPUTS_SIZE - 2];
620 is_inputs_should_in_memory_[ONEHOT_INPUTS_SIZE - 1] = is_parameter_[ONEHOT_INPUTS_SIZE - 1];
621 }
622
623 // return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const624 double SoftmaxCrossEntropyWithLogitsCost::GetForwardCommCost(const std::vector<TensorInfo> &,
625 const std::vector<TensorInfo> &, int64_t) const {
626 // SoftmaxCrossEntropyWithLogitsCost does not need communication in the forward phase
627 return 0.0;
628 }
629
630 // return the per device communication cost in the backward phase.
GetBackwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const631 double SoftmaxCrossEntropyWithLogitsCost::GetBackwardCommCost(const std::vector<TensorInfo> &,
632 const std::vector<TensorInfo> &, int64_t) const {
633 // SoftmaxCrossEntropyWithLogitsCost does not need communication in the backward phase
634 return 0.0;
635 }
636
637 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
638 // this operator uses
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const639 double SoftmaxCrossEntropyWithLogitsCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
640 const std::vector<TensorInfo> &, int64_t) const {
641 // In forward phase, the computation cost = slice(A) + slice(B)
642 Shape input0_slice_shape = inputs[0].slice_shape();
643 Shape input1_slice_shape = inputs[1].slice_shape();
644 double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
645 ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
646 return result;
647 }
648
649 // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes
650 // this operator uses
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const651 double SoftmaxCrossEntropyWithLogitsCost::GetBackwardComputationCost(const std::vector<TensorInfo> &,
652 const std::vector<TensorInfo> &, int64_t) const {
653 return 0.0;
654 }
655
656 // Taking account of output
CalculateOutputInMemory()657 void SoftmaxCrossEntropyWithLogitsCost::CalculateOutputInMemory() {
658 is_output_should_in_memory_ = is_parameter_involve_[0];
659 }
660
CalculateInputsInMemory(const std::map<size_t,bool> &)661 void SoftmaxCrossEntropyWithLogitsCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
662 is_inputs_should_in_memory_[0] = is_parameter_[0];
663 is_inputs_should_in_memory_[1] = is_parameter_[1];
664 }
665
666 // return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id) const667 double ReshapeCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
668 int64_t stage_id) const {
669 CheckGlobalDeviceManager();
670 MS_EXCEPTION_IF_NULL(g_device_manager);
671 RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id);
672 TensorRedistribution tensor_redistribution(false, true);
673 if (tensor_redistribution.Init(inputs[0].tensor_layout(), outputs[0].tensor_layout(), dev_list) == FAILED) {
674 MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed.";
675 }
676 if (tensor_redistribution.ComputeCost() == FAILED) {
677 MS_LOG(EXCEPTION) << "Failure: tensor_redistribution ComputeCost failed.";
678 }
679 return (inputs_type_lengths_[0] * tensor_redistribution.comm_cost());
680 }
681
682 // return the per device communication cost in the backward phase.
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const683 double ReshapeCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
684 int64_t stage_id) const {
685 double result = 0.0;
686 if (is_parameter_[0]) {
687 TensorInfo input1 = inputs[0];
688 MS_EXCEPTION_IF_NULL(g_device_manager);
689 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
690 Shape input1_shape = input1.shape();
691 Shape input1_slice_shape = input1.slice_shape();
692 int64_t used_device_num = 1;
693 for (size_t i = 0; i < input1_shape.size(); ++i) {
694 used_device_num *= input1_shape[i] / input1_slice_shape[i];
695 }
696 if (total_device_num != LongToSize(used_device_num)) {
697 result = ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
698 }
699 }
700 return result;
701 }
702
703 // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
704 // this operator uses
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id) const705 double ReshapeCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
706 const std::vector<TensorInfo> &outputs, int64_t stage_id) const {
707 CheckGlobalDeviceManager();
708 MS_EXCEPTION_IF_NULL(g_device_manager);
709 RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id);
710 TensorRedistribution tensor_redistribution(false, true);
711 if (tensor_redistribution.Init(inputs[0].tensor_layout(), outputs[0].tensor_layout(), dev_list) == FAILED) {
712 MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed.";
713 }
714 if (tensor_redistribution.ComputeCost() == FAILED) {
715 MS_LOG(EXCEPTION) << "Failure: tensor_redistribution ComputeCost failed.";
716 }
717 return (inputs_type_lengths_[0] * tensor_redistribution.computation_cost());
718 }
719
720 // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes
721 // this operator uses
GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,const std::vector<mindspore::parallel::TensorInfo> &,int64_t) const722 double ReshapeCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &,
723 const std::vector<mindspore::parallel::TensorInfo> &, int64_t) const {
724 return 0.0;
725 }
726
CalculateOutputInMemory()727 void ReshapeCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
728
CalculateInputsInMemory(const std::map<size_t,bool> &)729 void ReshapeCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
730 is_inputs_should_in_memory_[0] = is_parameter_[0];
731 is_inputs_should_in_memory_[1] = is_parameter_[1];
732 }
733
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const734 double SubCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
735 int64_t) const {
736 double result = ListProduct(inputs[0].slice_shape()) * static_cast<double>(inputs_type_lengths_[0]) +
737 ListProduct(inputs[1].slice_shape()) * static_cast<double>(inputs_type_lengths_[1]);
738 return result;
739 }
740
GetBackwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const741 double SubCost::GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
742 int64_t stage_id) const {
743 double result = 0.0;
744 CheckGlobalDeviceManager();
745 MS_EXCEPTION_IF_NULL(g_device_manager);
746 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
747
748 if (is_parameter_[0]) {
749 TensorInfo input_a_tensor_info = inputs[0];
750 Shape input_a_shape = input_a_tensor_info.shape();
751 Shape input_a_slice_shape = input_a_tensor_info.slice_shape();
752 int64_t used_device_num = 1;
753 for (size_t i = 0; i < input_a_shape.size(); ++i) {
754 used_device_num *= input_a_shape[i] / input_a_slice_shape[i];
755 }
756
757 if (total_device_num != LongToSize(used_device_num))
758 result += ListProduct(input_a_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
759 }
760
761 if (is_parameter_[1]) {
762 TensorInfo input_b_tensor_info = inputs[1];
763 Shape input_b_shape = input_b_tensor_info.shape();
764 Shape input_b_slice_shape = input_b_tensor_info.slice_shape();
765 int64_t used_device_num = 1;
766 for (size_t i = 0; i < input_b_shape.size(); ++i) {
767 used_device_num *= input_b_shape[i] / input_b_slice_shape[i];
768 }
769
770 if (total_device_num != LongToSize(used_device_num))
771 result += ListProduct(input_b_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
772 }
773 return result;
774 }
775
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const776 double SubCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
777 int64_t stage_id) const {
778 double result = 0.0;
779 CheckGlobalDeviceManager();
780 MS_EXCEPTION_IF_NULL(g_device_manager);
781 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
782
783 if (is_parameter_[0]) {
784 TensorInfo input_a_tensor_info = inputs[0];
785 Shape input_a_shape = input_a_tensor_info.shape();
786 Shape input_a_slice_shape = input_a_tensor_info.slice_shape();
787 int64_t used_device_num = 1;
788 for (size_t i = 0; i < input_a_shape.size(); ++i) {
789 used_device_num *= input_a_shape[i] / input_a_slice_shape[i];
790 }
791
792 if (total_device_num != LongToSize(used_device_num))
793 result += ListProduct(input_a_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
794 }
795
796 if (is_parameter_[1]) {
797 TensorInfo input_b_tensor_info = inputs[1];
798 Shape input_b_shape = input_b_tensor_info.shape();
799 Shape input_b_slice_shape = input_b_tensor_info.slice_shape();
800 int64_t used_device_num = 1;
801 for (size_t i = 0; i < input_b_shape.size(); ++i) {
802 used_device_num *= input_b_shape[i] / input_b_slice_shape[i];
803 }
804
805 if (total_device_num != LongToSize(used_device_num))
806 result += ListProduct(input_b_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
807 }
808
809 return result;
810 }
811
812 // Not taking account of output
CalculateOutputInMemory()813 void SubCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
814
815 // Not taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> &)816 void SubCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
817 is_inputs_should_in_memory_[0] = is_parameter_[0];
818 is_inputs_should_in_memory_[1] = is_parameter_[1];
819 }
820
821 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)822 void MulCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
823 if (is_parameter_[0]) {
824 // 'x' is parameter, so it should be in memory.
825 is_inputs_should_in_memory_[0] = true;
826 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
827 // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
828 is_inputs_should_in_memory_[1] = true;
829 }
830 } else if (is_parameter_involve_[0]) {
831 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
832 is_inputs_should_in_memory_[1] = true;
833 }
834 }
835
836 if (is_parameter_[1]) {
837 is_inputs_should_in_memory_[1] = true;
838 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
839 is_inputs_should_in_memory_[0] = true;
840 }
841 } else if (is_parameter_involve_[1]) {
842 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
843 is_inputs_should_in_memory_[0] = true;
844 }
845 }
846 }
847
848 // Taking account of output
CalculateOutputInMemory()849 void DivCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[1]; }
850
851 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)852 void DivCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
853 // When calculating 'dx', taking account of 'y'
854 if (is_parameter_[0]) {
855 // 'x' is parameter, so it should be in memory.
856 is_inputs_should_in_memory_[0] = true;
857 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
858 // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
859 is_inputs_should_in_memory_[1] = true;
860 }
861 } else if (is_parameter_involve_[0]) {
862 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
863 is_inputs_should_in_memory_[1] = true;
864 }
865 }
866
867 // When calculating 'dy', taking account of 'y'
868 if (is_parameter_involve_[1]) {
869 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
870 is_inputs_should_in_memory_[1] = true;
871 }
872 }
873 }
874
875 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)876 void ModCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
877 // When calculating 'dx', not taking account of 'x' and 'y'
878 is_inputs_should_in_memory_[0] = is_parameter_[0];
879 // When calculating 'dy', taking account of 'x' and 'y'
880 if (is_parameter_involve_[1]) {
881 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
882 is_inputs_should_in_memory_[0] = true;
883 }
884 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
885 is_inputs_should_in_memory_[1] = true;
886 }
887 }
888 }
889
CalculateOutputInMemory()890 void PowCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[1]; }
891
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)892 void PowCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
893 // When calculating 'dx', taking account of both 'x' and 'power'
894 if (is_parameter_involve_[0]) {
895 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
896 is_inputs_should_in_memory_[0] = true;
897 }
898 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
899 is_inputs_should_in_memory_[1] = true;
900 }
901 }
902 // When calculating 'dpower', taking account of 'x'
903 if (is_parameter_[1]) {
904 is_inputs_should_in_memory_[1] = true;
905 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
906 is_inputs_should_in_memory_[0] = true;
907 }
908 } else if (is_parameter_involve_[1]) {
909 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
910 is_inputs_should_in_memory_[0] = true;
911 }
912 }
913 }
914
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)915 void AssignCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
916 // When calculating 'dx', taking account of 'x'
917 if (is_parameter_involve_[0]) {
918 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
919 is_inputs_should_in_memory_[0] = true;
920 }
921 }
922 // When calculating 'dy', not taking account of 'x' and 'y'
923 is_inputs_should_in_memory_[1] = is_parameter_[1];
924 }
925
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)926 void SigmoidCrossEntropyWithLogitsCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
927 // When calculating 'dx', taking account of both 'x' and 'y'
928 if (is_parameter_involve_[0]) {
929 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
930 is_inputs_should_in_memory_[0] = true;
931 }
932 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
933 is_inputs_should_in_memory_[1] = true;
934 }
935 }
936 // When calculating 'dy', not taking account of 'x' and 'y'
937 if (!is_inputs_should_in_memory_[1]) {
938 is_inputs_should_in_memory_[1] = is_parameter_[1];
939 }
940 }
941
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)942 void Atan2Cost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
943 // When calculating 'dx', taking account of both 'x' and 'y'; when calculating 'dy', taking account of both 'x' and
944 // 'y'
945 if (is_parameter_involve_[0] || is_parameter_involve_[1]) {
946 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
947 is_inputs_should_in_memory_[0] = true;
948 }
949 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
950 is_inputs_should_in_memory_[1] = true;
951 }
952 }
953 }
954
CalculateOutputInMemory()955 void DivNoNanCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[1]; }
956
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)957 void DivNoNanCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
958 // When calculating 'dx', taking account of 'y'
959 if (is_parameter_[0]) {
960 // 'x' is parameter, so it should be in memory.
961 is_inputs_should_in_memory_[0] = true;
962 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
963 // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
964 is_inputs_should_in_memory_[1] = true;
965 }
966 } else if (is_parameter_involve_[0]) {
967 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
968 is_inputs_should_in_memory_[1] = true;
969 }
970 }
971
972 // When calculating 'dy', taking account of 'y'
973 if (is_parameter_involve_[1]) {
974 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
975 is_inputs_should_in_memory_[1] = true;
976 }
977 }
978 }
979
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)980 void MaximumCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
981 // When calculating 'dx', taking account of both 'x' and 'y';
982 // when calculating 'dy', taking account of both 'x' and 'y'
983 if (is_parameter_involve_[0] || is_parameter_involve_[1]) {
984 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
985 is_inputs_should_in_memory_[0] = true;
986 }
987 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
988 is_inputs_should_in_memory_[1] = true;
989 }
990 }
991 }
992
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)993 void SliceCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
994 // When calculating 'dx', taking account of 'y' and 'z'
995 if (is_parameter_[0]) {
996 is_inputs_should_in_memory_[0] = true;
997 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
998 is_inputs_should_in_memory_[1] = true;
999 }
1000 if ((prev_output_in_mem.find(SLICE_INPUTS_SIZE - 1) == prev_output_in_mem.end()) ||
1001 (!prev_output_in_mem.at(SLICE_INPUTS_SIZE - 1))) {
1002 is_inputs_should_in_memory_[SLICE_INPUTS_SIZE - 1] = true;
1003 }
1004 } else if (is_parameter_involve_[0]) {
1005 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1006 is_inputs_should_in_memory_[1] = true;
1007 }
1008 if ((prev_output_in_mem.find(SLICE_INPUTS_SIZE - 1) == prev_output_in_mem.end()) ||
1009 (!prev_output_in_mem.at(SLICE_INPUTS_SIZE - 1))) {
1010 is_inputs_should_in_memory_[SLICE_INPUTS_SIZE - 1] = true;
1011 }
1012 }
1013
1014 if (!is_inputs_should_in_memory_[1]) {
1015 is_inputs_should_in_memory_[1] = is_parameter_[1];
1016 }
1017 if (!is_inputs_should_in_memory_[SLICE_INPUTS_SIZE - 1]) {
1018 is_inputs_should_in_memory_[SLICE_INPUTS_SIZE - 1] = is_parameter_[SLICE_INPUTS_SIZE - 1];
1019 }
1020 }
1021
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1022 void StridedSliceCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1023 // When calculating 'dx', taking account of 'y', 'z' and 'w'
1024 if (is_parameter_[0]) {
1025 is_inputs_should_in_memory_[0] = true;
1026 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1027 is_inputs_should_in_memory_[1] = true;
1028 }
1029 if ((prev_output_in_mem.find(STRIDED_SLICE_INPUTS_SIZE - 2) == prev_output_in_mem.end()) ||
1030 (!prev_output_in_mem.at(STRIDED_SLICE_INPUTS_SIZE - 2))) {
1031 is_inputs_should_in_memory_[STRIDED_SLICE_INPUTS_SIZE - 2] = true;
1032 }
1033 if ((prev_output_in_mem.find(STRIDED_SLICE_INPUTS_SIZE - 1) == prev_output_in_mem.end()) ||
1034 (!prev_output_in_mem.at(STRIDED_SLICE_INPUTS_SIZE - 1))) {
1035 is_inputs_should_in_memory_[STRIDED_SLICE_INPUTS_SIZE - 1] = true;
1036 }
1037 } else if (is_parameter_involve_[0]) {
1038 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1039 is_inputs_should_in_memory_[1] = true;
1040 }
1041 if ((prev_output_in_mem.find(STRIDED_SLICE_INPUTS_SIZE - 2) == prev_output_in_mem.end()) ||
1042 (!prev_output_in_mem.at(STRIDED_SLICE_INPUTS_SIZE - 2))) {
1043 is_inputs_should_in_memory_[STRIDED_SLICE_INPUTS_SIZE - 2] = true;
1044 }
1045 if ((prev_output_in_mem.find(STRIDED_SLICE_INPUTS_SIZE - 1) == prev_output_in_mem.end()) ||
1046 (!prev_output_in_mem.at(STRIDED_SLICE_INPUTS_SIZE - 1))) {
1047 is_inputs_should_in_memory_[STRIDED_SLICE_INPUTS_SIZE - 1] = true;
1048 }
1049 }
1050
1051 if (!is_inputs_should_in_memory_[1]) {
1052 is_inputs_should_in_memory_[1] = is_parameter_[1];
1053 }
1054 if (!is_inputs_should_in_memory_[STRIDED_SLICE_INPUTS_SIZE - 2]) {
1055 is_inputs_should_in_memory_[STRIDED_SLICE_INPUTS_SIZE - 2] = is_parameter_[STRIDED_SLICE_INPUTS_SIZE - 2];
1056 }
1057 if (!is_inputs_should_in_memory_[STRIDED_SLICE_INPUTS_SIZE - 1]) {
1058 is_inputs_should_in_memory_[STRIDED_SLICE_INPUTS_SIZE - 1] = is_parameter_[STRIDED_SLICE_INPUTS_SIZE - 1];
1059 }
1060 }
1061
CalculateOutputInMemory()1062 void DropOutDoMaskCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
1063
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1064 void DropOutDoMaskCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1065 // When calculating 'dx', taking account of 'y'
1066 if (is_parameter_[0]) {
1067 // 'x' is parameter, so it should be in memory.
1068 is_inputs_should_in_memory_[0] = true;
1069 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1070 // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
1071 is_inputs_should_in_memory_[1] = true;
1072 }
1073 } else if (is_parameter_involve_[0]) {
1074 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1075 is_inputs_should_in_memory_[1] = true;
1076 }
1077 }
1078
1079 if (!is_inputs_should_in_memory_[1]) {
1080 is_inputs_should_in_memory_[1] = is_parameter_[1];
1081 }
1082 is_inputs_should_in_memory_[DROPOUTDOMASK_INPUTS_SIZE - 1] = is_parameter_[DROPOUTDOMASK_INPUTS_SIZE - 1];
1083 }
1084
IsDataParallel(const Shape & shape,const Shape & slice_shape,int64_t stage_id)1085 bool IsDataParallel(const Shape &shape, const Shape &slice_shape, int64_t stage_id) {
1086 CheckGlobalDeviceManager();
1087 MS_EXCEPTION_IF_NULL(g_device_manager);
1088 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
1089 auto strategy0 = shape[0] / slice_shape[0];
1090
1091 return (total_device_num == LongToSize(strategy0));
1092 }
1093
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id) const1094 double ReduceSumCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1095 int64_t stage_id) const {
1096 double result = 0.0;
1097 TensorInfo input0 = inputs[0];
1098 TensorInfo output0 = outputs[0];
1099 Shape input0_shape = input0.shape();
1100 Shape input0_slice_shape = input0.slice_shape();
1101 if (cross_batch_ && IsDataParallel(input0_shape, input0_slice_shape, stage_id)) {
1102 return result;
1103 }
1104 std::vector<int64_t> dim_list = input0.reduce_dim();
1105 auto pos = std::find_if(dim_list.begin(), dim_list.end(), [input0_shape, input0_slice_shape](int64_t index) {
1106 return input0_shape[LongToSize(index)] != input0_slice_shape[LongToSize(index)];
1107 });
1108 if (pos != dim_list.end()) {
1109 result += ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
1110 }
1111
1112 return result;
1113 }
1114
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const1115 double ReduceSumCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1116 int64_t stage_id) const {
1117 double result = 0.0;
1118 if (is_parameter_[0]) {
1119 TensorInfo input_tensor_info = inputs[0];
1120 CheckGlobalDeviceManager();
1121 MS_EXCEPTION_IF_NULL(g_device_manager);
1122 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
1123
1124 Shape input_shape = input_tensor_info.shape();
1125 Shape input_slice_shape = input_tensor_info.slice_shape();
1126 int64_t used_device_num = 1;
1127 for (size_t i = 0; i < input_shape.size(); ++i) {
1128 used_device_num *= input_shape[i] / input_slice_shape[i];
1129 }
1130
1131 if (total_device_num != LongToSize(used_device_num))
1132 result += ListProduct(input_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1133 }
1134
1135 return result;
1136 }
1137
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id) const1138 double ReduceSumCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
1139 const std::vector<TensorInfo> &outputs, int64_t stage_id) const {
1140 double result = 0.0;
1141 TensorInfo input0 = inputs[0];
1142 TensorInfo output0 = outputs[0];
1143 std::vector<int64_t> dim_list = input0.reduce_dim();
1144 Shape input0_slice_shape = input0.slice_shape();
1145 Shape input0_shape = input0.shape();
1146 if (!cross_batch_ || !IsDataParallel(input0_shape, input0_slice_shape, stage_id)) {
1147 auto pos = std::find_if(dim_list.begin(), dim_list.end(), [input0_shape, input0_slice_shape](int64_t index) {
1148 return input0_shape[LongToSize(index)] != input0_slice_shape[LongToSize(index)];
1149 });
1150 if (pos != dim_list.end()) {
1151 result += ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
1152 }
1153 }
1154 result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1155
1156 return result;
1157 }
1158
1159 // Not taking account of output
CalculateOutputInMemory()1160 void ReduceSumCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
1161
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1162 void ReduceSumCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1163 // When calculating 'dx', taking account of 'y'
1164 if (is_parameter_[0]) {
1165 // 'x' is parameter, so it should be in memory.
1166 is_inputs_should_in_memory_[0] = true;
1167 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1168 // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
1169 is_inputs_should_in_memory_[1] = true;
1170 }
1171 } else if (is_parameter_involve_[0]) {
1172 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1173 is_inputs_should_in_memory_[1] = true;
1174 }
1175 }
1176
1177 // Not taking account of 'y'
1178 if (!is_inputs_should_in_memory_[1]) {
1179 is_inputs_should_in_memory_[1] = is_parameter_[1];
1180 }
1181 }
1182
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id) const1183 double ReduceMeanCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
1184 const std::vector<TensorInfo> &outputs, int64_t stage_id) const {
1185 double result = 0.0;
1186 TensorInfo input0 = inputs[0];
1187 TensorInfo output0 = outputs[0];
1188 std::vector<int64_t> dim_list = input0.reduce_dim();
1189 Shape input0_slice_shape = input0.slice_shape();
1190 Shape input0_shape = input0.shape();
1191 if (!cross_batch_ || !IsDataParallel(input0_shape, input0_slice_shape, stage_id)) {
1192 auto pos = std::find_if(dim_list.begin(), dim_list.end(), [input0_shape, input0_slice_shape](int64_t index) {
1193 return input0_shape[LongToSize(index)] != input0_slice_shape[LongToSize(index)];
1194 });
1195 if (pos != dim_list.end()) {
1196 result += ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]) * 2.0;
1197 }
1198 }
1199 result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1200
1201 return result;
1202 }
1203
CalculateOutputInMemory()1204 void ReduceMinCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
1205
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1206 void ReduceMinCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1207 // When calculating 'dx', taking account of 'y'
1208 if (is_parameter_[0]) {
1209 // 'x' is parameter, so it should be in memory.
1210 is_inputs_should_in_memory_[0] = true;
1211 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1212 // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here.
1213 is_inputs_should_in_memory_[1] = true;
1214 }
1215 } else if (is_parameter_involve_[0]) {
1216 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1217 is_inputs_should_in_memory_[1] = true;
1218 }
1219 }
1220
1221 // Not taking account of 'y'
1222 if (!is_inputs_should_in_memory_[1]) {
1223 is_inputs_should_in_memory_[1] = is_parameter_[1];
1224 }
1225 }
1226
CalculateOutputInMemory()1227 void ArgMaxWithValueCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
1228
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1229 void ArgMaxWithValueCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1230 // When calculating 'dx', taking account of 'x'
1231 if (is_parameter_[0]) {
1232 is_inputs_should_in_memory_[0] = true;
1233 } else if (is_parameter_involve_[0]) {
1234 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
1235 is_inputs_should_in_memory_[0] = true;
1236 }
1237 }
1238 }
1239
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const1240 double DropOutCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1241 int64_t) const {
1242 if (inputs.empty()) {
1243 return 0.0;
1244 }
1245 TensorInfo input0 = inputs[0];
1246 Shape input0_slice_shape = input0.slice_shape();
1247 return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * DROPOUT_COST_RATE;
1248 }
1249
1250 // return the per device communication cost in the forward phase.
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const1251 double GatherV2Cost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
1252 int64_t) const {
1253 // GatherV2Cost does not need communication in the forward phase
1254 return 0.0;
1255 }
1256
1257 // return the per device communication cost in the backward phase.
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const1258 double GatherV2Cost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1259 int64_t stage_id) const {
1260 double result = 0.0;
1261 CheckGlobalDeviceManager();
1262 MS_EXCEPTION_IF_NULL(g_device_manager);
1263 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
1264
1265 for (size_t j = 0; j < inputs.size(); ++j) {
1266 if (!is_parameter_[j]) {
1267 continue;
1268 }
1269 TensorInfo input_a_tensor_info = inputs[j];
1270 Shape input_a_shape = input_a_tensor_info.shape();
1271 Shape input_a_slice_shape = input_a_tensor_info.slice_shape();
1272 int64_t used_device_num = 1;
1273 for (size_t i = 0; i < input_a_shape.size(); ++i) {
1274 used_device_num *= input_a_shape[i] / input_a_slice_shape[i];
1275 }
1276 if (total_device_num != LongToSize(used_device_num)) {
1277 result += ListProduct(input_a_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1278 }
1279 }
1280
1281 return result;
1282 }
1283
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const1284 double GatherV2Cost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1285 int64_t) const {
1286 // In forward phase, the computation cost = slice(A) + slice(B)
1287 Shape input0_slice_shape = inputs[0].slice_shape();
1288 Shape input1_slice_shape = inputs[1].slice_shape();
1289 double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
1290 ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
1291 return result;
1292 }
1293
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const1294 double GatherV2Cost::GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
1295 int64_t) const {
1296 return 0.0;
1297 }
1298
1299 // Not taking account of output
CalculateOutputInMemory()1300 void GatherV2Cost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
1301
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1302 void GatherV2Cost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1303 // When calculating 'dx', taking account of 'y' and 'z'
1304 if (is_parameter_[0]) {
1305 // 'x' is parameter, so it should be in memory.
1306 is_inputs_should_in_memory_[0] = true;
1307 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1308 is_inputs_should_in_memory_[1] = true;
1309 }
1310 if ((prev_output_in_mem.find(GATHERV2_INPUTS_SIZE - 1) == prev_output_in_mem.end()) ||
1311 (!prev_output_in_mem.at(GATHERV2_INPUTS_SIZE - 1))) {
1312 is_inputs_should_in_memory_[GATHERV2_INPUTS_SIZE - 1] = true;
1313 }
1314 } else if (is_parameter_involve_[0]) {
1315 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1316 is_inputs_should_in_memory_[1] = true;
1317 }
1318 if ((prev_output_in_mem.find(GATHERV2_INPUTS_SIZE - 1) == prev_output_in_mem.end()) ||
1319 (!prev_output_in_mem.at(GATHERV2_INPUTS_SIZE - 1))) {
1320 is_inputs_should_in_memory_[GATHERV2_INPUTS_SIZE - 1] = true;
1321 }
1322 }
1323
1324 if (!is_inputs_should_in_memory_[1]) {
1325 is_inputs_should_in_memory_[1] = is_parameter_[1];
1326 }
1327 if (!is_inputs_should_in_memory_[GATHERV2_INPUTS_SIZE - 1]) {
1328 is_inputs_should_in_memory_[GATHERV2_INPUTS_SIZE - 1] = is_parameter_[GATHERV2_INPUTS_SIZE - 1];
1329 }
1330 }
1331
CalculateOutputInMemory()1332 void GetNextCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
1333
CalculateInputsInMemory(const std::map<size_t,bool> &)1334 void GetNextCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
1335 if (is_inputs_should_in_memory_.size() == 0) {
1336 return;
1337 }
1338 is_inputs_should_in_memory_[0] = is_parameter_[0];
1339 }
1340
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const1341 double DSDMatmulCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1342 int64_t) const {
1343 double result = 0.0;
1344 if (inputs_type_lengths_.size() != inputs.size()) {
1345 MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for layer norm cost";
1346 }
1347
1348 for (size_t index = 0; index < inputs.size(); ++index) {
1349 TensorInfo tensor_info = inputs[index];
1350 Shape slice_shape = tensor_info.slice_shape();
1351 result += ListProduct(slice_shape) * static_cast<double>(inputs_type_lengths_[index]);
1352 }
1353 return result;
1354 }
1355
CalculateOutputInMemory()1356 void DSDMatmulCost::CalculateOutputInMemory() {
1357 is_output_should_in_memory_ =
1358 (std::find(is_parameter_involve_.begin(), is_parameter_involve_.end(), true) != is_parameter_involve_.end());
1359 }
1360
CalculateInputsInMemory(const std::map<size_t,bool> &)1361 void DSDMatmulCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
1362 bool keep_mem =
1363 (std::find(is_parameter_.begin(), is_parameter_.end(), true) != is_parameter_.end()) ||
1364 (std::find(is_parameter_involve_.begin(), is_parameter_involve_.end(), true) != is_parameter_involve_.end());
1365 std::fill(is_inputs_should_in_memory_.begin(), is_inputs_should_in_memory_.end(), keep_mem);
1366 }
1367
CalculateOutputInMemory()1368 void UniqueCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
1369
CalculateInputsInMemory(const std::map<size_t,bool> &)1370 void UniqueCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
1371 is_inputs_should_in_memory_[0] = is_parameter_[0];
1372 }
1373
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const1374 double LayerNormCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1375 int64_t stage_id) const {
1376 double result = 0.0;
1377 if (is_parameter_.size() != inputs.size()) {
1378 MS_LOG(EXCEPTION) << "Invalid parameter size " << is_parameter_.size() << " for layer norm cost";
1379 }
1380 if (inputs_type_lengths_.size() != inputs.size()) {
1381 MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for layer norm cost";
1382 }
1383
1384 MS_EXCEPTION_IF_NULL(g_device_manager);
1385 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
1386
1387 for (size_t index = 0; index < inputs.size(); ++index) {
1388 if (is_parameter_[index]) {
1389 TensorInfo tensor_info = inputs[index];
1390 Shape shape = tensor_info.shape();
1391 Shape slice_shape = tensor_info.slice_shape();
1392 int64_t used_device_num = 1;
1393 for (size_t i = 0; i < shape.size(); ++i) {
1394 if (slice_shape[i] == 0) {
1395 MS_LOG(EXCEPTION) << "Invalid slice shape " << ShapeToString(slice_shape);
1396 }
1397 used_device_num *= shape[i] / slice_shape[i];
1398 }
1399 if (total_device_num != LongToSize(used_device_num)) {
1400 result += ListProduct(slice_shape) * static_cast<double>(inputs_type_lengths_[index]);
1401 }
1402 }
1403 }
1404 return result;
1405 }
1406
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const1407 double LayerNormCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1408 int64_t) const {
1409 double result = 0.0;
1410 if (inputs_type_lengths_.size() != inputs.size()) {
1411 MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for layer norm cost";
1412 }
1413
1414 for (size_t index = 0; index < inputs.size(); ++index) {
1415 TensorInfo tensor_info = inputs[index];
1416 Shape slice_shape = tensor_info.slice_shape();
1417 result += ListProduct(slice_shape) * static_cast<double>(inputs_type_lengths_[index]);
1418 }
1419 return result;
1420 }
1421
CalculateOutputInMemory()1422 void LayerNormCost::CalculateOutputInMemory() {
1423 is_output_should_in_memory_ =
1424 is_parameter_involve_[0] || is_parameter_involve_[1] || is_parameter_involve_[LAYERNORM_INPUTS_SIZE - 1];
1425 }
1426
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1427 void LayerNormCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1428 // When calculating 'dx', taking account of both 'x' and 'y'
1429 // When calculating 'dy', taking account of both 'x' and 'y'
1430 if (is_parameter_involve_[0] || is_parameter_involve_[1]) {
1431 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
1432 is_inputs_should_in_memory_[0] = true;
1433 }
1434 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1435 is_inputs_should_in_memory_[1] = true;
1436 }
1437 }
1438 is_inputs_should_in_memory_[LAYERNORM_INPUTS_SIZE - 1] = is_parameter_[LAYERNORM_INPUTS_SIZE - 1];
1439 }
1440
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t) const1441 double UniqueCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const {
1442 return 0.0;
1443 }
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const1444 double UniqueCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1445 int64_t stage_id) const {
1446 double result = 0.0;
1447 if (is_parameter_[0]) {
1448 TensorInfo input = inputs[0];
1449 CheckGlobalDeviceManager();
1450 MS_EXCEPTION_IF_NULL(g_device_manager);
1451 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
1452 Shape input_shape = input.shape();
1453 Shape input_slice_shape = input.slice_shape();
1454 int64_t used_device_num = 1;
1455 for (size_t i = 0; i < input_shape.size(); ++i) {
1456 used_device_num *= input_shape[i] / input_slice_shape[i];
1457 }
1458 if (total_device_num != LongToSize(used_device_num)) {
1459 result = ListProduct(input_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1460 }
1461 }
1462 return result;
1463 }
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const1464 double UniqueCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1465 int64_t) const {
1466 // In forward phase, the computation cost = slice(A) + slice(B)
1467 Shape input_slice_shape = inputs[0].slice_shape();
1468 double result = ListProduct(input_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1469 return result;
1470 }
GetBackwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const1471 double UniqueCost::GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1472 int64_t stage_id) const {
1473 // In backward phase, the computation cost = (0 or 1) allreduce(slice(B))
1474 double result = 0.0;
1475 if (is_parameter_[0]) {
1476 TensorInfo input = inputs[0]; // tensor B
1477 CheckGlobalDeviceManager();
1478 MS_EXCEPTION_IF_NULL(g_device_manager);
1479 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
1480
1481 Shape input_shape = input.shape();
1482 Shape input_slice_shape = input.slice_shape();
1483 int64_t used_device_num = 1;
1484 for (size_t i = 0; i < input_shape.size(); ++i) {
1485 used_device_num *= input_shape[i] / input_slice_shape[i];
1486 }
1487
1488 if (total_device_num != LongToSize(used_device_num)) {
1489 result += ListProduct(input_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1490 }
1491 }
1492 return result;
1493 }
1494
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1495 double GatherV2PCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1496 int64_t) const {
1497 double result = 0.0;
1498 if (outputs_type_lengths_.size() != outputs.size()) {
1499 MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for gatherv2 cost";
1500 }
1501 // don't split axis
1502 if (strategy_.at(LongToSize(axis_)) == 1) {
1503 return result;
1504 }
1505
1506 // split axis
1507 auto param_shape = inputs[0].slice_shape();
1508 auto index_shape = inputs[1].slice_shape();
1509 Shape reducescatter_shape = index_shape;
1510 if (param_shape.size() == 2) {
1511 reducescatter_shape.push_back(param_shape.at(LongToSize(1 - axis_)));
1512 }
1513 result += ListProduct(reducescatter_shape) * static_cast<double>(outputs_type_lengths_[0]);
1514 return result;
1515 }
1516
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t stage_id) const1517 double GatherV2PCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1518 int64_t stage_id) const {
1519 double result = 0.0;
1520 CheckGlobalDeviceManager();
1521 MS_EXCEPTION_IF_NULL(g_device_manager);
1522 auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
1523
1524 for (size_t j = 0; j < inputs.size(); ++j) {
1525 if (!is_parameter_[j]) {
1526 continue;
1527 }
1528 TensorInfo input_a_tensor_info = inputs[j];
1529 Shape input_a_shape = input_a_tensor_info.shape();
1530 Shape input_a_slice_shape = input_a_tensor_info.slice_shape();
1531 int64_t used_device_num = 1;
1532 for (size_t i = 0; i < input_a_shape.size(); ++i) {
1533 used_device_num *= input_a_shape[i] / input_a_slice_shape[i];
1534 }
1535 if (total_device_num != LongToSize(used_device_num)) {
1536 result += ListProduct(input_a_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1537 }
1538 }
1539 return result;
1540 }
1541
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const1542 double UniformCandidateSamplerCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
1543 const std::vector<TensorInfo> &, int64_t) const {
1544 Shape input0_slice_shape = inputs[0].slice_shape();
1545 if (inputs_type_lengths_.size() != inputs.size()) {
1546 MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size()
1547 << " for UniformCandidateSampler cost";
1548 }
1549
1550 double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1551
1552 return result;
1553 }
1554
CalculateOutputInMemory()1555 void UniformCandidateSamplerCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
1556
CalculateInputsInMemory(const std::map<size_t,bool> &)1557 void UniformCandidateSamplerCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
1558 is_inputs_should_in_memory_[0] = is_parameter_[0];
1559 }
1560
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const1561 double GatherV2PCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1562 int64_t) const {
1563 double result = 0.0;
1564 Shape input0_slice_shape = inputs[0].slice_shape();
1565 Shape input1_slice_shape = inputs[1].slice_shape();
1566 if (inputs_type_lengths_.size() != inputs.size()) {
1567 MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for gatherv2 cost";
1568 }
1569 // don't split axis
1570 if (strategy_.at(LongToSize(axis_)) == 1) {
1571 result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
1572 ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
1573 } else {
1574 // split axis
1575 result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * GATHERV2_COST_WEIGHT0 +
1576 ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * GATHERV2_COST_WEIGHT1;
1577 }
1578
1579 return result;
1580 }
1581
GetBackwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1582 double GatherV2PCost::GetBackwardComputationCost(const std::vector<TensorInfo> &inputs,
1583 const std::vector<TensorInfo> &outputs, int64_t) const {
1584 double result = 0.0;
1585 Shape input1_slice_shape = inputs[1].slice_shape();
1586 Shape output0_slice_shape = outputs[0].slice_shape();
1587 // don't split axis
1588 if (strategy_.at(LongToSize(axis_)) == 1) {
1589 result += ListProduct(output0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
1590 } else {
1591 // split axis
1592 result += ListProduct(output0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * GATHERV2_COST_WEIGHT2 +
1593 ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * GATHERV2_COST_WEIGHT3;
1594 }
1595
1596 return result;
1597 }
1598
1599 // The forward communication is determined by whether the slice is column split or row split
1600 // The number of segments is actually the shape[0] of the output, which is the cost of the AllReduce
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1601 double UnsortedSegmentSumCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs,
1602 const std::vector<TensorInfo> &outputs, int64_t) const {
1603 TensorInfo input0 = inputs[0];
1604 TensorInfo input1 = inputs[1];
1605 TensorInfo output0 = outputs[0];
1606 Shape input0_shape = input0.shape();
1607 Shape input0_slice_shape = inputs[0].slice_shape();
1608 double result = 0.0;
1609 if (inputs_type_lengths_.size() != inputs.size()) {
1610 MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for UnsortedSegmentSum cost";
1611 }
1612 // If the shape b is not the same as the shape a, we regard it as column slice
1613 for (size_t i = 0; i < input1.shape().size(); ++i) {
1614 if (input0_shape[i] != input0_slice_shape[i]) {
1615 result = ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
1616 return result;
1617 }
1618 }
1619 return result;
1620 }
1621
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1622 double UnsortedSegmentSumCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs,
1623 const std::vector<TensorInfo> &outputs, int64_t) const {
1624 TensorInfo input0 = inputs[0];
1625 TensorInfo input1 = inputs[1];
1626 TensorInfo output0 = outputs[0];
1627 Shape input0_shape = input0.shape();
1628 Shape input0_slice_shape = inputs[0].slice_shape();
1629 double result = 0.0;
1630 if (inputs_type_lengths_.size() != inputs.size()) {
1631 MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for UnsortedSegmentSum cost";
1632 }
1633 if (is_parameter_[0]) {
1634 // If the forward process has a AllReduce, then the backward also needs one.
1635 for (size_t i = 0; i < input1.shape().size(); ++i) {
1636 if (input0_shape[i] != input0_slice_shape[i]) {
1637 result = ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
1638 return result;
1639 }
1640 }
1641 }
1642 return result;
1643 }
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1644 double UnsortedSegmentSumCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
1645 const std::vector<TensorInfo> &outputs, int64_t) const {
1646 // In forward phase, the computation cost = slice(A) + slice(B)
1647 Shape input0_slice_shape = inputs[0].slice_shape();
1648 Shape input1_slice_shape = inputs[1].slice_shape();
1649 Shape output_slice_shape = outputs[0].slice_shape();
1650 double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
1651 ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) +
1652 ListProduct(output_slice_shape) * static_cast<double>(outputs_type_lengths_[0]);
1653 return result;
1654 }
1655
1656 // Not taking account of output
CalculateOutputInMemory()1657 void UnsortedSegmentSumCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
1658
1659 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1660 void UnsortedSegmentSumCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1661 // When calculating 'dx', taking account of 'y'
1662 if (is_parameter_[0]) {
1663 is_inputs_should_in_memory_[0] = true;
1664 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1665 is_inputs_should_in_memory_[1] = true;
1666 }
1667 } else if (is_parameter_involve_[0]) {
1668 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1669 is_inputs_should_in_memory_[1] = true;
1670 }
1671 }
1672
1673 if (!is_inputs_should_in_memory_[1]) {
1674 is_inputs_should_in_memory_[1] = is_parameter_[1];
1675 }
1676 is_inputs_should_in_memory_[UNSORTEDSEGMENTSUM_INPUTS_SIZE - 1] = is_parameter_[UNSORTEDSEGMENTSUM_INPUTS_SIZE - 1];
1677 }
1678
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1679 double UnsortedSegmentMinCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs,
1680 const std::vector<TensorInfo> &outputs, int64_t) const {
1681 TensorInfo input0 = inputs[0];
1682 TensorInfo input1 = inputs[1];
1683 TensorInfo output0 = outputs[0];
1684 Shape input0_shape = input0.shape();
1685 Shape input0_slice_shape = inputs[0].slice_shape();
1686 double result = 0.0;
1687 if (inputs_type_lengths_.size() != inputs.size()) {
1688 MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size()
1689 << " for UnsortedSegmentMinCost cost";
1690 }
1691 // If the shape b is not the same as the shape a, we regard it as column slice
1692 // The cost is a AllGather operation, the shape is the same as the output of UnsortedSegmentMin.
1693 for (size_t i = 0; i < input1.shape().size(); ++i) {
1694 if (input0_shape[i] != input0_slice_shape[i]) {
1695 result = ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
1696 return result;
1697 }
1698 }
1699 return result;
1700 }
1701
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1702 double UnsortedSegmentMinCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs,
1703 const std::vector<TensorInfo> &outputs, int64_t) const {
1704 TensorInfo input0 = inputs[0];
1705 TensorInfo input1 = inputs[1];
1706 TensorInfo output0 = outputs[0];
1707 Shape input0_shape = input0.shape();
1708 Shape input0_slice_shape = inputs[0].slice_shape();
1709 double result = 0.0;
1710 if (inputs_type_lengths_.size() != inputs.size()) {
1711 MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size()
1712 << " for UnsortedSegmentMinCost cost";
1713 }
1714 if (is_parameter_[0]) {
1715 // If the forward process has a AllGather, then the backward also needs one ReduceScatter.
1716 for (size_t i = 0; i < input1.shape().size(); ++i) {
1717 if (input0_shape[i] != input0_slice_shape[i]) {
1718 result = ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
1719 return result;
1720 }
1721 }
1722 }
1723 return result;
1724 }
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t) const1725 double UnsortedSegmentMinCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
1726 const std::vector<TensorInfo> &outputs, int64_t) const {
1727 // In forward phase, the computation cost = slice(A) + slice(B)
1728 Shape input0_slice_shape = inputs[0].slice_shape();
1729 Shape input1_slice_shape = inputs[1].slice_shape();
1730 Shape output_slice_shape = outputs[0].slice_shape();
1731 // The forward operation is UnsortedSegmentMin + ReudceMin
1732 double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
1733 ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) +
1734 ListProduct(output_slice_shape) * static_cast<double>(outputs_type_lengths_[0]) +
1735 ListProduct(output_slice_shape) * static_cast<double>(outputs_type_lengths_[0]); // ReduceMin
1736 return result;
1737 }
1738
1739 // Taking account of output
CalculateOutputInMemory()1740 void UnsortedSegmentMinCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; }
1741
1742 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1743 void UnsortedSegmentMinCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
1744 // When calculating 'dx', taking account of 'x', 'y' and 'z'
1745 if (is_parameter_involve_[0]) {
1746 if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) {
1747 is_inputs_should_in_memory_[0] = true;
1748 }
1749 if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
1750 is_inputs_should_in_memory_[1] = true;
1751 }
1752 if ((prev_output_in_mem.find(UNSORTEDSEGMENTSUM_INPUTS_SIZE - 1) == prev_output_in_mem.end()) ||
1753 (!prev_output_in_mem.at(UNSORTEDSEGMENTSUM_INPUTS_SIZE - 1))) {
1754 is_inputs_should_in_memory_[UNSORTEDSEGMENTSUM_INPUTS_SIZE - 1] = true;
1755 }
1756 }
1757 if (!is_inputs_should_in_memory_[1]) {
1758 is_inputs_should_in_memory_[1] = is_parameter_[1];
1759 }
1760 if (!is_inputs_should_in_memory_[UNSORTEDSEGMENTSUM_INPUTS_SIZE - 1]) {
1761 is_inputs_should_in_memory_[UNSORTEDSEGMENTSUM_INPUTS_SIZE - 1] = is_parameter_[UNSORTEDSEGMENTSUM_INPUTS_SIZE - 1];
1762 }
1763 }
1764
1765 // Not taking account of output
CalculateOutputInMemory()1766 void VirtualDatasetCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; }
1767
1768 // Not taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> &)1769 void VirtualDatasetCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
1770 for (size_t i = 0; i < is_inputs_should_in_memory_.size(); ++i) {
1771 is_inputs_should_in_memory_[i] = is_parameter_[i];
1772 }
1773 }
1774
GetForwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> &,int64_t) const1775 double MatmulDDSCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1776 int64_t) const {
1777 double result = 0.0;
1778 if (inputs_type_lengths_.size() != inputs.size()) {
1779 MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for layer norm cost";
1780 }
1781
1782 for (size_t index = 0; index < inputs.size(); ++index) {
1783 TensorInfo tensor_info = inputs[index];
1784 Shape slice_shape = tensor_info.slice_shape();
1785 result += ListProduct(slice_shape) * static_cast<double>(inputs_type_lengths_[index]);
1786 }
1787 return result;
1788 }
1789
1790 // Not taking account of output
CalculateOutputInMemory()1791 void MatmulDDSCost::CalculateOutputInMemory() {
1792 is_output_should_in_memory_ =
1793 (std::find(is_parameter_involve_.begin(), is_parameter_involve_.end(), true) != is_parameter_involve_.end());
1794 }
1795
1796 // Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> &)1797 void MatmulDDSCost::CalculateInputsInMemory(const std::map<size_t, bool> &) {
1798 bool keep_mem =
1799 (std::find(is_parameter_.begin(), is_parameter_.end(), true) != is_parameter_.end()) ||
1800 (std::find(is_parameter_involve_.begin(), is_parameter_involve_.end(), true) != is_parameter_involve_.end());
1801 std::fill(is_inputs_should_in_memory_.begin(), is_inputs_should_in_memory_.end(), keep_mem);
1802 }
1803 } // namespace parallel
1804 } // namespace mindspore
1805