1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_
18 #define PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_
19
20 #include <memory>
21 #include <vector>
22 #include <map>
23 #include "frontend/parallel/device_manager.h"
24 #include "frontend/parallel/tensor_layout/tensor_info.h"
25
26 namespace mindspore {
27 namespace parallel {
28 #define MAXIMUM_INPUT_NUMBER 100
29 #define DEFAULT_DATA_TYPE_LENGTH 4
30 #define DROPOUT_COST_RATE 1.125 // the DropoutGenMask need 12.5% memory
31 #define GATHERV2_COST_WEIGHT0 3
32 #define GATHERV2_COST_WEIGHT1 7
33 #define GATHERV2_COST_WEIGHT2 2
34 #define GATHERV2_COST_WEIGHT3 6
35
36 class OperatorCost;
37 using OperatorCostPtr = std::shared_ptr<OperatorCost>;
38
39 template <typename T>
ListProduct(std::vector<T> vec)40 double ListProduct(std::vector<T> vec) {
41 double result = 1;
42 for (size_t i = 0; i < vec.size(); ++i) {
43 result *= vec[i];
44 }
45 return result;
46 }
47 // NOTE: Currently, the returned value in each method is bytes of memory size, which is calculated by the number of
48 // entries timing the length of each entry's data type
49 class OperatorCost {
50 public:
OperatorCost()51 OperatorCost() {
52 // this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked
53 for (size_t i = 0; i < MAXIMUM_INPUT_NUMBER; ++i) {
54 is_parameter_.push_back(false);
55 is_parameter_involve_.push_back(false);
56 inputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH);
57 outputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH);
58 }
59 }
60 virtual ~OperatorCost() = default;
61
62 void set_is_parameter(const std::vector<bool> &is_parameter);
63 void set_is_parameter_involve(const std::vector<bool> &);
64 void set_output_parameter_involve(int64_t);
65 void set_output_critical(int64_t);
66 void SetInputAndOutputTypeLength(const std::vector<size_t> &input_lengths, const std::vector<size_t> &output_lengths);
inputs_type_lengths()67 std::vector<size_t> inputs_type_lengths() const { return inputs_type_lengths_; }
outputs_type_lengths()68 std::vector<size_t> outputs_type_lengths() const { return outputs_type_lengths_; }
69
70 // per device communication cost
71 virtual double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
72 int64_t stage_id) const = 0;
73 virtual double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
74 int64_t stage_id) const = 0;
75 virtual double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
76 int64_t stage_id) const = 0;
77 // per device computation cost
78 virtual double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
79 int64_t stage_id) const = 0;
80 virtual double GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
81 const std::vector<TensorInfo> &outputs, int64_t stage_id) const = 0;
82 virtual double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs,
83 const std::vector<TensorInfo> &outputs, int64_t stage_id) const = 0;
84 virtual void CalculateOutputInMemory() = 0;
85 virtual void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) = 0;
is_output_in_memory()86 bool is_output_in_memory() const { return is_output_should_in_memory_; }
87 // per device PEAK memory cost in a training iteration
88 // Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-involved),
89 // plus necessary inputs.
90 virtual double GetMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const;
91 // Contributing the input part for 'GetMemoryCost'
92 double GetInputMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const;
93 // Contributing the output part for 'GetMemoryCost'
94 double GetOutputMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const;
95 // per device memory cost in a inference phase
96 double GetMemoryCostForInference(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &) const;
97
98 protected:
99 // For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of
100 // pre-operator that has parameters as input.
101 std::vector<bool> is_parameter_involve_;
102 int64_t output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved
103 // For each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter
104 std::vector<bool> is_parameter_;
105 // Whether the input should keep in memory in training phase. It depends on the operator and the operator's
106 // previous operators.
107 std::vector<bool> is_inputs_should_in_memory_;
108 // Whether the output should keep in memory in training phase. It depends on 'is_parameter_involve_' and the operator.
109 bool is_output_should_in_memory_ = false;
110 // For each input and output, the followings record the number of bytes of each element
111 std::vector<size_t> inputs_type_lengths_;
112 std::vector<size_t> outputs_type_lengths_;
113 // Whether the output is critical, which means that this output is included in calculating peak memory cost
114 // in the inference phase.
115 int64_t is_outputs_critical_ = -1;
116 };
117 using OperatorCostPtr = std::shared_ptr<OperatorCost>;
118
119 class MatMulCost : public OperatorCost {
120 public:
MatMulCost()121 MatMulCost() : OperatorCost() {}
122 ~MatMulCost() override = default;
123
124 // per device communication cost
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)125 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
126 int64_t stage_id) const override {
127 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
128 }
129 double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
130 int64_t stage_id) const override;
131 double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
132 int64_t stage_id) const override;
133
134 // per device computation cost
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)135 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
136 int64_t stage_id) const override {
137 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
138 }
139 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
140 int64_t stage_id) const override;
141 double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
142 int64_t stage_id) const override;
143 void CalculateOutputInMemory() override;
144 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
145 };
146 using TensorDotCost = MatMulCost;
147
148 class CastCost : public OperatorCost {
149 public:
CastCost()150 CastCost() : OperatorCost() {}
151 ~CastCost() override = default;
152
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)153 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
154 int64_t stage_id) const override {
155 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
156 }
157 double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
158 int64_t stage_id) const override;
159 double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
160 int64_t stage_id) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)161 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
162 int64_t stage_id) const override {
163 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
164 }
165 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
166 int64_t stage_id) const override;
167 double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
168 int64_t stage_id) const override;
169 // Not taking account of output
170 void CalculateOutputInMemory() override;
171 // Not Taking account of input
172 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
173 };
174 using RepeatElementsCost = CastCost;
175 using NegCost = CastCost;
176 using ExpandDimsCost = CastCost;
177 using SqueezeCost = CastCost;
178 using ConcatCost = CastCost;
179 using LogicalNotCost = CastCost;
180 using SignCost = CastCost;
181 using FloorCost = CastCost;
182 using RoundCost = CastCost;
183 using CeilCost = CastCost;
184 using ZerosLikeCost = CastCost;
185 using OnesLikeCost = CastCost;
186 using RangeCost = CastCost;
187 using SplitCost = CastCost;
188 using ScatterUpdateCost = CastCost;
189 using ResizeBilinearCost = CastCost;
190 using UniformRealCost = CastCost;
191
192 class SqrtCost : public CastCost {
193 public:
SqrtCost()194 SqrtCost() : CastCost() {}
195 ~SqrtCost() override = default;
196 // Taking account of output, not taking accounting of input
197 void CalculateOutputInMemory() override;
198 };
199 using TanhCost = SqrtCost;
200 using EluCost = SqrtCost;
201 using ReLUCost = SqrtCost;
202 using SigmoidCost = SqrtCost;
203 using ReciprocalCost =
204 SqrtCost; // The derivative of 'Reciprocal' is different on 'Ascend' and 'GPU'. Here, 'Ascend' is chosen
205 using InvCost = SqrtCost;
206 using RsqrtCost = SqrtCost;
207 using AsinhCost = SqrtCost;
208 using AcoshCost = SqrtCost;
209 using ReLUV2Cost = SqrtCost;
210 using TopKCost = SqrtCost;
211
212 class ReLU6Cost : public CastCost {
213 public:
ReLU6Cost()214 ReLU6Cost() : CastCost() {}
215 ~ReLU6Cost() override = default;
216 // Taking account of input, not taking account of output
217 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
218 };
219 using SoftsignCost = ReLU6Cost;
220 using SoftplusCost = ReLU6Cost;
221 using SquareCost = ReLU6Cost;
222 using ExpCost = ReLU6Cost;
223 using LogCost = ReLU6Cost;
224 using CosCost = ReLU6Cost;
225 using ACosCost = ReLU6Cost;
226 using AbsCost = ReLU6Cost;
227 using TanCost = ReLU6Cost;
228 using SinCost = ReLU6Cost;
229 using SinhCost = ReLU6Cost;
230 using Log1pCost = ReLU6Cost;
231 using Expm1Cost = ReLU6Cost;
232 using CoshCost = ReLU6Cost;
233 using AtanhCost = ReLU6Cost;
234 using AtanCost = ReLU6Cost;
235 using AsinCost = ReLU6Cost;
236 using ErfCost = ReLU6Cost;
237 using ErfcCost = ReLU6Cost;
238 using ActivationInfoCost = ReLU6Cost;
239 using SelectCost = ReLU6Cost;
240
241 class TransposeCost : public CastCost {
242 public:
TransposeCost()243 TransposeCost() : CastCost() {}
244 ~TransposeCost() override = default;
245 // Taking account of input, not taking account of output
246 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
247 };
248
249 class GeLUCost : public SqrtCost {
250 public:
GeLUCost()251 GeLUCost() : SqrtCost() {}
252 ~GeLUCost() override = default;
253 // Taking account of input and output
254 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
255 };
256 using FastGeLUCost = GeLUCost;
257 using BesselI0eCost = GeLUCost;
258 using BesselI1eCost = GeLUCost;
259 using L2NormalizeCost = GeLUCost;
260 using MaxPoolCost = GeLUCost;
261
262 class SoftmaxCost : public OperatorCost {
263 public:
SoftmaxCost()264 SoftmaxCost() : OperatorCost() {}
265 ~SoftmaxCost() override = default;
266
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)267 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
268 int64_t stage_id) const override {
269 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
270 }
271 double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
272 int64_t stage_id) const override;
273 double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
274 int64_t stage_id) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)275 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
276 int64_t stage_id) const override {
277 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
278 }
279 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
280 int64_t stage_id) const override;
281 double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
282 int64_t) const override;
283 // Taking account of output
284 void CalculateOutputInMemory() override;
285 // Not Taking account of input
286 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
287 };
288
289 class TileCost : public SoftmaxCost {
290 public:
TileCost()291 TileCost() : SoftmaxCost() {}
292 ~TileCost() override = default;
293 // Not taking account of output
294 void CalculateOutputInMemory() override;
295 // Taking account of input
296 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
297 };
298
299 class PackCost : public SoftmaxCost {
300 public:
PackCost()301 PackCost() : SoftmaxCost() {}
302 ~PackCost() override = default;
303 // Not taking account of output
304 void CalculateOutputInMemory() override;
305 // Not taking account of input
306 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
307 };
308
309 class BroadcastToCost : public SoftmaxCost {
310 public:
BroadcastToCost()311 BroadcastToCost() : SoftmaxCost() {}
312 ~BroadcastToCost() override = default;
313 // Not taking account of output
314 void CalculateOutputInMemory() override;
315 // Not Taking account of input
316 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
317 };
318
319 class TmpIdentityCost : public OperatorCost {
320 public:
TmpIdentityCost()321 TmpIdentityCost() : OperatorCost() {}
322 ~TmpIdentityCost() override = default;
323
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)324 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
325 int64_t stage_id) const override {
326 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
327 }
328 double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
329 int64_t stage_id) const override;
330 double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
331 int64_t stage_id) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)332 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
333 int64_t stage_id) const override {
334 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
335 }
336 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
337 int64_t stage_id) const override;
338 double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
339 int64_t stage_id) const override;
340 // Not taking account of output
341 void CalculateOutputInMemory() override;
342 // Not taking account of input
343 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
344 };
345 using TmpIdentityCostPtr = std::shared_ptr<TmpIdentityCost>;
346
347 class BatchParallelCost : public OperatorCost {
348 public:
BatchParallelCost()349 BatchParallelCost() : OperatorCost() {}
350 ~BatchParallelCost() override = default;
351
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)352 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
353 int64_t stage_id) const override {
354 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
355 }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)356 double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
357 return 0.0;
358 }
359 double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)360 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
361 int64_t stage_id) const override {
362 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
363 }
364 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
365 int64_t stage_id) const override;
366 double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
367 int64_t stage_id) const override;
368 // Not taking account of output
369 void CalculateOutputInMemory() override;
370 // Taking account of input
371 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
372 };
373
374 class SparseSoftmaxCrossEntropyWithLogitsCost : public BatchParallelCost {
375 public:
SparseSoftmaxCrossEntropyWithLogitsCost()376 SparseSoftmaxCrossEntropyWithLogitsCost() : BatchParallelCost() {}
377 ~SparseSoftmaxCrossEntropyWithLogitsCost() override = default;
378 // Taking account of output
379 void CalculateOutputInMemory() override;
380 // Not taking account of input
381 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
382 };
383
384 class VirtualDatasetCost : public OperatorCost {
385 public:
VirtualDatasetCost()386 VirtualDatasetCost() : OperatorCost() {}
387 ~VirtualDatasetCost() override = default;
388
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)389 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
390 int64_t stage_id) const override {
391 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
392 }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)393 double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
394 return 0.0;
395 }
GetBackwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)396 double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
397 return 0.0;
398 }
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)399 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
400 int64_t stage_id) const override {
401 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
402 }
GetForwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)403 double GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
404 int64_t) const override {
405 return 0.0;
406 }
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)407 double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
408 int64_t) const override {
409 return 0.0;
410 }
411 // Not taking account of output
412 void CalculateOutputInMemory() override;
413 // Not taking account of input
414 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
415 };
416
417 class GeneratorBaseCost : public OperatorCost {
418 public:
GeneratorBaseCost()419 GeneratorBaseCost() : OperatorCost() {}
420 ~GeneratorBaseCost() override = default;
421
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)422 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
423 int64_t stage_id) const override {
424 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
425 }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)426 double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
427 return 0.0;
428 }
GetBackwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)429 double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
430 return 0.0;
431 }
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)432 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
433 int64_t stage_id) const override {
434 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
435 }
436 // Inputs vector is empty for generator ops.
GetForwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)437 double GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
438 int64_t) const override {
439 return 0.0;
440 }
441 // Generator ops don't have backward steps.
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)442 double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
443 int64_t) const override {
444 return 0.0;
445 }
446 };
447 using GeneratorBaseCostPtr = std::shared_ptr<GeneratorBaseCost>;
448
449 class PReLUCost : public OperatorCost {
450 public:
PReLUCost()451 PReLUCost() : OperatorCost() {}
452 ~PReLUCost() override = default;
453
454 // per device communication cost
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)455 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
456 int64_t stage_id) const override {
457 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
458 }
459 double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
460 int64_t stage_id) const override;
461 double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
462 int64_t stage_id) const override;
463
464 // per device computation cost
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)465 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
466 int64_t stage_id) const override {
467 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
468 }
469 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
470 int64_t stage_id) const override;
471 double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
472 int64_t stage_id) const override;
473 // Not taking account of output
474 void CalculateOutputInMemory() override;
475 // Taking account of input
476 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
477 };
478 using PReLUCostPtr = std::shared_ptr<PReLUCost>;
479
480 class OneHotCost : public OperatorCost {
481 public:
OneHotCost()482 OneHotCost() : OperatorCost() {}
483 ~OneHotCost() override = default;
484
485 // per device communication cost
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)486 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
487 int64_t stage_id) const override {
488 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
489 }
490 double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
491 int64_t stage_id) const override;
492 double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
493 int64_t stage_id) const override;
494
495 // per device computation cost
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)496 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
497 int64_t stage_id) const override {
498 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
499 }
500 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
501 int64_t stage_id) const override;
502 double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
503 int64_t stage_id) const override;
504 // Not taking account of output
505 void CalculateOutputInMemory() override;
506 // Not taking account of input
507 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
508 };
509 using OneHotCostPtr = std::shared_ptr<OneHotCost>;
510
511 class SoftmaxCrossEntropyWithLogitsCost : public OperatorCost {
512 public:
SoftmaxCrossEntropyWithLogitsCost()513 SoftmaxCrossEntropyWithLogitsCost() : OperatorCost() {}
514 ~SoftmaxCrossEntropyWithLogitsCost() override = default;
515
516 // per device communication cost
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)517 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
518 int64_t stage_id) const override {
519 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
520 }
521 double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
522 int64_t stage_id) const override;
523 double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
524 int64_t stage_id) const override;
525
526 // per device computation cost
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)527 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
528 int64_t stage_id) const override {
529 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
530 }
531 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
532 int64_t stage_id) const override;
533 double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
534 int64_t stage_id) const override;
535 // Taking account of output
536 void CalculateOutputInMemory() override;
537 // Not taking account of input
538 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
539 };
540
541 class ReshapeCost : public OperatorCost {
542 public:
ReshapeCost()543 ReshapeCost() : OperatorCost() {}
544
545 ~ReshapeCost() override = default;
546
547 // per device communication cost
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)548 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
549 int64_t stage_id) const override {
550 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
551 }
552
553 double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
554 int64_t stage_id) const override;
555
556 double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
557 int64_t stage_id) const override;
558
559 // per device computation cost
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)560 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
561 int64_t stage_id) const override {
562 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
563 }
564
565 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
566 int64_t stage_id) const override;
567
568 double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
569 int64_t stage_id) const override;
570 // Not taking account of output
571 void CalculateOutputInMemory() override;
572 // Not taking account of input
573 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
574 };
575 using ReshapeCostPtr = std::shared_ptr<ReshapeCost>;
576
577 class SubCost : public OperatorCost {
578 public:
SubCost()579 SubCost() : OperatorCost() {}
580 ~SubCost() override = default;
581
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)582 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
583 int64_t stage_id) const override {
584 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
585 }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)586 double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
587 return 0.0;
588 }
589 double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override;
590
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)591 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
592 int64_t stage_id) const override {
593 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
594 }
595 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
596 int64_t stage_id) const override;
597 double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
598 int64_t stage_id) const override;
599 // Not taking account of output
600 void CalculateOutputInMemory() override;
601 // Not taking account of input
602 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
603 };
604 using TensorAddCost = SubCost;
605 using FloorDivCost = SubCost;
606 using AssignSubCost = SubCost;
607 using AssignAddCost = SubCost;
608 using LogicalAndCost = SubCost;
609 using LogicalOrCost = SubCost;
610 using BiasAddCost = SubCost;
611 using EqualCost = SubCost;
612 using ApproximateEqualCost = SubCost;
613 using NotEqualCost = SubCost;
614 using GreaterCost = SubCost;
615 using GreaterEqualCost = SubCost;
616 using LessCost = SubCost;
617 using LessEqualCost = SubCost;
618 using GatherNdCost = SubCost;
619
620 class MulCost : public SubCost {
621 public:
MulCost()622 MulCost() : SubCost() {}
623 ~MulCost() override = default;
624 // Taking account of input, not taking account of output
625 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
626 };
627
628 using GatherDCost = MulCost;
629
630 class DivCost : public SubCost {
631 public:
DivCost()632 DivCost() : SubCost() {}
633 ~DivCost() override = default;
634 // Taking account of output
635 void CalculateOutputInMemory() override;
636 // Taking account of input
637 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
638 };
639 using ReadDivCost = DivCost;
640
641 class ModCost : public SubCost {
642 public:
ModCost()643 ModCost() : SubCost() {}
644 ~ModCost() override = default;
645 // Taking account of input, not taking account of output
646 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
647 };
648 using FloorModCost = ModCost;
649
650 class PowCost : public SubCost {
651 public:
PowCost()652 PowCost() : SubCost() {}
653 ~PowCost() override = default;
654 // Taking account of output
655 void CalculateOutputInMemory() override;
656 // Taking account of input
657 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
658 };
659
660 class AssignCost : public SubCost {
661 public:
AssignCost()662 AssignCost() : SubCost() {}
663 ~AssignCost() override = default;
664 // Taking account of input, not taking account of output
665 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
666 };
667
668 class SigmoidCrossEntropyWithLogitsCost : public SubCost {
669 public:
SigmoidCrossEntropyWithLogitsCost()670 SigmoidCrossEntropyWithLogitsCost() : SubCost() {}
671 ~SigmoidCrossEntropyWithLogitsCost() override = default;
672 // Taking account of input, not taking account of output
673 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
674 };
675
676 class Atan2Cost : public SubCost {
677 public:
Atan2Cost()678 Atan2Cost() : SubCost() {}
679 ~Atan2Cost() override = default;
680 // Taking account of input, not taking account of output
681 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
682 };
683
684 class DivNoNanCost : public SubCost {
685 public:
DivNoNanCost()686 DivNoNanCost() : SubCost() {}
687 ~DivNoNanCost() override = default;
688 // Taking account of output
689 void CalculateOutputInMemory() override;
690 // Taking account of input
691 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
692 };
693
694 class MaximumCost : public SubCost {
695 public:
MaximumCost()696 MaximumCost() : SubCost() {}
697 ~MaximumCost() override = default;
698 // Taking account of input, not taking account of output
699 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
700 };
701 using MinimumCost = MaximumCost;
702
703 class SliceCost : public CastCost {
704 public:
SliceCost()705 SliceCost() : CastCost() {}
706 ~SliceCost() override = default;
707 // Not taking account of output, taking account of input
708 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
709 };
710
711 class StridedSliceCost : public CastCost {
712 public:
StridedSliceCost()713 StridedSliceCost() : CastCost() {}
714 ~StridedSliceCost() override = default;
715 // Not taking account of output, taking account of input
716 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
717 };
718
719 class ReduceSumCost : public OperatorCost {
720 public:
ReduceSumCost()721 ReduceSumCost() : OperatorCost() {}
722 ~ReduceSumCost() override = default;
723
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)724 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
725 int64_t stage_id) const override {
726 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
727 }
728 double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
729 int64_t stage_id) const override;
730 double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
731 int64_t stage_id) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)732 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
733 int64_t stage_id) const override {
734 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
735 }
736 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
737 int64_t stage_id) const override;
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)738 double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
739 int64_t) const override {
740 return 0.0;
741 }
set_cross_batch(bool cb)742 void set_cross_batch(bool cb) { cross_batch_ = cb; }
743 // Not taking account of output
744 void CalculateOutputInMemory() override;
745 // Taking account of input
746 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
747
748 protected:
749 bool cross_batch_ = false;
750 };
751 using ReduceMethodCost = ReduceSumCost;
752
753 class ReduceMeanCost : public ReduceSumCost {
754 public:
ReduceMeanCost()755 ReduceMeanCost() : ReduceSumCost() {}
756 ~ReduceMeanCost() override = default;
757
758 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
759 int64_t stage_id) const override;
760 };
761
762 class ReduceMinCost : public ReduceSumCost {
763 public:
ReduceMinCost()764 ReduceMinCost() : ReduceSumCost() {}
765 ~ReduceMinCost() override = default;
766 // Taking account of output
767 void CalculateOutputInMemory() override;
768 // Taking account of input
769 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
770 };
771 using ReduceMaxCost = ReduceMinCost;
772
773 class ArgMaxWithValueCost : public ReduceSumCost {
774 public:
ArgMaxWithValueCost()775 ArgMaxWithValueCost() : ReduceSumCost() {}
776 ~ArgMaxWithValueCost() override = default;
777 // Taking account of output
778 void CalculateOutputInMemory() override;
779 // Taking account of input
780 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
781 };
782 using ArgMinWithValueCost = ArgMaxWithValueCost;
783
784 class GetNextCost : public OperatorCost {
785 public:
GetNextCost()786 GetNextCost() : OperatorCost() {}
787 ~GetNextCost() override = default;
788
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)789 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
790 int64_t stage_id) const override {
791 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
792 }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)793 double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
794 return 0.0;
795 }
GetBackwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)796 double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
797 return 0.0;
798 }
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)799 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
800 int64_t stage_id) const override {
801 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
802 }
803 // Inputs vector is empty for generator ops.
GetForwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)804 double GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
805 int64_t) const override {
806 return 0.0;
807 }
808 // Generator ops don't have backward steps.
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)809 double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
810 int64_t) const override {
811 return 0.0;
812 }
813 // Not taking account of output
814 void CalculateOutputInMemory() override;
815 // Not Taking account of input
816 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
817 };
818 using GetNextCostPtr = std::shared_ptr<GetNextCost>;
819
820 class DSDMatmulCost : public OperatorCost {
821 public:
DSDMatmulCost()822 DSDMatmulCost() : OperatorCost() {}
823 ~DSDMatmulCost() override = default;
824
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)825 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
826 int64_t stage_id) const override {
827 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
828 }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)829 double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
830 return 0.0;
831 }
GetBackwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)832 double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
833 return 0.0;
834 }
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)835 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
836 int64_t stage_id) const override {
837 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
838 }
839 // Inputs vector is empty for generator ops.
840 double GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
841 int64_t) const override;
842 // Generator ops don't have backward steps.
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)843 double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
844 int64_t) const override {
845 return 0.0;
846 }
847 // Not taking account of output
848 void CalculateOutputInMemory() override;
849 // Not Taking account of input
850 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
851 };
852 using DSDMatmulCostPtr = std::shared_ptr<DSDMatmulCost>;
853
854 // For memory cost, taking account of output, not taking account of input
855 class DropOutCost : public SqrtCost {
856 public:
DropOutCost()857 DropOutCost() : SqrtCost() {}
858 ~DropOutCost() override = default;
859
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)860 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
861 int64_t stage_id) const override {
862 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
863 }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)864 double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
865 return 0.0;
866 }
GetBackwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)867 double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
868 return 0.0;
869 }
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)870 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
871 int64_t stage_id) const override {
872 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
873 }
874 double GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
875 int64_t) const override;
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)876 double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
877 int64_t) const override {
878 return 0.0;
879 }
880 };
881
882 class DropOutDoMaskCost : public DropOutCost {
883 public:
DropOutDoMaskCost()884 DropOutDoMaskCost() : DropOutCost() {}
885 ~DropOutDoMaskCost() override = default;
886 // Not taking account of output
887 void CalculateOutputInMemory() override;
888 // Taking account of input
889 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
890 };
891
892 class UnsortedSegmentSumCost : public OperatorCost {
893 public:
UnsortedSegmentSumCost()894 UnsortedSegmentSumCost() : OperatorCost() {}
895 ~UnsortedSegmentSumCost() override = default;
896
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)897 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
898 int64_t stage_id) const override {
899 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
900 }
901 double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override;
902 double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)903 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
904 int64_t stage_id) const override {
905 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
906 }
907 double GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
908 int64_t) const override;
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)909 double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
910 int64_t) const override {
911 return 0.0;
912 }
913 // Not taking account of output
914 void CalculateOutputInMemory() override;
915 // Taking account of input
916 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
917 };
918
919 class UnsortedSegmentMinCost : public OperatorCost {
920 public:
UnsortedSegmentMinCost()921 UnsortedSegmentMinCost() : OperatorCost() {}
922 ~UnsortedSegmentMinCost() override = default;
923
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)924 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
925 int64_t stage_id) const override {
926 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
927 }
928 double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override;
929 double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)930 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
931 int64_t stage_id) const override {
932 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
933 }
934 double GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
935 int64_t) const override;
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)936 double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
937 int64_t) const override {
938 return 0.0;
939 }
940 // Taking account of output
941 void CalculateOutputInMemory() override;
942 // Taking account of input
943 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
944 };
945 using UnsortedSegmentMaxCost = UnsortedSegmentMinCost;
946
947 class LayerNormCost : public OperatorCost {
948 public:
LayerNormCost()949 LayerNormCost() : OperatorCost() {}
950 ~LayerNormCost() override = default;
951
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)952 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
953 int64_t stage_id) const override {
954 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
955 }
GetForwardCommCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)956 double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override {
957 return 0.0;
958 }
959 double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)960 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
961 int64_t stage_id) const override {
962 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
963 }
964 double GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
965 int64_t) const override;
GetBackwardComputationCost(const std::vector<TensorInfo> &,const std::vector<TensorInfo> &,int64_t)966 double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
967 int64_t) const override {
968 return 0.0;
969 }
970 // Taking account of output
971 void CalculateOutputInMemory() override;
972 // Taking account of input
973 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
974 };
975
976 class UniqueCost : public OperatorCost {
977 public:
UniqueCost()978 UniqueCost() : OperatorCost() {}
979 ~UniqueCost() override = default;
980
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)981 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
982 int64_t stage_id) const override {
983 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
984 }
985 double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
986 int64_t stage_id) const override;
987 double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
988 int64_t stage_id) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)989 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
990 int64_t stage_id) const override {
991 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
992 }
993 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
994 int64_t stage_id) const override;
995 double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
996 int64_t) const override;
997 // Taking account of output
998 void CalculateOutputInMemory() override;
999 // Not Taking account of input
1000 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
1001 };
1002
1003 class UniformCandidateSamplerCost : public OperatorCost {
1004 public:
UniformCandidateSamplerCost()1005 UniformCandidateSamplerCost() : OperatorCost() {}
1006 ~UniformCandidateSamplerCost() override = default;
1007
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1008 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1009 int64_t stage_id) const override {
1010 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
1011 }
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1012 double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1013 int64_t stage_id) const override {
1014 return 0;
1015 }
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1016 double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1017 int64_t stage_id) const override {
1018 return 0;
1019 }
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1020 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1021 int64_t stage_id) const override {
1022 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
1023 }
1024 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1025 int64_t stage_id) const override;
GetBackwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t)1026 double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1027 int64_t) const override {
1028 return 0.0;
1029 }
1030 // Not taking account of output
1031 void CalculateOutputInMemory() override;
1032 // Not Taking account of input
1033 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
1034 };
1035
1036 class GatherV2Cost : public OperatorCost {
1037 public:
GatherV2Cost()1038 GatherV2Cost() : OperatorCost() {}
1039 ~GatherV2Cost() override = default;
1040
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1041 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1042 int64_t stage_id) const override {
1043 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
1044 }
1045 double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1046 int64_t stage_id) const override;
1047 double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1048 int64_t stage_id) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1049 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1050 int64_t stage_id) const override {
1051 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
1052 }
1053 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1054 int64_t stage_id) const override;
1055 double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1056 int64_t) const override;
1057 // Not taking account of output
1058 void CalculateOutputInMemory() override;
1059 // Taking account of input
1060 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
1061 };
1062
1063 class GatherV2PCost : public GatherV2Cost {
1064 public:
GatherV2PCost()1065 GatherV2PCost() : GatherV2Cost(), axis_(0) {}
1066 ~GatherV2PCost() override = default;
1067
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1068 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1069 int64_t stage_id) const override {
1070 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
1071 }
1072 double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1073 int64_t stage_id) const override;
1074 double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1075 int64_t stage_id) const override;
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1076 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1077 int64_t stage_id) const override {
1078 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
1079 }
1080 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1081 int64_t stage_id) const override;
1082 double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1083 int64_t) const override;
set_axis(int64_t axis)1084 void set_axis(int64_t axis) { axis_ = axis; }
set_strategy(const Shape & strategy)1085 void set_strategy(const Shape &strategy) { strategy_ = strategy; }
1086
1087 protected:
1088 int64_t axis_;
1089 Shape strategy_;
1090 };
1091
1092 class MatmulDDSCost : public OperatorCost {
1093 public:
MatmulDDSCost()1094 MatmulDDSCost() : OperatorCost() {}
1095 ~MatmulDDSCost() override = default;
1096
1097 // per device communication cost
GetCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1098 double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1099 int64_t stage_id) const override {
1100 return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
1101 }
GetForwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1102 double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1103 int64_t stage_id) const override {
1104 return 0.0;
1105 };
GetBackwardCommCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1106 double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1107 int64_t stage_id) const override {
1108 return 0.0;
1109 };
1110
1111 // per device computation cost
GetComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1112 double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1113 int64_t stage_id) const override {
1114 return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
1115 }
1116 double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1117 int64_t stage_id) const override;
GetBackwardComputationCost(const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,int64_t stage_id)1118 double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
1119 int64_t stage_id) const override {
1120 return 0.0;
1121 };
1122 // Not taking account of output
1123 void CalculateOutputInMemory() override;
1124 // Taking account of input
1125 void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
1126 };
1127 using MatmulDDSCostPtr = std::shared_ptr<MatmulDDSCost>;
1128
1129 } // namespace parallel
1130 } // namespace mindspore
1131 #endif // PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_
1132