• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_BASE_H_
18 #define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_BASE_H_
19 
20 #include <unordered_map>
21 #include <string>
22 #include <memory>
23 #include <utility>
24 #include <vector>
25 #include <sstream>
26 
27 #include "transform/graph_ir/util.h"
28 #include "ir/anf.h"
29 #include "ir/primitive.h"
30 #include "ir/value.h"
31 #include "transform/graph_ir/types.h"
32 #ifdef ENABLE_GE
33 #ifdef OPEN_SOURCE
34 #include "graph/types.h"
35 #endif
36 #endif
37 
38 #include "graph/operator_reg.h"
39 #ifdef OPEN_SOURCE
40 #include "ge/client/ge_api.h"
41 #else
42 #include "external/ge/ge_api.h"
43 #endif
44 #include "graph/tensor.h"
45 
46 namespace ge {
47 class CustomOperator : public Operator {
48  public:
CustomOperator(const string & name,const string & type)49   CustomOperator(const string &name, const string &type) : Operator(name, type) {}
50 
~CustomOperator()51   ~CustomOperator() override{};
52 
CustomInputRegister(const string & name)53   void CustomInputRegister(const string &name) { Operator::InputRegister(name); }
54 
CustomOutputRegister(const string & name)55   void CustomOutputRegister(const string &name) { Operator::OutputRegister(name); }
56 
CustomInferFuncRegister(const std::function<graphStatus (Operator &)> & func)57   void CustomInferFuncRegister(const std::function<graphStatus(Operator &)> &func) {
58     Operator::InferFuncRegister(func);
59   }
60 };
61 }  // namespace ge
62 
63 namespace mindspore {
64 namespace transform {
65 using CusOperatorPtr = std::shared_ptr<ge::CustomOperator>;
66 using CustomOperator = ge::CustomOperator;
67 
68 struct OutHandler {
69   OperatorPtr op;
70   std::string out;
71   AnfNodePtr node;
OutHandlerOutHandler72   OutHandler() : op(nullptr), out(""), node(nullptr) {}
73   OutHandler(const OperatorPtr &op, const std::string out, const AnfNodePtr &node = nullptr)
opOutHandler74       : op(op), out(out), node(node) {}
75 };
76 
77 struct ControlEdge {
78   OperatorPtr src_op;
79   OperatorPtr dest_op;
80 };
81 
82 using AttrFunc = std::function<void(OperatorPtr, ValuePtr)>;
83 using OutputFunc = std::function<OutHandler(OperatorPtr)>;
84 using InputOpFunc = std::function<void(OperatorPtr, OperatorPtr)>;
85 using InputHandleFunc = std::function<void(OperatorPtr, OutHandler)>;
86 using CreateDynInputOpFunc = std::function<void(OperatorPtr, unsigned int)>;
87 using DynInputOpFunc = std::function<void(OperatorPtr, unsigned int, OperatorPtr)>;
88 using DynInputHandleFunc = std::function<void(OperatorPtr, unsigned int, OutHandler)>;
89 using UpdateOutputDescFunc = std::function<void(OperatorPtr, GeTensorDesc)>;
90 using CreateDynOutputOpFunc = std::function<void(OperatorPtr, unsigned int)>;
91 using CreateDynSubGraphFunc = std::function<void(OperatorPtr, unsigned int)>;
92 using DynSubGraphFunc = std::function<void(OperatorPtr, unsigned int, DfGraphPtr)>;
93 
94 struct AttrDesc {
95   std::string name;
96   AttrFunc set_attr;
97 };
98 
99 struct InputDesc {
100   std::string name;
101   InputOpFunc set_op;
102   InputHandleFunc set_handle;
103   UpdateOutputDescFunc update_input_desc;
104 };
105 
106 struct DynInputDesc {
107   std::string name;
108   CreateDynInputOpFunc create_dyn_input;
109   DynInputOpFunc set_op;
110   DynInputHandleFunc set_handle;
111 };
112 
113 struct DynSubGraphDesc {
114   std::string name;
115   CreateDynSubGraphFunc create_dyn_subgraph;
116   DynSubGraphFunc set_subgraph;
117 };
118 
119 struct OutputDesc {
120   std::string name;
121   UpdateOutputDescFunc update_out_desc;
122 };
123 
124 struct DynOutputDesc {
125   std::string name;
126   CreateDynOutputOpFunc create_dyn_output;
127 };
128 
129 class BaseOpAdapter {
130  public:
~BaseOpAdapter()131   virtual ~BaseOpAdapter() {}
132   virtual OperatorPtr generate(const AnfNodePtr &anf) = 0;
generate(const std::string & type)133   virtual OperatorPtr generate(const std::string &type) { return std::make_shared<ge::Operator>(type); }
134   virtual int setSubgraph(const OperatorPtr &op, int index, std::shared_ptr<std::vector<DfGraph>> branches) = 0;
135   virtual int setInput(const OperatorPtr &op, int index, const OperatorPtr &input) = 0;
136   virtual int setInput(const OperatorPtr &op, int index, const OutHandler &handle) = 0;
137   virtual int setInput(const OperatorPtr &op, int index,
138                        const std::shared_ptr<std::vector<OutHandler>> &handler_vec) = 0;
139   virtual int setAttr(const OperatorPtr &op, const std::string &attrKey, const ValuePtr &attrValue) = 0;
140   virtual int setAttr(const OperatorPtr &op, const PrimitivePtr &prim) = 0;
141   virtual int setAttr(const OperatorPtr &op, const AnfNodePtr &node) = 0;
142   virtual std::unordered_map<std::string, ValuePtr> GetExtraAttr() = 0;
143   template <typename T, typename _ = typename std::enable_if<!std::is_base_of<Value, T>::value>::type>
setAttr(const OperatorPtr & op,const std::string & attrKey,const std::shared_ptr<T> & attrValue)144   int setAttr(const OperatorPtr &op, const std::string &attrKey, const std::shared_ptr<T> &attrValue) {
145     return setAttr(op, attrKey, MakeValue(attrValue));
146   }
147   template <typename T, typename _ = typename std::enable_if<!is_shared_ptr<T>::value>::type>
setAttr(const OperatorPtr & op,const std::string & attrKey,const T & attrValue)148   int setAttr(const OperatorPtr &op, const std::string &attrKey, const T &attrValue) {
149     return setAttr(op, attrKey, MakeValue(attrValue));
150   }
151   virtual OutHandler getOutput(const OperatorPtr &op, int index) = 0;
152   virtual void updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type,
153                                 const AnfNodePtr &node) = 0;
154   virtual const std::unordered_map<int, InputDesc> &getInputMap() = 0;
155   virtual const std::unordered_map<unsigned int, AttrDesc> &getInputAttrMap() = 0;
156   virtual const std::unordered_map<int, DynInputDesc> &getDynInputMap() = 0;
157   virtual const std::unordered_map<int, OutputDesc> &getOutputMap() = 0;
158   virtual const std::unordered_map<int, DynSubGraphDesc> &getDynSubgraphMap() = 0;
AddAttrToDrawGraph(const std::string & attr_str)159   void AddAttrToDrawGraph(const std::string &attr_str) { attrs_vec_.push_back(attr_str); }
GetAttrsFromDrawGraph()160   const std::vector<std::string> &GetAttrsFromDrawGraph() const { return attrs_vec_; }
clearAttrVect()161   void clearAttrVect() { attrs_vec_.clear(); }
162 
163  private:
164   std::vector<std::string> attrs_vec_;
165 };
166 
167 using OpAdapterPtr = std::shared_ptr<BaseOpAdapter>;
168 
169 enum AttrType {
170   ATTR_INT = 0,
171   ATTR_FLOAT,
172   ATTR_DOUBLE,
173   ATTR_STRING,
174   ATTR_TENSOR,
175   ATTR_BOOL,
176   ATTR_LIST_INT,
177   ATTR_LIST_ANY_INT,
178   ATTR_ENUM
179 };
180 
181 struct GeEnum {};
182 struct TFType {};
183 struct GEType {};
184 
185 // declare Any type
186 template <typename T>
187 struct AnyTraits {
188   using type = T;
189 };
190 
191 template <>
192 struct AnyTraits<int> {
193   using type = int64_t;
194 };
195 
196 using ExtraAttr = std::unordered_map<std::string, ValuePtr>;
197 }  // namespace transform
198 }  // namespace mindspore
199 #endif  // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_BASE_H_
200