• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_
19 
20 #include <ir/value.h>
21 #include <memory>
22 #include <string>
23 #include <unordered_map>
24 #include <vector>
25 
26 #include "frontend/parallel/auto_parallel/operator_costmodel.h"
27 #include "frontend/parallel/ops_info/operator_info.h"
28 #include "frontend/parallel/strategy.h"
29 
30 namespace mindspore {
31 namespace parallel {
32 class ActivationBase : public OperatorInfo {
33  public:
ActivationBase(const std::string & operator_name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs,OperatorCostPtr cost)34   ActivationBase(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
35                  const PrimitiveAttrs &attrs, OperatorCostPtr cost)
36       : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, cost) {}
37   ~ActivationBase() override = default;
38 
39   Status Init(const StrategyPtr &strategy) override;
40   Status InitForCostModel(const StrategyPtr &strategy) override;
41 
42  protected:
43   Status InferMirrorOps() override;
44   Status InferForwardCommunication() override;
45   Status InferTensorMap() override;
46   Status InferDevMatrixShape() override;
47 };
48 
49 class Activation : public ActivationBase {
50  public:
Activation(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs,OperatorCostPtr cost)51   Activation(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
52              const PrimitiveAttrs &attrs, OperatorCostPtr cost)
53       : ActivationBase(name, inputs_shape, outputs_shape, attrs, cost) {}
54   ~Activation() override = default;
55   std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
56   Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
57 
58  protected:
59   Status CheckStrategy(const StrategyPtr &strategy) override;
60 };
61 
62 class ActivationInfo : public Activation {
63  public:
ActivationInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)64   ActivationInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
65                  const PrimitiveAttrs &attrs)
66       : Activation(name, inputs_shape, outputs_shape, attrs, std::make_shared<ActivationInfoCost>()) {}
67   ~ActivationInfo() override = default;
68 
69  protected:
70   Status GetAttrs() override;  // activation_type: relu, relu6, sigmoid
71 };
72 
73 class ActivationOther : public Activation {
74  public:
ActivationOther(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs,OperatorCostPtr cost)75   ActivationOther(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
76                   const PrimitiveAttrs &attrs, OperatorCostPtr cost)
77       : Activation(name, inputs_shape, outputs_shape, attrs, cost) {}
78   ~ActivationOther() override = default;
79 
80  protected:
81   Status GetAttrs() override;
82 };
83 
84 class GeLUInfo : public ActivationOther {
85  public:
GeLUInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)86   GeLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
87            const PrimitiveAttrs &attrs)
88       : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<GeLUCost>()) {}
89   ~GeLUInfo() override = default;
90 };
91 
92 class FastGeLUInfo : public ActivationOther {
93  public:
FastGeLUInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)94   FastGeLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
95                const PrimitiveAttrs &attrs)
96       : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<FastGeLUCost>()) {}
97   ~FastGeLUInfo() override = default;
98 };
99 
100 class TanhInfo : public ActivationOther {
101  public:
TanhInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)102   TanhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
103            const PrimitiveAttrs &attrs)
104       : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<TanhCost>()) {}
105   ~TanhInfo() override = default;
106 };
107 
108 class Softmax : public ActivationBase {
109  public:
Softmax(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)110   explicit Softmax(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
111                    const PrimitiveAttrs &attrs)
112       : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCost>()) {}
113   ~Softmax() override = default;
114   std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
115   Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
116 
117  protected:
118   Status CheckStrategy(const StrategyPtr &strategy) override;
119   Status GetAttrs() override;
120 
121  private:
122   std::vector<int64_t> axis_;
123 };
124 
125 class SoftmaxInfo : public Softmax {
126  public:
SoftmaxInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)127   SoftmaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
128               const PrimitiveAttrs &attrs)
129       : Softmax(name, inputs_shape, outputs_shape, attrs) {}
130   ~SoftmaxInfo() override = default;
131 };
132 
133 class LogSoftmaxInfo : public Softmax {
134  public:
LogSoftmaxInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)135   LogSoftmaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
136                  const PrimitiveAttrs &attrs)
137       : Softmax(name, inputs_shape, outputs_shape, attrs) {}
138   ~LogSoftmaxInfo() override = default;
139 };
140 
141 class EluInfo : public ActivationOther {
142  public:
EluInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)143   EluInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
144       : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<EluCost>()) {}
145   ~EluInfo() override = default;
146 };
147 
148 class ReLUInfo : public ActivationOther {
149  public:
ReLUInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)150   ReLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
151            const PrimitiveAttrs &attrs)
152       : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReLUCost>()) {}
153   ~ReLUInfo() override = default;
154 };
155 
156 class RepeatElementsInfo : public ActivationOther {
157  public:
RepeatElementsInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)158   RepeatElementsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
159                      const PrimitiveAttrs &attrs)
160       : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<RepeatElementsCost>()) {}
161   ~RepeatElementsInfo() override = default;
162 };
163 
164 class ReLU6Info : public ActivationOther {
165  public:
ReLU6Info(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)166   ReLU6Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
167             const PrimitiveAttrs &attrs)
168       : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReLU6Cost>()) {}
169   ~ReLU6Info() override = default;
170 };
171 
172 class SoftsignInfo : public ActivationOther {
173  public:
SoftsignInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)174   SoftsignInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
175                const PrimitiveAttrs &attrs)
176       : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftsignCost>()) {}
177   ~SoftsignInfo() override = default;
178 };
179 
180 class SoftplusInfo : public ActivationOther {
181  public:
SoftplusInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)182   SoftplusInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
183                const PrimitiveAttrs &attrs)
184       : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftplusCost>()) {}
185   ~SoftplusInfo() override = default;
186 };
187 
188 class CastInfo : public ActivationOther {
189  public:
CastInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)190   CastInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
191            const PrimitiveAttrs &attrs)
192       : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<CastCost>()) {}
193   ~CastInfo() override = default;
194 
195  protected:
196   Status InferMirrorOps() override;
197 };
198 
199 class SqrtInfo : public ActivationOther {
200  public:
SqrtInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)201   SqrtInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
202            const PrimitiveAttrs &attrs)
203       : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SqrtCost>()) {}
204   ~SqrtInfo() override = default;
205 };
206 
207 class NegInfo : public ActivationOther {
208  public:
NegInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)209   NegInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
210       : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<NegCost>()) {}
211   ~NegInfo() override = default;
212 };
213 
214 class ExpandDimsInfo : public ActivationOther {
215  public:
ExpandDimsInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)216   ExpandDimsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
217                  const PrimitiveAttrs &attrs)
218       : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ExpandDimsCost>()) {}
219   ~ExpandDimsInfo() override = default;
220 
221  protected:
222   Status GetAttrs() override;
223   Status InferTensorMap() override;
224   Status InferMirrorOps() override;
225   Status InferTensorStrategy();
226 
227  private:
228   int64_t positive_axis_ = -1;
229   Strategys inputs_strategy_;
230   Strategys outputs_strategy_;
231 };
232 
233 class SqueezeInfo : public ActivationOther {
234  public:
SqueezeInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)235   SqueezeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
236               const PrimitiveAttrs &attrs)
237       : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SqueezeCost>()) {}
238   ~SqueezeInfo() override = default;
239 
240  protected:
241   Status InferAxis(const ValueTuplePtr &value_tuple);
242   Status GetAttrs() override;
243   Status InferReplaceOps();
244   Status InferTensorMap() override;
245   Status Init(const StrategyPtr &strategy) override;
246 
247  private:
248   ValueTuplePtr axis_;
249 };
250 
251 class SquareInfo : public ActivationOther {
252  public:
SquareInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)253   SquareInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
254              const PrimitiveAttrs &attrs)
255       : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SquareCost>()) {}
256   ~SquareInfo() override = default;
257 };
258 
259 class SigmoidInfo : public ActivationOther {
260  public:
SigmoidInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)261   SigmoidInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
262               const PrimitiveAttrs &attrs)
263       : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SigmoidCost>()) {}
264   ~SigmoidInfo() override = default;
265 };
266 
267 class DropoutInfo : public ActivationOther {
268  public:
DropoutInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)269   DropoutInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
270               const PrimitiveAttrs &attrs)
271       : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<DropOutCost>()) {}
272   ~DropoutInfo() override = default;
273   std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
274   Status Init(const StrategyPtr &strategy) override;
275 
276  protected:
277   Status GetAttrs() override;
278   Status InferTensorMap() override;
279   Status InferReplaceOps();
280   Status InferAsLossDivisor() override;
281 
282  private:
283   int64_t seed0_ = 0;
284   int64_t seed1_ = 0;
get_seed()285   int64_t get_seed() {
286     static int64_t SEED_NUM;
287     return ++SEED_NUM;
288   }
289 };
290 }  // namespace parallel
291 }  // namespace mindspore
292 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_
293