• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_
17 #define PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_
18 
19 #include <memory>
20 #include <vector>
21 #include <map>
22 #include "frontend/parallel/device_manager.h"
23 #include "frontend/parallel/tensor_layout/tensor_info.h"
24 #include "frontend/parallel/ops_info/ops_utils.h"
25 #include "frontend/parallel/auto_parallel/costmodel.h"
26 
27 namespace mindspore {
28 namespace parallel {
29 constexpr size_t MAXIMUM_INPUT_NUMBER = 100;
30 constexpr size_t DEFAULT_DATA_TYPE_LENGTH = 4;
31 constexpr double DROPOUT_COST_RATE = 1.125;  // the DropoutGenMask need 12.5% memory
32 constexpr size_t GATHERV2_COST_WEIGHT0 = 3;
33 constexpr size_t GATHERV2_COST_WEIGHT1 = 7;
34 constexpr size_t GATHERV2_COST_WEIGHT2 = 2;
35 constexpr size_t GATHERV2_COST_WEIGHT3 = 6;
36 
37 class OperatorCost;
38 using OperatorCostPtr = std::shared_ptr<OperatorCost>;
39 
40 template <typename T>
ListProduct(std::vector<T> vec)41 double ListProduct(std::vector<T> vec) {
42   double result = 1;
43   for (size_t i = 0; i < vec.size(); ++i) {
44     result *= vec[i];
45   }
46   return result;
47 }
48 // NOTE: Currently, the returned value in each method is bytes of memory size, which is calculated by the number of
49 // entries timing the length of each entry's data type
50 class OperatorCost {
51  public:
OperatorCost()52   OperatorCost() {
53     // this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked
54     for (size_t i = 0; i < MAXIMUM_INPUT_NUMBER; ++i) {
55       is_parameter_.push_back(false);
56       is_parameter_involve_.push_back(false);
57       inputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH);
58       outputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH);
59     }
60   }
61   virtual ~OperatorCost() = default;
62 
63   void set_is_parameter(const std::vector<bool> &is_parameter);
64   void set_is_parameter_involve(const std::vector<bool> &);
65   void set_output_parameter_involve(int64_t);
66   void set_output_critical(int64_t);
67   void SetInputAndOutputTypeLength(const std::vector<size_t> &input_lengths, const std::vector<size_t> &output_lengths);
inputs_type_lengths()68   std::vector<size_t> inputs_type_lengths() const { return inputs_type_lengths_; }
outputs_type_lengths()69   std::vector<size_t> outputs_type_lengths() const { return outputs_type_lengths_; }
70 
71   // per device communication cost
72   virtual double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
73                              int64_t stage_id) const = 0;
74   virtual double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
75                                     int64_t stage_id) const = 0;
76   virtual double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
77                                      int64_t stage_id) const = 0;
78   // per device computation cost
79   virtual double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
80                                     int64_t stage_id) const = 0;
81   virtual double GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
82                                            const std::vector<TensorInfo> &outputs, int64_t stage_id) const = 0;
83   virtual double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs,
84                                             const std::vector<TensorInfo> &outputs, int64_t stage_id) const = 0;
85   virtual void CalculateOutputInMemory() = 0;
86   virtual void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) = 0;
is_output_in_memory()87   bool is_output_in_memory() const { return is_output_should_in_memory_; }
88   // per device PEAK memory cost in a training iteration
89   // Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-involved),
90   // plus necessary inputs.
91   virtual double GetMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const;
92   // Contributing the input part for 'GetMemoryCost'
93   double GetInputMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const;
94   // Contributing the output part for 'GetMemoryCost'
95   double GetOutputMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const;
96   // per device memory cost in a inference phase
97   double GetMemoryCostForInference(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &outputs) const;
98 
99  protected:
100   // For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of
101   // pre-operator that has parameters as input.
102   std::vector<bool> is_parameter_involve_;
103   int64_t output_parameter_involve_ = -1;  // -1: unset; 0: not parameter_involved; 1: parameter_involved
104   // For each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter
105   std::vector<bool> is_parameter_;
106   // Whether the input should keep in memory in training phase. It depends on the operator and the operator's
107   // previous operators.
108   std::vector<bool> is_inputs_should_in_memory_;
109   // Whether the output should keep in memory in training phase. It depends on 'is_parameter_involve_' and the operator.
110   bool is_output_should_in_memory_ = false;
111   // For each input and output, the followings record the number of bytes of each element
112   std::vector<size_t> inputs_type_lengths_;
113   std::vector<size_t> outputs_type_lengths_;
114   // Whether the output is critical, which means that this output is included in calculating peak memory cost
115   // in the inference phase.
116   int64_t is_outputs_critical_ = -1;
117 };
118 
119 class MatMulCost : public OperatorCost {
120  public:
MatMulCost()121   MatMulCost() : OperatorCost() {}
122   ~MatMulCost() override = default;
123 
124   // per device communication cost
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)125   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
126                      int64_t stage_id) const override {
127     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
128   }
129   double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
130                             int64_t stage_id) const override;
131   double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
132                              int64_t stage_id) const override;
133 
134   // per device computation cost
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)135   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
136                             int64_t stage_id) const override {
137     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
138   }
139   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
140                                    int64_t stage_id) const override;
141   double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
142                                     int64_t stage_id) const override;
143   void CalculateOutputInMemory() override;
144   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
145 };
146 using TensorDotCost = MatMulCost;
147 
148 class BatchNormCost : public OperatorCost {
149  public:
BatchNormCost()150   BatchNormCost() : OperatorCost() {}
151   ~BatchNormCost() override = default;
152 
153   // per device communication cost
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)154   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
155                      int64_t stage_id) const override {
156     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
157   }
158   double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
159                             int64_t stage_id) const override;
160   double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
161                              int64_t stage_id) const override;
162 
163   // per device computation cost
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)164   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
165                             int64_t stage_id) const override {
166     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
167   }
168   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
169                                    int64_t stage_id) const override;
170   double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
171                                     int64_t) const override;
172   void CalculateOutputInMemory() override;
173   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
174 };
175 
176 class CastCost : public OperatorCost {
177  public:
CastCost()178   CastCost() : OperatorCost() {}
179   ~CastCost() override = default;
180 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)181   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
182                      int64_t stage_id) const override {
183     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
184   }
185   double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
186                             int64_t stage_id) const override;
187   double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
188                              int64_t stage_id) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)189   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
190                             int64_t stage_id) const override {
191     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
192   }
193   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
194                                    int64_t stage_id) const override;
195   double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
196                                     int64_t stage_id) const override;
197   // Not taking account of output
198   void CalculateOutputInMemory() override;
199   // Not Taking account of input
200   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
201 };
202 using RepeatElementsCost = CastCost;
203 using NegCost = CastCost;
204 using ExpandDimsCost = CastCost;
205 using SqueezeCost = CastCost;
206 using ConcatCost = CastCost;
207 using LogicalNotCost = CastCost;
208 using SignCost = CastCost;
209 using FloorCost = CastCost;
210 using RoundCost = CastCost;
211 using CeilCost = CastCost;
212 using ZerosLikeCost = CastCost;
213 using OnesLikeCost = CastCost;
214 using RangeCost = CastCost;
215 using SplitCost = CastCost;
216 using ScatterUpdateCost = CastCost;
217 using RandomDistributeCost = CastCost;
218 using FillV2Cost = CastCost;
219 using ResizeBilinearCost = CastCost;
220 using BoundingBoxEncodeCost = CastCost;
221 using IOUCost = CastCost;
222 using RandomChoicWithMaskCost = CastCost;
223 using IsFiniteCost = CastCost;
224 using RintCost = CastCost;
225 using GammaCost = CastCost;
226 using LinSpaceCost = CastCost;
227 
228 class SqrtCost : public CastCost {
229  public:
SqrtCost()230   SqrtCost() : CastCost() {}
231   ~SqrtCost() override = default;
232   // Taking account of output, not taking accounting of input
233   void CalculateOutputInMemory() override;
234 };
235 using TanhCost = SqrtCost;
236 using EluCost = SqrtCost;
237 using ReLUCost = SqrtCost;
238 using SiLUCost = SqrtCost;
239 using identityCost = SqrtCost;
240 using SigmoidCost = SqrtCost;
241 using ReciprocalCost =
242   SqrtCost;  // The derivative of 'Reciprocal' is different on 'Ascend' and 'GPU'. Here, 'Ascend' is chosen
243 using InvCost = SqrtCost;
244 using RsqrtCost = SqrtCost;
245 using AsinhCost = SqrtCost;
246 using AcoshCost = SqrtCost;
247 using TopKCost = SqrtCost;
248 using HShrinkCost = SqrtCost;
249 using HSigmoidCost = SqrtCost;
250 using MishCost = SqrtCost;
251 using SeLUCost = SqrtCost;
252 using SoftShrinkCost = SqrtCost;
253 
254 class ReLU6Cost : public CastCost {
255  public:
ReLU6Cost()256   ReLU6Cost() : CastCost() {}
257   ~ReLU6Cost() override = default;
258   // Taking account of input, not taking account of output
259   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
260 };
261 using SoftsignCost = ReLU6Cost;
262 using SoftplusCost = ReLU6Cost;
263 using SquareCost = ReLU6Cost;
264 using ExpCost = ReLU6Cost;
265 using LogCost = ReLU6Cost;
266 using CosCost = ReLU6Cost;
267 using ACosCost = ReLU6Cost;
268 using AbsCost = ReLU6Cost;
269 using TanCost = ReLU6Cost;
270 using SinCost = ReLU6Cost;
271 using SinhCost = ReLU6Cost;
272 using Log1pCost = ReLU6Cost;
273 using Expm1Cost = ReLU6Cost;
274 using CoshCost = ReLU6Cost;
275 using AtanhCost = ReLU6Cost;
276 using AtanCost = ReLU6Cost;
277 using AsinCost = ReLU6Cost;
278 using ErfCost = ReLU6Cost;
279 using ErfcCost = ReLU6Cost;
280 using ActivationInfoCost = ReLU6Cost;
281 using SelectCost = ReLU6Cost;
282 using XlogyCost = ReLU6Cost;
283 using ErfinvCost = ReLU6Cost;
284 
285 class TransposeCost : public CastCost {
286  public:
TransposeCost()287   TransposeCost() : CastCost() {}
288   ~TransposeCost() override = default;
289   // Taking account of input, not taking account of output
290   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
291 };
292 
293 class GeLUCost : public SqrtCost {
294  public:
GeLUCost()295   GeLUCost() : SqrtCost() {}
296   ~GeLUCost() override = default;
297   // Taking account of input and output
298   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
299 };
300 using FastGeLUCost = GeLUCost;
301 using BesselI0eCost = GeLUCost;
302 using BesselI1eCost = GeLUCost;
303 using L2NormalizeCost = GeLUCost;
304 using MaxPoolCost = GeLUCost;
305 
306 class SoftmaxCost : public OperatorCost {
307  public:
SoftmaxCost()308   SoftmaxCost() : OperatorCost() {}
309   ~SoftmaxCost() override = default;
310 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)311   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
312                      int64_t stage_id) const override {
313     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
314   }
315   double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
316                             int64_t stage_id) const override;
317   double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
318                              int64_t stage_id) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)319   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
320                             int64_t stage_id) const override {
321     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
322   }
323   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
324                                    int64_t stage_id) const override;
325   double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
326                                     int64_t) const override;
327   // Taking account of output
328   void CalculateOutputInMemory() override;
329   // Not Taking account of input
330   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
331 };
332 
333 using CumSumCost = SoftmaxCost;
334 using CumProdCost = SoftmaxCost;
335 
336 class TileCost : public SoftmaxCost {
337  public:
TileCost()338   TileCost() : SoftmaxCost() {}
339   ~TileCost() override = default;
340   // Not taking account of output
341   void CalculateOutputInMemory() override;
342   // Taking account of input
343   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
344 };
345 
346 class PackCost : public SoftmaxCost {
347  public:
PackCost()348   PackCost() : SoftmaxCost() {}
349   ~PackCost() override = default;
350   // Not taking account of output
351   void CalculateOutputInMemory() override;
352   // Not taking account of input
353   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
354 };
355 
356 class BroadcastToCost : public SoftmaxCost {
357  public:
BroadcastToCost()358   BroadcastToCost() : SoftmaxCost() {}
359   ~BroadcastToCost() override = default;
360   // Not taking account of output
361   void CalculateOutputInMemory() override;
362   // Not Taking account of input
363   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
364 };
365 
366 class TmpIdentityCost : public OperatorCost {
367  public:
TmpIdentityCost()368   TmpIdentityCost() : OperatorCost() {}
369   ~TmpIdentityCost() override = default;
370 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)371   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
372                      int64_t stage_id) const override {
373     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
374   }
375   double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
376                             int64_t stage_id) const override;
377   double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
378                              int64_t stage_id) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)379   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
380                             int64_t stage_id) const override {
381     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
382   }
383   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
384                                    int64_t stage_id) const override;
385   double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
386                                     int64_t stage_id) const override;
387   // Not taking account of output
388   void CalculateOutputInMemory() override;
389   // Not taking account of input
390   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
391 };
392 using TmpIdentityCostPtr = std::shared_ptr<TmpIdentityCost>;
393 
394 class BatchParallelCost : public OperatorCost {
395  public:
BatchParallelCost()396   BatchParallelCost() : OperatorCost() {}
397   ~BatchParallelCost() override = default;
398 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)399   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
400                      int64_t stage_id) const override {
401     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
402   }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)403   double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
404     return 0.0;
405   }
406   double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
407                              int64_t stage_id) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)408   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
409                             int64_t stage_id) const override {
410     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
411   }
412   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
413                                    int64_t stage_id) const override;
414   double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
415                                     int64_t stage_id) const override;
416   // Not taking account of output
417   void CalculateOutputInMemory() override;
418   // Taking account of input
419   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
420 };
421 
422 class SparseSoftmaxCrossEntropyWithLogitsCost : public BatchParallelCost {
423  public:
SparseSoftmaxCrossEntropyWithLogitsCost()424   SparseSoftmaxCrossEntropyWithLogitsCost() : BatchParallelCost() {}
425   ~SparseSoftmaxCrossEntropyWithLogitsCost() override = default;
426   // Taking account of output
427   void CalculateOutputInMemory() override;
428   // Not taking account of input
429   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
430 };
431 
432 class VirtualDatasetCost : public OperatorCost {
433  public:
VirtualDatasetCost()434   VirtualDatasetCost() : OperatorCost() {}
435   ~VirtualDatasetCost() override = default;
436 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)437   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
438                      int64_t stage_id) const override {
439     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
440   }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)441   double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
442     return 0.0;
443   }
GetBackwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)444   double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
445     return 0.0;
446   }
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)447   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
448                             int64_t stage_id) const override {
449     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
450   }
GetForwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)451   double GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
452                                    int64_t) const override {
453     return 0.0;
454   }
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)455   double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
456                                     int64_t) const override {
457     return 0.0;
458   }
459   // Not taking account of output
460   void CalculateOutputInMemory() override;
461   // Not taking account of input
462   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
463 };
464 
465 class GeneratorBaseCost : public OperatorCost {
466  public:
GeneratorBaseCost()467   GeneratorBaseCost() : OperatorCost() {}
468   ~GeneratorBaseCost() override = default;
469 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)470   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
471                      int64_t stage_id) const override {
472     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
473   }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)474   double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
475     return 0.0;
476   }
GetBackwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)477   double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
478     return 0.0;
479   }
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)480   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
481                             int64_t stage_id) const override {
482     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
483   }
484   // Inputs vector is empty for generator ops.
GetForwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)485   double GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
486                                    int64_t) const override {
487     return 0.0;
488   }
489   // Generator ops don't have backward steps.
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)490   double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
491                                     int64_t) const override {
492     return 0.0;
493   }
494 };
495 using GeneratorBaseCostPtr = std::shared_ptr<GeneratorBaseCost>;
496 
497 class PReLUCost : public OperatorCost {
498  public:
PReLUCost()499   PReLUCost() : OperatorCost() {}
500   ~PReLUCost() override = default;
501 
502   // per device communication cost
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)503   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
504                      int64_t stage_id) const override {
505     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
506   }
507   double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
508                             int64_t stage_id) const override;
509   double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
510                              int64_t stage_id) const override;
511 
512   // per device computation cost
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)513   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
514                             int64_t stage_id) const override {
515     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
516   }
517   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
518                                    int64_t stage_id) const override;
519   double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
520                                     int64_t stage_id) const override;
521   // Not taking account of output
522   void CalculateOutputInMemory() override;
523   // Taking account of input
524   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
525 };
526 using PReLUCostPtr = std::shared_ptr<PReLUCost>;
527 
528 class OneHotCost : public OperatorCost {
529  public:
OneHotCost()530   OneHotCost() : OperatorCost() {}
531   ~OneHotCost() override = default;
532 
533   // per device communication cost
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)534   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
535                      int64_t stage_id) const override {
536     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
537   }
538   double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
539                             int64_t stage_id) const override;
540   double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
541                              int64_t stage_id) const override;
542 
543   // per device computation cost
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)544   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
545                             int64_t stage_id) const override {
546     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
547   }
548   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
549                                    int64_t stage_id) const override;
550   double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
551                                     int64_t stage_id) const override;
552   // Not taking account of output
553   void CalculateOutputInMemory() override;
554   // Not taking account of input
555   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
556 };
557 using OneHotCostPtr = std::shared_ptr<OneHotCost>;
558 
559 class SoftmaxCrossEntropyWithLogitsCost : public OperatorCost {
560  public:
SoftmaxCrossEntropyWithLogitsCost()561   SoftmaxCrossEntropyWithLogitsCost() : OperatorCost() {}
562   ~SoftmaxCrossEntropyWithLogitsCost() override = default;
563 
564   // per device communication cost
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)565   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
566                      int64_t stage_id) const override {
567     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
568   }
569   double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
570                             int64_t stage_id) const override;
571   double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
572                              int64_t stage_id) const override;
573 
574   // per device computation cost
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)575   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
576                             int64_t stage_id) const override {
577     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
578   }
579   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
580                                    int64_t stage_id) const override;
581   double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
582                                     int64_t stage_id) const override;
583   // Taking account of output
584   void CalculateOutputInMemory() override;
585   // Not taking account of input
586   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
587 };
588 
589 class ReshapeCost : public OperatorCost {
590  public:
ReshapeCost()591   ReshapeCost() : OperatorCost() {}
592 
593   ~ReshapeCost() override = default;
594 
595   // per device communication cost
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)596   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
597                      int64_t stage_id) const override {
598     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
599   }
600 
601   double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
602                             int64_t stage_id) const override;
603 
604   double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
605                              int64_t stage_id) const override;
606 
607   // per device computation cost
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)608   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
609                             int64_t stage_id) const override {
610     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
611   }
612 
613   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
614                                    int64_t stage_id) const override;
615 
616   double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
617                                     int64_t stage_id) const override;
618   // Not taking account of output
619   void CalculateOutputInMemory() override;
620   // Not taking account of input
621   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
622 };
623 using ReshapeCostPtr = std::shared_ptr<ReshapeCost>;
624 
625 class SubCost : public OperatorCost {
626  public:
SubCost()627   SubCost() : OperatorCost() {}
628   ~SubCost() override = default;
629 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)630   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
631                      int64_t stage_id) const override {
632     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
633   }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)634   double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
635     return 0.0;
636   }
637   double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
638                              int64_t stage_id) const override;
639 
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)640   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
641                             int64_t stage_id) const override {
642     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
643   }
644   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
645                                    int64_t stage_id) const override;
646   double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
647                                     int64_t stage_id) const override;
648   // Not taking account of output
649   void CalculateOutputInMemory() override;
650   // Not taking account of input
651   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
652 };
653 using TensorAddCost = SubCost;
654 using FloorDivCost = SubCost;
655 using AssignSubCost = SubCost;
656 using AssignAddCost = SubCost;
657 using LogicalAndCost = SubCost;
658 using LogicalOrCost = SubCost;
659 using BiasAddCost = SubCost;
660 using EqualCost = SubCost;
661 using ApproximateEqualCost = SubCost;
662 using NotEqualCost = SubCost;
663 using GreaterCost = SubCost;
664 using GreaterEqualCost = SubCost;
665 using LessCost = SubCost;
666 using LessEqualCost = SubCost;
667 using GatherNdCost = SubCost;
668 using BitwiseAndCost = SubCost;
669 using BitwiseOrCost = SubCost;
670 using BitwiseXorCost = SubCost;
671 using AddNCost = SubCost;
672 using InplaceAddCost = SubCost;
673 using InplaceSubCost = InplaceAddCost;
674 using InplaceUpdateCost = InplaceAddCost;
675 using MaskedFillCost = SubCost;
676 
677 class MulCost : public SubCost {
678  public:
MulCost()679   MulCost() : SubCost() {}
680   ~MulCost() override = default;
681   // Taking account of input, not taking account of output
682   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
683 };
684 
685 using MulNoNanCost = MulCost;
686 using GatherDCost = MulCost;
687 using LerpCost = MulCost;
688 using SquaredDifferenceCost = MulCost;
689 
690 class DivCost : public SubCost {
691  public:
DivCost()692   DivCost() : SubCost() {}
693   ~DivCost() override = default;
694   // Taking account of output
695   void CalculateOutputInMemory() override;
696   // Taking account of input
697   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
698 };
699 using ReadDivCost = DivCost;
700 using TruncateDivCost = DivCost;
701 using XdivyCost = DivCost;
702 using CdistCost = DivCost;
703 
704 class ModCost : public SubCost {
705  public:
ModCost()706   ModCost() : SubCost() {}
707   ~ModCost() override = default;
708   // Taking account of input, not taking account of output
709   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
710 };
711 using FloorModCost = ModCost;
712 using TruncateModCost = ModCost;
713 
714 class PowCost : public SubCost {
715  public:
PowCost()716   PowCost() : SubCost() {}
717   ~PowCost() override = default;
718   // Taking account of output
719   void CalculateOutputInMemory() override;
720   // Taking account of input
721   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
722 };
723 
724 class AssignCost : public SubCost {
725  public:
AssignCost()726   AssignCost() : SubCost() {}
727   ~AssignCost() override = default;
728   // Taking account of input, not taking account of output
729   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
730 };
731 
732 class SigmoidCrossEntropyWithLogitsCost : public SubCost {
733  public:
SigmoidCrossEntropyWithLogitsCost()734   SigmoidCrossEntropyWithLogitsCost() : SubCost() {}
735   ~SigmoidCrossEntropyWithLogitsCost() override = default;
736   // Taking account of input, not taking account of output
737   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
738 };
739 
740 class Atan2Cost : public SubCost {
741  public:
Atan2Cost()742   Atan2Cost() : SubCost() {}
743   ~Atan2Cost() override = default;
744   // Taking account of input, not taking account of output
745   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
746 };
747 
748 class DivNoNanCost : public SubCost {
749  public:
DivNoNanCost()750   DivNoNanCost() : SubCost() {}
751   ~DivNoNanCost() override = default;
752   // Taking account of output
753   void CalculateOutputInMemory() override;
754   // Taking account of input
755   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
756 };
757 
758 class MaximumCost : public SubCost {
759  public:
MaximumCost()760   MaximumCost() : SubCost() {}
761   ~MaximumCost() override = default;
762   // Taking account of input, not taking account of output
763   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
764 };
765 using MinimumCost = MaximumCost;
766 using CumminCost = MaximumCost;
767 
768 class SliceCost : public CastCost {
769  public:
SliceCost()770   SliceCost() : CastCost() {}
771   ~SliceCost() override = default;
772   // Not taking account of output, taking account of input
773   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
774 };
775 
776 class StridedSliceCost : public CastCost {
777  public:
StridedSliceCost()778   StridedSliceCost() : CastCost() {}
779   ~StridedSliceCost() override = default;
780   // Not taking account of output, taking account of input
781   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
782 };
783 
784 class ReduceSumCost : public OperatorCost {
785  public:
ReduceSumCost()786   ReduceSumCost() : OperatorCost() {}
787   ~ReduceSumCost() override = default;
788 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)789   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
790                      int64_t stage_id) const override {
791     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
792   }
793   double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
794                             int64_t stage_id) const override;
795   double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
796                              int64_t stage_id) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)797   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
798                             int64_t stage_id) const override {
799     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
800   }
801   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
802                                    int64_t stage_id) const override;
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)803   double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
804                                     int64_t) const override {
805     return 0.0;
806   }
set_cross_batch(bool cb)807   void set_cross_batch(bool cb) { cross_batch_ = cb; }
808   // Not taking account of output
809   void CalculateOutputInMemory() override;
810   // Taking account of input
811   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
812 
813  protected:
814   bool cross_batch_ = false;
815 };
816 using ReduceMethodCost = ReduceSumCost;
817 using ReduceProdCost = ReduceSumCost;
818 using SquareSumAllCost = ReduceSumCost;
819 using L2LossCost = ReduceSumCost;
820 using KLDivLossCost = ReduceSumCost;
821 
822 class ReduceMeanCost : public ReduceSumCost {
823  public:
ReduceMeanCost()824   ReduceMeanCost() : ReduceSumCost() {}
825   ~ReduceMeanCost() override = default;
826 
827   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
828                                    int64_t stage_id) const override;
829 };
830 
831 class ReduceMinCost : public ReduceSumCost {
832  public:
ReduceMinCost()833   ReduceMinCost() : ReduceSumCost() {}
834   ~ReduceMinCost() override = default;
835   // Taking account of output
836   void CalculateOutputInMemory() override;
837   // Taking account of input
838   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
839 };
840 using ReduceMaxCost = ReduceMinCost;
841 
842 class ArgMaxWithValueCost : public ReduceSumCost {
843  public:
ArgMaxWithValueCost()844   ArgMaxWithValueCost() : ReduceSumCost() {}
845   ~ArgMaxWithValueCost() override = default;
846   // Taking account of output
847   void CalculateOutputInMemory() override;
848   // Taking account of input
849   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
850 };
851 using ArgMinWithValueCost = ArgMaxWithValueCost;
852 using ArgmaxCost = ArgMaxWithValueCost;
853 
854 class GetNextCost : public OperatorCost {
855  public:
GetNextCost()856   GetNextCost() : OperatorCost() {}
857   ~GetNextCost() override = default;
858 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)859   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
860                      int64_t stage_id) const override {
861     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
862   }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t stage_id)863   double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
864                             int64_t stage_id) const override {
865     return 0.0;
866   }
GetBackwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t stage_id)867   double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
868                              int64_t stage_id) const override {
869     return 0.0;
870   }
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)871   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
872                             int64_t stage_id) const override {
873     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
874   }
875   // Inputs vector is empty for generator ops.
GetForwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)876   double GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
877                                    int64_t) const override {
878     return 0.0;
879   }
880   // Generator ops don't have backward steps.
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)881   double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
882                                     int64_t) const override {
883     return 0.0;
884   }
885   // Not taking account of output
886   void CalculateOutputInMemory() override;
887   // Not Taking account of input
888   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
889 };
890 using GetNextCostPtr = std::shared_ptr<GetNextCost>;
891 
892 class DSDMatmulCost : public OperatorCost {
893  public:
DSDMatmulCost()894   DSDMatmulCost() : OperatorCost() {}
895   ~DSDMatmulCost() override = default;
896 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)897   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
898                      int64_t stage_id) const override {
899     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
900   }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)901   double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
902     return 0.0;
903   }
GetBackwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)904   double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
905     return 0.0;
906   }
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)907   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
908                             int64_t stage_id) const override {
909     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
910   }
911   // Inputs vector is empty for generator ops.
912   double GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
913                                    int64_t) const override;
914   // Generator ops don't have backward steps.
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)915   double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
916                                     int64_t) const override {
917     return 0.0;
918   }
919   // Not taking account of output
920   void CalculateOutputInMemory() override;
921   // Not Taking account of input
922   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
923 };
924 using DSDMatmulCostPtr = std::shared_ptr<DSDMatmulCost>;
925 
926 // For memory cost, taking account of output, not taking account of input
927 class DropOutCost : public SqrtCost {
928  public:
DropOutCost()929   DropOutCost() : SqrtCost() {}
930   ~DropOutCost() override = default;
931 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)932   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
933                      int64_t stage_id) const override {
934     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
935   }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)936   double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
937     return 0.0;
938   }
GetBackwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)939   double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
940     return 0.0;
941   }
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)942   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
943                             int64_t stage_id) const override {
944     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
945   }
946   double GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
947                                    int64_t) const override;
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)948   double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
949                                     int64_t) const override {
950     return 0.0;
951   }
952 };
953 
954 class DropOutDoMaskCost : public DropOutCost {
955  public:
DropOutDoMaskCost()956   DropOutDoMaskCost() : DropOutCost() {}
957   ~DropOutDoMaskCost() override = default;
958   // Not taking account of output
959   void CalculateOutputInMemory() override;
960   // Taking account of input
961   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
962 };
963 
964 class UnsortedSegmentSumCost : public OperatorCost {
965  public:
UnsortedSegmentSumCost()966   UnsortedSegmentSumCost() : OperatorCost() {}
967   ~UnsortedSegmentSumCost() override = default;
968 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)969   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
970                      int64_t stage_id) const override {
971     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
972   }
973   double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override;
974   double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)975   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
976                             int64_t stage_id) const override {
977     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
978   }
979   double GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
980                                    int64_t) const override;
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)981   double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
982                                     int64_t) const override {
983     return 0.0;
984   }
985   // Not taking account of output
986   void CalculateOutputInMemory() override;
987   // Taking account of input
988   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
989 };
990 using UnsortedSegmentProdCost = UnsortedSegmentSumCost;
991 
992 class UnsortedSegmentMinCost : public OperatorCost {
993  public:
UnsortedSegmentMinCost()994   UnsortedSegmentMinCost() : OperatorCost() {}
995   ~UnsortedSegmentMinCost() override = default;
996 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)997   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
998                      int64_t stage_id) const override {
999     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
1000   }
1001   double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override;
1002   double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1003   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1004                             int64_t stage_id) const override {
1005     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
1006   }
1007   double GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
1008                                    int64_t) const override;
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)1009   double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
1010                                     int64_t) const override {
1011     return 0.0;
1012   }
1013   // Taking account of output
1014   void CalculateOutputInMemory() override;
1015   // Taking account of input
1016   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
1017 };
1018 using UnsortedSegmentMaxCost = UnsortedSegmentMinCost;
1019 
1020 class LayerNormCost : public OperatorCost {
1021  public:
LayerNormCost()1022   LayerNormCost() : OperatorCost() {}
1023   ~LayerNormCost() override = default;
1024 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1025   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1026                      int64_t stage_id) const override {
1027     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
1028   }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)1029   double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
1030     return 0.0;
1031   }
1032   double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1033   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1034                             int64_t stage_id) const override {
1035     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
1036   }
1037   double GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
1038                                    int64_t) const override;
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)1039   double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
1040                                     int64_t) const override {
1041     return 0.0;
1042   }
1043   // Taking account of output
1044   void CalculateOutputInMemory() override;
1045   // Taking account of input
1046   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
1047 };
1048 using GroupNormCost = LayerNormCost;
1049 
1050 class UniqueCost : public OperatorCost {
1051  public:
UniqueCost()1052   UniqueCost() : OperatorCost() {}
1053   ~UniqueCost() override = default;
1054 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1055   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1056                      int64_t stage_id) const override {
1057     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
1058   }
1059   double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1060                             int64_t stage_id) const override;
1061   double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1062                              int64_t stage_id) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1063   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1064                             int64_t stage_id) const override {
1065     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
1066   }
1067   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1068                                    int64_t stage_id) const override;
1069   double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1070                                     int64_t stage_id) const override;
1071   // Taking account of output
1072   void CalculateOutputInMemory() override;
1073   // Not Taking account of input
1074   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
1075 };
1076 
1077 class UniformCandidateSamplerCost : public OperatorCost {
1078  public:
UniformCandidateSamplerCost()1079   UniformCandidateSamplerCost() : OperatorCost() {}
1080   ~UniformCandidateSamplerCost() override = default;
1081 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1082   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1083                      int64_t stage_id) const override {
1084     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
1085   }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)1086   double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
1087     return 0.0;
1088   }
GetBackwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)1089   double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
1090     return 0.0;
1091   }
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1092   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1093                             int64_t stage_id) const override {
1094     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
1095   }
1096   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1097                                    int64_t stage_id) const override;
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)1098   double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
1099                                     int64_t) const override {
1100     return 0.0;
1101   }
1102   // Not taking account of output
1103   void CalculateOutputInMemory() override;
1104   // Not Taking account of input
1105   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
1106 };
1107 
1108 class GatherV2Cost : public OperatorCost {
1109  public:
GatherV2Cost()1110   GatherV2Cost() : OperatorCost() {}
1111   ~GatherV2Cost() override = default;
1112 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1113   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1114                      int64_t stage_id) const override {
1115     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
1116   }
1117   double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1118                             int64_t stage_id) const override;
1119   double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1120                              int64_t stage_id) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1121   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1122                             int64_t stage_id) const override {
1123     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
1124   }
1125   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1126                                    int64_t stage_id) const override;
1127   double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1128                                     int64_t) const override;
1129   // Not taking account of output
1130   void CalculateOutputInMemory() override;
1131   // Taking account of input
1132   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
1133 };
1134 
1135 class GatherCost : public GatherV2Cost {
1136  public:
GatherCost()1137   GatherCost() : GatherV2Cost(), axis_(0) {}
1138   ~GatherCost() override = default;
1139 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1140   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1141                      int64_t stage_id) const override {
1142     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
1143   }
1144   double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1145                             int64_t stage_id) const override;
1146   double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1147                              int64_t stage_id) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1148   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1149                             int64_t stage_id) const override {
1150     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
1151   }
1152   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1153                                    int64_t stage_id) const override;
1154   double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1155                                     int64_t) const override;
set_axis(int64_t axis)1156   void set_axis(int64_t axis) { axis_ = axis; }
set_strategy(const Shape & strategy)1157   void set_strategy(const Shape &strategy) { strategy_ = strategy; }
1158 
1159  protected:
1160   int64_t axis_;
1161   Shape strategy_;
1162 };
1163 
1164 class MatmulDDSCost : public OperatorCost {
1165  public:
MatmulDDSCost()1166   MatmulDDSCost() : OperatorCost() {}
1167   ~MatmulDDSCost() override = default;
1168 
1169   // per device communication cost
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1170   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1171                      int64_t stage_id) const override {
1172     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
1173   }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)1174   double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
1175     return 0.0;
1176   };
GetBackwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)1177   double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
1178     return 0.0;
1179   };
1180 
1181   // per device computation cost
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1182   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1183                             int64_t stage_id) const override {
1184     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
1185   }
1186   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1187                                    int64_t stage_id) const override;
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)1188   double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
1189                                     int64_t) const override {
1190     return 0.0;
1191   };
1192   // Not taking account of output
1193   void CalculateOutputInMemory() override;
1194   // Taking account of input
1195   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
1196 };
1197 using MatmulDDSCostPtr = std::shared_ptr<MatmulDDSCost>;
1198 
1199 class ScatterMathOpsCost : public OperatorCost {
1200  public:
ScatterMathOpsCost()1201   ScatterMathOpsCost() : OperatorCost() {}
1202   ~ScatterMathOpsCost() override = default;
1203 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1204   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1205                      int64_t stage_id) const override {
1206     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
1207   }
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1208   double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1209                             int64_t stage_id) const override {
1210     return 0.0;
1211   }
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1212   double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1213                              int64_t stage_id) const override {
1214     return 0.0;
1215   }
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1216   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1217                             int64_t stage_id) const override {
1218     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
1219   }
1220   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1221                                    int64_t stage_id) const override;
GetBackwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1222   double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1223                                     int64_t stage_id) const override {
1224     return 0.0;
1225   }
1226   // Not taking account of output
CalculateOutputInMemory()1227   void CalculateOutputInMemory() override { is_output_should_in_memory_ = false; }
1228   // Not Taking account of input
CalculateInputsInMemory(const std::map<size_t,bool> & prev_output_in_mem)1229   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override {
1230     is_inputs_should_in_memory_[0] = true;
1231   }
1232 
set_is_split_axis(bool is_split_axis)1233   void set_is_split_axis(bool is_split_axis) { is_split_axis_ = is_split_axis; }
set_coefficient(int32_t input_coefficient,int32_t indices_coefficient,int32_t updates_coefficient)1234   void set_coefficient(int32_t input_coefficient, int32_t indices_coefficient, int32_t updates_coefficient) {
1235     input_coefficient_ = input_coefficient;
1236     indices_coefficient_ = indices_coefficient;
1237     updates_coefficient_ = updates_coefficient;
1238   }
1239 
1240  protected:
1241   bool is_split_axis_ = false;
1242   int32_t input_coefficient_ = 1;
1243   int32_t indices_coefficient_ = 3;
1244   int32_t updates_coefficient_ = 2;
1245 };
1246 
1247 class ScatterNdOpsCost : public ScatterMathOpsCost {
1248  public:
ScatterNdOpsCost()1249   ScatterNdOpsCost() : ScatterMathOpsCost() {}
1250   ~ScatterNdOpsCost() override = default;
1251 };
1252 
1253 class TensorScatterOpsCost : public ScatterMathOpsCost {
1254  public:
TensorScatterOpsCost()1255   TensorScatterOpsCost() : ScatterMathOpsCost() {}
1256   ~TensorScatterOpsCost() override = default;
1257   double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1258                                     int64_t stage_id) const override;
1259   double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1260                              int64_t stage_id) const override;
1261 };
1262 
1263 class CropAndResizeCost : public OperatorCost {
1264  public:
CropAndResizeCost()1265   CropAndResizeCost() : OperatorCost() {}
1266   ~CropAndResizeCost() override = default;
1267 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1268   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1269                      int64_t stage_id) const override {
1270     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
1271   }
1272   double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1273                             int64_t) const override;
1274   double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1275                              int64_t stage_id) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1276   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1277                             int64_t stage_id) const override {
1278     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
1279   }
1280   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1281                                    int64_t) const override;
1282   double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
1283                                     int64_t) const override;
1284   // Taking account for input
1285   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
1286   // Not taking account of output
1287   void CalculateOutputInMemory() override;
1288 
set_strategy(const Shape & strategy)1289   void set_strategy(const Shape &strategy) { strategy_ = strategy; }
set_crop_size(const std::vector<int64_t> & crop_size)1290   void set_crop_size(const std::vector<int64_t> &crop_size) { crop_size_ = crop_size; }
1291 
1292  protected:
1293   Shape strategy_;
1294   std::vector<int64_t> crop_size_;
1295 
1296  private:
1297   static const size_t CROP_AND_RESIZE_COST_WEIGHT0 = 1;
1298   static const size_t CROP_AND_RESIZE_COST_WEIGHT1 = 1;
1299   static const size_t CROP_AND_RESIZE_COST_WEIGHT2 = 8;
1300   static const size_t CROP_AND_RESIZE_COST_WEIGHT3 = 2;
1301 };
1302 
1303 class ROIAlignCost : public CropAndResizeCost {
1304  public:
ROIAlignCost()1305   ROIAlignCost() : CropAndResizeCost() {}
1306   ~ROIAlignCost() override = default;
1307 
1308   double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1309                             int64_t) const override;
1310   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
1311                                    int64_t) const override;
1312   double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &outputs,
1313                                     int64_t) const override;
1314   // Taking account for input
1315   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
1316   // Taking account of output
1317   void CalculateOutputInMemory() override;
1318 
set_pooled_shape(const Shape & pooled_shape)1319   void set_pooled_shape(const Shape &pooled_shape) { pooled_shape_ = pooled_shape; }
1320 
1321  protected:
1322   Shape pooled_shape_;
1323 
1324  private:
1325   static const size_t ROI_ALIGN_COST_WEIGHT0 = 1;
1326   static const size_t ROI_ALIGN_COST_WEIGHT1 = 4;
1327   static const size_t ROI_ALIGN_COST_WEIGHT2 = 2;
1328 };
1329 }  // namespace parallel
1330 }  // namespace mindspore
1331 #endif  // PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_
1332