1 /** 2 * Copyright 2021-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_OP_NODE_H_ 17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_OP_NODE_H_ 18 19 #include <memory> 20 #include <string> 21 #include <vector> 22 #include <utility> 23 #include "ops/primitive_c.h" 24 #include "backend/common/graph_kernel/model/node.h" 25 #include "ir/dtype/type.h" 26 #include "include/backend/visible.h" 27 28 namespace mindspore::graphkernel::inner { 29 #define CHECK_ATTR(attrs, attr_name) \ 30 do { \ 31 if (attrs.count(attr_name) == 0) { \ 32 MS_LOG(EXCEPTION) << "The attr [" << attr_name << "] does not exist in [" << #attrs << "]"; \ 33 } \ 34 } while (0) 35 36 class BACKEND_EXPORT PrimOp : public Node { 37 public: 38 enum class ComputeType : int { 39 VIRTUAL = 0, 40 RESHAPE = 1, 41 ELEMWISE = 2, 42 BROADCAST = 3, 43 REDUCE = 4, 44 OPAQUE = 5, 45 }; 46 PrimOp(const std::string & op,ComputeType compute)47 PrimOp(const std::string &op, ComputeType compute) 48 : Node({{}, TypeId::kNumberTypeBegin, kOpFormat_DEFAULT}), op_(op), compute_type_(compute) {} 49 ~PrimOp() = default; 50 51 NodeBaseList Infer(const NodePtrList &inputs, const DAttrs &attrs); 52 53 std::string ToString() const override; NodeType()54 NType NodeType() override { return NType::Primitive; } 55 op()56 const std::string &op() const { return op_; } compute_type()57 ComputeType compute_type() const { return compute_type_; } 58 // infer output value when all inputs are constant 59 virtual NodePtr InferValue(const NodePtrList &inputs, const DAttrs &attrs); 60 61 protected: 62 // Check node info before inference the shape/type/format. Check(const NodePtrList &,const DAttrs &)63 virtual void Check(const NodePtrList &, const DAttrs &) {} 64 65 // Infer format. assume all outputs have the same format. InferFormat(const NodePtrList & inputs,const DAttrs &)66 virtual DFormat InferFormat(const NodePtrList &inputs, const DAttrs &) { return inputs[0]->format; } 67 68 // Infer shape. returning an empty vector means using PrimitiveC's infer_shape function. 69 virtual std::vector<DShape> InferShape(const NodePtrList &, const DAttrs &); 70 71 // Infer type. returning an empty vector means using PrimitiveC's infer_type function. 72 virtual std::vector<TypeId> InferType(const NodePtrList &, const DAttrs &); 73 74 // calculate const inputs, used for InferValue 75 template <typename TM> 76 tensor::TensorPtr CalcByOperator(const NodePtrList &inputs, const DAttrs &) const; 77 78 // Gen PrimitiveC and abstract list to call PrimitiveC's inference function. 79 std::pair<PrimitivePtr, AbstractBasePtrList> GenPrimAndAbstract(const NodePtrList &inputs, const DAttrs &attrs) const; 80 81 // rectify abstract before calling PrimitiveC's inference function. RectifyAbstract(const PrimitivePtr &,AbstractBasePtrList *)82 virtual void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *) {} 83 84 std::string op_; 85 ComputeType compute_type_; 86 }; 87 using PrimOpPtr = std::shared_ptr<PrimOp>; 88 89 class ReshapeOp : public PrimOp { 90 public: ReshapeOp(const std::string & op)91 explicit ReshapeOp(const std::string &op) : PrimOp(op, ComputeType::RESHAPE) {} 92 ~ReshapeOp() = default; 93 NodePtr InferValue(const NodePtrList &inputs, const DAttrs &) override; 94 95 protected: InferFormat(const NodePtrList &,const DAttrs & attrs)96 DFormat InferFormat(const NodePtrList &, const DAttrs &attrs) override { 97 return attrs.find("format") == attrs.end() ? kOpFormat_DEFAULT 98 : GetValue<std::string>(attrs.find("format")->second); 99 } 100 }; 101 102 class ElemwiseOp : public PrimOp { 103 public: ElemwiseOp(const std::string & op)104 explicit ElemwiseOp(const std::string &op) : PrimOp(op, ComputeType::ELEMWISE) {} 105 ~ElemwiseOp() = default; 106 107 protected: 108 std::vector<DShape> InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; 109 DFormat InferFormat(const NodePtrList &inputs, const DAttrs &) override; 110 }; 111 112 class BroadcastOp : public PrimOp { 113 public: BroadcastOp(const std::string & op)114 explicit BroadcastOp(const std::string &op) : PrimOp(op, ComputeType::BROADCAST) {} 115 ~BroadcastOp() = default; 116 }; 117 118 class TileOp : public BroadcastOp { 119 public: TileOp(const std::string & op)120 explicit TileOp(const std::string &op) : BroadcastOp(op) {} 121 ~TileOp() = default; 122 }; 123 124 class ReduceOp : public PrimOp { 125 public: ReduceOp(const std::string & op)126 explicit ReduceOp(const std::string &op) : PrimOp(op, ComputeType::REDUCE) {} 127 ~ReduceOp() = default; 128 129 protected: InferFormat(const NodePtrList &,const DAttrs &)130 DFormat InferFormat(const NodePtrList &, const DAttrs &) override { return kOpFormat_DEFAULT; }; 131 void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) override; 132 }; 133 134 class ArgReduceOp : public ReduceOp { 135 public: ArgReduceOp(const std::string & op)136 explicit ArgReduceOp(const std::string &op) : ReduceOp(op) {} 137 ~ArgReduceOp() = default; 138 139 protected: 140 std::vector<DShape> InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; 141 std::vector<TypeId> InferType(const NodePtrList &, const DAttrs &attrs) override; 142 void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) override; 143 }; 144 145 class OpaqueOp : public PrimOp { 146 public: OpaqueOp(const std::string & op)147 explicit OpaqueOp(const std::string &op) : PrimOp(op, ComputeType::OPAQUE) {} 148 ~OpaqueOp() = default; 149 150 protected: 151 // for pclint warning: 1790 public base symbol of symbol has no non-destructor virtual functions DoNothing()152 virtual void DoNothing() {} 153 }; 154 155 class VirtualOp : public PrimOp { 156 public: VirtualOp(const std::string & op)157 explicit VirtualOp(const std::string &op) : PrimOp(op, ComputeType::VIRTUAL) {} 158 ~VirtualOp() = default; 159 }; 160 161 class TransposeOp : public OpaqueOp { 162 public: TransposeOp(const std::string & op)163 explicit TransposeOp(const std::string &op) : OpaqueOp(op) {} 164 ~TransposeOp() = default; 165 166 protected: 167 DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override; 168 }; 169 170 class OneHotOp : public OpaqueOp { 171 public: OneHotOp(const std::string & op)172 explicit OneHotOp(const std::string &op) : OpaqueOp(op) {} 173 ~OneHotOp() = default; 174 175 protected: 176 void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) override; 177 }; 178 179 class CumSumOp : public OpaqueOp { 180 public: CumSumOp(const std::string & op)181 explicit CumSumOp(const std::string &op) : OpaqueOp(op) {} 182 ~CumSumOp() = default; 183 184 protected: 185 void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) override; 186 }; 187 188 class LayoutTransformOp : public OpaqueOp { 189 public: LayoutTransformOp(const std::string & op)190 explicit LayoutTransformOp(const std::string &op) : OpaqueOp(op) {} 191 ~LayoutTransformOp() = default; 192 193 protected: 194 std::vector<DShape> InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; InferType(const NodePtrList & inputs,const DAttrs &)195 std::vector<TypeId> InferType(const NodePtrList &inputs, const DAttrs &) override { return {inputs[0]->type}; } InferFormat(const NodePtrList &,const DAttrs & attrs)196 DFormat InferFormat(const NodePtrList &, const DAttrs &attrs) override { 197 return GetValue<std::string>(attrs.find("dst_format")->second); 198 } 199 }; 200 201 class ElemAnyOp : public OpaqueOp { 202 public: ElemAnyOp(const std::string & op)203 explicit ElemAnyOp(const std::string &op) : OpaqueOp(op) {} 204 ~ElemAnyOp() = default; 205 206 protected: InferShape(const NodePtrList &,const DAttrs & attrs)207 std::vector<DShape> InferShape(const NodePtrList &, const DAttrs &attrs) override { 208 auto iter = attrs.find("empty_shape"); 209 if (iter != attrs.end() && GetValue<bool>(iter->second) == true) { 210 return {{}}; 211 } 212 return {{1}}; 213 } InferType(const NodePtrList &,const DAttrs &)214 std::vector<TypeId> InferType(const NodePtrList &, const DAttrs &) override { return {TypeId::kNumberTypeFloat32}; } 215 }; 216 217 class ShapeOp : public OpaqueOp { 218 public: ShapeOp(const std::string & op)219 explicit ShapeOp(const std::string &op) : OpaqueOp(op) {} 220 ~ShapeOp() = default; 221 NodePtr InferValue(const NodePtrList &inputs, const DAttrs &) override; 222 223 protected: InferShape(const NodePtrList & inputs,const DAttrs &)224 std::vector<DShape> InferShape(const NodePtrList &inputs, const DAttrs &) override { 225 return {{SizeToLong(inputs[0]->shape.size())}}; 226 } InferType(const NodePtrList &,const DAttrs &)227 std::vector<TypeId> InferType(const NodePtrList &, const DAttrs &) override { return {TypeId::kNumberTypeInt32}; } InferFormat(const NodePtrList &,const DAttrs &)228 DFormat InferFormat(const NodePtrList &, const DAttrs &) override { return kOpFormat_DEFAULT; }; 229 }; 230 231 class ConstantOfShapeOp : public OpaqueOp { 232 public: ConstantOfShapeOp(const std::string & op)233 explicit ConstantOfShapeOp(const std::string &op) : OpaqueOp(op) {} 234 ~ConstantOfShapeOp() = default; 235 236 NodePtr InferValue(const NodePtrList &inputs, const DAttrs &attrs) override; 237 238 protected: 239 std::vector<DShape> InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; InferType(const NodePtrList &,const DAttrs & attrs)240 std::vector<TypeId> InferType(const NodePtrList &, const DAttrs &attrs) override { 241 return {static_cast<TypeId>(GetValue<int64_t>(attrs.find("data_type")->second))}; 242 } InferFormat(const NodePtrList &,const DAttrs &)243 DFormat InferFormat(const NodePtrList &, const DAttrs &) override { return kOpFormat_DEFAULT; } 244 }; 245 246 class PadAkgOp : public OpaqueOp { 247 public: PadAkgOp(const std::string & op)248 explicit PadAkgOp(const std::string &op) : OpaqueOp(op) {} 249 ~PadAkgOp() = default; 250 251 protected: 252 std::vector<DShape> InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; InferType(const NodePtrList & inputs,const DAttrs &)253 std::vector<TypeId> InferType(const NodePtrList &inputs, const DAttrs &) override { return {inputs[0]->type}; } 254 }; 255 256 class UnPadAkgOp : public OpaqueOp { 257 public: UnPadAkgOp(const std::string & op)258 explicit UnPadAkgOp(const std::string &op) : OpaqueOp(op) {} 259 ~UnPadAkgOp() = default; 260 261 protected: 262 std::vector<DShape> InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; InferType(const NodePtrList & inputs,const DAttrs &)263 std::vector<TypeId> InferType(const NodePtrList &inputs, const DAttrs &) override { return {inputs[0]->type}; } 264 }; 265 266 class Conv2dOp : public OpaqueOp { 267 public: Conv2dOp(const std::string & op)268 explicit Conv2dOp(const std::string &op) : OpaqueOp(op) {} 269 ~Conv2dOp() = default; 270 static bool HadPad(const ShapeVector &pad_list, const std::string &pad_mode); 271 272 protected: 273 std::vector<DShape> InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; 274 std::vector<TypeId> InferType(const NodePtrList &inputs, const DAttrs &attrs) override; 275 DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override; 276 }; 277 278 class GatherOp : public OpaqueOp { 279 public: GatherOp(const std::string & op)280 explicit GatherOp(const std::string &op) : OpaqueOp(op) {} 281 ~GatherOp() = default; 282 NodePtr InferValue(const NodePtrList &inputs, const DAttrs &attrs) override; 283 284 protected: 285 template <typename TM> 286 tensor::TensorPtr CalcGather(const NodePtrList &inputs, const DAttrs &attrs) const; InferFormat(const NodePtrList &,const DAttrs &)287 DFormat InferFormat(const NodePtrList &, const DAttrs &) override { return kOpFormat_DEFAULT; }; 288 void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) override; 289 }; 290 291 class ConcatOp : public OpaqueOp { 292 public: ConcatOp(const std::string & op)293 explicit ConcatOp(const std::string &op) : OpaqueOp(op) {} 294 ~ConcatOp() = default; 295 NodePtr InferValue(const NodePtrList &inputs, const DAttrs &attrs) override; 296 297 protected: 298 template <typename TM> 299 tensor::TensorPtr CalcConcat(const NodePtrList &inputs, const DAttrs &attrs); InferFormat(const NodePtrList &,const DAttrs &)300 DFormat InferFormat(const NodePtrList &, const DAttrs &) override { return kOpFormat_DEFAULT; }; 301 void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) override; 302 }; 303 304 class CImagRealOp : public ElemwiseOp { 305 public: CImagRealOp(const std::string & op)306 explicit CImagRealOp(const std::string &op) : ElemwiseOp(op) {} 307 ~CImagRealOp() = default; 308 309 protected: Check(const NodePtrList & inputs,const DAttrs &)310 void Check(const NodePtrList &inputs, const DAttrs &) override { 311 if (inputs[0]->type != TypeId::kNumberTypeComplex64) { 312 MS_LOG(EXCEPTION) << op_ << "'s input[0] should be complex64, but got " << TypeIdToString(inputs[0]->type, true); 313 } 314 }; 315 InferShape(const NodePtrList & inputs,const DAttrs &)316 std::vector<DShape> InferShape(const NodePtrList &inputs, const DAttrs &) override { return {inputs[0]->shape}; } InferType(const NodePtrList &,const DAttrs &)317 std::vector<TypeId> InferType(const NodePtrList &, const DAttrs &) override { return {TypeId::kNumberTypeFloat32}; } 318 }; 319 320 class Pool2DOp : public OpaqueOp { 321 public: Pool2DOp(const std::string & op)322 explicit Pool2DOp(const std::string &op) : OpaqueOp(op) {} 323 ~Pool2DOp() = default; 324 325 protected: 326 std::vector<DShape> InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; InferType(const NodePtrList & inputs,const DAttrs &)327 std::vector<TypeId> InferType(const NodePtrList &inputs, const DAttrs &) override { return {inputs[0]->type}; } 328 }; 329 330 class ComplexOp : public ElemwiseOp { 331 public: ComplexOp(const std::string & op)332 explicit ComplexOp(const std::string &op) : ElemwiseOp(op) {} 333 ~ComplexOp() = default; 334 335 protected: 336 void Check(const NodePtrList &inputs, const DAttrs &) override; InferShape(const NodePtrList & inputs,const DAttrs &)337 std::vector<DShape> InferShape(const NodePtrList &inputs, const DAttrs &) override { return {inputs[0]->shape}; } InferType(const NodePtrList &,const DAttrs &)338 std::vector<TypeId> InferType(const NodePtrList &, const DAttrs &) override { return {TypeId::kNumberTypeComplex64}; } 339 }; 340 341 class StandardNormalOp : public OpaqueOp { 342 public: StandardNormalOp(const std::string & op)343 explicit StandardNormalOp(const std::string &op) : OpaqueOp(op) {} 344 ~StandardNormalOp() = default; 345 346 protected: 347 std::vector<DShape> InferShape(const NodePtrList &, const DAttrs &attrs) override; InferType(const NodePtrList &,const DAttrs &)348 std::vector<TypeId> InferType(const NodePtrList &, const DAttrs &) override { return {TypeId::kNumberTypeFloat32}; } InferFormat(const NodePtrList &,const DAttrs &)349 DFormat InferFormat(const NodePtrList &, const DAttrs &) override { return kOpFormat_DEFAULT; } 350 }; 351 352 class StridedSliceOp : public OpaqueOp { 353 public: StridedSliceOp(const std::string & op)354 explicit StridedSliceOp(const std::string &op) : OpaqueOp(op) {} 355 ~StridedSliceOp() = default; RectifyAbstract(const PrimitivePtr & p,AbstractBasePtrList * input_abstract_ptr)356 void RectifyAbstract(const PrimitivePtr &p, AbstractBasePtrList *input_abstract_ptr) override { 357 input_abstract_ptr->push_back(p->GetAttr("begin_mask")->ToAbstract()); 358 input_abstract_ptr->push_back(p->GetAttr("end_mask")->ToAbstract()); 359 input_abstract_ptr->push_back(p->GetAttr("ellipsis_mask")->ToAbstract()); 360 input_abstract_ptr->push_back(p->GetAttr("new_axis_mask")->ToAbstract()); 361 input_abstract_ptr->push_back(p->GetAttr("shrink_axis_mask")->ToAbstract()); 362 } 363 }; 364 365 class StridedSliceOnnxOp : public OpaqueOp { 366 public: StridedSliceOnnxOp(const std::string & op)367 explicit StridedSliceOnnxOp(const std::string &op) : OpaqueOp(op) {} 368 ~StridedSliceOnnxOp() = default; 369 NodePtr InferValue(const NodePtrList &inputs, const DAttrs &attrs) override; 370 371 protected: 372 template <typename TM> 373 tensor::TensorPtr CalcStridedSliceOnnx(const NodePtrList &inputs, const DAttrs &) const; InferShape(const NodePtrList &,const DAttrs & attrs)374 std::vector<DShape> InferShape(const NodePtrList &, const DAttrs &attrs) override { 375 return GetValue<std::vector<DShape>>(attrs.find("output_shape")->second); 376 } InferType(const NodePtrList & inputs,const DAttrs &)377 std::vector<TypeId> InferType(const NodePtrList &inputs, const DAttrs &) override { return {inputs[0]->type}; } InferFormat(const NodePtrList &,const DAttrs &)378 DFormat InferFormat(const NodePtrList &, const DAttrs &) override { return kOpFormat_DEFAULT; } 379 }; 380 381 class MatMulOp : public OpaqueOp { 382 public: MatMulOp(const std::string & op)383 explicit MatMulOp(const std::string &op) : OpaqueOp(op) {} 384 ~MatMulOp() = default; 385 386 protected: 387 void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) override; 388 std::vector<DShape> InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; 389 std::vector<TypeId> InferType(const NodePtrList &inputs, const DAttrs &attrs) override; 390 }; 391 392 class TupleGetItemOp : public VirtualOp { 393 public: 394 using VirtualOp::VirtualOp; 395 ~TupleGetItemOp() = default; 396 }; 397 398 class PagedAttentionOp : public OpaqueOp { 399 public: PagedAttentionOp(const std::string & op)400 explicit PagedAttentionOp(const std::string &op) : OpaqueOp(op) {} 401 ~PagedAttentionOp() = default; 402 403 protected: 404 void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) override; 405 }; 406 } // namespace mindspore::graphkernel::inner 407 #endif 408