• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_OP_NODE_H_
17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_OP_NODE_H_
18 
19 #include <memory>
20 #include <algorithm>
21 #include <sstream>
22 #include <string>
23 #include <unordered_map>
24 #include <functional>
25 
26 #include "backend/optimizer/graph_kernel/model/node.h"
27 #include "backend/kernel_compiler/common_utils.h"
28 #include "ir/dtype/type.h"
29 
30 namespace mindspore {
31 namespace opt {
32 namespace graphkernel {
33 #define CHECK_ATTR(attrs, attr_name)                                                              \
34   do {                                                                                            \
35     if (attrs.count(attr_name) == 0) {                                                            \
36       MS_LOG(EXCEPTION) << "The attr [" << attr_name << "] does not exist in [" << #attrs << "]"; \
37     }                                                                                             \
38   } while (0)
39 
40 class PrimOp : public Node {
41  public:
42   enum ComputeType {
43     RESHAPE,
44     ELEMWISE,
45     BROADCAST,
46     REDUCE,
47     OPAQUE,
48   };
49 
PrimOp(const std::string & op,const std::string & node_name,ComputeType compute)50   PrimOp(const std::string &op, const std::string &node_name, ComputeType compute)
51       : Node({{}, TypeId::kNumberTypeBegin, kOpFormat_DEFAULT}, node_name), op_(op), compute_type_(compute) {}
52   ~PrimOp() = default;
53 
54   virtual NodeBase Infer(const NodePtrList &inputs, const DAttrs &attrs);
55   virtual NodePtr InferValue(const NodePtrList &inputs, const DAttrs &attrs, const std::string &op);
56 
57   void Dump(std::ostringstream &os) const override;
NodeType()58   NType NodeType() override { return NType::Primitive; }
59 
op()60   const std::string &op() const { return op_; }
compute_type()61   ComputeType compute_type() const { return compute_type_; }
62 
63  protected:
64   virtual void Check(const NodePtrList &inputs, const DAttrs &attrs);
CheckShape(const NodePtrList & inputs,const DAttrs & attrs)65   virtual void CheckShape(const NodePtrList &inputs, const DAttrs &attrs) {}
66   virtual void CheckType(const NodePtrList &inputs, const DAttrs &attrs);
67   virtual void CheckFormat(const NodePtrList &inputs, const DAttrs &attrs);
68 
InferShape(const NodePtrList & inputs,const DAttrs & attrs)69   virtual DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->shape; }
InferType(const NodePtrList & inputs,const DAttrs & attrs)70   virtual TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->type; }
InferFormat(const NodePtrList & inputs,const DAttrs & attrs)71   virtual DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->format; }
72 
73   std::string op_;
74   ComputeType compute_type_;
75 };
76 using PrimOpPtr = std::shared_ptr<PrimOp>;
77 
78 class ElemwiseOp : public PrimOp {
79  public:
ElemwiseOp(const std::string & op,const std::string & node_name)80   ElemwiseOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, ELEMWISE) {}
81   ~ElemwiseOp() = default;
82 
83   NodeBase Infer(const NodePtrList &inputs, const DAttrs &attrs) override;
84 
85  protected:
86   DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
87   DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override;
88 };
89 
90 class CastOp : public ElemwiseOp {
91  public:
CastOp(const std::string & op,const std::string & node_name)92   CastOp(const std::string &op, const std::string &node_name) : ElemwiseOp("Cast", node_name) {}
93   ~CastOp() = default;
94 
95  protected:
96   TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override;
97 };
98 
99 class InplaceAssignOp : public ElemwiseOp {
100  public:
InplaceAssignOp(const std::string & op,const std::string & node_name)101   InplaceAssignOp(const std::string &op, const std::string &node_name) : ElemwiseOp("InplaceAssign", node_name) {}
102   ~InplaceAssignOp() = default;
103 
104  protected:
InferShape(const NodePtrList & inputs,const DAttrs & attrs)105   DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override { return inputs[2]->shape; }
InferType(const NodePtrList & inputs,const DAttrs & attrs)106   TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return inputs[2]->type; }
InferFormat(const NodePtrList & inputs,const DAttrs & attrs)107   DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override { return inputs[2]->format; }
108 };
109 
110 class SelectOp : public ElemwiseOp {
111  public:
SelectOp(const std::string & op,const std::string & node_name)112   SelectOp(const std::string &op, const std::string &node_name) : ElemwiseOp("Select", node_name) {}
113   ~SelectOp() = default;
114 
115  protected:
116   void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override;
InferType(const NodePtrList & inputs,const DAttrs & attrs)117   TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return inputs[1]->type; }
118 };
119 
120 class CompareOp : public ElemwiseOp {
121  public:
CompareOp(const std::string & op,const std::string & node_name)122   CompareOp(const std::string &op, const std::string &node_name) : ElemwiseOp(op, node_name) {}
123   ~CompareOp() = default;
124 
125  protected:
InferType(const NodePtrList & inputs,const DAttrs & attrs)126   TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return TypeId::kNumberTypeBool; }
127 };
128 
129 class LessOp : public CompareOp {
130  public:
LessOp(const std::string & op,const std::string & node_name)131   LessOp(const std::string &op, const std::string &node_name) : CompareOp("Less", node_name) {}
132   ~LessOp() = default;
133 };
134 
135 class EqualOp : public CompareOp {
136  public:
EqualOp(const std::string & op,const std::string & node_name)137   EqualOp(const std::string &op, const std::string &node_name) : CompareOp("Equal", node_name) {}
138   ~EqualOp() = default;
139 };
140 
141 class LessEqualOp : public CompareOp {
142  public:
LessEqualOp(const std::string & op,const std::string & node_name)143   LessEqualOp(const std::string &op, const std::string &node_name) : CompareOp("LessEqual", node_name) {}
144   ~LessEqualOp() = default;
145 };
146 
147 class GreaterOp : public CompareOp {
148  public:
GreaterOp(const std::string & op,const std::string & node_name)149   GreaterOp(const std::string &op, const std::string &node_name) : CompareOp("Greater", node_name) {}
150   ~GreaterOp() = default;
151 };
152 
153 class GreaterEqualOp : public CompareOp {
154  public:
GreaterEqualOp(const std::string & op,const std::string & node_name)155   GreaterEqualOp(const std::string &op, const std::string &node_name) : CompareOp("GreaterEqual", node_name) {}
156   ~GreaterEqualOp() = default;
157 };
158 
159 class ReshapeOp : public PrimOp {
160  public:
ReshapeOp(const std::string & op,const std::string & node_name)161   ReshapeOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, RESHAPE) {}
162   ~ReshapeOp() = default;
163 
164  protected:
165   DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
InferFormat(const NodePtrList & inputs,const DAttrs & attrs)166   DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override {
167     return attrs.find("format") == attrs.end() ? kOpFormat_DEFAULT
168                                                : GetValue<std::string>(attrs.find("format")->second);
169   }
170 };
171 
172 class BroadcastToOp : public PrimOp {
173  public:
BroadcastToOp(const std::string & op,const std::string & node_name)174   BroadcastToOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, BROADCAST) {}
175   ~BroadcastToOp() = default;
176 
177  protected:
178   DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
179 };
180 
181 class ReduceOp : public PrimOp {
182  public:
ReduceOp(const std::string & op,const std::string & node_name)183   ReduceOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, REDUCE) {}
184   ~ReduceOp() = default;
185 
186  protected:
187   void Check(const NodePtrList &inputs, const DAttrs &attrs) override;
188   DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
InferFormat(const NodePtrList & inputs,const DAttrs & attrs)189   DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override { return kOpFormat_DEFAULT; };
190 };
191 
192 class OpaqueOp : public PrimOp {
193  public:
OpaqueOp(const std::string & op,const std::string & node_name)194   OpaqueOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, OPAQUE) {}
195   ~OpaqueOp() = default;
196 };
197 
198 class Conv2dOp : public OpaqueOp {
199  public:
Conv2dOp(const std::string & op,const std::string & node_name)200   Conv2dOp(const std::string &op, const std::string &node_name) : OpaqueOp("Conv2D", node_name) {}
201   ~Conv2dOp() = default;
202 
203  protected:
204   DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
205   TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override;
206 };
207 
208 class TransposeOp : public OpaqueOp {
209  public:
TransposeOp(const std::string & op,const std::string & node_name)210   TransposeOp(const std::string &op, const std::string &node_name) : OpaqueOp("Transpose", node_name) {}
211   ~TransposeOp() = default;
212 
213  protected:
214   DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
215   DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override;
216 };
217 
218 class MatMulOp : public OpaqueOp {
219  public:
MatMulOp(const std::string & op,const std::string & node_name)220   MatMulOp(const std::string &op, const std::string &node_name) : OpaqueOp("MatMul", node_name) {}
221   ~MatMulOp() = default;
222 
223  protected:
224   DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
225   TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override;
226 };
227 
228 class PadAkgOp : public OpaqueOp {
229  public:
PadAkgOp(const std::string & op,const std::string & node_name)230   PadAkgOp(const std::string &op, const std::string &node_name) : OpaqueOp("PadAkg", node_name) {}
231   ~PadAkgOp() = default;
232 
233  protected:
234   DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
235 };
236 
237 class UnPadAkgOp : public OpaqueOp {
238  public:
UnPadAkgOp(const std::string & op,const std::string & node_name)239   UnPadAkgOp(const std::string &op, const std::string &node_name) : OpaqueOp("UnPadAkg", node_name) {}
240   ~UnPadAkgOp() = default;
241 
242  protected:
243   DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
244 };
245 
246 class CImagOp : public ElemwiseOp {
247  public:
CImagOp(const std::string & op,const std::string & node_name)248   CImagOp(const std::string &op, const std::string &node_name) : ElemwiseOp("CImag", node_name) {}
249   ~CImagOp() = default;
250 
251  protected:
CheckType(const NodePtrList & inputs,const DAttrs & attrs)252   void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override {
253     if (inputs[0]->type != TypeId::kNumberTypeComplex64) {
254       MS_LOG(EXCEPTION) << "CImag's input[0] should be complex64";
255     }
256   };
257 
InferType(const NodePtrList & inputs,const DAttrs & attrs)258   TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return TypeId::kNumberTypeFloat32; }
259 };
260 
261 class CRealOp : public ElemwiseOp {
262  public:
CRealOp(const std::string & op,const std::string & node_name)263   CRealOp(const std::string &op, const std::string &node_name) : ElemwiseOp("CReal", node_name) {}
264   ~CRealOp() = default;
265 
266  protected:
CheckType(const NodePtrList & inputs,const DAttrs & attrs)267   void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override {
268     if (inputs[0]->type != TypeId::kNumberTypeComplex64) {
269       MS_LOG(EXCEPTION) << "CReal's input[0] should be complex64";
270     }
271   };
272 
InferType(const NodePtrList & inputs,const DAttrs & attrs)273   TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return TypeId::kNumberTypeFloat32; }
274 };
275 
276 class ComplexOp : public ElemwiseOp {
277  public:
ComplexOp(const std::string & op,const std::string & node_name)278   ComplexOp(const std::string &op, const std::string &node_name) : ElemwiseOp("Complex", node_name) {}
279   ~ComplexOp() = default;
280 
281  protected:
282   void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override;
InferType(const NodePtrList & inputs,const DAttrs & attrs)283   TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return TypeId::kNumberTypeComplex64; }
284 };
285 
286 class StandardNormalOp : public OpaqueOp {
287  public:
StandardNormalOp(const std::string & op,const std::string & node_name)288   StandardNormalOp(const std::string &op, const std::string &node_name) : OpaqueOp("StandardNormal", node_name) {}
289   ~StandardNormalOp() = default;
290 
291  protected:
292   DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
InferType(const NodePtrList & inputs,const DAttrs & attrs)293   TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return TypeId::kNumberTypeFloat32; }
InferFormat(const NodePtrList & inputs,const DAttrs & attrs)294   DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override { return kOpFormat_DEFAULT; }
295 };
296 }  // namespace graphkernel
297 }  // namespace opt
298 }  // namespace mindspore
299 #endif
300