• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_BASE_H_
18 #define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_BASE_H_
19 
20 #include <string>
21 #include <memory>
22 #include <utility>
23 #include <vector>
24 #include <sstream>
25 #include <map>
26 
27 #include "utils/hash_map.h"
28 #include "transform/graph_ir/transform_util.h"
29 #include "ir/anf.h"
30 #include "ir/primitive.h"
31 #include "ir/value.h"
32 #include "graph/operator_reg.h"
33 #include "ge/ge_api.h"
34 #include "graph/tensor.h"
35 #include "graph/types.h"
36 #include "mindapi/base/format.h"
37 
38 namespace ge {
39 class CustomOperator : public Operator {
40  public:
CustomOperator(const string & name,const string & type)41   CustomOperator(const string &name, const string &type) : Operator(name, type) {}
42 
~CustomOperator()43   ~CustomOperator() override{};
44 
CustomInputRegister(const string & name)45   void CustomInputRegister(const string &name) { Operator::InputRegister(name); }
46 
CustomOutputRegister(const string & name)47   void CustomOutputRegister(const string &name) { Operator::OutputRegister(name); }
48 
CustomRequiredAttrRegister(const string & name)49   void CustomRequiredAttrRegister(const string &name) { Operator::RequiredAttrRegister(name); }
50 
CustomInferFuncRegister(const std::function<graphStatus (Operator &)> & func)51   void CustomInferFuncRegister(const std::function<graphStatus(Operator &)> &func) {
52     Operator::InferFuncRegister(func);
53   }
54 };
55 }  // namespace ge
56 
57 namespace mindspore {
58 namespace transform {
59 using CusOperatorPtr = std::shared_ptr<::ge::CustomOperator>;
60 using CustomOperator = ::ge::CustomOperator;
61 using AttrFunc = std::function<void(OperatorPtr, ValuePtr)>;
62 using GetAttrFunc = std::function<void(OperatorPtr, ValuePtr *)>;
63 using OutputFunc = std::function<OutHandler(OperatorPtr)>;
64 using InputOpFunc = std::function<void(OperatorPtr, OperatorPtr)>;
65 using InputHandleFunc = std::function<void(OperatorPtr, OutHandler)>;
66 using CreateDynInputOpFunc = std::function<void(OperatorPtr, unsigned int)>;
67 using CreateDynInputOpByIndexFunc = std::function<void(OperatorPtr, unsigned int, size_t)>;
68 using DynInputOpFunc = std::function<void(OperatorPtr, unsigned int, OperatorPtr)>;
69 using DynInputHandleFunc = std::function<void(OperatorPtr, unsigned int, OutHandler)>;
70 using UpdateOutputDescFunc = std::function<void(OperatorPtr, GeTensorDesc)>;
71 using CreateDynOutputOpFunc = std::function<void(OperatorPtr, unsigned int)>;
72 using UpdateDynOutputDescFunc = std::function<void(OperatorPtr, unsigned int, GeTensorDesc)>;
73 using SubGraphFunc = std::function<void(OperatorPtr, DfGraphPtr)>;
74 using CreateDynSubGraphFunc = std::function<void(OperatorPtr, unsigned int)>;
75 
76 using DynSubGraphFunc = std::function<void(OperatorPtr, unsigned int, DfGraphPtr)>;
77 
78 struct AttrDesc {
79   std::string name;
80   AttrFunc set_attr;
81   GetAttrFunc get_attr;
82   enum {
83     REQUIRED = 0,
84     OPTIONAL = 1,
85     DEFAULT = OPTIONAL,
86   } type = DEFAULT;
87 };
88 
89 struct InputDesc {
90   std::string name;
91   size_t index;
92   InputOpFunc set_op;
93   InputHandleFunc set_handle;
94   UpdateOutputDescFunc update_input_desc;
95   enum {
96     REQUIRED = 0,
97     OPTIONAL = 1,
98     DEFAULT = REQUIRED,
99   } type = DEFAULT;
100   std::vector<enum ::ge::DataType> supported_dtypes;
101 };
102 
103 struct DynInputDesc {
104   std::string name;
105   size_t index;
106   CreateDynInputOpFunc create_dyn_input;
107   CreateDynInputOpByIndexFunc create_dyn_input_by_index;
108   DynInputOpFunc set_op;
109   DynInputHandleFunc set_handle;
110   std::vector<enum ::ge::DataType> supported_dtypes;
111 };
112 
113 struct SubGraphDesc {
114   std::string name;
115   SubGraphFunc set_subgraph;
116 };
117 
118 struct DynSubGraphDesc {
119   std::string name;
120   CreateDynSubGraphFunc create_dyn_subgraph;
121   DynSubGraphFunc set_subgraph;
122 };
123 
124 struct OutputDesc {
125   std::string name;
126   size_t index;
127   UpdateOutputDescFunc update_out_desc;
128   std::vector<enum ::ge::DataType> supported_dtypes;
129 };
130 
131 struct DynOutputDesc {
132   std::string name;
133   size_t index;
134   CreateDynOutputOpFunc create_dyn_output;
135   UpdateDynOutputDescFunc update_dyn_output_desc;
136   std::vector<enum ::ge::DataType> supported_dtypes;
137 };
138 
139 class BaseOpAdapter {
140  public:
~BaseOpAdapter()141   virtual ~BaseOpAdapter() {}
142   virtual OperatorPtr generate(const AnfNodePtr &anf) = 0;
generate(const std::string & type)143   virtual OperatorPtr generate(const std::string &type) const { return std::make_shared<::ge::Operator>(type); }
generateDynOutputOp(const AnfNodePtr & anf)144   virtual OperatorPtr generateDynOutputOp(const AnfNodePtr &anf) { return nullptr; }
setDynamicOutputNum(const OperatorPtr & op,size_t dyn_output_size)145   virtual void setDynamicOutputNum(const OperatorPtr &op, size_t dyn_output_size) { return; }
146   virtual void setSubgraph(const OperatorPtr &op, std::shared_ptr<std::vector<DfGraph>> subgraphs) = 0;
147   virtual void setSubgraph(const OperatorPtr &op, int index, const std::shared_ptr<std::vector<DfGraph>> &branches) = 0;
148   virtual int setInput(const OperatorPtr &op, int index, const OperatorPtr &input) = 0;
149   virtual int setInput(const OperatorPtr &op, int index, const OutHandler &handle) = 0;
150   virtual int setInput(const OperatorPtr &op, int index, const std::shared_ptr<std::vector<OutHandler>> &handler_vec,
151                        bool use_create_byindex_func = false, size_t dyn_index = 0) = 0;
152   virtual int setAttr(const OperatorPtr &op, const std::string &attrKey, const ValuePtr &attrValue) = 0;
153   virtual int setAttr(const OperatorPtr &op, const PrimitivePtr &prim) = 0;
154   virtual int setAttr(const OperatorPtr &op, const AnfNodePtr &node) = 0;
155   virtual int setAttr(const std::string &attrKey, const ValuePtr &attrValue) = 0;
156   virtual int setAttr(const uint32_t &input_idx, const ValuePtr &attrValue) = 0;
157   virtual int getAttr(const std::string &attrKey, ValuePtr *attrValue) = 0;
158   virtual int getAttr(const uint32_t &input_idx, ValuePtr *attrValue) = 0;
159   virtual mindspore::HashMap<std::string, ValuePtr> GetExtraAttr() = 0;
160   template <typename T, typename _ = typename std::enable_if<!std::is_base_of<Value, T>::value>::type>
setAttr(const OperatorPtr & op,const std::string & attrKey,const std::shared_ptr<T> & attrValue)161   int setAttr(const OperatorPtr &op, const std::string &attrKey, const std::shared_ptr<T> &attrValue) {
162     return setAttr(op, attrKey, MakeValue(attrValue));
163   }
164   template <typename T, typename _ = typename std::enable_if<!is_shared_ptr<T>::value>::type>
setAttr(const OperatorPtr & op,const std::string & attrKey,const T & attrValue)165   int setAttr(const OperatorPtr &op, const std::string &attrKey, const T &attrValue) {
166     return setAttr(op, attrKey, MakeValue(attrValue));
167   }
168   virtual std::string getOpType() = 0;
169   virtual OutHandler getOutput(const OperatorPtr &op, int index) = 0;
170   virtual std::vector<OutHandler> getOutputs(const OperatorPtr &op) = 0;
171   virtual void updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type,
172                                 const AnfNodePtr &node) = 0;
173   virtual const mindspore::HashMap<int, InputDesc> &getInputMap() = 0;
174   virtual const mindspore::HashMap<unsigned int, AttrDesc> &getInputAttrMap() = 0;
175   virtual const mindspore::HashMap<std::string, AttrDesc> &getAttrMap() = 0;
176   virtual const mindspore::HashMap<std::string, std::string> &getAttrInputMap() = 0;
177   virtual const mindspore::HashMap<int, DynInputDesc> &getDynInputMap() = 0;
178   virtual const std::map<int, OutputDesc> &getOutputMap() = 0;
179   virtual const mindspore::HashMap<int, DynOutputDesc> &getDynOutputMap() = 0;
180   virtual const mindspore::HashMap<int, SubGraphDesc> &getSubgraphMap() = 0;
181   virtual const mindspore::HashMap<int, DynSubGraphDesc> &getDynSubgraphMap() = 0;
182   virtual std::map<std::string, ValuePtr> GetNormalOpAttrList(const AnfNodePtr &node) = 0;
183   virtual std::map<std::string, ValuePtr> GetOpAttrList() = 0;
184   virtual bool IsDynInputOp(uint64_t index) = 0;
185   virtual bool IsDyOutputOp(uint64_t index) = 0;
186   virtual bool IsMultipleOutputOp(const AnfNodePtr &anf) = 0;
187   virtual bool GetDynamicShapeSupport() = 0;
AddAttrToDrawGraph(const std::string & attr_str)188   void AddAttrToDrawGraph(const std::string &attr_str) { attrs_vec_.push_back(attr_str); }
GetAttrsFromDrawGraph()189   const std::vector<std::string> &GetAttrsFromDrawGraph() const { return attrs_vec_; }
clearAttrVect()190   void clearAttrVect() { attrs_vec_.clear(); }
191 
192  private:
193   std::vector<std::string> attrs_vec_;
194 };
195 
196 using OpAdapterPtr = std::shared_ptr<BaseOpAdapter>;
197 
198 enum AttrType {
199   ATTR_INT = 0,
200   ATTR_FLOAT,
201   ATTR_DOUBLE,
202   ATTR_STRING,
203   ATTR_TENSOR,
204   ATTR_BOOL,
205   ATTR_LIST_INT,
206   ATTR_LIST_ANY_INT,
207   ATTR_ENUM
208 };
209 
210 struct GeEnum {};
211 struct TFType {};
212 struct GEType {};
213 struct GEEnumToStr {};
214 
215 class GEDataFormat {
216  public:
ConvertEnumToString(int64_t id)217   static std::string ConvertEnumToString(int64_t id) {
218     const auto &enum_string = FormatEnumToString(static_cast<mindspore::Format>(id));
219     if (enum_string.empty()) {
220       MS_LOG(EXCEPTION) << "Invalid data format " << id;
221     }
222     return enum_string;
223   }
224 };
225 
226 class AscendQuantRoundMode {
227  public:
ConvertEnumToString(int64_t id)228   static std::string ConvertEnumToString(int64_t id) {
229     static const std::vector<std::string> round_mode = {"round", "trunc", "floor", "ceil"};
230     if (id < 0 || id >= static_cast<int64_t>(round_mode.size())) {
231       MS_LOG(EXCEPTION) << "Invalid AscendQuant round mode " << id;
232       return "";
233     }
234     return round_mode[id];
235   }
236 };
237 
238 class FASInputLayoutMode {
239  public:
ConvertEnumToString(int64_t id)240   static std::string ConvertEnumToString(int64_t id) {
241     static const std::vector<std::string> input_layout_modes = {"BSH", "BNSD", "SBH", "BSND", "TND"};
242     if (id < 0 || id >= static_cast<int64_t>(input_layout_modes.size())) {
243       MS_LOG(EXCEPTION) << "Invalid input layout mode " << id;
244       return "";
245     }
246     return input_layout_modes[id];
247   }
248 };
249 
250 class FFNActivationMode {
251  public:
ConvertEnumToString(int64_t id)252   static std::string ConvertEnumToString(int64_t id) {
253     static const std::vector<std::string> activation_mode = {
254       "no_activation", "relu", "sigmoid", "relu6",   "elu",      "leaky_relu",    "abs",    "relu1",     "softsign",
255       "softplus",      "tanh", "selu",    "hswish",  "hsigmoid", "thresholdrelu", "linear", "hard_tanh", "sign",
256       "swish",         "gelu", "glu",     "unknown", "fastgelu", "silu",          "geglu",  "swiglu",    "reglu"};
257     if (id < 0 || id >= static_cast<int64_t>(activation_mode.size())) {
258       MS_LOG(EXCEPTION) << "Invalid moe ffn activation " << id;
259       return "";
260     }
261     return activation_mode[id];
262   }
263 };
264 
265 class ScatterReduceMode {
266  public:
ConvertEnumToString(int64_t id)267   static std::string ConvertEnumToString(int64_t id) {
268     static const std::vector<std::string> reduce_mode = {"sum", "mean", "none", "update"};
269     if (id < 0 || id >= static_cast<int64_t>(reduce_mode.size())) {
270       MS_LOG(EXCEPTION) << "Invalid reduce mode " << id;
271       return "";
272     }
273     return reduce_mode[id];
274   }
275 };
276 
277 class GEPadMod {
278  public:
ConvertEnumToString(int64_t id)279   static std::string ConvertEnumToString(int64_t id) {
280     static const std::vector<std::string> pad_mods = {"PAD", "SAME", "VALID"};
281     if (id < 0 || id >= static_cast<int64_t>(pad_mods.size())) {
282       MS_LOG(EXCEPTION) << "Invalid pad mod " << id;
283       return "";
284     }
285     return pad_mods[id];
286   }
287 };
288 
289 class GEReduction {
290  public:
ConvertEnumToString(int64_t id)291   static std::string ConvertEnumToString(int64_t id) {
292     static const std::vector<std::string> reductions = {"sum", "mean", "none"};
293     if (id < 0 || id >= static_cast<int64_t>(reductions.size())) {
294       MS_LOG(EXCEPTION) << "Invalid reduction " << id;
295       return "";
296     }
297     return reductions[id];
298   }
299 };
300 
301 class GECoordinateTransformMode {
302  public:
ConvertEnumToString(int64_t id)303   static std::string ConvertEnumToString(int64_t id) {
304     static const std::vector<std::string> modes = {"asymmetric", "align_corners", "half_pixel", "crop_and_resize"};
305     if (id < 0 || id >= static_cast<int64_t>(modes.size())) {
306       MS_LOG(EXCEPTION) << "Invalid CoordinateTransformMode " << id;
307       return "";
308     }
309     return modes[id];
310   }
311 };
312 
313 // declare Any type
314 template <typename T>
315 struct AnyTraits {
316   using type = T;
317 };
318 
319 template <>
320 struct AnyTraits<int> {
321   using type = int64_t;
322 };
323 
324 using ExtraAttr = mindspore::HashMap<std::string, ValuePtr>;
325 }  // namespace transform
326 }  // namespace mindspore
327 #endif  // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_BASE_H_
328