1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_OP_NODE_H_ 17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_OP_NODE_H_ 18 19 #include <memory> 20 #include <algorithm> 21 #include <sstream> 22 #include <string> 23 #include <unordered_map> 24 #include <functional> 25 26 #include "backend/optimizer/graph_kernel/model/node.h" 27 #include "backend/kernel_compiler/common_utils.h" 28 #include "ir/dtype/type.h" 29 30 namespace mindspore { 31 namespace opt { 32 namespace graphkernel { 33 #define CHECK_ATTR(attrs, attr_name) \ 34 do { \ 35 if (attrs.count(attr_name) == 0) { \ 36 MS_LOG(EXCEPTION) << "The attr [" << attr_name << "] does not exist in [" << #attrs << "]"; \ 37 } \ 38 } while (0) 39 40 class PrimOp : public Node { 41 public: 42 enum ComputeType { 43 RESHAPE, 44 ELEMWISE, 45 BROADCAST, 46 REDUCE, 47 OPAQUE, 48 }; 49 PrimOp(const std::string & op,const std::string & node_name,ComputeType compute)50 PrimOp(const std::string &op, const std::string &node_name, ComputeType compute) 51 : Node({{}, TypeId::kNumberTypeBegin, kOpFormat_DEFAULT}, node_name), op_(op), compute_type_(compute) {} 52 ~PrimOp() = default; 53 54 virtual NodeBase Infer(const NodePtrList &inputs, const DAttrs &attrs); 55 virtual NodePtr InferValue(const NodePtrList &inputs, const DAttrs &attrs, const std::string &op); 56 57 void Dump(std::ostringstream &os) const override; NodeType()58 NType NodeType() override { return NType::Primitive; } 59 op()60 const std::string &op() const { return op_; } compute_type()61 ComputeType compute_type() const { return compute_type_; } 62 63 protected: 64 virtual void Check(const NodePtrList &inputs, const DAttrs &attrs); CheckShape(const NodePtrList & inputs,const DAttrs & attrs)65 virtual void CheckShape(const NodePtrList &inputs, const DAttrs &attrs) {} 66 virtual void CheckType(const NodePtrList &inputs, const DAttrs &attrs); 67 virtual void CheckFormat(const NodePtrList &inputs, const DAttrs &attrs); 68 InferShape(const NodePtrList & inputs,const DAttrs & attrs)69 virtual DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->shape; } InferType(const NodePtrList & inputs,const DAttrs & attrs)70 virtual TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->type; } InferFormat(const NodePtrList & inputs,const DAttrs & attrs)71 virtual DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->format; } 72 73 std::string op_; 74 ComputeType compute_type_; 75 }; 76 using PrimOpPtr = std::shared_ptr<PrimOp>; 77 78 class ElemwiseOp : public PrimOp { 79 public: ElemwiseOp(const std::string & op,const std::string & node_name)80 ElemwiseOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, ELEMWISE) {} 81 ~ElemwiseOp() = default; 82 83 NodeBase Infer(const NodePtrList &inputs, const DAttrs &attrs) override; 84 85 protected: 86 DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; 87 DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override; 88 }; 89 90 class CastOp : public ElemwiseOp { 91 public: CastOp(const std::string & op,const std::string & node_name)92 CastOp(const std::string &op, const std::string &node_name) : ElemwiseOp("Cast", node_name) {} 93 ~CastOp() = default; 94 95 protected: 96 TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override; 97 }; 98 99 class InplaceAssignOp : public ElemwiseOp { 100 public: InplaceAssignOp(const std::string & op,const std::string & node_name)101 InplaceAssignOp(const std::string &op, const std::string &node_name) : ElemwiseOp("InplaceAssign", node_name) {} 102 ~InplaceAssignOp() = default; 103 104 protected: InferShape(const NodePtrList & inputs,const DAttrs & attrs)105 DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override { return inputs[2]->shape; } InferType(const NodePtrList & inputs,const DAttrs & attrs)106 TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return inputs[2]->type; } InferFormat(const NodePtrList & inputs,const DAttrs & attrs)107 DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override { return inputs[2]->format; } 108 }; 109 110 class SelectOp : public ElemwiseOp { 111 public: SelectOp(const std::string & op,const std::string & node_name)112 SelectOp(const std::string &op, const std::string &node_name) : ElemwiseOp("Select", node_name) {} 113 ~SelectOp() = default; 114 115 protected: 116 void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override; InferType(const NodePtrList & inputs,const DAttrs & attrs)117 TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return inputs[1]->type; } 118 }; 119 120 class CompareOp : public ElemwiseOp { 121 public: CompareOp(const std::string & op,const std::string & node_name)122 CompareOp(const std::string &op, const std::string &node_name) : ElemwiseOp(op, node_name) {} 123 ~CompareOp() = default; 124 125 protected: InferType(const NodePtrList & inputs,const DAttrs & attrs)126 TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return TypeId::kNumberTypeBool; } 127 }; 128 129 class LessOp : public CompareOp { 130 public: LessOp(const std::string & op,const std::string & node_name)131 LessOp(const std::string &op, const std::string &node_name) : CompareOp("Less", node_name) {} 132 ~LessOp() = default; 133 }; 134 135 class EqualOp : public CompareOp { 136 public: EqualOp(const std::string & op,const std::string & node_name)137 EqualOp(const std::string &op, const std::string &node_name) : CompareOp("Equal", node_name) {} 138 ~EqualOp() = default; 139 }; 140 141 class LessEqualOp : public CompareOp { 142 public: LessEqualOp(const std::string & op,const std::string & node_name)143 LessEqualOp(const std::string &op, const std::string &node_name) : CompareOp("LessEqual", node_name) {} 144 ~LessEqualOp() = default; 145 }; 146 147 class GreaterOp : public CompareOp { 148 public: GreaterOp(const std::string & op,const std::string & node_name)149 GreaterOp(const std::string &op, const std::string &node_name) : CompareOp("Greater", node_name) {} 150 ~GreaterOp() = default; 151 }; 152 153 class GreaterEqualOp : public CompareOp { 154 public: GreaterEqualOp(const std::string & op,const std::string & node_name)155 GreaterEqualOp(const std::string &op, const std::string &node_name) : CompareOp("GreaterEqual", node_name) {} 156 ~GreaterEqualOp() = default; 157 }; 158 159 class ReshapeOp : public PrimOp { 160 public: ReshapeOp(const std::string & op,const std::string & node_name)161 ReshapeOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, RESHAPE) {} 162 ~ReshapeOp() = default; 163 164 protected: 165 DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; InferFormat(const NodePtrList & inputs,const DAttrs & attrs)166 DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override { 167 return attrs.find("format") == attrs.end() ? kOpFormat_DEFAULT 168 : GetValue<std::string>(attrs.find("format")->second); 169 } 170 }; 171 172 class BroadcastToOp : public PrimOp { 173 public: BroadcastToOp(const std::string & op,const std::string & node_name)174 BroadcastToOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, BROADCAST) {} 175 ~BroadcastToOp() = default; 176 177 protected: 178 DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; 179 }; 180 181 class ReduceOp : public PrimOp { 182 public: ReduceOp(const std::string & op,const std::string & node_name)183 ReduceOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, REDUCE) {} 184 ~ReduceOp() = default; 185 186 protected: 187 void Check(const NodePtrList &inputs, const DAttrs &attrs) override; 188 DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; InferFormat(const NodePtrList & inputs,const DAttrs & attrs)189 DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override { return kOpFormat_DEFAULT; }; 190 }; 191 192 class OpaqueOp : public PrimOp { 193 public: OpaqueOp(const std::string & op,const std::string & node_name)194 OpaqueOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, OPAQUE) {} 195 ~OpaqueOp() = default; 196 }; 197 198 class Conv2dOp : public OpaqueOp { 199 public: Conv2dOp(const std::string & op,const std::string & node_name)200 Conv2dOp(const std::string &op, const std::string &node_name) : OpaqueOp("Conv2D", node_name) {} 201 ~Conv2dOp() = default; 202 203 protected: 204 DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; 205 TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override; 206 }; 207 208 class TransposeOp : public OpaqueOp { 209 public: TransposeOp(const std::string & op,const std::string & node_name)210 TransposeOp(const std::string &op, const std::string &node_name) : OpaqueOp("Transpose", node_name) {} 211 ~TransposeOp() = default; 212 213 protected: 214 DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; 215 DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override; 216 }; 217 218 class MatMulOp : public OpaqueOp { 219 public: MatMulOp(const std::string & op,const std::string & node_name)220 MatMulOp(const std::string &op, const std::string &node_name) : OpaqueOp("MatMul", node_name) {} 221 ~MatMulOp() = default; 222 223 protected: 224 DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; 225 TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override; 226 }; 227 228 class PadAkgOp : public OpaqueOp { 229 public: PadAkgOp(const std::string & op,const std::string & node_name)230 PadAkgOp(const std::string &op, const std::string &node_name) : OpaqueOp("PadAkg", node_name) {} 231 ~PadAkgOp() = default; 232 233 protected: 234 DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; 235 }; 236 237 class UnPadAkgOp : public OpaqueOp { 238 public: UnPadAkgOp(const std::string & op,const std::string & node_name)239 UnPadAkgOp(const std::string &op, const std::string &node_name) : OpaqueOp("UnPadAkg", node_name) {} 240 ~UnPadAkgOp() = default; 241 242 protected: 243 DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; 244 }; 245 246 class CImagOp : public ElemwiseOp { 247 public: CImagOp(const std::string & op,const std::string & node_name)248 CImagOp(const std::string &op, const std::string &node_name) : ElemwiseOp("CImag", node_name) {} 249 ~CImagOp() = default; 250 251 protected: CheckType(const NodePtrList & inputs,const DAttrs & attrs)252 void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override { 253 if (inputs[0]->type != TypeId::kNumberTypeComplex64) { 254 MS_LOG(EXCEPTION) << "CImag's input[0] should be complex64"; 255 } 256 }; 257 InferType(const NodePtrList & inputs,const DAttrs & attrs)258 TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return TypeId::kNumberTypeFloat32; } 259 }; 260 261 class CRealOp : public ElemwiseOp { 262 public: CRealOp(const std::string & op,const std::string & node_name)263 CRealOp(const std::string &op, const std::string &node_name) : ElemwiseOp("CReal", node_name) {} 264 ~CRealOp() = default; 265 266 protected: CheckType(const NodePtrList & inputs,const DAttrs & attrs)267 void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override { 268 if (inputs[0]->type != TypeId::kNumberTypeComplex64) { 269 MS_LOG(EXCEPTION) << "CReal's input[0] should be complex64"; 270 } 271 }; 272 InferType(const NodePtrList & inputs,const DAttrs & attrs)273 TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return TypeId::kNumberTypeFloat32; } 274 }; 275 276 class ComplexOp : public ElemwiseOp { 277 public: ComplexOp(const std::string & op,const std::string & node_name)278 ComplexOp(const std::string &op, const std::string &node_name) : ElemwiseOp("Complex", node_name) {} 279 ~ComplexOp() = default; 280 281 protected: 282 void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override; InferType(const NodePtrList & inputs,const DAttrs & attrs)283 TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return TypeId::kNumberTypeComplex64; } 284 }; 285 286 class StandardNormalOp : public OpaqueOp { 287 public: StandardNormalOp(const std::string & op,const std::string & node_name)288 StandardNormalOp(const std::string &op, const std::string &node_name) : OpaqueOp("StandardNormal", node_name) {} 289 ~StandardNormalOp() = default; 290 291 protected: 292 DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; InferType(const NodePtrList & inputs,const DAttrs & attrs)293 TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return TypeId::kNumberTypeFloat32; } InferFormat(const NodePtrList & inputs,const DAttrs & attrs)294 DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override { return kOpFormat_DEFAULT; } 295 }; 296 } // namespace graphkernel 297 } // namespace opt 298 } // namespace mindspore 299 #endif 300