• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_H_
18 #define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_H_
19 
20 #include <memory>
21 #include <utility>
22 #include <vector>
23 #include <string>
24 #include <map>
25 #include "utils/hash_map.h"
26 #include "transform/graph_ir/op_adapter_util.h"
27 #include "transform/graph_ir/op_adapter_base.h"
28 #include "include/common/utils/utils.h"
29 #include "include/common/utils/anfalgo.h"
30 #include "ops/other_ops.h"
31 #include "ops/sequence_ops.h"
32 #include "ops/framework_ops.h"
33 #include "ops/op_utils.h"
34 namespace mindspore {
35 namespace transform {
36 class OpAdapterImpl {
37  public:
OpAdapterImpl(const mindspore::HashMap<int,InputDesc> & input_map,const mindspore::HashMap<int,DynInputDesc> & dyn_input_map,const std::map<int,OutputDesc> & output_map,const mindspore::HashMap<int,DynOutputDesc> & dyn_output_map,const mindspore::HashMap<int,SubGraphDesc> & subgraph_map,const mindspore::HashMap<int,DynSubGraphDesc> & dyn_subgraph_map,const mindspore::HashMap<std::string,AttrDesc> & attr_map,const mindspore::HashMap<std::string,int> & enum_map,const mindspore::HashMap<unsigned int,AttrDesc> & input_attr_map,const mindspore::HashMap<std::string,std::string> & attr_input_map,mindspore::HashMap<std::string,mindspore::HashMap<int,std::string>> * cus_input_map,mindspore::HashMap<std::string,std::map<int,std::string>> * cus_output_map,mindspore::HashMap<std::string,ValuePtr> * extra_attr,mindspore::HashMap<std::string,int> * name_counts,BaseOpAdapter * adpt)38   OpAdapterImpl(const mindspore::HashMap<int, InputDesc> &input_map,
39                 const mindspore::HashMap<int, DynInputDesc> &dyn_input_map, const std::map<int, OutputDesc> &output_map,
40                 const mindspore::HashMap<int, DynOutputDesc> &dyn_output_map,
41                 const mindspore::HashMap<int, SubGraphDesc> &subgraph_map,
42                 const mindspore::HashMap<int, DynSubGraphDesc> &dyn_subgraph_map,
43                 const mindspore::HashMap<std::string, AttrDesc> &attr_map,
44                 const mindspore::HashMap<std::string, int> &enum_map,
45                 const mindspore::HashMap<unsigned int, AttrDesc> &input_attr_map,
46                 const mindspore::HashMap<std::string, std::string> &attr_input_map,
47                 mindspore::HashMap<std::string, mindspore::HashMap<int, std::string>> *cus_input_map,
48                 mindspore::HashMap<std::string, std::map<int, std::string>> *cus_output_map,
49                 mindspore::HashMap<std::string, ValuePtr> *extra_attr,
50                 mindspore::HashMap<std::string, int> *name_counts, BaseOpAdapter *adpt)
51       : input_map_(input_map),
52         dyn_input_map_(dyn_input_map),
53         output_map_(output_map),
54         dyn_output_map_(dyn_output_map),
55         subgraph_map_(subgraph_map),
56         dyn_subgraph_map_(dyn_subgraph_map),
57         attr_map_(attr_map),
58         enum_map_(enum_map),
59         input_attr_map_(input_attr_map),
60         attr_input_map_(attr_input_map),
61         cus_input_map_(cus_input_map),
62         cus_output_map_(cus_output_map),
63         extra_attr_(extra_attr),
64         name_counts_(name_counts),
65         adpt_(adpt) {
66     MS_EXCEPTION_IF_NULL(cus_input_map_);
67     MS_EXCEPTION_IF_NULL(cus_output_map_);
68     MS_EXCEPTION_IF_NULL(extra_attr_);
69     MS_EXCEPTION_IF_NULL(name_counts_);
70     MS_EXCEPTION_IF_NULL(adpt_);
71   }
~OpAdapterImpl()72   ~OpAdapterImpl() {}
73   bool IsCustomOp(const OperatorPtr &op) const;
74   std::string GetCustomOpType(const PrimitivePtr &prim) const;
75   Status GenerateCustomOpInputMap(const CusOperatorPtr &op, const PrimitivePtr &prim);
76   Status GenerateCustomOpOutputMap(const CusOperatorPtr &op, const PrimitivePtr &prim);
77   OperatorPtr GenerateCustomOp(const AnfNodePtr anf);
78   Status SetOpSubgraphFunc(const OperatorPtr &op, const std::shared_ptr<std::vector<DfGraph>> &subgraphs);
79   Status SetOpSubgraphFunc(const OperatorPtr &op, int index, const std::shared_ptr<std::vector<DfGraph>> &branches);
80   Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OperatorPtr &input) const;
81   Status SetNormalOpInput(const OperatorPtr &op, int index, const OperatorPtr &input);
82   int setInput(const OperatorPtr &op, int index, const OperatorPtr &input);
83   Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OutHandler &handle) const;
84   Status SetNormalOpInput(const OperatorPtr &op, int index, const OutHandler &handle);
85   int setInput(const OperatorPtr &op, int index, const OutHandler &handle);
86   int setInput(const OperatorPtr &op, int index, const std::shared_ptr<std::vector<OutHandler>> &handler_vec,
87                bool use_create_byindex_func = false, size_t dyn_index = 0);
88   OutHandler getOutput(const OperatorPtr &op, int index);
89   std::vector<OutHandler> getOutputs(const OperatorPtr &op) const;
90   OutHandler getCustomOutput(const OperatorPtr &op, int index) const;
91   OutHandler getNormalOutput(const OperatorPtr &op, int index);
92   std::vector<OutHandler> getNormalOutputs(const OperatorPtr &op) const;
93   std::vector<OutHandler> getCustomOutputs(const OperatorPtr &op) const;
94   Status UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type,
95                                 const std::string &format);
96   size_t GetCustomOpOutputSize(const CusOperatorPtr &cus_op) const;
97   std::map<std::string, ValuePtr> GetNormalOpAttrList(const OperatorPtr &op, const AnfNodePtr &node) const;
98   std::map<std::string, ValuePtr> GetOpAttrList(const OperatorPtr &op) const;
99   std::shared_ptr<GeTensorDesc> CreateOutputDesc(const abstract::ShapePtr &shape_ptr, const TypePtr &type,
100                                                  const std::string &format) const;
101   Status UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type,
102                                const std::string &format);
103   std::shared_ptr<GeTensorDesc> CreateNodeDesc(const AnfNodePtr &node, const std::string &format) const;
104   void UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr &node, const std::string format);
105   void UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node, const std::string format) const;
106   void updateInputDesc(const OperatorPtr &op, const AnfNodePtr &node);
107   void updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type,
108                         const AnfNodePtr &node);
109   int setAttr(const OperatorPtr &op, const std::string &attr_key, const ValuePtr &attr_value);
110   int SetCustomOpAttr(const CusOperatorPtr &op, const PrimitivePtr &prim) const;
111   int SetNormalOpAttr(const OperatorPtr &op, const PrimitivePtr &prim);
112   int SetNoFoldingOpAttr(const OperatorPtr &op, const PrimitivePtr &prim);
113   int setAttr(const OperatorPtr &op, const PrimitivePtr &prim);
114   int setAttr(const OperatorPtr &op, const AnfNodePtr &node);
115   int setAttr(const OperatorPtr &op, const uint32_t &input_idx, const ValuePtr &attr_value);
116   int getAttr(const OperatorPtr &op, const std::string &attr_key, ValuePtr *attr_value);
117   int getAttr(const OperatorPtr &op, uint32_t input_idx, ValuePtr *attr_value);
118 
119  private:
120   const mindspore::HashMap<int, InputDesc> &input_map_;
121   const mindspore::HashMap<int, DynInputDesc> &dyn_input_map_;
122   const std::map<int, OutputDesc> &output_map_;
123   const mindspore::HashMap<int, DynOutputDesc> &dyn_output_map_;
124   const mindspore::HashMap<int, SubGraphDesc> &subgraph_map_;
125   const mindspore::HashMap<int, DynSubGraphDesc> &dyn_subgraph_map_;
126   const mindspore::HashMap<std::string, AttrDesc> &attr_map_;
127   const mindspore::HashMap<std::string, int> &enum_map_;
128   // NOTE: The key of input_attr_map_ is anf node index, so index 0 is primitive value node
129   const mindspore::HashMap<unsigned int, AttrDesc> &input_attr_map_;
130   const mindspore::HashMap<std::string, std::string> &attr_input_map_;
131   mindspore::HashMap<std::string, mindspore::HashMap<int, std::string>> *const cus_input_map_;
132   mindspore::HashMap<std::string, std::map<int, std::string>> *const cus_output_map_;
133   mindspore::HashMap<std::string, ValuePtr> *const extra_attr_;
134   mindspore::HashMap<std::string, int> *const name_counts_;
135   BaseOpAdapter *const adpt_;
136 };
137 
138 template <typename T>
139 class OpAdapter : public BaseOpAdapter {
140  public:
141   using OpType = T;
OpAdapter(std::string op_type_obj)142   explicit OpAdapter(std::string op_type_obj)
143       : op_type_obj_(std::move(op_type_obj)),
144         impl_(std::make_shared<OpAdapterImpl>(input_map_, dyn_input_map_, output_map_, dyn_output_map_, subgraph_map_,
145                                               dyn_subgraph_map_, attr_map_, enum_map_, input_attr_map_, attr_input_map_,
146                                               &cus_input_map_, &cus_output_map_, &extra_attr_, &name_counts_, this)) {
147     MS_EXCEPTION_IF_NULL(impl_);
148   }
OpAdapter(std::string op_type_obj,ExtraAttr extra_attr)149   explicit OpAdapter(std::string op_type_obj, ExtraAttr extra_attr)
150       : op_type_obj_(std::move(op_type_obj)),
151         extra_attr_(std::move(extra_attr)),
152         impl_(std::make_shared<OpAdapterImpl>(input_map_, dyn_input_map_, output_map_, dyn_output_map_, subgraph_map_,
153                                               dyn_subgraph_map_, attr_map_, enum_map_, input_attr_map_, attr_input_map_,
154                                               &cus_input_map_, &cus_output_map_, &extra_attr_, &name_counts_, this)) {
155     MS_EXCEPTION_IF_NULL(impl_);
156   }
157   ~OpAdapter() override = default;
158 
IsCustomOp(const OperatorPtr & op)159   bool IsCustomOp(const OperatorPtr &op) { return impl_->IsCustomOp(op); }
160 
GenerateCustomOpInputMap(const CusOperatorPtr & op,const PrimitivePtr & prim)161   Status GenerateCustomOpInputMap(const CusOperatorPtr &op, const PrimitivePtr &prim) {
162     return impl_->GenerateCustomOpInputMap(op, prim);
163   }
164 
GenerateCustomOpOutputMap(const CusOperatorPtr & op,const PrimitivePtr & prim)165   Status GenerateCustomOpOutputMap(const CusOperatorPtr &op, const PrimitivePtr &prim) {
166     return impl_->GenerateCustomOpOutputMap(op, prim);
167   }
168 
169   // Convert ME UserCustom AnfNode to GE CustomOp. And set it's attrs.
GenerateCustomOp(const AnfNodePtr anf)170   OperatorPtr GenerateCustomOp(const AnfNodePtr anf) { return impl_->GenerateCustomOp(anf); }
171 
GenerateNormalOp(const AnfNodePtr & anf)172   OperatorPtr GenerateNormalOp(const AnfNodePtr &anf) const {
173     OperatorPtr op = nullptr;
174     std::string op_name;
175     if (anf != nullptr && !anf->fullname_with_scope().empty()) {
176       op_name = anf->fullname_with_scope();
177     }
178     op = generate(op_name);
179 
180     // set dynamic output num if op use DYNAMIC_OUTPUT
181     if ((op != nullptr) && (!dyn_output_map_.empty()) && (anf != nullptr)) {
182       TypePtr type = anf->Type();
183       if (type == nullptr) {
184         MS_LOG(EXCEPTION) << "Dynamic output node:" << op->GetName() << "'s Type is a nullptr!";
185       }
186       auto num = GetOutputSize(type);
187 
188       auto judge_node = anf;
189       if (common::AnfAlgo::CheckPrimitiveType(anf, prim::kPrimReturn)) {
190         auto cnode = anf->cast<CNodePtr>();
191         MS_EXCEPTION_IF_NULL(cnode);
192         auto input_node = cnode->inputs()[1];
193         if (common::AnfAlgo::CheckPrimitiveType(judge_node, prim::kPrimMakeTuple)) {
194           judge_node = input_node;
195         }
196         auto judge_cnode = judge_node->cast<CNodePtr>();
197         if (judge_cnode != nullptr) {
198           auto inputs = judge_cnode->inputs();
199           for (const auto &input : inputs) {
200             if (common::AnfAlgo::IsNoOuputNode(input)) {
201               --num;
202             }
203           }
204         }
205       }
206 
207       if (common::AnfAlgo::CheckPrimitiveType(judge_node, prim::kPrimMakeTuple)) {
208         auto cnode = judge_node->cast<CNodePtr>();
209         if (cnode != nullptr) {
210           auto inputs = cnode->inputs();
211           for (const auto &input : inputs) {
212             if (common::AnfAlgo::IsNoOuputNode(input)) {
213               --num;
214             }
215           }
216         }
217       }
218 
219       MS_LOG(INFO) << "create_dyn_output for node:" << anf->fullname_with_scope() << ", type:" << type->ToString()
220                    << ", num:" << num;
221       dyn_output_map_.begin()->second.create_dyn_output(op, static_cast<unsigned int>(num));
222     }
223     return op;
224   }
225 
GenerateDynamicOutputOp(const AnfNodePtr & anf)226   OperatorPtr GenerateDynamicOutputOp(const AnfNodePtr &anf) const {
227     OperatorPtr op = nullptr;
228     std::string op_name;
229     if (anf != nullptr && !anf->fullname_with_scope().empty()) {
230       op_name = anf->fullname_with_scope();
231     }
232     op = generate(op_name);
233     return op;
234   }
235 
setDynamicOutputNum(const OperatorPtr & op,size_t dyn_output_size)236   void setDynamicOutputNum(const OperatorPtr &op, size_t dyn_output_size) override {
237     // set dynamic output num if op use DYNAMIC_OUTPUT
238     if ((op != nullptr) && (!dyn_output_map_.empty())) {
239       MS_LOG(DEBUG) << "create_dyn_output for node:" << op->GetName() << ", num:" << dyn_output_size;
240       dyn_output_map_.begin()->second.create_dyn_output(op, static_cast<unsigned int>(dyn_output_size));
241     }
242   }
243 
generate(const AnfNodePtr & anf)244   OperatorPtr generate(const AnfNodePtr &anf) override {
245     OperatorPtr op = nullptr;
246     if (IsCustomCNode(anf)) {
247       op = GenerateCustomOp(anf);
248     } else {
249       op = GenerateNormalOp(anf);
250     }
251     if (op == nullptr) {
252       MS_LOG(EXCEPTION) << "Can not generate op for " << anf->fullname_with_scope();
253     }
254     return op;
255   }
256 
generate(const std::string & op_name)257   OperatorPtr generate(const std::string &op_name) const override {
258     std::string op_name_fix = op_name;
259     if (op_name_fix.empty()) {
260       // There are duplicate names in ANF graph, do not assign ANF node name to GE
261       // GE will generate unique name automatically
262       static int64_t idx = 0;
263       op_name_fix = op_type_obj_ + "_NULL_" + std::to_string(idx++);
264     }
265     if (!::ge::OperatorFactory::IsExistOp(op_type_obj_.c_str())) {
266       MS_LOG(ERROR) << "OperatorFactory is not exist, op type: " << op_type_obj_;
267       return std::make_shared<Operator>(Operator(op_name_fix, op_type_obj_));
268     }
269     auto op = ::ge::OperatorFactory::CreateOperator(op_name_fix, op_type_obj_);
270     return std::make_shared<Operator>(op);
271   }
272 
generateDynOutputOp(const AnfNodePtr & anf)273   OperatorPtr generateDynOutputOp(const AnfNodePtr &anf) override {
274     OperatorPtr op = nullptr;
275     op = GenerateDynamicOutputOp(anf);
276     if (op == nullptr) {
277       MS_LOG(EXCEPTION) << "Can not generate op for " << anf->fullname_with_scope();
278     }
279     return op;
280   }
281 
getOpType()282   std::string getOpType() override { return op_type_obj_; }
283 
GetStaticOpType()284   static std::string GetStaticOpType() { return op_type_; }
285 
getInputMap()286   const mindspore::HashMap<int, InputDesc> &getInputMap() override { return input_map_; }
getInputAttrMap()287   const mindspore::HashMap<unsigned int, AttrDesc> &getInputAttrMap() override { return input_attr_map_; }
getAttrMap()288   const mindspore::HashMap<std::string, AttrDesc> &getAttrMap() override { return attr_map_; }
getAttrInputMap()289   const mindspore::HashMap<std::string, std::string> &getAttrInputMap() override { return attr_input_map_; }
getDynInputMap()290   const mindspore::HashMap<int, DynInputDesc> &getDynInputMap() override { return dyn_input_map_; }
getSubgraphMap()291   const mindspore::HashMap<int, SubGraphDesc> &getSubgraphMap() override { return subgraph_map_; }
getOutputMap()292   const std::map<int, OutputDesc> &getOutputMap() override { return output_map_; }
getDynOutputMap()293   const mindspore::HashMap<int, DynOutputDesc> &getDynOutputMap() override { return dyn_output_map_; }
getDynSubgraphMap()294   const mindspore::HashMap<int, DynSubGraphDesc> &getDynSubgraphMap() override { return dyn_subgraph_map_; }
GetNormalOpAttrList(const AnfNodePtr & node)295   std::map<std::string, ValuePtr> GetNormalOpAttrList(const AnfNodePtr &node) override {
296     return impl_->GetNormalOpAttrList(getOp(), node);
297   }
GetOpAttrList()298   std::map<std::string, ValuePtr> GetOpAttrList() override { return impl_->GetOpAttrList(getOp()); }
IsDynInputOp(uint64_t index)299   bool IsDynInputOp(uint64_t index) override { return dyn_input_map_.find(index) != dyn_input_map_.end(); }
IsDyOutputOp(uint64_t index)300   bool IsDyOutputOp(uint64_t index) override { return dyn_output_map_.find(index) != dyn_output_map_.end(); }
IsMultipleOutputOp(const AnfNodePtr & anf)301   bool IsMultipleOutputOp(const AnfNodePtr &anf) override {
302     if (IsCustomCNode(anf)) {
303       // Custom op
304       auto node = anf->cast<CNodePtr>();
305       MS_EXCEPTION_IF_NULL(node);
306       auto prim = GetValueNode<PrimitivePtr>(node->inputs().at(0));
307       MS_EXCEPTION_IF_NULL(prim);
308       auto op_type = impl_->GetCustomOpType(prim);
309       if (cus_output_map_.find(op_type) != cus_output_map_.end()) {
310         return cus_output_map_[op_type].size() > 1;
311       }
312       return false;
313     }
314     // Normal op
315     return output_map_.size() > 1;
316   }
317 
SetOpSubgraphFunc(const OperatorPtr & op,std::shared_ptr<std::vector<DfGraph>> subgraphs)318   Status SetOpSubgraphFunc(const OperatorPtr &op, std::shared_ptr<std::vector<DfGraph>> subgraphs) {
319     return impl_->SetOpSubgraphFunc(op, subgraphs);
320   }
321 
setSubgraph(const OperatorPtr & op,std::shared_ptr<std::vector<DfGraph>> subgraphs)322   void setSubgraph(const OperatorPtr &op, std::shared_ptr<std::vector<DfGraph>> subgraphs) override {
323     (void)SetOpSubgraphFunc(op, subgraphs);
324   }
325 
SetOpSubgraphFunc(const OperatorPtr & op,int index,const std::shared_ptr<std::vector<DfGraph>> & branches)326   Status SetOpSubgraphFunc(const OperatorPtr &op, int index, const std::shared_ptr<std::vector<DfGraph>> &branches) {
327     return impl_->SetOpSubgraphFunc(op, index, branches);
328   }
329 
setSubgraph(const OperatorPtr & op,int index,const std::shared_ptr<std::vector<DfGraph>> & branches)330   void setSubgraph(const OperatorPtr &op, int index, const std::shared_ptr<std::vector<DfGraph>> &branches) override {
331     (void)SetOpSubgraphFunc(op, index, branches);
332   }
333 
SetCustomOpInput(const CusOperatorPtr & op,int index,const OperatorPtr & input)334   Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OperatorPtr &input) {
335     return impl_->SetCustomOpInput(op, index, input);
336   }
337 
SetNormalOpInput(const OperatorPtr & op,int index,const OperatorPtr & input)338   Status SetNormalOpInput(const OperatorPtr &op, int index, const OperatorPtr &input) {
339     return impl_->SetNormalOpInput(op, index, input);
340   }
341 
setInput(const OperatorPtr & op,int index,const OperatorPtr & input)342   int setInput(const OperatorPtr &op, int index, const OperatorPtr &input) override {
343     return impl_->setInput(op, index, input);
344   }
345 
SetCustomOpInput(const CusOperatorPtr & op,int index,const OutHandler & handle)346   Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OutHandler &handle) {
347     return impl_->SetCustomOpInput(op, index, handle);
348   }
349 
SetNormalOpInput(const OperatorPtr & op,int index,const OutHandler & handle)350   Status SetNormalOpInput(const OperatorPtr &op, int index, const OutHandler &handle) {
351     return impl_->SetNormalOpInput(op, index, handle);
352   }
353 
setInput(const OperatorPtr & op,int index,const OutHandler & handle)354   int setInput(const OperatorPtr &op, int index, const OutHandler &handle) override {
355     return impl_->setInput(op, index, handle);
356   }
357 
358   int setInput(const OperatorPtr &op, int index, const std::shared_ptr<std::vector<OutHandler>> &handler_vec,
359                bool use_create_byindex_func = false, size_t dyn_index = 0) override {
360     return impl_->setInput(op, index, handler_vec, use_create_byindex_func, dyn_index);
361   }
362 
getOutput(const OperatorPtr & op,int index)363   OutHandler getOutput(const OperatorPtr &op, int index) override { return impl_->getOutput(op, index); }
364 
getOutputs(const OperatorPtr & op)365   std::vector<OutHandler> getOutputs(const OperatorPtr &op) override { return impl_->getOutputs(op); }
366 
getCustomOutput(const OperatorPtr & op,int index)367   OutHandler getCustomOutput(const OperatorPtr &op, int index) { return impl_->getCustomOutput(op, index); }
368 
getNormalOutput(const OperatorPtr & op,int index)369   OutHandler getNormalOutput(const OperatorPtr &op, int index) { return impl_->getNormalOutput(op, index); }
370 
UpdateSingleOutputDesc(const OperatorPtr & op,const abstract::BaseShapePtr & shp,const TypePtr & type,const std::string & format)371   Status UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type,
372                                 const std::string &format) {
373     return impl_->UpdateSingleOutputDesc(op, shp, type, format);
374   }
375 
GetCustomOpOutputSize(const CusOperatorPtr & cus_op)376   size_t GetCustomOpOutputSize(const CusOperatorPtr &cus_op) { return impl_->GetCustomOpOutputSize(cus_op); }
377 
CreateOutputDesc(const abstract::ShapePtr & shape_ptr,const TypePtr & type,const std::string & format)378   std::shared_ptr<GeTensorDesc> CreateOutputDesc(const abstract::ShapePtr &shape_ptr, const TypePtr &type,
379                                                  const std::string &format) {
380     return impl_->CreateOutputDesc(shape_ptr, type, format);
381   }
382 
UpdateMultiOutputDesc(const OperatorPtr & op,const abstract::BaseShapePtr & shp,const TypePtr & type,const std::string & format)383   Status UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type,
384                                const std::string &format) {
385     return impl_->UpdateMultiOutputDesc(op, shp, type, format);
386   }
387 
CreateNodeDesc(const AnfNodePtr & node,const std::string & format)388   std::shared_ptr<GeTensorDesc> CreateNodeDesc(const AnfNodePtr &node, const std::string &format) {
389     return impl_->CreateNodeDesc(node, format);
390   }
391 
UpdateNormalOpInputDesc(const OperatorPtr & op,const AnfNodePtr node,const std::string format)392   void UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr node, const std::string format) {
393     return impl_->UpdateNormalOpInputDesc(op, node, format);
394   }
395 
UpdateCustomOpInputDesc(const CusOperatorPtr & op,const AnfNodePtr & node,const std::string format)396   void UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node, const std::string format) {
397     return impl_->UpdateCustomOpInputDesc(op, node, format);
398   }
399 
updateInputDesc(const OperatorPtr & op,const AnfNodePtr & node)400   void updateInputDesc(const OperatorPtr &op, const AnfNodePtr &node) { impl_->updateInputDesc(op, node); }
401 
updateOutputDesc(const OperatorPtr & op,const abstract::BaseShapePtr & shp,const TypePtr & type,const AnfNodePtr & node)402   void updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type,
403                         const AnfNodePtr &node) override {
404     impl_->updateOutputDesc(op, shp, type, node);
405   }
406 
setAttr(const OperatorPtr & op,const std::string & attrKey,const ValuePtr & attrValue)407   int setAttr(const OperatorPtr &op, const std::string &attrKey, const ValuePtr &attrValue) override {
408     return impl_->setAttr(op, attrKey, attrValue);
409   }
410 
SetCustomOpAttr(const CusOperatorPtr & op,const PrimitivePtr & prim)411   int SetCustomOpAttr(const CusOperatorPtr &op, const PrimitivePtr &prim) { return impl_->SetCustomOpAttr(op, prim); }
412 
SetNormalOpAttr(const OperatorPtr & op,const PrimitivePtr & prim)413   int SetNormalOpAttr(const OperatorPtr &op, const PrimitivePtr &prim) { return impl_->SetNormalOpAttr(op, prim); }
414 
setAttr(const OperatorPtr & op,const PrimitivePtr & prim)415   int setAttr(const OperatorPtr &op, const PrimitivePtr &prim) override { return impl_->setAttr(op, prim); }
416 
setAttr(const OperatorPtr & op,const AnfNodePtr & node)417   int setAttr(const OperatorPtr &op, const AnfNodePtr &node) override { return impl_->setAttr(op, node); }
418 
setAttr(const std::string & attr_key,const ValuePtr & attr_value)419   int setAttr(const std::string &attr_key, const ValuePtr &attr_value) override {
420     return impl_->setAttr(getOp(), attr_key, attr_value);
421   }
422 
setAttr(const uint32_t & input_idx,const ValuePtr & attr_value)423   int setAttr(const uint32_t &input_idx, const ValuePtr &attr_value) override {
424     return impl_->setAttr(getOp(), input_idx, attr_value);
425   }
426 
getAttr(const std::string & attr_key,ValuePtr * attr_value)427   int getAttr(const std::string &attr_key, ValuePtr *attr_value) override {
428     MS_EXCEPTION_IF_NULL(attr_value);
429     return impl_->getAttr(getOp(), attr_key, attr_value);
430   }
getAttr(const uint32_t & input_idx,ValuePtr * attr_value)431   int getAttr(const uint32_t &input_idx, ValuePtr *attr_value) override {
432     MS_EXCEPTION_IF_NULL(attr_value);
433     return impl_->getAttr(getOp(), input_idx, attr_value);
434   }
GetExtraAttr()435   mindspore::HashMap<std::string, ValuePtr> GetExtraAttr() override { return extra_attr_; }
GetDynamicShapeSupport()436   bool GetDynamicShapeSupport() override { return dynamic_shape_support_; }
437 
438  private:
439   template <typename S>
ConvertAny(const ValuePtr & value,const AnyTraits<S> &)440   static S ConvertAny(const ValuePtr &value, const AnyTraits<S> &) {
441     return ops::GetValueWithCheck<S>(value);
442   }
443 
444   template <typename S>
ConvertAny(const ValuePtr & value,const AnyTraits<std::vector<S>> &,size_t size,S default_val)445   static std::vector<S> ConvertAny(const ValuePtr &value, const AnyTraits<std::vector<S>> &, size_t size,
446                                    S default_val) {
447     auto v = ops::GetValueWithCheck<std::vector<S>>(value);
448     if (v.size() < size) {
449       v.insert(v.begin(), size - v.size(), default_val);
450     }
451     return v;
452   }
453 
454   // specialization for reverse bool
ConvertAny(const ValuePtr & value,const AnyTraits<bool> &,bool reverse)455   static bool ConvertAny(const ValuePtr &value, const AnyTraits<bool> &, bool reverse) {
456     return reverse != ops::GetValueWithCheck<bool>(value);
457   }
458 
459   template <typename P, typename Q>
ConvertAny(const ValuePtr & value,const AnyTraits<P> & traits_from,const AnyTraits<Q> & traits_to)460   static Q ConvertAny(const ValuePtr &value, const AnyTraits<P> &traits_from, const AnyTraits<Q> &traits_to) {
461     return ConvertAnyUtil(value, traits_from, traits_to);
462   }
463 
464   // specialization for tensor
ConvertAny(const ValuePtr & value,const AnyTraits<mindspore::tensor::Tensor> & traits)465   static GeTensor ConvertAny(const ValuePtr &value, const AnyTraits<mindspore::tensor::Tensor> &traits) {
466     // To-DO the format may read from ME tensor
467     return ConvertAnyUtil(value, traits);
468   }
469 
470   // specialization for int
ConvertAny(const ValuePtr & value,const AnyTraits<int64_t>)471   static int64_t ConvertAny(const ValuePtr &value, const AnyTraits<int64_t>) {
472     return ops::GetValueWithCheck<int64_t>(value);
473   }
474 
475   // specialization for float
ConvertAny(const ValuePtr & value,const AnyTraits<float>)476   static float ConvertAny(const ValuePtr &value, const AnyTraits<float>) { return GetCastFloatValue<float>(value); }
477 
478   // specialization for int or tuple broadcast to Vector
ConvertAny(const ValuePtr & value,const std::string & name,const AnyTraits<std::vector<int64_t>> anyTraitsInt)479   static std::vector<int64_t> ConvertAny(const ValuePtr &value, const std::string &name,
480                                          const AnyTraits<std::vector<int64_t>> anyTraitsInt) {
481     return ConvertAnyUtil(value, name, anyTraitsInt);
482   }
483 
ConvertAny(const ValuePtr & value,const AnyTraits<std::vector<std::vector<int64_t>>>)484   static std::vector<std::vector<int64_t>> ConvertAny(const ValuePtr &value,
485                                                       const AnyTraits<std::vector<std::vector<int64_t>>>) {
486     MS_EXCEPTION_IF_NULL(value);
487     MS_LOG(INFO) << "Value: " << value->type_name();
488     std::vector<std::vector<int64_t>> list;
489 
490     ValuePtrList valuelists;
491     if (value->isa<ValueTuple>()) {
492       auto vec = value->cast<ValueTuplePtr>();
493       MS_EXCEPTION_IF_NULL(vec);
494       valuelists = vec->value();
495     } else if (value->isa<ValueList>()) {
496       auto vec = value->cast<ValueListPtr>();
497       MS_EXCEPTION_IF_NULL(vec);
498       valuelists = vec->value();
499     } else {
500       MS_LOG(EXCEPTION) << "Value should be ValueTuple or ValueList, but got " << value->type_name();
501     }
502 
503     for (auto &it : valuelists) {
504       MS_EXCEPTION_IF_NULL(it);
505       std::vector<int64_t> sublist;
506       if (!it->isa<ValueTuple>()) {
507         if (it->type_name() != "ValueList") {
508           MS_LOG(EXCEPTION) << "It should be ValueTuple or ValueList, but got " << it->type_name();
509         }
510         auto sub_vector = it->cast<ValueListPtr>();
511         for (auto &item : sub_vector->value()) {
512           sublist.emplace_back(ops::GetValueWithCheck<int64_t>(item));
513         }
514       } else {
515         auto sub_vector = it->cast<ValueTuplePtr>();
516         for (auto &item : sub_vector->value()) {
517           sublist.emplace_back(ops::GetValueWithCheck<int64_t>(item));
518         }
519       }
520       list.emplace_back(sublist);
521     }
522     return list;
523   }
524 
ConvertAny(const ValuePtr & value,const AnyTraits<std::vector<std::vector<int64_t>>>,const AnyTraits<std::vector<int64_t>>)525   static std::vector<int64_t> ConvertAny(const ValuePtr &value, const AnyTraits<std::vector<std::vector<int64_t>>>,
526                                          const AnyTraits<std::vector<int64_t>>) {
527     MS_EXCEPTION_IF_NULL(value);
528     MS_LOG(DEBUG) << "Value: " << value->type_name();
529     if (!value->isa<ValueSequence>()) {
530       MS_LOG(EXCEPTION) << "Value should be ValueSequence, but got " << value->type_name();
531     }
532     auto vec = value->cast<ValueSequencePtr>();
533     std::vector<int64_t> list;
534     for (auto &it : vec->value()) {
535       MS_EXCEPTION_IF_NULL(it);
536       if (!it->isa<ValueSequence>()) {
537         MS_LOG(EXCEPTION) << "It should be ValueSequence, but got " << it->type_name();
538       }
539       auto sub_vector = it->cast<ValueSequencePtr>();
540       for (auto &item : sub_vector->value()) {
541         list.emplace_back(ops::GetValueWithCheck<int64_t>(item));
542       }
543     }
544     return list;
545   }
546 
ConvertAny(const ValuePtr & value,const AnyTraits<std::vector<int64_t>>,size_t index)547   static int64_t ConvertAny(const ValuePtr &value, const AnyTraits<std::vector<int64_t>>, size_t index) {
548     MS_EXCEPTION_IF_NULL(value);
549     MS_LOG(DEBUG) << "Value: " << value->type_name();
550     if (!value->isa<ValueSequence>()) {
551       MS_LOG(EXCEPTION) << "Value should be ValueSequence, but got " << value->type_name();
552     }
553     std::vector<int64_t> list;
554     auto vec = value->cast<ValueSequencePtr>();
555     MS_EXCEPTION_IF_NULL(vec);
556     for (auto &it : vec->value()) {
557       list.emplace_back(GetCastIntegralValue<int64_t>(it));
558     }
559     if (index >= list.size()) {
560       MS_LOG(EXCEPTION) << "reg dyn_input_sizes index error, must less than " << list.size() << "but got " << index;
561     }
562     return list[index];
563   }
564 
ConvertAny(const ValuePtr & value,const AnyTraits<std::vector<int64_t>>,const AnyTraits<std::vector<int64_t>>)565   static std::vector<int64_t> ConvertAny(const ValuePtr &value, const AnyTraits<std::vector<int64_t>>,
566                                          const AnyTraits<std::vector<int64_t>>) {
567     MS_EXCEPTION_IF_NULL(value);
568     MS_LOG(INFO) << "Value: " << value->type_name();
569     std::vector<int64_t> list;
570     if (value->isa<ValueSequence>()) {
571       auto vec = value->cast<ValueSequencePtr>();
572       MS_EXCEPTION_IF_NULL(vec);
573       for (auto &it : vec->value()) {
574         list.emplace_back(GetCastIntegralValue<int64_t>(it));
575       }
576       return list;
577     }
578     if (value->isa<Scalar>()) {
579       list.emplace_back(GetCastIntegralValue<int64_t>(value));
580       return list;
581     }
582     if (value->isa<MeTensor>()) {
583       auto tensor_ptr = value->cast<MeTensorPtr>();
584       MS_EXCEPTION_IF_NULL(tensor_ptr);
585       auto type = tensor_ptr->data_type();
586       std::vector<int64_t> v;
587       if (type == kNumberTypeInt64) {
588         int64_t *data = static_cast<int64_t *>(tensor_ptr->data_c());
589         auto size = tensor_ptr->Size() / sizeof(int64_t);
590         for (size_t i = 0; i < size; i++) {
591           (void)v.emplace_back(data[i]);
592         }
593         return v;
594       }
595       if (type == kNumberTypeInt32) {
596         int32_t *data = static_cast<int32_t *>(tensor_ptr->data_c());
597         auto size = tensor_ptr->Size() / sizeof(int32_t);
598         for (size_t i = 0; i < size; i++) {
599           (void)v.emplace_back(IntToLong(data[i]));
600         }
601         return v;
602       }
603     } else {
604       return ops::GetValueWithCheck<std::vector<int64_t>>(value);
605     }
606     MS_LOG(EXCEPTION) << "Value should be ValueTuple or Scalar, but got " << value->type_name();
607   }
608 
ConvertAny(const ValuePtr & value,const AnyTraits<std::vector<int64_t>> anyTraitsVec,const AnyTraits<std::string> anyTraitsStr)609   static std::string ConvertAny(const ValuePtr &value, const AnyTraits<std::vector<int64_t>> anyTraitsVec,
610                                 const AnyTraits<std::string> anyTraitsStr) {
611     return ConvertAnyUtil(value, anyTraitsVec, anyTraitsStr);
612   }
613 
ConvertAny(const ValuePtr & value,const AnyTraits<std::vector<float>> anyTraitsVec,const AnyTraits<float> anyTraitsFlo)614   static std::vector<float> ConvertAny(const ValuePtr &value, const AnyTraits<std::vector<float>> anyTraitsVec,
615                                        const AnyTraits<float> anyTraitsFlo) {
616     return ConvertAnyUtil(value, anyTraitsVec, anyTraitsFlo);
617   }
618 
ConvertAny(const ValuePtr & value,const std::string & format,const AnyTraits<std::vector<int64_t>> anyTraitsVec,const AnyTraits<int64_t> anyTraitsInt)619   static std::vector<int64_t> ConvertAny(const ValuePtr &value, const std::string &format,
620                                          const AnyTraits<std::vector<int64_t>> anyTraitsVec,
621                                          const AnyTraits<int64_t> anyTraitsInt) {
622     return ConvertAnyUtil(value, format, anyTraitsVec, anyTraitsInt);
623   }
624 
625   // convert value list for value tuple to vector
626   template <typename P, typename Q>
ConvertAny(const ValuePtr & value,const AnyTraits<P> & anyTraitsP,const AnyTraits<std::vector<Q>> anyTraitsQ)627   static std::vector<Q> ConvertAny(const ValuePtr &value, const AnyTraits<P> &anyTraitsP,
628                                    const AnyTraits<std::vector<Q>> anyTraitsQ) {
629     return ConvertAnyUtil(value, anyTraitsP, anyTraitsQ);
630   }
631 
ConvertAny(const ValuePtr & value,const AnyTraits<GeEnum>)632   static int64_t ConvertAny(const ValuePtr &value, const AnyTraits<GeEnum>) {
633     auto name = GetValue<std::string>(value);
634     auto it = enum_map_.find(name);
635     int v = 0;
636     if (it != enum_map_.end()) {
637       v = it->second;
638     }
639     return v;
640   }
641 
ConvertAny(const ValuePtr & value,const AnyTraits<GEType> anyTraitsGE)642   static GeDataType ConvertAny(const ValuePtr &value, const AnyTraits<GEType> anyTraitsGE) {
643     return ConvertAnyUtil(value, anyTraitsGE);
644   }
645 
ConvertAny(const ValuePtr & value,const AnyTraits<GEType> anyTraitsGE,const AnyTraits<int64_t> anyTraitsInt)646   static int64_t ConvertAny(const ValuePtr &value, const AnyTraits<GEType> anyTraitsGE,
647                             const AnyTraits<int64_t> anyTraitsInt) {
648     return static_cast<int64_t>(ConvertAnyUtil(value, anyTraitsGE));
649   }
650 
ConvertAny(const ValuePtr & value,const AnyTraits<std::vector<GEType>> anyTraitsGE)651   static std::vector<GeDataType> ConvertAny(const ValuePtr &value, const AnyTraits<std::vector<GEType>> anyTraitsGE) {
652     return ConvertAnyUtil(value, anyTraitsGE);
653   }
654 
ConvertAny(const ValuePtr & value,const AnyTraits<GEDataFormat> anyTraitsGE)655   static std::string ConvertAny(const ValuePtr &value, const AnyTraits<GEDataFormat> anyTraitsGE) {
656     return ConvertAnyUtil(value, anyTraitsGE);
657   }
658 
ConvertAny(const ValuePtr & value,const AnyTraits<GEPadMod> anyTraitsGE)659   static std::string ConvertAny(const ValuePtr &value, const AnyTraits<GEPadMod> anyTraitsGE) {
660     return ConvertAnyUtil(value, anyTraitsGE);
661   }
662 
ConvertAny(const ValuePtr & value,const AnyTraits<GEReduction> anyTraitsGE)663   static std::string ConvertAny(const ValuePtr &value, const AnyTraits<GEReduction> anyTraitsGE) {
664     return ConvertAnyUtil(value, anyTraitsGE);
665   }
666 
ConvertAny(const ValuePtr & value,const AnyTraits<AscendQuantRoundMode> anyTraitsGE)667   static std::string ConvertAny(const ValuePtr &value, const AnyTraits<AscendQuantRoundMode> anyTraitsGE) {
668     return ConvertAnyUtil(value, anyTraitsGE);
669   }
670 
ConvertAny(const ValuePtr & value,const AnyTraits<FASInputLayoutMode> anyTraitsGE)671   static std::string ConvertAny(const ValuePtr &value, const AnyTraits<FASInputLayoutMode> anyTraitsGE) {
672     return ConvertAnyUtil(value, anyTraitsGE);
673   }
674 
ConvertAny(const ValuePtr & value,const AnyTraits<FFNActivationMode> anyTraitsGE)675   static std::string ConvertAny(const ValuePtr &value, const AnyTraits<FFNActivationMode> anyTraitsGE) {
676     return ConvertAnyUtil(value, anyTraitsGE);
677   }
678 
ConvertAny(const ValuePtr & value,const AnyTraits<ScatterReduceMode> anyTraitsGE)679   static std::string ConvertAny(const ValuePtr &value, const AnyTraits<ScatterReduceMode> anyTraitsGE) {
680     return ConvertAnyUtil(value, anyTraitsGE);
681   }
682 
ConvertAny(const ValuePtr & value,const AnyTraits<GECoordinateTransformMode> anyTraitsGE)683   static std::string ConvertAny(const ValuePtr &value, const AnyTraits<GECoordinateTransformMode> anyTraitsGE) {
684     return ConvertAnyUtil(value, anyTraitsGE);
685   }
686 
ConvertAny(const ValuePtr & value,const AnyTraits<GEEnumToStr> enum_str,const std::vector<std::string> & enum_string)687   static std::string ConvertAny(const ValuePtr &value, const AnyTraits<GEEnumToStr> enum_str,
688                                 const std::vector<std::string> &enum_string) {
689     return ConvertAnyUtil(value, enum_str, enum_string);
690   }
691 
692   // convert any value to tensor
ConvertAny(const ValuePtr & value,const AnyTraits<ValueAny> anyTraitsValue)693   static GeTensor ConvertAny(const ValuePtr &value, const AnyTraits<ValueAny> anyTraitsValue) {
694     return ConvertAnyUtil(value, anyTraitsValue);
695   }
696 
GetOutputSize(const TypePtr & type)697   size_t GetOutputSize(const TypePtr &type) const {
698     // NOTE: sparse tensor is subclass of tuple, the inheritance relationship is
699     //  AbstractTuple
700     //  +-- AbstractSparseTensor
701     //      +--- AbstractCOOTensor    composed of (indices, values, num_row, num_col)
702     //      `--- AbstractCSRTensor    composed of (index_ptr, indices, values, num_row, num_col)
703     constexpr size_t kCOOTensorOutputSize = 4;
704     constexpr size_t kCSRTensorOutputSize = 5;
705     if (!type->isa<Tuple>()) {
706       if (type->isa<COOTensorType>()) {
707         return kCOOTensorOutputSize;
708       }
709       if (type->isa<CSRTensorType>()) {
710         return kCSRTensorOutputSize;
711       }
712       return (type->isa<MonadType>() || type->isa<TypeNone>() || type->isa<TypeNull>()) ? 0 : 1;
713     }
714     size_t output_size = 0;
715     auto tuple_type = type->cast<std::shared_ptr<Tuple>>();
716     MS_EXCEPTION_IF_NULL(tuple_type);
717     auto elements = tuple_type->elements();
718     for (const auto &element : elements) {
719       if (element->isa<MonadType>() || element->isa<TypeNone>() || element->isa<TypeNull>()) {
720         continue;
721       }
722       output_size = output_size + GetOutputSize(element);
723     }
724     return output_size;
725   }
726 
getOp()727   static OperatorPtr getOp() {
728     if (op_ == nullptr) {
729       if (!::ge::OperatorFactory::IsExistOp(op_type_)) {
730         MS_LOG(EXCEPTION) << "OperatorFactory is not exist, op type: " << op_type_;
731       }
732       auto op = ::ge::OperatorFactory::CreateOperator("", op_type_);
733       op_ = std::make_shared<Operator>(op);
734     }
735     return op_;
736   }
737 
738   // func list used to get ge attr type
739   template <typename S>
GetAttrType(const AnyTraits<S> &)740   static S GetAttrType(const AnyTraits<S> &) {
741     S ret{};
742     return ret;
743   }
744 
745   template <typename S>
GetAttrType(const AnyTraits<std::vector<S>> &,size_t size,S default_val)746   static std::vector<S> GetAttrType(const AnyTraits<std::vector<S>> &, size_t size, S default_val) {
747     std::vector<S> ret{};
748     return ret;
749   }
750 
751   // specialization for reverse bool
GetAttrType(const AnyTraits<bool> &,bool reverse)752   static bool GetAttrType(const AnyTraits<bool> &, bool reverse) {
753     bool ret = false;
754     return ret;
755   }
756 
757   template <typename P, typename Q>
GetAttrType(const AnyTraits<P> & traits_from,const AnyTraits<Q> & traits_to)758   static Q GetAttrType(const AnyTraits<P> &traits_from, const AnyTraits<Q> &traits_to) {
759     Q ret{};
760     return ret;
761   }
762 
763   // specialization for tensor
GetAttrType(const AnyTraits<mindspore::tensor::Tensor> & traits)764   static GeTensor GetAttrType(const AnyTraits<mindspore::tensor::Tensor> &traits) {
765     GeTensor ret{};
766     return ret;
767   }
768 
769   // specialization for int
GetAttrType(const AnyTraits<int64_t>)770   static int64_t GetAttrType(const AnyTraits<int64_t>) {
771     int64_t ret{1};
772     return ret;
773   }
774 
775   // specialization for float
GetAttrType(const AnyTraits<float>)776   static float GetAttrType(const AnyTraits<float>) {
777     float ret{1.0};
778     return ret;
779   }
780 
GetAttrType(const AnyTraits<std::vector<std::vector<int64_t>>>)781   static std::vector<std::vector<int64_t>> GetAttrType(const AnyTraits<std::vector<std::vector<int64_t>>>) {
782     std::vector<std::vector<int64_t>> ret{};
783     return ret;
784   }
785 
GetAttrType(const AnyTraits<std::vector<std::vector<int64_t>>>,const AnyTraits<std::vector<int64_t>>)786   static std::vector<int64_t> GetAttrType(const AnyTraits<std::vector<std::vector<int64_t>>>,
787                                           const AnyTraits<std::vector<int64_t>>) {
788     std::vector<int64_t> ret{};
789     return ret;
790   }
791 
GetAttrType(const AnyTraits<std::vector<int64_t>>,size_t index)792   static int64_t GetAttrType(const AnyTraits<std::vector<int64_t>>, size_t index) {
793     int64_t ret{1};
794     return ret;
795   }
796 
GetAttrType(const AnyTraits<std::vector<int64_t>>,const AnyTraits<std::vector<int64_t>>)797   static std::vector<int64_t> GetAttrType(const AnyTraits<std::vector<int64_t>>,
798                                           const AnyTraits<std::vector<int64_t>>) {
799     std::vector<int64_t> ret{};
800     return ret;
801   }
802 
GetAttrType(const AnyTraits<std::vector<int64_t>> anyTraitsVec,const AnyTraits<std::string> anyTraitsStr)803   static std::string GetAttrType(const AnyTraits<std::vector<int64_t>> anyTraitsVec,
804                                  const AnyTraits<std::string> anyTraitsStr) {
805     std::string ret{};
806     return ret;
807   }
808 
GetAttrType(const AnyTraits<std::vector<float>> anyTraitsVec,const AnyTraits<float> anyTraitsFlo)809   static std::vector<float> GetAttrType(const AnyTraits<std::vector<float>> anyTraitsVec,
810                                         const AnyTraits<float> anyTraitsFlo) {
811     std::vector<float> ret{};
812     return ret;
813   }
814 
GetAttrType(const std::string & format,const AnyTraits<std::vector<int64_t>> anyTraitsVec,const AnyTraits<int64_t> anyTraitsInt)815   static std::vector<int64_t> GetAttrType(const std::string &format, const AnyTraits<std::vector<int64_t>> anyTraitsVec,
816                                           const AnyTraits<int64_t> anyTraitsInt) {
817     std::vector<int64_t> ret{};
818     return ret;
819   }
820 
821   // convert value list for value tuple to vector
822   template <typename P, typename Q>
GetAttrType(const AnyTraits<P> & anyTraitsP,const AnyTraits<std::vector<Q>> anyTraitsQ)823   static std::vector<Q> GetAttrType(const AnyTraits<P> &anyTraitsP, const AnyTraits<std::vector<Q>> anyTraitsQ) {
824     std::vector<Q> ret{};
825     return ret;
826   }
827 
GetAttrType(const AnyTraits<GeEnum>)828   static int64_t GetAttrType(const AnyTraits<GeEnum>) {
829     int64_t ret{1};
830     return ret;
831   }
832 
GetAttrType(const AnyTraits<GEType> anyTraitsGE)833   static GeDataType GetAttrType(const AnyTraits<GEType> anyTraitsGE) {
834     GeDataType ret{};
835     return ret;
836   }
837 
GetAttrType(const AnyTraits<GEType> anyTraitsGE,const AnyTraits<int64_t> anyTraitsInt)838   static int64_t GetAttrType(const AnyTraits<GEType> anyTraitsGE, const AnyTraits<int64_t> anyTraitsInt) {
839     int64_t ret{1};
840     return ret;
841   }
842 
GetAttrType(const AnyTraits<std::vector<GEType>> anyTraitsGE)843   static std::vector<GeDataType> GetAttrType(const AnyTraits<std::vector<GEType>> anyTraitsGE) {
844     std::vector<GeDataType> ret{};
845     return ret;
846   }
847 
GetAttrType(const AnyTraits<GEDataFormat> anyTraitsGE)848   static std::string GetAttrType(const AnyTraits<GEDataFormat> anyTraitsGE) {
849     std::string ret{};
850     return ret;
851   }
852 
GetAttrType(const AnyTraits<AscendQuantRoundMode> anyTraitsGE)853   static std::string GetAttrType(const AnyTraits<AscendQuantRoundMode> anyTraitsGE) {
854     std::string ret{};
855     return ret;
856   }
857 
GetAttrType(const AnyTraits<FASInputLayoutMode> anyTraitsGE)858   static std::string GetAttrType(const AnyTraits<FASInputLayoutMode> anyTraitsGE) {
859     std::string ret{};
860     return ret;
861   }
862 
GetAttrType(const AnyTraits<FFNActivationMode> anyTraitsGE)863   static std::string GetAttrType(const AnyTraits<FFNActivationMode> anyTraitsGE) {
864     std::string ret{};
865     return ret;
866   }
867 
GetAttrType(const AnyTraits<ScatterReduceMode> anyTraitsGE)868   static std::string GetAttrType(const AnyTraits<ScatterReduceMode> anyTraitsGE) {
869     std::string ret{};
870     return ret;
871   }
872 
GetAttrType(const AnyTraits<GEPadMod> anyTraitsGE)873   static std::string GetAttrType(const AnyTraits<GEPadMod> anyTraitsGE) {
874     std::string ret{};
875     return ret;
876   }
877 
GetAttrType(const AnyTraits<GEReduction> anyTraitsGE)878   static std::string GetAttrType(const AnyTraits<GEReduction> anyTraitsGE) {
879     std::string ret{};
880     return ret;
881   }
882 
GetAttrType(const AnyTraits<GECoordinateTransformMode> anyTraitsGE)883   static std::string GetAttrType(const AnyTraits<GECoordinateTransformMode> anyTraitsGE) {
884     std::string ret{};
885     return ret;
886   }
887 
GetAttrType(const AnyTraits<GEEnumToStr> enum_str,const std::vector<std::string> & enum_string)888   static std::string GetAttrType(const AnyTraits<GEEnumToStr> enum_str, const std::vector<std::string> &enum_string) {
889     std::string ret{};
890     return ret;
891   }
892 
893   // convert any value to tensor
GetAttrType(const AnyTraits<ValueAny> anyTraitsValue)894   static GeTensor GetAttrType(const AnyTraits<ValueAny> anyTraitsValue) {
895     GeTensor ret{};
896     return ret;
897   }
898 
899   static const mindspore::HashMap<int, InputDesc> input_map_;
900   static const mindspore::HashMap<int, DynInputDesc> dyn_input_map_;
901   // note: To keep the outputs in order, the 'output_map_' and 'cus_output_map_' must be std::map instead of Hashmap.
902   static const std::map<int, OutputDesc> output_map_;
903   static const mindspore::HashMap<int, DynOutputDesc> dyn_output_map_;
904   static const mindspore::HashMap<int, SubGraphDesc> subgraph_map_;
905   static const mindspore::HashMap<int, DynSubGraphDesc> dyn_subgraph_map_;
906   static const mindspore::HashMap<std::string, AttrDesc> attr_map_;
907   static const mindspore::HashMap<std::string, int> enum_map_;
908   // convert input from anf graph to Attr in Operators
909   static const mindspore::HashMap<unsigned int, AttrDesc> input_attr_map_;
910   static const mindspore::HashMap<std::string, std::string> attr_input_map_;
911   static const bool dynamic_shape_support_;
912   static mindspore::HashMap<std::string, mindspore::HashMap<int, std::string>> cus_input_map_;
913   static mindspore::HashMap<std::string, std::map<int, std::string>> cus_output_map_;
914   static const char op_type_[];
915   std::string op_type_obj_;
916   mindspore::HashMap<std::string, ValuePtr> extra_attr_;
917   mindspore::HashMap<std::string, int> name_counts_;
918   const std::shared_ptr<OpAdapterImpl> impl_;
919   // cache the Operator to avoid memory leak caused by 'std::make_shared<OpType>()'
920   inline static OperatorPtr op_ = nullptr;
921 };  // namespace transform
922 
923 template <typename T>
924 const mindspore::HashMap<int, InputDesc> OpAdapter<T>::input_map_;
925 template <typename T>
926 const mindspore::HashMap<int, DynInputDesc> OpAdapter<T>::dyn_input_map_;
927 template <typename T>
928 const std::map<int, OutputDesc> OpAdapter<T>::output_map_;
929 template <typename T>
930 const mindspore::HashMap<int, DynOutputDesc> OpAdapter<T>::dyn_output_map_;
931 template <typename T>
932 const mindspore::HashMap<int, SubGraphDesc> OpAdapter<T>::subgraph_map_;
933 template <typename T>
934 const mindspore::HashMap<int, DynSubGraphDesc> OpAdapter<T>::dyn_subgraph_map_;
935 template <typename T>
936 const mindspore::HashMap<std::string, AttrDesc> OpAdapter<T>::attr_map_;
937 template <typename T>
938 const mindspore::HashMap<std::string, int> OpAdapter<T>::enum_map_;
939 template <typename T>
940 const mindspore::HashMap<unsigned int, AttrDesc> OpAdapter<T>::input_attr_map_;
941 template <typename T>
942 const mindspore::HashMap<std::string, std::string> OpAdapter<T>::attr_input_map_;
943 template <typename T>
944 mindspore::HashMap<std::string, mindspore::HashMap<int, std::string>> OpAdapter<T>::cus_input_map_;
945 template <typename T>
946 mindspore::HashMap<std::string, std::map<int, std::string>> OpAdapter<T>::cus_output_map_;
947 template <typename T>
948 const bool OpAdapter<T>::dynamic_shape_support_{true};
949 template <typename T>
950 const char OpAdapter<T>::op_type_[]{""};
951 
952 // specialization for method
953 }  // namespace transform
954 }  // namespace mindspore
955 #endif  // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_H_
956