• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_OP_NODE_H_
17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_OP_NODE_H_
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 #include <utility>
23 #include "ops/primitive_c.h"
24 #include "backend/common/graph_kernel/model/node.h"
25 #include "ir/dtype/type.h"
26 #include "include/backend/visible.h"
27 
28 namespace mindspore::graphkernel::inner {
29 #define CHECK_ATTR(attrs, attr_name)                                                              \
30   do {                                                                                            \
31     if (attrs.count(attr_name) == 0) {                                                            \
32       MS_LOG(EXCEPTION) << "The attr [" << attr_name << "] does not exist in [" << #attrs << "]"; \
33     }                                                                                             \
34   } while (0)
35 
36 class BACKEND_EXPORT PrimOp : public Node {
37  public:
38   enum class ComputeType : int {
39     VIRTUAL = 0,
40     RESHAPE = 1,
41     ELEMWISE = 2,
42     BROADCAST = 3,
43     REDUCE = 4,
44     OPAQUE = 5,
45   };
46 
PrimOp(const std::string & op,ComputeType compute)47   PrimOp(const std::string &op, ComputeType compute)
48       : Node({{}, TypeId::kNumberTypeBegin, kOpFormat_DEFAULT}), op_(op), compute_type_(compute) {}
49   ~PrimOp() = default;
50 
51   NodeBaseList Infer(const NodePtrList &inputs, const DAttrs &attrs);
52 
53   std::string ToString() const override;
NodeType()54   NType NodeType() override { return NType::Primitive; }
55 
op()56   const std::string &op() const { return op_; }
compute_type()57   ComputeType compute_type() const { return compute_type_; }
58   // infer output value when all inputs are constant
59   virtual NodePtr InferValue(const NodePtrList &inputs, const DAttrs &attrs);
60 
61  protected:
62   // Check node info before inference the shape/type/format.
Check(const NodePtrList &,const DAttrs &)63   virtual void Check(const NodePtrList &, const DAttrs &) {}
64 
65   // Infer format. assume all outputs have the same format.
InferFormat(const NodePtrList & inputs,const DAttrs &)66   virtual DFormat InferFormat(const NodePtrList &inputs, const DAttrs &) { return inputs[0]->format; }
67 
68   // Infer shape. returning an empty vector means using PrimitiveC's infer_shape function.
69   virtual std::vector<DShape> InferShape(const NodePtrList &, const DAttrs &);
70 
71   // Infer type. returning an empty vector means using PrimitiveC's infer_type function.
72   virtual std::vector<TypeId> InferType(const NodePtrList &, const DAttrs &);
73 
74   // calculate const inputs, used for InferValue
75   template <typename TM>
76   tensor::TensorPtr CalcByOperator(const NodePtrList &inputs, const DAttrs &) const;
77 
78   // Gen PrimitiveC and abstract list to call PrimitiveC's inference function.
79   std::pair<PrimitivePtr, AbstractBasePtrList> GenPrimAndAbstract(const NodePtrList &inputs, const DAttrs &attrs) const;
80 
81   // rectify abstract before calling PrimitiveC's inference function.
RectifyAbstract(const PrimitivePtr &,AbstractBasePtrList *)82   virtual void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *) {}
83 
84   std::string op_;
85   ComputeType compute_type_;
86 };
87 using PrimOpPtr = std::shared_ptr<PrimOp>;
88 
89 class ReshapeOp : public PrimOp {
90  public:
ReshapeOp(const std::string & op)91   explicit ReshapeOp(const std::string &op) : PrimOp(op, ComputeType::RESHAPE) {}
92   ~ReshapeOp() = default;
93   NodePtr InferValue(const NodePtrList &inputs, const DAttrs &) override;
94 
95  protected:
InferFormat(const NodePtrList &,const DAttrs & attrs)96   DFormat InferFormat(const NodePtrList &, const DAttrs &attrs) override {
97     return attrs.find("format") == attrs.end() ? kOpFormat_DEFAULT
98                                                : GetValue<std::string>(attrs.find("format")->second);
99   }
100 };
101 
102 class ElemwiseOp : public PrimOp {
103  public:
ElemwiseOp(const std::string & op)104   explicit ElemwiseOp(const std::string &op) : PrimOp(op, ComputeType::ELEMWISE) {}
105   ~ElemwiseOp() = default;
106 
107  protected:
108   std::vector<DShape> InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
109   DFormat InferFormat(const NodePtrList &inputs, const DAttrs &) override;
110 };
111 
112 class BroadcastOp : public PrimOp {
113  public:
BroadcastOp(const std::string & op)114   explicit BroadcastOp(const std::string &op) : PrimOp(op, ComputeType::BROADCAST) {}
115   ~BroadcastOp() = default;
116 };
117 
118 class TileOp : public BroadcastOp {
119  public:
TileOp(const std::string & op)120   explicit TileOp(const std::string &op) : BroadcastOp(op) {}
121   ~TileOp() = default;
122 };
123 
124 class ReduceOp : public PrimOp {
125  public:
ReduceOp(const std::string & op)126   explicit ReduceOp(const std::string &op) : PrimOp(op, ComputeType::REDUCE) {}
127   ~ReduceOp() = default;
128 
129  protected:
InferFormat(const NodePtrList &,const DAttrs &)130   DFormat InferFormat(const NodePtrList &, const DAttrs &) override { return kOpFormat_DEFAULT; };
131   void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) override;
132 };
133 
134 class ArgReduceOp : public ReduceOp {
135  public:
ArgReduceOp(const std::string & op)136   explicit ArgReduceOp(const std::string &op) : ReduceOp(op) {}
137   ~ArgReduceOp() = default;
138 
139  protected:
140   std::vector<DShape> InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
141   std::vector<TypeId> InferType(const NodePtrList &, const DAttrs &attrs) override;
142   void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) override;
143 };
144 
145 class OpaqueOp : public PrimOp {
146  public:
OpaqueOp(const std::string & op)147   explicit OpaqueOp(const std::string &op) : PrimOp(op, ComputeType::OPAQUE) {}
148   ~OpaqueOp() = default;
149 
150  protected:
151   // for pclint warning: 1790 public base symbol of symbol has no non-destructor virtual functions
DoNothing()152   virtual void DoNothing() {}
153 };
154 
155 class VirtualOp : public PrimOp {
156  public:
VirtualOp(const std::string & op)157   explicit VirtualOp(const std::string &op) : PrimOp(op, ComputeType::VIRTUAL) {}
158   ~VirtualOp() = default;
159 };
160 
161 class TransposeOp : public OpaqueOp {
162  public:
TransposeOp(const std::string & op)163   explicit TransposeOp(const std::string &op) : OpaqueOp(op) {}
164   ~TransposeOp() = default;
165 
166  protected:
167   DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override;
168 };
169 
170 class OneHotOp : public OpaqueOp {
171  public:
OneHotOp(const std::string & op)172   explicit OneHotOp(const std::string &op) : OpaqueOp(op) {}
173   ~OneHotOp() = default;
174 
175  protected:
176   void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) override;
177 };
178 
179 class CumSumOp : public OpaqueOp {
180  public:
CumSumOp(const std::string & op)181   explicit CumSumOp(const std::string &op) : OpaqueOp(op) {}
182   ~CumSumOp() = default;
183 
184  protected:
185   void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) override;
186 };
187 
188 class LayoutTransformOp : public OpaqueOp {
189  public:
LayoutTransformOp(const std::string & op)190   explicit LayoutTransformOp(const std::string &op) : OpaqueOp(op) {}
191   ~LayoutTransformOp() = default;
192 
193  protected:
194   std::vector<DShape> InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
InferType(const NodePtrList & inputs,const DAttrs &)195   std::vector<TypeId> InferType(const NodePtrList &inputs, const DAttrs &) override { return {inputs[0]->type}; }
InferFormat(const NodePtrList &,const DAttrs & attrs)196   DFormat InferFormat(const NodePtrList &, const DAttrs &attrs) override {
197     return GetValue<std::string>(attrs.find("dst_format")->second);
198   }
199 };
200 
201 class ElemAnyOp : public OpaqueOp {
202  public:
ElemAnyOp(const std::string & op)203   explicit ElemAnyOp(const std::string &op) : OpaqueOp(op) {}
204   ~ElemAnyOp() = default;
205 
206  protected:
InferShape(const NodePtrList &,const DAttrs & attrs)207   std::vector<DShape> InferShape(const NodePtrList &, const DAttrs &attrs) override {
208     auto iter = attrs.find("empty_shape");
209     if (iter != attrs.end() && GetValue<bool>(iter->second) == true) {
210       return {{}};
211     }
212     return {{1}};
213   }
InferType(const NodePtrList &,const DAttrs &)214   std::vector<TypeId> InferType(const NodePtrList &, const DAttrs &) override { return {TypeId::kNumberTypeFloat32}; }
215 };
216 
217 class ShapeOp : public OpaqueOp {
218  public:
ShapeOp(const std::string & op)219   explicit ShapeOp(const std::string &op) : OpaqueOp(op) {}
220   ~ShapeOp() = default;
221   NodePtr InferValue(const NodePtrList &inputs, const DAttrs &) override;
222 
223  protected:
InferShape(const NodePtrList & inputs,const DAttrs &)224   std::vector<DShape> InferShape(const NodePtrList &inputs, const DAttrs &) override {
225     return {{SizeToLong(inputs[0]->shape.size())}};
226   }
InferType(const NodePtrList &,const DAttrs &)227   std::vector<TypeId> InferType(const NodePtrList &, const DAttrs &) override { return {TypeId::kNumberTypeInt32}; }
InferFormat(const NodePtrList &,const DAttrs &)228   DFormat InferFormat(const NodePtrList &, const DAttrs &) override { return kOpFormat_DEFAULT; };
229 };
230 
231 class ConstantOfShapeOp : public OpaqueOp {
232  public:
ConstantOfShapeOp(const std::string & op)233   explicit ConstantOfShapeOp(const std::string &op) : OpaqueOp(op) {}
234   ~ConstantOfShapeOp() = default;
235 
236   NodePtr InferValue(const NodePtrList &inputs, const DAttrs &attrs) override;
237 
238  protected:
239   std::vector<DShape> InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
InferType(const NodePtrList &,const DAttrs & attrs)240   std::vector<TypeId> InferType(const NodePtrList &, const DAttrs &attrs) override {
241     return {static_cast<TypeId>(GetValue<int64_t>(attrs.find("data_type")->second))};
242   }
InferFormat(const NodePtrList &,const DAttrs &)243   DFormat InferFormat(const NodePtrList &, const DAttrs &) override { return kOpFormat_DEFAULT; }
244 };
245 
246 class PadAkgOp : public OpaqueOp {
247  public:
PadAkgOp(const std::string & op)248   explicit PadAkgOp(const std::string &op) : OpaqueOp(op) {}
249   ~PadAkgOp() = default;
250 
251  protected:
252   std::vector<DShape> InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
InferType(const NodePtrList & inputs,const DAttrs &)253   std::vector<TypeId> InferType(const NodePtrList &inputs, const DAttrs &) override { return {inputs[0]->type}; }
254 };
255 
256 class UnPadAkgOp : public OpaqueOp {
257  public:
UnPadAkgOp(const std::string & op)258   explicit UnPadAkgOp(const std::string &op) : OpaqueOp(op) {}
259   ~UnPadAkgOp() = default;
260 
261  protected:
262   std::vector<DShape> InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
InferType(const NodePtrList & inputs,const DAttrs &)263   std::vector<TypeId> InferType(const NodePtrList &inputs, const DAttrs &) override { return {inputs[0]->type}; }
264 };
265 
266 class Conv2dOp : public OpaqueOp {
267  public:
Conv2dOp(const std::string & op)268   explicit Conv2dOp(const std::string &op) : OpaqueOp(op) {}
269   ~Conv2dOp() = default;
270   static bool HadPad(const ShapeVector &pad_list, const std::string &pad_mode);
271 
272  protected:
273   std::vector<DShape> InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
274   std::vector<TypeId> InferType(const NodePtrList &inputs, const DAttrs &attrs) override;
275   DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override;
276 };
277 
278 class GatherOp : public OpaqueOp {
279  public:
GatherOp(const std::string & op)280   explicit GatherOp(const std::string &op) : OpaqueOp(op) {}
281   ~GatherOp() = default;
282   NodePtr InferValue(const NodePtrList &inputs, const DAttrs &attrs) override;
283 
284  protected:
285   template <typename TM>
286   tensor::TensorPtr CalcGather(const NodePtrList &inputs, const DAttrs &attrs) const;
InferFormat(const NodePtrList &,const DAttrs &)287   DFormat InferFormat(const NodePtrList &, const DAttrs &) override { return kOpFormat_DEFAULT; };
288   void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) override;
289 };
290 
291 class ConcatOp : public OpaqueOp {
292  public:
ConcatOp(const std::string & op)293   explicit ConcatOp(const std::string &op) : OpaqueOp(op) {}
294   ~ConcatOp() = default;
295   NodePtr InferValue(const NodePtrList &inputs, const DAttrs &attrs) override;
296 
297  protected:
298   template <typename TM>
299   tensor::TensorPtr CalcConcat(const NodePtrList &inputs, const DAttrs &attrs);
InferFormat(const NodePtrList &,const DAttrs &)300   DFormat InferFormat(const NodePtrList &, const DAttrs &) override { return kOpFormat_DEFAULT; };
301   void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) override;
302 };
303 
304 class CImagRealOp : public ElemwiseOp {
305  public:
CImagRealOp(const std::string & op)306   explicit CImagRealOp(const std::string &op) : ElemwiseOp(op) {}
307   ~CImagRealOp() = default;
308 
309  protected:
Check(const NodePtrList & inputs,const DAttrs &)310   void Check(const NodePtrList &inputs, const DAttrs &) override {
311     if (inputs[0]->type != TypeId::kNumberTypeComplex64) {
312       MS_LOG(EXCEPTION) << op_ << "'s input[0] should be complex64, but got " << TypeIdToString(inputs[0]->type, true);
313     }
314   };
315 
InferShape(const NodePtrList & inputs,const DAttrs &)316   std::vector<DShape> InferShape(const NodePtrList &inputs, const DAttrs &) override { return {inputs[0]->shape}; }
InferType(const NodePtrList &,const DAttrs &)317   std::vector<TypeId> InferType(const NodePtrList &, const DAttrs &) override { return {TypeId::kNumberTypeFloat32}; }
318 };
319 
320 class Pool2DOp : public OpaqueOp {
321  public:
Pool2DOp(const std::string & op)322   explicit Pool2DOp(const std::string &op) : OpaqueOp(op) {}
323   ~Pool2DOp() = default;
324 
325  protected:
326   std::vector<DShape> InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
InferType(const NodePtrList & inputs,const DAttrs &)327   std::vector<TypeId> InferType(const NodePtrList &inputs, const DAttrs &) override { return {inputs[0]->type}; }
328 };
329 
330 class ComplexOp : public ElemwiseOp {
331  public:
ComplexOp(const std::string & op)332   explicit ComplexOp(const std::string &op) : ElemwiseOp(op) {}
333   ~ComplexOp() = default;
334 
335  protected:
336   void Check(const NodePtrList &inputs, const DAttrs &) override;
InferShape(const NodePtrList & inputs,const DAttrs &)337   std::vector<DShape> InferShape(const NodePtrList &inputs, const DAttrs &) override { return {inputs[0]->shape}; }
InferType(const NodePtrList &,const DAttrs &)338   std::vector<TypeId> InferType(const NodePtrList &, const DAttrs &) override { return {TypeId::kNumberTypeComplex64}; }
339 };
340 
341 class StandardNormalOp : public OpaqueOp {
342  public:
StandardNormalOp(const std::string & op)343   explicit StandardNormalOp(const std::string &op) : OpaqueOp(op) {}
344   ~StandardNormalOp() = default;
345 
346  protected:
347   std::vector<DShape> InferShape(const NodePtrList &, const DAttrs &attrs) override;
InferType(const NodePtrList &,const DAttrs &)348   std::vector<TypeId> InferType(const NodePtrList &, const DAttrs &) override { return {TypeId::kNumberTypeFloat32}; }
InferFormat(const NodePtrList &,const DAttrs &)349   DFormat InferFormat(const NodePtrList &, const DAttrs &) override { return kOpFormat_DEFAULT; }
350 };
351 
352 class StridedSliceOp : public OpaqueOp {
353  public:
StridedSliceOp(const std::string & op)354   explicit StridedSliceOp(const std::string &op) : OpaqueOp(op) {}
355   ~StridedSliceOp() = default;
RectifyAbstract(const PrimitivePtr & p,AbstractBasePtrList * input_abstract_ptr)356   void RectifyAbstract(const PrimitivePtr &p, AbstractBasePtrList *input_abstract_ptr) override {
357     input_abstract_ptr->push_back(p->GetAttr("begin_mask")->ToAbstract());
358     input_abstract_ptr->push_back(p->GetAttr("end_mask")->ToAbstract());
359     input_abstract_ptr->push_back(p->GetAttr("ellipsis_mask")->ToAbstract());
360     input_abstract_ptr->push_back(p->GetAttr("new_axis_mask")->ToAbstract());
361     input_abstract_ptr->push_back(p->GetAttr("shrink_axis_mask")->ToAbstract());
362   }
363 };
364 
365 class StridedSliceOnnxOp : public OpaqueOp {
366  public:
StridedSliceOnnxOp(const std::string & op)367   explicit StridedSliceOnnxOp(const std::string &op) : OpaqueOp(op) {}
368   ~StridedSliceOnnxOp() = default;
369   NodePtr InferValue(const NodePtrList &inputs, const DAttrs &attrs) override;
370 
371  protected:
372   template <typename TM>
373   tensor::TensorPtr CalcStridedSliceOnnx(const NodePtrList &inputs, const DAttrs &) const;
InferShape(const NodePtrList &,const DAttrs & attrs)374   std::vector<DShape> InferShape(const NodePtrList &, const DAttrs &attrs) override {
375     return GetValue<std::vector<DShape>>(attrs.find("output_shape")->second);
376   }
InferType(const NodePtrList & inputs,const DAttrs &)377   std::vector<TypeId> InferType(const NodePtrList &inputs, const DAttrs &) override { return {inputs[0]->type}; }
InferFormat(const NodePtrList &,const DAttrs &)378   DFormat InferFormat(const NodePtrList &, const DAttrs &) override { return kOpFormat_DEFAULT; }
379 };
380 
381 class MatMulOp : public OpaqueOp {
382  public:
MatMulOp(const std::string & op)383   explicit MatMulOp(const std::string &op) : OpaqueOp(op) {}
384   ~MatMulOp() = default;
385 
386  protected:
387   void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) override;
388   std::vector<DShape> InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
389   std::vector<TypeId> InferType(const NodePtrList &inputs, const DAttrs &attrs) override;
390 };
391 
392 class TupleGetItemOp : public VirtualOp {
393  public:
394   using VirtualOp::VirtualOp;
395   ~TupleGetItemOp() = default;
396 };
397 
398 class PagedAttentionOp : public OpaqueOp {
399  public:
PagedAttentionOp(const std::string & op)400   explicit PagedAttentionOp(const std::string &op) : OpaqueOp(op) {}
401   ~PagedAttentionOp() = default;
402 
403  protected:
404   void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) override;
405 };
406 }  // namespace mindspore::graphkernel::inner
407 #endif
408