• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_H_
18 #define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_H_
19 
20 #include <memory>
21 #include <vector>
22 #include <string>
23 #include <unordered_map>
24 
25 #include "transform/graph_ir/op_adapter_util.h"
26 #include "utils/utils.h"
27 namespace mindspore {
28 namespace transform {
29 class OpAdapterImpl {
30  public:
OpAdapterImpl(const std::unordered_map<int,InputDesc> & input_map,const std::unordered_map<int,DynInputDesc> & dyn_input_map,const std::unordered_map<int,OutputDesc> & output_map,const std::unordered_map<int,DynOutputDesc> & dyn_output_map,const std::unordered_map<int,DynSubGraphDesc> & dyn_subgraph_map,const std::unordered_map<std::string,AttrDesc> & attr_map,const std::unordered_map<std::string,int> & enum_map,const std::unordered_map<unsigned int,AttrDesc> & input_attr_map,std::unordered_map<std::string,std::unordered_map<int,std::string>> * cus_input_map,std::unordered_map<std::string,std::unordered_map<int,std::string>> * cus_output_map,std::unordered_map<std::string,ValuePtr> * extra_attr,std::unordered_map<std::string,int> * name_counts,BaseOpAdapter * adpt)31   OpAdapterImpl(const std::unordered_map<int, InputDesc> &input_map,
32                 const std::unordered_map<int, DynInputDesc> &dyn_input_map,
33                 const std::unordered_map<int, OutputDesc> &output_map,
34                 const std::unordered_map<int, DynOutputDesc> &dyn_output_map,
35                 const std::unordered_map<int, DynSubGraphDesc> &dyn_subgraph_map,
36                 const std::unordered_map<std::string, AttrDesc> &attr_map,
37                 const std::unordered_map<std::string, int> &enum_map,
38                 const std::unordered_map<unsigned int, AttrDesc> &input_attr_map,
39                 std::unordered_map<std::string, std::unordered_map<int, std::string>> *cus_input_map,
40                 std::unordered_map<std::string, std::unordered_map<int, std::string>> *cus_output_map,
41                 std::unordered_map<std::string, ValuePtr> *extra_attr,
42                 std::unordered_map<std::string, int> *name_counts, BaseOpAdapter *adpt)
43       : input_map_(input_map),
44         dyn_input_map_(dyn_input_map),
45         output_map_(output_map),
46         dyn_output_map_(dyn_output_map),
47         dyn_subgraph_map_(dyn_subgraph_map),
48         attr_map_(attr_map),
49         enum_map_(enum_map),
50         input_attr_map_(input_attr_map),
51         cus_input_map_(cus_input_map),
52         cus_output_map_(cus_output_map),
53         extra_attr_(extra_attr),
54         name_counts_(name_counts),
55         adpt_(adpt) {
56     MS_EXCEPTION_IF_NULL(cus_input_map_);
57     MS_EXCEPTION_IF_NULL(cus_output_map_);
58     MS_EXCEPTION_IF_NULL(extra_attr_);
59     MS_EXCEPTION_IF_NULL(name_counts_);
60     MS_EXCEPTION_IF_NULL(adpt_);
61   }
~OpAdapterImpl()62   ~OpAdapterImpl() {}
63   bool IsCustomOp(const OperatorPtr &op);
64   Status GenerateCustomOpInputMap(const CusOperatorPtr &op, const PrimitivePtr &prim);
65   Status GenerateCustomOpOutputMap(const CusOperatorPtr &op, const PrimitivePtr &prim);
66   OperatorPtr GenerateCustomOp(const AnfNodePtr anf);
67   Status SetOpSubgraphFunc(const OperatorPtr &op, int index, const std::shared_ptr<std::vector<DfGraph>> &branches);
68   Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OperatorPtr &input);
69   Status SetNormalOpInput(const OperatorPtr &op, int index, const OperatorPtr &input);
70   int setInput(const OperatorPtr &op, int index, const OperatorPtr &input);
71   Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OutHandler &handle);
72   Status SetNormalOpInput(const OperatorPtr &op, int index, const OutHandler &handle);
73   int setInput(const OperatorPtr &op, int index, const OutHandler &handle);
74   int setInput(const OperatorPtr &op, int index, const std::shared_ptr<std::vector<OutHandler>> &handler_vec);
75   OutHandler getOutput(const OperatorPtr &op, int index);
76   OutHandler getCustomOutput(const OperatorPtr &op, int index);
77   OutHandler getNormalOutput(const OperatorPtr &op, int index);
78   Status UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type,
79                                 const std::string &format);
80   size_t GetCustomOpOutputSize(const CusOperatorPtr &cus_op);
81   std::shared_ptr<GeTensorDesc> CreateOutputDesc(const abstract::ShapePtr &shape_ptr, const TypePtr &type,
82                                                  const std::string &format);
83   Status UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type,
84                                const std::string &format);
85   std::shared_ptr<GeTensorDesc> CreateNodeDesc(const AnfNodePtr &node, const std::string &format);
86   void UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr &node, const std::string format);
87   void UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node, const std::string format);
88   void updateInputDesc(const OperatorPtr &op, const AnfNodePtr &node);
89   void updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type,
90                         const AnfNodePtr &node);
91   int setAttr(const OperatorPtr &op, const std::string &attr_key, const ValuePtr &attr_value);
92   int SetCustomOpAttr(const CusOperatorPtr &op, const PrimitivePtr &prim);
93   int SetNormalOpAttr(const OperatorPtr &op, const PrimitivePtr &prim);
94   int setAttr(const OperatorPtr &op, const PrimitivePtr &prim);
95   int setAttr(const OperatorPtr &op, const AnfNodePtr &node);
96 
97  private:
98   const std::unordered_map<int, InputDesc> &input_map_;
99   const std::unordered_map<int, DynInputDesc> &dyn_input_map_;
100   const std::unordered_map<int, OutputDesc> &output_map_;
101   const std::unordered_map<int, DynOutputDesc> &dyn_output_map_;
102   const std::unordered_map<int, DynSubGraphDesc> &dyn_subgraph_map_;
103   const std::unordered_map<std::string, AttrDesc> &attr_map_;
104   const std::unordered_map<std::string, int> &enum_map_;
105   const std::unordered_map<unsigned int, AttrDesc> &input_attr_map_;
106   std::unordered_map<std::string, std::unordered_map<int, std::string>> *const cus_input_map_;
107   std::unordered_map<std::string, std::unordered_map<int, std::string>> *const cus_output_map_;
108   std::unordered_map<std::string, ValuePtr> *const extra_attr_;
109   std::unordered_map<std::string, int> *const name_counts_;
110   BaseOpAdapter *const adpt_;
111 };
112 
113 template <typename T>
114 class OpAdapter : public BaseOpAdapter {
115  public:
116   using OpType = T;
OpAdapter()117   OpAdapter()
118       : impl_(std::make_shared<OpAdapterImpl>(input_map_, dyn_input_map_, output_map_, dyn_output_map_,
119                                               dyn_subgraph_map_, attr_map_, enum_map_, input_attr_map_, &cus_input_map_,
120                                               &cus_output_map_, &extra_attr_, &name_counts_, this)) {
121     MS_EXCEPTION_IF_NULL(impl_);
122   }
OpAdapter(const ExtraAttr & extra_attr)123   explicit OpAdapter(const ExtraAttr &extra_attr)
124       : extra_attr_(extra_attr),
125         impl_(std::make_shared<OpAdapterImpl>(input_map_, dyn_input_map_, output_map_, dyn_output_map_,
126                                               dyn_subgraph_map_, attr_map_, enum_map_, input_attr_map_, &cus_input_map_,
127                                               &cus_output_map_, &extra_attr_, &name_counts_, this)) {
128     MS_EXCEPTION_IF_NULL(impl_);
129   }
~OpAdapter()130   ~OpAdapter() override {}
131 
IsCustomOp(const OperatorPtr & op)132   bool IsCustomOp(const OperatorPtr &op) { return impl_->IsCustomOp(op); }
133 
GenerateCustomOpInputMap(const CusOperatorPtr & op,const PrimitivePtr & prim)134   Status GenerateCustomOpInputMap(const CusOperatorPtr &op, const PrimitivePtr &prim) {
135     return impl_->GenerateCustomOpInputMap(op, prim);
136   }
137 
GenerateCustomOpOutputMap(const CusOperatorPtr & op,const PrimitivePtr & prim)138   Status GenerateCustomOpOutputMap(const CusOperatorPtr &op, const PrimitivePtr &prim) {
139     return impl_->GenerateCustomOpOutputMap(op, prim);
140   }
141 
142   // Convert ME UserCustom AnfNode to GE CustomOp. And set it's attrs.
GenerateCustomOp(const AnfNodePtr anf)143   OperatorPtr GenerateCustomOp(const AnfNodePtr anf) { return impl_->GenerateCustomOp(anf); }
144 
GenerateNormalOp(const AnfNodePtr & anf)145   OperatorPtr GenerateNormalOp(const AnfNodePtr &anf) {
146     OperatorPtr op = nullptr;
147     // There are duplicate names in ANF graph, do not assign ANF node name to GE
148     // GE will generate unique name automatically
149     if (anf != nullptr && anf->fullname_with_scope() != "") {
150       MS_LOG(DEBUG) << anf->fullname_with_scope();
151       op = std::make_shared<OpType>(anf->fullname_with_scope());
152     } else {
153       MS_LOG(DEBUG) << "no fullname_with_scope";
154       op = std::make_shared<OpType>();
155     }
156 
157     // set dynamic output num if op use DYNAMIC_OUTPUT
158     if ((op != nullptr) && (!dyn_output_map_.empty()) && (anf != nullptr)) {
159       TypePtr type = anf->Type();
160       if (type == nullptr) {
161         MS_LOG(EXCEPTION) << "Dynamic output node:" << op->GetName() << "'s Type is a nullptr!";
162       }
163       size_t num = type->isa<Tuple>() ? (type->cast<std::shared_ptr<Tuple>>()->size()) : 1;
164       MS_LOG(INFO) << "create_dyn_output for node:" << anf->ToString() << ", type:" << type->ToString()
165                    << ", num:" << num;
166       dyn_output_map_.begin()->second.create_dyn_output(op, static_cast<unsigned int>(num));
167     }
168     return op;
169   }
170 
generate(const AnfNodePtr & anf)171   OperatorPtr generate(const AnfNodePtr &anf) override {
172     OperatorPtr op = nullptr;
173     if (IsCustomCNode(anf)) {
174       op = GenerateCustomOp(anf);
175     } else {
176       op = GenerateNormalOp(anf);
177     }
178     if (op == nullptr) {
179       MS_LOG(EXCEPTION) << "Can not generate op for " << anf->fullname_with_scope();
180     }
181     return op;
182   }
183 
generate(const std::string & op_name)184   OperatorPtr generate(const std::string &op_name) override { return std::make_shared<OpType>(op_name); }
185 
getInputMap()186   const std::unordered_map<int, InputDesc> &getInputMap() override { return input_map_; }
getInputAttrMap()187   const std::unordered_map<unsigned int, AttrDesc> &getInputAttrMap() override { return input_attr_map_; }
getDynInputMap()188   const std::unordered_map<int, DynInputDesc> &getDynInputMap() override { return dyn_input_map_; }
getOutputMap()189   const std::unordered_map<int, OutputDesc> &getOutputMap() override { return output_map_; }
getDynSubgraphMap()190   const std::unordered_map<int, DynSubGraphDesc> &getDynSubgraphMap() override { return dyn_subgraph_map_; }
191 
SetOpSubgraphFunc(const OperatorPtr & op,int index,std::shared_ptr<std::vector<DfGraph>> branches)192   Status SetOpSubgraphFunc(const OperatorPtr &op, int index, std::shared_ptr<std::vector<DfGraph>> branches) {
193     return impl_->SetOpSubgraphFunc(op, index, branches);
194   }
195 
setSubgraph(const OperatorPtr & op,int index,std::shared_ptr<std::vector<DfGraph>> branches)196   int setSubgraph(const OperatorPtr &op, int index, std::shared_ptr<std::vector<DfGraph>> branches) override {
197     return static_cast<int>(SetOpSubgraphFunc(op, index, branches));
198   }
199 
SetCustomOpInput(const CusOperatorPtr & op,int index,const OperatorPtr & input)200   Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OperatorPtr &input) {
201     return impl_->SetCustomOpInput(op, index, input);
202   }
203 
SetNormalOpInput(const OperatorPtr & op,int index,const OperatorPtr & input)204   Status SetNormalOpInput(const OperatorPtr &op, int index, const OperatorPtr &input) {
205     return impl_->SetNormalOpInput(op, index, input);
206   }
207 
setInput(const OperatorPtr & op,int index,const OperatorPtr & input)208   int setInput(const OperatorPtr &op, int index, const OperatorPtr &input) override {
209     return impl_->setInput(op, index, input);
210   }
211 
SetCustomOpInput(const CusOperatorPtr & op,int index,const OutHandler & handle)212   Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OutHandler &handle) {
213     return impl_->SetCustomOpInput(op, index, handle);
214   }
215 
SetNormalOpInput(const OperatorPtr & op,int index,const OutHandler & handle)216   Status SetNormalOpInput(const OperatorPtr &op, int index, const OutHandler &handle) {
217     return impl_->SetNormalOpInput(op, index, handle);
218   }
219 
setInput(const OperatorPtr & op,int index,const OutHandler & handle)220   int setInput(const OperatorPtr &op, int index, const OutHandler &handle) override {
221     return impl_->setInput(op, index, handle);
222   }
223 
setInput(const OperatorPtr & op,int index,const std::shared_ptr<std::vector<OutHandler>> & handler_vec)224   int setInput(const OperatorPtr &op, int index, const std::shared_ptr<std::vector<OutHandler>> &handler_vec) override {
225     return impl_->setInput(op, index, handler_vec);
226   }
227 
getOutput(const OperatorPtr & op,int index)228   OutHandler getOutput(const OperatorPtr &op, int index) override { return impl_->getOutput(op, index); }
229 
getCustomOutput(const OperatorPtr & op,int index)230   OutHandler getCustomOutput(const OperatorPtr &op, int index) { return impl_->getCustomOutput(op, index); }
231 
getNormalOutput(const OperatorPtr & op,int index)232   OutHandler getNormalOutput(const OperatorPtr &op, int index) { return impl_->getNormalOutput(op, index); }
233 
UpdateSingleOutputDesc(const OperatorPtr & op,const abstract::BaseShapePtr & shp,const TypePtr & type,const std::string & format)234   Status UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type,
235                                 const std::string &format) {
236     return impl_->UpdateSingleOutputDesc(op, shp, type, format);
237   }
238 
GetCustomOpOutputSize(const CusOperatorPtr & cus_op)239   size_t GetCustomOpOutputSize(const CusOperatorPtr &cus_op) { return impl_->GetCustomOpOutputSize(cus_op); }
240 
CreateOutputDesc(const abstract::ShapePtr & shape_ptr,const TypePtr & type,const std::string & format)241   std::shared_ptr<GeTensorDesc> CreateOutputDesc(const abstract::ShapePtr &shape_ptr, const TypePtr &type,
242                                                  const std::string &format) {
243     return impl_->CreateOutputDesc(shape_ptr, type, format);
244   }
245 
UpdateMultiOutputDesc(const OperatorPtr & op,const abstract::BaseShapePtr & shp,const TypePtr & type,const std::string & format)246   Status UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type,
247                                const std::string &format) {
248     return impl_->UpdateMultiOutputDesc(op, shp, type, format);
249   }
250 
CreateNodeDesc(const AnfNodePtr & node,const std::string & format)251   std::shared_ptr<GeTensorDesc> CreateNodeDesc(const AnfNodePtr &node, const std::string &format) {
252     return impl_->CreateNodeDesc(node, format);
253   }
254 
UpdateNormalOpInputDesc(const OperatorPtr & op,const AnfNodePtr node,const std::string format)255   void UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr node, const std::string format) {
256     return impl_->UpdateNormalOpInputDesc(op, node, format);
257   }
258 
UpdateCustomOpInputDesc(const CusOperatorPtr & op,const AnfNodePtr & node,const std::string format)259   void UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node, const std::string format) {
260     return impl_->UpdateCustomOpInputDesc(op, node, format);
261   }
262 
updateInputDesc(const OperatorPtr & op,const AnfNodePtr & node)263   void updateInputDesc(const OperatorPtr &op, const AnfNodePtr &node) { impl_->updateInputDesc(op, node); }
264 
updateOutputDesc(const OperatorPtr & op,const abstract::BaseShapePtr & shp,const TypePtr & type,const AnfNodePtr & node)265   void updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type,
266                         const AnfNodePtr &node) override {
267     impl_->updateOutputDesc(op, shp, type, node);
268   }
269 
setAttr(const OperatorPtr & op,const std::string & attrKey,const ValuePtr & attrValue)270   int setAttr(const OperatorPtr &op, const std::string &attrKey, const ValuePtr &attrValue) override {
271     return impl_->setAttr(op, attrKey, attrValue);
272   }
273 
SetCustomOpAttr(const CusOperatorPtr & op,const PrimitivePtr & prim)274   int SetCustomOpAttr(const CusOperatorPtr &op, const PrimitivePtr &prim) { return impl_->SetCustomOpAttr(op, prim); }
275 
SetNormalOpAttr(const OperatorPtr & op,const PrimitivePtr & prim)276   int SetNormalOpAttr(const OperatorPtr &op, const PrimitivePtr &prim) { return impl_->SetNormalOpAttr(op, prim); }
277 
setAttr(const OperatorPtr & op,const PrimitivePtr & prim)278   int setAttr(const OperatorPtr &op, const PrimitivePtr &prim) override { return impl_->setAttr(op, prim); }
279 
setAttr(const OperatorPtr & op,const AnfNodePtr & node)280   int setAttr(const OperatorPtr &op, const AnfNodePtr &node) override { return impl_->setAttr(op, node); }
281 
GetExtraAttr()282   std::unordered_map<std::string, ValuePtr> GetExtraAttr() override { return extra_attr_; }
283 
284  private:
285   template <typename S>
ConvertAny(const ValuePtr & value,const AnyTraits<S> &)286   static S ConvertAny(const ValuePtr &value, const AnyTraits<S> &) {
287     return GetValue<S>(value);
288   }
289 
290   // specialization for reverse bool
ConvertAny(const ValuePtr & value,const AnyTraits<bool> &,bool reverse)291   static bool ConvertAny(const ValuePtr &value, const AnyTraits<bool> &, bool reverse) {
292     return reverse != GetValue<bool>(value);
293   }
294 
295   template <typename P, typename Q>
ConvertAny(const ValuePtr & value,const AnyTraits<P> & traits_from,const AnyTraits<Q> & traits_to)296   static Q ConvertAny(const ValuePtr &value, const AnyTraits<P> &traits_from, const AnyTraits<Q> &traits_to) {
297     return ConvertAnyUtil(value, traits_from, traits_to);
298   }
299 
300   // specialization for tensor
ConvertAny(const ValuePtr & value,const AnyTraits<mindspore::tensor::Tensor> & traits)301   static GeTensor ConvertAny(const ValuePtr &value, const AnyTraits<mindspore::tensor::Tensor> &traits) {
302     // To-DO the format may read from ME tensor
303     return ConvertAnyUtil(value, traits);
304   }
305 
306   // specialization for int
ConvertAny(const ValuePtr & value,const AnyTraits<int64_t>)307   static int64_t ConvertAny(const ValuePtr &value, const AnyTraits<int64_t>) {
308     return static_cast<int64_t>(GetValue<int64_t>(value));
309   }
310 
311   // specialization for int or tuple broadcast to Vector
ConvertAny(const ValuePtr & value,const std::string & name,const AnyTraits<std::vector<int64_t>> anyTraitsInt)312   static std::vector<int64_t> ConvertAny(const ValuePtr &value, const std::string &name,
313                                          const AnyTraits<std::vector<int64_t>> anyTraitsInt) {
314     return ConvertAnyUtil(value, name, anyTraitsInt);
315   }
316 
ConvertAny(const ValuePtr & value,const AnyTraits<std::vector<std::vector<int64_t>>>)317   static std::vector<std::vector<int64_t>> ConvertAny(const ValuePtr &value,
318                                                       const AnyTraits<std::vector<std::vector<int64_t>>>) {
319     MS_EXCEPTION_IF_NULL(value);
320     MS_LOG(INFO) << "Value: " << value->type_name();
321     std::vector<std::vector<int64_t>> list;
322     if (!value->isa<ValueTuple>()) {
323       MS_LOG(EXCEPTION) << "Value should be ValueTuple, but got " << value->type_name();
324     }
325     auto vec = value->cast<ValueTuplePtr>();
326     MS_EXCEPTION_IF_NULL(vec);
327     for (auto &it : vec->value()) {
328       MS_EXCEPTION_IF_NULL(it);
329       if (!it->isa<ValueTuple>()) {
330         MS_LOG(EXCEPTION) << "It should be ValueTuple, but got " << it->type_name();
331       }
332       auto sub_vector = it->cast<ValueTuplePtr>();
333       std::vector<int64_t> sublist;
334       for (auto &item : sub_vector->value()) {
335         sublist.push_back(static_cast<int64_t>(GetValue<int64_t>(item)));
336       }
337       list.push_back(sublist);
338     }
339     return list;
340   }
341 
ConvertAny(const ValuePtr & value,const AnyTraits<std::vector<std::vector<int64_t>>>,const AnyTraits<std::vector<int64_t>>)342   static std::vector<int64_t> ConvertAny(const ValuePtr &value, const AnyTraits<std::vector<std::vector<int64_t>>>,
343                                          const AnyTraits<std::vector<int64_t>>) {
344     MS_EXCEPTION_IF_NULL(value);
345     MS_LOG(DEBUG) << "Value: " << value->type_name();
346     if (!value->isa<ValueList>()) {
347       MS_LOG(EXCEPTION) << "Value should be ValueList, but got " << value->type_name();
348     }
349     auto vec = value->cast<ValueListPtr>();
350     std::vector<int64_t> list;
351     for (auto &it : vec->value()) {
352       MS_EXCEPTION_IF_NULL(it);
353       if (!it->isa<ValueList>()) {
354         MS_LOG(EXCEPTION) << "It should be ValueList, but got " << it->type_name();
355       }
356       auto sub_vector = it->cast<ValueListPtr>();
357       for (auto &item : sub_vector->value()) {
358         list.push_back(static_cast<int64_t>(GetValue<int64_t>(item)));
359       }
360     }
361     return list;
362   }
363 
ConvertAny(const ValuePtr & value,const AnyTraits<std::vector<int64_t>>,const AnyTraits<std::vector<int64_t>>)364   static std::vector<int64_t> ConvertAny(const ValuePtr &value, const AnyTraits<std::vector<int64_t>>,
365                                          const AnyTraits<std::vector<int64_t>>) {
366     MS_EXCEPTION_IF_NULL(value);
367     MS_LOG(INFO) << "Value: " << value->type_name();
368     std::vector<int64_t> list;
369     if (value->isa<ValueSequeue>()) {
370       auto vec = value->cast<ValueSequeuePtr>();
371       MS_EXCEPTION_IF_NULL(vec);
372       for (auto &it : vec->value()) {
373         list.push_back(static_cast<int64_t>(GetValue<int64_t>(it)));
374       }
375       return list;
376     }
377     if (value->isa<Scalar>()) {
378       list.push_back(static_cast<int64_t>(GetValue<int64_t>(value)));
379       return list;
380     }
381     MS_LOG(EXCEPTION) << "Value should be ValueTuple or Scalar, but got " << value->type_name();
382   }
383 
ConvertAny(const ValuePtr & value,const AnyTraits<std::vector<int64_t>> anyTraitsVec,const AnyTraits<std::string> anyTraitsStr)384   static std::string ConvertAny(const ValuePtr &value, const AnyTraits<std::vector<int64_t>> anyTraitsVec,
385                                 const AnyTraits<std::string> anyTraitsStr) {
386     return ConvertAnyUtil(value, anyTraitsVec, anyTraitsStr);
387   }
388 
ConvertAny(const ValuePtr & value,const AnyTraits<std::vector<float>> anyTraitsVec,const AnyTraits<float> anyTraitsFlo)389   static std::vector<float> ConvertAny(const ValuePtr &value, const AnyTraits<std::vector<float>> anyTraitsVec,
390                                        const AnyTraits<float> anyTraitsFlo) {
391     return ConvertAnyUtil(value, anyTraitsVec, anyTraitsFlo);
392   }
393 
ConvertAny(const ValuePtr & value,const std::string & format,const AnyTraits<std::vector<int64_t>> anyTraitsVec,const AnyTraits<int64_t> anyTraitsInt)394   static std::vector<int64_t> ConvertAny(const ValuePtr &value, const std::string &format,
395                                          const AnyTraits<std::vector<int64_t>> anyTraitsVec,
396                                          const AnyTraits<int64_t> anyTraitsInt) {
397     return ConvertAnyUtil(value, format, anyTraitsVec, anyTraitsInt);
398   }
399 
400   // convert value list for value tuple to vector
401   template <typename P, typename Q>
ConvertAny(const ValuePtr & value,const AnyTraits<P> & anyTraitsP,const AnyTraits<std::vector<Q>> anyTraitsQ)402   static std::vector<Q> ConvertAny(const ValuePtr &value, const AnyTraits<P> &anyTraitsP,
403                                    const AnyTraits<std::vector<Q>> anyTraitsQ) {
404     return ConvertAnyUtil(value, anyTraitsP, anyTraitsQ);
405   }
406 
ConvertAny(const ValuePtr & value,const AnyTraits<GeEnum>)407   static int64_t ConvertAny(const ValuePtr &value, const AnyTraits<GeEnum>) {
408     auto name = GetValue<std::string>(value);
409     auto it = enum_map_.find(name);
410     int v = 0;
411     if (it != enum_map_.end()) {
412       v = it->second;
413     }
414     return v;
415   }
416 
ConvertAny(const ValuePtr & value,const AnyTraits<GEType> anyTraitsGE)417   static GeDataType ConvertAny(const ValuePtr &value, const AnyTraits<GEType> anyTraitsGE) {
418     return ConvertAnyUtil(value, anyTraitsGE);
419   }
420 
421   // convert any value to tensor
ConvertAny(const ValuePtr & value,const AnyTraits<AnyValue> anyTraitsValue)422   static GeTensor ConvertAny(const ValuePtr &value, const AnyTraits<AnyValue> anyTraitsValue) {
423     return ConvertAnyUtil(value, anyTraitsValue);
424   }
425 
426   static const std::unordered_map<int, InputDesc> input_map_;
427   static const std::unordered_map<int, DynInputDesc> dyn_input_map_;
428   static const std::unordered_map<int, OutputDesc> output_map_;
429   static const std::unordered_map<int, DynOutputDesc> dyn_output_map_;
430   static const std::unordered_map<int, DynSubGraphDesc> dyn_subgraph_map_;
431   static const std::unordered_map<std::string, AttrDesc> attr_map_;
432   static const std::unordered_map<std::string, int> enum_map_;
433   // convert input from anf graph to Attr in Operators
434   static const std::unordered_map<unsigned int, AttrDesc> input_attr_map_;
435   static std::unordered_map<std::string, std::unordered_map<int, std::string>> cus_input_map_;
436   static std::unordered_map<std::string, std::unordered_map<int, std::string>> cus_output_map_;
437   std::unordered_map<std::string, ValuePtr> extra_attr_;
438   std::unordered_map<std::string, int> name_counts_;
439   const std::shared_ptr<OpAdapterImpl> impl_;
440 };
441 
442 template <typename T>
443 const std::unordered_map<int, InputDesc> OpAdapter<T>::input_map_;
444 template <typename T>
445 const std::unordered_map<int, DynInputDesc> OpAdapter<T>::dyn_input_map_;
446 template <typename T>
447 const std::unordered_map<int, OutputDesc> OpAdapter<T>::output_map_;
448 template <typename T>
449 const std::unordered_map<int, DynOutputDesc> OpAdapter<T>::dyn_output_map_;
450 template <typename T>
451 const std::unordered_map<int, DynSubGraphDesc> OpAdapter<T>::dyn_subgraph_map_;
452 template <typename T>
453 const std::unordered_map<std::string, AttrDesc> OpAdapter<T>::attr_map_;
454 template <typename T>
455 const std::unordered_map<std::string, int> OpAdapter<T>::enum_map_;
456 template <typename T>
457 const std::unordered_map<unsigned int, AttrDesc> OpAdapter<T>::input_attr_map_;
458 template <typename T>
459 std::unordered_map<std::string, std::unordered_map<int, std::string>> OpAdapter<T>::cus_input_map_;
460 template <typename T>
461 std::unordered_map<std::string, std::unordered_map<int, std::string>> OpAdapter<T>::cus_output_map_;
462 
463 // specialization for method
464 }  // namespace transform
465 }  // namespace mindspore
466 
467 #endif  // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_H_
468