• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_
18 #define PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_
19 
20 #include <memory>
21 #include <vector>
22 #include <map>
23 #include "frontend/parallel/device_manager.h"
24 #include "frontend/parallel/tensor_layout/tensor_info.h"
25 
26 namespace mindspore {
27 namespace parallel {
28 #define MAXIMUM_INPUT_NUMBER 100
29 #define DEFAULT_DATA_TYPE_LENGTH 4
30 #define DROPOUT_COST_RATE 1.125  // the DropoutGenMask need 12.5% memory
31 #define GATHERV2_COST_WEIGHT0 3
32 #define GATHERV2_COST_WEIGHT1 7
33 #define GATHERV2_COST_WEIGHT2 2
34 #define GATHERV2_COST_WEIGHT3 6
35 
36 class OperatorCost;
37 using OperatorCostPtr = std::shared_ptr<OperatorCost>;
38 
39 template <typename T>
ListProduct(std::vector<T> vec)40 double ListProduct(std::vector<T> vec) {
41   double result = 1;
42   for (size_t i = 0; i < vec.size(); ++i) {
43     result *= vec[i];
44   }
45   return result;
46 }
47 // NOTE: Currently, the returned value in each method is bytes of memory size, which is calculated by the number of
48 // entries timing the length of each entry's data type
49 class OperatorCost {
50  public:
OperatorCost()51   OperatorCost() {
52     // this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked
53     for (size_t i = 0; i < MAXIMUM_INPUT_NUMBER; ++i) {
54       is_parameter_.push_back(false);
55       is_parameter_involve_.push_back(false);
56       inputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH);
57       outputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH);
58     }
59   }
60   virtual ~OperatorCost() = default;
61 
62   void set_is_parameter(const std::vector<bool> &is_parameter);
63   void set_is_parameter_involve(const std::vector<bool> &);
64   void set_output_parameter_involve(int64_t);
65   void set_output_critical(int64_t);
66   void SetInputAndOutputTypeLength(const std::vector<size_t> &input_lengths, const std::vector<size_t> &output_lengths);
inputs_type_lengths()67   std::vector<size_t> inputs_type_lengths() const { return inputs_type_lengths_; }
outputs_type_lengths()68   std::vector<size_t> outputs_type_lengths() const { return outputs_type_lengths_; }
69 
70   // per device communication cost
71   virtual double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
72                              int64_t stage_id) const = 0;
73   virtual double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
74                                     int64_t stage_id) const = 0;
75   virtual double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
76                                      int64_t stage_id) const = 0;
77   // per device computation cost
78   virtual double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
79                                     int64_t stage_id) const = 0;
80   virtual double GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
81                                            const std::vector<TensorInfo> &outputs, int64_t stage_id) const = 0;
82   virtual double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs,
83                                             const std::vector<TensorInfo> &outputs, int64_t stage_id) const = 0;
84   virtual void CalculateOutputInMemory() = 0;
85   virtual void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) = 0;
is_output_in_memory()86   bool is_output_in_memory() const { return is_output_should_in_memory_; }
87   // per device PEAK memory cost in a training iteration
88   // Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-involved),
89   // plus necessary inputs.
90   virtual double GetMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const;
91   // Contributing the input part for 'GetMemoryCost'
92   double GetInputMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const;
93   // Contributing the output part for 'GetMemoryCost'
94   double GetOutputMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const;
95   // per device memory cost in a inference phase
96   double GetMemoryCostForInference(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &) const;
97 
98  protected:
99   // For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of
100   // pre-operator that has parameters as input.
101   std::vector<bool> is_parameter_involve_;
102   int64_t output_parameter_involve_ = -1;  // -1: unset; 0: not parameter_involved; 1: parameter_involved
103   // For each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter
104   std::vector<bool> is_parameter_;
105   // Whether the input should keep in memory in training phase. It depends on the operator and the operator's
106   // previous operators.
107   std::vector<bool> is_inputs_should_in_memory_;
108   // Whether the output should keep in memory in training phase. It depends on 'is_parameter_involve_' and the operator.
109   bool is_output_should_in_memory_ = false;
110   // For each input and output, the followings record the number of bytes of each element
111   std::vector<size_t> inputs_type_lengths_;
112   std::vector<size_t> outputs_type_lengths_;
113   // Whether the output is critical, which means that this output is included in calculating peak memory cost
114   // in the inference phase.
115   int64_t is_outputs_critical_ = -1;
116 };
117 using OperatorCostPtr = std::shared_ptr<OperatorCost>;
118 
119 class MatMulCost : public OperatorCost {
120  public:
MatMulCost()121   MatMulCost() : OperatorCost() {}
122   ~MatMulCost() override = default;
123 
124   // per device communication cost
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)125   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
126                      int64_t stage_id) const override {
127     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
128   }
129   double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
130                             int64_t stage_id) const override;
131   double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
132                              int64_t stage_id) const override;
133 
134   // per device computation cost
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)135   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
136                             int64_t stage_id) const override {
137     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
138   }
139   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
140                                    int64_t stage_id) const override;
141   double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
142                                     int64_t stage_id) const override;
143   void CalculateOutputInMemory() override;
144   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
145 };
146 using TensorDotCost = MatMulCost;
147 
148 class CastCost : public OperatorCost {
149  public:
CastCost()150   CastCost() : OperatorCost() {}
151   ~CastCost() override = default;
152 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)153   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
154                      int64_t stage_id) const override {
155     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
156   }
157   double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
158                             int64_t stage_id) const override;
159   double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
160                              int64_t stage_id) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)161   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
162                             int64_t stage_id) const override {
163     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
164   }
165   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
166                                    int64_t stage_id) const override;
167   double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
168                                     int64_t stage_id) const override;
169   // Not taking account of output
170   void CalculateOutputInMemory() override;
171   // Not Taking account of input
172   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
173 };
174 using RepeatElementsCost = CastCost;
175 using NegCost = CastCost;
176 using ExpandDimsCost = CastCost;
177 using SqueezeCost = CastCost;
178 using ConcatCost = CastCost;
179 using LogicalNotCost = CastCost;
180 using SignCost = CastCost;
181 using FloorCost = CastCost;
182 using RoundCost = CastCost;
183 using CeilCost = CastCost;
184 using ZerosLikeCost = CastCost;
185 using OnesLikeCost = CastCost;
186 using RangeCost = CastCost;
187 using SplitCost = CastCost;
188 using ScatterUpdateCost = CastCost;
189 using ResizeBilinearCost = CastCost;
190 using UniformRealCost = CastCost;
191 
192 class SqrtCost : public CastCost {
193  public:
SqrtCost()194   SqrtCost() : CastCost() {}
195   ~SqrtCost() override = default;
196   // Taking account of output, not taking accounting of input
197   void CalculateOutputInMemory() override;
198 };
199 using TanhCost = SqrtCost;
200 using EluCost = SqrtCost;
201 using ReLUCost = SqrtCost;
202 using SigmoidCost = SqrtCost;
203 using ReciprocalCost =
204   SqrtCost;  // The derivative of 'Reciprocal' is different on 'Ascend' and 'GPU'. Here, 'Ascend' is chosen
205 using InvCost = SqrtCost;
206 using RsqrtCost = SqrtCost;
207 using AsinhCost = SqrtCost;
208 using AcoshCost = SqrtCost;
209 using ReLUV2Cost = SqrtCost;
210 using TopKCost = SqrtCost;
211 
212 class ReLU6Cost : public CastCost {
213  public:
ReLU6Cost()214   ReLU6Cost() : CastCost() {}
215   ~ReLU6Cost() override = default;
216   // Taking account of input, not taking account of output
217   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
218 };
219 using SoftsignCost = ReLU6Cost;
220 using SoftplusCost = ReLU6Cost;
221 using SquareCost = ReLU6Cost;
222 using ExpCost = ReLU6Cost;
223 using LogCost = ReLU6Cost;
224 using CosCost = ReLU6Cost;
225 using ACosCost = ReLU6Cost;
226 using AbsCost = ReLU6Cost;
227 using TanCost = ReLU6Cost;
228 using SinCost = ReLU6Cost;
229 using SinhCost = ReLU6Cost;
230 using Log1pCost = ReLU6Cost;
231 using Expm1Cost = ReLU6Cost;
232 using CoshCost = ReLU6Cost;
233 using AtanhCost = ReLU6Cost;
234 using AtanCost = ReLU6Cost;
235 using AsinCost = ReLU6Cost;
236 using ErfCost = ReLU6Cost;
237 using ErfcCost = ReLU6Cost;
238 using ActivationInfoCost = ReLU6Cost;
239 using SelectCost = ReLU6Cost;
240 
241 class TransposeCost : public CastCost {
242  public:
TransposeCost()243   TransposeCost() : CastCost() {}
244   ~TransposeCost() override = default;
245   // Taking account of input, not taking account of output
246   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
247 };
248 
249 class GeLUCost : public SqrtCost {
250  public:
GeLUCost()251   GeLUCost() : SqrtCost() {}
252   ~GeLUCost() override = default;
253   // Taking account of input and output
254   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
255 };
256 using FastGeLUCost = GeLUCost;
257 using BesselI0eCost = GeLUCost;
258 using BesselI1eCost = GeLUCost;
259 using L2NormalizeCost = GeLUCost;
260 using MaxPoolCost = GeLUCost;
261 
262 class SoftmaxCost : public OperatorCost {
263  public:
SoftmaxCost()264   SoftmaxCost() : OperatorCost() {}
265   ~SoftmaxCost() override = default;
266 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)267   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
268                      int64_t stage_id) const override {
269     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
270   }
271   double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
272                             int64_t stage_id) const override;
273   double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
274                              int64_t stage_id) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)275   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
276                             int64_t stage_id) const override {
277     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
278   }
279   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
280                                    int64_t stage_id) const override;
281   double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
282                                     int64_t) const override;
283   // Taking account of output
284   void CalculateOutputInMemory() override;
285   // Not Taking account of input
286   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
287 };
288 
289 class TileCost : public SoftmaxCost {
290  public:
TileCost()291   TileCost() : SoftmaxCost() {}
292   ~TileCost() override = default;
293   // Not taking account of output
294   void CalculateOutputInMemory() override;
295   // Taking account of input
296   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
297 };
298 
299 class PackCost : public SoftmaxCost {
300  public:
PackCost()301   PackCost() : SoftmaxCost() {}
302   ~PackCost() override = default;
303   // Not taking account of output
304   void CalculateOutputInMemory() override;
305   // Not taking account of input
306   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
307 };
308 
309 class BroadcastToCost : public SoftmaxCost {
310  public:
BroadcastToCost()311   BroadcastToCost() : SoftmaxCost() {}
312   ~BroadcastToCost() override = default;
313   // Not taking account of output
314   void CalculateOutputInMemory() override;
315   // Not Taking account of input
316   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
317 };
318 
319 class TmpIdentityCost : public OperatorCost {
320  public:
TmpIdentityCost()321   TmpIdentityCost() : OperatorCost() {}
322   ~TmpIdentityCost() override = default;
323 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)324   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
325                      int64_t stage_id) const override {
326     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
327   }
328   double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
329                             int64_t stage_id) const override;
330   double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
331                              int64_t stage_id) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)332   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
333                             int64_t stage_id) const override {
334     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
335   }
336   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
337                                    int64_t stage_id) const override;
338   double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
339                                     int64_t stage_id) const override;
340   // Not taking account of output
341   void CalculateOutputInMemory() override;
342   // Not taking account of input
343   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
344 };
345 using TmpIdentityCostPtr = std::shared_ptr<TmpIdentityCost>;
346 
347 class BatchParallelCost : public OperatorCost {
348  public:
BatchParallelCost()349   BatchParallelCost() : OperatorCost() {}
350   ~BatchParallelCost() override = default;
351 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)352   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
353                      int64_t stage_id) const override {
354     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
355   }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)356   double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
357     return 0.0;
358   }
359   double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)360   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
361                             int64_t stage_id) const override {
362     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
363   }
364   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
365                                    int64_t stage_id) const override;
366   double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
367                                     int64_t stage_id) const override;
368   // Not taking account of output
369   void CalculateOutputInMemory() override;
370   // Taking account of input
371   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
372 };
373 
374 class SparseSoftmaxCrossEntropyWithLogitsCost : public BatchParallelCost {
375  public:
SparseSoftmaxCrossEntropyWithLogitsCost()376   SparseSoftmaxCrossEntropyWithLogitsCost() : BatchParallelCost() {}
377   ~SparseSoftmaxCrossEntropyWithLogitsCost() override = default;
378   // Taking account of output
379   void CalculateOutputInMemory() override;
380   // Not taking account of input
381   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
382 };
383 
384 class VirtualDatasetCost : public OperatorCost {
385  public:
VirtualDatasetCost()386   VirtualDatasetCost() : OperatorCost() {}
387   ~VirtualDatasetCost() override = default;
388 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)389   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
390                      int64_t stage_id) const override {
391     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
392   }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)393   double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
394     return 0.0;
395   }
GetBackwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)396   double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
397     return 0.0;
398   }
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)399   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
400                             int64_t stage_id) const override {
401     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
402   }
GetForwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)403   double GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
404                                    int64_t) const override {
405     return 0.0;
406   }
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)407   double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
408                                     int64_t) const override {
409     return 0.0;
410   }
411   // Not taking account of output
412   void CalculateOutputInMemory() override;
413   // Not taking account of input
414   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
415 };
416 
417 class GeneratorBaseCost : public OperatorCost {
418  public:
GeneratorBaseCost()419   GeneratorBaseCost() : OperatorCost() {}
420   ~GeneratorBaseCost() override = default;
421 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)422   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
423                      int64_t stage_id) const override {
424     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
425   }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)426   double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
427     return 0.0;
428   }
GetBackwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)429   double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
430     return 0.0;
431   }
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)432   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
433                             int64_t stage_id) const override {
434     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
435   }
436   // Inputs vector is empty for generator ops.
GetForwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)437   double GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
438                                    int64_t) const override {
439     return 0.0;
440   }
441   // Generator ops don't have backward steps.
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)442   double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
443                                     int64_t) const override {
444     return 0.0;
445   }
446 };
447 using GeneratorBaseCostPtr = std::shared_ptr<GeneratorBaseCost>;
448 
449 class PReLUCost : public OperatorCost {
450  public:
PReLUCost()451   PReLUCost() : OperatorCost() {}
452   ~PReLUCost() override = default;
453 
454   // per device communication cost
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)455   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
456                      int64_t stage_id) const override {
457     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
458   }
459   double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
460                             int64_t stage_id) const override;
461   double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
462                              int64_t stage_id) const override;
463 
464   // per device computation cost
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)465   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
466                             int64_t stage_id) const override {
467     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
468   }
469   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
470                                    int64_t stage_id) const override;
471   double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
472                                     int64_t stage_id) const override;
473   // Not taking account of output
474   void CalculateOutputInMemory() override;
475   // Taking account of input
476   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
477 };
478 using PReLUCostPtr = std::shared_ptr<PReLUCost>;
479 
480 class OneHotCost : public OperatorCost {
481  public:
OneHotCost()482   OneHotCost() : OperatorCost() {}
483   ~OneHotCost() override = default;
484 
485   // per device communication cost
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)486   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
487                      int64_t stage_id) const override {
488     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
489   }
490   double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
491                             int64_t stage_id) const override;
492   double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
493                              int64_t stage_id) const override;
494 
495   // per device computation cost
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)496   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
497                             int64_t stage_id) const override {
498     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
499   }
500   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
501                                    int64_t stage_id) const override;
502   double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
503                                     int64_t stage_id) const override;
504   // Not taking account of output
505   void CalculateOutputInMemory() override;
506   // Not taking account of input
507   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
508 };
509 using OneHotCostPtr = std::shared_ptr<OneHotCost>;
510 
511 class SoftmaxCrossEntropyWithLogitsCost : public OperatorCost {
512  public:
SoftmaxCrossEntropyWithLogitsCost()513   SoftmaxCrossEntropyWithLogitsCost() : OperatorCost() {}
514   ~SoftmaxCrossEntropyWithLogitsCost() override = default;
515 
516   // per device communication cost
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)517   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
518                      int64_t stage_id) const override {
519     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
520   }
521   double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
522                             int64_t stage_id) const override;
523   double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
524                              int64_t stage_id) const override;
525 
526   // per device computation cost
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)527   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
528                             int64_t stage_id) const override {
529     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
530   }
531   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
532                                    int64_t stage_id) const override;
533   double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
534                                     int64_t stage_id) const override;
535   // Taking account of output
536   void CalculateOutputInMemory() override;
537   // Not taking account of input
538   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
539 };
540 
541 class ReshapeCost : public OperatorCost {
542  public:
ReshapeCost()543   ReshapeCost() : OperatorCost() {}
544 
545   ~ReshapeCost() override = default;
546 
547   // per device communication cost
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)548   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
549                      int64_t stage_id) const override {
550     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
551   }
552 
553   double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
554                             int64_t stage_id) const override;
555 
556   double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
557                              int64_t stage_id) const override;
558 
559   // per device computation cost
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)560   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
561                             int64_t stage_id) const override {
562     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
563   }
564 
565   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
566                                    int64_t stage_id) const override;
567 
568   double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
569                                     int64_t stage_id) const override;
570   // Not taking account of output
571   void CalculateOutputInMemory() override;
572   // Not taking account of input
573   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
574 };
575 using ReshapeCostPtr = std::shared_ptr<ReshapeCost>;
576 
577 class SubCost : public OperatorCost {
578  public:
SubCost()579   SubCost() : OperatorCost() {}
580   ~SubCost() override = default;
581 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)582   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
583                      int64_t stage_id) const override {
584     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
585   }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)586   double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
587     return 0.0;
588   }
589   double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override;
590 
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)591   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
592                             int64_t stage_id) const override {
593     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
594   }
595   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
596                                    int64_t stage_id) const override;
597   double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
598                                     int64_t stage_id) const override;
599   // Not taking account of output
600   void CalculateOutputInMemory() override;
601   // Not taking account of input
602   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
603 };
604 using TensorAddCost = SubCost;
605 using FloorDivCost = SubCost;
606 using AssignSubCost = SubCost;
607 using AssignAddCost = SubCost;
608 using LogicalAndCost = SubCost;
609 using LogicalOrCost = SubCost;
610 using BiasAddCost = SubCost;
611 using EqualCost = SubCost;
612 using ApproximateEqualCost = SubCost;
613 using NotEqualCost = SubCost;
614 using GreaterCost = SubCost;
615 using GreaterEqualCost = SubCost;
616 using LessCost = SubCost;
617 using LessEqualCost = SubCost;
618 using GatherNdCost = SubCost;
619 
620 class MulCost : public SubCost {
621  public:
MulCost()622   MulCost() : SubCost() {}
623   ~MulCost() override = default;
624   // Taking account of input, not taking account of output
625   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
626 };
627 
628 using GatherDCost = MulCost;
629 
630 class DivCost : public SubCost {
631  public:
DivCost()632   DivCost() : SubCost() {}
633   ~DivCost() override = default;
634   // Taking account of output
635   void CalculateOutputInMemory() override;
636   // Taking account of input
637   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
638 };
639 using ReadDivCost = DivCost;
640 
641 class ModCost : public SubCost {
642  public:
ModCost()643   ModCost() : SubCost() {}
644   ~ModCost() override = default;
645   // Taking account of input, not taking account of output
646   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
647 };
648 using FloorModCost = ModCost;
649 
650 class PowCost : public SubCost {
651  public:
PowCost()652   PowCost() : SubCost() {}
653   ~PowCost() override = default;
654   // Taking account of output
655   void CalculateOutputInMemory() override;
656   // Taking account of input
657   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
658 };
659 
660 class AssignCost : public SubCost {
661  public:
AssignCost()662   AssignCost() : SubCost() {}
663   ~AssignCost() override = default;
664   // Taking account of input, not taking account of output
665   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
666 };
667 
668 class SigmoidCrossEntropyWithLogitsCost : public SubCost {
669  public:
SigmoidCrossEntropyWithLogitsCost()670   SigmoidCrossEntropyWithLogitsCost() : SubCost() {}
671   ~SigmoidCrossEntropyWithLogitsCost() override = default;
672   // Taking account of input, not taking account of output
673   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
674 };
675 
676 class Atan2Cost : public SubCost {
677  public:
Atan2Cost()678   Atan2Cost() : SubCost() {}
679   ~Atan2Cost() override = default;
680   // Taking account of input, not taking account of output
681   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
682 };
683 
684 class DivNoNanCost : public SubCost {
685  public:
DivNoNanCost()686   DivNoNanCost() : SubCost() {}
687   ~DivNoNanCost() override = default;
688   // Taking account of output
689   void CalculateOutputInMemory() override;
690   // Taking account of input
691   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
692 };
693 
694 class MaximumCost : public SubCost {
695  public:
MaximumCost()696   MaximumCost() : SubCost() {}
697   ~MaximumCost() override = default;
698   // Taking account of input, not taking account of output
699   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
700 };
701 using MinimumCost = MaximumCost;
702 
703 class SliceCost : public CastCost {
704  public:
SliceCost()705   SliceCost() : CastCost() {}
706   ~SliceCost() override = default;
707   // Not taking account of output, taking account of input
708   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
709 };
710 
711 class StridedSliceCost : public CastCost {
712  public:
StridedSliceCost()713   StridedSliceCost() : CastCost() {}
714   ~StridedSliceCost() override = default;
715   // Not taking account of output, taking account of input
716   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
717 };
718 
719 class ReduceSumCost : public OperatorCost {
720  public:
ReduceSumCost()721   ReduceSumCost() : OperatorCost() {}
722   ~ReduceSumCost() override = default;
723 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)724   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
725                      int64_t stage_id) const override {
726     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
727   }
728   double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
729                             int64_t stage_id) const override;
730   double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
731                              int64_t stage_id) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)732   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
733                             int64_t stage_id) const override {
734     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
735   }
736   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
737                                    int64_t stage_id) const override;
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)738   double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
739                                     int64_t) const override {
740     return 0.0;
741   }
set_cross_batch(bool cb)742   void set_cross_batch(bool cb) { cross_batch_ = cb; }
743   // Not taking account of output
744   void CalculateOutputInMemory() override;
745   // Taking account of input
746   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
747 
748  protected:
749   bool cross_batch_ = false;
750 };
751 using ReduceMethodCost = ReduceSumCost;
752 
753 class ReduceMeanCost : public ReduceSumCost {
754  public:
ReduceMeanCost()755   ReduceMeanCost() : ReduceSumCost() {}
756   ~ReduceMeanCost() override = default;
757 
758   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
759                                    int64_t stage_id) const override;
760 };
761 
762 class ReduceMinCost : public ReduceSumCost {
763  public:
ReduceMinCost()764   ReduceMinCost() : ReduceSumCost() {}
765   ~ReduceMinCost() override = default;
766   // Taking account of output
767   void CalculateOutputInMemory() override;
768   // Taking account of input
769   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
770 };
771 using ReduceMaxCost = ReduceMinCost;
772 
773 class ArgMaxWithValueCost : public ReduceSumCost {
774  public:
ArgMaxWithValueCost()775   ArgMaxWithValueCost() : ReduceSumCost() {}
776   ~ArgMaxWithValueCost() override = default;
777   // Taking account of output
778   void CalculateOutputInMemory() override;
779   // Taking account of input
780   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
781 };
782 using ArgMinWithValueCost = ArgMaxWithValueCost;
783 
784 class GetNextCost : public OperatorCost {
785  public:
GetNextCost()786   GetNextCost() : OperatorCost() {}
787   ~GetNextCost() override = default;
788 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)789   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
790                      int64_t stage_id) const override {
791     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
792   }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)793   double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
794     return 0.0;
795   }
GetBackwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)796   double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
797     return 0.0;
798   }
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)799   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
800                             int64_t stage_id) const override {
801     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
802   }
803   // Inputs vector is empty for generator ops.
GetForwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)804   double GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
805                                    int64_t) const override {
806     return 0.0;
807   }
808   // Generator ops don't have backward steps.
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)809   double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
810                                     int64_t) const override {
811     return 0.0;
812   }
813   // Not taking account of output
814   void CalculateOutputInMemory() override;
815   // Not Taking account of input
816   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
817 };
818 using GetNextCostPtr = std::shared_ptr<GetNextCost>;
819 
820 class DSDMatmulCost : public OperatorCost {
821  public:
DSDMatmulCost()822   DSDMatmulCost() : OperatorCost() {}
823   ~DSDMatmulCost() override = default;
824 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)825   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
826                      int64_t stage_id) const override {
827     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
828   }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)829   double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
830     return 0.0;
831   }
GetBackwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)832   double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
833     return 0.0;
834   }
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)835   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
836                             int64_t stage_id) const override {
837     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
838   }
839   // Inputs vector is empty for generator ops.
840   double GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
841                                    int64_t) const override;
842   // Generator ops don't have backward steps.
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)843   double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
844                                     int64_t) const override {
845     return 0.0;
846   }
847   // Not taking account of output
848   void CalculateOutputInMemory() override;
849   // Not Taking account of input
850   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
851 };
852 using DSDMatmulCostPtr = std::shared_ptr<DSDMatmulCost>;
853 
854 // For memory cost, taking account of output, not taking account of input
855 class DropOutCost : public SqrtCost {
856  public:
DropOutCost()857   DropOutCost() : SqrtCost() {}
858   ~DropOutCost() override = default;
859 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)860   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
861                      int64_t stage_id) const override {
862     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
863   }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)864   double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
865     return 0.0;
866   }
GetBackwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)867   double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
868     return 0.0;
869   }
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)870   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
871                             int64_t stage_id) const override {
872     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
873   }
874   double GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
875                                    int64_t) const override;
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)876   double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
877                                     int64_t) const override {
878     return 0.0;
879   }
880 };
881 
882 class DropOutDoMaskCost : public DropOutCost {
883  public:
DropOutDoMaskCost()884   DropOutDoMaskCost() : DropOutCost() {}
885   ~DropOutDoMaskCost() override = default;
886   // Not taking account of output
887   void CalculateOutputInMemory() override;
888   // Taking account of input
889   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
890 };
891 
892 class UnsortedSegmentSumCost : public OperatorCost {
893  public:
UnsortedSegmentSumCost()894   UnsortedSegmentSumCost() : OperatorCost() {}
895   ~UnsortedSegmentSumCost() override = default;
896 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)897   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
898                      int64_t stage_id) const override {
899     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
900   }
901   double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override;
902   double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)903   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
904                             int64_t stage_id) const override {
905     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
906   }
907   double GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
908                                    int64_t) const override;
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)909   double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
910                                     int64_t) const override {
911     return 0.0;
912   }
913   // Not taking account of output
914   void CalculateOutputInMemory() override;
915   // Taking account of input
916   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
917 };
918 
919 class UnsortedSegmentMinCost : public OperatorCost {
920  public:
UnsortedSegmentMinCost()921   UnsortedSegmentMinCost() : OperatorCost() {}
922   ~UnsortedSegmentMinCost() override = default;
923 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)924   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
925                      int64_t stage_id) const override {
926     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
927   }
928   double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override;
929   double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)930   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
931                             int64_t stage_id) const override {
932     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
933   }
934   double GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
935                                    int64_t) const override;
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)936   double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
937                                     int64_t) const override {
938     return 0.0;
939   }
940   // Taking account of output
941   void CalculateOutputInMemory() override;
942   // Taking account of input
943   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
944 };
945 using UnsortedSegmentMaxCost = UnsortedSegmentMinCost;
946 
947 class LayerNormCost : public OperatorCost {
948  public:
LayerNormCost()949   LayerNormCost() : OperatorCost() {}
950   ~LayerNormCost() override = default;
951 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)952   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
953                      int64_t stage_id) const override {
954     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
955   }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)956   double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
957     return 0.0;
958   }
959   double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)960   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
961                             int64_t stage_id) const override {
962     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
963   }
964   double GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
965                                    int64_t) const override;
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)966   double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
967                                     int64_t) const override {
968     return 0.0;
969   }
970   // Taking account of output
971   void CalculateOutputInMemory() override;
972   // Taking account of input
973   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
974 };
975 
976 class UniqueCost : public OperatorCost {
977  public:
UniqueCost()978   UniqueCost() : OperatorCost() {}
979   ~UniqueCost() override = default;
980 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)981   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
982                      int64_t stage_id) const override {
983     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
984   }
985   double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
986                             int64_t stage_id) const override;
987   double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
988                              int64_t stage_id) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)989   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
990                             int64_t stage_id) const override {
991     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
992   }
993   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
994                                    int64_t stage_id) const override;
995   double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
996                                     int64_t) const override;
997   // Taking account of output
998   void CalculateOutputInMemory() override;
999   // Not Taking account of input
1000   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
1001 };
1002 
1003 class UniformCandidateSamplerCost : public OperatorCost {
1004  public:
UniformCandidateSamplerCost()1005   UniformCandidateSamplerCost() : OperatorCost() {}
1006   ~UniformCandidateSamplerCost() override = default;
1007 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1008   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1009                      int64_t stage_id) const override {
1010     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
1011   }
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1012   double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1013                             int64_t stage_id) const override {
1014     return 0;
1015   }
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1016   double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1017                              int64_t stage_id) const override {
1018     return 0;
1019   }
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1020   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1021                             int64_t stage_id) const override {
1022     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
1023   }
1024   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1025                                    int64_t stage_id) const override;
GetBackwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t)1026   double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1027                                     int64_t) const override {
1028     return 0.0;
1029   }
1030   // Not taking account of output
1031   void CalculateOutputInMemory() override;
1032   // Not Taking account of input
1033   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
1034 };
1035 
1036 class GatherV2Cost : public OperatorCost {
1037  public:
GatherV2Cost()1038   GatherV2Cost() : OperatorCost() {}
1039   ~GatherV2Cost() override = default;
1040 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1041   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1042                      int64_t stage_id) const override {
1043     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
1044   }
1045   double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1046                             int64_t stage_id) const override;
1047   double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1048                              int64_t stage_id) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1049   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1050                             int64_t stage_id) const override {
1051     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
1052   }
1053   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1054                                    int64_t stage_id) const override;
1055   double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1056                                     int64_t) const override;
1057   // Not taking account of output
1058   void CalculateOutputInMemory() override;
1059   // Taking account of input
1060   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
1061 };
1062 
1063 class GatherV2PCost : public GatherV2Cost {
1064  public:
GatherV2PCost()1065   GatherV2PCost() : GatherV2Cost(), axis_(0) {}
1066   ~GatherV2PCost() override = default;
1067 
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1068   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1069                      int64_t stage_id) const override {
1070     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
1071   }
1072   double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1073                             int64_t stage_id) const override;
1074   double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1075                              int64_t stage_id) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1076   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1077                             int64_t stage_id) const override {
1078     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
1079   }
1080   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1081                                    int64_t stage_id) const override;
1082   double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1083                                     int64_t) const override;
set_axis(int64_t axis)1084   void set_axis(int64_t axis) { axis_ = axis; }
set_strategy(const Shape & strategy)1085   void set_strategy(const Shape &strategy) { strategy_ = strategy; }
1086 
1087  protected:
1088   int64_t axis_;
1089   Shape strategy_;
1090 };
1091 
1092 class MatmulDDSCost : public OperatorCost {
1093  public:
MatmulDDSCost()1094   MatmulDDSCost() : OperatorCost() {}
1095   ~MatmulDDSCost() override = default;
1096 
1097   // per device communication cost
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1098   double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1099                      int64_t stage_id) const override {
1100     return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
1101   }
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1102   double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1103                             int64_t stage_id) const override {
1104     return 0.0;
1105   };
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1106   double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1107                              int64_t stage_id) const override {
1108     return 0.0;
1109   };
1110 
1111   // per device computation cost
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1112   double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1113                             int64_t stage_id) const override {
1114     return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
1115   }
1116   double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1117                                    int64_t stage_id) const override;
GetBackwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1118   double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1119                                     int64_t stage_id) const override {
1120     return 0.0;
1121   };
1122   // Not taking account of output
1123   void CalculateOutputInMemory() override;
1124   // Taking account of input
1125   void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
1126 };
1127 using MatmulDDSCostPtr = std::shared_ptr<MatmulDDSCost>;
1128 
1129 }  // namespace parallel
1130 }  // namespace mindspore
1131 #endif  // PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_
1132