1 /** 2 * Copyright 2019-2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_H_ 18 #define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_H_ 19 20 #include <memory> 21 #include <utility> 22 #include <vector> 23 #include <string> 24 #include <map> 25 #include "utils/hash_map.h" 26 #include "transform/graph_ir/op_adapter_util.h" 27 #include "transform/graph_ir/op_adapter_base.h" 28 #include "include/common/utils/utils.h" 29 #include "include/common/utils/anfalgo.h" 30 #include "ops/other_ops.h" 31 #include "ops/sequence_ops.h" 32 #include "ops/framework_ops.h" 33 #include "ops/op_utils.h" 34 namespace mindspore { 35 namespace transform { 36 class OpAdapterImpl { 37 public: OpAdapterImpl(const mindspore::HashMap<int,InputDesc> & input_map,const mindspore::HashMap<int,DynInputDesc> & dyn_input_map,const std::map<int,OutputDesc> & output_map,const mindspore::HashMap<int,DynOutputDesc> & dyn_output_map,const mindspore::HashMap<int,SubGraphDesc> & subgraph_map,const mindspore::HashMap<int,DynSubGraphDesc> & dyn_subgraph_map,const mindspore::HashMap<std::string,AttrDesc> & attr_map,const mindspore::HashMap<std::string,int> & enum_map,const mindspore::HashMap<unsigned int,AttrDesc> & input_attr_map,const mindspore::HashMap<std::string,std::string> & attr_input_map,mindspore::HashMap<std::string,mindspore::HashMap<int,std::string>> * cus_input_map,mindspore::HashMap<std::string,std::map<int,std::string>> * cus_output_map,mindspore::HashMap<std::string,ValuePtr> * extra_attr,mindspore::HashMap<std::string,int> * name_counts,BaseOpAdapter * adpt)38 OpAdapterImpl(const mindspore::HashMap<int, InputDesc> &input_map, 39 const mindspore::HashMap<int, DynInputDesc> &dyn_input_map, const std::map<int, OutputDesc> &output_map, 40 const mindspore::HashMap<int, DynOutputDesc> &dyn_output_map, 41 const mindspore::HashMap<int, SubGraphDesc> &subgraph_map, 42 const mindspore::HashMap<int, DynSubGraphDesc> &dyn_subgraph_map, 43 const mindspore::HashMap<std::string, AttrDesc> &attr_map, 44 const mindspore::HashMap<std::string, int> &enum_map, 45 const mindspore::HashMap<unsigned int, AttrDesc> &input_attr_map, 46 const mindspore::HashMap<std::string, std::string> &attr_input_map, 47 mindspore::HashMap<std::string, mindspore::HashMap<int, std::string>> *cus_input_map, 48 mindspore::HashMap<std::string, std::map<int, std::string>> *cus_output_map, 49 mindspore::HashMap<std::string, ValuePtr> *extra_attr, 50 mindspore::HashMap<std::string, int> *name_counts, BaseOpAdapter *adpt) 51 : input_map_(input_map), 52 dyn_input_map_(dyn_input_map), 53 output_map_(output_map), 54 dyn_output_map_(dyn_output_map), 55 subgraph_map_(subgraph_map), 56 dyn_subgraph_map_(dyn_subgraph_map), 57 attr_map_(attr_map), 58 enum_map_(enum_map), 59 input_attr_map_(input_attr_map), 60 attr_input_map_(attr_input_map), 61 cus_input_map_(cus_input_map), 62 cus_output_map_(cus_output_map), 63 extra_attr_(extra_attr), 64 name_counts_(name_counts), 65 adpt_(adpt) { 66 MS_EXCEPTION_IF_NULL(cus_input_map_); 67 MS_EXCEPTION_IF_NULL(cus_output_map_); 68 MS_EXCEPTION_IF_NULL(extra_attr_); 69 MS_EXCEPTION_IF_NULL(name_counts_); 70 MS_EXCEPTION_IF_NULL(adpt_); 71 } ~OpAdapterImpl()72 ~OpAdapterImpl() {} 73 bool IsCustomOp(const OperatorPtr &op) const; 74 std::string GetCustomOpType(const PrimitivePtr &prim) const; 75 Status GenerateCustomOpInputMap(const CusOperatorPtr &op, const PrimitivePtr &prim); 76 Status GenerateCustomOpOutputMap(const CusOperatorPtr &op, const PrimitivePtr &prim); 77 OperatorPtr GenerateCustomOp(const AnfNodePtr anf); 78 Status SetOpSubgraphFunc(const OperatorPtr &op, const std::shared_ptr<std::vector<DfGraph>> &subgraphs); 79 Status SetOpSubgraphFunc(const OperatorPtr &op, int index, const std::shared_ptr<std::vector<DfGraph>> &branches); 80 Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OperatorPtr &input) const; 81 Status SetNormalOpInput(const OperatorPtr &op, int index, const OperatorPtr &input); 82 int setInput(const OperatorPtr &op, int index, const OperatorPtr &input); 83 Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OutHandler &handle) const; 84 Status SetNormalOpInput(const OperatorPtr &op, int index, const OutHandler &handle); 85 int setInput(const OperatorPtr &op, int index, const OutHandler &handle); 86 int setInput(const OperatorPtr &op, int index, const std::shared_ptr<std::vector<OutHandler>> &handler_vec, 87 bool use_create_byindex_func = false, size_t dyn_index = 0); 88 OutHandler getOutput(const OperatorPtr &op, int index); 89 std::vector<OutHandler> getOutputs(const OperatorPtr &op) const; 90 OutHandler getCustomOutput(const OperatorPtr &op, int index) const; 91 OutHandler getNormalOutput(const OperatorPtr &op, int index); 92 std::vector<OutHandler> getNormalOutputs(const OperatorPtr &op) const; 93 std::vector<OutHandler> getCustomOutputs(const OperatorPtr &op) const; 94 Status UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, 95 const std::string &format); 96 size_t GetCustomOpOutputSize(const CusOperatorPtr &cus_op) const; 97 std::map<std::string, ValuePtr> GetNormalOpAttrList(const OperatorPtr &op, const AnfNodePtr &node) const; 98 std::map<std::string, ValuePtr> GetOpAttrList(const OperatorPtr &op) const; 99 std::shared_ptr<GeTensorDesc> CreateOutputDesc(const abstract::ShapePtr &shape_ptr, const TypePtr &type, 100 const std::string &format) const; 101 Status UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, 102 const std::string &format); 103 std::shared_ptr<GeTensorDesc> CreateNodeDesc(const AnfNodePtr &node, const std::string &format) const; 104 void UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr &node, const std::string format); 105 void UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node, const std::string format) const; 106 void updateInputDesc(const OperatorPtr &op, const AnfNodePtr &node); 107 void updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, 108 const AnfNodePtr &node); 109 int setAttr(const OperatorPtr &op, const std::string &attr_key, const ValuePtr &attr_value); 110 int SetCustomOpAttr(const CusOperatorPtr &op, const PrimitivePtr &prim) const; 111 int SetNormalOpAttr(const OperatorPtr &op, const PrimitivePtr &prim); 112 int SetNoFoldingOpAttr(const OperatorPtr &op, const PrimitivePtr &prim); 113 int setAttr(const OperatorPtr &op, const PrimitivePtr &prim); 114 int setAttr(const OperatorPtr &op, const AnfNodePtr &node); 115 int setAttr(const OperatorPtr &op, const uint32_t &input_idx, const ValuePtr &attr_value); 116 int getAttr(const OperatorPtr &op, const std::string &attr_key, ValuePtr *attr_value); 117 int getAttr(const OperatorPtr &op, uint32_t input_idx, ValuePtr *attr_value); 118 119 private: 120 const mindspore::HashMap<int, InputDesc> &input_map_; 121 const mindspore::HashMap<int, DynInputDesc> &dyn_input_map_; 122 const std::map<int, OutputDesc> &output_map_; 123 const mindspore::HashMap<int, DynOutputDesc> &dyn_output_map_; 124 const mindspore::HashMap<int, SubGraphDesc> &subgraph_map_; 125 const mindspore::HashMap<int, DynSubGraphDesc> &dyn_subgraph_map_; 126 const mindspore::HashMap<std::string, AttrDesc> &attr_map_; 127 const mindspore::HashMap<std::string, int> &enum_map_; 128 // NOTE: The key of input_attr_map_ is anf node index, so index 0 is primitive value node 129 const mindspore::HashMap<unsigned int, AttrDesc> &input_attr_map_; 130 const mindspore::HashMap<std::string, std::string> &attr_input_map_; 131 mindspore::HashMap<std::string, mindspore::HashMap<int, std::string>> *const cus_input_map_; 132 mindspore::HashMap<std::string, std::map<int, std::string>> *const cus_output_map_; 133 mindspore::HashMap<std::string, ValuePtr> *const extra_attr_; 134 mindspore::HashMap<std::string, int> *const name_counts_; 135 BaseOpAdapter *const adpt_; 136 }; 137 138 template <typename T> 139 class OpAdapter : public BaseOpAdapter { 140 public: 141 using OpType = T; OpAdapter(std::string op_type_obj)142 explicit OpAdapter(std::string op_type_obj) 143 : op_type_obj_(std::move(op_type_obj)), 144 impl_(std::make_shared<OpAdapterImpl>(input_map_, dyn_input_map_, output_map_, dyn_output_map_, subgraph_map_, 145 dyn_subgraph_map_, attr_map_, enum_map_, input_attr_map_, attr_input_map_, 146 &cus_input_map_, &cus_output_map_, &extra_attr_, &name_counts_, this)) { 147 MS_EXCEPTION_IF_NULL(impl_); 148 } OpAdapter(std::string op_type_obj,ExtraAttr extra_attr)149 explicit OpAdapter(std::string op_type_obj, ExtraAttr extra_attr) 150 : op_type_obj_(std::move(op_type_obj)), 151 extra_attr_(std::move(extra_attr)), 152 impl_(std::make_shared<OpAdapterImpl>(input_map_, dyn_input_map_, output_map_, dyn_output_map_, subgraph_map_, 153 dyn_subgraph_map_, attr_map_, enum_map_, input_attr_map_, attr_input_map_, 154 &cus_input_map_, &cus_output_map_, &extra_attr_, &name_counts_, this)) { 155 MS_EXCEPTION_IF_NULL(impl_); 156 } 157 ~OpAdapter() override = default; 158 IsCustomOp(const OperatorPtr & op)159 bool IsCustomOp(const OperatorPtr &op) { return impl_->IsCustomOp(op); } 160 GenerateCustomOpInputMap(const CusOperatorPtr & op,const PrimitivePtr & prim)161 Status GenerateCustomOpInputMap(const CusOperatorPtr &op, const PrimitivePtr &prim) { 162 return impl_->GenerateCustomOpInputMap(op, prim); 163 } 164 GenerateCustomOpOutputMap(const CusOperatorPtr & op,const PrimitivePtr & prim)165 Status GenerateCustomOpOutputMap(const CusOperatorPtr &op, const PrimitivePtr &prim) { 166 return impl_->GenerateCustomOpOutputMap(op, prim); 167 } 168 169 // Convert ME UserCustom AnfNode to GE CustomOp. And set it's attrs. GenerateCustomOp(const AnfNodePtr anf)170 OperatorPtr GenerateCustomOp(const AnfNodePtr anf) { return impl_->GenerateCustomOp(anf); } 171 GenerateNormalOp(const AnfNodePtr & anf)172 OperatorPtr GenerateNormalOp(const AnfNodePtr &anf) const { 173 OperatorPtr op = nullptr; 174 std::string op_name; 175 if (anf != nullptr && !anf->fullname_with_scope().empty()) { 176 op_name = anf->fullname_with_scope(); 177 } 178 op = generate(op_name); 179 180 // set dynamic output num if op use DYNAMIC_OUTPUT 181 if ((op != nullptr) && (!dyn_output_map_.empty()) && (anf != nullptr)) { 182 TypePtr type = anf->Type(); 183 if (type == nullptr) { 184 MS_LOG(EXCEPTION) << "Dynamic output node:" << op->GetName() << "'s Type is a nullptr!"; 185 } 186 auto num = GetOutputSize(type); 187 188 auto judge_node = anf; 189 if (common::AnfAlgo::CheckPrimitiveType(anf, prim::kPrimReturn)) { 190 auto cnode = anf->cast<CNodePtr>(); 191 MS_EXCEPTION_IF_NULL(cnode); 192 auto input_node = cnode->inputs()[1]; 193 if (common::AnfAlgo::CheckPrimitiveType(judge_node, prim::kPrimMakeTuple)) { 194 judge_node = input_node; 195 } 196 auto judge_cnode = judge_node->cast<CNodePtr>(); 197 if (judge_cnode != nullptr) { 198 auto inputs = judge_cnode->inputs(); 199 for (const auto &input : inputs) { 200 if (common::AnfAlgo::IsNoOuputNode(input)) { 201 --num; 202 } 203 } 204 } 205 } 206 207 if (common::AnfAlgo::CheckPrimitiveType(judge_node, prim::kPrimMakeTuple)) { 208 auto cnode = judge_node->cast<CNodePtr>(); 209 if (cnode != nullptr) { 210 auto inputs = cnode->inputs(); 211 for (const auto &input : inputs) { 212 if (common::AnfAlgo::IsNoOuputNode(input)) { 213 --num; 214 } 215 } 216 } 217 } 218 219 MS_LOG(INFO) << "create_dyn_output for node:" << anf->fullname_with_scope() << ", type:" << type->ToString() 220 << ", num:" << num; 221 dyn_output_map_.begin()->second.create_dyn_output(op, static_cast<unsigned int>(num)); 222 } 223 return op; 224 } 225 GenerateDynamicOutputOp(const AnfNodePtr & anf)226 OperatorPtr GenerateDynamicOutputOp(const AnfNodePtr &anf) const { 227 OperatorPtr op = nullptr; 228 std::string op_name; 229 if (anf != nullptr && !anf->fullname_with_scope().empty()) { 230 op_name = anf->fullname_with_scope(); 231 } 232 op = generate(op_name); 233 return op; 234 } 235 setDynamicOutputNum(const OperatorPtr & op,size_t dyn_output_size)236 void setDynamicOutputNum(const OperatorPtr &op, size_t dyn_output_size) override { 237 // set dynamic output num if op use DYNAMIC_OUTPUT 238 if ((op != nullptr) && (!dyn_output_map_.empty())) { 239 MS_LOG(DEBUG) << "create_dyn_output for node:" << op->GetName() << ", num:" << dyn_output_size; 240 dyn_output_map_.begin()->second.create_dyn_output(op, static_cast<unsigned int>(dyn_output_size)); 241 } 242 } 243 generate(const AnfNodePtr & anf)244 OperatorPtr generate(const AnfNodePtr &anf) override { 245 OperatorPtr op = nullptr; 246 if (IsCustomCNode(anf)) { 247 op = GenerateCustomOp(anf); 248 } else { 249 op = GenerateNormalOp(anf); 250 } 251 if (op == nullptr) { 252 MS_LOG(EXCEPTION) << "Can not generate op for " << anf->fullname_with_scope(); 253 } 254 return op; 255 } 256 generate(const std::string & op_name)257 OperatorPtr generate(const std::string &op_name) const override { 258 std::string op_name_fix = op_name; 259 if (op_name_fix.empty()) { 260 // There are duplicate names in ANF graph, do not assign ANF node name to GE 261 // GE will generate unique name automatically 262 static int64_t idx = 0; 263 op_name_fix = op_type_obj_ + "_NULL_" + std::to_string(idx++); 264 } 265 if (!::ge::OperatorFactory::IsExistOp(op_type_obj_.c_str())) { 266 MS_LOG(ERROR) << "OperatorFactory is not exist, op type: " << op_type_obj_; 267 return std::make_shared<Operator>(Operator(op_name_fix, op_type_obj_)); 268 } 269 auto op = ::ge::OperatorFactory::CreateOperator(op_name_fix, op_type_obj_); 270 return std::make_shared<Operator>(op); 271 } 272 generateDynOutputOp(const AnfNodePtr & anf)273 OperatorPtr generateDynOutputOp(const AnfNodePtr &anf) override { 274 OperatorPtr op = nullptr; 275 op = GenerateDynamicOutputOp(anf); 276 if (op == nullptr) { 277 MS_LOG(EXCEPTION) << "Can not generate op for " << anf->fullname_with_scope(); 278 } 279 return op; 280 } 281 getOpType()282 std::string getOpType() override { return op_type_obj_; } 283 GetStaticOpType()284 static std::string GetStaticOpType() { return op_type_; } 285 getInputMap()286 const mindspore::HashMap<int, InputDesc> &getInputMap() override { return input_map_; } getInputAttrMap()287 const mindspore::HashMap<unsigned int, AttrDesc> &getInputAttrMap() override { return input_attr_map_; } getAttrMap()288 const mindspore::HashMap<std::string, AttrDesc> &getAttrMap() override { return attr_map_; } getAttrInputMap()289 const mindspore::HashMap<std::string, std::string> &getAttrInputMap() override { return attr_input_map_; } getDynInputMap()290 const mindspore::HashMap<int, DynInputDesc> &getDynInputMap() override { return dyn_input_map_; } getSubgraphMap()291 const mindspore::HashMap<int, SubGraphDesc> &getSubgraphMap() override { return subgraph_map_; } getOutputMap()292 const std::map<int, OutputDesc> &getOutputMap() override { return output_map_; } getDynOutputMap()293 const mindspore::HashMap<int, DynOutputDesc> &getDynOutputMap() override { return dyn_output_map_; } getDynSubgraphMap()294 const mindspore::HashMap<int, DynSubGraphDesc> &getDynSubgraphMap() override { return dyn_subgraph_map_; } GetNormalOpAttrList(const AnfNodePtr & node)295 std::map<std::string, ValuePtr> GetNormalOpAttrList(const AnfNodePtr &node) override { 296 return impl_->GetNormalOpAttrList(getOp(), node); 297 } GetOpAttrList()298 std::map<std::string, ValuePtr> GetOpAttrList() override { return impl_->GetOpAttrList(getOp()); } IsDynInputOp(uint64_t index)299 bool IsDynInputOp(uint64_t index) override { return dyn_input_map_.find(index) != dyn_input_map_.end(); } IsDyOutputOp(uint64_t index)300 bool IsDyOutputOp(uint64_t index) override { return dyn_output_map_.find(index) != dyn_output_map_.end(); } IsMultipleOutputOp(const AnfNodePtr & anf)301 bool IsMultipleOutputOp(const AnfNodePtr &anf) override { 302 if (IsCustomCNode(anf)) { 303 // Custom op 304 auto node = anf->cast<CNodePtr>(); 305 MS_EXCEPTION_IF_NULL(node); 306 auto prim = GetValueNode<PrimitivePtr>(node->inputs().at(0)); 307 MS_EXCEPTION_IF_NULL(prim); 308 auto op_type = impl_->GetCustomOpType(prim); 309 if (cus_output_map_.find(op_type) != cus_output_map_.end()) { 310 return cus_output_map_[op_type].size() > 1; 311 } 312 return false; 313 } 314 // Normal op 315 return output_map_.size() > 1; 316 } 317 SetOpSubgraphFunc(const OperatorPtr & op,std::shared_ptr<std::vector<DfGraph>> subgraphs)318 Status SetOpSubgraphFunc(const OperatorPtr &op, std::shared_ptr<std::vector<DfGraph>> subgraphs) { 319 return impl_->SetOpSubgraphFunc(op, subgraphs); 320 } 321 setSubgraph(const OperatorPtr & op,std::shared_ptr<std::vector<DfGraph>> subgraphs)322 void setSubgraph(const OperatorPtr &op, std::shared_ptr<std::vector<DfGraph>> subgraphs) override { 323 (void)SetOpSubgraphFunc(op, subgraphs); 324 } 325 SetOpSubgraphFunc(const OperatorPtr & op,int index,const std::shared_ptr<std::vector<DfGraph>> & branches)326 Status SetOpSubgraphFunc(const OperatorPtr &op, int index, const std::shared_ptr<std::vector<DfGraph>> &branches) { 327 return impl_->SetOpSubgraphFunc(op, index, branches); 328 } 329 setSubgraph(const OperatorPtr & op,int index,const std::shared_ptr<std::vector<DfGraph>> & branches)330 void setSubgraph(const OperatorPtr &op, int index, const std::shared_ptr<std::vector<DfGraph>> &branches) override { 331 (void)SetOpSubgraphFunc(op, index, branches); 332 } 333 SetCustomOpInput(const CusOperatorPtr & op,int index,const OperatorPtr & input)334 Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OperatorPtr &input) { 335 return impl_->SetCustomOpInput(op, index, input); 336 } 337 SetNormalOpInput(const OperatorPtr & op,int index,const OperatorPtr & input)338 Status SetNormalOpInput(const OperatorPtr &op, int index, const OperatorPtr &input) { 339 return impl_->SetNormalOpInput(op, index, input); 340 } 341 setInput(const OperatorPtr & op,int index,const OperatorPtr & input)342 int setInput(const OperatorPtr &op, int index, const OperatorPtr &input) override { 343 return impl_->setInput(op, index, input); 344 } 345 SetCustomOpInput(const CusOperatorPtr & op,int index,const OutHandler & handle)346 Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OutHandler &handle) { 347 return impl_->SetCustomOpInput(op, index, handle); 348 } 349 SetNormalOpInput(const OperatorPtr & op,int index,const OutHandler & handle)350 Status SetNormalOpInput(const OperatorPtr &op, int index, const OutHandler &handle) { 351 return impl_->SetNormalOpInput(op, index, handle); 352 } 353 setInput(const OperatorPtr & op,int index,const OutHandler & handle)354 int setInput(const OperatorPtr &op, int index, const OutHandler &handle) override { 355 return impl_->setInput(op, index, handle); 356 } 357 358 int setInput(const OperatorPtr &op, int index, const std::shared_ptr<std::vector<OutHandler>> &handler_vec, 359 bool use_create_byindex_func = false, size_t dyn_index = 0) override { 360 return impl_->setInput(op, index, handler_vec, use_create_byindex_func, dyn_index); 361 } 362 getOutput(const OperatorPtr & op,int index)363 OutHandler getOutput(const OperatorPtr &op, int index) override { return impl_->getOutput(op, index); } 364 getOutputs(const OperatorPtr & op)365 std::vector<OutHandler> getOutputs(const OperatorPtr &op) override { return impl_->getOutputs(op); } 366 getCustomOutput(const OperatorPtr & op,int index)367 OutHandler getCustomOutput(const OperatorPtr &op, int index) { return impl_->getCustomOutput(op, index); } 368 getNormalOutput(const OperatorPtr & op,int index)369 OutHandler getNormalOutput(const OperatorPtr &op, int index) { return impl_->getNormalOutput(op, index); } 370 UpdateSingleOutputDesc(const OperatorPtr & op,const abstract::BaseShapePtr & shp,const TypePtr & type,const std::string & format)371 Status UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, 372 const std::string &format) { 373 return impl_->UpdateSingleOutputDesc(op, shp, type, format); 374 } 375 GetCustomOpOutputSize(const CusOperatorPtr & cus_op)376 size_t GetCustomOpOutputSize(const CusOperatorPtr &cus_op) { return impl_->GetCustomOpOutputSize(cus_op); } 377 CreateOutputDesc(const abstract::ShapePtr & shape_ptr,const TypePtr & type,const std::string & format)378 std::shared_ptr<GeTensorDesc> CreateOutputDesc(const abstract::ShapePtr &shape_ptr, const TypePtr &type, 379 const std::string &format) { 380 return impl_->CreateOutputDesc(shape_ptr, type, format); 381 } 382 UpdateMultiOutputDesc(const OperatorPtr & op,const abstract::BaseShapePtr & shp,const TypePtr & type,const std::string & format)383 Status UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, 384 const std::string &format) { 385 return impl_->UpdateMultiOutputDesc(op, shp, type, format); 386 } 387 CreateNodeDesc(const AnfNodePtr & node,const std::string & format)388 std::shared_ptr<GeTensorDesc> CreateNodeDesc(const AnfNodePtr &node, const std::string &format) { 389 return impl_->CreateNodeDesc(node, format); 390 } 391 UpdateNormalOpInputDesc(const OperatorPtr & op,const AnfNodePtr node,const std::string format)392 void UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr node, const std::string format) { 393 return impl_->UpdateNormalOpInputDesc(op, node, format); 394 } 395 UpdateCustomOpInputDesc(const CusOperatorPtr & op,const AnfNodePtr & node,const std::string format)396 void UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node, const std::string format) { 397 return impl_->UpdateCustomOpInputDesc(op, node, format); 398 } 399 updateInputDesc(const OperatorPtr & op,const AnfNodePtr & node)400 void updateInputDesc(const OperatorPtr &op, const AnfNodePtr &node) { impl_->updateInputDesc(op, node); } 401 updateOutputDesc(const OperatorPtr & op,const abstract::BaseShapePtr & shp,const TypePtr & type,const AnfNodePtr & node)402 void updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, 403 const AnfNodePtr &node) override { 404 impl_->updateOutputDesc(op, shp, type, node); 405 } 406 setAttr(const OperatorPtr & op,const std::string & attrKey,const ValuePtr & attrValue)407 int setAttr(const OperatorPtr &op, const std::string &attrKey, const ValuePtr &attrValue) override { 408 return impl_->setAttr(op, attrKey, attrValue); 409 } 410 SetCustomOpAttr(const CusOperatorPtr & op,const PrimitivePtr & prim)411 int SetCustomOpAttr(const CusOperatorPtr &op, const PrimitivePtr &prim) { return impl_->SetCustomOpAttr(op, prim); } 412 SetNormalOpAttr(const OperatorPtr & op,const PrimitivePtr & prim)413 int SetNormalOpAttr(const OperatorPtr &op, const PrimitivePtr &prim) { return impl_->SetNormalOpAttr(op, prim); } 414 setAttr(const OperatorPtr & op,const PrimitivePtr & prim)415 int setAttr(const OperatorPtr &op, const PrimitivePtr &prim) override { return impl_->setAttr(op, prim); } 416 setAttr(const OperatorPtr & op,const AnfNodePtr & node)417 int setAttr(const OperatorPtr &op, const AnfNodePtr &node) override { return impl_->setAttr(op, node); } 418 setAttr(const std::string & attr_key,const ValuePtr & attr_value)419 int setAttr(const std::string &attr_key, const ValuePtr &attr_value) override { 420 return impl_->setAttr(getOp(), attr_key, attr_value); 421 } 422 setAttr(const uint32_t & input_idx,const ValuePtr & attr_value)423 int setAttr(const uint32_t &input_idx, const ValuePtr &attr_value) override { 424 return impl_->setAttr(getOp(), input_idx, attr_value); 425 } 426 getAttr(const std::string & attr_key,ValuePtr * attr_value)427 int getAttr(const std::string &attr_key, ValuePtr *attr_value) override { 428 MS_EXCEPTION_IF_NULL(attr_value); 429 return impl_->getAttr(getOp(), attr_key, attr_value); 430 } getAttr(const uint32_t & input_idx,ValuePtr * attr_value)431 int getAttr(const uint32_t &input_idx, ValuePtr *attr_value) override { 432 MS_EXCEPTION_IF_NULL(attr_value); 433 return impl_->getAttr(getOp(), input_idx, attr_value); 434 } GetExtraAttr()435 mindspore::HashMap<std::string, ValuePtr> GetExtraAttr() override { return extra_attr_; } GetDynamicShapeSupport()436 bool GetDynamicShapeSupport() override { return dynamic_shape_support_; } 437 438 private: 439 template <typename S> ConvertAny(const ValuePtr & value,const AnyTraits<S> &)440 static S ConvertAny(const ValuePtr &value, const AnyTraits<S> &) { 441 return ops::GetValueWithCheck<S>(value); 442 } 443 444 template <typename S> ConvertAny(const ValuePtr & value,const AnyTraits<std::vector<S>> &,size_t size,S default_val)445 static std::vector<S> ConvertAny(const ValuePtr &value, const AnyTraits<std::vector<S>> &, size_t size, 446 S default_val) { 447 auto v = ops::GetValueWithCheck<std::vector<S>>(value); 448 if (v.size() < size) { 449 v.insert(v.begin(), size - v.size(), default_val); 450 } 451 return v; 452 } 453 454 // specialization for reverse bool ConvertAny(const ValuePtr & value,const AnyTraits<bool> &,bool reverse)455 static bool ConvertAny(const ValuePtr &value, const AnyTraits<bool> &, bool reverse) { 456 return reverse != ops::GetValueWithCheck<bool>(value); 457 } 458 459 template <typename P, typename Q> ConvertAny(const ValuePtr & value,const AnyTraits<P> & traits_from,const AnyTraits<Q> & traits_to)460 static Q ConvertAny(const ValuePtr &value, const AnyTraits<P> &traits_from, const AnyTraits<Q> &traits_to) { 461 return ConvertAnyUtil(value, traits_from, traits_to); 462 } 463 464 // specialization for tensor ConvertAny(const ValuePtr & value,const AnyTraits<mindspore::tensor::Tensor> & traits)465 static GeTensor ConvertAny(const ValuePtr &value, const AnyTraits<mindspore::tensor::Tensor> &traits) { 466 // To-DO the format may read from ME tensor 467 return ConvertAnyUtil(value, traits); 468 } 469 470 // specialization for int ConvertAny(const ValuePtr & value,const AnyTraits<int64_t>)471 static int64_t ConvertAny(const ValuePtr &value, const AnyTraits<int64_t>) { 472 return ops::GetValueWithCheck<int64_t>(value); 473 } 474 475 // specialization for float ConvertAny(const ValuePtr & value,const AnyTraits<float>)476 static float ConvertAny(const ValuePtr &value, const AnyTraits<float>) { return GetCastFloatValue<float>(value); } 477 478 // specialization for int or tuple broadcast to Vector ConvertAny(const ValuePtr & value,const std::string & name,const AnyTraits<std::vector<int64_t>> anyTraitsInt)479 static std::vector<int64_t> ConvertAny(const ValuePtr &value, const std::string &name, 480 const AnyTraits<std::vector<int64_t>> anyTraitsInt) { 481 return ConvertAnyUtil(value, name, anyTraitsInt); 482 } 483 ConvertAny(const ValuePtr & value,const AnyTraits<std::vector<std::vector<int64_t>>>)484 static std::vector<std::vector<int64_t>> ConvertAny(const ValuePtr &value, 485 const AnyTraits<std::vector<std::vector<int64_t>>>) { 486 MS_EXCEPTION_IF_NULL(value); 487 MS_LOG(INFO) << "Value: " << value->type_name(); 488 std::vector<std::vector<int64_t>> list; 489 490 ValuePtrList valuelists; 491 if (value->isa<ValueTuple>()) { 492 auto vec = value->cast<ValueTuplePtr>(); 493 MS_EXCEPTION_IF_NULL(vec); 494 valuelists = vec->value(); 495 } else if (value->isa<ValueList>()) { 496 auto vec = value->cast<ValueListPtr>(); 497 MS_EXCEPTION_IF_NULL(vec); 498 valuelists = vec->value(); 499 } else { 500 MS_LOG(EXCEPTION) << "Value should be ValueTuple or ValueList, but got " << value->type_name(); 501 } 502 503 for (auto &it : valuelists) { 504 MS_EXCEPTION_IF_NULL(it); 505 std::vector<int64_t> sublist; 506 if (!it->isa<ValueTuple>()) { 507 if (it->type_name() != "ValueList") { 508 MS_LOG(EXCEPTION) << "It should be ValueTuple or ValueList, but got " << it->type_name(); 509 } 510 auto sub_vector = it->cast<ValueListPtr>(); 511 for (auto &item : sub_vector->value()) { 512 sublist.emplace_back(ops::GetValueWithCheck<int64_t>(item)); 513 } 514 } else { 515 auto sub_vector = it->cast<ValueTuplePtr>(); 516 for (auto &item : sub_vector->value()) { 517 sublist.emplace_back(ops::GetValueWithCheck<int64_t>(item)); 518 } 519 } 520 list.emplace_back(sublist); 521 } 522 return list; 523 } 524 ConvertAny(const ValuePtr & value,const AnyTraits<std::vector<std::vector<int64_t>>>,const AnyTraits<std::vector<int64_t>>)525 static std::vector<int64_t> ConvertAny(const ValuePtr &value, const AnyTraits<std::vector<std::vector<int64_t>>>, 526 const AnyTraits<std::vector<int64_t>>) { 527 MS_EXCEPTION_IF_NULL(value); 528 MS_LOG(DEBUG) << "Value: " << value->type_name(); 529 if (!value->isa<ValueSequence>()) { 530 MS_LOG(EXCEPTION) << "Value should be ValueSequence, but got " << value->type_name(); 531 } 532 auto vec = value->cast<ValueSequencePtr>(); 533 std::vector<int64_t> list; 534 for (auto &it : vec->value()) { 535 MS_EXCEPTION_IF_NULL(it); 536 if (!it->isa<ValueSequence>()) { 537 MS_LOG(EXCEPTION) << "It should be ValueSequence, but got " << it->type_name(); 538 } 539 auto sub_vector = it->cast<ValueSequencePtr>(); 540 for (auto &item : sub_vector->value()) { 541 list.emplace_back(ops::GetValueWithCheck<int64_t>(item)); 542 } 543 } 544 return list; 545 } 546 ConvertAny(const ValuePtr & value,const AnyTraits<std::vector<int64_t>>,size_t index)547 static int64_t ConvertAny(const ValuePtr &value, const AnyTraits<std::vector<int64_t>>, size_t index) { 548 MS_EXCEPTION_IF_NULL(value); 549 MS_LOG(DEBUG) << "Value: " << value->type_name(); 550 if (!value->isa<ValueSequence>()) { 551 MS_LOG(EXCEPTION) << "Value should be ValueSequence, but got " << value->type_name(); 552 } 553 std::vector<int64_t> list; 554 auto vec = value->cast<ValueSequencePtr>(); 555 MS_EXCEPTION_IF_NULL(vec); 556 for (auto &it : vec->value()) { 557 list.emplace_back(GetCastIntegralValue<int64_t>(it)); 558 } 559 if (index >= list.size()) { 560 MS_LOG(EXCEPTION) << "reg dyn_input_sizes index error, must less than " << list.size() << "but got " << index; 561 } 562 return list[index]; 563 } 564 ConvertAny(const ValuePtr & value,const AnyTraits<std::vector<int64_t>>,const AnyTraits<std::vector<int64_t>>)565 static std::vector<int64_t> ConvertAny(const ValuePtr &value, const AnyTraits<std::vector<int64_t>>, 566 const AnyTraits<std::vector<int64_t>>) { 567 MS_EXCEPTION_IF_NULL(value); 568 MS_LOG(INFO) << "Value: " << value->type_name(); 569 std::vector<int64_t> list; 570 if (value->isa<ValueSequence>()) { 571 auto vec = value->cast<ValueSequencePtr>(); 572 MS_EXCEPTION_IF_NULL(vec); 573 for (auto &it : vec->value()) { 574 list.emplace_back(GetCastIntegralValue<int64_t>(it)); 575 } 576 return list; 577 } 578 if (value->isa<Scalar>()) { 579 list.emplace_back(GetCastIntegralValue<int64_t>(value)); 580 return list; 581 } 582 if (value->isa<MeTensor>()) { 583 auto tensor_ptr = value->cast<MeTensorPtr>(); 584 MS_EXCEPTION_IF_NULL(tensor_ptr); 585 auto type = tensor_ptr->data_type(); 586 std::vector<int64_t> v; 587 if (type == kNumberTypeInt64) { 588 int64_t *data = static_cast<int64_t *>(tensor_ptr->data_c()); 589 auto size = tensor_ptr->Size() / sizeof(int64_t); 590 for (size_t i = 0; i < size; i++) { 591 (void)v.emplace_back(data[i]); 592 } 593 return v; 594 } 595 if (type == kNumberTypeInt32) { 596 int32_t *data = static_cast<int32_t *>(tensor_ptr->data_c()); 597 auto size = tensor_ptr->Size() / sizeof(int32_t); 598 for (size_t i = 0; i < size; i++) { 599 (void)v.emplace_back(IntToLong(data[i])); 600 } 601 return v; 602 } 603 } else { 604 return ops::GetValueWithCheck<std::vector<int64_t>>(value); 605 } 606 MS_LOG(EXCEPTION) << "Value should be ValueTuple or Scalar, but got " << value->type_name(); 607 } 608 ConvertAny(const ValuePtr & value,const AnyTraits<std::vector<int64_t>> anyTraitsVec,const AnyTraits<std::string> anyTraitsStr)609 static std::string ConvertAny(const ValuePtr &value, const AnyTraits<std::vector<int64_t>> anyTraitsVec, 610 const AnyTraits<std::string> anyTraitsStr) { 611 return ConvertAnyUtil(value, anyTraitsVec, anyTraitsStr); 612 } 613 ConvertAny(const ValuePtr & value,const AnyTraits<std::vector<float>> anyTraitsVec,const AnyTraits<float> anyTraitsFlo)614 static std::vector<float> ConvertAny(const ValuePtr &value, const AnyTraits<std::vector<float>> anyTraitsVec, 615 const AnyTraits<float> anyTraitsFlo) { 616 return ConvertAnyUtil(value, anyTraitsVec, anyTraitsFlo); 617 } 618 ConvertAny(const ValuePtr & value,const std::string & format,const AnyTraits<std::vector<int64_t>> anyTraitsVec,const AnyTraits<int64_t> anyTraitsInt)619 static std::vector<int64_t> ConvertAny(const ValuePtr &value, const std::string &format, 620 const AnyTraits<std::vector<int64_t>> anyTraitsVec, 621 const AnyTraits<int64_t> anyTraitsInt) { 622 return ConvertAnyUtil(value, format, anyTraitsVec, anyTraitsInt); 623 } 624 625 // convert value list for value tuple to vector 626 template <typename P, typename Q> ConvertAny(const ValuePtr & value,const AnyTraits<P> & anyTraitsP,const AnyTraits<std::vector<Q>> anyTraitsQ)627 static std::vector<Q> ConvertAny(const ValuePtr &value, const AnyTraits<P> &anyTraitsP, 628 const AnyTraits<std::vector<Q>> anyTraitsQ) { 629 return ConvertAnyUtil(value, anyTraitsP, anyTraitsQ); 630 } 631 ConvertAny(const ValuePtr & value,const AnyTraits<GeEnum>)632 static int64_t ConvertAny(const ValuePtr &value, const AnyTraits<GeEnum>) { 633 auto name = GetValue<std::string>(value); 634 auto it = enum_map_.find(name); 635 int v = 0; 636 if (it != enum_map_.end()) { 637 v = it->second; 638 } 639 return v; 640 } 641 ConvertAny(const ValuePtr & value,const AnyTraits<GEType> anyTraitsGE)642 static GeDataType ConvertAny(const ValuePtr &value, const AnyTraits<GEType> anyTraitsGE) { 643 return ConvertAnyUtil(value, anyTraitsGE); 644 } 645 ConvertAny(const ValuePtr & value,const AnyTraits<GEType> anyTraitsGE,const AnyTraits<int64_t> anyTraitsInt)646 static int64_t ConvertAny(const ValuePtr &value, const AnyTraits<GEType> anyTraitsGE, 647 const AnyTraits<int64_t> anyTraitsInt) { 648 return static_cast<int64_t>(ConvertAnyUtil(value, anyTraitsGE)); 649 } 650 ConvertAny(const ValuePtr & value,const AnyTraits<std::vector<GEType>> anyTraitsGE)651 static std::vector<GeDataType> ConvertAny(const ValuePtr &value, const AnyTraits<std::vector<GEType>> anyTraitsGE) { 652 return ConvertAnyUtil(value, anyTraitsGE); 653 } 654 ConvertAny(const ValuePtr & value,const AnyTraits<GEDataFormat> anyTraitsGE)655 static std::string ConvertAny(const ValuePtr &value, const AnyTraits<GEDataFormat> anyTraitsGE) { 656 return ConvertAnyUtil(value, anyTraitsGE); 657 } 658 ConvertAny(const ValuePtr & value,const AnyTraits<GEPadMod> anyTraitsGE)659 static std::string ConvertAny(const ValuePtr &value, const AnyTraits<GEPadMod> anyTraitsGE) { 660 return ConvertAnyUtil(value, anyTraitsGE); 661 } 662 ConvertAny(const ValuePtr & value,const AnyTraits<GEReduction> anyTraitsGE)663 static std::string ConvertAny(const ValuePtr &value, const AnyTraits<GEReduction> anyTraitsGE) { 664 return ConvertAnyUtil(value, anyTraitsGE); 665 } 666 ConvertAny(const ValuePtr & value,const AnyTraits<AscendQuantRoundMode> anyTraitsGE)667 static std::string ConvertAny(const ValuePtr &value, const AnyTraits<AscendQuantRoundMode> anyTraitsGE) { 668 return ConvertAnyUtil(value, anyTraitsGE); 669 } 670 ConvertAny(const ValuePtr & value,const AnyTraits<FASInputLayoutMode> anyTraitsGE)671 static std::string ConvertAny(const ValuePtr &value, const AnyTraits<FASInputLayoutMode> anyTraitsGE) { 672 return ConvertAnyUtil(value, anyTraitsGE); 673 } 674 ConvertAny(const ValuePtr & value,const AnyTraits<FFNActivationMode> anyTraitsGE)675 static std::string ConvertAny(const ValuePtr &value, const AnyTraits<FFNActivationMode> anyTraitsGE) { 676 return ConvertAnyUtil(value, anyTraitsGE); 677 } 678 ConvertAny(const ValuePtr & value,const AnyTraits<ScatterReduceMode> anyTraitsGE)679 static std::string ConvertAny(const ValuePtr &value, const AnyTraits<ScatterReduceMode> anyTraitsGE) { 680 return ConvertAnyUtil(value, anyTraitsGE); 681 } 682 ConvertAny(const ValuePtr & value,const AnyTraits<GECoordinateTransformMode> anyTraitsGE)683 static std::string ConvertAny(const ValuePtr &value, const AnyTraits<GECoordinateTransformMode> anyTraitsGE) { 684 return ConvertAnyUtil(value, anyTraitsGE); 685 } 686 ConvertAny(const ValuePtr & value,const AnyTraits<GEEnumToStr> enum_str,const std::vector<std::string> & enum_string)687 static std::string ConvertAny(const ValuePtr &value, const AnyTraits<GEEnumToStr> enum_str, 688 const std::vector<std::string> &enum_string) { 689 return ConvertAnyUtil(value, enum_str, enum_string); 690 } 691 692 // convert any value to tensor ConvertAny(const ValuePtr & value,const AnyTraits<ValueAny> anyTraitsValue)693 static GeTensor ConvertAny(const ValuePtr &value, const AnyTraits<ValueAny> anyTraitsValue) { 694 return ConvertAnyUtil(value, anyTraitsValue); 695 } 696 GetOutputSize(const TypePtr & type)697 size_t GetOutputSize(const TypePtr &type) const { 698 // NOTE: sparse tensor is subclass of tuple, the inheritance relationship is 699 // AbstractTuple 700 // +-- AbstractSparseTensor 701 // +--- AbstractCOOTensor composed of (indices, values, num_row, num_col) 702 // `--- AbstractCSRTensor composed of (index_ptr, indices, values, num_row, num_col) 703 constexpr size_t kCOOTensorOutputSize = 4; 704 constexpr size_t kCSRTensorOutputSize = 5; 705 if (!type->isa<Tuple>()) { 706 if (type->isa<COOTensorType>()) { 707 return kCOOTensorOutputSize; 708 } 709 if (type->isa<CSRTensorType>()) { 710 return kCSRTensorOutputSize; 711 } 712 return (type->isa<MonadType>() || type->isa<TypeNone>() || type->isa<TypeNull>()) ? 0 : 1; 713 } 714 size_t output_size = 0; 715 auto tuple_type = type->cast<std::shared_ptr<Tuple>>(); 716 MS_EXCEPTION_IF_NULL(tuple_type); 717 auto elements = tuple_type->elements(); 718 for (const auto &element : elements) { 719 if (element->isa<MonadType>() || element->isa<TypeNone>() || element->isa<TypeNull>()) { 720 continue; 721 } 722 output_size = output_size + GetOutputSize(element); 723 } 724 return output_size; 725 } 726 getOp()727 static OperatorPtr getOp() { 728 if (op_ == nullptr) { 729 if (!::ge::OperatorFactory::IsExistOp(op_type_)) { 730 MS_LOG(EXCEPTION) << "OperatorFactory is not exist, op type: " << op_type_; 731 } 732 auto op = ::ge::OperatorFactory::CreateOperator("", op_type_); 733 op_ = std::make_shared<Operator>(op); 734 } 735 return op_; 736 } 737 738 // func list used to get ge attr type 739 template <typename S> GetAttrType(const AnyTraits<S> &)740 static S GetAttrType(const AnyTraits<S> &) { 741 S ret{}; 742 return ret; 743 } 744 745 template <typename S> GetAttrType(const AnyTraits<std::vector<S>> &,size_t size,S default_val)746 static std::vector<S> GetAttrType(const AnyTraits<std::vector<S>> &, size_t size, S default_val) { 747 std::vector<S> ret{}; 748 return ret; 749 } 750 751 // specialization for reverse bool GetAttrType(const AnyTraits<bool> &,bool reverse)752 static bool GetAttrType(const AnyTraits<bool> &, bool reverse) { 753 bool ret = false; 754 return ret; 755 } 756 757 template <typename P, typename Q> GetAttrType(const AnyTraits<P> & traits_from,const AnyTraits<Q> & traits_to)758 static Q GetAttrType(const AnyTraits<P> &traits_from, const AnyTraits<Q> &traits_to) { 759 Q ret{}; 760 return ret; 761 } 762 763 // specialization for tensor GetAttrType(const AnyTraits<mindspore::tensor::Tensor> & traits)764 static GeTensor GetAttrType(const AnyTraits<mindspore::tensor::Tensor> &traits) { 765 GeTensor ret{}; 766 return ret; 767 } 768 769 // specialization for int GetAttrType(const AnyTraits<int64_t>)770 static int64_t GetAttrType(const AnyTraits<int64_t>) { 771 int64_t ret{1}; 772 return ret; 773 } 774 775 // specialization for float GetAttrType(const AnyTraits<float>)776 static float GetAttrType(const AnyTraits<float>) { 777 float ret{1.0}; 778 return ret; 779 } 780 GetAttrType(const AnyTraits<std::vector<std::vector<int64_t>>>)781 static std::vector<std::vector<int64_t>> GetAttrType(const AnyTraits<std::vector<std::vector<int64_t>>>) { 782 std::vector<std::vector<int64_t>> ret{}; 783 return ret; 784 } 785 GetAttrType(const AnyTraits<std::vector<std::vector<int64_t>>>,const AnyTraits<std::vector<int64_t>>)786 static std::vector<int64_t> GetAttrType(const AnyTraits<std::vector<std::vector<int64_t>>>, 787 const AnyTraits<std::vector<int64_t>>) { 788 std::vector<int64_t> ret{}; 789 return ret; 790 } 791 GetAttrType(const AnyTraits<std::vector<int64_t>>,size_t index)792 static int64_t GetAttrType(const AnyTraits<std::vector<int64_t>>, size_t index) { 793 int64_t ret{1}; 794 return ret; 795 } 796 GetAttrType(const AnyTraits<std::vector<int64_t>>,const AnyTraits<std::vector<int64_t>>)797 static std::vector<int64_t> GetAttrType(const AnyTraits<std::vector<int64_t>>, 798 const AnyTraits<std::vector<int64_t>>) { 799 std::vector<int64_t> ret{}; 800 return ret; 801 } 802 GetAttrType(const AnyTraits<std::vector<int64_t>> anyTraitsVec,const AnyTraits<std::string> anyTraitsStr)803 static std::string GetAttrType(const AnyTraits<std::vector<int64_t>> anyTraitsVec, 804 const AnyTraits<std::string> anyTraitsStr) { 805 std::string ret{}; 806 return ret; 807 } 808 GetAttrType(const AnyTraits<std::vector<float>> anyTraitsVec,const AnyTraits<float> anyTraitsFlo)809 static std::vector<float> GetAttrType(const AnyTraits<std::vector<float>> anyTraitsVec, 810 const AnyTraits<float> anyTraitsFlo) { 811 std::vector<float> ret{}; 812 return ret; 813 } 814 GetAttrType(const std::string & format,const AnyTraits<std::vector<int64_t>> anyTraitsVec,const AnyTraits<int64_t> anyTraitsInt)815 static std::vector<int64_t> GetAttrType(const std::string &format, const AnyTraits<std::vector<int64_t>> anyTraitsVec, 816 const AnyTraits<int64_t> anyTraitsInt) { 817 std::vector<int64_t> ret{}; 818 return ret; 819 } 820 821 // convert value list for value tuple to vector 822 template <typename P, typename Q> GetAttrType(const AnyTraits<P> & anyTraitsP,const AnyTraits<std::vector<Q>> anyTraitsQ)823 static std::vector<Q> GetAttrType(const AnyTraits<P> &anyTraitsP, const AnyTraits<std::vector<Q>> anyTraitsQ) { 824 std::vector<Q> ret{}; 825 return ret; 826 } 827 GetAttrType(const AnyTraits<GeEnum>)828 static int64_t GetAttrType(const AnyTraits<GeEnum>) { 829 int64_t ret{1}; 830 return ret; 831 } 832 GetAttrType(const AnyTraits<GEType> anyTraitsGE)833 static GeDataType GetAttrType(const AnyTraits<GEType> anyTraitsGE) { 834 GeDataType ret{}; 835 return ret; 836 } 837 GetAttrType(const AnyTraits<GEType> anyTraitsGE,const AnyTraits<int64_t> anyTraitsInt)838 static int64_t GetAttrType(const AnyTraits<GEType> anyTraitsGE, const AnyTraits<int64_t> anyTraitsInt) { 839 int64_t ret{1}; 840 return ret; 841 } 842 GetAttrType(const AnyTraits<std::vector<GEType>> anyTraitsGE)843 static std::vector<GeDataType> GetAttrType(const AnyTraits<std::vector<GEType>> anyTraitsGE) { 844 std::vector<GeDataType> ret{}; 845 return ret; 846 } 847 GetAttrType(const AnyTraits<GEDataFormat> anyTraitsGE)848 static std::string GetAttrType(const AnyTraits<GEDataFormat> anyTraitsGE) { 849 std::string ret{}; 850 return ret; 851 } 852 GetAttrType(const AnyTraits<AscendQuantRoundMode> anyTraitsGE)853 static std::string GetAttrType(const AnyTraits<AscendQuantRoundMode> anyTraitsGE) { 854 std::string ret{}; 855 return ret; 856 } 857 GetAttrType(const AnyTraits<FASInputLayoutMode> anyTraitsGE)858 static std::string GetAttrType(const AnyTraits<FASInputLayoutMode> anyTraitsGE) { 859 std::string ret{}; 860 return ret; 861 } 862 GetAttrType(const AnyTraits<FFNActivationMode> anyTraitsGE)863 static std::string GetAttrType(const AnyTraits<FFNActivationMode> anyTraitsGE) { 864 std::string ret{}; 865 return ret; 866 } 867 GetAttrType(const AnyTraits<ScatterReduceMode> anyTraitsGE)868 static std::string GetAttrType(const AnyTraits<ScatterReduceMode> anyTraitsGE) { 869 std::string ret{}; 870 return ret; 871 } 872 GetAttrType(const AnyTraits<GEPadMod> anyTraitsGE)873 static std::string GetAttrType(const AnyTraits<GEPadMod> anyTraitsGE) { 874 std::string ret{}; 875 return ret; 876 } 877 GetAttrType(const AnyTraits<GEReduction> anyTraitsGE)878 static std::string GetAttrType(const AnyTraits<GEReduction> anyTraitsGE) { 879 std::string ret{}; 880 return ret; 881 } 882 GetAttrType(const AnyTraits<GECoordinateTransformMode> anyTraitsGE)883 static std::string GetAttrType(const AnyTraits<GECoordinateTransformMode> anyTraitsGE) { 884 std::string ret{}; 885 return ret; 886 } 887 GetAttrType(const AnyTraits<GEEnumToStr> enum_str,const std::vector<std::string> & enum_string)888 static std::string GetAttrType(const AnyTraits<GEEnumToStr> enum_str, const std::vector<std::string> &enum_string) { 889 std::string ret{}; 890 return ret; 891 } 892 893 // convert any value to tensor GetAttrType(const AnyTraits<ValueAny> anyTraitsValue)894 static GeTensor GetAttrType(const AnyTraits<ValueAny> anyTraitsValue) { 895 GeTensor ret{}; 896 return ret; 897 } 898 899 static const mindspore::HashMap<int, InputDesc> input_map_; 900 static const mindspore::HashMap<int, DynInputDesc> dyn_input_map_; 901 // note: To keep the outputs in order, the 'output_map_' and 'cus_output_map_' must be std::map instead of Hashmap. 902 static const std::map<int, OutputDesc> output_map_; 903 static const mindspore::HashMap<int, DynOutputDesc> dyn_output_map_; 904 static const mindspore::HashMap<int, SubGraphDesc> subgraph_map_; 905 static const mindspore::HashMap<int, DynSubGraphDesc> dyn_subgraph_map_; 906 static const mindspore::HashMap<std::string, AttrDesc> attr_map_; 907 static const mindspore::HashMap<std::string, int> enum_map_; 908 // convert input from anf graph to Attr in Operators 909 static const mindspore::HashMap<unsigned int, AttrDesc> input_attr_map_; 910 static const mindspore::HashMap<std::string, std::string> attr_input_map_; 911 static const bool dynamic_shape_support_; 912 static mindspore::HashMap<std::string, mindspore::HashMap<int, std::string>> cus_input_map_; 913 static mindspore::HashMap<std::string, std::map<int, std::string>> cus_output_map_; 914 static const char op_type_[]; 915 std::string op_type_obj_; 916 mindspore::HashMap<std::string, ValuePtr> extra_attr_; 917 mindspore::HashMap<std::string, int> name_counts_; 918 const std::shared_ptr<OpAdapterImpl> impl_; 919 // cache the Operator to avoid memory leak caused by 'std::make_shared<OpType>()' 920 inline static OperatorPtr op_ = nullptr; 921 }; // namespace transform 922 923 template <typename T> 924 const mindspore::HashMap<int, InputDesc> OpAdapter<T>::input_map_; 925 template <typename T> 926 const mindspore::HashMap<int, DynInputDesc> OpAdapter<T>::dyn_input_map_; 927 template <typename T> 928 const std::map<int, OutputDesc> OpAdapter<T>::output_map_; 929 template <typename T> 930 const mindspore::HashMap<int, DynOutputDesc> OpAdapter<T>::dyn_output_map_; 931 template <typename T> 932 const mindspore::HashMap<int, SubGraphDesc> OpAdapter<T>::subgraph_map_; 933 template <typename T> 934 const mindspore::HashMap<int, DynSubGraphDesc> OpAdapter<T>::dyn_subgraph_map_; 935 template <typename T> 936 const mindspore::HashMap<std::string, AttrDesc> OpAdapter<T>::attr_map_; 937 template <typename T> 938 const mindspore::HashMap<std::string, int> OpAdapter<T>::enum_map_; 939 template <typename T> 940 const mindspore::HashMap<unsigned int, AttrDesc> OpAdapter<T>::input_attr_map_; 941 template <typename T> 942 const mindspore::HashMap<std::string, std::string> OpAdapter<T>::attr_input_map_; 943 template <typename T> 944 mindspore::HashMap<std::string, mindspore::HashMap<int, std::string>> OpAdapter<T>::cus_input_map_; 945 template <typename T> 946 mindspore::HashMap<std::string, std::map<int, std::string>> OpAdapter<T>::cus_output_map_; 947 template <typename T> 948 const bool OpAdapter<T>::dynamic_shape_support_{true}; 949 template <typename T> 950 const char OpAdapter<T>::op_type_[]{""}; 951 952 // specialization for method 953 } // namespace transform 954 } // namespace mindspore 955 #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_H_ 956