• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_ARITHMETIC_INFO_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_ARITHMETIC_INFO_H_
19 
20 #include <memory>
21 #include <string>
22 #include <unordered_map>
23 #include <vector>
24 
25 #include "ir/value.h"
26 #include "frontend/parallel/auto_parallel/operator_costmodel.h"
27 #include "frontend/parallel/ops_info/operator_info.h"
28 #include "frontend/parallel/strategy.h"
29 
30 namespace mindspore {
31 namespace parallel {
32 class ArithmeticBase : public OperatorInfo {
33  public:
ArithmeticBase(const std::string & operator_name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs,OperatorCostPtr cost)34   ArithmeticBase(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
35                  const PrimitiveAttrs &attrs, OperatorCostPtr cost)
36       : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, cost) {}
37   ~ArithmeticBase() override = default;
38   Status Init(const StrategyPtr &strategy) override;
39   Status InitForCostModel(const StrategyPtr &strategy) override;
40   std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
41   Status SetCostUnderStrategy(const StrategyPtr &) override;
42   void ReComputeBatchSplitFlagList() override;
43 
44  protected:
GetAttrs()45   Status GetAttrs() override { return SUCCESS; }
46   Status CheckStrategy(const StrategyPtr &strategy) override;
InferForwardCommunication()47   Status InferForwardCommunication() override { return SUCCESS; }
48   Status InferDevMatrixShape() override;
49   Status InferTensorMap() override;
50   Shapes InferExpendShape();
51 };
52 
53 class SubInfo : public ArithmeticBase {
54  public:
SubInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)55   SubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
56       : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<SubCost>()) {}
57   ~SubInfo() override = default;
58 };
59 
60 class AddInfo : public ArithmeticBase {
61  public:
AddInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)62   AddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
63       : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<TensorAddCost>()) {}
64   ~AddInfo() override = default;
65 };
66 
67 class MulInfo : public ArithmeticBase {
68  public:
MulInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)69   MulInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
70       : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<MulCost>()) {}
71   ~MulInfo() override = default;
72 };
73 
74 class DivInfo : public ArithmeticBase {
75  public:
DivInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)76   DivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
77       : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<DivCost>()) {}
78   ~DivInfo() override = default;
79 };
80 
81 class ModInfo : public ArithmeticBase {
82  public:
ModInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)83   ModInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
84       : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ModCost>()) {}
85   ~ModInfo() override = default;
86 };
87 
88 class RealDivInfo : public ArithmeticBase {
89  public:
RealDivInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)90   RealDivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
91               const PrimitiveAttrs &attrs)
92       : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReadDivCost>()) {}
93   ~RealDivInfo() override = default;
94 };
95 
96 class FloorDivInfo : public ArithmeticBase {
97  public:
FloorDivInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)98   FloorDivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
99                const PrimitiveAttrs &attrs)
100       : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<FloorDivCost>()) {}
101   ~FloorDivInfo() override = default;
102 };
103 
104 class FloorModInfo : public ArithmeticBase {
105  public:
FloorModInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)106   FloorModInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
107                const PrimitiveAttrs &attrs)
108       : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<FloorModCost>()) {}
109   ~FloorModInfo() override = default;
110 };
111 
112 class PowInfo : public ArithmeticBase {
113  public:
PowInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)114   PowInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
115       : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<PowCost>()) {}
116   ~PowInfo() override = default;
117 };
118 
119 class AssignSubInfo : public ArithmeticBase {
120  public:
AssignSubInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)121   AssignSubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
122                 const PrimitiveAttrs &attrs)
123       : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<AssignSubCost>()) {}
124   ~AssignSubInfo() override = default;
125 };
126 
127 class AssignInfo : public ArithmeticBase {
128  public:
AssignInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)129   AssignInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
130              const PrimitiveAttrs &attrs)
131       : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<AssignCost>()) {}
132   ~AssignInfo() override = default;
133 };
134 
135 class AssignAddInfo : public ArithmeticBase {
136  public:
AssignAddInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)137   AssignAddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
138                 const PrimitiveAttrs &attrs)
139       : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<AssignAddCost>()) {}
140   ~AssignAddInfo() override = default;
141 };
142 
143 // All dimensions can be split arbitrarily, but the split method of Logits should be the same as that of label.
144 class SigmoidCrossEntropyWithLogitsInfo : public ArithmeticBase {
145  public:
SigmoidCrossEntropyWithLogitsInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)146   SigmoidCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
147                                     const PrimitiveAttrs &attrs)
148       : ArithmeticBase(name, inputs_shape, outputs_shape, attrs,
149                        std::make_shared<SigmoidCrossEntropyWithLogitsCost>()) {}
150   ~SigmoidCrossEntropyWithLogitsInfo() override = default;
151 };
152 
153 class Atan2Info : public ArithmeticBase {
154  public:
Atan2Info(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)155   Atan2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
156             const PrimitiveAttrs &attrs)
157       : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<Atan2Cost>()) {}
158   ~Atan2Info() override = default;
159 };
160 
161 class DivNoNanInfo : public ArithmeticBase {
162  public:
DivNoNanInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)163   DivNoNanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
164                const PrimitiveAttrs &attrs)
165       : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<DivNoNanCost>()) {}
166   ~DivNoNanInfo() override = default;
167 };
168 
169 class LogicalAndInfo : public ArithmeticBase {
170  public:
LogicalAndInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)171   LogicalAndInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
172                  const PrimitiveAttrs &attrs)
173       : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<LogicalAndCost>()) {}
174   ~LogicalAndInfo() override = default;
175 };
176 
177 class LogicalOrInfo : public ArithmeticBase {
178  public:
LogicalOrInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)179   LogicalOrInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
180                 const PrimitiveAttrs &attrs)
181       : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<LogicalOrCost>()) {}
182   ~LogicalOrInfo() override = default;
183 };
184 }  // namespace parallel
185 }  // namespace mindspore
186 
187 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_ARITHMETIC_INFO_H_
188