1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #ifndef PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_
17 #define PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_
18
19 #include <memory>
20 #include <vector>
21 #include <map>
22 #include "frontend/parallel/device_manager.h"
23 #include "frontend/parallel/tensor_layout/tensor_info.h"
24 #include "frontend/parallel/ops_info/ops_utils.h"
25 #include "frontend/parallel/auto_parallel/costmodel.h"
26
27 namespace mindspore {
28 namespace parallel {
29 constexpr size_t MAXIMUM_INPUT_NUMBER = 100;
30 constexpr size_t DEFAULT_DATA_TYPE_LENGTH = 4;
31 constexpr double DROPOUT_COST_RATE = 1.125; // the DropoutGenMask need 12.5% memory
32 constexpr size_t GATHERV2_COST_WEIGHT0 = 3;
33 constexpr size_t GATHERV2_COST_WEIGHT1 = 7;
34 constexpr size_t GATHERV2_COST_WEIGHT2 = 2;
35 constexpr size_t GATHERV2_COST_WEIGHT3 = 6;
36
37 class OperatorCost;
38 using OperatorCostPtr = std::shared_ptr<OperatorCost>;
39
40 template <typename T>
ListProduct(std::vector<T> vec)41 double ListProduct(std::vector<T> vec) {
42 double result = 1;
43 for (size_t i = 0; i < vec.size(); ++i) {
44 result *= vec[i];
45 }
46 return result;
47 }
48 // NOTE: Currently, the returned value in each method is bytes of memory size, which is calculated by the number of
49 // entries timing the length of each entry's data type
50 class OperatorCost {
51 public:
OperatorCost()52 OperatorCost() {
53 // this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked
54 for (size_t i = 0; i < MAXIMUM_INPUT_NUMBER; ++i) {
55 is_parameter_.push_back(false);
56 is_parameter_involve_.push_back(false);
57 inputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH);
58 outputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH);
59 }
60 }
61 virtual ~OperatorCost() = default;
62
63 void set_is_parameter(const std::vector<bool> &is_parameter);
64 void set_is_parameter_involve(const std::vector<bool> &);
65 void set_output_parameter_involve(int64_t);
66 void set_output_critical(int64_t);
67 void SetInputAndOutputTypeLength(const std::vector<size_t> &input_lengths, const std::vector<size_t> &output_lengths);
inputs_type_lengths()68 std::vector<size_t> inputs_type_lengths() const { return inputs_type_lengths_; }
outputs_type_lengths()69 std::vector<size_t> outputs_type_lengths() const { return outputs_type_lengths_; }
70
71 // per device communication cost
72 virtual double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
73 int64_t stage_id) const = 0;
74 virtual double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
75 int64_t stage_id) const = 0;
76 virtual double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
77 int64_t stage_id) const = 0;
78 // per device computation cost
79 virtual double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
80 int64_t stage_id) const = 0;
81 virtual double GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
82 const std::vector<TensorInfo> &outputs, int64_t stage_id) const = 0;
83 virtual double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs,
84 const std::vector<TensorInfo> &outputs, int64_t stage_id) const = 0;
85 virtual void CalculateOutputInMemory() = 0;
86 virtual void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) = 0;
is_output_in_memory()87 bool is_output_in_memory() const { return is_output_should_in_memory_; }
88 // per device PEAK memory cost in a training iteration
89 // Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-involved),
90 // plus necessary inputs.
91 virtual double GetMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const;
92 // Contributing the input part for 'GetMemoryCost'
93 double GetInputMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const;
94 // Contributing the output part for 'GetMemoryCost'
95 double GetOutputMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const;
96 // per device memory cost in a inference phase
97 double GetMemoryCostForInference(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &outputs) const;
98
99 protected:
100 // For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of
101 // pre-operator that has parameters as input.
102 std::vector<bool> is_parameter_involve_;
103 int64_t output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved
104 // For each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter
105 std::vector<bool> is_parameter_;
106 // Whether the input should keep in memory in training phase. It depends on the operator and the operator's
107 // previous operators.
108 std::vector<bool> is_inputs_should_in_memory_;
109 // Whether the output should keep in memory in training phase. It depends on 'is_parameter_involve_' and the operator.
110 bool is_output_should_in_memory_ = false;
111 // For each input and output, the followings record the number of bytes of each element
112 std::vector<size_t> inputs_type_lengths_;
113 std::vector<size_t> outputs_type_lengths_;
114 // Whether the output is critical, which means that this output is included in calculating peak memory cost
115 // in the inference phase.
116 int64_t is_outputs_critical_ = -1;
117 };
118
119 class MatMulCost : public OperatorCost {
120 public:
MatMulCost()121 MatMulCost() : OperatorCost() {}
122 ~MatMulCost() override = default;
123
124 // per device communication cost
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)125 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
126 int64_t stage_id) const override {
127 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
128 }
129 double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
130 int64_t stage_id) const override;
131 double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
132 int64_t stage_id) const override;
133
134 // per device computation cost
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)135 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
136 int64_t stage_id) const override {
137 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
138 }
139 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
140 int64_t stage_id) const override;
141 double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
142 int64_t stage_id) const override;
143 void CalculateOutputInMemory() override;
144 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
145 };
146 using TensorDotCost = MatMulCost;
147
148 class BatchNormCost : public OperatorCost {
149 public:
BatchNormCost()150 BatchNormCost() : OperatorCost() {}
151 ~BatchNormCost() override = default;
152
153 // per device communication cost
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)154 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
155 int64_t stage_id) const override {
156 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
157 }
158 double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
159 int64_t stage_id) const override;
160 double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
161 int64_t stage_id) const override;
162
163 // per device computation cost
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)164 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
165 int64_t stage_id) const override {
166 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
167 }
168 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
169 int64_t stage_id) const override;
170 double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
171 int64_t) const override;
172 void CalculateOutputInMemory() override;
173 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
174 };
175
176 class CastCost : public OperatorCost {
177 public:
CastCost()178 CastCost() : OperatorCost() {}
179 ~CastCost() override = default;
180
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)181 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
182 int64_t stage_id) const override {
183 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
184 }
185 double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
186 int64_t stage_id) const override;
187 double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
188 int64_t stage_id) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)189 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
190 int64_t stage_id) const override {
191 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
192 }
193 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
194 int64_t stage_id) const override;
195 double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
196 int64_t stage_id) const override;
197 // Not taking account of output
198 void CalculateOutputInMemory() override;
199 // Not Taking account of input
200 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
201 };
202 using RepeatElementsCost = CastCost;
203 using NegCost = CastCost;
204 using ExpandDimsCost = CastCost;
205 using SqueezeCost = CastCost;
206 using ConcatCost = CastCost;
207 using LogicalNotCost = CastCost;
208 using SignCost = CastCost;
209 using FloorCost = CastCost;
210 using RoundCost = CastCost;
211 using CeilCost = CastCost;
212 using ZerosLikeCost = CastCost;
213 using OnesLikeCost = CastCost;
214 using RangeCost = CastCost;
215 using SplitCost = CastCost;
216 using ScatterUpdateCost = CastCost;
217 using RandomDistributeCost = CastCost;
218 using FillV2Cost = CastCost;
219 using ResizeBilinearCost = CastCost;
220 using BoundingBoxEncodeCost = CastCost;
221 using IOUCost = CastCost;
222 using RandomChoicWithMaskCost = CastCost;
223 using IsFiniteCost = CastCost;
224 using RintCost = CastCost;
225 using GammaCost = CastCost;
226 using LinSpaceCost = CastCost;
227
228 class SqrtCost : public CastCost {
229 public:
SqrtCost()230 SqrtCost() : CastCost() {}
231 ~SqrtCost() override = default;
232 // Taking account of output, not taking accounting of input
233 void CalculateOutputInMemory() override;
234 };
235 using TanhCost = SqrtCost;
236 using EluCost = SqrtCost;
237 using ReLUCost = SqrtCost;
238 using SiLUCost = SqrtCost;
239 using identityCost = SqrtCost;
240 using SigmoidCost = SqrtCost;
241 using ReciprocalCost =
242 SqrtCost; // The derivative of 'Reciprocal' is different on 'Ascend' and 'GPU'. Here, 'Ascend' is chosen
243 using InvCost = SqrtCost;
244 using RsqrtCost = SqrtCost;
245 using AsinhCost = SqrtCost;
246 using AcoshCost = SqrtCost;
247 using TopKCost = SqrtCost;
248 using HShrinkCost = SqrtCost;
249 using HSigmoidCost = SqrtCost;
250 using MishCost = SqrtCost;
251 using SeLUCost = SqrtCost;
252 using SoftShrinkCost = SqrtCost;
253
254 class ReLU6Cost : public CastCost {
255 public:
ReLU6Cost()256 ReLU6Cost() : CastCost() {}
257 ~ReLU6Cost() override = default;
258 // Taking account of input, not taking account of output
259 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
260 };
261 using SoftsignCost = ReLU6Cost;
262 using SoftplusCost = ReLU6Cost;
263 using SquareCost = ReLU6Cost;
264 using ExpCost = ReLU6Cost;
265 using LogCost = ReLU6Cost;
266 using CosCost = ReLU6Cost;
267 using ACosCost = ReLU6Cost;
268 using AbsCost = ReLU6Cost;
269 using TanCost = ReLU6Cost;
270 using SinCost = ReLU6Cost;
271 using SinhCost = ReLU6Cost;
272 using Log1pCost = ReLU6Cost;
273 using Expm1Cost = ReLU6Cost;
274 using CoshCost = ReLU6Cost;
275 using AtanhCost = ReLU6Cost;
276 using AtanCost = ReLU6Cost;
277 using AsinCost = ReLU6Cost;
278 using ErfCost = ReLU6Cost;
279 using ErfcCost = ReLU6Cost;
280 using ActivationInfoCost = ReLU6Cost;
281 using SelectCost = ReLU6Cost;
282 using XlogyCost = ReLU6Cost;
283 using ErfinvCost = ReLU6Cost;
284
285 class TransposeCost : public CastCost {
286 public:
TransposeCost()287 TransposeCost() : CastCost() {}
288 ~TransposeCost() override = default;
289 // Taking account of input, not taking account of output
290 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
291 };
292
293 class GeLUCost : public SqrtCost {
294 public:
GeLUCost()295 GeLUCost() : SqrtCost() {}
296 ~GeLUCost() override = default;
297 // Taking account of input and output
298 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
299 };
300 using FastGeLUCost = GeLUCost;
301 using BesselI0eCost = GeLUCost;
302 using BesselI1eCost = GeLUCost;
303 using L2NormalizeCost = GeLUCost;
304 using MaxPoolCost = GeLUCost;
305
306 class SoftmaxCost : public OperatorCost {
307 public:
SoftmaxCost()308 SoftmaxCost() : OperatorCost() {}
309 ~SoftmaxCost() override = default;
310
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)311 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
312 int64_t stage_id) const override {
313 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
314 }
315 double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
316 int64_t stage_id) const override;
317 double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
318 int64_t stage_id) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)319 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
320 int64_t stage_id) const override {
321 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
322 }
323 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
324 int64_t stage_id) const override;
325 double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
326 int64_t) const override;
327 // Taking account of output
328 void CalculateOutputInMemory() override;
329 // Not Taking account of input
330 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
331 };
332
333 using CumSumCost = SoftmaxCost;
334 using CumProdCost = SoftmaxCost;
335
336 class TileCost : public SoftmaxCost {
337 public:
TileCost()338 TileCost() : SoftmaxCost() {}
339 ~TileCost() override = default;
340 // Not taking account of output
341 void CalculateOutputInMemory() override;
342 // Taking account of input
343 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
344 };
345
346 class PackCost : public SoftmaxCost {
347 public:
PackCost()348 PackCost() : SoftmaxCost() {}
349 ~PackCost() override = default;
350 // Not taking account of output
351 void CalculateOutputInMemory() override;
352 // Not taking account of input
353 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
354 };
355
356 class BroadcastToCost : public SoftmaxCost {
357 public:
BroadcastToCost()358 BroadcastToCost() : SoftmaxCost() {}
359 ~BroadcastToCost() override = default;
360 // Not taking account of output
361 void CalculateOutputInMemory() override;
362 // Not Taking account of input
363 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
364 };
365
366 class TmpIdentityCost : public OperatorCost {
367 public:
TmpIdentityCost()368 TmpIdentityCost() : OperatorCost() {}
369 ~TmpIdentityCost() override = default;
370
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)371 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
372 int64_t stage_id) const override {
373 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
374 }
375 double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
376 int64_t stage_id) const override;
377 double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
378 int64_t stage_id) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)379 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
380 int64_t stage_id) const override {
381 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
382 }
383 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
384 int64_t stage_id) const override;
385 double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
386 int64_t stage_id) const override;
387 // Not taking account of output
388 void CalculateOutputInMemory() override;
389 // Not taking account of input
390 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
391 };
392 using TmpIdentityCostPtr = std::shared_ptr<TmpIdentityCost>;
393
394 class BatchParallelCost : public OperatorCost {
395 public:
BatchParallelCost()396 BatchParallelCost() : OperatorCost() {}
397 ~BatchParallelCost() override = default;
398
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)399 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
400 int64_t stage_id) const override {
401 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
402 }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)403 double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
404 return 0.0;
405 }
406 double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
407 int64_t stage_id) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)408 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
409 int64_t stage_id) const override {
410 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
411 }
412 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
413 int64_t stage_id) const override;
414 double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
415 int64_t stage_id) const override;
416 // Not taking account of output
417 void CalculateOutputInMemory() override;
418 // Taking account of input
419 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
420 };
421
422 class SparseSoftmaxCrossEntropyWithLogitsCost : public BatchParallelCost {
423 public:
SparseSoftmaxCrossEntropyWithLogitsCost()424 SparseSoftmaxCrossEntropyWithLogitsCost() : BatchParallelCost() {}
425 ~SparseSoftmaxCrossEntropyWithLogitsCost() override = default;
426 // Taking account of output
427 void CalculateOutputInMemory() override;
428 // Not taking account of input
429 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
430 };
431
432 class VirtualDatasetCost : public OperatorCost {
433 public:
VirtualDatasetCost()434 VirtualDatasetCost() : OperatorCost() {}
435 ~VirtualDatasetCost() override = default;
436
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)437 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
438 int64_t stage_id) const override {
439 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
440 }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)441 double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
442 return 0.0;
443 }
GetBackwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)444 double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
445 return 0.0;
446 }
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)447 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
448 int64_t stage_id) const override {
449 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
450 }
GetForwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)451 double GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
452 int64_t) const override {
453 return 0.0;
454 }
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)455 double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
456 int64_t) const override {
457 return 0.0;
458 }
459 // Not taking account of output
460 void CalculateOutputInMemory() override;
461 // Not taking account of input
462 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
463 };
464
465 class GeneratorBaseCost : public OperatorCost {
466 public:
GeneratorBaseCost()467 GeneratorBaseCost() : OperatorCost() {}
468 ~GeneratorBaseCost() override = default;
469
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)470 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
471 int64_t stage_id) const override {
472 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
473 }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)474 double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
475 return 0.0;
476 }
GetBackwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)477 double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
478 return 0.0;
479 }
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)480 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
481 int64_t stage_id) const override {
482 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
483 }
484 // Inputs vector is empty for generator ops.
GetForwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)485 double GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
486 int64_t) const override {
487 return 0.0;
488 }
489 // Generator ops don't have backward steps.
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)490 double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
491 int64_t) const override {
492 return 0.0;
493 }
494 };
495 using GeneratorBaseCostPtr = std::shared_ptr<GeneratorBaseCost>;
496
497 class PReLUCost : public OperatorCost {
498 public:
PReLUCost()499 PReLUCost() : OperatorCost() {}
500 ~PReLUCost() override = default;
501
502 // per device communication cost
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)503 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
504 int64_t stage_id) const override {
505 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
506 }
507 double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
508 int64_t stage_id) const override;
509 double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
510 int64_t stage_id) const override;
511
512 // per device computation cost
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)513 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
514 int64_t stage_id) const override {
515 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
516 }
517 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
518 int64_t stage_id) const override;
519 double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
520 int64_t stage_id) const override;
521 // Not taking account of output
522 void CalculateOutputInMemory() override;
523 // Taking account of input
524 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
525 };
526 using PReLUCostPtr = std::shared_ptr<PReLUCost>;
527
528 class OneHotCost : public OperatorCost {
529 public:
OneHotCost()530 OneHotCost() : OperatorCost() {}
531 ~OneHotCost() override = default;
532
533 // per device communication cost
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)534 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
535 int64_t stage_id) const override {
536 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
537 }
538 double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
539 int64_t stage_id) const override;
540 double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
541 int64_t stage_id) const override;
542
543 // per device computation cost
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)544 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
545 int64_t stage_id) const override {
546 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
547 }
548 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
549 int64_t stage_id) const override;
550 double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
551 int64_t stage_id) const override;
552 // Not taking account of output
553 void CalculateOutputInMemory() override;
554 // Not taking account of input
555 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
556 };
557 using OneHotCostPtr = std::shared_ptr<OneHotCost>;
558
559 class SoftmaxCrossEntropyWithLogitsCost : public OperatorCost {
560 public:
SoftmaxCrossEntropyWithLogitsCost()561 SoftmaxCrossEntropyWithLogitsCost() : OperatorCost() {}
562 ~SoftmaxCrossEntropyWithLogitsCost() override = default;
563
564 // per device communication cost
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)565 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
566 int64_t stage_id) const override {
567 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
568 }
569 double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
570 int64_t stage_id) const override;
571 double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
572 int64_t stage_id) const override;
573
574 // per device computation cost
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)575 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
576 int64_t stage_id) const override {
577 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
578 }
579 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
580 int64_t stage_id) const override;
581 double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
582 int64_t stage_id) const override;
583 // Taking account of output
584 void CalculateOutputInMemory() override;
585 // Not taking account of input
586 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
587 };
588
589 class ReshapeCost : public OperatorCost {
590 public:
ReshapeCost()591 ReshapeCost() : OperatorCost() {}
592
593 ~ReshapeCost() override = default;
594
595 // per device communication cost
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)596 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
597 int64_t stage_id) const override {
598 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
599 }
600
601 double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
602 int64_t stage_id) const override;
603
604 double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
605 int64_t stage_id) const override;
606
607 // per device computation cost
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)608 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
609 int64_t stage_id) const override {
610 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
611 }
612
613 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
614 int64_t stage_id) const override;
615
616 double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
617 int64_t stage_id) const override;
618 // Not taking account of output
619 void CalculateOutputInMemory() override;
620 // Not taking account of input
621 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
622 };
623 using ReshapeCostPtr = std::shared_ptr<ReshapeCost>;
624
625 class SubCost : public OperatorCost {
626 public:
SubCost()627 SubCost() : OperatorCost() {}
628 ~SubCost() override = default;
629
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)630 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
631 int64_t stage_id) const override {
632 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
633 }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)634 double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
635 return 0.0;
636 }
637 double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
638 int64_t stage_id) const override;
639
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)640 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
641 int64_t stage_id) const override {
642 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
643 }
644 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
645 int64_t stage_id) const override;
646 double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
647 int64_t stage_id) const override;
648 // Not taking account of output
649 void CalculateOutputInMemory() override;
650 // Not taking account of input
651 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
652 };
653 using TensorAddCost = SubCost;
654 using FloorDivCost = SubCost;
655 using AssignSubCost = SubCost;
656 using AssignAddCost = SubCost;
657 using LogicalAndCost = SubCost;
658 using LogicalOrCost = SubCost;
659 using BiasAddCost = SubCost;
660 using EqualCost = SubCost;
661 using ApproximateEqualCost = SubCost;
662 using NotEqualCost = SubCost;
663 using GreaterCost = SubCost;
664 using GreaterEqualCost = SubCost;
665 using LessCost = SubCost;
666 using LessEqualCost = SubCost;
667 using GatherNdCost = SubCost;
668 using BitwiseAndCost = SubCost;
669 using BitwiseOrCost = SubCost;
670 using BitwiseXorCost = SubCost;
671 using AddNCost = SubCost;
672 using InplaceAddCost = SubCost;
673 using InplaceSubCost = InplaceAddCost;
674 using InplaceUpdateCost = InplaceAddCost;
675 using MaskedFillCost = SubCost;
676
677 class MulCost : public SubCost {
678 public:
MulCost()679 MulCost() : SubCost() {}
680 ~MulCost() override = default;
681 // Taking account of input, not taking account of output
682 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
683 };
684
685 using MulNoNanCost = MulCost;
686 using GatherDCost = MulCost;
687 using LerpCost = MulCost;
688 using SquaredDifferenceCost = MulCost;
689
690 class DivCost : public SubCost {
691 public:
DivCost()692 DivCost() : SubCost() {}
693 ~DivCost() override = default;
694 // Taking account of output
695 void CalculateOutputInMemory() override;
696 // Taking account of input
697 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
698 };
699 using ReadDivCost = DivCost;
700 using TruncateDivCost = DivCost;
701 using XdivyCost = DivCost;
702 using CdistCost = DivCost;
703
704 class ModCost : public SubCost {
705 public:
ModCost()706 ModCost() : SubCost() {}
707 ~ModCost() override = default;
708 // Taking account of input, not taking account of output
709 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
710 };
711 using FloorModCost = ModCost;
712 using TruncateModCost = ModCost;
713
714 class PowCost : public SubCost {
715 public:
PowCost()716 PowCost() : SubCost() {}
717 ~PowCost() override = default;
718 // Taking account of output
719 void CalculateOutputInMemory() override;
720 // Taking account of input
721 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
722 };
723
724 class AssignCost : public SubCost {
725 public:
AssignCost()726 AssignCost() : SubCost() {}
727 ~AssignCost() override = default;
728 // Taking account of input, not taking account of output
729 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
730 };
731
732 class SigmoidCrossEntropyWithLogitsCost : public SubCost {
733 public:
SigmoidCrossEntropyWithLogitsCost()734 SigmoidCrossEntropyWithLogitsCost() : SubCost() {}
735 ~SigmoidCrossEntropyWithLogitsCost() override = default;
736 // Taking account of input, not taking account of output
737 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
738 };
739
740 class Atan2Cost : public SubCost {
741 public:
Atan2Cost()742 Atan2Cost() : SubCost() {}
743 ~Atan2Cost() override = default;
744 // Taking account of input, not taking account of output
745 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
746 };
747
748 class DivNoNanCost : public SubCost {
749 public:
DivNoNanCost()750 DivNoNanCost() : SubCost() {}
751 ~DivNoNanCost() override = default;
752 // Taking account of output
753 void CalculateOutputInMemory() override;
754 // Taking account of input
755 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
756 };
757
758 class MaximumCost : public SubCost {
759 public:
MaximumCost()760 MaximumCost() : SubCost() {}
761 ~MaximumCost() override = default;
762 // Taking account of input, not taking account of output
763 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
764 };
765 using MinimumCost = MaximumCost;
766 using CumminCost = MaximumCost;
767
768 class SliceCost : public CastCost {
769 public:
SliceCost()770 SliceCost() : CastCost() {}
771 ~SliceCost() override = default;
772 // Not taking account of output, taking account of input
773 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
774 };
775
776 class StridedSliceCost : public CastCost {
777 public:
StridedSliceCost()778 StridedSliceCost() : CastCost() {}
779 ~StridedSliceCost() override = default;
780 // Not taking account of output, taking account of input
781 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
782 };
783
784 class ReduceSumCost : public OperatorCost {
785 public:
ReduceSumCost()786 ReduceSumCost() : OperatorCost() {}
787 ~ReduceSumCost() override = default;
788
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)789 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
790 int64_t stage_id) const override {
791 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
792 }
793 double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
794 int64_t stage_id) const override;
795 double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
796 int64_t stage_id) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)797 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
798 int64_t stage_id) const override {
799 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
800 }
801 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
802 int64_t stage_id) const override;
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)803 double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
804 int64_t) const override {
805 return 0.0;
806 }
set_cross_batch(bool cb)807 void set_cross_batch(bool cb) { cross_batch_ = cb; }
808 // Not taking account of output
809 void CalculateOutputInMemory() override;
810 // Taking account of input
811 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
812
813 protected:
814 bool cross_batch_ = false;
815 };
816 using ReduceMethodCost = ReduceSumCost;
817 using ReduceProdCost = ReduceSumCost;
818 using SquareSumAllCost = ReduceSumCost;
819 using L2LossCost = ReduceSumCost;
820 using KLDivLossCost = ReduceSumCost;
821
822 class ReduceMeanCost : public ReduceSumCost {
823 public:
ReduceMeanCost()824 ReduceMeanCost() : ReduceSumCost() {}
825 ~ReduceMeanCost() override = default;
826
827 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
828 int64_t stage_id) const override;
829 };
830
831 class ReduceMinCost : public ReduceSumCost {
832 public:
ReduceMinCost()833 ReduceMinCost() : ReduceSumCost() {}
834 ~ReduceMinCost() override = default;
835 // Taking account of output
836 void CalculateOutputInMemory() override;
837 // Taking account of input
838 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
839 };
840 using ReduceMaxCost = ReduceMinCost;
841
842 class ArgMaxWithValueCost : public ReduceSumCost {
843 public:
ArgMaxWithValueCost()844 ArgMaxWithValueCost() : ReduceSumCost() {}
845 ~ArgMaxWithValueCost() override = default;
846 // Taking account of output
847 void CalculateOutputInMemory() override;
848 // Taking account of input
849 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
850 };
851 using ArgMinWithValueCost = ArgMaxWithValueCost;
852 using ArgmaxCost = ArgMaxWithValueCost;
853
854 class GetNextCost : public OperatorCost {
855 public:
GetNextCost()856 GetNextCost() : OperatorCost() {}
857 ~GetNextCost() override = default;
858
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)859 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
860 int64_t stage_id) const override {
861 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
862 }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t stage_id)863 double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
864 int64_t stage_id) const override {
865 return 0.0;
866 }
GetBackwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t stage_id)867 double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
868 int64_t stage_id) const override {
869 return 0.0;
870 }
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)871 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
872 int64_t stage_id) const override {
873 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
874 }
875 // Inputs vector is empty for generator ops.
GetForwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)876 double GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
877 int64_t) const override {
878 return 0.0;
879 }
880 // Generator ops don't have backward steps.
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)881 double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
882 int64_t) const override {
883 return 0.0;
884 }
885 // Not taking account of output
886 void CalculateOutputInMemory() override;
887 // Not Taking account of input
888 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
889 };
890 using GetNextCostPtr = std::shared_ptr<GetNextCost>;
891
892 class DSDMatmulCost : public OperatorCost {
893 public:
DSDMatmulCost()894 DSDMatmulCost() : OperatorCost() {}
895 ~DSDMatmulCost() override = default;
896
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)897 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
898 int64_t stage_id) const override {
899 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
900 }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)901 double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
902 return 0.0;
903 }
GetBackwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)904 double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
905 return 0.0;
906 }
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)907 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
908 int64_t stage_id) const override {
909 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
910 }
911 // Inputs vector is empty for generator ops.
912 double GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
913 int64_t) const override;
914 // Generator ops don't have backward steps.
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)915 double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
916 int64_t) const override {
917 return 0.0;
918 }
919 // Not taking account of output
920 void CalculateOutputInMemory() override;
921 // Not Taking account of input
922 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
923 };
924 using DSDMatmulCostPtr = std::shared_ptr<DSDMatmulCost>;
925
926 // For memory cost, taking account of output, not taking account of input
927 class DropOutCost : public SqrtCost {
928 public:
DropOutCost()929 DropOutCost() : SqrtCost() {}
930 ~DropOutCost() override = default;
931
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)932 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
933 int64_t stage_id) const override {
934 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
935 }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)936 double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
937 return 0.0;
938 }
GetBackwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)939 double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
940 return 0.0;
941 }
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)942 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
943 int64_t stage_id) const override {
944 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
945 }
946 double GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
947 int64_t) const override;
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)948 double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
949 int64_t) const override {
950 return 0.0;
951 }
952 };
953
954 class DropOutDoMaskCost : public DropOutCost {
955 public:
DropOutDoMaskCost()956 DropOutDoMaskCost() : DropOutCost() {}
957 ~DropOutDoMaskCost() override = default;
958 // Not taking account of output
959 void CalculateOutputInMemory() override;
960 // Taking account of input
961 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
962 };
963
964 class UnsortedSegmentSumCost : public OperatorCost {
965 public:
UnsortedSegmentSumCost()966 UnsortedSegmentSumCost() : OperatorCost() {}
967 ~UnsortedSegmentSumCost() override = default;
968
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)969 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
970 int64_t stage_id) const override {
971 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
972 }
973 double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override;
974 double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)975 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
976 int64_t stage_id) const override {
977 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
978 }
979 double GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
980 int64_t) const override;
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)981 double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
982 int64_t) const override {
983 return 0.0;
984 }
985 // Not taking account of output
986 void CalculateOutputInMemory() override;
987 // Taking account of input
988 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
989 };
990 using UnsortedSegmentProdCost = UnsortedSegmentSumCost;
991
992 class UnsortedSegmentMinCost : public OperatorCost {
993 public:
UnsortedSegmentMinCost()994 UnsortedSegmentMinCost() : OperatorCost() {}
995 ~UnsortedSegmentMinCost() override = default;
996
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)997 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
998 int64_t stage_id) const override {
999 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
1000 }
1001 double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override;
1002 double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1003 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1004 int64_t stage_id) const override {
1005 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
1006 }
1007 double GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
1008 int64_t) const override;
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)1009 double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
1010 int64_t) const override {
1011 return 0.0;
1012 }
1013 // Taking account of output
1014 void CalculateOutputInMemory() override;
1015 // Taking account of input
1016 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
1017 };
1018 using UnsortedSegmentMaxCost = UnsortedSegmentMinCost;
1019
1020 class LayerNormCost : public OperatorCost {
1021 public:
LayerNormCost()1022 LayerNormCost() : OperatorCost() {}
1023 ~LayerNormCost() override = default;
1024
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1025 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1026 int64_t stage_id) const override {
1027 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
1028 }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)1029 double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
1030 return 0.0;
1031 }
1032 double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1033 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1034 int64_t stage_id) const override {
1035 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
1036 }
1037 double GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
1038 int64_t) const override;
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)1039 double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
1040 int64_t) const override {
1041 return 0.0;
1042 }
1043 // Taking account of output
1044 void CalculateOutputInMemory() override;
1045 // Taking account of input
1046 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
1047 };
1048 using GroupNormCost = LayerNormCost;
1049
1050 class UniqueCost : public OperatorCost {
1051 public:
UniqueCost()1052 UniqueCost() : OperatorCost() {}
1053 ~UniqueCost() override = default;
1054
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1055 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1056 int64_t stage_id) const override {
1057 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
1058 }
1059 double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1060 int64_t stage_id) const override;
1061 double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1062 int64_t stage_id) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1063 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1064 int64_t stage_id) const override {
1065 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
1066 }
1067 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1068 int64_t stage_id) const override;
1069 double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1070 int64_t stage_id) const override;
1071 // Taking account of output
1072 void CalculateOutputInMemory() override;
1073 // Not Taking account of input
1074 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
1075 };
1076
1077 class UniformCandidateSamplerCost : public OperatorCost {
1078 public:
UniformCandidateSamplerCost()1079 UniformCandidateSamplerCost() : OperatorCost() {}
1080 ~UniformCandidateSamplerCost() override = default;
1081
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1082 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1083 int64_t stage_id) const override {
1084 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
1085 }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)1086 double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
1087 return 0.0;
1088 }
GetBackwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)1089 double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
1090 return 0.0;
1091 }
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1092 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1093 int64_t stage_id) const override {
1094 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
1095 }
1096 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1097 int64_t stage_id) const override;
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)1098 double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
1099 int64_t) const override {
1100 return 0.0;
1101 }
1102 // Not taking account of output
1103 void CalculateOutputInMemory() override;
1104 // Not Taking account of input
1105 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
1106 };
1107
1108 class GatherV2Cost : public OperatorCost {
1109 public:
GatherV2Cost()1110 GatherV2Cost() : OperatorCost() {}
1111 ~GatherV2Cost() override = default;
1112
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1113 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1114 int64_t stage_id) const override {
1115 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
1116 }
1117 double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1118 int64_t stage_id) const override;
1119 double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1120 int64_t stage_id) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1121 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1122 int64_t stage_id) const override {
1123 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
1124 }
1125 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1126 int64_t stage_id) const override;
1127 double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1128 int64_t) const override;
1129 // Not taking account of output
1130 void CalculateOutputInMemory() override;
1131 // Taking account of input
1132 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
1133 };
1134
1135 class GatherCost : public GatherV2Cost {
1136 public:
GatherCost()1137 GatherCost() : GatherV2Cost(), axis_(0) {}
1138 ~GatherCost() override = default;
1139
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1140 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1141 int64_t stage_id) const override {
1142 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
1143 }
1144 double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1145 int64_t stage_id) const override;
1146 double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1147 int64_t stage_id) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1148 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1149 int64_t stage_id) const override {
1150 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
1151 }
1152 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1153 int64_t stage_id) const override;
1154 double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1155 int64_t) const override;
set_axis(int64_t axis)1156 void set_axis(int64_t axis) { axis_ = axis; }
set_strategy(const Shape & strategy)1157 void set_strategy(const Shape &strategy) { strategy_ = strategy; }
1158
1159 protected:
1160 int64_t axis_;
1161 Shape strategy_;
1162 };
1163
1164 class MatmulDDSCost : public OperatorCost {
1165 public:
MatmulDDSCost()1166 MatmulDDSCost() : OperatorCost() {}
1167 ~MatmulDDSCost() override = default;
1168
1169 // per device communication cost
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1170 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1171 int64_t stage_id) const override {
1172 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
1173 }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)1174 double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
1175 return 0.0;
1176 };
GetBackwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)1177 double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
1178 return 0.0;
1179 };
1180
1181 // per device computation cost
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1182 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1183 int64_t stage_id) const override {
1184 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
1185 }
1186 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1187 int64_t stage_id) const override;
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)1188 double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
1189 int64_t) const override {
1190 return 0.0;
1191 };
1192 // Not taking account of output
1193 void CalculateOutputInMemory() override;
1194 // Taking account of input
1195 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
1196 };
1197 using MatmulDDSCostPtr = std::shared_ptr<MatmulDDSCost>;
1198
1199 class ScatterMathOpsCost : public OperatorCost {
1200 public:
ScatterMathOpsCost()1201 ScatterMathOpsCost() : OperatorCost() {}
1202 ~ScatterMathOpsCost() override = default;
1203
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1204 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1205 int64_t stage_id) const override {
1206 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
1207 }
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1208 double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1209 int64_t stage_id) const override {
1210 return 0.0;
1211 }
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1212 double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1213 int64_t stage_id) const override {
1214 return 0.0;
1215 }
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1216 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1217 int64_t stage_id) const override {
1218 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
1219 }
1220 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1221 int64_t stage_id) const override;
GetBackwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1222 double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1223 int64_t stage_id) const override {
1224 return 0.0;
1225 }
1226 // Not taking account of output
CalculateOutputInMemory()1227 void CalculateOutputInMemory() override { is_output_should_in_memory_ = false; }
1228 // Not Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1229 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override {
1230 is_inputs_should_in_memory_[0] = true;
1231 }
1232
set_is_split_axis(bool is_split_axis)1233 void set_is_split_axis(bool is_split_axis) { is_split_axis_ = is_split_axis; }
set_coefficient(int32_t input_coefficient,int32_t indices_coefficient,int32_t updates_coefficient)1234 void set_coefficient(int32_t input_coefficient, int32_t indices_coefficient, int32_t updates_coefficient) {
1235 input_coefficient_ = input_coefficient;
1236 indices_coefficient_ = indices_coefficient;
1237 updates_coefficient_ = updates_coefficient;
1238 }
1239
1240 protected:
1241 bool is_split_axis_ = false;
1242 int32_t input_coefficient_ = 1;
1243 int32_t indices_coefficient_ = 3;
1244 int32_t updates_coefficient_ = 2;
1245 };
1246
1247 class ScatterNdOpsCost : public ScatterMathOpsCost {
1248 public:
ScatterNdOpsCost()1249 ScatterNdOpsCost() : ScatterMathOpsCost() {}
1250 ~ScatterNdOpsCost() override = default;
1251 };
1252
1253 class TensorScatterOpsCost : public ScatterMathOpsCost {
1254 public:
TensorScatterOpsCost()1255 TensorScatterOpsCost() : ScatterMathOpsCost() {}
1256 ~TensorScatterOpsCost() override = default;
1257 double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1258 int64_t stage_id) const override;
1259 double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1260 int64_t stage_id) const override;
1261 };
1262
1263 class CropAndResizeCost : public OperatorCost {
1264 public:
CropAndResizeCost()1265 CropAndResizeCost() : OperatorCost() {}
1266 ~CropAndResizeCost() override = default;
1267
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1268 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1269 int64_t stage_id) const override {
1270 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
1271 }
1272 double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1273 int64_t) const override;
1274 double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1275 int64_t stage_id) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1276 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1277 int64_t stage_id) const override {
1278 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
1279 }
1280 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1281 int64_t) const override;
1282 double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
1283 int64_t) const override;
1284 // Taking account for input
1285 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
1286 // Not taking account of output
1287 void CalculateOutputInMemory() override;
1288
set_strategy(const Shape & strategy)1289 void set_strategy(const Shape &strategy) { strategy_ = strategy; }
set_crop_size(const std::vector<int64_t> & crop_size)1290 void set_crop_size(const std::vector<int64_t> &crop_size) { crop_size_ = crop_size; }
1291
1292 protected:
1293 Shape strategy_;
1294 std::vector<int64_t> crop_size_;
1295
1296 private:
1297 static const size_t CROP_AND_RESIZE_COST_WEIGHT0 = 1;
1298 static const size_t CROP_AND_RESIZE_COST_WEIGHT1 = 1;
1299 static const size_t CROP_AND_RESIZE_COST_WEIGHT2 = 8;
1300 static const size_t CROP_AND_RESIZE_COST_WEIGHT3 = 2;
1301 };
1302
1303 class ROIAlignCost : public CropAndResizeCost {
1304 public:
ROIAlignCost()1305 ROIAlignCost() : CropAndResizeCost() {}
1306 ~ROIAlignCost() override = default;
1307
1308 double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1309 int64_t) const override;
1310 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1311 int64_t) const override;
1312 double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &outputs,
1313 int64_t) const override;
1314 // Taking account for input
1315 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
1316 // Taking account of output
1317 void CalculateOutputInMemory() override;
1318
set_pooled_shape(const Shape & pooled_shape)1319 void set_pooled_shape(const Shape &pooled_shape) { pooled_shape_ = pooled_shape; }
1320
1321 protected:
1322 Shape pooled_shape_;
1323
1324 private:
1325 static const size_t ROI_ALIGN_COST_WEIGHT0 = 1;
1326 static const size_t ROI_ALIGN_COST_WEIGHT1 = 4;
1327 static const size_t ROI_ALIGN_COST_WEIGHT2 = 2;
1328 };
1329 } // namespace parallel
1330 } // namespace mindspore
1331 #endif // PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_
1332