1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_BASE_H_ 18 #define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_BASE_H_ 19 20 #include <unordered_map> 21 #include <string> 22 #include <memory> 23 #include <utility> 24 #include <vector> 25 #include <sstream> 26 27 #include "transform/graph_ir/util.h" 28 #include "ir/anf.h" 29 #include "ir/primitive.h" 30 #include "ir/value.h" 31 #include "transform/graph_ir/types.h" 32 #ifdef ENABLE_GE 33 #ifdef OPEN_SOURCE 34 #include "graph/types.h" 35 #endif 36 #endif 37 38 #include "graph/operator_reg.h" 39 #ifdef OPEN_SOURCE 40 #include "ge/client/ge_api.h" 41 #else 42 #include "external/ge/ge_api.h" 43 #endif 44 #include "graph/tensor.h" 45 46 namespace ge { 47 class CustomOperator : public Operator { 48 public: CustomOperator(const string & name,const string & type)49 CustomOperator(const string &name, const string &type) : Operator(name, type) {} 50 ~CustomOperator()51 ~CustomOperator() override{}; 52 CustomInputRegister(const string & name)53 void CustomInputRegister(const string &name) { Operator::InputRegister(name); } 54 CustomOutputRegister(const string & name)55 void CustomOutputRegister(const string &name) { Operator::OutputRegister(name); } 56 CustomInferFuncRegister(const std::function<graphStatus (Operator &)> & func)57 void CustomInferFuncRegister(const std::function<graphStatus(Operator &)> &func) { 58 Operator::InferFuncRegister(func); 59 } 60 }; 61 } // namespace ge 62 63 namespace mindspore { 64 namespace transform { 65 using CusOperatorPtr = std::shared_ptr<ge::CustomOperator>; 66 using CustomOperator = ge::CustomOperator; 67 68 struct OutHandler { 69 OperatorPtr op; 70 std::string out; 71 AnfNodePtr node; OutHandlerOutHandler72 OutHandler() : op(nullptr), out(""), node(nullptr) {} 73 OutHandler(const OperatorPtr &op, const std::string out, const AnfNodePtr &node = nullptr) opOutHandler74 : op(op), out(out), node(node) {} 75 }; 76 77 struct ControlEdge { 78 OperatorPtr src_op; 79 OperatorPtr dest_op; 80 }; 81 82 using AttrFunc = std::function<void(OperatorPtr, ValuePtr)>; 83 using OutputFunc = std::function<OutHandler(OperatorPtr)>; 84 using InputOpFunc = std::function<void(OperatorPtr, OperatorPtr)>; 85 using InputHandleFunc = std::function<void(OperatorPtr, OutHandler)>; 86 using CreateDynInputOpFunc = std::function<void(OperatorPtr, unsigned int)>; 87 using DynInputOpFunc = std::function<void(OperatorPtr, unsigned int, OperatorPtr)>; 88 using DynInputHandleFunc = std::function<void(OperatorPtr, unsigned int, OutHandler)>; 89 using UpdateOutputDescFunc = std::function<void(OperatorPtr, GeTensorDesc)>; 90 using CreateDynOutputOpFunc = std::function<void(OperatorPtr, unsigned int)>; 91 using CreateDynSubGraphFunc = std::function<void(OperatorPtr, unsigned int)>; 92 using DynSubGraphFunc = std::function<void(OperatorPtr, unsigned int, DfGraphPtr)>; 93 94 struct AttrDesc { 95 std::string name; 96 AttrFunc set_attr; 97 }; 98 99 struct InputDesc { 100 std::string name; 101 InputOpFunc set_op; 102 InputHandleFunc set_handle; 103 UpdateOutputDescFunc update_input_desc; 104 }; 105 106 struct DynInputDesc { 107 std::string name; 108 CreateDynInputOpFunc create_dyn_input; 109 DynInputOpFunc set_op; 110 DynInputHandleFunc set_handle; 111 }; 112 113 struct DynSubGraphDesc { 114 std::string name; 115 CreateDynSubGraphFunc create_dyn_subgraph; 116 DynSubGraphFunc set_subgraph; 117 }; 118 119 struct OutputDesc { 120 std::string name; 121 UpdateOutputDescFunc update_out_desc; 122 }; 123 124 struct DynOutputDesc { 125 std::string name; 126 CreateDynOutputOpFunc create_dyn_output; 127 }; 128 129 class BaseOpAdapter { 130 public: ~BaseOpAdapter()131 virtual ~BaseOpAdapter() {} 132 virtual OperatorPtr generate(const AnfNodePtr &anf) = 0; generate(const std::string & type)133 virtual OperatorPtr generate(const std::string &type) { return std::make_shared<ge::Operator>(type); } 134 virtual int setSubgraph(const OperatorPtr &op, int index, std::shared_ptr<std::vector<DfGraph>> branches) = 0; 135 virtual int setInput(const OperatorPtr &op, int index, const OperatorPtr &input) = 0; 136 virtual int setInput(const OperatorPtr &op, int index, const OutHandler &handle) = 0; 137 virtual int setInput(const OperatorPtr &op, int index, 138 const std::shared_ptr<std::vector<OutHandler>> &handler_vec) = 0; 139 virtual int setAttr(const OperatorPtr &op, const std::string &attrKey, const ValuePtr &attrValue) = 0; 140 virtual int setAttr(const OperatorPtr &op, const PrimitivePtr &prim) = 0; 141 virtual int setAttr(const OperatorPtr &op, const AnfNodePtr &node) = 0; 142 virtual std::unordered_map<std::string, ValuePtr> GetExtraAttr() = 0; 143 template <typename T, typename _ = typename std::enable_if<!std::is_base_of<Value, T>::value>::type> setAttr(const OperatorPtr & op,const std::string & attrKey,const std::shared_ptr<T> & attrValue)144 int setAttr(const OperatorPtr &op, const std::string &attrKey, const std::shared_ptr<T> &attrValue) { 145 return setAttr(op, attrKey, MakeValue(attrValue)); 146 } 147 template <typename T, typename _ = typename std::enable_if<!is_shared_ptr<T>::value>::type> setAttr(const OperatorPtr & op,const std::string & attrKey,const T & attrValue)148 int setAttr(const OperatorPtr &op, const std::string &attrKey, const T &attrValue) { 149 return setAttr(op, attrKey, MakeValue(attrValue)); 150 } 151 virtual OutHandler getOutput(const OperatorPtr &op, int index) = 0; 152 virtual void updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, 153 const AnfNodePtr &node) = 0; 154 virtual const std::unordered_map<int, InputDesc> &getInputMap() = 0; 155 virtual const std::unordered_map<unsigned int, AttrDesc> &getInputAttrMap() = 0; 156 virtual const std::unordered_map<int, DynInputDesc> &getDynInputMap() = 0; 157 virtual const std::unordered_map<int, OutputDesc> &getOutputMap() = 0; 158 virtual const std::unordered_map<int, DynSubGraphDesc> &getDynSubgraphMap() = 0; AddAttrToDrawGraph(const std::string & attr_str)159 void AddAttrToDrawGraph(const std::string &attr_str) { attrs_vec_.push_back(attr_str); } GetAttrsFromDrawGraph()160 const std::vector<std::string> &GetAttrsFromDrawGraph() const { return attrs_vec_; } clearAttrVect()161 void clearAttrVect() { attrs_vec_.clear(); } 162 163 private: 164 std::vector<std::string> attrs_vec_; 165 }; 166 167 using OpAdapterPtr = std::shared_ptr<BaseOpAdapter>; 168 169 enum AttrType { 170 ATTR_INT = 0, 171 ATTR_FLOAT, 172 ATTR_DOUBLE, 173 ATTR_STRING, 174 ATTR_TENSOR, 175 ATTR_BOOL, 176 ATTR_LIST_INT, 177 ATTR_LIST_ANY_INT, 178 ATTR_ENUM 179 }; 180 181 struct GeEnum {}; 182 struct TFType {}; 183 struct GEType {}; 184 185 // declare Any type 186 template <typename T> 187 struct AnyTraits { 188 using type = T; 189 }; 190 191 template <> 192 struct AnyTraits<int> { 193 using type = int64_t; 194 }; 195 196 using ExtraAttr = std::unordered_map<std::string, ValuePtr>; 197 } // namespace transform 198 } // namespace mindspore 199 #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_BASE_H_ 200