1 /** 2 * Copyright 2019-2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_BASE_H_ 18 #define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_BASE_H_ 19 20 #include <string> 21 #include <memory> 22 #include <utility> 23 #include <vector> 24 #include <sstream> 25 #include <map> 26 27 #include "utils/hash_map.h" 28 #include "transform/graph_ir/transform_util.h" 29 #include "ir/anf.h" 30 #include "ir/primitive.h" 31 #include "ir/value.h" 32 #include "graph/operator_reg.h" 33 #include "ge/ge_api.h" 34 #include "graph/tensor.h" 35 #include "graph/types.h" 36 #include "mindapi/base/format.h" 37 38 namespace ge { 39 class CustomOperator : public Operator { 40 public: CustomOperator(const string & name,const string & type)41 CustomOperator(const string &name, const string &type) : Operator(name, type) {} 42 ~CustomOperator()43 ~CustomOperator() override{}; 44 CustomInputRegister(const string & name)45 void CustomInputRegister(const string &name) { Operator::InputRegister(name); } 46 CustomOutputRegister(const string & name)47 void CustomOutputRegister(const string &name) { Operator::OutputRegister(name); } 48 CustomRequiredAttrRegister(const string & name)49 void CustomRequiredAttrRegister(const string &name) { Operator::RequiredAttrRegister(name); } 50 CustomInferFuncRegister(const std::function<graphStatus (Operator &)> & func)51 void CustomInferFuncRegister(const std::function<graphStatus(Operator &)> &func) { 52 Operator::InferFuncRegister(func); 53 } 54 }; 55 } // namespace ge 56 57 namespace mindspore { 58 namespace transform { 59 using CusOperatorPtr = std::shared_ptr<::ge::CustomOperator>; 60 using CustomOperator = ::ge::CustomOperator; 61 using AttrFunc = std::function<void(OperatorPtr, ValuePtr)>; 62 using GetAttrFunc = std::function<void(OperatorPtr, ValuePtr *)>; 63 using OutputFunc = std::function<OutHandler(OperatorPtr)>; 64 using InputOpFunc = std::function<void(OperatorPtr, OperatorPtr)>; 65 using InputHandleFunc = std::function<void(OperatorPtr, OutHandler)>; 66 using CreateDynInputOpFunc = std::function<void(OperatorPtr, unsigned int)>; 67 using CreateDynInputOpByIndexFunc = std::function<void(OperatorPtr, unsigned int, size_t)>; 68 using DynInputOpFunc = std::function<void(OperatorPtr, unsigned int, OperatorPtr)>; 69 using DynInputHandleFunc = std::function<void(OperatorPtr, unsigned int, OutHandler)>; 70 using UpdateOutputDescFunc = std::function<void(OperatorPtr, GeTensorDesc)>; 71 using CreateDynOutputOpFunc = std::function<void(OperatorPtr, unsigned int)>; 72 using UpdateDynOutputDescFunc = std::function<void(OperatorPtr, unsigned int, GeTensorDesc)>; 73 using SubGraphFunc = std::function<void(OperatorPtr, DfGraphPtr)>; 74 using CreateDynSubGraphFunc = std::function<void(OperatorPtr, unsigned int)>; 75 76 using DynSubGraphFunc = std::function<void(OperatorPtr, unsigned int, DfGraphPtr)>; 77 78 struct AttrDesc { 79 std::string name; 80 AttrFunc set_attr; 81 GetAttrFunc get_attr; 82 enum { 83 REQUIRED = 0, 84 OPTIONAL = 1, 85 DEFAULT = OPTIONAL, 86 } type = DEFAULT; 87 }; 88 89 struct InputDesc { 90 std::string name; 91 size_t index; 92 InputOpFunc set_op; 93 InputHandleFunc set_handle; 94 UpdateOutputDescFunc update_input_desc; 95 enum { 96 REQUIRED = 0, 97 OPTIONAL = 1, 98 DEFAULT = REQUIRED, 99 } type = DEFAULT; 100 std::vector<enum ::ge::DataType> supported_dtypes; 101 }; 102 103 struct DynInputDesc { 104 std::string name; 105 size_t index; 106 CreateDynInputOpFunc create_dyn_input; 107 CreateDynInputOpByIndexFunc create_dyn_input_by_index; 108 DynInputOpFunc set_op; 109 DynInputHandleFunc set_handle; 110 std::vector<enum ::ge::DataType> supported_dtypes; 111 }; 112 113 struct SubGraphDesc { 114 std::string name; 115 SubGraphFunc set_subgraph; 116 }; 117 118 struct DynSubGraphDesc { 119 std::string name; 120 CreateDynSubGraphFunc create_dyn_subgraph; 121 DynSubGraphFunc set_subgraph; 122 }; 123 124 struct OutputDesc { 125 std::string name; 126 size_t index; 127 UpdateOutputDescFunc update_out_desc; 128 std::vector<enum ::ge::DataType> supported_dtypes; 129 }; 130 131 struct DynOutputDesc { 132 std::string name; 133 size_t index; 134 CreateDynOutputOpFunc create_dyn_output; 135 UpdateDynOutputDescFunc update_dyn_output_desc; 136 std::vector<enum ::ge::DataType> supported_dtypes; 137 }; 138 139 class BaseOpAdapter { 140 public: ~BaseOpAdapter()141 virtual ~BaseOpAdapter() {} 142 virtual OperatorPtr generate(const AnfNodePtr &anf) = 0; generate(const std::string & type)143 virtual OperatorPtr generate(const std::string &type) const { return std::make_shared<::ge::Operator>(type); } generateDynOutputOp(const AnfNodePtr & anf)144 virtual OperatorPtr generateDynOutputOp(const AnfNodePtr &anf) { return nullptr; } setDynamicOutputNum(const OperatorPtr & op,size_t dyn_output_size)145 virtual void setDynamicOutputNum(const OperatorPtr &op, size_t dyn_output_size) { return; } 146 virtual void setSubgraph(const OperatorPtr &op, std::shared_ptr<std::vector<DfGraph>> subgraphs) = 0; 147 virtual void setSubgraph(const OperatorPtr &op, int index, const std::shared_ptr<std::vector<DfGraph>> &branches) = 0; 148 virtual int setInput(const OperatorPtr &op, int index, const OperatorPtr &input) = 0; 149 virtual int setInput(const OperatorPtr &op, int index, const OutHandler &handle) = 0; 150 virtual int setInput(const OperatorPtr &op, int index, const std::shared_ptr<std::vector<OutHandler>> &handler_vec, 151 bool use_create_byindex_func = false, size_t dyn_index = 0) = 0; 152 virtual int setAttr(const OperatorPtr &op, const std::string &attrKey, const ValuePtr &attrValue) = 0; 153 virtual int setAttr(const OperatorPtr &op, const PrimitivePtr &prim) = 0; 154 virtual int setAttr(const OperatorPtr &op, const AnfNodePtr &node) = 0; 155 virtual int setAttr(const std::string &attrKey, const ValuePtr &attrValue) = 0; 156 virtual int setAttr(const uint32_t &input_idx, const ValuePtr &attrValue) = 0; 157 virtual int getAttr(const std::string &attrKey, ValuePtr *attrValue) = 0; 158 virtual int getAttr(const uint32_t &input_idx, ValuePtr *attrValue) = 0; 159 virtual mindspore::HashMap<std::string, ValuePtr> GetExtraAttr() = 0; 160 template <typename T, typename _ = typename std::enable_if<!std::is_base_of<Value, T>::value>::type> setAttr(const OperatorPtr & op,const std::string & attrKey,const std::shared_ptr<T> & attrValue)161 int setAttr(const OperatorPtr &op, const std::string &attrKey, const std::shared_ptr<T> &attrValue) { 162 return setAttr(op, attrKey, MakeValue(attrValue)); 163 } 164 template <typename T, typename _ = typename std::enable_if<!is_shared_ptr<T>::value>::type> setAttr(const OperatorPtr & op,const std::string & attrKey,const T & attrValue)165 int setAttr(const OperatorPtr &op, const std::string &attrKey, const T &attrValue) { 166 return setAttr(op, attrKey, MakeValue(attrValue)); 167 } 168 virtual std::string getOpType() = 0; 169 virtual OutHandler getOutput(const OperatorPtr &op, int index) = 0; 170 virtual std::vector<OutHandler> getOutputs(const OperatorPtr &op) = 0; 171 virtual void updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, 172 const AnfNodePtr &node) = 0; 173 virtual const mindspore::HashMap<int, InputDesc> &getInputMap() = 0; 174 virtual const mindspore::HashMap<unsigned int, AttrDesc> &getInputAttrMap() = 0; 175 virtual const mindspore::HashMap<std::string, AttrDesc> &getAttrMap() = 0; 176 virtual const mindspore::HashMap<std::string, std::string> &getAttrInputMap() = 0; 177 virtual const mindspore::HashMap<int, DynInputDesc> &getDynInputMap() = 0; 178 virtual const std::map<int, OutputDesc> &getOutputMap() = 0; 179 virtual const mindspore::HashMap<int, DynOutputDesc> &getDynOutputMap() = 0; 180 virtual const mindspore::HashMap<int, SubGraphDesc> &getSubgraphMap() = 0; 181 virtual const mindspore::HashMap<int, DynSubGraphDesc> &getDynSubgraphMap() = 0; 182 virtual std::map<std::string, ValuePtr> GetNormalOpAttrList(const AnfNodePtr &node) = 0; 183 virtual std::map<std::string, ValuePtr> GetOpAttrList() = 0; 184 virtual bool IsDynInputOp(uint64_t index) = 0; 185 virtual bool IsDyOutputOp(uint64_t index) = 0; 186 virtual bool IsMultipleOutputOp(const AnfNodePtr &anf) = 0; 187 virtual bool GetDynamicShapeSupport() = 0; AddAttrToDrawGraph(const std::string & attr_str)188 void AddAttrToDrawGraph(const std::string &attr_str) { attrs_vec_.push_back(attr_str); } GetAttrsFromDrawGraph()189 const std::vector<std::string> &GetAttrsFromDrawGraph() const { return attrs_vec_; } clearAttrVect()190 void clearAttrVect() { attrs_vec_.clear(); } 191 192 private: 193 std::vector<std::string> attrs_vec_; 194 }; 195 196 using OpAdapterPtr = std::shared_ptr<BaseOpAdapter>; 197 198 enum AttrType { 199 ATTR_INT = 0, 200 ATTR_FLOAT, 201 ATTR_DOUBLE, 202 ATTR_STRING, 203 ATTR_TENSOR, 204 ATTR_BOOL, 205 ATTR_LIST_INT, 206 ATTR_LIST_ANY_INT, 207 ATTR_ENUM 208 }; 209 210 struct GeEnum {}; 211 struct TFType {}; 212 struct GEType {}; 213 struct GEEnumToStr {}; 214 215 class GEDataFormat { 216 public: ConvertEnumToString(int64_t id)217 static std::string ConvertEnumToString(int64_t id) { 218 const auto &enum_string = FormatEnumToString(static_cast<mindspore::Format>(id)); 219 if (enum_string.empty()) { 220 MS_LOG(EXCEPTION) << "Invalid data format " << id; 221 } 222 return enum_string; 223 } 224 }; 225 226 class AscendQuantRoundMode { 227 public: ConvertEnumToString(int64_t id)228 static std::string ConvertEnumToString(int64_t id) { 229 static const std::vector<std::string> round_mode = {"round", "trunc", "floor", "ceil"}; 230 if (id < 0 || id >= static_cast<int64_t>(round_mode.size())) { 231 MS_LOG(EXCEPTION) << "Invalid AscendQuant round mode " << id; 232 return ""; 233 } 234 return round_mode[id]; 235 } 236 }; 237 238 class FASInputLayoutMode { 239 public: ConvertEnumToString(int64_t id)240 static std::string ConvertEnumToString(int64_t id) { 241 static const std::vector<std::string> input_layout_modes = {"BSH", "BNSD", "SBH", "BSND", "TND"}; 242 if (id < 0 || id >= static_cast<int64_t>(input_layout_modes.size())) { 243 MS_LOG(EXCEPTION) << "Invalid input layout mode " << id; 244 return ""; 245 } 246 return input_layout_modes[id]; 247 } 248 }; 249 250 class FFNActivationMode { 251 public: ConvertEnumToString(int64_t id)252 static std::string ConvertEnumToString(int64_t id) { 253 static const std::vector<std::string> activation_mode = { 254 "no_activation", "relu", "sigmoid", "relu6", "elu", "leaky_relu", "abs", "relu1", "softsign", 255 "softplus", "tanh", "selu", "hswish", "hsigmoid", "thresholdrelu", "linear", "hard_tanh", "sign", 256 "swish", "gelu", "glu", "unknown", "fastgelu", "silu", "geglu", "swiglu", "reglu"}; 257 if (id < 0 || id >= static_cast<int64_t>(activation_mode.size())) { 258 MS_LOG(EXCEPTION) << "Invalid moe ffn activation " << id; 259 return ""; 260 } 261 return activation_mode[id]; 262 } 263 }; 264 265 class ScatterReduceMode { 266 public: ConvertEnumToString(int64_t id)267 static std::string ConvertEnumToString(int64_t id) { 268 static const std::vector<std::string> reduce_mode = {"sum", "mean", "none", "update"}; 269 if (id < 0 || id >= static_cast<int64_t>(reduce_mode.size())) { 270 MS_LOG(EXCEPTION) << "Invalid reduce mode " << id; 271 return ""; 272 } 273 return reduce_mode[id]; 274 } 275 }; 276 277 class GEPadMod { 278 public: ConvertEnumToString(int64_t id)279 static std::string ConvertEnumToString(int64_t id) { 280 static const std::vector<std::string> pad_mods = {"PAD", "SAME", "VALID"}; 281 if (id < 0 || id >= static_cast<int64_t>(pad_mods.size())) { 282 MS_LOG(EXCEPTION) << "Invalid pad mod " << id; 283 return ""; 284 } 285 return pad_mods[id]; 286 } 287 }; 288 289 class GEReduction { 290 public: ConvertEnumToString(int64_t id)291 static std::string ConvertEnumToString(int64_t id) { 292 static const std::vector<std::string> reductions = {"sum", "mean", "none"}; 293 if (id < 0 || id >= static_cast<int64_t>(reductions.size())) { 294 MS_LOG(EXCEPTION) << "Invalid reduction " << id; 295 return ""; 296 } 297 return reductions[id]; 298 } 299 }; 300 301 class GECoordinateTransformMode { 302 public: ConvertEnumToString(int64_t id)303 static std::string ConvertEnumToString(int64_t id) { 304 static const std::vector<std::string> modes = {"asymmetric", "align_corners", "half_pixel", "crop_and_resize"}; 305 if (id < 0 || id >= static_cast<int64_t>(modes.size())) { 306 MS_LOG(EXCEPTION) << "Invalid CoordinateTransformMode " << id; 307 return ""; 308 } 309 return modes[id]; 310 } 311 }; 312 313 // declare Any type 314 template <typename T> 315 struct AnyTraits { 316 using type = T; 317 }; 318 319 template <> 320 struct AnyTraits<int> { 321 using type = int64_t; 322 }; 323 324 using ExtraAttr = mindspore::HashMap<std::string, ValuePtr>; 325 } // namespace transform 326 } // namespace mindspore 327 #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_BASE_H_ 328